From 5b10268827e8c3a8ac8e0ad563d042e66c7003c7 Mon Sep 17 00:00:00 2001 From: Danny Kopping Date: Fri, 29 May 2026 10:24:16 +0200 Subject: [PATCH] feat: serve 503 sentinel for disabled providers (#25794) _Disclosure: created with Coder Agents._ When providers are disabled, we should serve a sentinel error so the requesting client (Claude Code, Coder Agents, etc) is informed. Coder Agents can also conditionalize its display to show a helpful error message. --------- Signed-off-by: Danny Kopping Co-authored-by: Claude Opus 4.7 (1M context) --- aibridge/api.go | 8 +++ aibridge/bridge.go | 27 +++++++++ aibridge/bridge_test.go | 55 +++++++++++++++++ aibridge/internal/testutil/mockprovider.go | 2 + aibridge/provider/anthropic.go | 2 + aibridge/provider/copilot.go | 2 + aibridge/provider/disabled.go | 47 +++++++++++++++ aibridge/provider/openai.go | 2 + aibridge/provider/provider.go | 2 + cli/aibridged.go | 49 +++++++++++---- cli/aibridged_internal_test.go | 69 ++++++++++++++++------ cli/server_aibridge_internal_test.go | 10 ++++ coderd/aibridged/pool.go | 7 ++- coderd/aibridged/provider.go | 6 +- enterprise/coderd/aibridge_reload_test.go | 18 +++++- 15 files changed, 269 insertions(+), 37 deletions(-) create mode 100644 aibridge/provider/disabled.go diff --git a/aibridge/api.go b/aibridge/api.go index 809d452fe9..34dce84ef8 100644 --- a/aibridge/api.go +++ b/aibridge/api.go @@ -57,6 +57,14 @@ func NewCopilotProvider(cfg config.Copilot) provider.Provider { return provider.NewCopilot(cfg) } +// NewDisabledProviderStub returns a Provider that reports Enabled() == +// false and has no-op implementations for all other methods. Use this +// instead of constructing a concrete provider for disabled rows so that +// adding a new provider type does not require updating a switch here. +func NewDisabledProviderStub(name, providerType string) provider.Provider { + return provider.NewDisabledStub(name, providerType) +} + func NewMetrics(reg prometheus.Registerer) *metrics.Metrics { return metrics.NewMetrics(reg) } diff --git a/aibridge/bridge.go b/aibridge/bridge.go index f604d0a38a..79715af880 100644 --- a/aibridge/bridge.go +++ b/aibridge/bridge.go @@ -30,6 +30,11 @@ import ( const ( // The duration after which an async recording will be aborted. recordingTimeout = time.Second * 5 + + // ErrorCodeProviderDisabled is the code written in the response + // body when a request targets a configured-but-disabled provider. + // Paired with HTTP 503. + ErrorCodeProviderDisabled = "provider_disabled" ) // RequestBridge is an [http.Handler] which is capable of masquerading as AI providers' APIs; @@ -96,6 +101,14 @@ func NewRequestBridge(ctx context.Context, providers []provider.Provider, rec re mux := http.NewServeMux() for _, prov := range providers { + // Disabled providers serve a 503 sentinel on every path under + // "//". Bound to the bare name (not RoutePrefix) so paths + // outside the provider's normal "/v1" subtree are also caught. + if !prov.Enabled() { + prefix := fmt.Sprintf("/%s/", prov.Name()) + mux.HandleFunc(prefix, disabledProviderHandler(prov.Name(), logger)) + continue + } // Create per-provider circuit breaker if configured cfg := prov.CircuitBreakerConfig() providerName := prov.Name() @@ -170,6 +183,20 @@ func NewRequestBridge(ctx context.Context, providers []provider.Provider, rec re }, nil } +// disabledProviderHandler returns 503 with a body containing +// [ErrorCodeProviderDisabled] and the provider name for every request +// targeting name. +func disabledProviderHandler(name string, logger slog.Logger) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + logger.Debug(r.Context(), "refusing request for disabled ai provider", + slog.F("provider", name), + slog.F("path", r.URL.Path), + slog.F("method", r.Method), + ) + http.Error(w, fmt.Sprintf("%s: AI provider %q is disabled", ErrorCodeProviderDisabled, name), http.StatusServiceUnavailable) + } +} + // newInterceptionProcessor returns an [http.HandlerFunc] which is capable of creating a new interceptor and processing a given request // using [Provider] p, recording all usage events using [Recorder] rec. // If cbs is non-nil, circuit breaker protection is applied per endpoint/model tuple. diff --git a/aibridge/bridge_test.go b/aibridge/bridge_test.go index f2657ab80f..93beb82de9 100644 --- a/aibridge/bridge_test.go +++ b/aibridge/bridge_test.go @@ -205,3 +205,58 @@ func TestPassthroughRoutesForProviders(t *testing.T) { }) } } + +// TestDisabledProviderHandler asserts that requests to a disabled +// provider return a 503 with an ErrorCodeProviderDisabled body and +// that a sibling enabled provider keeps routing normally. +func TestDisabledProviderHandler(t *testing.T) { + t.Parallel() + + logger := slogtest.Make(t, nil) + + upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte("upstream-reached")) + })) + t.Cleanup(upstream.Close) + + enabled := aibridge.NewOpenAIProvider(config.OpenAI{Name: "enabled-openai", BaseURL: upstream.URL}) + disabled := aibridge.NewDisabledProviderStub("disabled-openai", "openai") + bridge, err := aibridge.NewRequestBridge( + t.Context(), + []provider.Provider{enabled, disabled}, + nil, nil, logger, nil, bridgeTestTracer, + ) + require.NoError(t, err) + + for _, tc := range []struct { + name string + path string + }{ + {name: "Bridged", path: "/disabled-openai/v1/chat/completions"}, + {name: "Passthrough", path: "/disabled-openai/v1/models"}, + {name: "Unknown", path: "/disabled-openai/anything/else"}, + } { + t.Run("DisabledProviderReturnsSentinel/"+tc.name, func(t *testing.T) { + t.Parallel() + + req := httptest.NewRequest(http.MethodPost, tc.path, nil) + resp := httptest.NewRecorder() + bridge.ServeHTTP(resp, req) + + assert.Equal(t, http.StatusServiceUnavailable, resp.Code) + assert.Contains(t, resp.Body.String(), aibridge.ErrorCodeProviderDisabled) + assert.Contains(t, resp.Body.String(), "disabled-openai") + }) + } + + t.Run("EnabledProviderUnaffected", func(t *testing.T) { + t.Parallel() + + req := httptest.NewRequest(http.MethodGet, "/enabled-openai/v1/models", nil) + resp := httptest.NewRecorder() + bridge.ServeHTTP(resp, req) + + assert.Equal(t, http.StatusOK, resp.Code) + assert.Equal(t, "upstream-reached", resp.Body.String()) + }) +} diff --git a/aibridge/internal/testutil/mockprovider.go b/aibridge/internal/testutil/mockprovider.go index 0fd85d2863..e5015cd870 100644 --- a/aibridge/internal/testutil/mockprovider.go +++ b/aibridge/internal/testutil/mockprovider.go @@ -15,6 +15,7 @@ import ( type MockProvider struct { NameStr string URL string + Disabled bool Bridged []string Passthrough []string InterceptorFunc func(w http.ResponseWriter, r *http.Request, tracer trace.Tracer) (intercept.Interceptor, error) @@ -22,6 +23,7 @@ type MockProvider struct { func (m *MockProvider) Type() string { return m.NameStr } func (m *MockProvider) Name() string { return m.NameStr } +func (m *MockProvider) Enabled() bool { return !m.Disabled } func (m *MockProvider) BaseURL() string { return m.URL } func (m *MockProvider) RoutePrefix() string { return fmt.Sprintf("/%s", m.NameStr) } func (m *MockProvider) BridgedRoutes() []string { return m.Bridged } diff --git a/aibridge/provider/anthropic.go b/aibridge/provider/anthropic.go index eb50a3b296..664183f82f 100644 --- a/aibridge/provider/anthropic.go +++ b/aibridge/provider/anthropic.go @@ -95,6 +95,8 @@ func (p *Anthropic) Name() string { return p.cfg.Name } +func (*Anthropic) Enabled() bool { return true } + func (p *Anthropic) RoutePrefix() string { return fmt.Sprintf("/%s", p.Name()) } diff --git a/aibridge/provider/copilot.go b/aibridge/provider/copilot.go index 1186e8b253..fd317aadab 100644 --- a/aibridge/provider/copilot.go +++ b/aibridge/provider/copilot.go @@ -78,6 +78,8 @@ func (p *Copilot) Name() string { return p.cfg.Name } +func (*Copilot) Enabled() bool { return true } + func (p *Copilot) BaseURL() string { return p.cfg.BaseURL } diff --git a/aibridge/provider/disabled.go b/aibridge/provider/disabled.go new file mode 100644 index 0000000000..95384b4952 --- /dev/null +++ b/aibridge/provider/disabled.go @@ -0,0 +1,47 @@ +package provider + +import ( + "fmt" + "net/http" + + "go.opentelemetry.io/otel/trace" + + "cdr.dev/slog/v3" + "github.com/coder/coder/v2/aibridge/config" + "github.com/coder/coder/v2/aibridge/intercept" + "github.com/coder/coder/v2/aibridge/keypool" +) + +// DisabledStub is a Provider placeholder for a configured-but-disabled +// provider. Only Name and Enabled return meaningful values; all other +// methods return empty/nil so the stub never influences routing. +type DisabledStub struct { + name string + providerType string +} + +// NewDisabledStub returns a Provider stub that reports Enabled() == false. +// The type string is preserved so callers can distinguish provider families. +func NewDisabledStub(name, providerType string) *DisabledStub { + return &DisabledStub{name: name, providerType: providerType} +} + +func (d *DisabledStub) Type() string { return d.providerType } +func (d *DisabledStub) Name() string { return d.name } +func (*DisabledStub) Enabled() bool { return false } +func (*DisabledStub) BaseURL() string { return "" } +func (d *DisabledStub) RoutePrefix() string { + return fmt.Sprintf("/%s", d.name) +} +func (*DisabledStub) BridgedRoutes() []string { return nil } +func (*DisabledStub) PassthroughRoutes() []string { return nil } +func (*DisabledStub) AuthHeader() string { return "" } +func (*DisabledStub) KeyFailoverConfig(_ slog.Logger) keypool.KeyFailoverConfig { + return keypool.KeyFailoverConfig{} +} +func (*DisabledStub) CircuitBreakerConfig() *config.CircuitBreaker { return nil } +func (*DisabledStub) APIDumpDir() string { return "" } +func (*DisabledStub) CreateInterceptor(_ http.ResponseWriter, _ *http.Request, _ trace.Tracer) (intercept.Interceptor, error) { + //nolint:nilnil // disabled providers never reach the interceptor. + return nil, nil +} diff --git a/aibridge/provider/openai.go b/aibridge/provider/openai.go index 177ae03409..7a78069df6 100644 --- a/aibridge/provider/openai.go +++ b/aibridge/provider/openai.go @@ -84,6 +84,8 @@ func (p *OpenAI) Name() string { return p.cfg.Name } +func (*OpenAI) Enabled() bool { return true } + func (p *OpenAI) RoutePrefix() string { // Route prefix includes version to match default OpenAI base URL. // More detailed explanation: https://github.com/coder/aibridge/pull/174#discussion_r2782320152 diff --git a/aibridge/provider/provider.go b/aibridge/provider/provider.go index 7520333b53..6f21d7290d 100644 --- a/aibridge/provider/provider.go +++ b/aibridge/provider/provider.go @@ -53,6 +53,8 @@ type Provider interface { // Name returns the provider instance name. // Defaults to Type() when not explicitly configured. Name() string + // Enabled reports whether the provider should serve requests. + Enabled() bool // BaseURL defines the base URL endpoint for this provider's API. BaseURL() string diff --git a/cli/aibridged.go b/cli/aibridged.go index caf67082fc..a890488a10 100644 --- a/cli/aibridged.go +++ b/cli/aibridged.go @@ -4,6 +4,7 @@ package cli import ( "context" + "slices" "github.com/google/uuid" "github.com/prometheus/client_golang/prometheus" @@ -101,10 +102,18 @@ func (r *poolDBReloader) Reload(ctx context.Context) error { return nil } -// BuildProviders loads every ai_providers row (including disabled) -// and returns the active provider list plus per-row outcomes. Per-row -// build errors are logged and excluded from providers but recorded in -// outcomes; only DB query failures propagate. +// BuildProviders loads all ai_providers rows (enabled and disabled), +// attaches keys to enabled rows, and constructs the equivalent +// [aibridge.Provider] instances. The database is the single source of +// truth for runtime provider configuration. +// +// Disabled rows produce a Provider stub with Enabled() == false so the +// bridge can answer requests targeting them with a 503 sentinel. +// +// Per-provider construction errors are logged and the offending row is +// excluded from the returned snapshot; only a failure of the DB query +// itself is propagated. This keeps a single misconfigured row from +// taking the whole daemon down. func BuildProviders(ctx context.Context, db database.Store, cfg codersdk.AIBridgeConfig, logger slog.Logger) ([]aibridge.Provider, []aibridged.ProviderOutcome, error) { //nolint:gocritic // AsAIBridged has a minimal permission set for this purpose. authCtx := dbauthz.AsAIBridged(ctx) @@ -160,12 +169,9 @@ func BuildProviders(ctx context.Context, db database.Store, cfg codersdk.AIBridg Name: row.Name, Type: string(row.Type), } - if !row.Enabled { - outcome.Status = aibridged.ProviderStatusDisabled - outcomes = append(outcomes, outcome) - continue + if row.Enabled { + enabledCount++ } - enabledCount++ prov, err := buildAIProviderFromRow(row, keysByProvider[row.ID], cfg) if err != nil { outcome.Status = aibridged.ProviderStatusError @@ -179,13 +185,17 @@ func BuildProviders(ctx context.Context, db database.Store, cfg codersdk.AIBridg ) continue } - outcome.Status = aibridged.ProviderStatusEnabled + if row.Enabled { + outcome.Status = aibridged.ProviderStatusEnabled + } else { + outcome.Status = aibridged.ProviderStatusDisabled + } outcomes = append(outcomes, outcome) providers = append(providers, prov) } - if enabledCount > 0 && len(providers) == 0 { - logger.Warn(ctx, "all enabled ai providers failed to build; daemon will start with zero providers") + if enabledCount > 0 && !slices.ContainsFunc(providers, func(p aibridge.Provider) bool { return p.Enabled() }) { + logger.Warn(ctx, "all enabled ai providers failed to build; only disabled providers remain") } return providers, outcomes, nil @@ -193,11 +203,18 @@ func BuildProviders(ctx context.Context, db database.Store, cfg codersdk.AIBridg // buildAIProviderFromRow decodes the settings blob and constructs the // appropriate [aibridge.Provider] for a single ai_providers row. +// Disabled rows return a Provider stub carrying only Name and +// Disabled: true; settings decode, key loading, and credential checks +// are skipped because the provider will never call upstream. func buildAIProviderFromRow( row database.AIProvider, keys []database.AIProviderKey, cfg codersdk.AIBridgeConfig, ) (aibridge.Provider, error) { + if !row.Enabled { + return disabledProviderFromRow(row) + } + settings, err := db2sdk.AIProviderSettings(row.Settings) if err != nil { return nil, xerrors.Errorf("decode settings: %w", err) @@ -287,6 +304,14 @@ func buildAIProviderFromRow( } } +// disabledProviderFromRow builds a Provider stub for a disabled row. +// Using provider.DisabledStub rather than a concrete provider avoids +// duplicating the row.Type switch and ensures that a new AiProviderType +// value is automatically handled without requiring a matching case here. +func disabledProviderFromRow(row database.AIProvider) (aibridge.Provider, error) { + return aibridge.NewDisabledProviderStub(row.Name, string(row.Type)), nil +} + // buildAIProviderKeyPool builds a [keypool.Pool]. Callers must check // len(keys) > 0 first; keypool.New rejects empty input. func buildAIProviderKeyPool(keys []database.AIProviderKey) (*keypool.Pool, error) { diff --git a/cli/aibridged_internal_test.go b/cli/aibridged_internal_test.go index 0226974520..6b3e1eb7ac 100644 --- a/cli/aibridged_internal_test.go +++ b/cli/aibridged_internal_test.go @@ -393,25 +393,60 @@ func TestBuildProvidersSkipsBadRows(t *testing.T) { t.Run("DisabledRowClassifiedAsDisabled", func(t *testing.T) { t.Parallel() - db, _ := dbtestutil.NewDB(t) - ctx := testutil.Context(t, testutil.WaitShort) - logger := slogtest.Make(t, nil) - dbgen.AIProvider(t, db, database.AIProvider{ - Type: database.AiProviderTypeOpenai, - Name: "openai-off", - BaseUrl: "https://api.openai.com/", - }, func(p *database.InsertAIProviderParams) { - p.Enabled = false - }) + for _, tc := range []struct { + name string + row database.AIProvider + }{ + { + name: "OpenAI", + row: database.AIProvider{ + Type: database.AiProviderTypeOpenai, + Name: "openai-off", + BaseUrl: "https://api.openai.com/", + }, + }, + { + // Anthropic and Bedrock have stricter credential checks + // than the OpenAI family; the disabled short-circuit + // must reach them too. No keys, no bedrock settings. + name: "Anthropic", + row: database.AIProvider{ + Type: database.AiProviderTypeAnthropic, + Name: "anthropic-off", + BaseUrl: "https://api.anthropic.com/", + }, + }, + { + name: "Bedrock", + row: database.AIProvider{ + Type: database.AiProviderTypeBedrock, + Name: "bedrock-off", + BaseUrl: "https://bedrock-runtime.us-east-1.amazonaws.com/", + }, + }, + } { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + db, _ := dbtestutil.NewDB(t) + ctx := testutil.Context(t, testutil.WaitShort) + logger := slogtest.Make(t, nil) - providers, outcomes, err := BuildProviders(ctx, db, codersdk.AIBridgeConfig{}, logger) - require.NoError(t, err) - assert.Empty(t, providers, "disabled providers must not be in the active snapshot") - require.Len(t, outcomes, 1) - assert.Equal(t, "openai-off", outcomes[0].Name) - assert.Equal(t, aibridged.ProviderStatusDisabled, outcomes[0].Status) - assert.NoError(t, outcomes[0].Err) + dbgen.AIProvider(t, db, tc.row, func(p *database.InsertAIProviderParams) { + p.Enabled = false + }) + + providers, outcomes, err := BuildProviders(ctx, db, codersdk.AIBridgeConfig{}, logger) + require.NoError(t, err) + require.Len(t, providers, 1, "disabled providers stay in the snapshot so the bridge can serve a 503 sentinel") + assert.Equal(t, tc.row.Name, providers[0].Name()) + assert.False(t, providers[0].Enabled()) + require.Len(t, outcomes, 1) + assert.Equal(t, tc.row.Name, outcomes[0].Name) + assert.Equal(t, aibridged.ProviderStatusDisabled, outcomes[0].Status) + assert.NoError(t, outcomes[0].Err) + }) + } }) } diff --git a/cli/server_aibridge_internal_test.go b/cli/server_aibridge_internal_test.go index 6b712a6352..a91e5b51d2 100644 --- a/cli/server_aibridge_internal_test.go +++ b/cli/server_aibridge_internal_test.go @@ -588,6 +588,7 @@ func TestBuildAIProviderFromRowSetsAPIDumpDir(t *testing.T) { { name: "OpenAI", row: database.AIProvider{ + Enabled: true, Type: database.AiProviderTypeOpenai, Name: "openai", BaseUrl: "https://api.openai.com/", @@ -597,6 +598,7 @@ func TestBuildAIProviderFromRowSetsAPIDumpDir(t *testing.T) { { name: "Anthropic", row: database.AIProvider{ + Enabled: true, Type: database.AiProviderTypeAnthropic, Name: "anthropic", BaseUrl: "https://api.anthropic.com/", @@ -606,6 +608,7 @@ func TestBuildAIProviderFromRowSetsAPIDumpDir(t *testing.T) { { name: "Copilot", row: database.AIProvider{ + Enabled: true, Type: database.AiProviderTypeCopilot, Name: "copilot", BaseUrl: "https://api.githubcopilot.com/", @@ -615,6 +618,7 @@ func TestBuildAIProviderFromRowSetsAPIDumpDir(t *testing.T) { { name: "Azure", row: database.AIProvider{ + Enabled: true, Type: database.AiProviderTypeAzure, Name: "azure", BaseUrl: "https://example.openai.azure.com/", @@ -624,6 +628,7 @@ func TestBuildAIProviderFromRowSetsAPIDumpDir(t *testing.T) { { name: "Google", row: database.AIProvider{ + Enabled: true, Type: database.AiProviderTypeGoogle, Name: "google", BaseUrl: "https://generativelanguage.googleapis.com/v1beta/openai/", @@ -633,6 +638,7 @@ func TestBuildAIProviderFromRowSetsAPIDumpDir(t *testing.T) { { name: "OpenAICompat", row: database.AIProvider{ + Enabled: true, Type: database.AiProviderTypeOpenaiCompat, Name: "openai-compat", BaseUrl: "https://compat.example.com/v1/", @@ -642,6 +648,7 @@ func TestBuildAIProviderFromRowSetsAPIDumpDir(t *testing.T) { { name: "OpenRouter", row: database.AIProvider{ + Enabled: true, Type: database.AiProviderTypeOpenrouter, Name: "openrouter", BaseUrl: "https://openrouter.ai/api/v1/", @@ -651,6 +658,7 @@ func TestBuildAIProviderFromRowSetsAPIDumpDir(t *testing.T) { { name: "Vercel", row: database.AIProvider{ + Enabled: true, Type: database.AiProviderTypeVercel, Name: "vercel", BaseUrl: "https://api.v0.dev/v1/", @@ -660,6 +668,7 @@ func TestBuildAIProviderFromRowSetsAPIDumpDir(t *testing.T) { { name: "Bedrock", row: database.AIProvider{ + Enabled: true, Type: database.AiProviderTypeBedrock, Name: "bedrock", BaseUrl: "https://bedrock-runtime.us-east-1.amazonaws.com/", @@ -694,6 +703,7 @@ func TestBuildAIProviderFromRowBedrockWithoutSettings(t *testing.T) { t.Parallel() _, err := buildAIProviderFromRow(database.AIProvider{ + Enabled: true, Type: database.AiProviderTypeBedrock, Name: "bedrock-no-settings", BaseUrl: "https://bedrock-runtime.us-east-1.amazonaws.com/", diff --git a/coderd/aibridged/pool.go b/coderd/aibridged/pool.go index 3b7e60955c..b86cefe00a 100644 --- a/coderd/aibridged/pool.go +++ b/coderd/aibridged/pool.go @@ -30,7 +30,9 @@ const ( type Pooler interface { Acquire(ctx context.Context, req Request, clientFn ClientFunc, mcpBootstrapper MCPProxyBuilder) (http.Handler, error) // ReplaceProviders swaps the providers used to construct future - // RequestBridge instances and clears the cache. + // RequestBridge instances and clears the cache. Disabled providers + // must be included; the bridge serves a 503 sentinel on their + // routes. ReplaceProviders(providers []aibridge.Provider) Shutdown(ctx context.Context) error } @@ -53,7 +55,8 @@ var _ Pooler = &CachedBridgePool{} type CachedBridgePool struct { cache *ristretto.Cache[string, *aibridge.RequestBridge] - // providers is the live provider set used by new RequestBridge instances. + // providers is the live provider set used by new RequestBridge + // instances. Includes disabled providers. providers atomic.Pointer[[]aibridge.Provider] providerVersion atomic.Int64 logger slog.Logger diff --git a/coderd/aibridged/provider.go b/coderd/aibridged/provider.go index 6fb53e1a93..9d2faa030b 100644 --- a/coderd/aibridged/provider.go +++ b/coderd/aibridged/provider.go @@ -17,9 +17,9 @@ const ( ) // ProviderOutcome classifies one ai_providers row, including disabled -// and errored rows the pool excludes. Err is populated only when -// Status == ProviderStatusError; the build error is already logged at -// the call site. +// rows (which the pool keeps as 503 stubs) and errored rows (which the +// pool excludes). Err is populated only when Status == ProviderStatusError; +// the build error is already logged at the call site. type ProviderOutcome struct { Name string Type string diff --git a/enterprise/coderd/aibridge_reload_test.go b/enterprise/coderd/aibridge_reload_test.go index 65f678df6e..e3370c8f7d 100644 --- a/enterprise/coderd/aibridge_reload_test.go +++ b/enterprise/coderd/aibridge_reload_test.go @@ -211,6 +211,18 @@ func TestAIBridgeProviderHotReload(t *testing.T) { "expected provider %q to stop routing", providerName) } + // requireDisabledSentinel polls until the provider name yields a + // 503 with the provider_disabled body, indicating the disabled + // handler is wired up for the row. + requireDisabledSentinel := func(t *testing.T, providerName string) { + t.Helper() + require.Eventuallyf(t, func() bool { + status, _ := sendRequest(providerName) + return status == http.StatusServiceUnavailable + }, testutil.WaitShort, testutil.IntervalFast, + "expected provider %q to serve the disabled sentinel", providerName) + } + // 1. Create: provider points at upstream A. created, err := client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{ Type: codersdk.AIProviderTypeOpenAI, @@ -233,14 +245,14 @@ func TestAIBridgeProviderHotReload(t *testing.T) { requireRoutesTo(t, "primary", upstreamB) requireProviderStatus(t, "primary", "enabled") - // 3. Disable: the provider drops out of the snapshot, requests - // stop reaching any upstream. The metric flips to "disabled". + // 3. Disable: requests stop reaching upstream and the bridge + // answers with the 503 sentinel. The metric flips to "disabled". disabled := false _, err = client.UpdateAIProvider(ctx, "primary", codersdk.UpdateAIProviderRequest{ Enabled: &disabled, }) require.NoError(t, err) - requireRoutingGone(t, "primary") + requireDisabledSentinel(t, "primary") requireProviderStatus(t, "primary", "disabled") // 4. Re-enable: routing comes back at the most recent BaseURL.