mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
refactor(aibridge): remove InjectAuthHeader in favor of KeyFailoverConfig (#25618)
## Description `Provider.InjectAuthHeader` is no longer needed. With the addition of `KeyFailoverConfig` in #24920, authentication is now applied per-attempt by `KeyFailoverTransport` on passthrough routes. This PR removes the dead method from the `Provider` interface, all implementations (`Anthropic`, `OpenAI`, `Copilot`), and the test mock. The orphaned `InjectAuthHeader` unit tests are replaced with `Test{Anthropic,OpenAI,Copilot}_KeyFailoverConfig`. `TestPassthrough_KeyFailover` is also extended to cover Copilot in the BYOK scenario. Related to: https://linear.app/codercom/issue/AIGOV-334/aibridge-follow-ups-from-key-failover-prs > [!NOTE] > Initially generated by Claude Opus 4.7, modified and reviewed by @ssncferreira
This commit is contained in:
@@ -13,12 +13,11 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type MockProvider struct {
|
type MockProvider struct {
|
||||||
NameStr string
|
NameStr string
|
||||||
URL string
|
URL string
|
||||||
Bridged []string
|
Bridged []string
|
||||||
Passthrough []string
|
Passthrough []string
|
||||||
InterceptorFunc func(w http.ResponseWriter, r *http.Request, tracer trace.Tracer) (intercept.Interceptor, error)
|
InterceptorFunc func(w http.ResponseWriter, r *http.Request, tracer trace.Tracer) (intercept.Interceptor, error)
|
||||||
InjectAuthHeaderFunc func(h *http.Header)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockProvider) Type() string { return m.NameStr }
|
func (m *MockProvider) Type() string { return m.NameStr }
|
||||||
@@ -28,11 +27,6 @@ func (m *MockProvider) RoutePrefix() string { return fmt.Sprintf("/%s",
|
|||||||
func (m *MockProvider) BridgedRoutes() []string { return m.Bridged }
|
func (m *MockProvider) BridgedRoutes() []string { return m.Bridged }
|
||||||
func (m *MockProvider) PassthroughRoutes() []string { return m.Passthrough }
|
func (m *MockProvider) PassthroughRoutes() []string { return m.Passthrough }
|
||||||
func (*MockProvider) AuthHeader() string { return "Authorization" }
|
func (*MockProvider) AuthHeader() string { return "Authorization" }
|
||||||
func (m *MockProvider) InjectAuthHeader(h *http.Header) {
|
|
||||||
if m.InjectAuthHeaderFunc != nil {
|
|
||||||
m.InjectAuthHeaderFunc(h)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (*MockProvider) KeyFailoverConfig(_ slog.Logger) keypool.KeyFailoverConfig {
|
func (*MockProvider) KeyFailoverConfig(_ slog.Logger) keypool.KeyFailoverConfig {
|
||||||
return keypool.KeyFailoverConfig{}
|
return keypool.KeyFailoverConfig{}
|
||||||
|
|||||||
@@ -311,13 +311,14 @@ func TestPassthrough_KeyFailover(t *testing.T) {
|
|||||||
successBody = `{"data":[]}`
|
successBody = `{"data":[]}`
|
||||||
)
|
)
|
||||||
|
|
||||||
// providers parameterises the table over the two providers
|
// providers parameterises the table over the providers exposed
|
||||||
// that support key failover. Each entry encapsulates the
|
// to the failover transport. Each entry encapsulates the
|
||||||
// provider-specific bits the test needs: how the mock upstream
|
// provider-specific bits the test needs: how the mock upstream
|
||||||
// extracts the key from the request, how a BYOK request sets
|
// extracts the key from the request, how a BYOK request sets
|
||||||
// it, and how the provider is constructed for a given pool.
|
// it, and how the provider is constructed for a given pool.
|
||||||
providers := []struct {
|
providers := []struct {
|
||||||
name string
|
name string
|
||||||
|
byokOnly bool
|
||||||
extractKey func(*http.Request) string
|
extractKey func(*http.Request) string
|
||||||
setBYOK func(*http.Request, string)
|
setBYOK func(*http.Request, string)
|
||||||
newProvider func(baseURL string, pool *keypool.Pool) provider.Provider
|
newProvider func(baseURL string, pool *keypool.Pool) provider.Provider
|
||||||
@@ -353,6 +354,21 @@ func TestPassthrough_KeyFailover(t *testing.T) {
|
|||||||
return provider.NewOpenAI(cfg)
|
return provider.NewOpenAI(cfg)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
// Copilot is BYOK-only: its KeyFailoverConfig is zero-value
|
||||||
|
// so the failover transport short-circuits.
|
||||||
|
{
|
||||||
|
name: "copilot",
|
||||||
|
byokOnly: true,
|
||||||
|
extractKey: func(r *http.Request) string {
|
||||||
|
return strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ")
|
||||||
|
},
|
||||||
|
setBYOK: func(r *http.Request, key string) {
|
||||||
|
r.Header.Set("Authorization", "Bearer "+key)
|
||||||
|
},
|
||||||
|
newProvider: func(baseURL string, _ *keypool.Pool) provider.Provider {
|
||||||
|
return provider.NewCopilot(config.Copilot{BaseURL: baseURL})
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
@@ -516,6 +532,11 @@ func TestPassthrough_KeyFailover(t *testing.T) {
|
|||||||
|
|
||||||
for _, prov := range providers {
|
for _, prov := range providers {
|
||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
|
// BYOK-only providers don't use the pool, so pool-based
|
||||||
|
// cases don't apply.
|
||||||
|
if prov.byokOnly && tc.byokKey == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
t.Run(prov.name+"/"+tc.name, func(t *testing.T) {
|
t.Run(prov.name+"/"+tc.name, func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
|||||||
@@ -197,29 +197,6 @@ func (*Anthropic) AuthHeader() string {
|
|||||||
return "X-Api-Key"
|
return "X-Api-Key"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Anthropic) InjectAuthHeader(headers *http.Header) {
|
|
||||||
if headers == nil {
|
|
||||||
headers = &http.Header{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// BYOK: if the request already carries user-supplied credentials,
|
|
||||||
// do not overwrite them with the centralized key.
|
|
||||||
if headers.Get("X-Api-Key") != "" || headers.Get("Authorization") != "" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Centralized: pull a single key from the pool. No failover
|
|
||||||
// or exhaustion handling here.
|
|
||||||
// TODO(ssncferreira): replace with RoundTripper-based auth
|
|
||||||
// in the upstack passthrough PR.
|
|
||||||
if p.cfg.KeyPool == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if key, err := p.cfg.KeyPool.Walker().Next(); err == nil {
|
|
||||||
headers.Set(p.AuthHeader(), key.Value())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Anthropic) KeyFailoverConfig(logger slog.Logger) keypool.KeyFailoverConfig {
|
func (p *Anthropic) KeyFailoverConfig(logger slog.Logger) keypool.KeyFailoverConfig {
|
||||||
return keypool.KeyFailoverConfig{
|
return keypool.KeyFailoverConfig{
|
||||||
Pool: p.cfg.KeyPool,
|
Pool: p.cfg.KeyPool,
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@@ -317,58 +318,139 @@ func TestAnthropic_CreateInterceptor_BYOK(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAnthropic_InjectAuthHeader(t *testing.T) {
|
func TestAnthropic_KeyFailoverConfig(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
provider := NewAnthropic(config.Anthropic{Key: "centralized-key"}, nil)
|
pool, err := keypool.New([]string{"k0", "k1"}, quartz.NewMock(t))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
tests := []struct {
|
p := NewAnthropic(config.Anthropic{KeyPool: pool}, nil)
|
||||||
name string
|
|
||||||
presetHeaders map[string]string
|
cfg := p.KeyFailoverConfig(slog.Make())
|
||||||
wantXApiKey string
|
|
||||||
wantAuthorization string
|
assert.Same(t, pool, cfg.Pool, "Pool must be wired from the provider config")
|
||||||
}{
|
assert.Equal(t, config.ProviderAnthropic, cfg.ProviderName, "ProviderName must match the provider name")
|
||||||
{
|
require.NotNil(t, cfg.IsBYOK)
|
||||||
name: "when no auth headers are provided, inject centralized key",
|
require.NotNil(t, cfg.InjectAuthKey)
|
||||||
presetHeaders: map[string]string{},
|
require.NotNil(t, cfg.BuildKeyPoolResponse)
|
||||||
wantXApiKey: "centralized-key",
|
|
||||||
},
|
t.Run("IsBYOK", func(t *testing.T) {
|
||||||
{
|
t.Parallel()
|
||||||
name: "when X-Api-Key header is provided, use it",
|
|
||||||
presetHeaders: map[string]string{"X-Api-Key": "user-api-key"},
|
cases := []struct {
|
||||||
wantXApiKey: "user-api-key",
|
name string
|
||||||
},
|
headers map[string]string
|
||||||
{
|
want bool
|
||||||
name: "when Authorization header is provided, use it",
|
}{
|
||||||
presetHeaders: map[string]string{"Authorization": "Bearer user-access-token"},
|
{
|
||||||
wantAuthorization: "Bearer user-access-token",
|
name: "no_auth_headers",
|
||||||
},
|
headers: nil,
|
||||||
{
|
want: false,
|
||||||
name: "when both headers are provided, keep both",
|
|
||||||
presetHeaders: map[string]string{
|
|
||||||
"Authorization": "Bearer user-access-token",
|
|
||||||
"X-Api-Key": "user-api-key",
|
|
||||||
},
|
},
|
||||||
wantXApiKey: "user-api-key",
|
{
|
||||||
wantAuthorization: "Bearer user-access-token",
|
name: "non_auth_header",
|
||||||
},
|
headers: map[string]string{"Content-Type": "application/json"},
|
||||||
}
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "x_api_key_only",
|
||||||
|
headers: map[string]string{"X-Api-Key": "user-key"},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "authorization_only",
|
||||||
|
headers: map[string]string{"Authorization": "Bearer user-token"},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "both_headers_set",
|
||||||
|
headers: map[string]string{
|
||||||
|
"X-Api-Key": "user-key",
|
||||||
|
"Authorization": "Bearer user-token",
|
||||||
|
},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
for _, tc := range tests {
|
for _, tc := range cases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
r := httptest.NewRequest(http.MethodPost, "/", nil)
|
||||||
|
for k, v := range tc.headers {
|
||||||
|
r.Header.Set(k, v)
|
||||||
|
}
|
||||||
|
assert.Equal(t, tc.want, cfg.IsBYOK(r))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
headers := http.Header{}
|
t.Run("InjectAuthKey", func(t *testing.T) {
|
||||||
for k, v := range tc.presetHeaders {
|
t.Parallel()
|
||||||
headers.Set(k, v)
|
|
||||||
}
|
|
||||||
|
|
||||||
provider.InjectAuthHeader(&headers)
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
initialHeaders http.Header
|
||||||
|
key string
|
||||||
|
wantAuthorization string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "writes_key_to_x_api_key",
|
||||||
|
initialHeaders: http.Header{},
|
||||||
|
key: "centralized-key",
|
||||||
|
wantAuthorization: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "overwrites_existing_x_api_key",
|
||||||
|
initialHeaders: http.Header{"X-Api-Key": {"stale"}, "Authorization": {"Bearer stale"}},
|
||||||
|
key: "next-key",
|
||||||
|
wantAuthorization: "Bearer stale",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
assert.Equal(t, tc.wantXApiKey, headers.Get("X-Api-Key"))
|
for _, tc := range cases {
|
||||||
assert.Equal(t, tc.wantAuthorization, headers.Get("Authorization"))
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
})
|
t.Parallel()
|
||||||
}
|
headers := tc.initialHeaders
|
||||||
|
cfg.InjectAuthKey(&headers, tc.key)
|
||||||
|
assert.Equal(t, tc.key, headers.Get("X-Api-Key"))
|
||||||
|
assert.Equal(t, tc.wantAuthorization, headers.Get("Authorization"))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("BuildKeyPoolResponse", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
err *keypool.Error
|
||||||
|
wantStatus int
|
||||||
|
wantRetryAfter string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "permanent_returns_502",
|
||||||
|
err: &keypool.Error{Kind: keypool.ErrorKindPermanent},
|
||||||
|
wantStatus: http.StatusBadGateway,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "rate_limited_returns_429_with_retry_after",
|
||||||
|
err: &keypool.Error{Kind: keypool.ErrorKindRateLimited, RetryAfter: 5 * time.Second},
|
||||||
|
wantStatus: http.StatusTooManyRequests,
|
||||||
|
wantRetryAfter: "5",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
resp := cfg.BuildKeyPoolResponse(tc.err)
|
||||||
|
require.NotNil(t, resp)
|
||||||
|
t.Cleanup(func() { _ = resp.Body.Close() })
|
||||||
|
assert.Equal(t, tc.wantStatus, resp.StatusCode)
|
||||||
|
assert.Equal(t, tc.wantRetryAfter, resp.Header.Get("Retry-After"))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestExtractAnthropicHeaders(t *testing.T) {
|
func TestExtractAnthropicHeaders(t *testing.T) {
|
||||||
|
|||||||
@@ -107,12 +107,6 @@ func (*Copilot) AuthHeader() string {
|
|||||||
return "Authorization"
|
return "Authorization"
|
||||||
}
|
}
|
||||||
|
|
||||||
// InjectAuthHeader is a no-op for Copilot.
|
|
||||||
// Copilot uses per-user tokens passed in the original Authorization header,
|
|
||||||
// rather than a global key configured at the provider level.
|
|
||||||
// The original Authorization header flows through untouched from the client.
|
|
||||||
func (*Copilot) InjectAuthHeader(_ *http.Header) {}
|
|
||||||
|
|
||||||
// KeyFailoverConfig returns a config with a nil Pool, which makes
|
// KeyFailoverConfig returns a config with a nil Pool, which makes
|
||||||
// the KeyFailoverTransport short-circuit. Copilot is always BYOK.
|
// the KeyFailoverTransport short-circuit. Copilot is always BYOK.
|
||||||
func (*Copilot) KeyFailoverConfig(_ slog.Logger) keypool.KeyFailoverConfig {
|
func (*Copilot) KeyFailoverConfig(_ slog.Logger) keypool.KeyFailoverConfig {
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
"cdr.dev/slog/v3"
|
"cdr.dev/slog/v3"
|
||||||
"github.com/coder/coder/v2/aibridge/config"
|
"github.com/coder/coder/v2/aibridge/config"
|
||||||
"github.com/coder/coder/v2/aibridge/internal/testutil"
|
"github.com/coder/coder/v2/aibridge/internal/testutil"
|
||||||
|
"github.com/coder/coder/v2/aibridge/keypool"
|
||||||
)
|
)
|
||||||
|
|
||||||
var testTracer = otel.Tracer("copilot_test")
|
var testTracer = otel.Tracer("copilot_test")
|
||||||
@@ -51,37 +52,17 @@ func TestCopilot_TypeAndName(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCopilot_InjectAuthHeader(t *testing.T) {
|
// TestCopilot_KeyFailoverConfig verifies that Copilot, being BYOK-only,
|
||||||
|
// returns a zero-value KeyFailoverConfig so that KeyFailoverTransport
|
||||||
|
// short-circuits and passes the request through unchanged.
|
||||||
|
func TestCopilot_KeyFailoverConfig(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
// Copilot uses per-user key passed in the Authorization header,
|
p := NewCopilot(config.Copilot{})
|
||||||
// so InjectAuthHeader should not modify any headers.
|
|
||||||
provider := NewCopilot(config.Copilot{})
|
|
||||||
|
|
||||||
t.Run("ExistingHeaders_Unchanged", func(t *testing.T) {
|
cfg := p.KeyFailoverConfig(slog.Make())
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
headers := http.Header{}
|
assert.Equal(t, keypool.KeyFailoverConfig{}, cfg, "Copilot must return a zero-value KeyFailoverConfig to short-circuit the transport")
|
||||||
headers.Set("Authorization", "Bearer user-token")
|
|
||||||
headers.Set("X-Custom-Header", "custom-value")
|
|
||||||
|
|
||||||
provider.InjectAuthHeader(&headers)
|
|
||||||
|
|
||||||
assert.Equal(t, "Bearer user-token", headers.Get("Authorization"),
|
|
||||||
"Authorization header should remain unchanged")
|
|
||||||
assert.Equal(t, "custom-value", headers.Get("X-Custom-Header"),
|
|
||||||
"other headers should remain unchanged")
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("EmptyHeaders_NoneAdded", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
headers := http.Header{}
|
|
||||||
|
|
||||||
provider.InjectAuthHeader(&headers)
|
|
||||||
|
|
||||||
assert.Empty(t, headers, "no headers should be added")
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCopilot_CreateInterceptor(t *testing.T) {
|
func TestCopilot_CreateInterceptor(t *testing.T) {
|
||||||
|
|||||||
@@ -196,29 +196,6 @@ func (*OpenAI) AuthHeader() string {
|
|||||||
return "Authorization"
|
return "Authorization"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *OpenAI) InjectAuthHeader(headers *http.Header) {
|
|
||||||
if headers == nil {
|
|
||||||
headers = &http.Header{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// BYOK: if the request already carries user-supplied credentials,
|
|
||||||
// do not overwrite them with the centralized key.
|
|
||||||
if headers.Get("Authorization") != "" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Centralized: pull a single key from the pool. No failover
|
|
||||||
// or exhaustion handling here.
|
|
||||||
// TODO(ssncferreira): replace with RoundTripper-based auth
|
|
||||||
// in the upstack passthrough PR.
|
|
||||||
if p.cfg.KeyPool == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if key, err := p.cfg.KeyPool.Walker().Next(); err == nil {
|
|
||||||
headers.Set(p.AuthHeader(), "Bearer "+key.Value())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *OpenAI) KeyFailoverConfig(logger slog.Logger) keypool.KeyFailoverConfig {
|
func (p *OpenAI) KeyFailoverConfig(logger slog.Logger) keypool.KeyFailoverConfig {
|
||||||
return keypool.KeyFailoverConfig{
|
return keypool.KeyFailoverConfig{
|
||||||
Pool: p.cfg.KeyPool,
|
Pool: p.cfg.KeyPool,
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@@ -18,6 +19,8 @@ import (
|
|||||||
"github.com/coder/coder/v2/aibridge/config"
|
"github.com/coder/coder/v2/aibridge/config"
|
||||||
"github.com/coder/coder/v2/aibridge/intercept"
|
"github.com/coder/coder/v2/aibridge/intercept"
|
||||||
"github.com/coder/coder/v2/aibridge/internal/testutil"
|
"github.com/coder/coder/v2/aibridge/internal/testutil"
|
||||||
|
"github.com/coder/coder/v2/aibridge/keypool"
|
||||||
|
"github.com/coder/quartz"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -325,42 +328,133 @@ func TestOpenAI_CreateInterceptor(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestOpenAI_InjectAuthHeader(t *testing.T) {
|
func TestOpenAI_KeyFailoverConfig(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
provider := NewOpenAI(config.OpenAI{Key: "centralized-key"})
|
pool, err := keypool.New([]string{"k0", "k1"}, quartz.NewMock(t))
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
tests := []struct {
|
p := NewOpenAI(config.OpenAI{KeyPool: pool})
|
||||||
name string
|
|
||||||
presetHeaders map[string]string
|
|
||||||
wantAuthorization string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "when no Authorization header is provided, inject centralized key",
|
|
||||||
presetHeaders: map[string]string{},
|
|
||||||
wantAuthorization: "Bearer centralized-key",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "when Authorization header is provided, do not overwrite it",
|
|
||||||
presetHeaders: map[string]string{"Authorization": "Bearer user-token"},
|
|
||||||
wantAuthorization: "Bearer user-token",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
cfg := p.KeyFailoverConfig(slog.Make())
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
headers := http.Header{}
|
assert.Same(t, pool, cfg.Pool, "Pool must be wired from the provider config")
|
||||||
for k, v := range tc.presetHeaders {
|
assert.Equal(t, config.ProviderOpenAI, cfg.ProviderName, "ProviderName must match the provider name")
|
||||||
headers.Set(k, v)
|
require.NotNil(t, cfg.IsBYOK)
|
||||||
}
|
require.NotNil(t, cfg.InjectAuthKey)
|
||||||
|
require.NotNil(t, cfg.BuildKeyPoolResponse)
|
||||||
|
|
||||||
provider.InjectAuthHeader(&headers)
|
t.Run("IsBYOK", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
headers map[string]string
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no_auth_headers",
|
||||||
|
headers: nil,
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "non_auth_header",
|
||||||
|
headers: map[string]string{"Content-Type": "application/json"},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "authorization_only",
|
||||||
|
headers: map[string]string{"Authorization": "Bearer user-token"},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "x_api_key_only",
|
||||||
|
headers: map[string]string{"X-Api-Key": "user-key"},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "both_headers_set",
|
||||||
|
headers: map[string]string{
|
||||||
|
"Authorization": "Bearer user-token",
|
||||||
|
"X-Api-Key": "user-key",
|
||||||
|
},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
r := httptest.NewRequest(http.MethodPost, "/", nil)
|
||||||
|
for k, v := range tc.headers {
|
||||||
|
r.Header.Set(k, v)
|
||||||
|
}
|
||||||
|
assert.Equal(t, tc.want, cfg.IsBYOK(r))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
assert.Equal(t, tc.wantAuthorization, headers.Get("Authorization"))
|
t.Run("InjectAuthKey", func(t *testing.T) {
|
||||||
})
|
t.Parallel()
|
||||||
}
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
initialHeaders http.Header
|
||||||
|
key string
|
||||||
|
wantAPIKey string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "writes_bearer_token_to_authorization",
|
||||||
|
initialHeaders: http.Header{},
|
||||||
|
key: "centralized-key",
|
||||||
|
wantAPIKey: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "overwrites_existing_authorization",
|
||||||
|
initialHeaders: http.Header{"Authorization": {"Bearer stale"}, "X-Api-Key": {"stale"}},
|
||||||
|
key: "next-key",
|
||||||
|
wantAPIKey: "stale",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
headers := tc.initialHeaders
|
||||||
|
cfg.InjectAuthKey(&headers, tc.key)
|
||||||
|
assert.Equal(t, "Bearer "+tc.key, headers.Get("Authorization"))
|
||||||
|
assert.Equal(t, tc.wantAPIKey, headers.Get("X-Api-Key"))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("BuildKeyPoolResponse", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
cases := []struct {
|
||||||
|
name string
|
||||||
|
err *keypool.Error
|
||||||
|
wantStatus int
|
||||||
|
wantRetryAfter string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "permanent_returns_502",
|
||||||
|
err: &keypool.Error{Kind: keypool.ErrorKindPermanent},
|
||||||
|
wantStatus: http.StatusBadGateway,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "rate_limited_returns_429_with_retry_after",
|
||||||
|
err: &keypool.Error{Kind: keypool.ErrorKindRateLimited, RetryAfter: 5 * time.Second},
|
||||||
|
wantStatus: http.StatusTooManyRequests,
|
||||||
|
wantRetryAfter: "5",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tc := range cases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
resp := cfg.BuildKeyPoolResponse(tc.err)
|
||||||
|
require.NotNil(t, resp)
|
||||||
|
t.Cleanup(func() { _ = resp.Body.Close() })
|
||||||
|
assert.Equal(t, tc.wantStatus, resp.StatusCode)
|
||||||
|
assert.Equal(t, tc.wantRetryAfter, resp.Header.Get("Retry-After"))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func BenchmarkOpenAI_CreateInterceptor_ChatCompletions(b *testing.B) {
|
func BenchmarkOpenAI_CreateInterceptor_ChatCompletions(b *testing.B) {
|
||||||
|
|||||||
@@ -77,10 +77,6 @@ type Provider interface {
|
|||||||
// AuthHeader returns the name of the header which the provider expects to find its authentication
|
// AuthHeader returns the name of the header which the provider expects to find its authentication
|
||||||
// token in.
|
// token in.
|
||||||
AuthHeader() string
|
AuthHeader() string
|
||||||
// InjectAuthHeader allows [Provider]s to set its authentication header.
|
|
||||||
// TODO(ssncferreira): remove. Auth is now applied per-attempt by
|
|
||||||
// KeyFailoverTransport (see [Provider.KeyFailoverConfig]).
|
|
||||||
InjectAuthHeader(*http.Header)
|
|
||||||
// KeyFailoverConfig returns the per-provider configuration for
|
// KeyFailoverConfig returns the per-provider configuration for
|
||||||
// automatic key failover on passthrough routes.
|
// automatic key failover on passthrough routes.
|
||||||
KeyFailoverConfig(logger slog.Logger) keypool.KeyFailoverConfig
|
KeyFailoverConfig(logger slog.Logger) keypool.KeyFailoverConfig
|
||||||
|
|||||||
Reference in New Issue
Block a user