mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
feat: add automatic key failover for AI Bridge passthrough (#24920)
## Description Adds automatic key failover for passthrough routes for the Anthropic and OpenAI providers. A new `keyFailoverTransport` wraps the reverse-proxy transport: centralized requests walk the configured key pool and retry with the next key on key-specific failures (401/403/429), reusing the same key-marking semantics as the bridged routes. BYOK passthrough requests run as a single attempt with no failover. ## Changes - New `keypool.KeyFailoverConfig` carrying the `Pool` to walk and the provider-specific closures (`IsBYOK`, `InjectAuthKey`, `MarkKey`, `BuildExhaustedResponse`). - New `keypool.NewKeyFailoverTransport`: wraps an inner `http.RoundTripper`. Returns `inner` unchanged when `Pool` is nil, otherwise produces a transport that buffers the request body once, walks the pool per request, and replays each attempt with the next key. - New `Provider.KeyFailoverConfig(logger)` interface method. Anthropic injects `X-Api-Key`; OpenAI injects `Authorization: Bearer ...`; Copilot returns an empty config. - `passthrough.go` wires `NewKeyFailoverTransport` around the existing apidump middleware, so every retry attempt is recorded. ## Related Issues Related to: https://github.com/coder/internal/issues/1446 Related to: https://linear.app/codercom/issue/AIGOV-197/aibridge-automatic-key-failover-for-bridged-and-passthrough-routes ## Follow-up PRs - Remove dead `Provider.InjectAuthHeader` method now that all auth is applied per-attempt by `KeyFailoverTransport`. - Bedrock multi-key support. - Refactor provider vs interceptor config separation. - Record the actually-used key in the interception credential hint after failover. > [!NOTE] > Initially generated by Claude Opus 4.7, modified and reviewed by @ssncferreira
This commit is contained in:
@@ -27,6 +27,7 @@ import (
|
||||
"github.com/coder/coder/v2/aibridge/mcp"
|
||||
"github.com/coder/coder/v2/aibridge/recorder"
|
||||
"github.com/coder/coder/v2/aibridge/tracing"
|
||||
"github.com/coder/coder/v2/aibridge/utils"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
@@ -188,7 +189,7 @@ func (i *interceptionBase) unmarshalArgs(in string) (args recorder.ToolArgs) {
|
||||
}
|
||||
|
||||
// writeUpstreamError marshals and writes a given error.
|
||||
func (i *interceptionBase) writeUpstreamError(w http.ResponseWriter, oaiErr *responseError) {
|
||||
func (i *interceptionBase) writeUpstreamError(w http.ResponseWriter, oaiErr *ResponseError) {
|
||||
if oaiErr == nil {
|
||||
return
|
||||
}
|
||||
@@ -234,10 +235,10 @@ func (i *interceptionBase) markKeyOnError(ctx context.Context, key *keypool.Key,
|
||||
)
|
||||
}
|
||||
|
||||
// processKeyPoolError translates a keypool exhaustion error
|
||||
// ProcessKeyPoolError translates a keypool exhaustion error
|
||||
// into a developer-facing responseError shaped for the OpenAI
|
||||
// API. Returns nil if err is not an exhaustion error.
|
||||
func processKeyPoolError(err error) *responseError {
|
||||
func ProcessKeyPoolError(err error) *ResponseError {
|
||||
var transient *keypool.TransientKeyPoolError
|
||||
switch {
|
||||
case errors.As(err, &transient):
|
||||
@@ -292,7 +293,7 @@ func calculateActualInputTokenUsage(in openai.CompletionUsage) int64 {
|
||||
in.PromptTokensDetails.CachedTokens /* The aggregated number of text input tokens that has been cached from previous requests. */
|
||||
}
|
||||
|
||||
func getErrorResponse(err error) *responseError {
|
||||
func getErrorResponse(err error) *ResponseError {
|
||||
var apiErr *openai.Error
|
||||
if !errors.As(err, &apiErr) {
|
||||
return nil
|
||||
@@ -300,16 +301,16 @@ func getErrorResponse(err error) *responseError {
|
||||
return newErrorResponse(apiErr.Message, apiErr.Type, apiErr.Code, apiErr.StatusCode, keypool.ParseRetryAfter(apiErr.Response))
|
||||
}
|
||||
|
||||
var _ error = &responseError{}
|
||||
var _ error = &ResponseError{}
|
||||
|
||||
type responseError struct {
|
||||
type ResponseError struct {
|
||||
ErrorObject *shared.ErrorObject `json:"error"`
|
||||
StatusCode int `json:"-"`
|
||||
RetryAfter time.Duration `json:"-"`
|
||||
}
|
||||
|
||||
func newErrorResponse(msg, errType, code string, status int, retryAfter time.Duration) *responseError {
|
||||
return &responseError{
|
||||
func newErrorResponse(msg, errType, code string, status int, retryAfter time.Duration) *ResponseError {
|
||||
return &ResponseError{
|
||||
ErrorObject: &shared.ErrorObject{
|
||||
Code: code,
|
||||
Message: msg,
|
||||
@@ -320,9 +321,19 @@ func newErrorResponse(msg, errType, code string, status int, retryAfter time.Dur
|
||||
}
|
||||
}
|
||||
|
||||
func (a *responseError) Error() string {
|
||||
if a.ErrorObject == nil {
|
||||
func (e *ResponseError) Error() string {
|
||||
if e.ErrorObject == nil {
|
||||
return ""
|
||||
}
|
||||
return a.ErrorObject.Message
|
||||
return e.ErrorObject.Message
|
||||
}
|
||||
|
||||
// ToResponse marshals e into an *http.Response shaped for the
|
||||
// OpenAI API.
|
||||
func (e *ResponseError) ToResponse() *http.Response {
|
||||
body, err := json.Marshal(e)
|
||||
if err != nil {
|
||||
body = []byte(`{"error":{"type":"error","message":"error marshaling upstream error","code":"server_error"}}`)
|
||||
}
|
||||
return utils.NewJSONErrorResponse(e.StatusCode, e.RetryAfter, body)
|
||||
}
|
||||
|
||||
@@ -127,7 +127,7 @@ func TestProcessKeyPoolError(t *testing.T) {
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := processKeyPoolError(tc.err)
|
||||
got := ProcessKeyPoolError(tc.err)
|
||||
if tc.expectedNil {
|
||||
require.Nil(t, got)
|
||||
return
|
||||
@@ -207,7 +207,7 @@ func TestWriteUpstreamError(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
respErr *responseError
|
||||
respErr *ResponseError
|
||||
expectStatus int
|
||||
// Empty string means the header should be absent.
|
||||
expectRetryAfter string
|
||||
|
||||
@@ -225,7 +225,7 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req
|
||||
|
||||
// The failover loop may return a keypool exhaustion
|
||||
// error. Check before the SDK-error path.
|
||||
if keyErr := processKeyPoolError(err); keyErr != nil {
|
||||
if keyErr := ProcessKeyPoolError(err); keyErr != nil {
|
||||
i.writeUpstreamError(w, keyErr)
|
||||
return xerrors.Errorf("key pool exhausted: %w", err)
|
||||
}
|
||||
|
||||
@@ -144,7 +144,7 @@ func (i *StreamingInterception) ProcessRequest(w http.ResponseWriter, r *http.Re
|
||||
var currentKey *keypool.Key
|
||||
if walker != nil {
|
||||
key, err := walker.Next()
|
||||
if respErr := processKeyPoolError(err); respErr != nil {
|
||||
if respErr := ProcessKeyPoolError(err); respErr != nil {
|
||||
// Pool exhausted in this iteration. Relay the
|
||||
// error to the client: as an SSE event if events
|
||||
// have already been sent, or by direct write
|
||||
@@ -474,7 +474,7 @@ func (i *StreamingInterception) newStream(ctx context.Context, svc openai.ChatCo
|
||||
// processing error into a relayable responseError. Returns nil
|
||||
// when the error is unrecoverable, in which case nothing can be
|
||||
// relayed back.
|
||||
func (*StreamingInterception) mapStreamError(ctx context.Context, logger slog.Logger, streamErr, lastErr error) *responseError {
|
||||
func (*StreamingInterception) mapStreamError(ctx context.Context, logger slog.Logger, streamErr, lastErr error) *ResponseError {
|
||||
if streamErr != nil {
|
||||
if eventstream.IsUnrecoverableError(streamErr) {
|
||||
logger.Debug(ctx, "stream terminated", slog.Error(streamErr))
|
||||
|
||||
@@ -433,7 +433,7 @@ func filterBedrockBetaFlags(headers http.Header, model string) {
|
||||
}
|
||||
|
||||
// writeUpstreamError marshals and writes a given error.
|
||||
func (i *interceptionBase) writeUpstreamError(w http.ResponseWriter, antErr *responseError) {
|
||||
func (i *interceptionBase) writeUpstreamError(w http.ResponseWriter, antErr *ResponseError) {
|
||||
if antErr == nil {
|
||||
return
|
||||
}
|
||||
@@ -537,10 +537,10 @@ func (i *interceptionBase) markKeyOnError(ctx context.Context, key *keypool.Key,
|
||||
)
|
||||
}
|
||||
|
||||
// processKeyPoolError translates a keypool exhaustion error
|
||||
// ProcessKeyPoolError translates a keypool exhaustion error
|
||||
// into a developer-facing responseError shaped for the Anthropic
|
||||
// API. Returns nil if err is not an exhaustion error.
|
||||
func processKeyPoolError(err error) *responseError {
|
||||
func ProcessKeyPoolError(err error) *ResponseError {
|
||||
var transient *keypool.TransientKeyPoolError
|
||||
switch {
|
||||
case errors.As(err, &transient):
|
||||
@@ -562,7 +562,7 @@ func processKeyPoolError(err error) *responseError {
|
||||
}
|
||||
}
|
||||
|
||||
func getErrorResponse(err error) *responseError {
|
||||
func getErrorResponse(err error) *ResponseError {
|
||||
var apierr *anthropic.Error
|
||||
if !errors.As(err, &apierr) {
|
||||
return nil
|
||||
@@ -583,17 +583,17 @@ func getErrorResponse(err error) *responseError {
|
||||
return newErrorResponse(msg, errType, apierr.StatusCode, keypool.ParseRetryAfter(apierr.Response))
|
||||
}
|
||||
|
||||
var _ error = &responseError{}
|
||||
var _ error = &ResponseError{}
|
||||
|
||||
type responseError struct {
|
||||
type ResponseError struct {
|
||||
*anthropic.ErrorResponse
|
||||
|
||||
StatusCode int `json:"-"`
|
||||
RetryAfter time.Duration `json:"-"`
|
||||
}
|
||||
|
||||
func newErrorResponse(msg, errType string, status int, retryAfter time.Duration) *responseError {
|
||||
return &responseError{
|
||||
func newErrorResponse(msg, errType string, status int, retryAfter time.Duration) *ResponseError {
|
||||
return &ResponseError{
|
||||
ErrorResponse: &shared.ErrorResponse{
|
||||
Error: shared.ErrorObjectUnion{
|
||||
Message: msg,
|
||||
@@ -606,9 +606,19 @@ func newErrorResponse(msg, errType string, status int, retryAfter time.Duration)
|
||||
}
|
||||
}
|
||||
|
||||
func (a *responseError) Error() string {
|
||||
if a.ErrorResponse == nil {
|
||||
func (e *ResponseError) Error() string {
|
||||
if e.ErrorResponse == nil {
|
||||
return ""
|
||||
}
|
||||
return a.ErrorResponse.Error.Message
|
||||
return e.ErrorResponse.Error.Message
|
||||
}
|
||||
|
||||
// ToResponse marshals e into an *http.Response shaped for the
|
||||
// Anthropic API.
|
||||
func (e *ResponseError) ToResponse() *http.Response {
|
||||
body, err := json.Marshal(e)
|
||||
if err != nil {
|
||||
body = []byte(`{"type":"error","error":{"type":"error","message":"error marshaling upstream error"}}`)
|
||||
}
|
||||
return utils.NewJSONErrorResponse(e.StatusCode, e.RetryAfter, body)
|
||||
}
|
||||
|
||||
@@ -1039,7 +1039,7 @@ func TestProcessKeyPoolError(t *testing.T) {
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := processKeyPoolError(tc.err)
|
||||
got := ProcessKeyPoolError(tc.err)
|
||||
if tc.expectedNil {
|
||||
require.Nil(t, got)
|
||||
return
|
||||
@@ -1119,7 +1119,7 @@ func TestWriteUpstreamError(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
respErr *responseError
|
||||
respErr *ResponseError
|
||||
expectStatus int
|
||||
// Empty string means the header should be absent.
|
||||
expectRetryAfter string
|
||||
|
||||
@@ -114,7 +114,7 @@ func (i *BlockingInterception) ProcessRequest(w http.ResponseWriter, r *http.Req
|
||||
|
||||
// The failover loop may return a keypool exhaustion
|
||||
// error. Check before the SDK-error path.
|
||||
if keyErr := processKeyPoolError(err); keyErr != nil {
|
||||
if keyErr := ProcessKeyPoolError(err); keyErr != nil {
|
||||
i.writeUpstreamError(w, keyErr)
|
||||
return xerrors.Errorf("key pool exhausted: %w", err)
|
||||
}
|
||||
|
||||
@@ -175,7 +175,7 @@ newStream:
|
||||
var currentKey *keypool.Key
|
||||
if walker != nil {
|
||||
key, err := walker.Next()
|
||||
if respErr := processKeyPoolError(err); respErr != nil {
|
||||
if respErr := ProcessKeyPoolError(err); respErr != nil {
|
||||
// Pool exhausted in this iteration. Relay the
|
||||
// error to the client: as an SSE event if events
|
||||
// have already been sent, or by direct write
|
||||
@@ -599,10 +599,10 @@ newStream:
|
||||
}
|
||||
|
||||
// mapStreamError converts a mid-stream upstream error or
|
||||
// processing error into a relayable responseError. Returns nil
|
||||
// processing error into a relayable ResponseError. Returns nil
|
||||
// when the error is unrecoverable, in which case nothing can be
|
||||
// relayed back.
|
||||
func (*StreamingInterception) mapStreamError(ctx context.Context, logger slog.Logger, streamErr, lastErr error) *responseError {
|
||||
func (*StreamingInterception) mapStreamError(ctx context.Context, logger slog.Logger, streamErr, lastErr error) *ResponseError {
|
||||
if streamErr != nil {
|
||||
if eventstream.IsUnrecoverableError(streamErr) {
|
||||
logger.Debug(ctx, "stream terminated", slog.Error(streamErr))
|
||||
|
||||
@@ -35,6 +35,7 @@ import (
|
||||
"github.com/coder/coder/v2/aibridge/mcp"
|
||||
"github.com/coder/coder/v2/aibridge/recorder"
|
||||
"github.com/coder/coder/v2/aibridge/tracing"
|
||||
"github.com/coder/coder/v2/aibridge/utils"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
@@ -142,7 +143,7 @@ func (i *responsesInterceptionBase) validateRequest(ctx context.Context, w http.
|
||||
}
|
||||
|
||||
// writeUpstreamError marshals and writes a given error.
|
||||
func (i *responsesInterceptionBase) writeUpstreamError(w http.ResponseWriter, oaiErr *responseError) {
|
||||
func (i *responsesInterceptionBase) writeUpstreamError(w http.ResponseWriter, oaiErr *ResponseError) {
|
||||
if oaiErr == nil {
|
||||
return
|
||||
}
|
||||
@@ -188,10 +189,10 @@ func (i *responsesInterceptionBase) markKeyOnError(ctx context.Context, key *key
|
||||
)
|
||||
}
|
||||
|
||||
// processKeyPoolError translates a keypool exhaustion error
|
||||
// into a developer-facing responseError shaped for the OpenAI
|
||||
// ProcessKeyPoolError translates a keypool exhaustion error
|
||||
// into a developer-facing ResponseError shaped for the OpenAI
|
||||
// API. Returns nil if err is not an exhaustion error.
|
||||
func processKeyPoolError(err error) *responseError {
|
||||
func ProcessKeyPoolError(err error) *ResponseError {
|
||||
var transient *keypool.TransientKeyPoolError
|
||||
switch {
|
||||
case errors.As(err, &transient):
|
||||
@@ -215,8 +216,8 @@ func processKeyPoolError(err error) *responseError {
|
||||
}
|
||||
}
|
||||
|
||||
func newErrorResponse(msg, errType, code string, status int, retryAfter time.Duration) *responseError {
|
||||
return &responseError{
|
||||
func newErrorResponse(msg, errType, code string, status int, retryAfter time.Duration) *ResponseError {
|
||||
return &ResponseError{
|
||||
ErrorObject: &shared.ErrorObject{
|
||||
Code: code,
|
||||
Message: msg,
|
||||
@@ -227,19 +228,29 @@ func newErrorResponse(msg, errType, code string, status int, retryAfter time.Dur
|
||||
}
|
||||
}
|
||||
|
||||
var _ error = &responseError{}
|
||||
var _ error = &ResponseError{}
|
||||
|
||||
type responseError struct {
|
||||
type ResponseError struct {
|
||||
ErrorObject *shared.ErrorObject `json:"error"`
|
||||
StatusCode int `json:"-"`
|
||||
RetryAfter time.Duration `json:"-"`
|
||||
}
|
||||
|
||||
func (a *responseError) Error() string {
|
||||
if a.ErrorObject == nil {
|
||||
func (e *ResponseError) Error() string {
|
||||
if e.ErrorObject == nil {
|
||||
return ""
|
||||
}
|
||||
return a.ErrorObject.Message
|
||||
return e.ErrorObject.Message
|
||||
}
|
||||
|
||||
// ToResponse marshals e into an *http.Response shaped for the
|
||||
// OpenAI API.
|
||||
func (e *ResponseError) ToResponse() *http.Response {
|
||||
body, err := json.Marshal(e)
|
||||
if err != nil {
|
||||
body = []byte(`{"error":{"type":"error","message":"error marshaling upstream error","code":"server_error"}}`)
|
||||
}
|
||||
return utils.NewJSONErrorResponse(e.StatusCode, e.RetryAfter, body)
|
||||
}
|
||||
|
||||
// sendCustomErr sends custom responses.Error error to the client
|
||||
|
||||
@@ -432,7 +432,7 @@ func TestProcessKeyPoolError(t *testing.T) {
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
got := processKeyPoolError(tc.err)
|
||||
got := ProcessKeyPoolError(tc.err)
|
||||
if tc.expectedNil {
|
||||
require.Nil(t, got)
|
||||
return
|
||||
@@ -512,7 +512,7 @@ func TestWriteUpstreamError(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
respErr *responseError
|
||||
respErr *ResponseError
|
||||
expectStatus int
|
||||
// Empty string means the header should be absent.
|
||||
expectRetryAfter string
|
||||
|
||||
@@ -103,7 +103,7 @@ func (i *BlockingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r *
|
||||
// The failover loop may return a keypool exhaustion
|
||||
// error. Render it here.
|
||||
if upstreamErr != nil {
|
||||
if keyErr := processKeyPoolError(upstreamErr); keyErr != nil {
|
||||
if keyErr := ProcessKeyPoolError(upstreamErr); keyErr != nil {
|
||||
i.writeUpstreamError(w, keyErr)
|
||||
return xerrors.Errorf("key pool exhausted: %w", upstreamErr)
|
||||
}
|
||||
|
||||
@@ -135,7 +135,7 @@ func (i *StreamingResponsesInterceptor) ProcessRequest(w http.ResponseWriter, r
|
||||
var currentKey *keypool.Key
|
||||
if walker != nil {
|
||||
key, err := walker.Next()
|
||||
if respErr := processKeyPoolError(err); respErr != nil {
|
||||
if respErr := ProcessKeyPoolError(err); respErr != nil {
|
||||
// Pool exhausted: write the error directly. In
|
||||
// agentic mode the inner loop buffers events
|
||||
// instead of streaming them downstream, so the
|
||||
|
||||
@@ -6,8 +6,10 @@ import (
|
||||
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/aibridge/config"
|
||||
"github.com/coder/coder/v2/aibridge/intercept"
|
||||
"github.com/coder/coder/v2/aibridge/keypool"
|
||||
)
|
||||
|
||||
type MockProvider struct {
|
||||
@@ -31,6 +33,10 @@ func (m *MockProvider) InjectAuthHeader(h *http.Header) {
|
||||
m.InjectAuthHeaderFunc(h)
|
||||
}
|
||||
}
|
||||
|
||||
func (*MockProvider) KeyFailoverConfig(_ slog.Logger) keypool.KeyFailoverConfig {
|
||||
return keypool.KeyFailoverConfig{}
|
||||
}
|
||||
func (*MockProvider) CircuitBreakerConfig() *config.CircuitBreaker { return nil }
|
||||
func (*MockProvider) APIDumpDir() string { return "" }
|
||||
func (m *MockProvider) CreateInterceptor(w http.ResponseWriter, r *http.Request, tracer trace.Tracer) (intercept.Interceptor, error) {
|
||||
|
||||
@@ -0,0 +1,120 @@
|
||||
package keypool
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/coder/coder/v2/aibridge/utils"
|
||||
)
|
||||
|
||||
// KeyFailoverConfig is the per-provider configuration consumed by
|
||||
// NewKeyFailoverTransport.
|
||||
type KeyFailoverConfig struct {
|
||||
// Pool is the key pool to walk. Nil disables key failover.
|
||||
Pool *Pool
|
||||
|
||||
// IsBYOK returns true when the request already carries
|
||||
// user-supplied auth. BYOK requests skip key failover.
|
||||
IsBYOK func(*http.Request) bool
|
||||
|
||||
// InjectAuthKey writes the key value into the outbound headers
|
||||
// in the format the provider expects.
|
||||
InjectAuthKey func(*http.Header, string)
|
||||
|
||||
// MarkKey marks the key based on the upstream response.
|
||||
// Returns true when the response is a key-specific error,
|
||||
// causing the walker to advance and retry with the next key.
|
||||
MarkKey func(ctx context.Context, key *Key, resp *http.Response) bool
|
||||
|
||||
// BuildExhaustedResponse returns the response sent to the
|
||||
// client when the walker has no more keys to try.
|
||||
BuildExhaustedResponse func(err error) *http.Response
|
||||
}
|
||||
|
||||
// keyFailoverTransport retries inner across the key pool on
|
||||
// key-specific failures.
|
||||
type keyFailoverTransport struct {
|
||||
inner http.RoundTripper
|
||||
config KeyFailoverConfig
|
||||
}
|
||||
|
||||
// NewKeyFailoverTransport returns an http.RoundTripper backed by
|
||||
// keyFailoverTransport. If config.Pool is nil, inner is returned
|
||||
// unchanged.
|
||||
func NewKeyFailoverTransport(inner http.RoundTripper, config KeyFailoverConfig) http.RoundTripper {
|
||||
if config.Pool == nil {
|
||||
return inner
|
||||
}
|
||||
return &keyFailoverTransport{
|
||||
inner: inner,
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
// RoundTrip is invoked by the proxy once per outer client request,
|
||||
// after Rewrite has applied proxy headers.
|
||||
//
|
||||
// For centralized requests it walks the key pool, retrying on
|
||||
// key-specific failures until one key succeeds or the pool is
|
||||
// exhausted. BYOK requests skip the failover loop.
|
||||
func (t *keyFailoverTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
if t.config.IsBYOK(req) {
|
||||
return t.inner.RoundTrip(req)
|
||||
}
|
||||
|
||||
// Buffer once so retries can replay the body.
|
||||
body, err := bufferBody(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Fresh walker per request, independent of other inflight requests.
|
||||
walker := t.config.Pool.Walker()
|
||||
for {
|
||||
key, err := walker.Next()
|
||||
if err != nil {
|
||||
resp := t.config.BuildExhaustedResponse(err)
|
||||
if resp == nil {
|
||||
// Fallback if BuildExhaustedResponse returns nil.
|
||||
body := []byte(fmt.Sprintf(`{"error":"key pool exhausted: %s"}`, err))
|
||||
resp = utils.NewJSONErrorResponse(http.StatusBadGateway, 0, body)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// Clone per attempt so the original request isn't mutated.
|
||||
outReq := req.Clone(req.Context())
|
||||
if body != nil {
|
||||
outReq.Body = io.NopCloser(bytes.NewReader(body))
|
||||
}
|
||||
t.config.InjectAuthKey(&outReq.Header, key.Value())
|
||||
|
||||
resp, rtErr := t.inner.RoundTrip(outReq)
|
||||
if rtErr != nil {
|
||||
// Transport-level error, not a key issue.
|
||||
return resp, rtErr
|
||||
}
|
||||
// MarkKey returns true on key-specific failures (e.g. 401/403/429).
|
||||
if t.config.MarkKey(req.Context(), key, resp) {
|
||||
// Drain and retry with the next key.
|
||||
_, _ = io.Copy(io.Discard, resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
continue
|
||||
}
|
||||
// Success or non-key error, forward as-is.
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
|
||||
// bufferBody reads the request body fully so it can be replayed
|
||||
// across key-failover retries. Returns nil for a nil body.
|
||||
func bufferBody(req *http.Request) ([]byte, error) {
|
||||
if req.Body == nil {
|
||||
return nil, nil
|
||||
}
|
||||
defer req.Body.Close()
|
||||
return io.ReadAll(req.Body)
|
||||
}
|
||||
@@ -0,0 +1,69 @@
|
||||
package keypool_test
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/aibridge/keypool"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
// errFakeRoundTripperCalled is returned by fakeRoundTripper if it
|
||||
// ever gets invoked. The constructor identity tests should never
|
||||
// trigger a RoundTrip call.
|
||||
var errFakeRoundTripperCalled = xerrors.New("fakeRoundTripper should not be invoked")
|
||||
|
||||
// fakeRoundTripper is a no-op http.RoundTripper used to check
|
||||
// constructor identity in tests.
|
||||
type fakeRoundTripper struct{}
|
||||
|
||||
func (*fakeRoundTripper) RoundTrip(*http.Request) (*http.Response, error) {
|
||||
return nil, errFakeRoundTripperCalled
|
||||
}
|
||||
|
||||
func TestNewKeyFailoverTransport(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
pool, err := keypool.New([]string{"k0"}, quartz.NewMock(t))
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
// Constructor input.
|
||||
config keypool.KeyFailoverConfig
|
||||
// Whether the constructor returns inner unchanged.
|
||||
expectSame bool
|
||||
}{
|
||||
{
|
||||
// Pool is nil: failover is disabled, inner is returned unchanged.
|
||||
name: "pool_nil_returns_inner",
|
||||
config: keypool.KeyFailoverConfig{},
|
||||
expectSame: true,
|
||||
},
|
||||
{
|
||||
// Pool is set: inner is wrapped in a key-failover transport.
|
||||
name: "pool_set_returns_wrapper",
|
||||
config: keypool.KeyFailoverConfig{Pool: pool},
|
||||
expectSame: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
inner := &fakeRoundTripper{}
|
||||
got := keypool.NewKeyFailoverTransport(inner, tc.config)
|
||||
|
||||
if tc.expectSame {
|
||||
assert.Same(t, inner, got)
|
||||
} else {
|
||||
assert.NotSame(t, inner, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
+11
-9
@@ -13,6 +13,7 @@ import (
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/aibridge/intercept/apidump"
|
||||
"github.com/coder/coder/v2/aibridge/keypool"
|
||||
"github.com/coder/coder/v2/aibridge/metrics"
|
||||
"github.com/coder/coder/v2/aibridge/provider"
|
||||
"github.com/coder/coder/v2/aibridge/tracing"
|
||||
@@ -41,13 +42,17 @@ func newPassthroughRouter(prov provider.Provider, logger slog.Logger, m *metrics
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
|
||||
// Build a reverse proxy to the upstream, reused across all requests for this provider.
|
||||
// All request modifications happen in Rewrite.
|
||||
// Build the passthrough proxy, reused across all requests for this provider.
|
||||
// Rewrite sets proxy headers. For centralized requests, KeyFailoverTransport
|
||||
// handles auth and failover. BYOK requests pass through.
|
||||
proxy := &httputil.ReverseProxy{
|
||||
Rewrite: func(pr *httputil.ProxyRequest) {
|
||||
rewritePassthroughRequest(pr, provBaseURL, prov)
|
||||
rewritePassthroughRequest(pr, provBaseURL)
|
||||
},
|
||||
Transport: apidump.NewPassthroughMiddleware(t, prov.APIDumpDir(), prov.Name(), logger, quartz.NewReal()),
|
||||
Transport: keypool.NewKeyFailoverTransport(
|
||||
apidump.NewPassthroughMiddleware(t, prov.APIDumpDir(), prov.Name(), logger, quartz.NewReal()),
|
||||
prov.KeyFailoverConfig(logger),
|
||||
),
|
||||
ErrorHandler: func(rw http.ResponseWriter, req *http.Request, e error) {
|
||||
logger.Warn(req.Context(), "reverse proxy error", slog.Error(e), slog.F("path", req.URL.Path))
|
||||
http.Error(rw, "upstream proxy error", http.StatusBadGateway)
|
||||
@@ -67,8 +72,8 @@ func newPassthroughRouter(prov provider.Provider, logger slog.Logger, m *metrics
|
||||
}
|
||||
|
||||
// rewritePassthroughRequest configures the outbound request for the upstream and
|
||||
// applies proxy headers and provider auth.
|
||||
func rewritePassthroughRequest(pr *httputil.ProxyRequest, provBaseURL *url.URL, prov provider.Provider) {
|
||||
// applies proxy headers.
|
||||
func rewritePassthroughRequest(pr *httputil.ProxyRequest, provBaseURL *url.URL) {
|
||||
pr.SetURL(provBaseURL)
|
||||
|
||||
// Rewrite sets "X-Forwarded-For" to just last hop (clients IP address).
|
||||
@@ -87,9 +92,6 @@ func rewritePassthroughRequest(pr *httputil.ProxyRequest, provBaseURL *url.URL,
|
||||
if _, ok := pr.Out.Header["User-Agent"]; !ok {
|
||||
pr.Out.Header.Set("User-Agent", "aibridge") // TODO: use build tag.
|
||||
}
|
||||
|
||||
// Inject provider auth.
|
||||
prov.InjectAuthHeader(&pr.Out.Header)
|
||||
}
|
||||
|
||||
// newInvalidBaseURLHandler returns a handler that always returns 502
|
||||
|
||||
+290
-20
@@ -2,20 +2,27 @@ package aibridge //nolint:testpackage // tests unexported newPassthroughRouter
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"maps"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel"
|
||||
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/aibridge/config"
|
||||
"github.com/coder/coder/v2/aibridge/internal/testutil"
|
||||
"github.com/coder/coder/v2/aibridge/keypool"
|
||||
"github.com/coder/coder/v2/aibridge/provider"
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
var testTracer = otel.Tracer("bridge_test")
|
||||
@@ -171,25 +178,6 @@ func TestRewritePassthroughRequest(t *testing.T) {
|
||||
"User-Agent": {"custom-agent/1.0"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "injects_auth_header",
|
||||
reqPath: "http://client-host/chat",
|
||||
reqRemoteAddr: "1.1.1.1:1111",
|
||||
provider: &testutil.MockProvider{
|
||||
URL: "https://upstream-host/base",
|
||||
InjectAuthHeaderFunc: func(h *http.Header) {
|
||||
h.Set("Authorization", "Bearer test-token")
|
||||
},
|
||||
},
|
||||
expectURL: "https://upstream-host/base/chat",
|
||||
expectHeaders: http.Header{
|
||||
"X-Forwarded-Host": {"client-host"},
|
||||
"X-Forwarded-Proto": {"http"},
|
||||
"X-Forwarded-For": {"1.1.1.1"},
|
||||
"User-Agent": {"aibridge"},
|
||||
"Authorization": {"Bearer test-token"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "appends_remote_addr_to_existing_forwarded_for_chain",
|
||||
reqPath: "http://client-host/chat",
|
||||
@@ -260,7 +248,7 @@ func TestRewritePassthroughRequest(t *testing.T) {
|
||||
Out: r.Clone(r.Context()),
|
||||
}
|
||||
|
||||
rewritePassthroughRequest(pr, provBaseURL, tc.provider)
|
||||
rewritePassthroughRequest(pr, provBaseURL)
|
||||
|
||||
assert.Equal(t, tc.expectURL, pr.Out.URL.String())
|
||||
assert.Equal(t, "", pr.Out.Host)
|
||||
@@ -301,3 +289,285 @@ func TestPassthroughRouterReusesProxyInstance(t *testing.T) {
|
||||
|
||||
assert.EqualValues(t, 1, newConnections.Load())
|
||||
}
|
||||
|
||||
// TestPassthrough_KeyFailover exercises the KeyFailoverTransport
|
||||
// end-to-end through the passthrough proxy, parameterised over
|
||||
// providers (anthropic, openai). Each scenario asserts the upstream
|
||||
// request count, the response status and Retry-After, and the final
|
||||
// pool state.
|
||||
func TestPassthrough_KeyFailover(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type upstreamResponse struct {
|
||||
statusCode int
|
||||
body string
|
||||
headers map[string]string
|
||||
}
|
||||
|
||||
const (
|
||||
rateLimitBody = `{"error":"rate"}`
|
||||
authErrorBody = `{"error":"unauthorized"}`
|
||||
serverErrorBody = `{"error":"server"}`
|
||||
successBody = `{"data":[]}`
|
||||
)
|
||||
|
||||
// providers parameterises the table over the two providers
|
||||
// that support key failover. Each entry encapsulates the
|
||||
// provider-specific bits the test needs: how the mock upstream
|
||||
// extracts the key from the request, how a BYOK request sets
|
||||
// it, and how the provider is constructed for a given pool.
|
||||
providers := []struct {
|
||||
name string
|
||||
extractKey func(*http.Request) string
|
||||
setBYOK func(*http.Request, string)
|
||||
newProvider func(baseURL string, pool *keypool.Pool) provider.Provider
|
||||
}{
|
||||
{
|
||||
name: "anthropic",
|
||||
extractKey: func(r *http.Request) string {
|
||||
return r.Header.Get("X-Api-Key")
|
||||
},
|
||||
setBYOK: func(r *http.Request, key string) {
|
||||
r.Header.Set("X-Api-Key", key)
|
||||
},
|
||||
newProvider: func(baseURL string, pool *keypool.Pool) provider.Provider {
|
||||
return provider.NewAnthropic(config.Anthropic{
|
||||
BaseURL: baseURL,
|
||||
KeyPool: pool,
|
||||
}, nil)
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "openai",
|
||||
extractKey: func(r *http.Request) string {
|
||||
return strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ")
|
||||
},
|
||||
setBYOK: func(r *http.Request, key string) {
|
||||
r.Header.Set("Authorization", "Bearer "+key)
|
||||
},
|
||||
newProvider: func(baseURL string, pool *keypool.Pool) provider.Provider {
|
||||
cfg := config.OpenAI{BaseURL: baseURL}
|
||||
if pool != nil {
|
||||
cfg.KeyPool = pool
|
||||
}
|
||||
return provider.NewOpenAI(cfg)
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
// Centralized pool keys. Empty when byokKey is set.
|
||||
keys []string
|
||||
// BYOK key. Empty when keys is set.
|
||||
byokKey string
|
||||
// Scripted upstream responses keyed by API key value.
|
||||
responses map[string]upstreamResponse
|
||||
expectedRequestCount int32
|
||||
expectedStatusCode int
|
||||
expectedRetryAfter string
|
||||
// Expected key states after the request, by index in keys.
|
||||
expectedKeyStates []keypool.KeyState
|
||||
}{
|
||||
{
|
||||
// Given: 1 valid key returning 200.
|
||||
// Then: 1 request, 200 response, key remains valid.
|
||||
name: "single_valid_key",
|
||||
keys: []string{"k0"},
|
||||
responses: map[string]upstreamResponse{
|
||||
"k0": {statusCode: http.StatusOK, body: successBody},
|
||||
},
|
||||
expectedRequestCount: 1,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
expectedKeyStates: []keypool.KeyState{keypool.KeyStateValid},
|
||||
},
|
||||
{
|
||||
// Given: 2 keys; key-0 returns 429, key-1 returns 200.
|
||||
// Then: 2 requests, 200 response, key-0 temporary, key-1 valid.
|
||||
name: "failover_after_429",
|
||||
keys: []string{"k0", "k1"},
|
||||
responses: map[string]upstreamResponse{
|
||||
"k0": {
|
||||
statusCode: http.StatusTooManyRequests,
|
||||
headers: map[string]string{"Retry-After": "5"},
|
||||
body: rateLimitBody,
|
||||
},
|
||||
"k1": {statusCode: http.StatusOK, body: successBody},
|
||||
},
|
||||
expectedRequestCount: 2,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
expectedKeyStates: []keypool.KeyState{
|
||||
keypool.KeyStateTemporary,
|
||||
keypool.KeyStateValid,
|
||||
},
|
||||
},
|
||||
{
|
||||
// Given: 2 keys; key-0 returns 401, key-1 returns 200.
|
||||
// Then: 2 requests, 200 response, key-0 permanent, key-1 valid.
|
||||
name: "failover_after_401",
|
||||
keys: []string{"k0", "k1"},
|
||||
responses: map[string]upstreamResponse{
|
||||
"k0": {statusCode: http.StatusUnauthorized, body: authErrorBody},
|
||||
"k1": {statusCode: http.StatusOK, body: successBody},
|
||||
},
|
||||
expectedRequestCount: 2,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
expectedKeyStates: []keypool.KeyState{
|
||||
keypool.KeyStatePermanent,
|
||||
keypool.KeyStateValid,
|
||||
},
|
||||
},
|
||||
{
|
||||
// Given: 2 keys; key-0 returns 403, key-1 returns 200.
|
||||
// Then: 2 requests, 200 response, key-0 permanent, key-1 valid.
|
||||
name: "failover_after_403",
|
||||
keys: []string{"k0", "k1"},
|
||||
responses: map[string]upstreamResponse{
|
||||
"k0": {statusCode: http.StatusForbidden, body: authErrorBody},
|
||||
"k1": {statusCode: http.StatusOK, body: successBody},
|
||||
},
|
||||
expectedRequestCount: 2,
|
||||
expectedStatusCode: http.StatusOK,
|
||||
expectedKeyStates: []keypool.KeyState{
|
||||
keypool.KeyStatePermanent,
|
||||
keypool.KeyStateValid,
|
||||
},
|
||||
},
|
||||
{
|
||||
// Given: 3 keys; all return 429 with cooldowns 5s, 3s, 10s.
|
||||
// Then: 3 requests, 429 response with smallest Retry-After,
|
||||
// all keys temporary.
|
||||
name: "all_keys_rate_limited",
|
||||
keys: []string{"k0", "k1", "k2"},
|
||||
responses: map[string]upstreamResponse{
|
||||
"k0": {
|
||||
statusCode: http.StatusTooManyRequests,
|
||||
headers: map[string]string{"Retry-After": "5"},
|
||||
body: rateLimitBody,
|
||||
},
|
||||
"k1": {
|
||||
statusCode: http.StatusTooManyRequests,
|
||||
headers: map[string]string{"Retry-After": "3"},
|
||||
body: rateLimitBody,
|
||||
},
|
||||
"k2": {
|
||||
statusCode: http.StatusTooManyRequests,
|
||||
headers: map[string]string{"Retry-After": "10"},
|
||||
body: rateLimitBody,
|
||||
},
|
||||
},
|
||||
expectedRequestCount: 3,
|
||||
expectedStatusCode: http.StatusTooManyRequests,
|
||||
expectedRetryAfter: "3",
|
||||
expectedKeyStates: []keypool.KeyState{
|
||||
keypool.KeyStateTemporary,
|
||||
keypool.KeyStateTemporary,
|
||||
keypool.KeyStateTemporary,
|
||||
},
|
||||
},
|
||||
{
|
||||
// Given: 2 keys; both return 401.
|
||||
// Then: 2 requests, 502 response, both keys permanent.
|
||||
name: "all_keys_unauthorized",
|
||||
keys: []string{"k0", "k1"},
|
||||
responses: map[string]upstreamResponse{
|
||||
"k0": {statusCode: http.StatusUnauthorized, body: authErrorBody},
|
||||
"k1": {statusCode: http.StatusUnauthorized, body: authErrorBody},
|
||||
},
|
||||
expectedRequestCount: 2,
|
||||
expectedStatusCode: http.StatusBadGateway,
|
||||
expectedKeyStates: []keypool.KeyState{
|
||||
keypool.KeyStatePermanent,
|
||||
keypool.KeyStatePermanent,
|
||||
},
|
||||
},
|
||||
{
|
||||
// Given: 2 keys; key-0 returns 500.
|
||||
// Then: 1 request, 500 response, both keys remain valid.
|
||||
name: "server_error_no_failover",
|
||||
keys: []string{"k0", "k1"},
|
||||
responses: map[string]upstreamResponse{
|
||||
"k0": {statusCode: http.StatusInternalServerError, body: serverErrorBody},
|
||||
},
|
||||
expectedRequestCount: 1,
|
||||
expectedStatusCode: http.StatusInternalServerError,
|
||||
expectedKeyStates: []keypool.KeyState{
|
||||
keypool.KeyStateValid,
|
||||
keypool.KeyStateValid,
|
||||
},
|
||||
},
|
||||
{
|
||||
// Given: BYOK with a single user-supplied key returning 429.
|
||||
// Then: 1 request, 429 forwarded as-is, no failover.
|
||||
name: "byok_no_failover",
|
||||
byokKey: "user-byok",
|
||||
responses: map[string]upstreamResponse{
|
||||
"user-byok": {
|
||||
statusCode: http.StatusTooManyRequests,
|
||||
headers: map[string]string{"Retry-After": "5"},
|
||||
body: rateLimitBody,
|
||||
},
|
||||
},
|
||||
expectedRequestCount: 1,
|
||||
expectedStatusCode: http.StatusTooManyRequests,
|
||||
expectedRetryAfter: "5",
|
||||
},
|
||||
}
|
||||
|
||||
for _, prov := range providers {
|
||||
for _, tc := range tests {
|
||||
t.Run(prov.name+"/"+tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Mock upstream: counts requests and returns
|
||||
// scripted responses keyed by API key. An unmapped
|
||||
// key falls through to 500 so misconfigured cases
|
||||
// surface via the status assertion.
|
||||
var requestCount atomic.Int32
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
requestCount.Add(1)
|
||||
_, _ = io.Copy(io.Discard, r.Body)
|
||||
resp, ok := tc.responses[prov.extractKey(r)]
|
||||
if !ok {
|
||||
resp = upstreamResponse{statusCode: http.StatusInternalServerError}
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
for hk, hv := range resp.headers {
|
||||
w.Header().Set(hk, hv)
|
||||
}
|
||||
w.WriteHeader(resp.statusCode)
|
||||
_, _ = w.Write([]byte(resp.body))
|
||||
}))
|
||||
t.Cleanup(upstream.Close)
|
||||
|
||||
var pool *keypool.Pool
|
||||
if len(tc.keys) > 0 {
|
||||
var err error
|
||||
pool, err = keypool.New(tc.keys, quartz.NewMock(t))
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
p := prov.newProvider(upstream.URL, pool)
|
||||
// IgnoreErrors: MarkKey logs at ERROR level when a
|
||||
// key is marked permanent (401/403); slogtest would
|
||||
// otherwise fail those scenarios.
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
handler := newPassthroughRouter(p, logger, nil, testTracer)
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/v1/models", nil)
|
||||
if tc.byokKey != "" {
|
||||
prov.setBYOK(req, tc.byokKey)
|
||||
}
|
||||
w := httptest.NewRecorder()
|
||||
handler.ServeHTTP(w, req)
|
||||
|
||||
assert.Equal(t, tc.expectedRequestCount, requestCount.Load(), "upstream request count")
|
||||
assert.Equal(t, tc.expectedStatusCode, w.Code, "response status code")
|
||||
assert.Equal(t, tc.expectedRetryAfter, w.Header().Get("Retry-After"), "Retry-After header")
|
||||
if pool != nil {
|
||||
assert.Equal(t, tc.expectedKeyStates, pool.PoolState(), "key states")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -11,6 +12,7 @@ import (
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/aibridge/circuitbreaker"
|
||||
"github.com/coder/coder/v2/aibridge/config"
|
||||
"github.com/coder/coder/v2/aibridge/intercept"
|
||||
@@ -219,6 +221,25 @@ func (p *Anthropic) InjectAuthHeader(headers *http.Header) {
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Anthropic) KeyFailoverConfig(logger slog.Logger) keypool.KeyFailoverConfig {
|
||||
name := p.Name()
|
||||
return keypool.KeyFailoverConfig{
|
||||
Pool: p.cfg.KeyPool,
|
||||
IsBYOK: func(r *http.Request) bool {
|
||||
return r.Header.Get("X-Api-Key") != "" || r.Header.Get("Authorization") != ""
|
||||
},
|
||||
InjectAuthKey: func(h *http.Header, key string) {
|
||||
h.Set("X-Api-Key", key)
|
||||
},
|
||||
MarkKey: func(ctx context.Context, key *keypool.Key, resp *http.Response) bool {
|
||||
return keypool.MarkKeyOnStatus(ctx, key, resp, logger, name)
|
||||
},
|
||||
BuildExhaustedResponse: func(err error) *http.Response {
|
||||
return messages.ProcessKeyPoolError(err).ToResponse()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Anthropic) CircuitBreakerConfig() *config.CircuitBreaker {
|
||||
return p.cfg.CircuitBreaker
|
||||
}
|
||||
|
||||
@@ -12,10 +12,12 @@ import (
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/aibridge/config"
|
||||
"github.com/coder/coder/v2/aibridge/intercept"
|
||||
"github.com/coder/coder/v2/aibridge/intercept/chatcompletions"
|
||||
"github.com/coder/coder/v2/aibridge/intercept/responses"
|
||||
"github.com/coder/coder/v2/aibridge/keypool"
|
||||
"github.com/coder/coder/v2/aibridge/tracing"
|
||||
"github.com/coder/coder/v2/aibridge/utils"
|
||||
)
|
||||
@@ -111,6 +113,12 @@ func (*Copilot) AuthHeader() string {
|
||||
// The original Authorization header flows through untouched from the client.
|
||||
func (*Copilot) InjectAuthHeader(_ *http.Header) {}
|
||||
|
||||
// KeyFailoverConfig returns a config with a nil Pool, which makes
|
||||
// the KeyFailoverTransport short-circuit. Copilot is always BYOK.
|
||||
func (*Copilot) KeyFailoverConfig(_ slog.Logger) keypool.KeyFailoverConfig {
|
||||
return keypool.KeyFailoverConfig{}
|
||||
}
|
||||
|
||||
func (p *Copilot) CircuitBreakerConfig() *config.CircuitBreaker {
|
||||
return p.circuitBreaker
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -12,6 +13,7 @@ import (
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/aibridge/config"
|
||||
"github.com/coder/coder/v2/aibridge/intercept"
|
||||
"github.com/coder/coder/v2/aibridge/intercept/chatcompletions"
|
||||
@@ -218,6 +220,25 @@ func (p *OpenAI) InjectAuthHeader(headers *http.Header) {
|
||||
}
|
||||
}
|
||||
|
||||
func (p *OpenAI) KeyFailoverConfig(logger slog.Logger) keypool.KeyFailoverConfig {
|
||||
name := p.Name()
|
||||
return keypool.KeyFailoverConfig{
|
||||
Pool: p.cfg.KeyPool,
|
||||
IsBYOK: func(r *http.Request) bool {
|
||||
return r.Header.Get("Authorization") != ""
|
||||
},
|
||||
InjectAuthKey: func(h *http.Header, key string) {
|
||||
h.Set("Authorization", "Bearer "+key)
|
||||
},
|
||||
MarkKey: func(ctx context.Context, key *keypool.Key, resp *http.Response) bool {
|
||||
return keypool.MarkKeyOnStatus(ctx, key, resp, logger, name)
|
||||
},
|
||||
BuildExhaustedResponse: func(err error) *http.Response {
|
||||
return chatcompletions.ProcessKeyPoolError(err).ToResponse()
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (p *OpenAI) CircuitBreakerConfig() *config.CircuitBreaker {
|
||||
return p.circuitBreaker
|
||||
}
|
||||
|
||||
@@ -6,8 +6,10 @@ import (
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/aibridge/config"
|
||||
"github.com/coder/coder/v2/aibridge/intercept"
|
||||
"github.com/coder/coder/v2/aibridge/keypool"
|
||||
)
|
||||
|
||||
var ErrUnknownRoute = xerrors.New("unknown route")
|
||||
@@ -76,7 +78,12 @@ type Provider interface {
|
||||
// token in.
|
||||
AuthHeader() string
|
||||
// InjectAuthHeader allows [Provider]s to set its authentication header.
|
||||
// TODO(ssncferreira): remove. Auth is now applied per-attempt by
|
||||
// KeyFailoverTransport (see [Provider.KeyFailoverConfig]).
|
||||
InjectAuthHeader(*http.Header)
|
||||
// KeyFailoverConfig returns the per-provider configuration for
|
||||
// automatic key failover on passthrough routes.
|
||||
KeyFailoverConfig(logger slog.Logger) keypool.KeyFailoverConfig
|
||||
|
||||
// CircuitBreakerConfig returns the circuit breaker configuration for the provider.
|
||||
CircuitBreakerConfig() *config.CircuitBreaker
|
||||
|
||||
@@ -0,0 +1,34 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
// NewJSONErrorResponse builds an *http.Response with a JSON body
|
||||
// and optional Retry-After header. Used to synthesize bridge-side
|
||||
// error responses (e.g. key-pool exhaustion, marshaling
|
||||
// fallbacks). Retry-After is set to whole seconds (rounded up)
|
||||
// when retryAfter is positive, and omitted otherwise.
|
||||
func NewJSONErrorResponse(status int, retryAfter time.Duration, body []byte) *http.Response {
|
||||
h := http.Header{}
|
||||
h.Set("Content-Type", "application/json")
|
||||
if retryAfter > 0 {
|
||||
h.Set("Retry-After", strconv.Itoa(int(math.Ceil(retryAfter.Seconds()))))
|
||||
}
|
||||
return &http.Response{
|
||||
Status: fmt.Sprintf("%d %s", status, http.StatusText(status)),
|
||||
StatusCode: status,
|
||||
Proto: "HTTP/1.1",
|
||||
ProtoMajor: 1,
|
||||
ProtoMinor: 1,
|
||||
Header: h,
|
||||
Body: io.NopCloser(bytes.NewReader(body)),
|
||||
ContentLength: int64(len(body)),
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,91 @@
|
||||
package utils_test
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/aibridge/utils"
|
||||
)
|
||||
|
||||
func TestNewJSONErrorResponse(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
status int
|
||||
retryAfter time.Duration
|
||||
body []byte
|
||||
// Empty string means the header should be absent.
|
||||
expectRetryAfter string
|
||||
}{
|
||||
{
|
||||
// Permanent exhaustion: 502 with no Retry-After.
|
||||
name: "permanent_no_retry_after",
|
||||
status: http.StatusBadGateway,
|
||||
retryAfter: 0,
|
||||
body: []byte(`{"error":"permanent"}`),
|
||||
expectRetryAfter: "",
|
||||
},
|
||||
{
|
||||
// Transient exhaustion with zero retryAfter: no Retry-After.
|
||||
name: "transient_no_retry_after",
|
||||
status: http.StatusTooManyRequests,
|
||||
retryAfter: 0,
|
||||
body: []byte(`{"error":"rate"}`),
|
||||
expectRetryAfter: "",
|
||||
},
|
||||
{
|
||||
// Transient exhaustion: 429 with Retry-After in seconds.
|
||||
name: "transient_with_retry_after",
|
||||
status: http.StatusTooManyRequests,
|
||||
retryAfter: 60 * time.Second,
|
||||
body: []byte(`{"error":"rate"}`),
|
||||
expectRetryAfter: "60",
|
||||
},
|
||||
{
|
||||
// Transient exhaustion with negative retryAfter: Retry-After header omitted.
|
||||
name: "transient_negative_retry_after",
|
||||
status: http.StatusTooManyRequests,
|
||||
retryAfter: -1 * time.Second,
|
||||
body: []byte(`{"error":"rate"}`),
|
||||
expectRetryAfter: "",
|
||||
},
|
||||
{
|
||||
// Transient exhaustion with 500ms retryAfter rounds up to Retry-After: 1.
|
||||
name: "transient_under_one_second_rounds_up",
|
||||
status: http.StatusTooManyRequests,
|
||||
retryAfter: 500 * time.Millisecond,
|
||||
body: []byte(`{"error":"rate"}`),
|
||||
expectRetryAfter: "1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
resp := utils.NewJSONErrorResponse(tc.status, tc.retryAfter, tc.body)
|
||||
require.NotNil(t, resp)
|
||||
|
||||
assert.Equal(t, tc.status, resp.StatusCode)
|
||||
assert.Equal(t, "application/json", resp.Header.Get("Content-Type"))
|
||||
assert.Equal(t, int64(len(tc.body)), resp.ContentLength)
|
||||
|
||||
if tc.expectRetryAfter == "" {
|
||||
assert.Empty(t, resp.Header.Get("Retry-After"))
|
||||
} else {
|
||||
assert.Equal(t, tc.expectRetryAfter, resp.Header.Get("Retry-After"))
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, resp.Body.Close())
|
||||
assert.Equal(t, tc.body, body)
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user