diff --git a/aibridge/internal/testutil/mockprovider.go b/aibridge/internal/testutil/mockprovider.go index 0c56cf2c9e..0fd85d2863 100644 --- a/aibridge/internal/testutil/mockprovider.go +++ b/aibridge/internal/testutil/mockprovider.go @@ -13,12 +13,11 @@ import ( ) type MockProvider struct { - NameStr string - URL string - Bridged []string - Passthrough []string - InterceptorFunc func(w http.ResponseWriter, r *http.Request, tracer trace.Tracer) (intercept.Interceptor, error) - InjectAuthHeaderFunc func(h *http.Header) + NameStr string + URL string + Bridged []string + Passthrough []string + InterceptorFunc func(w http.ResponseWriter, r *http.Request, tracer trace.Tracer) (intercept.Interceptor, error) } func (m *MockProvider) Type() string { return m.NameStr } @@ -28,11 +27,6 @@ func (m *MockProvider) RoutePrefix() string { return fmt.Sprintf("/%s", func (m *MockProvider) BridgedRoutes() []string { return m.Bridged } func (m *MockProvider) PassthroughRoutes() []string { return m.Passthrough } func (*MockProvider) AuthHeader() string { return "Authorization" } -func (m *MockProvider) InjectAuthHeader(h *http.Header) { - if m.InjectAuthHeaderFunc != nil { - m.InjectAuthHeaderFunc(h) - } -} func (*MockProvider) KeyFailoverConfig(_ slog.Logger) keypool.KeyFailoverConfig { return keypool.KeyFailoverConfig{} diff --git a/aibridge/passthrough_internal_test.go b/aibridge/passthrough_internal_test.go index 79290e05a1..0cfeb00f63 100644 --- a/aibridge/passthrough_internal_test.go +++ b/aibridge/passthrough_internal_test.go @@ -311,13 +311,14 @@ func TestPassthrough_KeyFailover(t *testing.T) { successBody = `{"data":[]}` ) - // providers parameterises the table over the two providers - // that support key failover. Each entry encapsulates the + // providers parameterises the table over the providers exposed + // to the failover transport. Each entry encapsulates the // provider-specific bits the test needs: how the mock upstream // extracts the key from the request, how a BYOK request sets // it, and how the provider is constructed for a given pool. providers := []struct { name string + byokOnly bool extractKey func(*http.Request) string setBYOK func(*http.Request, string) newProvider func(baseURL string, pool *keypool.Pool) provider.Provider @@ -353,6 +354,21 @@ func TestPassthrough_KeyFailover(t *testing.T) { return provider.NewOpenAI(cfg) }, }, + // Copilot is BYOK-only: its KeyFailoverConfig is zero-value + // so the failover transport short-circuits. + { + name: "copilot", + byokOnly: true, + extractKey: func(r *http.Request) string { + return strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ") + }, + setBYOK: func(r *http.Request, key string) { + r.Header.Set("Authorization", "Bearer "+key) + }, + newProvider: func(baseURL string, _ *keypool.Pool) provider.Provider { + return provider.NewCopilot(config.Copilot{BaseURL: baseURL}) + }, + }, } tests := []struct { @@ -516,6 +532,11 @@ func TestPassthrough_KeyFailover(t *testing.T) { for _, prov := range providers { for _, tc := range tests { + // BYOK-only providers don't use the pool, so pool-based + // cases don't apply. + if prov.byokOnly && tc.byokKey == "" { + continue + } t.Run(prov.name+"/"+tc.name, func(t *testing.T) { t.Parallel() diff --git a/aibridge/provider/anthropic.go b/aibridge/provider/anthropic.go index 5fdbebae16..eb50a3b296 100644 --- a/aibridge/provider/anthropic.go +++ b/aibridge/provider/anthropic.go @@ -197,29 +197,6 @@ func (*Anthropic) AuthHeader() string { return "X-Api-Key" } -func (p *Anthropic) InjectAuthHeader(headers *http.Header) { - if headers == nil { - headers = &http.Header{} - } - - // BYOK: if the request already carries user-supplied credentials, - // do not overwrite them with the centralized key. - if headers.Get("X-Api-Key") != "" || headers.Get("Authorization") != "" { - return - } - - // Centralized: pull a single key from the pool. No failover - // or exhaustion handling here. - // TODO(ssncferreira): replace with RoundTripper-based auth - // in the upstack passthrough PR. - if p.cfg.KeyPool == nil { - return - } - if key, err := p.cfg.KeyPool.Walker().Next(); err == nil { - headers.Set(p.AuthHeader(), key.Value()) - } -} - func (p *Anthropic) KeyFailoverConfig(logger slog.Logger) keypool.KeyFailoverConfig { return keypool.KeyFailoverConfig{ Pool: p.cfg.KeyPool, diff --git a/aibridge/provider/anthropic_internal_test.go b/aibridge/provider/anthropic_internal_test.go index fe375d4617..b3d89556a8 100644 --- a/aibridge/provider/anthropic_internal_test.go +++ b/aibridge/provider/anthropic_internal_test.go @@ -5,6 +5,7 @@ import ( "net/http" "net/http/httptest" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -317,58 +318,139 @@ func TestAnthropic_CreateInterceptor_BYOK(t *testing.T) { } } -func TestAnthropic_InjectAuthHeader(t *testing.T) { +func TestAnthropic_KeyFailoverConfig(t *testing.T) { t.Parallel() - provider := NewAnthropic(config.Anthropic{Key: "centralized-key"}, nil) + pool, err := keypool.New([]string{"k0", "k1"}, quartz.NewMock(t)) + require.NoError(t, err) - tests := []struct { - name string - presetHeaders map[string]string - wantXApiKey string - wantAuthorization string - }{ - { - name: "when no auth headers are provided, inject centralized key", - presetHeaders: map[string]string{}, - wantXApiKey: "centralized-key", - }, - { - name: "when X-Api-Key header is provided, use it", - presetHeaders: map[string]string{"X-Api-Key": "user-api-key"}, - wantXApiKey: "user-api-key", - }, - { - name: "when Authorization header is provided, use it", - presetHeaders: map[string]string{"Authorization": "Bearer user-access-token"}, - wantAuthorization: "Bearer user-access-token", - }, - { - name: "when both headers are provided, keep both", - presetHeaders: map[string]string{ - "Authorization": "Bearer user-access-token", - "X-Api-Key": "user-api-key", + p := NewAnthropic(config.Anthropic{KeyPool: pool}, nil) + + cfg := p.KeyFailoverConfig(slog.Make()) + + assert.Same(t, pool, cfg.Pool, "Pool must be wired from the provider config") + assert.Equal(t, config.ProviderAnthropic, cfg.ProviderName, "ProviderName must match the provider name") + require.NotNil(t, cfg.IsBYOK) + require.NotNil(t, cfg.InjectAuthKey) + require.NotNil(t, cfg.BuildKeyPoolResponse) + + t.Run("IsBYOK", func(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + headers map[string]string + want bool + }{ + { + name: "no_auth_headers", + headers: nil, + want: false, }, - wantXApiKey: "user-api-key", - wantAuthorization: "Bearer user-access-token", - }, - } + { + name: "non_auth_header", + headers: map[string]string{"Content-Type": "application/json"}, + want: false, + }, + { + name: "x_api_key_only", + headers: map[string]string{"X-Api-Key": "user-key"}, + want: true, + }, + { + name: "authorization_only", + headers: map[string]string{"Authorization": "Bearer user-token"}, + want: true, + }, + { + name: "both_headers_set", + headers: map[string]string{ + "X-Api-Key": "user-key", + "Authorization": "Bearer user-token", + }, + want: true, + }, + } - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodPost, "/", nil) + for k, v := range tc.headers { + r.Header.Set(k, v) + } + assert.Equal(t, tc.want, cfg.IsBYOK(r)) + }) + } + }) - headers := http.Header{} - for k, v := range tc.presetHeaders { - headers.Set(k, v) - } + t.Run("InjectAuthKey", func(t *testing.T) { + t.Parallel() - provider.InjectAuthHeader(&headers) + cases := []struct { + name string + initialHeaders http.Header + key string + wantAuthorization string + }{ + { + name: "writes_key_to_x_api_key", + initialHeaders: http.Header{}, + key: "centralized-key", + wantAuthorization: "", + }, + { + name: "overwrites_existing_x_api_key", + initialHeaders: http.Header{"X-Api-Key": {"stale"}, "Authorization": {"Bearer stale"}}, + key: "next-key", + wantAuthorization: "Bearer stale", + }, + } - assert.Equal(t, tc.wantXApiKey, headers.Get("X-Api-Key")) - assert.Equal(t, tc.wantAuthorization, headers.Get("Authorization")) - }) - } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + headers := tc.initialHeaders + cfg.InjectAuthKey(&headers, tc.key) + assert.Equal(t, tc.key, headers.Get("X-Api-Key")) + assert.Equal(t, tc.wantAuthorization, headers.Get("Authorization")) + }) + } + }) + + t.Run("BuildKeyPoolResponse", func(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + err *keypool.Error + wantStatus int + wantRetryAfter string + }{ + { + name: "permanent_returns_502", + err: &keypool.Error{Kind: keypool.ErrorKindPermanent}, + wantStatus: http.StatusBadGateway, + }, + { + name: "rate_limited_returns_429_with_retry_after", + err: &keypool.Error{Kind: keypool.ErrorKindRateLimited, RetryAfter: 5 * time.Second}, + wantStatus: http.StatusTooManyRequests, + wantRetryAfter: "5", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + resp := cfg.BuildKeyPoolResponse(tc.err) + require.NotNil(t, resp) + t.Cleanup(func() { _ = resp.Body.Close() }) + assert.Equal(t, tc.wantStatus, resp.StatusCode) + assert.Equal(t, tc.wantRetryAfter, resp.Header.Get("Retry-After")) + }) + } + }) } func TestExtractAnthropicHeaders(t *testing.T) { diff --git a/aibridge/provider/copilot.go b/aibridge/provider/copilot.go index b68513ecec..1186e8b253 100644 --- a/aibridge/provider/copilot.go +++ b/aibridge/provider/copilot.go @@ -107,12 +107,6 @@ func (*Copilot) AuthHeader() string { return "Authorization" } -// InjectAuthHeader is a no-op for Copilot. -// Copilot uses per-user tokens passed in the original Authorization header, -// rather than a global key configured at the provider level. -// The original Authorization header flows through untouched from the client. -func (*Copilot) InjectAuthHeader(_ *http.Header) {} - // KeyFailoverConfig returns a config with a nil Pool, which makes // the KeyFailoverTransport short-circuit. Copilot is always BYOK. func (*Copilot) KeyFailoverConfig(_ slog.Logger) keypool.KeyFailoverConfig { diff --git a/aibridge/provider/copilot_internal_test.go b/aibridge/provider/copilot_internal_test.go index 836dde5b7b..49cb582b78 100644 --- a/aibridge/provider/copilot_internal_test.go +++ b/aibridge/provider/copilot_internal_test.go @@ -13,6 +13,7 @@ import ( "cdr.dev/slog/v3" "github.com/coder/coder/v2/aibridge/config" "github.com/coder/coder/v2/aibridge/internal/testutil" + "github.com/coder/coder/v2/aibridge/keypool" ) var testTracer = otel.Tracer("copilot_test") @@ -51,37 +52,17 @@ func TestCopilot_TypeAndName(t *testing.T) { } } -func TestCopilot_InjectAuthHeader(t *testing.T) { +// TestCopilot_KeyFailoverConfig verifies that Copilot, being BYOK-only, +// returns a zero-value KeyFailoverConfig so that KeyFailoverTransport +// short-circuits and passes the request through unchanged. +func TestCopilot_KeyFailoverConfig(t *testing.T) { t.Parallel() - // Copilot uses per-user key passed in the Authorization header, - // so InjectAuthHeader should not modify any headers. - provider := NewCopilot(config.Copilot{}) + p := NewCopilot(config.Copilot{}) - t.Run("ExistingHeaders_Unchanged", func(t *testing.T) { - t.Parallel() + cfg := p.KeyFailoverConfig(slog.Make()) - headers := http.Header{} - headers.Set("Authorization", "Bearer user-token") - headers.Set("X-Custom-Header", "custom-value") - - provider.InjectAuthHeader(&headers) - - assert.Equal(t, "Bearer user-token", headers.Get("Authorization"), - "Authorization header should remain unchanged") - assert.Equal(t, "custom-value", headers.Get("X-Custom-Header"), - "other headers should remain unchanged") - }) - - t.Run("EmptyHeaders_NoneAdded", func(t *testing.T) { - t.Parallel() - - headers := http.Header{} - - provider.InjectAuthHeader(&headers) - - assert.Empty(t, headers, "no headers should be added") - }) + assert.Equal(t, keypool.KeyFailoverConfig{}, cfg, "Copilot must return a zero-value KeyFailoverConfig to short-circuit the transport") } func TestCopilot_CreateInterceptor(t *testing.T) { diff --git a/aibridge/provider/openai.go b/aibridge/provider/openai.go index be53e612d1..177ae03409 100644 --- a/aibridge/provider/openai.go +++ b/aibridge/provider/openai.go @@ -196,29 +196,6 @@ func (*OpenAI) AuthHeader() string { return "Authorization" } -func (p *OpenAI) InjectAuthHeader(headers *http.Header) { - if headers == nil { - headers = &http.Header{} - } - - // BYOK: if the request already carries user-supplied credentials, - // do not overwrite them with the centralized key. - if headers.Get("Authorization") != "" { - return - } - - // Centralized: pull a single key from the pool. No failover - // or exhaustion handling here. - // TODO(ssncferreira): replace with RoundTripper-based auth - // in the upstack passthrough PR. - if p.cfg.KeyPool == nil { - return - } - if key, err := p.cfg.KeyPool.Walker().Next(); err == nil { - headers.Set(p.AuthHeader(), "Bearer "+key.Value()) - } -} - func (p *OpenAI) KeyFailoverConfig(logger slog.Logger) keypool.KeyFailoverConfig { return keypool.KeyFailoverConfig{ Pool: p.cfg.KeyPool, diff --git a/aibridge/provider/openai_internal_test.go b/aibridge/provider/openai_internal_test.go index 96695a2c2f..e1afcc872c 100644 --- a/aibridge/provider/openai_internal_test.go +++ b/aibridge/provider/openai_internal_test.go @@ -8,6 +8,7 @@ import ( "net/http/httptest" "strings" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -18,6 +19,8 @@ import ( "github.com/coder/coder/v2/aibridge/config" "github.com/coder/coder/v2/aibridge/intercept" "github.com/coder/coder/v2/aibridge/internal/testutil" + "github.com/coder/coder/v2/aibridge/keypool" + "github.com/coder/quartz" ) const ( @@ -325,42 +328,133 @@ func TestOpenAI_CreateInterceptor(t *testing.T) { } } -func TestOpenAI_InjectAuthHeader(t *testing.T) { +func TestOpenAI_KeyFailoverConfig(t *testing.T) { t.Parallel() - provider := NewOpenAI(config.OpenAI{Key: "centralized-key"}) + pool, err := keypool.New([]string{"k0", "k1"}, quartz.NewMock(t)) + require.NoError(t, err) - tests := []struct { - name string - presetHeaders map[string]string - wantAuthorization string - }{ - { - name: "when no Authorization header is provided, inject centralized key", - presetHeaders: map[string]string{}, - wantAuthorization: "Bearer centralized-key", - }, - { - name: "when Authorization header is provided, do not overwrite it", - presetHeaders: map[string]string{"Authorization": "Bearer user-token"}, - wantAuthorization: "Bearer user-token", - }, - } + p := NewOpenAI(config.OpenAI{KeyPool: pool}) - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() + cfg := p.KeyFailoverConfig(slog.Make()) - headers := http.Header{} - for k, v := range tc.presetHeaders { - headers.Set(k, v) - } + assert.Same(t, pool, cfg.Pool, "Pool must be wired from the provider config") + assert.Equal(t, config.ProviderOpenAI, cfg.ProviderName, "ProviderName must match the provider name") + require.NotNil(t, cfg.IsBYOK) + require.NotNil(t, cfg.InjectAuthKey) + require.NotNil(t, cfg.BuildKeyPoolResponse) - provider.InjectAuthHeader(&headers) + t.Run("IsBYOK", func(t *testing.T) { + t.Parallel() + cases := []struct { + name string + headers map[string]string + want bool + }{ + { + name: "no_auth_headers", + headers: nil, + want: false, + }, + { + name: "non_auth_header", + headers: map[string]string{"Content-Type": "application/json"}, + want: false, + }, + { + name: "authorization_only", + headers: map[string]string{"Authorization": "Bearer user-token"}, + want: true, + }, + { + name: "x_api_key_only", + headers: map[string]string{"X-Api-Key": "user-key"}, + want: false, + }, + { + name: "both_headers_set", + headers: map[string]string{ + "Authorization": "Bearer user-token", + "X-Api-Key": "user-key", + }, + want: true, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + r := httptest.NewRequest(http.MethodPost, "/", nil) + for k, v := range tc.headers { + r.Header.Set(k, v) + } + assert.Equal(t, tc.want, cfg.IsBYOK(r)) + }) + } + }) - assert.Equal(t, tc.wantAuthorization, headers.Get("Authorization")) - }) - } + t.Run("InjectAuthKey", func(t *testing.T) { + t.Parallel() + cases := []struct { + name string + initialHeaders http.Header + key string + wantAPIKey string + }{ + { + name: "writes_bearer_token_to_authorization", + initialHeaders: http.Header{}, + key: "centralized-key", + wantAPIKey: "", + }, + { + name: "overwrites_existing_authorization", + initialHeaders: http.Header{"Authorization": {"Bearer stale"}, "X-Api-Key": {"stale"}}, + key: "next-key", + wantAPIKey: "stale", + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + headers := tc.initialHeaders + cfg.InjectAuthKey(&headers, tc.key) + assert.Equal(t, "Bearer "+tc.key, headers.Get("Authorization")) + assert.Equal(t, tc.wantAPIKey, headers.Get("X-Api-Key")) + }) + } + }) + + t.Run("BuildKeyPoolResponse", func(t *testing.T) { + t.Parallel() + cases := []struct { + name string + err *keypool.Error + wantStatus int + wantRetryAfter string + }{ + { + name: "permanent_returns_502", + err: &keypool.Error{Kind: keypool.ErrorKindPermanent}, + wantStatus: http.StatusBadGateway, + }, + { + name: "rate_limited_returns_429_with_retry_after", + err: &keypool.Error{Kind: keypool.ErrorKindRateLimited, RetryAfter: 5 * time.Second}, + wantStatus: http.StatusTooManyRequests, + wantRetryAfter: "5", + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + resp := cfg.BuildKeyPoolResponse(tc.err) + require.NotNil(t, resp) + t.Cleanup(func() { _ = resp.Body.Close() }) + assert.Equal(t, tc.wantStatus, resp.StatusCode) + assert.Equal(t, tc.wantRetryAfter, resp.Header.Get("Retry-After")) + }) + } + }) } func BenchmarkOpenAI_CreateInterceptor_ChatCompletions(b *testing.B) { diff --git a/aibridge/provider/provider.go b/aibridge/provider/provider.go index 587dfd85ce..7520333b53 100644 --- a/aibridge/provider/provider.go +++ b/aibridge/provider/provider.go @@ -77,10 +77,6 @@ type Provider interface { // AuthHeader returns the name of the header which the provider expects to find its authentication // token in. AuthHeader() string - // InjectAuthHeader allows [Provider]s to set its authentication header. - // TODO(ssncferreira): remove. Auth is now applied per-attempt by - // KeyFailoverTransport (see [Provider.KeyFailoverConfig]). - InjectAuthHeader(*http.Header) // KeyFailoverConfig returns the per-provider configuration for // automatic key failover on passthrough routes. KeyFailoverConfig(logger slog.Logger) keypool.KeyFailoverConfig