refactor: load AI providers from the database at startup (#25672)

Replace the env-based `BuildProviders` with a DB-backed loader. The database is now the single source of truth for runtime provider configuration; env config arrives via `SeedAIProvidersFromEnv` (run at boot) and `BuildProviders` reads it back as `aibridge.Provider` instances. `cli/server.go` and `enterprise/cli/server.go` both call the same path, so aibridged and aibridgeproxyd see the same provider set.

Per-provider `DumpDir` is replaced by a top-level `CODER_AI_GATEWAY_DUMP_DIR` base; each provider's effective dump path is `<base>/<provider name>`.
This commit is contained in:
Danny Kopping
2026-05-26 15:57:01 +02:00
committed by GitHub
parent dfd7ca3b98
commit 282ab7de34
19 changed files with 570 additions and 258 deletions
+168 -145
View File
@@ -5,15 +5,21 @@ package cli
import (
"context"
"github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/aibridge"
"github.com/coder/coder/v2/aibridge/config"
"github.com/coder/coder/v2/aibridge/keypool"
"github.com/coder/coder/v2/coderd"
"github.com/coder/coder/v2/coderd/aibridged"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/tracing"
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/quartz"
)
@@ -44,183 +50,200 @@ func newAIBridgeDaemon(coderAPI *coderd.API, providers []aibridge.Provider) (*ai
return srv, nil
}
// BuildProviders constructs the list of AI providers from config.
// It merges legacy single-provider env vars and indexed provider configs:
// 1. Legacy providers (from CODER_AI_GATEWAY_OPENAI_KEY, etc.) are added first.
// If a legacy name conflicts with an indexed provider, startup fails with
// a clear error asking the admin to remove one or the other.
// 2. Indexed providers (from CODER_AI_GATEWAY_PROVIDER_<N>_*) are added next.
func BuildProviders(cfg codersdk.AIBridgeConfig) ([]aibridge.Provider, error) {
var cbConfig *config.CircuitBreaker
if cfg.CircuitBreakerEnabled.Value() {
cbConfig = &config.CircuitBreaker{
FailureThreshold: uint32(cfg.CircuitBreakerFailureThreshold.Value()), //nolint:gosec // Validated by serpent.Validate in deployment options.
Interval: cfg.CircuitBreakerInterval.Value(),
Timeout: cfg.CircuitBreakerTimeout.Value(),
MaxRequests: uint32(cfg.CircuitBreakerMaxRequests.Value()), //nolint:gosec // Validated by serpent.Validate in deployment options.
// BuildProviders loads every enabled ai_providers row, attaches its
// keys, and constructs the equivalent [aibridge.Provider] instances.
// The database is the single source of truth for runtime provider
// configuration.
//
// Per-provider construction errors are logged and the offending row is
// excluded from the returned snapshot; only a failure of the DB query
// itself is propagated. This keeps a single misconfigured row from
// taking the whole daemon down.
func BuildProviders(ctx context.Context, db database.Store, cfg codersdk.AIBridgeConfig, logger slog.Logger) ([]aibridge.Provider, error) {
//nolint:gocritic // AsAIBridged has a minimal permission set for this purpose.
authCtx := dbauthz.AsAIBridged(ctx)
var rows []database.AIProvider
keysByProvider := make(map[uuid.UUID][]database.AIProviderKey)
// Wrap both queries in a read-only transaction so the provider list
// and the key list are consistent with each other.
err := db.InTx(func(tx database.Store) error {
var err error
rows, err = tx.GetAIProviders(authCtx, database.GetAIProvidersParams{
IncludeDisabled: false,
})
if err != nil {
return xerrors.Errorf("load ai providers: %w", err)
}
if len(rows) == 0 {
return nil
}
// Load keys only for the enabled providers to avoid materializing
// secrets for disabled rows.
ids := make([]uuid.UUID, len(rows))
for i, r := range rows {
ids[i] = r.ID
}
keyRows, err := tx.GetAIProviderKeysByProviderIDs(authCtx, ids)
if err != nil {
return xerrors.Errorf("load ai provider keys: %w", err)
}
for _, k := range keyRows {
keysByProvider[k.ProviderID] = append(keysByProvider[k.ProviderID], k)
}
return nil
}, &database.TxOptions{ReadOnly: true, TxIdentifier: "build_ai_providers"})
if err != nil {
return nil, err
}
var providers []aibridge.Provider
usedNames := make(map[string]struct{})
// Collect names from indexed providers so we can detect conflicts
// with legacy providers.
for _, p := range cfg.Providers {
name := p.Name
if name == "" {
name = p.Type
out := make([]aibridge.Provider, 0, len(rows))
for _, row := range rows {
prov, err := buildAIProviderFromRow(row, keysByProvider[row.ID], cfg)
if err != nil {
logger.Error(ctx, "skipping misconfigured ai provider",
slog.F("provider_id", row.ID),
slog.F("provider_name", row.Name),
slog.F("provider_type", string(row.Type)),
slog.Error(err),
)
continue
}
usedNames[name] = struct{}{}
out = append(out, prov)
}
// Add legacy OpenAI provider if configured.
if cfg.LegacyOpenAI.Key.String() != "" {
if _, conflict := usedNames[aibridge.ProviderOpenAI]; conflict {
return nil, xerrors.Errorf("legacy CODER_AI_GATEWAY_OPENAI_KEY (or CODER_AIBRIDGE_OPENAI_KEY) conflicts with indexed provider named %q; remove one or the other", aibridge.ProviderOpenAI)
}
providers = append(providers, aibridge.NewOpenAIProvider(aibridge.OpenAIConfig{
Name: aibridge.ProviderOpenAI,
BaseURL: cfg.LegacyOpenAI.BaseURL.String(),
Key: cfg.LegacyOpenAI.Key.String(),
CircuitBreaker: cbConfig,
SendActorHeaders: cfg.SendActorHeaders.Value(),
}))
usedNames[aibridge.ProviderOpenAI] = struct{}{}
if len(rows) > 0 && len(out) == 0 {
logger.Warn(ctx, "all enabled ai providers failed to build; daemon will start with zero providers")
}
// Add legacy Anthropic provider if configured. Bedrock credentials
// alone are sufficient, an Anthropic API key is not required when
// using AWS Bedrock.
if cfg.LegacyAnthropic.Key.String() != "" || getBedrockConfig(cfg.LegacyBedrock) != nil {
if _, conflict := usedNames[aibridge.ProviderAnthropic]; conflict {
return nil, xerrors.Errorf("legacy CODER_AI_GATEWAY_ANTHROPIC_KEY (or CODER_AIBRIDGE_ANTHROPIC_KEY) conflicts with indexed provider named %q; remove one or the other", aibridge.ProviderAnthropic)
return out, nil
}
// buildAIProviderFromRow decodes the settings blob and constructs the
// appropriate [aibridge.Provider] for a single ai_providers row.
func buildAIProviderFromRow(
row database.AIProvider,
keys []database.AIProviderKey,
cfg codersdk.AIBridgeConfig,
) (aibridge.Provider, error) {
settings, err := db2sdk.AIProviderSettings(row.Settings)
if err != nil {
return nil, xerrors.Errorf("decode settings: %w", err)
}
cbCfg := circuitBreakerConfig(cfg)
sendActorHeaders := cfg.SendActorHeaders.Value()
dumpDir := cfg.APIDumpDir.Value()
switch row.Type {
case database.AiProviderTypeOpenai:
if len(keys) == 0 && !cfg.AllowBYOK.Value() {
return nil, xerrors.New("openai provider has no api keys configured and BYOK is not enabled")
}
var pool *keypool.Pool
if key := cfg.LegacyAnthropic.Key.String(); key != "" {
if len(keys) > 0 {
var err error
pool, err = keypool.New([]string{key}, quartz.NewReal())
pool, err = buildAIProviderKeyPool(keys)
if err != nil {
return nil, xerrors.Errorf("create legacy anthropic key pool: %w", err)
return nil, xerrors.Errorf("openai key pool: %w", err)
}
}
providers = append(providers, aibridge.NewAnthropicProvider(aibridge.AnthropicConfig{
Name: aibridge.ProviderAnthropic,
BaseURL: cfg.LegacyAnthropic.BaseURL.String(),
return aibridge.NewOpenAIProvider(aibridge.OpenAIConfig{
Name: row.Name,
BaseURL: row.BaseUrl,
KeyPool: pool,
CircuitBreaker: cbConfig,
SendActorHeaders: cfg.SendActorHeaders.Value(),
}, getBedrockConfig(cfg.LegacyBedrock)))
usedNames[aibridge.ProviderAnthropic] = struct{}{}
}
APIDumpDir: dumpDir,
CircuitBreaker: cbCfg,
SendActorHeaders: sendActorHeaders,
}), nil
// Add indexed providers.
for _, p := range cfg.Providers {
name := p.Name
if name == "" {
name = p.Type
case database.AiProviderTypeAnthropic:
bedrock := bedrockConfigFromRow(row, settings)
// Bedrock-backed Anthropic authenticates via AWS credentials in
// the settings blob, not the api_keys table. A bearer-token
// Anthropic without any key cannot make upstream calls.
if bedrock == nil && len(keys) == 0 && !cfg.AllowBYOK.Value() {
return nil, xerrors.New("anthropic provider has no api keys, no bedrock credentials, and BYOK is not enabled")
}
switch p.Type {
case aibridge.ProviderOpenAI:
var pool *keypool.Pool
if len(p.Keys) > 0 {
var err error
pool, err = keypool.New(p.Keys, quartz.NewReal())
if err != nil {
return nil, xerrors.Errorf("create openai key pool for provider %q: %w", name, err)
}
var pool *keypool.Pool
if len(keys) > 0 {
var err error
pool, err = buildAIProviderKeyPool(keys)
if err != nil {
return nil, xerrors.Errorf("anthropic key pool: %w", err)
}
providers = append(providers, aibridge.NewOpenAIProvider(aibridge.OpenAIConfig{
Name: name,
BaseURL: p.BaseURL,
KeyPool: pool,
APIDumpDir: p.DumpDir,
CircuitBreaker: cbConfig,
SendActorHeaders: cfg.SendActorHeaders.Value(),
}))
case aibridge.ProviderAnthropic:
var pool *keypool.Pool
if len(p.Keys) > 0 {
var err error
pool, err = keypool.New(p.Keys, quartz.NewReal())
if err != nil {
return nil, xerrors.Errorf("create anthropic key pool for provider %q: %w", name, err)
}
}
providers = append(providers, aibridge.NewAnthropicProvider(aibridge.AnthropicConfig{
Name: name,
BaseURL: p.BaseURL,
KeyPool: pool,
APIDumpDir: p.DumpDir,
CircuitBreaker: cbConfig,
SendActorHeaders: cfg.SendActorHeaders.Value(),
}, bedrockConfigFromProvider(p)))
case aibridge.ProviderCopilot:
providers = append(providers, aibridge.NewCopilotProvider(aibridge.CopilotConfig{
Name: name,
BaseURL: p.BaseURL,
APIDumpDir: p.DumpDir,
CircuitBreaker: cbConfig,
}))
default:
return nil, xerrors.Errorf("unknown provider type %q for provider %q", p.Type, name)
}
}
return aibridge.NewAnthropicProvider(aibridge.AnthropicConfig{
Name: row.Name,
BaseURL: row.BaseUrl,
KeyPool: pool,
APIDumpDir: dumpDir,
CircuitBreaker: cbCfg,
SendActorHeaders: sendActorHeaders,
}, bedrock), nil
return providers, nil
case database.AiProviderTypeCopilot:
// Copilot is always BYOK; the per-user token is supplied on each
// request via the Authorization header, so no keypool is built.
return aibridge.NewCopilotProvider(aibridge.CopilotConfig{
Name: row.Name,
BaseURL: row.BaseUrl,
APIDumpDir: dumpDir,
CircuitBreaker: cbCfg,
}), nil
default:
return nil, xerrors.Errorf("unsupported provider type: %q", row.Type)
}
}
// bedrockConfigFromProvider converts Bedrock fields from an indexed
// AIProviderConfig into an aibridge AWSBedrockConfig.
// Returns nil if no Bedrock fields are set.
func bedrockConfigFromProvider(p codersdk.AIProviderConfig) *aibridge.AWSBedrockConfig {
// Currently, only the first key pair is used, if any.
// TODO(ssncferreira): pass a keypool.Pool instead.
var accessKey, accessKeySecret string
if len(p.BedrockAccessKeys) > 0 {
accessKey = p.BedrockAccessKeys[0]
// buildAIProviderKeyPool builds a [keypool.Pool]. Callers must check
// len(keys) > 0 first; keypool.New rejects empty input.
func buildAIProviderKeyPool(keys []database.AIProviderKey) (*keypool.Pool, error) {
raw := make([]string, 0, len(keys))
for _, k := range keys {
raw = append(raw, k.APIKey)
}
if len(p.BedrockAccessKeySecrets) > 0 {
accessKeySecret = p.BedrockAccessKeySecrets[0]
}
settings := codersdk.NewAIProviderBedrockSettings(
p.BedrockRegion, accessKey, accessKeySecret,
p.BedrockModel, p.BedrockSmallFastModel,
)
if !codersdk.IsBedrockConfigured(p.BedrockBaseURL, settings) {
return keypool.New(raw, quartz.NewReal())
}
// bedrockConfigFromRow returns nil when the settings have no Bedrock
// discriminator or when the Bedrock fields are not actually configured.
// The provider row's BaseUrl is the generic upstream endpoint and is
// always non-empty, so it cannot serve as a Bedrock detection signal;
// gate on the settings blob alone via [codersdk.AIProviderBedrockSettings.IsConfigured].
func bedrockConfigFromRow(row database.AIProvider, settings codersdk.AIProviderSettings) *aibridge.AWSBedrockConfig {
if settings.Bedrock == nil {
return nil
}
bedrockSettings := *settings.Bedrock
if !bedrockSettings.IsConfigured() {
return nil
}
accessKey := ptr.NilToEmpty(bedrockSettings.AccessKey)
accessKeySecret := ptr.NilToEmpty(bedrockSettings.AccessKeySecret)
return &aibridge.AWSBedrockConfig{
BaseURL: p.BedrockBaseURL,
Region: p.BedrockRegion,
BaseURL: row.BaseUrl,
Region: bedrockSettings.Region,
AccessKey: accessKey,
AccessKeySecret: accessKeySecret,
Model: p.BedrockModel,
SmallFastModel: p.BedrockSmallFastModel,
Model: bedrockSettings.Model,
SmallFastModel: bedrockSettings.SmallFastModel,
}
}
func getBedrockConfig(cfg codersdk.AIBridgeBedrockConfig) *aibridge.AWSBedrockConfig {
// codersdk.IsBedrockConfigured decides what counts as Bedrock; when
// it returns false, the AWS SDK default credential chain (env vars,
// shared config, IAM roles, etc.) is left to resolve credentials.
settings := codersdk.NewAIProviderBedrockSettings(
cfg.Region.String(),
cfg.AccessKey.String(),
cfg.AccessKeySecret.String(),
cfg.Model.String(),
cfg.SmallFastModel.String(),
)
if !codersdk.IsBedrockConfigured(cfg.BaseURL.String(), settings) {
// circuitBreakerConfig returns nil when the breaker is disabled.
func circuitBreakerConfig(cfg codersdk.AIBridgeConfig) *config.CircuitBreaker {
if !cfg.CircuitBreakerEnabled.Value() {
return nil
}
return &aibridge.AWSBedrockConfig{
BaseURL: cfg.BaseURL.String(),
Region: cfg.Region.String(),
AccessKey: cfg.AccessKey.String(),
AccessKeySecret: cfg.AccessKeySecret.String(),
Model: cfg.Model.String(),
SmallFastModel: cfg.SmallFastModel.String(),
return &config.CircuitBreaker{
FailureThreshold: uint32(cfg.CircuitBreakerFailureThreshold.Value()), //nolint:gosec // Validated by serpent.Validate in deployment options.
Interval: cfg.CircuitBreakerInterval.Value(),
Timeout: cfg.CircuitBreakerTimeout.Value(),
MaxRequests: uint32(cfg.CircuitBreakerMaxRequests.Value()), //nolint:gosec // Validated by serpent.Validate in deployment options.
}
}
+204 -32
View File
@@ -3,23 +3,47 @@
package cli
import (
"database/sql"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/aibridge"
"github.com/coder/coder/v2/coderd"
agplaibridge "github.com/coder/coder/v2/coderd/aibridge"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
"github.com/coder/serpent"
)
// buildFromEnv exercises the same env-config-in/providers-out path that
// production uses on boot: SeedAIProvidersFromEnv writes the env-derived
// rows to the database, and BuildProviders reads them back as runtime
// [aibridge.Provider] instances. This keeps the existing TestBuildProviders
// table intact while reflecting the post-refactor flow where the database
// is the single source of truth.
func buildFromEnv(t *testing.T, cfg codersdk.AIBridgeConfig) ([]aibridge.Provider, error) {
t.Helper()
db, _ := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, nil)
if err := coderd.SeedAIProvidersFromEnv(ctx, db, cfg, logger); err != nil {
return nil, err
}
return BuildProviders(ctx, db, cfg, logger)
}
func TestBuildProviders(t *testing.T) {
t.Parallel()
t.Run("EmptyConfig", func(t *testing.T) {
t.Parallel()
providers, err := BuildProviders(codersdk.AIBridgeConfig{})
providers, err := buildFromEnv(t, codersdk.AIBridgeConfig{})
require.NoError(t, err)
assert.Empty(t, providers)
})
@@ -30,7 +54,7 @@ func TestBuildProviders(t *testing.T) {
cfg.LegacyOpenAI.Key = serpent.String("sk-openai")
cfg.LegacyAnthropic.Key = serpent.String("sk-anthropic")
providers, err := BuildProviders(cfg)
providers, err := buildFromEnv(t, cfg)
require.NoError(t, err)
names := providerNames(providers)
@@ -44,28 +68,29 @@ func TestBuildProviders(t *testing.T) {
cfg := codersdk.AIBridgeConfig{
Providers: []codersdk.AIProviderConfig{
{
Type: aibridge.ProviderAnthropic,
Name: "anthropic-zdr",
Keys: []string{"sk-zdr"},
DumpDir: "/tmp/anthropic-dump",
Type: aibridge.ProviderAnthropic,
Name: "anthropic-zdr",
Keys: []string{"sk-zdr"},
},
{
Type: aibridge.ProviderOpenAI,
Name: "openai-azure",
Keys: []string{"sk-azure"},
BaseURL: "https://azure.openai.com",
DumpDir: "/tmp/openai-dump",
},
},
}
providers, err := BuildProviders(cfg)
providers, err := buildFromEnv(t, cfg)
require.NoError(t, err)
require.Len(t, providers, 2)
names := providerNames(providers)
assert.Equal(t, []string{"anthropic-zdr", "openai-azure"}, names)
assert.Equal(t, "/tmp/anthropic-dump", providers[0].APIDumpDir())
assert.Equal(t, "/tmp/openai-dump", providers[1].APIDumpDir())
byName := make(map[string]aibridge.Provider, len(providers))
for _, p := range providers {
byName[p.Name()] = p
}
require.Contains(t, byName, "anthropic-zdr")
require.Contains(t, byName, "openai-azure")
})
t.Run("LegacyOpenAIConflictsWithIndexed", func(t *testing.T) {
@@ -77,9 +102,9 @@ func TestBuildProviders(t *testing.T) {
}
cfg.LegacyOpenAI.Key = serpent.String("sk-legacy")
_, err := BuildProviders(cfg)
_, err := buildFromEnv(t, cfg)
require.Error(t, err)
assert.Contains(t, err.Error(), "conflicts with indexed provider")
assert.Contains(t, err.Error(), "conflicts with the legacy env var")
})
t.Run("LegacyAnthropicConflictsWithIndexed", func(t *testing.T) {
@@ -91,9 +116,9 @@ func TestBuildProviders(t *testing.T) {
}
cfg.LegacyAnthropic.Key = serpent.String("sk-legacy")
_, err := BuildProviders(cfg)
_, err := buildFromEnv(t, cfg)
require.Error(t, err)
assert.Contains(t, err.Error(), "conflicts with indexed provider")
assert.Contains(t, err.Error(), "conflicts with the legacy env var")
})
t.Run("MixedLegacyAndIndexed", func(t *testing.T) {
@@ -106,7 +131,7 @@ func TestBuildProviders(t *testing.T) {
cfg.LegacyOpenAI.Key = serpent.String("sk-openai")
cfg.LegacyAnthropic.Key = serpent.String("sk-anthropic")
providers, err := BuildProviders(cfg)
providers, err := buildFromEnv(t, cfg)
require.NoError(t, err)
names := providerNames(providers)
@@ -123,7 +148,7 @@ func TestBuildProviders(t *testing.T) {
cfg.LegacyBedrock.AccessKey = serpent.String("AKID")
cfg.LegacyBedrock.AccessKeySecret = serpent.String("secret")
providers, err := BuildProviders(cfg)
providers, err := buildFromEnv(t, cfg)
require.NoError(t, err)
names := providerNames(providers)
@@ -139,7 +164,7 @@ func TestBuildProviders(t *testing.T) {
cfg.LegacyBedrock.AccessKey = serpent.String("AKID")
cfg.LegacyBedrock.AccessKeySecret = serpent.String("secret")
providers, err := BuildProviders(cfg)
providers, err := buildFromEnv(t, cfg)
require.NoError(t, err)
require.Len(t, providers, 1)
@@ -150,15 +175,18 @@ func TestBuildProviders(t *testing.T) {
t.Run("UnknownType", func(t *testing.T) {
t.Parallel()
// Unknown provider types are dropped by the seed step (logged
// and skipped) so one misconfigured row cannot stop the daemon
// from starting. The end state is "no providers", not an error.
cfg := codersdk.AIBridgeConfig{
Providers: []codersdk.AIProviderConfig{
{Type: "gemini", Name: "gemini-pro"},
},
}
_, err := BuildProviders(cfg)
require.Error(t, err)
assert.Contains(t, err.Error(), "unknown provider type")
providers, err := buildFromEnv(t, cfg)
require.NoError(t, err)
assert.Empty(t, providers)
})
t.Run("CopilotVariants", func(t *testing.T) {
@@ -167,22 +195,25 @@ func TestBuildProviders(t *testing.T) {
// Copilot API hosts via an explicit BASE_URL.
cfg := codersdk.AIBridgeConfig{
Providers: []codersdk.AIProviderConfig{
{Type: aibridge.ProviderCopilot, Name: aibridge.ProviderCopilot, DumpDir: "/tmp/copilot-dump"},
{Type: aibridge.ProviderCopilot, Name: aibridge.ProviderCopilot},
{Type: aibridge.ProviderCopilot, Name: agplaibridge.ProviderCopilotBusiness, BaseURL: "https://" + agplaibridge.HostCopilotBusiness},
{Type: aibridge.ProviderCopilot, Name: agplaibridge.ProviderCopilotEnterprise, BaseURL: "https://" + agplaibridge.HostCopilotEnterprise},
},
}
providers, err := BuildProviders(cfg)
providers, err := buildFromEnv(t, cfg)
require.NoError(t, err)
require.Len(t, providers, 3)
assert.Equal(t, aibridge.ProviderCopilot, providers[0].Name())
assert.Equal(t, "/tmp/copilot-dump", providers[0].APIDumpDir())
assert.Equal(t, agplaibridge.ProviderCopilotBusiness, providers[1].Name())
assert.Equal(t, "https://"+agplaibridge.HostCopilotBusiness, providers[1].BaseURL())
assert.Equal(t, agplaibridge.ProviderCopilotEnterprise, providers[2].Name())
assert.Equal(t, "https://"+agplaibridge.HostCopilotEnterprise, providers[2].BaseURL())
byName := make(map[string]aibridge.Provider, len(providers))
for _, p := range providers {
byName[p.Name()] = p
}
require.Contains(t, byName, aibridge.ProviderCopilot)
require.Contains(t, byName, agplaibridge.ProviderCopilotBusiness)
require.Contains(t, byName, agplaibridge.ProviderCopilotEnterprise)
assert.Equal(t, "https://"+agplaibridge.HostCopilotBusiness, byName[agplaibridge.ProviderCopilotBusiness].BaseURL())
assert.Equal(t, "https://"+agplaibridge.HostCopilotEnterprise, byName[agplaibridge.ProviderCopilotEnterprise].BaseURL())
})
t.Run("ChatGPTProvider", func(t *testing.T) {
@@ -191,17 +222,158 @@ func TestBuildProviders(t *testing.T) {
// base URL. Admins configure it as an indexed openai provider.
cfg := codersdk.AIBridgeConfig{
Providers: []codersdk.AIProviderConfig{
{Type: aibridge.ProviderOpenAI, Name: agplaibridge.ProviderChatGPT, BaseURL: agplaibridge.BaseURLChatGPT},
{Type: aibridge.ProviderOpenAI, Name: agplaibridge.ProviderChatGPT, Keys: []string{"sk-chatgpt"}, BaseURL: agplaibridge.BaseURLChatGPT},
},
}
providers, err := BuildProviders(cfg)
providers, err := buildFromEnv(t, cfg)
require.NoError(t, err)
require.Len(t, providers, 1)
assert.Equal(t, agplaibridge.ProviderChatGPT, providers[0].Name())
assert.Equal(t, agplaibridge.BaseURLChatGPT, providers[0].BaseURL())
})
t.Run("NativeAnthropicDefaultBaseURL", func(t *testing.T) {
t.Parallel()
row := database.AIProvider{
Type: database.AiProviderTypeAnthropic,
Name: aibridge.ProviderAnthropic,
BaseUrl: "https://api.anthropic.com/",
}
assert.Nil(t, bedrockConfigFromRow(row, codersdk.AIProviderSettings{}))
})
t.Run("NativeAnthropicCustomBaseURL", func(t *testing.T) {
t.Parallel()
row := database.AIProvider{
Type: database.AiProviderTypeAnthropic,
Name: "anthropic-proxy",
BaseUrl: "https://internal-proxy.example.com/anthropic/",
}
assert.Nil(t, bedrockConfigFromRow(row, codersdk.AIProviderSettings{}))
})
t.Run("BedrockSettingsPresent", func(t *testing.T) {
t.Parallel()
accessKey := "AKID"
secret := "secret"
model := "anthropic.claude-3-5-sonnet-20241022-v2:0"
smallModel := "anthropic.claude-3-5-haiku-20241022-v1:0"
row := database.AIProvider{
Type: database.AiProviderTypeAnthropic,
Name: "anthropic-bedrock",
BaseUrl: "https://bedrock-runtime.us-west-2.amazonaws.com/",
}
settings := codersdk.AIProviderSettings{
Bedrock: &codersdk.AIProviderBedrockSettings{
Region: "us-west-2",
AccessKey: &accessKey,
AccessKeySecret: &secret,
Model: model,
SmallFastModel: smallModel,
},
}
got := bedrockConfigFromRow(row, settings)
require.NotNil(t, got)
assert.Equal(t, row.BaseUrl, got.BaseURL)
assert.Equal(t, "us-west-2", got.Region)
assert.Equal(t, accessKey, got.AccessKey)
assert.Equal(t, secret, got.AccessKeySecret)
assert.Equal(t, model, got.Model)
assert.Equal(t, smallModel, got.SmallFastModel)
})
t.Run("BedrockSettingsEmpty", func(t *testing.T) {
t.Parallel()
// A non-nil but zero-valued Bedrock settings blob should not
// produce a Bedrock config; the provider's generic BaseUrl is
// not a Bedrock detection signal.
row := database.AIProvider{
Type: database.AiProviderTypeAnthropic,
Name: "anthropic-empty-bedrock",
BaseUrl: "https://api.anthropic.com/",
}
settings := codersdk.AIProviderSettings{
Bedrock: &codersdk.AIProviderBedrockSettings{},
}
assert.Nil(t, bedrockConfigFromRow(row, settings))
})
}
// TestBuildProvidersSkipsBadRows exercises the skip-and-continue path
// directly: rows whose settings blob is malformed or whose type is not
// supported by the runtime builder are logged and excluded from the
// returned snapshot without surfacing a top-level error. The seed path
// filters most of these out before insert, so we bypass it and insert
// rows straight into the database via dbgen.
func TestBuildProvidersSkipsBadRows(t *testing.T) {
t.Parallel()
t.Run("CorruptSettings", func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
dbgen.AIProvider(t, db, database.AIProvider{
Type: database.AiProviderTypeAnthropic,
Name: "anthropic-broken",
BaseUrl: "https://api.anthropic.com/",
Settings: sql.NullString{String: "not-json", Valid: true},
})
providers, err := BuildProviders(ctx, db, codersdk.AIBridgeConfig{}, logger)
require.NoError(t, err)
assert.Empty(t, providers)
})
t.Run("UnsupportedType", func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
// Azure is a valid DB-level provider type but has no runtime
// builder yet; it must hit the default branch and be skipped.
dbgen.AIProvider(t, db, database.AIProvider{
Type: database.AiProviderTypeAzure,
Name: "azure-openai",
BaseUrl: "https://example.openai.azure.com/",
})
providers, err := BuildProviders(ctx, db, codersdk.AIBridgeConfig{}, logger)
require.NoError(t, err)
assert.Empty(t, providers)
})
t.Run("BadRowDoesNotBlockGoodRow", func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitShort)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
dbgen.AIProvider(t, db, database.AIProvider{
Type: database.AiProviderTypeAnthropic,
Name: "anthropic-broken",
BaseUrl: "https://api.anthropic.com/",
Settings: sql.NullString{String: "{not valid json", Valid: true},
})
good := dbgen.AIProvider(t, db, database.AIProvider{
Type: database.AiProviderTypeOpenai,
Name: "openai-good",
BaseUrl: "https://api.openai.com/",
})
dbgen.AIProviderKey(t, db, database.AIProviderKey{
ProviderID: good.ID,
APIKey: "sk-good",
})
providers, err := BuildProviders(ctx, db, codersdk.AIBridgeConfig{}, logger)
require.NoError(t, err)
require.Len(t, providers, 1)
assert.Equal(t, "openai-good", providers[0].Name())
})
}
func providerNames(providers []aibridge.Provider) []string {
+27 -19
View File
@@ -899,6 +899,32 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
if err != nil {
return xerrors.Errorf("remove secrets from deployment values: %w", err)
}
// AI provider DB initialization runs synchronously here so
// authorized reads complete before any background goroutine
// starts. Otherwise a mid-startup cancellation can interrupt
// them and fail startup. Seeding must also happen before
// newAPI so the aibridgeproxyd in the enterprise closure
// observes env-configured providers.
//
// This is a once-off operation; once completed, all providers
// will be sourced from the database.
if err := coderd.SeedAIProvidersFromEnv(
ctx,
options.Database,
vals.AI.BridgeConfig,
logger.Named("aibridge.envseed"),
); err != nil {
return xerrors.Errorf("seed ai providers from env: %w", err)
}
var aibridgeProviders []aibridge.Provider
if vals.AI.BridgeConfig.Enabled.Value() {
aibridgeProviders, err = BuildProviders(ctx, options.Database, vals.AI.BridgeConfig, logger.Named("aibridge.providers"))
if err != nil {
return xerrors.Errorf("build AI providers: %w", err)
}
}
telemetryReporter, err := telemetry.New(telemetry.Options{
Disabled: !vals.Telemetry.Enable.Value(),
BuiltinPostgres: builtinPostgres,
@@ -1006,18 +1032,6 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
notificationReportGenerator := reports.NewReportGenerator(ctx, logger.Named("notifications.report_generator"), options.Database, options.NotificationsEnqueuer, quartz.NewReal())
defer notificationReportGenerator.Close()
// Seed providers before newAPI so the aibridgeproxyd inside
// the enterprise closure observes env-configured providers
// at init.
if err := coderd.SeedAIProvidersFromEnv(
ctx,
options.Database,
vals.AI.BridgeConfig,
logger.Named("aibridge.envseed"),
); err != nil {
return xerrors.Errorf("seed ai providers from env: %w", err)
}
// We use a separate coderAPICloser so the Enterprise API
// can have its own close functions. This is cleaner
// than abstracting the Coder API itself.
@@ -1034,11 +1048,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
// unconditionally when the bridge feature is enabled by config so
// chatd can use it regardless of license entitlement.
if vals.AI.BridgeConfig.Enabled.Value() {
providers, err := BuildProviders(vals.AI.BridgeConfig)
if err != nil {
return xerrors.Errorf("build AI providers: %w", err)
}
aibridgeDaemon, err := newAIBridgeDaemon(coderAPI, providers)
aibridgeDaemon, err := newAIBridgeDaemon(coderAPI, aibridgeProviders)
if err != nil {
return xerrors.Errorf("create aibridged: %w", err)
}
@@ -3114,8 +3124,6 @@ func readAIProvidersForPrefix(logger slog.Logger, environ []string, prefix strin
}
case "BASE_URL":
provider.BaseURL = v.Value
case "DUMP_DIR":
provider.DumpDir = v.Value
case "BEDROCK_BASE_URL":
provider.BedrockBaseURL = v.Value
case "BEDROCK_REGION":
+51 -2
View File
@@ -10,8 +10,10 @@ import (
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/aibridge"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
"github.com/coder/serpent"
)
func TestReadAIProvidersFromEnv(t *testing.T) {
@@ -34,7 +36,6 @@ func TestReadAIProvidersFromEnv(t *testing.T) {
"CODER_AIBRIDGE_PROVIDER_0_NAME=anthropic-zdr",
"CODER_AIBRIDGE_PROVIDER_0_KEY=sk-ant-xxx",
"CODER_AIBRIDGE_PROVIDER_0_BASE_URL=https://api.anthropic.com/",
"CODER_AIBRIDGE_PROVIDER_0_DUMP_DIR=/tmp/aibridge-dump",
},
expected: []codersdk.AIProviderConfig{
{
@@ -42,7 +43,6 @@ func TestReadAIProvidersFromEnv(t *testing.T) {
Name: "anthropic-zdr",
Keys: []string{"sk-ant-xxx"},
BaseURL: "https://api.anthropic.com/",
DumpDir: "/tmp/aibridge-dump",
},
},
},
@@ -537,3 +537,52 @@ func TestValidateLegacyAIBridgeConfig(t *testing.T) {
})
}
}
func TestBuildAIProviderFromRowSetsAPIDumpDir(t *testing.T) {
t.Parallel()
const dumpDir = "/tmp/coder-aibridge-dumps"
tests := []struct {
name string
row database.AIProvider
}{
{
name: "OpenAI",
row: database.AIProvider{
Type: database.AiProviderTypeOpenai,
Name: "openai",
BaseUrl: "https://api.openai.com/",
},
},
{
name: "Anthropic",
row: database.AIProvider{
Type: database.AiProviderTypeAnthropic,
Name: "anthropic",
BaseUrl: "https://api.anthropic.com/",
},
},
{
name: "Copilot",
row: database.AIProvider{
Type: database.AiProviderTypeCopilot,
Name: "copilot",
BaseUrl: "https://api.githubcopilot.com/",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
provider, err := buildAIProviderFromRow(tt.row, nil, codersdk.AIBridgeConfig{
AllowBYOK: serpent.Bool(true),
APIDumpDir: serpent.String(dumpDir),
})
require.NoError(t, err)
assert.Equal(t, dumpDir, provider.APIDumpDir())
})
}
}
+6
View File
@@ -113,6 +113,12 @@ AI GATEWAY OPTIONS:
with AI budgets. "highest" selects the group with the largest spend
limit, and is currently the only supported value.
--ai-gateway-dump-dir string, $CODER_AI_GATEWAY_DUMP_DIR
Base directory for dumping AI Bridge request/response pairs to disk
for debugging. When set, each provider writes under a subdirectory
named after the provider. Sensitive headers are redacted. Leave empty
to disable.
--ai-gateway-allow-byok bool, $CODER_AI_GATEWAY_ALLOW_BYOK (default: true)
Allow users to provide their own LLM API keys or subscriptions. When
disabled, only centralized key authentication is permitted.
+5
View File
@@ -920,6 +920,11 @@ ai_gateway:
# X-Ai-Bridge-Actor-Metadata-Username (their username).
# (default: false, type: bool)
send_actor_headers: false
# Base directory for dumping AI Bridge request/response pairs to disk for
# debugging. When set, each provider writes under a subdirectory named after the
# provider. Sensitive headers are redacted. Leave empty to disable.
# (default: <unset>, type: string)
api_dump_dir: ""
# Allow users to provide their own LLM API keys or subscriptions. When disabled,
# only centralized key authentication is permitted.
# (default: true, type: bool)