From 620bdc097243c1d2ae66cc00f44ac59df9cdc1b6 Mon Sep 17 00:00:00 2001 From: Susana Cardoso Ferreira Date: Fri, 29 May 2026 19:23:05 +0000 Subject: [PATCH] refactor(aibridge): add shared keypool failover runner for blocking and passthrough --- aibridge/intercept/messages/base.go | 12 +++++ aibridge/intercept/messages/blocking.go | 49 +++++++++--------- aibridge/keypool/failover.go | 66 +++++++++++++------------ aibridge/keypool/failure.go | 60 ++++++++++++++++++++++ aibridge/keypool/runner.go | 53 ++++++++++++++++++++ 5 files changed, 185 insertions(+), 55 deletions(-) create mode 100644 aibridge/keypool/failure.go create mode 100644 aibridge/keypool/runner.go diff --git a/aibridge/intercept/messages/base.go b/aibridge/intercept/messages/base.go index 1f1f49e744..f337bb1107 100644 --- a/aibridge/intercept/messages/base.go +++ b/aibridge/intercept/messages/base.go @@ -589,6 +589,18 @@ func (i *interceptionBase) markKeyOnError(ctx context.Context, key *keypool.Key, ) } +// classifyError maps a centralized-request error to a *keypool.Failure by +// extracting the Anthropic SDK error and classifying its HTTP response. A +// non-SDK error (transport failure, context cancellation) yields nil so the +// failover loop returns it to the caller instead of retrying. +func (i *interceptionBase) classifyError(err error) *keypool.Failure { + var apiErr *anthropic.Error + if !errors.As(err, &apiErr) { + return nil + } + return keypool.Classify(apiErr.Response) +} + // ResponseErrorFromKeyPool translates a *keypool.Error into // a developer-facing ResponseError shaped for the Anthropic API. func ResponseErrorFromKeyPool(keyPoolErr *keypool.Error) *ResponseError { diff --git a/aibridge/intercept/messages/blocking.go b/aibridge/intercept/messages/blocking.go index bf74885b2b..b6aa455be7 100644 --- a/aibridge/intercept/messages/blocking.go +++ b/aibridge/intercept/messages/blocking.go @@ -367,29 +367,30 @@ func (i *BlockingInterception) newMessageWithKey(ctx context.Context, svc anthro // Errors that aren't key-specific don't trigger failover and // are returned to the caller. func (i *BlockingInterception) newMessageWithKeyFailover(ctx context.Context, svc anthropic.MessageService) (*anthropic.Message, error) { - walker := i.cfg.KeyPool.Walker() - for { - key, keyPoolErr := walker.Next() - if keyPoolErr != nil { - return nil, keyPoolErr - } - // Record the key in use so the hint reflects the last attempted key. - i.credential = intercept.NewCredentialInfo(intercept.CredentialKindCentralized, key.Value()) - i.logger.Debug(ctx, "using centralized api key", - slog.F("credential_hint", i.Credential().Hint), slog.F("credential_length", i.Credential().Length)) - - msg, err := i.newMessageWithKey(ctx, svc, - option.WithAPIKey(key.Value()), - // Disable SDK retries because the failover loop - // handles retries via key rotation. - option.WithMaxRetries(0), - ) - // Key-specific failure: try the next key. - if i.markKeyOnError(ctx, key, err) { - continue - } - // Either success (msg, nil) or a non-key error (nil, err): - // nothing to retry, return as-is. - return msg, err + // result carries both return values of one attempt so the failover loop + // hands back a success or a non-key error as a single atomic payload. + type result struct { + msg *anthropic.Message + err error } + res, keyPoolErr := keypool.Failover(ctx, i.cfg.KeyPool, i.logger, i.providerName, + func(ctx context.Context, key *keypool.Key) (result, *keypool.Failure) { + // Record the key in use so the hint reflects the last attempted key. + i.credential = intercept.NewCredentialInfo(intercept.CredentialKindCentralized, key.Value()) + i.logger.Debug(ctx, "using centralized api key", + slog.F("credential_hint", i.Credential().Hint), slog.F("credential_length", i.Credential().Length)) + + msg, err := i.newMessageWithKey(ctx, svc, + option.WithAPIKey(key.Value()), + // Disable SDK retries because the failover loop + // handles retries via key rotation. + option.WithMaxRetries(0), + ) + return result{msg, err}, i.classifyError(err) + }) + if keyPoolErr != nil { + return nil, keyPoolErr + } + // Either success (msg, nil) or a non-key error (nil, err): return as-is. + return res.msg, res.err } diff --git a/aibridge/keypool/failover.go b/aibridge/keypool/failover.go index 38dcd3b972..4ba4df212f 100644 --- a/aibridge/keypool/failover.go +++ b/aibridge/keypool/failover.go @@ -2,6 +2,7 @@ package keypool import ( "bytes" + "context" "io" "net/http" @@ -68,42 +69,45 @@ func (t *keyFailoverTransport) RoundTrip(req *http.Request) (*http.Response, err return nil, err } - // Fresh walker per request, independent of other inflight requests. - walker := t.config.Pool.Walker() - for { - key, keyPoolErr := walker.Next() - if keyPoolErr != nil { - resp := t.config.BuildKeyPoolResponse(keyPoolErr) - if resp == nil { - // Fallback if BuildKeyPoolResponse returns nil. - body := []byte(`{"error":"key pool unavailable"}`) - resp = utils.NewJSONErrorResponse(http.StatusBadGateway, 0, body) + // result carries both return values of one attempt so the failover loop + // hands back a success or a transport error as a single atomic payload. + type result struct { + resp *http.Response + err error + } + res, keyPoolErr := Failover(req.Context(), t.config.Pool, t.config.Logger, t.config.ProviderName, + func(ctx context.Context, key *Key) (result, *Failure) { + // Clone per attempt so the original request isn't mutated. + outReq := req.Clone(ctx) + if body != nil { + outReq.Body = io.NopCloser(bytes.NewReader(body)) } - return resp, nil - } + t.config.InjectAuthKey(&outReq.Header, key.Value()) - // Clone per attempt so the original request isn't mutated. - outReq := req.Clone(req.Context()) - if body != nil { - outReq.Body = io.NopCloser(bytes.NewReader(body)) + resp, rtErr := t.inner.RoundTrip(outReq) + if rtErr != nil { + // Transport-level error, not a key issue: stop and return. + return result{resp, rtErr}, nil + } + failure := Classify(resp) + if failure != nil { + // Drain the discarded response before retrying with the next key. + _, _ = io.Copy(io.Discard, resp.Body) + _ = resp.Body.Close() + } + return result{resp, nil}, failure + }) + if keyPoolErr != nil { + resp := t.config.BuildKeyPoolResponse(keyPoolErr) + if resp == nil { + // Fallback if BuildKeyPoolResponse returns nil. + body := []byte(`{"error":"key pool unavailable"}`) + resp = utils.NewJSONErrorResponse(http.StatusBadGateway, 0, body) } - t.config.InjectAuthKey(&outReq.Header, key.Value()) - - resp, rtErr := t.inner.RoundTrip(outReq) - if rtErr != nil { - // Transport-level error, not a key issue. - return resp, rtErr - } - // MarkKeyOnStatus returns true on key-specific failures (e.g. 401/403/429). - if MarkKeyOnStatus(req.Context(), key, resp, t.config.Logger, t.config.ProviderName) { - // Drain and retry with the next key. - _, _ = io.Copy(io.Discard, resp.Body) - _ = resp.Body.Close() - continue - } - // Success or non-key error, forward as-is. return resp, nil } + // Success or non-key transport error, forward as-is. + return res.resp, res.err } // bufferBody reads the request body fully so it can be replayed diff --git a/aibridge/keypool/failure.go b/aibridge/keypool/failure.go new file mode 100644 index 0000000000..c74f17541d --- /dev/null +++ b/aibridge/keypool/failure.go @@ -0,0 +1,60 @@ +package keypool + +import ( + "net/http" + "time" +) + +// FailoverReason explains why a key attempt failed in a way that should move +// the failover loop to the next key. +type FailoverReason int + +const ( + // FailoverRateLimited marks the key temporary and retries with the next + // key (HTTP 429). + FailoverRateLimited FailoverReason = iota + // FailoverUnauthorized marks the key permanent and retries with the next + // key (HTTP 401). + FailoverUnauthorized + // FailoverForbidden marks the key permanent and retries with the next key + // (HTTP 403). + FailoverForbidden +) + +// Failure describes a key-specific attempt failure that triggers failover. A +// nil *Failure means no key failure: the attempt produced a result the caller +// should keep (a success, a non-key error, a transport error, or a streaming +// attempt that already committed). +type Failure struct { + Reason FailoverReason + // Cooldown is honored only for FailoverRateLimited. + Cooldown time.Duration +} + +// Classify maps a key-specific HTTP response to a *Failure. A nil response or +// any non-failover status yields nil. 429 yields FailoverRateLimited carrying +// the parsed Retry-After (or defaultCooldown when absent), 401 yields +// FailoverUnauthorized, and 403 yields FailoverForbidden. +// +// Classify intentionally takes an *http.Response, not a provider error, so +// the pool stays SDK-agnostic. Callers unwrap the response from their SDK's +// error type (e.g. errors.As(err, &apiErr); apiErr.Response) before calling. +func Classify(resp *http.Response) *Failure { + if resp == nil { + return nil + } + switch resp.StatusCode { + case http.StatusTooManyRequests: + cooldown := ParseRetryAfter(resp) + if cooldown <= 0 { + cooldown = defaultCooldown + } + return &Failure{Reason: FailoverRateLimited, Cooldown: cooldown} + case http.StatusUnauthorized: + return &Failure{Reason: FailoverUnauthorized} + case http.StatusForbidden: + return &Failure{Reason: FailoverForbidden} + default: + return nil + } +} diff --git a/aibridge/keypool/runner.go b/aibridge/keypool/runner.go new file mode 100644 index 0000000000..7cacaf8208 --- /dev/null +++ b/aibridge/keypool/runner.go @@ -0,0 +1,53 @@ +package keypool + +import ( + "context" + + "cdr.dev/slog/v3" +) + +// Failover walks pool, invoking attempt with each candidate key until the +// attempt reports no failure (a nil *Failure) or the pool is exhausted. It +// owns key marking, the retry decision, and exhaustion. +// +// When an attempt returns a non-nil *Failure the chosen key is marked +// (temporary or permanent) and the next key is tried. The discarded payload +// is the closure's responsibility to clean up before reporting a failure. On +// exhaustion Failover returns the zero value of T and the pool's *Error. +func Failover[T any]( + ctx context.Context, + pool *Pool, + logger slog.Logger, + providerName string, + attempt func(ctx context.Context, key *Key) (T, *Failure), +) (T, *Error) { + var zero T + walker := pool.Walker() + for { + key, kpErr := walker.Next() + if kpErr != nil { + return zero, kpErr + } + + payload, failure := attempt(ctx, key) + if failure == nil { + return payload, nil + } + + switch failure.Reason { + case FailoverRateLimited: + if key.MarkTemporary(failure.Cooldown) { + logger.Info(ctx, "key marked temporary", + slog.F("provider", providerName), + slog.F("api_key_hint", key.Hint()), + slog.F("cooldown", failure.Cooldown)) + } + case FailoverUnauthorized, FailoverForbidden: + if key.MarkPermanent() { + logger.Warn(ctx, "key marked permanent", + slog.F("provider", providerName), + slog.F("api_key_hint", key.Hint())) + } + } + } +}