mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
5b10268827
_Disclosure: created with Coder Agents._ When providers are disabled, we should serve a sentinel error so the requesting client (Claude Code, Coder Agents, etc) is informed. Coder Agents can also conditionalize its display to show a helpful error message. --------- Signed-off-by: Danny Kopping <danny@coder.com> Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
460 lines
15 KiB
Go
460 lines
15 KiB
Go
//go:build !slim
|
|
|
|
package cli
|
|
|
|
import (
|
|
"database/sql"
|
|
"testing"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"cdr.dev/slog/v3/sloggers/slogtest"
|
|
"github.com/coder/coder/v2/aibridge"
|
|
"github.com/coder/coder/v2/coderd"
|
|
agplaibridge "github.com/coder/coder/v2/coderd/aibridge"
|
|
"github.com/coder/coder/v2/coderd/aibridged"
|
|
"github.com/coder/coder/v2/coderd/database"
|
|
"github.com/coder/coder/v2/coderd/database/dbgen"
|
|
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
|
"github.com/coder/coder/v2/codersdk"
|
|
"github.com/coder/coder/v2/testutil"
|
|
"github.com/coder/serpent"
|
|
)
|
|
|
|
// buildFromEnv exercises the same env-config-in/providers-out path that
|
|
// production uses on boot: SeedAIProvidersFromEnv writes the env-derived
|
|
// rows to the database, and BuildProviders reads them back as runtime
|
|
// [aibridge.Provider] instances. This keeps the existing TestBuildProviders
|
|
// table intact while reflecting the post-refactor flow where the database
|
|
// is the single source of truth.
|
|
func buildFromEnv(t *testing.T, cfg codersdk.AIBridgeConfig) ([]aibridge.Provider, error) {
|
|
t.Helper()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
logger := slogtest.Make(t, nil)
|
|
if err := coderd.SeedAIProvidersFromEnv(ctx, db, cfg, logger); err != nil {
|
|
return nil, err
|
|
}
|
|
providers, _, err := BuildProviders(ctx, db, cfg, logger)
|
|
return providers, err
|
|
}
|
|
|
|
func TestBuildProviders(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("EmptyConfig", func(t *testing.T) {
|
|
t.Parallel()
|
|
providers, err := buildFromEnv(t, codersdk.AIBridgeConfig{})
|
|
require.NoError(t, err)
|
|
assert.Empty(t, providers)
|
|
})
|
|
|
|
t.Run("LegacyOnly", func(t *testing.T) {
|
|
t.Parallel()
|
|
cfg := codersdk.AIBridgeConfig{}
|
|
cfg.LegacyOpenAI.Key = serpent.String("sk-openai")
|
|
cfg.LegacyAnthropic.Key = serpent.String("sk-anthropic")
|
|
|
|
providers, err := buildFromEnv(t, cfg)
|
|
require.NoError(t, err)
|
|
|
|
names := providerNames(providers)
|
|
assert.Contains(t, names, aibridge.ProviderOpenAI)
|
|
assert.Contains(t, names, aibridge.ProviderAnthropic)
|
|
assert.Len(t, names, 2)
|
|
})
|
|
|
|
t.Run("IndexedOnly", func(t *testing.T) {
|
|
t.Parallel()
|
|
cfg := codersdk.AIBridgeConfig{
|
|
Providers: []codersdk.AIProviderConfig{
|
|
{
|
|
Type: aibridge.ProviderAnthropic,
|
|
Name: "anthropic-zdr",
|
|
Keys: []string{"sk-zdr"},
|
|
},
|
|
{
|
|
Type: aibridge.ProviderOpenAI,
|
|
Name: "openai-azure",
|
|
Keys: []string{"sk-azure"},
|
|
BaseURL: "https://azure.openai.com",
|
|
},
|
|
},
|
|
}
|
|
|
|
providers, err := buildFromEnv(t, cfg)
|
|
require.NoError(t, err)
|
|
require.Len(t, providers, 2)
|
|
|
|
byName := make(map[string]aibridge.Provider, len(providers))
|
|
for _, p := range providers {
|
|
byName[p.Name()] = p
|
|
}
|
|
require.Contains(t, byName, "anthropic-zdr")
|
|
require.Contains(t, byName, "openai-azure")
|
|
})
|
|
|
|
t.Run("LegacyOpenAIConflictsWithIndexed", func(t *testing.T) {
|
|
t.Parallel()
|
|
cfg := codersdk.AIBridgeConfig{
|
|
Providers: []codersdk.AIProviderConfig{
|
|
{Type: aibridge.ProviderOpenAI, Name: aibridge.ProviderOpenAI, Keys: []string{"sk-indexed"}},
|
|
},
|
|
}
|
|
cfg.LegacyOpenAI.Key = serpent.String("sk-legacy")
|
|
|
|
_, err := buildFromEnv(t, cfg)
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "conflicts with the legacy env var")
|
|
})
|
|
|
|
t.Run("LegacyAnthropicConflictsWithIndexed", func(t *testing.T) {
|
|
t.Parallel()
|
|
cfg := codersdk.AIBridgeConfig{
|
|
Providers: []codersdk.AIProviderConfig{
|
|
{Type: aibridge.ProviderAnthropic, Name: aibridge.ProviderAnthropic, Keys: []string{"sk-indexed"}},
|
|
},
|
|
}
|
|
cfg.LegacyAnthropic.Key = serpent.String("sk-legacy")
|
|
|
|
_, err := buildFromEnv(t, cfg)
|
|
require.Error(t, err)
|
|
assert.Contains(t, err.Error(), "conflicts with the legacy env var")
|
|
})
|
|
|
|
t.Run("MixedLegacyAndIndexed", func(t *testing.T) {
|
|
t.Parallel()
|
|
cfg := codersdk.AIBridgeConfig{
|
|
Providers: []codersdk.AIProviderConfig{
|
|
{Type: aibridge.ProviderAnthropic, Name: "anthropic-zdr", Keys: []string{"sk-zdr"}},
|
|
},
|
|
}
|
|
cfg.LegacyOpenAI.Key = serpent.String("sk-openai")
|
|
cfg.LegacyAnthropic.Key = serpent.String("sk-anthropic")
|
|
|
|
providers, err := buildFromEnv(t, cfg)
|
|
require.NoError(t, err)
|
|
|
|
names := providerNames(providers)
|
|
assert.Contains(t, names, aibridge.ProviderOpenAI)
|
|
assert.Contains(t, names, aibridge.ProviderAnthropic)
|
|
assert.Contains(t, names, "anthropic-zdr")
|
|
})
|
|
|
|
t.Run("LegacyAnthropicWithBedrock", func(t *testing.T) {
|
|
t.Parallel()
|
|
cfg := codersdk.AIBridgeConfig{}
|
|
cfg.LegacyAnthropic.Key = serpent.String("sk-anthropic")
|
|
cfg.LegacyBedrock.Region = serpent.String("us-west-2")
|
|
cfg.LegacyBedrock.AccessKey = serpent.String("AKID")
|
|
cfg.LegacyBedrock.AccessKeySecret = serpent.String("secret")
|
|
|
|
providers, err := buildFromEnv(t, cfg)
|
|
require.NoError(t, err)
|
|
|
|
names := providerNames(providers)
|
|
assert.Equal(t, []string{aibridge.ProviderAnthropic}, names)
|
|
})
|
|
|
|
t.Run("LegacyBedrockWithoutAnthropicKey", func(t *testing.T) {
|
|
t.Parallel()
|
|
// Bedrock credentials alone should be enough to create an
|
|
// Anthropic provider — no CODER_AIBRIDGE_ANTHROPIC_KEY needed.
|
|
cfg := codersdk.AIBridgeConfig{}
|
|
cfg.LegacyBedrock.Region = serpent.String("us-west-2")
|
|
cfg.LegacyBedrock.AccessKey = serpent.String("AKID")
|
|
cfg.LegacyBedrock.AccessKeySecret = serpent.String("secret")
|
|
|
|
providers, err := buildFromEnv(t, cfg)
|
|
require.NoError(t, err)
|
|
require.Len(t, providers, 1)
|
|
|
|
p := providers[0]
|
|
assert.Equal(t, aibridge.ProviderAnthropic, p.Type())
|
|
assert.Equal(t, aibridge.ProviderAnthropic, p.Name())
|
|
})
|
|
|
|
t.Run("UnknownType", func(t *testing.T) {
|
|
t.Parallel()
|
|
// Unknown provider types are dropped by the seed step (logged
|
|
// and skipped) so one misconfigured row cannot stop the daemon
|
|
// from starting. The end state is "no providers", not an error.
|
|
cfg := codersdk.AIBridgeConfig{
|
|
Providers: []codersdk.AIProviderConfig{
|
|
{Type: "gemini", Name: "gemini-pro"},
|
|
},
|
|
}
|
|
|
|
providers, err := buildFromEnv(t, cfg)
|
|
require.NoError(t, err)
|
|
assert.Empty(t, providers)
|
|
})
|
|
|
|
t.Run("CopilotVariants", func(t *testing.T) {
|
|
t.Parallel()
|
|
// Copilot providers can target any of the three GitHub
|
|
// Copilot API hosts via an explicit BASE_URL.
|
|
cfg := codersdk.AIBridgeConfig{
|
|
Providers: []codersdk.AIProviderConfig{
|
|
{Type: aibridge.ProviderCopilot, Name: aibridge.ProviderCopilot},
|
|
{Type: aibridge.ProviderCopilot, Name: agplaibridge.ProviderCopilotBusiness, BaseURL: "https://" + agplaibridge.HostCopilotBusiness},
|
|
{Type: aibridge.ProviderCopilot, Name: agplaibridge.ProviderCopilotEnterprise, BaseURL: "https://" + agplaibridge.HostCopilotEnterprise},
|
|
},
|
|
}
|
|
|
|
providers, err := buildFromEnv(t, cfg)
|
|
require.NoError(t, err)
|
|
require.Len(t, providers, 3)
|
|
|
|
byName := make(map[string]aibridge.Provider, len(providers))
|
|
for _, p := range providers {
|
|
byName[p.Name()] = p
|
|
}
|
|
require.Contains(t, byName, aibridge.ProviderCopilot)
|
|
require.Contains(t, byName, agplaibridge.ProviderCopilotBusiness)
|
|
require.Contains(t, byName, agplaibridge.ProviderCopilotEnterprise)
|
|
assert.Equal(t, "https://"+agplaibridge.HostCopilotBusiness, byName[agplaibridge.ProviderCopilotBusiness].BaseURL())
|
|
assert.Equal(t, "https://"+agplaibridge.HostCopilotEnterprise, byName[agplaibridge.ProviderCopilotEnterprise].BaseURL())
|
|
})
|
|
|
|
t.Run("ChatGPTProvider", func(t *testing.T) {
|
|
t.Parallel()
|
|
// ChatGPT is an OpenAI-compatible provider with a custom
|
|
// base URL. Admins configure it as an indexed openai provider.
|
|
cfg := codersdk.AIBridgeConfig{
|
|
Providers: []codersdk.AIProviderConfig{
|
|
{Type: aibridge.ProviderOpenAI, Name: agplaibridge.ProviderChatGPT, Keys: []string{"sk-chatgpt"}, BaseURL: agplaibridge.BaseURLChatGPT},
|
|
},
|
|
}
|
|
|
|
providers, err := buildFromEnv(t, cfg)
|
|
require.NoError(t, err)
|
|
require.Len(t, providers, 1)
|
|
|
|
assert.Equal(t, agplaibridge.ProviderChatGPT, providers[0].Name())
|
|
assert.Equal(t, agplaibridge.BaseURLChatGPT, providers[0].BaseURL())
|
|
})
|
|
|
|
t.Run("NativeAnthropicDefaultBaseURL", func(t *testing.T) {
|
|
t.Parallel()
|
|
row := database.AIProvider{
|
|
Type: database.AiProviderTypeAnthropic,
|
|
Name: aibridge.ProviderAnthropic,
|
|
BaseUrl: "https://api.anthropic.com/",
|
|
}
|
|
assert.Nil(t, bedrockConfigFromRow(row, codersdk.AIProviderSettings{}))
|
|
})
|
|
|
|
t.Run("NativeAnthropicCustomBaseURL", func(t *testing.T) {
|
|
t.Parallel()
|
|
row := database.AIProvider{
|
|
Type: database.AiProviderTypeAnthropic,
|
|
Name: "anthropic-proxy",
|
|
BaseUrl: "https://internal-proxy.example.com/anthropic/",
|
|
}
|
|
assert.Nil(t, bedrockConfigFromRow(row, codersdk.AIProviderSettings{}))
|
|
})
|
|
|
|
t.Run("BedrockSettingsPresent", func(t *testing.T) {
|
|
t.Parallel()
|
|
accessKey := "AKID"
|
|
secret := "secret"
|
|
model := "anthropic.claude-3-5-sonnet-20241022-v2:0"
|
|
smallModel := "anthropic.claude-3-5-haiku-20241022-v1:0"
|
|
row := database.AIProvider{
|
|
Type: database.AiProviderTypeAnthropic,
|
|
Name: "anthropic-bedrock",
|
|
BaseUrl: "https://bedrock-runtime.us-west-2.amazonaws.com/",
|
|
}
|
|
settings := codersdk.AIProviderSettings{
|
|
Bedrock: &codersdk.AIProviderBedrockSettings{
|
|
Region: "us-west-2",
|
|
AccessKey: &accessKey,
|
|
AccessKeySecret: &secret,
|
|
Model: model,
|
|
SmallFastModel: smallModel,
|
|
},
|
|
}
|
|
got := bedrockConfigFromRow(row, settings)
|
|
require.NotNil(t, got)
|
|
assert.Equal(t, row.BaseUrl, got.BaseURL)
|
|
assert.Equal(t, "us-west-2", got.Region)
|
|
assert.Equal(t, accessKey, got.AccessKey)
|
|
assert.Equal(t, secret, got.AccessKeySecret)
|
|
assert.Equal(t, model, got.Model)
|
|
assert.Equal(t, smallModel, got.SmallFastModel)
|
|
})
|
|
|
|
t.Run("BedrockSettingsEmpty", func(t *testing.T) {
|
|
t.Parallel()
|
|
// A non-nil but zero-valued Bedrock settings blob should not
|
|
// produce a Bedrock config; the provider's generic BaseUrl is
|
|
// not a Bedrock detection signal.
|
|
row := database.AIProvider{
|
|
Type: database.AiProviderTypeAnthropic,
|
|
Name: "anthropic-empty-bedrock",
|
|
BaseUrl: "https://api.anthropic.com/",
|
|
}
|
|
settings := codersdk.AIProviderSettings{
|
|
Bedrock: &codersdk.AIProviderBedrockSettings{},
|
|
}
|
|
assert.Nil(t, bedrockConfigFromRow(row, settings))
|
|
})
|
|
}
|
|
|
|
// TestBuildProvidersSkipsBadRows exercises the skip-and-continue path
|
|
// directly: rows whose settings blob is malformed or whose type is not
|
|
// supported by the runtime builder are logged and excluded from the
|
|
// returned snapshot without surfacing a top-level error. The seed path
|
|
// filters most of these out before insert, so we bypass it and insert
|
|
// rows straight into the database via dbgen.
|
|
func TestBuildProvidersSkipsBadRows(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("CorruptSettings", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
|
|
dbgen.AIProvider(t, db, database.AIProvider{
|
|
Type: database.AiProviderTypeAnthropic,
|
|
Name: "anthropic-broken",
|
|
BaseUrl: "https://api.anthropic.com/",
|
|
Settings: sql.NullString{String: "not-json", Valid: true},
|
|
})
|
|
|
|
providers, outcomes, err := BuildProviders(ctx, db, codersdk.AIBridgeConfig{}, logger)
|
|
require.NoError(t, err)
|
|
assert.Empty(t, providers)
|
|
require.Len(t, outcomes, 1)
|
|
assert.Equal(t, "anthropic-broken", outcomes[0].Name)
|
|
assert.Equal(t, aibridged.ProviderStatusError, outcomes[0].Status)
|
|
assert.Error(t, outcomes[0].Err)
|
|
})
|
|
|
|
t.Run("EnabledButNoKeys", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
|
|
// Azure routes through the OpenAI-family builder, which rejects
|
|
// rows without keys when BYOK is disabled. The row must be
|
|
// classified as error and excluded from the snapshot.
|
|
dbgen.AIProvider(t, db, database.AIProvider{
|
|
Type: database.AiProviderTypeAzure,
|
|
Name: "azure-openai",
|
|
BaseUrl: "https://example.openai.azure.com/",
|
|
})
|
|
|
|
providers, outcomes, err := BuildProviders(ctx, db, codersdk.AIBridgeConfig{}, logger)
|
|
require.NoError(t, err)
|
|
assert.Empty(t, providers)
|
|
require.Len(t, outcomes, 1)
|
|
assert.Equal(t, aibridged.ProviderStatusError, outcomes[0].Status)
|
|
})
|
|
|
|
t.Run("BadRowDoesNotBlockGoodRow", func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
|
|
|
dbgen.AIProvider(t, db, database.AIProvider{
|
|
Type: database.AiProviderTypeAnthropic,
|
|
Name: "anthropic-broken",
|
|
BaseUrl: "https://api.anthropic.com/",
|
|
Settings: sql.NullString{String: "{not valid json", Valid: true},
|
|
})
|
|
good := dbgen.AIProvider(t, db, database.AIProvider{
|
|
Type: database.AiProviderTypeOpenai,
|
|
Name: "openai-good",
|
|
BaseUrl: "https://api.openai.com/",
|
|
})
|
|
dbgen.AIProviderKey(t, db, database.AIProviderKey{
|
|
ProviderID: good.ID,
|
|
APIKey: "sk-good",
|
|
})
|
|
|
|
providers, outcomes, err := BuildProviders(ctx, db, codersdk.AIBridgeConfig{}, logger)
|
|
require.NoError(t, err)
|
|
require.Len(t, providers, 1)
|
|
assert.Equal(t, "openai-good", providers[0].Name())
|
|
require.Len(t, outcomes, 2)
|
|
byName := map[string]aibridged.ProviderOutcome{}
|
|
for _, o := range outcomes {
|
|
byName[o.Name] = o
|
|
}
|
|
assert.Equal(t, aibridged.ProviderStatusError, byName["anthropic-broken"].Status)
|
|
assert.Equal(t, aibridged.ProviderStatusEnabled, byName["openai-good"].Status)
|
|
})
|
|
|
|
t.Run("DisabledRowClassifiedAsDisabled", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
for _, tc := range []struct {
|
|
name string
|
|
row database.AIProvider
|
|
}{
|
|
{
|
|
name: "OpenAI",
|
|
row: database.AIProvider{
|
|
Type: database.AiProviderTypeOpenai,
|
|
Name: "openai-off",
|
|
BaseUrl: "https://api.openai.com/",
|
|
},
|
|
},
|
|
{
|
|
// Anthropic and Bedrock have stricter credential checks
|
|
// than the OpenAI family; the disabled short-circuit
|
|
// must reach them too. No keys, no bedrock settings.
|
|
name: "Anthropic",
|
|
row: database.AIProvider{
|
|
Type: database.AiProviderTypeAnthropic,
|
|
Name: "anthropic-off",
|
|
BaseUrl: "https://api.anthropic.com/",
|
|
},
|
|
},
|
|
{
|
|
name: "Bedrock",
|
|
row: database.AIProvider{
|
|
Type: database.AiProviderTypeBedrock,
|
|
Name: "bedrock-off",
|
|
BaseUrl: "https://bedrock-runtime.us-east-1.amazonaws.com/",
|
|
},
|
|
},
|
|
} {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
db, _ := dbtestutil.NewDB(t)
|
|
ctx := testutil.Context(t, testutil.WaitShort)
|
|
logger := slogtest.Make(t, nil)
|
|
|
|
dbgen.AIProvider(t, db, tc.row, func(p *database.InsertAIProviderParams) {
|
|
p.Enabled = false
|
|
})
|
|
|
|
providers, outcomes, err := BuildProviders(ctx, db, codersdk.AIBridgeConfig{}, logger)
|
|
require.NoError(t, err)
|
|
require.Len(t, providers, 1, "disabled providers stay in the snapshot so the bridge can serve a 503 sentinel")
|
|
assert.Equal(t, tc.row.Name, providers[0].Name())
|
|
assert.False(t, providers[0].Enabled())
|
|
require.Len(t, outcomes, 1)
|
|
assert.Equal(t, tc.row.Name, outcomes[0].Name)
|
|
assert.Equal(t, aibridged.ProviderStatusDisabled, outcomes[0].Status)
|
|
assert.NoError(t, outcomes[0].Err)
|
|
})
|
|
}
|
|
})
|
|
}
|
|
|
|
func providerNames(providers []aibridge.Provider) []string {
|
|
names := make([]string, len(providers))
|
|
for i, p := range providers {
|
|
names[i] = p.Name()
|
|
}
|
|
return names
|
|
}
|