Files
coder/aibridge/intercept/chatcompletions/base.go
T
Susana Ferreira 0766cc3097 feat: add automatic key failover for AI Bridge passthrough (#24920)
## Description

Adds automatic key failover for passthrough routes for the Anthropic and OpenAI providers. A new `keyFailoverTransport` wraps the reverse-proxy transport: centralized requests walk the configured key pool and retry with the next key on key-specific failures (401/403/429), reusing the same key-marking semantics as the bridged routes.

BYOK passthrough requests run as a single attempt with no failover.

## Changes

- New `keypool.KeyFailoverConfig` carrying the `Pool` to walk and the provider-specific closures (`IsBYOK`, `InjectAuthKey`, `MarkKey`, `BuildExhaustedResponse`).
- New `keypool.NewKeyFailoverTransport`: wraps an inner `http.RoundTripper`. Returns `inner` unchanged when `Pool` is nil, otherwise produces a transport that buffers the request body once, walks the pool per request, and replays each attempt with the next key.
- New `Provider.KeyFailoverConfig(logger)` interface method. Anthropic injects `X-Api-Key`; OpenAI injects `Authorization: Bearer ...`; Copilot returns an empty config.
- `passthrough.go` wires `NewKeyFailoverTransport` around the existing apidump middleware, so every retry attempt is recorded.

## Related Issues

Related to: https://github.com/coder/internal/issues/1446
Related to: https://linear.app/codercom/issue/AIGOV-197/aibridge-automatic-key-failover-for-bridged-and-passthrough-routes

## Follow-up PRs

- Remove dead `Provider.InjectAuthHeader` method now that all auth is applied per-attempt by `KeyFailoverTransport`.
- Bedrock multi-key support.
- Refactor provider vs interceptor config separation.
- Record the actually-used key in the interception credential hint after failover.

> [!NOTE]
> Initially generated by Claude Opus 4.7, modified and reviewed by @ssncferreira
2026-05-07 15:46:36 +01:00

340 lines
11 KiB
Go

package chatcompletions
import (
"context"
"encoding/json"
"errors"
"fmt"
"math"
"net/http"
"strconv"
"strings"
"time"
"github.com/google/uuid"
"github.com/openai/openai-go/v3"
"github.com/openai/openai-go/v3/option"
"github.com/openai/openai-go/v3/shared"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/aibridge/config"
aibcontext "github.com/coder/coder/v2/aibridge/context"
"github.com/coder/coder/v2/aibridge/intercept"
"github.com/coder/coder/v2/aibridge/intercept/apidump"
"github.com/coder/coder/v2/aibridge/keypool"
"github.com/coder/coder/v2/aibridge/mcp"
"github.com/coder/coder/v2/aibridge/recorder"
"github.com/coder/coder/v2/aibridge/tracing"
"github.com/coder/coder/v2/aibridge/utils"
"github.com/coder/quartz"
)
type interceptionBase struct {
id uuid.UUID
providerName string
req *ChatCompletionNewParamsWrapper
cfg config.OpenAI
// clientHeaders are the original HTTP headers from the client request.
clientHeaders http.Header
authHeaderName string
logger slog.Logger
tracer trace.Tracer
recorder recorder.Recorder
mcpProxy mcp.ServerProxier
credential intercept.CredentialInfo
}
// newCompletionsService builds the SDK service used for upstream
// calls. BYOK auth is set here. Centralized auth is set
// per-attempt by the failover loop.
func (i *interceptionBase) newCompletionsService() openai.ChatCompletionService {
// TODO(ssncferreira): validate auth is configured per
// https://github.com/coder/aibridge/issues/266.
var opts []option.RequestOption
// BYOK auth.
if i.cfg.KeyPool == nil {
opts = append(opts, option.WithAPIKey(i.cfg.Key))
}
opts = append(opts, option.WithBaseURL(i.cfg.BaseURL))
// Add extra headers if configured.
// Some providers require additional headers that are not added by the SDK.
// TODO(ssncferreira): remove as part of https://github.com/coder/aibridge/issues/192
for key, value := range i.cfg.ExtraHeaders {
opts = append(opts, option.WithHeader(key, value))
}
// Forward client headers to upstream. This middleware runs after the SDK
// has built the request, and replaces the outgoing headers with the sanitized
// client headers plus provider auth.
if i.clientHeaders != nil {
opts = append(opts, option.WithMiddleware(func(req *http.Request, next option.MiddlewareNext) (*http.Response, error) {
req.Header = intercept.BuildUpstreamHeaders(req.Header, i.clientHeaders, i.authHeaderName)
return next(req)
}))
}
// Add API dump middleware if configured
if mw := apidump.NewBridgeMiddleware(i.cfg.APIDumpDir, i.providerName, i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil {
opts = append(opts, option.WithMiddleware(mw))
}
return openai.NewChatCompletionService(opts...)
}
func (i *interceptionBase) ID() uuid.UUID {
return i.id
}
func (i *interceptionBase) Credential() intercept.CredentialInfo {
return i.credential
}
func (i *interceptionBase) Setup(logger slog.Logger, rec recorder.Recorder, mcpProxy mcp.ServerProxier) {
i.logger = logger
i.recorder = rec
i.mcpProxy = mcpProxy
}
func (i *interceptionBase) CorrelatingToolCallID() *string {
if len(i.req.Messages) == 0 {
return nil
}
// The tool result should be the last input message.
msg := i.req.Messages[len(i.req.Messages)-1]
if msg.OfTool == nil {
return nil
}
return &msg.OfTool.ToolCallID
}
func (i *interceptionBase) baseTraceAttributes(r *http.Request, streaming bool) []attribute.KeyValue {
return []attribute.KeyValue{
attribute.String(tracing.RequestPath, r.URL.Path),
attribute.String(tracing.InterceptionID, i.id.String()),
attribute.String(tracing.InitiatorID, aibcontext.ActorIDFromContext(r.Context())),
attribute.String(tracing.Provider, i.providerName),
attribute.String(tracing.Model, i.Model()),
attribute.Bool(tracing.Streaming, streaming),
}
}
func (i *interceptionBase) Model() string {
if i.req == nil {
return "coder-aibridge-unknown"
}
return i.req.Model
}
func (*interceptionBase) newErrorResponse(err error) map[string]any {
return map[string]any{
"error": true,
"message": err.Error(),
}
}
func (i *interceptionBase) injectTools() {
if i.req == nil || i.mcpProxy == nil || !i.hasInjectableTools() {
return
}
// Disable parallel tool calls when injectable tools are present to simplify the inner agentic loop.
i.req.ParallelToolCalls = openai.Bool(false)
// Inject tools.
for _, tool := range i.mcpProxy.ListTools() {
fn := openai.ChatCompletionToolUnionParam{
OfFunction: &openai.ChatCompletionFunctionToolParam{
Function: openai.FunctionDefinitionParam{
Name: tool.ID,
Strict: openai.Bool(false), // TODO: configurable.
Description: openai.String(tool.Description),
Parameters: openai.FunctionParameters{
"type": "object",
"properties": tool.Params,
// "additionalProperties": false, // Only relevant when strict=true.
},
},
},
}
// Otherwise the request fails with "None is not of type 'array'" if a nil slice is given.
if len(tool.Required) > 0 {
// Must list ALL properties when strict=true.
fn.OfFunction.Function.Parameters["required"] = tool.Required
}
i.req.Tools = append(i.req.Tools, fn)
}
}
func (i *interceptionBase) unmarshalArgs(in string) (args recorder.ToolArgs) {
if len(strings.TrimSpace(in)) == 0 {
return args // An empty string will fail JSON unmarshaling.
}
if err := json.Unmarshal([]byte(in), &args); err != nil {
i.logger.Warn(context.Background(), "failed to unmarshal tool args", slog.Error(err))
}
return args
}
// writeUpstreamError marshals and writes a given error.
func (i *interceptionBase) writeUpstreamError(w http.ResponseWriter, oaiErr *ResponseError) {
if oaiErr == nil {
return
}
w.Header().Set("Content-Type", "application/json")
// Set Retry-After when a cooldown is configured.
if oaiErr.RetryAfter > 0 {
w.Header().Set("Retry-After", strconv.Itoa(int(math.Ceil(oaiErr.RetryAfter.Seconds()))))
}
w.WriteHeader(oaiErr.StatusCode)
out, err := json.Marshal(oaiErr)
if err != nil {
i.logger.Warn(context.Background(), "failed to marshal upstream error", slog.Error(err), slog.F("error_payload", fmt.Sprintf("%+v", oaiErr)))
// Response has to match expected format.
_, _ = w.Write([]byte(`{
"error": {
"type": "error",
"message":"error marshaling upstream error",
"code": "server_error"
}
}`))
} else {
_, _ = w.Write(out)
}
}
// For centralized requests, markKeyOnError extracts an OpenAI
// SDK error from err and marks the key based on its status
// code. Returns true if the status was a key-specific failover
// trigger so callers can retry with the next key.
func (i *interceptionBase) markKeyOnError(ctx context.Context, key *keypool.Key, err error) bool {
if i.cfg.KeyPool == nil {
return false
}
var apiErr *openai.Error
if !errors.As(err, &apiErr) {
return false
}
return keypool.MarkKeyOnStatus(
ctx, key, apiErr.Response,
i.logger, i.providerName,
)
}
// ProcessKeyPoolError translates a keypool exhaustion error
// into a developer-facing responseError shaped for the OpenAI
// API. Returns nil if err is not an exhaustion error.
func ProcessKeyPoolError(err error) *ResponseError {
var transient *keypool.TransientKeyPoolError
switch {
case errors.As(err, &transient):
return newErrorResponse(
"all configured keys are rate-limited",
intercept.OpenAIErrTypeRateLimit,
intercept.OpenAIErrCodeRateLimit,
http.StatusTooManyRequests,
transient.RetryAfter,
)
case errors.Is(err, keypool.ErrPermanentKeyPool):
return newErrorResponse(
"all configured keys failed authentication",
intercept.OpenAIErrTypeAPI,
intercept.OpenAIErrCodeServer,
http.StatusBadGateway,
0,
)
default:
return nil
}
}
func (i *interceptionBase) hasInjectableTools() bool {
return i.mcpProxy != nil && len(i.mcpProxy.ListTools()) > 0
}
func sumUsage(ref, in openai.CompletionUsage) openai.CompletionUsage {
return openai.CompletionUsage{
CompletionTokens: ref.CompletionTokens + in.CompletionTokens,
PromptTokens: ref.PromptTokens + in.PromptTokens,
TotalTokens: ref.TotalTokens + in.TotalTokens,
CompletionTokensDetails: openai.CompletionUsageCompletionTokensDetails{
AcceptedPredictionTokens: ref.CompletionTokensDetails.AcceptedPredictionTokens + in.CompletionTokensDetails.AcceptedPredictionTokens,
AudioTokens: ref.CompletionTokensDetails.AudioTokens + in.CompletionTokensDetails.AudioTokens,
ReasoningTokens: ref.CompletionTokensDetails.ReasoningTokens + in.CompletionTokensDetails.ReasoningTokens,
RejectedPredictionTokens: ref.CompletionTokensDetails.RejectedPredictionTokens + in.CompletionTokensDetails.RejectedPredictionTokens,
},
PromptTokensDetails: openai.CompletionUsagePromptTokensDetails{
AudioTokens: ref.PromptTokensDetails.AudioTokens + in.PromptTokensDetails.AudioTokens,
CachedTokens: ref.PromptTokensDetails.CachedTokens + in.PromptTokensDetails.CachedTokens,
},
}
}
// calculateActualInputTokenUsage accounts for cached tokens which are included in [openai.CompletionUsage].PromptTokens.
func calculateActualInputTokenUsage(in openai.CompletionUsage) int64 {
// Input *includes* the cached tokens, so we subtract them here to reflect actual input token usage.
// The original value can be reconstructed by adding CachedTokens back to Input.
// See https://platform.openai.com/docs/api-reference/usage/completions_object#usage/completions_object-input_tokens.
return in.PromptTokens /* The aggregated number of text input tokens used, including cached tokens. */ -
in.PromptTokensDetails.CachedTokens /* The aggregated number of text input tokens that has been cached from previous requests. */
}
func getErrorResponse(err error) *ResponseError {
var apiErr *openai.Error
if !errors.As(err, &apiErr) {
return nil
}
return newErrorResponse(apiErr.Message, apiErr.Type, apiErr.Code, apiErr.StatusCode, keypool.ParseRetryAfter(apiErr.Response))
}
var _ error = &ResponseError{}
type ResponseError struct {
ErrorObject *shared.ErrorObject `json:"error"`
StatusCode int `json:"-"`
RetryAfter time.Duration `json:"-"`
}
func newErrorResponse(msg, errType, code string, status int, retryAfter time.Duration) *ResponseError {
return &ResponseError{
ErrorObject: &shared.ErrorObject{
Code: code,
Message: msg,
Type: errType,
},
StatusCode: status,
RetryAfter: retryAfter,
}
}
func (e *ResponseError) Error() string {
if e.ErrorObject == nil {
return ""
}
return e.ErrorObject.Message
}
// ToResponse marshals e into an *http.Response shaped for the
// OpenAI API.
func (e *ResponseError) ToResponse() *http.Response {
body, err := json.Marshal(e)
if err != nil {
body = []byte(`{"error":{"type":"error","message":"error marshaling upstream error","code":"server_error"}}`)
}
return utils.NewJSONErrorResponse(e.StatusCode, e.RetryAfter, body)
}