Files
Susana Ferreira 7b903cad73 fix: track credential hint across key failover attempts in aibridge (#25735)
## Problem

Centralized requests recorded *the first available key from the pool at
`CreateInterceptor` time* as `credential_hint`, so the interception
could be persisted in the database with a hint that didn't match the key
that actually served the request. The fix consists in storing, at
end-of-interception, the hint of the key that succeeded, or the last
attempted key if all keys are unavailable.

## Changes

- Add `Key.Hint()` and update `credential_hint` on every failover
attempt so it reflects the actually-used key.
- Stop pre-populating `credential_hint` at `CreateInterceptor`.
Centralized starts empty and is updated by the key failover loop.
- Persist the final hint via `RecordInterceptionEnded`; SQL updates
`credential_hint` only when `credential_kind = 'centralized'` so BYOK
keeps its start-time value.
- Log the actually-used hint on interception end/failure; start log uses
a `<keypool-pending>` placeholder for centralized.

> [!NOTE]
> Initially generated by Claude Opus 4.7, modified and reviewed by
@ssncferreira
2026-05-29 12:01:37 +01:00

202 lines
6.5 KiB
Go

package responses
import (
"context"
"errors"
"net/http"
"time"
"github.com/google/uuid"
"github.com/openai/openai-go/v3/option"
"github.com/openai/openai-go/v3/responses"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/aibridge/config"
aibcontext "github.com/coder/coder/v2/aibridge/context"
"github.com/coder/coder/v2/aibridge/intercept"
"github.com/coder/coder/v2/aibridge/keypool"
"github.com/coder/coder/v2/aibridge/mcp"
"github.com/coder/coder/v2/aibridge/recorder"
"github.com/coder/coder/v2/aibridge/tracing"
)
type BlockingResponsesInterceptor struct {
responsesInterceptionBase
}
func NewBlockingInterceptor(
id uuid.UUID,
reqPayload RequestPayload,
providerName string,
cfg config.OpenAI,
clientHeaders http.Header,
authHeaderName string,
tracer trace.Tracer,
cred intercept.CredentialInfo,
) *BlockingResponsesInterceptor {
return &BlockingResponsesInterceptor{
responsesInterceptionBase: responsesInterceptionBase{
id: id,
providerName: providerName,
reqPayload: reqPayload,
cfg: cfg,
clientHeaders: clientHeaders,
authHeaderName: authHeaderName,
tracer: tracer,
credential: cred,
},
}
}
func (i *BlockingResponsesInterceptor) Setup(logger slog.Logger, rec recorder.Recorder, mcpProxy mcp.ServerProxier) {
i.responsesInterceptionBase.Setup(logger.Named("blocking"), rec, mcpProxy)
}
func (*BlockingResponsesInterceptor) Streaming() bool {
return false
}
func (i *BlockingResponsesInterceptor) TraceAttributes(r *http.Request) []attribute.KeyValue {
return i.responsesInterceptionBase.baseTraceAttributes(r, false)
}
func (i *BlockingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r *http.Request) (outErr error) {
ctx, span := i.tracer.Start(r.Context(), "Intercept.ProcessRequest", trace.WithAttributes(tracing.InterceptionAttributesFromContext(r.Context())...))
defer tracing.EndSpanErr(span, &outErr)
if err := i.validateRequest(ctx, w); err != nil {
return err
}
i.injectTools()
var (
response *responses.Response
upstreamErr error
respCopy responseCopier
firstResponseID string
)
prompt, promptFound, err := i.reqPayload.lastUserPrompt(ctx, i.logger)
if err != nil {
i.logger.Warn(ctx, "failed to get user prompt", slog.Error(err))
}
shouldLoop := true
for shouldLoop {
srv := i.newResponsesService()
respCopy = responseCopier{}
opts := i.requestOptions(&respCopy)
opts = append(opts, option.WithRequestTimeout(time.Second*600))
// TODO(ssncferreira): inject actor headers directly in the client-header
// middleware instead of using SDK options.
if actor := aibcontext.ActorFromContext(r.Context()); actor != nil && i.cfg.SendActorHeaders {
opts = append(opts, intercept.ActorHeadersAsOpenAIOpts(actor)...)
}
response, upstreamErr = i.newResponse(ctx, srv, opts)
// The failover loop may return a keypool exhaustion
// error. Render it here.
if upstreamErr != nil {
var keyPoolErr *keypool.Error
if errors.As(upstreamErr, &keyPoolErr) {
i.writeUpstreamError(w, intercept.ResponseErrorFromKeyPool(keyPoolErr))
return xerrors.Errorf("key pool exhausted: %w", upstreamErr)
}
}
if upstreamErr != nil || response == nil {
break
}
if firstResponseID == "" {
firstResponseID = response.ID
}
i.recordTokenUsage(ctx, response)
i.recordModelThoughts(ctx, response)
// Check if there any injected tools to invoke.
pending := i.getPendingInjectedToolCalls(response)
shouldLoop, err = i.handleInnerAgenticLoop(ctx, pending, response)
if err != nil {
i.sendCustomErr(ctx, w, http.StatusInternalServerError, err)
shouldLoop = false
}
}
if promptFound {
i.recordUserPrompt(ctx, firstResponseID, prompt)
}
i.recordNonInjectedToolUsage(ctx, response)
if upstreamErr != nil && !respCopy.responseReceived.Load() {
// no response received from upstream, return custom error
i.sendCustomErr(ctx, w, http.StatusInternalServerError, upstreamErr)
return xerrors.Errorf("failed to connect to upstream: %w", upstreamErr)
}
err = respCopy.forwardResp(w)
return errors.Join(upstreamErr, err)
}
// newResponse routes between BYOK (single attempt) and
// centralized failover.
func (i *BlockingResponsesInterceptor) newResponse(ctx context.Context, srv responses.ResponseService, opts []option.RequestOption) (*responses.Response, error) {
// BYOK: single attempt, no failover.
if i.cfg.KeyPool == nil {
return i.newResponseWithKey(ctx, srv, opts)
}
return i.newResponseWithKeyFailover(ctx, srv, opts)
}
// newResponseWithKey performs a single upstream call.
func (i *BlockingResponsesInterceptor) newResponseWithKey(ctx context.Context, srv responses.ResponseService, opts []option.RequestOption) (_ *responses.Response, outErr error) {
_, span := i.tracer.Start(ctx, "Intercept.ProcessRequest.Upstream", trace.WithAttributes(tracing.InterceptionAttributesFromContext(ctx)...))
defer tracing.EndSpanErr(span, &outErr)
// The body is overridden by option.WithRequestBody(reqPayload) in requestOptions
return srv.New(ctx, responses.ResponseNewParams{}, opts...)
}
// newResponseWithKeyFailover walks the centralized key pool,
// trying each key until one succeeds or the pool is exhausted.
// Keys are marked temporary on 429 and permanent on 401/403.
// Errors that aren't key-specific don't trigger failover and
// are returned to the caller.
func (i *BlockingResponsesInterceptor) newResponseWithKeyFailover(ctx context.Context, srv responses.ResponseService, opts []option.RequestOption) (*responses.Response, error) {
walker := i.cfg.KeyPool.Walker()
for {
key, keyPoolErr := walker.Next()
if keyPoolErr != nil {
return nil, keyPoolErr
}
// Record the key in use so the hint reflects the last attempted key.
i.credential = intercept.NewCredentialInfo(intercept.CredentialKindCentralized, key.Value())
i.logger.Debug(ctx, "using centralized api key",
slog.F("credential_hint", i.Credential().Hint), slog.F("credential_length", i.Credential().Length))
requestOpts := append([]option.RequestOption{}, opts...)
requestOpts = append(requestOpts,
option.WithAPIKey(key.Value()),
// Disable SDK retries because the failover loop
// handles retries via key rotation.
option.WithMaxRetries(0),
)
response, err := i.newResponseWithKey(ctx, srv, requestOpts)
// Key-specific failure: try the next key.
if i.markKeyOnError(ctx, key, err) {
continue
}
// Either success (response, nil) or a non-key error
// (nil, err): nothing to retry, return as-is.
return response, err
}
}