mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
refactor(aibridge): remove InjectAuthHeader in favor of KeyFailoverConfig (#25618)
## Description `Provider.InjectAuthHeader` is no longer needed. With the addition of `KeyFailoverConfig` in #24920, authentication is now applied per-attempt by `KeyFailoverTransport` on passthrough routes. This PR removes the dead method from the `Provider` interface, all implementations (`Anthropic`, `OpenAI`, `Copilot`), and the test mock. The orphaned `InjectAuthHeader` unit tests are replaced with `Test{Anthropic,OpenAI,Copilot}_KeyFailoverConfig`. `TestPassthrough_KeyFailover` is also extended to cover Copilot in the BYOK scenario. Related to: https://linear.app/codercom/issue/AIGOV-334/aibridge-follow-ups-from-key-failover-prs > [!NOTE] > Initially generated by Claude Opus 4.7, modified and reviewed by @ssncferreira
This commit is contained in:
@@ -13,12 +13,11 @@ import (
|
||||
)
|
||||
|
||||
type MockProvider struct {
|
||||
NameStr string
|
||||
URL string
|
||||
Bridged []string
|
||||
Passthrough []string
|
||||
InterceptorFunc func(w http.ResponseWriter, r *http.Request, tracer trace.Tracer) (intercept.Interceptor, error)
|
||||
InjectAuthHeaderFunc func(h *http.Header)
|
||||
NameStr string
|
||||
URL string
|
||||
Bridged []string
|
||||
Passthrough []string
|
||||
InterceptorFunc func(w http.ResponseWriter, r *http.Request, tracer trace.Tracer) (intercept.Interceptor, error)
|
||||
}
|
||||
|
||||
func (m *MockProvider) Type() string { return m.NameStr }
|
||||
@@ -28,11 +27,6 @@ func (m *MockProvider) RoutePrefix() string { return fmt.Sprintf("/%s",
|
||||
func (m *MockProvider) BridgedRoutes() []string { return m.Bridged }
|
||||
func (m *MockProvider) PassthroughRoutes() []string { return m.Passthrough }
|
||||
func (*MockProvider) AuthHeader() string { return "Authorization" }
|
||||
func (m *MockProvider) InjectAuthHeader(h *http.Header) {
|
||||
if m.InjectAuthHeaderFunc != nil {
|
||||
m.InjectAuthHeaderFunc(h)
|
||||
}
|
||||
}
|
||||
|
||||
func (*MockProvider) KeyFailoverConfig(_ slog.Logger) keypool.KeyFailoverConfig {
|
||||
return keypool.KeyFailoverConfig{}
|
||||
|
||||
@@ -311,13 +311,14 @@ func TestPassthrough_KeyFailover(t *testing.T) {
|
||||
successBody = `{"data":[]}`
|
||||
)
|
||||
|
||||
// providers parameterises the table over the two providers
|
||||
// that support key failover. Each entry encapsulates the
|
||||
// providers parameterises the table over the providers exposed
|
||||
// to the failover transport. Each entry encapsulates the
|
||||
// provider-specific bits the test needs: how the mock upstream
|
||||
// extracts the key from the request, how a BYOK request sets
|
||||
// it, and how the provider is constructed for a given pool.
|
||||
providers := []struct {
|
||||
name string
|
||||
byokOnly bool
|
||||
extractKey func(*http.Request) string
|
||||
setBYOK func(*http.Request, string)
|
||||
newProvider func(baseURL string, pool *keypool.Pool) provider.Provider
|
||||
@@ -353,6 +354,21 @@ func TestPassthrough_KeyFailover(t *testing.T) {
|
||||
return provider.NewOpenAI(cfg)
|
||||
},
|
||||
},
|
||||
// Copilot is BYOK-only: its KeyFailoverConfig is zero-value
|
||||
// so the failover transport short-circuits.
|
||||
{
|
||||
name: "copilot",
|
||||
byokOnly: true,
|
||||
extractKey: func(r *http.Request) string {
|
||||
return strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ")
|
||||
},
|
||||
setBYOK: func(r *http.Request, key string) {
|
||||
r.Header.Set("Authorization", "Bearer "+key)
|
||||
},
|
||||
newProvider: func(baseURL string, _ *keypool.Pool) provider.Provider {
|
||||
return provider.NewCopilot(config.Copilot{BaseURL: baseURL})
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
@@ -516,6 +532,11 @@ func TestPassthrough_KeyFailover(t *testing.T) {
|
||||
|
||||
for _, prov := range providers {
|
||||
for _, tc := range tests {
|
||||
// BYOK-only providers don't use the pool, so pool-based
|
||||
// cases don't apply.
|
||||
if prov.byokOnly && tc.byokKey == "" {
|
||||
continue
|
||||
}
|
||||
t.Run(prov.name+"/"+tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -197,29 +197,6 @@ func (*Anthropic) AuthHeader() string {
|
||||
return "X-Api-Key"
|
||||
}
|
||||
|
||||
func (p *Anthropic) InjectAuthHeader(headers *http.Header) {
|
||||
if headers == nil {
|
||||
headers = &http.Header{}
|
||||
}
|
||||
|
||||
// BYOK: if the request already carries user-supplied credentials,
|
||||
// do not overwrite them with the centralized key.
|
||||
if headers.Get("X-Api-Key") != "" || headers.Get("Authorization") != "" {
|
||||
return
|
||||
}
|
||||
|
||||
// Centralized: pull a single key from the pool. No failover
|
||||
// or exhaustion handling here.
|
||||
// TODO(ssncferreira): replace with RoundTripper-based auth
|
||||
// in the upstack passthrough PR.
|
||||
if p.cfg.KeyPool == nil {
|
||||
return
|
||||
}
|
||||
if key, err := p.cfg.KeyPool.Walker().Next(); err == nil {
|
||||
headers.Set(p.AuthHeader(), key.Value())
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Anthropic) KeyFailoverConfig(logger slog.Logger) keypool.KeyFailoverConfig {
|
||||
return keypool.KeyFailoverConfig{
|
||||
Pool: p.cfg.KeyPool,
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -317,58 +318,139 @@ func TestAnthropic_CreateInterceptor_BYOK(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestAnthropic_InjectAuthHeader(t *testing.T) {
|
||||
func TestAnthropic_KeyFailoverConfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
provider := NewAnthropic(config.Anthropic{Key: "centralized-key"}, nil)
|
||||
pool, err := keypool.New([]string{"k0", "k1"}, quartz.NewMock(t))
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
presetHeaders map[string]string
|
||||
wantXApiKey string
|
||||
wantAuthorization string
|
||||
}{
|
||||
{
|
||||
name: "when no auth headers are provided, inject centralized key",
|
||||
presetHeaders: map[string]string{},
|
||||
wantXApiKey: "centralized-key",
|
||||
},
|
||||
{
|
||||
name: "when X-Api-Key header is provided, use it",
|
||||
presetHeaders: map[string]string{"X-Api-Key": "user-api-key"},
|
||||
wantXApiKey: "user-api-key",
|
||||
},
|
||||
{
|
||||
name: "when Authorization header is provided, use it",
|
||||
presetHeaders: map[string]string{"Authorization": "Bearer user-access-token"},
|
||||
wantAuthorization: "Bearer user-access-token",
|
||||
},
|
||||
{
|
||||
name: "when both headers are provided, keep both",
|
||||
presetHeaders: map[string]string{
|
||||
"Authorization": "Bearer user-access-token",
|
||||
"X-Api-Key": "user-api-key",
|
||||
p := NewAnthropic(config.Anthropic{KeyPool: pool}, nil)
|
||||
|
||||
cfg := p.KeyFailoverConfig(slog.Make())
|
||||
|
||||
assert.Same(t, pool, cfg.Pool, "Pool must be wired from the provider config")
|
||||
assert.Equal(t, config.ProviderAnthropic, cfg.ProviderName, "ProviderName must match the provider name")
|
||||
require.NotNil(t, cfg.IsBYOK)
|
||||
require.NotNil(t, cfg.InjectAuthKey)
|
||||
require.NotNil(t, cfg.BuildKeyPoolResponse)
|
||||
|
||||
t.Run("IsBYOK", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
headers map[string]string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "no_auth_headers",
|
||||
headers: nil,
|
||||
want: false,
|
||||
},
|
||||
wantXApiKey: "user-api-key",
|
||||
wantAuthorization: "Bearer user-access-token",
|
||||
},
|
||||
}
|
||||
{
|
||||
name: "non_auth_header",
|
||||
headers: map[string]string{"Content-Type": "application/json"},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "x_api_key_only",
|
||||
headers: map[string]string{"X-Api-Key": "user-key"},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "authorization_only",
|
||||
headers: map[string]string{"Authorization": "Bearer user-token"},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "both_headers_set",
|
||||
headers: map[string]string{
|
||||
"X-Api-Key": "user-key",
|
||||
"Authorization": "Bearer user-token",
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
for k, v := range tc.headers {
|
||||
r.Header.Set(k, v)
|
||||
}
|
||||
assert.Equal(t, tc.want, cfg.IsBYOK(r))
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
headers := http.Header{}
|
||||
for k, v := range tc.presetHeaders {
|
||||
headers.Set(k, v)
|
||||
}
|
||||
t.Run("InjectAuthKey", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
provider.InjectAuthHeader(&headers)
|
||||
cases := []struct {
|
||||
name string
|
||||
initialHeaders http.Header
|
||||
key string
|
||||
wantAuthorization string
|
||||
}{
|
||||
{
|
||||
name: "writes_key_to_x_api_key",
|
||||
initialHeaders: http.Header{},
|
||||
key: "centralized-key",
|
||||
wantAuthorization: "",
|
||||
},
|
||||
{
|
||||
name: "overwrites_existing_x_api_key",
|
||||
initialHeaders: http.Header{"X-Api-Key": {"stale"}, "Authorization": {"Bearer stale"}},
|
||||
key: "next-key",
|
||||
wantAuthorization: "Bearer stale",
|
||||
},
|
||||
}
|
||||
|
||||
assert.Equal(t, tc.wantXApiKey, headers.Get("X-Api-Key"))
|
||||
assert.Equal(t, tc.wantAuthorization, headers.Get("Authorization"))
|
||||
})
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
headers := tc.initialHeaders
|
||||
cfg.InjectAuthKey(&headers, tc.key)
|
||||
assert.Equal(t, tc.key, headers.Get("X-Api-Key"))
|
||||
assert.Equal(t, tc.wantAuthorization, headers.Get("Authorization"))
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("BuildKeyPoolResponse", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
err *keypool.Error
|
||||
wantStatus int
|
||||
wantRetryAfter string
|
||||
}{
|
||||
{
|
||||
name: "permanent_returns_502",
|
||||
err: &keypool.Error{Kind: keypool.ErrorKindPermanent},
|
||||
wantStatus: http.StatusBadGateway,
|
||||
},
|
||||
{
|
||||
name: "rate_limited_returns_429_with_retry_after",
|
||||
err: &keypool.Error{Kind: keypool.ErrorKindRateLimited, RetryAfter: 5 * time.Second},
|
||||
wantStatus: http.StatusTooManyRequests,
|
||||
wantRetryAfter: "5",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
resp := cfg.BuildKeyPoolResponse(tc.err)
|
||||
require.NotNil(t, resp)
|
||||
t.Cleanup(func() { _ = resp.Body.Close() })
|
||||
assert.Equal(t, tc.wantStatus, resp.StatusCode)
|
||||
assert.Equal(t, tc.wantRetryAfter, resp.Header.Get("Retry-After"))
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestExtractAnthropicHeaders(t *testing.T) {
|
||||
|
||||
@@ -107,12 +107,6 @@ func (*Copilot) AuthHeader() string {
|
||||
return "Authorization"
|
||||
}
|
||||
|
||||
// InjectAuthHeader is a no-op for Copilot.
|
||||
// Copilot uses per-user tokens passed in the original Authorization header,
|
||||
// rather than a global key configured at the provider level.
|
||||
// The original Authorization header flows through untouched from the client.
|
||||
func (*Copilot) InjectAuthHeader(_ *http.Header) {}
|
||||
|
||||
// KeyFailoverConfig returns a config with a nil Pool, which makes
|
||||
// the KeyFailoverTransport short-circuit. Copilot is always BYOK.
|
||||
func (*Copilot) KeyFailoverConfig(_ slog.Logger) keypool.KeyFailoverConfig {
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/aibridge/config"
|
||||
"github.com/coder/coder/v2/aibridge/internal/testutil"
|
||||
"github.com/coder/coder/v2/aibridge/keypool"
|
||||
)
|
||||
|
||||
var testTracer = otel.Tracer("copilot_test")
|
||||
@@ -51,37 +52,17 @@ func TestCopilot_TypeAndName(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCopilot_InjectAuthHeader(t *testing.T) {
|
||||
// TestCopilot_KeyFailoverConfig verifies that Copilot, being BYOK-only,
|
||||
// returns a zero-value KeyFailoverConfig so that KeyFailoverTransport
|
||||
// short-circuits and passes the request through unchanged.
|
||||
func TestCopilot_KeyFailoverConfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Copilot uses per-user key passed in the Authorization header,
|
||||
// so InjectAuthHeader should not modify any headers.
|
||||
provider := NewCopilot(config.Copilot{})
|
||||
p := NewCopilot(config.Copilot{})
|
||||
|
||||
t.Run("ExistingHeaders_Unchanged", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
cfg := p.KeyFailoverConfig(slog.Make())
|
||||
|
||||
headers := http.Header{}
|
||||
headers.Set("Authorization", "Bearer user-token")
|
||||
headers.Set("X-Custom-Header", "custom-value")
|
||||
|
||||
provider.InjectAuthHeader(&headers)
|
||||
|
||||
assert.Equal(t, "Bearer user-token", headers.Get("Authorization"),
|
||||
"Authorization header should remain unchanged")
|
||||
assert.Equal(t, "custom-value", headers.Get("X-Custom-Header"),
|
||||
"other headers should remain unchanged")
|
||||
})
|
||||
|
||||
t.Run("EmptyHeaders_NoneAdded", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
headers := http.Header{}
|
||||
|
||||
provider.InjectAuthHeader(&headers)
|
||||
|
||||
assert.Empty(t, headers, "no headers should be added")
|
||||
})
|
||||
assert.Equal(t, keypool.KeyFailoverConfig{}, cfg, "Copilot must return a zero-value KeyFailoverConfig to short-circuit the transport")
|
||||
}
|
||||
|
||||
func TestCopilot_CreateInterceptor(t *testing.T) {
|
||||
|
||||
@@ -196,29 +196,6 @@ func (*OpenAI) AuthHeader() string {
|
||||
return "Authorization"
|
||||
}
|
||||
|
||||
func (p *OpenAI) InjectAuthHeader(headers *http.Header) {
|
||||
if headers == nil {
|
||||
headers = &http.Header{}
|
||||
}
|
||||
|
||||
// BYOK: if the request already carries user-supplied credentials,
|
||||
// do not overwrite them with the centralized key.
|
||||
if headers.Get("Authorization") != "" {
|
||||
return
|
||||
}
|
||||
|
||||
// Centralized: pull a single key from the pool. No failover
|
||||
// or exhaustion handling here.
|
||||
// TODO(ssncferreira): replace with RoundTripper-based auth
|
||||
// in the upstack passthrough PR.
|
||||
if p.cfg.KeyPool == nil {
|
||||
return
|
||||
}
|
||||
if key, err := p.cfg.KeyPool.Walker().Next(); err == nil {
|
||||
headers.Set(p.AuthHeader(), "Bearer "+key.Value())
|
||||
}
|
||||
}
|
||||
|
||||
func (p *OpenAI) KeyFailoverConfig(logger slog.Logger) keypool.KeyFailoverConfig {
|
||||
return keypool.KeyFailoverConfig{
|
||||
Pool: p.cfg.KeyPool,
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -18,6 +19,8 @@ import (
|
||||
"github.com/coder/coder/v2/aibridge/config"
|
||||
"github.com/coder/coder/v2/aibridge/intercept"
|
||||
"github.com/coder/coder/v2/aibridge/internal/testutil"
|
||||
"github.com/coder/coder/v2/aibridge/keypool"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -325,42 +328,133 @@ func TestOpenAI_CreateInterceptor(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAI_InjectAuthHeader(t *testing.T) {
|
||||
func TestOpenAI_KeyFailoverConfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
provider := NewOpenAI(config.OpenAI{Key: "centralized-key"})
|
||||
pool, err := keypool.New([]string{"k0", "k1"}, quartz.NewMock(t))
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
presetHeaders map[string]string
|
||||
wantAuthorization string
|
||||
}{
|
||||
{
|
||||
name: "when no Authorization header is provided, inject centralized key",
|
||||
presetHeaders: map[string]string{},
|
||||
wantAuthorization: "Bearer centralized-key",
|
||||
},
|
||||
{
|
||||
name: "when Authorization header is provided, do not overwrite it",
|
||||
presetHeaders: map[string]string{"Authorization": "Bearer user-token"},
|
||||
wantAuthorization: "Bearer user-token",
|
||||
},
|
||||
}
|
||||
p := NewOpenAI(config.OpenAI{KeyPool: pool})
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
cfg := p.KeyFailoverConfig(slog.Make())
|
||||
|
||||
headers := http.Header{}
|
||||
for k, v := range tc.presetHeaders {
|
||||
headers.Set(k, v)
|
||||
}
|
||||
assert.Same(t, pool, cfg.Pool, "Pool must be wired from the provider config")
|
||||
assert.Equal(t, config.ProviderOpenAI, cfg.ProviderName, "ProviderName must match the provider name")
|
||||
require.NotNil(t, cfg.IsBYOK)
|
||||
require.NotNil(t, cfg.InjectAuthKey)
|
||||
require.NotNil(t, cfg.BuildKeyPoolResponse)
|
||||
|
||||
provider.InjectAuthHeader(&headers)
|
||||
t.Run("IsBYOK", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
cases := []struct {
|
||||
name string
|
||||
headers map[string]string
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "no_auth_headers",
|
||||
headers: nil,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "non_auth_header",
|
||||
headers: map[string]string{"Content-Type": "application/json"},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "authorization_only",
|
||||
headers: map[string]string{"Authorization": "Bearer user-token"},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "x_api_key_only",
|
||||
headers: map[string]string{"X-Api-Key": "user-key"},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "both_headers_set",
|
||||
headers: map[string]string{
|
||||
"Authorization": "Bearer user-token",
|
||||
"X-Api-Key": "user-key",
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
r := httptest.NewRequest(http.MethodPost, "/", nil)
|
||||
for k, v := range tc.headers {
|
||||
r.Header.Set(k, v)
|
||||
}
|
||||
assert.Equal(t, tc.want, cfg.IsBYOK(r))
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
assert.Equal(t, tc.wantAuthorization, headers.Get("Authorization"))
|
||||
})
|
||||
}
|
||||
t.Run("InjectAuthKey", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
cases := []struct {
|
||||
name string
|
||||
initialHeaders http.Header
|
||||
key string
|
||||
wantAPIKey string
|
||||
}{
|
||||
{
|
||||
name: "writes_bearer_token_to_authorization",
|
||||
initialHeaders: http.Header{},
|
||||
key: "centralized-key",
|
||||
wantAPIKey: "",
|
||||
},
|
||||
{
|
||||
name: "overwrites_existing_authorization",
|
||||
initialHeaders: http.Header{"Authorization": {"Bearer stale"}, "X-Api-Key": {"stale"}},
|
||||
key: "next-key",
|
||||
wantAPIKey: "stale",
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
headers := tc.initialHeaders
|
||||
cfg.InjectAuthKey(&headers, tc.key)
|
||||
assert.Equal(t, "Bearer "+tc.key, headers.Get("Authorization"))
|
||||
assert.Equal(t, tc.wantAPIKey, headers.Get("X-Api-Key"))
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("BuildKeyPoolResponse", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
cases := []struct {
|
||||
name string
|
||||
err *keypool.Error
|
||||
wantStatus int
|
||||
wantRetryAfter string
|
||||
}{
|
||||
{
|
||||
name: "permanent_returns_502",
|
||||
err: &keypool.Error{Kind: keypool.ErrorKindPermanent},
|
||||
wantStatus: http.StatusBadGateway,
|
||||
},
|
||||
{
|
||||
name: "rate_limited_returns_429_with_retry_after",
|
||||
err: &keypool.Error{Kind: keypool.ErrorKindRateLimited, RetryAfter: 5 * time.Second},
|
||||
wantStatus: http.StatusTooManyRequests,
|
||||
wantRetryAfter: "5",
|
||||
},
|
||||
}
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
resp := cfg.BuildKeyPoolResponse(tc.err)
|
||||
require.NotNil(t, resp)
|
||||
t.Cleanup(func() { _ = resp.Body.Close() })
|
||||
assert.Equal(t, tc.wantStatus, resp.StatusCode)
|
||||
assert.Equal(t, tc.wantRetryAfter, resp.Header.Get("Retry-After"))
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkOpenAI_CreateInterceptor_ChatCompletions(b *testing.B) {
|
||||
|
||||
@@ -77,10 +77,6 @@ type Provider interface {
|
||||
// AuthHeader returns the name of the header which the provider expects to find its authentication
|
||||
// token in.
|
||||
AuthHeader() string
|
||||
// InjectAuthHeader allows [Provider]s to set its authentication header.
|
||||
// TODO(ssncferreira): remove. Auth is now applied per-attempt by
|
||||
// KeyFailoverTransport (see [Provider.KeyFailoverConfig]).
|
||||
InjectAuthHeader(*http.Header)
|
||||
// KeyFailoverConfig returns the per-provider configuration for
|
||||
// automatic key failover on passthrough routes.
|
||||
KeyFailoverConfig(logger slog.Logger) keypool.KeyFailoverConfig
|
||||
|
||||
Reference in New Issue
Block a user