mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
c650aabbef
My agent added `//nolint:testpackage` to a test file on one of my PRs. Again. This PR cleans it up across the entire repo and updates the in-repo conventions so future agents stop doing it. The repo already has a precedent for white-box tests that need to touch unexported symbols: `*_internal_test.go` (145+ existing files). The `testpackage` linter's default `skip-regexp` exempts that filename suffix, so the `//nolint:testpackage` directive is unnecessary in every case where someone reached for it. This PR renames 51 such files to `*_internal_test.go` via `git mv` so blame and history follow, and strips the dead directive from 2 files that were already correctly named (`coderd/oauth2provider/authorize_internal_test.go`, `coderd/x/chatd/advisor_internal_test.go`). `.claude/docs/TESTING.md` now documents the rule explicitly under *Test Package Naming*, which is imported into the root `AGENTS.md` via `@.claude/docs/TESTING.md`. The rule: prefer `package foo_test`; if you need internal access, rename the file to `*_internal_test.go` rather than adding a nolint directive.
1247 lines
42 KiB
Go
1247 lines
42 KiB
Go
package messages
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/anthropics/anthropic-sdk-go"
|
|
"github.com/anthropics/anthropic-sdk-go/shared/constant"
|
|
mcpgo "github.com/mark3labs/mcp-go/mcp"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"github.com/tidwall/gjson"
|
|
"golang.org/x/xerrors"
|
|
|
|
"cdr.dev/slog/v3"
|
|
"github.com/coder/coder/v2/aibridge/config"
|
|
"github.com/coder/coder/v2/aibridge/keypool"
|
|
"github.com/coder/coder/v2/aibridge/mcp"
|
|
"github.com/coder/coder/v2/aibridge/utils"
|
|
"github.com/coder/quartz"
|
|
)
|
|
|
|
func TestScanForCorrelatingToolCallID(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
requestBody string
|
|
expected *string
|
|
}{
|
|
{
|
|
name: "no messages field",
|
|
requestBody: `{}`,
|
|
expected: nil,
|
|
},
|
|
{
|
|
name: "messages string",
|
|
requestBody: `{"messages":"test"}`,
|
|
expected: nil,
|
|
},
|
|
{
|
|
name: "empty messages array",
|
|
requestBody: `{"messages":[]}`,
|
|
expected: nil,
|
|
},
|
|
{
|
|
name: "last message has no tool result blocks",
|
|
requestBody: `{"messages":[{"role":"user","content":"hello"}]}`,
|
|
expected: nil,
|
|
},
|
|
{
|
|
name: "single tool result block",
|
|
requestBody: `{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_abc","content":"result"}]}]}`,
|
|
expected: utils.PtrTo("toolu_abc"),
|
|
},
|
|
{
|
|
name: "multiple tool result blocks returns last",
|
|
requestBody: `{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_first","content":"first"},{"type":"text","text":"ignored"},{"type":"tool_result","tool_use_id":"toolu_second","content":"second"}]}]}`,
|
|
expected: utils.PtrTo("toolu_second"),
|
|
},
|
|
{
|
|
name: "last message is not a tool result",
|
|
requestBody: `{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_first","content":"first"}]},{"role":"user","content":"some text"}]}`,
|
|
expected: nil,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
base := &interceptionBase{
|
|
reqPayload: mustMessagesPayload(t, tc.requestBody),
|
|
}
|
|
|
|
require.Equal(t, tc.expected, base.CorrelatingToolCallID())
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAWSBedrockValidation(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
cfg *config.AWSBedrock
|
|
expectError bool
|
|
errorMsg string
|
|
}{
|
|
// Valid cases: static credentials.
|
|
{
|
|
name: "static credentials with region",
|
|
cfg: &config.AWSBedrock{
|
|
Region: "us-east-1",
|
|
AccessKey: "test-key",
|
|
AccessKeySecret: "test-secret",
|
|
Model: "test-model",
|
|
SmallFastModel: "test-small-model",
|
|
},
|
|
},
|
|
{
|
|
name: "static credentials with base url",
|
|
cfg: &config.AWSBedrock{
|
|
BaseURL: "http://bedrock.internal",
|
|
AccessKey: "test-key",
|
|
AccessKeySecret: "test-secret",
|
|
Model: "test-model",
|
|
SmallFastModel: "test-small-model",
|
|
},
|
|
},
|
|
{
|
|
// There unfortunately isn't a way for us to determine precedence in a unit test,
|
|
// since the produced options take a `requestconfig.RequestConfig` input value
|
|
// which is internal to the anthropic SDK.
|
|
//
|
|
// See TestAWSBedrockIntegration which validates this.
|
|
name: "static credentials with base url & region",
|
|
cfg: &config.AWSBedrock{
|
|
Region: "us-east-1",
|
|
AccessKey: "test-key",
|
|
AccessKeySecret: "test-secret",
|
|
Model: "test-model",
|
|
SmallFastModel: "test-small-model",
|
|
},
|
|
},
|
|
// Invalid cases.
|
|
{
|
|
name: "missing region & base url",
|
|
cfg: &config.AWSBedrock{
|
|
Region: "",
|
|
AccessKey: "test-key",
|
|
AccessKeySecret: "test-secret",
|
|
Model: "test-model",
|
|
SmallFastModel: "test-small-model",
|
|
},
|
|
expectError: true,
|
|
errorMsg: "region or base url required",
|
|
},
|
|
{
|
|
name: "missing access key",
|
|
cfg: &config.AWSBedrock{
|
|
Region: "us-east-1",
|
|
AccessKeySecret: "test-secret",
|
|
Model: "test-model",
|
|
SmallFastModel: "test-small-model",
|
|
},
|
|
expectError: true,
|
|
errorMsg: "both access key and access key secret must be provided together",
|
|
},
|
|
{
|
|
name: "missing access key secret",
|
|
cfg: &config.AWSBedrock{
|
|
Region: "us-east-1",
|
|
AccessKey: "test-key",
|
|
AccessKeySecret: "",
|
|
Model: "test-model",
|
|
SmallFastModel: "test-small-model",
|
|
},
|
|
expectError: true,
|
|
errorMsg: "both access key and access key secret must be provided together",
|
|
},
|
|
{
|
|
name: "missing model",
|
|
cfg: &config.AWSBedrock{
|
|
Region: "us-east-1",
|
|
AccessKey: "test-key",
|
|
AccessKeySecret: "test-secret",
|
|
Model: "",
|
|
SmallFastModel: "test-small-model",
|
|
},
|
|
expectError: true,
|
|
errorMsg: "model required",
|
|
},
|
|
{
|
|
name: "missing small fast model",
|
|
cfg: &config.AWSBedrock{
|
|
Region: "us-east-1",
|
|
AccessKey: "test-key",
|
|
AccessKeySecret: "test-secret",
|
|
Model: "test-model",
|
|
SmallFastModel: "",
|
|
},
|
|
expectError: true,
|
|
errorMsg: "small fast model required",
|
|
},
|
|
{
|
|
name: "all fields empty",
|
|
cfg: &config.AWSBedrock{},
|
|
expectError: true,
|
|
errorMsg: "region or base url required",
|
|
},
|
|
{
|
|
name: "nil config",
|
|
cfg: nil,
|
|
expectError: true,
|
|
errorMsg: "nil config given",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
base := &interceptionBase{}
|
|
opts, err := base.withAWSBedrockOptions(context.Background(), tt.cfg)
|
|
|
|
if tt.expectError {
|
|
require.Error(t, err)
|
|
require.Contains(t, err.Error(), tt.errorMsg)
|
|
} else {
|
|
require.NotEmpty(t, opts)
|
|
require.NoError(t, err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// TestAWSBedrockCredentialChain tests credential resolution via the AWS SDK default credential chain.
|
|
// NOTE: Cannot use t.Parallel() here because subtests use t.Setenv which requires sequential execution.
|
|
func TestAWSBedrockCredentialChain(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
cfg *config.AWSBedrock
|
|
envVars map[string]string
|
|
expectError bool
|
|
errorMsg string
|
|
}{
|
|
{
|
|
name: "temporary credentials via env",
|
|
cfg: &config.AWSBedrock{
|
|
Region: "us-east-1",
|
|
Model: "test-model",
|
|
SmallFastModel: "test-small-model",
|
|
},
|
|
envVars: map[string]string{
|
|
"AWS_ACCESS_KEY_ID": "test-key",
|
|
"AWS_SECRET_ACCESS_KEY": "test-secret",
|
|
},
|
|
},
|
|
{
|
|
name: "temporary credentials with session token via env",
|
|
cfg: &config.AWSBedrock{
|
|
Region: "us-east-1",
|
|
Model: "test-model",
|
|
SmallFastModel: "test-small-model",
|
|
},
|
|
envVars: map[string]string{
|
|
"AWS_ACCESS_KEY_ID": "test-key",
|
|
"AWS_SECRET_ACCESS_KEY": "test-secret",
|
|
"AWS_SESSION_TOKEN": "test-session-token",
|
|
},
|
|
},
|
|
{
|
|
// When static credentials are not provided and no environment credentials are set,
|
|
// the SDK default credential chain fails to resolve credentials.
|
|
name: "error when no credential source is configured",
|
|
cfg: &config.AWSBedrock{
|
|
Region: "us-east-1",
|
|
Model: "test-model",
|
|
SmallFastModel: "test-small-model",
|
|
},
|
|
envVars: map[string]string{
|
|
"AWS_ACCESS_KEY_ID": "",
|
|
"AWS_SECRET_ACCESS_KEY": "",
|
|
"AWS_SESSION_TOKEN": "",
|
|
"AWS_PROFILE": "",
|
|
"AWS_SHARED_CREDENTIALS_FILE": "/dev/null",
|
|
"AWS_CONFIG_FILE": "/dev/null",
|
|
"AWS_WEB_IDENTITY_TOKEN_FILE": "",
|
|
"AWS_ROLE_ARN": "",
|
|
"AWS_ROLE_SESSION_NAME": "",
|
|
"AWS_CONTAINER_CREDENTIALS_RELATIVE_URI": "",
|
|
"AWS_CONTAINER_CREDENTIALS_FULL_URI": "",
|
|
"AWS_CONTAINER_AUTHORIZATION_TOKEN": "",
|
|
"AWS_EC2_METADATA_DISABLED": "true",
|
|
},
|
|
expectError: true,
|
|
errorMsg: "no AWS credentials found",
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
for key, val := range tt.envVars {
|
|
t.Setenv(key, val)
|
|
}
|
|
base := &interceptionBase{}
|
|
opts, err := base.withAWSBedrockOptions(context.Background(), tt.cfg)
|
|
|
|
if tt.expectError {
|
|
require.Error(t, err)
|
|
require.Contains(t, err.Error(), tt.errorMsg)
|
|
} else {
|
|
require.NotEmpty(t, opts)
|
|
require.NoError(t, err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAccumulateUsage(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("Usage to Usage", func(t *testing.T) {
|
|
t.Parallel()
|
|
dest := &anthropic.Usage{
|
|
InputTokens: 10,
|
|
OutputTokens: 20,
|
|
CacheCreationInputTokens: 5,
|
|
CacheReadInputTokens: 3,
|
|
CacheCreation: anthropic.CacheCreation{
|
|
Ephemeral1hInputTokens: 2,
|
|
Ephemeral5mInputTokens: 1,
|
|
},
|
|
ServerToolUse: anthropic.ServerToolUsage{
|
|
WebSearchRequests: 1,
|
|
},
|
|
}
|
|
|
|
source := anthropic.Usage{
|
|
InputTokens: 15,
|
|
OutputTokens: 25,
|
|
CacheCreationInputTokens: 8,
|
|
CacheReadInputTokens: 4,
|
|
CacheCreation: anthropic.CacheCreation{
|
|
Ephemeral1hInputTokens: 3,
|
|
Ephemeral5mInputTokens: 2,
|
|
},
|
|
ServerToolUse: anthropic.ServerToolUsage{
|
|
WebSearchRequests: 2,
|
|
},
|
|
}
|
|
|
|
accumulateUsage(dest, source)
|
|
|
|
require.EqualValues(t, 25, dest.InputTokens)
|
|
require.EqualValues(t, 45, dest.OutputTokens)
|
|
require.EqualValues(t, 13, dest.CacheCreationInputTokens)
|
|
require.EqualValues(t, 7, dest.CacheReadInputTokens)
|
|
require.EqualValues(t, 5, dest.CacheCreation.Ephemeral1hInputTokens)
|
|
require.EqualValues(t, 3, dest.CacheCreation.Ephemeral5mInputTokens)
|
|
require.EqualValues(t, 3, dest.ServerToolUse.WebSearchRequests)
|
|
})
|
|
|
|
t.Run("MessageDeltaUsage to MessageDeltaUsage", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dest := &anthropic.MessageDeltaUsage{
|
|
InputTokens: 10,
|
|
OutputTokens: 20,
|
|
CacheCreationInputTokens: 5,
|
|
CacheReadInputTokens: 3,
|
|
ServerToolUse: anthropic.ServerToolUsage{
|
|
WebSearchRequests: 1,
|
|
},
|
|
}
|
|
|
|
source := anthropic.MessageDeltaUsage{
|
|
InputTokens: 15,
|
|
OutputTokens: 25,
|
|
CacheCreationInputTokens: 8,
|
|
CacheReadInputTokens: 4,
|
|
ServerToolUse: anthropic.ServerToolUsage{
|
|
WebSearchRequests: 2,
|
|
},
|
|
}
|
|
|
|
accumulateUsage(dest, source)
|
|
|
|
require.EqualValues(t, 25, dest.InputTokens)
|
|
require.EqualValues(t, 45, dest.OutputTokens)
|
|
require.EqualValues(t, 13, dest.CacheCreationInputTokens)
|
|
require.EqualValues(t, 7, dest.CacheReadInputTokens)
|
|
require.EqualValues(t, 3, dest.ServerToolUse.WebSearchRequests)
|
|
})
|
|
|
|
t.Run("Usage to MessageDeltaUsage", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dest := &anthropic.MessageDeltaUsage{
|
|
InputTokens: 10,
|
|
OutputTokens: 20,
|
|
CacheCreationInputTokens: 5,
|
|
CacheReadInputTokens: 3,
|
|
ServerToolUse: anthropic.ServerToolUsage{
|
|
WebSearchRequests: 1,
|
|
},
|
|
}
|
|
|
|
source := anthropic.Usage{
|
|
InputTokens: 15,
|
|
OutputTokens: 25,
|
|
CacheCreationInputTokens: 8,
|
|
CacheReadInputTokens: 4,
|
|
CacheCreation: anthropic.CacheCreation{
|
|
Ephemeral1hInputTokens: 3, // These won't be accumulated to MessageDeltaUsage
|
|
Ephemeral5mInputTokens: 2,
|
|
},
|
|
ServerToolUse: anthropic.ServerToolUsage{
|
|
WebSearchRequests: 2,
|
|
},
|
|
}
|
|
|
|
accumulateUsage(dest, source)
|
|
|
|
require.EqualValues(t, 25, dest.InputTokens)
|
|
require.EqualValues(t, 45, dest.OutputTokens)
|
|
require.EqualValues(t, 13, dest.CacheCreationInputTokens)
|
|
require.EqualValues(t, 7, dest.CacheReadInputTokens)
|
|
require.EqualValues(t, 3, dest.ServerToolUse.WebSearchRequests)
|
|
})
|
|
|
|
t.Run("MessageDeltaUsage to Usage", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dest := &anthropic.Usage{
|
|
InputTokens: 10,
|
|
OutputTokens: 20,
|
|
CacheCreationInputTokens: 5,
|
|
CacheReadInputTokens: 3,
|
|
CacheCreation: anthropic.CacheCreation{
|
|
Ephemeral1hInputTokens: 2,
|
|
Ephemeral5mInputTokens: 1,
|
|
},
|
|
ServerToolUse: anthropic.ServerToolUsage{
|
|
WebSearchRequests: 1,
|
|
},
|
|
}
|
|
|
|
source := anthropic.MessageDeltaUsage{
|
|
InputTokens: 15,
|
|
OutputTokens: 25,
|
|
CacheCreationInputTokens: 8,
|
|
CacheReadInputTokens: 4,
|
|
ServerToolUse: anthropic.ServerToolUsage{
|
|
WebSearchRequests: 2,
|
|
},
|
|
}
|
|
|
|
accumulateUsage(dest, source)
|
|
|
|
require.EqualValues(t, 25, dest.InputTokens)
|
|
require.EqualValues(t, 45, dest.OutputTokens)
|
|
require.EqualValues(t, 13, dest.CacheCreationInputTokens)
|
|
require.EqualValues(t, 7, dest.CacheReadInputTokens)
|
|
// Ephemeral tokens remain unchanged since MessageDeltaUsage doesn't have them
|
|
require.EqualValues(t, 2, dest.CacheCreation.Ephemeral1hInputTokens)
|
|
require.EqualValues(t, 1, dest.CacheCreation.Ephemeral5mInputTokens)
|
|
require.EqualValues(t, 3, dest.ServerToolUse.WebSearchRequests)
|
|
})
|
|
|
|
t.Run("Nil or unsupported types", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// Test with nil dest
|
|
var nilUsage *anthropic.Usage
|
|
source := anthropic.Usage{InputTokens: 10}
|
|
accumulateUsage(nilUsage, source) // Should not panic
|
|
|
|
// Test with unsupported types
|
|
var unsupported string
|
|
accumulateUsage(&unsupported, source) // Should not panic, just do nothing
|
|
})
|
|
}
|
|
|
|
func TestInjectTools_CacheBreakpoints(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("cache control preserved when no tools to inject", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// Request has existing tool with cache control, but no tools to inject.
|
|
i := &interceptionBase{
|
|
reqPayload: mustMessagesPayload(t, `{"tools":[`+
|
|
`{"name":"existing_tool","type":"custom","input_schema":{"type":"object","properties":{}},"cache_control":{"type":"ephemeral"}}]}`),
|
|
mcpProxy: &mockServerProxier{tools: nil},
|
|
logger: slog.Make(),
|
|
}
|
|
|
|
i.injectTools()
|
|
|
|
// Cache control should remain untouched since no tools were injected.
|
|
toolItems := gjson.GetBytes(i.reqPayload, "tools").Array()
|
|
require.Len(t, toolItems, 1)
|
|
require.Equal(t, "existing_tool", toolItems[0].Get("name").String())
|
|
require.Equal(t, string(constant.ValueOf[constant.Ephemeral]()), toolItems[0].Get("cache_control.type").String())
|
|
})
|
|
|
|
t.Run("cache control breakpoint is preserved by prepending injected tools", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// Request has existing tool with cache control.
|
|
i := &interceptionBase{
|
|
reqPayload: mustMessagesPayload(t, `{"tools":[`+
|
|
`{"name":"existing_tool","type":"custom","input_schema":{"type":"object","properties":{}},"cache_control":{"type":"ephemeral"}}]}`),
|
|
mcpProxy: &mockServerProxier{
|
|
tools: []*mcp.Tool{
|
|
{ID: "injected_tool", Name: "injected", Description: "Injected tool"},
|
|
},
|
|
},
|
|
logger: slog.Make(),
|
|
}
|
|
|
|
i.injectTools()
|
|
|
|
toolItems := gjson.GetBytes(i.reqPayload, "tools").Array()
|
|
require.Len(t, toolItems, 2)
|
|
// Injected tools are prepended.
|
|
require.Equal(t, "injected_tool", toolItems[0].Get("name").String())
|
|
require.Empty(t, toolItems[0].Get("cache_control.type").String())
|
|
// Original tool's cache control should be preserved at the end.
|
|
require.Equal(t, "existing_tool", toolItems[1].Get("name").String())
|
|
require.Equal(t, string(constant.ValueOf[constant.Ephemeral]()), toolItems[1].Get("cache_control.type").String())
|
|
})
|
|
|
|
// The cache breakpoint SHOULD be on the final tool, but may not be; we must preserve that intention.
|
|
t.Run("cache control breakpoint in non-standard location is preserved", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// Request has multiple tools with cache control breakpoints.
|
|
i := &interceptionBase{
|
|
reqPayload: mustMessagesPayload(t, `{"tools":[`+
|
|
`{"name":"tool_with_cache_1","type":"custom","input_schema":{"type":"object","properties":{}},"cache_control":{"type":"ephemeral"}},`+
|
|
`{"name":"tool_with_cache_2","type":"custom","input_schema":{"type":"object","properties":{}}}]}`),
|
|
mcpProxy: &mockServerProxier{
|
|
tools: []*mcp.Tool{
|
|
{ID: "injected_tool", Name: "injected", Description: "Injected tool"},
|
|
},
|
|
},
|
|
logger: slog.Make(),
|
|
}
|
|
|
|
i.injectTools()
|
|
|
|
toolItems := gjson.GetBytes(i.reqPayload, "tools").Array()
|
|
require.Len(t, toolItems, 3)
|
|
// Injected tool is prepended without cache control.
|
|
require.Equal(t, "injected_tool", toolItems[0].Get("name").String())
|
|
require.Empty(t, toolItems[0].Get("cache_control.type").String())
|
|
// Both original tools' cache controls should remain.
|
|
require.Equal(t, "tool_with_cache_1", toolItems[1].Get("name").String())
|
|
require.Equal(t, string(constant.ValueOf[constant.Ephemeral]()), toolItems[1].Get("cache_control.type").String())
|
|
require.Equal(t, "tool_with_cache_2", toolItems[2].Get("name").String())
|
|
require.Empty(t, toolItems[2].Get("cache_control.type").String())
|
|
})
|
|
|
|
t.Run("no cache control added when none originally set", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
// Request has tools but none with cache control.
|
|
i := &interceptionBase{
|
|
reqPayload: mustMessagesPayload(t, `{"tools":[`+
|
|
`{"name":"existing_tool_no_cache","type":"custom","input_schema":{"type":"object","properties":{}}}]}`),
|
|
mcpProxy: &mockServerProxier{
|
|
tools: []*mcp.Tool{
|
|
{ID: "injected_tool", Name: "injected", Description: "Injected tool"},
|
|
},
|
|
},
|
|
logger: slog.Make(),
|
|
}
|
|
|
|
i.injectTools()
|
|
|
|
toolItems := gjson.GetBytes(i.reqPayload, "tools").Array()
|
|
require.Len(t, toolItems, 2)
|
|
// Injected tool is prepended without cache control.
|
|
require.Equal(t, "injected_tool", toolItems[0].Get("name").String())
|
|
require.Empty(t, toolItems[0].Get("cache_control.type").String())
|
|
// Original tool remains at the end without cache control.
|
|
require.Equal(t, "existing_tool_no_cache", toolItems[1].Get("name").String())
|
|
require.Empty(t, toolItems[1].Get("cache_control.type").String())
|
|
})
|
|
}
|
|
|
|
func TestInjectTools_ParallelToolCalls(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("does not modify tool choice when no tools to inject", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
i := &interceptionBase{
|
|
reqPayload: mustMessagesPayload(t, `{"tool_choice":{"type":"auto"}}`),
|
|
mcpProxy: &mockServerProxier{tools: nil}, // No tools to inject.
|
|
logger: slog.Make(),
|
|
}
|
|
|
|
i.injectTools()
|
|
|
|
// Tool choice should remain unchanged - DisableParallelToolUse should not be set.
|
|
toolChoice := gjson.GetBytes(i.reqPayload, "tool_choice")
|
|
require.Equal(t, string(constant.ValueOf[constant.Auto]()), toolChoice.Get("type").String())
|
|
require.False(t, toolChoice.Get("disable_parallel_tool_use").Exists())
|
|
})
|
|
|
|
t.Run("disables parallel tool use for empty tool choice (default)", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
i := &interceptionBase{
|
|
reqPayload: mustMessagesPayload(t, `{}`),
|
|
mcpProxy: &mockServerProxier{
|
|
tools: []*mcp.Tool{{ID: "test_tool", Name: "test", Description: "Test"}},
|
|
},
|
|
logger: slog.Make(),
|
|
}
|
|
|
|
i.injectTools()
|
|
|
|
toolChoice := gjson.GetBytes(i.reqPayload, "tool_choice")
|
|
require.Equal(t, string(constant.ValueOf[constant.Auto]()), toolChoice.Get("type").String())
|
|
require.True(t, toolChoice.Get("disable_parallel_tool_use").Exists())
|
|
require.True(t, toolChoice.Get("disable_parallel_tool_use").Bool())
|
|
})
|
|
|
|
t.Run("disables parallel tool use for explicit auto tool choice", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
i := &interceptionBase{
|
|
reqPayload: mustMessagesPayload(t, `{"tool_choice":{"type":"auto"}}`),
|
|
mcpProxy: &mockServerProxier{
|
|
tools: []*mcp.Tool{{ID: "test_tool", Name: "test", Description: "Test"}},
|
|
},
|
|
logger: slog.Make(),
|
|
}
|
|
|
|
i.injectTools()
|
|
|
|
toolChoice := gjson.GetBytes(i.reqPayload, "tool_choice")
|
|
require.Equal(t, string(constant.ValueOf[constant.Auto]()), toolChoice.Get("type").String())
|
|
require.True(t, toolChoice.Get("disable_parallel_tool_use").Exists())
|
|
require.True(t, toolChoice.Get("disable_parallel_tool_use").Bool())
|
|
})
|
|
|
|
t.Run("disables parallel tool use for any tool choice", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
i := &interceptionBase{
|
|
reqPayload: mustMessagesPayload(t, `{"tool_choice":{"type":"any"}}`),
|
|
mcpProxy: &mockServerProxier{
|
|
tools: []*mcp.Tool{{ID: "test_tool", Name: "test", Description: "Test"}},
|
|
},
|
|
logger: slog.Make(),
|
|
}
|
|
|
|
i.injectTools()
|
|
|
|
toolChoice := gjson.GetBytes(i.reqPayload, "tool_choice")
|
|
require.Equal(t, string(constant.ValueOf[constant.Any]()), toolChoice.Get("type").String())
|
|
require.True(t, toolChoice.Get("disable_parallel_tool_use").Exists())
|
|
require.True(t, toolChoice.Get("disable_parallel_tool_use").Bool())
|
|
})
|
|
|
|
t.Run("disables parallel tool use for tool choice type", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
i := &interceptionBase{
|
|
reqPayload: mustMessagesPayload(t, `{"tool_choice":{"type":"tool","name":"specific_tool"}}`),
|
|
mcpProxy: &mockServerProxier{
|
|
tools: []*mcp.Tool{{ID: "test_tool", Name: "test", Description: "Test"}},
|
|
},
|
|
logger: slog.Make(),
|
|
}
|
|
|
|
i.injectTools()
|
|
|
|
toolChoice := gjson.GetBytes(i.reqPayload, "tool_choice")
|
|
require.Equal(t, string(constant.ValueOf[constant.Tool]()), toolChoice.Get("type").String())
|
|
require.True(t, toolChoice.Get("disable_parallel_tool_use").Exists())
|
|
require.True(t, toolChoice.Get("disable_parallel_tool_use").Bool())
|
|
})
|
|
|
|
t.Run("no-op for none tool choice type", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
i := &interceptionBase{
|
|
reqPayload: mustMessagesPayload(t, `{"tool_choice":{"type":"none"}}`),
|
|
mcpProxy: &mockServerProxier{
|
|
tools: []*mcp.Tool{{ID: "test_tool", Name: "test", Description: "Test"}},
|
|
},
|
|
logger: slog.Make(),
|
|
}
|
|
|
|
i.injectTools()
|
|
|
|
// Tools are still injected.
|
|
require.Len(t, gjson.GetBytes(i.reqPayload, "tools").Array(), 1)
|
|
// But no parallel tool use modification for "none" type.
|
|
toolChoice := gjson.GetBytes(i.reqPayload, "tool_choice")
|
|
require.Equal(t, string(constant.ValueOf[constant.None]()), toolChoice.Get("type").String())
|
|
require.False(t, toolChoice.Get("disable_parallel_tool_use").Exists())
|
|
})
|
|
}
|
|
|
|
func TestAugmentRequestForBedrock_AdaptiveThinking(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
|
|
bedrockModel string
|
|
requestBody string
|
|
clientBetaFlags string
|
|
|
|
expectThinkingType string
|
|
expectBudgetTokens int64 // 0 means budget_tokens should not be present
|
|
expectEffort string // expected output_config.effort; "" means must not be present
|
|
expectRemovedFields []string
|
|
expectKeptFields []string
|
|
expectBetaValues []string // expected separate Anthropic-Beta header values
|
|
}{
|
|
{
|
|
name: "non_4_6_model_with_adaptive_thinking_gets_converted",
|
|
bedrockModel: "anthropic.claude-sonnet-4-5-20250929-v1:0",
|
|
requestBody: `{"max_tokens":10000,"thinking":{"type":"adaptive"}}`,
|
|
expectThinkingType: "enabled",
|
|
expectBudgetTokens: 8000, // 10000 * 0.8 (default/high effort)
|
|
},
|
|
{
|
|
name: "non_4_6_model_with_adaptive_thinking_and_small_max_tokens_disables_thinking",
|
|
bedrockModel: "anthropic.claude-sonnet-4-5-20250929-v1:0",
|
|
requestBody: `{"max_tokens":1000,"thinking":{"type":"adaptive"}}`,
|
|
expectThinkingType: "disabled",
|
|
},
|
|
{
|
|
name: "opus_4_6_model_with_adaptive_thinking_is_not_converted",
|
|
bedrockModel: "anthropic.claude-opus-4-6-v1",
|
|
requestBody: `{"max_tokens":10000,"thinking":{"type":"adaptive"}}`,
|
|
expectThinkingType: "adaptive",
|
|
},
|
|
{
|
|
name: "sonnet_4_6_model_with_adaptive_thinking_is_not_converted",
|
|
bedrockModel: "anthropic.claude-sonnet-4-6",
|
|
requestBody: `{"max_tokens":10000,"thinking":{"type":"adaptive"}}`,
|
|
expectThinkingType: "adaptive",
|
|
},
|
|
{
|
|
name: "non_4_6_model_with_no_thinking_field_is_unchanged",
|
|
bedrockModel: "anthropic.claude-sonnet-4-5-20250929-v1:0",
|
|
requestBody: `{"max_tokens":10000}`,
|
|
},
|
|
{
|
|
name: "non_4_6_model_with_enabled_thinking_is_unchanged",
|
|
bedrockModel: "anthropic.claude-sonnet-4-5-20250929-v1:0",
|
|
requestBody: `{"max_tokens":10000,"thinking":{"type":"enabled","budget_tokens":5000}}`,
|
|
expectThinkingType: "enabled",
|
|
expectBudgetTokens: 5000,
|
|
},
|
|
{
|
|
name: "output_config_stripped_without_beta_flag_and_effort_used_for_budget",
|
|
bedrockModel: "anthropic.claude-sonnet-4-5-20250929-v1:0",
|
|
requestBody: `{"max_tokens":10000,"thinking":{"type":"adaptive"},"output_config":{"effort":"low"}}`,
|
|
expectThinkingType: "enabled",
|
|
expectBudgetTokens: 2000, // 10000 * 0.2 (low effort)
|
|
expectRemovedFields: []string{"output_config"},
|
|
},
|
|
{
|
|
name: "output_config_kept_when_effort_beta_flag_present_on_opus_4_5",
|
|
bedrockModel: "anthropic.claude-opus-4-5-20250929-v1:0",
|
|
clientBetaFlags: "effort-2025-11-24,interleaved-thinking-2025-05-14",
|
|
requestBody: `{"max_tokens":10000,"output_config":{"effort":"high"}}`,
|
|
expectEffort: "high",
|
|
expectKeptFields: []string{"output_config"},
|
|
expectBetaValues: []string{"effort-2025-11-24", "interleaved-thinking-2025-05-14"},
|
|
},
|
|
{
|
|
name: "output_config_stripped_for_non_opus_4_5_even_with_effort_beta_flag",
|
|
bedrockModel: "anthropic.claude-sonnet-4-5-20250929-v1:0",
|
|
clientBetaFlags: "effort-2025-11-24,interleaved-thinking-2025-05-14",
|
|
requestBody: `{"max_tokens":10000,"output_config":{"effort":"high"}}`,
|
|
expectRemovedFields: []string{"output_config"},
|
|
expectBetaValues: []string{"interleaved-thinking-2025-05-14"},
|
|
},
|
|
{
|
|
name: "context_management_kept_when_beta_flag_present",
|
|
bedrockModel: "anthropic.claude-sonnet-4-5-20250929-v1:0",
|
|
clientBetaFlags: "context-management-2025-06-27",
|
|
requestBody: `{"max_tokens":10000,"context_management":{"type":"auto"}}`,
|
|
expectKeptFields: []string{"context_management"},
|
|
expectBetaValues: []string{"context-management-2025-06-27"},
|
|
},
|
|
{
|
|
name: "context_management_stripped_without_beta_flag",
|
|
bedrockModel: "anthropic.claude-sonnet-4-5-20250929-v1:0",
|
|
requestBody: `{"max_tokens":10000,"context_management":{"type":"auto"}}`,
|
|
expectRemovedFields: []string{"context_management"},
|
|
},
|
|
{
|
|
name: "context_management_stripped_for_unsupported_model_even_with_beta_flag",
|
|
bedrockModel: "anthropic.claude-opus-4-6-v1",
|
|
clientBetaFlags: "context-management-2025-06-27",
|
|
requestBody: `{"max_tokens":10000,"thinking":{"type":"adaptive"},"context_management":{"type":"auto"}}`,
|
|
expectThinkingType: "adaptive",
|
|
expectRemovedFields: []string{"context_management"},
|
|
},
|
|
{
|
|
name: "unsupported_beta_flags_are_filtered_out",
|
|
bedrockModel: "anthropic.claude-sonnet-4-5-20250929-v1:0",
|
|
clientBetaFlags: "claude-code-20250219,interleaved-thinking-2025-05-14,prompt-caching-scope-2026-01-05",
|
|
requestBody: `{"max_tokens":10000}`,
|
|
expectBetaValues: []string{"interleaved-thinking-2025-05-14"},
|
|
},
|
|
{
|
|
name: "all_unsupported_fields_stripped_and_beta_flags_filtered",
|
|
bedrockModel: "anthropic.claude-sonnet-4-5-20250929-v1:0",
|
|
clientBetaFlags: "claude-code-20250219,prompt-caching-scope-2026-01-05",
|
|
requestBody: `{"max_tokens":10000,"output_config":{"effort":"high"},"metadata":{"user_id":"u123"},"service_tier":"auto","container":"ctr_abc","inference_geo":"us","context_management":{"type":"auto"}}`,
|
|
expectRemovedFields: []string{"output_config", "metadata", "service_tier", "container", "inference_geo", "context_management"},
|
|
},
|
|
|
|
// Adaptive-only models (Opus 4.7+), see coder/aibridge#280. The
|
|
// conversion drops budget_tokens and flips the type; an explicit
|
|
// output_config.effort from the caller is preserved, but none is
|
|
// fabricated when absent.
|
|
{
|
|
name: "opus_4_7_model_with_enabled_thinking_is_converted_to_adaptive_and_drops_budget",
|
|
bedrockModel: "us.anthropic.claude-opus-4-7",
|
|
requestBody: `{"max_tokens":10000,"thinking":{"type":"enabled","budget_tokens":5000}}`,
|
|
expectThinkingType: "adaptive",
|
|
},
|
|
{
|
|
name: "opus_4_7_model_with_adaptive_thinking_is_unchanged",
|
|
bedrockModel: "us.anthropic.claude-opus-4-7",
|
|
requestBody: `{"max_tokens":10000,"thinking":{"type":"adaptive"}}`,
|
|
expectThinkingType: "adaptive",
|
|
},
|
|
{
|
|
name: "opus_4_7_model_without_thinking_field_is_unchanged",
|
|
bedrockModel: "us.anthropic.claude-opus-4-7",
|
|
requestBody: `{"max_tokens":10000}`,
|
|
},
|
|
{
|
|
name: "opus_4_7_model_preserves_explicit_output_config_effort",
|
|
bedrockModel: "us.anthropic.claude-opus-4-7",
|
|
requestBody: `{"max_tokens":10000,"thinking":{"type":"enabled","budget_tokens":2000},"output_config":{"effort":"max"}}`,
|
|
expectThinkingType: "adaptive",
|
|
expectEffort: "max",
|
|
expectKeptFields: []string{"output_config"},
|
|
},
|
|
{
|
|
name: "opus_4_7_model_keeps_output_config_without_effort_beta_flag",
|
|
bedrockModel: "us.anthropic.claude-opus-4-7",
|
|
requestBody: `{"max_tokens":10000,"thinking":{"type":"adaptive"},"output_config":{"effort":"high"}}`,
|
|
expectThinkingType: "adaptive",
|
|
expectEffort: "high",
|
|
expectKeptFields: []string{"output_config"},
|
|
},
|
|
{
|
|
name: "arn_style_opus_4_7_application_inference_profile_is_treated_as_adaptive_only",
|
|
bedrockModel: "arn:aws:bedrock:us-east-1:123:application-inference-profile/global.anthropic.claude-opus-4-7",
|
|
requestBody: `{"max_tokens":10000,"thinking":{"type":"enabled","budget_tokens":8000}}`,
|
|
expectThinkingType: "adaptive",
|
|
},
|
|
{
|
|
// Opus 4.7 on Bedrock rejects output_config.format (structured
|
|
// outputs) with a 400 even though it accepts output_config.effort.
|
|
name: "opus_4_7_model_strips_output_config_format_but_keeps_effort",
|
|
bedrockModel: "us.anthropic.claude-opus-4-7",
|
|
requestBody: `{"max_tokens":10000,"output_config":{"effort":"high","format":{"type":"json_schema","schema":{"type":"object"}}}}`,
|
|
expectEffort: "high",
|
|
expectKeptFields: []string{"output_config", "output_config.effort"},
|
|
expectRemovedFields: []string{"output_config.format"},
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var clientHeaders http.Header
|
|
if tc.clientBetaFlags != "" {
|
|
clientHeaders = http.Header{
|
|
"Anthropic-Beta": {tc.clientBetaFlags},
|
|
}
|
|
}
|
|
|
|
i := &interceptionBase{
|
|
reqPayload: mustMessagesPayload(t, tc.requestBody),
|
|
bedrockCfg: &config.AWSBedrock{
|
|
Model: tc.bedrockModel,
|
|
SmallFastModel: "anthropic.claude-haiku-3-5",
|
|
},
|
|
clientHeaders: clientHeaders,
|
|
logger: slog.Make(),
|
|
}
|
|
|
|
i.augmentRequestForBedrock()
|
|
|
|
thinkingType := gjson.GetBytes(i.reqPayload, "thinking.type")
|
|
if tc.expectThinkingType == "" {
|
|
require.False(t, thinkingType.Exists())
|
|
} else {
|
|
require.Equal(t, tc.expectThinkingType, thinkingType.String())
|
|
}
|
|
|
|
budgetTokens := gjson.GetBytes(i.reqPayload, "thinking.budget_tokens")
|
|
if tc.expectBudgetTokens == 0 {
|
|
require.False(t, budgetTokens.Exists(), "budget_tokens should not be set")
|
|
} else {
|
|
require.Equal(t, tc.expectBudgetTokens, budgetTokens.Int())
|
|
}
|
|
|
|
// Model should always be set to the bedrock model.
|
|
require.Equal(t, tc.bedrockModel, gjson.GetBytes(i.reqPayload, "model").String())
|
|
|
|
// Verify expected fields are removed.
|
|
for _, field := range tc.expectRemovedFields {
|
|
require.False(t, gjson.GetBytes(i.reqPayload, field).Exists(), "%s should be removed", field)
|
|
}
|
|
|
|
// Verify expected fields are kept.
|
|
for _, field := range tc.expectKeptFields {
|
|
require.True(t, gjson.GetBytes(i.reqPayload, field).Exists(), "%s should be kept", field)
|
|
}
|
|
|
|
effort := gjson.GetBytes(i.reqPayload, "output_config.effort")
|
|
if tc.expectEffort == "" {
|
|
require.False(t, effort.Exists(), "output_config.effort should not be set")
|
|
} else {
|
|
require.Equal(t, tc.expectEffort, effort.String())
|
|
}
|
|
|
|
got := clientHeaders.Values("Anthropic-Beta")
|
|
require.Equal(t, tc.expectBetaValues, got)
|
|
})
|
|
}
|
|
}
|
|
|
|
func mustMessagesPayload(t *testing.T, requestBody string) RequestPayload {
|
|
t.Helper()
|
|
|
|
payload, err := NewRequestPayload([]byte(requestBody))
|
|
require.NoError(t, err)
|
|
|
|
return payload
|
|
}
|
|
|
|
// mockServerProxier is a test implementation of mcp.ServerProxier.
|
|
type mockServerProxier struct {
|
|
tools []*mcp.Tool
|
|
}
|
|
|
|
func (*mockServerProxier) Init(context.Context) error {
|
|
return nil
|
|
}
|
|
|
|
func (*mockServerProxier) Shutdown(context.Context) error {
|
|
return nil
|
|
}
|
|
|
|
func (m *mockServerProxier) ListTools() []*mcp.Tool {
|
|
return m.tools
|
|
}
|
|
|
|
func (m *mockServerProxier) GetTool(id string) *mcp.Tool {
|
|
for _, t := range m.tools {
|
|
if t.ID == id {
|
|
return t
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (*mockServerProxier) CallTool(context.Context, string, any) (*mcpgo.CallToolResult, error) {
|
|
return nil, nil //nolint:nilnil // mock: no-op implementation
|
|
}
|
|
|
|
func TestFilterBedrockBetaFlags(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
model string
|
|
inputValues []string // header values to set (each element is a separate header value)
|
|
expectValues []string // expected separate header values after filtering
|
|
}{
|
|
{
|
|
name: "empty header",
|
|
model: "anthropic.claude-sonnet-4-5-20250929-v1:0",
|
|
inputValues: nil,
|
|
expectValues: nil,
|
|
},
|
|
{
|
|
name: "all supported flags kept",
|
|
model: "anthropic.claude-opus-4-5-20250929-v1:0",
|
|
inputValues: []string{"interleaved-thinking-2025-05-14,effort-2025-11-24"},
|
|
expectValues: []string{"interleaved-thinking-2025-05-14", "effort-2025-11-24"},
|
|
},
|
|
{
|
|
name: "unsupported flags removed",
|
|
model: "anthropic.claude-sonnet-4-5-20250929-v1:0",
|
|
inputValues: []string{"claude-code-20250219,interleaved-thinking-2025-05-14,prompt-caching-scope-2026-01-05"},
|
|
expectValues: []string{"interleaved-thinking-2025-05-14"},
|
|
},
|
|
{
|
|
name: "header removed when all flags unsupported",
|
|
model: "anthropic.claude-sonnet-4-5-20250929-v1:0",
|
|
inputValues: []string{"claude-code-20250219,prompt-caching-scope-2026-01-05"},
|
|
expectValues: nil,
|
|
},
|
|
{
|
|
name: "effort flag removed for non opus 4.5 model",
|
|
model: "anthropic.claude-sonnet-4-5-20250929-v1:0",
|
|
inputValues: []string{"effort-2025-11-24,interleaved-thinking-2025-05-14"},
|
|
expectValues: []string{"interleaved-thinking-2025-05-14"},
|
|
},
|
|
{
|
|
name: "effort flag kept for opus 4.5 model",
|
|
model: "anthropic.claude-opus-4-5-20250929-v1:0",
|
|
inputValues: []string{"effort-2025-11-24,interleaved-thinking-2025-05-14"},
|
|
expectValues: []string{"effort-2025-11-24", "interleaved-thinking-2025-05-14"},
|
|
},
|
|
{
|
|
name: "context management kept for sonnet 4.5",
|
|
model: "anthropic.claude-sonnet-4-5-20250929-v1:0",
|
|
inputValues: []string{"context-management-2025-06-27"},
|
|
expectValues: []string{"context-management-2025-06-27"},
|
|
},
|
|
{
|
|
name: "context management kept for haiku 4.5",
|
|
model: "anthropic.claude-haiku-4-5-20250929-v1:0",
|
|
inputValues: []string{"context-management-2025-06-27"},
|
|
expectValues: []string{"context-management-2025-06-27"},
|
|
},
|
|
{
|
|
name: "context management removed for unsupported model",
|
|
model: "anthropic.claude-opus-4-6-v1",
|
|
inputValues: []string{"context-management-2025-06-27,interleaved-thinking-2025-05-14"},
|
|
expectValues: []string{"interleaved-thinking-2025-05-14"},
|
|
},
|
|
{
|
|
name: "separate header values are handled correctly",
|
|
model: "anthropic.claude-sonnet-4-5-20250929-v1:0",
|
|
inputValues: []string{"interleaved-thinking-2025-05-14", "context-management-2025-06-27"},
|
|
expectValues: []string{"interleaved-thinking-2025-05-14", "context-management-2025-06-27"},
|
|
},
|
|
{
|
|
name: "mixed comma-joined and separate header values",
|
|
model: "anthropic.claude-opus-4-5-20250929-v1:0",
|
|
inputValues: []string{"interleaved-thinking-2025-05-14,effort-2025-11-24", "token-efficient-tools-2025-02-19"},
|
|
expectValues: []string{"interleaved-thinking-2025-05-14", "effort-2025-11-24", "token-efficient-tools-2025-02-19"},
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
headers := http.Header{}
|
|
for _, v := range tc.inputValues {
|
|
headers.Add("Anthropic-Beta", v)
|
|
}
|
|
|
|
filterBedrockBetaFlags(headers, tc.model)
|
|
|
|
// Each kept flag should be a separate header value.
|
|
got := headers.Values("Anthropic-Beta")
|
|
require.Equal(t, tc.expectValues, got)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestProcessKeyPoolError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
err error
|
|
expectedNil bool
|
|
expectedStatus int
|
|
expectedRetryAfter time.Duration
|
|
}{
|
|
{
|
|
// Transient with valid keys present: 429, no Retry-After.
|
|
name: "transient_zero_retry_after",
|
|
err: &keypool.TransientKeyPoolError{},
|
|
expectedStatus: http.StatusTooManyRequests,
|
|
expectedRetryAfter: 0,
|
|
},
|
|
{
|
|
// Transient with cooldown: 429, Retry-After set.
|
|
name: "transient_with_retry_after",
|
|
err: &keypool.TransientKeyPoolError{RetryAfter: 5 * time.Second},
|
|
expectedStatus: http.StatusTooManyRequests,
|
|
expectedRetryAfter: 5 * time.Second,
|
|
},
|
|
{
|
|
// Permanent: 502 api_error.
|
|
name: "permanent_returns_502",
|
|
err: keypool.ErrPermanentKeyPool,
|
|
expectedStatus: http.StatusBadGateway,
|
|
},
|
|
{
|
|
// Anything else: not a pool-exhaustion error.
|
|
name: "non_pool_exhaustion_error_returns_nil",
|
|
err: xerrors.New("some other error"),
|
|
expectedNil: true,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
got := ProcessKeyPoolError(tc.err)
|
|
if tc.expectedNil {
|
|
require.Nil(t, got)
|
|
return
|
|
}
|
|
require.NotNil(t, got)
|
|
assert.Equal(t, tc.expectedStatus, got.StatusCode)
|
|
assert.Equal(t, tc.expectedRetryAfter, got.RetryAfter)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestMarkKeyOnError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
err error
|
|
expectedReturn bool
|
|
expectedState keypool.KeyState
|
|
}{
|
|
{
|
|
// Not an *anthropic.Error: no status code to act on.
|
|
name: "non_api_error_returns_false",
|
|
err: xerrors.New("network failure"),
|
|
expectedReturn: false,
|
|
expectedState: keypool.KeyStateValid,
|
|
},
|
|
{
|
|
// Rate-limited: temporary cooldown.
|
|
name: "429_marks_temporary",
|
|
err: &anthropic.Error{StatusCode: http.StatusTooManyRequests, Response: &http.Response{StatusCode: http.StatusTooManyRequests}},
|
|
expectedReturn: true,
|
|
expectedState: keypool.KeyStateTemporary,
|
|
},
|
|
{
|
|
// Auth failure: mark permanent.
|
|
name: "401_marks_permanent",
|
|
err: &anthropic.Error{StatusCode: http.StatusUnauthorized, Response: &http.Response{StatusCode: http.StatusUnauthorized}},
|
|
expectedReturn: true,
|
|
expectedState: keypool.KeyStatePermanent,
|
|
},
|
|
{
|
|
// Auth forbidden: mark permanent.
|
|
name: "403_marks_permanent",
|
|
err: &anthropic.Error{StatusCode: http.StatusForbidden, Response: &http.Response{StatusCode: http.StatusForbidden}},
|
|
expectedReturn: true,
|
|
expectedState: keypool.KeyStatePermanent,
|
|
},
|
|
{
|
|
// Server errors are not key-specific.
|
|
name: "500_does_not_mark",
|
|
err: &anthropic.Error{StatusCode: http.StatusInternalServerError, Response: &http.Response{StatusCode: http.StatusInternalServerError}},
|
|
expectedReturn: false,
|
|
expectedState: keypool.KeyStateValid,
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
pool, err := keypool.New([]string{"key-0"}, quartz.NewMock(t))
|
|
require.NoError(t, err)
|
|
key, err := pool.Walker().Next()
|
|
require.NoError(t, err)
|
|
|
|
base := &interceptionBase{cfg: config.Anthropic{KeyPool: pool}, logger: slog.Make()}
|
|
|
|
got := base.markKeyOnError(context.Background(), key, tc.err)
|
|
assert.Equal(t, tc.expectedReturn, got)
|
|
assert.Equal(t, tc.expectedState, key.State())
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestWriteUpstreamError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
respErr *ResponseError
|
|
expectStatus int
|
|
// Empty string means the header should be absent.
|
|
expectRetryAfter string
|
|
// Substring expected in the marshaled body. Empty means no body check.
|
|
expectBodyContains string
|
|
}{
|
|
{
|
|
// Standard error: status and JSON body written.
|
|
name: "writes_status_and_body",
|
|
respErr: newErrorResponse("upstream failed", "api_error", http.StatusBadGateway, 0),
|
|
expectStatus: http.StatusBadGateway,
|
|
expectBodyContains: `"upstream failed"`,
|
|
},
|
|
{
|
|
// Whole-second retryAfter: emitted as integer seconds.
|
|
name: "retry_after_in_seconds",
|
|
respErr: newErrorResponse("rate limited", "rate_limit_error", http.StatusTooManyRequests, 60*time.Second),
|
|
expectStatus: http.StatusTooManyRequests,
|
|
expectRetryAfter: "60",
|
|
},
|
|
{
|
|
// 500ms rounds up to Retry-After: 1.
|
|
name: "retry_after_500ms_rounds_up_to_one",
|
|
respErr: newErrorResponse("rate limited", "rate_limit_error", http.StatusTooManyRequests, 500*time.Millisecond),
|
|
expectStatus: http.StatusTooManyRequests,
|
|
expectRetryAfter: "1",
|
|
},
|
|
{
|
|
// 200ms rounds up to Retry-After: 1.
|
|
name: "retry_after_200ms_rounds_up_to_one",
|
|
respErr: newErrorResponse("rate limited", "rate_limit_error", http.StatusTooManyRequests, 200*time.Millisecond),
|
|
expectStatus: http.StatusTooManyRequests,
|
|
expectRetryAfter: "1",
|
|
},
|
|
{
|
|
// Negative retryAfter: header omitted.
|
|
name: "negative_retry_after_omits_header",
|
|
respErr: newErrorResponse("rate limited", "rate_limit_error", http.StatusTooManyRequests, -1*time.Second),
|
|
expectStatus: http.StatusTooManyRequests,
|
|
expectRetryAfter: "",
|
|
},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
base := &interceptionBase{logger: slog.Make()}
|
|
|
|
w := httptest.NewRecorder()
|
|
base.writeUpstreamError(w, tc.respErr)
|
|
|
|
assert.Equal(t, tc.expectStatus, w.Code, "status code")
|
|
assert.Equal(t, "application/json", w.Header().Get("Content-Type"), "Content-Type header")
|
|
assert.Equal(t, tc.expectRetryAfter, w.Header().Get("Retry-After"), "Retry-After header")
|
|
assert.Contains(t, w.Body.String(), `"type":"error"`, "outer error envelope")
|
|
if tc.expectBodyContains != "" {
|
|
assert.Contains(t, w.Body.String(), tc.expectBodyContains, "response body")
|
|
}
|
|
})
|
|
}
|
|
}
|