Files
coder/aibridge/intercept/responses/base_test.go
T
Paweł Banaszewski e00e85765b chore: move aibridge library code into coder repo (#24190)
This PR merges code from `coder/aibridge` repository into `coder/coder`.
It was split into 4 PRs for easier review but stacked PRs will need to
be merged into this PR so all checks pass.

* https://github.com/coder/coder/pull/24190 -> raw code copy (this PR,
before merging PRs on top of it, it was just 1 commit:
https://github.com/coder/coder/commit/70d33f33200c7e77df910957595715f81f9bec24)
* https://github.com/coder/coder/pull/24570 -> update imports in
`coder/coder` to use copied code
* https://github.com/coder/coder/pull/24586 -> linter fixes and CI
integration (also added README.md)
* https://github.com/coder/coder/pull/24571 -> added exclude to
scripts/check_emdash.sh check

Original PR message (before PR squash):
Moves coder/aibridge code into coder/coder repository.

Omitted files:

- `go.mod`, `go.sum`, `.gitignore`, `.github/workflows/ci.yml,`
`Makefile`, `LICENSE`, `README.md` (modified README.md is added later)
- `.github`, `example`, `buildinfo,` `scripts` directories

Simple verification script (will list omitted files)

```
tmp=$(mktemp -d)
echo "$tmp"
git clone --depth=1 https://github.com/coder/aibridge "$tmp/aibridge"
git clone --depth=1 --branch pb/aibridge-code-move https://github.com/coder/coder "$tmp/coder"
diff -rq --exclude=.git "$tmp/aibridge" "$tmp/coder/aibridge"
# rm -rf "$tmp"
```
2026-04-22 17:01:01 +02:00

385 lines
8.5 KiB
Go

package responses //nolint:testpackage // tests unexported internals
import (
"net/http"
"testing"
"time"
"github.com/google/uuid"
oairesponses "github.com/openai/openai-go/v3/responses"
"github.com/stretchr/testify/require"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/aibridge/internal/testutil"
"github.com/coder/coder/v2/aibridge/recorder"
)
func TestRecordPrompt(t *testing.T) {
t.Parallel()
tests := []struct {
name string
promptWasRecorded bool
prompt string
responseID string
wantRecorded bool
wantPrompt string
}{
{
name: "records_prompt_successfully",
prompt: "tell me a joke",
responseID: "resp_123",
wantRecorded: true,
wantPrompt: "tell me a joke",
},
{
name: "records_empty_prompt_successfully",
prompt: "",
responseID: "resp_123",
wantRecorded: true,
wantPrompt: "",
},
{
name: "skips_recording_on_empty_response_id",
prompt: "tell me a joke",
responseID: "",
wantRecorded: false,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
rec := &testutil.MockRecorder{}
id := uuid.New()
base := &responsesInterceptionBase{
id: id,
recorder: rec,
logger: slog.Make(),
}
base.recordUserPrompt(t.Context(), tc.responseID, tc.prompt)
prompts := rec.RecordedPromptUsages()
if tc.wantRecorded {
require.Len(t, prompts, 1)
require.Equal(t, id.String(), prompts[0].InterceptionID)
require.Equal(t, tc.responseID, prompts[0].MsgID)
require.Equal(t, tc.wantPrompt, prompts[0].Prompt)
} else {
require.Empty(t, prompts)
}
})
}
}
func TestRecordToolUsage(t *testing.T) {
t.Parallel()
id := uuid.MustParse("11111111-1111-1111-1111-111111111111")
tests := []struct {
name string
response *oairesponses.Response
expected []*recorder.ToolUsageRecord
}{
{
name: "nil_response",
response: nil,
expected: nil,
},
{
name: "empty_output",
response: &oairesponses.Response{
ID: "resp_123",
},
expected: nil,
},
{
name: "empty_tool_args",
response: &oairesponses.Response{
ID: "resp_456",
Output: []oairesponses.ResponseOutputItemUnion{
{
Type: "function_call",
CallID: "call_abc",
Name: "get_weather",
Arguments: "",
},
},
},
expected: []*recorder.ToolUsageRecord{
{
InterceptionID: id.String(),
MsgID: "resp_456",
ToolCallID: "call_abc",
Tool: "get_weather",
Args: "",
Injected: false,
},
},
},
{
name: "multiple_tool_calls",
response: &oairesponses.Response{
ID: "resp_789",
Output: []oairesponses.ResponseOutputItemUnion{
{
Type: "function_call",
CallID: "call_1",
Name: "get_weather",
Arguments: `{"location": "NYC"}`,
},
{
Type: "function_call",
CallID: "call_2",
Name: "bad_json_args",
Arguments: `{"bad": args`,
},
{
Type: "message",
ID: "msg_1",
Role: "assistant",
},
{
Type: "custom_tool_call",
CallID: "call_3",
Name: "search",
Input: `{\"query\": \"test\"}`,
},
{
Type: "function_call",
CallID: "call_4",
Name: "calculate",
Arguments: `{"a": 1, "b": 2}`,
},
},
},
expected: []*recorder.ToolUsageRecord{
{
InterceptionID: id.String(),
MsgID: "resp_789",
ToolCallID: "call_1",
Tool: "get_weather",
Args: map[string]any{"location": "NYC"},
Injected: false,
},
{
InterceptionID: id.String(),
MsgID: "resp_789",
ToolCallID: "call_2",
Tool: "bad_json_args",
Args: `{"bad": args`,
Injected: false,
},
{
InterceptionID: id.String(),
MsgID: "resp_789",
ToolCallID: "call_3",
Tool: "search",
Args: `{\"query\": \"test\"}`,
Injected: false,
},
{
InterceptionID: id.String(),
MsgID: "resp_789",
ToolCallID: "call_4",
Tool: "calculate",
Args: map[string]any{"a": float64(1), "b": float64(2)},
Injected: false,
},
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
rec := &testutil.MockRecorder{}
base := &responsesInterceptionBase{
id: id,
recorder: rec,
logger: slog.Make(),
}
base.recordNonInjectedToolUsage(t.Context(), tc.response)
tools := rec.RecordedToolUsages()
require.Len(t, tools, len(tc.expected))
for i, got := range tools {
got.CreatedAt = time.Time{}
require.Equal(t, tc.expected[i], got)
}
})
}
}
func TestParseJSONArgs(t *testing.T) {
t.Parallel()
tests := []struct {
name string
raw string
expected recorder.ToolArgs
}{
{
name: "empty_string",
raw: "",
expected: "",
},
{
name: "whitespace_only",
raw: " \t\n ",
expected: "",
},
{
name: "invalid_json",
raw: "{not valid json}",
expected: "{not valid json}",
},
{
name: "nested_object_with_trailing_spaces",
raw: ` {"user": {"name": "alice", "settings": {"theme": "dark", "notifications": true}}, "count": 42} `,
expected: map[string]any{
"user": map[string]any{
"name": "alice",
"settings": map[string]any{
"theme": "dark",
"notifications": true,
},
},
"count": float64(42),
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
base := &responsesInterceptionBase{}
result := base.parseFunctionCallJSONArgs(t.Context(), tc.raw)
require.Equal(t, tc.expected, result)
})
}
}
func TestRecordTokenUsage(t *testing.T) {
t.Parallel()
id := uuid.MustParse("22222222-2222-2222-2222-222222222222")
tests := []struct {
name string
response *oairesponses.Response
expected *recorder.TokenUsageRecord
}{
{
name: "nil_response",
response: nil,
expected: nil,
},
{
name: "with_all_token_details",
response: &oairesponses.Response{
ID: "resp_full",
Usage: oairesponses.ResponseUsage{
InputTokens: 10,
OutputTokens: 20,
TotalTokens: 30,
InputTokensDetails: oairesponses.ResponseUsageInputTokensDetails{
CachedTokens: 5,
},
OutputTokensDetails: oairesponses.ResponseUsageOutputTokensDetails{
ReasoningTokens: 5,
},
},
},
expected: &recorder.TokenUsageRecord{
InterceptionID: id.String(),
MsgID: "resp_full",
Input: 5, // 10 input - 5 cached
Output: 20,
CacheReadInputTokens: 5,
ExtraTokenTypes: map[string]int64{
"input_cached": 5,
"output_reasoning": 5,
"total_tokens": 30,
},
},
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
rec := &testutil.MockRecorder{}
base := &responsesInterceptionBase{
id: id,
recorder: rec,
logger: slog.Make(),
}
base.recordTokenUsage(t.Context(), tc.response)
tokens := rec.RecordedTokenUsages()
if tc.expected == nil {
require.Empty(t, tokens)
} else {
require.Len(t, tokens, 1)
got := tokens[0]
got.CreatedAt = time.Time{} // ignore time
require.Equal(t, tc.expected, got)
}
})
}
}
type mockResponseWriter struct {
headerCalled bool
writeCalled bool
writeHeaderCalled bool
}
func (mrw *mockResponseWriter) Header() http.Header {
mrw.headerCalled = true
return http.Header{}
}
func (mrw *mockResponseWriter) Write([]byte) (int, error) {
mrw.writeCalled = true
return 0, nil
}
func (mrw *mockResponseWriter) WriteHeader(statusCode int) {
mrw.writeHeaderCalled = true
}
func TestResponseCopierDoesntSendIfNoResponseReceived(t *testing.T) {
t.Parallel()
mrw := mockResponseWriter{}
respCopy := responseCopier{}
body := "test_body"
_, _ = respCopy.buff.Write([]byte(body)) // bytes.Buffer.Write never fails
err := respCopy.forwardResp(&mrw)
require.NoError(t, err)
require.False(t, mrw.headerCalled)
require.False(t, mrw.writeCalled)
require.False(t, mrw.writeHeaderCalled)
// after response is received data is forwarded
respCopy.responseReceived.Store(true)
err = respCopy.forwardResp(&mrw)
require.NoError(t, err)
require.True(t, mrw.headerCalled)
require.True(t, mrw.writeCalled)
require.True(t, mrw.writeHeaderCalled)
}