mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
refactor(aibridge): clean up keypool and provider error handling (#25609)
## Description Cleans up how key pool errors are represented and how they get turned into HTTP responses. Consolidates two error types into a single type with a kind tag, and gives the response helpers in both providers consistent names. ## Changes - Replaced the keypool sentinel and transient error struct with one error type that carries a kind and a retry-after duration. - Updated `KeyFailoverConfig.BuildKeyPoolResponse` to take the typed key pool error, so each provider can shape the exhaustion response in its own format. - Removed the per-provider `MarkKey` callback from `KeyFailoverConfig` since providers can rely on the shared `MarkKeyOnStatus` helper. - Renamed the response-error helpers so OpenAI and Anthropic use the same naming. Related to: https://linear.app/codercom/issue/AIGOV-334/aibridge-follow-ups-from-key-failover-prs > [!NOTE] > Initially generated by Claude Opus 4.7, modified and reviewed by @ssncferreira
This commit is contained in:
@@ -9,12 +9,10 @@ import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/openai/openai-go/v3"
|
||||
"github.com/openai/openai-go/v3/option"
|
||||
"github.com/openai/openai-go/v3/shared"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
|
||||
@@ -27,7 +25,6 @@ import (
|
||||
"github.com/coder/coder/v2/aibridge/mcp"
|
||||
"github.com/coder/coder/v2/aibridge/recorder"
|
||||
"github.com/coder/coder/v2/aibridge/tracing"
|
||||
"github.com/coder/coder/v2/aibridge/utils"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
@@ -189,7 +186,7 @@ func (i *interceptionBase) unmarshalArgs(in string) (args recorder.ToolArgs) {
|
||||
}
|
||||
|
||||
// writeUpstreamError marshals and writes a given error.
|
||||
func (i *interceptionBase) writeUpstreamError(w http.ResponseWriter, oaiErr *ResponseError) {
|
||||
func (i *interceptionBase) writeUpstreamError(w http.ResponseWriter, oaiErr *intercept.ResponseError) {
|
||||
if oaiErr == nil {
|
||||
return
|
||||
}
|
||||
@@ -235,33 +232,6 @@ func (i *interceptionBase) markKeyOnError(ctx context.Context, key *keypool.Key,
|
||||
)
|
||||
}
|
||||
|
||||
// ProcessKeyPoolError translates a keypool exhaustion error
|
||||
// into a developer-facing responseError shaped for the OpenAI
|
||||
// API. Returns nil if err is not an exhaustion error.
|
||||
func ProcessKeyPoolError(err error) *ResponseError {
|
||||
var transient *keypool.TransientKeyPoolError
|
||||
switch {
|
||||
case errors.As(err, &transient):
|
||||
return newErrorResponse(
|
||||
"all configured keys are rate-limited",
|
||||
intercept.OpenAIErrTypeRateLimit,
|
||||
intercept.OpenAIErrCodeRateLimit,
|
||||
http.StatusTooManyRequests,
|
||||
transient.RetryAfter,
|
||||
)
|
||||
case errors.Is(err, keypool.ErrPermanentKeyPool):
|
||||
return newErrorResponse(
|
||||
"all configured keys failed authentication",
|
||||
intercept.OpenAIErrTypeAPI,
|
||||
intercept.OpenAIErrCodeServer,
|
||||
http.StatusBadGateway,
|
||||
0,
|
||||
)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (i *interceptionBase) hasInjectableTools() bool {
|
||||
return i.mcpProxy != nil && len(i.mcpProxy.ListTools()) > 0
|
||||
}
|
||||
@@ -292,48 +262,3 @@ func calculateActualInputTokenUsage(in openai.CompletionUsage) int64 {
|
||||
return in.PromptTokens /* The aggregated number of text input tokens used, including cached tokens. */ -
|
||||
in.PromptTokensDetails.CachedTokens /* The aggregated number of text input tokens that has been cached from previous requests. */
|
||||
}
|
||||
|
||||
func getErrorResponse(err error) *ResponseError {
|
||||
var apiErr *openai.Error
|
||||
if !errors.As(err, &apiErr) {
|
||||
return nil
|
||||
}
|
||||
return newErrorResponse(apiErr.Message, apiErr.Type, apiErr.Code, apiErr.StatusCode, keypool.ParseRetryAfter(apiErr.Response))
|
||||
}
|
||||
|
||||
var _ error = &ResponseError{}
|
||||
|
||||
type ResponseError struct {
|
||||
ErrorObject *shared.ErrorObject `json:"error"`
|
||||
StatusCode int `json:"-"`
|
||||
RetryAfter time.Duration `json:"-"`
|
||||
}
|
||||
|
||||
func newErrorResponse(msg, errType, code string, status int, retryAfter time.Duration) *ResponseError {
|
||||
return &ResponseError{
|
||||
ErrorObject: &shared.ErrorObject{
|
||||
Code: code,
|
||||
Message: msg,
|
||||
Type: errType,
|
||||
},
|
||||
StatusCode: status,
|
||||
RetryAfter: retryAfter,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *ResponseError) Error() string {
|
||||
if e.ErrorObject == nil {
|
||||
return ""
|
||||
}
|
||||
return e.ErrorObject.Message
|
||||
}
|
||||
|
||||
// ToResponse marshals e into an *http.Response shaped for the
|
||||
// OpenAI API.
|
||||
func (e *ResponseError) ToResponse() *http.Response {
|
||||
body, err := json.Marshal(e)
|
||||
if err != nil {
|
||||
body = []byte(`{"error":{"type":"error","message":"error marshaling upstream error","code":"server_error"}}`)
|
||||
}
|
||||
return utils.NewJSONErrorResponse(e.StatusCode, e.RetryAfter, body)
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/aibridge/config"
|
||||
"github.com/coder/coder/v2/aibridge/intercept"
|
||||
"github.com/coder/coder/v2/aibridge/keypool"
|
||||
"github.com/coder/coder/v2/aibridge/utils"
|
||||
"github.com/coder/quartz"
|
||||
@@ -86,59 +87,6 @@ func TestScanForCorrelatingToolCallID(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessKeyPoolError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
expectedNil bool
|
||||
expectedStatus int
|
||||
expectedRetryAfter time.Duration
|
||||
}{
|
||||
{
|
||||
// Transient with valid keys present: 429, no Retry-After.
|
||||
name: "transient_zero_retry_after",
|
||||
err: &keypool.TransientKeyPoolError{},
|
||||
expectedStatus: http.StatusTooManyRequests,
|
||||
expectedRetryAfter: 0,
|
||||
},
|
||||
{
|
||||
// Transient with cooldown: 429, Retry-After set.
|
||||
name: "transient_with_retry_after",
|
||||
err: &keypool.TransientKeyPoolError{RetryAfter: 5 * time.Second},
|
||||
expectedStatus: http.StatusTooManyRequests,
|
||||
expectedRetryAfter: 5 * time.Second,
|
||||
},
|
||||
{
|
||||
// Permanent: 502 api_error.
|
||||
name: "permanent_returns_502",
|
||||
err: keypool.ErrPermanentKeyPool,
|
||||
expectedStatus: http.StatusBadGateway,
|
||||
},
|
||||
{
|
||||
// Anything else: not a pool-exhaustion error.
|
||||
name: "non_pool_exhaustion_error_returns_nil",
|
||||
err: xerrors.New("some other error"),
|
||||
expectedNil: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := ProcessKeyPoolError(tc.err)
|
||||
if tc.expectedNil {
|
||||
require.Nil(t, got)
|
||||
return
|
||||
}
|
||||
require.NotNil(t, got)
|
||||
assert.Equal(t, tc.expectedStatus, got.StatusCode)
|
||||
assert.Equal(t, tc.expectedRetryAfter, got.RetryAfter)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkKeyOnError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -190,8 +138,8 @@ func TestMarkKeyOnError(t *testing.T) {
|
||||
t.Parallel()
|
||||
pool, err := keypool.New([]string{"key-0"}, quartz.NewMock(t))
|
||||
require.NoError(t, err)
|
||||
key, err := pool.Walker().Next()
|
||||
require.NoError(t, err)
|
||||
key, keyPoolErr := pool.Walker().Next()
|
||||
require.Nil(t, keyPoolErr)
|
||||
|
||||
base := &interceptionBase{cfg: config.OpenAI{KeyPool: pool}, logger: slog.Make()}
|
||||
|
||||
@@ -207,7 +155,7 @@ func TestWriteUpstreamError(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
respErr *ResponseError
|
||||
respErr *intercept.ResponseError
|
||||
expectStatus int
|
||||
// Empty string means the header should be absent.
|
||||
expectRetryAfter string
|
||||
@@ -217,42 +165,42 @@ func TestWriteUpstreamError(t *testing.T) {
|
||||
{
|
||||
// Standard error: status, code, and JSON body written.
|
||||
name: "writes_status_and_body",
|
||||
respErr: newErrorResponse("upstream failed", "api_error", "server_error", http.StatusBadGateway, 0),
|
||||
respErr: intercept.NewResponseError("upstream failed", "api_error", "server_error", http.StatusBadGateway, 0),
|
||||
expectStatus: http.StatusBadGateway,
|
||||
expectBodyContains: `"upstream failed"`,
|
||||
},
|
||||
{
|
||||
// OpenAI envelope: the code field round-trips into the body.
|
||||
name: "writes_code_field",
|
||||
respErr: newErrorResponse("rate limited", "rate_limit_error", "rate_limit_exceeded", http.StatusTooManyRequests, 0),
|
||||
respErr: intercept.NewResponseError("rate limited", "rate_limit_error", "rate_limit_exceeded", http.StatusTooManyRequests, 0),
|
||||
expectStatus: http.StatusTooManyRequests,
|
||||
expectBodyContains: `"rate_limit_exceeded"`,
|
||||
},
|
||||
{
|
||||
// Whole-second retryAfter: emitted as integer seconds.
|
||||
name: "retry_after_in_seconds",
|
||||
respErr: newErrorResponse("rate limited", "rate_limit_error", "rate_limit_exceeded", http.StatusTooManyRequests, 60*time.Second),
|
||||
respErr: intercept.NewResponseError("rate limited", "rate_limit_error", "rate_limit_exceeded", http.StatusTooManyRequests, 60*time.Second),
|
||||
expectStatus: http.StatusTooManyRequests,
|
||||
expectRetryAfter: "60",
|
||||
},
|
||||
{
|
||||
// 500ms rounds up to Retry-After: 1.
|
||||
name: "retry_after_500ms_rounds_up_to_one",
|
||||
respErr: newErrorResponse("rate limited", "rate_limit_error", "rate_limit_exceeded", http.StatusTooManyRequests, 500*time.Millisecond),
|
||||
respErr: intercept.NewResponseError("rate limited", "rate_limit_error", "rate_limit_exceeded", http.StatusTooManyRequests, 500*time.Millisecond),
|
||||
expectStatus: http.StatusTooManyRequests,
|
||||
expectRetryAfter: "1",
|
||||
},
|
||||
{
|
||||
// 200ms rounds up to Retry-After: 1.
|
||||
name: "retry_after_200ms_rounds_up_to_one",
|
||||
respErr: newErrorResponse("rate limited", "rate_limit_error", "rate_limit_exceeded", http.StatusTooManyRequests, 200*time.Millisecond),
|
||||
respErr: intercept.NewResponseError("rate limited", "rate_limit_error", "rate_limit_exceeded", http.StatusTooManyRequests, 200*time.Millisecond),
|
||||
expectStatus: http.StatusTooManyRequests,
|
||||
expectRetryAfter: "1",
|
||||
},
|
||||
{
|
||||
// Negative retryAfter: header omitted.
|
||||
name: "negative_retry_after_omits_header",
|
||||
respErr: newErrorResponse("rate limited", "rate_limit_error", "rate_limit_exceeded", http.StatusTooManyRequests, -1*time.Second),
|
||||
respErr: intercept.NewResponseError("rate limited", "rate_limit_error", "rate_limit_exceeded", http.StatusTooManyRequests, -1*time.Second),
|
||||
expectStatus: http.StatusTooManyRequests,
|
||||
expectRetryAfter: "",
|
||||
},
|
||||
|
||||
@@ -3,6 +3,7 @@ package chatcompletions
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -19,6 +20,7 @@ import (
|
||||
aibcontext "github.com/coder/coder/v2/aibridge/context"
|
||||
"github.com/coder/coder/v2/aibridge/intercept"
|
||||
"github.com/coder/coder/v2/aibridge/intercept/eventstream"
|
||||
"github.com/coder/coder/v2/aibridge/keypool"
|
||||
"github.com/coder/coder/v2/aibridge/mcp"
|
||||
"github.com/coder/coder/v2/aibridge/recorder"
|
||||
"github.com/coder/coder/v2/aibridge/tracing"
|
||||
@@ -224,12 +226,13 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req
|
||||
|
||||
// The failover loop may return a keypool exhaustion
|
||||
// error. Check before the SDK-error path.
|
||||
if keyErr := ProcessKeyPoolError(err); keyErr != nil {
|
||||
i.writeUpstreamError(w, keyErr)
|
||||
var keyPoolErr *keypool.Error
|
||||
if errors.As(err, &keyPoolErr) {
|
||||
i.writeUpstreamError(w, intercept.ResponseErrorFromKeyPool(keyPoolErr))
|
||||
return xerrors.Errorf("key pool exhausted: %w", err)
|
||||
}
|
||||
|
||||
if apiErr := getErrorResponse(err); apiErr != nil {
|
||||
if apiErr := intercept.ResponseErrorFromAPIError(err); apiErr != nil {
|
||||
i.writeUpstreamError(w, apiErr)
|
||||
return xerrors.Errorf("openai API error: %w", err)
|
||||
}
|
||||
@@ -293,9 +296,9 @@ func (i *BlockingInterception) newChatCompletionWithKeyFailover(ctx context.Cont
|
||||
// success, the last tried key on failure) in the upstack PR.
|
||||
walker := i.cfg.KeyPool.Walker()
|
||||
for {
|
||||
key, err := walker.Next()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
key, keyPoolErr := walker.Next()
|
||||
if keyPoolErr != nil {
|
||||
return nil, keyPoolErr
|
||||
}
|
||||
|
||||
requestOpts := append([]option.RequestOption{}, opts...)
|
||||
|
||||
@@ -143,8 +143,9 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re
|
||||
var opts []option.RequestOption
|
||||
var currentKey *keypool.Key
|
||||
if walker != nil {
|
||||
key, err := walker.Next()
|
||||
if respErr := ProcessKeyPoolError(err); respErr != nil {
|
||||
key, keyPoolErr := walker.Next()
|
||||
if keyPoolErr != nil {
|
||||
respErr := intercept.ResponseErrorFromKeyPool(keyPoolErr)
|
||||
// Pool exhausted in this iteration. Relay the
|
||||
// error to the client: as an SSE event if events
|
||||
// have already been sent, or by direct write
|
||||
@@ -470,17 +471,17 @@ func (i *StreamingInterception) newStream(ctx context.Context, svc openai.ChatCo
|
||||
}
|
||||
|
||||
// mapStreamError converts a mid-stream upstream error or
|
||||
// processing error into a relayable responseError. Returns nil
|
||||
// processing error into a relayable ResponseError. Returns nil
|
||||
// when the error is unrecoverable, in which case nothing can be
|
||||
// relayed back.
|
||||
func (*StreamingInterception) mapStreamError(ctx context.Context, logger slog.Logger, streamErr, lastErr error) *ResponseError {
|
||||
func (*StreamingInterception) mapStreamError(ctx context.Context, logger slog.Logger, streamErr, lastErr error) *intercept.ResponseError {
|
||||
if streamErr != nil {
|
||||
if eventstream.IsUnrecoverableError(streamErr) {
|
||||
logger.Debug(ctx, "stream terminated", slog.Error(streamErr))
|
||||
// We can't reflect an error back if there's a connection error or the request context was canceled.
|
||||
return nil
|
||||
}
|
||||
if oaiErr := getErrorResponse(streamErr); oaiErr != nil {
|
||||
if oaiErr := intercept.ResponseErrorFromAPIError(streamErr); oaiErr != nil {
|
||||
logger.Warn(ctx, "openai stream error", slog.Error(streamErr))
|
||||
return oaiErr
|
||||
}
|
||||
@@ -489,11 +490,11 @@ func (*StreamingInterception) mapStreamError(ctx context.Context, logger slog.Lo
|
||||
// into known types (i.e. [shared.OverloadedError]).
|
||||
// See https://github.com/openai/openai-go/blob/v2.7.0/packages/ssestream/ssestream.go#L171
|
||||
// All it does is wrap the payload in an error - which is all we can return, currently.
|
||||
return newErrorResponse(fmt.Sprintf("unknown stream error: %s", streamErr), intercept.OpenAIErrTypeError, intercept.OpenAIErrTypeError, http.StatusBadGateway, 0)
|
||||
return intercept.NewResponseError(fmt.Sprintf("unknown stream error: %s", streamErr), intercept.OpenAIErrTypeError, intercept.OpenAIErrTypeError, http.StatusBadGateway, 0)
|
||||
}
|
||||
if lastErr != nil {
|
||||
logger.Warn(ctx, "stream processing failed", slog.Error(lastErr))
|
||||
return newErrorResponse(fmt.Sprintf("processing error: %s", lastErr), intercept.OpenAIErrTypeError, intercept.OpenAIErrTypeError, http.StatusBadGateway, 0)
|
||||
return intercept.NewResponseError(fmt.Sprintf("processing error: %s", lastErr), intercept.OpenAIErrTypeError, intercept.OpenAIErrTypeError, http.StatusBadGateway, 0)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -583,32 +583,36 @@ func (i *interceptionBase) markKeyOnError(ctx context.Context, key *keypool.Key,
|
||||
)
|
||||
}
|
||||
|
||||
// ProcessKeyPoolError translates a keypool exhaustion error
|
||||
// into a developer-facing responseError shaped for the Anthropic
|
||||
// API. Returns nil if err is not an exhaustion error.
|
||||
func ProcessKeyPoolError(err error) *ResponseError {
|
||||
var transient *keypool.TransientKeyPoolError
|
||||
switch {
|
||||
case errors.As(err, &transient):
|
||||
return newErrorResponse(
|
||||
"all configured keys are rate-limited",
|
||||
string(constant.ValueOf[constant.RateLimitError]()),
|
||||
http.StatusTooManyRequests,
|
||||
transient.RetryAfter,
|
||||
)
|
||||
case errors.Is(err, keypool.ErrPermanentKeyPool):
|
||||
return newErrorResponse(
|
||||
"all configured keys failed authentication",
|
||||
// ResponseErrorFromKeyPool translates a *keypool.Error into
|
||||
// a developer-facing ResponseError shaped for the Anthropic API.
|
||||
func ResponseErrorFromKeyPool(keyPoolErr *keypool.Error) *ResponseError {
|
||||
switch keyPoolErr.Kind {
|
||||
case keypool.ErrorKindPermanent:
|
||||
return newResponseError(
|
||||
keyPoolErr.Error(),
|
||||
string(constant.ValueOf[constant.APIError]()),
|
||||
http.StatusBadGateway,
|
||||
0,
|
||||
keyPoolErr.RetryAfter,
|
||||
)
|
||||
case keypool.ErrorKindRateLimited:
|
||||
return newResponseError(
|
||||
keyPoolErr.Error(),
|
||||
string(constant.ValueOf[constant.RateLimitError]()),
|
||||
http.StatusTooManyRequests,
|
||||
keyPoolErr.RetryAfter,
|
||||
)
|
||||
default:
|
||||
return nil
|
||||
// Fall back to a generic 502.
|
||||
return newResponseError(
|
||||
keyPoolErr.Error(),
|
||||
string(constant.ValueOf[constant.APIError]()),
|
||||
http.StatusBadGateway,
|
||||
keyPoolErr.RetryAfter,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func getErrorResponse(err error) *ResponseError {
|
||||
func responseErrorFromAPIError(err error) *ResponseError {
|
||||
var apierr *anthropic.Error
|
||||
if !errors.As(err, &apierr) {
|
||||
return nil
|
||||
@@ -626,7 +630,7 @@ func getErrorResponse(err error) *ResponseError {
|
||||
errType = string(detail.Type)
|
||||
}
|
||||
|
||||
return newErrorResponse(msg, errType, apierr.StatusCode, keypool.ParseRetryAfter(apierr.Response))
|
||||
return newResponseError(msg, errType, apierr.StatusCode, keypool.ParseRetryAfter(apierr.Response))
|
||||
}
|
||||
|
||||
var _ error = &ResponseError{}
|
||||
@@ -638,7 +642,7 @@ type ResponseError struct {
|
||||
RetryAfter time.Duration `json:"-"`
|
||||
}
|
||||
|
||||
func newErrorResponse(msg, errType string, status int, retryAfter time.Duration) *ResponseError {
|
||||
func newResponseError(msg, errType string, status int, retryAfter time.Duration) *ResponseError {
|
||||
return &ResponseError{
|
||||
ErrorResponse: &shared.ErrorResponse{
|
||||
Error: shared.ErrorObjectUnion{
|
||||
|
||||
@@ -1061,52 +1061,41 @@ func TestFilterBedrockBetaFlags(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessKeyPoolError(t *testing.T) {
|
||||
func TestResponseErrorFromKeyPool(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
expectedNil bool
|
||||
keyPoolErr *keypool.Error
|
||||
expectedStatus int
|
||||
expectedRetryAfter time.Duration
|
||||
}{
|
||||
{
|
||||
// Transient with valid keys present: 429, no Retry-After.
|
||||
name: "transient_zero_retry_after",
|
||||
err: &keypool.TransientKeyPoolError{},
|
||||
// Rate-limited with no cooldown: 429, no Retry-After.
|
||||
name: "rate_limited_zero_retry_after",
|
||||
keyPoolErr: &keypool.Error{Kind: keypool.ErrorKindRateLimited},
|
||||
expectedStatus: http.StatusTooManyRequests,
|
||||
expectedRetryAfter: 0,
|
||||
},
|
||||
{
|
||||
// Transient with cooldown: 429, Retry-After set.
|
||||
name: "transient_with_retry_after",
|
||||
err: &keypool.TransientKeyPoolError{RetryAfter: 5 * time.Second},
|
||||
// Rate-limited with cooldown: 429, Retry-After set.
|
||||
name: "rate_limited_with_retry_after",
|
||||
keyPoolErr: &keypool.Error{Kind: keypool.ErrorKindRateLimited, RetryAfter: 5 * time.Second},
|
||||
expectedStatus: http.StatusTooManyRequests,
|
||||
expectedRetryAfter: 5 * time.Second,
|
||||
},
|
||||
{
|
||||
// Permanent: 502 api_error.
|
||||
name: "permanent_returns_502",
|
||||
err: keypool.ErrPermanentKeyPool,
|
||||
keyPoolErr: &keypool.Error{Kind: keypool.ErrorKindPermanent},
|
||||
expectedStatus: http.StatusBadGateway,
|
||||
},
|
||||
{
|
||||
// Anything else: not a pool-exhaustion error.
|
||||
name: "non_pool_exhaustion_error_returns_nil",
|
||||
err: xerrors.New("some other error"),
|
||||
expectedNil: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := ProcessKeyPoolError(tc.err)
|
||||
if tc.expectedNil {
|
||||
require.Nil(t, got)
|
||||
return
|
||||
}
|
||||
got := ResponseErrorFromKeyPool(tc.keyPoolErr)
|
||||
require.NotNil(t, got)
|
||||
assert.Equal(t, tc.expectedStatus, got.StatusCode)
|
||||
assert.Equal(t, tc.expectedRetryAfter, got.RetryAfter)
|
||||
@@ -1165,8 +1154,8 @@ func TestMarkKeyOnError(t *testing.T) {
|
||||
t.Parallel()
|
||||
pool, err := keypool.New([]string{"key-0"}, quartz.NewMock(t))
|
||||
require.NoError(t, err)
|
||||
key, err := pool.Walker().Next()
|
||||
require.NoError(t, err)
|
||||
key, keyPoolErr := pool.Walker().Next()
|
||||
require.Nil(t, keyPoolErr)
|
||||
|
||||
base := &interceptionBase{cfg: config.Anthropic{KeyPool: pool}, logger: slog.Make()}
|
||||
|
||||
@@ -1192,35 +1181,35 @@ func TestWriteUpstreamError(t *testing.T) {
|
||||
{
|
||||
// Standard error: status and JSON body written.
|
||||
name: "writes_status_and_body",
|
||||
respErr: newErrorResponse("upstream failed", "api_error", http.StatusBadGateway, 0),
|
||||
respErr: newResponseError("upstream failed", "api_error", http.StatusBadGateway, 0),
|
||||
expectStatus: http.StatusBadGateway,
|
||||
expectBodyContains: `"upstream failed"`,
|
||||
},
|
||||
{
|
||||
// Whole-second retryAfter: emitted as integer seconds.
|
||||
name: "retry_after_in_seconds",
|
||||
respErr: newErrorResponse("rate limited", "rate_limit_error", http.StatusTooManyRequests, 60*time.Second),
|
||||
respErr: newResponseError("rate limited", "rate_limit_error", http.StatusTooManyRequests, 60*time.Second),
|
||||
expectStatus: http.StatusTooManyRequests,
|
||||
expectRetryAfter: "60",
|
||||
},
|
||||
{
|
||||
// 500ms rounds up to Retry-After: 1.
|
||||
name: "retry_after_500ms_rounds_up_to_one",
|
||||
respErr: newErrorResponse("rate limited", "rate_limit_error", http.StatusTooManyRequests, 500*time.Millisecond),
|
||||
respErr: newResponseError("rate limited", "rate_limit_error", http.StatusTooManyRequests, 500*time.Millisecond),
|
||||
expectStatus: http.StatusTooManyRequests,
|
||||
expectRetryAfter: "1",
|
||||
},
|
||||
{
|
||||
// 200ms rounds up to Retry-After: 1.
|
||||
name: "retry_after_200ms_rounds_up_to_one",
|
||||
respErr: newErrorResponse("rate limited", "rate_limit_error", http.StatusTooManyRequests, 200*time.Millisecond),
|
||||
respErr: newResponseError("rate limited", "rate_limit_error", http.StatusTooManyRequests, 200*time.Millisecond),
|
||||
expectStatus: http.StatusTooManyRequests,
|
||||
expectRetryAfter: "1",
|
||||
},
|
||||
{
|
||||
// Negative retryAfter: header omitted.
|
||||
name: "negative_retry_after_omits_header",
|
||||
respErr: newErrorResponse("rate limited", "rate_limit_error", http.StatusTooManyRequests, -1*time.Second),
|
||||
respErr: newResponseError("rate limited", "rate_limit_error", http.StatusTooManyRequests, -1*time.Second),
|
||||
expectStatus: http.StatusTooManyRequests,
|
||||
expectRetryAfter: "",
|
||||
},
|
||||
|
||||
@@ -2,6 +2,7 @@ package messages
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
@@ -20,6 +21,7 @@ import (
|
||||
aibcontext "github.com/coder/coder/v2/aibridge/context"
|
||||
"github.com/coder/coder/v2/aibridge/intercept"
|
||||
"github.com/coder/coder/v2/aibridge/intercept/eventstream"
|
||||
"github.com/coder/coder/v2/aibridge/keypool"
|
||||
"github.com/coder/coder/v2/aibridge/mcp"
|
||||
"github.com/coder/coder/v2/aibridge/recorder"
|
||||
"github.com/coder/coder/v2/aibridge/tracing"
|
||||
@@ -114,12 +116,13 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req
|
||||
|
||||
// The failover loop may return a keypool exhaustion
|
||||
// error. Check before the SDK-error path.
|
||||
if keyErr := ProcessKeyPoolError(err); keyErr != nil {
|
||||
i.writeUpstreamError(w, keyErr)
|
||||
var keyPoolErr *keypool.Error
|
||||
if errors.As(err, &keyPoolErr) {
|
||||
i.writeUpstreamError(w, ResponseErrorFromKeyPool(keyPoolErr))
|
||||
return xerrors.Errorf("key pool exhausted: %w", err)
|
||||
}
|
||||
|
||||
if antErr := getErrorResponse(err); antErr != nil {
|
||||
if antErr := responseErrorFromAPIError(err); antErr != nil {
|
||||
i.writeUpstreamError(w, antErr)
|
||||
return xerrors.Errorf("anthropic API error: %w", err)
|
||||
}
|
||||
@@ -369,9 +372,9 @@ func (i *BlockingInterception) newMessageWithKeyFailover(ctx context.Context, sv
|
||||
// success, the last tried key on failure) in the upstack PR.
|
||||
walker := i.cfg.KeyPool.Walker()
|
||||
for {
|
||||
key, err := walker.Next()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
key, keyPoolErr := walker.Next()
|
||||
if keyPoolErr != nil {
|
||||
return nil, keyPoolErr
|
||||
}
|
||||
|
||||
msg, err := i.newMessageWithKey(ctx, svc,
|
||||
|
||||
@@ -174,12 +174,13 @@ newStream:
|
||||
var streamOpts []option.RequestOption
|
||||
var currentKey *keypool.Key
|
||||
if walker != nil {
|
||||
key, err := walker.Next()
|
||||
if respErr := ProcessKeyPoolError(err); respErr != nil {
|
||||
key, keyPoolErr := walker.Next()
|
||||
if keyPoolErr != nil {
|
||||
// Pool exhausted in this iteration. Relay the
|
||||
// error to the client: as an SSE event if events
|
||||
// have already been sent, or by direct write
|
||||
// otherwise.
|
||||
respErr := ResponseErrorFromKeyPool(keyPoolErr)
|
||||
interceptionErr = respErr
|
||||
if events.IsStreaming() {
|
||||
payload, mErr := i.marshal(respErr)
|
||||
@@ -607,7 +608,7 @@ func (*StreamingInterception) mapStreamError(ctx context.Context, logger slog.Lo
|
||||
// We can't reflect an error back if there's a connection error or the request context was canceled.
|
||||
return nil
|
||||
}
|
||||
if antErr := getErrorResponse(streamErr); antErr != nil {
|
||||
if antErr := responseErrorFromAPIError(streamErr); antErr != nil {
|
||||
logger.Warn(ctx, "anthropic stream error", slog.Error(streamErr))
|
||||
return antErr
|
||||
}
|
||||
@@ -616,11 +617,11 @@ func (*StreamingInterception) mapStreamError(ctx context.Context, logger slog.Lo
|
||||
// into known types (i.e. [shared.OverloadedError]).
|
||||
// See https://github.com/anthropics/anthropic-sdk-go/blob/v1.12.0/packages/ssestream/ssestream.go#L172-L174
|
||||
// All it does is wrap the payload in an error - which is all we can return, currently.
|
||||
return newErrorResponse(fmt.Sprintf("unknown stream error: %s", streamErr), string(constant.ValueOf[constant.Error]()), http.StatusBadGateway, 0)
|
||||
return newResponseError(fmt.Sprintf("unknown stream error: %s", streamErr), string(constant.ValueOf[constant.Error]()), http.StatusBadGateway, 0)
|
||||
}
|
||||
if lastErr != nil {
|
||||
logger.Warn(ctx, "stream processing failed", slog.Error(lastErr))
|
||||
return newErrorResponse(fmt.Sprintf("processing error: %s", lastErr), string(constant.ValueOf[constant.Error]()), http.StatusBadGateway, 0)
|
||||
return newResponseError(fmt.Sprintf("processing error: %s", lastErr), string(constant.ValueOf[constant.Error]()), http.StatusBadGateway, 0)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,5 +1,18 @@
|
||||
package intercept
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/openai/openai-go/v3"
|
||||
"github.com/openai/openai-go/v3/shared"
|
||||
|
||||
"github.com/coder/coder/v2/aibridge/keypool"
|
||||
"github.com/coder/coder/v2/aibridge/utils"
|
||||
)
|
||||
|
||||
// OpenAI error type and code constants used by the chatcompletions
|
||||
// and responses interceptors. The OpenAI Go SDK does not expose
|
||||
// these as typed constants, so we define our own.
|
||||
@@ -12,3 +25,89 @@ const (
|
||||
OpenAIErrCodeServer = "server_error"
|
||||
OpenAIErrCodeRateLimit = "rate_limit_exceeded"
|
||||
)
|
||||
|
||||
var _ error = &ResponseError{}
|
||||
|
||||
// ResponseError is the OpenAI-shaped error envelope returned to
|
||||
// clients. StatusCode and RetryAfter map to HTTP headers, not JSON
|
||||
// fields. The chatcompletions and responses interceptors both
|
||||
// use the same response error format.
|
||||
type ResponseError struct {
|
||||
ErrorObject *shared.ErrorObject `json:"error"`
|
||||
StatusCode int `json:"-"`
|
||||
RetryAfter time.Duration `json:"-"`
|
||||
}
|
||||
|
||||
// NewResponseError builds a ResponseError with the OpenAI-shaped
|
||||
// envelope. errType and code should be one of the OpenAIErrType*
|
||||
// and OpenAIErrCode* constants defined above.
|
||||
func NewResponseError(msg, errType, code string, status int, retryAfter time.Duration) *ResponseError {
|
||||
return &ResponseError{
|
||||
ErrorObject: &shared.ErrorObject{
|
||||
Code: code,
|
||||
Message: msg,
|
||||
Type: errType,
|
||||
},
|
||||
StatusCode: status,
|
||||
RetryAfter: retryAfter,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *ResponseError) Error() string {
|
||||
if e.ErrorObject == nil {
|
||||
return ""
|
||||
}
|
||||
return e.ErrorObject.Message
|
||||
}
|
||||
|
||||
// ToResponse marshals e into an *http.Response shaped for the
|
||||
// OpenAI API.
|
||||
func (e *ResponseError) ToResponse() *http.Response {
|
||||
body, err := json.Marshal(e)
|
||||
if err != nil {
|
||||
body = []byte(`{"error":{"type":"error","message":"error marshaling upstream error","code":"server_error"}}`)
|
||||
}
|
||||
return utils.NewJSONErrorResponse(e.StatusCode, e.RetryAfter, body)
|
||||
}
|
||||
|
||||
// ResponseErrorFromKeyPool translates a *keypool.Error into
|
||||
// a developer-facing ResponseError shaped for the OpenAI API.
|
||||
func ResponseErrorFromKeyPool(keyPoolErr *keypool.Error) *ResponseError {
|
||||
switch keyPoolErr.Kind {
|
||||
case keypool.ErrorKindPermanent:
|
||||
return NewResponseError(
|
||||
keyPoolErr.Error(),
|
||||
OpenAIErrTypeAPI,
|
||||
OpenAIErrCodeServer,
|
||||
http.StatusBadGateway,
|
||||
keyPoolErr.RetryAfter,
|
||||
)
|
||||
case keypool.ErrorKindRateLimited:
|
||||
return NewResponseError(
|
||||
keyPoolErr.Error(),
|
||||
OpenAIErrTypeRateLimit,
|
||||
OpenAIErrCodeRateLimit,
|
||||
http.StatusTooManyRequests,
|
||||
keyPoolErr.RetryAfter,
|
||||
)
|
||||
default:
|
||||
// Fall back to a generic 502.
|
||||
return NewResponseError(
|
||||
keyPoolErr.Error(),
|
||||
OpenAIErrTypeAPI,
|
||||
OpenAIErrCodeServer,
|
||||
http.StatusBadGateway,
|
||||
keyPoolErr.RetryAfter,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// ResponseErrorFromAPIError converts an OpenAI SDK error into a
|
||||
// ResponseError. Returns nil if err is not an *openai.Error.
|
||||
func ResponseErrorFromAPIError(err error) *ResponseError {
|
||||
var apiErr *openai.Error
|
||||
if !errors.As(err, &apiErr) {
|
||||
return nil
|
||||
}
|
||||
return NewResponseError(apiErr.Message, apiErr.Type, apiErr.Code, apiErr.StatusCode, keypool.ParseRetryAfter(apiErr.Response))
|
||||
}
|
||||
|
||||
@@ -0,0 +1,55 @@
|
||||
package intercept_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/aibridge/intercept"
|
||||
"github.com/coder/coder/v2/aibridge/keypool"
|
||||
)
|
||||
|
||||
func TestResponseErrorFromKeyPool(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
keyPoolErr *keypool.Error
|
||||
expectedStatus int
|
||||
expectedRetryAfter time.Duration
|
||||
}{
|
||||
{
|
||||
// Rate-limited with no cooldown: 429, no Retry-After.
|
||||
name: "rate_limited_zero_retry_after",
|
||||
keyPoolErr: &keypool.Error{Kind: keypool.ErrorKindRateLimited},
|
||||
expectedStatus: http.StatusTooManyRequests,
|
||||
expectedRetryAfter: 0,
|
||||
},
|
||||
{
|
||||
// Rate-limited with cooldown: 429, Retry-After set.
|
||||
name: "rate_limited_with_retry_after",
|
||||
keyPoolErr: &keypool.Error{Kind: keypool.ErrorKindRateLimited, RetryAfter: 5 * time.Second},
|
||||
expectedStatus: http.StatusTooManyRequests,
|
||||
expectedRetryAfter: 5 * time.Second,
|
||||
},
|
||||
{
|
||||
// Permanent: 502 api_error.
|
||||
name: "permanent_returns_502",
|
||||
keyPoolErr: &keypool.Error{Kind: keypool.ErrorKindPermanent},
|
||||
expectedStatus: http.StatusBadGateway,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := intercept.ResponseErrorFromKeyPool(tc.keyPoolErr)
|
||||
require.NotNil(t, got)
|
||||
assert.Equal(t, tc.expectedStatus, got.StatusCode)
|
||||
assert.Equal(t, tc.expectedRetryAfter, got.RetryAfter)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -19,7 +19,6 @@ import (
|
||||
"github.com/openai/openai-go/v3"
|
||||
"github.com/openai/openai-go/v3/option"
|
||||
"github.com/openai/openai-go/v3/responses"
|
||||
"github.com/openai/openai-go/v3/shared"
|
||||
"github.com/openai/openai-go/v3/shared/constant"
|
||||
"github.com/tidwall/gjson"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
@@ -35,7 +34,6 @@ import (
|
||||
"github.com/coder/coder/v2/aibridge/mcp"
|
||||
"github.com/coder/coder/v2/aibridge/recorder"
|
||||
"github.com/coder/coder/v2/aibridge/tracing"
|
||||
"github.com/coder/coder/v2/aibridge/utils"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
@@ -143,7 +141,7 @@ func (i *responsesInterceptionBase) validateRequest(ctx context.Context, w http.
|
||||
}
|
||||
|
||||
// writeUpstreamError marshals and writes a given error.
|
||||
func (i *responsesInterceptionBase) writeUpstreamError(w http.ResponseWriter, oaiErr *ResponseError) {
|
||||
func (i *responsesInterceptionBase) writeUpstreamError(w http.ResponseWriter, oaiErr *intercept.ResponseError) {
|
||||
if oaiErr == nil {
|
||||
return
|
||||
}
|
||||
@@ -189,70 +187,6 @@ func (i *responsesInterceptionBase) markKeyOnError(ctx context.Context, key *key
|
||||
)
|
||||
}
|
||||
|
||||
// ProcessKeyPoolError translates a keypool exhaustion error
|
||||
// into a developer-facing ResponseError shaped for the OpenAI
|
||||
// API. Returns nil if err is not an exhaustion error.
|
||||
func ProcessKeyPoolError(err error) *ResponseError {
|
||||
var transient *keypool.TransientKeyPoolError
|
||||
switch {
|
||||
case errors.As(err, &transient):
|
||||
return newErrorResponse(
|
||||
"all configured keys are rate-limited",
|
||||
intercept.OpenAIErrTypeRateLimit,
|
||||
intercept.OpenAIErrCodeRateLimit,
|
||||
http.StatusTooManyRequests,
|
||||
transient.RetryAfter,
|
||||
)
|
||||
case errors.Is(err, keypool.ErrPermanentKeyPool):
|
||||
return newErrorResponse(
|
||||
"all configured keys failed authentication",
|
||||
intercept.OpenAIErrTypeAPI,
|
||||
intercept.OpenAIErrCodeServer,
|
||||
http.StatusBadGateway,
|
||||
0,
|
||||
)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func newErrorResponse(msg, errType, code string, status int, retryAfter time.Duration) *ResponseError {
|
||||
return &ResponseError{
|
||||
ErrorObject: &shared.ErrorObject{
|
||||
Code: code,
|
||||
Message: msg,
|
||||
Type: errType,
|
||||
},
|
||||
StatusCode: status,
|
||||
RetryAfter: retryAfter,
|
||||
}
|
||||
}
|
||||
|
||||
var _ error = &ResponseError{}
|
||||
|
||||
type ResponseError struct {
|
||||
ErrorObject *shared.ErrorObject `json:"error"`
|
||||
StatusCode int `json:"-"`
|
||||
RetryAfter time.Duration `json:"-"`
|
||||
}
|
||||
|
||||
func (e *ResponseError) Error() string {
|
||||
if e.ErrorObject == nil {
|
||||
return ""
|
||||
}
|
||||
return e.ErrorObject.Message
|
||||
}
|
||||
|
||||
// ToResponse marshals e into an *http.Response shaped for the
|
||||
// OpenAI API.
|
||||
func (e *ResponseError) ToResponse() *http.Response {
|
||||
body, err := json.Marshal(e)
|
||||
if err != nil {
|
||||
body = []byte(`{"error":{"type":"error","message":"error marshaling upstream error","code":"server_error"}}`)
|
||||
}
|
||||
return utils.NewJSONErrorResponse(e.StatusCode, e.RetryAfter, body)
|
||||
}
|
||||
|
||||
// sendCustomErr sends custom responses.Error error to the client
|
||||
// it should only be called before any data is sent back to the client
|
||||
func (i *responsesInterceptionBase) sendCustomErr(ctx context.Context, w http.ResponseWriter, code int, err error) {
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/aibridge/config"
|
||||
"github.com/coder/coder/v2/aibridge/intercept"
|
||||
"github.com/coder/coder/v2/aibridge/internal/testutil"
|
||||
"github.com/coder/coder/v2/aibridge/keypool"
|
||||
"github.com/coder/coder/v2/aibridge/recorder"
|
||||
@@ -390,59 +391,6 @@ func TestResponseCopierDoesntSendIfNoResponseReceived(t *testing.T) {
|
||||
require.True(t, mrw.writeHeaderCalled)
|
||||
}
|
||||
|
||||
func TestProcessKeyPoolError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
err error
|
||||
expectedNil bool
|
||||
expectedStatus int
|
||||
expectedRetryAfter time.Duration
|
||||
}{
|
||||
{
|
||||
// Transient with valid keys present: 429, no Retry-After.
|
||||
name: "transient_zero_retry_after",
|
||||
err: &keypool.TransientKeyPoolError{},
|
||||
expectedStatus: http.StatusTooManyRequests,
|
||||
expectedRetryAfter: 0,
|
||||
},
|
||||
{
|
||||
// Transient with cooldown: 429, Retry-After set.
|
||||
name: "transient_with_retry_after",
|
||||
err: &keypool.TransientKeyPoolError{RetryAfter: 5 * time.Second},
|
||||
expectedStatus: http.StatusTooManyRequests,
|
||||
expectedRetryAfter: 5 * time.Second,
|
||||
},
|
||||
{
|
||||
// Permanent: 502 api_error.
|
||||
name: "permanent_returns_502",
|
||||
err: keypool.ErrPermanentKeyPool,
|
||||
expectedStatus: http.StatusBadGateway,
|
||||
},
|
||||
{
|
||||
// Anything else: not a pool-exhaustion error.
|
||||
name: "non_pool_exhaustion_error_returns_nil",
|
||||
err: xerrors.New("some other error"),
|
||||
expectedNil: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := ProcessKeyPoolError(tc.err)
|
||||
if tc.expectedNil {
|
||||
require.Nil(t, got)
|
||||
return
|
||||
}
|
||||
require.NotNil(t, got)
|
||||
assert.Equal(t, tc.expectedStatus, got.StatusCode)
|
||||
assert.Equal(t, tc.expectedRetryAfter, got.RetryAfter)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMarkKeyOnError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -494,8 +442,8 @@ func TestMarkKeyOnError(t *testing.T) {
|
||||
t.Parallel()
|
||||
pool, err := keypool.New([]string{"key-0"}, quartz.NewMock(t))
|
||||
require.NoError(t, err)
|
||||
key, err := pool.Walker().Next()
|
||||
require.NoError(t, err)
|
||||
key, keyPoolErr := pool.Walker().Next()
|
||||
require.Nil(t, keyPoolErr)
|
||||
|
||||
base := &responsesInterceptionBase{cfg: config.OpenAI{KeyPool: pool}, logger: slog.Make()}
|
||||
|
||||
@@ -511,7 +459,7 @@ func TestWriteUpstreamError(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
respErr *ResponseError
|
||||
respErr *intercept.ResponseError
|
||||
expectStatus int
|
||||
// Empty string means the header should be absent.
|
||||
expectRetryAfter string
|
||||
@@ -521,42 +469,42 @@ func TestWriteUpstreamError(t *testing.T) {
|
||||
{
|
||||
// Standard error: status, code, and JSON body written.
|
||||
name: "writes_status_and_body",
|
||||
respErr: newErrorResponse("upstream failed", "api_error", "server_error", http.StatusBadGateway, 0),
|
||||
respErr: intercept.NewResponseError("upstream failed", "api_error", "server_error", http.StatusBadGateway, 0),
|
||||
expectStatus: http.StatusBadGateway,
|
||||
expectBodyContains: `"upstream failed"`,
|
||||
},
|
||||
{
|
||||
// OpenAI envelope: the code field round-trips into the body.
|
||||
name: "writes_code_field",
|
||||
respErr: newErrorResponse("rate limited", "rate_limit_error", "rate_limit_exceeded", http.StatusTooManyRequests, 0),
|
||||
respErr: intercept.NewResponseError("rate limited", "rate_limit_error", "rate_limit_exceeded", http.StatusTooManyRequests, 0),
|
||||
expectStatus: http.StatusTooManyRequests,
|
||||
expectBodyContains: `"rate_limit_exceeded"`,
|
||||
},
|
||||
{
|
||||
// Whole-second retryAfter: emitted as integer seconds.
|
||||
name: "retry_after_in_seconds",
|
||||
respErr: newErrorResponse("rate limited", "rate_limit_error", "rate_limit_exceeded", http.StatusTooManyRequests, 60*time.Second),
|
||||
respErr: intercept.NewResponseError("rate limited", "rate_limit_error", "rate_limit_exceeded", http.StatusTooManyRequests, 60*time.Second),
|
||||
expectStatus: http.StatusTooManyRequests,
|
||||
expectRetryAfter: "60",
|
||||
},
|
||||
{
|
||||
// 500ms rounds up to Retry-After: 1.
|
||||
name: "retry_after_500ms_rounds_up_to_one",
|
||||
respErr: newErrorResponse("rate limited", "rate_limit_error", "rate_limit_exceeded", http.StatusTooManyRequests, 500*time.Millisecond),
|
||||
respErr: intercept.NewResponseError("rate limited", "rate_limit_error", "rate_limit_exceeded", http.StatusTooManyRequests, 500*time.Millisecond),
|
||||
expectStatus: http.StatusTooManyRequests,
|
||||
expectRetryAfter: "1",
|
||||
},
|
||||
{
|
||||
// 200ms rounds up to Retry-After: 1.
|
||||
name: "retry_after_200ms_rounds_up_to_one",
|
||||
respErr: newErrorResponse("rate limited", "rate_limit_error", "rate_limit_exceeded", http.StatusTooManyRequests, 200*time.Millisecond),
|
||||
respErr: intercept.NewResponseError("rate limited", "rate_limit_error", "rate_limit_exceeded", http.StatusTooManyRequests, 200*time.Millisecond),
|
||||
expectStatus: http.StatusTooManyRequests,
|
||||
expectRetryAfter: "1",
|
||||
},
|
||||
{
|
||||
// Negative retryAfter: header omitted.
|
||||
name: "negative_retry_after_omits_header",
|
||||
respErr: newErrorResponse("rate limited", "rate_limit_error", "rate_limit_exceeded", http.StatusTooManyRequests, -1*time.Second),
|
||||
respErr: intercept.NewResponseError("rate limited", "rate_limit_error", "rate_limit_exceeded", http.StatusTooManyRequests, -1*time.Second),
|
||||
expectStatus: http.StatusTooManyRequests,
|
||||
expectRetryAfter: "",
|
||||
},
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"github.com/coder/coder/v2/aibridge/config"
|
||||
aibcontext "github.com/coder/coder/v2/aibridge/context"
|
||||
"github.com/coder/coder/v2/aibridge/intercept"
|
||||
"github.com/coder/coder/v2/aibridge/keypool"
|
||||
"github.com/coder/coder/v2/aibridge/mcp"
|
||||
"github.com/coder/coder/v2/aibridge/recorder"
|
||||
"github.com/coder/coder/v2/aibridge/tracing"
|
||||
@@ -103,8 +104,9 @@ func (i *BlockingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r *
|
||||
// The failover loop may return a keypool exhaustion
|
||||
// error. Render it here.
|
||||
if upstreamErr != nil {
|
||||
if keyErr := ProcessKeyPoolError(upstreamErr); keyErr != nil {
|
||||
i.writeUpstreamError(w, keyErr)
|
||||
var keyPoolErr *keypool.Error
|
||||
if errors.As(upstreamErr, &keyPoolErr) {
|
||||
i.writeUpstreamError(w, intercept.ResponseErrorFromKeyPool(keyPoolErr))
|
||||
return xerrors.Errorf("key pool exhausted: %w", upstreamErr)
|
||||
}
|
||||
}
|
||||
@@ -174,9 +176,9 @@ func (i *BlockingResponsesInterceptor) newResponseWithKeyFailover(ctx context.Co
|
||||
// success, the last tried key on failure) in the upstack PR.
|
||||
walker := i.cfg.KeyPool.Walker()
|
||||
for {
|
||||
key, err := walker.Next()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
key, keyPoolErr := walker.Next()
|
||||
if keyPoolErr != nil {
|
||||
return nil, keyPoolErr
|
||||
}
|
||||
|
||||
requestOpts := append([]option.RequestOption{}, opts...)
|
||||
|
||||
@@ -134,14 +134,14 @@ func (i *StreamingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r
|
||||
|
||||
var currentKey *keypool.Key
|
||||
if walker != nil {
|
||||
key, err := walker.Next()
|
||||
if respErr := ProcessKeyPoolError(err); respErr != nil {
|
||||
key, keyPoolErr := walker.Next()
|
||||
if keyPoolErr != nil {
|
||||
// Pool exhausted: write the error directly. In
|
||||
// agentic mode the inner loop buffers events
|
||||
// instead of streaming them downstream, so the
|
||||
// SSE connection has not been opened yet.
|
||||
i.writeUpstreamError(w, respErr)
|
||||
return xerrors.Errorf("key pool exhausted: %w", err)
|
||||
i.writeUpstreamError(w, intercept.ResponseErrorFromKeyPool(keyPoolErr))
|
||||
return xerrors.Errorf("key pool exhausted: %w", keyPoolErr)
|
||||
}
|
||||
currentKey = key
|
||||
opts = append(opts,
|
||||
|
||||
@@ -2,11 +2,10 @@ package keypool
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/aibridge/utils"
|
||||
)
|
||||
|
||||
@@ -16,6 +15,9 @@ type KeyFailoverConfig struct {
|
||||
// Pool is the key pool to walk. Nil disables key failover.
|
||||
Pool *Pool
|
||||
|
||||
ProviderName string
|
||||
Logger slog.Logger
|
||||
|
||||
// IsBYOK returns true when the request already carries
|
||||
// user-supplied auth. BYOK requests skip key failover.
|
||||
IsBYOK func(*http.Request) bool
|
||||
@@ -24,14 +26,9 @@ type KeyFailoverConfig struct {
|
||||
// in the format the provider expects.
|
||||
InjectAuthKey func(*http.Header, string)
|
||||
|
||||
// MarkKey marks the key based on the upstream response.
|
||||
// Returns true when the response is a key-specific error,
|
||||
// causing the walker to advance and retry with the next key.
|
||||
MarkKey func(ctx context.Context, key *Key, resp *http.Response) bool
|
||||
|
||||
// BuildExhaustedResponse returns the response sent to the
|
||||
// client when the walker has no more keys to try.
|
||||
BuildExhaustedResponse func(err error) *http.Response
|
||||
// BuildKeyPoolResponse renders the response sent to the client
|
||||
// when the walker has no more keys to try.
|
||||
BuildKeyPoolResponse func(*Error) *http.Response
|
||||
}
|
||||
|
||||
// keyFailoverTransport retries inner across the key pool on
|
||||
@@ -74,12 +71,12 @@ func (t *keyFailoverTransport) RoundTrip(req *http.Request) (*http.Response, err
|
||||
// Fresh walker per request, independent of other inflight requests.
|
||||
walker := t.config.Pool.Walker()
|
||||
for {
|
||||
key, err := walker.Next()
|
||||
if err != nil {
|
||||
resp := t.config.BuildExhaustedResponse(err)
|
||||
key, keyPoolErr := walker.Next()
|
||||
if keyPoolErr != nil {
|
||||
resp := t.config.BuildKeyPoolResponse(keyPoolErr)
|
||||
if resp == nil {
|
||||
// Fallback if BuildExhaustedResponse returns nil.
|
||||
body := []byte(fmt.Sprintf(`{"error":"key pool exhausted: %s"}`, err))
|
||||
// Fallback if BuildKeyPoolResponse returns nil.
|
||||
body := []byte(`{"error":"key pool unavailable"}`)
|
||||
resp = utils.NewJSONErrorResponse(http.StatusBadGateway, 0, body)
|
||||
}
|
||||
return resp, nil
|
||||
@@ -97,8 +94,8 @@ func (t *keyFailoverTransport) RoundTrip(req *http.Request) (*http.Response, err
|
||||
// Transport-level error, not a key issue.
|
||||
return resp, rtErr
|
||||
}
|
||||
// MarkKey returns true on key-specific failures (e.g. 401/403/429).
|
||||
if t.config.MarkKey(req.Context(), key, resp) {
|
||||
// MarkKeyOnStatus returns true on key-specific failures (e.g. 401/403/429).
|
||||
if MarkKeyOnStatus(req.Context(), key, resp, t.config.Logger, t.config.ProviderName) {
|
||||
// Drain and retry with the next key.
|
||||
_, _ = io.Copy(io.Discard, resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
|
||||
@@ -91,8 +91,8 @@ func TestMarkKeyOnStatus(t *testing.T) {
|
||||
clk := quartz.NewMock(t)
|
||||
pool, err := keypool.New([]string{"key-0"}, clk)
|
||||
require.NoError(t, err)
|
||||
key, err := pool.Walker().Next()
|
||||
require.NoError(t, err)
|
||||
key, keyPoolErr := pool.Walker().Next()
|
||||
require.Nil(t, keyPoolErr)
|
||||
|
||||
resp := &http.Response{
|
||||
StatusCode: tc.statusCode,
|
||||
|
||||
+43
-31
@@ -10,6 +10,8 @@ import (
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
// Configuration validation type errors. These surface when the
|
||||
// pool is built from invalid input.
|
||||
var (
|
||||
// ErrNoKeys is returned when the input is empty.
|
||||
ErrNoKeys = xerrors.New("no keys provided")
|
||||
@@ -18,20 +20,35 @@ var (
|
||||
ErrDuplicateKey = xerrors.New("duplicate key")
|
||||
)
|
||||
|
||||
// ErrPermanentKeyPool is returned when every key in the
|
||||
// pool has been permanently marked unavailable.
|
||||
var ErrPermanentKeyPool = xerrors.New("all keys permanently unavailable")
|
||||
// ErrorKind classifies a runtime key-pool failure.
|
||||
type ErrorKind int
|
||||
|
||||
// TransientKeyPoolError is returned when no key is currently
|
||||
// available but at least one will recover. RetryAfter is the
|
||||
// soonest remaining cooldown across the pool, or 0 if a key
|
||||
// just became valid mid-walk.
|
||||
type TransientKeyPoolError struct {
|
||||
const (
|
||||
// ErrorKindRateLimited means no key is currently available
|
||||
// but at least one key will recover after a cooldown.
|
||||
ErrorKindRateLimited ErrorKind = iota
|
||||
// ErrorKindPermanent means every key is permanently marked
|
||||
// and no key can satisfy the request.
|
||||
ErrorKindPermanent
|
||||
)
|
||||
|
||||
// Error is returned when no key is available for the
|
||||
// current attempt. RetryAfter is the soonest remaining
|
||||
// cooldown across the pool.
|
||||
type Error struct {
|
||||
Kind ErrorKind
|
||||
RetryAfter time.Duration
|
||||
}
|
||||
|
||||
func (e *TransientKeyPoolError) Error() string {
|
||||
return fmt.Sprintf("all keys exhausted (retry after %s)", e.RetryAfter)
|
||||
func (e *Error) Error() string {
|
||||
switch e.Kind {
|
||||
case ErrorKindPermanent:
|
||||
return "all configured keys failed authentication"
|
||||
case ErrorKindRateLimited:
|
||||
return fmt.Sprintf("all configured keys are rate-limited (retry after %s)", e.RetryAfter)
|
||||
default:
|
||||
return "key pool error"
|
||||
}
|
||||
}
|
||||
|
||||
// KeyState represents the current state of a key in the pool.
|
||||
@@ -176,20 +193,21 @@ func (k *Key) MarkPermanent() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// keyPoolError returns ErrPermanentKeyPool if every key
|
||||
// is permanently unavailable, or *TransientKeyPoolError if
|
||||
// at least one key is temporarily unavailable. When multiple
|
||||
// keys are temporary, the smallest remaining cooldown is used
|
||||
// as the retry-after.
|
||||
func (p *Pool) keyPoolError() error {
|
||||
// keyPoolError returns an Error summarizing why no
|
||||
// key is currently available. When at least one key is
|
||||
// temporary, the smallest remaining cooldown is used as the
|
||||
// retry-after.
|
||||
func (p *Pool) keyPoolError() *Error {
|
||||
var retryAfter time.Duration
|
||||
var hasCooldown bool
|
||||
for i := range p.keys {
|
||||
state, cooldown := p.keys[i].stateAndCooldown()
|
||||
switch state {
|
||||
// Recoverable now: signal transient with zero retry-after.
|
||||
// Recoverable now: a key's cooldown expired between the walker's
|
||||
// check and this scan. Return Retry-After: 0 to indicate that
|
||||
// an immediate retry will succeed.
|
||||
case KeyStateValid:
|
||||
return &TransientKeyPoolError{}
|
||||
return &Error{Kind: ErrorKindRateLimited}
|
||||
// Recoverable later: track soonest remaining cooldown.
|
||||
case KeyStateTemporary:
|
||||
if !hasCooldown || cooldown < retryAfter {
|
||||
@@ -201,9 +219,9 @@ func (p *Pool) keyPoolError() error {
|
||||
}
|
||||
}
|
||||
if hasCooldown {
|
||||
return &TransientKeyPoolError{RetryAfter: retryAfter}
|
||||
return &Error{Kind: ErrorKindRateLimited, RetryAfter: retryAfter}
|
||||
}
|
||||
return ErrPermanentKeyPool
|
||||
return &Error{Kind: ErrorKindPermanent}
|
||||
}
|
||||
|
||||
// PoolState returns a snapshot of each key's state in the pool's
|
||||
@@ -236,16 +254,10 @@ func (p *Pool) Walker() *Walker {
|
||||
// Next returns a Key handle for the next available key without
|
||||
// modifying the pool state.
|
||||
//
|
||||
// Returns *TransientKeyPoolError or ErrPermanentKeyPool
|
||||
// when no more keys are available.
|
||||
func (w *Walker) Next() (*Key, error) {
|
||||
pool := w.pool
|
||||
if pool == nil {
|
||||
return nil, ErrPermanentKeyPool
|
||||
}
|
||||
|
||||
for i := w.pos; i < len(pool.keys); i++ {
|
||||
key := &pool.keys[i]
|
||||
// Returns *Error when no more keys are available.
|
||||
func (w *Walker) Next() (*Key, *Error) {
|
||||
for i := w.pos; i < len(w.pool.keys); i++ {
|
||||
key := &w.pool.keys[i]
|
||||
if key.State() != KeyStateValid {
|
||||
continue
|
||||
}
|
||||
@@ -255,5 +267,5 @@ func (w *Walker) Next() (*Key, error) {
|
||||
}
|
||||
|
||||
// No keys available.
|
||||
return nil, pool.keyPoolError()
|
||||
return nil, w.pool.keyPoolError()
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package keypool_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
@@ -43,16 +42,15 @@ func TestNewKeyPool(t *testing.T) {
|
||||
// Verify all keys are returned in order and valid.
|
||||
walker := pool.Walker()
|
||||
for _, expected := range tc.expectedKeys {
|
||||
key, err := walker.Next()
|
||||
require.NoError(t, err)
|
||||
key, keyPoolErr := walker.Next()
|
||||
require.Nil(t, keyPoolErr)
|
||||
assert.Equal(t, expected, key.Value())
|
||||
assert.Equal(t, keypool.KeyStateValid, key.State())
|
||||
}
|
||||
|
||||
// No more keys available.
|
||||
_, err = walker.Next()
|
||||
var transient *keypool.TransientKeyPoolError
|
||||
require.ErrorAs(t, err, &transient, "expected transient exhaustion: walker returned all valid keys, none marked permanent")
|
||||
_, keyPoolErr := walker.Next()
|
||||
require.Equal(t, &keypool.Error{Kind: keypool.ErrorKindRateLimited}, keyPoolErr, "expected rate-limited exhaustion: walker returned all valid keys, none marked permanent")
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -69,8 +67,8 @@ func TestState(t *testing.T) {
|
||||
// Fresh key is valid.
|
||||
name: "fresh_key_is_valid",
|
||||
setup: func(t *testing.T, pool *keypool.Pool, _ *quartz.Mock) *keypool.Key {
|
||||
key, err := pool.Walker().Next()
|
||||
require.NoError(t, err)
|
||||
key, keyPoolErr := pool.Walker().Next()
|
||||
require.Nil(t, keyPoolErr)
|
||||
return key
|
||||
},
|
||||
expectedState: keypool.KeyStateValid,
|
||||
@@ -79,8 +77,8 @@ func TestState(t *testing.T) {
|
||||
// Active cooldown makes the key temporary.
|
||||
name: "active_cooldown_is_temporary",
|
||||
setup: func(t *testing.T, pool *keypool.Pool, _ *quartz.Mock) *keypool.Key {
|
||||
key, err := pool.Walker().Next()
|
||||
require.NoError(t, err)
|
||||
key, keyPoolErr := pool.Walker().Next()
|
||||
require.Nil(t, keyPoolErr)
|
||||
key.MarkTemporary(60 * time.Second)
|
||||
return key
|
||||
},
|
||||
@@ -90,8 +88,8 @@ func TestState(t *testing.T) {
|
||||
// Expired cooldown returns the key to valid.
|
||||
name: "expired_cooldown_is_valid",
|
||||
setup: func(t *testing.T, pool *keypool.Pool, clk *quartz.Mock) *keypool.Key {
|
||||
key, err := pool.Walker().Next()
|
||||
require.NoError(t, err)
|
||||
key, keyPoolErr := pool.Walker().Next()
|
||||
require.Nil(t, keyPoolErr)
|
||||
key.MarkTemporary(30 * time.Second)
|
||||
clk.Advance(35 * time.Second)
|
||||
return key
|
||||
@@ -102,8 +100,8 @@ func TestState(t *testing.T) {
|
||||
// Permanent key is permanent.
|
||||
name: "permanent_key",
|
||||
setup: func(t *testing.T, pool *keypool.Pool, _ *quartz.Mock) *keypool.Key {
|
||||
key, err := pool.Walker().Next()
|
||||
require.NoError(t, err)
|
||||
key, keyPoolErr := pool.Walker().Next()
|
||||
require.Nil(t, keyPoolErr)
|
||||
key.MarkPermanent()
|
||||
return key
|
||||
},
|
||||
@@ -113,8 +111,8 @@ func TestState(t *testing.T) {
|
||||
// Permanent takes precedence over active cooldown.
|
||||
name: "permanent_with_cooldown_is_permanent",
|
||||
setup: func(t *testing.T, pool *keypool.Pool, _ *quartz.Mock) *keypool.Key {
|
||||
key, err := pool.Walker().Next()
|
||||
require.NoError(t, err)
|
||||
key, keyPoolErr := pool.Walker().Next()
|
||||
require.Nil(t, keyPoolErr)
|
||||
key.MarkTemporary(60 * time.Second)
|
||||
key.MarkPermanent()
|
||||
return key
|
||||
@@ -152,8 +150,8 @@ func TestMarkTemporary(t *testing.T) {
|
||||
name: "valid_to_temporary",
|
||||
cooldown: 60 * time.Second,
|
||||
setup: func(t *testing.T, pool *keypool.Pool, _ *quartz.Mock) *keypool.Key {
|
||||
key, err := pool.Walker().Next()
|
||||
require.NoError(t, err)
|
||||
key, keyPoolErr := pool.Walker().Next()
|
||||
require.Nil(t, keyPoolErr)
|
||||
return key
|
||||
},
|
||||
expectedState: keypool.KeyStateTemporary,
|
||||
@@ -165,8 +163,8 @@ func TestMarkTemporary(t *testing.T) {
|
||||
name: "temporary_to_temporary_extends_cooldown",
|
||||
cooldown: 60 * time.Second,
|
||||
setup: func(t *testing.T, pool *keypool.Pool, _ *quartz.Mock) *keypool.Key {
|
||||
key, err := pool.Walker().Next()
|
||||
require.NoError(t, err)
|
||||
key, keyPoolErr := pool.Walker().Next()
|
||||
require.Nil(t, keyPoolErr)
|
||||
key.MarkTemporary(10 * time.Second)
|
||||
return key
|
||||
},
|
||||
@@ -179,8 +177,8 @@ func TestMarkTemporary(t *testing.T) {
|
||||
name: "temporary_to_temporary_keeps_longer_cooldown",
|
||||
cooldown: 10 * time.Second,
|
||||
setup: func(t *testing.T, pool *keypool.Pool, _ *quartz.Mock) *keypool.Key {
|
||||
key, err := pool.Walker().Next()
|
||||
require.NoError(t, err)
|
||||
key, keyPoolErr := pool.Walker().Next()
|
||||
require.Nil(t, keyPoolErr)
|
||||
key.MarkTemporary(60 * time.Second)
|
||||
return key
|
||||
},
|
||||
@@ -192,8 +190,8 @@ func TestMarkTemporary(t *testing.T) {
|
||||
name: "permanent_to_temporary_is_no_op",
|
||||
cooldown: 60 * time.Second,
|
||||
setup: func(t *testing.T, pool *keypool.Pool, _ *quartz.Mock) *keypool.Key {
|
||||
key, err := pool.Walker().Next()
|
||||
require.NoError(t, err)
|
||||
key, keyPoolErr := pool.Walker().Next()
|
||||
require.Nil(t, keyPoolErr)
|
||||
key.MarkPermanent()
|
||||
return key
|
||||
},
|
||||
@@ -231,8 +229,8 @@ func TestMarkPermanent(t *testing.T) {
|
||||
// valid -> permanent: key becomes permanently unavailable.
|
||||
name: "valid_to_permanent",
|
||||
setup: func(t *testing.T, pool *keypool.Pool) *keypool.Key {
|
||||
key, err := pool.Walker().Next()
|
||||
require.NoError(t, err)
|
||||
key, keyPoolErr := pool.Walker().Next()
|
||||
require.Nil(t, keyPoolErr)
|
||||
return key
|
||||
},
|
||||
expectedState: keypool.KeyStatePermanent,
|
||||
@@ -243,8 +241,8 @@ func TestMarkPermanent(t *testing.T) {
|
||||
// to auth failure.
|
||||
name: "temporary_to_permanent",
|
||||
setup: func(t *testing.T, pool *keypool.Pool) *keypool.Key {
|
||||
key, err := pool.Walker().Next()
|
||||
require.NoError(t, err)
|
||||
key, keyPoolErr := pool.Walker().Next()
|
||||
require.Nil(t, keyPoolErr)
|
||||
key.MarkTemporary(60 * time.Second)
|
||||
return key
|
||||
},
|
||||
@@ -255,8 +253,8 @@ func TestMarkPermanent(t *testing.T) {
|
||||
// permanent -> permanent: no-op, already permanent.
|
||||
name: "permanent_to_permanent",
|
||||
setup: func(t *testing.T, pool *keypool.Pool) *keypool.Key {
|
||||
key, err := pool.Walker().Next()
|
||||
require.NoError(t, err)
|
||||
key, keyPoolErr := pool.Walker().Next()
|
||||
require.Nil(t, keyPoolErr)
|
||||
key.MarkPermanent()
|
||||
return key
|
||||
},
|
||||
@@ -290,7 +288,7 @@ func TestWalkerNext(t *testing.T) {
|
||||
setup func(t *testing.T, pool *keypool.Pool)
|
||||
advance time.Duration
|
||||
expectedValid []string
|
||||
expectedErr error
|
||||
expectedErr *keypool.Error
|
||||
}{
|
||||
{
|
||||
// Given: key-0: valid, key-1: valid, key-2: valid.
|
||||
@@ -299,7 +297,7 @@ func TestWalkerNext(t *testing.T) {
|
||||
keys: []string{"key-0", "key-1", "key-2"},
|
||||
setup: func(_ *testing.T, _ *keypool.Pool) {},
|
||||
expectedValid: []string{"key-0", "key-1", "key-2"},
|
||||
expectedErr: &keypool.TransientKeyPoolError{},
|
||||
expectedErr: &keypool.Error{Kind: keypool.ErrorKindRateLimited},
|
||||
},
|
||||
{
|
||||
// Given: key-0: temporary, key-1: valid, key-2: valid.
|
||||
@@ -307,12 +305,12 @@ func TestWalkerNext(t *testing.T) {
|
||||
name: "skips_temporary_keys",
|
||||
keys: []string{"key-0", "key-1", "key-2"},
|
||||
setup: func(t *testing.T, pool *keypool.Pool) {
|
||||
key, err := pool.Walker().Next()
|
||||
require.NoError(t, err)
|
||||
key, keyPoolErr := pool.Walker().Next()
|
||||
require.Nil(t, keyPoolErr)
|
||||
key.MarkTemporary(60 * time.Second)
|
||||
},
|
||||
expectedValid: []string{"key-1", "key-2"},
|
||||
expectedErr: &keypool.TransientKeyPoolError{},
|
||||
expectedErr: &keypool.Error{Kind: keypool.ErrorKindRateLimited},
|
||||
},
|
||||
{
|
||||
// Given: key-0: permanent, key-1: permanent, key-2: valid.
|
||||
@@ -321,15 +319,15 @@ func TestWalkerNext(t *testing.T) {
|
||||
keys: []string{"key-0", "key-1", "key-2"},
|
||||
setup: func(t *testing.T, pool *keypool.Pool) {
|
||||
walker := pool.Walker()
|
||||
key0, err := walker.Next()
|
||||
require.NoError(t, err)
|
||||
key0, keyPoolErr := walker.Next()
|
||||
require.Nil(t, keyPoolErr)
|
||||
key0.MarkPermanent()
|
||||
key1, err := walker.Next()
|
||||
require.NoError(t, err)
|
||||
key1, keyPoolErr := walker.Next()
|
||||
require.Nil(t, keyPoolErr)
|
||||
key1.MarkPermanent()
|
||||
},
|
||||
expectedValid: []string{"key-2"},
|
||||
expectedErr: &keypool.TransientKeyPoolError{},
|
||||
expectedErr: &keypool.Error{Kind: keypool.ErrorKindRateLimited},
|
||||
},
|
||||
{
|
||||
// Given: key-0: temporary (30s), key-1: valid.
|
||||
@@ -338,13 +336,13 @@ func TestWalkerNext(t *testing.T) {
|
||||
name: "expired_temporary_is_available",
|
||||
keys: []string{"key-0", "key-1"},
|
||||
setup: func(t *testing.T, pool *keypool.Pool) {
|
||||
key, err := pool.Walker().Next()
|
||||
require.NoError(t, err)
|
||||
key, keyPoolErr := pool.Walker().Next()
|
||||
require.Nil(t, keyPoolErr)
|
||||
key.MarkTemporary(30 * time.Second)
|
||||
},
|
||||
advance: 35 * time.Second,
|
||||
expectedValid: []string{"key-0", "key-1"},
|
||||
expectedErr: &keypool.TransientKeyPoolError{},
|
||||
expectedErr: &keypool.Error{Kind: keypool.ErrorKindRateLimited},
|
||||
},
|
||||
{
|
||||
// Given: key-0: temporary (zero, default 60s), key-1: valid.
|
||||
@@ -353,13 +351,13 @@ func TestWalkerNext(t *testing.T) {
|
||||
name: "default_cooldown_not_expired",
|
||||
keys: []string{"key-0", "key-1"},
|
||||
setup: func(t *testing.T, pool *keypool.Pool) {
|
||||
key, err := pool.Walker().Next()
|
||||
require.NoError(t, err)
|
||||
key, keyPoolErr := pool.Walker().Next()
|
||||
require.Nil(t, keyPoolErr)
|
||||
key.MarkTemporary(0)
|
||||
},
|
||||
advance: 50 * time.Second,
|
||||
expectedValid: []string{"key-1"},
|
||||
expectedErr: &keypool.TransientKeyPoolError{},
|
||||
expectedErr: &keypool.Error{Kind: keypool.ErrorKindRateLimited},
|
||||
},
|
||||
{
|
||||
// Given: key-0: temporary (zero, default 60s), key-1: valid.
|
||||
@@ -368,13 +366,13 @@ func TestWalkerNext(t *testing.T) {
|
||||
name: "default_cooldown_expired",
|
||||
keys: []string{"key-0", "key-1"},
|
||||
setup: func(t *testing.T, pool *keypool.Pool) {
|
||||
key, err := pool.Walker().Next()
|
||||
require.NoError(t, err)
|
||||
key, keyPoolErr := pool.Walker().Next()
|
||||
require.Nil(t, keyPoolErr)
|
||||
key.MarkTemporary(0)
|
||||
},
|
||||
advance: 65 * time.Second,
|
||||
expectedValid: []string{"key-0", "key-1"},
|
||||
expectedErr: &keypool.TransientKeyPoolError{},
|
||||
expectedErr: &keypool.Error{Kind: keypool.ErrorKindRateLimited},
|
||||
},
|
||||
{
|
||||
// Given: key-0: temporary (negative, default 60s), key-1: valid.
|
||||
@@ -383,13 +381,13 @@ func TestWalkerNext(t *testing.T) {
|
||||
name: "negative_cooldown_uses_default",
|
||||
keys: []string{"key-0", "key-1"},
|
||||
setup: func(t *testing.T, pool *keypool.Pool) {
|
||||
key, err := pool.Walker().Next()
|
||||
require.NoError(t, err)
|
||||
key, keyPoolErr := pool.Walker().Next()
|
||||
require.Nil(t, keyPoolErr)
|
||||
key.MarkTemporary(-10 * time.Second)
|
||||
},
|
||||
advance: 65 * time.Second,
|
||||
expectedValid: []string{"key-0", "key-1"},
|
||||
expectedErr: &keypool.TransientKeyPoolError{},
|
||||
expectedErr: &keypool.Error{Kind: keypool.ErrorKindRateLimited},
|
||||
},
|
||||
{
|
||||
// Given: key-0: temporary (60s), then marked again with shorter cooldown (10s).
|
||||
@@ -398,14 +396,14 @@ func TestWalkerNext(t *testing.T) {
|
||||
name: "shorter_cooldown_preserves_longer_not_expired",
|
||||
keys: []string{"key-0"},
|
||||
setup: func(t *testing.T, pool *keypool.Pool) {
|
||||
key, err := pool.Walker().Next()
|
||||
require.NoError(t, err)
|
||||
key, keyPoolErr := pool.Walker().Next()
|
||||
require.Nil(t, keyPoolErr)
|
||||
key.MarkTemporary(60 * time.Second)
|
||||
key.MarkTemporary(10 * time.Second)
|
||||
},
|
||||
advance: 15 * time.Second,
|
||||
expectedValid: []string{},
|
||||
expectedErr: &keypool.TransientKeyPoolError{RetryAfter: 45 * time.Second},
|
||||
expectedErr: &keypool.Error{Kind: keypool.ErrorKindRateLimited, RetryAfter: 45 * time.Second},
|
||||
},
|
||||
{
|
||||
// Given: key-0: temporary (60s), then marked again with shorter cooldown (10s).
|
||||
@@ -414,14 +412,14 @@ func TestWalkerNext(t *testing.T) {
|
||||
name: "shorter_cooldown_preserves_longer_expired",
|
||||
keys: []string{"key-0"},
|
||||
setup: func(t *testing.T, pool *keypool.Pool) {
|
||||
key, err := pool.Walker().Next()
|
||||
require.NoError(t, err)
|
||||
key, keyPoolErr := pool.Walker().Next()
|
||||
require.Nil(t, keyPoolErr)
|
||||
key.MarkTemporary(60 * time.Second)
|
||||
key.MarkTemporary(10 * time.Second)
|
||||
},
|
||||
advance: 65 * time.Second,
|
||||
expectedValid: []string{"key-0"},
|
||||
expectedErr: &keypool.TransientKeyPoolError{},
|
||||
expectedErr: &keypool.Error{Kind: keypool.ErrorKindRateLimited},
|
||||
},
|
||||
{
|
||||
// Given: key-0: temporary (60s), key-1: temporary (10s), key-2: temporary (30s).
|
||||
@@ -431,18 +429,18 @@ func TestWalkerNext(t *testing.T) {
|
||||
keys: []string{"key-0", "key-1", "key-2"},
|
||||
setup: func(t *testing.T, pool *keypool.Pool) {
|
||||
walker := pool.Walker()
|
||||
key0, err := walker.Next()
|
||||
require.NoError(t, err)
|
||||
key0, keyPoolErr := walker.Next()
|
||||
require.Nil(t, keyPoolErr)
|
||||
key0.MarkTemporary(60 * time.Second)
|
||||
key1, err := walker.Next()
|
||||
require.NoError(t, err)
|
||||
key1, keyPoolErr := walker.Next()
|
||||
require.Nil(t, keyPoolErr)
|
||||
key1.MarkTemporary(10 * time.Second)
|
||||
key2, err := walker.Next()
|
||||
require.NoError(t, err)
|
||||
key2, keyPoolErr := walker.Next()
|
||||
require.Nil(t, keyPoolErr)
|
||||
key2.MarkTemporary(30 * time.Second)
|
||||
},
|
||||
expectedValid: []string{},
|
||||
expectedErr: &keypool.TransientKeyPoolError{RetryAfter: 10 * time.Second},
|
||||
expectedErr: &keypool.Error{Kind: keypool.ErrorKindRateLimited, RetryAfter: 10 * time.Second},
|
||||
},
|
||||
{
|
||||
// Given: key-0: temporary, key-1: temporary.
|
||||
@@ -451,15 +449,15 @@ func TestWalkerNext(t *testing.T) {
|
||||
keys: []string{"key-0", "key-1"},
|
||||
setup: func(t *testing.T, pool *keypool.Pool) {
|
||||
walker := pool.Walker()
|
||||
key0, err := walker.Next()
|
||||
require.NoError(t, err)
|
||||
key0, keyPoolErr := walker.Next()
|
||||
require.Nil(t, keyPoolErr)
|
||||
key0.MarkTemporary(60 * time.Second)
|
||||
key1, err := walker.Next()
|
||||
require.NoError(t, err)
|
||||
key1, keyPoolErr := walker.Next()
|
||||
require.Nil(t, keyPoolErr)
|
||||
key1.MarkTemporary(60 * time.Second)
|
||||
},
|
||||
expectedValid: []string{},
|
||||
expectedErr: &keypool.TransientKeyPoolError{RetryAfter: 60 * time.Second},
|
||||
expectedErr: &keypool.Error{Kind: keypool.ErrorKindRateLimited, RetryAfter: 60 * time.Second},
|
||||
},
|
||||
{
|
||||
// Given: key-0: permanent, key-1: permanent.
|
||||
@@ -468,15 +466,15 @@ func TestWalkerNext(t *testing.T) {
|
||||
keys: []string{"key-0", "key-1"},
|
||||
setup: func(t *testing.T, pool *keypool.Pool) {
|
||||
walker := pool.Walker()
|
||||
key0, err := walker.Next()
|
||||
require.NoError(t, err)
|
||||
key0, keyPoolErr := walker.Next()
|
||||
require.Nil(t, keyPoolErr)
|
||||
key0.MarkPermanent()
|
||||
key1, err := walker.Next()
|
||||
require.NoError(t, err)
|
||||
key1, keyPoolErr := walker.Next()
|
||||
require.Nil(t, keyPoolErr)
|
||||
key1.MarkPermanent()
|
||||
},
|
||||
expectedValid: []string{},
|
||||
expectedErr: keypool.ErrPermanentKeyPool,
|
||||
expectedErr: &keypool.Error{Kind: keypool.ErrorKindPermanent},
|
||||
},
|
||||
{
|
||||
// Given: key-0: permanent, key-1: temporary, key-2: permanent.
|
||||
@@ -485,18 +483,18 @@ func TestWalkerNext(t *testing.T) {
|
||||
keys: []string{"key-0", "key-1", "key-2"},
|
||||
setup: func(t *testing.T, pool *keypool.Pool) {
|
||||
walker := pool.Walker()
|
||||
key0, err := walker.Next()
|
||||
require.NoError(t, err)
|
||||
key0, keyPoolErr := walker.Next()
|
||||
require.Nil(t, keyPoolErr)
|
||||
key0.MarkPermanent()
|
||||
key1, err := walker.Next()
|
||||
require.NoError(t, err)
|
||||
key1, keyPoolErr := walker.Next()
|
||||
require.Nil(t, keyPoolErr)
|
||||
key1.MarkTemporary(60 * time.Second)
|
||||
key2, err := walker.Next()
|
||||
require.NoError(t, err)
|
||||
key2, keyPoolErr := walker.Next()
|
||||
require.Nil(t, keyPoolErr)
|
||||
key2.MarkPermanent()
|
||||
},
|
||||
expectedValid: []string{},
|
||||
expectedErr: &keypool.TransientKeyPoolError{RetryAfter: 60 * time.Second},
|
||||
expectedErr: &keypool.Error{Kind: keypool.ErrorKindRateLimited, RetryAfter: 60 * time.Second},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -516,21 +514,14 @@ func TestWalkerNext(t *testing.T) {
|
||||
|
||||
walker := pool.Walker()
|
||||
for _, expectedKey := range tc.expectedValid {
|
||||
key, err := walker.Next()
|
||||
require.NoError(t, err)
|
||||
key, keyPoolErr := walker.Next()
|
||||
require.Nil(t, keyPoolErr)
|
||||
assert.Equal(t, expectedKey, key.Value())
|
||||
}
|
||||
|
||||
// After all expected keys, the walker should be exhausted.
|
||||
_, err = walker.Next()
|
||||
var wantTransient *keypool.TransientKeyPoolError
|
||||
if errors.As(tc.expectedErr, &wantTransient) {
|
||||
var got *keypool.TransientKeyPoolError
|
||||
require.ErrorAs(t, err, &got)
|
||||
assert.Equal(t, wantTransient.RetryAfter, got.RetryAfter)
|
||||
} else {
|
||||
require.ErrorIs(t, err, tc.expectedErr)
|
||||
}
|
||||
_, keyPoolErr := walker.Next()
|
||||
require.Equal(t, tc.expectedErr, keyPoolErr)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -595,8 +586,8 @@ func TestKeyConcurrent(t *testing.T) {
|
||||
clk := quartz.NewMock(t)
|
||||
pool, err := keypool.New([]string{"key-0"}, clk)
|
||||
require.NoError(t, err)
|
||||
key, err := pool.Walker().Next()
|
||||
require.NoError(t, err)
|
||||
key, keyPoolErr := pool.Walker().Next()
|
||||
require.Nil(t, keyPoolErr)
|
||||
|
||||
const numGoroutines = 10
|
||||
var wg sync.WaitGroup
|
||||
@@ -628,29 +619,29 @@ func TestWalkerIndependence(t *testing.T) {
|
||||
walker := pool.Walker()
|
||||
|
||||
// First attempt: get key-0.
|
||||
key, err := walker.Next()
|
||||
require.NoError(t, err)
|
||||
key, keyPoolErr := walker.Next()
|
||||
require.Nil(t, keyPoolErr)
|
||||
assert.Equal(t, "key-0", key.Value())
|
||||
|
||||
// Simulate 429: mark key-0 temporary.
|
||||
key.MarkTemporary(60 * time.Second)
|
||||
|
||||
// Second attempt: walker advances to key-1.
|
||||
key, err = walker.Next()
|
||||
require.NoError(t, err)
|
||||
key, keyPoolErr = walker.Next()
|
||||
require.Nil(t, keyPoolErr)
|
||||
assert.Equal(t, "key-1", key.Value())
|
||||
|
||||
// Simulate 401: mark key-1 permanent.
|
||||
key.MarkPermanent()
|
||||
|
||||
// Third attempt: walker advances to key-2.
|
||||
key, err = walker.Next()
|
||||
require.NoError(t, err)
|
||||
key, keyPoolErr = walker.Next()
|
||||
require.Nil(t, keyPoolErr)
|
||||
assert.Equal(t, "key-2", key.Value())
|
||||
|
||||
// A new walker should skip key-0 (temporary) and key-1
|
||||
// (permanent), and return key-2.
|
||||
key2, err := pool.Walker().Next()
|
||||
require.NoError(t, err)
|
||||
key2, keyPoolErr := pool.Walker().Next()
|
||||
require.Nil(t, keyPoolErr)
|
||||
assert.Equal(t, "key-2", key2.Value())
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -173,8 +172,8 @@ func (p *Anthropic) CreateInterceptor(_ http.ResponseWriter, r *http.Request, tr
|
||||
// Centralized: use the first key as a placeholder hint.
|
||||
// TODO(ssncferreira): record the actually-used key in
|
||||
// the interception record to reflect failover.
|
||||
if k, err := cfg.KeyPool.Walker().Next(); err == nil {
|
||||
credSecret = k.Value()
|
||||
if key, keyPoolErr := cfg.KeyPool.Walker().Next(); keyPoolErr == nil {
|
||||
credSecret = key.Value()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -222,20 +221,18 @@ func (p *Anthropic) InjectAuthHeader(headers *http.Header) {
|
||||
}
|
||||
|
||||
func (p *Anthropic) KeyFailoverConfig(logger slog.Logger) keypool.KeyFailoverConfig {
|
||||
name := p.Name()
|
||||
return keypool.KeyFailoverConfig{
|
||||
Pool: p.cfg.KeyPool,
|
||||
Pool: p.cfg.KeyPool,
|
||||
ProviderName: p.Name(),
|
||||
Logger: logger,
|
||||
IsBYOK: func(r *http.Request) bool {
|
||||
return r.Header.Get("X-Api-Key") != "" || r.Header.Get("Authorization") != ""
|
||||
},
|
||||
InjectAuthKey: func(h *http.Header, key string) {
|
||||
h.Set("X-Api-Key", key)
|
||||
},
|
||||
MarkKey: func(ctx context.Context, key *keypool.Key, resp *http.Response) bool {
|
||||
return keypool.MarkKeyOnStatus(ctx, key, resp, logger, name)
|
||||
},
|
||||
BuildExhaustedResponse: func(err error) *http.Response {
|
||||
return messages.ProcessKeyPoolError(err).ToResponse()
|
||||
BuildKeyPoolResponse: func(keyPoolErr *keypool.Error) *http.Response {
|
||||
return messages.ResponseErrorFromKeyPool(keyPoolErr).ToResponse()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -146,8 +145,8 @@ func (p *OpenAI) CreateInterceptor(_ http.ResponseWriter, r *http.Request, trace
|
||||
// Centralized: use the first key as a placeholder hint.
|
||||
// TODO(ssncferreira): record the actually-used key in
|
||||
// the interception record to reflect failover.
|
||||
if k, err := cfg.KeyPool.Walker().Next(); err == nil {
|
||||
credSecret = k.Value()
|
||||
if key, keyPoolErr := cfg.KeyPool.Walker().Next(); keyPoolErr == nil {
|
||||
credSecret = key.Value()
|
||||
}
|
||||
}
|
||||
cred := intercept.NewCredentialInfo(credKind, credSecret)
|
||||
@@ -221,20 +220,18 @@ func (p *OpenAI) InjectAuthHeader(headers *http.Header) {
|
||||
}
|
||||
|
||||
func (p *OpenAI) KeyFailoverConfig(logger slog.Logger) keypool.KeyFailoverConfig {
|
||||
name := p.Name()
|
||||
return keypool.KeyFailoverConfig{
|
||||
Pool: p.cfg.KeyPool,
|
||||
Pool: p.cfg.KeyPool,
|
||||
ProviderName: p.Name(),
|
||||
Logger: logger,
|
||||
IsBYOK: func(r *http.Request) bool {
|
||||
return r.Header.Get("Authorization") != ""
|
||||
},
|
||||
InjectAuthKey: func(h *http.Header, key string) {
|
||||
h.Set("Authorization", "Bearer "+key)
|
||||
},
|
||||
MarkKey: func(ctx context.Context, key *keypool.Key, resp *http.Response) bool {
|
||||
return keypool.MarkKeyOnStatus(ctx, key, resp, logger, name)
|
||||
},
|
||||
BuildExhaustedResponse: func(err error) *http.Response {
|
||||
return chatcompletions.ProcessKeyPoolError(err).ToResponse()
|
||||
BuildKeyPoolResponse: func(keyPoolErr *keypool.Error) *http.Response {
|
||||
return intercept.ResponseErrorFromKeyPool(keyPoolErr).ToResponse()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user