refactor(aibridge): remove InjectAuthHeader in favor of KeyFailoverConfig (#25618)

## Description

`Provider.InjectAuthHeader` is no longer needed. With the addition of `KeyFailoverConfig` in #24920, authentication is now applied per-attempt by `KeyFailoverTransport` on passthrough routes. This PR removes the dead method from the `Provider` interface, all implementations (`Anthropic`, `OpenAI`, `Copilot`), and the test mock.

The orphaned `InjectAuthHeader` unit tests are replaced with `Test{Anthropic,OpenAI,Copilot}_KeyFailoverConfig`. `TestPassthrough_KeyFailover` is also extended to cover Copilot in the BYOK scenario.

Related to: https://linear.app/codercom/issue/AIGOV-334/aibridge-follow-ups-from-key-failover-prs

> [!NOTE]
> Initially generated by Claude Opus 4.7, modified and reviewed by @ssncferreira
This commit is contained in:
Susana Ferreira
2026-05-25 19:10:38 +01:00
committed by GitHub
parent 22109a54ad
commit 846aac2f74
9 changed files with 285 additions and 169 deletions
@@ -18,7 +18,6 @@ type MockProvider struct {
Bridged []string
Passthrough []string
InterceptorFunc func(w http.ResponseWriter, r *http.Request, tracer trace.Tracer) (intercept.Interceptor, error)
InjectAuthHeaderFunc func(h *http.Header)
}
func (m *MockProvider) Type() string { return m.NameStr }
@@ -28,11 +27,6 @@ func (m *MockProvider) RoutePrefix() string { return fmt.Sprintf("/%s",
func (m *MockProvider) BridgedRoutes() []string { return m.Bridged }
func (m *MockProvider) PassthroughRoutes() []string { return m.Passthrough }
func (*MockProvider) AuthHeader() string { return "Authorization" }
func (m *MockProvider) InjectAuthHeader(h *http.Header) {
if m.InjectAuthHeaderFunc != nil {
m.InjectAuthHeaderFunc(h)
}
}
func (*MockProvider) KeyFailoverConfig(_ slog.Logger) keypool.KeyFailoverConfig {
return keypool.KeyFailoverConfig{}
+23 -2
View File
@@ -311,13 +311,14 @@ func TestPassthrough_KeyFailover(t *testing.T) {
successBody = `{"data":[]}`
)
// providers parameterises the table over the two providers
// that support key failover. Each entry encapsulates the
// providers parameterises the table over the providers exposed
// to the failover transport. Each entry encapsulates the
// provider-specific bits the test needs: how the mock upstream
// extracts the key from the request, how a BYOK request sets
// it, and how the provider is constructed for a given pool.
providers := []struct {
name string
byokOnly bool
extractKey func(*http.Request) string
setBYOK func(*http.Request, string)
newProvider func(baseURL string, pool *keypool.Pool) provider.Provider
@@ -353,6 +354,21 @@ func TestPassthrough_KeyFailover(t *testing.T) {
return provider.NewOpenAI(cfg)
},
},
// Copilot is BYOK-only: its KeyFailoverConfig is zero-value
// so the failover transport short-circuits.
{
name: "copilot",
byokOnly: true,
extractKey: func(r *http.Request) string {
return strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ")
},
setBYOK: func(r *http.Request, key string) {
r.Header.Set("Authorization", "Bearer "+key)
},
newProvider: func(baseURL string, _ *keypool.Pool) provider.Provider {
return provider.NewCopilot(config.Copilot{BaseURL: baseURL})
},
},
}
tests := []struct {
@@ -516,6 +532,11 @@ func TestPassthrough_KeyFailover(t *testing.T) {
for _, prov := range providers {
for _, tc := range tests {
// BYOK-only providers don't use the pool, so pool-based
// cases don't apply.
if prov.byokOnly && tc.byokKey == "" {
continue
}
t.Run(prov.name+"/"+tc.name, func(t *testing.T) {
t.Parallel()
-23
View File
@@ -197,29 +197,6 @@ func (*Anthropic) AuthHeader() string {
return "X-Api-Key"
}
func (p *Anthropic) InjectAuthHeader(headers *http.Header) {
if headers == nil {
headers = &http.Header{}
}
// BYOK: if the request already carries user-supplied credentials,
// do not overwrite them with the centralized key.
if headers.Get("X-Api-Key") != "" || headers.Get("Authorization") != "" {
return
}
// Centralized: pull a single key from the pool. No failover
// or exhaustion handling here.
// TODO(ssncferreira): replace with RoundTripper-based auth
// in the upstack passthrough PR.
if p.cfg.KeyPool == nil {
return
}
if key, err := p.cfg.KeyPool.Walker().Next(); err == nil {
headers.Set(p.AuthHeader(), key.Value())
}
}
func (p *Anthropic) KeyFailoverConfig(logger slog.Logger) keypool.KeyFailoverConfig {
return keypool.KeyFailoverConfig{
Pool: p.cfg.KeyPool,
+117 -35
View File
@@ -5,6 +5,7 @@ import (
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -317,58 +318,139 @@ func TestAnthropic_CreateInterceptor_BYOK(t *testing.T) {
}
}
func TestAnthropic_InjectAuthHeader(t *testing.T) {
func TestAnthropic_KeyFailoverConfig(t *testing.T) {
t.Parallel()
provider := NewAnthropic(config.Anthropic{Key: "centralized-key"}, nil)
pool, err := keypool.New([]string{"k0", "k1"}, quartz.NewMock(t))
require.NoError(t, err)
tests := []struct {
p := NewAnthropic(config.Anthropic{KeyPool: pool}, nil)
cfg := p.KeyFailoverConfig(slog.Make())
assert.Same(t, pool, cfg.Pool, "Pool must be wired from the provider config")
assert.Equal(t, config.ProviderAnthropic, cfg.ProviderName, "ProviderName must match the provider name")
require.NotNil(t, cfg.IsBYOK)
require.NotNil(t, cfg.InjectAuthKey)
require.NotNil(t, cfg.BuildKeyPoolResponse)
t.Run("IsBYOK", func(t *testing.T) {
t.Parallel()
cases := []struct {
name string
presetHeaders map[string]string
wantXApiKey string
headers map[string]string
want bool
}{
{
name: "no_auth_headers",
headers: nil,
want: false,
},
{
name: "non_auth_header",
headers: map[string]string{"Content-Type": "application/json"},
want: false,
},
{
name: "x_api_key_only",
headers: map[string]string{"X-Api-Key": "user-key"},
want: true,
},
{
name: "authorization_only",
headers: map[string]string{"Authorization": "Bearer user-token"},
want: true,
},
{
name: "both_headers_set",
headers: map[string]string{
"X-Api-Key": "user-key",
"Authorization": "Bearer user-token",
},
want: true,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
r := httptest.NewRequest(http.MethodPost, "/", nil)
for k, v := range tc.headers {
r.Header.Set(k, v)
}
assert.Equal(t, tc.want, cfg.IsBYOK(r))
})
}
})
t.Run("InjectAuthKey", func(t *testing.T) {
t.Parallel()
cases := []struct {
name string
initialHeaders http.Header
key string
wantAuthorization string
}{
{
name: "when no auth headers are provided, inject centralized key",
presetHeaders: map[string]string{},
wantXApiKey: "centralized-key",
name: "writes_key_to_x_api_key",
initialHeaders: http.Header{},
key: "centralized-key",
wantAuthorization: "",
},
{
name: "when X-Api-Key header is provided, use it",
presetHeaders: map[string]string{"X-Api-Key": "user-api-key"},
wantXApiKey: "user-api-key",
},
{
name: "when Authorization header is provided, use it",
presetHeaders: map[string]string{"Authorization": "Bearer user-access-token"},
wantAuthorization: "Bearer user-access-token",
},
{
name: "when both headers are provided, keep both",
presetHeaders: map[string]string{
"Authorization": "Bearer user-access-token",
"X-Api-Key": "user-api-key",
},
wantXApiKey: "user-api-key",
wantAuthorization: "Bearer user-access-token",
name: "overwrites_existing_x_api_key",
initialHeaders: http.Header{"X-Api-Key": {"stale"}, "Authorization": {"Bearer stale"}},
key: "next-key",
wantAuthorization: "Bearer stale",
},
}
for _, tc := range tests {
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
headers := http.Header{}
for k, v := range tc.presetHeaders {
headers.Set(k, v)
}
provider.InjectAuthHeader(&headers)
assert.Equal(t, tc.wantXApiKey, headers.Get("X-Api-Key"))
headers := tc.initialHeaders
cfg.InjectAuthKey(&headers, tc.key)
assert.Equal(t, tc.key, headers.Get("X-Api-Key"))
assert.Equal(t, tc.wantAuthorization, headers.Get("Authorization"))
})
}
})
t.Run("BuildKeyPoolResponse", func(t *testing.T) {
t.Parallel()
cases := []struct {
name string
err *keypool.Error
wantStatus int
wantRetryAfter string
}{
{
name: "permanent_returns_502",
err: &keypool.Error{Kind: keypool.ErrorKindPermanent},
wantStatus: http.StatusBadGateway,
},
{
name: "rate_limited_returns_429_with_retry_after",
err: &keypool.Error{Kind: keypool.ErrorKindRateLimited, RetryAfter: 5 * time.Second},
wantStatus: http.StatusTooManyRequests,
wantRetryAfter: "5",
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
resp := cfg.BuildKeyPoolResponse(tc.err)
require.NotNil(t, resp)
t.Cleanup(func() { _ = resp.Body.Close() })
assert.Equal(t, tc.wantStatus, resp.StatusCode)
assert.Equal(t, tc.wantRetryAfter, resp.Header.Get("Retry-After"))
})
}
})
}
func TestExtractAnthropicHeaders(t *testing.T) {
-6
View File
@@ -107,12 +107,6 @@ func (*Copilot) AuthHeader() string {
return "Authorization"
}
// InjectAuthHeader is a no-op for Copilot.
// Copilot uses per-user tokens passed in the original Authorization header,
// rather than a global key configured at the provider level.
// The original Authorization header flows through untouched from the client.
func (*Copilot) InjectAuthHeader(_ *http.Header) {}
// KeyFailoverConfig returns a config with a nil Pool, which makes
// the KeyFailoverTransport short-circuit. Copilot is always BYOK.
func (*Copilot) KeyFailoverConfig(_ slog.Logger) keypool.KeyFailoverConfig {
+8 -27
View File
@@ -13,6 +13,7 @@ import (
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/aibridge/config"
"github.com/coder/coder/v2/aibridge/internal/testutil"
"github.com/coder/coder/v2/aibridge/keypool"
)
var testTracer = otel.Tracer("copilot_test")
@@ -51,37 +52,17 @@ func TestCopilot_TypeAndName(t *testing.T) {
}
}
func TestCopilot_InjectAuthHeader(t *testing.T) {
// TestCopilot_KeyFailoverConfig verifies that Copilot, being BYOK-only,
// returns a zero-value KeyFailoverConfig so that KeyFailoverTransport
// short-circuits and passes the request through unchanged.
func TestCopilot_KeyFailoverConfig(t *testing.T) {
t.Parallel()
// Copilot uses per-user key passed in the Authorization header,
// so InjectAuthHeader should not modify any headers.
provider := NewCopilot(config.Copilot{})
p := NewCopilot(config.Copilot{})
t.Run("ExistingHeaders_Unchanged", func(t *testing.T) {
t.Parallel()
cfg := p.KeyFailoverConfig(slog.Make())
headers := http.Header{}
headers.Set("Authorization", "Bearer user-token")
headers.Set("X-Custom-Header", "custom-value")
provider.InjectAuthHeader(&headers)
assert.Equal(t, "Bearer user-token", headers.Get("Authorization"),
"Authorization header should remain unchanged")
assert.Equal(t, "custom-value", headers.Get("X-Custom-Header"),
"other headers should remain unchanged")
})
t.Run("EmptyHeaders_NoneAdded", func(t *testing.T) {
t.Parallel()
headers := http.Header{}
provider.InjectAuthHeader(&headers)
assert.Empty(t, headers, "no headers should be added")
})
assert.Equal(t, keypool.KeyFailoverConfig{}, cfg, "Copilot must return a zero-value KeyFailoverConfig to short-circuit the transport")
}
func TestCopilot_CreateInterceptor(t *testing.T) {
-23
View File
@@ -196,29 +196,6 @@ func (*OpenAI) AuthHeader() string {
return "Authorization"
}
func (p *OpenAI) InjectAuthHeader(headers *http.Header) {
if headers == nil {
headers = &http.Header{}
}
// BYOK: if the request already carries user-supplied credentials,
// do not overwrite them with the centralized key.
if headers.Get("Authorization") != "" {
return
}
// Centralized: pull a single key from the pool. No failover
// or exhaustion handling here.
// TODO(ssncferreira): replace with RoundTripper-based auth
// in the upstack passthrough PR.
if p.cfg.KeyPool == nil {
return
}
if key, err := p.cfg.KeyPool.Walker().Next(); err == nil {
headers.Set(p.AuthHeader(), "Bearer "+key.Value())
}
}
func (p *OpenAI) KeyFailoverConfig(logger slog.Logger) keypool.KeyFailoverConfig {
return keypool.KeyFailoverConfig{
Pool: p.cfg.KeyPool,
+115 -21
View File
@@ -8,6 +8,7 @@ import (
"net/http/httptest"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -18,6 +19,8 @@ import (
"github.com/coder/coder/v2/aibridge/config"
"github.com/coder/coder/v2/aibridge/intercept"
"github.com/coder/coder/v2/aibridge/internal/testutil"
"github.com/coder/coder/v2/aibridge/keypool"
"github.com/coder/quartz"
)
const (
@@ -325,42 +328,133 @@ func TestOpenAI_CreateInterceptor(t *testing.T) {
}
}
func TestOpenAI_InjectAuthHeader(t *testing.T) {
func TestOpenAI_KeyFailoverConfig(t *testing.T) {
t.Parallel()
provider := NewOpenAI(config.OpenAI{Key: "centralized-key"})
pool, err := keypool.New([]string{"k0", "k1"}, quartz.NewMock(t))
require.NoError(t, err)
tests := []struct {
p := NewOpenAI(config.OpenAI{KeyPool: pool})
cfg := p.KeyFailoverConfig(slog.Make())
assert.Same(t, pool, cfg.Pool, "Pool must be wired from the provider config")
assert.Equal(t, config.ProviderOpenAI, cfg.ProviderName, "ProviderName must match the provider name")
require.NotNil(t, cfg.IsBYOK)
require.NotNil(t, cfg.InjectAuthKey)
require.NotNil(t, cfg.BuildKeyPoolResponse)
t.Run("IsBYOK", func(t *testing.T) {
t.Parallel()
cases := []struct {
name string
presetHeaders map[string]string
wantAuthorization string
headers map[string]string
want bool
}{
{
name: "when no Authorization header is provided, inject centralized key",
presetHeaders: map[string]string{},
wantAuthorization: "Bearer centralized-key",
name: "no_auth_headers",
headers: nil,
want: false,
},
{
name: "when Authorization header is provided, do not overwrite it",
presetHeaders: map[string]string{"Authorization": "Bearer user-token"},
wantAuthorization: "Bearer user-token",
name: "non_auth_header",
headers: map[string]string{"Content-Type": "application/json"},
want: false,
},
{
name: "authorization_only",
headers: map[string]string{"Authorization": "Bearer user-token"},
want: true,
},
{
name: "x_api_key_only",
headers: map[string]string{"X-Api-Key": "user-key"},
want: false,
},
{
name: "both_headers_set",
headers: map[string]string{
"Authorization": "Bearer user-token",
"X-Api-Key": "user-key",
},
want: true,
},
}
for _, tc := range tests {
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
headers := http.Header{}
for k, v := range tc.presetHeaders {
headers.Set(k, v)
r := httptest.NewRequest(http.MethodPost, "/", nil)
for k, v := range tc.headers {
r.Header.Set(k, v)
}
provider.InjectAuthHeader(&headers)
assert.Equal(t, tc.wantAuthorization, headers.Get("Authorization"))
assert.Equal(t, tc.want, cfg.IsBYOK(r))
})
}
})
t.Run("InjectAuthKey", func(t *testing.T) {
t.Parallel()
cases := []struct {
name string
initialHeaders http.Header
key string
wantAPIKey string
}{
{
name: "writes_bearer_token_to_authorization",
initialHeaders: http.Header{},
key: "centralized-key",
wantAPIKey: "",
},
{
name: "overwrites_existing_authorization",
initialHeaders: http.Header{"Authorization": {"Bearer stale"}, "X-Api-Key": {"stale"}},
key: "next-key",
wantAPIKey: "stale",
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
headers := tc.initialHeaders
cfg.InjectAuthKey(&headers, tc.key)
assert.Equal(t, "Bearer "+tc.key, headers.Get("Authorization"))
assert.Equal(t, tc.wantAPIKey, headers.Get("X-Api-Key"))
})
}
})
t.Run("BuildKeyPoolResponse", func(t *testing.T) {
t.Parallel()
cases := []struct {
name string
err *keypool.Error
wantStatus int
wantRetryAfter string
}{
{
name: "permanent_returns_502",
err: &keypool.Error{Kind: keypool.ErrorKindPermanent},
wantStatus: http.StatusBadGateway,
},
{
name: "rate_limited_returns_429_with_retry_after",
err: &keypool.Error{Kind: keypool.ErrorKindRateLimited, RetryAfter: 5 * time.Second},
wantStatus: http.StatusTooManyRequests,
wantRetryAfter: "5",
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
resp := cfg.BuildKeyPoolResponse(tc.err)
require.NotNil(t, resp)
t.Cleanup(func() { _ = resp.Body.Close() })
assert.Equal(t, tc.wantStatus, resp.StatusCode)
assert.Equal(t, tc.wantRetryAfter, resp.Header.Get("Retry-After"))
})
}
})
}
func BenchmarkOpenAI_CreateInterceptor_ChatCompletions(b *testing.B) {
-4
View File
@@ -77,10 +77,6 @@ type Provider interface {
// AuthHeader returns the name of the header which the provider expects to find its authentication
// token in.
AuthHeader() string
// InjectAuthHeader allows [Provider]s to set its authentication header.
// TODO(ssncferreira): remove. Auth is now applied per-attempt by
// KeyFailoverTransport (see [Provider.KeyFailoverConfig]).
InjectAuthHeader(*http.Header)
// KeyFailoverConfig returns the per-provider configuration for
// automatic key failover on passthrough routes.
KeyFailoverConfig(logger slog.Logger) keypool.KeyFailoverConfig