Files
coder/aibridge/internal/integrationtest/keypool_failover_internal_test.go
T
Ethan c650aabbef chore: standardize on *_internal_test.go for white-box tests (#25601)
My agent added `//nolint:testpackage` to a test file on one of my PRs.
Again. This PR cleans it up across the entire repo and updates the
in-repo conventions so future agents stop doing it.

The repo already has a precedent for white-box tests that need to touch
unexported symbols: `*_internal_test.go` (145+ existing files). The
`testpackage` linter's default `skip-regexp` exempts that filename
suffix, so the `//nolint:testpackage` directive is unnecessary in every
case where someone reached for it. This PR renames 51 such files to
`*_internal_test.go` via `git mv` so blame and history follow, and
strips the dead directive from 2 files that were already correctly named
(`coderd/oauth2provider/authorize_internal_test.go`,
`coderd/x/chatd/advisor_internal_test.go`).

`.claude/docs/TESTING.md` now documents the rule explicitly under *Test
Package Naming*, which is imported into the root `AGENTS.md` via
`@.claude/docs/TESTING.md`. The rule: prefer `package foo_test`; if you
need internal access, rename the file to `*_internal_test.go` rather
than adding a nolint directive.
2026-05-22 20:24:38 +10:00

261 lines
7.7 KiB
Go

package integrationtest
import (
"fmt"
"io"
"net/http"
"net/http/httptest"
"sync"
"sync/atomic"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tidwall/sjson"
"github.com/coder/coder/v2/aibridge/config"
"github.com/coder/coder/v2/aibridge/fixtures"
"github.com/coder/coder/v2/aibridge/keypool"
"github.com/coder/coder/v2/aibridge/provider"
"github.com/coder/coder/v2/aibridge/utils"
"github.com/coder/quartz"
)
// TestOpenAI_KeyFailover verifies that a pool's key state
// persists across distinct client requests for both OpenAI APIs
// (chat completions and responses), in both blocking and
// streaming modes. A key marked temporary on request 1 is
// skipped on request 2 without a wasted upstream attempt.
func TestOpenAI_KeyFailover(t *testing.T) {
t.Parallel()
tests := []struct {
name string
fixture []byte
path string
streaming bool
successCType string
}{
{
name: "chatcompletions_blocking",
fixture: fixtures.OaiChatSimple,
path: pathOpenAIChatCompletions,
streaming: false,
successCType: "application/json",
},
{
name: "chatcompletions_streaming",
fixture: fixtures.OaiChatSimple,
path: pathOpenAIChatCompletions,
streaming: true,
successCType: "text/event-stream",
},
{
name: "responses_blocking",
fixture: fixtures.OaiResponsesBlockingSimple,
path: pathOpenAIResponses,
streaming: false,
successCType: "application/json",
},
{
name: "responses_streaming",
fixture: fixtures.OaiResponsesStreamingSimple,
path: pathOpenAIResponses,
streaming: true,
successCType: "text/event-stream",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
fix := fixtures.Parse(t, tc.fixture)
var successBody []byte
if tc.streaming {
successBody = fix.Streaming()
} else {
successBody = fix.NonStreaming()
}
pool, err := keypool.New([]string{"k0", "k1"}, quartz.NewMock(t))
require.NoError(t, err)
var requestCount atomic.Int32
var seenKeysMu sync.Mutex
var seenKeys []string
// Mock upstream: k0 always returns 429, k1 returns
// the per-test success body.
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestCount.Add(1)
key := utils.ExtractBearerToken(r.Header.Get("Authorization"))
seenKeysMu.Lock()
seenKeys = append(seenKeys, key)
seenKeysMu.Unlock()
_, _ = io.Copy(io.Discard, r.Body)
switch key {
case "k0":
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Retry-After", "60")
w.WriteHeader(http.StatusTooManyRequests)
_, _ = fmt.Fprint(w, `{"error":{"type":"rate_limit_error","message":"rate limited","code":"rate_limit_exceeded"}}`)
case "k1":
w.Header().Set("Content-Type", tc.successCType)
w.WriteHeader(http.StatusOK)
_, _ = w.Write(successBody)
default:
w.WriteHeader(http.StatusInternalServerError)
}
}))
t.Cleanup(upstream.Close)
bridgeServer := newBridgeTestServer(t.Context(), t, upstream.URL,
withCustomProvider(provider.NewOpenAI(config.OpenAI{
BaseURL: upstream.URL,
KeyPool: pool,
})),
)
requestBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming)
require.NoError(t, err)
// Request 1: walker starts at k0, fails over to k1
// after 429.
resp, err := bridgeServer.makeRequest(t, http.MethodPost, tc.path, requestBody)
require.NoError(t, err)
_, _ = io.Copy(io.Discard, resp.Body)
require.NoError(t, resp.Body.Close())
require.Equal(t, http.StatusOK, resp.StatusCode)
// Request 2: walker skips the now-temporary k0 and
// goes straight to k1 (1 upstream call, not 2).
resp, err = bridgeServer.makeRequest(t, http.MethodPost, tc.path, requestBody)
require.NoError(t, err)
_, _ = io.Copy(io.Discard, resp.Body)
require.NoError(t, resp.Body.Close())
require.Equal(t, http.StatusOK, resp.StatusCode)
seenKeysMu.Lock()
defer seenKeysMu.Unlock()
// Request 1: 2 calls (k0 then k1). Request 2: 1 call (k1).
assert.Equal(t, int32(3), requestCount.Load(), "upstream request count")
assert.Equal(t, []string{"k0", "k1", "k1"}, seenKeys, "seen keys")
// Pool state persists: k0 temporary, k1 valid.
assert.Equal(t, []keypool.KeyState{
keypool.KeyStateTemporary,
keypool.KeyStateValid,
}, pool.PoolState(), "key states")
})
}
}
// TestAnthropic_KeyFailover verifies that a pool's key state
// persists across distinct client requests: a key marked
// temporary on request 1 is still skipped on request 2 without
// a wasted upstream attempt.
func TestAnthropic_KeyFailover(t *testing.T) {
t.Parallel()
fix := fixtures.Parse(t, fixtures.AntSimple)
tests := []struct {
name string
streaming bool
successBody []byte
successCType string
}{
{
name: "blocking",
streaming: false,
successBody: fix.NonStreaming(),
successCType: "application/json",
},
{
name: "streaming",
streaming: true,
successBody: fix.Streaming(),
successCType: "text/event-stream",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
pool, err := keypool.New([]string{"k0", "k1"}, quartz.NewMock(t))
require.NoError(t, err)
var requestCount atomic.Int32
var seenKeysMu sync.Mutex
var seenKeys []string
// Mock upstream: k0 always returns 429, k1 returns
// the per-test success body.
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestCount.Add(1)
key := r.Header.Get("X-Api-Key")
seenKeysMu.Lock()
seenKeys = append(seenKeys, key)
seenKeysMu.Unlock()
_, _ = io.Copy(io.Discard, r.Body)
switch key {
case "k0":
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Retry-After", "60")
w.WriteHeader(http.StatusTooManyRequests)
_, _ = fmt.Fprint(w, `{"type":"error","error":{"type":"rate_limit_error","message":"rate limited"}}`)
case "k1":
w.Header().Set("Content-Type", tc.successCType)
w.WriteHeader(http.StatusOK)
_, _ = w.Write(tc.successBody)
default:
w.WriteHeader(http.StatusInternalServerError)
}
}))
t.Cleanup(upstream.Close)
bridgeServer := newBridgeTestServer(t.Context(), t, upstream.URL,
withCustomProvider(provider.NewAnthropic(config.Anthropic{
BaseURL: upstream.URL,
KeyPool: pool,
}, nil)),
)
requestBody, err := sjson.SetBytes(fix.Request(), "stream", tc.streaming)
require.NoError(t, err)
// Request 1: walker starts at k0, fails over to k1
// after 429.
resp, err := bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, requestBody)
require.NoError(t, err)
_, _ = io.Copy(io.Discard, resp.Body)
require.NoError(t, resp.Body.Close())
require.Equal(t, http.StatusOK, resp.StatusCode)
// Request 2: walker skips the now-temporary k0 and
// goes straight to k1 (1 upstream call, not 2).
resp, err = bridgeServer.makeRequest(t, http.MethodPost, pathAnthropicMessages, requestBody)
require.NoError(t, err)
_, _ = io.Copy(io.Discard, resp.Body)
require.NoError(t, resp.Body.Close())
require.Equal(t, http.StatusOK, resp.StatusCode)
seenKeysMu.Lock()
defer seenKeysMu.Unlock()
// Request 1: 2 calls (k0 then k1). Request 2: 1 call (k1).
assert.Equal(t, int32(3), requestCount.Load(), "upstream request count")
assert.Equal(t, []string{"k0", "k1", "k1"}, seenKeys, "seen keys")
// Pool state persists: k0 temporary, k1 valid.
assert.Equal(t, []keypool.KeyState{
keypool.KeyStateTemporary,
keypool.KeyStateValid,
}, pool.PoolState(), "key states")
})
}
}