Files
coder/aibridge/intercept/chatcompletions/blocking.go
T
Susana Ferreira 7b903cad73 fix: track credential hint across key failover attempts in aibridge (#25735)
## Problem

Centralized requests recorded *the first available key from the pool at
`CreateInterceptor` time* as `credential_hint`, so the interception
could be persisted in the database with a hint that didn't match the key
that actually served the request. The fix consists in storing, at
end-of-interception, the hint of the key that succeeded, or the last
attempted key if all keys are unavailable.

## Changes

- Add `Key.Hint()` and update `credential_hint` on every failover
attempt so it reflects the actually-used key.
- Stop pre-populating `credential_hint` at `CreateInterceptor`.
Centralized starts empty and is updated by the key failover loop.
- Persist the final hint via `RecordInterceptionEnded`; SQL updates
`credential_hint` only when `credential_kind = 'centralized'` so BYOK
keeps its start-time value.
- Log the actually-used hint on interception end/failure; start log uses
a `<keypool-pending>` placeholder for centralized.

> [!NOTE]
> Initially generated by Claude Opus 4.7, modified and reviewed by
@ssncferreira
2026-05-29 12:01:37 +01:00

322 lines
11 KiB
Go

package chatcompletions
import (
"context"
"encoding/json"
"errors"
"net/http"
"strings"
"time"
"github.com/google/uuid"
"github.com/openai/openai-go/v3"
"github.com/openai/openai-go/v3/option"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/aibridge/config"
aibcontext "github.com/coder/coder/v2/aibridge/context"
"github.com/coder/coder/v2/aibridge/intercept"
"github.com/coder/coder/v2/aibridge/intercept/eventstream"
"github.com/coder/coder/v2/aibridge/keypool"
"github.com/coder/coder/v2/aibridge/mcp"
"github.com/coder/coder/v2/aibridge/recorder"
"github.com/coder/coder/v2/aibridge/tracing"
)
type BlockingInterception struct {
interceptionBase
}
func NewBlockingInterceptor(
id uuid.UUID,
req *ChatCompletionNewParamsWrapper,
providerName string,
cfg config.OpenAI,
clientHeaders http.Header,
authHeaderName string,
tracer trace.Tracer,
cred intercept.CredentialInfo,
) *BlockingInterception {
return &BlockingInterception{interceptionBase: interceptionBase{
id: id,
providerName: providerName,
req: req,
cfg: cfg,
clientHeaders: clientHeaders,
authHeaderName: authHeaderName,
tracer: tracer,
credential: cred,
}}
}
func (i *BlockingInterception) Setup(logger slog.Logger, rec recorder.Recorder, mcpProxy mcp.ServerProxier) {
i.interceptionBase.Setup(logger.Named("blocking"), rec, mcpProxy)
}
func (*BlockingInterception) Streaming() bool {
return false
}
func (i *BlockingInterception) TraceAttributes(r *http.Request) []attribute.KeyValue {
return i.interceptionBase.baseTraceAttributes(r, false)
}
func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Request) (outErr error) {
if i.req == nil {
return xerrors.New("developer error: req is nil")
}
ctx, span := i.tracer.Start(r.Context(), "Intercept.ProcessRequest", trace.WithAttributes(tracing.InterceptionAttributesFromContext(r.Context())...))
defer tracing.EndSpanErr(span, &outErr)
svc := i.newCompletionsService()
logger := i.logger.With(slog.F("model", i.req.Model))
var (
cumulativeUsage openai.CompletionUsage
completion *openai.ChatCompletion
err error
)
i.injectTools()
prompt, err := i.req.lastUserPrompt()
if err != nil {
logger.Warn(ctx, "failed to retrieve last user prompt", slog.Error(err))
}
for {
// TODO add outer loop span (https://github.com/coder/aibridge/issues/67)
var opts []option.RequestOption
opts = append(opts, option.WithRequestTimeout(time.Second*600))
// TODO(ssncferreira): inject actor headers directly in the client-header
// middleware instead of using SDK options.
if actor := aibcontext.ActorFromContext(r.Context()); actor != nil && i.cfg.SendActorHeaders {
opts = append(opts, intercept.ActorHeadersAsOpenAIOpts(actor)...)
}
completion, err = i.newChatCompletion(ctx, svc, opts)
if err != nil {
break
}
if prompt != nil {
_ = i.recorder.RecordPromptUsage(ctx, &recorder.PromptUsageRecord{
InterceptionID: i.ID().String(),
MsgID: completion.ID,
Prompt: *prompt,
})
prompt = nil
}
lastUsage := completion.Usage
cumulativeUsage = sumUsage(cumulativeUsage, completion.Usage)
_ = i.recorder.RecordTokenUsage(ctx, &recorder.TokenUsageRecord{
InterceptionID: i.ID().String(),
MsgID: completion.ID,
Input: calculateActualInputTokenUsage(lastUsage),
Output: lastUsage.CompletionTokens,
CacheReadInputTokens: lastUsage.PromptTokensDetails.CachedTokens,
ExtraTokenTypes: map[string]int64{
"prompt_audio": lastUsage.PromptTokensDetails.AudioTokens,
"completion_accepted_prediction": lastUsage.CompletionTokensDetails.AcceptedPredictionTokens,
"completion_rejected_prediction": lastUsage.CompletionTokensDetails.RejectedPredictionTokens,
"completion_audio": lastUsage.CompletionTokensDetails.AudioTokens,
"completion_reasoning": lastUsage.CompletionTokensDetails.ReasoningTokens,
},
})
// Check if we have tool calls to process.
var pendingToolCalls []openai.ChatCompletionMessageToolCallUnion
if len(completion.Choices) > 0 && completion.Choices[0].Message.ToolCalls != nil {
for _, toolCall := range completion.Choices[0].Message.ToolCalls {
if i.mcpProxy != nil && i.mcpProxy.GetTool(toolCall.Function.Name) != nil {
pendingToolCalls = append(pendingToolCalls, toolCall)
} else {
_ = i.recorder.RecordToolUsage(ctx, &recorder.ToolUsageRecord{
InterceptionID: i.ID().String(),
MsgID: completion.ID,
ToolCallID: toolCall.ID,
Tool: toolCall.Function.Name,
Args: i.unmarshalArgs(toolCall.Function.Arguments),
Injected: false,
})
}
}
}
// If no injected tool calls, we're done.
if len(pendingToolCalls) == 0 {
break
}
appendedPrevMsg := false
for _, tc := range pendingToolCalls {
if i.mcpProxy == nil {
continue
}
tool := i.mcpProxy.GetTool(tc.Function.Name)
if tool == nil {
// Not a known tool, don't do anything.
logger.Warn(ctx, "pending tool call for non-managed tool, skipping", slog.F("tool", tc.Function.Name))
continue
}
// Only do this once.
if !appendedPrevMsg {
// Append the whole message from this stream as context since we'll be sending a new request with the tool results.
i.req.Messages = append(i.req.Messages, completion.Choices[0].Message.ToParam())
appendedPrevMsg = true
}
args := i.unmarshalArgs(tc.Function.Arguments)
res, err := tool.Call(ctx, args, i.tracer)
_ = i.recorder.RecordToolUsage(ctx, &recorder.ToolUsageRecord{
InterceptionID: i.ID().String(),
MsgID: completion.ID,
ToolCallID: tc.ID,
ServerURL: &tool.ServerURL,
Tool: tool.Name,
Args: args,
Injected: true,
InvocationError: err,
})
if err != nil {
// Always provide a tool result even if the tool call failed
errorResponse := map[string]interface{}{
// TODO: interception ID?
"error": true,
"message": err.Error(),
}
errorJSON, _ := json.Marshal(errorResponse)
i.req.Messages = append(i.req.Messages, openai.ToolMessage(string(errorJSON), tc.ID))
continue
}
var out strings.Builder
if err := json.NewEncoder(&out).Encode(res); err != nil {
logger.Warn(ctx, "failed to encode tool response", slog.Error(err))
// Always provide a tool result even if encoding failed
errorResponse := map[string]interface{}{
// TODO: interception ID?
"error": true,
"message": err.Error(),
}
errorJSON, _ := json.Marshal(errorResponse)
i.req.Messages = append(i.req.Messages, openai.ToolMessage(string(errorJSON), tc.ID))
continue
}
i.req.Messages = append(i.req.Messages, openai.ToolMessage(out.String(), tc.ID))
}
}
if err != nil {
if eventstream.IsConnError(err) {
http.Error(w, err.Error(), http.StatusInternalServerError)
return xerrors.Errorf("upstream connection closed: %w", err)
}
// The failover loop may return a keypool exhaustion
// error. Check before the SDK-error path.
var keyPoolErr *keypool.Error
if errors.As(err, &keyPoolErr) {
i.writeUpstreamError(w, intercept.ResponseErrorFromKeyPool(keyPoolErr))
return xerrors.Errorf("key pool exhausted: %w", err)
}
if apiErr := intercept.ResponseErrorFromAPIError(err); apiErr != nil {
i.writeUpstreamError(w, apiErr)
return xerrors.Errorf("openai API error: %w", err)
}
http.Error(w, err.Error(), http.StatusInternalServerError)
return xerrors.Errorf("chat completion failed: %w", err)
}
if completion == nil {
return nil
}
// Overwrite response identifier since proxy obscures injected tool call invocations.
completion.ID = i.ID().String()
// Update the cumulative usage in the final response.
if completion.Usage.CompletionTokens > 0 {
completion.Usage = cumulativeUsage
}
w.Header().Set("Content-Type", "application/json")
out, err := json.Marshal(completion)
if err != nil {
out, _ = json.Marshal(i.newErrorResponse(xerrors.Errorf("failed to marshal response: %w", err)))
w.WriteHeader(http.StatusInternalServerError)
} else {
w.WriteHeader(http.StatusOK)
}
_, _ = w.Write(out)
return nil
}
// newChatCompletion routes between BYOK (single attempt) and
// centralized failover.
func (i *BlockingInterception) newChatCompletion(ctx context.Context, svc openai.ChatCompletionService, opts []option.RequestOption) (*openai.ChatCompletion, error) {
// BYOK: single attempt, no failover.
if i.cfg.KeyPool == nil {
return i.newChatCompletionWithKey(ctx, svc, opts)
}
return i.newChatCompletionWithKeyFailover(ctx, svc, opts)
}
// newChatCompletionWithKey performs a single upstream call.
func (i *BlockingInterception) newChatCompletionWithKey(ctx context.Context, svc openai.ChatCompletionService, opts []option.RequestOption) (_ *openai.ChatCompletion, outErr error) {
_, span := i.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...))
defer tracing.EndSpanErr(span, &outErr)
return svc.New(ctx, i.req.ChatCompletionNewParams, opts...)
}
// newChatCompletionWithKeyFailover walks the centralized key
// pool, trying each key until one succeeds or the pool is
// exhausted. Keys are marked temporary on 429 and permanent on
// 401/403. Errors that aren't key-specific don't trigger
// failover and are returned to the caller.
func (i *BlockingInterception) newChatCompletionWithKeyFailover(ctx context.Context, svc openai.ChatCompletionService, opts []option.RequestOption) (*openai.ChatCompletion, error) {
walker := i.cfg.KeyPool.Walker()
for {
key, keyPoolErr := walker.Next()
if keyPoolErr != nil {
return nil, keyPoolErr
}
// Record the key in use so the hint reflects the last attempted key.
i.credential = intercept.NewCredentialInfo(intercept.CredentialKindCentralized, key.Value())
i.logger.Debug(ctx, "using centralized api key",
slog.F("credential_hint", i.Credential().Hint), slog.F("credential_length", i.Credential().Length))
requestOpts := append([]option.RequestOption{}, opts...)
requestOpts = append(requestOpts,
option.WithAPIKey(key.Value()),
// Disable SDK retries because the failover loop
// handles retries via key rotation.
option.WithMaxRetries(0),
)
completion, err := i.newChatCompletionWithKey(ctx, svc, requestOpts)
// Key-specific failure: try the next key.
if i.markKeyOnError(ctx, key, err) {
continue
}
// Either success (completion, nil) or a non-key error
// (nil, err): nothing to retry, return as-is.
return completion, err
}
}