feat: add automatic key failover for AI Bridge Anthropic (#24836)

## Description

Adds automatic key failover for centralized Anthropic provider. When a key pool is configured, each upstream call walks the pool and tries keys in order until one succeeds or the pool is exhausted. Keys are marked **temporary** on 429 (with cooldown from `Retry-After`) and **permanent** on 401/403. Errors that aren't key-specific don't trigger failover. Each agentic-loop iteration gets its own fresh walker, so a tool-call continuation can fail over independently of the initial request.

BYOK is unchanged: BYOK requests run as a single attempt with no failover.

## Changes

- `config.Anthropic` carries a `KeyPool`. `Key` remains for BYOK X-Api-Key set per interception.
- Blocking interceptor: walks the pool, marks keys on key-specific failures, returns on first success or non-failover error.
- Streaming interceptor: per-iteration walker. Pre-stream failures fail over to the next key; mid-stream errors are relayed as SSE events.
- New `keypool` error types: `TransientExhaustionError` (carries soonest cooldown) and `ErrPermanentExhaustion`. Replace the prior `ErrAllKeysExhausted`.
- Error responses now consistently include the outer `"type": "error"` field.

## Related Issues

Related to: https://github.com/coder/internal/issues/1446
Related to: https://linear.app/codercom/issue/AIGOV-197/aibridge-automatic-key-failover-for-bridged-and-passthrough-routes

## Follow-up PRs

- Bedrock multi-key support.
- Refactor provider vs interceptor config separation.
- Record the actually-used key in the interception credential hint after failover.

> [!NOTE]
> Initially generated by Claude Opus 4.7, modified and reviewed by @ssncferreira
This commit is contained in:
Susana Ferreira
2026-05-07 14:57:44 +01:00
committed by GitHub
parent 273e828442
commit f1155ac4d7
17 changed files with 2313 additions and 121 deletions
+18 -1
View File
@@ -1,6 +1,10 @@
package config
import "time"
import (
"time"
"github.com/coder/coder/v2/aibridge/keypool"
)
const (
ProviderAnthropic = "anthropic"
@@ -8,11 +12,24 @@ const (
ProviderCopilot = "copilot"
)
// Anthropic carries configuration for an Anthropic provider.
//
// Authentication is mutually exclusive across these three fields,
// set per interception in the provider's CreateInterceptor:
// - KeyPool: centralized requests with automatic key failover.
// - Key: BYOK with X-Api-Key (single attempt, no failover).
// - BYOKBearerToken: BYOK with Authorization Bearer (single
// attempt, no failover).
//
// TODO(ssncferreira): consolidate the three authentication
// fields into a single abstraction per
// https://github.com/coder/aibridge/issues/266.
type Anthropic struct {
// Name is the provider instance name. If empty, defaults to "anthropic".
Name string
BaseURL string
Key string
KeyPool *keypool.Pool
APIDumpDir string
CircuitBreaker *CircuitBreaker
SendActorHeaders bool
+82 -28
View File
@@ -5,7 +5,9 @@ import (
"encoding/json"
"errors"
"fmt"
"math"
"net/http"
"strconv"
"strings"
"time"
@@ -26,6 +28,7 @@ import (
aibcontext "github.com/coder/coder/v2/aibridge/context"
"github.com/coder/coder/v2/aibridge/intercept"
"github.com/coder/coder/v2/aibridge/intercept/apidump"
"github.com/coder/coder/v2/aibridge/keypool"
"github.com/coder/coder/v2/aibridge/mcp"
"github.com/coder/coder/v2/aibridge/recorder"
"github.com/coder/coder/v2/aibridge/tracing"
@@ -202,19 +205,28 @@ func (i *interceptionBase) isSmallFastModel() bool {
return strings.Contains(i.reqPayload.model(), "haiku")
}
// newMessagesService builds the SDK service used for upstream
// calls. BYOK auth is set here. Centralized auth is set
// per-attempt by the failover loop.
func (i *interceptionBase) newMessagesService(ctx context.Context, opts ...option.RequestOption) (anthropic.MessageService, error) {
// BYOK with access token uses Authorization: Bearer.
// Otherwise use X-Api-Key (centralized or BYOK with personal API key).
if i.cfg.BYOKBearerToken != "" {
i.logger.Debug(ctx, "using byok access token auth",
slog.F("bearer_hint", utils.MaskSecret(i.cfg.BYOKBearerToken)),
)
opts = append(opts, option.WithAuthToken(i.cfg.BYOKBearerToken))
} else {
i.logger.Debug(ctx, "using api key auth",
slog.F("api_key_hint", utils.MaskSecret(i.cfg.Key)),
)
opts = append(opts, option.WithAPIKey(i.cfg.Key))
// TODO(ssncferreira): validate auth is configured per
// https://github.com/coder/aibridge/issues/266.
// BYOK auth.
if i.cfg.KeyPool == nil {
if i.cfg.BYOKBearerToken != "" {
// BYOK Bearer: Authorization header.
i.logger.Debug(ctx, "using byok access token auth",
slog.F("bearer_hint", utils.MaskSecret(i.cfg.BYOKBearerToken)),
)
opts = append(opts, option.WithAuthToken(i.cfg.BYOKBearerToken))
} else {
// BYOK X-Api-Key.
i.logger.Debug(ctx, "using api key auth",
slog.F("api_key_hint", utils.MaskSecret(i.cfg.Key)),
)
opts = append(opts, option.WithAPIKey(i.cfg.Key))
}
}
opts = append(opts, option.WithBaseURL(i.cfg.BaseURL))
@@ -427,6 +439,10 @@ func (i *interceptionBase) writeUpstreamError(w http.ResponseWriter, antErr *res
}
w.Header().Set("Content-Type", "application/json")
// Set Retry-After when a cooldown is configured.
if antErr.RetryAfter > 0 {
w.Header().Set("Retry-After", strconv.Itoa(int(math.Ceil(antErr.RetryAfter.Seconds()))))
}
w.WriteHeader(antErr.StatusCode)
out, err := json.Marshal(antErr)
@@ -503,6 +519,49 @@ func accumulateUsage(dest, src any) {
}
}
// For centralized requests, markKeyOnError extracts an
// Anthropic SDK error from err and marks the key based on
// its status code. Returns true if the status was a key-specific
// failover trigger so callers can retry with the next key.
func (i *interceptionBase) markKeyOnError(ctx context.Context, key *keypool.Key, err error) bool {
if i.cfg.KeyPool == nil {
return false
}
var apiErr *anthropic.Error
if !errors.As(err, &apiErr) {
return false
}
return keypool.MarkKeyOnStatus(
ctx, key, apiErr.Response,
i.logger, i.providerName,
)
}
// processKeyPoolError translates a keypool exhaustion error
// into a developer-facing responseError shaped for the Anthropic
// API. Returns nil if err is not an exhaustion error.
func processKeyPoolError(err error) *responseError {
var transient *keypool.TransientKeyPoolError
switch {
case errors.As(err, &transient):
return newErrorResponse(
"all configured keys are rate-limited",
string(constant.ValueOf[constant.RateLimitError]()),
http.StatusTooManyRequests,
transient.RetryAfter,
)
case errors.Is(err, keypool.ErrPermanentKeyPool):
return newErrorResponse(
"all configured keys failed authentication",
string(constant.ValueOf[constant.APIError]()),
http.StatusBadGateway,
0,
)
default:
return nil
}
}
func getErrorResponse(err error) *responseError {
var apierr *anthropic.Error
if !errors.As(err, &apierr) {
@@ -510,7 +569,7 @@ func getErrorResponse(err error) *responseError {
}
msg := apierr.Error()
typ := string(constant.ValueOf[constant.APIError]())
errType := string(constant.ValueOf[constant.APIError]())
var detail *anthropic.APIErrorObject
if field, ok := apierr.JSON.ExtraFields["error"]; ok {
@@ -518,19 +577,10 @@ func getErrorResponse(err error) *responseError {
}
if detail != nil {
msg = detail.Message
typ = string(detail.Type)
errType = string(detail.Type)
}
return &responseError{
ErrorResponse: &anthropic.ErrorResponse{
Error: anthropic.ErrorObjectUnion{
Message: msg,
Type: typ,
},
Type: constant.ValueOf[constant.Error](),
},
StatusCode: apierr.StatusCode,
}
return newErrorResponse(msg, errType, apierr.StatusCode, keypool.ParseRetryAfter(apierr.Response))
}
var _ error = &responseError{}
@@ -538,17 +588,21 @@ var _ error = &responseError{}
type responseError struct {
*anthropic.ErrorResponse
StatusCode int `json:"-"`
StatusCode int `json:"-"`
RetryAfter time.Duration `json:"-"`
}
func newErrorResponse(msg error) *responseError {
func newErrorResponse(msg, errType string, status int, retryAfter time.Duration) *responseError {
return &responseError{
ErrorResponse: &shared.ErrorResponse{
Error: shared.ErrorObjectUnion{
Message: msg.Error(),
Type: "error",
Message: msg,
Type: errType,
},
Type: constant.ValueOf[constant.Error](),
},
StatusCode: status,
RetryAfter: retryAfter,
}
}
+190
View File
@@ -3,18 +3,24 @@ package messages //nolint:testpackage // tests unexported internals
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/anthropics/anthropic-sdk-go"
"github.com/anthropics/anthropic-sdk-go/shared/constant"
mcpgo "github.com/mark3labs/mcp-go/mcp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/aibridge/config"
"github.com/coder/coder/v2/aibridge/keypool"
"github.com/coder/coder/v2/aibridge/mcp"
"github.com/coder/coder/v2/aibridge/utils"
"github.com/coder/quartz"
)
func TestScanForCorrelatingToolCallID(t *testing.T) {
@@ -991,3 +997,187 @@ func TestFilterBedrockBetaFlags(t *testing.T) {
})
}
}
func TestProcessKeyPoolError(t *testing.T) {
t.Parallel()
tests := []struct {
name string
err error
expectedNil bool
expectedStatus int
expectedRetryAfter time.Duration
}{
{
// Transient with valid keys present: 429, no Retry-After.
name: "transient_zero_retry_after",
err: &keypool.TransientKeyPoolError{},
expectedStatus: http.StatusTooManyRequests,
expectedRetryAfter: 0,
},
{
// Transient with cooldown: 429, Retry-After set.
name: "transient_with_retry_after",
err: &keypool.TransientKeyPoolError{RetryAfter: 5 * time.Second},
expectedStatus: http.StatusTooManyRequests,
expectedRetryAfter: 5 * time.Second,
},
{
// Permanent: 502 api_error.
name: "permanent_returns_502",
err: keypool.ErrPermanentKeyPool,
expectedStatus: http.StatusBadGateway,
},
{
// Anything else: not a pool-exhaustion error.
name: "non_pool_exhaustion_error_returns_nil",
err: xerrors.New("some other error"),
expectedNil: true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got := processKeyPoolError(tc.err)
if tc.expectedNil {
require.Nil(t, got)
return
}
require.NotNil(t, got)
assert.Equal(t, tc.expectedStatus, got.StatusCode)
assert.Equal(t, tc.expectedRetryAfter, got.RetryAfter)
})
}
}
func TestMarkKeyOnError(t *testing.T) {
t.Parallel()
tests := []struct {
name string
err error
expectedReturn bool
expectedState keypool.KeyState
}{
{
// Not an *anthropic.Error: no status code to act on.
name: "non_api_error_returns_false",
err: xerrors.New("network failure"),
expectedReturn: false,
expectedState: keypool.KeyStateValid,
},
{
// Rate-limited: temporary cooldown.
name: "429_marks_temporary",
err: &anthropic.Error{StatusCode: http.StatusTooManyRequests, Response: &http.Response{StatusCode: http.StatusTooManyRequests}},
expectedReturn: true,
expectedState: keypool.KeyStateTemporary,
},
{
// Auth failure: mark permanent.
name: "401_marks_permanent",
err: &anthropic.Error{StatusCode: http.StatusUnauthorized, Response: &http.Response{StatusCode: http.StatusUnauthorized}},
expectedReturn: true,
expectedState: keypool.KeyStatePermanent,
},
{
// Auth forbidden: mark permanent.
name: "403_marks_permanent",
err: &anthropic.Error{StatusCode: http.StatusForbidden, Response: &http.Response{StatusCode: http.StatusForbidden}},
expectedReturn: true,
expectedState: keypool.KeyStatePermanent,
},
{
// Server errors are not key-specific.
name: "500_does_not_mark",
err: &anthropic.Error{StatusCode: http.StatusInternalServerError, Response: &http.Response{StatusCode: http.StatusInternalServerError}},
expectedReturn: false,
expectedState: keypool.KeyStateValid,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
pool, err := keypool.New([]string{"key-0"}, quartz.NewMock(t))
require.NoError(t, err)
key, err := pool.Walker().Next()
require.NoError(t, err)
base := &interceptionBase{cfg: config.Anthropic{KeyPool: pool}, logger: slog.Make()}
got := base.markKeyOnError(context.Background(), key, tc.err)
assert.Equal(t, tc.expectedReturn, got)
assert.Equal(t, tc.expectedState, key.State())
})
}
}
func TestWriteUpstreamError(t *testing.T) {
t.Parallel()
tests := []struct {
name string
respErr *responseError
expectStatus int
// Empty string means the header should be absent.
expectRetryAfter string
// Substring expected in the marshaled body. Empty means no body check.
expectBodyContains string
}{
{
// Standard error: status and JSON body written.
name: "writes_status_and_body",
respErr: newErrorResponse("upstream failed", "api_error", http.StatusBadGateway, 0),
expectStatus: http.StatusBadGateway,
expectBodyContains: `"upstream failed"`,
},
{
// Whole-second retryAfter: emitted as integer seconds.
name: "retry_after_in_seconds",
respErr: newErrorResponse("rate limited", "rate_limit_error", http.StatusTooManyRequests, 60*time.Second),
expectStatus: http.StatusTooManyRequests,
expectRetryAfter: "60",
},
{
// 500ms rounds up to Retry-After: 1.
name: "retry_after_500ms_rounds_up_to_one",
respErr: newErrorResponse("rate limited", "rate_limit_error", http.StatusTooManyRequests, 500*time.Millisecond),
expectStatus: http.StatusTooManyRequests,
expectRetryAfter: "1",
},
{
// 200ms rounds up to Retry-After: 1.
name: "retry_after_200ms_rounds_up_to_one",
respErr: newErrorResponse("rate limited", "rate_limit_error", http.StatusTooManyRequests, 200*time.Millisecond),
expectStatus: http.StatusTooManyRequests,
expectRetryAfter: "1",
},
{
// Negative retryAfter: header omitted.
name: "negative_retry_after_omits_header",
respErr: newErrorResponse("rate limited", "rate_limit_error", http.StatusTooManyRequests, -1*time.Second),
expectStatus: http.StatusTooManyRequests,
expectRetryAfter: "",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
base := &interceptionBase{logger: slog.Make()}
w := httptest.NewRecorder()
base.writeUpstreamError(w, tc.respErr)
assert.Equal(t, tc.expectStatus, w.Code, "status code")
assert.Equal(t, "application/json", w.Header().Get("Content-Type"), "Content-Type header")
assert.Equal(t, tc.expectRetryAfter, w.Header().Get("Retry-After"), "Retry-After header")
assert.Contains(t, w.Body.String(), `"type":"error"`, "outer error envelope")
if tc.expectBodyContains != "" {
assert.Contains(t, w.Body.String(), tc.expectBodyContains, "response body")
}
})
}
}
+54 -3
View File
@@ -112,6 +112,13 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req
return xerrors.Errorf("upstream connection closed: %w", err)
}
// The failover loop may return a keypool exhaustion
// error. Check before the SDK-error path.
if keyErr := processKeyPoolError(err); keyErr != nil {
i.writeUpstreamError(w, keyErr)
return xerrors.Errorf("key pool exhausted: %w", err)
}
if antErr := getErrorResponse(err); antErr != nil {
i.writeUpstreamError(w, antErr)
return xerrors.Errorf("anthropic API error: %w", err)
@@ -334,9 +341,53 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req
return nil
}
func (i *BlockingInterception) newMessage(ctx context.Context, svc anthropic.MessageService) (_ *anthropic.Message, outErr error) {
ctx, span := i.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...))
// newMessage routes between BYOK (single attempt) and centralized
// failover.
func (i *BlockingInterception) newMessage(ctx context.Context, svc anthropic.MessageService) (*anthropic.Message, error) {
// BYOK: single attempt, no failover.
if i.cfg.KeyPool == nil {
return i.newMessageWithKey(ctx, svc)
}
return i.newMessageWithKeyFailover(ctx, svc)
}
// newMessageWithKey performs a single upstream call.
func (i *BlockingInterception) newMessageWithKey(ctx context.Context, svc anthropic.MessageService, extraOpts ...option.RequestOption) (_ *anthropic.Message, outErr error) {
_, span := i.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...))
defer tracing.EndSpanErr(span, &outErr)
return svc.New(ctx, anthropic.MessageNewParams{}, i.withBody())
opts := append([]option.RequestOption{i.withBody()}, extraOpts...)
return svc.New(ctx, anthropic.MessageNewParams{}, opts...)
}
// newMessageWithKeyFailover walks the centralized key pool,
// trying each key until one succeeds or the pool is exhausted.
// Keys are marked temporary on 429 and permanent on 401/403.
// Errors that aren't key-specific don't trigger failover and
// are returned to the caller.
func (i *BlockingInterception) newMessageWithKeyFailover(ctx context.Context, svc anthropic.MessageService) (*anthropic.Message, error) {
// TODO(ssncferreira): update the interception's credential
// hint with the actually-used key (the successful key on
// success, the last tried key on failure) in the upstack PR.
walker := i.cfg.KeyPool.Walker()
for {
key, err := walker.Next()
if err != nil {
return nil, err
}
msg, err := i.newMessageWithKey(ctx, svc,
option.WithAPIKey(key.Value()),
// Disable SDK retries because the failover loop
// handles retries via key rotation.
option.WithMaxRetries(0),
)
// Key-specific failure: try the next key.
if i.markKeyOnError(ctx, key, err) {
continue
}
// Either success (msg, nil) or a non-key error (nil, err):
// nothing to retry, return as-is.
return msg, err
}
}
@@ -0,0 +1,457 @@
package messages //nolint:testpackage // tests unexported internals
import (
"io"
"net/http"
"net/http/httptest"
"sync"
"sync/atomic"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/aibridge/config"
"github.com/coder/coder/v2/aibridge/intercept"
"github.com/coder/coder/v2/aibridge/internal/testutil"
"github.com/coder/coder/v2/aibridge/keypool"
"github.com/coder/coder/v2/aibridge/mcp"
"github.com/coder/quartz"
)
// Common request and Anthropic-shaped response bodies.
const (
requestBody = `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`
successBody = `{"id":"msg_01","type":"message","role":"assistant","content":[{"type":"text","text":"Hello!"}],"model":"claude-opus-4-5","stop_reason":"end_turn","usage":{"input_tokens":10,"output_tokens":5}}`
toolUseBody = `{"id":"msg_01","type":"message","role":"assistant","content":[{"type":"tool_use","id":"toolu_01","name":"test_tool","input":{}}],"model":"claude-opus-4-5","stop_reason":"tool_use","usage":{"input_tokens":10,"output_tokens":5}}`
rateLimitBody = `{"type":"error","error":{"type":"rate_limit_error","message":"rate limited"}}`
authErrorBody = `{"type":"error","error":{"type":"authentication_error","message":"invalid key"}}`
serverErrorBody = `{"type":"error","error":{"type":"api_error","message":"server error"}}`
)
type upstreamResponse struct {
statusCode int
body string
headers map[string]string
}
func TestBlockingInterception_KeyFailover(t *testing.T) {
t.Parallel()
tests := []struct {
name string
// Centralized pool keys. Empty when byokKey is set.
keys []string
// BYOK key. Empty when keys is set.
byokKey string
// Scripted upstream responses keyed by X-Api-Key.
responses map[string]upstreamResponse
expectedRequestCount int32
expectedStatusCode int
expectedRetryAfter string
// Expected key states after the request, by index in keys.
expectedKeyStates []keypool.KeyState
}{
{
// Given: 1 valid key returning 200.
// Then: 1 request, 200 response, key remains valid.
name: "single_valid_key",
keys: []string{"k0"},
responses: map[string]upstreamResponse{
"k0": {statusCode: http.StatusOK, body: successBody},
},
expectedRequestCount: 1,
expectedStatusCode: http.StatusOK,
expectedKeyStates: []keypool.KeyState{keypool.KeyStateValid},
},
{
// Given: 2 keys; key-0 returns 429, key-1 returns 200.
// Then: 2 requests, 200 response, key-0 temporary, key-1 valid.
name: "failover_after_429",
keys: []string{"k0", "k1"},
responses: map[string]upstreamResponse{
"k0": {
statusCode: http.StatusTooManyRequests,
headers: map[string]string{"Retry-After": "5"},
body: rateLimitBody,
},
"k1": {statusCode: http.StatusOK, body: successBody},
},
expectedRequestCount: 2,
expectedStatusCode: http.StatusOK,
expectedKeyStates: []keypool.KeyState{
keypool.KeyStateTemporary,
keypool.KeyStateValid,
},
},
{
// Given: 2 keys; key-0 returns 401, key-1 returns 200.
// Then: 2 requests, 200 response, key-0 permanent, key-1 valid.
name: "failover_after_401",
keys: []string{"k0", "k1"},
responses: map[string]upstreamResponse{
"k0": {statusCode: http.StatusUnauthorized, body: authErrorBody},
"k1": {statusCode: http.StatusOK, body: successBody},
},
expectedRequestCount: 2,
expectedStatusCode: http.StatusOK,
expectedKeyStates: []keypool.KeyState{
keypool.KeyStatePermanent,
keypool.KeyStateValid,
},
},
{
// Given: 2 keys; key-0 returns 403, key-1 returns 200.
// Then: 2 requests, 200 response, key-0 permanent, key-1 valid.
name: "failover_after_403",
keys: []string{"k0", "k1"},
responses: map[string]upstreamResponse{
"k0": {statusCode: http.StatusForbidden, body: authErrorBody},
"k1": {statusCode: http.StatusOK, body: successBody},
},
expectedRequestCount: 2,
expectedStatusCode: http.StatusOK,
expectedKeyStates: []keypool.KeyState{
keypool.KeyStatePermanent,
keypool.KeyStateValid,
},
},
{
// Given: 3 keys; all return 429 with cooldowns 5s, 3s, 10s.
// Then: 3 requests, 429 response with smallest Retry-After,
// all keys temporary.
name: "all_keys_rate_limited",
keys: []string{"k0", "k1", "k2"},
responses: map[string]upstreamResponse{
"k0": {
statusCode: http.StatusTooManyRequests,
headers: map[string]string{"Retry-After": "5"},
body: rateLimitBody,
},
"k1": {
statusCode: http.StatusTooManyRequests,
headers: map[string]string{"Retry-After": "3"},
body: rateLimitBody,
},
"k2": {
statusCode: http.StatusTooManyRequests,
headers: map[string]string{"Retry-After": "10"},
body: rateLimitBody,
},
},
expectedRequestCount: 3,
expectedStatusCode: http.StatusTooManyRequests,
expectedRetryAfter: "3",
expectedKeyStates: []keypool.KeyState{
keypool.KeyStateTemporary,
keypool.KeyStateTemporary,
keypool.KeyStateTemporary,
},
},
{
// Given: 2 keys; both return 401.
// Then: 2 requests, 502 api_error response, both keys permanent.
name: "all_keys_unauthorized",
keys: []string{"k0", "k1"},
responses: map[string]upstreamResponse{
"k0": {statusCode: http.StatusUnauthorized, body: authErrorBody},
"k1": {statusCode: http.StatusUnauthorized, body: authErrorBody},
},
expectedRequestCount: 2,
expectedStatusCode: http.StatusBadGateway,
expectedKeyStates: []keypool.KeyState{
keypool.KeyStatePermanent,
keypool.KeyStatePermanent,
},
},
{
// Given: 2 keys; key-0 returns 500.
// Then: 1 request, 500 response, both keys remain valid.
name: "server_error_no_failover",
keys: []string{"k0", "k1"},
responses: map[string]upstreamResponse{
"k0": {statusCode: http.StatusInternalServerError, body: serverErrorBody},
},
expectedRequestCount: 1,
expectedStatusCode: http.StatusInternalServerError,
expectedKeyStates: []keypool.KeyState{
keypool.KeyStateValid,
keypool.KeyStateValid,
},
},
{
// Given: BYOK with a single key returning 429.
// Then: 1 request, 429 response, no failover, upstream
// Retry-After propagated to the client.
name: "byok_no_failover",
byokKey: "user-byok",
responses: map[string]upstreamResponse{
"user-byok": {
statusCode: http.StatusTooManyRequests,
headers: map[string]string{
"Retry-After": "5",
// BYOK doesn't set MaxRetries(0);
// suppress SDK retries to test a
// single attempt.
"x-should-retry": "false",
},
body: rateLimitBody,
},
},
expectedRequestCount: 1,
expectedStatusCode: http.StatusTooManyRequests,
expectedRetryAfter: "5",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
// Mock upstream: counts requests and returns
// scripted responses keyed by X-Api-Key. An unmapped
// key falls through to 500 so misconfigured cases
// surface via the status assertion.
var requestCount atomic.Int32
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestCount.Add(1)
_, _ = io.Copy(io.Discard, r.Body)
resp, ok := tc.responses[r.Header.Get("X-Api-Key")]
if !ok {
resp = upstreamResponse{statusCode: http.StatusInternalServerError}
}
w.Header().Set("Content-Type", "application/json")
for hk, hv := range resp.headers {
w.Header().Set(hk, hv)
}
w.WriteHeader(resp.statusCode)
_, _ = w.Write([]byte(resp.body))
}))
t.Cleanup(upstream.Close)
cfg := config.Anthropic{BaseURL: upstream.URL + "/"}
var pool *keypool.Pool
if len(tc.keys) > 0 {
var err error
pool, err = keypool.New(tc.keys, quartz.NewMock(t))
require.NoError(t, err)
cfg.KeyPool = pool
} else if tc.byokKey != "" {
cfg.Key = tc.byokKey
}
payload, err := NewRequestPayload([]byte(requestBody))
require.NoError(t, err)
interceptor := NewBlockingInterceptor(
uuid.New(),
payload,
config.ProviderAnthropic,
cfg,
nil,
http.Header{},
"X-Api-Key",
otel.Tracer("blocking_test"),
intercept.NewCredentialInfo(intercept.CredentialKindCentralized, ""),
)
interceptor.Setup(slog.Make(), &testutil.MockRecorder{}, nil)
req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
w := httptest.NewRecorder()
err = interceptor.ProcessRequest(w, req)
if tc.expectedStatusCode == http.StatusOK {
require.NoError(t, err)
} else {
require.Error(t, err)
}
assert.Equal(t, tc.expectedRequestCount, requestCount.Load(), "upstream request count")
assert.Equal(t, tc.expectedStatusCode, w.Code, "response status code")
assert.Equal(t, tc.expectedRetryAfter, w.Header().Get("Retry-After"), "Retry-After header")
if pool != nil {
assert.Equal(t, tc.expectedKeyStates, pool.PoolState(), "key states")
}
})
}
}
// TestBlockingInterception_AgenticLoopFailover covers the
// scenarios that span an agentic-loop continuation: the initial
// client request and the subsequent tool-call continuation can
// each fail over independently. Each iteration gets its own
// walker.
func TestBlockingInterception_AgenticLoopFailover(t *testing.T) {
t.Parallel()
tests := []struct {
name string
// Scripted upstream responses consumed in order of
// upstream request.
responses []upstreamResponse
expectedRequestCount int32
expectedSeenKeys []string
expectedStatusCode int
expectedRetryAfter string
expectedKeyStates []keypool.KeyState
}{
{
// Given: 2 keys; both upstream calls succeed on key-0.
// Then: 2 requests, 200 response, both keys remain valid.
name: "happy_path",
responses: []upstreamResponse{
{statusCode: http.StatusOK, body: toolUseBody},
{statusCode: http.StatusOK, body: successBody},
},
expectedRequestCount: 2,
expectedSeenKeys: []string{"k0", "k0"},
expectedStatusCode: http.StatusOK,
expectedKeyStates: []keypool.KeyState{
keypool.KeyStateValid,
keypool.KeyStateValid,
},
},
{
// Given: 2 keys; key-0 succeeds initially, then 429s
// during the agentic continuation, key-1 succeeds.
// Then: 3 requests, 200 response, key-0 temporary,
// key-1 valid.
name: "agentic_failover_to_k1",
responses: []upstreamResponse{
{statusCode: http.StatusOK, body: toolUseBody},
{
statusCode: http.StatusTooManyRequests,
headers: map[string]string{"Retry-After": "5"},
body: rateLimitBody,
},
{statusCode: http.StatusOK, body: successBody},
},
expectedRequestCount: 3,
expectedSeenKeys: []string{"k0", "k0", "k1"},
expectedStatusCode: http.StatusOK,
expectedKeyStates: []keypool.KeyState{
keypool.KeyStateTemporary,
keypool.KeyStateValid,
},
},
{
// Given: 2 keys; key-0 succeeds initially, then both
// keys 429 during the agentic continuation.
// Then: 3 requests, 429 response with smallest
// Retry-After, both keys temporary.
name: "agentic_all_keys_fail",
responses: []upstreamResponse{
{statusCode: http.StatusOK, body: toolUseBody},
{
statusCode: http.StatusTooManyRequests,
headers: map[string]string{"Retry-After": "5"},
body: rateLimitBody,
},
{
statusCode: http.StatusTooManyRequests,
headers: map[string]string{"Retry-After": "3"},
body: rateLimitBody,
},
},
expectedRequestCount: 3,
expectedSeenKeys: []string{"k0", "k0", "k1"},
expectedStatusCode: http.StatusTooManyRequests,
expectedRetryAfter: "3",
expectedKeyStates: []keypool.KeyState{
keypool.KeyStateTemporary,
keypool.KeyStateTemporary,
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
var requestCount atomic.Int32
var seenKeysMu sync.Mutex
var seenKeys []string
// Mock upstream: returns scripted responses in order,
// records each request's X-Api-Key for assertions.
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
idx := int(requestCount.Add(1)) - 1
seenKeysMu.Lock()
seenKeys = append(seenKeys, r.Header.Get("X-Api-Key"))
seenKeysMu.Unlock()
_, _ = io.Copy(io.Discard, r.Body)
if idx >= len(tc.responses) {
w.WriteHeader(http.StatusInternalServerError)
return
}
resp := tc.responses[idx]
w.Header().Set("Content-Type", "application/json")
for hk, hv := range resp.headers {
w.Header().Set(hk, hv)
}
w.WriteHeader(resp.statusCode)
_, _ = w.Write([]byte(resp.body))
}))
t.Cleanup(upstream.Close)
pool, err := keypool.New([]string{"k0", "k1"}, quartz.NewMock(t))
require.NoError(t, err)
cfg := config.Anthropic{
BaseURL: upstream.URL + "/",
KeyPool: pool,
}
payload, err := NewRequestPayload([]byte(requestBody))
require.NoError(t, err)
interceptor := NewBlockingInterceptor(
uuid.New(),
payload,
config.ProviderAnthropic,
cfg,
nil,
http.Header{},
"X-Api-Key",
otel.Tracer("blocking_test"),
intercept.NewCredentialInfo(intercept.CredentialKindCentralized, ""),
)
// Mock proxy with a tool the upstream's tool_use
// response will reference.
proxy := &mockServerProxier{
tools: []*mcp.Tool{
{
Client: stubToolCaller{},
ID: "test_tool",
Name: "test_tool",
ServerName: "coder",
Logger: slog.Make(),
},
},
}
interceptor.Setup(slog.Make(), &testutil.MockRecorder{}, proxy)
req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
w := httptest.NewRecorder()
err = interceptor.ProcessRequest(w, req)
if tc.expectedStatusCode == http.StatusOK {
require.NoError(t, err)
} else {
require.Error(t, err)
}
assert.Equal(t, tc.expectedRequestCount, requestCount.Load(), "upstream request count")
assert.Equal(t, tc.expectedStatusCode, w.Code, "response status code")
assert.Equal(t, tc.expectedRetryAfter, w.Header().Get("Retry-After"), "Retry-After header")
seenKeysMu.Lock()
defer seenKeysMu.Unlock()
assert.Equal(t, tc.expectedSeenKeys, seenKeys, "seen keys")
assert.Equal(t, tc.expectedKeyStates, pool.PoolState(), "key states")
})
}
}
+125 -31
View File
@@ -25,6 +25,7 @@ import (
aibcontext "github.com/coder/coder/v2/aibridge/context"
"github.com/coder/coder/v2/aibridge/intercept"
"github.com/coder/coder/v2/aibridge/intercept/eventstream"
"github.com/coder/coder/v2/aibridge/keypool"
"github.com/coder/coder/v2/aibridge/mcp"
"github.com/coder/coder/v2/aibridge/recorder"
"github.com/coder/coder/v2/aibridge/tracing"
@@ -161,14 +162,64 @@ newStream:
break
}
stream := i.newStream(streamCtx, svc)
// Per-iteration walker. An iteration is either an agentic
// continuation (sending a tool result back in a new
// stream) or a failover retry (previous key marked, try
// the next one).
var walker *keypool.Walker
if i.cfg.KeyPool != nil {
walker = i.cfg.KeyPool.Walker()
}
var streamOpts []option.RequestOption
var currentKey *keypool.Key
if walker != nil {
key, err := walker.Next()
if respErr := processKeyPoolError(err); respErr != nil {
// Pool exhausted in this iteration. Relay the
// error to the client: as an SSE event if events
// have already been sent, or by direct write
// otherwise.
interceptionErr = respErr
if events.IsStreaming() {
payload, mErr := i.marshal(respErr)
if mErr != nil {
logger.Warn(ctx, "failed to marshal exhaustion error", slog.Error(mErr))
} else if sErr := events.Send(streamCtx, payload); sErr != nil {
logger.Warn(ctx, "failed to relay exhaustion error", slog.Error(sErr))
}
} else {
i.writeUpstreamError(w, respErr)
}
break
}
currentKey = key
streamOpts = append(streamOpts,
option.WithAPIKey(key.Value()),
// Disable SDK retries because the failover
// loop handles retries via key rotation.
option.WithMaxRetries(0),
)
}
stream := i.newStream(streamCtx, svc, streamOpts...)
var message anthropic.Message
var lastToolName string
pendingToolCalls := make(map[string]string)
// iterationStarted is per-iteration (reset on every
// newStream loop): true once the upstream call has
// produced any events for this iteration. While false,
// a key-specific failure can still fail over to the
// next key. Distinct from events.IsStreaming(), which
// is stream-wide and stays true once iteration 1 has
// sent any event downstream.
var iterationStarted bool
for stream.Next() {
iterationStarted = true
event := stream.Current()
if err := message.Accumulate(event); err != nil {
logger.Warn(ctx, "failed to accumulate streaming events", slog.Error(err), slog.F("event", event), slog.F("msg", message.RawJSON()))
@@ -478,40 +529,53 @@ newStream:
promptFound = false //nolint:ineffassign // reset to prevent double-recording across newStream iterations
}
if events.IsStreaming() {
// Check if the stream encountered any errors.
if streamErr := stream.Err(); streamErr != nil {
if eventstream.IsUnrecoverableError(streamErr) {
logger.Debug(ctx, "stream terminated", slog.Error(streamErr))
// We can't reflect an error back if there's a connection error or the request context was canceled.
} else if antErr := getErrorResponse(streamErr); antErr != nil {
logger.Warn(ctx, "anthropic stream error", slog.Error(streamErr))
interceptionErr = antErr
} else {
logger.Warn(ctx, "unknown stream error", slog.Error(streamErr))
// Unfortunately, the Anthropic SDK does not support parsing errors received in the stream
// into known types (i.e. [shared.OverloadedError]).
// See https://github.com/anthropics/anthropic-sdk-go/blob/v1.12.0/packages/ssestream/ssestream.go#L172-L174
// All it does is wrap the payload in an error - which is all we can return, currently.
interceptionErr = newErrorResponse(xerrors.Errorf("unknown stream error: %w", streamErr))
}
} else if lastErr != nil {
// Otherwise check if any logical errors occurred during processing.
logger.Warn(ctx, "stream processing failed", slog.Error(lastErr))
interceptionErr = newErrorResponse(xerrors.Errorf("processing error: %w", lastErr))
}
if interceptionErr != nil {
payload, err := i.marshal(interceptionErr)
if iterationStarted {
// Mid-stream error or logical error: events have
// already streamed for this iteration, so the
// error is relayed as an SSE event.
streamErr := stream.Err()
if respErr := i.mapStreamError(ctx, logger, streamErr, lastErr); respErr != nil {
interceptionErr = respErr
payload, err := i.marshal(respErr)
if err != nil {
logger.Warn(ctx, "failed to marshal error", slog.Error(err), slog.F("error_payload", fmt.Sprintf("%+v", interceptionErr)))
logger.Warn(ctx, "failed to marshal error", slog.Error(err), slog.F("error_payload", fmt.Sprintf("%+v", respErr)))
} else if err := events.Send(streamCtx, payload); err != nil {
logger.Warn(ctx, "failed to relay error", slog.Error(err), slog.F("payload", payload))
}
} else if streamErr != nil {
// Unrecoverable (e.g., broken pipe, context
// canceled): can't relay to the client, but record
// the error so it isn't silently swallowed.
interceptionErr = streamErr
}
} else {
// Stream has not started yet; write to response if present.
i.writeUpstreamError(w, getErrorResponse(stream.Err()))
// Pre-stream failure of this iteration. For
// centralized requests, mark the key and retry with
// the next one.
if currentKey != nil && i.markKeyOnError(ctx, currentKey, stream.Err()) {
continue newStream
}
// Non-key error: relay it. Use mapStreamError so that
// unknown upstream errors (TCP reset, DNS failure, TLS
// error, deadline exceeded) are wrapped in a generic
// response instead of producing a silent HTTP 200.
respErr := i.mapStreamError(ctx, logger, stream.Err(), lastErr)
if respErr != nil {
interceptionErr = respErr
if events.IsStreaming() {
// Prior iterations have streamed, so the SSE
// connection is open: inject as an SSE event.
payload, mErr := i.marshal(respErr)
if mErr != nil {
logger.Warn(ctx, "failed to marshal error", slog.Error(mErr))
} else if sErr := events.Send(streamCtx, payload); sErr != nil {
logger.Warn(ctx, "failed to relay error", slog.Error(sErr))
}
} else {
// No events streamed yet, write the response directly.
i.writeUpstreamError(w, respErr)
}
}
}
shutdownCtx, shutdownCancel := context.WithTimeout(ctx, time.Second*30)
@@ -534,6 +598,35 @@ newStream:
return interceptionErr
}
// mapStreamError converts a mid-stream upstream error or
// processing error into a relayable responseError. Returns nil
// when the error is unrecoverable, in which case nothing can be
// relayed back.
func (*StreamingInterception) mapStreamError(ctx context.Context, logger slog.Logger, streamErr, lastErr error) *responseError {
if streamErr != nil {
if eventstream.IsUnrecoverableError(streamErr) {
logger.Debug(ctx, "stream terminated", slog.Error(streamErr))
// We can't reflect an error back if there's a connection error or the request context was canceled.
return nil
}
if antErr := getErrorResponse(streamErr); antErr != nil {
logger.Warn(ctx, "anthropic stream error", slog.Error(streamErr))
return antErr
}
logger.Warn(ctx, "unknown stream error", slog.Error(streamErr))
// Unfortunately, the Anthropic SDK does not support parsing errors received in the stream
// into known types (i.e. [shared.OverloadedError]).
// See https://github.com/anthropics/anthropic-sdk-go/blob/v1.12.0/packages/ssestream/ssestream.go#L172-L174
// All it does is wrap the payload in an error - which is all we can return, currently.
return newErrorResponse(fmt.Sprintf("unknown stream error: %s", streamErr), string(constant.ValueOf[constant.Error]()), http.StatusBadGateway, 0)
}
if lastErr != nil {
logger.Warn(ctx, "stream processing failed", slog.Error(lastErr))
return newErrorResponse(fmt.Sprintf("processing error: %s", lastErr), string(constant.ValueOf[constant.Error]()), http.StatusBadGateway, 0)
}
return nil
}
func (i *StreamingInterception) marshalEvent(event anthropic.MessageStreamEventUnion) ([]byte, error) {
sj, err := sjson.Set(event.RawJSON(), "message.id", i.ID().String())
if err != nil {
@@ -585,9 +678,10 @@ func (*StreamingInterception) encodeForStream(payload []byte, typ string) []byte
}
// newStream traces svc.NewStreaming() call.
func (i *StreamingInterception) newStream(ctx context.Context, svc anthropic.MessageService) *ssestream.Stream[anthropic.MessageStreamEventUnion] {
func (i *StreamingInterception) newStream(ctx context.Context, svc anthropic.MessageService, extraOpts ...option.RequestOption) *ssestream.Stream[anthropic.MessageStreamEventUnion] {
_, span := i.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...))
defer span.End()
return svc.NewStreaming(ctx, anthropic.MessageNewParams{}, i.withBody())
opts := append([]option.RequestOption{i.withBody()}, extraOpts...)
return svc.NewStreaming(ctx, anthropic.MessageNewParams{}, opts...)
}
@@ -0,0 +1,554 @@
package messages //nolint:testpackage // tests unexported internals
import (
"context"
"io"
"net/http"
"net/http/httptest"
"sync"
"sync/atomic"
"testing"
"github.com/google/uuid"
mcplib "github.com/mark3labs/mcp-go/mcp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/aibridge/config"
"github.com/coder/coder/v2/aibridge/intercept"
"github.com/coder/coder/v2/aibridge/internal/testutil"
"github.com/coder/coder/v2/aibridge/keypool"
"github.com/coder/coder/v2/aibridge/mcp"
"github.com/coder/quartz"
)
// Anthropic-shaped SSE body for a successful streaming response.
const streamingSuccessBody = `event: message_start
data: {"type":"message_start","message":{"id":"msg_01","type":"message","role":"assistant","model":"claude-opus-4-5","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":10,"output_tokens":1}}}
event: content_block_start
data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"hi"}}
event: content_block_stop
data: {"type":"content_block_stop","index":0}
event: message_delta
data: {"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"output_tokens":5}}
event: message_stop
data: {"type":"message_stop"}
`
func TestStreamingInterception_KeyFailover(t *testing.T) {
t.Parallel()
tests := []struct {
name string
// Centralized pool keys. Empty when byokKey is set.
keys []string
// BYOK key. Empty when keys is set.
byokKey string
// Scripted upstream responses keyed by X-Api-Key.
responses map[string]upstreamResponse
expectedRequestCount int32
expectedStatusCode int
expectedRetryAfter string
// Expected key states after the request, by index in keys.
expectedKeyStates []keypool.KeyState
}{
{
// Given: 1 valid key returning a successful stream.
// Then: 1 request, 200 response, key remains valid.
name: "single_valid_key",
keys: []string{"k0"},
responses: map[string]upstreamResponse{
"k0": {
statusCode: http.StatusOK,
headers: map[string]string{"Content-Type": "text/event-stream"},
body: streamingSuccessBody,
},
},
expectedRequestCount: 1,
expectedStatusCode: http.StatusOK,
expectedKeyStates: []keypool.KeyState{keypool.KeyStateValid},
},
{
// Given: 2 keys; key-0 returns 429 pre-stream, key-1
// streams successfully.
// Then: 2 requests, 200 response, key-0 temporary, key-1 valid.
name: "failover_after_429",
keys: []string{"k0", "k1"},
responses: map[string]upstreamResponse{
"k0": {
statusCode: http.StatusTooManyRequests,
headers: map[string]string{"Retry-After": "5"},
body: rateLimitBody,
},
"k1": {
statusCode: http.StatusOK,
headers: map[string]string{"Content-Type": "text/event-stream"},
body: streamingSuccessBody,
},
},
expectedRequestCount: 2,
expectedStatusCode: http.StatusOK,
expectedKeyStates: []keypool.KeyState{
keypool.KeyStateTemporary,
keypool.KeyStateValid,
},
},
{
// Given: 2 keys; key-0 returns 401 pre-stream, key-1
// streams successfully.
// Then: 2 requests, 200 response, key-0 permanent, key-1 valid.
name: "failover_after_401",
keys: []string{"k0", "k1"},
responses: map[string]upstreamResponse{
"k0": {statusCode: http.StatusUnauthorized, body: authErrorBody},
"k1": {
statusCode: http.StatusOK,
headers: map[string]string{"Content-Type": "text/event-stream"},
body: streamingSuccessBody,
},
},
expectedRequestCount: 2,
expectedStatusCode: http.StatusOK,
expectedKeyStates: []keypool.KeyState{
keypool.KeyStatePermanent,
keypool.KeyStateValid,
},
},
{
// Given: 2 keys; key-0 returns 403 pre-stream, key-1 streams.
// Then: 2 requests, 200 response, key-0 permanent, key-1 valid.
name: "failover_after_403",
keys: []string{"k0", "k1"},
responses: map[string]upstreamResponse{
"k0": {statusCode: http.StatusForbidden, body: authErrorBody},
"k1": {
statusCode: http.StatusOK,
headers: map[string]string{"Content-Type": "text/event-stream"},
body: streamingSuccessBody,
},
},
expectedRequestCount: 2,
expectedStatusCode: http.StatusOK,
expectedKeyStates: []keypool.KeyState{
keypool.KeyStatePermanent,
keypool.KeyStateValid,
},
},
{
// Given: 3 keys; all return 429 pre-stream with
// cooldowns 5s, 3s, 10s.
// Then: 3 requests, 429 response with smallest
// Retry-After, all keys temporary.
name: "all_keys_rate_limited",
keys: []string{"k0", "k1", "k2"},
responses: map[string]upstreamResponse{
"k0": {
statusCode: http.StatusTooManyRequests,
headers: map[string]string{"Retry-After": "5"},
body: rateLimitBody,
},
"k1": {
statusCode: http.StatusTooManyRequests,
headers: map[string]string{"Retry-After": "3"},
body: rateLimitBody,
},
"k2": {
statusCode: http.StatusTooManyRequests,
headers: map[string]string{"Retry-After": "10"},
body: rateLimitBody,
},
},
expectedRequestCount: 3,
expectedStatusCode: http.StatusTooManyRequests,
expectedRetryAfter: "3",
expectedKeyStates: []keypool.KeyState{
keypool.KeyStateTemporary,
keypool.KeyStateTemporary,
keypool.KeyStateTemporary,
},
},
{
// Given: 2 keys; both return 401 pre-stream.
// Then: 2 requests, 502 api_error response, both keys permanent.
name: "all_keys_unauthorized",
keys: []string{"k0", "k1"},
responses: map[string]upstreamResponse{
"k0": {statusCode: http.StatusUnauthorized, body: authErrorBody},
"k1": {statusCode: http.StatusUnauthorized, body: authErrorBody},
},
expectedRequestCount: 2,
expectedStatusCode: http.StatusBadGateway,
expectedKeyStates: []keypool.KeyState{
keypool.KeyStatePermanent,
keypool.KeyStatePermanent,
},
},
{
// Given: 2 keys; key-0 returns 500 pre-stream.
// Then: 1 request, 500 response, both keys remain valid.
name: "server_error_no_failover",
keys: []string{"k0", "k1"},
responses: map[string]upstreamResponse{
"k0": {statusCode: http.StatusInternalServerError, body: serverErrorBody},
},
expectedRequestCount: 1,
expectedStatusCode: http.StatusInternalServerError,
expectedKeyStates: []keypool.KeyState{
keypool.KeyStateValid,
keypool.KeyStateValid,
},
},
{
// Given: BYOK with a single key returning 429.
// Then: 1 request, 429 response, no failover, upstream
// Retry-After propagated to the client.
name: "byok_no_failover",
byokKey: "user-byok",
responses: map[string]upstreamResponse{
"user-byok": {
statusCode: http.StatusTooManyRequests,
headers: map[string]string{
"Retry-After": "5",
// BYOK doesn't set MaxRetries(0);
// suppress SDK retries to test a
// single attempt.
"x-should-retry": "false",
},
body: rateLimitBody,
},
},
expectedRequestCount: 1,
expectedStatusCode: http.StatusTooManyRequests,
expectedRetryAfter: "5",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
// Mock upstream: counts requests and returns
// scripted responses keyed by X-Api-Key. An unmapped
// key falls through to 500 so misconfigured cases
// surface via the status assertion.
var requestCount atomic.Int32
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestCount.Add(1)
_, _ = io.Copy(io.Discard, r.Body)
resp, ok := tc.responses[r.Header.Get("X-Api-Key")]
if !ok {
resp = upstreamResponse{statusCode: http.StatusInternalServerError}
}
for hk, hv := range resp.headers {
w.Header().Set(hk, hv)
}
w.WriteHeader(resp.statusCode)
_, _ = w.Write([]byte(resp.body))
}))
t.Cleanup(upstream.Close)
cfg := config.Anthropic{BaseURL: upstream.URL + "/"}
var pool *keypool.Pool
if len(tc.keys) > 0 {
var err error
pool, err = keypool.New(tc.keys, quartz.NewMock(t))
require.NoError(t, err)
cfg.KeyPool = pool
} else if tc.byokKey != "" {
cfg.Key = tc.byokKey
}
payload, err := NewRequestPayload([]byte(requestBody))
require.NoError(t, err)
interceptor := NewStreamingInterceptor(
uuid.New(),
payload,
config.ProviderAnthropic,
cfg,
nil,
http.Header{},
"X-Api-Key",
otel.Tracer("streaming_test"),
intercept.NewCredentialInfo(intercept.CredentialKindCentralized, ""),
)
interceptor.Setup(slog.Make(), &testutil.MockRecorder{}, nil)
req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
w := httptest.NewRecorder()
err = interceptor.ProcessRequest(w, req)
if tc.expectedStatusCode == http.StatusOK {
require.NoError(t, err)
} else {
require.Error(t, err)
}
assert.Equal(t, tc.expectedRequestCount, requestCount.Load(), "upstream request count")
assert.Equal(t, tc.expectedStatusCode, w.Code, "response status code")
assert.Equal(t, tc.expectedRetryAfter, w.Header().Get("Retry-After"), "Retry-After header")
// No prior iteration streamed, so errors must be a
// direct HTTP response, not an SSE event.
assert.NotContains(t, w.Body.String(), "event: error", "error must not be relayed as an SSE event")
if pool != nil {
assert.Equal(t, tc.expectedKeyStates, pool.PoolState(), "key states")
}
})
}
}
// SSE bodies covering an agentic-continuation flow.
const (
// First response: a tool_use block referencing the injected
// "test_tool". Triggers the agentic continuation loop.
toolUseStreamBody = `event: message_start
data: {"type":"message_start","message":{"id":"msg_01","type":"message","role":"assistant","model":"claude-opus-4-5","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":10,"output_tokens":1}}}
event: content_block_start
data: {"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"toolu_01","name":"test_tool","input":{}}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"{}"}}
event: content_block_stop
data: {"type":"content_block_stop","index":0}
event: message_delta
data: {"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"output_tokens":5}}
event: message_stop
data: {"type":"message_stop"}
`
// Second response (after the tool result is sent back):
// a plain text completion that ends the loop.
textStreamBody = `event: message_start
data: {"type":"message_start","message":{"id":"msg_02","type":"message","role":"assistant","model":"claude-opus-4-5","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":15,"output_tokens":1}}}
event: content_block_start
data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}
event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"done"}}
event: content_block_stop
data: {"type":"content_block_stop","index":0}
event: message_delta
data: {"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"output_tokens":3}}
event: message_stop
data: {"type":"message_stop"}
`
)
// stubToolCaller is a minimal mcp.ToolCaller that returns a fixed
// text result, so the agentic continuation can proceed.
type stubToolCaller struct{}
func (stubToolCaller) CallTool(_ context.Context, _ mcplib.CallToolRequest) (*mcplib.CallToolResult, error) {
return mcplib.NewToolResultText("tool result"), nil
}
// TestStreamingInterception_AgenticLoopFailover covers the
// scenarios that span an agentic-loop continuation: the initial
// client request and the subsequent tool-call continuation can
// each fail over independently. Each iteration gets its own
// walker.
func TestStreamingInterception_AgenticLoopFailover(t *testing.T) {
t.Parallel()
sseHeaders := map[string]string{"Content-Type": "text/event-stream"}
tests := []struct {
name string
// Scripted upstream responses consumed in order of
// upstream request.
responses []upstreamResponse
expectedRequestCount int32
expectedSeenKeys []string
// Substring expected in the response body. Either a
// success marker (e.g. "done") or an error marker
// (e.g. "rate_limit_error").
expectedBodyContains string
// True when the error must be relayed as an SSE event.
expectErrorAsSSEEvent bool
// True when ProcessRequest is expected to return an
// error (e.g. all keys exhausted).
expectedErr bool
expectedKeyStates []keypool.KeyState
}{
{
// Given: 2 keys; both upstream calls succeed on key-0.
// Then: 2 requests, success body, both keys remain valid.
name: "happy_path",
responses: []upstreamResponse{
{statusCode: http.StatusOK, headers: sseHeaders, body: toolUseStreamBody},
{statusCode: http.StatusOK, headers: sseHeaders, body: textStreamBody},
},
expectedRequestCount: 2,
expectedSeenKeys: []string{"k0", "k0"},
expectedBodyContains: "done",
expectErrorAsSSEEvent: false,
expectedKeyStates: []keypool.KeyState{
keypool.KeyStateValid,
keypool.KeyStateValid,
},
},
{
// Given: 2 keys; key-0 succeeds initially, then 429s
// during the agentic continuation, key-1 succeeds.
// Then: 3 requests, success body, key-0 temporary,
// key-1 valid.
name: "agentic_failover_to_k1",
responses: []upstreamResponse{
{statusCode: http.StatusOK, headers: sseHeaders, body: toolUseStreamBody},
{
statusCode: http.StatusTooManyRequests,
headers: map[string]string{"Retry-After": "5"},
body: rateLimitBody,
},
{statusCode: http.StatusOK, headers: sseHeaders, body: textStreamBody},
},
expectedRequestCount: 3,
expectedSeenKeys: []string{"k0", "k0", "k1"},
expectedBodyContains: "done",
expectErrorAsSSEEvent: false,
expectedKeyStates: []keypool.KeyState{
keypool.KeyStateTemporary,
keypool.KeyStateValid,
},
},
{
// Given: 2 keys; key-0 succeeds initially, then both
// keys 429 during the agentic continuation.
// Then: 3 requests, error injected as SSE event, both
// keys temporary.
name: "agentic_all_keys_fail",
responses: []upstreamResponse{
{statusCode: http.StatusOK, headers: sseHeaders, body: toolUseStreamBody},
{
statusCode: http.StatusTooManyRequests,
headers: map[string]string{"Retry-After": "5"},
body: rateLimitBody,
},
{
statusCode: http.StatusTooManyRequests,
headers: map[string]string{"Retry-After": "3"},
body: rateLimitBody,
},
},
expectedRequestCount: 3,
expectedSeenKeys: []string{"k0", "k0", "k1"},
expectedBodyContains: "all configured keys are rate-limited",
expectErrorAsSSEEvent: true,
expectedErr: true,
expectedKeyStates: []keypool.KeyState{
keypool.KeyStateTemporary,
keypool.KeyStateTemporary,
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
var requestCount atomic.Int32
var seenKeysMu sync.Mutex
var seenKeys []string
// Mock upstream: returns scripted responses in order,
// records each request's X-Api-Key for assertions.
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
idx := int(requestCount.Add(1)) - 1
seenKeysMu.Lock()
seenKeys = append(seenKeys, r.Header.Get("X-Api-Key"))
seenKeysMu.Unlock()
_, _ = io.Copy(io.Discard, r.Body)
if idx >= len(tc.responses) {
w.WriteHeader(http.StatusInternalServerError)
return
}
resp := tc.responses[idx]
for hk, hv := range resp.headers {
w.Header().Set(hk, hv)
}
w.WriteHeader(resp.statusCode)
_, _ = w.Write([]byte(resp.body))
}))
t.Cleanup(upstream.Close)
pool, err := keypool.New([]string{"k0", "k1"}, quartz.NewMock(t))
require.NoError(t, err)
cfg := config.Anthropic{
BaseURL: upstream.URL + "/",
KeyPool: pool,
}
payload, err := NewRequestPayload([]byte(requestBody))
require.NoError(t, err)
interceptor := NewStreamingInterceptor(
uuid.New(),
payload,
config.ProviderAnthropic,
cfg,
nil,
http.Header{},
"X-Api-Key",
otel.Tracer("streaming_test"),
intercept.NewCredentialInfo(intercept.CredentialKindCentralized, ""),
)
// Mock proxy with a tool the upstream's tool_use event
// will reference. The stub caller returns a fixed
// text result.
proxy := &mockServerProxier{
tools: []*mcp.Tool{
{
Client: stubToolCaller{},
ID: "test_tool",
Name: "test_tool",
ServerName: "coder",
Logger: slog.Make(),
},
},
}
interceptor.Setup(slog.Make(), &testutil.MockRecorder{}, proxy)
req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
w := httptest.NewRecorder()
err = interceptor.ProcessRequest(w, req)
if tc.expectedErr {
require.Error(t, err)
} else {
require.NoError(t, err)
}
assert.Equal(t, tc.expectedRequestCount, requestCount.Load(), "upstream request count")
body := w.Body.String()
assert.Contains(t, body, tc.expectedBodyContains, "response body")
if tc.expectErrorAsSSEEvent {
assert.Contains(t, body, "event: error", "error must be relayed as an SSE event")
}
seenKeysMu.Lock()
defer seenKeysMu.Unlock()
assert.Equal(t, tc.expectedSeenKeys, seenKeys, "seen keys")
assert.Equal(t, tc.expectedKeyStates, pool.PoolState(), "key states")
})
}
}
@@ -0,0 +1,128 @@
package integrationtest //nolint:testpackage // tests unexported internals
import (
"fmt"
"io"
"net/http"
"net/http/httptest"
"sync"
"sync/atomic"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tidwall/sjson"
"github.com/coder/coder/v2/aibridge/config"
"github.com/coder/coder/v2/aibridge/fixtures"
"github.com/coder/coder/v2/aibridge/keypool"
"github.com/coder/coder/v2/aibridge/provider"
"github.com/coder/quartz"
)
// TestAnthropic_KeyFailover verifies that a pool's key state
// persists across distinct client requests: a key marked
// temporary on request 1 is still skipped on request 2 without
// a wasted upstream attempt.
func TestAnthropic_KeyFailover(t *testing.T) {
t.Parallel()
fix := fixtures.Parse(t, fixtures.AntSimple)
tests := []struct {
name string
streaming bool
successBody []byte
successCType string
}{
{
name: "blocking",
streaming: false,
successBody: fix.NonStreaming(),
successCType: "application/json",
},
{
name: "streaming",
streaming: true,
successBody: fix.Streaming(),
successCType: "text/event-stream",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
pool, err := keypool.New([]string{"k0", "k1"}, quartz.NewMock(t))
require.NoError(t, err)
var requestCount atomic.Int32
var seenKeysMu sync.Mutex
var seenKeys []string
// Mock upstream: k0 always returns 429, k1 returns
// the per-test success body.
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestCount.Add(1)
key := r.Header.Get("X-Api-Key")
seenKeysMu.Lock()
seenKeys = append(seenKeys, key)
seenKeysMu.Unlock()
_, _ = io.Copy(io.Discard, r.Body)
switch key {
case "k0":
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Retry-After", "60")
w.WriteHeader(http.StatusTooManyRequests)
_, _ = fmt.Fprint(w, `{"type":"error","error":{"type":"rate_limit_error","message":"rate limited"}}`)
case "k1":
w.Header().Set("Content-Type", tc.successCType)
w.WriteHeader(http.StatusOK)
_, _ = w.Write(tc.successBody)
default:
w.WriteHeader(http.StatusInternalServerError)
}
}))
t.Cleanup(upstream.Close)
bridgeServer := newBridgeTestServer(t.Context(), t, upstream.URL,
withCustomProvider(provider.NewAnthropic(config.Anthropic{
BaseURL: upstream.URL,
KeyPool: pool,
}, nil)),
)
requestBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming)
require.NoError(t, err)
// Request 1: walker starts at k0, fails over to k1
// after 429.
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, requestBody)
require.NoError(t, err)
_, _ = io.Copy(io.Discard, resp.Body)
require.NoError(t, resp.Body.Close())
require.Equal(t, http.StatusOK, resp.StatusCode)
// Request 2: walker skips the now-temporary k0 and
// goes straight to k1 (1 upstream call, not 2).
resp, err = bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, requestBody)
require.NoError(t, err)
_, _ = io.Copy(io.Discard, resp.Body)
require.NoError(t, resp.Body.Close())
require.Equal(t, http.StatusOK, resp.StatusCode)
seenKeysMu.Lock()
defer seenKeysMu.Unlock()
// Request 1: 2 calls (k0 then k1). Request 2: 1 call (k1).
assert.Equal(t, int32(3), requestCount.Load(), "upstream request count")
assert.Equal(t, []string{"k0", "k1", "k1"}, seenKeys, "seen keys")
// Pool state persists: k0 temporary, k1 valid.
assert.Equal(t, []keypool.KeyState{
keypool.KeyStateTemporary,
keypool.KeyStateValid,
}, pool.PoolState(), "key states")
})
}
}
+37
View File
@@ -0,0 +1,37 @@
package keypool
import (
"net/http"
"strconv"
"strings"
"time"
)
// ParseRetryAfter extracts the cooldown duration from response
// headers. It prefers the OpenAI-specific "retry-after-ms"
// header (milliseconds) over the standard "Retry-After" header
// (seconds). Returns zero if neither header is present or
// parseable. The HTTP-date form of "Retry-After" is not parsed.
func ParseRetryAfter(resp *http.Response) time.Duration {
if resp == nil {
return 0
}
// OpenAI convention: millisecond precision.
if val := resp.Header.Get("retry-after-ms"); val != "" {
ms, err := strconv.ParseFloat(strings.TrimSpace(val), 64)
if err == nil && ms > 0 {
return time.Duration(ms * float64(time.Millisecond))
}
}
// Standard header: seconds.
if val := resp.Header.Get("Retry-After"); val != "" {
seconds, err := strconv.Atoi(strings.TrimSpace(val))
if err == nil && seconds > 0 {
return time.Duration(seconds) * time.Second
}
}
return 0
}
+110
View File
@@ -0,0 +1,110 @@
package keypool_test
import (
"net/http"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/coder/coder/v2/aibridge/keypool"
)
func TestParseRetryAfter(t *testing.T) {
t.Parallel()
tests := []struct {
name string
headers map[string]string
nilResponse bool
expected time.Duration
}{
// nil response.
{
name: "nil_response",
nilResponse: true,
expected: 0,
},
// No headers set.
{
name: "no_headers",
headers: nil,
expected: 0,
},
// retry-after-ms (OpenAI, preferred).
{
name: "openai_retry_after_ms",
headers: map[string]string{"retry-after-ms": "2500"},
expected: 2500 * time.Millisecond,
},
{
name: "whitespace_trimmed_ms",
headers: map[string]string{"retry-after-ms": " 1500 "},
expected: 1500 * time.Millisecond,
},
{
name: "negative_ms_returns_zero",
headers: map[string]string{"retry-after-ms": "-100"},
expected: 0,
},
// Retry-After (standard, seconds).
{
name: "standard_retry_after_seconds",
headers: map[string]string{"Retry-After": "60"},
expected: 60 * time.Second,
},
{
name: "whitespace_trimmed_seconds",
headers: map[string]string{"Retry-After": " 30 "},
expected: 30 * time.Second,
},
{
name: "zero_seconds_returns_zero",
headers: map[string]string{"Retry-After": "0"},
expected: 0,
},
{
name: "negative_seconds_returns_zero",
headers: map[string]string{"Retry-After": "-5"},
expected: 0,
},
// Both headers set: precedence and fallback.
{
name: "prefers_retry_after_ms_over_standard",
headers: map[string]string{
"retry-after-ms": "1500",
"Retry-After": "30",
},
expected: 1500 * time.Millisecond,
},
{
name: "falls_back_to_standard_when_ms_invalid",
headers: map[string]string{"retry-after-ms": "invalid", "Retry-After": "10"},
expected: 10 * time.Second,
},
{
name: "zero_ms_falls_back_to_standard",
headers: map[string]string{"retry-after-ms": "0", "Retry-After": "5"},
expected: 5 * time.Second,
},
{
name: "zero_ms_and_zero_seconds_return_zero",
headers: map[string]string{"retry-after-ms": "0", "Retry-After": "0"},
expected: 0,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
var resp *http.Response
if !tc.nilResponse {
resp = &http.Response{Header: make(http.Header)}
for key, val := range tc.headers {
resp.Header.Set(key, val)
}
}
assert.Equal(t, tc.expected, keypool.ParseRetryAfter(resp))
})
}
}
+54
View File
@@ -0,0 +1,54 @@
package keypool
import (
"context"
"net/http"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/aibridge/utils"
)
// MarkKeyOnStatus marks key based on a key-specific HTTP
// status code from resp (429 for temporary, 401 or 403 for
// permanent). Returns true if the status was a key-specific
// failover trigger so callers can retry with the next key.
func MarkKeyOnStatus(
ctx context.Context,
key *Key,
resp *http.Response,
logger slog.Logger,
providerName string,
) bool {
if resp == nil {
return false
}
statusCode := resp.StatusCode
switch statusCode {
case http.StatusTooManyRequests:
cooldown := ParseRetryAfter(resp)
if cooldown <= 0 {
cooldown = defaultCooldown
}
if key.MarkTemporary(cooldown) {
logger.Info(ctx, "key marked temporary",
slog.F("provider", providerName),
slog.F("api_key_hint", utils.MaskSecret(key.Value())),
slog.F("status", statusCode),
slog.F("cooldown", cooldown))
}
return true
case http.StatusUnauthorized, http.StatusForbidden:
if key.MarkPermanent() {
logger.Warn(ctx, "key marked permanent",
slog.F("provider", providerName),
slog.F("api_key_hint", utils.MaskSecret(key.Value())),
slog.F("status", statusCode))
}
return true
default:
logger.Debug(ctx, "status is not a key failover trigger",
slog.F("provider", providerName),
slog.F("status", statusCode))
return false
}
}
+127
View File
@@ -0,0 +1,127 @@
package keypool_test
import (
"context"
"net/http"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/aibridge/keypool"
"github.com/coder/quartz"
)
func TestMarkKeyOnStatus(t *testing.T) {
t.Parallel()
tests := []struct {
name string
statusCode int
headers map[string]string
expectedReturn bool
expectedState keypool.KeyState
expectedCooldown time.Duration
}{
{
// 429 with standard Retry-After header (seconds).
name: "429_with_retry_after_seconds",
statusCode: http.StatusTooManyRequests,
headers: map[string]string{"Retry-After": "5"},
expectedReturn: true,
expectedState: keypool.KeyStateTemporary,
expectedCooldown: 5 * time.Second,
},
{
// 429 with retry-after-ms header (milliseconds).
name: "429_with_retry_after_ms",
statusCode: http.StatusTooManyRequests,
headers: map[string]string{"retry-after-ms": "1500"},
expectedReturn: true,
expectedState: keypool.KeyStateTemporary,
expectedCooldown: 1500 * time.Millisecond,
},
{
// 429 without headers falls back to default cooldown.
name: "429_no_headers_uses_default",
statusCode: http.StatusTooManyRequests,
expectedReturn: true,
expectedState: keypool.KeyStateTemporary,
expectedCooldown: 60 * time.Second,
},
{
name: "401_marks_permanent",
statusCode: http.StatusUnauthorized,
expectedReturn: true,
expectedState: keypool.KeyStatePermanent,
},
{
name: "403_marks_permanent",
statusCode: http.StatusForbidden,
expectedReturn: true,
expectedState: keypool.KeyStatePermanent,
},
{
name: "200_does_not_mark",
statusCode: http.StatusOK,
expectedReturn: false,
expectedState: keypool.KeyStateValid,
},
{
name: "500_does_not_mark",
statusCode: http.StatusInternalServerError,
expectedReturn: false,
expectedState: keypool.KeyStateValid,
},
{
// 529 is the Anthropic overloaded status, handled by
// the circuit breaker, not key failover.
name: "529_does_not_mark",
statusCode: 529,
expectedReturn: false,
expectedState: keypool.KeyStateValid,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
clk := quartz.NewMock(t)
pool, err := keypool.New([]string{"key-0"}, clk)
require.NoError(t, err)
key, err := pool.Walker().Next()
require.NoError(t, err)
resp := &http.Response{
StatusCode: tc.statusCode,
Header: make(http.Header),
}
for k, v := range tc.headers {
resp.Header.Set(k, v)
}
got := keypool.MarkKeyOnStatus(
context.Background(),
key,
resp,
// 401 and 403 cases legitimately log at error
// level when marking a key permanent.
slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
"test",
)
assert.Equal(t, tc.expectedReturn, got)
assert.Equal(t, tc.expectedState, key.State())
// Verify cooldown was set to the expected duration:
// advancing by exactly that amount returns the key
// to valid.
if tc.expectedCooldown > 0 {
clk.Advance(tc.expectedCooldown)
assert.Equal(t, keypool.KeyStateValid, key.State())
}
})
}
}
+80 -8
View File
@@ -1,6 +1,7 @@
package keypool
import (
"fmt"
"sync"
"time"
@@ -15,11 +16,24 @@ var (
// ErrDuplicateKey is returned when the input contains
// duplicate key values.
ErrDuplicateKey = xerrors.New("duplicate key")
// ErrAllKeysExhausted is returned when the walker has visited
// every key in the pool and none are available.
ErrAllKeysExhausted = xerrors.New("all keys exhausted")
)
// ErrPermanentKeyPool is returned when every key in the
// pool has been permanently marked unavailable.
var ErrPermanentKeyPool = xerrors.New("all keys permanently unavailable")
// TransientKeyPoolError is returned when no key is currently
// available but at least one will recover. RetryAfter is the
// soonest remaining cooldown across the pool, or 0 if a key
// just became valid mid-walk.
type TransientKeyPoolError struct {
RetryAfter time.Duration
}
func (e *TransientKeyPoolError) Error() string {
return fmt.Sprintf("all keys exhausted (retry after %s)", e.RetryAfter)
}
// KeyState represents the current state of a key in the pool.
type KeyState int
@@ -101,6 +115,22 @@ func (k *Key) State() KeyState {
return KeyStateValid
}
// stateAndCooldown returns the key's state and remaining
// cooldown as a single atomic snapshot.
func (k *Key) stateAndCooldown() (KeyState, time.Duration) {
k.mu.RLock()
defer k.mu.RUnlock()
if k.permanent {
return KeyStatePermanent, 0
}
now := k.clock.Now()
if now.Before(k.cooldownUntil) {
return KeyStateTemporary, k.cooldownUntil.Sub(now)
}
return KeyStateValid, 0
}
// MarkTemporary marks the key as temporarily unavailable with
// the specified cooldown duration. Returns true if this call
// transitions the key to temporary.
@@ -146,6 +176,47 @@ func (k *Key) MarkPermanent() bool {
return true
}
// keyPoolError returns ErrPermanentKeyPool if every key
// is permanently unavailable, or *TransientKeyPoolError if
// at least one key is temporarily unavailable. When multiple
// keys are temporary, the smallest remaining cooldown is used
// as the retry-after.
func (p *Pool) keyPoolError() error {
var retryAfter time.Duration
var hasCooldown bool
for i := range p.keys {
state, cooldown := p.keys[i].stateAndCooldown()
switch state {
// Recoverable now: signal transient with zero retry-after.
case KeyStateValid:
return &TransientKeyPoolError{}
// Recoverable later: track soonest remaining cooldown.
case KeyStateTemporary:
if !hasCooldown || cooldown < retryAfter {
retryAfter = cooldown
hasCooldown = true
}
// Permanent: keep walking to confirm error type.
default:
}
}
if hasCooldown {
return &TransientKeyPoolError{RetryAfter: retryAfter}
}
return ErrPermanentKeyPool
}
// PoolState returns a snapshot of each key's state in the pool's
// original order, used by tests and other diagnostic callers. Use
// Walker for the failover iteration path.
func (p *Pool) PoolState() []KeyState {
states := make([]KeyState, len(p.keys))
for i := range p.keys {
states[i] = p.keys[i].State()
}
return states
}
// Walker traverses a Pool for a single request. Each request
// creates its own walker so that it can independently iterate
// through keys without interfering with other requests.
@@ -162,14 +233,15 @@ func (p *Pool) Walker() *Walker {
return &Walker{pool: p, pos: 0}
}
// Next returns a Key handle for the next available key. This is
// a read-only operation; it does not modify the pool state.
// Next returns a Key handle for the next available key without
// modifying the pool state.
//
// Returns ErrAllKeysExhausted when no more keys are available.
// Returns *TransientKeyPoolError or ErrPermanentKeyPool
// when no more keys are available.
func (w *Walker) Next() (*Key, error) {
pool := w.pool
if pool == nil {
return nil, ErrAllKeysExhausted
return nil, ErrPermanentKeyPool
}
for i := w.pos; i < len(pool.keys); i++ {
@@ -183,5 +255,5 @@ func (w *Walker) Next() (*Key, error) {
}
// No keys available.
return nil, ErrAllKeysExhausted
return nil, pool.keyPoolError()
}
+153 -30
View File
@@ -1,6 +1,8 @@
package keypool_test
import (
"errors"
"sync"
"testing"
"time"
@@ -49,7 +51,8 @@ func TestNewKeyPool(t *testing.T) {
// No more keys available.
_, err = walker.Next()
require.ErrorIs(t, err, keypool.ErrAllKeysExhausted)
var transient *keypool.TransientKeyPoolError
require.ErrorAs(t, err, &transient, "expected transient exhaustion: walker returned all valid keys, none marked permanent")
})
}
}
@@ -282,19 +285,21 @@ func TestWalkerNext(t *testing.T) {
t.Parallel()
tests := []struct {
name string
keys []string
setup func(t *testing.T, pool *keypool.Pool)
advance time.Duration
expectValid []string
name string
keys []string
setup func(t *testing.T, pool *keypool.Pool)
advance time.Duration
expectedValid []string
expectedErr error
}{
{
// Given: key-0: valid, key-1: valid, key-2: valid.
// Then: key-0: valid, key-1: valid, key-2: valid.
name: "all_keys_valid",
keys: []string{"key-0", "key-1", "key-2"},
setup: func(_ *testing.T, _ *keypool.Pool) {},
expectValid: []string{"key-0", "key-1", "key-2"},
name: "all_keys_valid",
keys: []string{"key-0", "key-1", "key-2"},
setup: func(_ *testing.T, _ *keypool.Pool) {},
expectedValid: []string{"key-0", "key-1", "key-2"},
expectedErr: &keypool.TransientKeyPoolError{},
},
{
// Given: key-0: temporary, key-1: valid, key-2: valid.
@@ -306,7 +311,8 @@ func TestWalkerNext(t *testing.T) {
require.NoError(t, err)
key.MarkTemporary(60 * time.Second)
},
expectValid: []string{"key-1", "key-2"},
expectedValid: []string{"key-1", "key-2"},
expectedErr: &keypool.TransientKeyPoolError{},
},
{
// Given: key-0: permanent, key-1: permanent, key-2: valid.
@@ -322,7 +328,8 @@ func TestWalkerNext(t *testing.T) {
require.NoError(t, err)
key1.MarkPermanent()
},
expectValid: []string{"key-2"},
expectedValid: []string{"key-2"},
expectedErr: &keypool.TransientKeyPoolError{},
},
{
// Given: key-0: temporary (30s), key-1: valid.
@@ -335,8 +342,9 @@ func TestWalkerNext(t *testing.T) {
require.NoError(t, err)
key.MarkTemporary(30 * time.Second)
},
advance: 35 * time.Second,
expectValid: []string{"key-0", "key-1"},
advance: 35 * time.Second,
expectedValid: []string{"key-0", "key-1"},
expectedErr: &keypool.TransientKeyPoolError{},
},
{
// Given: key-0: temporary (zero, default 60s), key-1: valid.
@@ -349,8 +357,9 @@ func TestWalkerNext(t *testing.T) {
require.NoError(t, err)
key.MarkTemporary(0)
},
advance: 50 * time.Second,
expectValid: []string{"key-1"},
advance: 50 * time.Second,
expectedValid: []string{"key-1"},
expectedErr: &keypool.TransientKeyPoolError{},
},
{
// Given: key-0: temporary (zero, default 60s), key-1: valid.
@@ -363,8 +372,9 @@ func TestWalkerNext(t *testing.T) {
require.NoError(t, err)
key.MarkTemporary(0)
},
advance: 65 * time.Second,
expectValid: []string{"key-0", "key-1"},
advance: 65 * time.Second,
expectedValid: []string{"key-0", "key-1"},
expectedErr: &keypool.TransientKeyPoolError{},
},
{
// Given: key-0: temporary (negative, default 60s), key-1: valid.
@@ -377,13 +387,14 @@ func TestWalkerNext(t *testing.T) {
require.NoError(t, err)
key.MarkTemporary(-10 * time.Second)
},
advance: 65 * time.Second,
expectValid: []string{"key-0", "key-1"},
advance: 65 * time.Second,
expectedValid: []string{"key-0", "key-1"},
expectedErr: &keypool.TransientKeyPoolError{},
},
{
// Given: key-0: temporary (60s), then marked again with shorter cooldown (10s).
// When: 15s pass (past 10s, but not 60s).
// Then: key-0: temporary.
// Then: key-0: temporary, 45s remaining.
name: "shorter_cooldown_preserves_longer_not_expired",
keys: []string{"key-0"},
setup: func(t *testing.T, pool *keypool.Pool) {
@@ -392,8 +403,9 @@ func TestWalkerNext(t *testing.T) {
key.MarkTemporary(60 * time.Second)
key.MarkTemporary(10 * time.Second)
},
advance: 15 * time.Second,
expectValid: []string{},
advance: 15 * time.Second,
expectedValid: []string{},
expectedErr: &keypool.TransientKeyPoolError{RetryAfter: 45 * time.Second},
},
{
// Given: key-0: temporary (60s), then marked again with shorter cooldown (10s).
@@ -407,8 +419,30 @@ func TestWalkerNext(t *testing.T) {
key.MarkTemporary(60 * time.Second)
key.MarkTemporary(10 * time.Second)
},
advance: 65 * time.Second,
expectValid: []string{"key-0"},
advance: 65 * time.Second,
expectedValid: []string{"key-0"},
expectedErr: &keypool.TransientKeyPoolError{},
},
{
// Given: key-0: temporary (60s), key-1: temporary (10s), key-2: temporary (30s).
// Then: key-0: temporary, key-1: temporary, key-2: temporary.
// Smallest remaining cooldown is reported on exhaustion.
name: "smallest_cooldown_across_temporary_keys",
keys: []string{"key-0", "key-1", "key-2"},
setup: func(t *testing.T, pool *keypool.Pool) {
walker := pool.Walker()
key0, err := walker.Next()
require.NoError(t, err)
key0.MarkTemporary(60 * time.Second)
key1, err := walker.Next()
require.NoError(t, err)
key1.MarkTemporary(10 * time.Second)
key2, err := walker.Next()
require.NoError(t, err)
key2.MarkTemporary(30 * time.Second)
},
expectedValid: []string{},
expectedErr: &keypool.TransientKeyPoolError{RetryAfter: 10 * time.Second},
},
{
// Given: key-0: temporary, key-1: temporary.
@@ -424,7 +458,8 @@ func TestWalkerNext(t *testing.T) {
require.NoError(t, err)
key1.MarkTemporary(60 * time.Second)
},
expectValid: []string{},
expectedValid: []string{},
expectedErr: &keypool.TransientKeyPoolError{RetryAfter: 60 * time.Second},
},
{
// Given: key-0: permanent, key-1: permanent.
@@ -440,7 +475,8 @@ func TestWalkerNext(t *testing.T) {
require.NoError(t, err)
key1.MarkPermanent()
},
expectValid: []string{},
expectedValid: []string{},
expectedErr: keypool.ErrPermanentKeyPool,
},
{
// Given: key-0: permanent, key-1: temporary, key-2: permanent.
@@ -459,7 +495,8 @@ func TestWalkerNext(t *testing.T) {
require.NoError(t, err)
key2.MarkPermanent()
},
expectValid: []string{},
expectedValid: []string{},
expectedErr: &keypool.TransientKeyPoolError{RetryAfter: 60 * time.Second},
},
}
@@ -478,7 +515,7 @@ func TestWalkerNext(t *testing.T) {
}
walker := pool.Walker()
for _, expectedKey := range tc.expectValid {
for _, expectedKey := range tc.expectedValid {
key, err := walker.Next()
require.NoError(t, err)
assert.Equal(t, expectedKey, key.Value())
@@ -486,7 +523,93 @@ func TestWalkerNext(t *testing.T) {
// After all expected keys, the walker should be exhausted.
_, err = walker.Next()
require.ErrorIs(t, err, keypool.ErrAllKeysExhausted)
var wantTransient *keypool.TransientKeyPoolError
if errors.As(tc.expectedErr, &wantTransient) {
var got *keypool.TransientKeyPoolError
require.ErrorAs(t, err, &got)
assert.Equal(t, wantTransient.RetryAfter, got.RetryAfter)
} else {
require.ErrorIs(t, err, tc.expectedErr)
}
})
}
}
// TestKeyConcurrent exercises the documented concurrent-safety
// contract by hammering a single key with concurrent Mark calls
// and asserting the resulting state honors the pool's invariants.
func TestKeyConcurrent(t *testing.T) {
t.Parallel()
tests := []struct {
name string
// run is called concurrently from numGoroutines, each
// with its own index.
run func(idx int, key *keypool.Key)
// verify asserts the final state. May advance the clock.
verify func(t *testing.T, key *keypool.Key, clk *quartz.Mock)
}{
{
// Half of the goroutines mark the key as temporary
// with 60s, the other half with 10s. The longer
// cooldown must win regardless of ordering.
name: "longer_cooldown_wins",
run: func(idx int, key *keypool.Key) {
if idx%2 == 0 {
key.MarkTemporary(60 * time.Second)
} else {
key.MarkTemporary(10 * time.Second)
}
},
verify: func(t *testing.T, key *keypool.Key, clk *quartz.Mock) {
// At 50s the 60s cooldown is still active.
clk.Advance(50 * time.Second)
assert.Equal(t, keypool.KeyStateTemporary, key.State())
// At 65s the 60s cooldown has expired.
clk.Advance(15 * time.Second)
assert.Equal(t, keypool.KeyStateValid, key.State())
},
},
{
// Half of the goroutines mark the key as permanent,
// the other half mark it as temporary. Permanent is
// terminal: any permanent call wins.
name: "permanent_wins_over_temporary",
run: func(idx int, key *keypool.Key) {
if idx%2 == 0 {
key.MarkPermanent()
} else {
key.MarkTemporary(60 * time.Second)
}
},
verify: func(t *testing.T, key *keypool.Key, _ *quartz.Mock) {
assert.Equal(t, keypool.KeyStatePermanent, key.State())
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
clk := quartz.NewMock(t)
pool, err := keypool.New([]string{"key-0"}, clk)
require.NoError(t, err)
key, err := pool.Walker().Next()
require.NoError(t, err)
const numGoroutines = 10
var wg sync.WaitGroup
for r := range numGoroutines {
wg.Add(1)
go func(r int) {
defer wg.Done()
tc.run(r, key)
}(r)
}
wg.Wait()
tc.verify(t, key, clk)
})
}
}
+52 -11
View File
@@ -15,8 +15,10 @@ import (
"github.com/coder/coder/v2/aibridge/config"
"github.com/coder/coder/v2/aibridge/intercept"
"github.com/coder/coder/v2/aibridge/intercept/messages"
"github.com/coder/coder/v2/aibridge/keypool"
"github.com/coder/coder/v2/aibridge/tracing"
"github.com/coder/coder/v2/aibridge/utils"
"github.com/coder/quartz"
)
// anthropicForwardHeaders lists headers from incoming requests that should be
@@ -55,6 +57,24 @@ func NewAnthropic(cfg config.Anthropic, bedrockCfg *config.AWSBedrock) *Anthropi
if cfg.BaseURL == "" {
cfg.BaseURL = "https://api.anthropic.com/"
}
// Resolve centralized key configuration into KeyPool.
// Precedence:
// 1. cfg.KeyPool (explicit, highest priority).
// 2. cfg.Key (legacy single key).
// After this block cfg.Key is empty so it can only carry a
// BYOK X-Api-Key set per interception in CreateInterceptor.
// TODO(ssncferreira): simplify auth field resolution per
// https://github.com/coder/aibridge/issues/266.
if cfg.KeyPool == nil && cfg.Key != "" {
// keypool.New only fails on empty or duplicate keys,
// neither possible with a single non-empty key.
pool, err := keypool.New([]string{cfg.Key}, quartz.NewReal())
if err != nil {
panic(fmt.Sprintf("anthropic provider: build single-key pool: %s", err))
}
cfg.KeyPool = pool
}
cfg.Key = ""
if cfg.CircuitBreaker != nil {
cfg.CircuitBreaker.IsFailure = anthropicIsFailure
cfg.CircuitBreaker.OpenErrorResponse = anthropicOpenErrorResponse
@@ -119,29 +139,41 @@ func (p *Anthropic) CreateInterceptor(_ http.ResponseWriter, r *http.Request, tr
// Any Coder-specific authentication has already been stripped.
//
// In centralized mode neither Authorization nor X-Api-Key is
// present, so cfg keeps the centralized key unchanged.
// present, so cfg keeps the KeyPool from provider construction
// and the failover loop walks it.
//
// In BYOK mode the user's LLM credentials survive intact.
// If X-Api-Key is present the user has a personal API key;
// overwrite the centralized key with it. If Authorization is
// present the user authenticated directly with provider;
// set BYOKBearerToken and clear the centralized key.
// When both are present, X-Api-Key takes priority to match
// claude-code behavior.
// In BYOK mode the user's LLM credentials survive intact and
// failover is disabled by clearing cfg.KeyPool. If X-Api-Key is
// present the user has a personal API key, populate cfg.Key.
// If Authorization is present the user authenticated directly
// with the provider, populate cfg.BYOKBearerToken. When both
// are present, X-Api-Key takes priority to match claude-code
// behavior.
//
// TODO(ssncferreira): consolidate auth field handling per
// https://github.com/coder/aibridge/issues/266.
credKind := intercept.CredentialKindCentralized
credSecret := cfg.Key
var credSecret string
authHeaderName := p.AuthHeader()
if apiKey := r.Header.Get("X-Api-Key"); apiKey != "" {
cfg.Key = apiKey
cfg.KeyPool = nil
authHeaderName = "X-Api-Key"
credKind = intercept.CredentialKindBYOK
credSecret = apiKey
} else if token := utils.ExtractBearerToken(r.Header.Get("Authorization")); token != "" {
cfg.BYOKBearerToken = token
cfg.Key = ""
cfg.KeyPool = nil
authHeaderName = "Authorization"
credKind = intercept.CredentialKindBYOK
credSecret = token
} else if cfg.KeyPool != nil {
// Centralized: use the first key as a placeholder hint.
// TODO(ssncferreira): record the actually-used key in
// the interception record to reflect failover.
if k, err := cfg.KeyPool.Walker().Next(); err == nil {
credSecret = k.Value()
}
}
cred := intercept.NewCredentialInfo(credKind, credSecret)
@@ -175,7 +207,16 @@ func (p *Anthropic) InjectAuthHeader(headers *http.Header) {
return
}
headers.Set(p.AuthHeader(), p.cfg.Key)
// Centralized: pull a single key from the pool. No failover
// or exhaustion handling here.
// TODO(ssncferreira): replace with RoundTripper-based auth
// in the upstack passthrough PR.
if p.cfg.KeyPool == nil {
return
}
if key, err := p.cfg.KeyPool.Walker().Next(); err == nil {
headers.Set(p.AuthHeader(), key.Value())
}
}
func (p *Anthropic) CircuitBreakerConfig() *config.CircuitBreaker {
+66
View File
@@ -13,6 +13,8 @@ import (
"github.com/coder/coder/v2/aibridge/config"
"github.com/coder/coder/v2/aibridge/intercept"
"github.com/coder/coder/v2/aibridge/internal/testutil"
"github.com/coder/coder/v2/aibridge/keypool"
"github.com/coder/quartz"
)
func TestAnthropic_TypeAndName(t *testing.T) {
@@ -49,6 +51,70 @@ func TestAnthropic_TypeAndName(t *testing.T) {
}
}
func TestNewAnthropic_KeyResolution(t *testing.T) {
t.Parallel()
pool, err := keypool.New([]string{"pool-key-0", "pool-key-1"}, quartz.NewMock(t))
require.NoError(t, err)
tests := []struct {
name string
cfg config.Anthropic
expectedKeys []string
}{
{
// Legacy single-key path: NewAnthropic builds a
// pool containing just that key.
name: "key_creates_keypool",
cfg: config.Anthropic{Key: "legacy-key"},
expectedKeys: []string{"legacy-key"},
},
{
// Caller supplies the pool directly.
name: "keypool_passed_directly",
cfg: config.Anthropic{KeyPool: pool},
expectedKeys: []string{"pool-key-0", "pool-key-1"},
},
{
// Both set: KeyPool wins, Key is ignored.
name: "keypool_takes_precedence_over_key",
cfg: config.Anthropic{Key: "legacy-key", KeyPool: pool},
expectedKeys: []string{"pool-key-0", "pool-key-1"},
},
{
// Neither set: no centralized auth available. BYOK
// auth is set per-request in CreateInterceptor.
name: "neither_set_no_centralized_auth",
cfg: config.Anthropic{},
expectedKeys: nil,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
p := NewAnthropic(tc.cfg, nil)
if tc.expectedKeys == nil {
assert.Nil(t, p.cfg.KeyPool, "expected no KeyPool")
return
}
require.NotNil(t, p.cfg.KeyPool)
walker := p.cfg.KeyPool.Walker()
var got []string
for {
key, err := walker.Next()
if err != nil {
break
}
got = append(got, key.Value())
}
assert.Equal(t, tc.expectedKeys, got)
})
}
}
func TestAnthropic_CreateInterceptor(t *testing.T) {
t.Parallel()