Files
coder/aibridge/passthrough_test.go
T
Susana Ferreira 0766cc3097 feat: add automatic key failover for AI Bridge passthrough (#24920)
## Description

Adds automatic key failover for passthrough routes for the Anthropic and OpenAI providers. A new `keyFailoverTransport` wraps the reverse-proxy transport: centralized requests walk the configured key pool and retry with the next key on key-specific failures (401/403/429), reusing the same key-marking semantics as the bridged routes.

BYOK passthrough requests run as a single attempt with no failover.

## Changes

- New `keypool.KeyFailoverConfig` carrying the `Pool` to walk and the provider-specific closures (`IsBYOK`, `InjectAuthKey`, `MarkKey`, `BuildExhaustedResponse`).
- New `keypool.NewKeyFailoverTransport`: wraps an inner `http.RoundTripper`. Returns `inner` unchanged when `Pool` is nil, otherwise produces a transport that buffers the request body once, walks the pool per request, and replays each attempt with the next key.
- New `Provider.KeyFailoverConfig(logger)` interface method. Anthropic injects `X-Api-Key`; OpenAI injects `Authorization: Bearer ...`; Copilot returns an empty config.
- `passthrough.go` wires `NewKeyFailoverTransport` around the existing apidump middleware, so every retry attempt is recorded.

## Related Issues

Related to: https://github.com/coder/internal/issues/1446
Related to: https://linear.app/codercom/issue/AIGOV-197/aibridge-automatic-key-failover-for-bridged-and-passthrough-routes

## Follow-up PRs

- Remove dead `Provider.InjectAuthHeader` method now that all auth is applied per-attempt by `KeyFailoverTransport`.
- Bedrock multi-key support.
- Refactor provider vs interceptor config separation.
- Record the actually-used key in the interception credential hint after failover.

> [!NOTE]
> Initially generated by Claude Opus 4.7, modified and reviewed by @ssncferreira
2026-05-07 15:46:36 +01:00

574 lines
18 KiB
Go

package aibridge //nolint:testpackage // tests unexported newPassthroughRouter
import (
"crypto/tls"
"io"
"maps"
"net"
"net/http"
"net/http/httptest"
"net/http/httputil"
"net/url"
"strings"
"sync/atomic"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/aibridge/config"
"github.com/coder/coder/v2/aibridge/internal/testutil"
"github.com/coder/coder/v2/aibridge/keypool"
"github.com/coder/coder/v2/aibridge/provider"
"github.com/coder/quartz"
)
var testTracer = otel.Tracer("bridge_test")
func TestPassthroughRoutes(t *testing.T) {
t.Parallel()
upstreamRespBody := "upstream response"
tests := []struct {
name string
baseURLPath string
reqPath string
reqHost string
reqRemoteAddr string
reqHeaders http.Header
expectRequestPath string
expectQuery string
expectHeaders http.Header
expectRespStatus int
expectRespBody string
}{
{
name: "passthrough_route_no_path",
reqPath: "/v1/conversations",
expectRequestPath: "/v1/conversations",
expectRespStatus: http.StatusOK,
expectRespBody: upstreamRespBody,
},
{
name: "base_URL_path_is_preserved_in_passthrough_routes",
baseURLPath: "/api/v2",
reqPath: "/v1/models",
expectRequestPath: "/api/v2/v1/models",
expectRespStatus: http.StatusOK,
expectRespBody: upstreamRespBody,
},
{
name: "passthrough_route_break_parse_base_url",
baseURLPath: "/%zz",
reqPath: "/v1/models/",
expectRespStatus: http.StatusBadGateway,
expectRespBody: "invalid provider base URL",
},
{
name: "passthrough_route_rejects_invalid_base_url_path",
baseURLPath: "/%25",
reqPath: "/v1/models",
expectRespStatus: http.StatusBadGateway,
expectRespBody: "invalid provider base URL",
},
{
name: "proxy_headers_are_set_and_forwarded_chain_is_appended",
reqPath: "/v1/models",
reqHost: "client.example.com",
reqRemoteAddr: "1.1.1.1:1111",
reqHeaders: http.Header{
"X-Forwarded-For": {"2.2.2.2, 3.3.3.3"},
},
expectRequestPath: "/v1/models",
expectRespStatus: http.StatusOK,
expectRespBody: upstreamRespBody,
expectHeaders: http.Header{
"Accept-Encoding": {"gzip"},
"User-Agent": {"aibridge"},
"X-Forwarded-For": {"2.2.2.2, 3.3.3.3, 1.1.1.1"},
"X-Forwarded-Host": {"client.example.com"},
"X-Forwarded-Proto": {"http"},
},
},
{
name: "query_string_is_preserved",
reqPath: "/v1/models?search=gpt&limit=10",
expectRequestPath: "/v1/models",
expectQuery: "search=gpt&limit=10",
expectRespStatus: http.StatusOK,
expectRespBody: upstreamRespBody,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, tc.expectRequestPath, r.URL.Path)
assert.Equal(t, tc.expectQuery, r.URL.RawQuery)
if tc.expectHeaders != nil {
assert.Equal(t, tc.expectHeaders, r.Header)
}
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(upstreamRespBody))
}))
t.Cleanup(upstream.Close)
prov := &testutil.MockProvider{
URL: upstream.URL + tc.baseURLPath,
}
handler := newPassthroughRouter(prov, logger, nil, testTracer)
req := httptest.NewRequest("", tc.reqPath, nil)
maps.Copy(req.Header, tc.reqHeaders)
req.Host = tc.reqHost
req.RemoteAddr = tc.reqRemoteAddr
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equal(t, tc.expectRespStatus, resp.Code)
assert.Contains(t, resp.Body.String(), tc.expectRespBody)
})
}
}
func TestRewritePassthroughRequest(t *testing.T) {
t.Parallel()
tests := []struct {
name string
reqPath string
reqRemoteAddr string
reqHeaders http.Header
reqTLS bool
provider *testutil.MockProvider
expectURL string
expectHeaders http.Header
}{
{
name: "sets_upstream_url_and_forwarded_headers_from_client_peer",
reqPath: "http://client-host/chat?stream=true",
reqRemoteAddr: "1.1.1.1:1111",
provider: &testutil.MockProvider{URL: "https://upstream-host/base"},
expectURL: "https://upstream-host/base/chat?stream=true",
expectHeaders: http.Header{
"X-Forwarded-Host": {"client-host"},
"X-Forwarded-Proto": {"http"},
"X-Forwarded-For": {"1.1.1.1"},
"User-Agent": {"aibridge"},
},
},
{
name: "preserves_client_user_agent",
reqPath: "http://client-host/chat",
reqRemoteAddr: "1.1.1.1:1111",
reqHeaders: http.Header{"User-Agent": {"custom-agent/1.0"}},
provider: &testutil.MockProvider{URL: "https://upstream-host/base"},
expectURL: "https://upstream-host/base/chat",
expectHeaders: http.Header{
"X-Forwarded-Host": {"client-host"},
"X-Forwarded-Proto": {"http"},
"X-Forwarded-For": {"1.1.1.1"},
"User-Agent": {"custom-agent/1.0"},
},
},
{
name: "appends_remote_addr_to_existing_forwarded_for_chain",
reqPath: "http://client-host/chat",
reqRemoteAddr: "1.1.1.1:1111",
reqHeaders: http.Header{
"X-Forwarded-For": {"2.2.2.2, 3.3.3.3"},
},
provider: &testutil.MockProvider{URL: "https://upstream-host/base"},
expectURL: "https://upstream-host/base/chat",
expectHeaders: http.Header{
"X-Forwarded-Host": {"client-host"},
"X-Forwarded-Proto": {"http"},
"X-Forwarded-For": {"2.2.2.2, 3.3.3.3, 1.1.1.1"},
"User-Agent": {"aibridge"},
},
},
{
name: "tls_request_sets_forwarded_proto_to_https",
reqPath: "http://client-host/chat",
reqRemoteAddr: "1.1.1.1:1111",
reqTLS: true,
provider: &testutil.MockProvider{URL: "https://upstream-host/base"},
expectURL: "https://upstream-host/base/chat",
expectHeaders: http.Header{
"X-Forwarded-Host": {"client-host"},
"X-Forwarded-Proto": {"https"},
"X-Forwarded-For": {"1.1.1.1"},
"User-Agent": {"aibridge"},
},
},
{
// This is an edge case where whole `X-Forwarded-For` header
// is dropped if last hop (remote addr) is not parseable.
// This is how library handles this case and is not directly
// related to our code. Added it to verify that we
// don't accidentally break this behavior.
name: "omits_forwarded_for_when_remote_addr_is_not_parseable",
reqPath: "http://client-host/chat",
reqRemoteAddr: "not-a-socket-address",
reqHeaders: http.Header{
"X-Forwarded-For": {"1.1.1.1"},
},
provider: &testutil.MockProvider{URL: "https://upstream-host/base"},
expectURL: "https://upstream-host/base/chat",
expectHeaders: http.Header{
"X-Forwarded-Host": {"client-host"},
"X-Forwarded-Proto": {"http"},
"User-Agent": {"aibridge"},
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
r := httptest.NewRequest(http.MethodGet, tc.reqPath, nil)
maps.Copy(r.Header, tc.reqHeaders)
r.RemoteAddr = tc.reqRemoteAddr
if tc.reqTLS {
r.TLS = &tls.ConnectionState{}
}
provBaseURL, err := url.Parse(tc.provider.URL)
assert.NoError(t, err)
pr := &httputil.ProxyRequest{
In: r,
Out: r.Clone(r.Context()),
}
rewritePassthroughRequest(pr, provBaseURL)
assert.Equal(t, tc.expectURL, pr.Out.URL.String())
assert.Equal(t, "", pr.Out.Host)
assert.Equal(t, tc.expectHeaders, pr.Out.Header)
})
}
}
func TestPassthroughRouterReusesProxyInstance(t *testing.T) {
t.Parallel()
var newConnections atomic.Int32
upstream := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("ok"))
}))
upstream.Config.ConnState = func(_ net.Conn, state http.ConnState) {
if state == http.StateNew {
newConnections.Add(1)
}
}
upstream.Start()
t.Cleanup(upstream.Close)
logger := slogtest.Make(t, nil)
prov := &testutil.MockProvider{URL: upstream.URL}
handler := newPassthroughRouter(prov, logger, nil, testTracer)
for i := range 2 {
req := httptest.NewRequest(http.MethodGet, "http://proxy.example.test/v1/models", nil)
resp := httptest.NewRecorder()
handler.ServeHTTP(resp, req)
assert.Equalf(t, http.StatusOK, resp.Code, "request %d", i+1)
assert.Equal(t, "ok", resp.Body.String())
}
assert.EqualValues(t, 1, newConnections.Load())
}
// TestPassthrough_KeyFailover exercises the KeyFailoverTransport
// end-to-end through the passthrough proxy, parameterised over
// providers (anthropic, openai). Each scenario asserts the upstream
// request count, the response status and Retry-After, and the final
// pool state.
func TestPassthrough_KeyFailover(t *testing.T) {
t.Parallel()
type upstreamResponse struct {
statusCode int
body string
headers map[string]string
}
const (
rateLimitBody = `{"error":"rate"}`
authErrorBody = `{"error":"unauthorized"}`
serverErrorBody = `{"error":"server"}`
successBody = `{"data":[]}`
)
// providers parameterises the table over the two providers
// that support key failover. Each entry encapsulates the
// provider-specific bits the test needs: how the mock upstream
// extracts the key from the request, how a BYOK request sets
// it, and how the provider is constructed for a given pool.
providers := []struct {
name string
extractKey func(*http.Request) string
setBYOK func(*http.Request, string)
newProvider func(baseURL string, pool *keypool.Pool) provider.Provider
}{
{
name: "anthropic",
extractKey: func(r *http.Request) string {
return r.Header.Get("X-Api-Key")
},
setBYOK: func(r *http.Request, key string) {
r.Header.Set("X-Api-Key", key)
},
newProvider: func(baseURL string, pool *keypool.Pool) provider.Provider {
return provider.NewAnthropic(config.Anthropic{
BaseURL: baseURL,
KeyPool: pool,
}, nil)
},
},
{
name: "openai",
extractKey: func(r *http.Request) string {
return strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ")
},
setBYOK: func(r *http.Request, key string) {
r.Header.Set("Authorization", "Bearer "+key)
},
newProvider: func(baseURL string, pool *keypool.Pool) provider.Provider {
cfg := config.OpenAI{BaseURL: baseURL}
if pool != nil {
cfg.KeyPool = pool
}
return provider.NewOpenAI(cfg)
},
},
}
tests := []struct {
name string
// Centralized pool keys. Empty when byokKey is set.
keys []string
// BYOK key. Empty when keys is set.
byokKey string
// Scripted upstream responses keyed by API key value.
responses map[string]upstreamResponse
expectedRequestCount int32
expectedStatusCode int
expectedRetryAfter string
// Expected key states after the request, by index in keys.
expectedKeyStates []keypool.KeyState
}{
{
// Given: 1 valid key returning 200.
// Then: 1 request, 200 response, key remains valid.
name: "single_valid_key",
keys: []string{"k0"},
responses: map[string]upstreamResponse{
"k0": {statusCode: http.StatusOK, body: successBody},
},
expectedRequestCount: 1,
expectedStatusCode: http.StatusOK,
expectedKeyStates: []keypool.KeyState{keypool.KeyStateValid},
},
{
// Given: 2 keys; key-0 returns 429, key-1 returns 200.
// Then: 2 requests, 200 response, key-0 temporary, key-1 valid.
name: "failover_after_429",
keys: []string{"k0", "k1"},
responses: map[string]upstreamResponse{
"k0": {
statusCode: http.StatusTooManyRequests,
headers: map[string]string{"Retry-After": "5"},
body: rateLimitBody,
},
"k1": {statusCode: http.StatusOK, body: successBody},
},
expectedRequestCount: 2,
expectedStatusCode: http.StatusOK,
expectedKeyStates: []keypool.KeyState{
keypool.KeyStateTemporary,
keypool.KeyStateValid,
},
},
{
// Given: 2 keys; key-0 returns 401, key-1 returns 200.
// Then: 2 requests, 200 response, key-0 permanent, key-1 valid.
name: "failover_after_401",
keys: []string{"k0", "k1"},
responses: map[string]upstreamResponse{
"k0": {statusCode: http.StatusUnauthorized, body: authErrorBody},
"k1": {statusCode: http.StatusOK, body: successBody},
},
expectedRequestCount: 2,
expectedStatusCode: http.StatusOK,
expectedKeyStates: []keypool.KeyState{
keypool.KeyStatePermanent,
keypool.KeyStateValid,
},
},
{
// Given: 2 keys; key-0 returns 403, key-1 returns 200.
// Then: 2 requests, 200 response, key-0 permanent, key-1 valid.
name: "failover_after_403",
keys: []string{"k0", "k1"},
responses: map[string]upstreamResponse{
"k0": {statusCode: http.StatusForbidden, body: authErrorBody},
"k1": {statusCode: http.StatusOK, body: successBody},
},
expectedRequestCount: 2,
expectedStatusCode: http.StatusOK,
expectedKeyStates: []keypool.KeyState{
keypool.KeyStatePermanent,
keypool.KeyStateValid,
},
},
{
// Given: 3 keys; all return 429 with cooldowns 5s, 3s, 10s.
// Then: 3 requests, 429 response with smallest Retry-After,
// all keys temporary.
name: "all_keys_rate_limited",
keys: []string{"k0", "k1", "k2"},
responses: map[string]upstreamResponse{
"k0": {
statusCode: http.StatusTooManyRequests,
headers: map[string]string{"Retry-After": "5"},
body: rateLimitBody,
},
"k1": {
statusCode: http.StatusTooManyRequests,
headers: map[string]string{"Retry-After": "3"},
body: rateLimitBody,
},
"k2": {
statusCode: http.StatusTooManyRequests,
headers: map[string]string{"Retry-After": "10"},
body: rateLimitBody,
},
},
expectedRequestCount: 3,
expectedStatusCode: http.StatusTooManyRequests,
expectedRetryAfter: "3",
expectedKeyStates: []keypool.KeyState{
keypool.KeyStateTemporary,
keypool.KeyStateTemporary,
keypool.KeyStateTemporary,
},
},
{
// Given: 2 keys; both return 401.
// Then: 2 requests, 502 response, both keys permanent.
name: "all_keys_unauthorized",
keys: []string{"k0", "k1"},
responses: map[string]upstreamResponse{
"k0": {statusCode: http.StatusUnauthorized, body: authErrorBody},
"k1": {statusCode: http.StatusUnauthorized, body: authErrorBody},
},
expectedRequestCount: 2,
expectedStatusCode: http.StatusBadGateway,
expectedKeyStates: []keypool.KeyState{
keypool.KeyStatePermanent,
keypool.KeyStatePermanent,
},
},
{
// Given: 2 keys; key-0 returns 500.
// Then: 1 request, 500 response, both keys remain valid.
name: "server_error_no_failover",
keys: []string{"k0", "k1"},
responses: map[string]upstreamResponse{
"k0": {statusCode: http.StatusInternalServerError, body: serverErrorBody},
},
expectedRequestCount: 1,
expectedStatusCode: http.StatusInternalServerError,
expectedKeyStates: []keypool.KeyState{
keypool.KeyStateValid,
keypool.KeyStateValid,
},
},
{
// Given: BYOK with a single user-supplied key returning 429.
// Then: 1 request, 429 forwarded as-is, no failover.
name: "byok_no_failover",
byokKey: "user-byok",
responses: map[string]upstreamResponse{
"user-byok": {
statusCode: http.StatusTooManyRequests,
headers: map[string]string{"Retry-After": "5"},
body: rateLimitBody,
},
},
expectedRequestCount: 1,
expectedStatusCode: http.StatusTooManyRequests,
expectedRetryAfter: "5",
},
}
for _, prov := range providers {
for _, tc := range tests {
t.Run(prov.name+"/"+tc.name, func(t *testing.T) {
t.Parallel()
// Mock upstream: counts requests and returns
// scripted responses keyed by API key. An unmapped
// key falls through to 500 so misconfigured cases
// surface via the status assertion.
var requestCount atomic.Int32
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestCount.Add(1)
_, _ = io.Copy(io.Discard, r.Body)
resp, ok := tc.responses[prov.extractKey(r)]
if !ok {
resp = upstreamResponse{statusCode: http.StatusInternalServerError}
}
w.Header().Set("Content-Type", "application/json")
for hk, hv := range resp.headers {
w.Header().Set(hk, hv)
}
w.WriteHeader(resp.statusCode)
_, _ = w.Write([]byte(resp.body))
}))
t.Cleanup(upstream.Close)
var pool *keypool.Pool
if len(tc.keys) > 0 {
var err error
pool, err = keypool.New(tc.keys, quartz.NewMock(t))
require.NoError(t, err)
}
p := prov.newProvider(upstream.URL, pool)
// IgnoreErrors: MarkKey logs at ERROR level when a
// key is marked permanent (401/403); slogtest would
// otherwise fail those scenarios.
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
handler := newPassthroughRouter(p, logger, nil, testTracer)
req := httptest.NewRequest(http.MethodGet, "/v1/models", nil)
if tc.byokKey != "" {
prov.setBYOK(req, tc.byokKey)
}
w := httptest.NewRecorder()
handler.ServeHTTP(w, req)
assert.Equal(t, tc.expectedRequestCount, requestCount.Load(), "upstream request count")
assert.Equal(t, tc.expectedStatusCode, w.Code, "response status code")
assert.Equal(t, tc.expectedRetryAfter, w.Header().Get("Retry-After"), "Retry-After header")
if pool != nil {
assert.Equal(t, tc.expectedKeyStates, pool.PoolState(), "key states")
}
})
}
}
}