feat: add automatic key failover for AI Bridge passthrough (#24920)

## Description

Adds automatic key failover for passthrough routes for the Anthropic and OpenAI providers. A new `keyFailoverTransport` wraps the reverse-proxy transport: centralized requests walk the configured key pool and retry with the next key on key-specific failures (401/403/429), reusing the same key-marking semantics as the bridged routes.

BYOK passthrough requests run as a single attempt with no failover.

## Changes

- New `keypool.KeyFailoverConfig` carrying the `Pool` to walk and the provider-specific closures (`IsBYOK`, `InjectAuthKey`, `MarkKey`, `BuildExhaustedResponse`).
- New `keypool.NewKeyFailoverTransport`: wraps an inner `http.RoundTripper`. Returns `inner` unchanged when `Pool` is nil, otherwise produces a transport that buffers the request body once, walks the pool per request, and replays each attempt with the next key.
- New `Provider.KeyFailoverConfig(logger)` interface method. Anthropic injects `X-Api-Key`; OpenAI injects `Authorization: Bearer ...`; Copilot returns an empty config.
- `passthrough.go` wires `NewKeyFailoverTransport` around the existing apidump middleware, so every retry attempt is recorded.

## Related Issues

Related to: https://github.com/coder/internal/issues/1446
Related to: https://linear.app/codercom/issue/AIGOV-197/aibridge-automatic-key-failover-for-bridged-and-passthrough-routes

## Follow-up PRs

- Remove dead `Provider.InjectAuthHeader` method now that all auth is applied per-attempt by `KeyFailoverTransport`.
- Bedrock multi-key support.
- Refactor provider vs interceptor config separation.
- Record the actually-used key in the interception credential hint after failover.

> [!NOTE]
> Initially generated by Claude Opus 4.7, modified and reviewed by @ssncferreira
This commit is contained in:
Susana Ferreira
2026-05-07 15:46:36 +01:00
committed by GitHub
parent b94a0aebcd
commit 0766cc3097
23 changed files with 758 additions and 77 deletions
+22 -11
View File
@@ -27,6 +27,7 @@ import (
"github.com/coder/coder/v2/aibridge/mcp"
"github.com/coder/coder/v2/aibridge/recorder"
"github.com/coder/coder/v2/aibridge/tracing"
"github.com/coder/coder/v2/aibridge/utils"
"github.com/coder/quartz"
)
@@ -188,7 +189,7 @@ func (i *interceptionBase) unmarshalArgs(in string) (args recorder.ToolArgs) {
}
// writeUpstreamError marshals and writes a given error.
func (i *interceptionBase) writeUpstreamError(w http.ResponseWriter, oaiErr *responseError) {
func (i *interceptionBase) writeUpstreamError(w http.ResponseWriter, oaiErr *ResponseError) {
if oaiErr == nil {
return
}
@@ -234,10 +235,10 @@ func (i *interceptionBase) markKeyOnError(ctx context.Context, key *keypool.Key,
)
}
// processKeyPoolError translates a keypool exhaustion error
// ProcessKeyPoolError translates a keypool exhaustion error
// into a developer-facing responseError shaped for the OpenAI
// API. Returns nil if err is not an exhaustion error.
func processKeyPoolError(err error) *responseError {
func ProcessKeyPoolError(err error) *ResponseError {
var transient *keypool.TransientKeyPoolError
switch {
case errors.As(err, &transient):
@@ -292,7 +293,7 @@ func calculateActualInputTokenUsage(in openai.CompletionUsage) int64 {
in.PromptTokensDetails.CachedTokens /* The aggregated number of text input tokens that has been cached from previous requests. */
}
func getErrorResponse(err error) *responseError {
func getErrorResponse(err error) *ResponseError {
var apiErr *openai.Error
if !errors.As(err, &apiErr) {
return nil
@@ -300,16 +301,16 @@ func getErrorResponse(err error) *responseError {
return newErrorResponse(apiErr.Message, apiErr.Type, apiErr.Code, apiErr.StatusCode, keypool.ParseRetryAfter(apiErr.Response))
}
var _ error = &responseError{}
var _ error = &ResponseError{}
type responseError struct {
type ResponseError struct {
ErrorObject *shared.ErrorObject `json:"error"`
StatusCode int `json:"-"`
RetryAfter time.Duration `json:"-"`
}
func newErrorResponse(msg, errType, code string, status int, retryAfter time.Duration) *responseError {
return &responseError{
func newErrorResponse(msg, errType, code string, status int, retryAfter time.Duration) *ResponseError {
return &ResponseError{
ErrorObject: &shared.ErrorObject{
Code: code,
Message: msg,
@@ -320,9 +321,19 @@ func newErrorResponse(msg, errType, code string, status int, retryAfter time.Dur
}
}
func (a *responseError) Error() string {
if a.ErrorObject == nil {
func (e *ResponseError) Error() string {
if e.ErrorObject == nil {
return ""
}
return a.ErrorObject.Message
return e.ErrorObject.Message
}
// ToResponse marshals e into an *http.Response shaped for the
// OpenAI API.
func (e *ResponseError) ToResponse() *http.Response {
body, err := json.Marshal(e)
if err != nil {
body = []byte(`{"error":{"type":"error","message":"error marshaling upstream error","code":"server_error"}}`)
}
return utils.NewJSONErrorResponse(e.StatusCode, e.RetryAfter, body)
}
@@ -127,7 +127,7 @@ func TestProcessKeyPoolError(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got := processKeyPoolError(tc.err)
got := ProcessKeyPoolError(tc.err)
if tc.expectedNil {
require.Nil(t, got)
return
@@ -207,7 +207,7 @@ func TestWriteUpstreamError(t *testing.T) {
tests := []struct {
name string
respErr *responseError
respErr *ResponseError
expectStatus int
// Empty string means the header should be absent.
expectRetryAfter string
@@ -225,7 +225,7 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req
// The failover loop may return a keypool exhaustion
// error. Check before the SDK-error path.
if keyErr := processKeyPoolError(err); keyErr != nil {
if keyErr := ProcessKeyPoolError(err); keyErr != nil {
i.writeUpstreamError(w, keyErr)
return xerrors.Errorf("key pool exhausted: %w", err)
}
@@ -144,7 +144,7 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re
var currentKey *keypool.Key
if walker != nil {
key, err := walker.Next()
if respErr := processKeyPoolError(err); respErr != nil {
if respErr := ProcessKeyPoolError(err); respErr != nil {
// Pool exhausted in this iteration. Relay the
// error to the client: as an SSE event if events
// have already been sent, or by direct write
@@ -474,7 +474,7 @@ func (i *StreamingInterception) newStream(ctx context.Context, svc openai.ChatCo
// processing error into a relayable responseError. Returns nil
// when the error is unrecoverable, in which case nothing can be
// relayed back.
func (*StreamingInterception) mapStreamError(ctx context.Context, logger slog.Logger, streamErr, lastErr error) *responseError {
func (*StreamingInterception) mapStreamError(ctx context.Context, logger slog.Logger, streamErr, lastErr error) *ResponseError {
if streamErr != nil {
if eventstream.IsUnrecoverableError(streamErr) {
logger.Debug(ctx, "stream terminated", slog.Error(streamErr))
+21 -11
View File
@@ -433,7 +433,7 @@ func filterBedrockBetaFlags(headers http.Header, model string) {
}
// writeUpstreamError marshals and writes a given error.
func (i *interceptionBase) writeUpstreamError(w http.ResponseWriter, antErr *responseError) {
func (i *interceptionBase) writeUpstreamError(w http.ResponseWriter, antErr *ResponseError) {
if antErr == nil {
return
}
@@ -537,10 +537,10 @@ func (i *interceptionBase) markKeyOnError(ctx context.Context, key *keypool.Key,
)
}
// processKeyPoolError translates a keypool exhaustion error
// ProcessKeyPoolError translates a keypool exhaustion error
// into a developer-facing responseError shaped for the Anthropic
// API. Returns nil if err is not an exhaustion error.
func processKeyPoolError(err error) *responseError {
func ProcessKeyPoolError(err error) *ResponseError {
var transient *keypool.TransientKeyPoolError
switch {
case errors.As(err, &transient):
@@ -562,7 +562,7 @@ func processKeyPoolError(err error) *responseError {
}
}
func getErrorResponse(err error) *responseError {
func getErrorResponse(err error) *ResponseError {
var apierr *anthropic.Error
if !errors.As(err, &apierr) {
return nil
@@ -583,17 +583,17 @@ func getErrorResponse(err error) *responseError {
return newErrorResponse(msg, errType, apierr.StatusCode, keypool.ParseRetryAfter(apierr.Response))
}
var _ error = &responseError{}
var _ error = &ResponseError{}
type responseError struct {
type ResponseError struct {
*anthropic.ErrorResponse
StatusCode int `json:"-"`
RetryAfter time.Duration `json:"-"`
}
func newErrorResponse(msg, errType string, status int, retryAfter time.Duration) *responseError {
return &responseError{
func newErrorResponse(msg, errType string, status int, retryAfter time.Duration) *ResponseError {
return &ResponseError{
ErrorResponse: &shared.ErrorResponse{
Error: shared.ErrorObjectUnion{
Message: msg,
@@ -606,9 +606,19 @@ func newErrorResponse(msg, errType string, status int, retryAfter time.Duration)
}
}
func (a *responseError) Error() string {
if a.ErrorResponse == nil {
func (e *ResponseError) Error() string {
if e.ErrorResponse == nil {
return ""
}
return a.ErrorResponse.Error.Message
return e.ErrorResponse.Error.Message
}
// ToResponse marshals e into an *http.Response shaped for the
// Anthropic API.
func (e *ResponseError) ToResponse() *http.Response {
body, err := json.Marshal(e)
if err != nil {
body = []byte(`{"type":"error","error":{"type":"error","message":"error marshaling upstream error"}}`)
}
return utils.NewJSONErrorResponse(e.StatusCode, e.RetryAfter, body)
}
+2 -2
View File
@@ -1039,7 +1039,7 @@ func TestProcessKeyPoolError(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got := processKeyPoolError(tc.err)
got := ProcessKeyPoolError(tc.err)
if tc.expectedNil {
require.Nil(t, got)
return
@@ -1119,7 +1119,7 @@ func TestWriteUpstreamError(t *testing.T) {
tests := []struct {
name string
respErr *responseError
respErr *ResponseError
expectStatus int
// Empty string means the header should be absent.
expectRetryAfter string
+1 -1
View File
@@ -114,7 +114,7 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req
// The failover loop may return a keypool exhaustion
// error. Check before the SDK-error path.
if keyErr := processKeyPoolError(err); keyErr != nil {
if keyErr := ProcessKeyPoolError(err); keyErr != nil {
i.writeUpstreamError(w, keyErr)
return xerrors.Errorf("key pool exhausted: %w", err)
}
+3 -3
View File
@@ -175,7 +175,7 @@ newStream:
var currentKey *keypool.Key
if walker != nil {
key, err := walker.Next()
if respErr := processKeyPoolError(err); respErr != nil {
if respErr := ProcessKeyPoolError(err); respErr != nil {
// Pool exhausted in this iteration. Relay the
// error to the client: as an SSE event if events
// have already been sent, or by direct write
@@ -599,10 +599,10 @@ newStream:
}
// mapStreamError converts a mid-stream upstream error or
// processing error into a relayable responseError. Returns nil
// processing error into a relayable ResponseError. Returns nil
// when the error is unrecoverable, in which case nothing can be
// relayed back.
func (*StreamingInterception) mapStreamError(ctx context.Context, logger slog.Logger, streamErr, lastErr error) *responseError {
func (*StreamingInterception) mapStreamError(ctx context.Context, logger slog.Logger, streamErr, lastErr error) *ResponseError {
if streamErr != nil {
if eventstream.IsUnrecoverableError(streamErr) {
logger.Debug(ctx, "stream terminated", slog.Error(streamErr))
+22 -11
View File
@@ -35,6 +35,7 @@ import (
"github.com/coder/coder/v2/aibridge/mcp"
"github.com/coder/coder/v2/aibridge/recorder"
"github.com/coder/coder/v2/aibridge/tracing"
"github.com/coder/coder/v2/aibridge/utils"
"github.com/coder/quartz"
)
@@ -142,7 +143,7 @@ func (i *responsesInterceptionBase) validateRequest(ctx context.Context, w http.
}
// writeUpstreamError marshals and writes a given error.
func (i *responsesInterceptionBase) writeUpstreamError(w http.ResponseWriter, oaiErr *responseError) {
func (i *responsesInterceptionBase) writeUpstreamError(w http.ResponseWriter, oaiErr *ResponseError) {
if oaiErr == nil {
return
}
@@ -188,10 +189,10 @@ func (i *responsesInterceptionBase) markKeyOnError(ctx context.Context, key *key
)
}
// processKeyPoolError translates a keypool exhaustion error
// into a developer-facing responseError shaped for the OpenAI
// ProcessKeyPoolError translates a keypool exhaustion error
// into a developer-facing ResponseError shaped for the OpenAI
// API. Returns nil if err is not an exhaustion error.
func processKeyPoolError(err error) *responseError {
func ProcessKeyPoolError(err error) *ResponseError {
var transient *keypool.TransientKeyPoolError
switch {
case errors.As(err, &transient):
@@ -215,8 +216,8 @@ func processKeyPoolError(err error) *responseError {
}
}
func newErrorResponse(msg, errType, code string, status int, retryAfter time.Duration) *responseError {
return &responseError{
func newErrorResponse(msg, errType, code string, status int, retryAfter time.Duration) *ResponseError {
return &ResponseError{
ErrorObject: &shared.ErrorObject{
Code: code,
Message: msg,
@@ -227,19 +228,29 @@ func newErrorResponse(msg, errType, code string, status int, retryAfter time.Dur
}
}
var _ error = &responseError{}
var _ error = &ResponseError{}
type responseError struct {
type ResponseError struct {
ErrorObject *shared.ErrorObject `json:"error"`
StatusCode int `json:"-"`
RetryAfter time.Duration `json:"-"`
}
func (a *responseError) Error() string {
if a.ErrorObject == nil {
func (e *ResponseError) Error() string {
if e.ErrorObject == nil {
return ""
}
return a.ErrorObject.Message
return e.ErrorObject.Message
}
// ToResponse marshals e into an *http.Response shaped for the
// OpenAI API.
func (e *ResponseError) ToResponse() *http.Response {
body, err := json.Marshal(e)
if err != nil {
body = []byte(`{"error":{"type":"error","message":"error marshaling upstream error","code":"server_error"}}`)
}
return utils.NewJSONErrorResponse(e.StatusCode, e.RetryAfter, body)
}
// sendCustomErr sends custom responses.Error error to the client
+2 -2
View File
@@ -432,7 +432,7 @@ func TestProcessKeyPoolError(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got := processKeyPoolError(tc.err)
got := ProcessKeyPoolError(tc.err)
if tc.expectedNil {
require.Nil(t, got)
return
@@ -512,7 +512,7 @@ func TestWriteUpstreamError(t *testing.T) {
tests := []struct {
name string
respErr *responseError
respErr *ResponseError
expectStatus int
// Empty string means the header should be absent.
expectRetryAfter string
+1 -1
View File
@@ -103,7 +103,7 @@ func (i *BlockingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r *
// The failover loop may return a keypool exhaustion
// error. Render it here.
if upstreamErr != nil {
if keyErr := processKeyPoolError(upstreamErr); keyErr != nil {
if keyErr := ProcessKeyPoolError(upstreamErr); keyErr != nil {
i.writeUpstreamError(w, keyErr)
return xerrors.Errorf("key pool exhausted: %w", upstreamErr)
}
+1 -1
View File
@@ -135,7 +135,7 @@ func (i *StreamingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r
var currentKey *keypool.Key
if walker != nil {
key, err := walker.Next()
if respErr := processKeyPoolError(err); respErr != nil {
if respErr := ProcessKeyPoolError(err); respErr != nil {
// Pool exhausted: write the error directly. In
// agentic mode the inner loop buffers events
// instead of streaming them downstream, so the
@@ -6,8 +6,10 @@ import (
"go.opentelemetry.io/otel/trace"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/aibridge/config"
"github.com/coder/coder/v2/aibridge/intercept"
"github.com/coder/coder/v2/aibridge/keypool"
)
type MockProvider struct {
@@ -31,6 +33,10 @@ func (m *MockProvider) InjectAuthHeader(h *http.Header) {
m.InjectAuthHeaderFunc(h)
}
}
func (*MockProvider) KeyFailoverConfig(_ slog.Logger) keypool.KeyFailoverConfig {
return keypool.KeyFailoverConfig{}
}
func (*MockProvider) CircuitBreakerConfig() *config.CircuitBreaker { return nil }
func (*MockProvider) APIDumpDir() string { return "" }
func (m *MockProvider) CreateInterceptor(w http.ResponseWriter, r *http.Request, tracer trace.Tracer) (intercept.Interceptor, error) {
+120
View File
@@ -0,0 +1,120 @@
package keypool
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"github.com/coder/coder/v2/aibridge/utils"
)
// KeyFailoverConfig is the per-provider configuration consumed by
// NewKeyFailoverTransport.
type KeyFailoverConfig struct {
// Pool is the key pool to walk. Nil disables key failover.
Pool *Pool
// IsBYOK returns true when the request already carries
// user-supplied auth. BYOK requests skip key failover.
IsBYOK func(*http.Request) bool
// InjectAuthKey writes the key value into the outbound headers
// in the format the provider expects.
InjectAuthKey func(*http.Header, string)
// MarkKey marks the key based on the upstream response.
// Returns true when the response is a key-specific error,
// causing the walker to advance and retry with the next key.
MarkKey func(ctx context.Context, key *Key, resp *http.Response) bool
// BuildExhaustedResponse returns the response sent to the
// client when the walker has no more keys to try.
BuildExhaustedResponse func(err error) *http.Response
}
// keyFailoverTransport retries inner across the key pool on
// key-specific failures.
type keyFailoverTransport struct {
inner http.RoundTripper
config KeyFailoverConfig
}
// NewKeyFailoverTransport returns an http.RoundTripper backed by
// keyFailoverTransport. If config.Pool is nil, inner is returned
// unchanged.
func NewKeyFailoverTransport(inner http.RoundTripper, config KeyFailoverConfig) http.RoundTripper {
if config.Pool == nil {
return inner
}
return &keyFailoverTransport{
inner: inner,
config: config,
}
}
// RoundTrip is invoked by the proxy once per outer client request,
// after Rewrite has applied proxy headers.
//
// For centralized requests it walks the key pool, retrying on
// key-specific failures until one key succeeds or the pool is
// exhausted. BYOK requests skip the failover loop.
func (t *keyFailoverTransport) RoundTrip(req *http.Request) (*http.Response, error) {
if t.config.IsBYOK(req) {
return t.inner.RoundTrip(req)
}
// Buffer once so retries can replay the body.
body, err := bufferBody(req)
if err != nil {
return nil, err
}
// Fresh walker per request, independent of other inflight requests.
walker := t.config.Pool.Walker()
for {
key, err := walker.Next()
if err != nil {
resp := t.config.BuildExhaustedResponse(err)
if resp == nil {
// Fallback if BuildExhaustedResponse returns nil.
body := []byte(fmt.Sprintf(`{"error":"key pool exhausted: %s"}`, err))
resp = utils.NewJSONErrorResponse(http.StatusBadGateway, 0, body)
}
return resp, nil
}
// Clone per attempt so the original request isn't mutated.
outReq := req.Clone(req.Context())
if body != nil {
outReq.Body = io.NopCloser(bytes.NewReader(body))
}
t.config.InjectAuthKey(&outReq.Header, key.Value())
resp, rtErr := t.inner.RoundTrip(outReq)
if rtErr != nil {
// Transport-level error, not a key issue.
return resp, rtErr
}
// MarkKey returns true on key-specific failures (e.g. 401/403/429).
if t.config.MarkKey(req.Context(), key, resp) {
// Drain and retry with the next key.
_, _ = io.Copy(io.Discard, resp.Body)
_ = resp.Body.Close()
continue
}
// Success or non-key error, forward as-is.
return resp, nil
}
}
// bufferBody reads the request body fully so it can be replayed
// across key-failover retries. Returns nil for a nil body.
func bufferBody(req *http.Request) ([]byte, error) {
if req.Body == nil {
return nil, nil
}
defer req.Body.Close()
return io.ReadAll(req.Body)
}
+69
View File
@@ -0,0 +1,69 @@
package keypool_test
import (
"net/http"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/aibridge/keypool"
"github.com/coder/quartz"
)
// errFakeRoundTripperCalled is returned by fakeRoundTripper if it
// ever gets invoked. The constructor identity tests should never
// trigger a RoundTrip call.
var errFakeRoundTripperCalled = xerrors.New("fakeRoundTripper should not be invoked")
// fakeRoundTripper is a no-op http.RoundTripper used to check
// constructor identity in tests.
type fakeRoundTripper struct{}
func (*fakeRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
return nil, errFakeRoundTripperCalled
}
func TestNewKeyFailoverTransport(t *testing.T) {
t.Parallel()
pool, err := keypool.New([]string{"k0"}, quartz.NewMock(t))
require.NoError(t, err)
tests := []struct {
name string
// Constructor input.
config keypool.KeyFailoverConfig
// Whether the constructor returns inner unchanged.
expectSame bool
}{
{
// Pool is nil: failover is disabled, inner is returned unchanged.
name: "pool_nil_returns_inner",
config: keypool.KeyFailoverConfig{},
expectSame: true,
},
{
// Pool is set: inner is wrapped in a key-failover transport.
name: "pool_set_returns_wrapper",
config: keypool.KeyFailoverConfig{Pool: pool},
expectSame: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
inner := &fakeRoundTripper{}
got := keypool.NewKeyFailoverTransport(inner, tc.config)
if tc.expectSame {
assert.Same(t, inner, got)
} else {
assert.NotSame(t, inner, got)
}
})
}
}
+11 -9
View File
@@ -13,6 +13,7 @@ import (
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/aibridge/intercept/apidump"
"github.com/coder/coder/v2/aibridge/keypool"
"github.com/coder/coder/v2/aibridge/metrics"
"github.com/coder/coder/v2/aibridge/provider"
"github.com/coder/coder/v2/aibridge/tracing"
@@ -41,13 +42,17 @@ func newPassthroughRouter(prov provider.Provider, logger slog.Logger, m *metrics
ExpectContinueTimeout: 1 * time.Second,
}
// Build a reverse proxy to the upstream, reused across all requests for this provider.
// All request modifications happen in Rewrite.
// Build the passthrough proxy, reused across all requests for this provider.
// Rewrite sets proxy headers. For centralized requests, KeyFailoverTransport
// handles auth and failover. BYOK requests pass through.
proxy := &httputil.ReverseProxy{
Rewrite: func(pr *httputil.ProxyRequest) {
rewritePassthroughRequest(pr, provBaseURL, prov)
rewritePassthroughRequest(pr, provBaseURL)
},
Transport: apidump.NewPassthroughMiddleware(t, prov.APIDumpDir(), prov.Name(), logger, quartz.NewReal()),
Transport: keypool.NewKeyFailoverTransport(
apidump.NewPassthroughMiddleware(t, prov.APIDumpDir(), prov.Name(), logger, quartz.NewReal()),
prov.KeyFailoverConfig(logger),
),
ErrorHandler: func(rw http.ResponseWriter, req *http.Request, e error) {
logger.Warn(req.Context(), "reverse proxy error", slog.Error(e), slog.F("path", req.URL.Path))
http.Error(rw, "upstream proxy error", http.StatusBadGateway)
@@ -67,8 +72,8 @@ func newPassthroughRouter(prov provider.Provider, logger slog.Logger, m *metrics
}
// rewritePassthroughRequest configures the outbound request for the upstream and
// applies proxy headers and provider auth.
func rewritePassthroughRequest(pr *httputil.ProxyRequest, provBaseURL *url.URL, prov provider.Provider) {
// applies proxy headers.
func rewritePassthroughRequest(pr *httputil.ProxyRequest, provBaseURL *url.URL) {
pr.SetURL(provBaseURL)
// Rewrite sets "X-Forwarded-For" to just last hop (clients IP address).
@@ -87,9 +92,6 @@ func rewritePassthroughRequest(pr *httputil.ProxyRequest, provBaseURL *url.URL,
if _, ok := pr.Out.Header["User-Agent"]; !ok {
pr.Out.Header.Set("User-Agent", "aibridge") // TODO: use build tag.
}
// Inject provider auth.
prov.InjectAuthHeader(&pr.Out.Header)
}
// newInvalidBaseURLHandler returns a handler that always returns 502
+290 -20
View File
@@ -2,20 +2,27 @@ package aibridge //nolint:testpackage // tests unexported newPassthroughRouter
import (
"crypto/tls"
"io"
"maps"
"net"
"net/http"
"net/http/httptest"
"net/http/httputil"
"net/url"
"strings"
"sync/atomic"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/aibridge/config"
"github.com/coder/coder/v2/aibridge/internal/testutil"
"github.com/coder/coder/v2/aibridge/keypool"
"github.com/coder/coder/v2/aibridge/provider"
"github.com/coder/quartz"
)
var testTracer = otel.Tracer("bridge_test")
@@ -171,25 +178,6 @@ func TestRewritePassthroughRequest(t *testing.T) {
"User-Agent": {"custom-agent/1.0"},
},
},
{
name: "injects_auth_header",
reqPath: "http://client-host/chat",
reqRemoteAddr: "1.1.1.1:1111",
provider: &testutil.MockProvider{
URL: "https://upstream-host/base",
InjectAuthHeaderFunc: func(h *http.Header) {
h.Set("Authorization", "Bearer test-token")
},
},
expectURL: "https://upstream-host/base/chat",
expectHeaders: http.Header{
"X-Forwarded-Host": {"client-host"},
"X-Forwarded-Proto": {"http"},
"X-Forwarded-For": {"1.1.1.1"},
"User-Agent": {"aibridge"},
"Authorization": {"Bearer test-token"},
},
},
{
name: "appends_remote_addr_to_existing_forwarded_for_chain",
reqPath: "http://client-host/chat",
@@ -260,7 +248,7 @@ func TestRewritePassthroughRequest(t *testing.T) {
Out: r.Clone(r.Context()),
}
rewritePassthroughRequest(pr, provBaseURL, tc.provider)
rewritePassthroughRequest(pr, provBaseURL)
assert.Equal(t, tc.expectURL, pr.Out.URL.String())
assert.Equal(t, "", pr.Out.Host)
@@ -301,3 +289,285 @@ func TestPassthroughRouterReusesProxyInstance(t *testing.T) {
assert.EqualValues(t, 1, newConnections.Load())
}
// TestPassthrough_KeyFailover exercises the KeyFailoverTransport
// end-to-end through the passthrough proxy, parameterised over
// providers (anthropic, openai). Each scenario asserts the upstream
// request count, the response status and Retry-After, and the final
// pool state.
func TestPassthrough_KeyFailover(t *testing.T) {
t.Parallel()
type upstreamResponse struct {
statusCode int
body string
headers map[string]string
}
const (
rateLimitBody = `{"error":"rate"}`
authErrorBody = `{"error":"unauthorized"}`
serverErrorBody = `{"error":"server"}`
successBody = `{"data":[]}`
)
// providers parameterises the table over the two providers
// that support key failover. Each entry encapsulates the
// provider-specific bits the test needs: how the mock upstream
// extracts the key from the request, how a BYOK request sets
// it, and how the provider is constructed for a given pool.
providers := []struct {
name string
extractKey func(*http.Request) string
setBYOK func(*http.Request, string)
newProvider func(baseURL string, pool *keypool.Pool) provider.Provider
}{
{
name: "anthropic",
extractKey: func(r *http.Request) string {
return r.Header.Get("X-Api-Key")
},
setBYOK: func(r *http.Request, key string) {
r.Header.Set("X-Api-Key", key)
},
newProvider: func(baseURL string, pool *keypool.Pool) provider.Provider {
return provider.NewAnthropic(config.Anthropic{
BaseURL: baseURL,
KeyPool: pool,
}, nil)
},
},
{
name: "openai",
extractKey: func(r *http.Request) string {
return strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ")
},
setBYOK: func(r *http.Request, key string) {
r.Header.Set("Authorization", "Bearer "+key)
},
newProvider: func(baseURL string, pool *keypool.Pool) provider.Provider {
cfg := config.OpenAI{BaseURL: baseURL}
if pool != nil {
cfg.KeyPool = pool
}
return provider.NewOpenAI(cfg)
},
},
}
tests := []struct {
name string
// Centralized pool keys. Empty when byokKey is set.
keys []string
// BYOK key. Empty when keys is set.
byokKey string
// Scripted upstream responses keyed by API key value.
responses map[string]upstreamResponse
expectedRequestCount int32
expectedStatusCode int
expectedRetryAfter string
// Expected key states after the request, by index in keys.
expectedKeyStates []keypool.KeyState
}{
{
// Given: 1 valid key returning 200.
// Then: 1 request, 200 response, key remains valid.
name: "single_valid_key",
keys: []string{"k0"},
responses: map[string]upstreamResponse{
"k0": {statusCode: http.StatusOK, body: successBody},
},
expectedRequestCount: 1,
expectedStatusCode: http.StatusOK,
expectedKeyStates: []keypool.KeyState{keypool.KeyStateValid},
},
{
// Given: 2 keys; key-0 returns 429, key-1 returns 200.
// Then: 2 requests, 200 response, key-0 temporary, key-1 valid.
name: "failover_after_429",
keys: []string{"k0", "k1"},
responses: map[string]upstreamResponse{
"k0": {
statusCode: http.StatusTooManyRequests,
headers: map[string]string{"Retry-After": "5"},
body: rateLimitBody,
},
"k1": {statusCode: http.StatusOK, body: successBody},
},
expectedRequestCount: 2,
expectedStatusCode: http.StatusOK,
expectedKeyStates: []keypool.KeyState{
keypool.KeyStateTemporary,
keypool.KeyStateValid,
},
},
{
// Given: 2 keys; key-0 returns 401, key-1 returns 200.
// Then: 2 requests, 200 response, key-0 permanent, key-1 valid.
name: "failover_after_401",
keys: []string{"k0", "k1"},
responses: map[string]upstreamResponse{
"k0": {statusCode: http.StatusUnauthorized, body: authErrorBody},
"k1": {statusCode: http.StatusOK, body: successBody},
},
expectedRequestCount: 2,
expectedStatusCode: http.StatusOK,
expectedKeyStates: []keypool.KeyState{
keypool.KeyStatePermanent,
keypool.KeyStateValid,
},
},
{
// Given: 2 keys; key-0 returns 403, key-1 returns 200.
// Then: 2 requests, 200 response, key-0 permanent, key-1 valid.
name: "failover_after_403",
keys: []string{"k0", "k1"},
responses: map[string]upstreamResponse{
"k0": {statusCode: http.StatusForbidden, body: authErrorBody},
"k1": {statusCode: http.StatusOK, body: successBody},
},
expectedRequestCount: 2,
expectedStatusCode: http.StatusOK,
expectedKeyStates: []keypool.KeyState{
keypool.KeyStatePermanent,
keypool.KeyStateValid,
},
},
{
// Given: 3 keys; all return 429 with cooldowns 5s, 3s, 10s.
// Then: 3 requests, 429 response with smallest Retry-After,
// all keys temporary.
name: "all_keys_rate_limited",
keys: []string{"k0", "k1", "k2"},
responses: map[string]upstreamResponse{
"k0": {
statusCode: http.StatusTooManyRequests,
headers: map[string]string{"Retry-After": "5"},
body: rateLimitBody,
},
"k1": {
statusCode: http.StatusTooManyRequests,
headers: map[string]string{"Retry-After": "3"},
body: rateLimitBody,
},
"k2": {
statusCode: http.StatusTooManyRequests,
headers: map[string]string{"Retry-After": "10"},
body: rateLimitBody,
},
},
expectedRequestCount: 3,
expectedStatusCode: http.StatusTooManyRequests,
expectedRetryAfter: "3",
expectedKeyStates: []keypool.KeyState{
keypool.KeyStateTemporary,
keypool.KeyStateTemporary,
keypool.KeyStateTemporary,
},
},
{
// Given: 2 keys; both return 401.
// Then: 2 requests, 502 response, both keys permanent.
name: "all_keys_unauthorized",
keys: []string{"k0", "k1"},
responses: map[string]upstreamResponse{
"k0": {statusCode: http.StatusUnauthorized, body: authErrorBody},
"k1": {statusCode: http.StatusUnauthorized, body: authErrorBody},
},
expectedRequestCount: 2,
expectedStatusCode: http.StatusBadGateway,
expectedKeyStates: []keypool.KeyState{
keypool.KeyStatePermanent,
keypool.KeyStatePermanent,
},
},
{
// Given: 2 keys; key-0 returns 500.
// Then: 1 request, 500 response, both keys remain valid.
name: "server_error_no_failover",
keys: []string{"k0", "k1"},
responses: map[string]upstreamResponse{
"k0": {statusCode: http.StatusInternalServerError, body: serverErrorBody},
},
expectedRequestCount: 1,
expectedStatusCode: http.StatusInternalServerError,
expectedKeyStates: []keypool.KeyState{
keypool.KeyStateValid,
keypool.KeyStateValid,
},
},
{
// Given: BYOK with a single user-supplied key returning 429.
// Then: 1 request, 429 forwarded as-is, no failover.
name: "byok_no_failover",
byokKey: "user-byok",
responses: map[string]upstreamResponse{
"user-byok": {
statusCode: http.StatusTooManyRequests,
headers: map[string]string{"Retry-After": "5"},
body: rateLimitBody,
},
},
expectedRequestCount: 1,
expectedStatusCode: http.StatusTooManyRequests,
expectedRetryAfter: "5",
},
}
for _, prov := range providers {
for _, tc := range tests {
t.Run(prov.name+"/"+tc.name, func(t *testing.T) {
t.Parallel()
// Mock upstream: counts requests and returns
// scripted responses keyed by API key. An unmapped
// key falls through to 500 so misconfigured cases
// surface via the status assertion.
var requestCount atomic.Int32
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestCount.Add(1)
_, _ = io.Copy(io.Discard, r.Body)
resp, ok := tc.responses[prov.extractKey(r)]
if !ok {
resp = upstreamResponse{statusCode: http.StatusInternalServerError}
}
w.Header().Set("Content-Type", "application/json")
for hk, hv := range resp.headers {
w.Header().Set(hk, hv)
}
w.WriteHeader(resp.statusCode)
_, _ = w.Write([]byte(resp.body))
}))
t.Cleanup(upstream.Close)
var pool *keypool.Pool
if len(tc.keys) > 0 {
var err error
pool, err = keypool.New(tc.keys, quartz.NewMock(t))
require.NoError(t, err)
}
p := prov.newProvider(upstream.URL, pool)
// IgnoreErrors: MarkKey logs at ERROR level when a
// key is marked permanent (401/403); slogtest would
// otherwise fail those scenarios.
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
handler := newPassthroughRouter(p, logger, nil, testTracer)
req := httptest.NewRequest(http.MethodGet, "/v1/models", nil)
if tc.byokKey != "" {
prov.setBYOK(req, tc.byokKey)
}
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
assert.Equal(t, tc.expectedRequestCount, requestCount.Load(), "upstream request count")
assert.Equal(t, tc.expectedStatusCode, w.Code, "response status code")
assert.Equal(t, tc.expectedRetryAfter, w.Header().Get("Retry-After"), "Retry-After header")
if pool != nil {
assert.Equal(t, tc.expectedKeyStates, pool.PoolState(), "key states")
}
})
}
}
}
+21
View File
@@ -1,6 +1,7 @@
package provider
import (
"context"
"fmt"
"io"
"net/http"
@@ -11,6 +12,7 @@ import (
"go.opentelemetry.io/otel/trace"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/aibridge/circuitbreaker"
"github.com/coder/coder/v2/aibridge/config"
"github.com/coder/coder/v2/aibridge/intercept"
@@ -219,6 +221,25 @@ func (p *Anthropic) InjectAuthHeader(headers *http.Header) {
}
}
func (p *Anthropic) KeyFailoverConfig(logger slog.Logger) keypool.KeyFailoverConfig {
name := p.Name()
return keypool.KeyFailoverConfig{
Pool: p.cfg.KeyPool,
IsBYOK: func(r *http.Request) bool {
return r.Header.Get("X-Api-Key") != "" || r.Header.Get("Authorization") != ""
},
InjectAuthKey: func(h *http.Header, key string) {
h.Set("X-Api-Key", key)
},
MarkKey: func(ctx context.Context, key *keypool.Key, resp *http.Response) bool {
return keypool.MarkKeyOnStatus(ctx, key, resp, logger, name)
},
BuildExhaustedResponse: func(err error) *http.Response {
return messages.ProcessKeyPoolError(err).ToResponse()
},
}
}
func (p *Anthropic) CircuitBreakerConfig() *config.CircuitBreaker {
return p.cfg.CircuitBreaker
}
+8
View File
@@ -12,10 +12,12 @@ import (
"go.opentelemetry.io/otel/trace"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/aibridge/config"
"github.com/coder/coder/v2/aibridge/intercept"
"github.com/coder/coder/v2/aibridge/intercept/chatcompletions"
"github.com/coder/coder/v2/aibridge/intercept/responses"
"github.com/coder/coder/v2/aibridge/keypool"
"github.com/coder/coder/v2/aibridge/tracing"
"github.com/coder/coder/v2/aibridge/utils"
)
@@ -111,6 +113,12 @@ func (*Copilot) AuthHeader() string {
// The original Authorization header flows through untouched from the client.
func (*Copilot) InjectAuthHeader(_ *http.Header) {}
// KeyFailoverConfig returns a config with a nil Pool, which makes
// the KeyFailoverTransport short-circuit. Copilot is always BYOK.
func (*Copilot) KeyFailoverConfig(_ slog.Logger) keypool.KeyFailoverConfig {
return keypool.KeyFailoverConfig{}
}
func (p *Copilot) CircuitBreakerConfig() *config.CircuitBreaker {
return p.circuitBreaker
}
+21
View File
@@ -1,6 +1,7 @@
package provider
import (
"context"
"encoding/json"
"fmt"
"io"
@@ -12,6 +13,7 @@ import (
"go.opentelemetry.io/otel/trace"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/aibridge/config"
"github.com/coder/coder/v2/aibridge/intercept"
"github.com/coder/coder/v2/aibridge/intercept/chatcompletions"
@@ -218,6 +220,25 @@ func (p *OpenAI) InjectAuthHeader(headers *http.Header) {
}
}
func (p *OpenAI) KeyFailoverConfig(logger slog.Logger) keypool.KeyFailoverConfig {
name := p.Name()
return keypool.KeyFailoverConfig{
Pool: p.cfg.KeyPool,
IsBYOK: func(r *http.Request) bool {
return r.Header.Get("Authorization") != ""
},
InjectAuthKey: func(h *http.Header, key string) {
h.Set("Authorization", "Bearer "+key)
},
MarkKey: func(ctx context.Context, key *keypool.Key, resp *http.Response) bool {
return keypool.MarkKeyOnStatus(ctx, key, resp, logger, name)
},
BuildExhaustedResponse: func(err error) *http.Response {
return chatcompletions.ProcessKeyPoolError(err).ToResponse()
},
}
}
func (p *OpenAI) CircuitBreakerConfig() *config.CircuitBreaker {
return p.circuitBreaker
}
+7
View File
@@ -6,8 +6,10 @@ import (
"go.opentelemetry.io/otel/trace"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/aibridge/config"
"github.com/coder/coder/v2/aibridge/intercept"
"github.com/coder/coder/v2/aibridge/keypool"
)
var ErrUnknownRoute = xerrors.New("unknown route")
@@ -76,7 +78,12 @@ type Provider interface {
// token in.
AuthHeader() string
// InjectAuthHeader allows [Provider]s to set its authentication header.
// TODO(ssncferreira): remove. Auth is now applied per-attempt by
// KeyFailoverTransport (see [Provider.KeyFailoverConfig]).
InjectAuthHeader(*http.Header)
// KeyFailoverConfig returns the per-provider configuration for
// automatic key failover on passthrough routes.
KeyFailoverConfig(logger slog.Logger) keypool.KeyFailoverConfig
// CircuitBreakerConfig returns the circuit breaker configuration for the provider.
CircuitBreakerConfig() *config.CircuitBreaker
+34
View File
@@ -0,0 +1,34 @@
package utils
import (
"bytes"
"fmt"
"io"
"math"
"net/http"
"strconv"
"time"
)
// NewJSONErrorResponse builds an *http.Response with a JSON body
// and optional Retry-After header. Used to synthesize bridge-side
// error responses (e.g. key-pool exhaustion, marshaling
// fallbacks). Retry-After is set to whole seconds (rounded up)
// when retryAfter is positive, and omitted otherwise.
func NewJSONErrorResponse(status int, retryAfter time.Duration, body []byte) *http.Response {
h := http.Header{}
h.Set("Content-Type", "application/json")
if retryAfter > 0 {
h.Set("Retry-After", strconv.Itoa(int(math.Ceil(retryAfter.Seconds()))))
}
return &http.Response{
Status: fmt.Sprintf("%d %s", status, http.StatusText(status)),
StatusCode: status,
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Header: h,
Body: io.NopCloser(bytes.NewReader(body)),
ContentLength: int64(len(body)),
}
}
+91
View File
@@ -0,0 +1,91 @@
package utils_test
import (
"io"
"net/http"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/aibridge/utils"
)
func TestNewJSONErrorResponse(t *testing.T) {
t.Parallel()
tests := []struct {
name string
status int
retryAfter time.Duration
body []byte
// Empty string means the header should be absent.
expectRetryAfter string
}{
{
// Permanent exhaustion: 502 with no Retry-After.
name: "permanent_no_retry_after",
status: http.StatusBadGateway,
retryAfter: 0,
body: []byte(`{"error":"permanent"}`),
expectRetryAfter: "",
},
{
// Transient exhaustion with zero retryAfter: no Retry-After.
name: "transient_no_retry_after",
status: http.StatusTooManyRequests,
retryAfter: 0,
body: []byte(`{"error":"rate"}`),
expectRetryAfter: "",
},
{
// Transient exhaustion: 429 with Retry-After in seconds.
name: "transient_with_retry_after",
status: http.StatusTooManyRequests,
retryAfter: 60 * time.Second,
body: []byte(`{"error":"rate"}`),
expectRetryAfter: "60",
},
{
// Transient exhaustion with negative retryAfter: Retry-After header omitted.
name: "transient_negative_retry_after",
status: http.StatusTooManyRequests,
retryAfter: -1 * time.Second,
body: []byte(`{"error":"rate"}`),
expectRetryAfter: "",
},
{
// Transient exhaustion with 500ms retryAfter rounds up to Retry-After: 1.
name: "transient_under_one_second_rounds_up",
status: http.StatusTooManyRequests,
retryAfter: 500 * time.Millisecond,
body: []byte(`{"error":"rate"}`),
expectRetryAfter: "1",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
resp := utils.NewJSONErrorResponse(tc.status, tc.retryAfter, tc.body)
require.NotNil(t, resp)
assert.Equal(t, tc.status, resp.StatusCode)
assert.Equal(t, "application/json", resp.Header.Get("Content-Type"))
assert.Equal(t, int64(len(tc.body)), resp.ContentLength)
if tc.expectRetryAfter == "" {
assert.Empty(t, resp.Header.Get("Retry-After"))
} else {
assert.Equal(t, tc.expectRetryAfter, resp.Header.Get("Retry-After"))
}
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.NoError(t, resp.Body.Close())
assert.Equal(t, tc.body, body)
})
}
}