Files
coder/aibridge/intercept/messages/base_internal_test.go
T
Ethan c650aabbef chore: standardize on *_internal_test.go for white-box tests (#25601)
My agent added `//nolint:testpackage` to a test file on one of my PRs.
Again. This PR cleans it up across the entire repo and updates the
in-repo conventions so future agents stop doing it.

The repo already has a precedent for white-box tests that need to touch
unexported symbols: `*_internal_test.go` (145+ existing files). The
`testpackage` linter's default `skip-regexp` exempts that filename
suffix, so the `//nolint:testpackage` directive is unnecessary in every
case where someone reached for it. This PR renames 51 such files to
`*_internal_test.go` via `git mv` so blame and history follow, and
strips the dead directive from 2 files that were already correctly named
(`coderd/oauth2provider/authorize_internal_test.go`,
`coderd/x/chatd/advisor_internal_test.go`).

`.claude/docs/TESTING.md` now documents the rule explicitly under *Test
Package Naming*, which is imported into the root `AGENTS.md` via
`@.claude/docs/TESTING.md`. The rule: prefer `package foo_test`; if you
need internal access, rename the file to `*_internal_test.go` rather
than adding a nolint directive.
2026-05-22 20:24:38 +10:00

1247 lines
42 KiB
Go

package messages
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/anthropics/anthropic-sdk-go"
"github.com/anthropics/anthropic-sdk-go/shared/constant"
mcpgo "github.com/mark3labs/mcp-go/mcp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/aibridge/config"
"github.com/coder/coder/v2/aibridge/keypool"
"github.com/coder/coder/v2/aibridge/mcp"
"github.com/coder/coder/v2/aibridge/utils"
"github.com/coder/quartz"
)
func TestScanForCorrelatingToolCallID(t *testing.T) {
t.Parallel()
tests := []struct {
name string
requestBody string
expected *string
}{
{
name: "no messages field",
requestBody: `{}`,
expected: nil,
},
{
name: "messages string",
requestBody: `{"messages":"test"}`,
expected: nil,
},
{
name: "empty messages array",
requestBody: `{"messages":[]}`,
expected: nil,
},
{
name: "last message has no tool result blocks",
requestBody: `{"messages":[{"role":"user","content":"hello"}]}`,
expected: nil,
},
{
name: "single tool result block",
requestBody: `{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_abc","content":"result"}]}]}`,
expected: utils.PtrTo("toolu_abc"),
},
{
name: "multiple tool result blocks returns last",
requestBody: `{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_first","content":"first"},{"type":"text","text":"ignored"},{"type":"tool_result","tool_use_id":"toolu_second","content":"second"}]}]}`,
expected: utils.PtrTo("toolu_second"),
},
{
name: "last message is not a tool result",
requestBody: `{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_first","content":"first"}]},{"role":"user","content":"some text"}]}`,
expected: nil,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
base := &interceptionBase{
reqPayload: mustMessagesPayload(t, tc.requestBody),
}
require.Equal(t, tc.expected, base.CorrelatingToolCallID())
})
}
}
func TestAWSBedrockValidation(t *testing.T) {
t.Parallel()
tests := []struct {
name string
cfg *config.AWSBedrock
expectError bool
errorMsg string
}{
// Valid cases: static credentials.
{
name: "static credentials with region",
cfg: &config.AWSBedrock{
Region: "us-east-1",
AccessKey: "test-key",
AccessKeySecret: "test-secret",
Model: "test-model",
SmallFastModel: "test-small-model",
},
},
{
name: "static credentials with base url",
cfg: &config.AWSBedrock{
BaseURL: "http://bedrock.internal",
AccessKey: "test-key",
AccessKeySecret: "test-secret",
Model: "test-model",
SmallFastModel: "test-small-model",
},
},
{
// There unfortunately isn't a way for us to determine precedence in a unit test,
// since the produced options take a `requestconfig.RequestConfig` input value
// which is internal to the anthropic SDK.
//
// See TestAWSBedrockIntegration which validates this.
name: "static credentials with base url & region",
cfg: &config.AWSBedrock{
Region: "us-east-1",
AccessKey: "test-key",
AccessKeySecret: "test-secret",
Model: "test-model",
SmallFastModel: "test-small-model",
},
},
// Invalid cases.
{
name: "missing region & base url",
cfg: &config.AWSBedrock{
Region: "",
AccessKey: "test-key",
AccessKeySecret: "test-secret",
Model: "test-model",
SmallFastModel: "test-small-model",
},
expectError: true,
errorMsg: "region or base url required",
},
{
name: "missing access key",
cfg: &config.AWSBedrock{
Region: "us-east-1",
AccessKeySecret: "test-secret",
Model: "test-model",
SmallFastModel: "test-small-model",
},
expectError: true,
errorMsg: "both access key and access key secret must be provided together",
},
{
name: "missing access key secret",
cfg: &config.AWSBedrock{
Region: "us-east-1",
AccessKey: "test-key",
AccessKeySecret: "",
Model: "test-model",
SmallFastModel: "test-small-model",
},
expectError: true,
errorMsg: "both access key and access key secret must be provided together",
},
{
name: "missing model",
cfg: &config.AWSBedrock{
Region: "us-east-1",
AccessKey: "test-key",
AccessKeySecret: "test-secret",
Model: "",
SmallFastModel: "test-small-model",
},
expectError: true,
errorMsg: "model required",
},
{
name: "missing small fast model",
cfg: &config.AWSBedrock{
Region: "us-east-1",
AccessKey: "test-key",
AccessKeySecret: "test-secret",
Model: "test-model",
SmallFastModel: "",
},
expectError: true,
errorMsg: "small fast model required",
},
{
name: "all fields empty",
cfg: &config.AWSBedrock{},
expectError: true,
errorMsg: "region or base url required",
},
{
name: "nil config",
cfg: nil,
expectError: true,
errorMsg: "nil config given",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
base := &interceptionBase{}
opts, err := base.withAWSBedrockOptions(context.Background(), tt.cfg)
if tt.expectError {
require.Error(t, err)
require.Contains(t, err.Error(), tt.errorMsg)
} else {
require.NotEmpty(t, opts)
require.NoError(t, err)
}
})
}
}
// TestAWSBedrockCredentialChain tests credential resolution via the AWS SDK default credential chain.
// NOTE: Cannot use t.Parallel() here because subtests use t.Setenv which requires sequential execution.
func TestAWSBedrockCredentialChain(t *testing.T) {
tests := []struct {
name string
cfg *config.AWSBedrock
envVars map[string]string
expectError bool
errorMsg string
}{
{
name: "temporary credentials via env",
cfg: &config.AWSBedrock{
Region: "us-east-1",
Model: "test-model",
SmallFastModel: "test-small-model",
},
envVars: map[string]string{
"AWS_ACCESS_KEY_ID": "test-key",
"AWS_SECRET_ACCESS_KEY": "test-secret",
},
},
{
name: "temporary credentials with session token via env",
cfg: &config.AWSBedrock{
Region: "us-east-1",
Model: "test-model",
SmallFastModel: "test-small-model",
},
envVars: map[string]string{
"AWS_ACCESS_KEY_ID": "test-key",
"AWS_SECRET_ACCESS_KEY": "test-secret",
"AWS_SESSION_TOKEN": "test-session-token",
},
},
{
// When static credentials are not provided and no environment credentials are set,
// the SDK default credential chain fails to resolve credentials.
name: "error when no credential source is configured",
cfg: &config.AWSBedrock{
Region: "us-east-1",
Model: "test-model",
SmallFastModel: "test-small-model",
},
envVars: map[string]string{
"AWS_ACCESS_KEY_ID": "",
"AWS_SECRET_ACCESS_KEY": "",
"AWS_SESSION_TOKEN": "",
"AWS_PROFILE": "",
"AWS_SHARED_CREDENTIALS_FILE": "/dev/null",
"AWS_CONFIG_FILE": "/dev/null",
"AWS_WEB_IDENTITY_TOKEN_FILE": "",
"AWS_ROLE_ARN": "",
"AWS_ROLE_SESSION_NAME": "",
"AWS_CONTAINER_CREDENTIALS_RELATIVE_URI": "",
"AWS_CONTAINER_CREDENTIALS_FULL_URI": "",
"AWS_CONTAINER_AUTHORIZATION_TOKEN": "",
"AWS_EC2_METADATA_DISABLED": "true",
},
expectError: true,
errorMsg: "no AWS credentials found",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
for key, val := range tt.envVars {
t.Setenv(key, val)
}
base := &interceptionBase{}
opts, err := base.withAWSBedrockOptions(context.Background(), tt.cfg)
if tt.expectError {
require.Error(t, err)
require.Contains(t, err.Error(), tt.errorMsg)
} else {
require.NotEmpty(t, opts)
require.NoError(t, err)
}
})
}
}
func TestAccumulateUsage(t *testing.T) {
t.Parallel()
t.Run("Usage to Usage", func(t *testing.T) {
t.Parallel()
dest := &anthropic.Usage{
InputTokens: 10,
OutputTokens: 20,
CacheCreationInputTokens: 5,
CacheReadInputTokens: 3,
CacheCreation: anthropic.CacheCreation{
Ephemeral1hInputTokens: 2,
Ephemeral5mInputTokens: 1,
},
ServerToolUse: anthropic.ServerToolUsage{
WebSearchRequests: 1,
},
}
source := anthropic.Usage{
InputTokens: 15,
OutputTokens: 25,
CacheCreationInputTokens: 8,
CacheReadInputTokens: 4,
CacheCreation: anthropic.CacheCreation{
Ephemeral1hInputTokens: 3,
Ephemeral5mInputTokens: 2,
},
ServerToolUse: anthropic.ServerToolUsage{
WebSearchRequests: 2,
},
}
accumulateUsage(dest, source)
require.EqualValues(t, 25, dest.InputTokens)
require.EqualValues(t, 45, dest.OutputTokens)
require.EqualValues(t, 13, dest.CacheCreationInputTokens)
require.EqualValues(t, 7, dest.CacheReadInputTokens)
require.EqualValues(t, 5, dest.CacheCreation.Ephemeral1hInputTokens)
require.EqualValues(t, 3, dest.CacheCreation.Ephemeral5mInputTokens)
require.EqualValues(t, 3, dest.ServerToolUse.WebSearchRequests)
})
t.Run("MessageDeltaUsage to MessageDeltaUsage", func(t *testing.T) {
t.Parallel()
dest := &anthropic.MessageDeltaUsage{
InputTokens: 10,
OutputTokens: 20,
CacheCreationInputTokens: 5,
CacheReadInputTokens: 3,
ServerToolUse: anthropic.ServerToolUsage{
WebSearchRequests: 1,
},
}
source := anthropic.MessageDeltaUsage{
InputTokens: 15,
OutputTokens: 25,
CacheCreationInputTokens: 8,
CacheReadInputTokens: 4,
ServerToolUse: anthropic.ServerToolUsage{
WebSearchRequests: 2,
},
}
accumulateUsage(dest, source)
require.EqualValues(t, 25, dest.InputTokens)
require.EqualValues(t, 45, dest.OutputTokens)
require.EqualValues(t, 13, dest.CacheCreationInputTokens)
require.EqualValues(t, 7, dest.CacheReadInputTokens)
require.EqualValues(t, 3, dest.ServerToolUse.WebSearchRequests)
})
t.Run("Usage to MessageDeltaUsage", func(t *testing.T) {
t.Parallel()
dest := &anthropic.MessageDeltaUsage{
InputTokens: 10,
OutputTokens: 20,
CacheCreationInputTokens: 5,
CacheReadInputTokens: 3,
ServerToolUse: anthropic.ServerToolUsage{
WebSearchRequests: 1,
},
}
source := anthropic.Usage{
InputTokens: 15,
OutputTokens: 25,
CacheCreationInputTokens: 8,
CacheReadInputTokens: 4,
CacheCreation: anthropic.CacheCreation{
Ephemeral1hInputTokens: 3, // These won't be accumulated to MessageDeltaUsage
Ephemeral5mInputTokens: 2,
},
ServerToolUse: anthropic.ServerToolUsage{
WebSearchRequests: 2,
},
}
accumulateUsage(dest, source)
require.EqualValues(t, 25, dest.InputTokens)
require.EqualValues(t, 45, dest.OutputTokens)
require.EqualValues(t, 13, dest.CacheCreationInputTokens)
require.EqualValues(t, 7, dest.CacheReadInputTokens)
require.EqualValues(t, 3, dest.ServerToolUse.WebSearchRequests)
})
t.Run("MessageDeltaUsage to Usage", func(t *testing.T) {
t.Parallel()
dest := &anthropic.Usage{
InputTokens: 10,
OutputTokens: 20,
CacheCreationInputTokens: 5,
CacheReadInputTokens: 3,
CacheCreation: anthropic.CacheCreation{
Ephemeral1hInputTokens: 2,
Ephemeral5mInputTokens: 1,
},
ServerToolUse: anthropic.ServerToolUsage{
WebSearchRequests: 1,
},
}
source := anthropic.MessageDeltaUsage{
InputTokens: 15,
OutputTokens: 25,
CacheCreationInputTokens: 8,
CacheReadInputTokens: 4,
ServerToolUse: anthropic.ServerToolUsage{
WebSearchRequests: 2,
},
}
accumulateUsage(dest, source)
require.EqualValues(t, 25, dest.InputTokens)
require.EqualValues(t, 45, dest.OutputTokens)
require.EqualValues(t, 13, dest.CacheCreationInputTokens)
require.EqualValues(t, 7, dest.CacheReadInputTokens)
// Ephemeral tokens remain unchanged since MessageDeltaUsage doesn't have them
require.EqualValues(t, 2, dest.CacheCreation.Ephemeral1hInputTokens)
require.EqualValues(t, 1, dest.CacheCreation.Ephemeral5mInputTokens)
require.EqualValues(t, 3, dest.ServerToolUse.WebSearchRequests)
})
t.Run("Nil or unsupported types", func(t *testing.T) {
t.Parallel()
// Test with nil dest
var nilUsage *anthropic.Usage
source := anthropic.Usage{InputTokens: 10}
accumulateUsage(nilUsage, source) // Should not panic
// Test with unsupported types
var unsupported string
accumulateUsage(&unsupported, source) // Should not panic, just do nothing
})
}
func TestInjectTools_CacheBreakpoints(t *testing.T) {
t.Parallel()
t.Run("cache control preserved when no tools to inject", func(t *testing.T) {
t.Parallel()
// Request has existing tool with cache control, but no tools to inject.
i := &interceptionBase{
reqPayload: mustMessagesPayload(t, `{"tools":[`+
`{"name":"existing_tool","type":"custom","input_schema":{"type":"object","properties":{}},"cache_control":{"type":"ephemeral"}}]}`),
mcpProxy: &mockServerProxier{tools: nil},
logger: slog.Make(),
}
i.injectTools()
// Cache control should remain untouched since no tools were injected.
toolItems := gjson.GetBytes(i.reqPayload, "tools").Array()
require.Len(t, toolItems, 1)
require.Equal(t, "existing_tool", toolItems[0].Get("name").String())
require.Equal(t, string(constant.ValueOf[constant.Ephemeral]()), toolItems[0].Get("cache_control.type").String())
})
t.Run("cache control breakpoint is preserved by prepending injected tools", func(t *testing.T) {
t.Parallel()
// Request has existing tool with cache control.
i := &interceptionBase{
reqPayload: mustMessagesPayload(t, `{"tools":[`+
`{"name":"existing_tool","type":"custom","input_schema":{"type":"object","properties":{}},"cache_control":{"type":"ephemeral"}}]}`),
mcpProxy: &mockServerProxier{
tools: []*mcp.Tool{
{ID: "injected_tool", Name: "injected", Description: "Injected tool"},
},
},
logger: slog.Make(),
}
i.injectTools()
toolItems := gjson.GetBytes(i.reqPayload, "tools").Array()
require.Len(t, toolItems, 2)
// Injected tools are prepended.
require.Equal(t, "injected_tool", toolItems[0].Get("name").String())
require.Empty(t, toolItems[0].Get("cache_control.type").String())
// Original tool's cache control should be preserved at the end.
require.Equal(t, "existing_tool", toolItems[1].Get("name").String())
require.Equal(t, string(constant.ValueOf[constant.Ephemeral]()), toolItems[1].Get("cache_control.type").String())
})
// The cache breakpoint SHOULD be on the final tool, but may not be; we must preserve that intention.
t.Run("cache control breakpoint in non-standard location is preserved", func(t *testing.T) {
t.Parallel()
// Request has multiple tools with cache control breakpoints.
i := &interceptionBase{
reqPayload: mustMessagesPayload(t, `{"tools":[`+
`{"name":"tool_with_cache_1","type":"custom","input_schema":{"type":"object","properties":{}},"cache_control":{"type":"ephemeral"}},`+
`{"name":"tool_with_cache_2","type":"custom","input_schema":{"type":"object","properties":{}}}]}`),
mcpProxy: &mockServerProxier{
tools: []*mcp.Tool{
{ID: "injected_tool", Name: "injected", Description: "Injected tool"},
},
},
logger: slog.Make(),
}
i.injectTools()
toolItems := gjson.GetBytes(i.reqPayload, "tools").Array()
require.Len(t, toolItems, 3)
// Injected tool is prepended without cache control.
require.Equal(t, "injected_tool", toolItems[0].Get("name").String())
require.Empty(t, toolItems[0].Get("cache_control.type").String())
// Both original tools' cache controls should remain.
require.Equal(t, "tool_with_cache_1", toolItems[1].Get("name").String())
require.Equal(t, string(constant.ValueOf[constant.Ephemeral]()), toolItems[1].Get("cache_control.type").String())
require.Equal(t, "tool_with_cache_2", toolItems[2].Get("name").String())
require.Empty(t, toolItems[2].Get("cache_control.type").String())
})
t.Run("no cache control added when none originally set", func(t *testing.T) {
t.Parallel()
// Request has tools but none with cache control.
i := &interceptionBase{
reqPayload: mustMessagesPayload(t, `{"tools":[`+
`{"name":"existing_tool_no_cache","type":"custom","input_schema":{"type":"object","properties":{}}}]}`),
mcpProxy: &mockServerProxier{
tools: []*mcp.Tool{
{ID: "injected_tool", Name: "injected", Description: "Injected tool"},
},
},
logger: slog.Make(),
}
i.injectTools()
toolItems := gjson.GetBytes(i.reqPayload, "tools").Array()
require.Len(t, toolItems, 2)
// Injected tool is prepended without cache control.
require.Equal(t, "injected_tool", toolItems[0].Get("name").String())
require.Empty(t, toolItems[0].Get("cache_control.type").String())
// Original tool remains at the end without cache control.
require.Equal(t, "existing_tool_no_cache", toolItems[1].Get("name").String())
require.Empty(t, toolItems[1].Get("cache_control.type").String())
})
}
func TestInjectTools_ParallelToolCalls(t *testing.T) {
t.Parallel()
t.Run("does not modify tool choice when no tools to inject", func(t *testing.T) {
t.Parallel()
i := &interceptionBase{
reqPayload: mustMessagesPayload(t, `{"tool_choice":{"type":"auto"}}`),
mcpProxy: &mockServerProxier{tools: nil}, // No tools to inject.
logger: slog.Make(),
}
i.injectTools()
// Tool choice should remain unchanged - DisableParallelToolUse should not be set.
toolChoice := gjson.GetBytes(i.reqPayload, "tool_choice")
require.Equal(t, string(constant.ValueOf[constant.Auto]()), toolChoice.Get("type").String())
require.False(t, toolChoice.Get("disable_parallel_tool_use").Exists())
})
t.Run("disables parallel tool use for empty tool choice (default)", func(t *testing.T) {
t.Parallel()
i := &interceptionBase{
reqPayload: mustMessagesPayload(t, `{}`),
mcpProxy: &mockServerProxier{
tools: []*mcp.Tool{{ID: "test_tool", Name: "test", Description: "Test"}},
},
logger: slog.Make(),
}
i.injectTools()
toolChoice := gjson.GetBytes(i.reqPayload, "tool_choice")
require.Equal(t, string(constant.ValueOf[constant.Auto]()), toolChoice.Get("type").String())
require.True(t, toolChoice.Get("disable_parallel_tool_use").Exists())
require.True(t, toolChoice.Get("disable_parallel_tool_use").Bool())
})
t.Run("disables parallel tool use for explicit auto tool choice", func(t *testing.T) {
t.Parallel()
i := &interceptionBase{
reqPayload: mustMessagesPayload(t, `{"tool_choice":{"type":"auto"}}`),
mcpProxy: &mockServerProxier{
tools: []*mcp.Tool{{ID: "test_tool", Name: "test", Description: "Test"}},
},
logger: slog.Make(),
}
i.injectTools()
toolChoice := gjson.GetBytes(i.reqPayload, "tool_choice")
require.Equal(t, string(constant.ValueOf[constant.Auto]()), toolChoice.Get("type").String())
require.True(t, toolChoice.Get("disable_parallel_tool_use").Exists())
require.True(t, toolChoice.Get("disable_parallel_tool_use").Bool())
})
t.Run("disables parallel tool use for any tool choice", func(t *testing.T) {
t.Parallel()
i := &interceptionBase{
reqPayload: mustMessagesPayload(t, `{"tool_choice":{"type":"any"}}`),
mcpProxy: &mockServerProxier{
tools: []*mcp.Tool{{ID: "test_tool", Name: "test", Description: "Test"}},
},
logger: slog.Make(),
}
i.injectTools()
toolChoice := gjson.GetBytes(i.reqPayload, "tool_choice")
require.Equal(t, string(constant.ValueOf[constant.Any]()), toolChoice.Get("type").String())
require.True(t, toolChoice.Get("disable_parallel_tool_use").Exists())
require.True(t, toolChoice.Get("disable_parallel_tool_use").Bool())
})
t.Run("disables parallel tool use for tool choice type", func(t *testing.T) {
t.Parallel()
i := &interceptionBase{
reqPayload: mustMessagesPayload(t, `{"tool_choice":{"type":"tool","name":"specific_tool"}}`),
mcpProxy: &mockServerProxier{
tools: []*mcp.Tool{{ID: "test_tool", Name: "test", Description: "Test"}},
},
logger: slog.Make(),
}
i.injectTools()
toolChoice := gjson.GetBytes(i.reqPayload, "tool_choice")
require.Equal(t, string(constant.ValueOf[constant.Tool]()), toolChoice.Get("type").String())
require.True(t, toolChoice.Get("disable_parallel_tool_use").Exists())
require.True(t, toolChoice.Get("disable_parallel_tool_use").Bool())
})
t.Run("no-op for none tool choice type", func(t *testing.T) {
t.Parallel()
i := &interceptionBase{
reqPayload: mustMessagesPayload(t, `{"tool_choice":{"type":"none"}}`),
mcpProxy: &mockServerProxier{
tools: []*mcp.Tool{{ID: "test_tool", Name: "test", Description: "Test"}},
},
logger: slog.Make(),
}
i.injectTools()
// Tools are still injected.
require.Len(t, gjson.GetBytes(i.reqPayload, "tools").Array(), 1)
// But no parallel tool use modification for "none" type.
toolChoice := gjson.GetBytes(i.reqPayload, "tool_choice")
require.Equal(t, string(constant.ValueOf[constant.None]()), toolChoice.Get("type").String())
require.False(t, toolChoice.Get("disable_parallel_tool_use").Exists())
})
}
func TestAugmentRequestForBedrock_AdaptiveThinking(t *testing.T) {
t.Parallel()
tests := []struct {
name string
bedrockModel string
requestBody string
clientBetaFlags string
expectThinkingType string
expectBudgetTokens int64 // 0 means budget_tokens should not be present
expectEffort string // expected output_config.effort; "" means must not be present
expectRemovedFields []string
expectKeptFields []string
expectBetaValues []string // expected separate Anthropic-Beta header values
}{
{
name: "non_4_6_model_with_adaptive_thinking_gets_converted",
bedrockModel: "anthropic.claude-sonnet-4-5-20250929-v1:0",
requestBody: `{"max_tokens":10000,"thinking":{"type":"adaptive"}}`,
expectThinkingType: "enabled",
expectBudgetTokens: 8000, // 10000 * 0.8 (default/high effort)
},
{
name: "non_4_6_model_with_adaptive_thinking_and_small_max_tokens_disables_thinking",
bedrockModel: "anthropic.claude-sonnet-4-5-20250929-v1:0",
requestBody: `{"max_tokens":1000,"thinking":{"type":"adaptive"}}`,
expectThinkingType: "disabled",
},
{
name: "opus_4_6_model_with_adaptive_thinking_is_not_converted",
bedrockModel: "anthropic.claude-opus-4-6-v1",
requestBody: `{"max_tokens":10000,"thinking":{"type":"adaptive"}}`,
expectThinkingType: "adaptive",
},
{
name: "sonnet_4_6_model_with_adaptive_thinking_is_not_converted",
bedrockModel: "anthropic.claude-sonnet-4-6",
requestBody: `{"max_tokens":10000,"thinking":{"type":"adaptive"}}`,
expectThinkingType: "adaptive",
},
{
name: "non_4_6_model_with_no_thinking_field_is_unchanged",
bedrockModel: "anthropic.claude-sonnet-4-5-20250929-v1:0",
requestBody: `{"max_tokens":10000}`,
},
{
name: "non_4_6_model_with_enabled_thinking_is_unchanged",
bedrockModel: "anthropic.claude-sonnet-4-5-20250929-v1:0",
requestBody: `{"max_tokens":10000,"thinking":{"type":"enabled","budget_tokens":5000}}`,
expectThinkingType: "enabled",
expectBudgetTokens: 5000,
},
{
name: "output_config_stripped_without_beta_flag_and_effort_used_for_budget",
bedrockModel: "anthropic.claude-sonnet-4-5-20250929-v1:0",
requestBody: `{"max_tokens":10000,"thinking":{"type":"adaptive"},"output_config":{"effort":"low"}}`,
expectThinkingType: "enabled",
expectBudgetTokens: 2000, // 10000 * 0.2 (low effort)
expectRemovedFields: []string{"output_config"},
},
{
name: "output_config_kept_when_effort_beta_flag_present_on_opus_4_5",
bedrockModel: "anthropic.claude-opus-4-5-20250929-v1:0",
clientBetaFlags: "effort-2025-11-24,interleaved-thinking-2025-05-14",
requestBody: `{"max_tokens":10000,"output_config":{"effort":"high"}}`,
expectEffort: "high",
expectKeptFields: []string{"output_config"},
expectBetaValues: []string{"effort-2025-11-24", "interleaved-thinking-2025-05-14"},
},
{
name: "output_config_stripped_for_non_opus_4_5_even_with_effort_beta_flag",
bedrockModel: "anthropic.claude-sonnet-4-5-20250929-v1:0",
clientBetaFlags: "effort-2025-11-24,interleaved-thinking-2025-05-14",
requestBody: `{"max_tokens":10000,"output_config":{"effort":"high"}}`,
expectRemovedFields: []string{"output_config"},
expectBetaValues: []string{"interleaved-thinking-2025-05-14"},
},
{
name: "context_management_kept_when_beta_flag_present",
bedrockModel: "anthropic.claude-sonnet-4-5-20250929-v1:0",
clientBetaFlags: "context-management-2025-06-27",
requestBody: `{"max_tokens":10000,"context_management":{"type":"auto"}}`,
expectKeptFields: []string{"context_management"},
expectBetaValues: []string{"context-management-2025-06-27"},
},
{
name: "context_management_stripped_without_beta_flag",
bedrockModel: "anthropic.claude-sonnet-4-5-20250929-v1:0",
requestBody: `{"max_tokens":10000,"context_management":{"type":"auto"}}`,
expectRemovedFields: []string{"context_management"},
},
{
name: "context_management_stripped_for_unsupported_model_even_with_beta_flag",
bedrockModel: "anthropic.claude-opus-4-6-v1",
clientBetaFlags: "context-management-2025-06-27",
requestBody: `{"max_tokens":10000,"thinking":{"type":"adaptive"},"context_management":{"type":"auto"}}`,
expectThinkingType: "adaptive",
expectRemovedFields: []string{"context_management"},
},
{
name: "unsupported_beta_flags_are_filtered_out",
bedrockModel: "anthropic.claude-sonnet-4-5-20250929-v1:0",
clientBetaFlags: "claude-code-20250219,interleaved-thinking-2025-05-14,prompt-caching-scope-2026-01-05",
requestBody: `{"max_tokens":10000}`,
expectBetaValues: []string{"interleaved-thinking-2025-05-14"},
},
{
name: "all_unsupported_fields_stripped_and_beta_flags_filtered",
bedrockModel: "anthropic.claude-sonnet-4-5-20250929-v1:0",
clientBetaFlags: "claude-code-20250219,prompt-caching-scope-2026-01-05",
requestBody: `{"max_tokens":10000,"output_config":{"effort":"high"},"metadata":{"user_id":"u123"},"service_tier":"auto","container":"ctr_abc","inference_geo":"us","context_management":{"type":"auto"}}`,
expectRemovedFields: []string{"output_config", "metadata", "service_tier", "container", "inference_geo", "context_management"},
},
// Adaptive-only models (Opus 4.7+), see coder/aibridge#280. The
// conversion drops budget_tokens and flips the type; an explicit
// output_config.effort from the caller is preserved, but none is
// fabricated when absent.
{
name: "opus_4_7_model_with_enabled_thinking_is_converted_to_adaptive_and_drops_budget",
bedrockModel: "us.anthropic.claude-opus-4-7",
requestBody: `{"max_tokens":10000,"thinking":{"type":"enabled","budget_tokens":5000}}`,
expectThinkingType: "adaptive",
},
{
name: "opus_4_7_model_with_adaptive_thinking_is_unchanged",
bedrockModel: "us.anthropic.claude-opus-4-7",
requestBody: `{"max_tokens":10000,"thinking":{"type":"adaptive"}}`,
expectThinkingType: "adaptive",
},
{
name: "opus_4_7_model_without_thinking_field_is_unchanged",
bedrockModel: "us.anthropic.claude-opus-4-7",
requestBody: `{"max_tokens":10000}`,
},
{
name: "opus_4_7_model_preserves_explicit_output_config_effort",
bedrockModel: "us.anthropic.claude-opus-4-7",
requestBody: `{"max_tokens":10000,"thinking":{"type":"enabled","budget_tokens":2000},"output_config":{"effort":"max"}}`,
expectThinkingType: "adaptive",
expectEffort: "max",
expectKeptFields: []string{"output_config"},
},
{
name: "opus_4_7_model_keeps_output_config_without_effort_beta_flag",
bedrockModel: "us.anthropic.claude-opus-4-7",
requestBody: `{"max_tokens":10000,"thinking":{"type":"adaptive"},"output_config":{"effort":"high"}}`,
expectThinkingType: "adaptive",
expectEffort: "high",
expectKeptFields: []string{"output_config"},
},
{
name: "arn_style_opus_4_7_application_inference_profile_is_treated_as_adaptive_only",
bedrockModel: "arn:aws:bedrock:us-east-1:123:application-inference-profile/global.anthropic.claude-opus-4-7",
requestBody: `{"max_tokens":10000,"thinking":{"type":"enabled","budget_tokens":8000}}`,
expectThinkingType: "adaptive",
},
{
// Opus 4.7 on Bedrock rejects output_config.format (structured
// outputs) with a 400 even though it accepts output_config.effort.
name: "opus_4_7_model_strips_output_config_format_but_keeps_effort",
bedrockModel: "us.anthropic.claude-opus-4-7",
requestBody: `{"max_tokens":10000,"output_config":{"effort":"high","format":{"type":"json_schema","schema":{"type":"object"}}}}`,
expectEffort: "high",
expectKeptFields: []string{"output_config", "output_config.effort"},
expectRemovedFields: []string{"output_config.format"},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
var clientHeaders http.Header
if tc.clientBetaFlags != "" {
clientHeaders = http.Header{
"Anthropic-Beta": {tc.clientBetaFlags},
}
}
i := &interceptionBase{
reqPayload: mustMessagesPayload(t, tc.requestBody),
bedrockCfg: &config.AWSBedrock{
Model: tc.bedrockModel,
SmallFastModel: "anthropic.claude-haiku-3-5",
},
clientHeaders: clientHeaders,
logger: slog.Make(),
}
i.augmentRequestForBedrock()
thinkingType := gjson.GetBytes(i.reqPayload, "thinking.type")
if tc.expectThinkingType == "" {
require.False(t, thinkingType.Exists())
} else {
require.Equal(t, tc.expectThinkingType, thinkingType.String())
}
budgetTokens := gjson.GetBytes(i.reqPayload, "thinking.budget_tokens")
if tc.expectBudgetTokens == 0 {
require.False(t, budgetTokens.Exists(), "budget_tokens should not be set")
} else {
require.Equal(t, tc.expectBudgetTokens, budgetTokens.Int())
}
// Model should always be set to the bedrock model.
require.Equal(t, tc.bedrockModel, gjson.GetBytes(i.reqPayload, "model").String())
// Verify expected fields are removed.
for _, field := range tc.expectRemovedFields {
require.False(t, gjson.GetBytes(i.reqPayload, field).Exists(), "%s should be removed", field)
}
// Verify expected fields are kept.
for _, field := range tc.expectKeptFields {
require.True(t, gjson.GetBytes(i.reqPayload, field).Exists(), "%s should be kept", field)
}
effort := gjson.GetBytes(i.reqPayload, "output_config.effort")
if tc.expectEffort == "" {
require.False(t, effort.Exists(), "output_config.effort should not be set")
} else {
require.Equal(t, tc.expectEffort, effort.String())
}
got := clientHeaders.Values("Anthropic-Beta")
require.Equal(t, tc.expectBetaValues, got)
})
}
}
func mustMessagesPayload(t *testing.T, requestBody string) RequestPayload {
t.Helper()
payload, err := NewRequestPayload([]byte(requestBody))
require.NoError(t, err)
return payload
}
// mockServerProxier is a test implementation of mcp.ServerProxier.
type mockServerProxier struct {
tools []*mcp.Tool
}
func (*mockServerProxier) Init(context.Context) error {
return nil
}
func (*mockServerProxier) Shutdown(context.Context) error {
return nil
}
func (m *mockServerProxier) ListTools() []*mcp.Tool {
return m.tools
}
func (m *mockServerProxier) GetTool(id string) *mcp.Tool {
for _, t := range m.tools {
if t.ID == id {
return t
}
}
return nil
}
func (*mockServerProxier) CallTool(context.Context, string, any) (*mcpgo.CallToolResult, error) {
return nil, nil //nolint:nilnil // mock: no-op implementation
}
func TestFilterBedrockBetaFlags(t *testing.T) {
t.Parallel()
tests := []struct {
name string
model string
inputValues []string // header values to set (each element is a separate header value)
expectValues []string // expected separate header values after filtering
}{
{
name: "empty header",
model: "anthropic.claude-sonnet-4-5-20250929-v1:0",
inputValues: nil,
expectValues: nil,
},
{
name: "all supported flags kept",
model: "anthropic.claude-opus-4-5-20250929-v1:0",
inputValues: []string{"interleaved-thinking-2025-05-14,effort-2025-11-24"},
expectValues: []string{"interleaved-thinking-2025-05-14", "effort-2025-11-24"},
},
{
name: "unsupported flags removed",
model: "anthropic.claude-sonnet-4-5-20250929-v1:0",
inputValues: []string{"claude-code-20250219,interleaved-thinking-2025-05-14,prompt-caching-scope-2026-01-05"},
expectValues: []string{"interleaved-thinking-2025-05-14"},
},
{
name: "header removed when all flags unsupported",
model: "anthropic.claude-sonnet-4-5-20250929-v1:0",
inputValues: []string{"claude-code-20250219,prompt-caching-scope-2026-01-05"},
expectValues: nil,
},
{
name: "effort flag removed for non opus 4.5 model",
model: "anthropic.claude-sonnet-4-5-20250929-v1:0",
inputValues: []string{"effort-2025-11-24,interleaved-thinking-2025-05-14"},
expectValues: []string{"interleaved-thinking-2025-05-14"},
},
{
name: "effort flag kept for opus 4.5 model",
model: "anthropic.claude-opus-4-5-20250929-v1:0",
inputValues: []string{"effort-2025-11-24,interleaved-thinking-2025-05-14"},
expectValues: []string{"effort-2025-11-24", "interleaved-thinking-2025-05-14"},
},
{
name: "context management kept for sonnet 4.5",
model: "anthropic.claude-sonnet-4-5-20250929-v1:0",
inputValues: []string{"context-management-2025-06-27"},
expectValues: []string{"context-management-2025-06-27"},
},
{
name: "context management kept for haiku 4.5",
model: "anthropic.claude-haiku-4-5-20250929-v1:0",
inputValues: []string{"context-management-2025-06-27"},
expectValues: []string{"context-management-2025-06-27"},
},
{
name: "context management removed for unsupported model",
model: "anthropic.claude-opus-4-6-v1",
inputValues: []string{"context-management-2025-06-27,interleaved-thinking-2025-05-14"},
expectValues: []string{"interleaved-thinking-2025-05-14"},
},
{
name: "separate header values are handled correctly",
model: "anthropic.claude-sonnet-4-5-20250929-v1:0",
inputValues: []string{"interleaved-thinking-2025-05-14", "context-management-2025-06-27"},
expectValues: []string{"interleaved-thinking-2025-05-14", "context-management-2025-06-27"},
},
{
name: "mixed comma-joined and separate header values",
model: "anthropic.claude-opus-4-5-20250929-v1:0",
inputValues: []string{"interleaved-thinking-2025-05-14,effort-2025-11-24", "token-efficient-tools-2025-02-19"},
expectValues: []string{"interleaved-thinking-2025-05-14", "effort-2025-11-24", "token-efficient-tools-2025-02-19"},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
headers := http.Header{}
for _, v := range tc.inputValues {
headers.Add("Anthropic-Beta", v)
}
filterBedrockBetaFlags(headers, tc.model)
// Each kept flag should be a separate header value.
got := headers.Values("Anthropic-Beta")
require.Equal(t, tc.expectValues, got)
})
}
}
func TestProcessKeyPoolError(t *testing.T) {
t.Parallel()
tests := []struct {
name string
err error
expectedNil bool
expectedStatus int
expectedRetryAfter time.Duration
}{
{
// Transient with valid keys present: 429, no Retry-After.
name: "transient_zero_retry_after",
err: &keypool.TransientKeyPoolError{},
expectedStatus: http.StatusTooManyRequests,
expectedRetryAfter: 0,
},
{
// Transient with cooldown: 429, Retry-After set.
name: "transient_with_retry_after",
err: &keypool.TransientKeyPoolError{RetryAfter: 5 * time.Second},
expectedStatus: http.StatusTooManyRequests,
expectedRetryAfter: 5 * time.Second,
},
{
// Permanent: 502 api_error.
name: "permanent_returns_502",
err: keypool.ErrPermanentKeyPool,
expectedStatus: http.StatusBadGateway,
},
{
// Anything else: not a pool-exhaustion error.
name: "non_pool_exhaustion_error_returns_nil",
err: xerrors.New("some other error"),
expectedNil: true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got := ProcessKeyPoolError(tc.err)
if tc.expectedNil {
require.Nil(t, got)
return
}
require.NotNil(t, got)
assert.Equal(t, tc.expectedStatus, got.StatusCode)
assert.Equal(t, tc.expectedRetryAfter, got.RetryAfter)
})
}
}
func TestMarkKeyOnError(t *testing.T) {
t.Parallel()
tests := []struct {
name string
err error
expectedReturn bool
expectedState keypool.KeyState
}{
{
// Not an *anthropic.Error: no status code to act on.
name: "non_api_error_returns_false",
err: xerrors.New("network failure"),
expectedReturn: false,
expectedState: keypool.KeyStateValid,
},
{
// Rate-limited: temporary cooldown.
name: "429_marks_temporary",
err: &anthropic.Error{StatusCode: http.StatusTooManyRequests, Response: &http.Response{StatusCode: http.StatusTooManyRequests}},
expectedReturn: true,
expectedState: keypool.KeyStateTemporary,
},
{
// Auth failure: mark permanent.
name: "401_marks_permanent",
err: &anthropic.Error{StatusCode: http.StatusUnauthorized, Response: &http.Response{StatusCode: http.StatusUnauthorized}},
expectedReturn: true,
expectedState: keypool.KeyStatePermanent,
},
{
// Auth forbidden: mark permanent.
name: "403_marks_permanent",
err: &anthropic.Error{StatusCode: http.StatusForbidden, Response: &http.Response{StatusCode: http.StatusForbidden}},
expectedReturn: true,
expectedState: keypool.KeyStatePermanent,
},
{
// Server errors are not key-specific.
name: "500_does_not_mark",
err: &anthropic.Error{StatusCode: http.StatusInternalServerError, Response: &http.Response{StatusCode: http.StatusInternalServerError}},
expectedReturn: false,
expectedState: keypool.KeyStateValid,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
pool, err := keypool.New([]string{"key-0"}, quartz.NewMock(t))
require.NoError(t, err)
key, err := pool.Walker().Next()
require.NoError(t, err)
base := &interceptionBase{cfg: config.Anthropic{KeyPool: pool}, logger: slog.Make()}
got := base.markKeyOnError(context.Background(), key, tc.err)
assert.Equal(t, tc.expectedReturn, got)
assert.Equal(t, tc.expectedState, key.State())
})
}
}
func TestWriteUpstreamError(t *testing.T) {
t.Parallel()
tests := []struct {
name string
respErr *ResponseError
expectStatus int
// Empty string means the header should be absent.
expectRetryAfter string
// Substring expected in the marshaled body. Empty means no body check.
expectBodyContains string
}{
{
// Standard error: status and JSON body written.
name: "writes_status_and_body",
respErr: newErrorResponse("upstream failed", "api_error", http.StatusBadGateway, 0),
expectStatus: http.StatusBadGateway,
expectBodyContains: `"upstream failed"`,
},
{
// Whole-second retryAfter: emitted as integer seconds.
name: "retry_after_in_seconds",
respErr: newErrorResponse("rate limited", "rate_limit_error", http.StatusTooManyRequests, 60*time.Second),
expectStatus: http.StatusTooManyRequests,
expectRetryAfter: "60",
},
{
// 500ms rounds up to Retry-After: 1.
name: "retry_after_500ms_rounds_up_to_one",
respErr: newErrorResponse("rate limited", "rate_limit_error", http.StatusTooManyRequests, 500*time.Millisecond),
expectStatus: http.StatusTooManyRequests,
expectRetryAfter: "1",
},
{
// 200ms rounds up to Retry-After: 1.
name: "retry_after_200ms_rounds_up_to_one",
respErr: newErrorResponse("rate limited", "rate_limit_error", http.StatusTooManyRequests, 200*time.Millisecond),
expectStatus: http.StatusTooManyRequests,
expectRetryAfter: "1",
},
{
// Negative retryAfter: header omitted.
name: "negative_retry_after_omits_header",
respErr: newErrorResponse("rate limited", "rate_limit_error", http.StatusTooManyRequests, -1*time.Second),
expectStatus: http.StatusTooManyRequests,
expectRetryAfter: "",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
base := &interceptionBase{logger: slog.Make()}
w := httptest.NewRecorder()
base.writeUpstreamError(w, tc.respErr)
assert.Equal(t, tc.expectStatus, w.Code, "status code")
assert.Equal(t, "application/json", w.Header().Get("Content-Type"), "Content-Type header")
assert.Equal(t, tc.expectRetryAfter, w.Header().Get("Retry-After"), "Retry-After header")
assert.Contains(t, w.Body.String(), `"type":"error"`, "outer error envelope")
if tc.expectBodyContains != "" {
assert.Contains(t, w.Body.String(), tc.expectBodyContains, "response body")
}
})
}
}