From 79e007cf30d6a95d2d216206c5100108bcaf0440 Mon Sep 17 00:00:00 2001 From: Danny Kopping Date: Wed, 27 May 2026 11:58:43 +0200 Subject: [PATCH] feat: hot-reload aibridged and aibridgeproxyd providers on DB changes (#25673) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Previously the in-process aibridge daemon and the enterprise aibridgeproxy daemon both snapshotted their provider routing once at boot. Any `ai_providers` or `ai_provider_keys` mutation required a restart for either to pick it up. Add an `ai_providers_changed` pubsub channel that the CRUD handlers publish on after Create / Update / Delete. Both daemons subscribe: - **aibridged** rebuilds its `[]aibridge.Provider` snapshot via `BuildProviders` and swaps it into the pool atomically. Inflight requests keep serving against the bridge they already acquired; new acquires build against the new snapshot. Per-provider construction errors stay scoped to the offending row. - **aibridgeproxyd** rebuilds its routing snapshot from `GetAIProviders` and swaps the host→provider map atomically. The MITM listener picks up new providers without restart. DB read for aibridgeproxyd uses the existing `AsAIProviderMetadataReader` subject for routing-only access. --- cli/aibridged.go | 54 ++- cli/server.go | 4 +- coderd/ai_providers.go | 19 + coderd/ai_providers_pubsub_test.go | 62 +++ coderd/aibridged/aibridgedmock/poolmock.go | 13 + coderd/aibridged/pool.go | 92 ++++- coderd/aibridged/pool_test.go | 215 ++++++++++ coderd/aibridged/reload.go | 50 +++ coderd/aibridged/reload_test.go | 133 ++++++ coderd/pubsub/aiproviderschangedevent.go | 11 + enterprise/aibridgeproxyd/aibridgeproxyd.go | 150 ++++--- .../aibridgeproxyd/aibridgeproxyd_test.go | 8 + enterprise/aibridgeproxyd/reload.go | 124 ++++++ .../aibridgeproxyd/reload_internal_test.go | 147 +++++++ enterprise/aibridgeproxyd/reload_test.go | 379 ++++++++++++++++++ enterprise/cli/aibridgeproxyd.go | 112 +++--- .../cli/aibridgeproxyd_internal_test.go | 72 ---- enterprise/cli/server.go | 27 +- enterprise/coderd/aibridge_reload_test.go | 233 +++++++++++ 19 files changed, 1678 insertions(+), 227 deletions(-) create mode 100644 coderd/ai_providers_pubsub_test.go create mode 100644 coderd/aibridged/reload.go create mode 100644 coderd/aibridged/reload_test.go create mode 100644 coderd/pubsub/aiproviderschangedevent.go create mode 100644 enterprise/aibridgeproxyd/reload.go create mode 100644 enterprise/aibridgeproxyd/reload_internal_test.go create mode 100644 enterprise/aibridgeproxyd/reload_test.go delete mode 100644 enterprise/cli/aibridgeproxyd_internal_test.go create mode 100644 enterprise/coderd/aibridge_reload_test.go diff --git a/cli/aibridged.go b/cli/aibridged.go index 8bb21a8cbf..50e2c35d5c 100644 --- a/cli/aibridged.go +++ b/cli/aibridged.go @@ -24,7 +24,12 @@ import ( "github.com/coder/quartz" ) -func newAIBridgeDaemon(coderAPI *coderd.API, providers []aibridge.Provider) (*aibridged.Server, error) { +// newAIBridgeDaemon constructs the in-memory aibridge daemon and wires +// up a subscription that hot-reloads the provider pool from the +// database on every ai_providers change event. The returned unsubscribe +// function tears down the subscription; callers must invoke it +// alongside Server.Close on shutdown. +func newAIBridgeDaemon(coderAPI *coderd.API, providers []aibridge.Provider, cfg codersdk.AIBridgeConfig) (*aibridged.Server, func(), error) { ctx := context.Background() coderAPI.Logger.Debug(ctx, "starting in-memory aibridge daemon") @@ -37,7 +42,25 @@ func newAIBridgeDaemon(coderAPI *coderd.API, providers []aibridge.Provider) (*ai // Create pool for reusable stateful [aibridge.RequestBridge] instances (one per user). pool, err := aibridged.NewCachedBridgePool(aibridged.DefaultPoolOptions, providers, logger.Named("pool"), metrics, tracer) // TODO: configurable size. if err != nil { - return nil, xerrors.Errorf("create request pool: %w", err) + return nil, nil, xerrors.Errorf("create request pool: %w", err) + } + + // Subscribe to ai_providers change events so the pool tracks the + // database without a restart. The boot-time `providers` snapshot + // derives from env config and serves as a fallback if the database + // load fails inside the reloader. + reloader := &poolDBReloader{ + pool: pool, + db: coderAPI.Database, + cfg: cfg, + logger: logger.Named("provider-loader"), + } + unsubscribe, err := aibridged.SubscribeProviderReload(ctx, coderAPI.Pubsub, reloader, logger.Named("provider-reload")) + if err != nil { + // Pool is still usable with the boot-time snapshot; subscription + // failure is logged but not fatal so the daemon still serves. + logger.Warn(ctx, "subscribe to ai providers change channel", slog.Error(err)) + unsubscribe = func() {} } // Create daemon. @@ -45,9 +68,32 @@ func newAIBridgeDaemon(coderAPI *coderd.API, providers []aibridge.Provider) (*ai return coderAPI.CreateInMemoryAIBridgeServer(dialCtx) }, logger, tracer) if err != nil { - return nil, xerrors.Errorf("start in-memory aibridge daemon: %w", err) + unsubscribe() + return nil, nil, xerrors.Errorf("start in-memory aibridge daemon: %w", err) } - return srv, nil + return srv, unsubscribe, nil +} + +// poolDBReloader implements [aibridged.ProviderReloader] by loading +// the live provider set from the database and forwarding it to the +// pool. +type poolDBReloader struct { + pool *aibridged.CachedBridgePool + db database.Store + cfg codersdk.AIBridgeConfig + logger slog.Logger +} + +func (r *poolDBReloader) Reload(ctx context.Context) error { + providers, err := BuildProviders(ctx, r.db, r.cfg, r.logger) + if err != nil { + // Keep the previous snapshot in place: dropping all providers + // because the DB read failed would compound the visible failure + // mode beyond the operator's actual misconfiguration. + return xerrors.Errorf("load ai providers from database: %w", err) + } + r.pool.ReplaceProviders(providers) + return nil } // BuildProviders loads every enabled ai_providers row, attaches its diff --git a/cli/server.go b/cli/server.go index 3fed22aeb5..3e3dd0c643 100644 --- a/cli/server.go +++ b/cli/server.go @@ -1046,7 +1046,8 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd. if err != nil { return xerrors.Errorf("build AI providers: %w", err) } - aibridgeDaemon, err = newAIBridgeDaemon(coderAPI, aibridgeProviders) + var unsubscribeProviderReload func() + aibridgeDaemon, unsubscribeProviderReload, err = newAIBridgeDaemon(coderAPI, aibridgeProviders, vals.AI.BridgeConfig) if err != nil { return xerrors.Errorf("create aibridged: %w", err) } @@ -1055,6 +1056,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd. // daemon does not affect in-flight requests but is needed to // release pool/recorder resources at shutdown. defer aibridgeDaemon.Close() + defer unsubscribeProviderReload() } if vals.Prometheus.Enable { diff --git a/coderd/ai_providers.go b/coderd/ai_providers.go index d791cf9470..dd8e4c00d3 100644 --- a/coderd/ai_providers.go +++ b/coderd/ai_providers.go @@ -21,6 +21,7 @@ import ( "github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/httpmw" + coderpubsub "github.com/coder/coder/v2/coderd/pubsub" "github.com/coder/coder/v2/coderd/util/ptr" "github.com/coder/coder/v2/codersdk" ) @@ -235,6 +236,7 @@ func (api *API) aiProvidersCreate(rw http.ResponseWriter, r *http.Request) { aReq.New = row auditAIProviderKeyChanges(ctx, r, *auditor, api.Logger, aiProviderKeyChanges{Added: keys}) + api.publishAIProvidersChanged(ctx) sdk, err := db2sdk.AIProvider(row, keys) if err != nil { @@ -400,6 +402,7 @@ func (api *API) aiProvidersUpdate(rw http.ResponseWriter, r *http.Request) { } auditAIProviderKeyChanges(ctx, r, *auditor, api.Logger, keyChanges) + api.publishAIProvidersChanged(ctx) sdk, err := db2sdk.AIProvider(updated, keys) if err != nil { @@ -453,9 +456,25 @@ func (api *API) aiProvidersDelete(rw http.ResponseWriter, r *http.Request) { return } + api.publishAIProvidersChanged(ctx) + rw.WriteHeader(http.StatusNoContent) } +// publishAIProvidersChanged notifies subscribers (aibridged, +// aibridgeproxyd) that the live provider set changed and they should +// refetch from the database. Pubsub failures are logged but not +// propagated: subscribers refresh authoritatively from the DB, so a +// dropped notification only delays convergence. +func (api *API) publishAIProvidersChanged(ctx context.Context) { + if api.Pubsub == nil { + return + } + if err := api.Pubsub.Publish(coderpubsub.AIProvidersChangedChannel, nil); err != nil { + api.Logger.Warn(ctx, "publish ai providers changed event", slog.Error(err)) + } +} + // errBedrockRejectsAPIKeys is the sentinel returned from inside the // update transaction when a caller attempts to attach api_keys to a // Bedrock-typed provider; the outer handler translates it into a 400. diff --git a/coderd/ai_providers_pubsub_test.go b/coderd/ai_providers_pubsub_test.go new file mode 100644 index 0000000000..808ac29c7c --- /dev/null +++ b/coderd/ai_providers_pubsub_test.go @@ -0,0 +1,62 @@ +package coderd_test + +import ( + "context" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/coderd/coderdtest" + coderpubsub "github.com/coder/coder/v2/coderd/pubsub" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/testutil" +) + +// TestAIProvidersChangedPubsub asserts that the CRUD handlers publish +// on AIProvidersChangedChannel for the operations that affect the +// runtime provider set. Subscribers (aibridged, aibridgeproxyd) depend +// on these notifications to trigger their pool reload. +// +// The handlers publish best-effort and the payload is empty, so we +// assert "at least one event per mutation" via a counter. +func TestAIProvidersChangedPubsub(t *testing.T) { + t.Parallel() + + client, _, api := coderdtest.NewWithAPI(t, nil) + _ = coderdtest.CreateFirstUser(t, client) + ctx := testutil.Context(t, testutil.WaitLong) + + var count atomic.Int64 + unsubscribe, err := api.Pubsub.Subscribe(coderpubsub.AIProvidersChangedChannel, func(_ context.Context, _ []byte) { + count.Add(1) + }) + require.NoError(t, err) + t.Cleanup(unsubscribe) + + // Create. + req := codersdk.CreateAIProviderRequest{ + Type: codersdk.AIProviderTypeOpenAI, + Name: "pubsub-openai", + Enabled: true, + BaseURL: "https://api.openai.com/v1/", + APIKeys: []string{"k1"}, + } + //nolint:gocritic // Owner role is the audience for this endpoint. + created, err := client.CreateAIProvider(ctx, req) + require.NoError(t, err) + testutil.Eventually(ctx, t, func(_ context.Context) bool { return count.Load() >= 1 }, testutil.IntervalFast) + + // Update. + newKey := "k2" + _, err = client.UpdateAIProvider(ctx, created.ID.String(), codersdk.UpdateAIProviderRequest{ + APIKeys: &[]codersdk.AIProviderKeyMutation{{APIKey: &newKey}}, + }) + require.NoError(t, err) + testutil.Eventually(ctx, t, func(_ context.Context) bool { return count.Load() >= 2 }, testutil.IntervalFast) + + // Delete. + err = client.DeleteAIProvider(ctx, created.ID.String()) + require.NoError(t, err) + testutil.Eventually(ctx, t, func(_ context.Context) bool { return count.Load() >= 3 }, testutil.IntervalFast) +} diff --git a/coderd/aibridged/aibridgedmock/poolmock.go b/coderd/aibridged/aibridgedmock/poolmock.go index ac3562b795..36c4d4775c 100644 --- a/coderd/aibridged/aibridgedmock/poolmock.go +++ b/coderd/aibridged/aibridgedmock/poolmock.go @@ -14,6 +14,7 @@ import ( http "net/http" reflect "reflect" + aibridge "github.com/coder/coder/v2/aibridge" aibridged "github.com/coder/coder/v2/coderd/aibridged" gomock "go.uber.org/mock/gomock" ) @@ -57,6 +58,18 @@ func (mr *MockPoolerMockRecorder) Acquire(ctx, req, clientFn, mcpBootstrapper an return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Acquire", reflect.TypeOf((*MockPooler)(nil).Acquire), ctx, req, clientFn, mcpBootstrapper) } +// ReplaceProviders mocks base method. +func (m *MockPooler) ReplaceProviders(providers []aibridge.Provider) { + m.ctrl.T.Helper() + m.ctrl.Call(m, "ReplaceProviders", providers) +} + +// ReplaceProviders indicates an expected call of ReplaceProviders. +func (mr *MockPoolerMockRecorder) ReplaceProviders(providers any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReplaceProviders", reflect.TypeOf((*MockPooler)(nil).ReplaceProviders), providers) +} + // Shutdown mocks base method. func (m *MockPooler) Shutdown(ctx context.Context) error { m.ctrl.T.Helper() diff --git a/coderd/aibridged/pool.go b/coderd/aibridged/pool.go index 0468acb582..3b7e60955c 100644 --- a/coderd/aibridged/pool.go +++ b/coderd/aibridged/pool.go @@ -3,7 +3,10 @@ package aibridged import ( "context" "net/http" + "slices" + "strconv" "sync" + "sync/atomic" "time" "github.com/dgraph-io/ristretto/v2" @@ -26,6 +29,9 @@ const ( // One [*aibridge.RequestBridge] instance is created per given key. type Pooler interface { Acquire(ctx context.Context, req Request, clientFn ClientFunc, mcpBootstrapper MCPProxyBuilder) (http.Handler, error) + // ReplaceProviders swaps the providers used to construct future + // RequestBridge instances and clears the cache. + ReplaceProviders(providers []aibridge.Provider) Shutdown(ctx context.Context) error } @@ -46,10 +52,12 @@ var DefaultPoolOptions = PoolOptions{MaxItems: 5000, TTL: time.Minute * 15} var _ Pooler = &CachedBridgePool{} type CachedBridgePool struct { - cache *ristretto.Cache[string, *aibridge.RequestBridge] - providers []aibridge.Provider - logger slog.Logger - options PoolOptions + cache *ristretto.Cache[string, *aibridge.RequestBridge] + // providers is the live provider set used by new RequestBridge instances. + providers atomic.Pointer[[]aibridge.Provider] + providerVersion atomic.Int64 + logger slog.Logger + options PoolOptions singleflight *singleflight.Group[string, *aibridge.RequestBridge] @@ -71,13 +79,16 @@ func NewCachedBridgePool(options PoolOptions, providers []aibridge.Provider, log if item == nil || item.Value == nil { return } - - shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), time.Second*5) - defer shutdownCancel() - - // Run the eviction in the background since ristretto blocks sets until a free slot is available. + // Capture the value synchronously: ristretto reuses the + // item slot after OnEvict returns, so reading item.Value + // from the goroutine below races with the caller of + // Clear/Set. The shutdown still runs in the background to + // avoid blocking ristretto's eviction loop. + bridge := item.Value go func() { - _ = item.Value.Shutdown(shutdownCtx) + shutdownCtx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + _ = bridge.Shutdown(shutdownCtx) }() }, }) @@ -85,18 +96,53 @@ func NewCachedBridgePool(options PoolOptions, providers []aibridge.Provider, log return nil, xerrors.Errorf("create cache: %w", err) } - return &CachedBridgePool{ - cache: cache, - providers: providers, - options: options, - metrics: metrics, - tracer: tracer, - logger: logger, + pool := &CachedBridgePool{ + cache: cache, + options: options, + metrics: metrics, + tracer: tracer, + logger: logger, singleflight: &singleflight.Group[string, *aibridge.RequestBridge]{}, shuttingDownCh: make(chan struct{}), - }, nil + } + initial := slices.Clone(providers) + pool.providers.Store(&initial) + return pool, nil +} + +// ReplaceProviders swaps the provider snapshot used by future Acquires. +// It is safe to call concurrently with Acquire and is a no-op after +// Shutdown. +func (p *CachedBridgePool) ReplaceProviders(providers []aibridge.Provider) { + select { + case <-p.shuttingDownCh: + return + default: + } + snapshot := slices.Clone(providers) + p.providers.Store(&snapshot) + version := time.Now().UnixNano() + p.providerVersion.Store(version) + // Clear evicts every cached bridge; OnEvict shuts each one down in + // the background. Wait for buffered writes to drain so a replacement + // immediately followed by an Acquire always sees the cleared cache. + p.cache.Clear() + p.cache.Wait() + p.logger.Info(context.Background(), "request bridge pool reloaded", + slog.F("provider_count", len(snapshot)), + slog.F("provider_version", version), + ) +} + +// loadProviders returns the current providers snapshot. The returned +// slice must not be mutated. +func (p *CachedBridgePool) loadProviders() []aibridge.Provider { + if ptr := p.providers.Load(); ptr != nil { + return *ptr + } + return nil } // Acquire retrieves or creates a [*aibridge.RequestBridge] instance per given key. @@ -140,6 +186,7 @@ func (p *CachedBridgePool) Acquire(ctx context.Context, req Request, clientFn Cl } span.AddEvent("cache_miss") + providerVersion := p.providerVersion.Load() recorder := aibridge.NewRecorder(p.logger.Named("recorder"), p.tracer, func() (aibridge.Recorder, error) { client, err := clientFn() if err != nil { @@ -152,7 +199,8 @@ func (p *CachedBridgePool) Acquire(ctx context.Context, req Request, clientFn Cl // Slow path. // Creating an *aibridge.RequestBridge may take some time, so gate all subsequent callers behind the initial request and return the resulting value. // TODO: track startup time since it adds latency to first request (histogram count will also help us see how often this occurs). - instance, err, _ := p.singleflight.Do(req.InitiatorID.String(), func() (*aibridge.RequestBridge, error) { + singleflightKey := cacheKey + "|" + strconv.FormatInt(providerVersion, 10) + instance, err, _ := p.singleflight.Do(singleflightKey, func() (*aibridge.RequestBridge, error) { var ( mcpServers mcp.ServerProxier err error @@ -171,12 +219,14 @@ func (p *CachedBridgePool) Acquire(ctx context.Context, req Request, clientFn Cl } } - bridge, err := aibridge.NewRequestBridge(ctx, p.providers, recorder, mcpServers, p.logger, p.metrics, p.tracer) + bridge, err := aibridge.NewRequestBridge(ctx, p.loadProviders(), recorder, mcpServers, p.logger, p.metrics, p.tracer) if err != nil { return nil, xerrors.Errorf("create new request bridge: %w", err) } - p.cache.SetWithTTL(cacheKey, bridge, cacheCost, p.options.TTL) + if p.providerVersion.Load() == providerVersion { + p.cache.SetWithTTL(cacheKey, bridge, cacheCost, p.options.TTL) + } return bridge, nil }) diff --git a/coderd/aibridged/pool_test.go b/coderd/aibridged/pool_test.go index f5153fe4d9..bb42c4c256 100644 --- a/coderd/aibridged/pool_test.go +++ b/coderd/aibridged/pool_test.go @@ -2,6 +2,10 @@ package aibridged_test import ( "context" + "io" + "net/http" + "net/http/httptest" + "sync/atomic" "testing" "testing/synctest" "time" @@ -12,10 +16,13 @@ import ( "go.uber.org/mock/gomock" "cdr.dev/slog/v3/sloggers/slogtest" + "github.com/coder/coder/v2/aibridge" + "github.com/coder/coder/v2/aibridge/config" "github.com/coder/coder/v2/aibridge/mcp" "github.com/coder/coder/v2/aibridge/mcpmock" "github.com/coder/coder/v2/coderd/aibridged" mock "github.com/coder/coder/v2/coderd/aibridged/aibridgedmock" + "github.com/coder/coder/v2/testutil" ) // TestPool validates the published behavior of [aibridged.CachedBridgePool]. @@ -107,6 +114,160 @@ func TestPool(t *testing.T) { require.EqualValues(t, 3, cacheMetrics.Misses()) } +func TestPoolReplaceProvidersClearsCacheAndUsesNewProviders(t *testing.T) { + t.Parallel() + + oldUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, "old") + })) + t.Cleanup(oldUpstream.Close) + newUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, "new") + })) + t.Cleanup(newUpstream.Close) + + logger := slogtest.Make(t, nil) + ctrl := gomock.NewController(t) + client := mock.NewMockDRPCClient(ctrl) + mcpProxy := mcpmock.NewMockServerProxier(ctrl) + mcpProxy.EXPECT().Init(gomock.Any()).AnyTimes().Return(nil) + mcpProxy.EXPECT().Shutdown(gomock.Any()).AnyTimes().Return(nil) + + opts := aibridged.PoolOptions{MaxItems: 1, TTL: time.Minute} + pool, err := aibridged.NewCachedBridgePool(opts, []aibridge.Provider{ + aibridge.NewOpenAIProvider(config.OpenAI{Name: "old", BaseURL: oldUpstream.URL}), + }, logger, nil, testTracer) + require.NoError(t, err) + t.Cleanup(func() { _ = pool.Shutdown(context.Background()) }) + + req := aibridged.Request{ + SessionKey: "key", + InitiatorID: uuid.New(), + APIKeyID: uuid.New().String(), + } + clientFn := func() (aibridged.DRPCClient, error) { + return client, nil + } + + inst, err := pool.Acquire(t.Context(), req, clientFn, newMockMCPFactory(mcpProxy)) + require.NoError(t, err) + assertHandlerBody(t, inst, "/old/v1/models", "old") + + pool.ReplaceProviders([]aibridge.Provider{ + aibridge.NewOpenAIProvider(config.OpenAI{Name: "new", BaseURL: newUpstream.URL}), + }) + + instAfterReload, err := pool.Acquire(t.Context(), req, clientFn, newMockMCPFactory(mcpProxy)) + require.NoError(t, err) + require.NotSame(t, inst, instAfterReload) + assertHandlerBody(t, instAfterReload, "/new/v1/models", "new") +} + +func TestPoolReplaceProvidersDoesNotJoinStaleSingleflight(t *testing.T) { + t.Parallel() + + oldUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, "old") + })) + t.Cleanup(oldUpstream.Close) + newUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, "new") + })) + t.Cleanup(newUpstream.Close) + + logger := slogtest.Make(t, nil) + ctrl := gomock.NewController(t) + client := mock.NewMockDRPCClient(ctrl) + + opts := aibridged.PoolOptions{MaxItems: 1, TTL: time.Minute} + pool, err := aibridged.NewCachedBridgePool(opts, []aibridge.Provider{ + aibridge.NewOpenAIProvider(config.OpenAI{Name: "old", BaseURL: oldUpstream.URL}), + }, logger, nil, testTracer) + require.NoError(t, err) + t.Cleanup(func() { _ = pool.Shutdown(context.Background()) }) + + req := aibridged.Request{ + SessionKey: "key", + InitiatorID: uuid.New(), + APIKeyID: uuid.New().String(), + } + clientFn := func() (aibridged.DRPCClient, error) { + return client, nil + } + + factory := newBlockingMCPFactory() + firstDone := make(chan acquireResult, 1) + go func() { + handler, err := pool.Acquire(t.Context(), req, clientFn, factory) + firstDone <- acquireResult{handler: handler, err: err} + }() + + require.Eventually(t, factory.firstBuildStarted, testutil.WaitShort, testutil.IntervalFast) + + pool.ReplaceProviders([]aibridge.Provider{ + aibridge.NewOpenAIProvider(config.OpenAI{Name: "new", BaseURL: newUpstream.URL}), + }) + + secondDone := make(chan acquireResult, 1) + go func() { + handler, err := pool.Acquire(t.Context(), req, clientFn, factory) + secondDone <- acquireResult{handler: handler, err: err} + }() + + var second acquireResult + require.Eventually(t, func() bool { + select { + case second = <-secondDone: + return true + default: + return false + } + }, testutil.WaitShort, testutil.IntervalFast) + require.NoError(t, second.err) + assertHandlerBody(t, second.handler, "/new/v1/models", "new") + + close(factory.releaseFirst) + var first acquireResult + require.Eventually(t, func() bool { + select { + case first = <-firstDone: + return true + default: + return false + } + }, testutil.WaitShort, testutil.IntervalFast) + require.NoError(t, first.err) + + third, err := pool.Acquire(t.Context(), req, clientFn, factory) + require.NoError(t, err) + require.Same(t, second.handler, third) +} + +func TestPoolReplaceProvidersAfterShutdownIsNoop(t *testing.T) { + t.Parallel() + + logger := slogtest.Make(t, nil) + opts := aibridged.PoolOptions{MaxItems: 1, TTL: time.Minute} + pool, err := aibridged.NewCachedBridgePool(opts, nil, logger, nil, testTracer) + require.NoError(t, err) + + require.NoError(t, pool.Shutdown(t.Context())) + require.NotPanics(t, func() { + pool.ReplaceProviders([]aibridge.Provider{ + aibridge.NewOpenAIProvider(config.OpenAI{Name: "new", BaseURL: "https://example.com"}), + }) + }) + + _, err = pool.Acquire(t.Context(), aibridged.Request{ + SessionKey: "key", + InitiatorID: uuid.New(), + APIKeyID: uuid.New().String(), + }, func() (aibridged.DRPCClient, error) { + return nil, context.Canceled + }, newMockMCPFactory(nil)) + require.ErrorContains(t, err, "pool shutting down") +} + func TestPool_Expiry(t *testing.T) { t.Parallel() @@ -166,6 +327,21 @@ func TestPool_Expiry(t *testing.T) { }) } +func assertHandlerBody(t *testing.T, handler http.Handler, path string, body string) { + t.Helper() + + req := httptest.NewRequest(http.MethodGet, path, nil) + rw := httptest.NewRecorder() + handler.ServeHTTP(rw, req) + resp := rw.Result() + defer resp.Body.Close() + + got, err := io.ReadAll(resp.Body) + require.NoError(t, err) + require.Equal(t, http.StatusOK, resp.StatusCode) + require.Equal(t, body, string(got)) +} + var _ aibridged.MCPProxyBuilder = &mockMCPFactory{} type mockMCPFactory struct { @@ -179,3 +355,42 @@ func newMockMCPFactory(proxy *mcpmock.MockServerProxier) *mockMCPFactory { func (m *mockMCPFactory) Build(ctx context.Context, req aibridged.Request, tracer trace.Tracer) (mcp.ServerProxier, error) { return m.proxy, nil } + +type acquireResult struct { + handler http.Handler + err error +} + +type blockingMCPFactory struct { + calls atomic.Int32 + firstStarted chan struct{} + releaseFirst chan struct{} +} + +func newBlockingMCPFactory() *blockingMCPFactory { + return &blockingMCPFactory{ + firstStarted: make(chan struct{}), + releaseFirst: make(chan struct{}), + } +} + +func (m *blockingMCPFactory) firstBuildStarted() bool { + select { + case <-m.firstStarted: + return true + default: + return false + } +} + +func (m *blockingMCPFactory) Build(ctx context.Context, _ aibridged.Request, _ trace.Tracer) (mcp.ServerProxier, error) { + if m.calls.Add(1) == 1 { + close(m.firstStarted) + select { + case <-m.releaseFirst: + case <-ctx.Done(): + return nil, ctx.Err() + } + } + return nil, context.Canceled +} diff --git a/coderd/aibridged/reload.go b/coderd/aibridged/reload.go new file mode 100644 index 0000000000..9909d3de0c --- /dev/null +++ b/coderd/aibridged/reload.go @@ -0,0 +1,50 @@ +package aibridged + +import ( + "context" + + "golang.org/x/xerrors" + + "cdr.dev/slog/v3" + dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub" + "github.com/coder/coder/v2/coderd/pubsub" +) + +// ProviderReloader refreshes a component's provider snapshot. +type ProviderReloader interface { + Reload(ctx context.Context) error +} + +// SubscribeProviderReload refreshes once, then on AI provider changes. +func SubscribeProviderReload( + ctx context.Context, + ps dbpubsub.Pubsub, + reloader ProviderReloader, + logger slog.Logger, +) (func(), error) { + if ps == nil { + return nil, xerrors.New("pubsub is required") + } + if reloader == nil { + return nil, xerrors.New("reloader is required") + } + + unsubscribe, err := ps.SubscribeWithErr(pubsub.AIProvidersChangedChannel, func(cbCtx context.Context, _ []byte, err error) { + if err != nil { + logger.Warn(cbCtx, "ai providers changed event delivered with error", slog.Error(err)) + return + } + if err := reloader.Reload(cbCtx); err != nil { + logger.Warn(cbCtx, "reload ai provider snapshot from pubsub event", slog.Error(err)) + return + } + logger.Debug(cbCtx, "reloaded ai provider snapshot from pubsub event") + }) + if err != nil { + return nil, xerrors.Errorf("subscribe to %s: %w", pubsub.AIProvidersChangedChannel, err) + } + if err := reloader.Reload(ctx); err != nil { + logger.Warn(ctx, "initial ai provider reload", slog.Error(err)) + } + return unsubscribe, nil +} diff --git a/coderd/aibridged/reload_test.go b/coderd/aibridged/reload_test.go new file mode 100644 index 0000000000..e73489ba83 --- /dev/null +++ b/coderd/aibridged/reload_test.go @@ -0,0 +1,133 @@ +package aibridged_test + +import ( + "context" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/require" + "golang.org/x/xerrors" + + "cdr.dev/slog/v3/sloggers/slogtest" + "github.com/coder/coder/v2/coderd/aibridged" + dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub" + "github.com/coder/coder/v2/coderd/pubsub" + "github.com/coder/coder/v2/testutil" +) + +func TestSubscribeProviderReload(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitMedium) + + logger := slogtest.Make(t, nil) + ps := dbpubsub.NewInMemory() + t.Cleanup(func() { _ = ps.Close() }) + + calls := &recordingReloader{} + + unsub, err := aibridged.SubscribeProviderReload(ctx, ps, calls, logger) + require.NoError(t, err) + t.Cleanup(unsub) + + require.Equal(t, 1, calls.count()) + + require.NoError(t, ps.Publish(pubsub.AIProvidersChangedChannel, nil)) + + require.Eventually(t, func() bool { return calls.count() >= 2 }, testutil.WaitShort, testutil.IntervalFast, + "Reload must fire again after a pubsub notification") +} + +func TestSubscribeProviderReloadSurfacesReloadError(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitMedium) + + logger := slogtest.Make(t, nil) + ps := dbpubsub.NewInMemory() + t.Cleanup(func() { _ = ps.Close() }) + + calls := &recordingReloader{returnErr: true} + + unsub, err := aibridged.SubscribeProviderReload(ctx, ps, calls, logger) + require.NoError(t, err) + t.Cleanup(unsub) + + require.Equal(t, 1, calls.count()) + require.NoError(t, ps.Publish(pubsub.AIProvidersChangedChannel, nil)) + require.Eventually(t, func() bool { return calls.count() >= 2 }, testutil.WaitShort, testutil.IntervalFast, + "Reload must keep firing even after a previous Reload returned an error") +} + +func TestSubscribeProviderReloadIgnoresEventError(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitMedium) + + logger := slogtest.Make(t, nil) + ps := &errInjectingPubsub{} + + calls := &recordingReloader{} + unsub, err := aibridged.SubscribeProviderReload(ctx, ps, calls, logger) + require.NoError(t, err) + t.Cleanup(unsub) + + require.Equal(t, 1, calls.count()) + + ps.listener(ctx, nil, errPubsubDelivery) + require.Equal(t, 1, calls.count()) + + ps.listener(ctx, nil, nil) + require.Equal(t, 2, calls.count()) +} + +// recordingReloader is a minimal [aibridged.ProviderReloader] that +// counts calls. +type recordingReloader struct { + n atomic.Int32 + returnErr bool +} + +func (r *recordingReloader) Reload(_ context.Context) error { + r.n.Add(1) + if r.returnErr { + return errReloadFailed + } + return nil +} + +func (r *recordingReloader) count() int { + return int(r.n.Load()) +} + +var ( + errReloadFailed = stubError("reload failed") + errPubsubDelivery = stubError("pubsub delivery failed") +) + +type stubError string + +func (s stubError) Error() string { return string(s) } + +var _ dbpubsub.Pubsub = &errInjectingPubsub{} + +type errInjectingPubsub struct { + listener dbpubsub.ListenerWithErr +} + +func (*errInjectingPubsub) Subscribe(string, dbpubsub.Listener) (func(), error) { + return nil, xerrors.New("Subscribe not implemented") +} + +func (p *errInjectingPubsub) SubscribeWithErr(_ string, listener dbpubsub.ListenerWithErr) (func(), error) { + p.listener = listener + return func() {}, nil +} + +func (*errInjectingPubsub) Publish(string, []byte) error { + return xerrors.New("Publish not implemented") +} + +func (*errInjectingPubsub) Close() error { + return nil +} diff --git a/coderd/pubsub/aiproviderschangedevent.go b/coderd/pubsub/aiproviderschangedevent.go new file mode 100644 index 0000000000..a0ff20f960 --- /dev/null +++ b/coderd/pubsub/aiproviderschangedevent.go @@ -0,0 +1,11 @@ +package pubsub + +// AIProvidersChangedChannel is the pubsub channel that carries AI +// provider lifecycle events: provider create / update / soft-delete +// and key insert / delete. Subscribers (aibridged, aibridgeproxyd) +// reload their in-memory provider snapshot on receipt. +// +// The payload is an empty invalidation hint; subscribers refetch the +// authoritative state from the database, so dropped messages only +// delay convergence rather than diverge state. +const AIProvidersChangedChannel = "ai_providers_changed" diff --git a/enterprise/aibridgeproxyd/aibridgeproxyd.go b/enterprise/aibridgeproxyd/aibridgeproxyd.go index c32c5c41c5..19d05ca511 100644 --- a/enterprise/aibridgeproxyd/aibridgeproxyd.go +++ b/enterprise/aibridgeproxyd/aibridgeproxyd.go @@ -18,6 +18,7 @@ import ( "strconv" "strings" "sync" + "sync/atomic" "syscall" "time" @@ -30,6 +31,18 @@ import ( agplaibridge "github.com/coder/coder/v2/coderd/aibridge" ) +// ProviderRoute is the routing entry for a single AI provider: the +// instance name (the routing key) and the upstream base URL (the +// source of the MITM allowlist host). +type ProviderRoute struct { + Name string + BaseURL string +} + +// RefreshProvidersFunc returns the live provider set used by Reload to +// rebuild the proxy's routing snapshot. +type RefreshProvidersFunc func(ctx context.Context) ([]ProviderRoute, error) + // Known AI provider hosts. const ( HostAnthropic = "api.anthropic.com" @@ -119,14 +132,21 @@ var blockedIPRanges = func() []net.IPNet { // - decrypting requests using the configured MITM CA certificate // - forwarding requests to aibridged for processing type Server struct { - ctx context.Context - logger slog.Logger - proxy *goproxy.ProxyHttpServer - httpServer *http.Server - listener net.Listener - tlsEnabled bool - coderAccessURL *url.URL - aibridgeProviderFromHost func(host string) string + ctx context.Context + logger slog.Logger + proxy *goproxy.ProxyHttpServer + httpServer *http.Server + listener net.Listener + tlsEnabled bool + coderAccessURL *url.URL + // refreshProviders fetches the live provider snapshot on Reload. + // Nil disables hot-reload. + refreshProviders RefreshProvidersFunc + // providerRouter holds the live (mitmHosts, nameByHost) pair. + providerRouter atomic.Pointer[providerRouter] + // allowedPorts is the port allowlist for CONNECT requests. Fixed at + // construction; not reloadable. + allowedPorts []string // caCert is the PEM-encoded MITM CA certificate loaded during initialization. // This is served to clients who need to trust the proxy's generated certificates. caCert []byte @@ -139,6 +159,21 @@ type Server struct { metrics *Metrics } +// providerRouter keeps CONNECT matching and provider lookup in sync. +type providerRouter struct { + mitmHosts []string // host:port allowlist for the goproxy condition. + nameByHost map[string]string // lowercase hostname -> provider name. +} + +// emptyProviderRouter is used before the first Reload (or when the +// operator deconfigures every provider) so handlers can safely call +// loadProviderRouter without a nil check. +var emptyProviderRouter = &providerRouter{nameByHost: map[string]string{}} + +func (r *providerRouter) providerFromHost(host string) string { + return r.nameByHost[strings.ToLower(host)] +} + // requestContext holds metadata propagated through the proxy request/response chain. // It is stored in goproxy's ProxyCtx.UserData and enriched as the request progresses // through the proxy handlers. @@ -183,13 +218,12 @@ type Options struct { // CertStore is an optional certificate cache for MITM. If nil, a default // cache is created. Exposed for testing. CertStore goproxy.CertStorage - // DomainAllowlist is the list of domains to intercept and route through AI Bridge. - // Only requests to these domains will be MITM'd and forwarded to aibridged. - // Requests to other domains will be tunneled directly without decryption. + // DomainAllowlist seeds the boot-time MITM allowlist. Production + // callers should leave this empty and rely on RefreshProviders; + // tests use it to skip the refresh round-trip. DomainAllowlist []string - // AIBridgeProviderFromHost maps a hostname to a known aibridge provider - // name. Must be non-nil; the caller derives it from the configured - // provider list. + // AIBridgeProviderFromHost seeds the boot-time host -> provider + // name mapping. Required iff DomainAllowlist is non-empty. AIBridgeProviderFromHost func(host string) string // UpstreamProxy is the URL of an upstream HTTP proxy to chain tunneled // (non-allowlisted) requests through. If empty, tunneled requests connect @@ -214,6 +248,10 @@ type Options struct { // Metrics is the prometheus metrics instance for recording proxy metrics. // If nil, metrics will not be recorded. Metrics *Metrics + // RefreshProviders, when set, is invoked by Server.Reload to fetch + // the live provider snapshot used to derive the MITM allowlist and + // host -> provider-name routing. Nil disables hot-reload. + RefreshProviders RefreshProvidersFunc } func New(ctx context.Context, logger slog.Logger, opts Options) (*Server, error) { @@ -258,29 +296,12 @@ func New(ctx context.Context, logger slog.Logger, opts Options) (*Server, error) allowedPorts = []string{"80", "443"} } - // An empty allowlist is permitted so the server can boot before any - // ai_providers row exists; every intercept attempt is then rejected - // until providers are configured. - // TODO: refresh the allowlist when ai_providers changes so a restart - // is not required after the first provider is configured. - mitmHosts, err := convertDomainsToHosts(opts.DomainAllowlist, allowedPorts) + // Build the boot-time router from DomainAllowlist + the lookup fn. + // Both empty is fine: the server fails closed (no MITM until + // Reload populates the router from the database). + bootRouter, err := buildBootRouter(opts.DomainAllowlist, opts.AIBridgeProviderFromHost, allowedPorts) if err != nil { - return nil, xerrors.Errorf("invalid domain allowlist: %w", err) - } - - if opts.AIBridgeProviderFromHost == nil { - return nil, xerrors.New("AIBridgeProviderFromHost is required") - } - aibridgeProviderFromHost := opts.AIBridgeProviderFromHost - - for _, domain := range opts.DomainAllowlist { - domain = strings.TrimSpace(strings.ToLower(domain)) - if domain == "" { - continue - } - if aibridgeProviderFromHost(domain) == "" { - return nil, xerrors.Errorf("domain %q is in allowlist but has no provider mapping", domain) - } + return nil, err } // Parse configured exceptions to the blocked IP ranges. @@ -319,17 +340,22 @@ func New(ctx context.Context, logger slog.Logger, opts Options) (*Server, error) } srv := &Server{ - ctx: ctx, - logger: logger, - proxy: proxy, - tlsEnabled: opts.TLSCertFile != "", - coderAccessURL: coderAccessURL, - aibridgeProviderFromHost: aibridgeProviderFromHost, - caCert: certPEM, - allowedPrivateRanges: allowedPrivateRanges, - newDumper: opts.NewDumper, - metrics: opts.Metrics, + ctx: ctx, + logger: logger, + proxy: proxy, + tlsEnabled: opts.TLSCertFile != "", + coderAccessURL: coderAccessURL, + refreshProviders: opts.RefreshProviders, + allowedPorts: allowedPorts, + caCert: certPEM, + allowedPrivateRanges: allowedPrivateRanges, + newDumper: opts.NewDumper, + metrics: opts.Metrics, } + // Seed the boot-time router from the constructor inputs so the + // proxy can serve immediately. Reload may swap this snapshot at any + // point after construction. + srv.providerRouter.Store(bootRouter) // Configure upstream proxy for tunneled (non-allowlisted) CONNECT requests. // Allowlisted domains are MITM'd and forwarded to aibridge directly, @@ -417,12 +443,11 @@ func New(ctx context.Context, logger slog.Logger, opts Options) (*Server, error) // Reject CONNECT requests to non-standard ports. proxy.OnRequest().HandleConnectFunc(srv.portMiddleware(allowedPorts)) - // Apply MITM with authentication only to allowlisted hosts. - proxy.OnRequest( - // Only CONNECT requests to these hosts will be intercepted and decrypted. - // All other requests will be tunneled directly to their destination. - goproxy.ReqHostIs(mitmHosts...), - ).HandleConnectFunc( + // Apply MITM with authentication only to allowlisted hosts. The host + // list is loaded from the atomic router on every CONNECT so a + // Reload while inflight requests are in progress takes effect on + // the next CONNECT without touching the already-MITM'd ones. + proxy.OnRequest(srv.mitmHostsCondition()).HandleConnectFunc( // Extract Coder token from proxy authentication to forward to aibridged. srv.authMiddleware, ) @@ -470,7 +495,7 @@ func New(ctx context.Context, logger slog.Logger, opts Options) (*Server, error) slog.F("listen_addr", listener.Addr().String()), slog.F("tls_listener_enabled", srv.tlsEnabled), slog.F("coder_access_url", coderAccessURL.String()), - slog.F("domain_allowlist", mitmHosts), + slog.F("domain_allowlist", bootRouter.mitmHosts), slog.F("upstream_proxy", opts.UpstreamProxy), slog.F("allowed_private_cidrs", opts.AllowedPrivateCIDRs), slog.F("api_dump_enabled", opts.NewDumper != nil), @@ -651,11 +676,11 @@ func (s *Server) authMiddleware(host string, ctx *goproxy.ProxyCtx) (*goproxy.Co ) // Determine the provider from the request hostname. - provider := s.aibridgeProviderFromHost(ctx.Req.URL.Hostname()) - // This should never happen: startup validation ensures all allowlisted - // domains have known aibridge provider mappings. + provider := s.loadProviderRouter().providerFromHost(ctx.Req.URL.Hostname()) + // A concurrent Reload can swap the router between CONNECT matching + // and provider lookup, so treat a missing mapping as a runtime miss. if provider == "" { - logger.Error(s.ctx, "rejecting CONNECT request with no provider mapping") + logger.Warn(s.ctx, "rejecting CONNECT request with no provider mapping") return goproxy.RejectConnect, host } @@ -922,12 +947,11 @@ func (s *Server) handleRequest(req *http.Request, ctx *goproxy.ProxyCtx) (*http. } if reqCtx.Provider == "" { - // This should never happen: startup validation ensures all allowlisted - // domains have known aibridge provider mappings. - // The request is MITM'd (decrypted) but since there is no mapping, - // there is no known route to aibridge. - // Log error and forward to the original destination as a fallback. - s.logger.Error(s.ctx, "decrypted request has no provider mapping, passing through", + // A concurrent Reload can remove the provider after CONNECT + // authentication. The request is MITM'd (decrypted), but without a + // mapping there is no known route to aibridge. Log and forward + // to the original destination as a fallback. + s.logger.Warn(s.ctx, "decrypted request has no provider mapping, passing through", slog.F("connect_id", reqCtx.ConnectSessionID.String()), slog.F("host", req.Host), slog.F("method", req.Method), diff --git a/enterprise/aibridgeproxyd/aibridgeproxyd_test.go b/enterprise/aibridgeproxyd/aibridgeproxyd_test.go index 334fd289f8..6b843d8b14 100644 --- a/enterprise/aibridgeproxyd/aibridgeproxyd_test.go +++ b/enterprise/aibridgeproxyd/aibridgeproxyd_test.go @@ -158,6 +158,7 @@ type testProxyConfig struct { allowedPrivateCIDRs []string newDumper func(string, string) aibridgeproxyd.RoundTripDumper metrics *aibridgeproxyd.Metrics + refreshProviders aibridgeproxyd.RefreshProvidersFunc } type testProxyOption func(*testProxyConfig) @@ -250,6 +251,12 @@ func withListenerTLS(certFile, keyFile string) testProxyOption { } } +func withRefreshProviders(fn aibridgeproxyd.RefreshProvidersFunc) testProxyOption { + return func(cfg *testProxyConfig) { + cfg.refreshProviders = fn + } +} + // newTestProxy creates a new AI Bridge Proxy server for testing. // It uses the shared MITM certificate and registers cleanup automatically. // It waits for the proxy server to be ready before returning. @@ -289,6 +296,7 @@ func newTestProxy(t *testing.T, opts ...testProxyOption) *aibridgeproxyd.Server AllowedPrivateCIDRs: cfg.allowedPrivateCIDRs, NewDumper: cfg.newDumper, Metrics: cfg.metrics, + RefreshProviders: cfg.refreshProviders, } if cfg.certStore != nil { aibridgeOpts.CertStore = cfg.certStore diff --git a/enterprise/aibridgeproxyd/reload.go b/enterprise/aibridgeproxyd/reload.go new file mode 100644 index 0000000000..d235a9be3b --- /dev/null +++ b/enterprise/aibridgeproxyd/reload.go @@ -0,0 +1,124 @@ +package aibridgeproxyd + +import ( + "context" + "net/http" + "net/url" + "slices" + "strings" + + "github.com/elazarl/goproxy" + "golang.org/x/xerrors" + + "cdr.dev/slog/v3" +) + +// Reload refreshes proxy routing from the configured provider source. +// A refresh failure leaves the previous snapshot in place. +func (s *Server) Reload(ctx context.Context) error { + if s.refreshProviders == nil { + return nil + } + providers, err := s.refreshProviders(ctx) + if err != nil { + return xerrors.Errorf("refresh ai providers for proxy routing: %w", err) + } + router, err := buildProviderRouter(ctx, s.logger, providers, s.allowedPorts) + if err != nil { + return xerrors.Errorf("build provider router (provider_count=%d): %w", len(providers), err) + } + s.providerRouter.Store(router) + s.logger.Debug(s.ctx, "aibridgeproxyd router reloaded", + slog.F("mitm_host_count", len(router.mitmHosts)), + ) + return nil +} + +func (s *Server) loadProviderRouter() *providerRouter { + if p := s.providerRouter.Load(); p != nil { + return p + } + return emptyProviderRouter +} + +// mitmHostsCondition returns a goproxy ReqConditionFunc that reads the +// allowlist from the atomic router on every match. Using a closure +// instead of goproxy.ReqHostIs(...) lets Reload affect every later +// CONNECT without re-registering handlers. +func (s *Server) mitmHostsCondition() goproxy.ReqConditionFunc { + return func(req *http.Request, _ *goproxy.ProxyCtx) bool { + if req == nil { + return false + } + return slices.Contains(s.loadProviderRouter().mitmHosts, strings.ToLower(req.URL.Host)) + } +} + +// buildProviderRouter constructs a router snapshot from a refreshed +// provider list. First provider wins on duplicate hostnames. +func buildProviderRouter(ctx context.Context, logger slog.Logger, providers []ProviderRoute, allowedPorts []string) (*providerRouter, error) { + nameByHost := make(map[string]string, len(providers)) + var domains []string + for _, p := range providers { + if p.BaseURL == "" { + logger.Warn(ctx, "skipping ai provider without base url", + slog.F("provider_name", p.Name), + ) + continue + } + u, err := url.Parse(p.BaseURL) + if err != nil { + logger.Warn(ctx, "skipping ai provider with invalid base url", + slog.F("provider_name", p.Name), + slog.F("base_url", p.BaseURL), + slog.Error(err), + ) + continue + } + if u.Hostname() == "" { + logger.Warn(ctx, "skipping ai provider base url without hostname", + slog.F("provider_name", p.Name), + slog.F("base_url", p.BaseURL), + ) + continue + } + host := strings.ToLower(u.Hostname()) + if _, exists := nameByHost[host]; exists { + continue + } + nameByHost[host] = p.Name + domains = append(domains, host) + } + mitmHosts, err := convertDomainsToHosts(domains, allowedPorts) + if err != nil { + return nil, err + } + return &providerRouter{mitmHosts: mitmHosts, nameByHost: nameByHost}, nil +} + +// buildBootRouter seeds the providerRouter from the boot-time inputs. +// The lookup function is consulted only for hosts in the allowlist; a +// nil function with an empty allowlist is fine and yields an empty +// router (the proxy fails closed until Reload populates it). +func buildBootRouter(domainAllowlist []string, providerFromHost func(string) string, allowedPorts []string) (*providerRouter, error) { + mitmHosts, err := convertDomainsToHosts(domainAllowlist, allowedPorts) + if err != nil { + return nil, xerrors.Errorf("invalid domain allowlist: %w", err) + } + nameByHost := make(map[string]string, len(domainAllowlist)) + for _, domain := range domainAllowlist { + domain = strings.TrimSpace(strings.ToLower(domain)) + if domain == "" { + continue + } + var name string + if providerFromHost != nil { + name = providerFromHost(domain) + } + if name == "" { + return nil, xerrors.Errorf("domain %q is in allowlist but has no provider mapping", domain) + } + nameByHost[domain] = name + } + return &providerRouter{mitmHosts: mitmHosts, nameByHost: nameByHost}, nil +} diff --git a/enterprise/aibridgeproxyd/reload_internal_test.go b/enterprise/aibridgeproxyd/reload_internal_test.go new file mode 100644 index 0000000000..7a8b0f9fae --- /dev/null +++ b/enterprise/aibridgeproxyd/reload_internal_test.go @@ -0,0 +1,147 @@ +package aibridgeproxyd + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/xerrors" + + "cdr.dev/slog/v3/sloggers/slogtest" + "github.com/coder/coder/v2/testutil" +) + +func TestServerReloadSwapsProviderRouter(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + providers := []ProviderRoute{{Name: "old", BaseURL: "https://old.example.com/"}} + srv := &Server{ + ctx: ctx, + logger: slogtest.Make(t, nil), + allowedPorts: []string{"443"}, + refreshProviders: func(context.Context) ([]ProviderRoute, error) { + return providers, nil + }, + } + srv.providerRouter.Store(emptyProviderRouter) + + require.NoError(t, srv.Reload(ctx)) + assert.Equal(t, "old", srv.loadProviderRouter().providerFromHost("old.example.com")) + assert.Empty(t, srv.loadProviderRouter().providerFromHost("new.example.com")) + + providers = []ProviderRoute{{Name: "new", BaseURL: "https://new.example.com/"}} + require.NoError(t, srv.Reload(ctx)) + + router := srv.loadProviderRouter() + assert.Empty(t, router.providerFromHost("old.example.com")) + assert.Equal(t, "new", router.providerFromHost("new.example.com")) + assert.Equal(t, []string{"new.example.com:443"}, router.mitmHosts) +} + +func TestServerReloadPreservesProviderRouterOnRefreshError(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + refreshErr := xerrors.New("refresh failed") + providers := []ProviderRoute{{Name: "old", BaseURL: "https://old.example.com/"}} + failRefresh := false + srv := &Server{ + ctx: ctx, + logger: slogtest.Make(t, nil), + allowedPorts: []string{"443"}, + refreshProviders: func(context.Context) ([]ProviderRoute, error) { + if failRefresh { + return nil, refreshErr + } + return providers, nil + }, + } + srv.providerRouter.Store(emptyProviderRouter) + + require.NoError(t, srv.Reload(ctx)) + before := srv.loadProviderRouter() + assert.Equal(t, "old", before.providerFromHost("old.example.com")) + + failRefresh = true + require.ErrorIs(t, srv.Reload(ctx), refreshErr) + + after := srv.loadProviderRouter() + assert.Same(t, before, after) + assert.Equal(t, "old", after.providerFromHost("old.example.com")) + assert.Equal(t, []string{"old.example.com:443"}, after.mitmHosts) +} + +// TestBuildProviderRouter covers the host-and-routing derivation that +// Reload feeds into the providerRouter. +func TestBuildProviderRouter(t *testing.T) { + t.Parallel() + + t.Run("ExtractsHostnames", func(t *testing.T) { + t.Parallel() + + providers := []ProviderRoute{ + {Name: "openai", BaseURL: "https://api.openai.com/v1/"}, + {Name: "anthropic", BaseURL: "https://api.anthropic.com/"}, + {Name: "custom", BaseURL: "https://custom-llm.example.com:8443/api"}, + } + + router, err := buildProviderRouter(testutil.Context(t, testutil.WaitShort), slogtest.Make(t, nil), providers, []string{"443"}) + require.NoError(t, err) + + assert.Equal(t, "openai", router.providerFromHost("api.openai.com")) + assert.Equal(t, "anthropic", router.providerFromHost("api.anthropic.com")) + assert.Equal(t, "custom", router.providerFromHost("custom-llm.example.com")) + assert.Empty(t, router.providerFromHost("unknown.com")) + + assert.Contains(t, router.mitmHosts, "api.openai.com:443") + assert.Contains(t, router.mitmHosts, "api.anthropic.com:443") + }) + + t.Run("DeduplicatesSameHost", func(t *testing.T) { + t.Parallel() + + providers := []ProviderRoute{ + {Name: "first", BaseURL: "https://api.example.com/v1"}, + {Name: "second", BaseURL: "https://api.example.com/v2"}, + } + + router, err := buildProviderRouter(testutil.Context(t, testutil.WaitShort), slogtest.Make(t, nil), providers, []string{"443"}) + require.NoError(t, err) + + // First provider wins on duplicate host. + assert.Equal(t, "first", router.providerFromHost("api.example.com")) + }) + + t.Run("CaseInsensitive", func(t *testing.T) { + t.Parallel() + + providers := []ProviderRoute{ + {Name: "provider", BaseURL: "https://API.Example.COM/v1"}, + } + + router, err := buildProviderRouter(testutil.Context(t, testutil.WaitShort), slogtest.Make(t, nil), providers, []string{"443"}) + require.NoError(t, err) + + assert.Equal(t, "provider", router.providerFromHost("API.Example.COM")) + assert.Equal(t, "provider", router.providerFromHost("api.example.com")) + }) + + t.Run("SkipsEmptyOrMalformedBaseURL", func(t *testing.T) { + t.Parallel() + + providers := []ProviderRoute{ + {Name: "no-url"}, + {Name: "scheme-only", BaseURL: "https://"}, + {Name: "good", BaseURL: "https://api.good.example.com/"}, + } + + router, err := buildProviderRouter(testutil.Context(t, testutil.WaitShort), slogtest.Make(t, nil), providers, []string{"443"}) + require.NoError(t, err) + + assert.Equal(t, "good", router.providerFromHost("api.good.example.com")) + assert.Empty(t, router.providerFromHost("scheme-only")) + assert.Equal(t, []string{"api.good.example.com:443"}, router.mitmHosts) + }) +} diff --git a/enterprise/aibridgeproxyd/reload_test.go b/enterprise/aibridgeproxyd/reload_test.go new file mode 100644 index 0000000000..d51aa5bc98 --- /dev/null +++ b/enterprise/aibridgeproxyd/reload_test.go @@ -0,0 +1,379 @@ +package aibridgeproxyd_test + +import ( + "context" + "io" + "net/http" + "net/http/httptest" + "slices" + "strings" + "sync" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/enterprise/aibridgeproxyd" + "github.com/coder/coder/v2/testutil" +) + +// reloadTestHarness wires a real proxy server to a mutable provider +// store and a mock aibridged backend so tests can drive Reload through +// a CRUD-style sequence and observe routing via real proxy requests. +type reloadTestHarness struct { + srv *aibridgeproxyd.Server + store *providerStore + client *http.Client + bridged *httptest.Server + recorder *aibridgedRecorder +} + +// aibridgedRecorder captures the path of the last request received by +// the mock aibridged backend. Access is mutex-guarded so the test +// goroutine and the proxy's response goroutine can read/write safely. +type aibridgedRecorder struct { + mu sync.Mutex + path string +} + +func (r *aibridgedRecorder) record(path string) { + r.mu.Lock() + defer r.mu.Unlock() + r.path = path +} + +func (r *aibridgedRecorder) load() string { + r.mu.Lock() + defer r.mu.Unlock() + return r.path +} + +func (r *aibridgedRecorder) reset() { + r.mu.Lock() + defer r.mu.Unlock() + r.path = "" +} + +// providerStore is a mutable [aibridgeproxyd.RefreshProvidersFunc] +// backing for integration tests. set / setErr mutate the snapshot +// returned by the next Reload, mimicking CRUD against the database. +type providerStore struct { + mu sync.Mutex + providers []aibridgeproxyd.ProviderRoute + err error +} + +func (s *providerStore) set(providers []aibridgeproxyd.ProviderRoute) { + s.mu.Lock() + defer s.mu.Unlock() + s.providers = providers + s.err = nil +} + +func (s *providerStore) setErr(err error) { + s.mu.Lock() + defer s.mu.Unlock() + s.err = err +} + +func (s *providerStore) refresh(context.Context) ([]aibridgeproxyd.ProviderRoute, error) { + s.mu.Lock() + defer s.mu.Unlock() + if s.err != nil { + return nil, s.err + } + // Return a copy so callers can't mutate our internal snapshot. + return slices.Clone(s.providers), nil +} + +// newReloadTestHarness boots a proxy with an empty boot allowlist and a +// store-backed RefreshProviders. Production wiring is identical: the +// daemon constructs the proxy without a static allowlist and lets +// Reload populate the router from the database. +func newReloadTestHarness(t *testing.T) *reloadTestHarness { + t.Helper() + + recorder := &aibridgedRecorder{} + bridged := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + recorder.record(r.URL.Path) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("aibridged")) + })) + t.Cleanup(bridged.Close) + + store := &providerStore{} + srv := newTestProxy(t, + withCoderAccessURL(bridged.URL), + withAllowedPorts("443"), + // Empty boot allowlist: the router must be populated by Reload, + // matching the production daemon's behavior. + withDomainAllowlist(), + withAIBridgeProviderFromHost(nil), + withRefreshProviders(store.refresh), + ) + + certPool := getProxyCertPool(t) + client := newProxyClient(t, srv, makeProxyAuthHeader("coder-token"), certPool, false) + // Disable keep-alives so each request opens a fresh CONNECT through + // the proxy. Per the Reload contract, already-MITM'd tunnels keep + // the provider name they captured at CONNECT time; only new + // connections see the post-Reload snapshot. Tests need a fresh + // CONNECT between phases to assert on the new routing. + client.Transport.(*http.Transport).DisableKeepAlives = true + + return &reloadTestHarness{ + srv: srv, + store: store, + client: client, + bridged: bridged, + recorder: recorder, + } +} + +// requestResult is the outcome of sending a request through the proxy. +// Either err is set (CONNECT failed for a non-MITM'd host whose dial +// fell through to the tunneled path and could not be resolved) or +// status/body carry the MITM'd response from the mock aibridged. +type requestResult struct { + status int + body string + err error +} + +// sendRequest issues a single POST through the proxy. It returns rather +// than asserting so callers can branch on whether the host is currently +// routed (MITM'd to aibridged) or not (tunneled, dial of an unresolvable +// host fails). +func (h *reloadTestHarness) sendRequest(t *testing.T, targetURL string) requestResult { + t.Helper() + + ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitShort) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, strings.NewReader(`{}`)) + require.NoError(t, err) + req.Header.Set("Content-Type", "application/json") + + resp, err := h.client.Do(req) + if err != nil { + return requestResult{err: err} + } + defer resp.Body.Close() + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + return requestResult{status: resp.StatusCode, body: string(body)} +} + +// expectRoutedTo asserts the proxy MITM'd the request and forwarded it +// to aibridged with the expected /api/v2/aibridge//. +func (h *reloadTestHarness) expectRoutedTo(t *testing.T, targetURL, expectedPath string) { + t.Helper() + + h.recorder.reset() + res := h.sendRequest(t, targetURL) + require.NoError(t, res.err, "request to routed host must succeed") + require.Equal(t, http.StatusOK, res.status) + require.Equal(t, "aibridged", res.body) + require.Equal(t, expectedPath, h.recorder.load(), + "aibridged must observe the rewritten path for %s", targetURL) +} + +// expectNotRouted asserts the proxy did not MITM the request for the +// given host. The CONNECT either falls through to the tunneled path +// (where the .invalid hostname fails to dial) or to a 502 from the +// proxy. Either way, aibridged never sees the request. +func (h *reloadTestHarness) expectNotRouted(t *testing.T, targetURL string) { + t.Helper() + + h.recorder.reset() + _ = h.sendRequest(t, targetURL) + require.Empty(t, h.recorder.load(), + "aibridged must not be reached for non-routed host %s", targetURL) +} + +// TestProxy_HotReloadRoutingCRUD drives the proxy through a CRUD-style +// sequence of provider changes and asserts on routing after each +// Reload via real HTTPS requests. Each sub-test mutates the store and +// validates that: +// - newly created providers are MITM'd to aibridged with the right +// /api/v2/aibridge// +// - renamed providers route under the new name +// - providers whose BaseURL host changes route the new host and stop +// MITM'ing the old host +// - deleted providers stop being MITM'd; aibridged sees nothing +// +// Hostnames are .invalid (RFC 2606) so a request that escapes the MITM +// path fails fast via DNS rather than reaching a real upstream. +func TestProxy_HotReloadRoutingCRUD(t *testing.T) { + t.Parallel() + + h := newReloadTestHarness(t) + + // InitialEmptyRouter: no Reload has been called and the boot + // allowlist is empty, so any host falls through to the tunneled + // middleware. + h.expectNotRouted(t, "https://alpha.invalid/v1/messages") + + // CreateProvider. + h.store.set([]aibridgeproxyd.ProviderRoute{ + {Name: "alpha", BaseURL: "https://alpha.invalid/v1"}, + }) + require.NoError(t, h.srv.Reload(t.Context())) + h.expectRoutedTo(t, "https://alpha.invalid/v1/messages", "/api/v2/aibridge/alpha/v1/messages") + + // UpdateProviderName: the same BaseURL with a new name must route + // under the new name on the next Reload. + h.store.set([]aibridgeproxyd.ProviderRoute{ + {Name: "alpha-v2", BaseURL: "https://alpha.invalid/v1"}, + }) + require.NoError(t, h.srv.Reload(t.Context())) + h.expectRoutedTo(t, "https://alpha.invalid/v1/messages", "/api/v2/aibridge/alpha-v2/v1/messages") + + // UpdateProviderBaseURLHost: moving the provider to a new host must + // start MITM'ing the new host and stop MITM'ing the old one. + h.store.set([]aibridgeproxyd.ProviderRoute{ + {Name: "alpha-v2", BaseURL: "https://alpha-new.invalid/v1"}, + }) + require.NoError(t, h.srv.Reload(t.Context())) + h.expectRoutedTo(t, "https://alpha-new.invalid/v1/messages", "/api/v2/aibridge/alpha-v2/v1/messages") + h.expectNotRouted(t, "https://alpha.invalid/v1/messages") + + // AddSecondProvider: a second provider added in the same Reload must + // route independently from the first. + h.store.set([]aibridgeproxyd.ProviderRoute{ + {Name: "alpha-v2", BaseURL: "https://alpha-new.invalid/v1"}, + {Name: "beta", BaseURL: "https://beta.invalid/v1"}, + }) + require.NoError(t, h.srv.Reload(t.Context())) + h.expectRoutedTo(t, "https://alpha-new.invalid/v1/messages", "/api/v2/aibridge/alpha-v2/v1/messages") + h.expectRoutedTo(t, "https://beta.invalid/v1/chat/completions", "/api/v2/aibridge/beta/v1/chat/completions") + + // DeleteOneProvider: removing alpha must keep beta routed and stop + // routing alpha. + h.store.set([]aibridgeproxyd.ProviderRoute{ + {Name: "beta", BaseURL: "https://beta.invalid/v1"}, + }) + require.NoError(t, h.srv.Reload(t.Context())) + h.expectRoutedTo(t, "https://beta.invalid/v1/chat/completions", "/api/v2/aibridge/beta/v1/chat/completions") + h.expectNotRouted(t, "https://alpha-new.invalid/v1/messages") + + // DeleteAllProviders: an empty Reload must collapse the router to + // the fail-closed state with no host MITM'd. + h.store.set(nil) + require.NoError(t, h.srv.Reload(t.Context())) + h.expectNotRouted(t, "https://beta.invalid/v1/chat/completions") + h.expectNotRouted(t, "https://alpha-new.invalid/v1/messages") + + // RecreateAfterDelete: reintroducing a previously-deleted provider + // must route again without restart, confirming the swap is + // symmetric. + h.store.set([]aibridgeproxyd.ProviderRoute{ + {Name: "alpha", BaseURL: "https://alpha.invalid/v1"}, + }) + require.NoError(t, h.srv.Reload(t.Context())) + h.expectRoutedTo(t, "https://alpha.invalid/v1/messages", "/api/v2/aibridge/alpha/v1/messages") +} + +// TestProxy_HotReloadRoutingInvalidProviders covers the resilience +// requirements stated in the [aibridgeproxyd.Server.Reload] contract: +// individual invalid provider entries do not poison the snapshot, and +// a refresh-level error does not collapse the previous snapshot to +// empty. +func TestProxy_HotReloadRoutingInvalidProviders(t *testing.T) { + t.Parallel() + + t.Run("EmptyBaseURLSkipped", func(t *testing.T) { + t.Parallel() + + h := newReloadTestHarness(t) + // One valid provider and one with an empty BaseURL. The empty + // entry must be silently dropped; the valid one must still + // route. + h.store.set([]aibridgeproxyd.ProviderRoute{ + {Name: "no-url"}, + {Name: "valid", BaseURL: "https://valid.invalid/v1"}, + }) + require.NoError(t, h.srv.Reload(t.Context())) + + h.expectRoutedTo(t, "https://valid.invalid/v1/messages", "/api/v2/aibridge/valid/v1/messages") + }) + + t.Run("MalformedBaseURLSkipped", func(t *testing.T) { + t.Parallel() + + h := newReloadTestHarness(t) + // A BaseURL that fails url.Parse and one whose Hostname() is + // empty must both be dropped. Mixed with a valid entry, only + // the valid one routes. + h.store.set([]aibridgeproxyd.ProviderRoute{ + {Name: "malformed", BaseURL: "://not-a-url"}, + {Name: "no-host", BaseURL: "https://"}, + {Name: "valid", BaseURL: "https://valid.invalid/v1"}, + }) + require.NoError(t, h.srv.Reload(t.Context())) + + h.expectRoutedTo(t, "https://valid.invalid/v1/messages", "/api/v2/aibridge/valid/v1/messages") + }) + + t.Run("DuplicateHostFirstWins", func(t *testing.T) { + t.Parallel() + + h := newReloadTestHarness(t) + // Two providers with the same BaseURL host: the first one wins, + // matching buildProviderRouter's documented contract. + h.store.set([]aibridgeproxyd.ProviderRoute{ + {Name: "first", BaseURL: "https://shared.invalid/v1"}, + {Name: "second", BaseURL: "https://shared.invalid/v2"}, + }) + require.NoError(t, h.srv.Reload(t.Context())) + + h.expectRoutedTo(t, "https://shared.invalid/v1/messages", "/api/v2/aibridge/first/v1/messages") + }) + + t.Run("AllInvalidYieldsEmptyRouter", func(t *testing.T) { + t.Parallel() + + h := newReloadTestHarness(t) + // When every provider is invalid, the router contains no + // entries and the proxy fails closed: no host is MITM'd. + h.store.set([]aibridgeproxyd.ProviderRoute{ + {Name: "no-url"}, + {Name: "malformed", BaseURL: "://not-a-url"}, + {Name: "no-host", BaseURL: "https://"}, + }) + require.NoError(t, h.srv.Reload(t.Context())) + + h.expectNotRouted(t, "https://anything.invalid/v1/messages") + }) + + t.Run("RefreshErrorPreservesPreviousSnapshot", func(t *testing.T) { + t.Parallel() + + h := newReloadTestHarness(t) + // Seed a valid snapshot so we have something to preserve. + h.store.set([]aibridgeproxyd.ProviderRoute{ + {Name: "alpha", BaseURL: "https://alpha.invalid/v1"}, + }) + require.NoError(t, h.srv.Reload(t.Context())) + h.expectRoutedTo(t, "https://alpha.invalid/v1/messages", "/api/v2/aibridge/alpha/v1/messages") + + // A refresh error must NOT clear the router: dropping the + // allowlist on every transient DB hiccup would amplify the + // fault into a denial of service. + h.store.setErr(xerrors.New("simulated db failure")) + err := h.srv.Reload(t.Context()) + require.Error(t, err) + assert.Contains(t, err.Error(), "refresh ai providers for proxy routing") + h.expectRoutedTo(t, "https://alpha.invalid/v1/messages", "/api/v2/aibridge/alpha/v1/messages") + + // Recovery: once the store returns providers again, the next + // Reload applies the new snapshot. + h.store.set([]aibridgeproxyd.ProviderRoute{ + {Name: "beta", BaseURL: "https://beta.invalid/v1"}, + }) + require.NoError(t, h.srv.Reload(t.Context())) + h.expectRoutedTo(t, "https://beta.invalid/v1/messages", "/api/v2/aibridge/beta/v1/messages") + h.expectNotRouted(t, "https://alpha.invalid/v1/messages") + }) +} diff --git a/enterprise/cli/aibridgeproxyd.go b/enterprise/cli/aibridgeproxyd.go index abc5320e92..0f7ba976a5 100644 --- a/enterprise/cli/aibridgeproxyd.go +++ b/enterprise/cli/aibridgeproxyd.go @@ -4,27 +4,45 @@ package cli import ( "context" - "net/url" + "io" "path/filepath" - "strings" "github.com/prometheus/client_golang/prometheus" "golang.org/x/xerrors" - "github.com/coder/coder/v2/aibridge" + "cdr.dev/slog/v3" "github.com/coder/coder/v2/aibridge/intercept/apidump" + "github.com/coder/coder/v2/coderd/aibridged" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/enterprise/aibridgeproxyd" "github.com/coder/coder/v2/enterprise/coderd" ) -func newAIBridgeProxyDaemon(coderAPI *coderd.API, providers []aibridge.Provider) (*aibridgeproxyd.Server, error) { +// aiBridgeProxyDaemon bundles the proxy server and its pubsub +// subscription so both are torn down by a single Close call. +type aiBridgeProxyDaemon struct { + server *aibridgeproxyd.Server + unsubscribe func() +} + +func (d *aiBridgeProxyDaemon) Close() error { + if d.unsubscribe != nil { + d.unsubscribe() + } + return d.server.Close() +} + +// newAIBridgeProxyDaemon starts the enterprise aibridge proxy daemon, +// subscribes to ai_providers changes so the proxy's routing snapshot +// tracks the database, and registers the HTTP handler on the API. +// The returned io.Closer tears down both the subscription and server. +func newAIBridgeProxyDaemon(coderAPI *coderd.API) (io.Closer, error) { ctx := context.Background() coderAPI.Logger.Debug(ctx, "starting in-memory aibridgeproxy daemon") logger := coderAPI.Logger.Named("aibridgeproxyd") - domains, providerFromHost := domainsFromProviders(providers) - reg := prometheus.WrapRegistererWithPrefix("coder_aibridgeproxyd_", coderAPI.PrometheusRegistry) metrics := aibridgeproxyd.NewMetrics(reg) @@ -36,55 +54,51 @@ func newAIBridgeProxyDaemon(coderAPI *coderd.API, providers []aibridge.Provider) } srv, err := aibridgeproxyd.New(ctx, logger, aibridgeproxyd.Options{ - ListenAddr: coderAPI.DeploymentValues.AI.BridgeProxyConfig.ListenAddr.String(), - TLSCertFile: coderAPI.DeploymentValues.AI.BridgeProxyConfig.TLSCertFile.String(), - TLSKeyFile: coderAPI.DeploymentValues.AI.BridgeProxyConfig.TLSKeyFile.String(), - CoderAccessURL: coderAPI.AccessURL.String(), - MITMCertFile: coderAPI.DeploymentValues.AI.BridgeProxyConfig.MITMCertFile.String(), - MITMKeyFile: coderAPI.DeploymentValues.AI.BridgeProxyConfig.MITMKeyFile.String(), - DomainAllowlist: domains, - AIBridgeProviderFromHost: providerFromHost, - UpstreamProxy: coderAPI.DeploymentValues.AI.BridgeProxyConfig.UpstreamProxy.String(), - UpstreamProxyCA: coderAPI.DeploymentValues.AI.BridgeProxyConfig.UpstreamProxyCA.String(), - AllowedPrivateCIDRs: coderAPI.DeploymentValues.AI.BridgeProxyConfig.AllowedPrivateCIDRs.Value(), - NewDumper: newDumper, - Metrics: metrics, + ListenAddr: coderAPI.DeploymentValues.AI.BridgeProxyConfig.ListenAddr.String(), + TLSCertFile: coderAPI.DeploymentValues.AI.BridgeProxyConfig.TLSCertFile.String(), + TLSKeyFile: coderAPI.DeploymentValues.AI.BridgeProxyConfig.TLSKeyFile.String(), + CoderAccessURL: coderAPI.AccessURL.String(), + MITMCertFile: coderAPI.DeploymentValues.AI.BridgeProxyConfig.MITMCertFile.String(), + MITMKeyFile: coderAPI.DeploymentValues.AI.BridgeProxyConfig.MITMKeyFile.String(), + UpstreamProxy: coderAPI.DeploymentValues.AI.BridgeProxyConfig.UpstreamProxy.String(), + UpstreamProxyCA: coderAPI.DeploymentValues.AI.BridgeProxyConfig.UpstreamProxyCA.String(), + AllowedPrivateCIDRs: coderAPI.DeploymentValues.AI.BridgeProxyConfig.AllowedPrivateCIDRs.Value(), + NewDumper: newDumper, + Metrics: metrics, + RefreshProviders: refreshProxyProviders(coderAPI.Database), }) if err != nil { return nil, xerrors.Errorf("failed to start in-memory aibridgeproxy daemon: %w", err) } - return srv, nil -} - -// domainsFromProviders extracts distinct hostnames from providers' base -// URLs and builds a host-to-provider-name mapping function. The returned -// domain list is suitable for use as DomainAllowlist and the mapping -// function is suitable for use as AIBridgeProviderFromHost. -func domainsFromProviders(providers []aibridge.Provider) ([]string, func(string) string) { - hostToProvider := make(map[string]string, len(providers)) - var domains []string - for _, p := range providers { - raw := p.BaseURL() - if raw == "" { - continue - } - u, err := url.Parse(raw) - if err != nil || u.Hostname() == "" { - continue - } - host := strings.ToLower(u.Hostname()) - if _, exists := hostToProvider[host]; exists { - // First provider wins; duplicates are expected when - // multiple providers share a base URL host (e.g. two - // OpenAI providers using the same proxy). - continue - } - hostToProvider[host] = p.Name() - domains = append(domains, host) + unsubscribe, err := aibridged.SubscribeProviderReload(ctx, coderAPI.Pubsub, srv, logger.Named("provider-reload")) + if err != nil { + logger.Warn(ctx, "subscribe aibridgeproxyd to ai providers change channel", slog.Error(err)) + unsubscribe = func() {} } - return domains, func(host string) string { - return hostToProvider[strings.ToLower(host)] + // Register the handler so coderd can serve the proxy endpoints. + coderAPI.RegisterInMemoryAIBridgeProxydHTTPHandler(srv.Handler()) + + return &aiBridgeProxyDaemon{ + server: srv, + unsubscribe: unsubscribe, + }, nil +} + +func refreshProxyProviders(db database.Store) aibridgeproxyd.RefreshProvidersFunc { + return func(ctx context.Context) ([]aibridgeproxyd.ProviderRoute, error) { + //nolint:gocritic // AsAIProviderMetadataReader is the correct subject for routing-only access. + rows, err := db.GetAIProviders(dbauthz.AsAIProviderMetadataReader(ctx), database.GetAIProvidersParams{ + IncludeDisabled: false, + }) + if err != nil { + return nil, xerrors.Errorf("load ai providers: %w", err) + } + out := make([]aibridgeproxyd.ProviderRoute, 0, len(rows)) + for _, row := range rows { + out = append(out, aibridgeproxyd.ProviderRoute{Name: row.Name, BaseURL: row.BaseUrl}) + } + return out, nil } } diff --git a/enterprise/cli/aibridgeproxyd_internal_test.go b/enterprise/cli/aibridgeproxyd_internal_test.go deleted file mode 100644 index f7b25a549c..0000000000 --- a/enterprise/cli/aibridgeproxyd_internal_test.go +++ /dev/null @@ -1,72 +0,0 @@ -//go:build !slim - -package cli - -import ( - "testing" - - "github.com/stretchr/testify/assert" - - "github.com/coder/coder/v2/aibridge" -) - -func TestDomainsFromProviders(t *testing.T) { - t.Parallel() - - t.Run("ExtractsHostnames", func(t *testing.T) { - t.Parallel() - - providers := []aibridge.Provider{ - aibridge.NewOpenAIProvider(aibridge.OpenAIConfig{Name: "openai", BaseURL: "https://api.openai.com/v1/"}), - aibridge.NewAnthropicProvider(aibridge.AnthropicConfig{Name: "anthropic", BaseURL: "https://api.anthropic.com/"}, nil), - aibridge.NewOpenAIProvider(aibridge.OpenAIConfig{Name: "custom", BaseURL: "https://custom-llm.example.com:8443/api"}), - } - - domains, mapping := domainsFromProviders(providers) - - assert.Contains(t, domains, "api.openai.com") - assert.Contains(t, domains, "api.anthropic.com") - assert.Contains(t, domains, "custom-llm.example.com") - - assert.Equal(t, "openai", mapping("api.openai.com")) - assert.Equal(t, "anthropic", mapping("api.anthropic.com")) - assert.Equal(t, "custom", mapping("custom-llm.example.com")) - assert.Empty(t, mapping("unknown.com")) - }) - - t.Run("DeduplicatesSameHost", func(t *testing.T) { - t.Parallel() - - providers := []aibridge.Provider{ - aibridge.NewOpenAIProvider(aibridge.OpenAIConfig{Name: "first", BaseURL: "https://api.example.com/v1"}), - aibridge.NewOpenAIProvider(aibridge.OpenAIConfig{Name: "second", BaseURL: "https://api.example.com/v2"}), - } - - domains, mapping := domainsFromProviders(providers) - - // Count occurrences of api.example.com. - count := 0 - for _, d := range domains { - if d == "api.example.com" { - count++ - } - } - assert.Equal(t, 1, count) - // First provider wins. - assert.Equal(t, "first", mapping("api.example.com")) - }) - - t.Run("CaseInsensitive", func(t *testing.T) { - t.Parallel() - - providers := []aibridge.Provider{ - aibridge.NewOpenAIProvider(aibridge.OpenAIConfig{Name: "provider", BaseURL: "https://API.Example.COM/v1"}), - } - - domains, mapping := domainsFromProviders(providers) - - assert.Contains(t, domains, "api.example.com") - assert.Equal(t, "provider", mapping("API.Example.COM")) - assert.Equal(t, "provider", mapping("api.example.com")) - }) -} diff --git a/enterprise/cli/server.go b/enterprise/cli/server.go index 5ef7db1b2f..c77a03a0b4 100644 --- a/enterprise/cli/server.go +++ b/enterprise/cli/server.go @@ -15,7 +15,6 @@ import ( "tailscale.com/derp" "tailscale.com/types/key" - agplcli "github.com/coder/coder/v2/cli" agplcoderd "github.com/coder/coder/v2/coderd" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/cryptorand" @@ -167,13 +166,14 @@ func (r *RootCmd) Server(_ func()) *serpent.Command { // in-memory roundtripper regardless of license); only the proxy // daemon remains enterprise-gated by config. if options.DeploymentValues.AI.BridgeProxyConfig.Enabled.Value() { - // Seed env-derived providers before reading them back so the - // proxy observes them on first startup. options.Database is - // dbcrypt-wrapped at this point (set by coderd.New above), - // so env-seeded keys are also written encrypted. Detached - // ctx for the same reason as in agplcli below: an early - // return would orphan newAPI's goroutines. Seeding is - // idempotent; the agplcli path seeds again post-newAPI. + // Seed env-derived providers before the proxy daemon's reloader + // reads them back so the proxy observes them on first startup. + // options.Database is dbcrypt-wrapped at this point (set by + // coderd.New above), so env-seeded keys are also written + // encrypted. Detached ctx for the same reason as in agplcli + // below: an early return would orphan newAPI's goroutines. + // Seeding is idempotent; the agplcli path seeds again + // post-newAPI. //nolint:gocritic // Production timeout, not a test wait. aibridgeInitCtx, aibridgeInitCancel := context.WithTimeout(context.WithoutCancel(ctx), 30*time.Second) defer aibridgeInitCancel() @@ -185,19 +185,12 @@ func (r *RootCmd) Server(_ func()) *serpent.Command { ); err != nil { return nil, nil, xerrors.Errorf("seed ai providers from env: %w", err) } - providers, err := agplcli.BuildProviders(aibridgeInitCtx, options.Database, options.DeploymentValues.AI.BridgeConfig, options.Logger.Named("aibridge.providers")) - if err != nil { - return nil, nil, xerrors.Errorf("build AI providers: %w", err) - } - aiBridgeProxyServer, err := newAIBridgeProxyDaemon(api, providers) + aiBridgeProxyCloser, err := newAIBridgeProxyDaemon(api) if err != nil { _ = closers.Close() return nil, nil, xerrors.Errorf("create aibridgeproxyd: %w", err) } - closers.Add(aiBridgeProxyServer) - - // Register the handler so coderd can serve the proxy endpoints. - api.RegisterInMemoryAIBridgeProxydHTTPHandler(aiBridgeProxyServer.Handler()) + closers.Add(aiBridgeProxyCloser) } return api.AGPL, closers, nil diff --git a/enterprise/coderd/aibridge_reload_test.go b/enterprise/coderd/aibridge_reload_test.go new file mode 100644 index 0000000000..6aee3afa91 --- /dev/null +++ b/enterprise/coderd/aibridge_reload_test.go @@ -0,0 +1,233 @@ +package coderd_test + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "sync/atomic" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel" + + "cdr.dev/slog/v3" + "cdr.dev/slog/v3/sloggers/slogtest" + "github.com/coder/coder/v2/cli" + "github.com/coder/coder/v2/coderd" + "github.com/coder/coder/v2/coderd/aibridged" + "github.com/coder/coder/v2/coderd/coderdtest" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/enterprise/coderd/coderdenttest" + "github.com/coder/coder/v2/enterprise/coderd/license" + "github.com/coder/coder/v2/testutil" + "github.com/coder/serpent" +) + +// mockUpstream is a single httptest server identified by a unique +// marker that it echoes in every response body, so callers can verify +// which upstream a proxied request actually reached. The hit counter +// supports asserting the upstream was touched at all. +type mockUpstream struct { + server *httptest.Server + name string + hits atomic.Int32 +} + +func newMockUpstream(t *testing.T, name string) *mockUpstream { + t.Helper() + m := &mockUpstream{name: name} + m.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + m.hits.Add(1) + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + assert.NoError(t, json.NewEncoder(w).Encode(map[string]string{"upstream": name})) + })) + t.Cleanup(m.server.Close) + return m +} + +// startTestAIBridgeDaemon wires an in-process aibridged daemon onto +// the supplied API and subscribes it to ai_providers change events. +// This mirrors what cli/server.go does in production so /api/v2/aibridge +// requests dispatch through the real pool and reloader. +func startTestAIBridgeDaemon(t *testing.T, api *coderd.API) { + t.Helper() + + ctx := context.Background() + logger := slogtest.Make(t, nil).Named("aibridged").Leveled(slog.LevelDebug) + cfg := api.DeploymentValues.AI.BridgeConfig + tracer := otel.Tracer("aibridge-reload-test") + + providers, err := cli.BuildProviders(ctx, api.Database, cfg, logger) + require.NoError(t, err) + + pool, err := aibridged.NewCachedBridgePool(aibridged.DefaultPoolOptions, providers, logger.Named("pool"), nil, tracer) + require.NoError(t, err) + t.Cleanup(func() { _ = pool.Shutdown(context.Background()) }) + + reloader := &testPoolReloader{pool: pool, db: api.Database, cfg: cfg, logger: logger.Named("reloader")} + unsubscribe, err := aibridged.SubscribeProviderReload(ctx, api.Pubsub, reloader, logger.Named("subscriber")) + require.NoError(t, err) + t.Cleanup(unsubscribe) + + srv, err := aibridged.New(ctx, pool, func(dialCtx context.Context) (aibridged.DRPCClient, error) { + return api.CreateInMemoryAIBridgeServer(dialCtx) + }, logger, tracer) + require.NoError(t, err) + t.Cleanup(func() { _ = srv.Close() }) + + api.RegisterInMemoryAIBridgedHTTPHandler(srv) +} + +type testPoolReloader struct { + pool *aibridged.CachedBridgePool + db database.Store + cfg codersdk.AIBridgeConfig + logger slog.Logger +} + +func (r *testPoolReloader) Reload(ctx context.Context) error { + providers, err := cli.BuildProviders(ctx, r.db, r.cfg, r.logger) + if err != nil { + return err + } + r.pool.ReplaceProviders(providers) + return nil +} + +// TestAIBridgeProviderHotReload exercises the end-to-end CRUD -> +// reload -> routing path: every provider mutation made through codersdk +// must, within a short window, change the routing observed at +// /api/v2/aibridge/{name}/v1/models. The OpenAI passthrough route +// /v1/models reverse-proxies to BaseURL, so the upstream that responds +// identifies which provider the daemon's mux dispatched to. +func TestAIBridgeProviderHotReload(t *testing.T) { + t.Parallel() + + // Two distinct upstreams so an Update that swings the BaseURL is + // observable: which upstream answers tells us which BaseURL the + // freshly-built provider is pointed at. + upstreamA := newMockUpstream(t, "a") + upstreamB := newMockUpstream(t, "b") + + dv := coderdtest.DeploymentValues(t) + dv.AI.BridgeConfig.Enabled = serpent.Bool(true) + + client, _, api, _ := coderdenttest.NewWithAPI(t, &coderdenttest.Options{ + Options: &coderdtest.Options{DeploymentValues: dv}, + LicenseOptions: &coderdenttest.LicenseOptions{ + Features: license.Features{codersdk.FeatureAIBridge: 1}, + }, + }) + + startTestAIBridgeDaemon(t, api.AGPL) + + ctx := testutil.Context(t, testutil.WaitLong) + + // sendRequest issues GET /api/v2/aibridge/{name}/v1/models and + // returns the status and the upstream marker decoded from the + // JSON body (empty if the body was not the marker JSON). + sendRequest := func(providerName string) (int, string) { + url := client.URL.String() + "/api/v2/aibridge/" + providerName + "/v1/models" + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + require.NoError(t, err) + req.Header.Set("Authorization", "Bearer "+client.SessionToken()) + + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + return resp.StatusCode, "" + } + var decoded map[string]string + _ = json.Unmarshal(body, &decoded) + return resp.StatusCode, decoded["upstream"] + } + + // requireRoutesTo polls until the routing reflects the expected + // upstream. The pool reloads asynchronously from a pubsub event; + // require.Eventually is the natural fit. + requireRoutesTo := func(t *testing.T, providerName string, upstream *mockUpstream) { + t.Helper() + before := upstream.hits.Load() + require.Eventuallyf(t, func() bool { + status, marker := sendRequest(providerName) + return status == http.StatusOK && marker == upstream.name + }, testutil.WaitShort, testutil.IntervalFast, + "expected provider %q to route to upstream %q", providerName, upstream.name) + require.Greater(t, upstream.hits.Load(), before, + "upstream %q must have observed at least one request", upstream.name) + } + + // requireRoutingGone polls until the provider name yields a 404 + // from the aibridge mux's catch-all, indicating the provider has + // been removed from the pool snapshot. + requireRoutingGone := func(t *testing.T, providerName string) { + t.Helper() + require.Eventuallyf(t, func() bool { + status, _ := sendRequest(providerName) + return status == http.StatusNotFound + }, testutil.WaitShort, testutil.IntervalFast, + "expected provider %q to stop routing", providerName) + } + + // 1. Create: provider points at upstream A. + created, err := client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{ + Type: codersdk.AIProviderTypeOpenAI, + Name: "primary", + Enabled: true, + BaseURL: upstreamA.server.URL, + APIKeys: []string{"sk-primary-key"}, + }) + require.NoError(t, err) + require.Equal(t, "primary", created.Name) + requireRoutesTo(t, "primary", upstreamA) + + // 2. Update BaseURL: same name, now points at upstream B. + newBaseURL := upstreamB.server.URL + _, err = client.UpdateAIProvider(ctx, "primary", codersdk.UpdateAIProviderRequest{ + BaseURL: &newBaseURL, + }) + require.NoError(t, err) + requireRoutesTo(t, "primary", upstreamB) + + // 3. Disable: the provider drops out of the snapshot, requests + // stop reaching any upstream. + disabled := false + _, err = client.UpdateAIProvider(ctx, "primary", codersdk.UpdateAIProviderRequest{ + Enabled: &disabled, + }) + require.NoError(t, err) + requireRoutingGone(t, "primary") + + // 4. Re-enable: routing comes back at the most recent BaseURL. + enabled := true + _, err = client.UpdateAIProvider(ctx, "primary", codersdk.UpdateAIProviderRequest{ + Enabled: &enabled, + }) + require.NoError(t, err) + requireRoutesTo(t, "primary", upstreamB) + + // 5. Add a second provider; both names must route independently. + _, err = client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{ + Type: codersdk.AIProviderTypeOpenAI, + Name: "secondary", + Enabled: true, + BaseURL: upstreamA.server.URL, + APIKeys: []string{"sk-secondary-key"}, + }) + require.NoError(t, err) + requireRoutesTo(t, "primary", upstreamB) + requireRoutesTo(t, "secondary", upstreamA) + + // 6. Delete primary: only secondary remains routable. + require.NoError(t, client.DeleteAIProvider(ctx, "primary")) + requireRoutingGone(t, "primary") + requireRoutesTo(t, "secondary", upstreamA) +}