Files
coder/aibridge/intercept/messages/reqpayload.go
T
Paweł Banaszewski e00e85765b chore: move aibridge library code into coder repo (#24190)
This PR merges code from `coder/aibridge` repository into `coder/coder`.
It was split into 4 PRs for easier review but stacked PRs will need to
be merged into this PR so all checks pass.

* https://github.com/coder/coder/pull/24190 -> raw code copy (this PR,
before merging PRs on top of it, it was just 1 commit:
https://github.com/coder/coder/commit/70d33f33200c7e77df910957595715f81f9bec24)
* https://github.com/coder/coder/pull/24570 -> update imports in
`coder/coder` to use copied code
* https://github.com/coder/coder/pull/24586 -> linter fixes and CI
integration (also added README.md)
* https://github.com/coder/coder/pull/24571 -> added exclude to
scripts/check_emdash.sh check

Original PR message (before PR squash):
Moves coder/aibridge code into coder/coder repository.

Omitted files:

- `go.mod`, `go.sum`, `.gitignore`, `.github/workflows/ci.yml,`
`Makefile`, `LICENSE`, `README.md` (modified README.md is added later)
- `.github`, `example`, `buildinfo,` `scripts` directories

Simple verification script (will list omitted files)

```
tmp=$(mktemp -d)
echo "$tmp"
git clone --depth=1 https://github.com/coder/aibridge "$tmp/aibridge"
git clone --depth=1 --branch pb/aibridge-code-move https://github.com/coder/coder "$tmp/coder"
diff -rq --exclude=.git "$tmp/aibridge" "$tmp/coder/aibridge"
# rm -rf "$tmp"
```
2026-04-22 17:01:01 +02:00

413 lines
14 KiB
Go

package messages
import (
"bytes"
"encoding/json"
"net/http"
"slices"
"github.com/anthropics/anthropic-sdk-go"
"github.com/anthropics/anthropic-sdk-go/shared/constant"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"golang.org/x/xerrors"
)
const (
// Absolute JSON paths from the request root.
messagesReqPathMessages = "messages"
messagesReqPathMaxTokens = "max_tokens"
messagesReqPathModel = "model"
messagesReqPathOutputConfig = "output_config"
messagesReqPathOutputConfigEffort = "output_config.effort"
messagesReqPathMetadata = "metadata"
messagesReqPathServiceTier = "service_tier"
messagesReqPathContainer = "container"
messagesReqPathInferenceGeo = "inference_geo"
messagesReqPathContextManagement = "context_management"
messagesReqPathStream = "stream"
messagesReqPathThinking = "thinking"
messagesReqPathThinkingBudgetTokens = "thinking.budget_tokens"
messagesReqPathThinkingType = "thinking.type"
messagesReqPathToolChoice = "tool_choice"
messagesReqPathToolChoiceDisableParallel = "tool_choice.disable_parallel_tool_use"
messagesReqPathToolChoiceType = "tool_choice.type"
messagesReqPathTools = "tools"
// Relative field names used within sub-objects.
messagesReqFieldContent = "content"
messagesReqFieldRole = "role"
messagesReqFieldText = "text"
messagesReqFieldToolUseID = "tool_use_id"
messagesReqFieldType = "type"
)
const (
constAdaptive = "adaptive"
constDisabled = "disabled"
constEnabled = "enabled"
)
var (
constAny = string(constant.ValueOf[constant.Any]())
constAuto = string(constant.ValueOf[constant.Auto]())
constNone = string(constant.ValueOf[constant.None]())
constText = string(constant.ValueOf[constant.Text]())
constTool = string(constant.ValueOf[constant.Tool]())
constToolResult = string(constant.ValueOf[constant.ToolResult]())
constUser = string(anthropic.MessageParamRoleUser)
// bedrockUnsupportedFields are top-level fields present in the Anthropic Messages
// API that are absent from the Bedrock request body schema. Sending them results
// in a 400 "Extra inputs are not permitted" error.
//
// Anthropic API fields: https://platform.claude.com/docs/en/api/messages/create
// Bedrock request body: https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages-request-response.html
bedrockUnsupportedFields = []string{
messagesReqPathMetadata,
messagesReqPathServiceTier,
messagesReqPathContainer,
messagesReqPathInferenceGeo,
}
// bedrockBetaGatedFields maps body fields to the beta flag that enables them.
// If the beta flag is present in the (already-filtered) Anthropic-Beta header,
// the field is kept; otherwise it is stripped. Model-specific beta flags must
// be removed from the header before this check (see filterBedrockBetaFlags).
bedrockBetaGatedFields = map[string]string{
// output_config requires the effort beta (Opus 4.5 only).
messagesReqPathOutputConfig: "effort-2025-11-24",
// context_management requires the context-management beta (Sonnet 4.5, Haiku 4.5).
messagesReqPathContextManagement: "context-management-2025-06-27",
}
)
// RequestPayload is raw JSON bytes of an Anthropic Messages API request.
// Methods provide package-specific reads and rewrites while preserving the
// original body for upstream pass-through.
type RequestPayload []byte
func NewRequestPayload(raw []byte) (RequestPayload, error) {
if len(bytes.TrimSpace(raw)) == 0 {
return nil, xerrors.New("messages empty request body")
}
if !json.Valid(raw) {
return nil, xerrors.New("messages invalid JSON request body")
}
return RequestPayload(raw), nil
}
func (p RequestPayload) Stream() bool {
v := gjson.GetBytes(p, messagesReqPathStream)
if !v.IsBool() {
return false
}
return v.Bool()
}
func (p RequestPayload) model() string {
return gjson.GetBytes(p, messagesReqPathModel).Str
}
func (p RequestPayload) correlatingToolCallID() *string {
messages := gjson.GetBytes(p, messagesReqPathMessages)
if !messages.IsArray() {
return nil
}
messageItems := messages.Array()
if len(messageItems) == 0 {
return nil
}
content := messageItems[len(messageItems)-1].Get(messagesReqFieldContent)
if !content.IsArray() {
return nil
}
contentItems := content.Array()
for idx := len(contentItems) - 1; idx >= 0; idx-- {
contentItem := contentItems[idx]
if contentItem.Get(messagesReqFieldType).String() != constToolResult {
continue
}
toolUseID := contentItem.Get(messagesReqFieldToolUseID).String()
if toolUseID == "" {
continue
}
return &toolUseID
}
return nil
}
// lastUserPrompt returns the prompt text from the last user message. If no prompt
// is found, it returns empty string, false, nil. Unexpected shapes are treated as
// unsupported and do not fail the request path.
func (p RequestPayload) lastUserPrompt() (string, bool, error) {
messages := gjson.GetBytes(p, messagesReqPathMessages)
if !messages.Exists() || messages.Type == gjson.Null {
return "", false, nil
}
if !messages.IsArray() {
return "", false, xerrors.Errorf("unexpected messages type: %s", messages.Type)
}
messageItems := messages.Array()
if len(messageItems) == 0 {
return "", false, nil
}
lastMessage := messageItems[len(messageItems)-1]
if lastMessage.Get(messagesReqFieldRole).String() != constUser {
return "", false, nil
}
content := lastMessage.Get(messagesReqFieldContent)
if !content.Exists() || content.Type == gjson.Null {
return "", false, nil
}
if content.Type == gjson.String {
return content.String(), true, nil
}
if !content.IsArray() {
return "", false, xerrors.Errorf("unexpected message content type: %s", content.Type)
}
contentItems := content.Array()
for idx := len(contentItems) - 1; idx >= 0; idx-- {
contentItem := contentItems[idx]
if contentItem.Get(messagesReqFieldType).String() != constText {
continue
}
text := contentItem.Get(messagesReqFieldText)
if text.Type != gjson.String {
continue
}
return text.String(), true, nil
}
return "", false, nil
}
func (p RequestPayload) injectTools(injected []anthropic.ToolUnionParam) (RequestPayload, error) {
if len(injected) == 0 {
return p, nil
}
existing, err := p.tools()
if err != nil {
return p, xerrors.Errorf("get existing tools: %w", err)
}
// Using []json.Marshaler to merge differently-typed slices ([]anthropic.ToolUnionParam
// and []json.Marshaler containing json.RawMessage) keeps JSON re-marshalings to a minimum:
// sjson.SetBytes marshals each element exactly once, and json.RawMessage
// elements are passed through without re-serialization.
allTools := make([]json.Marshaler, 0, len(injected)+len(existing))
for _, tool := range injected {
allTools = append(allTools, tool)
}
for _, e := range existing {
allTools = append(allTools, e)
}
return p.set(messagesReqPathTools, allTools)
}
func (p RequestPayload) disableParallelToolCalls() (RequestPayload, error) {
toolChoice := gjson.GetBytes(p, messagesReqPathToolChoice)
// If no tool_choice was defined, assume auto.
// See https://platform.claude.com/docs/en/agents-and-tools/tool-use/implement-tool-use#parallel-tool-use.
if !toolChoice.Exists() || toolChoice.Type == gjson.Null {
updated, err := p.set(messagesReqPathToolChoiceType, constAuto)
if err != nil {
return p, xerrors.Errorf("set tool choice type: %w", err)
}
return updated.set(messagesReqPathToolChoiceDisableParallel, true)
}
if !toolChoice.IsObject() {
return p, xerrors.Errorf("unsupported tool_choice type: %s", toolChoice.Type)
}
toolChoiceType := gjson.GetBytes(p, messagesReqPathToolChoiceType)
if toolChoiceType.Exists() && toolChoiceType.Type != gjson.String {
return p, xerrors.Errorf("unsupported tool_choice.type type: %s", toolChoiceType.Type)
}
switch toolChoiceType.String() {
case "":
updated, err := p.set(messagesReqPathToolChoiceType, constAuto)
if err != nil {
return p, xerrors.Errorf("set tool_choice.type: %w", err)
}
return updated.set(messagesReqPathToolChoiceDisableParallel, true)
case constAuto, constAny, constTool:
return p.set(messagesReqPathToolChoiceDisableParallel, true)
case constNone:
return p, nil
default:
return p, xerrors.Errorf("unsupported tool_choice.type value: %q", toolChoiceType.String())
}
}
func (p RequestPayload) appendedMessages(newMessages []anthropic.MessageParam) (RequestPayload, error) {
if len(newMessages) == 0 {
return p, nil
}
existing, err := p.messages()
if err != nil {
return p, xerrors.Errorf("get existing messages: %w", err)
}
// Using []json.Marshaler to merge differently-typed slices ([]json.Marshaler containing
// json.RawMessage and []anthropic.MessageParam) keeps JSON re-marshalings
// to a minimum: sjson.SetBytes marshals each element exactly once, and
// json.RawMessage elements are passed through without re-serialization.
allMessages := make([]json.Marshaler, 0, len(existing)+len(newMessages))
for _, e := range existing {
allMessages = append(allMessages, e)
}
for _, new := range newMessages {
allMessages = append(allMessages, new)
}
return p.set(messagesReqPathMessages, allMessages)
}
func (p RequestPayload) withModel(model string) (RequestPayload, error) {
return p.set(messagesReqPathModel, model)
}
func (p RequestPayload) messages() ([]json.RawMessage, error) {
messages := gjson.GetBytes(p, messagesReqPathMessages)
if !messages.Exists() || messages.Type == gjson.Null {
return nil, nil
}
if !messages.IsArray() {
return nil, xerrors.Errorf("unsupported messages type: %s", messages.Type)
}
return p.resultToRawMessage(messages.Array()), nil
}
func (p RequestPayload) tools() ([]json.RawMessage, error) {
tools := gjson.GetBytes(p, messagesReqPathTools)
if !tools.Exists() || tools.Type == gjson.Null {
return nil, nil
}
if !tools.IsArray() {
return nil, xerrors.Errorf("unsupported tools type: %s", tools.Type)
}
return p.resultToRawMessage(tools.Array()), nil
}
func (RequestPayload) resultToRawMessage(items []gjson.Result) []json.RawMessage {
// gjson.Result conversion to json.RawMessage is needed because
// gjson.Result does not implement json.Marshaler. It would
// serialize its struct fields instead of the raw JSON it represents.
rawMessages := make([]json.RawMessage, 0, len(items))
for _, item := range items {
rawMessages = append(rawMessages, json.RawMessage(item.Raw))
}
return rawMessages
}
// convertAdaptiveThinkingForBedrock converts thinking.type "adaptive" to "enabled" with a calculated budget_tokens
// conversion is needed for Bedrock models that does not support the "adaptive" thinking.type
func (p RequestPayload) convertAdaptiveThinkingForBedrock() (RequestPayload, error) {
thinkingType := gjson.GetBytes(p, messagesReqPathThinkingType)
if thinkingType.String() != constAdaptive {
return p, nil
}
maxTokens := gjson.GetBytes(p, messagesReqPathMaxTokens).Int()
if maxTokens <= 0 {
// max_tokens is required by messages API
return p, xerrors.New("max_tokens: field required")
}
effort := gjson.GetBytes(p, messagesReqPathOutputConfigEffort).String()
// Enabled thinking type requires budget_tokens set.
// Heuristically calculate value based on the effort level.
// Effort-to-ratio mapping adapted from OpenRouter:
// https://openrouter.ai/docs/guides/best-practices/reasoning-tokens#reasoning-effort-level
var ratio float64
switch effort {
case "low":
ratio = 0.2
case "medium":
ratio = 0.5
case "max":
ratio = 0.95
default: // "high" or absent (high is the default effort)
ratio = 0.8
}
// budget_tokens must be ≥ 1024 && < max_tokens. If the calculated budget
// doesn't meet the minimum, disable thinking entirely rather than forcing
// an artificially high budget that would starve the output.
// https://platform.claude.com/docs/en/api/messages/create#create.thinking
// https://platform.claude.com/docs/en/build-with-claude/extended-thinking#how-to-use-extended-thinking
budgetTokens := int64(float64(maxTokens) * ratio)
if budgetTokens < 1024 {
return p.set(messagesReqPathThinking, map[string]string{"type": constDisabled})
}
return p.set(messagesReqPathThinking, map[string]any{
"type": constEnabled,
"budget_tokens": budgetTokens,
})
}
// removeUnsupportedBedrockFields strips top-level fields that Bedrock does not
// support from the payload. Fields that are gated behind a beta flag are only
// removed when the corresponding flag is absent from the Anthropic-Beta header.
// Model-specific beta flags must already be filtered from the header before
// calling this method (see filterBedrockBetaFlags).
func (p RequestPayload) removeUnsupportedBedrockFields(headers http.Header) (RequestPayload, error) {
var payloadMap map[string]any
if err := json.Unmarshal(p, &payloadMap); err != nil {
return p, xerrors.Errorf("failed to unmarshal request payload when removing unsupported Bedrock fields: %w", err)
}
// Always strip unconditionally unsupported fields.
for _, field := range bedrockUnsupportedFields {
delete(payloadMap, field)
}
// Strip beta-gated fields only when their beta flag is missing.
betaValues := headers.Values("Anthropic-Beta")
for field, requiredFlag := range bedrockBetaGatedFields {
if !slices.Contains(betaValues, requiredFlag) {
delete(payloadMap, field)
}
}
result, err := json.Marshal(payloadMap)
if err != nil {
return p, xerrors.Errorf("failed to marshal request payload when removing unsupported Bedrock fields: %w", err)
}
return RequestPayload(result), nil
}
func (p RequestPayload) set(path string, value any) (RequestPayload, error) {
out, err := sjson.SetBytes(p, path, value)
if err != nil {
return p, xerrors.Errorf("set %s: %w", path, err)
}
return RequestPayload(out), nil
}