mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
c650aabbef
My agent added `//nolint:testpackage` to a test file on one of my PRs. Again. This PR cleans it up across the entire repo and updates the in-repo conventions so future agents stop doing it. The repo already has a precedent for white-box tests that need to touch unexported symbols: `*_internal_test.go` (145+ existing files). The `testpackage` linter's default `skip-regexp` exempts that filename suffix, so the `//nolint:testpackage` directive is unnecessary in every case where someone reached for it. This PR renames 51 such files to `*_internal_test.go` via `git mv` so blame and history follow, and strips the dead directive from 2 files that were already correctly named (`coderd/oauth2provider/authorize_internal_test.go`, `coderd/x/chatd/advisor_internal_test.go`). `.claude/docs/TESTING.md` now documents the rule explicitly under *Test Package Naming*, which is imported into the root `AGENTS.md` via `@.claude/docs/TESTING.md`. The rule: prefer `package foo_test`; if you need internal access, rename the file to `*_internal_test.go` rather than adding a nolint directive.
582 lines
15 KiB
Go
582 lines
15 KiB
Go
package responses
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/openai/openai-go/v3"
|
|
oairesponses "github.com/openai/openai-go/v3/responses"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"golang.org/x/xerrors"
|
|
|
|
"cdr.dev/slog/v3"
|
|
"github.com/coder/coder/v2/aibridge/config"
|
|
"github.com/coder/coder/v2/aibridge/internal/testutil"
|
|
"github.com/coder/coder/v2/aibridge/keypool"
|
|
"github.com/coder/coder/v2/aibridge/recorder"
|
|
"github.com/coder/quartz"
|
|
)
|
|
|
|
func TestRecordPrompt(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
promptWasRecorded bool
|
|
prompt string
|
|
responseID string
|
|
wantRecorded bool
|
|
wantPrompt string
|
|
}{
|
|
{
|
|
name: "records_prompt_successfully",
|
|
prompt: "tell me a joke",
|
|
responseID: "resp_123",
|
|
wantRecorded: true,
|
|
wantPrompt: "tell me a joke",
|
|
},
|
|
{
|
|
name: "records_empty_prompt_successfully",
|
|
prompt: "",
|
|
responseID: "resp_123",
|
|
wantRecorded: true,
|
|
wantPrompt: "",
|
|
},
|
|
{
|
|
name: "skips_recording_on_empty_response_id",
|
|
prompt: "tell me a joke",
|
|
responseID: "",
|
|
wantRecorded: false,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
rec := &testutil.MockRecorder{}
|
|
id := uuid.New()
|
|
base := &responsesInterceptionBase{
|
|
id: id,
|
|
recorder: rec,
|
|
logger: slog.Make(),
|
|
}
|
|
|
|
base.recordUserPrompt(t.Context(), tc.responseID, tc.prompt)
|
|
|
|
prompts := rec.RecordedPromptUsages()
|
|
if tc.wantRecorded {
|
|
require.Len(t, prompts, 1)
|
|
require.Equal(t, id.String(), prompts[0].InterceptionID)
|
|
require.Equal(t, tc.responseID, prompts[0].MsgID)
|
|
require.Equal(t, tc.wantPrompt, prompts[0].Prompt)
|
|
} else {
|
|
require.Empty(t, prompts)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRecordToolUsage(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
id := uuid.MustParse("11111111-1111-1111-1111-111111111111")
|
|
|
|
tests := []struct {
|
|
name string
|
|
response *oairesponses.Response
|
|
expected []*recorder.ToolUsageRecord
|
|
}{
|
|
{
|
|
name: "nil_response",
|
|
response: nil,
|
|
expected: nil,
|
|
},
|
|
{
|
|
name: "empty_output",
|
|
response: &oairesponses.Response{
|
|
ID: "resp_123",
|
|
},
|
|
expected: nil,
|
|
},
|
|
{
|
|
name: "empty_tool_args",
|
|
response: &oairesponses.Response{
|
|
ID: "resp_456",
|
|
Output: []oairesponses.ResponseOutputItemUnion{
|
|
{
|
|
Type: "function_call",
|
|
CallID: "call_abc",
|
|
Name: "get_weather",
|
|
Arguments: "",
|
|
},
|
|
},
|
|
},
|
|
expected: []*recorder.ToolUsageRecord{
|
|
{
|
|
InterceptionID: id.String(),
|
|
MsgID: "resp_456",
|
|
ToolCallID: "call_abc",
|
|
Tool: "get_weather",
|
|
Args: "",
|
|
Injected: false,
|
|
},
|
|
},
|
|
},
|
|
{
|
|
name: "multiple_tool_calls",
|
|
response: &oairesponses.Response{
|
|
ID: "resp_789",
|
|
Output: []oairesponses.ResponseOutputItemUnion{
|
|
{
|
|
Type: "function_call",
|
|
CallID: "call_1",
|
|
Name: "get_weather",
|
|
Arguments: `{"location": "NYC"}`,
|
|
},
|
|
{
|
|
Type: "function_call",
|
|
CallID: "call_2",
|
|
Name: "bad_json_args",
|
|
Arguments: `{"bad": args`,
|
|
},
|
|
{
|
|
Type: "message",
|
|
ID: "msg_1",
|
|
Role: "assistant",
|
|
},
|
|
{
|
|
Type: "custom_tool_call",
|
|
CallID: "call_3",
|
|
Name: "search",
|
|
Input: `{\"query\": \"test\"}`,
|
|
},
|
|
{
|
|
Type: "function_call",
|
|
CallID: "call_4",
|
|
Name: "calculate",
|
|
Arguments: `{"a": 1, "b": 2}`,
|
|
},
|
|
},
|
|
},
|
|
expected: []*recorder.ToolUsageRecord{
|
|
{
|
|
InterceptionID: id.String(),
|
|
MsgID: "resp_789",
|
|
ToolCallID: "call_1",
|
|
Tool: "get_weather",
|
|
Args: map[string]any{"location": "NYC"},
|
|
Injected: false,
|
|
},
|
|
{
|
|
InterceptionID: id.String(),
|
|
MsgID: "resp_789",
|
|
ToolCallID: "call_2",
|
|
Tool: "bad_json_args",
|
|
Args: `{"bad": args`,
|
|
Injected: false,
|
|
},
|
|
{
|
|
InterceptionID: id.String(),
|
|
MsgID: "resp_789",
|
|
ToolCallID: "call_3",
|
|
Tool: "search",
|
|
Args: `{\"query\": \"test\"}`,
|
|
Injected: false,
|
|
},
|
|
{
|
|
InterceptionID: id.String(),
|
|
MsgID: "resp_789",
|
|
ToolCallID: "call_4",
|
|
Tool: "calculate",
|
|
Args: map[string]any{"a": float64(1), "b": float64(2)},
|
|
Injected: false,
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
rec := &testutil.MockRecorder{}
|
|
base := &responsesInterceptionBase{
|
|
id: id,
|
|
recorder: rec,
|
|
logger: slog.Make(),
|
|
}
|
|
|
|
base.recordNonInjectedToolUsage(t.Context(), tc.response)
|
|
|
|
tools := rec.RecordedToolUsages()
|
|
require.Len(t, tools, len(tc.expected))
|
|
for i, got := range tools {
|
|
got.CreatedAt = time.Time{}
|
|
require.Equal(t, tc.expected[i], got)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestParseJSONArgs(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
raw string
|
|
expected recorder.ToolArgs
|
|
}{
|
|
{
|
|
name: "empty_string",
|
|
raw: "",
|
|
expected: "",
|
|
},
|
|
{
|
|
name: "whitespace_only",
|
|
raw: " \t\n ",
|
|
expected: "",
|
|
},
|
|
{
|
|
name: "invalid_json",
|
|
raw: "{not valid json}",
|
|
expected: "{not valid json}",
|
|
},
|
|
{
|
|
name: "nested_object_with_trailing_spaces",
|
|
raw: ` {"user": {"name": "alice", "settings": {"theme": "dark", "notifications": true}}, "count": 42} `,
|
|
expected: map[string]any{
|
|
"user": map[string]any{
|
|
"name": "alice",
|
|
"settings": map[string]any{
|
|
"theme": "dark",
|
|
"notifications": true,
|
|
},
|
|
},
|
|
"count": float64(42),
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
base := &responsesInterceptionBase{}
|
|
result := base.parseFunctionCallJSONArgs(t.Context(), tc.raw)
|
|
require.Equal(t, tc.expected, result)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestRecordTokenUsage(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
id := uuid.MustParse("22222222-2222-2222-2222-222222222222")
|
|
|
|
tests := []struct {
|
|
name string
|
|
response *oairesponses.Response
|
|
expected *recorder.TokenUsageRecord
|
|
}{
|
|
{
|
|
name: "nil_response",
|
|
response: nil,
|
|
expected: nil,
|
|
},
|
|
{
|
|
name: "with_all_token_details",
|
|
response: &oairesponses.Response{
|
|
ID: "resp_full",
|
|
Usage: oairesponses.ResponseUsage{
|
|
InputTokens: 10,
|
|
OutputTokens: 20,
|
|
TotalTokens: 30,
|
|
InputTokensDetails: oairesponses.ResponseUsageInputTokensDetails{
|
|
CachedTokens: 5,
|
|
},
|
|
OutputTokensDetails: oairesponses.ResponseUsageOutputTokensDetails{
|
|
ReasoningTokens: 5,
|
|
},
|
|
},
|
|
},
|
|
expected: &recorder.TokenUsageRecord{
|
|
InterceptionID: id.String(),
|
|
MsgID: "resp_full",
|
|
Input: 5, // 10 input - 5 cached
|
|
Output: 20,
|
|
CacheReadInputTokens: 5,
|
|
ExtraTokenTypes: map[string]int64{
|
|
"output_reasoning": 5,
|
|
"total_tokens": 30,
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
rec := &testutil.MockRecorder{}
|
|
base := &responsesInterceptionBase{
|
|
id: id,
|
|
recorder: rec,
|
|
logger: slog.Make(),
|
|
}
|
|
|
|
base.recordTokenUsage(t.Context(), tc.response)
|
|
|
|
tokens := rec.RecordedTokenUsages()
|
|
if tc.expected == nil {
|
|
require.Empty(t, tokens)
|
|
} else {
|
|
require.Len(t, tokens, 1)
|
|
got := tokens[0]
|
|
got.CreatedAt = time.Time{} // ignore time
|
|
require.Equal(t, tc.expected, got)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
type mockResponseWriter struct {
|
|
headerCalled bool
|
|
writeCalled bool
|
|
writeHeaderCalled bool
|
|
}
|
|
|
|
func (mrw *mockResponseWriter) Header() http.Header {
|
|
mrw.headerCalled = true
|
|
return http.Header{}
|
|
}
|
|
|
|
func (mrw *mockResponseWriter) Write([]byte) (int, error) {
|
|
mrw.writeCalled = true
|
|
return 0, nil
|
|
}
|
|
|
|
func (mrw *mockResponseWriter) WriteHeader(statusCode int) {
|
|
mrw.writeHeaderCalled = true
|
|
}
|
|
|
|
func TestResponseCopierDoesntSendIfNoResponseReceived(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
mrw := mockResponseWriter{}
|
|
|
|
respCopy := responseCopier{}
|
|
body := "test_body"
|
|
_, _ = respCopy.buff.Write([]byte(body)) // bytes.Buffer.Write never fails
|
|
|
|
err := respCopy.forwardResp(&mrw)
|
|
require.NoError(t, err)
|
|
require.False(t, mrw.headerCalled)
|
|
require.False(t, mrw.writeCalled)
|
|
require.False(t, mrw.writeHeaderCalled)
|
|
|
|
// after response is received data is forwarded
|
|
respCopy.responseReceived.Store(true)
|
|
|
|
err = respCopy.forwardResp(&mrw)
|
|
require.NoError(t, err)
|
|
require.True(t, mrw.headerCalled)
|
|
require.True(t, mrw.writeCalled)
|
|
require.True(t, mrw.writeHeaderCalled)
|
|
}
|
|
|
|
func TestProcessKeyPoolError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
err error
|
|
expectedNil bool
|
|
expectedStatus int
|
|
expectedRetryAfter time.Duration
|
|
}{
|
|
{
|
|
// Transient with valid keys present: 429, no Retry-After.
|
|
name: "transient_zero_retry_after",
|
|
err: &keypool.TransientKeyPoolError{},
|
|
expectedStatus: http.StatusTooManyRequests,
|
|
expectedRetryAfter: 0,
|
|
},
|
|
{
|
|
// Transient with cooldown: 429, Retry-After set.
|
|
name: "transient_with_retry_after",
|
|
err: &keypool.TransientKeyPoolError{RetryAfter: 5 * time.Second},
|
|
expectedStatus: http.StatusTooManyRequests,
|
|
expectedRetryAfter: 5 * time.Second,
|
|
},
|
|
{
|
|
// Permanent: 502 api_error.
|
|
name: "permanent_returns_502",
|
|
err: keypool.ErrPermanentKeyPool,
|
|
expectedStatus: http.StatusBadGateway,
|
|
},
|
|
{
|
|
// Anything else: not a pool-exhaustion error.
|
|
name: "non_pool_exhaustion_error_returns_nil",
|
|
err: xerrors.New("some other error"),
|
|
expectedNil: true,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
got := ProcessKeyPoolError(tc.err)
|
|
if tc.expectedNil {
|
|
require.Nil(t, got)
|
|
return
|
|
}
|
|
require.NotNil(t, got)
|
|
assert.Equal(t, tc.expectedStatus, got.StatusCode)
|
|
assert.Equal(t, tc.expectedRetryAfter, got.RetryAfter)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestMarkKeyOnError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
err error
|
|
expectedReturn bool
|
|
expectedState keypool.KeyState
|
|
}{
|
|
{
|
|
// Not an *openai.Error: no status code to act on.
|
|
name: "non_api_error_returns_false",
|
|
err: xerrors.New("network failure"),
|
|
expectedReturn: false,
|
|
expectedState: keypool.KeyStateValid,
|
|
},
|
|
{
|
|
// Rate-limited: temporary cooldown.
|
|
name: "429_marks_temporary",
|
|
err: &openai.Error{StatusCode: http.StatusTooManyRequests, Response: &http.Response{StatusCode: http.StatusTooManyRequests}},
|
|
expectedReturn: true,
|
|
expectedState: keypool.KeyStateTemporary,
|
|
},
|
|
{
|
|
// Auth failure: mark permanent.
|
|
name: "401_marks_permanent",
|
|
err: &openai.Error{StatusCode: http.StatusUnauthorized, Response: &http.Response{StatusCode: http.StatusUnauthorized}},
|
|
expectedReturn: true,
|
|
expectedState: keypool.KeyStatePermanent,
|
|
},
|
|
{
|
|
// Auth forbidden: mark permanent.
|
|
name: "403_marks_permanent",
|
|
err: &openai.Error{StatusCode: http.StatusForbidden, Response: &http.Response{StatusCode: http.StatusForbidden}},
|
|
expectedReturn: true,
|
|
expectedState: keypool.KeyStatePermanent,
|
|
},
|
|
{
|
|
// Server errors are not key-specific.
|
|
name: "500_does_not_mark",
|
|
err: &openai.Error{StatusCode: http.StatusInternalServerError, Response: &http.Response{StatusCode: http.StatusInternalServerError}},
|
|
expectedReturn: false,
|
|
expectedState: keypool.KeyStateValid,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
pool, err := keypool.New([]string{"key-0"}, quartz.NewMock(t))
|
|
require.NoError(t, err)
|
|
key, err := pool.Walker().Next()
|
|
require.NoError(t, err)
|
|
|
|
base := &responsesInterceptionBase{cfg: config.OpenAI{KeyPool: pool}, logger: slog.Make()}
|
|
|
|
got := base.markKeyOnError(context.Background(), key, tc.err)
|
|
assert.Equal(t, tc.expectedReturn, got)
|
|
assert.Equal(t, tc.expectedState, key.State())
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestWriteUpstreamError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
respErr *ResponseError
|
|
expectStatus int
|
|
// Empty string means the header should be absent.
|
|
expectRetryAfter string
|
|
// Substring expected in the marshaled body. Empty means no body check.
|
|
expectBodyContains string
|
|
}{
|
|
{
|
|
// Standard error: status, code, and JSON body written.
|
|
name: "writes_status_and_body",
|
|
respErr: newErrorResponse("upstream failed", "api_error", "server_error", http.StatusBadGateway, 0),
|
|
expectStatus: http.StatusBadGateway,
|
|
expectBodyContains: `"upstream failed"`,
|
|
},
|
|
{
|
|
// OpenAI envelope: the code field round-trips into the body.
|
|
name: "writes_code_field",
|
|
respErr: newErrorResponse("rate limited", "rate_limit_error", "rate_limit_exceeded", http.StatusTooManyRequests, 0),
|
|
expectStatus: http.StatusTooManyRequests,
|
|
expectBodyContains: `"rate_limit_exceeded"`,
|
|
},
|
|
{
|
|
// Whole-second retryAfter: emitted as integer seconds.
|
|
name: "retry_after_in_seconds",
|
|
respErr: newErrorResponse("rate limited", "rate_limit_error", "rate_limit_exceeded", http.StatusTooManyRequests, 60*time.Second),
|
|
expectStatus: http.StatusTooManyRequests,
|
|
expectRetryAfter: "60",
|
|
},
|
|
{
|
|
// 500ms rounds up to Retry-After: 1.
|
|
name: "retry_after_500ms_rounds_up_to_one",
|
|
respErr: newErrorResponse("rate limited", "rate_limit_error", "rate_limit_exceeded", http.StatusTooManyRequests, 500*time.Millisecond),
|
|
expectStatus: http.StatusTooManyRequests,
|
|
expectRetryAfter: "1",
|
|
},
|
|
{
|
|
// 200ms rounds up to Retry-After: 1.
|
|
name: "retry_after_200ms_rounds_up_to_one",
|
|
respErr: newErrorResponse("rate limited", "rate_limit_error", "rate_limit_exceeded", http.StatusTooManyRequests, 200*time.Millisecond),
|
|
expectStatus: http.StatusTooManyRequests,
|
|
expectRetryAfter: "1",
|
|
},
|
|
{
|
|
// Negative retryAfter: header omitted.
|
|
name: "negative_retry_after_omits_header",
|
|
respErr: newErrorResponse("rate limited", "rate_limit_error", "rate_limit_exceeded", http.StatusTooManyRequests, -1*time.Second),
|
|
expectStatus: http.StatusTooManyRequests,
|
|
expectRetryAfter: "",
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
base := &responsesInterceptionBase{logger: slog.Make()}
|
|
|
|
w := httptest.NewRecorder()
|
|
base.writeUpstreamError(w, tc.respErr)
|
|
|
|
assert.Equal(t, tc.expectStatus, w.Code, "status code")
|
|
assert.Equal(t, "application/json", w.Header().Get("Content-Type"), "Content-Type header")
|
|
assert.Equal(t, tc.expectRetryAfter, w.Header().Get("Retry-After"), "Retry-After header")
|
|
if tc.expectBodyContains != "" {
|
|
assert.Contains(t, w.Body.String(), tc.expectBodyContains, "response body")
|
|
}
|
|
})
|
|
}
|
|
}
|