mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
0766cc3097
## Description Adds automatic key failover for passthrough routes for the Anthropic and OpenAI providers. A new `keyFailoverTransport` wraps the reverse-proxy transport: centralized requests walk the configured key pool and retry with the next key on key-specific failures (401/403/429), reusing the same key-marking semantics as the bridged routes. BYOK passthrough requests run as a single attempt with no failover. ## Changes - New `keypool.KeyFailoverConfig` carrying the `Pool` to walk and the provider-specific closures (`IsBYOK`, `InjectAuthKey`, `MarkKey`, `BuildExhaustedResponse`). - New `keypool.NewKeyFailoverTransport`: wraps an inner `http.RoundTripper`. Returns `inner` unchanged when `Pool` is nil, otherwise produces a transport that buffers the request body once, walks the pool per request, and replays each attempt with the next key. - New `Provider.KeyFailoverConfig(logger)` interface method. Anthropic injects `X-Api-Key`; OpenAI injects `Authorization: Bearer ...`; Copilot returns an empty config. - `passthrough.go` wires `NewKeyFailoverTransport` around the existing apidump middleware, so every retry attempt is recorded. ## Related Issues Related to: https://github.com/coder/internal/issues/1446 Related to: https://linear.app/codercom/issue/AIGOV-197/aibridge-automatic-key-failover-for-bridged-and-passthrough-routes ## Follow-up PRs - Remove dead `Provider.InjectAuthHeader` method now that all auth is applied per-attempt by `KeyFailoverTransport`. - Bedrock multi-key support. - Refactor provider vs interceptor config separation. - Record the actually-used key in the interception credential hint after failover. > [!NOTE] > Initially generated by Claude Opus 4.7, modified and reviewed by @ssncferreira
340 lines
11 KiB
Go
340 lines
11 KiB
Go
package chatcompletions
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"math"
|
|
"net/http"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/openai/openai-go/v3"
|
|
"github.com/openai/openai-go/v3/option"
|
|
"github.com/openai/openai-go/v3/shared"
|
|
"go.opentelemetry.io/otel/attribute"
|
|
"go.opentelemetry.io/otel/trace"
|
|
|
|
"cdr.dev/slog/v3"
|
|
"github.com/coder/coder/v2/aibridge/config"
|
|
aibcontext "github.com/coder/coder/v2/aibridge/context"
|
|
"github.com/coder/coder/v2/aibridge/intercept"
|
|
"github.com/coder/coder/v2/aibridge/intercept/apidump"
|
|
"github.com/coder/coder/v2/aibridge/keypool"
|
|
"github.com/coder/coder/v2/aibridge/mcp"
|
|
"github.com/coder/coder/v2/aibridge/recorder"
|
|
"github.com/coder/coder/v2/aibridge/tracing"
|
|
"github.com/coder/coder/v2/aibridge/utils"
|
|
"github.com/coder/quartz"
|
|
)
|
|
|
|
type interceptionBase struct {
|
|
id uuid.UUID
|
|
providerName string
|
|
req *ChatCompletionNewParamsWrapper
|
|
cfg config.OpenAI
|
|
|
|
// clientHeaders are the original HTTP headers from the client request.
|
|
clientHeaders http.Header
|
|
authHeaderName string
|
|
|
|
logger slog.Logger
|
|
tracer trace.Tracer
|
|
|
|
recorder recorder.Recorder
|
|
mcpProxy mcp.ServerProxier
|
|
credential intercept.CredentialInfo
|
|
}
|
|
|
|
// newCompletionsService builds the SDK service used for upstream
|
|
// calls. BYOK auth is set here. Centralized auth is set
|
|
// per-attempt by the failover loop.
|
|
func (i *interceptionBase) newCompletionsService() openai.ChatCompletionService {
|
|
// TODO(ssncferreira): validate auth is configured per
|
|
// https://github.com/coder/aibridge/issues/266.
|
|
|
|
var opts []option.RequestOption
|
|
// BYOK auth.
|
|
if i.cfg.KeyPool == nil {
|
|
opts = append(opts, option.WithAPIKey(i.cfg.Key))
|
|
}
|
|
opts = append(opts, option.WithBaseURL(i.cfg.BaseURL))
|
|
|
|
// Add extra headers if configured.
|
|
// Some providers require additional headers that are not added by the SDK.
|
|
// TODO(ssncferreira): remove as part of https://github.com/coder/aibridge/issues/192
|
|
for key, value := range i.cfg.ExtraHeaders {
|
|
opts = append(opts, option.WithHeader(key, value))
|
|
}
|
|
|
|
// Forward client headers to upstream. This middleware runs after the SDK
|
|
// has built the request, and replaces the outgoing headers with the sanitized
|
|
// client headers plus provider auth.
|
|
if i.clientHeaders != nil {
|
|
opts = append(opts, option.WithMiddleware(func(req *http.Request, next option.MiddlewareNext) (*http.Response, error) {
|
|
req.Header = intercept.BuildUpstreamHeaders(req.Header, i.clientHeaders, i.authHeaderName)
|
|
return next(req)
|
|
}))
|
|
}
|
|
|
|
// Add API dump middleware if configured
|
|
if mw := apidump.NewBridgeMiddleware(i.cfg.APIDumpDir, i.providerName, i.Model(), i.id, i.logger, quartz.NewReal()); mw != nil {
|
|
opts = append(opts, option.WithMiddleware(mw))
|
|
}
|
|
|
|
return openai.NewChatCompletionService(opts...)
|
|
}
|
|
|
|
func (i *interceptionBase) ID() uuid.UUID {
|
|
return i.id
|
|
}
|
|
|
|
func (i *interceptionBase) Credential() intercept.CredentialInfo {
|
|
return i.credential
|
|
}
|
|
|
|
func (i *interceptionBase) Setup(logger slog.Logger, rec recorder.Recorder, mcpProxy mcp.ServerProxier) {
|
|
i.logger = logger
|
|
i.recorder = rec
|
|
i.mcpProxy = mcpProxy
|
|
}
|
|
|
|
func (i *interceptionBase) CorrelatingToolCallID() *string {
|
|
if len(i.req.Messages) == 0 {
|
|
return nil
|
|
}
|
|
|
|
// The tool result should be the last input message.
|
|
msg := i.req.Messages[len(i.req.Messages)-1]
|
|
if msg.OfTool == nil {
|
|
return nil
|
|
}
|
|
return &msg.OfTool.ToolCallID
|
|
}
|
|
|
|
func (i *interceptionBase) baseTraceAttributes(r *http.Request, streaming bool) []attribute.KeyValue {
|
|
return []attribute.KeyValue{
|
|
attribute.String(tracing.RequestPath, r.URL.Path),
|
|
attribute.String(tracing.InterceptionID, i.id.String()),
|
|
attribute.String(tracing.InitiatorID, aibcontext.ActorIDFromContext(r.Context())),
|
|
attribute.String(tracing.Provider, i.providerName),
|
|
attribute.String(tracing.Model, i.Model()),
|
|
attribute.Bool(tracing.Streaming, streaming),
|
|
}
|
|
}
|
|
|
|
func (i *interceptionBase) Model() string {
|
|
if i.req == nil {
|
|
return "coder-aibridge-unknown"
|
|
}
|
|
|
|
return i.req.Model
|
|
}
|
|
|
|
func (*interceptionBase) newErrorResponse(err error) map[string]any {
|
|
return map[string]any{
|
|
"error": true,
|
|
"message": err.Error(),
|
|
}
|
|
}
|
|
|
|
func (i *interceptionBase) injectTools() {
|
|
if i.req == nil || i.mcpProxy == nil || !i.hasInjectableTools() {
|
|
return
|
|
}
|
|
|
|
// Disable parallel tool calls when injectable tools are present to simplify the inner agentic loop.
|
|
i.req.ParallelToolCalls = openai.Bool(false)
|
|
|
|
// Inject tools.
|
|
for _, tool := range i.mcpProxy.ListTools() {
|
|
fn := openai.ChatCompletionToolUnionParam{
|
|
OfFunction: &openai.ChatCompletionFunctionToolParam{
|
|
Function: openai.FunctionDefinitionParam{
|
|
Name: tool.ID,
|
|
Strict: openai.Bool(false), // TODO: configurable.
|
|
Description: openai.String(tool.Description),
|
|
Parameters: openai.FunctionParameters{
|
|
"type": "object",
|
|
"properties": tool.Params,
|
|
// "additionalProperties": false, // Only relevant when strict=true.
|
|
},
|
|
},
|
|
},
|
|
}
|
|
|
|
// Otherwise the request fails with "None is not of type 'array'" if a nil slice is given.
|
|
if len(tool.Required) > 0 {
|
|
// Must list ALL properties when strict=true.
|
|
fn.OfFunction.Function.Parameters["required"] = tool.Required
|
|
}
|
|
|
|
i.req.Tools = append(i.req.Tools, fn)
|
|
}
|
|
}
|
|
|
|
func (i *interceptionBase) unmarshalArgs(in string) (args recorder.ToolArgs) {
|
|
if len(strings.TrimSpace(in)) == 0 {
|
|
return args // An empty string will fail JSON unmarshaling.
|
|
}
|
|
|
|
if err := json.Unmarshal([]byte(in), &args); err != nil {
|
|
i.logger.Warn(context.Background(), "failed to unmarshal tool args", slog.Error(err))
|
|
}
|
|
|
|
return args
|
|
}
|
|
|
|
// writeUpstreamError marshals and writes a given error.
|
|
func (i *interceptionBase) writeUpstreamError(w http.ResponseWriter, oaiErr *ResponseError) {
|
|
if oaiErr == nil {
|
|
return
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
// Set Retry-After when a cooldown is configured.
|
|
if oaiErr.RetryAfter > 0 {
|
|
w.Header().Set("Retry-After", strconv.Itoa(int(math.Ceil(oaiErr.RetryAfter.Seconds()))))
|
|
}
|
|
w.WriteHeader(oaiErr.StatusCode)
|
|
|
|
out, err := json.Marshal(oaiErr)
|
|
if err != nil {
|
|
i.logger.Warn(context.Background(), "failed to marshal upstream error", slog.Error(err), slog.F("error_payload", fmt.Sprintf("%+v", oaiErr)))
|
|
// Response has to match expected format.
|
|
_, _ = w.Write([]byte(`{
|
|
"error": {
|
|
"type": "error",
|
|
"message":"error marshaling upstream error",
|
|
"code": "server_error"
|
|
}
|
|
}`))
|
|
} else {
|
|
_, _ = w.Write(out)
|
|
}
|
|
}
|
|
|
|
// For centralized requests, markKeyOnError extracts an OpenAI
|
|
// SDK error from err and marks the key based on its status
|
|
// code. Returns true if the status was a key-specific failover
|
|
// trigger so callers can retry with the next key.
|
|
func (i *interceptionBase) markKeyOnError(ctx context.Context, key *keypool.Key, err error) bool {
|
|
if i.cfg.KeyPool == nil {
|
|
return false
|
|
}
|
|
var apiErr *openai.Error
|
|
if !errors.As(err, &apiErr) {
|
|
return false
|
|
}
|
|
return keypool.MarkKeyOnStatus(
|
|
ctx, key, apiErr.Response,
|
|
i.logger, i.providerName,
|
|
)
|
|
}
|
|
|
|
// ProcessKeyPoolError translates a keypool exhaustion error
|
|
// into a developer-facing responseError shaped for the OpenAI
|
|
// API. Returns nil if err is not an exhaustion error.
|
|
func ProcessKeyPoolError(err error) *ResponseError {
|
|
var transient *keypool.TransientKeyPoolError
|
|
switch {
|
|
case errors.As(err, &transient):
|
|
return newErrorResponse(
|
|
"all configured keys are rate-limited",
|
|
intercept.OpenAIErrTypeRateLimit,
|
|
intercept.OpenAIErrCodeRateLimit,
|
|
http.StatusTooManyRequests,
|
|
transient.RetryAfter,
|
|
)
|
|
case errors.Is(err, keypool.ErrPermanentKeyPool):
|
|
return newErrorResponse(
|
|
"all configured keys failed authentication",
|
|
intercept.OpenAIErrTypeAPI,
|
|
intercept.OpenAIErrCodeServer,
|
|
http.StatusBadGateway,
|
|
0,
|
|
)
|
|
default:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func (i *interceptionBase) hasInjectableTools() bool {
|
|
return i.mcpProxy != nil && len(i.mcpProxy.ListTools()) > 0
|
|
}
|
|
|
|
func sumUsage(ref, in openai.CompletionUsage) openai.CompletionUsage {
|
|
return openai.CompletionUsage{
|
|
CompletionTokens: ref.CompletionTokens + in.CompletionTokens,
|
|
PromptTokens: ref.PromptTokens + in.PromptTokens,
|
|
TotalTokens: ref.TotalTokens + in.TotalTokens,
|
|
CompletionTokensDetails: openai.CompletionUsageCompletionTokensDetails{
|
|
AcceptedPredictionTokens: ref.CompletionTokensDetails.AcceptedPredictionTokens + in.CompletionTokensDetails.AcceptedPredictionTokens,
|
|
AudioTokens: ref.CompletionTokensDetails.AudioTokens + in.CompletionTokensDetails.AudioTokens,
|
|
ReasoningTokens: ref.CompletionTokensDetails.ReasoningTokens + in.CompletionTokensDetails.ReasoningTokens,
|
|
RejectedPredictionTokens: ref.CompletionTokensDetails.RejectedPredictionTokens + in.CompletionTokensDetails.RejectedPredictionTokens,
|
|
},
|
|
PromptTokensDetails: openai.CompletionUsagePromptTokensDetails{
|
|
AudioTokens: ref.PromptTokensDetails.AudioTokens + in.PromptTokensDetails.AudioTokens,
|
|
CachedTokens: ref.PromptTokensDetails.CachedTokens + in.PromptTokensDetails.CachedTokens,
|
|
},
|
|
}
|
|
}
|
|
|
|
// calculateActualInputTokenUsage accounts for cached tokens which are included in [openai.CompletionUsage].PromptTokens.
|
|
func calculateActualInputTokenUsage(in openai.CompletionUsage) int64 {
|
|
// Input *includes* the cached tokens, so we subtract them here to reflect actual input token usage.
|
|
// The original value can be reconstructed by adding CachedTokens back to Input.
|
|
// See https://platform.openai.com/docs/api-reference/usage/completions_object#usage/completions_object-input_tokens.
|
|
return in.PromptTokens /* The aggregated number of text input tokens used, including cached tokens. */ -
|
|
in.PromptTokensDetails.CachedTokens /* The aggregated number of text input tokens that has been cached from previous requests. */
|
|
}
|
|
|
|
func getErrorResponse(err error) *ResponseError {
|
|
var apiErr *openai.Error
|
|
if !errors.As(err, &apiErr) {
|
|
return nil
|
|
}
|
|
return newErrorResponse(apiErr.Message, apiErr.Type, apiErr.Code, apiErr.StatusCode, keypool.ParseRetryAfter(apiErr.Response))
|
|
}
|
|
|
|
var _ error = &ResponseError{}
|
|
|
|
type ResponseError struct {
|
|
ErrorObject *shared.ErrorObject `json:"error"`
|
|
StatusCode int `json:"-"`
|
|
RetryAfter time.Duration `json:"-"`
|
|
}
|
|
|
|
func newErrorResponse(msg, errType, code string, status int, retryAfter time.Duration) *ResponseError {
|
|
return &ResponseError{
|
|
ErrorObject: &shared.ErrorObject{
|
|
Code: code,
|
|
Message: msg,
|
|
Type: errType,
|
|
},
|
|
StatusCode: status,
|
|
RetryAfter: retryAfter,
|
|
}
|
|
}
|
|
|
|
func (e *ResponseError) Error() string {
|
|
if e.ErrorObject == nil {
|
|
return ""
|
|
}
|
|
return e.ErrorObject.Message
|
|
}
|
|
|
|
// ToResponse marshals e into an *http.Response shaped for the
|
|
// OpenAI API.
|
|
func (e *ResponseError) ToResponse() *http.Response {
|
|
body, err := json.Marshal(e)
|
|
if err != nil {
|
|
body = []byte(`{"error":{"type":"error","message":"error marshaling upstream error","code":"server_error"}}`)
|
|
}
|
|
return utils.NewJSONErrorResponse(e.StatusCode, e.RetryAfter, body)
|
|
}
|