mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
0766cc3097
## Description Adds automatic key failover for passthrough routes for the Anthropic and OpenAI providers. A new `keyFailoverTransport` wraps the reverse-proxy transport: centralized requests walk the configured key pool and retry with the next key on key-specific failures (401/403/429), reusing the same key-marking semantics as the bridged routes. BYOK passthrough requests run as a single attempt with no failover. ## Changes - New `keypool.KeyFailoverConfig` carrying the `Pool` to walk and the provider-specific closures (`IsBYOK`, `InjectAuthKey`, `MarkKey`, `BuildExhaustedResponse`). - New `keypool.NewKeyFailoverTransport`: wraps an inner `http.RoundTripper`. Returns `inner` unchanged when `Pool` is nil, otherwise produces a transport that buffers the request body once, walks the pool per request, and replays each attempt with the next key. - New `Provider.KeyFailoverConfig(logger)` interface method. Anthropic injects `X-Api-Key`; OpenAI injects `Authorization: Bearer ...`; Copilot returns an empty config. - `passthrough.go` wires `NewKeyFailoverTransport` around the existing apidump middleware, so every retry attempt is recorded. ## Related Issues Related to: https://github.com/coder/internal/issues/1446 Related to: https://linear.app/codercom/issue/AIGOV-197/aibridge-automatic-key-failover-for-bridged-and-passthrough-routes ## Follow-up PRs - Remove dead `Provider.InjectAuthHeader` method now that all auth is applied per-attempt by `KeyFailoverTransport`. - Bedrock multi-key support. - Refactor provider vs interceptor config separation. - Record the actually-used key in the interception credential hint after failover. > [!NOTE] > Initially generated by Claude Opus 4.7, modified and reviewed by @ssncferreira
574 lines
18 KiB
Go
574 lines
18 KiB
Go
package aibridge //nolint:testpackage // tests unexported newPassthroughRouter
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"io"
|
|
"maps"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"net/http/httputil"
|
|
"net/url"
|
|
"strings"
|
|
"sync/atomic"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"go.opentelemetry.io/otel"
|
|
|
|
"cdr.dev/slog/v3/sloggers/slogtest"
|
|
"github.com/coder/coder/v2/aibridge/config"
|
|
"github.com/coder/coder/v2/aibridge/internal/testutil"
|
|
"github.com/coder/coder/v2/aibridge/keypool"
|
|
"github.com/coder/coder/v2/aibridge/provider"
|
|
"github.com/coder/quartz"
|
|
)
|
|
|
|
var testTracer = otel.Tracer("bridge_test")
|
|
|
|
func TestPassthroughRoutes(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
upstreamRespBody := "upstream response"
|
|
tests := []struct {
|
|
name string
|
|
baseURLPath string
|
|
reqPath string
|
|
reqHost string
|
|
reqRemoteAddr string
|
|
reqHeaders http.Header
|
|
expectRequestPath string
|
|
expectQuery string
|
|
expectHeaders http.Header
|
|
expectRespStatus int
|
|
expectRespBody string
|
|
}{
|
|
{
|
|
name: "passthrough_route_no_path",
|
|
reqPath: "/v1/conversations",
|
|
expectRequestPath: "/v1/conversations",
|
|
expectRespStatus: http.StatusOK,
|
|
expectRespBody: upstreamRespBody,
|
|
},
|
|
{
|
|
name: "base_URL_path_is_preserved_in_passthrough_routes",
|
|
baseURLPath: "/api/v2",
|
|
reqPath: "/v1/models",
|
|
expectRequestPath: "/api/v2/v1/models",
|
|
expectRespStatus: http.StatusOK,
|
|
expectRespBody: upstreamRespBody,
|
|
},
|
|
{
|
|
name: "passthrough_route_break_parse_base_url",
|
|
baseURLPath: "/%zz",
|
|
reqPath: "/v1/models/",
|
|
expectRespStatus: http.StatusBadGateway,
|
|
expectRespBody: "invalid provider base URL",
|
|
},
|
|
{
|
|
name: "passthrough_route_rejects_invalid_base_url_path",
|
|
baseURLPath: "/%25",
|
|
reqPath: "/v1/models",
|
|
expectRespStatus: http.StatusBadGateway,
|
|
expectRespBody: "invalid provider base URL",
|
|
},
|
|
{
|
|
name: "proxy_headers_are_set_and_forwarded_chain_is_appended",
|
|
reqPath: "/v1/models",
|
|
reqHost: "client.example.com",
|
|
reqRemoteAddr: "1.1.1.1:1111",
|
|
reqHeaders: http.Header{
|
|
"X-Forwarded-For": {"2.2.2.2, 3.3.3.3"},
|
|
},
|
|
expectRequestPath: "/v1/models",
|
|
expectRespStatus: http.StatusOK,
|
|
expectRespBody: upstreamRespBody,
|
|
expectHeaders: http.Header{
|
|
"Accept-Encoding": {"gzip"},
|
|
"User-Agent": {"aibridge"},
|
|
"X-Forwarded-For": {"2.2.2.2, 3.3.3.3, 1.1.1.1"},
|
|
"X-Forwarded-Host": {"client.example.com"},
|
|
"X-Forwarded-Proto": {"http"},
|
|
},
|
|
},
|
|
{
|
|
name: "query_string_is_preserved",
|
|
reqPath: "/v1/models?search=gpt&limit=10",
|
|
expectRequestPath: "/v1/models",
|
|
expectQuery: "search=gpt&limit=10",
|
|
expectRespStatus: http.StatusOK,
|
|
expectRespBody: upstreamRespBody,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
logger := slogtest.Make(t, nil)
|
|
|
|
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
assert.Equal(t, tc.expectRequestPath, r.URL.Path)
|
|
assert.Equal(t, tc.expectQuery, r.URL.RawQuery)
|
|
if tc.expectHeaders != nil {
|
|
assert.Equal(t, tc.expectHeaders, r.Header)
|
|
}
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = w.Write([]byte(upstreamRespBody))
|
|
}))
|
|
t.Cleanup(upstream.Close)
|
|
|
|
prov := &testutil.MockProvider{
|
|
URL: upstream.URL + tc.baseURLPath,
|
|
}
|
|
|
|
handler := newPassthroughRouter(prov, logger, nil, testTracer)
|
|
|
|
req := httptest.NewRequest("", tc.reqPath, nil)
|
|
maps.Copy(req.Header, tc.reqHeaders)
|
|
req.Host = tc.reqHost
|
|
req.RemoteAddr = tc.reqRemoteAddr
|
|
resp := httptest.NewRecorder()
|
|
handler.ServeHTTP(resp, req)
|
|
|
|
assert.Equal(t, tc.expectRespStatus, resp.Code)
|
|
assert.Contains(t, resp.Body.String(), tc.expectRespBody)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRewritePassthroughRequest(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
reqPath string
|
|
reqRemoteAddr string
|
|
reqHeaders http.Header
|
|
reqTLS bool
|
|
provider *testutil.MockProvider
|
|
expectURL string
|
|
expectHeaders http.Header
|
|
}{
|
|
{
|
|
name: "sets_upstream_url_and_forwarded_headers_from_client_peer",
|
|
reqPath: "http://client-host/chat?stream=true",
|
|
reqRemoteAddr: "1.1.1.1:1111",
|
|
provider: &testutil.MockProvider{URL: "https://upstream-host/base"},
|
|
expectURL: "https://upstream-host/base/chat?stream=true",
|
|
expectHeaders: http.Header{
|
|
"X-Forwarded-Host": {"client-host"},
|
|
"X-Forwarded-Proto": {"http"},
|
|
"X-Forwarded-For": {"1.1.1.1"},
|
|
"User-Agent": {"aibridge"},
|
|
},
|
|
},
|
|
{
|
|
name: "preserves_client_user_agent",
|
|
reqPath: "http://client-host/chat",
|
|
reqRemoteAddr: "1.1.1.1:1111",
|
|
reqHeaders: http.Header{"User-Agent": {"custom-agent/1.0"}},
|
|
provider: &testutil.MockProvider{URL: "https://upstream-host/base"},
|
|
expectURL: "https://upstream-host/base/chat",
|
|
expectHeaders: http.Header{
|
|
"X-Forwarded-Host": {"client-host"},
|
|
"X-Forwarded-Proto": {"http"},
|
|
"X-Forwarded-For": {"1.1.1.1"},
|
|
"User-Agent": {"custom-agent/1.0"},
|
|
},
|
|
},
|
|
{
|
|
name: "appends_remote_addr_to_existing_forwarded_for_chain",
|
|
reqPath: "http://client-host/chat",
|
|
reqRemoteAddr: "1.1.1.1:1111",
|
|
reqHeaders: http.Header{
|
|
"X-Forwarded-For": {"2.2.2.2, 3.3.3.3"},
|
|
},
|
|
provider: &testutil.MockProvider{URL: "https://upstream-host/base"},
|
|
expectURL: "https://upstream-host/base/chat",
|
|
expectHeaders: http.Header{
|
|
"X-Forwarded-Host": {"client-host"},
|
|
"X-Forwarded-Proto": {"http"},
|
|
"X-Forwarded-For": {"2.2.2.2, 3.3.3.3, 1.1.1.1"},
|
|
"User-Agent": {"aibridge"},
|
|
},
|
|
},
|
|
{
|
|
name: "tls_request_sets_forwarded_proto_to_https",
|
|
reqPath: "http://client-host/chat",
|
|
reqRemoteAddr: "1.1.1.1:1111",
|
|
reqTLS: true,
|
|
provider: &testutil.MockProvider{URL: "https://upstream-host/base"},
|
|
expectURL: "https://upstream-host/base/chat",
|
|
expectHeaders: http.Header{
|
|
"X-Forwarded-Host": {"client-host"},
|
|
"X-Forwarded-Proto": {"https"},
|
|
"X-Forwarded-For": {"1.1.1.1"},
|
|
"User-Agent": {"aibridge"},
|
|
},
|
|
},
|
|
{
|
|
// This is an edge case where whole `X-Forwarded-For` header
|
|
// is dropped if last hop (remote addr) is not parseable.
|
|
// This is how library handles this case and is not directly
|
|
// related to our code. Added it to verify that we
|
|
// don't accidentally break this behavior.
|
|
name: "omits_forwarded_for_when_remote_addr_is_not_parseable",
|
|
reqPath: "http://client-host/chat",
|
|
reqRemoteAddr: "not-a-socket-address",
|
|
reqHeaders: http.Header{
|
|
"X-Forwarded-For": {"1.1.1.1"},
|
|
},
|
|
provider: &testutil.MockProvider{URL: "https://upstream-host/base"},
|
|
expectURL: "https://upstream-host/base/chat",
|
|
expectHeaders: http.Header{
|
|
"X-Forwarded-Host": {"client-host"},
|
|
"X-Forwarded-Proto": {"http"},
|
|
"User-Agent": {"aibridge"},
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
r := httptest.NewRequest(http.MethodGet, tc.reqPath, nil)
|
|
maps.Copy(r.Header, tc.reqHeaders)
|
|
r.RemoteAddr = tc.reqRemoteAddr
|
|
if tc.reqTLS {
|
|
r.TLS = &tls.ConnectionState{}
|
|
}
|
|
provBaseURL, err := url.Parse(tc.provider.URL)
|
|
assert.NoError(t, err)
|
|
|
|
pr := &httputil.ProxyRequest{
|
|
In: r,
|
|
Out: r.Clone(r.Context()),
|
|
}
|
|
|
|
rewritePassthroughRequest(pr, provBaseURL)
|
|
|
|
assert.Equal(t, tc.expectURL, pr.Out.URL.String())
|
|
assert.Equal(t, "", pr.Out.Host)
|
|
assert.Equal(t, tc.expectHeaders, pr.Out.Header)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestPassthroughRouterReusesProxyInstance(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var newConnections atomic.Int32
|
|
upstream := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
_, _ = w.Write([]byte("ok"))
|
|
}))
|
|
upstream.Config.ConnState = func(_ net.Conn, state http.ConnState) {
|
|
if state == http.StateNew {
|
|
newConnections.Add(1)
|
|
}
|
|
}
|
|
upstream.Start()
|
|
t.Cleanup(upstream.Close)
|
|
|
|
logger := slogtest.Make(t, nil)
|
|
prov := &testutil.MockProvider{URL: upstream.URL}
|
|
handler := newPassthroughRouter(prov, logger, nil, testTracer)
|
|
|
|
for i := range 2 {
|
|
req := httptest.NewRequest(http.MethodGet, "http://proxy.example.test/v1/models", nil)
|
|
resp := httptest.NewRecorder()
|
|
|
|
handler.ServeHTTP(resp, req)
|
|
|
|
assert.Equalf(t, http.StatusOK, resp.Code, "request %d", i+1)
|
|
assert.Equal(t, "ok", resp.Body.String())
|
|
}
|
|
|
|
assert.EqualValues(t, 1, newConnections.Load())
|
|
}
|
|
|
|
// TestPassthrough_KeyFailover exercises the KeyFailoverTransport
|
|
// end-to-end through the passthrough proxy, parameterised over
|
|
// providers (anthropic, openai). Each scenario asserts the upstream
|
|
// request count, the response status and Retry-After, and the final
|
|
// pool state.
|
|
func TestPassthrough_KeyFailover(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
type upstreamResponse struct {
|
|
statusCode int
|
|
body string
|
|
headers map[string]string
|
|
}
|
|
|
|
const (
|
|
rateLimitBody = `{"error":"rate"}`
|
|
authErrorBody = `{"error":"unauthorized"}`
|
|
serverErrorBody = `{"error":"server"}`
|
|
successBody = `{"data":[]}`
|
|
)
|
|
|
|
// providers parameterises the table over the two providers
|
|
// that support key failover. Each entry encapsulates the
|
|
// provider-specific bits the test needs: how the mock upstream
|
|
// extracts the key from the request, how a BYOK request sets
|
|
// it, and how the provider is constructed for a given pool.
|
|
providers := []struct {
|
|
name string
|
|
extractKey func(*http.Request) string
|
|
setBYOK func(*http.Request, string)
|
|
newProvider func(baseURL string, pool *keypool.Pool) provider.Provider
|
|
}{
|
|
{
|
|
name: "anthropic",
|
|
extractKey: func(r *http.Request) string {
|
|
return r.Header.Get("X-Api-Key")
|
|
},
|
|
setBYOK: func(r *http.Request, key string) {
|
|
r.Header.Set("X-Api-Key", key)
|
|
},
|
|
newProvider: func(baseURL string, pool *keypool.Pool) provider.Provider {
|
|
return provider.NewAnthropic(config.Anthropic{
|
|
BaseURL: baseURL,
|
|
KeyPool: pool,
|
|
}, nil)
|
|
},
|
|
},
|
|
{
|
|
name: "openai",
|
|
extractKey: func(r *http.Request) string {
|
|
return strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ")
|
|
},
|
|
setBYOK: func(r *http.Request, key string) {
|
|
r.Header.Set("Authorization", "Bearer "+key)
|
|
},
|
|
newProvider: func(baseURL string, pool *keypool.Pool) provider.Provider {
|
|
cfg := config.OpenAI{BaseURL: baseURL}
|
|
if pool != nil {
|
|
cfg.KeyPool = pool
|
|
}
|
|
return provider.NewOpenAI(cfg)
|
|
},
|
|
},
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
// Centralized pool keys. Empty when byokKey is set.
|
|
keys []string
|
|
// BYOK key. Empty when keys is set.
|
|
byokKey string
|
|
// Scripted upstream responses keyed by API key value.
|
|
responses map[string]upstreamResponse
|
|
expectedRequestCount int32
|
|
expectedStatusCode int
|
|
expectedRetryAfter string
|
|
// Expected key states after the request, by index in keys.
|
|
expectedKeyStates []keypool.KeyState
|
|
}{
|
|
{
|
|
// Given: 1 valid key returning 200.
|
|
// Then: 1 request, 200 response, key remains valid.
|
|
name: "single_valid_key",
|
|
keys: []string{"k0"},
|
|
responses: map[string]upstreamResponse{
|
|
"k0": {statusCode: http.StatusOK, body: successBody},
|
|
},
|
|
expectedRequestCount: 1,
|
|
expectedStatusCode: http.StatusOK,
|
|
expectedKeyStates: []keypool.KeyState{keypool.KeyStateValid},
|
|
},
|
|
{
|
|
// Given: 2 keys; key-0 returns 429, key-1 returns 200.
|
|
// Then: 2 requests, 200 response, key-0 temporary, key-1 valid.
|
|
name: "failover_after_429",
|
|
keys: []string{"k0", "k1"},
|
|
responses: map[string]upstreamResponse{
|
|
"k0": {
|
|
statusCode: http.StatusTooManyRequests,
|
|
headers: map[string]string{"Retry-After": "5"},
|
|
body: rateLimitBody,
|
|
},
|
|
"k1": {statusCode: http.StatusOK, body: successBody},
|
|
},
|
|
expectedRequestCount: 2,
|
|
expectedStatusCode: http.StatusOK,
|
|
expectedKeyStates: []keypool.KeyState{
|
|
keypool.KeyStateTemporary,
|
|
keypool.KeyStateValid,
|
|
},
|
|
},
|
|
{
|
|
// Given: 2 keys; key-0 returns 401, key-1 returns 200.
|
|
// Then: 2 requests, 200 response, key-0 permanent, key-1 valid.
|
|
name: "failover_after_401",
|
|
keys: []string{"k0", "k1"},
|
|
responses: map[string]upstreamResponse{
|
|
"k0": {statusCode: http.StatusUnauthorized, body: authErrorBody},
|
|
"k1": {statusCode: http.StatusOK, body: successBody},
|
|
},
|
|
expectedRequestCount: 2,
|
|
expectedStatusCode: http.StatusOK,
|
|
expectedKeyStates: []keypool.KeyState{
|
|
keypool.KeyStatePermanent,
|
|
keypool.KeyStateValid,
|
|
},
|
|
},
|
|
{
|
|
// Given: 2 keys; key-0 returns 403, key-1 returns 200.
|
|
// Then: 2 requests, 200 response, key-0 permanent, key-1 valid.
|
|
name: "failover_after_403",
|
|
keys: []string{"k0", "k1"},
|
|
responses: map[string]upstreamResponse{
|
|
"k0": {statusCode: http.StatusForbidden, body: authErrorBody},
|
|
"k1": {statusCode: http.StatusOK, body: successBody},
|
|
},
|
|
expectedRequestCount: 2,
|
|
expectedStatusCode: http.StatusOK,
|
|
expectedKeyStates: []keypool.KeyState{
|
|
keypool.KeyStatePermanent,
|
|
keypool.KeyStateValid,
|
|
},
|
|
},
|
|
{
|
|
// Given: 3 keys; all return 429 with cooldowns 5s, 3s, 10s.
|
|
// Then: 3 requests, 429 response with smallest Retry-After,
|
|
// all keys temporary.
|
|
name: "all_keys_rate_limited",
|
|
keys: []string{"k0", "k1", "k2"},
|
|
responses: map[string]upstreamResponse{
|
|
"k0": {
|
|
statusCode: http.StatusTooManyRequests,
|
|
headers: map[string]string{"Retry-After": "5"},
|
|
body: rateLimitBody,
|
|
},
|
|
"k1": {
|
|
statusCode: http.StatusTooManyRequests,
|
|
headers: map[string]string{"Retry-After": "3"},
|
|
body: rateLimitBody,
|
|
},
|
|
"k2": {
|
|
statusCode: http.StatusTooManyRequests,
|
|
headers: map[string]string{"Retry-After": "10"},
|
|
body: rateLimitBody,
|
|
},
|
|
},
|
|
expectedRequestCount: 3,
|
|
expectedStatusCode: http.StatusTooManyRequests,
|
|
expectedRetryAfter: "3",
|
|
expectedKeyStates: []keypool.KeyState{
|
|
keypool.KeyStateTemporary,
|
|
keypool.KeyStateTemporary,
|
|
keypool.KeyStateTemporary,
|
|
},
|
|
},
|
|
{
|
|
// Given: 2 keys; both return 401.
|
|
// Then: 2 requests, 502 response, both keys permanent.
|
|
name: "all_keys_unauthorized",
|
|
keys: []string{"k0", "k1"},
|
|
responses: map[string]upstreamResponse{
|
|
"k0": {statusCode: http.StatusUnauthorized, body: authErrorBody},
|
|
"k1": {statusCode: http.StatusUnauthorized, body: authErrorBody},
|
|
},
|
|
expectedRequestCount: 2,
|
|
expectedStatusCode: http.StatusBadGateway,
|
|
expectedKeyStates: []keypool.KeyState{
|
|
keypool.KeyStatePermanent,
|
|
keypool.KeyStatePermanent,
|
|
},
|
|
},
|
|
{
|
|
// Given: 2 keys; key-0 returns 500.
|
|
// Then: 1 request, 500 response, both keys remain valid.
|
|
name: "server_error_no_failover",
|
|
keys: []string{"k0", "k1"},
|
|
responses: map[string]upstreamResponse{
|
|
"k0": {statusCode: http.StatusInternalServerError, body: serverErrorBody},
|
|
},
|
|
expectedRequestCount: 1,
|
|
expectedStatusCode: http.StatusInternalServerError,
|
|
expectedKeyStates: []keypool.KeyState{
|
|
keypool.KeyStateValid,
|
|
keypool.KeyStateValid,
|
|
},
|
|
},
|
|
{
|
|
// Given: BYOK with a single user-supplied key returning 429.
|
|
// Then: 1 request, 429 forwarded as-is, no failover.
|
|
name: "byok_no_failover",
|
|
byokKey: "user-byok",
|
|
responses: map[string]upstreamResponse{
|
|
"user-byok": {
|
|
statusCode: http.StatusTooManyRequests,
|
|
headers: map[string]string{"Retry-After": "5"},
|
|
body: rateLimitBody,
|
|
},
|
|
},
|
|
expectedRequestCount: 1,
|
|
expectedStatusCode: http.StatusTooManyRequests,
|
|
expectedRetryAfter: "5",
|
|
},
|
|
}
|
|
|
|
for _, prov := range providers {
|
|
for _, tc := range tests {
|
|
t.Run(prov.name+"/"+tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// Mock upstream: counts requests and returns
|
|
// scripted responses keyed by API key. An unmapped
|
|
// key falls through to 500 so misconfigured cases
|
|
// surface via the status assertion.
|
|
var requestCount atomic.Int32
|
|
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
requestCount.Add(1)
|
|
_, _ = io.Copy(io.Discard, r.Body)
|
|
resp, ok := tc.responses[prov.extractKey(r)]
|
|
if !ok {
|
|
resp = upstreamResponse{statusCode: http.StatusInternalServerError}
|
|
}
|
|
w.Header().Set("Content-Type", "application/json")
|
|
for hk, hv := range resp.headers {
|
|
w.Header().Set(hk, hv)
|
|
}
|
|
w.WriteHeader(resp.statusCode)
|
|
_, _ = w.Write([]byte(resp.body))
|
|
}))
|
|
t.Cleanup(upstream.Close)
|
|
|
|
var pool *keypool.Pool
|
|
if len(tc.keys) > 0 {
|
|
var err error
|
|
pool, err = keypool.New(tc.keys, quartz.NewMock(t))
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
p := prov.newProvider(upstream.URL, pool)
|
|
// IgnoreErrors: MarkKey logs at ERROR level when a
|
|
// key is marked permanent (401/403); slogtest would
|
|
// otherwise fail those scenarios.
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
handler := newPassthroughRouter(p, logger, nil, testTracer)
|
|
|
|
req := httptest.NewRequest(http.MethodGet, "/v1/models", nil)
|
|
if tc.byokKey != "" {
|
|
prov.setBYOK(req, tc.byokKey)
|
|
}
|
|
w := httptest.NewRecorder()
|
|
handler.ServeHTTP(w, req)
|
|
|
|
assert.Equal(t, tc.expectedRequestCount, requestCount.Load(), "upstream request count")
|
|
assert.Equal(t, tc.expectedStatusCode, w.Code, "response status code")
|
|
assert.Equal(t, tc.expectedRetryAfter, w.Header().Get("Retry-After"), "Retry-After header")
|
|
if pool != nil {
|
|
assert.Equal(t, tc.expectedKeyStates, pool.PoolState(), "key states")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
}
|