Files
coder/aibridge/passthrough.go
T
Susana Ferreira 0766cc3097 feat: add automatic key failover for AI Bridge passthrough (#24920)
## Description

Adds automatic key failover for passthrough routes for the Anthropic and OpenAI providers. A new `keyFailoverTransport` wraps the reverse-proxy transport: centralized requests walk the configured key pool and retry with the next key on key-specific failures (401/403/429), reusing the same key-marking semantics as the bridged routes.

BYOK passthrough requests run as a single attempt with no failover.

## Changes

- New `keypool.KeyFailoverConfig` carrying the `Pool` to walk and the provider-specific closures (`IsBYOK`, `InjectAuthKey`, `MarkKey`, `BuildExhaustedResponse`).
- New `keypool.NewKeyFailoverTransport`: wraps an inner `http.RoundTripper`. Returns `inner` unchanged when `Pool` is nil, otherwise produces a transport that buffers the request body once, walks the pool per request, and replays each attempt with the next key.
- New `Provider.KeyFailoverConfig(logger)` interface method. Anthropic injects `X-Api-Key`; OpenAI injects `Authorization: Bearer ...`; Copilot returns an empty config.
- `passthrough.go` wires `NewKeyFailoverTransport` around the existing apidump middleware, so every retry attempt is recorded.

## Related Issues

Related to: https://github.com/coder/internal/issues/1446
Related to: https://linear.app/codercom/issue/AIGOV-197/aibridge-automatic-key-failover-for-bridged-and-passthrough-routes

## Follow-up PRs

- Remove dead `Provider.InjectAuthHeader` method now that all auth is applied per-attempt by `KeyFailoverTransport`.
- Bedrock multi-key support.
- Refactor provider vs interceptor config separation.
- Record the actually-used key in the interception credential hint after failover.

> [!NOTE]
> Initially generated by Claude Opus 4.7, modified and reviewed by @ssncferreira
2026-05-07 15:46:36 +01:00

120 lines
4.3 KiB
Go

package aibridge
import (
"context"
"net/http"
"net/http/httputil"
"net/url"
"time"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/trace"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/aibridge/intercept/apidump"
"github.com/coder/coder/v2/aibridge/keypool"
"github.com/coder/coder/v2/aibridge/metrics"
"github.com/coder/coder/v2/aibridge/provider"
"github.com/coder/coder/v2/aibridge/tracing"
"github.com/coder/quartz"
)
// newPassthroughRouter returns a simple reverse-proxy implementation which will be used when a route is not handled specifically
// by a [intercept.Provider].
// A single reverse proxy is created per provider and reused across all requests.
func newPassthroughRouter(prov provider.Provider, logger slog.Logger, m *metrics.Metrics, tracer trace.Tracer) http.HandlerFunc {
provBaseURL, err := url.Parse(prov.BaseURL())
if err != nil {
return newInvalidBaseURLHandler(prov, logger, m, tracer, err)
}
if _, err := url.JoinPath(provBaseURL.Path, "/"); err != nil {
return newInvalidBaseURLHandler(prov, logger, m, tracer, err)
}
// Transport tuned for streaming (no response header timeout).
t := &http.Transport{
Proxy: http.ProxyFromEnvironment,
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
// Build the passthrough proxy, reused across all requests for this provider.
// Rewrite sets proxy headers. For centralized requests, KeyFailoverTransport
// handles auth and failover. BYOK requests pass through.
proxy := &httputil.ReverseProxy{
Rewrite: func(pr *httputil.ProxyRequest) {
rewritePassthroughRequest(pr, provBaseURL)
},
Transport: keypool.NewKeyFailoverTransport(
apidump.NewPassthroughMiddleware(t, prov.APIDumpDir(), prov.Name(), logger, quartz.NewReal()),
prov.KeyFailoverConfig(logger),
),
ErrorHandler: func(rw http.ResponseWriter, req *http.Request, e error) {
logger.Warn(req.Context(), "reverse proxy error", slog.Error(e), slog.F("path", req.URL.Path))
http.Error(rw, "upstream proxy error", http.StatusBadGateway)
},
}
return func(w http.ResponseWriter, r *http.Request) {
if m != nil {
m.PassthroughCount.WithLabelValues(prov.Name(), r.URL.Path, r.Method).Add(1)
}
ctx, span := startSpan(r, tracer)
defer span.End()
proxy.ServeHTTP(w, r.WithContext(ctx))
}
}
// rewritePassthroughRequest configures the outbound request for the upstream and
// applies proxy headers.
func rewritePassthroughRequest(pr *httputil.ProxyRequest, provBaseURL *url.URL) {
pr.SetURL(provBaseURL)
// Rewrite sets "X-Forwarded-For" to just last hop (clients IP address).
// To preserve old Director behavior pr.In "X-Forwarded-For" header
// values need to be copied manually.
// https://pkg.go.dev/net/http/httputil#ProxyRequest.SetXForwarded
if prior, ok := pr.In.Header["X-Forwarded-For"]; ok {
pr.Out.Header["X-Forwarded-For"] = append([]string(nil), prior...)
}
pr.SetXForwarded()
span := trace.SpanFromContext(pr.Out.Context())
span.SetAttributes(attribute.String(tracing.PassthroughUpstreamURL, pr.Out.URL.String()))
// Avoid default Go user-agent if none provided.
if _, ok := pr.Out.Header["User-Agent"]; !ok {
pr.Out.Header.Set("User-Agent", "aibridge") // TODO: use build tag.
}
}
// newInvalidBaseURLHandler returns a handler that always returns 502
// when the provider's base URL is invalid.
func newInvalidBaseURLHandler(prov provider.Provider, logger slog.Logger, m *metrics.Metrics, tracer trace.Tracer, baseURLErr error) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx, span := startSpan(r, tracer)
defer span.End()
if m != nil {
m.PassthroughCount.WithLabelValues(prov.Name(), r.URL.Path, r.Method).Add(1)
}
logger.Warn(ctx, "invalid provider base URL", slog.Error(baseURLErr))
http.Error(w, "invalid provider base URL", http.StatusBadGateway)
span.SetStatus(codes.Error, "invalid provider base URL: "+baseURLErr.Error())
}
}
func startSpan(r *http.Request, tracer trace.Tracer) (context.Context, trace.Span) {
return tracer.Start(r.Context(), "Passthrough", trace.WithAttributes(
attribute.String(tracing.PassthroughURL, r.URL.String()),
attribute.String(tracing.PassthroughMethod, r.Method),
))
}