Files
coder/aibridge/intercept/responses/reqpayload_test.go
T
Paweł Banaszewski e00e85765b chore: move aibridge library code into coder repo (#24190)
This PR merges code from `coder/aibridge` repository into `coder/coder`.
It was split into 4 PRs for easier review but stacked PRs will need to
be merged into this PR so all checks pass.

* https://github.com/coder/coder/pull/24190 -> raw code copy (this PR,
before merging PRs on top of it, it was just 1 commit:
https://github.com/coder/coder/commit/70d33f33200c7e77df910957595715f81f9bec24)
* https://github.com/coder/coder/pull/24570 -> update imports in
`coder/coder` to use copied code
* https://github.com/coder/coder/pull/24586 -> linter fixes and CI
integration (also added README.md)
* https://github.com/coder/coder/pull/24571 -> added exclude to
scripts/check_emdash.sh check

Original PR message (before PR squash):
Moves coder/aibridge code into coder/coder repository.

Omitted files:

- `go.mod`, `go.sum`, `.gitignore`, `.github/workflows/ci.yml,`
`Makefile`, `LICENSE`, `README.md` (modified README.md is added later)
- `.github`, `example`, `buildinfo,` `scripts` directories

Simple verification script (will list omitted files)

```
tmp=$(mktemp -d)
echo "$tmp"
git clone --depth=1 https://github.com/coder/aibridge "$tmp/aibridge"
git clone --depth=1 --branch pb/aibridge-code-move https://github.com/coder/coder "$tmp/coder"
diff -rq --exclude=.git "$tmp/aibridge" "$tmp/coder/aibridge"
# rm -rf "$tmp"
```
2026-04-22 17:01:01 +02:00

528 lines
15 KiB
Go

package responses //nolint:testpackage // tests unexported internals
import (
"encoding/json"
"fmt"
"testing"
"github.com/openai/openai-go/v3"
"github.com/openai/openai-go/v3/responses"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/aibridge/fixtures"
"github.com/coder/coder/v2/aibridge/utils"
)
func TestNewRequestPayload(t *testing.T) {
t.Parallel()
payloadWithWrongTypes := []byte(`{"model":123,"stream":"yes","input":42,"background":"nope"}`)
tests := []struct {
name string
raw []byte
want []byte
model string
stream bool
background bool
err string
}{
{
name: "empty payload",
raw: nil,
want: nil,
err: "empty request body",
},
{
name: "invalid json",
raw: []byte(`{broken`),
want: nil,
err: "invalid JSON payload",
},
{
// RequestPayload just checks for JSON validity,
// schema errors are not surfaced here and
// the original body is preserved for upstream handling
// similar to how reverse proxy would behave.
name: "wrong field types still wrap",
raw: payloadWithWrongTypes,
want: payloadWithWrongTypes,
model: "123",
stream: false,
background: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
payload, err := NewRequestPayload(tc.raw)
if tc.err != "" {
require.ErrorContains(t, err, tc.err)
assert.Nil(t, payload)
return
}
require.NoError(t, err)
require.NotNil(t, payload)
assert.EqualValues(t, tc.want, payload)
assert.Equal(t, tc.model, payload.model())
assert.Equal(t, tc.stream, payload.Stream())
assert.Equal(t, tc.background, payload.background())
})
}
}
func TestCorrelatingToolCallID(t *testing.T) {
t.Parallel()
tests := []struct {
name string
payload []byte
wantCall *string
}{
{
name: "no input items",
payload: []byte(`{"model":"gpt-4o"}`),
},
{
name: "empty input array",
payload: []byte(`{"model":"gpt-4o","input":[]}`),
},
{
name: "no function_call_output items",
payload: []byte(`{"model":"gpt-4o","input":[{"role":"user","content":"hi"}]}`),
},
{
name: "single function_call_output",
payload: []byte(`{"model":"gpt-4o","input":[{"role":"user","content":"hi"},{"type":"function_call_output","call_id":"call_abc","output":"result"}]}`),
wantCall: utils.PtrTo("call_abc"),
},
{
name: "multiple function_call_outputs returns last",
payload: []byte(`{"model":"gpt-4o","input":[{"type":"function_call_output","call_id":"call_first","output":"r1"},{"role":"user","content":"hi"},{"type":"function_call_output","call_id":"call_second","output":"r2"}]}`),
wantCall: utils.PtrTo("call_second"),
},
{
name: "last input is not a tool result",
payload: []byte(`{"model":"gpt-4o","input":[{"type":"function_call_output","call_id":"call_first","output":"r1"},{"role":"user","content":"hi"}]}`),
},
{
name: "missing call id",
payload: []byte(`{"input":[{"type":"function_call_output","output":"ok"}]}`),
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
callID := mustPayload(t, tc.payload).correlatingToolCallID()
assert.Equal(t, tc.wantCall, callID)
})
}
}
func TestLastUserPrompt(t *testing.T) {
t.Parallel()
tests := []struct {
name string
reqPayload []byte
expect string
found bool
expectErr string
}{
{
name: "no input",
reqPayload: []byte(`{}`),
found: false,
},
{
name: "input null",
reqPayload: []byte(`{"input": null}`),
found: false,
},
{
name: "empty input array",
reqPayload: []byte(`{"input": []}`),
found: false,
},
{
name: "input empty string",
reqPayload: []byte(`{"input": ""}`),
expect: "",
found: true,
},
{
name: "input array content empty string",
reqPayload: []byte(`{"input": [{"role": "user", "content": ""}]}`),
expect: "",
found: true,
},
{
name: "input array content array empty string",
reqPayload: []byte(`{"input": [ { "role": "user", "content": [{"type": "input_text", "text": ""}] } ] }`),
expect: "",
found: true,
},
{
name: "input array content array multiple inputs",
reqPayload: []byte(`{"input": [ { "role": "user", "content": [{"type": "input_text", "text": "a"}, {"type": "input_text", "text": "b"}] } ] }`),
expect: "a\nb",
found: true,
},
{
name: "simple string input",
reqPayload: fixtures.Request(t, fixtures.OaiResponsesBlockingSimple),
expect: "tell me a joke",
found: true,
},
{
name: "array single input string",
reqPayload: fixtures.Request(t, fixtures.OaiResponsesBlockingSingleBuiltinTool),
expect: "Is 3 + 5 a prime number? Use the add function to calculate the sum.",
found: true,
},
{
name: "array multiple items content objects",
reqPayload: fixtures.Request(t, fixtures.OaiResponsesStreamingCodex),
expect: "hello",
found: true,
},
{
name: "input integer",
reqPayload: []byte(`{"input": 123}`),
expectErr: "unexpected input type",
},
{
name: "no user role",
reqPayload: []byte(`{"input": [{"role": "assistant", "content": "hello"}]}`),
found: false,
},
{
name: "user with empty content array",
reqPayload: []byte(`{"input": [{"role": "user", "content": []}]}`),
found: false,
},
{
name: "user content missing",
reqPayload: []byte(`{"input": [{"role": "user"}]}`),
found: false,
},
{
name: "user content null",
reqPayload: []byte(`{"input": [{"role": "user", "content": null}]}`),
found: false,
},
{
name: "input array integer",
reqPayload: []byte(`{"input": [{"role": "user", "content": 123}]}`),
expectErr: "unexpected input content type",
},
{
name: "user with non input_text content",
reqPayload: []byte(`{"input": [{"role": "user", "content": [{"type": "input_image", "url": "http://example.com/img.png"}]}]}`),
found: false,
},
{
name: "user content not last",
reqPayload: []byte(`{"input": [ {"role": "user", "content":"input"}, {"role": "assistant", "content": "hello"} ]}`),
found: false,
},
{
name: "input array content array integer",
reqPayload: []byte(`{"input": [ { "role": "user", "content": [{"type": "input_text", "text": 123}] } ] }`),
found: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
prompt, promptFound, err := mustPayload(t, tc.reqPayload).lastUserPrompt(t.Context(), slog.Make())
if tc.expectErr != "" {
require.ErrorContains(t, err, tc.expectErr)
return
}
require.NoError(t, err)
require.Equal(t, tc.expect, prompt)
require.Equal(t, tc.found, promptFound)
})
}
}
func TestInjectTools(t *testing.T) {
t.Parallel()
tests := []struct {
name string
raw []byte
injected []responses.ToolUnionParam
wantNames []string
wantErr string
wantSame bool
}{
{
name: "appends to existing tools",
raw: []byte(`{"model":"gpt-4o","input":"hello","tools":[{"type":"function","name":"existing"}]}`),
injected: []responses.ToolUnionParam{injectedFunctionTool("injected")},
wantNames: []string{"existing", "injected"},
},
{
name: "adds tools when none exist",
raw: []byte(`{"model":"gpt-4o","input":"hello"}`),
injected: []responses.ToolUnionParam{injectedFunctionTool("injected")},
wantNames: []string{"injected"},
},
{
name: "adds to empty tools array",
raw: []byte(`{"model":"gpt-4o","input":"hello","tools":[]}`),
injected: []responses.ToolUnionParam{injectedFunctionTool("injected")},
wantNames: []string{"injected"},
},
{
name: "appends multiple injected tools",
raw: []byte(`{"model":"gpt-4o","input":"hello","tools":[{"type":"function","name":"existing"}]}`),
injected: []responses.ToolUnionParam{
injectedFunctionTool("injected-one"),
injectedFunctionTool("injected-two"),
},
wantNames: []string{"existing", "injected-one", "injected-two"},
},
{
name: "empty injected tools is no op",
raw: []byte(`{"model":"gpt-4o","input":"hello","tools":[{"type":"function","name":"existing"}]}`),
wantSame: true,
},
{
name: "errors on unsupported tools shape",
raw: []byte(`{"model":"gpt-4o","input":"hello","tools":"bad"}`),
injected: []responses.ToolUnionParam{injectedFunctionTool("injected")},
wantErr: "failed to get existing tools: unsupported 'tools' type: String",
wantSame: true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
p := mustPayload(t, tc.raw)
updated, err := p.injectTools(tc.injected)
if tc.wantErr != "" {
require.EqualError(t, err, tc.wantErr)
} else {
require.NoError(t, err)
}
if tc.wantSame {
require.EqualValues(t, tc.raw, updated)
}
for i, wantName := range tc.wantNames {
path := fmt.Sprintf("tools.%d.name", i) // name of the i-th element in tools array
require.Equal(t, wantName, gjson.GetBytes(updated, path).String())
}
})
}
}
func TestDisableParallelToolCalls(t *testing.T) {
t.Parallel()
tests := []struct {
name string
raw []byte
}{
{
name: "sets flag when not present",
raw: []byte(`{"model":"gpt-4o"}`),
},
{
name: "overrides when already true",
raw: []byte(`{"model":"gpt-4o","parallel_tool_calls":true}`),
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
p := mustPayload(t, tc.raw)
updated, err := p.disableParallelToolCalls()
require.NoError(t, err)
assert.False(t, gjson.GetBytes(updated, "parallel_tool_calls").Bool())
})
}
}
func TestAppendInputItems(t *testing.T) {
t.Parallel()
tests := []struct {
name string
raw []byte
items []responses.ResponseInputItemUnionParam
wantErr string
wantSame bool
wantPaths map[string]string
}{
{
name: "string input becomes user message",
raw: []byte(`{"model":"gpt-4o","input":"hello"}`),
items: []responses.ResponseInputItemUnionParam{responses.ResponseInputItemParamOfFunctionCallOutput("call_123", "done")},
wantPaths: map[string]string{
"input.0.role": "user",
"input.0.content": "hello",
"input.1.type": "function_call_output",
"input.1.call_id": "call_123",
},
},
{
name: "array input is preserved and appended",
raw: []byte(`{"model":"gpt-4o","input":[{"role":"user","content":"hello"}]}`),
items: []responses.ResponseInputItemUnionParam{responses.ResponseInputItemParamOfFunctionCallOutput("call_123", "done")},
wantPaths: map[string]string{
"input.0.content": "hello",
"input.1.call_id": "call_123",
},
},
{
name: "unsupported input shape errors during rewrite",
raw: []byte(`{"model":"gpt-4o","input":123}`),
items: []responses.ResponseInputItemUnionParam{responses.ResponseInputItemParamOfFunctionCallOutput("call_123", "done")},
wantErr: "failed to get existing 'input' items: unsupported 'input' type: Number",
wantSame: true,
},
{
name: "missing input creates appended input",
raw: []byte(`{"model":"gpt-4o"}`),
items: []responses.ResponseInputItemUnionParam{responses.ResponseInputItemParamOfFunctionCallOutput("call_123", "done")},
wantPaths: map[string]string{
"input.0.type": "function_call_output",
"input.0.call_id": "call_123",
},
},
{
name: "null input creates appended input",
raw: []byte(`{"model":"gpt-4o","input":null}`),
items: []responses.ResponseInputItemUnionParam{responses.ResponseInputItemParamOfFunctionCallOutput("call_123", "done")},
wantPaths: map[string]string{
"input.0.type": "function_call_output",
"input.0.call_id": "call_123",
},
},
{
name: "multiple output item types are appended in order",
raw: []byte(`{"model":"gpt-4o","input":[{"role":"user","content":"hello"}]}`),
items: []responses.ResponseInputItemUnionParam{
responses.ResponseInputItemParamOfCompaction("encrypted-content"),
responses.ResponseInputItemParamOfOutputMessage([]responses.ResponseOutputMessageContentUnionParam{
{
OfOutputText: &responses.ResponseOutputTextParam{
Annotations: []responses.ResponseOutputTextAnnotationUnionParam{},
Text: "assistant text",
},
},
}, "msg_123", responses.ResponseOutputMessageStatusCompleted),
responses.ResponseInputItemParamOfFileSearchCall("fs_123", []string{"hello"}, "completed"),
responses.ResponseInputItemParamOfImageGenerationCall("img_123", "base64-image", "completed"),
},
wantPaths: map[string]string{
"input.0.content": "hello",
"input.1.type": "compaction",
"input.2.type": "message",
"input.2.id": "msg_123",
"input.2.content.0.type": "output_text",
"input.2.content.0.text": "assistant text",
"input.3.type": "file_search_call",
"input.3.id": "fs_123",
"input.4.type": "image_generation_call",
"input.4.id": "img_123",
},
},
{
name: "empty appended items is no op",
raw: []byte(`{"model":"gpt-4o","input":"hello"}`),
wantSame: true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
p := mustPayload(t, tc.raw)
updated, err := p.appendInputItems(tc.items)
if tc.wantErr != "" {
require.EqualError(t, err, tc.wantErr)
} else {
require.NoError(t, err)
}
if tc.wantSame {
require.EqualValues(t, tc.raw, updated)
}
for path, want := range tc.wantPaths {
require.Equal(t, want, gjson.GetBytes(updated, path).String())
}
})
}
}
func TestChainedRewritesProduceValidJSON(t *testing.T) {
t.Parallel()
p := mustPayload(t, []byte(`{"model":"gpt-4o","input":"hello"}`))
p, err := p.injectTools([]responses.ToolUnionParam{{
OfFunction: &responses.FunctionToolParam{
Name: "tool_a",
Description: openai.String("tool"),
Strict: openai.Bool(false),
Parameters: map[string]any{
"type": "object",
},
},
}})
require.NoError(t, err)
p, err = p.disableParallelToolCalls()
require.NoError(t, err)
p, err = p.appendInputItems([]responses.ResponseInputItemUnionParam{
responses.ResponseInputItemParamOfFunctionCallOutput("call_123", "done"),
})
require.NoError(t, err)
assert.True(t, json.Valid(p), "chained rewrites should produce valid JSON")
assert.Equal(t, "tool_a", gjson.GetBytes(p, "tools.0.name").String())
assert.Equal(t, "call_123", gjson.GetBytes(p, "input.1.call_id").String())
assert.False(t, gjson.GetBytes(p, "parallel_tool_calls").Bool())
}
func injectedFunctionTool(name string) responses.ToolUnionParam {
return responses.ToolUnionParam{
OfFunction: &responses.FunctionToolParam{
Name: name,
Description: openai.String("tool"),
Strict: openai.Bool(false),
Parameters: map[string]any{
"type": "object",
},
},
}
}
func mustPayload(t *testing.T, raw []byte) RequestPayload {
t.Helper()
payload, err := NewRequestPayload(raw)
require.NoError(t, err)
return payload
}