Files
coder/aibridge/intercept/responses/base.go
T
Paweł Banaszewski e00e85765b chore: move aibridge library code into coder repo (#24190)
This PR merges code from `coder/aibridge` repository into `coder/coder`.
It was split into 4 PRs for easier review but stacked PRs will need to
be merged into this PR so all checks pass.

* https://github.com/coder/coder/pull/24190 -> raw code copy (this PR,
before merging PRs on top of it, it was just 1 commit:
https://github.com/coder/coder/commit/70d33f33200c7e77df910957595715f81f9bec24)
* https://github.com/coder/coder/pull/24570 -> update imports in
`coder/coder` to use copied code
* https://github.com/coder/coder/pull/24586 -> linter fixes and CI
integration (also added README.md)
* https://github.com/coder/coder/pull/24571 -> added exclude to
scripts/check_emdash.sh check

Original PR message (before PR squash):
Moves coder/aibridge code into coder/coder repository.

Omitted files:

- `go.mod`, `go.sum`, `.gitignore`, `.github/workflows/ci.yml,`
`Makefile`, `LICENSE`, `README.md` (modified README.md is added later)
- `.github`, `example`, `buildinfo,` `scripts` directories

Simple verification script (will list omitted files)

```
tmp=$(mktemp -d)
echo "$tmp"
git clone --depth=1 https://github.com/coder/aibridge "$tmp/aibridge"
git clone --depth=1 --branch pb/aibridge-code-move https://github.com/coder/coder "$tmp/coder"
diff -rq --exclude=.git "$tmp/aibridge" "$tmp/coder/aibridge"
# rm -rf "$tmp"
```
2026-04-22 17:01:01 +02:00

417 lines
14 KiB
Go

package responses
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
"github.com/openai/openai-go/v3/option"
"github.com/openai/openai-go/v3/responses"
"github.com/openai/openai-go/v3/shared/constant"
"github.com/tidwall/gjson"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/aibridge/config"
aibcontext "github.com/coder/coder/v2/aibridge/context"
"github.com/coder/coder/v2/aibridge/intercept"
"github.com/coder/coder/v2/aibridge/intercept/apidump"
"github.com/coder/coder/v2/aibridge/mcp"
"github.com/coder/coder/v2/aibridge/recorder"
"github.com/coder/coder/v2/aibridge/tracing"
"github.com/coder/quartz"
)
const (
requestTimeout = time.Second * 600
)
type responsesInterceptionBase struct {
id uuid.UUID
providerName string
// clientHeaders are the original HTTP headers from the client request.
clientHeaders http.Header
authHeaderName string
reqPayload RequestPayload
cfg config.OpenAI
recorder recorder.Recorder
mcpProxy mcp.ServerProxier
logger slog.Logger
tracer trace.Tracer
credential intercept.CredentialInfo
}
func (i *responsesInterceptionBase) newResponsesService() responses.ResponseService {
opts := []option.RequestOption{option.WithBaseURL(i.cfg.BaseURL), option.WithAPIKey(i.cfg.Key)}
if i.cfg.MaxRetries != nil {
opts = append(opts, option.WithMaxRetries(*i.cfg.MaxRetries))
}
// Add extra headers if configured.
// Some providers require additional headers that are not added by the SDK.
// TODO(ssncferreira): remove as part of https://github.com/coder/aibridge/issues/192
for key, value := range i.cfg.ExtraHeaders {
opts = append(opts, option.WithHeader(key, value))
}
// Forward client headers to upstream. This middleware runs after the SDK
// has built the request, and replaces the outgoing headers with the sanitized
// client headers plus provider auth.
if i.clientHeaders != nil {
opts = append(opts, option.WithMiddleware(func(req *http.Request, next option.MiddlewareNext) (*http.Response, error) {
req.Header = intercept.BuildUpstreamHeaders(req.Header, i.clientHeaders, i.authHeaderName)
return next(req)
}))
}
// Add API dump middleware if configured
if mw := apidump.NewBridgeMiddleware(i.cfg.APIDumpDir, i.providerName, i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil {
opts = append(opts, option.WithMiddleware(mw))
}
return responses.NewResponseService(opts...)
}
func (i *responsesInterceptionBase) ID() uuid.UUID {
return i.id
}
func (i *responsesInterceptionBase) Credential() intercept.CredentialInfo {
return i.credential
}
func (i *responsesInterceptionBase) Setup(logger slog.Logger, rec recorder.Recorder, mcpProxy mcp.ServerProxier) {
i.logger = logger.With(slog.F("model", i.Model()))
i.recorder = rec
i.mcpProxy = mcpProxy
}
func (i *responsesInterceptionBase) Model() string {
return i.reqPayload.model()
}
func (i *responsesInterceptionBase) CorrelatingToolCallID() *string {
return i.reqPayload.correlatingToolCallID()
}
func (i *responsesInterceptionBase) baseTraceAttributes(r *http.Request, streaming bool) []attribute.KeyValue {
return []attribute.KeyValue{
attribute.String(tracing.RequestPath, r.URL.Path),
attribute.String(tracing.InterceptionID, i.id.String()),
attribute.String(tracing.InitiatorID, aibcontext.ActorIDFromContext(r.Context())),
attribute.String(tracing.Provider, i.providerName),
attribute.String(tracing.Model, i.Model()),
attribute.Bool(tracing.Streaming, streaming),
}
}
func (i *responsesInterceptionBase) validateRequest(ctx context.Context, w http.ResponseWriter) error {
if i.reqPayload.background() {
err := xerrors.New("background requests are currently not supported by AI Bridge")
i.sendCustomErr(ctx, w, http.StatusNotImplemented, err)
return err
}
return nil
}
// sendCustomErr sends custom responses.Error error to the client
// it should only be called before any data is sent back to the client
func (i *responsesInterceptionBase) sendCustomErr(ctx context.Context, w http.ResponseWriter, code int, err error) {
// Same JSON shape as responses.Error but using a plain struct because
// responses.Error embeds *http.Request whose GetBody func field
// is not JSON-marshalable (SA1026).
respErr := struct {
Code string `json:"code"`
Message string `json:"message"`
}{
Code: strconv.Itoa(code),
Message: err.Error(),
}
if b, err := json.Marshal(respErr); err != nil {
i.logger.Warn(ctx, "failed to marshal custom error: ", slog.Error(err))
} else {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(code)
if _, err := w.Write(b); err != nil {
i.logger.Warn(ctx, "failed to send custom error: ", slog.Error(err))
}
}
}
func (i *responsesInterceptionBase) requestOptions(respCopy *responseCopier) []option.RequestOption {
opts := []option.RequestOption{
// Sends original payload to solve json re-encoding issues
// eg. Codex CLI produces requests without ID set in reasoning items: https://platform.openai.com/docs/api-reference/responses/create#responses_create-input-input_item_list-item-reasoning-id
// when re-encoded, ID field is set to empty string which results
// in bad request while not sending ID field at all somehow works.
option.WithRequestBody("application/json", []byte(i.reqPayload)),
// copyMiddleware copies body of original response body to the buffer in responseCopier,
// also reference to headers and status code is kept responseCopier.
// responseCopier is used by interceptors to forward response as it was received,
// eliminating any possibility of JSON re-encoding issues.
option.WithMiddleware(respCopy.copyMiddleware),
}
if !i.reqPayload.Stream() {
opts = append(opts, option.WithRequestTimeout(requestTimeout))
}
return opts
}
func (i *responsesInterceptionBase) recordUserPrompt(ctx context.Context, responseID string, prompt string) {
if responseID == "" {
i.logger.Warn(ctx, "got empty response ID, skipping prompt recording")
return
}
promptUsage := &recorder.PromptUsageRecord{
InterceptionID: i.ID().String(),
MsgID: responseID,
Prompt: prompt,
}
if err := i.recorder.RecordPromptUsage(ctx, promptUsage); err != nil {
i.logger.Warn(ctx, "failed to record prompt usage", slog.Error(err))
}
}
func (i *responsesInterceptionBase) recordModelThoughts(ctx context.Context, response *responses.Response) {
for _, t := range i.extractModelThoughts(response) {
_ = i.recorder.RecordModelThought(ctx, &recorder.ModelThoughtRecord{
InterceptionID: i.ID().String(),
Content: t.Content,
Metadata: t.Metadata,
})
}
}
func (i *responsesInterceptionBase) recordNonInjectedToolUsage(ctx context.Context, response *responses.Response) {
if response == nil {
i.logger.Warn(ctx, "got empty response, skipping tool usage recording")
return
}
for _, item := range response.Output {
var args recorder.ToolArgs
// recording other function types to be considered: https://github.com/coder/aibridge/issues/121
switch item.Type {
case string(constant.ValueOf[constant.FunctionCall]()):
args = i.parseFunctionCallJSONArgs(ctx, item.Arguments)
case string(constant.ValueOf[constant.CustomToolCall]()):
args = item.Input
default:
continue
}
if err := i.recorder.RecordToolUsage(ctx, &recorder.ToolUsageRecord{
InterceptionID: i.ID().String(),
MsgID: response.ID,
ToolCallID: item.CallID,
Tool: item.Name,
Args: args,
Injected: false,
}); err != nil {
i.logger.Warn(ctx, "failed to record tool usage", slog.Error(err), slog.F("tool", item.Name))
}
}
}
func (i *responsesInterceptionBase) parseFunctionCallJSONArgs(ctx context.Context, raw string) recorder.ToolArgs {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return trimmed
}
var args recorder.ToolArgs
if err := json.Unmarshal([]byte(trimmed), &args); err != nil {
i.logger.Warn(ctx, "failed to unmarshal tool args", slog.Error(err))
return trimmed
}
return args
}
func (i *responsesInterceptionBase) recordTokenUsage(ctx context.Context, response *responses.Response) {
if response == nil {
i.logger.Warn(ctx, "got empty response, skipping token usage recording")
return
}
usage := response.Usage
// Keeping logic consistent with chat completions
// Input *includes* the cached tokens, so we subtract them here to reflect actual input token usage.
inputNonCacheTokens := usage.InputTokens - usage.InputTokensDetails.CachedTokens
if err := i.recorder.RecordTokenUsage(ctx, &recorder.TokenUsageRecord{
InterceptionID: i.ID().String(),
MsgID: response.ID,
Input: inputNonCacheTokens,
Output: usage.OutputTokens,
CacheReadInputTokens: usage.InputTokensDetails.CachedTokens,
ExtraTokenTypes: map[string]int64{
"input_cached": usage.InputTokensDetails.CachedTokens, // TODO: remove from ExtraTokenTypes (https://github.com/coder/aibridge/issues/243)
"output_reasoning": usage.OutputTokensDetails.ReasoningTokens,
"total_tokens": usage.TotalTokens,
},
}); err != nil {
i.logger.Warn(ctx, "failed to record token usage", slog.Error(err))
}
}
// extractModelThoughts extracts model thoughts from response output items.
// It captures both reasoning summary items and commentary messages (message
// output items with "phase": "commentary") as model thoughts.
func (*responsesInterceptionBase) extractModelThoughts(response *responses.Response) []*recorder.ModelThoughtRecord {
if response == nil {
return nil
}
var thoughts []*recorder.ModelThoughtRecord
for _, item := range response.Output {
switch item.Type {
case string(constant.ValueOf[constant.Reasoning]()):
reasoning := item.AsReasoning()
for _, summary := range reasoning.Summary {
if summary.Text == "" {
continue
}
thoughts = append(thoughts, &recorder.ModelThoughtRecord{
Content: summary.Text,
Metadata: recorder.Metadata{"source": recorder.ThoughtSourceReasoningSummary},
})
}
case string(constant.ValueOf[constant.Message]()):
// The API sometimes returns commentary messages instead of reasoning
// summaries. These are assistant message output items with "phase": "commentary".
// The SDK doesn't expose a Phase field, so we extract it from raw JSON.
// TODO: revisit when the OpenAI SDK adds a proper Phase field.
raw := item.RawJSON()
if gjson.Get(raw, "role").String() != string(constant.ValueOf[constant.Assistant]()) ||
gjson.Get(raw, "phase").String() != "commentary" {
continue
}
msg := item.AsMessage()
for _, part := range msg.Content {
if part.Type != string(constant.ValueOf[constant.OutputText]()) {
continue
}
if part.Text == "" {
continue
}
thoughts = append(thoughts, &recorder.ModelThoughtRecord{
Content: part.Text,
Metadata: recorder.Metadata{"source": recorder.ThoughtSourceCommentary},
})
}
}
}
return thoughts
}
func (i *responsesInterceptionBase) hasInjectableTools() bool {
return i.mcpProxy != nil && len(i.mcpProxy.ListTools()) > 0
}
// responseCopier helper struct to send original response to the client
type responseCopier struct {
buff deltaBuffer
responseStatus int
responseHeaders http.Header
// responseBody keeps reference to original ReadCloser.
// TeeReader in copyMiddleware copies read bytes from
// response body (read by SDK) to the buffer. In case
// SDK doesns't read everything readAll method reads from
// this closer to makes sure whole response body is in the buffer.
responseBody io.ReadCloser
// responseReceived flag is used to determine if AI Bridge needs to write custom error:
// - If responseReceived is true, the upstream response is forwarded as-is.
// - If responseReceived is false, no response was returned and there is nothing to forward (eg. connection/client error). Custom error will be returned.
responseReceived atomic.Bool
}
func (r *responseCopier) copyMiddleware(req *http.Request, next option.MiddlewareNext) (*http.Response, error) {
resp, err := next(req)
if err != nil || resp == nil {
return resp, err
}
r.responseReceived.Store(true)
r.responseStatus = resp.StatusCode
r.responseHeaders = resp.Header
resp.Body = io.NopCloser(io.TeeReader(resp.Body, &r.buff))
r.responseBody = resp.Body
return resp, nil
}
// readAll reads all data from resp.Body returned by so TeeReader
// so it appends all read data to the buffer and returns buffer contents.
func (r *responseCopier) readAll() ([]byte, error) {
if r.responseBody == nil {
return []byte{}, nil
}
_, err := io.ReadAll(r.responseBody)
return r.buff.readDelta(), err
}
// forwardResp writes whole response as received to ResponseWriter
func (r *responseCopier) forwardResp(w http.ResponseWriter) error {
// no response was received, nothing to forward
if !r.responseReceived.Load() {
return nil
}
w.Header().Set("Content-Type", r.responseHeaders.Get("Content-Type"))
w.WriteHeader(r.responseStatus)
b, err := r.readAll()
if err != nil {
return xerrors.Errorf("failed to read response body: %w", err)
}
if _, err := w.Write(b); err != nil {
return xerrors.Errorf("failed to write response body: %w", err)
}
return nil
}
// deltaBuffer is a thread safe byte buffer
// supports reading incremental data (added after last read)
type deltaBuffer struct {
mu sync.Mutex
buf bytes.Buffer
}
func (d *deltaBuffer) Write(p []byte) (int, error) {
d.mu.Lock()
defer d.mu.Unlock()
return d.buf.Write(p)
}
// readDelta returns only the bytes appended
// after the last readDelta call.
func (d *deltaBuffer) readDelta() []byte {
d.mu.Lock()
defer d.mu.Unlock()
b := bytes.Clone(d.buf.Bytes())
d.buf.Reset()
return b
}