Files
coder/aibridge/intercept/chatcompletions/streaming_test.go
T
Paweł Banaszewski e00e85765b chore: move aibridge library code into coder repo (#24190)
This PR merges code from `coder/aibridge` repository into `coder/coder`.
It was split into 4 PRs for easier review but stacked PRs will need to
be merged into this PR so all checks pass.

* https://github.com/coder/coder/pull/24190 -> raw code copy (this PR,
before merging PRs on top of it, it was just 1 commit:
https://github.com/coder/coder/commit/70d33f33200c7e77df910957595715f81f9bec24)
* https://github.com/coder/coder/pull/24570 -> update imports in
`coder/coder` to use copied code
* https://github.com/coder/coder/pull/24586 -> linter fixes and CI
integration (also added README.md)
* https://github.com/coder/coder/pull/24571 -> added exclude to
scripts/check_emdash.sh check

Original PR message (before PR squash):
Moves coder/aibridge code into coder/coder repository.

Omitted files:

- `go.mod`, `go.sum`, `.gitignore`, `.github/workflows/ci.yml,`
`Makefile`, `LICENSE`, `README.md` (modified README.md is added later)
- `.github`, `example`, `buildinfo,` `scripts` directories

Simple verification script (will list omitted files)

```
tmp=$(mktemp -d)
echo "$tmp"
git clone --depth=1 https://github.com/coder/aibridge "$tmp/aibridge"
git clone --depth=1 --branch pb/aibridge-code-move https://github.com/coder/coder "$tmp/coder"
diff -rq --exclude=.git "$tmp/aibridge" "$tmp/coder/aibridge"
# rm -rf "$tmp"
```
2026-04-22 17:01:01 +02:00

113 lines
3.7 KiB
Go

package chatcompletions_test
import (
"net/http"
"net/http/httptest"
"strconv"
"testing"
"github.com/google/uuid"
"github.com/openai/openai-go/v3"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/aibridge/config"
"github.com/coder/coder/v2/aibridge/intercept"
"github.com/coder/coder/v2/aibridge/intercept/chatcompletions"
"github.com/coder/coder/v2/aibridge/internal/testutil"
)
// Test that when the upstream provider returns an error before streaming starts,
// the error status code and body are correctly relayed to the client.
func TestStreamingInterception_RelaysUpstreamErrorToClient(t *testing.T) {
t.Parallel()
tests := []struct {
name string
statusCode int
responseBody string
expectedErrStr string
expectedBody string
}{
{
name: "bad request error",
statusCode: http.StatusBadRequest,
responseBody: `{"error":{"message":"Invalid request","type":"invalid_request_error","code":"invalid_request"}}`,
expectedErrStr: strconv.Itoa(http.StatusBadRequest),
expectedBody: "invalid_request",
},
{
name: "rate limit error",
statusCode: http.StatusTooManyRequests,
responseBody: `{"error":{"message":"Rate limit exceeded","type":"rate_limit_error","code":"rate_limit_exceeded"}}`,
expectedErrStr: strconv.Itoa(http.StatusTooManyRequests),
expectedBody: "rate_limit",
},
{
name: "internal server error",
statusCode: http.StatusInternalServerError,
responseBody: `{"error":{"message":"Internal server error","type":"server_error","code":"internal_error"}}`,
expectedErrStr: strconv.Itoa(http.StatusInternalServerError),
expectedBody: "server_error",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
// Setup a mock server that returns an error immediately (before any streaming)
mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("x-should-retry", "false")
w.WriteHeader(tc.statusCode)
_, _ = w.Write([]byte(tc.responseBody))
}))
t.Cleanup(mockServer.Close)
// Create interceptor with mock server URL
cfg := config.OpenAI{
BaseURL: mockServer.URL,
Key: "test-key",
}
req := &chatcompletions.ChatCompletionNewParamsWrapper{
ChatCompletionNewParams: openai.ChatCompletionNewParams{
Model: "gpt-4",
Messages: []openai.ChatCompletionMessageParamUnion{
openai.UserMessage("hello"),
},
},
Stream: true,
}
// Create test request
w := httptest.NewRecorder()
httpReq := httptest.NewRequest(http.MethodPost, "/chat/completions", nil)
tracer := otel.Tracer("test")
interceptor := chatcompletions.NewStreamingInterceptor(uuid.New(), req, config.ProviderOpenAI, cfg, httpReq.Header, "Authorization", tracer, intercept.CredentialInfo{})
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: false}).Leveled(slog.LevelDebug)
interceptor.Setup(logger, &testutil.MockRecorder{}, nil)
// Process the request
err := interceptor.ProcessRequest(w, httpReq)
// Verify error was returned
require.Error(t, err)
assert.Contains(t, err.Error(), tc.expectedErrStr)
// Verify status code was written to response
assert.Equal(t, tc.statusCode, w.Code, "expected status code to be relayed to client")
// Verify error body contains expected error info
body := w.Body.String()
assert.Contains(t, body, tc.expectedBody, "expected error type in response body")
})
}
}