mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
feat: serve 503 sentinel for disabled providers (#25794)
_Disclosure: created with Coder Agents._ When providers are disabled, we should serve a sentinel error so the requesting client (Claude Code, Coder Agents, etc) is informed. Coder Agents can also conditionalize its display to show a helpful error message. --------- Signed-off-by: Danny Kopping <danny@coder.com> Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -57,6 +57,14 @@ func NewCopilotProvider(cfg config.Copilot) provider.Provider {
|
||||
return provider.NewCopilot(cfg)
|
||||
}
|
||||
|
||||
// NewDisabledProviderStub returns a Provider that reports Enabled() ==
|
||||
// false and has no-op implementations for all other methods. Use this
|
||||
// instead of constructing a concrete provider for disabled rows so that
|
||||
// adding a new provider type does not require updating a switch here.
|
||||
func NewDisabledProviderStub(name, providerType string) provider.Provider {
|
||||
return provider.NewDisabledStub(name, providerType)
|
||||
}
|
||||
|
||||
func NewMetrics(reg prometheus.Registerer) *metrics.Metrics {
|
||||
return metrics.NewMetrics(reg)
|
||||
}
|
||||
|
||||
@@ -30,6 +30,11 @@ import (
|
||||
const (
|
||||
// The duration after which an async recording will be aborted.
|
||||
recordingTimeout = time.Second * 5
|
||||
|
||||
// ErrorCodeProviderDisabled is the code written in the response
|
||||
// body when a request targets a configured-but-disabled provider.
|
||||
// Paired with HTTP 503.
|
||||
ErrorCodeProviderDisabled = "provider_disabled"
|
||||
)
|
||||
|
||||
// RequestBridge is an [http.Handler] which is capable of masquerading as AI providers' APIs;
|
||||
@@ -96,6 +101,14 @@ func NewRequestBridge(ctx context.Context, providers []provider.Provider, rec re
|
||||
mux := http.NewServeMux()
|
||||
|
||||
for _, prov := range providers {
|
||||
// Disabled providers serve a 503 sentinel on every path under
|
||||
// "/<name>/". Bound to the bare name (not RoutePrefix) so paths
|
||||
// outside the provider's normal "/v1" subtree are also caught.
|
||||
if !prov.Enabled() {
|
||||
prefix := fmt.Sprintf("/%s/", prov.Name())
|
||||
mux.HandleFunc(prefix, disabledProviderHandler(prov.Name(), logger))
|
||||
continue
|
||||
}
|
||||
// Create per-provider circuit breaker if configured
|
||||
cfg := prov.CircuitBreakerConfig()
|
||||
providerName := prov.Name()
|
||||
@@ -170,6 +183,20 @@ func NewRequestBridge(ctx context.Context, providers []provider.Provider, rec re
|
||||
}, nil
|
||||
}
|
||||
|
||||
// disabledProviderHandler returns 503 with a body containing
|
||||
// [ErrorCodeProviderDisabled] and the provider name for every request
|
||||
// targeting name.
|
||||
func disabledProviderHandler(name string, logger slog.Logger) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
logger.Debug(r.Context(), "refusing request for disabled ai provider",
|
||||
slog.F("provider", name),
|
||||
slog.F("path", r.URL.Path),
|
||||
slog.F("method", r.Method),
|
||||
)
|
||||
http.Error(w, fmt.Sprintf("%s: AI provider %q is disabled", ErrorCodeProviderDisabled, name), http.StatusServiceUnavailable)
|
||||
}
|
||||
}
|
||||
|
||||
// newInterceptionProcessor returns an [http.HandlerFunc] which is capable of creating a new interceptor and processing a given request
|
||||
// using [Provider] p, recording all usage events using [Recorder] rec.
|
||||
// If cbs is non-nil, circuit breaker protection is applied per endpoint/model tuple.
|
||||
|
||||
@@ -205,3 +205,58 @@ func TestPassthroughRoutesForProviders(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestDisabledProviderHandler asserts that requests to a disabled
|
||||
// provider return a 503 with an ErrorCodeProviderDisabled body and
|
||||
// that a sibling enabled provider keeps routing normally.
|
||||
func TestDisabledProviderHandler(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = w.Write([]byte("upstream-reached"))
|
||||
}))
|
||||
t.Cleanup(upstream.Close)
|
||||
|
||||
enabled := aibridge.NewOpenAIProvider(config.OpenAI{Name: "enabled-openai", BaseURL: upstream.URL})
|
||||
disabled := aibridge.NewDisabledProviderStub("disabled-openai", "openai")
|
||||
bridge, err := aibridge.NewRequestBridge(
|
||||
t.Context(),
|
||||
[]provider.Provider{enabled, disabled},
|
||||
nil, nil, logger, nil, bridgeTestTracer,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
path string
|
||||
}{
|
||||
{name: "Bridged", path: "/disabled-openai/v1/chat/completions"},
|
||||
{name: "Passthrough", path: "/disabled-openai/v1/models"},
|
||||
{name: "Unknown", path: "/disabled-openai/anything/else"},
|
||||
} {
|
||||
t.Run("DisabledProviderReturnsSentinel/"+tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req := httptest.NewRequest(http.MethodPost, tc.path, nil)
|
||||
resp := httptest.NewRecorder()
|
||||
bridge.ServeHTTP(resp, req)
|
||||
|
||||
assert.Equal(t, http.StatusServiceUnavailable, resp.Code)
|
||||
assert.Contains(t, resp.Body.String(), aibridge.ErrorCodeProviderDisabled)
|
||||
assert.Contains(t, resp.Body.String(), "disabled-openai")
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("EnabledProviderUnaffected", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/enabled-openai/v1/models", nil)
|
||||
resp := httptest.NewRecorder()
|
||||
bridge.ServeHTTP(resp, req)
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.Code)
|
||||
assert.Equal(t, "upstream-reached", resp.Body.String())
|
||||
})
|
||||
}
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
type MockProvider struct {
|
||||
NameStr string
|
||||
URL string
|
||||
Disabled bool
|
||||
Bridged []string
|
||||
Passthrough []string
|
||||
InterceptorFunc func(w http.ResponseWriter, r *http.Request, tracer trace.Tracer) (intercept.Interceptor, error)
|
||||
@@ -22,6 +23,7 @@ type MockProvider struct {
|
||||
|
||||
func (m *MockProvider) Type() string { return m.NameStr }
|
||||
func (m *MockProvider) Name() string { return m.NameStr }
|
||||
func (m *MockProvider) Enabled() bool { return !m.Disabled }
|
||||
func (m *MockProvider) BaseURL() string { return m.URL }
|
||||
func (m *MockProvider) RoutePrefix() string { return fmt.Sprintf("/%s", m.NameStr) }
|
||||
func (m *MockProvider) BridgedRoutes() []string { return m.Bridged }
|
||||
|
||||
@@ -95,6 +95,8 @@ func (p *Anthropic) Name() string {
|
||||
return p.cfg.Name
|
||||
}
|
||||
|
||||
func (*Anthropic) Enabled() bool { return true }
|
||||
|
||||
func (p *Anthropic) RoutePrefix() string {
|
||||
return fmt.Sprintf("/%s", p.Name())
|
||||
}
|
||||
|
||||
@@ -78,6 +78,8 @@ func (p *Copilot) Name() string {
|
||||
return p.cfg.Name
|
||||
}
|
||||
|
||||
func (*Copilot) Enabled() bool { return true }
|
||||
|
||||
func (p *Copilot) BaseURL() string {
|
||||
return p.cfg.BaseURL
|
||||
}
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
package provider
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/aibridge/config"
|
||||
"github.com/coder/coder/v2/aibridge/intercept"
|
||||
"github.com/coder/coder/v2/aibridge/keypool"
|
||||
)
|
||||
|
||||
// DisabledStub is a Provider placeholder for a configured-but-disabled
|
||||
// provider. Only Name and Enabled return meaningful values; all other
|
||||
// methods return empty/nil so the stub never influences routing.
|
||||
type DisabledStub struct {
|
||||
name string
|
||||
providerType string
|
||||
}
|
||||
|
||||
// NewDisabledStub returns a Provider stub that reports Enabled() == false.
|
||||
// The type string is preserved so callers can distinguish provider families.
|
||||
func NewDisabledStub(name, providerType string) *DisabledStub {
|
||||
return &DisabledStub{name: name, providerType: providerType}
|
||||
}
|
||||
|
||||
func (d *DisabledStub) Type() string { return d.providerType }
|
||||
func (d *DisabledStub) Name() string { return d.name }
|
||||
func (*DisabledStub) Enabled() bool { return false }
|
||||
func (*DisabledStub) BaseURL() string { return "" }
|
||||
func (d *DisabledStub) RoutePrefix() string {
|
||||
return fmt.Sprintf("/%s", d.name)
|
||||
}
|
||||
func (*DisabledStub) BridgedRoutes() []string { return nil }
|
||||
func (*DisabledStub) PassthroughRoutes() []string { return nil }
|
||||
func (*DisabledStub) AuthHeader() string { return "" }
|
||||
func (*DisabledStub) KeyFailoverConfig(_ slog.Logger) keypool.KeyFailoverConfig {
|
||||
return keypool.KeyFailoverConfig{}
|
||||
}
|
||||
func (*DisabledStub) CircuitBreakerConfig() *config.CircuitBreaker { return nil }
|
||||
func (*DisabledStub) APIDumpDir() string { return "" }
|
||||
func (*DisabledStub) CreateInterceptor(_ http.ResponseWriter, _ *http.Request, _ trace.Tracer) (intercept.Interceptor, error) {
|
||||
//nolint:nilnil // disabled providers never reach the interceptor.
|
||||
return nil, nil
|
||||
}
|
||||
@@ -84,6 +84,8 @@ func (p *OpenAI) Name() string {
|
||||
return p.cfg.Name
|
||||
}
|
||||
|
||||
func (*OpenAI) Enabled() bool { return true }
|
||||
|
||||
func (p *OpenAI) RoutePrefix() string {
|
||||
// Route prefix includes version to match default OpenAI base URL.
|
||||
// More detailed explanation: https://github.com/coder/aibridge/pull/174#discussion_r2782320152
|
||||
|
||||
@@ -53,6 +53,8 @@ type Provider interface {
|
||||
// Name returns the provider instance name.
|
||||
// Defaults to Type() when not explicitly configured.
|
||||
Name() string
|
||||
// Enabled reports whether the provider should serve requests.
|
||||
Enabled() bool
|
||||
// BaseURL defines the base URL endpoint for this provider's API.
|
||||
BaseURL() string
|
||||
|
||||
|
||||
+36
-11
@@ -4,6 +4,7 @@ package cli
|
||||
|
||||
import (
|
||||
"context"
|
||||
"slices"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
@@ -101,10 +102,18 @@ func (r *poolDBReloader) Reload(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// BuildProviders loads every ai_providers row (including disabled)
|
||||
// and returns the active provider list plus per-row outcomes. Per-row
|
||||
// build errors are logged and excluded from providers but recorded in
|
||||
// outcomes; only DB query failures propagate.
|
||||
// BuildProviders loads all ai_providers rows (enabled and disabled),
|
||||
// attaches keys to enabled rows, and constructs the equivalent
|
||||
// [aibridge.Provider] instances. The database is the single source of
|
||||
// truth for runtime provider configuration.
|
||||
//
|
||||
// Disabled rows produce a Provider stub with Enabled() == false so the
|
||||
// bridge can answer requests targeting them with a 503 sentinel.
|
||||
//
|
||||
// Per-provider construction errors are logged and the offending row is
|
||||
// excluded from the returned snapshot; only a failure of the DB query
|
||||
// itself is propagated. This keeps a single misconfigured row from
|
||||
// taking the whole daemon down.
|
||||
func BuildProviders(ctx context.Context, db database.Store, cfg codersdk.AIBridgeConfig, logger slog.Logger) ([]aibridge.Provider, []aibridged.ProviderOutcome, error) {
|
||||
//nolint:gocritic // AsAIBridged has a minimal permission set for this purpose.
|
||||
authCtx := dbauthz.AsAIBridged(ctx)
|
||||
@@ -160,12 +169,9 @@ func BuildProviders(ctx context.Context, db database.Store, cfg codersdk.AIBridg
|
||||
Name: row.Name,
|
||||
Type: string(row.Type),
|
||||
}
|
||||
if !row.Enabled {
|
||||
outcome.Status = aibridged.ProviderStatusDisabled
|
||||
outcomes = append(outcomes, outcome)
|
||||
continue
|
||||
}
|
||||
if row.Enabled {
|
||||
enabledCount++
|
||||
}
|
||||
prov, err := buildAIProviderFromRow(row, keysByProvider[row.ID], cfg)
|
||||
if err != nil {
|
||||
outcome.Status = aibridged.ProviderStatusError
|
||||
@@ -179,13 +185,17 @@ func BuildProviders(ctx context.Context, db database.Store, cfg codersdk.AIBridg
|
||||
)
|
||||
continue
|
||||
}
|
||||
if row.Enabled {
|
||||
outcome.Status = aibridged.ProviderStatusEnabled
|
||||
} else {
|
||||
outcome.Status = aibridged.ProviderStatusDisabled
|
||||
}
|
||||
outcomes = append(outcomes, outcome)
|
||||
providers = append(providers, prov)
|
||||
}
|
||||
|
||||
if enabledCount > 0 && len(providers) == 0 {
|
||||
logger.Warn(ctx, "all enabled ai providers failed to build; daemon will start with zero providers")
|
||||
if enabledCount > 0 && !slices.ContainsFunc(providers, func(p aibridge.Provider) bool { return p.Enabled() }) {
|
||||
logger.Warn(ctx, "all enabled ai providers failed to build; only disabled providers remain")
|
||||
}
|
||||
|
||||
return providers, outcomes, nil
|
||||
@@ -193,11 +203,18 @@ func BuildProviders(ctx context.Context, db database.Store, cfg codersdk.AIBridg
|
||||
|
||||
// buildAIProviderFromRow decodes the settings blob and constructs the
|
||||
// appropriate [aibridge.Provider] for a single ai_providers row.
|
||||
// Disabled rows return a Provider stub carrying only Name and
|
||||
// Disabled: true; settings decode, key loading, and credential checks
|
||||
// are skipped because the provider will never call upstream.
|
||||
func buildAIProviderFromRow(
|
||||
row database.AIProvider,
|
||||
keys []database.AIProviderKey,
|
||||
cfg codersdk.AIBridgeConfig,
|
||||
) (aibridge.Provider, error) {
|
||||
if !row.Enabled {
|
||||
return disabledProviderFromRow(row)
|
||||
}
|
||||
|
||||
settings, err := db2sdk.AIProviderSettings(row.Settings)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("decode settings: %w", err)
|
||||
@@ -287,6 +304,14 @@ func buildAIProviderFromRow(
|
||||
}
|
||||
}
|
||||
|
||||
// disabledProviderFromRow builds a Provider stub for a disabled row.
|
||||
// Using provider.DisabledStub rather than a concrete provider avoids
|
||||
// duplicating the row.Type switch and ensures that a new AiProviderType
|
||||
// value is automatically handled without requiring a matching case here.
|
||||
func disabledProviderFromRow(row database.AIProvider) (aibridge.Provider, error) {
|
||||
return aibridge.NewDisabledProviderStub(row.Name, string(row.Type)), nil
|
||||
}
|
||||
|
||||
// buildAIProviderKeyPool builds a [keypool.Pool]. Callers must check
|
||||
// len(keys) > 0 first; keypool.New rejects empty input.
|
||||
func buildAIProviderKeyPool(keys []database.AIProviderKey) (*keypool.Pool, error) {
|
||||
|
||||
@@ -393,26 +393,61 @@ func TestBuildProvidersSkipsBadRows(t *testing.T) {
|
||||
|
||||
t.Run("DisabledRowClassifiedAsDisabled", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
row database.AIProvider
|
||||
}{
|
||||
{
|
||||
name: "OpenAI",
|
||||
row: database.AIProvider{
|
||||
Type: database.AiProviderTypeOpenai,
|
||||
Name: "openai-off",
|
||||
BaseUrl: "https://api.openai.com/",
|
||||
},
|
||||
},
|
||||
{
|
||||
// Anthropic and Bedrock have stricter credential checks
|
||||
// than the OpenAI family; the disabled short-circuit
|
||||
// must reach them too. No keys, no bedrock settings.
|
||||
name: "Anthropic",
|
||||
row: database.AIProvider{
|
||||
Type: database.AiProviderTypeAnthropic,
|
||||
Name: "anthropic-off",
|
||||
BaseUrl: "https://api.anthropic.com/",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Bedrock",
|
||||
row: database.AIProvider{
|
||||
Type: database.AiProviderTypeBedrock,
|
||||
Name: "bedrock-off",
|
||||
BaseUrl: "https://bedrock-runtime.us-east-1.amazonaws.com/",
|
||||
},
|
||||
},
|
||||
} {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
logger := slogtest.Make(t, nil)
|
||||
|
||||
dbgen.AIProvider(t, db, database.AIProvider{
|
||||
Type: database.AiProviderTypeOpenai,
|
||||
Name: "openai-off",
|
||||
BaseUrl: "https://api.openai.com/",
|
||||
}, func(p *database.InsertAIProviderParams) {
|
||||
dbgen.AIProvider(t, db, tc.row, func(p *database.InsertAIProviderParams) {
|
||||
p.Enabled = false
|
||||
})
|
||||
|
||||
providers, outcomes, err := BuildProviders(ctx, db, codersdk.AIBridgeConfig{}, logger)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, providers, "disabled providers must not be in the active snapshot")
|
||||
require.Len(t, providers, 1, "disabled providers stay in the snapshot so the bridge can serve a 503 sentinel")
|
||||
assert.Equal(t, tc.row.Name, providers[0].Name())
|
||||
assert.False(t, providers[0].Enabled())
|
||||
require.Len(t, outcomes, 1)
|
||||
assert.Equal(t, "openai-off", outcomes[0].Name)
|
||||
assert.Equal(t, tc.row.Name, outcomes[0].Name)
|
||||
assert.Equal(t, aibridged.ProviderStatusDisabled, outcomes[0].Status)
|
||||
assert.NoError(t, outcomes[0].Err)
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func providerNames(providers []aibridge.Provider) []string {
|
||||
|
||||
@@ -588,6 +588,7 @@ func TestBuildAIProviderFromRowSetsAPIDumpDir(t *testing.T) {
|
||||
{
|
||||
name: "OpenAI",
|
||||
row: database.AIProvider{
|
||||
Enabled: true,
|
||||
Type: database.AiProviderTypeOpenai,
|
||||
Name: "openai",
|
||||
BaseUrl: "https://api.openai.com/",
|
||||
@@ -597,6 +598,7 @@ func TestBuildAIProviderFromRowSetsAPIDumpDir(t *testing.T) {
|
||||
{
|
||||
name: "Anthropic",
|
||||
row: database.AIProvider{
|
||||
Enabled: true,
|
||||
Type: database.AiProviderTypeAnthropic,
|
||||
Name: "anthropic",
|
||||
BaseUrl: "https://api.anthropic.com/",
|
||||
@@ -606,6 +608,7 @@ func TestBuildAIProviderFromRowSetsAPIDumpDir(t *testing.T) {
|
||||
{
|
||||
name: "Copilot",
|
||||
row: database.AIProvider{
|
||||
Enabled: true,
|
||||
Type: database.AiProviderTypeCopilot,
|
||||
Name: "copilot",
|
||||
BaseUrl: "https://api.githubcopilot.com/",
|
||||
@@ -615,6 +618,7 @@ func TestBuildAIProviderFromRowSetsAPIDumpDir(t *testing.T) {
|
||||
{
|
||||
name: "Azure",
|
||||
row: database.AIProvider{
|
||||
Enabled: true,
|
||||
Type: database.AiProviderTypeAzure,
|
||||
Name: "azure",
|
||||
BaseUrl: "https://example.openai.azure.com/",
|
||||
@@ -624,6 +628,7 @@ func TestBuildAIProviderFromRowSetsAPIDumpDir(t *testing.T) {
|
||||
{
|
||||
name: "Google",
|
||||
row: database.AIProvider{
|
||||
Enabled: true,
|
||||
Type: database.AiProviderTypeGoogle,
|
||||
Name: "google",
|
||||
BaseUrl: "https://generativelanguage.googleapis.com/v1beta/openai/",
|
||||
@@ -633,6 +638,7 @@ func TestBuildAIProviderFromRowSetsAPIDumpDir(t *testing.T) {
|
||||
{
|
||||
name: "OpenAICompat",
|
||||
row: database.AIProvider{
|
||||
Enabled: true,
|
||||
Type: database.AiProviderTypeOpenaiCompat,
|
||||
Name: "openai-compat",
|
||||
BaseUrl: "https://compat.example.com/v1/",
|
||||
@@ -642,6 +648,7 @@ func TestBuildAIProviderFromRowSetsAPIDumpDir(t *testing.T) {
|
||||
{
|
||||
name: "OpenRouter",
|
||||
row: database.AIProvider{
|
||||
Enabled: true,
|
||||
Type: database.AiProviderTypeOpenrouter,
|
||||
Name: "openrouter",
|
||||
BaseUrl: "https://openrouter.ai/api/v1/",
|
||||
@@ -651,6 +658,7 @@ func TestBuildAIProviderFromRowSetsAPIDumpDir(t *testing.T) {
|
||||
{
|
||||
name: "Vercel",
|
||||
row: database.AIProvider{
|
||||
Enabled: true,
|
||||
Type: database.AiProviderTypeVercel,
|
||||
Name: "vercel",
|
||||
BaseUrl: "https://api.v0.dev/v1/",
|
||||
@@ -660,6 +668,7 @@ func TestBuildAIProviderFromRowSetsAPIDumpDir(t *testing.T) {
|
||||
{
|
||||
name: "Bedrock",
|
||||
row: database.AIProvider{
|
||||
Enabled: true,
|
||||
Type: database.AiProviderTypeBedrock,
|
||||
Name: "bedrock",
|
||||
BaseUrl: "https://bedrock-runtime.us-east-1.amazonaws.com/",
|
||||
@@ -694,6 +703,7 @@ func TestBuildAIProviderFromRowBedrockWithoutSettings(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
_, err := buildAIProviderFromRow(database.AIProvider{
|
||||
Enabled: true,
|
||||
Type: database.AiProviderTypeBedrock,
|
||||
Name: "bedrock-no-settings",
|
||||
BaseUrl: "https://bedrock-runtime.us-east-1.amazonaws.com/",
|
||||
|
||||
@@ -30,7 +30,9 @@ const (
|
||||
type Pooler interface {
|
||||
Acquire(ctx context.Context, req Request, clientFn ClientFunc, mcpBootstrapper MCPProxyBuilder) (http.Handler, error)
|
||||
// ReplaceProviders swaps the providers used to construct future
|
||||
// RequestBridge instances and clears the cache.
|
||||
// RequestBridge instances and clears the cache. Disabled providers
|
||||
// must be included; the bridge serves a 503 sentinel on their
|
||||
// routes.
|
||||
ReplaceProviders(providers []aibridge.Provider)
|
||||
Shutdown(ctx context.Context) error
|
||||
}
|
||||
@@ -53,7 +55,8 @@ var _ Pooler = &CachedBridgePool{}
|
||||
|
||||
type CachedBridgePool struct {
|
||||
cache *ristretto.Cache[string, *aibridge.RequestBridge]
|
||||
// providers is the live provider set used by new RequestBridge instances.
|
||||
// providers is the live provider set used by new RequestBridge
|
||||
// instances. Includes disabled providers.
|
||||
providers atomic.Pointer[[]aibridge.Provider]
|
||||
providerVersion atomic.Int64
|
||||
logger slog.Logger
|
||||
|
||||
@@ -17,9 +17,9 @@ const (
|
||||
)
|
||||
|
||||
// ProviderOutcome classifies one ai_providers row, including disabled
|
||||
// and errored rows the pool excludes. Err is populated only when
|
||||
// Status == ProviderStatusError; the build error is already logged at
|
||||
// the call site.
|
||||
// rows (which the pool keeps as 503 stubs) and errored rows (which the
|
||||
// pool excludes). Err is populated only when Status == ProviderStatusError;
|
||||
// the build error is already logged at the call site.
|
||||
type ProviderOutcome struct {
|
||||
Name string
|
||||
Type string
|
||||
|
||||
@@ -211,6 +211,18 @@ func TestAIBridgeProviderHotReload(t *testing.T) {
|
||||
"expected provider %q to stop routing", providerName)
|
||||
}
|
||||
|
||||
// requireDisabledSentinel polls until the provider name yields a
|
||||
// 503 with the provider_disabled body, indicating the disabled
|
||||
// handler is wired up for the row.
|
||||
requireDisabledSentinel := func(t *testing.T, providerName string) {
|
||||
t.Helper()
|
||||
require.Eventuallyf(t, func() bool {
|
||||
status, _ := sendRequest(providerName)
|
||||
return status == http.StatusServiceUnavailable
|
||||
}, testutil.WaitShort, testutil.IntervalFast,
|
||||
"expected provider %q to serve the disabled sentinel", providerName)
|
||||
}
|
||||
|
||||
// 1. Create: provider points at upstream A.
|
||||
created, err := client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{
|
||||
Type: codersdk.AIProviderTypeOpenAI,
|
||||
@@ -233,14 +245,14 @@ func TestAIBridgeProviderHotReload(t *testing.T) {
|
||||
requireRoutesTo(t, "primary", upstreamB)
|
||||
requireProviderStatus(t, "primary", "enabled")
|
||||
|
||||
// 3. Disable: the provider drops out of the snapshot, requests
|
||||
// stop reaching any upstream. The metric flips to "disabled".
|
||||
// 3. Disable: requests stop reaching upstream and the bridge
|
||||
// answers with the 503 sentinel. The metric flips to "disabled".
|
||||
disabled := false
|
||||
_, err = client.UpdateAIProvider(ctx, "primary", codersdk.UpdateAIProviderRequest{
|
||||
Enabled: &disabled,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
requireRoutingGone(t, "primary")
|
||||
requireDisabledSentinel(t, "primary")
|
||||
requireProviderStatus(t, "primary", "disabled")
|
||||
|
||||
// 4. Re-enable: routing comes back at the most recent BaseURL.
|
||||
|
||||
Reference in New Issue
Block a user