From ef6ee2af68b24e3a30ec4f242add6aed7af16065 Mon Sep 17 00:00:00 2001 From: Danny Kopping Date: Fri, 22 May 2026 12:45:14 +0200 Subject: [PATCH] chore: tolerate empty providers at startup and log env seeds (#25605) Since AI Gateway is now enabled by default, and if the AI Gateway Proxy is enabled too it's possible the server can start without any configured providers. This would previously block startup, which is unacceptable. In an upstack PR we will handle reloading the providers at runtime, so the server needs to be able to start up even if it can't handle any proxy requests to AI Gateway. This change was necessitated because if there are providers configured in the environment they need to be seeded _before_ the proxy starts. --- cli/server.go | 24 ++--- coderd/ai_providers_migrate.go | 30 +++--- coderd/ai_providers_migrate_test.go | 92 +++++++++---------- enterprise/aibridgeproxyd/aibridgeproxyd.go | 16 ++-- .../aibridgeproxyd/aibridgeproxyd_test.go | 32 ++++--- 5 files changed, 91 insertions(+), 103 deletions(-) diff --git a/cli/server.go b/cli/server.go index c0f3d7db47..1b2350d931 100644 --- a/cli/server.go +++ b/cli/server.go @@ -1006,6 +1006,18 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd. notificationReportGenerator := reports.NewReportGenerator(ctx, logger.Named("notifications.report_generator"), options.Database, options.NotificationsEnqueuer, quartz.NewReal()) defer notificationReportGenerator.Close() + // Seed providers before newAPI so the aibridgeproxyd inside + // the enterprise closure observes env-configured providers + // at init. + if err := coderd.SeedAIProvidersFromEnv( + ctx, + options.Database, + vals.AI.BridgeConfig, + logger.Named("aibridge.envseed"), + ); err != nil { + return xerrors.Errorf("seed ai providers from env: %w", err) + } + // We use a separate coderAPICloser so the Enterprise API // can have its own close functions. This is cleaner // than abstracting the Coder API itself. @@ -1014,18 +1026,6 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd. return xerrors.Errorf("create coder API: %w", err) } - // Runs unconditionally so operators can seed providers via - // env without enabling the bridge or proxy features. - if err := coderd.SeedAIProvidersFromEnv( - ctx, - options.Database, - vals.AI.BridgeConfig, - options.Auditor, - logger.Named("aibridge.envseed"), - ); err != nil { - return xerrors.Errorf("seed ai providers from env: %w", err) - } - // In-memory aibridge daemon. Registered on coderd so chatd can // dispatch LLM requests via the in-process transport without // crossing the gated /api/v2/aibridge HTTP route. The HTTP route diff --git a/coderd/ai_providers_migrate.go b/coderd/ai_providers_migrate.go index 783c062986..aedd8855e7 100644 --- a/coderd/ai_providers_migrate.go +++ b/coderd/ai_providers_migrate.go @@ -15,7 +15,6 @@ import ( "cdr.dev/slog/v3" "github.com/coder/coder/v2/aibridge" aibridgeutils "github.com/coder/coder/v2/aibridge/utils" - "github.com/coder/coder/v2/coderd/audit" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/db2sdk" "github.com/coder/coder/v2/coderd/database/dbauthz" @@ -45,7 +44,6 @@ func SeedAIProvidersFromEnv( ctx context.Context, db database.Store, cfg codersdk.AIBridgeConfig, - auditor audit.Auditor, logger slog.Logger, ) error { desired, err := providersFromEnv(ctx, cfg, logger) @@ -178,25 +176,19 @@ func SeedAIProvidersFromEnv( } for _, row := range insertedProviders { - audit.BackgroundAudit(sysCtx, &audit.BackgroundAuditParams[database.AIProvider]{ - Audit: auditor, - Log: logger, - Action: database.AuditActionCreate, - New: row, - }) + logger.Info(sysCtx, "env-seeded ai provider", + slog.F("provider_id", row.ID), + slog.F("name", row.Name), + slog.F("type", row.Type), + slog.F("base_url", row.BaseUrl), + ) } for _, keyRow := range insertedKeys { - // Mask the plaintext key before it enters the audit pipeline; - // the audit policy on api_key relies on the masked rendering - // so plaintext never reaches a backend. - auditRow := keyRow - auditRow.APIKey = aibridgeutils.MaskSecret(auditRow.APIKey) - audit.BackgroundAudit(sysCtx, &audit.BackgroundAuditParams[database.AIProviderKey]{ - Audit: auditor, - Log: logger, - Action: database.AuditActionCreate, - New: auditRow, - }) + logger.Info(sysCtx, "env-seeded ai provider key", + slog.F("key_id", keyRow.ID), + slog.F("provider_id", keyRow.ProviderID), + slog.F("api_key", aibridgeutils.MaskSecret(keyRow.APIKey)), + ) } return nil } diff --git a/coderd/ai_providers_migrate_test.go b/coderd/ai_providers_migrate_test.go index 2599b6dffb..a989d126dc 100644 --- a/coderd/ai_providers_migrate_test.go +++ b/coderd/ai_providers_migrate_test.go @@ -1,14 +1,15 @@ package coderd_test import ( + "bytes" "testing" "github.com/stretchr/testify/require" "cdr.dev/slog/v3" + "cdr.dev/slog/v3/sloggers/sloghuman" "cdr.dev/slog/v3/sloggers/slogtest" "github.com/coder/coder/v2/coderd" - "github.com/coder/coder/v2/coderd/audit" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/codersdk" @@ -23,17 +24,14 @@ func TestSeedAIProvidersFromEnv(t *testing.T) { t.Parallel() db, _ := dbtestutil.NewDB(t) ctx := testutil.Context(t, testutil.WaitShort) - auditor := audit.NewMock() - err := coderd.SeedAIProvidersFromEnv(ctx, db, codersdk.AIBridgeConfig{}, auditor, testLogger(t)) + err := coderd.SeedAIProvidersFromEnv(ctx, db, codersdk.AIBridgeConfig{}, testLogger(t)) require.NoError(t, err) - require.Empty(t, auditor.AuditLogs()) }) t.Run("LegacyOpenAI", func(t *testing.T) { t.Parallel() db, _ := dbtestutil.NewDB(t) ctx := testutil.Context(t, testutil.WaitShort) - auditor := audit.NewMock() cfg := codersdk.AIBridgeConfig{ LegacyOpenAI: codersdk.AIBridgeOpenAIConfig{ @@ -41,7 +39,8 @@ func TestSeedAIProvidersFromEnv(t *testing.T) { Key: serpent.String("sk-legacy"), }, } - err := coderd.SeedAIProvidersFromEnv(ctx, db, cfg, auditor, testLogger(t)) + var firstSeedLogs bytes.Buffer + err := coderd.SeedAIProvidersFromEnv(ctx, db, cfg, capturedLogger(&firstSeedLogs)) require.NoError(t, err) // One row exists for "openai". @@ -57,16 +56,18 @@ func TestSeedAIProvidersFromEnv(t *testing.T) { require.Len(t, keys, 1) require.Equal(t, "sk-legacy", keys[0].APIKey) - // The first seed must have emitted audit entries for the new - // provider row and the key row. - require.GreaterOrEqual(t, len(auditor.AuditLogs()), 2) + // The seed emits one info line per inserted provider and one per + // inserted key, replacing the audit entries that used to record + // the same events. + require.Contains(t, firstSeedLogs.String(), "env-seeded ai provider") + require.Contains(t, firstSeedLogs.String(), "env-seeded ai provider key") - // Re-running with the same config is a no-op (no errors, no - // new audit logs because the row matches). - auditor.ResetLogs() - err = coderd.SeedAIProvidersFromEnv(ctx, db, cfg, auditor, testLogger(t)) + // Re-running with the same config is a no-op and emits no new + // env-seed log lines. + var rerunLogs bytes.Buffer + err = coderd.SeedAIProvidersFromEnv(ctx, db, cfg, capturedLogger(&rerunLogs)) require.NoError(t, err) - require.Empty(t, auditor.AuditLogs()) + require.NotContains(t, rerunLogs.String(), "env-seeded ai provider") // Verify there's still only one row and one key. all, err := db.GetAIProviders(ctx, database.GetAIProvidersParams{}) @@ -81,7 +82,6 @@ func TestSeedAIProvidersFromEnv(t *testing.T) { t.Parallel() db, _ := dbtestutil.NewDB(t) ctx := testutil.Context(t, testutil.WaitShort) - auditor := audit.NewMock() cfg := codersdk.AIBridgeConfig{ LegacyOpenAI: codersdk.AIBridgeOpenAIConfig{ @@ -89,18 +89,18 @@ func TestSeedAIProvidersFromEnv(t *testing.T) { Key: serpent.String("sk-original"), }, } - require.NoError(t, coderd.SeedAIProvidersFromEnv(ctx, db, cfg, auditor, testLogger(t))) + require.NoError(t, coderd.SeedAIProvidersFromEnv(ctx, db, cfg, testLogger(t))) // Changing the API key alone does NOT count as drift: keys // live in a separate table and operators rotate them via the // API. Only changes to non-credential provider-level fields // (base_url, type, Bedrock region/model) trip the drift check. cfg.LegacyOpenAI.Key = serpent.String("sk-rotated") - require.NoError(t, coderd.SeedAIProvidersFromEnv(ctx, db, cfg, auditor, testLogger(t))) + require.NoError(t, coderd.SeedAIProvidersFromEnv(ctx, db, cfg, testLogger(t))) // Changing the base URL is real drift. cfg.LegacyOpenAI.BaseURL = serpent.String("https://api.openai.com/v2") - err := coderd.SeedAIProvidersFromEnv(ctx, db, cfg, auditor, testLogger(t)) + err := coderd.SeedAIProvidersFromEnv(ctx, db, cfg, testLogger(t)) require.Error(t, err) require.Contains(t, err.Error(), "differs from the current environment configuration") }) @@ -109,7 +109,6 @@ func TestSeedAIProvidersFromEnv(t *testing.T) { t.Parallel() db, _ := dbtestutil.NewDB(t) ctx := testutil.Context(t, testutil.WaitShort) - auditor := audit.NewMock() cfg := codersdk.AIBridgeConfig{ LegacyBedrock: codersdk.AIBridgeBedrockConfig{ @@ -119,19 +118,19 @@ func TestSeedAIProvidersFromEnv(t *testing.T) { Model: serpent.String("anthropic.claude-3-5-sonnet"), }, } - require.NoError(t, coderd.SeedAIProvidersFromEnv(ctx, db, cfg, auditor, testLogger(t))) + require.NoError(t, coderd.SeedAIProvidersFromEnv(ctx, db, cfg, testLogger(t))) // Rotating the Bedrock access key and secret in env must NOT // trip the drift check: they're credentials, equivalent to // bearer API keys, and operators rotate them via the API. cfg.LegacyBedrock.AccessKey = serpent.String("AKIA-rotated") cfg.LegacyBedrock.AccessKeySecret = serpent.String("secret-rotated") - require.NoError(t, coderd.SeedAIProvidersFromEnv(ctx, db, cfg, auditor, testLogger(t))) + require.NoError(t, coderd.SeedAIProvidersFromEnv(ctx, db, cfg, testLogger(t))) // Changing the Bedrock region (a non-credential field) is // real drift. cfg.LegacyBedrock.Region = serpent.String("us-west-2") - err := coderd.SeedAIProvidersFromEnv(ctx, db, cfg, auditor, testLogger(t)) + err := coderd.SeedAIProvidersFromEnv(ctx, db, cfg, testLogger(t)) require.Error(t, err) require.Contains(t, err.Error(), "differs from the current environment configuration") }) @@ -140,7 +139,6 @@ func TestSeedAIProvidersFromEnv(t *testing.T) { t.Parallel() db, _ := dbtestutil.NewDB(t) ctx := testutil.Context(t, testutil.WaitShort) - auditor := audit.NewMock() // Bedrock fields without an Anthropic key produce a Bedrock- // authenticated Anthropic provider with no bearer keys. @@ -153,7 +151,7 @@ func TestSeedAIProvidersFromEnv(t *testing.T) { SmallFastModel: serpent.String("anthropic.claude-3-5-haiku"), }, } - require.NoError(t, coderd.SeedAIProvidersFromEnv(ctx, db, cfg, auditor, testLogger(t))) + require.NoError(t, coderd.SeedAIProvidersFromEnv(ctx, db, cfg, testLogger(t))) row, err := db.GetAIProviderByName(ctx, "anthropic") require.NoError(t, err) @@ -172,7 +170,6 @@ func TestSeedAIProvidersFromEnv(t *testing.T) { t.Parallel() db, _ := dbtestutil.NewDB(t) ctx := testutil.Context(t, testutil.WaitShort) - auditor := audit.NewMock() // LegacyBedrock.Model and LegacyBedrock.SmallFastModel both // have serpent-level defaults that are always populated in a @@ -189,7 +186,7 @@ func TestSeedAIProvidersFromEnv(t *testing.T) { cfg := dv.AI.BridgeConfig cfg.LegacyAnthropic.Key = serpent.String("sk-ant-only") - require.NoError(t, coderd.SeedAIProvidersFromEnv(ctx, db, cfg, auditor, testLogger(t))) + require.NoError(t, coderd.SeedAIProvidersFromEnv(ctx, db, cfg, testLogger(t))) row, err := db.GetAIProviderByName(ctx, "anthropic") require.NoError(t, err) @@ -204,7 +201,6 @@ func TestSeedAIProvidersFromEnv(t *testing.T) { t.Parallel() db, _ := dbtestutil.NewDB(t) ctx := testutil.Context(t, testutil.WaitShort) - auditor := audit.NewMock() // Any non-empty Bedrock field signals Bedrock auth. AWS // credentials are optional because Bedrock can authenticate @@ -215,7 +211,7 @@ func TestSeedAIProvidersFromEnv(t *testing.T) { Model: serpent.String("anthropic.claude-3-5-sonnet"), }, } - require.NoError(t, coderd.SeedAIProvidersFromEnv(ctx, db, cfg, auditor, testLogger(t))) + require.NoError(t, coderd.SeedAIProvidersFromEnv(ctx, db, cfg, testLogger(t))) row, err := db.GetAIProviderByName(ctx, "anthropic") require.NoError(t, err) @@ -231,7 +227,6 @@ func TestSeedAIProvidersFromEnv(t *testing.T) { t.Parallel() db, _ := dbtestutil.NewDB(t) ctx := testutil.Context(t, testutil.WaitShort) - auditor := audit.NewMock() cfg := codersdk.AIBridgeConfig{ LegacyBedrock: codersdk.AIBridgeBedrockConfig{ @@ -241,7 +236,7 @@ func TestSeedAIProvidersFromEnv(t *testing.T) { Model: serpent.String("anthropic.claude-3-5-sonnet"), }, } - require.NoError(t, coderd.SeedAIProvidersFromEnv(ctx, db, cfg, auditor, testLogger(t))) + require.NoError(t, coderd.SeedAIProvidersFromEnv(ctx, db, cfg, testLogger(t))) row, err := db.GetAIProviderByName(ctx, "anthropic") require.NoError(t, err) require.Contains(t, row.Settings.String, "us-east-1") @@ -258,7 +253,6 @@ func TestSeedAIProvidersFromEnv(t *testing.T) { t.Parallel() db, _ := dbtestutil.NewDB(t) ctx := testutil.Context(t, testutil.WaitShort) - auditor := audit.NewMock() cfg := codersdk.AIBridgeConfig{ Providers: []codersdk.AIProviderConfig{ @@ -276,7 +270,7 @@ func TestSeedAIProvidersFromEnv(t *testing.T) { }, }, } - require.NoError(t, coderd.SeedAIProvidersFromEnv(ctx, db, cfg, auditor, testLogger(t))) + require.NoError(t, coderd.SeedAIProvidersFromEnv(ctx, db, cfg, testLogger(t))) oa, err := db.GetAIProviderByName(ctx, "primary-openai") require.NoError(t, err) @@ -303,7 +297,6 @@ func TestSeedAIProvidersFromEnv(t *testing.T) { t.Parallel() db, _ := dbtestutil.NewDB(t) ctx := testutil.Context(t, testutil.WaitShort) - auditor := audit.NewMock() cfg := codersdk.AIBridgeConfig{ Providers: []codersdk.AIProviderConfig{ @@ -318,7 +311,7 @@ func TestSeedAIProvidersFromEnv(t *testing.T) { }, }, } - require.NoError(t, coderd.SeedAIProvidersFromEnv(ctx, db, cfg, auditor, testLogger(t))) + require.NoError(t, coderd.SeedAIProvidersFromEnv(ctx, db, cfg, testLogger(t))) row, err := db.GetAIProviderByName(ctx, "bedrock-anthropic") require.NoError(t, err) @@ -334,7 +327,6 @@ func TestSeedAIProvidersFromEnv(t *testing.T) { t.Parallel() db, _ := dbtestutil.NewDB(t) ctx := testutil.Context(t, testutil.WaitShort) - auditor := audit.NewMock() cfg := codersdk.AIBridgeConfig{ LegacyOpenAI: codersdk.AIBridgeOpenAIConfig{ @@ -350,7 +342,7 @@ func TestSeedAIProvidersFromEnv(t *testing.T) { }, }, } - err := coderd.SeedAIProvidersFromEnv(ctx, db, cfg, auditor, testLogger(t)) + err := coderd.SeedAIProvidersFromEnv(ctx, db, cfg, testLogger(t)) require.Error(t, err) require.Contains(t, err.Error(), "conflicts") }) @@ -359,7 +351,6 @@ func TestSeedAIProvidersFromEnv(t *testing.T) { t.Parallel() db, _ := dbtestutil.NewDB(t) ctx := testutil.Context(t, testutil.WaitShort) - auditor := audit.NewMock() cfg := codersdk.AIBridgeConfig{ Providers: []codersdk.AIProviderConfig{ @@ -370,7 +361,7 @@ func TestSeedAIProvidersFromEnv(t *testing.T) { }, }, } - err := coderd.SeedAIProvidersFromEnv(ctx, db, cfg, auditor, testLogger(t)) + err := coderd.SeedAIProvidersFromEnv(ctx, db, cfg, testLogger(t)) require.Error(t, err) require.Contains(t, err.Error(), "invalid AI provider name") }) @@ -379,7 +370,6 @@ func TestSeedAIProvidersFromEnv(t *testing.T) { t.Parallel() db, _ := dbtestutil.NewDB(t) ctx := testutil.Context(t, testutil.WaitShort) - auditor := audit.NewMock() cfg := codersdk.AIBridgeConfig{ Providers: []codersdk.AIProviderConfig{ @@ -396,7 +386,7 @@ func TestSeedAIProvidersFromEnv(t *testing.T) { }, }, } - require.NoError(t, coderd.SeedAIProvidersFromEnv(ctx, db, cfg, auditor, testLogger(t))) + require.NoError(t, coderd.SeedAIProvidersFromEnv(ctx, db, cfg, testLogger(t))) all, err := db.GetAIProviders(ctx, database.GetAIProvidersParams{}) require.NoError(t, err) @@ -408,7 +398,6 @@ func TestSeedAIProvidersFromEnv(t *testing.T) { t.Parallel() db, _ := dbtestutil.NewDB(t) ctx := testutil.Context(t, testutil.WaitShort) - auditor := audit.NewMock() cfg := codersdk.AIBridgeConfig{ LegacyOpenAI: codersdk.AIBridgeOpenAIConfig{ @@ -416,7 +405,7 @@ func TestSeedAIProvidersFromEnv(t *testing.T) { Key: serpent.String("sk-original"), }, } - require.NoError(t, coderd.SeedAIProvidersFromEnv(ctx, db, cfg, auditor, testLogger(t))) + require.NoError(t, coderd.SeedAIProvidersFromEnv(ctx, db, cfg, testLogger(t))) row, err := db.GetAIProviderByName(ctx, "openai") require.NoError(t, err) @@ -424,7 +413,7 @@ func TestSeedAIProvidersFromEnv(t *testing.T) { // Re-run seed; the soft-deleted row should remain soft-deleted // and no new row should be created. - require.NoError(t, coderd.SeedAIProvidersFromEnv(ctx, db, cfg, auditor, testLogger(t))) + require.NoError(t, coderd.SeedAIProvidersFromEnv(ctx, db, cfg, testLogger(t))) all, err := db.GetAIProviders(ctx, database.GetAIProvidersParams{}) require.NoError(t, err) @@ -435,7 +424,6 @@ func TestSeedAIProvidersFromEnv(t *testing.T) { t.Parallel() db, _ := dbtestutil.NewDB(t) ctx := testutil.Context(t, testutil.WaitShort) - auditor := audit.NewMock() cfg := codersdk.AIBridgeConfig{ LegacyOpenAI: codersdk.AIBridgeOpenAIConfig{ @@ -443,7 +431,7 @@ func TestSeedAIProvidersFromEnv(t *testing.T) { Key: serpent.String("sk-original"), }, } - require.NoError(t, coderd.SeedAIProvidersFromEnv(ctx, db, cfg, auditor, testLogger(t))) + require.NoError(t, coderd.SeedAIProvidersFromEnv(ctx, db, cfg, testLogger(t))) row, err := db.GetAIProviderByName(ctx, "openai") require.NoError(t, err) @@ -452,7 +440,7 @@ func TestSeedAIProvidersFromEnv(t *testing.T) { // keys on a row that already exists; the new key is only // installed via the API/CRUD layer in this flow. cfg.LegacyOpenAI.Key = serpent.String("sk-rotated") - require.NoError(t, coderd.SeedAIProvidersFromEnv(ctx, db, cfg, auditor, testLogger(t))) + require.NoError(t, coderd.SeedAIProvidersFromEnv(ctx, db, cfg, testLogger(t))) keys, err := db.GetAIProviderKeysByProviderID(ctx, row.ID) require.NoError(t, err) @@ -464,7 +452,6 @@ func TestSeedAIProvidersFromEnv(t *testing.T) { t.Parallel() db, _ := dbtestutil.NewDB(t) ctx := testutil.Context(t, testutil.WaitShort) - auditor := audit.NewMock() // Two entries under the same name with identical canonical // fields are deduplicated silently. @@ -484,7 +471,7 @@ func TestSeedAIProvidersFromEnv(t *testing.T) { }, }, } - require.NoError(t, coderd.SeedAIProvidersFromEnv(ctx, db, cfg, auditor, testLogger(t))) + require.NoError(t, coderd.SeedAIProvidersFromEnv(ctx, db, cfg, testLogger(t))) all, err := db.GetAIProviders(ctx, database.GetAIProvidersParams{}) require.NoError(t, err) @@ -495,7 +482,6 @@ func TestSeedAIProvidersFromEnv(t *testing.T) { t.Parallel() db, _ := dbtestutil.NewDB(t) ctx := testutil.Context(t, testutil.WaitShort) - auditor := audit.NewMock() // Same name, different canonical fields: must be rejected. cfg := codersdk.AIBridgeConfig{ @@ -514,7 +500,7 @@ func TestSeedAIProvidersFromEnv(t *testing.T) { }, }, } - err := coderd.SeedAIProvidersFromEnv(ctx, db, cfg, auditor, testLogger(t)) + err := coderd.SeedAIProvidersFromEnv(ctx, db, cfg, testLogger(t)) require.Error(t, err) require.Contains(t, err.Error(), "conflicting fields") }) @@ -524,3 +510,9 @@ func testLogger(t *testing.T) slog.Logger { t.Helper() return slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) } + +// capturedLogger returns a logger that writes structured records to buf, +// for tests that assert on log output instead of audit-table emissions. +func capturedLogger(buf *bytes.Buffer) slog.Logger { + return slog.Make(sloghuman.Sink(buf)).Leveled(slog.LevelDebug) +} diff --git a/enterprise/aibridgeproxyd/aibridgeproxyd.go b/enterprise/aibridgeproxyd/aibridgeproxyd.go index d796d36dbb..c32c5c41c5 100644 --- a/enterprise/aibridgeproxyd/aibridgeproxyd.go +++ b/enterprise/aibridgeproxyd/aibridgeproxyd.go @@ -258,24 +258,26 @@ func New(ctx context.Context, logger slog.Logger, opts Options) (*Server, error) allowedPorts = []string{"80", "443"} } - if len(opts.DomainAllowlist) == 0 { - return nil, xerrors.New("domain allow list is required") - } + // An empty allowlist is permitted so the server can boot before any + // ai_providers row exists; every intercept attempt is then rejected + // until providers are configured. + // TODO: refresh the allowlist when ai_providers changes so a restart + // is not required after the first provider is configured. mitmHosts, err := convertDomainsToHosts(opts.DomainAllowlist, allowedPorts) if err != nil { return nil, xerrors.Errorf("invalid domain allowlist: %w", err) } - if len(mitmHosts) == 0 { - return nil, xerrors.New("domain allowlist is empty, at least one domain is required") - } if opts.AIBridgeProviderFromHost == nil { return nil, xerrors.New("AIBridgeProviderFromHost is required") } aibridgeProviderFromHost := opts.AIBridgeProviderFromHost - // Validate that all allowlisted domains have correct aibridge provider mappings. for _, domain := range opts.DomainAllowlist { + domain = strings.TrimSpace(strings.ToLower(domain)) + if domain == "" { + continue + } if aibridgeProviderFromHost(domain) == "" { return nil, xerrors.Errorf("domain %q is in allowlist but has no provider mapping", domain) } diff --git a/enterprise/aibridgeproxyd/aibridgeproxyd_test.go b/enterprise/aibridgeproxyd/aibridgeproxyd_test.go index bd4ac07089..334fd289f8 100644 --- a/enterprise/aibridgeproxyd/aibridgeproxyd_test.go +++ b/enterprise/aibridgeproxyd/aibridgeproxyd_test.go @@ -678,14 +678,15 @@ func TestNew(t *testing.T) { mitmCertFile, mitmKeyFile := getSharedTestMITMCert(t) logger := slogtest.Make(t, nil) - _, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{ - ListenAddr: ":0", - CoderAccessURL: "http://localhost:3000", - MITMCertFile: mitmCertFile, - MITMKeyFile: mitmKeyFile, + srv, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{ + ListenAddr: ":0", + CoderAccessURL: "http://localhost:3000", + MITMCertFile: mitmCertFile, + MITMKeyFile: mitmKeyFile, + AIBridgeProviderFromHost: testProviderFromHost, }) - require.Error(t, err) - require.Contains(t, err.Error(), "domain allow list is required") + require.NoError(t, err) + t.Cleanup(func() { _ = srv.Close() }) }) t.Run("EmptyDomainAllowlist", func(t *testing.T) { @@ -694,15 +695,16 @@ func TestNew(t *testing.T) { mitmCertFile, mitmKeyFile := getSharedTestMITMCert(t) logger := slogtest.Make(t, nil) - _, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{ - ListenAddr: ":0", - CoderAccessURL: "http://localhost:3000", - MITMCertFile: mitmCertFile, - MITMKeyFile: mitmKeyFile, - DomainAllowlist: []string{""}, + srv, err := aibridgeproxyd.New(t.Context(), logger, aibridgeproxyd.Options{ + ListenAddr: ":0", + CoderAccessURL: "http://localhost:3000", + MITMCertFile: mitmCertFile, + MITMKeyFile: mitmKeyFile, + DomainAllowlist: []string{""}, + AIBridgeProviderFromHost: testProviderFromHost, }) - require.Error(t, err) - require.Contains(t, err.Error(), "domain allowlist is empty, at least one domain is required") + require.NoError(t, err) + t.Cleanup(func() { _ = srv.Close() }) }) t.Run("InvalidDomainAllowlist", func(t *testing.T) {