mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
feat: add automatic key failover for AI Bridge Anthropic (#24836)
## Description Adds automatic key failover for centralized Anthropic provider. When a key pool is configured, each upstream call walks the pool and tries keys in order until one succeeds or the pool is exhausted. Keys are marked **temporary** on 429 (with cooldown from `Retry-After`) and **permanent** on 401/403. Errors that aren't key-specific don't trigger failover. Each agentic-loop iteration gets its own fresh walker, so a tool-call continuation can fail over independently of the initial request. BYOK is unchanged: BYOK requests run as a single attempt with no failover. ## Changes - `config.Anthropic` carries a `KeyPool`. `Key` remains for BYOK X-Api-Key set per interception. - Blocking interceptor: walks the pool, marks keys on key-specific failures, returns on first success or non-failover error. - Streaming interceptor: per-iteration walker. Pre-stream failures fail over to the next key; mid-stream errors are relayed as SSE events. - New `keypool` error types: `TransientExhaustionError` (carries soonest cooldown) and `ErrPermanentExhaustion`. Replace the prior `ErrAllKeysExhausted`. - Error responses now consistently include the outer `"type": "error"` field. ## Related Issues Related to: https://github.com/coder/internal/issues/1446 Related to: https://linear.app/codercom/issue/AIGOV-197/aibridge-automatic-key-failover-for-bridged-and-passthrough-routes ## Follow-up PRs - Bedrock multi-key support. - Refactor provider vs interceptor config separation. - Record the actually-used key in the interception credential hint after failover. > [!NOTE] > Initially generated by Claude Opus 4.7, modified and reviewed by @ssncferreira
This commit is contained in:
@@ -1,6 +1,10 @@
|
||||
package config
|
||||
|
||||
import "time"
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/coder/coder/v2/aibridge/keypool"
|
||||
)
|
||||
|
||||
const (
|
||||
ProviderAnthropic = "anthropic"
|
||||
@@ -8,11 +12,24 @@ const (
|
||||
ProviderCopilot = "copilot"
|
||||
)
|
||||
|
||||
// Anthropic carries configuration for an Anthropic provider.
|
||||
//
|
||||
// Authentication is mutually exclusive across these three fields,
|
||||
// set per interception in the provider's CreateInterceptor:
|
||||
// - KeyPool: centralized requests with automatic key failover.
|
||||
// - Key: BYOK with X-Api-Key (single attempt, no failover).
|
||||
// - BYOKBearerToken: BYOK with Authorization Bearer (single
|
||||
// attempt, no failover).
|
||||
//
|
||||
// TODO(ssncferreira): consolidate the three authentication
|
||||
// fields into a single abstraction per
|
||||
// https://github.com/coder/aibridge/issues/266.
|
||||
type Anthropic struct {
|
||||
// Name is the provider instance name. If empty, defaults to "anthropic".
|
||||
Name string
|
||||
BaseURL string
|
||||
Key string
|
||||
KeyPool *keypool.Pool
|
||||
APIDumpDir string
|
||||
CircuitBreaker *CircuitBreaker
|
||||
SendActorHeaders bool
|
||||
|
||||
@@ -5,7 +5,9 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -26,6 +28,7 @@ import (
|
||||
aibcontext "github.com/coder/coder/v2/aibridge/context"
|
||||
"github.com/coder/coder/v2/aibridge/intercept"
|
||||
"github.com/coder/coder/v2/aibridge/intercept/apidump"
|
||||
"github.com/coder/coder/v2/aibridge/keypool"
|
||||
"github.com/coder/coder/v2/aibridge/mcp"
|
||||
"github.com/coder/coder/v2/aibridge/recorder"
|
||||
"github.com/coder/coder/v2/aibridge/tracing"
|
||||
@@ -202,19 +205,28 @@ func (i *interceptionBase) isSmallFastModel() bool {
|
||||
return strings.Contains(i.reqPayload.model(), "haiku")
|
||||
}
|
||||
|
||||
// newMessagesService builds the SDK service used for upstream
|
||||
// calls. BYOK auth is set here. Centralized auth is set
|
||||
// per-attempt by the failover loop.
|
||||
func (i *interceptionBase) newMessagesService(ctx context.Context, opts ...option.RequestOption) (anthropic.MessageService, error) {
|
||||
// BYOK with access token uses Authorization: Bearer.
|
||||
// Otherwise use X-Api-Key (centralized or BYOK with personal API key).
|
||||
if i.cfg.BYOKBearerToken != "" {
|
||||
i.logger.Debug(ctx, "using byok access token auth",
|
||||
slog.F("bearer_hint", utils.MaskSecret(i.cfg.BYOKBearerToken)),
|
||||
)
|
||||
opts = append(opts, option.WithAuthToken(i.cfg.BYOKBearerToken))
|
||||
} else {
|
||||
i.logger.Debug(ctx, "using api key auth",
|
||||
slog.F("api_key_hint", utils.MaskSecret(i.cfg.Key)),
|
||||
)
|
||||
opts = append(opts, option.WithAPIKey(i.cfg.Key))
|
||||
// TODO(ssncferreira): validate auth is configured per
|
||||
// https://github.com/coder/aibridge/issues/266.
|
||||
|
||||
// BYOK auth.
|
||||
if i.cfg.KeyPool == nil {
|
||||
if i.cfg.BYOKBearerToken != "" {
|
||||
// BYOK Bearer: Authorization header.
|
||||
i.logger.Debug(ctx, "using byok access token auth",
|
||||
slog.F("bearer_hint", utils.MaskSecret(i.cfg.BYOKBearerToken)),
|
||||
)
|
||||
opts = append(opts, option.WithAuthToken(i.cfg.BYOKBearerToken))
|
||||
} else {
|
||||
// BYOK X-Api-Key.
|
||||
i.logger.Debug(ctx, "using api key auth",
|
||||
slog.F("api_key_hint", utils.MaskSecret(i.cfg.Key)),
|
||||
)
|
||||
opts = append(opts, option.WithAPIKey(i.cfg.Key))
|
||||
}
|
||||
}
|
||||
opts = append(opts, option.WithBaseURL(i.cfg.BaseURL))
|
||||
|
||||
@@ -427,6 +439,10 @@ func (i *interceptionBase) writeUpstreamError(w http.ResponseWriter, antErr *res
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
// Set Retry-After when a cooldown is configured.
|
||||
if antErr.RetryAfter > 0 {
|
||||
w.Header().Set("Retry-After", strconv.Itoa(int(math.Ceil(antErr.RetryAfter.Seconds()))))
|
||||
}
|
||||
w.WriteHeader(antErr.StatusCode)
|
||||
|
||||
out, err := json.Marshal(antErr)
|
||||
@@ -503,6 +519,49 @@ func accumulateUsage(dest, src any) {
|
||||
}
|
||||
}
|
||||
|
||||
// For centralized requests, markKeyOnError extracts an
|
||||
// Anthropic SDK error from err and marks the key based on
|
||||
// its status code. Returns true if the status was a key-specific
|
||||
// failover trigger so callers can retry with the next key.
|
||||
func (i *interceptionBase) markKeyOnError(ctx context.Context, key *keypool.Key, err error) bool {
|
||||
if i.cfg.KeyPool == nil {
|
||||
return false
|
||||
}
|
||||
var apiErr *anthropic.Error
|
||||
if !errors.As(err, &apiErr) {
|
||||
return false
|
||||
}
|
||||
return keypool.MarkKeyOnStatus(
|
||||
ctx, key, apiErr.Response,
|
||||
i.logger, i.providerName,
|
||||
)
|
||||
}
|
||||
|
||||
// processKeyPoolError translates a keypool exhaustion error
|
||||
// into a developer-facing responseError shaped for the Anthropic
|
||||
// API. Returns nil if err is not an exhaustion error.
|
||||
func processKeyPoolError(err error) *responseError {
|
||||
var transient *keypool.TransientKeyPoolError
|
||||
switch {
|
||||
case errors.As(err, &transient):
|
||||
return newErrorResponse(
|
||||
"all configured keys are rate-limited",
|
||||
string(constant.ValueOf[constant.RateLimitError]()),
|
||||
http.StatusTooManyRequests,
|
||||
transient.RetryAfter,
|
||||
)
|
||||
case errors.Is(err, keypool.ErrPermanentKeyPool):
|
||||
return newErrorResponse(
|
||||
"all configured keys failed authentication",
|
||||
string(constant.ValueOf[constant.APIError]()),
|
||||
http.StatusBadGateway,
|
||||
0,
|
||||
)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func getErrorResponse(err error) *responseError {
|
||||
var apierr *anthropic.Error
|
||||
if !errors.As(err, &apierr) {
|
||||
@@ -510,7 +569,7 @@ func getErrorResponse(err error) *responseError {
|
||||
}
|
||||
|
||||
msg := apierr.Error()
|
||||
typ := string(constant.ValueOf[constant.APIError]())
|
||||
errType := string(constant.ValueOf[constant.APIError]())
|
||||
|
||||
var detail *anthropic.APIErrorObject
|
||||
if field, ok := apierr.JSON.ExtraFields["error"]; ok {
|
||||
@@ -518,19 +577,10 @@ func getErrorResponse(err error) *responseError {
|
||||
}
|
||||
if detail != nil {
|
||||
msg = detail.Message
|
||||
typ = string(detail.Type)
|
||||
errType = string(detail.Type)
|
||||
}
|
||||
|
||||
return &responseError{
|
||||
ErrorResponse: &anthropic.ErrorResponse{
|
||||
Error: anthropic.ErrorObjectUnion{
|
||||
Message: msg,
|
||||
Type: typ,
|
||||
},
|
||||
Type: constant.ValueOf[constant.Error](),
|
||||
},
|
||||
StatusCode: apierr.StatusCode,
|
||||
}
|
||||
return newErrorResponse(msg, errType, apierr.StatusCode, keypool.ParseRetryAfter(apierr.Response))
|
||||
}
|
||||
|
||||
var _ error = &responseError{}
|
||||
@@ -538,17 +588,21 @@ var _ error = &responseError{}
|
||||
type responseError struct {
|
||||
*anthropic.ErrorResponse
|
||||
|
||||
StatusCode int `json:"-"`
|
||||
StatusCode int `json:"-"`
|
||||
RetryAfter time.Duration `json:"-"`
|
||||
}
|
||||
|
||||
func newErrorResponse(msg error) *responseError {
|
||||
func newErrorResponse(msg, errType string, status int, retryAfter time.Duration) *responseError {
|
||||
return &responseError{
|
||||
ErrorResponse: &shared.ErrorResponse{
|
||||
Error: shared.ErrorObjectUnion{
|
||||
Message: msg.Error(),
|
||||
Type: "error",
|
||||
Message: msg,
|
||||
Type: errType,
|
||||
},
|
||||
Type: constant.ValueOf[constant.Error](),
|
||||
},
|
||||
StatusCode: status,
|
||||
RetryAfter: retryAfter,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -3,18 +3,24 @@ package messages //nolint:testpackage // tests unexported internals
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/anthropics/anthropic-sdk-go"
|
||||
"github.com/anthropics/anthropic-sdk-go/shared/constant"
|
||||
mcpgo "github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/gjson"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/aibridge/config"
|
||||
"github.com/coder/coder/v2/aibridge/keypool"
|
||||
"github.com/coder/coder/v2/aibridge/mcp"
|
||||
"github.com/coder/coder/v2/aibridge/utils"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
func TestScanForCorrelatingToolCallID(t *testing.T) {
|
||||
@@ -991,3 +997,187 @@ func TestFilterBedrockBetaFlags(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessKeyPoolError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
expectedNil bool
|
||||
expectedStatus int
|
||||
expectedRetryAfter time.Duration
|
||||
}{
|
||||
{
|
||||
// Transient with valid keys present: 429, no Retry-After.
|
||||
name: "transient_zero_retry_after",
|
||||
err: &keypool.TransientKeyPoolError{},
|
||||
expectedStatus: http.StatusTooManyRequests,
|
||||
expectedRetryAfter: 0,
|
||||
},
|
||||
{
|
||||
// Transient with cooldown: 429, Retry-After set.
|
||||
name: "transient_with_retry_after",
|
||||
err: &keypool.TransientKeyPoolError{RetryAfter: 5 * time.Second},
|
||||
expectedStatus: http.StatusTooManyRequests,
|
||||
expectedRetryAfter: 5 * time.Second,
|
||||
},
|
||||
{
|
||||
// Permanent: 502 api_error.
|
||||
name: "permanent_returns_502",
|
||||
err: keypool.ErrPermanentKeyPool,
|
||||
expectedStatus: http.StatusBadGateway,
|
||||
},
|
||||
{
|
||||
// Anything else: not a pool-exhaustion error.
|
||||
name: "non_pool_exhaustion_error_returns_nil",
|
||||
err: xerrors.New("some other error"),
|
||||
expectedNil: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := processKeyPoolError(tc.err)
|
||||
if tc.expectedNil {
|
||||
require.Nil(t, got)
|
||||
return
|
||||
}
|
||||
require.NotNil(t, got)
|
||||
assert.Equal(t, tc.expectedStatus, got.StatusCode)
|
||||
assert.Equal(t, tc.expectedRetryAfter, got.RetryAfter)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkKeyOnError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
expectedReturn bool
|
||||
expectedState keypool.KeyState
|
||||
}{
|
||||
{
|
||||
// Not an *anthropic.Error: no status code to act on.
|
||||
name: "non_api_error_returns_false",
|
||||
err: xerrors.New("network failure"),
|
||||
expectedReturn: false,
|
||||
expectedState: keypool.KeyStateValid,
|
||||
},
|
||||
{
|
||||
// Rate-limited: temporary cooldown.
|
||||
name: "429_marks_temporary",
|
||||
err: &anthropic.Error{StatusCode: http.StatusTooManyRequests, Response: &http.Response{StatusCode: http.StatusTooManyRequests}},
|
||||
expectedReturn: true,
|
||||
expectedState: keypool.KeyStateTemporary,
|
||||
},
|
||||
{
|
||||
// Auth failure: mark permanent.
|
||||
name: "401_marks_permanent",
|
||||
err: &anthropic.Error{StatusCode: http.StatusUnauthorized, Response: &http.Response{StatusCode: http.StatusUnauthorized}},
|
||||
expectedReturn: true,
|
||||
expectedState: keypool.KeyStatePermanent,
|
||||
},
|
||||
{
|
||||
// Auth forbidden: mark permanent.
|
||||
name: "403_marks_permanent",
|
||||
err: &anthropic.Error{StatusCode: http.StatusForbidden, Response: &http.Response{StatusCode: http.StatusForbidden}},
|
||||
expectedReturn: true,
|
||||
expectedState: keypool.KeyStatePermanent,
|
||||
},
|
||||
{
|
||||
// Server errors are not key-specific.
|
||||
name: "500_does_not_mark",
|
||||
err: &anthropic.Error{StatusCode: http.StatusInternalServerError, Response: &http.Response{StatusCode: http.StatusInternalServerError}},
|
||||
expectedReturn: false,
|
||||
expectedState: keypool.KeyStateValid,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
pool, err := keypool.New([]string{"key-0"}, quartz.NewMock(t))
|
||||
require.NoError(t, err)
|
||||
key, err := pool.Walker().Next()
|
||||
require.NoError(t, err)
|
||||
|
||||
base := &interceptionBase{cfg: config.Anthropic{KeyPool: pool}, logger: slog.Make()}
|
||||
|
||||
got := base.markKeyOnError(context.Background(), key, tc.err)
|
||||
assert.Equal(t, tc.expectedReturn, got)
|
||||
assert.Equal(t, tc.expectedState, key.State())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteUpstreamError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
respErr *responseError
|
||||
expectStatus int
|
||||
// Empty string means the header should be absent.
|
||||
expectRetryAfter string
|
||||
// Substring expected in the marshaled body. Empty means no body check.
|
||||
expectBodyContains string
|
||||
}{
|
||||
{
|
||||
// Standard error: status and JSON body written.
|
||||
name: "writes_status_and_body",
|
||||
respErr: newErrorResponse("upstream failed", "api_error", http.StatusBadGateway, 0),
|
||||
expectStatus: http.StatusBadGateway,
|
||||
expectBodyContains: `"upstream failed"`,
|
||||
},
|
||||
{
|
||||
// Whole-second retryAfter: emitted as integer seconds.
|
||||
name: "retry_after_in_seconds",
|
||||
respErr: newErrorResponse("rate limited", "rate_limit_error", http.StatusTooManyRequests, 60*time.Second),
|
||||
expectStatus: http.StatusTooManyRequests,
|
||||
expectRetryAfter: "60",
|
||||
},
|
||||
{
|
||||
// 500ms rounds up to Retry-After: 1.
|
||||
name: "retry_after_500ms_rounds_up_to_one",
|
||||
respErr: newErrorResponse("rate limited", "rate_limit_error", http.StatusTooManyRequests, 500*time.Millisecond),
|
||||
expectStatus: http.StatusTooManyRequests,
|
||||
expectRetryAfter: "1",
|
||||
},
|
||||
{
|
||||
// 200ms rounds up to Retry-After: 1.
|
||||
name: "retry_after_200ms_rounds_up_to_one",
|
||||
respErr: newErrorResponse("rate limited", "rate_limit_error", http.StatusTooManyRequests, 200*time.Millisecond),
|
||||
expectStatus: http.StatusTooManyRequests,
|
||||
expectRetryAfter: "1",
|
||||
},
|
||||
{
|
||||
// Negative retryAfter: header omitted.
|
||||
name: "negative_retry_after_omits_header",
|
||||
respErr: newErrorResponse("rate limited", "rate_limit_error", http.StatusTooManyRequests, -1*time.Second),
|
||||
expectStatus: http.StatusTooManyRequests,
|
||||
expectRetryAfter: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
base := &interceptionBase{logger: slog.Make()}
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
base.writeUpstreamError(w, tc.respErr)
|
||||
|
||||
assert.Equal(t, tc.expectStatus, w.Code, "status code")
|
||||
assert.Equal(t, "application/json", w.Header().Get("Content-Type"), "Content-Type header")
|
||||
assert.Equal(t, tc.expectRetryAfter, w.Header().Get("Retry-After"), "Retry-After header")
|
||||
assert.Contains(t, w.Body.String(), `"type":"error"`, "outer error envelope")
|
||||
if tc.expectBodyContains != "" {
|
||||
assert.Contains(t, w.Body.String(), tc.expectBodyContains, "response body")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -112,6 +112,13 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req
|
||||
return xerrors.Errorf("upstream connection closed: %w", err)
|
||||
}
|
||||
|
||||
// The failover loop may return a keypool exhaustion
|
||||
// error. Check before the SDK-error path.
|
||||
if keyErr := processKeyPoolError(err); keyErr != nil {
|
||||
i.writeUpstreamError(w, keyErr)
|
||||
return xerrors.Errorf("key pool exhausted: %w", err)
|
||||
}
|
||||
|
||||
if antErr := getErrorResponse(err); antErr != nil {
|
||||
i.writeUpstreamError(w, antErr)
|
||||
return xerrors.Errorf("anthropic API error: %w", err)
|
||||
@@ -334,9 +341,53 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *BlockingInterception) newMessage(ctx context.Context, svc anthropic.MessageService) (_ *anthropic.Message, outErr error) {
|
||||
ctx, span := i.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...))
|
||||
// newMessage routes between BYOK (single attempt) and centralized
|
||||
// failover.
|
||||
func (i *BlockingInterception) newMessage(ctx context.Context, svc anthropic.MessageService) (*anthropic.Message, error) {
|
||||
// BYOK: single attempt, no failover.
|
||||
if i.cfg.KeyPool == nil {
|
||||
return i.newMessageWithKey(ctx, svc)
|
||||
}
|
||||
return i.newMessageWithKeyFailover(ctx, svc)
|
||||
}
|
||||
|
||||
// newMessageWithKey performs a single upstream call.
|
||||
func (i *BlockingInterception) newMessageWithKey(ctx context.Context, svc anthropic.MessageService, extraOpts ...option.RequestOption) (_ *anthropic.Message, outErr error) {
|
||||
_, span := i.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...))
|
||||
defer tracing.EndSpanErr(span, &outErr)
|
||||
|
||||
return svc.New(ctx, anthropic.MessageNewParams{}, i.withBody())
|
||||
opts := append([]option.RequestOption{i.withBody()}, extraOpts...)
|
||||
return svc.New(ctx, anthropic.MessageNewParams{}, opts...)
|
||||
}
|
||||
|
||||
// newMessageWithKeyFailover walks the centralized key pool,
|
||||
// trying each key until one succeeds or the pool is exhausted.
|
||||
// Keys are marked temporary on 429 and permanent on 401/403.
|
||||
// Errors that aren't key-specific don't trigger failover and
|
||||
// are returned to the caller.
|
||||
func (i *BlockingInterception) newMessageWithKeyFailover(ctx context.Context, svc anthropic.MessageService) (*anthropic.Message, error) {
|
||||
// TODO(ssncferreira): update the interception's credential
|
||||
// hint with the actually-used key (the successful key on
|
||||
// success, the last tried key on failure) in the upstack PR.
|
||||
walker := i.cfg.KeyPool.Walker()
|
||||
for {
|
||||
key, err := walker.Next()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
msg, err := i.newMessageWithKey(ctx, svc,
|
||||
option.WithAPIKey(key.Value()),
|
||||
// Disable SDK retries because the failover loop
|
||||
// handles retries via key rotation.
|
||||
option.WithMaxRetries(0),
|
||||
)
|
||||
// Key-specific failure: try the next key.
|
||||
if i.markKeyOnError(ctx, key, err) {
|
||||
continue
|
||||
}
|
||||
// Either success (msg, nil) or a non-key error (nil, err):
|
||||
// nothing to retry, return as-is.
|
||||
return msg, err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,457 @@
|
||||
package messages //nolint:testpackage // tests unexported internals
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/aibridge/config"
|
||||
"github.com/coder/coder/v2/aibridge/intercept"
|
||||
"github.com/coder/coder/v2/aibridge/internal/testutil"
|
||||
"github.com/coder/coder/v2/aibridge/keypool"
|
||||
"github.com/coder/coder/v2/aibridge/mcp"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
// Common request and Anthropic-shaped response bodies.
|
||||
const (
|
||||
requestBody = `{"model":"claude-opus-4-5","max_tokens":1024,"messages":[{"role":"user","content":"hi"}]}`
|
||||
successBody = `{"id":"msg_01","type":"message","role":"assistant","content":[{"type":"text","text":"Hello!"}],"model":"claude-opus-4-5","stop_reason":"end_turn","usage":{"input_tokens":10,"output_tokens":5}}`
|
||||
toolUseBody = `{"id":"msg_01","type":"message","role":"assistant","content":[{"type":"tool_use","id":"toolu_01","name":"test_tool","input":{}}],"model":"claude-opus-4-5","stop_reason":"tool_use","usage":{"input_tokens":10,"output_tokens":5}}`
|
||||
rateLimitBody = `{"type":"error","error":{"type":"rate_limit_error","message":"rate limited"}}`
|
||||
authErrorBody = `{"type":"error","error":{"type":"authentication_error","message":"invalid key"}}`
|
||||
serverErrorBody = `{"type":"error","error":{"type":"api_error","message":"server error"}}`
|
||||
)
|
||||
|
||||
type upstreamResponse struct {
|
||||
statusCode int
|
||||
body string
|
||||
headers map[string]string
|
||||
}
|
||||
|
||||
func TestBlockingInterception_KeyFailover(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
// Centralized pool keys. Empty when byokKey is set.
|
||||
keys []string
|
||||
// BYOK key. Empty when keys is set.
|
||||
byokKey string
|
||||
// Scripted upstream responses keyed by X-Api-Key.
|
||||
responses map[string]upstreamResponse
|
||||
expectedRequestCount int32
|
||||
expectedStatusCode int
|
||||
expectedRetryAfter string
|
||||
// Expected key states after the request, by index in keys.
|
||||
expectedKeyStates []keypool.KeyState
|
||||
}{
|
||||
{
|
||||
// Given: 1 valid key returning 200.
|
||||
// Then: 1 request, 200 response, key remains valid.
|
||||
name: "single_valid_key",
|
||||
keys: []string{"k0"},
|
||||
responses: map[string]upstreamResponse{
|
||||
"k0": {statusCode: http.StatusOK, body: successBody},
|
||||
},
|
||||
expectedRequestCount: 1,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
expectedKeyStates: []keypool.KeyState{keypool.KeyStateValid},
|
||||
},
|
||||
{
|
||||
// Given: 2 keys; key-0 returns 429, key-1 returns 200.
|
||||
// Then: 2 requests, 200 response, key-0 temporary, key-1 valid.
|
||||
name: "failover_after_429",
|
||||
keys: []string{"k0", "k1"},
|
||||
responses: map[string]upstreamResponse{
|
||||
"k0": {
|
||||
statusCode: http.StatusTooManyRequests,
|
||||
headers: map[string]string{"Retry-After": "5"},
|
||||
body: rateLimitBody,
|
||||
},
|
||||
"k1": {statusCode: http.StatusOK, body: successBody},
|
||||
},
|
||||
expectedRequestCount: 2,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
expectedKeyStates: []keypool.KeyState{
|
||||
keypool.KeyStateTemporary,
|
||||
keypool.KeyStateValid,
|
||||
},
|
||||
},
|
||||
{
|
||||
// Given: 2 keys; key-0 returns 401, key-1 returns 200.
|
||||
// Then: 2 requests, 200 response, key-0 permanent, key-1 valid.
|
||||
name: "failover_after_401",
|
||||
keys: []string{"k0", "k1"},
|
||||
responses: map[string]upstreamResponse{
|
||||
"k0": {statusCode: http.StatusUnauthorized, body: authErrorBody},
|
||||
"k1": {statusCode: http.StatusOK, body: successBody},
|
||||
},
|
||||
expectedRequestCount: 2,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
expectedKeyStates: []keypool.KeyState{
|
||||
keypool.KeyStatePermanent,
|
||||
keypool.KeyStateValid,
|
||||
},
|
||||
},
|
||||
{
|
||||
// Given: 2 keys; key-0 returns 403, key-1 returns 200.
|
||||
// Then: 2 requests, 200 response, key-0 permanent, key-1 valid.
|
||||
name: "failover_after_403",
|
||||
keys: []string{"k0", "k1"},
|
||||
responses: map[string]upstreamResponse{
|
||||
"k0": {statusCode: http.StatusForbidden, body: authErrorBody},
|
||||
"k1": {statusCode: http.StatusOK, body: successBody},
|
||||
},
|
||||
expectedRequestCount: 2,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
expectedKeyStates: []keypool.KeyState{
|
||||
keypool.KeyStatePermanent,
|
||||
keypool.KeyStateValid,
|
||||
},
|
||||
},
|
||||
{
|
||||
// Given: 3 keys; all return 429 with cooldowns 5s, 3s, 10s.
|
||||
// Then: 3 requests, 429 response with smallest Retry-After,
|
||||
// all keys temporary.
|
||||
name: "all_keys_rate_limited",
|
||||
keys: []string{"k0", "k1", "k2"},
|
||||
responses: map[string]upstreamResponse{
|
||||
"k0": {
|
||||
statusCode: http.StatusTooManyRequests,
|
||||
headers: map[string]string{"Retry-After": "5"},
|
||||
body: rateLimitBody,
|
||||
},
|
||||
"k1": {
|
||||
statusCode: http.StatusTooManyRequests,
|
||||
headers: map[string]string{"Retry-After": "3"},
|
||||
body: rateLimitBody,
|
||||
},
|
||||
"k2": {
|
||||
statusCode: http.StatusTooManyRequests,
|
||||
headers: map[string]string{"Retry-After": "10"},
|
||||
body: rateLimitBody,
|
||||
},
|
||||
},
|
||||
expectedRequestCount: 3,
|
||||
expectedStatusCode: http.StatusTooManyRequests,
|
||||
expectedRetryAfter: "3",
|
||||
expectedKeyStates: []keypool.KeyState{
|
||||
keypool.KeyStateTemporary,
|
||||
keypool.KeyStateTemporary,
|
||||
keypool.KeyStateTemporary,
|
||||
},
|
||||
},
|
||||
{
|
||||
// Given: 2 keys; both return 401.
|
||||
// Then: 2 requests, 502 api_error response, both keys permanent.
|
||||
name: "all_keys_unauthorized",
|
||||
keys: []string{"k0", "k1"},
|
||||
responses: map[string]upstreamResponse{
|
||||
"k0": {statusCode: http.StatusUnauthorized, body: authErrorBody},
|
||||
"k1": {statusCode: http.StatusUnauthorized, body: authErrorBody},
|
||||
},
|
||||
expectedRequestCount: 2,
|
||||
expectedStatusCode: http.StatusBadGateway,
|
||||
expectedKeyStates: []keypool.KeyState{
|
||||
keypool.KeyStatePermanent,
|
||||
keypool.KeyStatePermanent,
|
||||
},
|
||||
},
|
||||
{
|
||||
// Given: 2 keys; key-0 returns 500.
|
||||
// Then: 1 request, 500 response, both keys remain valid.
|
||||
name: "server_error_no_failover",
|
||||
keys: []string{"k0", "k1"},
|
||||
responses: map[string]upstreamResponse{
|
||||
"k0": {statusCode: http.StatusInternalServerError, body: serverErrorBody},
|
||||
},
|
||||
expectedRequestCount: 1,
|
||||
expectedStatusCode: http.StatusInternalServerError,
|
||||
expectedKeyStates: []keypool.KeyState{
|
||||
keypool.KeyStateValid,
|
||||
keypool.KeyStateValid,
|
||||
},
|
||||
},
|
||||
{
|
||||
// Given: BYOK with a single key returning 429.
|
||||
// Then: 1 request, 429 response, no failover, upstream
|
||||
// Retry-After propagated to the client.
|
||||
name: "byok_no_failover",
|
||||
byokKey: "user-byok",
|
||||
responses: map[string]upstreamResponse{
|
||||
"user-byok": {
|
||||
statusCode: http.StatusTooManyRequests,
|
||||
headers: map[string]string{
|
||||
"Retry-After": "5",
|
||||
// BYOK doesn't set MaxRetries(0);
|
||||
// suppress SDK retries to test a
|
||||
// single attempt.
|
||||
"x-should-retry": "false",
|
||||
},
|
||||
body: rateLimitBody,
|
||||
},
|
||||
},
|
||||
expectedRequestCount: 1,
|
||||
expectedStatusCode: http.StatusTooManyRequests,
|
||||
expectedRetryAfter: "5",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Mock upstream: counts requests and returns
|
||||
// scripted responses keyed by X-Api-Key. An unmapped
|
||||
// key falls through to 500 so misconfigured cases
|
||||
// surface via the status assertion.
|
||||
var requestCount atomic.Int32
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestCount.Add(1)
|
||||
_, _ = io.Copy(io.Discard, r.Body)
|
||||
resp, ok := tc.responses[r.Header.Get("X-Api-Key")]
|
||||
if !ok {
|
||||
resp = upstreamResponse{statusCode: http.StatusInternalServerError}
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
for hk, hv := range resp.headers {
|
||||
w.Header().Set(hk, hv)
|
||||
}
|
||||
w.WriteHeader(resp.statusCode)
|
||||
_, _ = w.Write([]byte(resp.body))
|
||||
}))
|
||||
t.Cleanup(upstream.Close)
|
||||
|
||||
cfg := config.Anthropic{BaseURL: upstream.URL + "/"}
|
||||
var pool *keypool.Pool
|
||||
if len(tc.keys) > 0 {
|
||||
var err error
|
||||
pool, err = keypool.New(tc.keys, quartz.NewMock(t))
|
||||
require.NoError(t, err)
|
||||
cfg.KeyPool = pool
|
||||
} else if tc.byokKey != "" {
|
||||
cfg.Key = tc.byokKey
|
||||
}
|
||||
|
||||
payload, err := NewRequestPayload([]byte(requestBody))
|
||||
require.NoError(t, err)
|
||||
|
||||
interceptor := NewBlockingInterceptor(
|
||||
uuid.New(),
|
||||
payload,
|
||||
config.ProviderAnthropic,
|
||||
cfg,
|
||||
nil,
|
||||
http.Header{},
|
||||
"X-Api-Key",
|
||||
otel.Tracer("blocking_test"),
|
||||
intercept.NewCredentialInfo(intercept.CredentialKindCentralized, ""),
|
||||
)
|
||||
interceptor.Setup(slog.Make(), &testutil.MockRecorder{}, nil)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
w := httptest.NewRecorder()
|
||||
err = interceptor.ProcessRequest(w, req)
|
||||
if tc.expectedStatusCode == http.StatusOK {
|
||||
require.NoError(t, err)
|
||||
} else {
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
assert.Equal(t, tc.expectedRequestCount, requestCount.Load(), "upstream request count")
|
||||
assert.Equal(t, tc.expectedStatusCode, w.Code, "response status code")
|
||||
assert.Equal(t, tc.expectedRetryAfter, w.Header().Get("Retry-After"), "Retry-After header")
|
||||
if pool != nil {
|
||||
assert.Equal(t, tc.expectedKeyStates, pool.PoolState(), "key states")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestBlockingInterception_AgenticLoopFailover covers the
|
||||
// scenarios that span an agentic-loop continuation: the initial
|
||||
// client request and the subsequent tool-call continuation can
|
||||
// each fail over independently. Each iteration gets its own
|
||||
// walker.
|
||||
func TestBlockingInterception_AgenticLoopFailover(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
// Scripted upstream responses consumed in order of
|
||||
// upstream request.
|
||||
responses []upstreamResponse
|
||||
expectedRequestCount int32
|
||||
expectedSeenKeys []string
|
||||
expectedStatusCode int
|
||||
expectedRetryAfter string
|
||||
expectedKeyStates []keypool.KeyState
|
||||
}{
|
||||
{
|
||||
// Given: 2 keys; both upstream calls succeed on key-0.
|
||||
// Then: 2 requests, 200 response, both keys remain valid.
|
||||
name: "happy_path",
|
||||
responses: []upstreamResponse{
|
||||
{statusCode: http.StatusOK, body: toolUseBody},
|
||||
{statusCode: http.StatusOK, body: successBody},
|
||||
},
|
||||
expectedRequestCount: 2,
|
||||
expectedSeenKeys: []string{"k0", "k0"},
|
||||
expectedStatusCode: http.StatusOK,
|
||||
expectedKeyStates: []keypool.KeyState{
|
||||
keypool.KeyStateValid,
|
||||
keypool.KeyStateValid,
|
||||
},
|
||||
},
|
||||
{
|
||||
// Given: 2 keys; key-0 succeeds initially, then 429s
|
||||
// during the agentic continuation, key-1 succeeds.
|
||||
// Then: 3 requests, 200 response, key-0 temporary,
|
||||
// key-1 valid.
|
||||
name: "agentic_failover_to_k1",
|
||||
responses: []upstreamResponse{
|
||||
{statusCode: http.StatusOK, body: toolUseBody},
|
||||
{
|
||||
statusCode: http.StatusTooManyRequests,
|
||||
headers: map[string]string{"Retry-After": "5"},
|
||||
body: rateLimitBody,
|
||||
},
|
||||
{statusCode: http.StatusOK, body: successBody},
|
||||
},
|
||||
expectedRequestCount: 3,
|
||||
expectedSeenKeys: []string{"k0", "k0", "k1"},
|
||||
expectedStatusCode: http.StatusOK,
|
||||
expectedKeyStates: []keypool.KeyState{
|
||||
keypool.KeyStateTemporary,
|
||||
keypool.KeyStateValid,
|
||||
},
|
||||
},
|
||||
{
|
||||
// Given: 2 keys; key-0 succeeds initially, then both
|
||||
// keys 429 during the agentic continuation.
|
||||
// Then: 3 requests, 429 response with smallest
|
||||
// Retry-After, both keys temporary.
|
||||
name: "agentic_all_keys_fail",
|
||||
responses: []upstreamResponse{
|
||||
{statusCode: http.StatusOK, body: toolUseBody},
|
||||
{
|
||||
statusCode: http.StatusTooManyRequests,
|
||||
headers: map[string]string{"Retry-After": "5"},
|
||||
body: rateLimitBody,
|
||||
},
|
||||
{
|
||||
statusCode: http.StatusTooManyRequests,
|
||||
headers: map[string]string{"Retry-After": "3"},
|
||||
body: rateLimitBody,
|
||||
},
|
||||
},
|
||||
expectedRequestCount: 3,
|
||||
expectedSeenKeys: []string{"k0", "k0", "k1"},
|
||||
expectedStatusCode: http.StatusTooManyRequests,
|
||||
expectedRetryAfter: "3",
|
||||
expectedKeyStates: []keypool.KeyState{
|
||||
keypool.KeyStateTemporary,
|
||||
keypool.KeyStateTemporary,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var requestCount atomic.Int32
|
||||
var seenKeysMu sync.Mutex
|
||||
var seenKeys []string
|
||||
|
||||
// Mock upstream: returns scripted responses in order,
|
||||
// records each request's X-Api-Key for assertions.
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
idx := int(requestCount.Add(1)) - 1
|
||||
seenKeysMu.Lock()
|
||||
seenKeys = append(seenKeys, r.Header.Get("X-Api-Key"))
|
||||
seenKeysMu.Unlock()
|
||||
_, _ = io.Copy(io.Discard, r.Body)
|
||||
|
||||
if idx >= len(tc.responses) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
resp := tc.responses[idx]
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
for hk, hv := range resp.headers {
|
||||
w.Header().Set(hk, hv)
|
||||
}
|
||||
w.WriteHeader(resp.statusCode)
|
||||
_, _ = w.Write([]byte(resp.body))
|
||||
}))
|
||||
t.Cleanup(upstream.Close)
|
||||
|
||||
pool, err := keypool.New([]string{"k0", "k1"}, quartz.NewMock(t))
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg := config.Anthropic{
|
||||
BaseURL: upstream.URL + "/",
|
||||
KeyPool: pool,
|
||||
}
|
||||
|
||||
payload, err := NewRequestPayload([]byte(requestBody))
|
||||
require.NoError(t, err)
|
||||
|
||||
interceptor := NewBlockingInterceptor(
|
||||
uuid.New(),
|
||||
payload,
|
||||
config.ProviderAnthropic,
|
||||
cfg,
|
||||
nil,
|
||||
http.Header{},
|
||||
"X-Api-Key",
|
||||
otel.Tracer("blocking_test"),
|
||||
intercept.NewCredentialInfo(intercept.CredentialKindCentralized, ""),
|
||||
)
|
||||
|
||||
// Mock proxy with a tool the upstream's tool_use
|
||||
// response will reference.
|
||||
proxy := &mockServerProxier{
|
||||
tools: []*mcp.Tool{
|
||||
{
|
||||
Client: stubToolCaller{},
|
||||
ID: "test_tool",
|
||||
Name: "test_tool",
|
||||
ServerName: "coder",
|
||||
Logger: slog.Make(),
|
||||
},
|
||||
},
|
||||
}
|
||||
interceptor.Setup(slog.Make(), &testutil.MockRecorder{}, proxy)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
w := httptest.NewRecorder()
|
||||
err = interceptor.ProcessRequest(w, req)
|
||||
if tc.expectedStatusCode == http.StatusOK {
|
||||
require.NoError(t, err)
|
||||
} else {
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
assert.Equal(t, tc.expectedRequestCount, requestCount.Load(), "upstream request count")
|
||||
assert.Equal(t, tc.expectedStatusCode, w.Code, "response status code")
|
||||
assert.Equal(t, tc.expectedRetryAfter, w.Header().Get("Retry-After"), "Retry-After header")
|
||||
|
||||
seenKeysMu.Lock()
|
||||
defer seenKeysMu.Unlock()
|
||||
assert.Equal(t, tc.expectedSeenKeys, seenKeys, "seen keys")
|
||||
assert.Equal(t, tc.expectedKeyStates, pool.PoolState(), "key states")
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -25,6 +25,7 @@ import (
|
||||
aibcontext "github.com/coder/coder/v2/aibridge/context"
|
||||
"github.com/coder/coder/v2/aibridge/intercept"
|
||||
"github.com/coder/coder/v2/aibridge/intercept/eventstream"
|
||||
"github.com/coder/coder/v2/aibridge/keypool"
|
||||
"github.com/coder/coder/v2/aibridge/mcp"
|
||||
"github.com/coder/coder/v2/aibridge/recorder"
|
||||
"github.com/coder/coder/v2/aibridge/tracing"
|
||||
@@ -161,14 +162,64 @@ newStream:
|
||||
break
|
||||
}
|
||||
|
||||
stream := i.newStream(streamCtx, svc)
|
||||
// Per-iteration walker. An iteration is either an agentic
|
||||
// continuation (sending a tool result back in a new
|
||||
// stream) or a failover retry (previous key marked, try
|
||||
// the next one).
|
||||
var walker *keypool.Walker
|
||||
if i.cfg.KeyPool != nil {
|
||||
walker = i.cfg.KeyPool.Walker()
|
||||
}
|
||||
|
||||
var streamOpts []option.RequestOption
|
||||
var currentKey *keypool.Key
|
||||
if walker != nil {
|
||||
key, err := walker.Next()
|
||||
if respErr := processKeyPoolError(err); respErr != nil {
|
||||
// Pool exhausted in this iteration. Relay the
|
||||
// error to the client: as an SSE event if events
|
||||
// have already been sent, or by direct write
|
||||
// otherwise.
|
||||
interceptionErr = respErr
|
||||
if events.IsStreaming() {
|
||||
payload, mErr := i.marshal(respErr)
|
||||
if mErr != nil {
|
||||
logger.Warn(ctx, "failed to marshal exhaustion error", slog.Error(mErr))
|
||||
} else if sErr := events.Send(streamCtx, payload); sErr != nil {
|
||||
logger.Warn(ctx, "failed to relay exhaustion error", slog.Error(sErr))
|
||||
}
|
||||
} else {
|
||||
i.writeUpstreamError(w, respErr)
|
||||
}
|
||||
break
|
||||
}
|
||||
currentKey = key
|
||||
streamOpts = append(streamOpts,
|
||||
option.WithAPIKey(key.Value()),
|
||||
// Disable SDK retries because the failover
|
||||
// loop handles retries via key rotation.
|
||||
option.WithMaxRetries(0),
|
||||
)
|
||||
}
|
||||
|
||||
stream := i.newStream(streamCtx, svc, streamOpts...)
|
||||
|
||||
var message anthropic.Message
|
||||
var lastToolName string
|
||||
|
||||
pendingToolCalls := make(map[string]string)
|
||||
|
||||
// iterationStarted is per-iteration (reset on every
|
||||
// newStream loop): true once the upstream call has
|
||||
// produced any events for this iteration. While false,
|
||||
// a key-specific failure can still fail over to the
|
||||
// next key. Distinct from events.IsStreaming(), which
|
||||
// is stream-wide and stays true once iteration 1 has
|
||||
// sent any event downstream.
|
||||
var iterationStarted bool
|
||||
|
||||
for stream.Next() {
|
||||
iterationStarted = true
|
||||
event := stream.Current()
|
||||
if err := message.Accumulate(event); err != nil {
|
||||
logger.Warn(ctx, "failed to accumulate streaming events", slog.Error(err), slog.F("event", event), slog.F("msg", message.RawJSON()))
|
||||
@@ -478,40 +529,53 @@ newStream:
|
||||
promptFound = false //nolint:ineffassign // reset to prevent double-recording across newStream iterations
|
||||
}
|
||||
|
||||
if events.IsStreaming() {
|
||||
// Check if the stream encountered any errors.
|
||||
if streamErr := stream.Err(); streamErr != nil {
|
||||
if eventstream.IsUnrecoverableError(streamErr) {
|
||||
logger.Debug(ctx, "stream terminated", slog.Error(streamErr))
|
||||
// We can't reflect an error back if there's a connection error or the request context was canceled.
|
||||
} else if antErr := getErrorResponse(streamErr); antErr != nil {
|
||||
logger.Warn(ctx, "anthropic stream error", slog.Error(streamErr))
|
||||
interceptionErr = antErr
|
||||
} else {
|
||||
logger.Warn(ctx, "unknown stream error", slog.Error(streamErr))
|
||||
// Unfortunately, the Anthropic SDK does not support parsing errors received in the stream
|
||||
// into known types (i.e. [shared.OverloadedError]).
|
||||
// See https://github.com/anthropics/anthropic-sdk-go/blob/v1.12.0/packages/ssestream/ssestream.go#L172-L174
|
||||
// All it does is wrap the payload in an error - which is all we can return, currently.
|
||||
interceptionErr = newErrorResponse(xerrors.Errorf("unknown stream error: %w", streamErr))
|
||||
}
|
||||
} else if lastErr != nil {
|
||||
// Otherwise check if any logical errors occurred during processing.
|
||||
logger.Warn(ctx, "stream processing failed", slog.Error(lastErr))
|
||||
interceptionErr = newErrorResponse(xerrors.Errorf("processing error: %w", lastErr))
|
||||
}
|
||||
|
||||
if interceptionErr != nil {
|
||||
payload, err := i.marshal(interceptionErr)
|
||||
if iterationStarted {
|
||||
// Mid-stream error or logical error: events have
|
||||
// already streamed for this iteration, so the
|
||||
// error is relayed as an SSE event.
|
||||
streamErr := stream.Err()
|
||||
if respErr := i.mapStreamError(ctx, logger, streamErr, lastErr); respErr != nil {
|
||||
interceptionErr = respErr
|
||||
payload, err := i.marshal(respErr)
|
||||
if err != nil {
|
||||
logger.Warn(ctx, "failed to marshal error", slog.Error(err), slog.F("error_payload", fmt.Sprintf("%+v", interceptionErr)))
|
||||
logger.Warn(ctx, "failed to marshal error", slog.Error(err), slog.F("error_payload", fmt.Sprintf("%+v", respErr)))
|
||||
} else if err := events.Send(streamCtx, payload); err != nil {
|
||||
logger.Warn(ctx, "failed to relay error", slog.Error(err), slog.F("payload", payload))
|
||||
}
|
||||
} else if streamErr != nil {
|
||||
// Unrecoverable (e.g., broken pipe, context
|
||||
// canceled): can't relay to the client, but record
|
||||
// the error so it isn't silently swallowed.
|
||||
interceptionErr = streamErr
|
||||
}
|
||||
} else {
|
||||
// Stream has not started yet; write to response if present.
|
||||
i.writeUpstreamError(w, getErrorResponse(stream.Err()))
|
||||
// Pre-stream failure of this iteration. For
|
||||
// centralized requests, mark the key and retry with
|
||||
// the next one.
|
||||
if currentKey != nil && i.markKeyOnError(ctx, currentKey, stream.Err()) {
|
||||
continue newStream
|
||||
}
|
||||
// Non-key error: relay it. Use mapStreamError so that
|
||||
// unknown upstream errors (TCP reset, DNS failure, TLS
|
||||
// error, deadline exceeded) are wrapped in a generic
|
||||
// response instead of producing a silent HTTP 200.
|
||||
respErr := i.mapStreamError(ctx, logger, stream.Err(), lastErr)
|
||||
if respErr != nil {
|
||||
interceptionErr = respErr
|
||||
if events.IsStreaming() {
|
||||
// Prior iterations have streamed, so the SSE
|
||||
// connection is open: inject as an SSE event.
|
||||
payload, mErr := i.marshal(respErr)
|
||||
if mErr != nil {
|
||||
logger.Warn(ctx, "failed to marshal error", slog.Error(mErr))
|
||||
} else if sErr := events.Send(streamCtx, payload); sErr != nil {
|
||||
logger.Warn(ctx, "failed to relay error", slog.Error(sErr))
|
||||
}
|
||||
} else {
|
||||
// No events streamed yet, write the response directly.
|
||||
i.writeUpstreamError(w, respErr)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
shutdownCtx, shutdownCancel := context.WithTimeout(ctx, time.Second*30)
|
||||
@@ -534,6 +598,35 @@ newStream:
|
||||
return interceptionErr
|
||||
}
|
||||
|
||||
// mapStreamError converts a mid-stream upstream error or
|
||||
// processing error into a relayable responseError. Returns nil
|
||||
// when the error is unrecoverable, in which case nothing can be
|
||||
// relayed back.
|
||||
func (*StreamingInterception) mapStreamError(ctx context.Context, logger slog.Logger, streamErr, lastErr error) *responseError {
|
||||
if streamErr != nil {
|
||||
if eventstream.IsUnrecoverableError(streamErr) {
|
||||
logger.Debug(ctx, "stream terminated", slog.Error(streamErr))
|
||||
// We can't reflect an error back if there's a connection error or the request context was canceled.
|
||||
return nil
|
||||
}
|
||||
if antErr := getErrorResponse(streamErr); antErr != nil {
|
||||
logger.Warn(ctx, "anthropic stream error", slog.Error(streamErr))
|
||||
return antErr
|
||||
}
|
||||
logger.Warn(ctx, "unknown stream error", slog.Error(streamErr))
|
||||
// Unfortunately, the Anthropic SDK does not support parsing errors received in the stream
|
||||
// into known types (i.e. [shared.OverloadedError]).
|
||||
// See https://github.com/anthropics/anthropic-sdk-go/blob/v1.12.0/packages/ssestream/ssestream.go#L172-L174
|
||||
// All it does is wrap the payload in an error - which is all we can return, currently.
|
||||
return newErrorResponse(fmt.Sprintf("unknown stream error: %s", streamErr), string(constant.ValueOf[constant.Error]()), http.StatusBadGateway, 0)
|
||||
}
|
||||
if lastErr != nil {
|
||||
logger.Warn(ctx, "stream processing failed", slog.Error(lastErr))
|
||||
return newErrorResponse(fmt.Sprintf("processing error: %s", lastErr), string(constant.ValueOf[constant.Error]()), http.StatusBadGateway, 0)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *StreamingInterception) marshalEvent(event anthropic.MessageStreamEventUnion) ([]byte, error) {
|
||||
sj, err := sjson.Set(event.RawJSON(), "message.id", i.ID().String())
|
||||
if err != nil {
|
||||
@@ -585,9 +678,10 @@ func (*StreamingInterception) encodeForStream(payload []byte, typ string) []byte
|
||||
}
|
||||
|
||||
// newStream traces svc.NewStreaming() call.
|
||||
func (i *StreamingInterception) newStream(ctx context.Context, svc anthropic.MessageService) *ssestream.Stream[anthropic.MessageStreamEventUnion] {
|
||||
func (i *StreamingInterception) newStream(ctx context.Context, svc anthropic.MessageService, extraOpts ...option.RequestOption) *ssestream.Stream[anthropic.MessageStreamEventUnion] {
|
||||
_, span := i.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...))
|
||||
defer span.End()
|
||||
|
||||
return svc.NewStreaming(ctx, anthropic.MessageNewParams{}, i.withBody())
|
||||
opts := append([]option.RequestOption{i.withBody()}, extraOpts...)
|
||||
return svc.NewStreaming(ctx, anthropic.MessageNewParams{}, opts...)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,554 @@
|
||||
package messages //nolint:testpackage // tests unexported internals
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
mcplib "github.com/mark3labs/mcp-go/mcp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/aibridge/config"
|
||||
"github.com/coder/coder/v2/aibridge/intercept"
|
||||
"github.com/coder/coder/v2/aibridge/internal/testutil"
|
||||
"github.com/coder/coder/v2/aibridge/keypool"
|
||||
"github.com/coder/coder/v2/aibridge/mcp"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
// Anthropic-shaped SSE body for a successful streaming response.
|
||||
const streamingSuccessBody = `event: message_start
|
||||
data: {"type":"message_start","message":{"id":"msg_01","type":"message","role":"assistant","model":"claude-opus-4-5","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":10,"output_tokens":1}}}
|
||||
|
||||
event: content_block_start
|
||||
data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"hi"}}
|
||||
|
||||
event: content_block_stop
|
||||
data: {"type":"content_block_stop","index":0}
|
||||
|
||||
event: message_delta
|
||||
data: {"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"output_tokens":5}}
|
||||
|
||||
event: message_stop
|
||||
data: {"type":"message_stop"}
|
||||
`
|
||||
|
||||
func TestStreamingInterception_KeyFailover(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
// Centralized pool keys. Empty when byokKey is set.
|
||||
keys []string
|
||||
// BYOK key. Empty when keys is set.
|
||||
byokKey string
|
||||
// Scripted upstream responses keyed by X-Api-Key.
|
||||
responses map[string]upstreamResponse
|
||||
expectedRequestCount int32
|
||||
expectedStatusCode int
|
||||
expectedRetryAfter string
|
||||
// Expected key states after the request, by index in keys.
|
||||
expectedKeyStates []keypool.KeyState
|
||||
}{
|
||||
{
|
||||
// Given: 1 valid key returning a successful stream.
|
||||
// Then: 1 request, 200 response, key remains valid.
|
||||
name: "single_valid_key",
|
||||
keys: []string{"k0"},
|
||||
responses: map[string]upstreamResponse{
|
||||
"k0": {
|
||||
statusCode: http.StatusOK,
|
||||
headers: map[string]string{"Content-Type": "text/event-stream"},
|
||||
body: streamingSuccessBody,
|
||||
},
|
||||
},
|
||||
expectedRequestCount: 1,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
expectedKeyStates: []keypool.KeyState{keypool.KeyStateValid},
|
||||
},
|
||||
{
|
||||
// Given: 2 keys; key-0 returns 429 pre-stream, key-1
|
||||
// streams successfully.
|
||||
// Then: 2 requests, 200 response, key-0 temporary, key-1 valid.
|
||||
name: "failover_after_429",
|
||||
keys: []string{"k0", "k1"},
|
||||
responses: map[string]upstreamResponse{
|
||||
"k0": {
|
||||
statusCode: http.StatusTooManyRequests,
|
||||
headers: map[string]string{"Retry-After": "5"},
|
||||
body: rateLimitBody,
|
||||
},
|
||||
"k1": {
|
||||
statusCode: http.StatusOK,
|
||||
headers: map[string]string{"Content-Type": "text/event-stream"},
|
||||
body: streamingSuccessBody,
|
||||
},
|
||||
},
|
||||
expectedRequestCount: 2,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
expectedKeyStates: []keypool.KeyState{
|
||||
keypool.KeyStateTemporary,
|
||||
keypool.KeyStateValid,
|
||||
},
|
||||
},
|
||||
{
|
||||
// Given: 2 keys; key-0 returns 401 pre-stream, key-1
|
||||
// streams successfully.
|
||||
// Then: 2 requests, 200 response, key-0 permanent, key-1 valid.
|
||||
name: "failover_after_401",
|
||||
keys: []string{"k0", "k1"},
|
||||
responses: map[string]upstreamResponse{
|
||||
"k0": {statusCode: http.StatusUnauthorized, body: authErrorBody},
|
||||
"k1": {
|
||||
statusCode: http.StatusOK,
|
||||
headers: map[string]string{"Content-Type": "text/event-stream"},
|
||||
body: streamingSuccessBody,
|
||||
},
|
||||
},
|
||||
expectedRequestCount: 2,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
expectedKeyStates: []keypool.KeyState{
|
||||
keypool.KeyStatePermanent,
|
||||
keypool.KeyStateValid,
|
||||
},
|
||||
},
|
||||
{
|
||||
// Given: 2 keys; key-0 returns 403 pre-stream, key-1 streams.
|
||||
// Then: 2 requests, 200 response, key-0 permanent, key-1 valid.
|
||||
name: "failover_after_403",
|
||||
keys: []string{"k0", "k1"},
|
||||
responses: map[string]upstreamResponse{
|
||||
"k0": {statusCode: http.StatusForbidden, body: authErrorBody},
|
||||
"k1": {
|
||||
statusCode: http.StatusOK,
|
||||
headers: map[string]string{"Content-Type": "text/event-stream"},
|
||||
body: streamingSuccessBody,
|
||||
},
|
||||
},
|
||||
expectedRequestCount: 2,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
expectedKeyStates: []keypool.KeyState{
|
||||
keypool.KeyStatePermanent,
|
||||
keypool.KeyStateValid,
|
||||
},
|
||||
},
|
||||
{
|
||||
// Given: 3 keys; all return 429 pre-stream with
|
||||
// cooldowns 5s, 3s, 10s.
|
||||
// Then: 3 requests, 429 response with smallest
|
||||
// Retry-After, all keys temporary.
|
||||
name: "all_keys_rate_limited",
|
||||
keys: []string{"k0", "k1", "k2"},
|
||||
responses: map[string]upstreamResponse{
|
||||
"k0": {
|
||||
statusCode: http.StatusTooManyRequests,
|
||||
headers: map[string]string{"Retry-After": "5"},
|
||||
body: rateLimitBody,
|
||||
},
|
||||
"k1": {
|
||||
statusCode: http.StatusTooManyRequests,
|
||||
headers: map[string]string{"Retry-After": "3"},
|
||||
body: rateLimitBody,
|
||||
},
|
||||
"k2": {
|
||||
statusCode: http.StatusTooManyRequests,
|
||||
headers: map[string]string{"Retry-After": "10"},
|
||||
body: rateLimitBody,
|
||||
},
|
||||
},
|
||||
expectedRequestCount: 3,
|
||||
expectedStatusCode: http.StatusTooManyRequests,
|
||||
expectedRetryAfter: "3",
|
||||
expectedKeyStates: []keypool.KeyState{
|
||||
keypool.KeyStateTemporary,
|
||||
keypool.KeyStateTemporary,
|
||||
keypool.KeyStateTemporary,
|
||||
},
|
||||
},
|
||||
{
|
||||
// Given: 2 keys; both return 401 pre-stream.
|
||||
// Then: 2 requests, 502 api_error response, both keys permanent.
|
||||
name: "all_keys_unauthorized",
|
||||
keys: []string{"k0", "k1"},
|
||||
responses: map[string]upstreamResponse{
|
||||
"k0": {statusCode: http.StatusUnauthorized, body: authErrorBody},
|
||||
"k1": {statusCode: http.StatusUnauthorized, body: authErrorBody},
|
||||
},
|
||||
expectedRequestCount: 2,
|
||||
expectedStatusCode: http.StatusBadGateway,
|
||||
expectedKeyStates: []keypool.KeyState{
|
||||
keypool.KeyStatePermanent,
|
||||
keypool.KeyStatePermanent,
|
||||
},
|
||||
},
|
||||
{
|
||||
// Given: 2 keys; key-0 returns 500 pre-stream.
|
||||
// Then: 1 request, 500 response, both keys remain valid.
|
||||
name: "server_error_no_failover",
|
||||
keys: []string{"k0", "k1"},
|
||||
responses: map[string]upstreamResponse{
|
||||
"k0": {statusCode: http.StatusInternalServerError, body: serverErrorBody},
|
||||
},
|
||||
expectedRequestCount: 1,
|
||||
expectedStatusCode: http.StatusInternalServerError,
|
||||
expectedKeyStates: []keypool.KeyState{
|
||||
keypool.KeyStateValid,
|
||||
keypool.KeyStateValid,
|
||||
},
|
||||
},
|
||||
{
|
||||
// Given: BYOK with a single key returning 429.
|
||||
// Then: 1 request, 429 response, no failover, upstream
|
||||
// Retry-After propagated to the client.
|
||||
name: "byok_no_failover",
|
||||
byokKey: "user-byok",
|
||||
responses: map[string]upstreamResponse{
|
||||
"user-byok": {
|
||||
statusCode: http.StatusTooManyRequests,
|
||||
headers: map[string]string{
|
||||
"Retry-After": "5",
|
||||
// BYOK doesn't set MaxRetries(0);
|
||||
// suppress SDK retries to test a
|
||||
// single attempt.
|
||||
"x-should-retry": "false",
|
||||
},
|
||||
body: rateLimitBody,
|
||||
},
|
||||
},
|
||||
expectedRequestCount: 1,
|
||||
expectedStatusCode: http.StatusTooManyRequests,
|
||||
expectedRetryAfter: "5",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Mock upstream: counts requests and returns
|
||||
// scripted responses keyed by X-Api-Key. An unmapped
|
||||
// key falls through to 500 so misconfigured cases
|
||||
// surface via the status assertion.
|
||||
var requestCount atomic.Int32
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestCount.Add(1)
|
||||
_, _ = io.Copy(io.Discard, r.Body)
|
||||
resp, ok := tc.responses[r.Header.Get("X-Api-Key")]
|
||||
if !ok {
|
||||
resp = upstreamResponse{statusCode: http.StatusInternalServerError}
|
||||
}
|
||||
for hk, hv := range resp.headers {
|
||||
w.Header().Set(hk, hv)
|
||||
}
|
||||
w.WriteHeader(resp.statusCode)
|
||||
_, _ = w.Write([]byte(resp.body))
|
||||
}))
|
||||
t.Cleanup(upstream.Close)
|
||||
|
||||
cfg := config.Anthropic{BaseURL: upstream.URL + "/"}
|
||||
var pool *keypool.Pool
|
||||
if len(tc.keys) > 0 {
|
||||
var err error
|
||||
pool, err = keypool.New(tc.keys, quartz.NewMock(t))
|
||||
require.NoError(t, err)
|
||||
cfg.KeyPool = pool
|
||||
} else if tc.byokKey != "" {
|
||||
cfg.Key = tc.byokKey
|
||||
}
|
||||
|
||||
payload, err := NewRequestPayload([]byte(requestBody))
|
||||
require.NoError(t, err)
|
||||
|
||||
interceptor := NewStreamingInterceptor(
|
||||
uuid.New(),
|
||||
payload,
|
||||
config.ProviderAnthropic,
|
||||
cfg,
|
||||
nil,
|
||||
http.Header{},
|
||||
"X-Api-Key",
|
||||
otel.Tracer("streaming_test"),
|
||||
intercept.NewCredentialInfo(intercept.CredentialKindCentralized, ""),
|
||||
)
|
||||
interceptor.Setup(slog.Make(), &testutil.MockRecorder{}, nil)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
w := httptest.NewRecorder()
|
||||
err = interceptor.ProcessRequest(w, req)
|
||||
if tc.expectedStatusCode == http.StatusOK {
|
||||
require.NoError(t, err)
|
||||
} else {
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
assert.Equal(t, tc.expectedRequestCount, requestCount.Load(), "upstream request count")
|
||||
assert.Equal(t, tc.expectedStatusCode, w.Code, "response status code")
|
||||
assert.Equal(t, tc.expectedRetryAfter, w.Header().Get("Retry-After"), "Retry-After header")
|
||||
// No prior iteration streamed, so errors must be a
|
||||
// direct HTTP response, not an SSE event.
|
||||
assert.NotContains(t, w.Body.String(), "event: error", "error must not be relayed as an SSE event")
|
||||
if pool != nil {
|
||||
assert.Equal(t, tc.expectedKeyStates, pool.PoolState(), "key states")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// SSE bodies covering an agentic-continuation flow.
|
||||
const (
|
||||
// First response: a tool_use block referencing the injected
|
||||
// "test_tool". Triggers the agentic continuation loop.
|
||||
toolUseStreamBody = `event: message_start
|
||||
data: {"type":"message_start","message":{"id":"msg_01","type":"message","role":"assistant","model":"claude-opus-4-5","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":10,"output_tokens":1}}}
|
||||
|
||||
event: content_block_start
|
||||
data: {"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"toolu_01","name":"test_tool","input":{}}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":"{}"}}
|
||||
|
||||
event: content_block_stop
|
||||
data: {"type":"content_block_stop","index":0}
|
||||
|
||||
event: message_delta
|
||||
data: {"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"output_tokens":5}}
|
||||
|
||||
event: message_stop
|
||||
data: {"type":"message_stop"}
|
||||
|
||||
`
|
||||
|
||||
// Second response (after the tool result is sent back):
|
||||
// a plain text completion that ends the loop.
|
||||
textStreamBody = `event: message_start
|
||||
data: {"type":"message_start","message":{"id":"msg_02","type":"message","role":"assistant","model":"claude-opus-4-5","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":15,"output_tokens":1}}}
|
||||
|
||||
event: content_block_start
|
||||
data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}
|
||||
|
||||
event: content_block_delta
|
||||
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"done"}}
|
||||
|
||||
event: content_block_stop
|
||||
data: {"type":"content_block_stop","index":0}
|
||||
|
||||
event: message_delta
|
||||
data: {"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"output_tokens":3}}
|
||||
|
||||
event: message_stop
|
||||
data: {"type":"message_stop"}
|
||||
|
||||
`
|
||||
)
|
||||
|
||||
// stubToolCaller is a minimal mcp.ToolCaller that returns a fixed
|
||||
// text result, so the agentic continuation can proceed.
|
||||
type stubToolCaller struct{}
|
||||
|
||||
func (stubToolCaller) CallTool(_ context.Context, _ mcplib.CallToolRequest) (*mcplib.CallToolResult, error) {
|
||||
return mcplib.NewToolResultText("tool result"), nil
|
||||
}
|
||||
|
||||
// TestStreamingInterception_AgenticLoopFailover covers the
|
||||
// scenarios that span an agentic-loop continuation: the initial
|
||||
// client request and the subsequent tool-call continuation can
|
||||
// each fail over independently. Each iteration gets its own
|
||||
// walker.
|
||||
func TestStreamingInterception_AgenticLoopFailover(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
sseHeaders := map[string]string{"Content-Type": "text/event-stream"}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
// Scripted upstream responses consumed in order of
|
||||
// upstream request.
|
||||
responses []upstreamResponse
|
||||
expectedRequestCount int32
|
||||
expectedSeenKeys []string
|
||||
// Substring expected in the response body. Either a
|
||||
// success marker (e.g. "done") or an error marker
|
||||
// (e.g. "rate_limit_error").
|
||||
expectedBodyContains string
|
||||
// True when the error must be relayed as an SSE event.
|
||||
expectErrorAsSSEEvent bool
|
||||
// True when ProcessRequest is expected to return an
|
||||
// error (e.g. all keys exhausted).
|
||||
expectedErr bool
|
||||
expectedKeyStates []keypool.KeyState
|
||||
}{
|
||||
{
|
||||
// Given: 2 keys; both upstream calls succeed on key-0.
|
||||
// Then: 2 requests, success body, both keys remain valid.
|
||||
name: "happy_path",
|
||||
responses: []upstreamResponse{
|
||||
{statusCode: http.StatusOK, headers: sseHeaders, body: toolUseStreamBody},
|
||||
{statusCode: http.StatusOK, headers: sseHeaders, body: textStreamBody},
|
||||
},
|
||||
expectedRequestCount: 2,
|
||||
expectedSeenKeys: []string{"k0", "k0"},
|
||||
expectedBodyContains: "done",
|
||||
expectErrorAsSSEEvent: false,
|
||||
expectedKeyStates: []keypool.KeyState{
|
||||
keypool.KeyStateValid,
|
||||
keypool.KeyStateValid,
|
||||
},
|
||||
},
|
||||
{
|
||||
// Given: 2 keys; key-0 succeeds initially, then 429s
|
||||
// during the agentic continuation, key-1 succeeds.
|
||||
// Then: 3 requests, success body, key-0 temporary,
|
||||
// key-1 valid.
|
||||
name: "agentic_failover_to_k1",
|
||||
responses: []upstreamResponse{
|
||||
{statusCode: http.StatusOK, headers: sseHeaders, body: toolUseStreamBody},
|
||||
{
|
||||
statusCode: http.StatusTooManyRequests,
|
||||
headers: map[string]string{"Retry-After": "5"},
|
||||
body: rateLimitBody,
|
||||
},
|
||||
{statusCode: http.StatusOK, headers: sseHeaders, body: textStreamBody},
|
||||
},
|
||||
expectedRequestCount: 3,
|
||||
expectedSeenKeys: []string{"k0", "k0", "k1"},
|
||||
expectedBodyContains: "done",
|
||||
expectErrorAsSSEEvent: false,
|
||||
expectedKeyStates: []keypool.KeyState{
|
||||
keypool.KeyStateTemporary,
|
||||
keypool.KeyStateValid,
|
||||
},
|
||||
},
|
||||
{
|
||||
// Given: 2 keys; key-0 succeeds initially, then both
|
||||
// keys 429 during the agentic continuation.
|
||||
// Then: 3 requests, error injected as SSE event, both
|
||||
// keys temporary.
|
||||
name: "agentic_all_keys_fail",
|
||||
responses: []upstreamResponse{
|
||||
{statusCode: http.StatusOK, headers: sseHeaders, body: toolUseStreamBody},
|
||||
{
|
||||
statusCode: http.StatusTooManyRequests,
|
||||
headers: map[string]string{"Retry-After": "5"},
|
||||
body: rateLimitBody,
|
||||
},
|
||||
{
|
||||
statusCode: http.StatusTooManyRequests,
|
||||
headers: map[string]string{"Retry-After": "3"},
|
||||
body: rateLimitBody,
|
||||
},
|
||||
},
|
||||
expectedRequestCount: 3,
|
||||
expectedSeenKeys: []string{"k0", "k0", "k1"},
|
||||
expectedBodyContains: "all configured keys are rate-limited",
|
||||
expectErrorAsSSEEvent: true,
|
||||
expectedErr: true,
|
||||
expectedKeyStates: []keypool.KeyState{
|
||||
keypool.KeyStateTemporary,
|
||||
keypool.KeyStateTemporary,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var requestCount atomic.Int32
|
||||
var seenKeysMu sync.Mutex
|
||||
var seenKeys []string
|
||||
|
||||
// Mock upstream: returns scripted responses in order,
|
||||
// records each request's X-Api-Key for assertions.
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
idx := int(requestCount.Add(1)) - 1
|
||||
seenKeysMu.Lock()
|
||||
seenKeys = append(seenKeys, r.Header.Get("X-Api-Key"))
|
||||
seenKeysMu.Unlock()
|
||||
_, _ = io.Copy(io.Discard, r.Body)
|
||||
|
||||
if idx >= len(tc.responses) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
resp := tc.responses[idx]
|
||||
for hk, hv := range resp.headers {
|
||||
w.Header().Set(hk, hv)
|
||||
}
|
||||
w.WriteHeader(resp.statusCode)
|
||||
_, _ = w.Write([]byte(resp.body))
|
||||
}))
|
||||
t.Cleanup(upstream.Close)
|
||||
|
||||
pool, err := keypool.New([]string{"k0", "k1"}, quartz.NewMock(t))
|
||||
require.NoError(t, err)
|
||||
|
||||
cfg := config.Anthropic{
|
||||
BaseURL: upstream.URL + "/",
|
||||
KeyPool: pool,
|
||||
}
|
||||
|
||||
payload, err := NewRequestPayload([]byte(requestBody))
|
||||
require.NoError(t, err)
|
||||
|
||||
interceptor := NewStreamingInterceptor(
|
||||
uuid.New(),
|
||||
payload,
|
||||
config.ProviderAnthropic,
|
||||
cfg,
|
||||
nil,
|
||||
http.Header{},
|
||||
"X-Api-Key",
|
||||
otel.Tracer("streaming_test"),
|
||||
intercept.NewCredentialInfo(intercept.CredentialKindCentralized, ""),
|
||||
)
|
||||
|
||||
// Mock proxy with a tool the upstream's tool_use event
|
||||
// will reference. The stub caller returns a fixed
|
||||
// text result.
|
||||
proxy := &mockServerProxier{
|
||||
tools: []*mcp.Tool{
|
||||
{
|
||||
Client: stubToolCaller{},
|
||||
ID: "test_tool",
|
||||
Name: "test_tool",
|
||||
ServerName: "coder",
|
||||
Logger: slog.Make(),
|
||||
},
|
||||
},
|
||||
}
|
||||
interceptor.Setup(slog.Make(), &testutil.MockRecorder{}, proxy)
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, "/v1/messages", nil)
|
||||
w := httptest.NewRecorder()
|
||||
err = interceptor.ProcessRequest(w, req)
|
||||
if tc.expectedErr {
|
||||
require.Error(t, err)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
assert.Equal(t, tc.expectedRequestCount, requestCount.Load(), "upstream request count")
|
||||
body := w.Body.String()
|
||||
assert.Contains(t, body, tc.expectedBodyContains, "response body")
|
||||
if tc.expectErrorAsSSEEvent {
|
||||
assert.Contains(t, body, "event: error", "error must be relayed as an SSE event")
|
||||
}
|
||||
|
||||
seenKeysMu.Lock()
|
||||
defer seenKeysMu.Unlock()
|
||||
assert.Equal(t, tc.expectedSeenKeys, seenKeys, "seen keys")
|
||||
assert.Equal(t, tc.expectedKeyStates, pool.PoolState(), "key states")
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,128 @@
|
||||
package integrationtest //nolint:testpackage // tests unexported internals
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tidwall/sjson"
|
||||
|
||||
"github.com/coder/coder/v2/aibridge/config"
|
||||
"github.com/coder/coder/v2/aibridge/fixtures"
|
||||
"github.com/coder/coder/v2/aibridge/keypool"
|
||||
"github.com/coder/coder/v2/aibridge/provider"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
// TestAnthropic_KeyFailover verifies that a pool's key state
|
||||
// persists across distinct client requests: a key marked
|
||||
// temporary on request 1 is still skipped on request 2 without
|
||||
// a wasted upstream attempt.
|
||||
func TestAnthropic_KeyFailover(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
fix := fixtures.Parse(t, fixtures.AntSimple)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
streaming bool
|
||||
successBody []byte
|
||||
successCType string
|
||||
}{
|
||||
{
|
||||
name: "blocking",
|
||||
streaming: false,
|
||||
successBody: fix.NonStreaming(),
|
||||
successCType: "application/json",
|
||||
},
|
||||
{
|
||||
name: "streaming",
|
||||
streaming: true,
|
||||
successBody: fix.Streaming(),
|
||||
successCType: "text/event-stream",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pool, err := keypool.New([]string{"k0", "k1"}, quartz.NewMock(t))
|
||||
require.NoError(t, err)
|
||||
|
||||
var requestCount atomic.Int32
|
||||
var seenKeysMu sync.Mutex
|
||||
var seenKeys []string
|
||||
|
||||
// Mock upstream: k0 always returns 429, k1 returns
|
||||
// the per-test success body.
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestCount.Add(1)
|
||||
key := r.Header.Get("X-Api-Key")
|
||||
seenKeysMu.Lock()
|
||||
seenKeys = append(seenKeys, key)
|
||||
seenKeysMu.Unlock()
|
||||
_, _ = io.Copy(io.Discard, r.Body)
|
||||
|
||||
switch key {
|
||||
case "k0":
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Retry-After", "60")
|
||||
w.WriteHeader(http.StatusTooManyRequests)
|
||||
_, _ = fmt.Fprint(w, `{"type":"error","error":{"type":"rate_limit_error","message":"rate limited"}}`)
|
||||
case "k1":
|
||||
w.Header().Set("Content-Type", tc.successCType)
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(tc.successBody)
|
||||
default:
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
}))
|
||||
t.Cleanup(upstream.Close)
|
||||
|
||||
bridgeServer := newBridgeTestServer(t.Context(), t, upstream.URL,
|
||||
withCustomProvider(provider.NewAnthropic(config.Anthropic{
|
||||
BaseURL: upstream.URL,
|
||||
KeyPool: pool,
|
||||
}, nil)),
|
||||
)
|
||||
|
||||
requestBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Request 1: walker starts at k0, fails over to k1
|
||||
// after 429.
|
||||
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, requestBody)
|
||||
require.NoError(t, err)
|
||||
_, _ = io.Copy(io.Discard, resp.Body)
|
||||
require.NoError(t, resp.Body.Close())
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
// Request 2: walker skips the now-temporary k0 and
|
||||
// goes straight to k1 (1 upstream call, not 2).
|
||||
resp, err = bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, requestBody)
|
||||
require.NoError(t, err)
|
||||
_, _ = io.Copy(io.Discard, resp.Body)
|
||||
require.NoError(t, resp.Body.Close())
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
|
||||
seenKeysMu.Lock()
|
||||
defer seenKeysMu.Unlock()
|
||||
// Request 1: 2 calls (k0 then k1). Request 2: 1 call (k1).
|
||||
assert.Equal(t, int32(3), requestCount.Load(), "upstream request count")
|
||||
assert.Equal(t, []string{"k0", "k1", "k1"}, seenKeys, "seen keys")
|
||||
|
||||
// Pool state persists: k0 temporary, k1 valid.
|
||||
assert.Equal(t, []keypool.KeyState{
|
||||
keypool.KeyStateTemporary,
|
||||
keypool.KeyStateValid,
|
||||
}, pool.PoolState(), "key states")
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
package keypool
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ParseRetryAfter extracts the cooldown duration from response
|
||||
// headers. It prefers the OpenAI-specific "retry-after-ms"
|
||||
// header (milliseconds) over the standard "Retry-After" header
|
||||
// (seconds). Returns zero if neither header is present or
|
||||
// parseable. The HTTP-date form of "Retry-After" is not parsed.
|
||||
func ParseRetryAfter(resp *http.Response) time.Duration {
|
||||
if resp == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
// OpenAI convention: millisecond precision.
|
||||
if val := resp.Header.Get("retry-after-ms"); val != "" {
|
||||
ms, err := strconv.ParseFloat(strings.TrimSpace(val), 64)
|
||||
if err == nil && ms > 0 {
|
||||
return time.Duration(ms * float64(time.Millisecond))
|
||||
}
|
||||
}
|
||||
|
||||
// Standard header: seconds.
|
||||
if val := resp.Header.Get("Retry-After"); val != "" {
|
||||
seconds, err := strconv.Atoi(strings.TrimSpace(val))
|
||||
if err == nil && seconds > 0 {
|
||||
return time.Duration(seconds) * time.Second
|
||||
}
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
@@ -0,0 +1,110 @@
|
||||
package keypool_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/coder/coder/v2/aibridge/keypool"
|
||||
)
|
||||
|
||||
func TestParseRetryAfter(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
headers map[string]string
|
||||
nilResponse bool
|
||||
expected time.Duration
|
||||
}{
|
||||
// nil response.
|
||||
{
|
||||
name: "nil_response",
|
||||
nilResponse: true,
|
||||
expected: 0,
|
||||
},
|
||||
// No headers set.
|
||||
{
|
||||
name: "no_headers",
|
||||
headers: nil,
|
||||
expected: 0,
|
||||
},
|
||||
// retry-after-ms (OpenAI, preferred).
|
||||
{
|
||||
name: "openai_retry_after_ms",
|
||||
headers: map[string]string{"retry-after-ms": "2500"},
|
||||
expected: 2500 * time.Millisecond,
|
||||
},
|
||||
{
|
||||
name: "whitespace_trimmed_ms",
|
||||
headers: map[string]string{"retry-after-ms": " 1500 "},
|
||||
expected: 1500 * time.Millisecond,
|
||||
},
|
||||
{
|
||||
name: "negative_ms_returns_zero",
|
||||
headers: map[string]string{"retry-after-ms": "-100"},
|
||||
expected: 0,
|
||||
},
|
||||
// Retry-After (standard, seconds).
|
||||
{
|
||||
name: "standard_retry_after_seconds",
|
||||
headers: map[string]string{"Retry-After": "60"},
|
||||
expected: 60 * time.Second,
|
||||
},
|
||||
{
|
||||
name: "whitespace_trimmed_seconds",
|
||||
headers: map[string]string{"Retry-After": " 30 "},
|
||||
expected: 30 * time.Second,
|
||||
},
|
||||
{
|
||||
name: "zero_seconds_returns_zero",
|
||||
headers: map[string]string{"Retry-After": "0"},
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "negative_seconds_returns_zero",
|
||||
headers: map[string]string{"Retry-After": "-5"},
|
||||
expected: 0,
|
||||
},
|
||||
// Both headers set: precedence and fallback.
|
||||
{
|
||||
name: "prefers_retry_after_ms_over_standard",
|
||||
headers: map[string]string{
|
||||
"retry-after-ms": "1500",
|
||||
"Retry-After": "30",
|
||||
},
|
||||
expected: 1500 * time.Millisecond,
|
||||
},
|
||||
{
|
||||
name: "falls_back_to_standard_when_ms_invalid",
|
||||
headers: map[string]string{"retry-after-ms": "invalid", "Retry-After": "10"},
|
||||
expected: 10 * time.Second,
|
||||
},
|
||||
{
|
||||
name: "zero_ms_falls_back_to_standard",
|
||||
headers: map[string]string{"retry-after-ms": "0", "Retry-After": "5"},
|
||||
expected: 5 * time.Second,
|
||||
},
|
||||
{
|
||||
name: "zero_ms_and_zero_seconds_return_zero",
|
||||
headers: map[string]string{"retry-after-ms": "0", "Retry-After": "0"},
|
||||
expected: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
var resp *http.Response
|
||||
if !tc.nilResponse {
|
||||
resp = &http.Response{Header: make(http.Header)}
|
||||
for key, val := range tc.headers {
|
||||
resp.Header.Set(key, val)
|
||||
}
|
||||
}
|
||||
assert.Equal(t, tc.expected, keypool.ParseRetryAfter(resp))
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,54 @@
|
||||
package keypool
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/aibridge/utils"
|
||||
)
|
||||
|
||||
// MarkKeyOnStatus marks key based on a key-specific HTTP
|
||||
// status code from resp (429 for temporary, 401 or 403 for
|
||||
// permanent). Returns true if the status was a key-specific
|
||||
// failover trigger so callers can retry with the next key.
|
||||
func MarkKeyOnStatus(
|
||||
ctx context.Context,
|
||||
key *Key,
|
||||
resp *http.Response,
|
||||
logger slog.Logger,
|
||||
providerName string,
|
||||
) bool {
|
||||
if resp == nil {
|
||||
return false
|
||||
}
|
||||
statusCode := resp.StatusCode
|
||||
switch statusCode {
|
||||
case http.StatusTooManyRequests:
|
||||
cooldown := ParseRetryAfter(resp)
|
||||
if cooldown <= 0 {
|
||||
cooldown = defaultCooldown
|
||||
}
|
||||
if key.MarkTemporary(cooldown) {
|
||||
logger.Info(ctx, "key marked temporary",
|
||||
slog.F("provider", providerName),
|
||||
slog.F("api_key_hint", utils.MaskSecret(key.Value())),
|
||||
slog.F("status", statusCode),
|
||||
slog.F("cooldown", cooldown))
|
||||
}
|
||||
return true
|
||||
case http.StatusUnauthorized, http.StatusForbidden:
|
||||
if key.MarkPermanent() {
|
||||
logger.Warn(ctx, "key marked permanent",
|
||||
slog.F("provider", providerName),
|
||||
slog.F("api_key_hint", utils.MaskSecret(key.Value())),
|
||||
slog.F("status", statusCode))
|
||||
}
|
||||
return true
|
||||
default:
|
||||
logger.Debug(ctx, "status is not a key failover trigger",
|
||||
slog.F("provider", providerName),
|
||||
slog.F("status", statusCode))
|
||||
return false
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,127 @@
|
||||
package keypool_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/aibridge/keypool"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
func TestMarkKeyOnStatus(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
statusCode int
|
||||
headers map[string]string
|
||||
expectedReturn bool
|
||||
expectedState keypool.KeyState
|
||||
expectedCooldown time.Duration
|
||||
}{
|
||||
{
|
||||
// 429 with standard Retry-After header (seconds).
|
||||
name: "429_with_retry_after_seconds",
|
||||
statusCode: http.StatusTooManyRequests,
|
||||
headers: map[string]string{"Retry-After": "5"},
|
||||
expectedReturn: true,
|
||||
expectedState: keypool.KeyStateTemporary,
|
||||
expectedCooldown: 5 * time.Second,
|
||||
},
|
||||
{
|
||||
// 429 with retry-after-ms header (milliseconds).
|
||||
name: "429_with_retry_after_ms",
|
||||
statusCode: http.StatusTooManyRequests,
|
||||
headers: map[string]string{"retry-after-ms": "1500"},
|
||||
expectedReturn: true,
|
||||
expectedState: keypool.KeyStateTemporary,
|
||||
expectedCooldown: 1500 * time.Millisecond,
|
||||
},
|
||||
{
|
||||
// 429 without headers falls back to default cooldown.
|
||||
name: "429_no_headers_uses_default",
|
||||
statusCode: http.StatusTooManyRequests,
|
||||
expectedReturn: true,
|
||||
expectedState: keypool.KeyStateTemporary,
|
||||
expectedCooldown: 60 * time.Second,
|
||||
},
|
||||
{
|
||||
name: "401_marks_permanent",
|
||||
statusCode: http.StatusUnauthorized,
|
||||
expectedReturn: true,
|
||||
expectedState: keypool.KeyStatePermanent,
|
||||
},
|
||||
{
|
||||
name: "403_marks_permanent",
|
||||
statusCode: http.StatusForbidden,
|
||||
expectedReturn: true,
|
||||
expectedState: keypool.KeyStatePermanent,
|
||||
},
|
||||
{
|
||||
name: "200_does_not_mark",
|
||||
statusCode: http.StatusOK,
|
||||
expectedReturn: false,
|
||||
expectedState: keypool.KeyStateValid,
|
||||
},
|
||||
{
|
||||
name: "500_does_not_mark",
|
||||
statusCode: http.StatusInternalServerError,
|
||||
expectedReturn: false,
|
||||
expectedState: keypool.KeyStateValid,
|
||||
},
|
||||
{
|
||||
// 529 is the Anthropic overloaded status, handled by
|
||||
// the circuit breaker, not key failover.
|
||||
name: "529_does_not_mark",
|
||||
statusCode: 529,
|
||||
expectedReturn: false,
|
||||
expectedState: keypool.KeyStateValid,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
clk := quartz.NewMock(t)
|
||||
pool, err := keypool.New([]string{"key-0"}, clk)
|
||||
require.NoError(t, err)
|
||||
key, err := pool.Walker().Next()
|
||||
require.NoError(t, err)
|
||||
|
||||
resp := &http.Response{
|
||||
StatusCode: tc.statusCode,
|
||||
Header: make(http.Header),
|
||||
}
|
||||
for k, v := range tc.headers {
|
||||
resp.Header.Set(k, v)
|
||||
}
|
||||
|
||||
got := keypool.MarkKeyOnStatus(
|
||||
context.Background(),
|
||||
key,
|
||||
resp,
|
||||
// 401 and 403 cases legitimately log at error
|
||||
// level when marking a key permanent.
|
||||
slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
|
||||
"test",
|
||||
)
|
||||
|
||||
assert.Equal(t, tc.expectedReturn, got)
|
||||
assert.Equal(t, tc.expectedState, key.State())
|
||||
|
||||
// Verify cooldown was set to the expected duration:
|
||||
// advancing by exactly that amount returns the key
|
||||
// to valid.
|
||||
if tc.expectedCooldown > 0 {
|
||||
clk.Advance(tc.expectedCooldown)
|
||||
assert.Equal(t, keypool.KeyStateValid, key.State())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package keypool
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -15,11 +16,24 @@ var (
|
||||
// ErrDuplicateKey is returned when the input contains
|
||||
// duplicate key values.
|
||||
ErrDuplicateKey = xerrors.New("duplicate key")
|
||||
// ErrAllKeysExhausted is returned when the walker has visited
|
||||
// every key in the pool and none are available.
|
||||
ErrAllKeysExhausted = xerrors.New("all keys exhausted")
|
||||
)
|
||||
|
||||
// ErrPermanentKeyPool is returned when every key in the
|
||||
// pool has been permanently marked unavailable.
|
||||
var ErrPermanentKeyPool = xerrors.New("all keys permanently unavailable")
|
||||
|
||||
// TransientKeyPoolError is returned when no key is currently
|
||||
// available but at least one will recover. RetryAfter is the
|
||||
// soonest remaining cooldown across the pool, or 0 if a key
|
||||
// just became valid mid-walk.
|
||||
type TransientKeyPoolError struct {
|
||||
RetryAfter time.Duration
|
||||
}
|
||||
|
||||
func (e *TransientKeyPoolError) Error() string {
|
||||
return fmt.Sprintf("all keys exhausted (retry after %s)", e.RetryAfter)
|
||||
}
|
||||
|
||||
// KeyState represents the current state of a key in the pool.
|
||||
type KeyState int
|
||||
|
||||
@@ -101,6 +115,22 @@ func (k *Key) State() KeyState {
|
||||
return KeyStateValid
|
||||
}
|
||||
|
||||
// stateAndCooldown returns the key's state and remaining
|
||||
// cooldown as a single atomic snapshot.
|
||||
func (k *Key) stateAndCooldown() (KeyState, time.Duration) {
|
||||
k.mu.RLock()
|
||||
defer k.mu.RUnlock()
|
||||
|
||||
if k.permanent {
|
||||
return KeyStatePermanent, 0
|
||||
}
|
||||
now := k.clock.Now()
|
||||
if now.Before(k.cooldownUntil) {
|
||||
return KeyStateTemporary, k.cooldownUntil.Sub(now)
|
||||
}
|
||||
return KeyStateValid, 0
|
||||
}
|
||||
|
||||
// MarkTemporary marks the key as temporarily unavailable with
|
||||
// the specified cooldown duration. Returns true if this call
|
||||
// transitions the key to temporary.
|
||||
@@ -146,6 +176,47 @@ func (k *Key) MarkPermanent() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// keyPoolError returns ErrPermanentKeyPool if every key
|
||||
// is permanently unavailable, or *TransientKeyPoolError if
|
||||
// at least one key is temporarily unavailable. When multiple
|
||||
// keys are temporary, the smallest remaining cooldown is used
|
||||
// as the retry-after.
|
||||
func (p *Pool) keyPoolError() error {
|
||||
var retryAfter time.Duration
|
||||
var hasCooldown bool
|
||||
for i := range p.keys {
|
||||
state, cooldown := p.keys[i].stateAndCooldown()
|
||||
switch state {
|
||||
// Recoverable now: signal transient with zero retry-after.
|
||||
case KeyStateValid:
|
||||
return &TransientKeyPoolError{}
|
||||
// Recoverable later: track soonest remaining cooldown.
|
||||
case KeyStateTemporary:
|
||||
if !hasCooldown || cooldown < retryAfter {
|
||||
retryAfter = cooldown
|
||||
hasCooldown = true
|
||||
}
|
||||
// Permanent: keep walking to confirm error type.
|
||||
default:
|
||||
}
|
||||
}
|
||||
if hasCooldown {
|
||||
return &TransientKeyPoolError{RetryAfter: retryAfter}
|
||||
}
|
||||
return ErrPermanentKeyPool
|
||||
}
|
||||
|
||||
// PoolState returns a snapshot of each key's state in the pool's
|
||||
// original order, used by tests and other diagnostic callers. Use
|
||||
// Walker for the failover iteration path.
|
||||
func (p *Pool) PoolState() []KeyState {
|
||||
states := make([]KeyState, len(p.keys))
|
||||
for i := range p.keys {
|
||||
states[i] = p.keys[i].State()
|
||||
}
|
||||
return states
|
||||
}
|
||||
|
||||
// Walker traverses a Pool for a single request. Each request
|
||||
// creates its own walker so that it can independently iterate
|
||||
// through keys without interfering with other requests.
|
||||
@@ -162,14 +233,15 @@ func (p *Pool) Walker() *Walker {
|
||||
return &Walker{pool: p, pos: 0}
|
||||
}
|
||||
|
||||
// Next returns a Key handle for the next available key. This is
|
||||
// a read-only operation; it does not modify the pool state.
|
||||
// Next returns a Key handle for the next available key without
|
||||
// modifying the pool state.
|
||||
//
|
||||
// Returns ErrAllKeysExhausted when no more keys are available.
|
||||
// Returns *TransientKeyPoolError or ErrPermanentKeyPool
|
||||
// when no more keys are available.
|
||||
func (w *Walker) Next() (*Key, error) {
|
||||
pool := w.pool
|
||||
if pool == nil {
|
||||
return nil, ErrAllKeysExhausted
|
||||
return nil, ErrPermanentKeyPool
|
||||
}
|
||||
|
||||
for i := w.pos; i < len(pool.keys); i++ {
|
||||
@@ -183,5 +255,5 @@ func (w *Walker) Next() (*Key, error) {
|
||||
}
|
||||
|
||||
// No keys available.
|
||||
return nil, ErrAllKeysExhausted
|
||||
return nil, pool.keyPoolError()
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package keypool_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -49,7 +51,8 @@ func TestNewKeyPool(t *testing.T) {
|
||||
|
||||
// No more keys available.
|
||||
_, err = walker.Next()
|
||||
require.ErrorIs(t, err, keypool.ErrAllKeysExhausted)
|
||||
var transient *keypool.TransientKeyPoolError
|
||||
require.ErrorAs(t, err, &transient, "expected transient exhaustion: walker returned all valid keys, none marked permanent")
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -282,19 +285,21 @@ func TestWalkerNext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
keys []string
|
||||
setup func(t *testing.T, pool *keypool.Pool)
|
||||
advance time.Duration
|
||||
expectValid []string
|
||||
name string
|
||||
keys []string
|
||||
setup func(t *testing.T, pool *keypool.Pool)
|
||||
advance time.Duration
|
||||
expectedValid []string
|
||||
expectedErr error
|
||||
}{
|
||||
{
|
||||
// Given: key-0: valid, key-1: valid, key-2: valid.
|
||||
// Then: key-0: valid, key-1: valid, key-2: valid.
|
||||
name: "all_keys_valid",
|
||||
keys: []string{"key-0", "key-1", "key-2"},
|
||||
setup: func(_ *testing.T, _ *keypool.Pool) {},
|
||||
expectValid: []string{"key-0", "key-1", "key-2"},
|
||||
name: "all_keys_valid",
|
||||
keys: []string{"key-0", "key-1", "key-2"},
|
||||
setup: func(_ *testing.T, _ *keypool.Pool) {},
|
||||
expectedValid: []string{"key-0", "key-1", "key-2"},
|
||||
expectedErr: &keypool.TransientKeyPoolError{},
|
||||
},
|
||||
{
|
||||
// Given: key-0: temporary, key-1: valid, key-2: valid.
|
||||
@@ -306,7 +311,8 @@ func TestWalkerNext(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
key.MarkTemporary(60 * time.Second)
|
||||
},
|
||||
expectValid: []string{"key-1", "key-2"},
|
||||
expectedValid: []string{"key-1", "key-2"},
|
||||
expectedErr: &keypool.TransientKeyPoolError{},
|
||||
},
|
||||
{
|
||||
// Given: key-0: permanent, key-1: permanent, key-2: valid.
|
||||
@@ -322,7 +328,8 @@ func TestWalkerNext(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
key1.MarkPermanent()
|
||||
},
|
||||
expectValid: []string{"key-2"},
|
||||
expectedValid: []string{"key-2"},
|
||||
expectedErr: &keypool.TransientKeyPoolError{},
|
||||
},
|
||||
{
|
||||
// Given: key-0: temporary (30s), key-1: valid.
|
||||
@@ -335,8 +342,9 @@ func TestWalkerNext(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
key.MarkTemporary(30 * time.Second)
|
||||
},
|
||||
advance: 35 * time.Second,
|
||||
expectValid: []string{"key-0", "key-1"},
|
||||
advance: 35 * time.Second,
|
||||
expectedValid: []string{"key-0", "key-1"},
|
||||
expectedErr: &keypool.TransientKeyPoolError{},
|
||||
},
|
||||
{
|
||||
// Given: key-0: temporary (zero, default 60s), key-1: valid.
|
||||
@@ -349,8 +357,9 @@ func TestWalkerNext(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
key.MarkTemporary(0)
|
||||
},
|
||||
advance: 50 * time.Second,
|
||||
expectValid: []string{"key-1"},
|
||||
advance: 50 * time.Second,
|
||||
expectedValid: []string{"key-1"},
|
||||
expectedErr: &keypool.TransientKeyPoolError{},
|
||||
},
|
||||
{
|
||||
// Given: key-0: temporary (zero, default 60s), key-1: valid.
|
||||
@@ -363,8 +372,9 @@ func TestWalkerNext(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
key.MarkTemporary(0)
|
||||
},
|
||||
advance: 65 * time.Second,
|
||||
expectValid: []string{"key-0", "key-1"},
|
||||
advance: 65 * time.Second,
|
||||
expectedValid: []string{"key-0", "key-1"},
|
||||
expectedErr: &keypool.TransientKeyPoolError{},
|
||||
},
|
||||
{
|
||||
// Given: key-0: temporary (negative, default 60s), key-1: valid.
|
||||
@@ -377,13 +387,14 @@ func TestWalkerNext(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
key.MarkTemporary(-10 * time.Second)
|
||||
},
|
||||
advance: 65 * time.Second,
|
||||
expectValid: []string{"key-0", "key-1"},
|
||||
advance: 65 * time.Second,
|
||||
expectedValid: []string{"key-0", "key-1"},
|
||||
expectedErr: &keypool.TransientKeyPoolError{},
|
||||
},
|
||||
{
|
||||
// Given: key-0: temporary (60s), then marked again with shorter cooldown (10s).
|
||||
// When: 15s pass (past 10s, but not 60s).
|
||||
// Then: key-0: temporary.
|
||||
// Then: key-0: temporary, 45s remaining.
|
||||
name: "shorter_cooldown_preserves_longer_not_expired",
|
||||
keys: []string{"key-0"},
|
||||
setup: func(t *testing.T, pool *keypool.Pool) {
|
||||
@@ -392,8 +403,9 @@ func TestWalkerNext(t *testing.T) {
|
||||
key.MarkTemporary(60 * time.Second)
|
||||
key.MarkTemporary(10 * time.Second)
|
||||
},
|
||||
advance: 15 * time.Second,
|
||||
expectValid: []string{},
|
||||
advance: 15 * time.Second,
|
||||
expectedValid: []string{},
|
||||
expectedErr: &keypool.TransientKeyPoolError{RetryAfter: 45 * time.Second},
|
||||
},
|
||||
{
|
||||
// Given: key-0: temporary (60s), then marked again with shorter cooldown (10s).
|
||||
@@ -407,8 +419,30 @@ func TestWalkerNext(t *testing.T) {
|
||||
key.MarkTemporary(60 * time.Second)
|
||||
key.MarkTemporary(10 * time.Second)
|
||||
},
|
||||
advance: 65 * time.Second,
|
||||
expectValid: []string{"key-0"},
|
||||
advance: 65 * time.Second,
|
||||
expectedValid: []string{"key-0"},
|
||||
expectedErr: &keypool.TransientKeyPoolError{},
|
||||
},
|
||||
{
|
||||
// Given: key-0: temporary (60s), key-1: temporary (10s), key-2: temporary (30s).
|
||||
// Then: key-0: temporary, key-1: temporary, key-2: temporary.
|
||||
// Smallest remaining cooldown is reported on exhaustion.
|
||||
name: "smallest_cooldown_across_temporary_keys",
|
||||
keys: []string{"key-0", "key-1", "key-2"},
|
||||
setup: func(t *testing.T, pool *keypool.Pool) {
|
||||
walker := pool.Walker()
|
||||
key0, err := walker.Next()
|
||||
require.NoError(t, err)
|
||||
key0.MarkTemporary(60 * time.Second)
|
||||
key1, err := walker.Next()
|
||||
require.NoError(t, err)
|
||||
key1.MarkTemporary(10 * time.Second)
|
||||
key2, err := walker.Next()
|
||||
require.NoError(t, err)
|
||||
key2.MarkTemporary(30 * time.Second)
|
||||
},
|
||||
expectedValid: []string{},
|
||||
expectedErr: &keypool.TransientKeyPoolError{RetryAfter: 10 * time.Second},
|
||||
},
|
||||
{
|
||||
// Given: key-0: temporary, key-1: temporary.
|
||||
@@ -424,7 +458,8 @@ func TestWalkerNext(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
key1.MarkTemporary(60 * time.Second)
|
||||
},
|
||||
expectValid: []string{},
|
||||
expectedValid: []string{},
|
||||
expectedErr: &keypool.TransientKeyPoolError{RetryAfter: 60 * time.Second},
|
||||
},
|
||||
{
|
||||
// Given: key-0: permanent, key-1: permanent.
|
||||
@@ -440,7 +475,8 @@ func TestWalkerNext(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
key1.MarkPermanent()
|
||||
},
|
||||
expectValid: []string{},
|
||||
expectedValid: []string{},
|
||||
expectedErr: keypool.ErrPermanentKeyPool,
|
||||
},
|
||||
{
|
||||
// Given: key-0: permanent, key-1: temporary, key-2: permanent.
|
||||
@@ -459,7 +495,8 @@ func TestWalkerNext(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
key2.MarkPermanent()
|
||||
},
|
||||
expectValid: []string{},
|
||||
expectedValid: []string{},
|
||||
expectedErr: &keypool.TransientKeyPoolError{RetryAfter: 60 * time.Second},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -478,7 +515,7 @@ func TestWalkerNext(t *testing.T) {
|
||||
}
|
||||
|
||||
walker := pool.Walker()
|
||||
for _, expectedKey := range tc.expectValid {
|
||||
for _, expectedKey := range tc.expectedValid {
|
||||
key, err := walker.Next()
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expectedKey, key.Value())
|
||||
@@ -486,7 +523,93 @@ func TestWalkerNext(t *testing.T) {
|
||||
|
||||
// After all expected keys, the walker should be exhausted.
|
||||
_, err = walker.Next()
|
||||
require.ErrorIs(t, err, keypool.ErrAllKeysExhausted)
|
||||
var wantTransient *keypool.TransientKeyPoolError
|
||||
if errors.As(tc.expectedErr, &wantTransient) {
|
||||
var got *keypool.TransientKeyPoolError
|
||||
require.ErrorAs(t, err, &got)
|
||||
assert.Equal(t, wantTransient.RetryAfter, got.RetryAfter)
|
||||
} else {
|
||||
require.ErrorIs(t, err, tc.expectedErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestKeyConcurrent exercises the documented concurrent-safety
|
||||
// contract by hammering a single key with concurrent Mark calls
|
||||
// and asserting the resulting state honors the pool's invariants.
|
||||
func TestKeyConcurrent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
// run is called concurrently from numGoroutines, each
|
||||
// with its own index.
|
||||
run func(idx int, key *keypool.Key)
|
||||
// verify asserts the final state. May advance the clock.
|
||||
verify func(t *testing.T, key *keypool.Key, clk *quartz.Mock)
|
||||
}{
|
||||
{
|
||||
// Half of the goroutines mark the key as temporary
|
||||
// with 60s, the other half with 10s. The longer
|
||||
// cooldown must win regardless of ordering.
|
||||
name: "longer_cooldown_wins",
|
||||
run: func(idx int, key *keypool.Key) {
|
||||
if idx%2 == 0 {
|
||||
key.MarkTemporary(60 * time.Second)
|
||||
} else {
|
||||
key.MarkTemporary(10 * time.Second)
|
||||
}
|
||||
},
|
||||
verify: func(t *testing.T, key *keypool.Key, clk *quartz.Mock) {
|
||||
// At 50s the 60s cooldown is still active.
|
||||
clk.Advance(50 * time.Second)
|
||||
assert.Equal(t, keypool.KeyStateTemporary, key.State())
|
||||
// At 65s the 60s cooldown has expired.
|
||||
clk.Advance(15 * time.Second)
|
||||
assert.Equal(t, keypool.KeyStateValid, key.State())
|
||||
},
|
||||
},
|
||||
{
|
||||
// Half of the goroutines mark the key as permanent,
|
||||
// the other half mark it as temporary. Permanent is
|
||||
// terminal: any permanent call wins.
|
||||
name: "permanent_wins_over_temporary",
|
||||
run: func(idx int, key *keypool.Key) {
|
||||
if idx%2 == 0 {
|
||||
key.MarkPermanent()
|
||||
} else {
|
||||
key.MarkTemporary(60 * time.Second)
|
||||
}
|
||||
},
|
||||
verify: func(t *testing.T, key *keypool.Key, _ *quartz.Mock) {
|
||||
assert.Equal(t, keypool.KeyStatePermanent, key.State())
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
clk := quartz.NewMock(t)
|
||||
pool, err := keypool.New([]string{"key-0"}, clk)
|
||||
require.NoError(t, err)
|
||||
key, err := pool.Walker().Next()
|
||||
require.NoError(t, err)
|
||||
|
||||
const numGoroutines = 10
|
||||
var wg sync.WaitGroup
|
||||
for r := range numGoroutines {
|
||||
wg.Add(1)
|
||||
go func(r int) {
|
||||
defer wg.Done()
|
||||
tc.run(r, key)
|
||||
}(r)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
tc.verify(t, key, clk)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,8 +15,10 @@ import (
|
||||
"github.com/coder/coder/v2/aibridge/config"
|
||||
"github.com/coder/coder/v2/aibridge/intercept"
|
||||
"github.com/coder/coder/v2/aibridge/intercept/messages"
|
||||
"github.com/coder/coder/v2/aibridge/keypool"
|
||||
"github.com/coder/coder/v2/aibridge/tracing"
|
||||
"github.com/coder/coder/v2/aibridge/utils"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
// anthropicForwardHeaders lists headers from incoming requests that should be
|
||||
@@ -55,6 +57,24 @@ func NewAnthropic(cfg config.Anthropic, bedrockCfg *config.AWSBedrock) *Anthropi
|
||||
if cfg.BaseURL == "" {
|
||||
cfg.BaseURL = "https://api.anthropic.com/"
|
||||
}
|
||||
// Resolve centralized key configuration into KeyPool.
|
||||
// Precedence:
|
||||
// 1. cfg.KeyPool (explicit, highest priority).
|
||||
// 2. cfg.Key (legacy single key).
|
||||
// After this block cfg.Key is empty so it can only carry a
|
||||
// BYOK X-Api-Key set per interception in CreateInterceptor.
|
||||
// TODO(ssncferreira): simplify auth field resolution per
|
||||
// https://github.com/coder/aibridge/issues/266.
|
||||
if cfg.KeyPool == nil && cfg.Key != "" {
|
||||
// keypool.New only fails on empty or duplicate keys,
|
||||
// neither possible with a single non-empty key.
|
||||
pool, err := keypool.New([]string{cfg.Key}, quartz.NewReal())
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("anthropic provider: build single-key pool: %s", err))
|
||||
}
|
||||
cfg.KeyPool = pool
|
||||
}
|
||||
cfg.Key = ""
|
||||
if cfg.CircuitBreaker != nil {
|
||||
cfg.CircuitBreaker.IsFailure = anthropicIsFailure
|
||||
cfg.CircuitBreaker.OpenErrorResponse = anthropicOpenErrorResponse
|
||||
@@ -119,29 +139,41 @@ func (p *Anthropic) CreateInterceptor(_ http.ResponseWriter, r *http.Request, tr
|
||||
// Any Coder-specific authentication has already been stripped.
|
||||
//
|
||||
// In centralized mode neither Authorization nor X-Api-Key is
|
||||
// present, so cfg keeps the centralized key unchanged.
|
||||
// present, so cfg keeps the KeyPool from provider construction
|
||||
// and the failover loop walks it.
|
||||
//
|
||||
// In BYOK mode the user's LLM credentials survive intact.
|
||||
// If X-Api-Key is present the user has a personal API key;
|
||||
// overwrite the centralized key with it. If Authorization is
|
||||
// present the user authenticated directly with provider;
|
||||
// set BYOKBearerToken and clear the centralized key.
|
||||
// When both are present, X-Api-Key takes priority to match
|
||||
// claude-code behavior.
|
||||
// In BYOK mode the user's LLM credentials survive intact and
|
||||
// failover is disabled by clearing cfg.KeyPool. If X-Api-Key is
|
||||
// present the user has a personal API key, populate cfg.Key.
|
||||
// If Authorization is present the user authenticated directly
|
||||
// with the provider, populate cfg.BYOKBearerToken. When both
|
||||
// are present, X-Api-Key takes priority to match claude-code
|
||||
// behavior.
|
||||
//
|
||||
// TODO(ssncferreira): consolidate auth field handling per
|
||||
// https://github.com/coder/aibridge/issues/266.
|
||||
credKind := intercept.CredentialKindCentralized
|
||||
credSecret := cfg.Key
|
||||
var credSecret string
|
||||
authHeaderName := p.AuthHeader()
|
||||
if apiKey := r.Header.Get("X-Api-Key"); apiKey != "" {
|
||||
cfg.Key = apiKey
|
||||
cfg.KeyPool = nil
|
||||
authHeaderName = "X-Api-Key"
|
||||
credKind = intercept.CredentialKindBYOK
|
||||
credSecret = apiKey
|
||||
} else if token := utils.ExtractBearerToken(r.Header.Get("Authorization")); token != "" {
|
||||
cfg.BYOKBearerToken = token
|
||||
cfg.Key = ""
|
||||
cfg.KeyPool = nil
|
||||
authHeaderName = "Authorization"
|
||||
credKind = intercept.CredentialKindBYOK
|
||||
credSecret = token
|
||||
} else if cfg.KeyPool != nil {
|
||||
// Centralized: use the first key as a placeholder hint.
|
||||
// TODO(ssncferreira): record the actually-used key in
|
||||
// the interception record to reflect failover.
|
||||
if k, err := cfg.KeyPool.Walker().Next(); err == nil {
|
||||
credSecret = k.Value()
|
||||
}
|
||||
}
|
||||
|
||||
cred := intercept.NewCredentialInfo(credKind, credSecret)
|
||||
@@ -175,7 +207,16 @@ func (p *Anthropic) InjectAuthHeader(headers *http.Header) {
|
||||
return
|
||||
}
|
||||
|
||||
headers.Set(p.AuthHeader(), p.cfg.Key)
|
||||
// Centralized: pull a single key from the pool. No failover
|
||||
// or exhaustion handling here.
|
||||
// TODO(ssncferreira): replace with RoundTripper-based auth
|
||||
// in the upstack passthrough PR.
|
||||
if p.cfg.KeyPool == nil {
|
||||
return
|
||||
}
|
||||
if key, err := p.cfg.KeyPool.Walker().Next(); err == nil {
|
||||
headers.Set(p.AuthHeader(), key.Value())
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Anthropic) CircuitBreakerConfig() *config.CircuitBreaker {
|
||||
|
||||
@@ -13,6 +13,8 @@ import (
|
||||
"github.com/coder/coder/v2/aibridge/config"
|
||||
"github.com/coder/coder/v2/aibridge/intercept"
|
||||
"github.com/coder/coder/v2/aibridge/internal/testutil"
|
||||
"github.com/coder/coder/v2/aibridge/keypool"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
func TestAnthropic_TypeAndName(t *testing.T) {
|
||||
@@ -49,6 +51,70 @@ func TestAnthropic_TypeAndName(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewAnthropic_KeyResolution(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pool, err := keypool.New([]string{"pool-key-0", "pool-key-1"}, quartz.NewMock(t))
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cfg config.Anthropic
|
||||
expectedKeys []string
|
||||
}{
|
||||
{
|
||||
// Legacy single-key path: NewAnthropic builds a
|
||||
// pool containing just that key.
|
||||
name: "key_creates_keypool",
|
||||
cfg: config.Anthropic{Key: "legacy-key"},
|
||||
expectedKeys: []string{"legacy-key"},
|
||||
},
|
||||
{
|
||||
// Caller supplies the pool directly.
|
||||
name: "keypool_passed_directly",
|
||||
cfg: config.Anthropic{KeyPool: pool},
|
||||
expectedKeys: []string{"pool-key-0", "pool-key-1"},
|
||||
},
|
||||
{
|
||||
// Both set: KeyPool wins, Key is ignored.
|
||||
name: "keypool_takes_precedence_over_key",
|
||||
cfg: config.Anthropic{Key: "legacy-key", KeyPool: pool},
|
||||
expectedKeys: []string{"pool-key-0", "pool-key-1"},
|
||||
},
|
||||
{
|
||||
// Neither set: no centralized auth available. BYOK
|
||||
// auth is set per-request in CreateInterceptor.
|
||||
name: "neither_set_no_centralized_auth",
|
||||
cfg: config.Anthropic{},
|
||||
expectedKeys: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
p := NewAnthropic(tc.cfg, nil)
|
||||
|
||||
if tc.expectedKeys == nil {
|
||||
assert.Nil(t, p.cfg.KeyPool, "expected no KeyPool")
|
||||
return
|
||||
}
|
||||
|
||||
require.NotNil(t, p.cfg.KeyPool)
|
||||
walker := p.cfg.KeyPool.Walker()
|
||||
var got []string
|
||||
for {
|
||||
key, err := walker.Next()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
got = append(got, key.Value())
|
||||
}
|
||||
assert.Equal(t, tc.expectedKeys, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnthropic_CreateInterceptor(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user