mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
refactor(aibridge): add shared keypool failover runner for blocking and passthrough
This commit is contained in:
@@ -589,6 +589,18 @@ func (i *interceptionBase) markKeyOnError(ctx context.Context, key *keypool.Key,
|
||||
)
|
||||
}
|
||||
|
||||
// classifyError maps a centralized-request error to a *keypool.Failure by
|
||||
// extracting the Anthropic SDK error and classifying its HTTP response. A
|
||||
// non-SDK error (transport failure, context cancellation) yields nil so the
|
||||
// failover loop returns it to the caller instead of retrying.
|
||||
func (i *interceptionBase) classifyError(err error) *keypool.Failure {
|
||||
var apiErr *anthropic.Error
|
||||
if !errors.As(err, &apiErr) {
|
||||
return nil
|
||||
}
|
||||
return keypool.Classify(apiErr.Response)
|
||||
}
|
||||
|
||||
// ResponseErrorFromKeyPool translates a *keypool.Error into
|
||||
// a developer-facing ResponseError shaped for the Anthropic API.
|
||||
func ResponseErrorFromKeyPool(keyPoolErr *keypool.Error) *ResponseError {
|
||||
|
||||
@@ -367,29 +367,30 @@ func (i *BlockingInterception) newMessageWithKey(ctx context.Context, svc anthro
|
||||
// Errors that aren't key-specific don't trigger failover and
|
||||
// are returned to the caller.
|
||||
func (i *BlockingInterception) newMessageWithKeyFailover(ctx context.Context, svc anthropic.MessageService) (*anthropic.Message, error) {
|
||||
walker := i.cfg.KeyPool.Walker()
|
||||
for {
|
||||
key, keyPoolErr := walker.Next()
|
||||
if keyPoolErr != nil {
|
||||
return nil, keyPoolErr
|
||||
}
|
||||
// Record the key in use so the hint reflects the last attempted key.
|
||||
i.credential = intercept.NewCredentialInfo(intercept.CredentialKindCentralized, key.Value())
|
||||
i.logger.Debug(ctx, "using centralized api key",
|
||||
slog.F("credential_hint", i.Credential().Hint), slog.F("credential_length", i.Credential().Length))
|
||||
|
||||
msg, err := i.newMessageWithKey(ctx, svc,
|
||||
option.WithAPIKey(key.Value()),
|
||||
// Disable SDK retries because the failover loop
|
||||
// handles retries via key rotation.
|
||||
option.WithMaxRetries(0),
|
||||
)
|
||||
// Key-specific failure: try the next key.
|
||||
if i.markKeyOnError(ctx, key, err) {
|
||||
continue
|
||||
}
|
||||
// Either success (msg, nil) or a non-key error (nil, err):
|
||||
// nothing to retry, return as-is.
|
||||
return msg, err
|
||||
// result carries both return values of one attempt so the failover loop
|
||||
// hands back a success or a non-key error as a single atomic payload.
|
||||
type result struct {
|
||||
msg *anthropic.Message
|
||||
err error
|
||||
}
|
||||
res, keyPoolErr := keypool.Failover(ctx, i.cfg.KeyPool, i.logger, i.providerName,
|
||||
func(ctx context.Context, key *keypool.Key) (result, *keypool.Failure) {
|
||||
// Record the key in use so the hint reflects the last attempted key.
|
||||
i.credential = intercept.NewCredentialInfo(intercept.CredentialKindCentralized, key.Value())
|
||||
i.logger.Debug(ctx, "using centralized api key",
|
||||
slog.F("credential_hint", i.Credential().Hint), slog.F("credential_length", i.Credential().Length))
|
||||
|
||||
msg, err := i.newMessageWithKey(ctx, svc,
|
||||
option.WithAPIKey(key.Value()),
|
||||
// Disable SDK retries because the failover loop
|
||||
// handles retries via key rotation.
|
||||
option.WithMaxRetries(0),
|
||||
)
|
||||
return result{msg, err}, i.classifyError(err)
|
||||
})
|
||||
if keyPoolErr != nil {
|
||||
return nil, keyPoolErr
|
||||
}
|
||||
// Either success (msg, nil) or a non-key error (nil, err): return as-is.
|
||||
return res.msg, res.err
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package keypool
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
@@ -68,42 +69,45 @@ func (t *keyFailoverTransport) RoundTrip(req *http.Request) (*http.Response, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Fresh walker per request, independent of other inflight requests.
|
||||
walker := t.config.Pool.Walker()
|
||||
for {
|
||||
key, keyPoolErr := walker.Next()
|
||||
if keyPoolErr != nil {
|
||||
resp := t.config.BuildKeyPoolResponse(keyPoolErr)
|
||||
if resp == nil {
|
||||
// Fallback if BuildKeyPoolResponse returns nil.
|
||||
body := []byte(`{"error":"key pool unavailable"}`)
|
||||
resp = utils.NewJSONErrorResponse(http.StatusBadGateway, 0, body)
|
||||
// result carries both return values of one attempt so the failover loop
|
||||
// hands back a success or a transport error as a single atomic payload.
|
||||
type result struct {
|
||||
resp *http.Response
|
||||
err error
|
||||
}
|
||||
res, keyPoolErr := Failover(req.Context(), t.config.Pool, t.config.Logger, t.config.ProviderName,
|
||||
func(ctx context.Context, key *Key) (result, *Failure) {
|
||||
// Clone per attempt so the original request isn't mutated.
|
||||
outReq := req.Clone(ctx)
|
||||
if body != nil {
|
||||
outReq.Body = io.NopCloser(bytes.NewReader(body))
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
t.config.InjectAuthKey(&outReq.Header, key.Value())
|
||||
|
||||
// Clone per attempt so the original request isn't mutated.
|
||||
outReq := req.Clone(req.Context())
|
||||
if body != nil {
|
||||
outReq.Body = io.NopCloser(bytes.NewReader(body))
|
||||
resp, rtErr := t.inner.RoundTrip(outReq)
|
||||
if rtErr != nil {
|
||||
// Transport-level error, not a key issue: stop and return.
|
||||
return result{resp, rtErr}, nil
|
||||
}
|
||||
failure := Classify(resp)
|
||||
if failure != nil {
|
||||
// Drain the discarded response before retrying with the next key.
|
||||
_, _ = io.Copy(io.Discard, resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
}
|
||||
return result{resp, nil}, failure
|
||||
})
|
||||
if keyPoolErr != nil {
|
||||
resp := t.config.BuildKeyPoolResponse(keyPoolErr)
|
||||
if resp == nil {
|
||||
// Fallback if BuildKeyPoolResponse returns nil.
|
||||
body := []byte(`{"error":"key pool unavailable"}`)
|
||||
resp = utils.NewJSONErrorResponse(http.StatusBadGateway, 0, body)
|
||||
}
|
||||
t.config.InjectAuthKey(&outReq.Header, key.Value())
|
||||
|
||||
resp, rtErr := t.inner.RoundTrip(outReq)
|
||||
if rtErr != nil {
|
||||
// Transport-level error, not a key issue.
|
||||
return resp, rtErr
|
||||
}
|
||||
// MarkKeyOnStatus returns true on key-specific failures (e.g. 401/403/429).
|
||||
if MarkKeyOnStatus(req.Context(), key, resp, t.config.Logger, t.config.ProviderName) {
|
||||
// Drain and retry with the next key.
|
||||
_, _ = io.Copy(io.Discard, resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
continue
|
||||
}
|
||||
// Success or non-key error, forward as-is.
|
||||
return resp, nil
|
||||
}
|
||||
// Success or non-key transport error, forward as-is.
|
||||
return res.resp, res.err
|
||||
}
|
||||
|
||||
// bufferBody reads the request body fully so it can be replayed
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
package keypool
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
)
|
||||
|
||||
// FailoverReason explains why a key attempt failed in a way that should move
|
||||
// the failover loop to the next key.
|
||||
type FailoverReason int
|
||||
|
||||
const (
|
||||
// FailoverRateLimited marks the key temporary and retries with the next
|
||||
// key (HTTP 429).
|
||||
FailoverRateLimited FailoverReason = iota
|
||||
// FailoverUnauthorized marks the key permanent and retries with the next
|
||||
// key (HTTP 401).
|
||||
FailoverUnauthorized
|
||||
// FailoverForbidden marks the key permanent and retries with the next key
|
||||
// (HTTP 403).
|
||||
FailoverForbidden
|
||||
)
|
||||
|
||||
// Failure describes a key-specific attempt failure that triggers failover. A
|
||||
// nil *Failure means no key failure: the attempt produced a result the caller
|
||||
// should keep (a success, a non-key error, a transport error, or a streaming
|
||||
// attempt that already committed).
|
||||
type Failure struct {
|
||||
Reason FailoverReason
|
||||
// Cooldown is honored only for FailoverRateLimited.
|
||||
Cooldown time.Duration
|
||||
}
|
||||
|
||||
// Classify maps a key-specific HTTP response to a *Failure. A nil response or
|
||||
// any non-failover status yields nil. 429 yields FailoverRateLimited carrying
|
||||
// the parsed Retry-After (or defaultCooldown when absent), 401 yields
|
||||
// FailoverUnauthorized, and 403 yields FailoverForbidden.
|
||||
//
|
||||
// Classify intentionally takes an *http.Response, not a provider error, so
|
||||
// the pool stays SDK-agnostic. Callers unwrap the response from their SDK's
|
||||
// error type (e.g. errors.As(err, &apiErr); apiErr.Response) before calling.
|
||||
func Classify(resp *http.Response) *Failure {
|
||||
if resp == nil {
|
||||
return nil
|
||||
}
|
||||
switch resp.StatusCode {
|
||||
case http.StatusTooManyRequests:
|
||||
cooldown := ParseRetryAfter(resp)
|
||||
if cooldown <= 0 {
|
||||
cooldown = defaultCooldown
|
||||
}
|
||||
return &Failure{Reason: FailoverRateLimited, Cooldown: cooldown}
|
||||
case http.StatusUnauthorized:
|
||||
return &Failure{Reason: FailoverUnauthorized}
|
||||
case http.StatusForbidden:
|
||||
return &Failure{Reason: FailoverForbidden}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,53 @@
|
||||
package keypool
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
)
|
||||
|
||||
// Failover walks pool, invoking attempt with each candidate key until the
|
||||
// attempt reports no failure (a nil *Failure) or the pool is exhausted. It
|
||||
// owns key marking, the retry decision, and exhaustion.
|
||||
//
|
||||
// When an attempt returns a non-nil *Failure the chosen key is marked
|
||||
// (temporary or permanent) and the next key is tried. The discarded payload
|
||||
// is the closure's responsibility to clean up before reporting a failure. On
|
||||
// exhaustion Failover returns the zero value of T and the pool's *Error.
|
||||
func Failover[T any](
|
||||
ctx context.Context,
|
||||
pool *Pool,
|
||||
logger slog.Logger,
|
||||
providerName string,
|
||||
attempt func(ctx context.Context, key *Key) (T, *Failure),
|
||||
) (T, *Error) {
|
||||
var zero T
|
||||
walker := pool.Walker()
|
||||
for {
|
||||
key, kpErr := walker.Next()
|
||||
if kpErr != nil {
|
||||
return zero, kpErr
|
||||
}
|
||||
|
||||
payload, failure := attempt(ctx, key)
|
||||
if failure == nil {
|
||||
return payload, nil
|
||||
}
|
||||
|
||||
switch failure.Reason {
|
||||
case FailoverRateLimited:
|
||||
if key.MarkTemporary(failure.Cooldown) {
|
||||
logger.Info(ctx, "key marked temporary",
|
||||
slog.F("provider", providerName),
|
||||
slog.F("api_key_hint", key.Hint()),
|
||||
slog.F("cooldown", failure.Cooldown))
|
||||
}
|
||||
case FailoverUnauthorized, FailoverForbidden:
|
||||
if key.MarkPermanent() {
|
||||
logger.Warn(ctx, "key marked permanent",
|
||||
slog.F("provider", providerName),
|
||||
slog.F("api_key_hint", key.Hint()))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user