Files
Danny Kopping 5b10268827 feat: serve 503 sentinel for disabled providers (#25794)
_Disclosure: created with Coder Agents._

When providers are disabled, we should serve a sentinel error so the
requesting client (Claude Code, Coder Agents, etc) is informed. Coder
Agents can also conditionalize its display to show a helpful error
message.

---------

Signed-off-by: Danny Kopping <danny@coder.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-29 10:24:16 +02:00

263 lines
8.4 KiB
Go

package aibridge_test
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/aibridge"
"github.com/coder/coder/v2/aibridge/config"
"github.com/coder/coder/v2/aibridge/internal/testutil"
"github.com/coder/coder/v2/aibridge/provider"
)
var bridgeTestTracer = otel.Tracer("bridge_test")
func TestValidateProviders(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
tests := []struct {
name string
providers []provider.Provider
expectErr string
}{
{
name: "all_supported_providers",
providers: []provider.Provider{
aibridge.NewOpenAIProvider(config.OpenAI{Name: "openai", BaseURL: "https://api.openai.com/v1/"}),
aibridge.NewAnthropicProvider(config.Anthropic{Name: "anthropic", BaseURL: "https://api.anthropic.com/"}, nil),
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}),
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.business.githubcopilot.com"}),
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot-enterprise", BaseURL: "https://api.enterprise.githubcopilot.com"}),
},
},
{
name: "default_names_and_base_urls",
providers: []provider.Provider{
aibridge.NewOpenAIProvider(config.OpenAI{}),
aibridge.NewAnthropicProvider(config.Anthropic{}, nil),
aibridge.NewCopilotProvider(config.Copilot{}),
},
},
{
name: "multiple_copilot_instances",
providers: []provider.Provider{
aibridge.NewCopilotProvider(config.Copilot{}),
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.business.githubcopilot.com"}),
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot-enterprise", BaseURL: "https://api.enterprise.githubcopilot.com"}),
},
},
{
name: "name_with_slashes",
providers: []provider.Provider{
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot/business", BaseURL: "https://api.business.githubcopilot.com"}),
},
expectErr: "invalid provider name",
},
{
name: "name_with_spaces",
providers: []provider.Provider{
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot business", BaseURL: "https://api.business.githubcopilot.com"}),
},
expectErr: "invalid provider name",
},
{
name: "name_with_uppercase",
providers: []provider.Provider{
aibridge.NewCopilotProvider(config.Copilot{Name: "Copilot", BaseURL: "https://api.business.githubcopilot.com"}),
},
expectErr: "invalid provider name",
},
{
name: "unique_names",
providers: []provider.Provider{
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}),
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.business.githubcopilot.com"}),
},
},
{
name: "duplicate_base_url_different_names",
providers: []provider.Provider{
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}),
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot-business", BaseURL: "https://api.individual.githubcopilot.com"}),
},
},
{
name: "duplicate_name",
providers: []provider.Provider{
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.individual.githubcopilot.com"}),
aibridge.NewCopilotProvider(config.Copilot{Name: "copilot", BaseURL: "https://api.business.githubcopilot.com"}),
},
expectErr: "duplicate provider name",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
_, err := aibridge.NewRequestBridge(t.Context(), tc.providers, nil, nil, logger, nil, bridgeTestTracer)
if tc.expectErr != "" {
require.Error(t, err)
assert.Contains(t, err.Error(), tc.expectErr)
} else {
require.NoError(t, err)
}
})
}
}
func TestPassthroughRoutesForProviders(t *testing.T) {
t.Parallel()
upstreamRespBody := "upstream response"
tests := []struct {
name string
baseURLPath string
requestPath string
provider func(string) provider.Provider
expectPath string
}{
{
name: "openAI_no_base_path",
requestPath: "/openai/v1/conversations",
provider: func(baseURL string) provider.Provider {
return aibridge.NewOpenAIProvider(config.OpenAI{BaseURL: baseURL})
},
expectPath: "/conversations",
},
{
name: "openAI_with_base_path",
baseURLPath: "/v1",
requestPath: "/openai/v1/conversations",
provider: func(baseURL string) provider.Provider {
return aibridge.NewOpenAIProvider(config.OpenAI{BaseURL: baseURL})
},
expectPath: "/v1/conversations",
},
{
name: "anthropic_no_base_path",
requestPath: "/anthropic/v1/models",
provider: func(baseURL string) provider.Provider {
return aibridge.NewAnthropicProvider(config.Anthropic{BaseURL: baseURL}, nil)
},
expectPath: "/v1/models",
},
{
name: "anthropic_with_base_path",
baseURLPath: "/v1",
requestPath: "/anthropic/v1/models",
provider: func(baseURL string) provider.Provider {
return aibridge.NewAnthropicProvider(config.Anthropic{BaseURL: baseURL}, nil)
},
expectPath: "/v1/v1/models",
},
{
name: "copilot_no_base_path",
requestPath: "/copilot/models",
provider: func(baseURL string) provider.Provider {
return aibridge.NewCopilotProvider(config.Copilot{BaseURL: baseURL})
},
expectPath: "/models",
},
{
name: "copilot_with_base_path",
baseURLPath: "/v1",
requestPath: "/copilot/models",
provider: func(baseURL string) provider.Provider {
return aibridge.NewCopilotProvider(config.Copilot{BaseURL: baseURL})
},
expectPath: "/v1/models",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, tc.expectPath, r.URL.Path)
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte(upstreamRespBody))
}))
t.Cleanup(upstream.Close)
rec := testutil.MockRecorder{}
prov := tc.provider(upstream.URL + tc.baseURLPath)
bridge, err := aibridge.NewRequestBridge(t.Context(), []provider.Provider{prov}, &rec, nil, logger, nil, bridgeTestTracer)
require.NoError(t, err)
req := httptest.NewRequest("", tc.requestPath, nil)
resp := httptest.NewRecorder()
bridge.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
assert.Contains(t, resp.Body.String(), upstreamRespBody)
})
}
}
// TestDisabledProviderHandler asserts that requests to a disabled
// provider return a 503 with an ErrorCodeProviderDisabled body and
// that a sibling enabled provider keeps routing normally.
func TestDisabledProviderHandler(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_, _ = w.Write([]byte("upstream-reached"))
}))
t.Cleanup(upstream.Close)
enabled := aibridge.NewOpenAIProvider(config.OpenAI{Name: "enabled-openai", BaseURL: upstream.URL})
disabled := aibridge.NewDisabledProviderStub("disabled-openai", "openai")
bridge, err := aibridge.NewRequestBridge(
t.Context(),
[]provider.Provider{enabled, disabled},
nil, nil, logger, nil, bridgeTestTracer,
)
require.NoError(t, err)
for _, tc := range []struct {
name string
path string
}{
{name: "Bridged", path: "/disabled-openai/v1/chat/completions"},
{name: "Passthrough", path: "/disabled-openai/v1/models"},
{name: "Unknown", path: "/disabled-openai/anything/else"},
} {
t.Run("DisabledProviderReturnsSentinel/"+tc.name, func(t *testing.T) {
t.Parallel()
req := httptest.NewRequest(http.MethodPost, tc.path, nil)
resp := httptest.NewRecorder()
bridge.ServeHTTP(resp, req)
assert.Equal(t, http.StatusServiceUnavailable, resp.Code)
assert.Contains(t, resp.Body.String(), aibridge.ErrorCodeProviderDisabled)
assert.Contains(t, resp.Body.String(), "disabled-openai")
})
}
t.Run("EnabledProviderUnaffected", func(t *testing.T) {
t.Parallel()
req := httptest.NewRequest(http.MethodGet, "/enabled-openai/v1/models", nil)
resp := httptest.NewRecorder()
bridge.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
assert.Equal(t, "upstream-reached", resp.Body.String())
})
}