feat: hot-reload aibridged and aibridgeproxyd providers on DB changes (#25673)

Previously the in-process aibridge daemon and the enterprise aibridgeproxy daemon both snapshotted their provider routing once at boot. Any `ai_providers` or `ai_provider_keys` mutation required a restart for either to pick it up.

Add an `ai_providers_changed` pubsub channel that the CRUD handlers publish on after Create / Update / Delete. Both daemons subscribe:

- **aibridged** rebuilds its `[]aibridge.Provider` snapshot via `BuildProviders` and swaps it into the pool atomically. Inflight requests keep serving against the bridge they already acquired; new acquires build against the new snapshot. Per-provider construction errors stay scoped to the offending row.
- **aibridgeproxyd** rebuilds its routing snapshot from `GetAIProviders` and swaps the host→provider map atomically. The MITM listener picks up new providers without restart.

DB read for aibridgeproxyd uses the existing `AsAIProviderMetadataReader` subject for routing-only access.
This commit is contained in:
Danny Kopping
2026-05-27 11:58:43 +02:00
committed by GitHub
parent 6acfe6c835
commit 79e007cf30
19 changed files with 1678 additions and 227 deletions
+50 -4
View File
@@ -24,7 +24,12 @@ import (
"github.com/coder/quartz"
)
func newAIBridgeDaemon(coderAPI *coderd.API, providers []aibridge.Provider) (*aibridged.Server, error) {
// newAIBridgeDaemon constructs the in-memory aibridge daemon and wires
// up a subscription that hot-reloads the provider pool from the
// database on every ai_providers change event. The returned unsubscribe
// function tears down the subscription; callers must invoke it
// alongside Server.Close on shutdown.
func newAIBridgeDaemon(coderAPI *coderd.API, providers []aibridge.Provider, cfg codersdk.AIBridgeConfig) (*aibridged.Server, func(), error) {
ctx := context.Background()
coderAPI.Logger.Debug(ctx, "starting in-memory aibridge daemon")
@@ -37,7 +42,25 @@ func newAIBridgeDaemon(coderAPI *coderd.API, providers []aibridge.Provider) (*ai
// Create pool for reusable stateful [aibridge.RequestBridge] instances (one per user).
pool, err := aibridged.NewCachedBridgePool(aibridged.DefaultPoolOptions, providers, logger.Named("pool"), metrics, tracer) // TODO: configurable size.
if err != nil {
return nil, xerrors.Errorf("create request pool: %w", err)
return nil, nil, xerrors.Errorf("create request pool: %w", err)
}
// Subscribe to ai_providers change events so the pool tracks the
// database without a restart. The boot-time `providers` snapshot
// derives from env config and serves as a fallback if the database
// load fails inside the reloader.
reloader := &poolDBReloader{
pool: pool,
db: coderAPI.Database,
cfg: cfg,
logger: logger.Named("provider-loader"),
}
unsubscribe, err := aibridged.SubscribeProviderReload(ctx, coderAPI.Pubsub, reloader, logger.Named("provider-reload"))
if err != nil {
// Pool is still usable with the boot-time snapshot; subscription
// failure is logged but not fatal so the daemon still serves.
logger.Warn(ctx, "subscribe to ai providers change channel", slog.Error(err))
unsubscribe = func() {}
}
// Create daemon.
@@ -45,9 +68,32 @@ func newAIBridgeDaemon(coderAPI *coderd.API, providers []aibridge.Provider) (*ai
return coderAPI.CreateInMemoryAIBridgeServer(dialCtx)
}, logger, tracer)
if err != nil {
return nil, xerrors.Errorf("start in-memory aibridge daemon: %w", err)
unsubscribe()
return nil, nil, xerrors.Errorf("start in-memory aibridge daemon: %w", err)
}
return srv, nil
return srv, unsubscribe, nil
}
// poolDBReloader implements [aibridged.ProviderReloader] by loading
// the live provider set from the database and forwarding it to the
// pool.
type poolDBReloader struct {
pool *aibridged.CachedBridgePool
db database.Store
cfg codersdk.AIBridgeConfig
logger slog.Logger
}
func (r *poolDBReloader) Reload(ctx context.Context) error {
providers, err := BuildProviders(ctx, r.db, r.cfg, r.logger)
if err != nil {
// Keep the previous snapshot in place: dropping all providers
// because the DB read failed would compound the visible failure
// mode beyond the operator's actual misconfiguration.
return xerrors.Errorf("load ai providers from database: %w", err)
}
r.pool.ReplaceProviders(providers)
return nil
}
// BuildProviders loads every enabled ai_providers row, attaches its
+3 -1
View File
@@ -1046,7 +1046,8 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
if err != nil {
return xerrors.Errorf("build AI providers: %w", err)
}
aibridgeDaemon, err = newAIBridgeDaemon(coderAPI, aibridgeProviders)
var unsubscribeProviderReload func()
aibridgeDaemon, unsubscribeProviderReload, err = newAIBridgeDaemon(coderAPI, aibridgeProviders, vals.AI.BridgeConfig)
if err != nil {
return xerrors.Errorf("create aibridged: %w", err)
}
@@ -1055,6 +1056,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd.
// daemon does not affect in-flight requests but is needed to
// release pool/recorder resources at shutdown.
defer aibridgeDaemon.Close()
defer unsubscribeProviderReload()
}
if vals.Prometheus.Enable {
+19
View File
@@ -21,6 +21,7 @@ import (
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/httpmw"
coderpubsub "github.com/coder/coder/v2/coderd/pubsub"
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/codersdk"
)
@@ -235,6 +236,7 @@ func (api *API) aiProvidersCreate(rw http.ResponseWriter, r *http.Request) {
aReq.New = row
auditAIProviderKeyChanges(ctx, r, *auditor, api.Logger, aiProviderKeyChanges{Added: keys})
api.publishAIProvidersChanged(ctx)
sdk, err := db2sdk.AIProvider(row, keys)
if err != nil {
@@ -400,6 +402,7 @@ func (api *API) aiProvidersUpdate(rw http.ResponseWriter, r *http.Request) {
}
auditAIProviderKeyChanges(ctx, r, *auditor, api.Logger, keyChanges)
api.publishAIProvidersChanged(ctx)
sdk, err := db2sdk.AIProvider(updated, keys)
if err != nil {
@@ -453,9 +456,25 @@ func (api *API) aiProvidersDelete(rw http.ResponseWriter, r *http.Request) {
return
}
api.publishAIProvidersChanged(ctx)
rw.WriteHeader(http.StatusNoContent)
}
// publishAIProvidersChanged notifies subscribers (aibridged,
// aibridgeproxyd) that the live provider set changed and they should
// refetch from the database. Pubsub failures are logged but not
// propagated: subscribers refresh authoritatively from the DB, so a
// dropped notification only delays convergence.
func (api *API) publishAIProvidersChanged(ctx context.Context) {
if api.Pubsub == nil {
return
}
if err := api.Pubsub.Publish(coderpubsub.AIProvidersChangedChannel, nil); err != nil {
api.Logger.Warn(ctx, "publish ai providers changed event", slog.Error(err))
}
}
// errBedrockRejectsAPIKeys is the sentinel returned from inside the
// update transaction when a caller attempts to attach api_keys to a
// Bedrock-typed provider; the outer handler translates it into a 400.
+62
View File
@@ -0,0 +1,62 @@
package coderd_test
import (
"context"
"sync/atomic"
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/coderdtest"
coderpubsub "github.com/coder/coder/v2/coderd/pubsub"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
)
// TestAIProvidersChangedPubsub asserts that the CRUD handlers publish
// on AIProvidersChangedChannel for the operations that affect the
// runtime provider set. Subscribers (aibridged, aibridgeproxyd) depend
// on these notifications to trigger their pool reload.
//
// The handlers publish best-effort and the payload is empty, so we
// assert "at least one event per mutation" via a counter.
func TestAIProvidersChangedPubsub(t *testing.T) {
t.Parallel()
client, _, api := coderdtest.NewWithAPI(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
ctx := testutil.Context(t, testutil.WaitLong)
var count atomic.Int64
unsubscribe, err := api.Pubsub.Subscribe(coderpubsub.AIProvidersChangedChannel, func(_ context.Context, _ []byte) {
count.Add(1)
})
require.NoError(t, err)
t.Cleanup(unsubscribe)
// Create.
req := codersdk.CreateAIProviderRequest{
Type: codersdk.AIProviderTypeOpenAI,
Name: "pubsub-openai",
Enabled: true,
BaseURL: "https://api.openai.com/v1/",
APIKeys: []string{"k1"},
}
//nolint:gocritic // Owner role is the audience for this endpoint.
created, err := client.CreateAIProvider(ctx, req)
require.NoError(t, err)
testutil.Eventually(ctx, t, func(_ context.Context) bool { return count.Load() >= 1 }, testutil.IntervalFast)
// Update.
newKey := "k2"
_, err = client.UpdateAIProvider(ctx, created.ID.String(), codersdk.UpdateAIProviderRequest{
APIKeys: &[]codersdk.AIProviderKeyMutation{{APIKey: &newKey}},
})
require.NoError(t, err)
testutil.Eventually(ctx, t, func(_ context.Context) bool { return count.Load() >= 2 }, testutil.IntervalFast)
// Delete.
err = client.DeleteAIProvider(ctx, created.ID.String())
require.NoError(t, err)
testutil.Eventually(ctx, t, func(_ context.Context) bool { return count.Load() >= 3 }, testutil.IntervalFast)
}
@@ -14,6 +14,7 @@ import (
http "net/http"
reflect "reflect"
aibridge "github.com/coder/coder/v2/aibridge"
aibridged "github.com/coder/coder/v2/coderd/aibridged"
gomock "go.uber.org/mock/gomock"
)
@@ -57,6 +58,18 @@ func (mr *MockPoolerMockRecorder) Acquire(ctx, req, clientFn, mcpBootstrapper an
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Acquire", reflect.TypeOf((*MockPooler)(nil).Acquire), ctx, req, clientFn, mcpBootstrapper)
}
// ReplaceProviders mocks base method.
func (m *MockPooler) ReplaceProviders(providers []aibridge.Provider) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "ReplaceProviders", providers)
}
// ReplaceProviders indicates an expected call of ReplaceProviders.
func (mr *MockPoolerMockRecorder) ReplaceProviders(providers any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReplaceProviders", reflect.TypeOf((*MockPooler)(nil).ReplaceProviders), providers)
}
// Shutdown mocks base method.
func (m *MockPooler) Shutdown(ctx context.Context) error {
m.ctrl.T.Helper()
+71 -21
View File
@@ -3,7 +3,10 @@ package aibridged
import (
"context"
"net/http"
"slices"
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/dgraph-io/ristretto/v2"
@@ -26,6 +29,9 @@ const (
// One [*aibridge.RequestBridge] instance is created per given key.
type Pooler interface {
Acquire(ctx context.Context, req Request, clientFn ClientFunc, mcpBootstrapper MCPProxyBuilder) (http.Handler, error)
// ReplaceProviders swaps the providers used to construct future
// RequestBridge instances and clears the cache.
ReplaceProviders(providers []aibridge.Provider)
Shutdown(ctx context.Context) error
}
@@ -46,10 +52,12 @@ var DefaultPoolOptions = PoolOptions{MaxItems: 5000, TTL: time.Minute * 15}
var _ Pooler = &CachedBridgePool{}
type CachedBridgePool struct {
cache *ristretto.Cache[string, *aibridge.RequestBridge]
providers []aibridge.Provider
logger slog.Logger
options PoolOptions
cache *ristretto.Cache[string, *aibridge.RequestBridge]
// providers is the live provider set used by new RequestBridge instances.
providers atomic.Pointer[[]aibridge.Provider]
providerVersion atomic.Int64
logger slog.Logger
options PoolOptions
singleflight *singleflight.Group[string, *aibridge.RequestBridge]
@@ -71,13 +79,16 @@ func NewCachedBridgePool(options PoolOptions, providers []aibridge.Provider, log
if item == nil || item.Value == nil {
return
}
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), time.Second*5)
defer shutdownCancel()
// Run the eviction in the background since ristretto blocks sets until a free slot is available.
// Capture the value synchronously: ristretto reuses the
// item slot after OnEvict returns, so reading item.Value
// from the goroutine below races with the caller of
// Clear/Set. The shutdown still runs in the background to
// avoid blocking ristretto's eviction loop.
bridge := item.Value
go func() {
_ = item.Value.Shutdown(shutdownCtx)
shutdownCtx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
_ = bridge.Shutdown(shutdownCtx)
}()
},
})
@@ -85,18 +96,53 @@ func NewCachedBridgePool(options PoolOptions, providers []aibridge.Provider, log
return nil, xerrors.Errorf("create cache: %w", err)
}
return &CachedBridgePool{
cache: cache,
providers: providers,
options: options,
metrics: metrics,
tracer: tracer,
logger: logger,
pool := &CachedBridgePool{
cache: cache,
options: options,
metrics: metrics,
tracer: tracer,
logger: logger,
singleflight: &singleflight.Group[string, *aibridge.RequestBridge]{},
shuttingDownCh: make(chan struct{}),
}, nil
}
initial := slices.Clone(providers)
pool.providers.Store(&initial)
return pool, nil
}
// ReplaceProviders swaps the provider snapshot used by future Acquires.
// It is safe to call concurrently with Acquire and is a no-op after
// Shutdown.
func (p *CachedBridgePool) ReplaceProviders(providers []aibridge.Provider) {
select {
case <-p.shuttingDownCh:
return
default:
}
snapshot := slices.Clone(providers)
p.providers.Store(&snapshot)
version := time.Now().UnixNano()
p.providerVersion.Store(version)
// Clear evicts every cached bridge; OnEvict shuts each one down in
// the background. Wait for buffered writes to drain so a replacement
// immediately followed by an Acquire always sees the cleared cache.
p.cache.Clear()
p.cache.Wait()
p.logger.Info(context.Background(), "request bridge pool reloaded",
slog.F("provider_count", len(snapshot)),
slog.F("provider_version", version),
)
}
// loadProviders returns the current providers snapshot. The returned
// slice must not be mutated.
func (p *CachedBridgePool) loadProviders() []aibridge.Provider {
if ptr := p.providers.Load(); ptr != nil {
return *ptr
}
return nil
}
// Acquire retrieves or creates a [*aibridge.RequestBridge] instance per given key.
@@ -140,6 +186,7 @@ func (p *CachedBridgePool) Acquire(ctx context.Context, req Request, clientFn Cl
}
span.AddEvent("cache_miss")
providerVersion := p.providerVersion.Load()
recorder := aibridge.NewRecorder(p.logger.Named("recorder"), p.tracer, func() (aibridge.Recorder, error) {
client, err := clientFn()
if err != nil {
@@ -152,7 +199,8 @@ func (p *CachedBridgePool) Acquire(ctx context.Context, req Request, clientFn Cl
// Slow path.
// Creating an *aibridge.RequestBridge may take some time, so gate all subsequent callers behind the initial request and return the resulting value.
// TODO: track startup time since it adds latency to first request (histogram count will also help us see how often this occurs).
instance, err, _ := p.singleflight.Do(req.InitiatorID.String(), func() (*aibridge.RequestBridge, error) {
singleflightKey := cacheKey + "|" + strconv.FormatInt(providerVersion, 10)
instance, err, _ := p.singleflight.Do(singleflightKey, func() (*aibridge.RequestBridge, error) {
var (
mcpServers mcp.ServerProxier
err error
@@ -171,12 +219,14 @@ func (p *CachedBridgePool) Acquire(ctx context.Context, req Request, clientFn Cl
}
}
bridge, err := aibridge.NewRequestBridge(ctx, p.providers, recorder, mcpServers, p.logger, p.metrics, p.tracer)
bridge, err := aibridge.NewRequestBridge(ctx, p.loadProviders(), recorder, mcpServers, p.logger, p.metrics, p.tracer)
if err != nil {
return nil, xerrors.Errorf("create new request bridge: %w", err)
}
p.cache.SetWithTTL(cacheKey, bridge, cacheCost, p.options.TTL)
if p.providerVersion.Load() == providerVersion {
p.cache.SetWithTTL(cacheKey, bridge, cacheCost, p.options.TTL)
}
return bridge, nil
})
+215
View File
@@ -2,6 +2,10 @@ package aibridged_test
import (
"context"
"io"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"testing/synctest"
"time"
@@ -12,10 +16,13 @@ import (
"go.uber.org/mock/gomock"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/aibridge"
"github.com/coder/coder/v2/aibridge/config"
"github.com/coder/coder/v2/aibridge/mcp"
"github.com/coder/coder/v2/aibridge/mcpmock"
"github.com/coder/coder/v2/coderd/aibridged"
mock "github.com/coder/coder/v2/coderd/aibridged/aibridgedmock"
"github.com/coder/coder/v2/testutil"
)
// TestPool validates the published behavior of [aibridged.CachedBridgePool].
@@ -107,6 +114,160 @@ func TestPool(t *testing.T) {
require.EqualValues(t, 3, cacheMetrics.Misses())
}
func TestPoolReplaceProvidersClearsCacheAndUsesNewProviders(t *testing.T) {
t.Parallel()
oldUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_, _ = io.WriteString(w, "old")
}))
t.Cleanup(oldUpstream.Close)
newUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_, _ = io.WriteString(w, "new")
}))
t.Cleanup(newUpstream.Close)
logger := slogtest.Make(t, nil)
ctrl := gomock.NewController(t)
client := mock.NewMockDRPCClient(ctrl)
mcpProxy := mcpmock.NewMockServerProxier(ctrl)
mcpProxy.EXPECT().Init(gomock.Any()).AnyTimes().Return(nil)
mcpProxy.EXPECT().Shutdown(gomock.Any()).AnyTimes().Return(nil)
opts := aibridged.PoolOptions{MaxItems: 1, TTL: time.Minute}
pool, err := aibridged.NewCachedBridgePool(opts, []aibridge.Provider{
aibridge.NewOpenAIProvider(config.OpenAI{Name: "old", BaseURL: oldUpstream.URL}),
}, logger, nil, testTracer)
require.NoError(t, err)
t.Cleanup(func() { _ = pool.Shutdown(context.Background()) })
req := aibridged.Request{
SessionKey: "key",
InitiatorID: uuid.New(),
APIKeyID: uuid.New().String(),
}
clientFn := func() (aibridged.DRPCClient, error) {
return client, nil
}
inst, err := pool.Acquire(t.Context(), req, clientFn, newMockMCPFactory(mcpProxy))
require.NoError(t, err)
assertHandlerBody(t, inst, "/old/v1/models", "old")
pool.ReplaceProviders([]aibridge.Provider{
aibridge.NewOpenAIProvider(config.OpenAI{Name: "new", BaseURL: newUpstream.URL}),
})
instAfterReload, err := pool.Acquire(t.Context(), req, clientFn, newMockMCPFactory(mcpProxy))
require.NoError(t, err)
require.NotSame(t, inst, instAfterReload)
assertHandlerBody(t, instAfterReload, "/new/v1/models", "new")
}
func TestPoolReplaceProvidersDoesNotJoinStaleSingleflight(t *testing.T) {
t.Parallel()
oldUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_, _ = io.WriteString(w, "old")
}))
t.Cleanup(oldUpstream.Close)
newUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_, _ = io.WriteString(w, "new")
}))
t.Cleanup(newUpstream.Close)
logger := slogtest.Make(t, nil)
ctrl := gomock.NewController(t)
client := mock.NewMockDRPCClient(ctrl)
opts := aibridged.PoolOptions{MaxItems: 1, TTL: time.Minute}
pool, err := aibridged.NewCachedBridgePool(opts, []aibridge.Provider{
aibridge.NewOpenAIProvider(config.OpenAI{Name: "old", BaseURL: oldUpstream.URL}),
}, logger, nil, testTracer)
require.NoError(t, err)
t.Cleanup(func() { _ = pool.Shutdown(context.Background()) })
req := aibridged.Request{
SessionKey: "key",
InitiatorID: uuid.New(),
APIKeyID: uuid.New().String(),
}
clientFn := func() (aibridged.DRPCClient, error) {
return client, nil
}
factory := newBlockingMCPFactory()
firstDone := make(chan acquireResult, 1)
go func() {
handler, err := pool.Acquire(t.Context(), req, clientFn, factory)
firstDone <- acquireResult{handler: handler, err: err}
}()
require.Eventually(t, factory.firstBuildStarted, testutil.WaitShort, testutil.IntervalFast)
pool.ReplaceProviders([]aibridge.Provider{
aibridge.NewOpenAIProvider(config.OpenAI{Name: "new", BaseURL: newUpstream.URL}),
})
secondDone := make(chan acquireResult, 1)
go func() {
handler, err := pool.Acquire(t.Context(), req, clientFn, factory)
secondDone <- acquireResult{handler: handler, err: err}
}()
var second acquireResult
require.Eventually(t, func() bool {
select {
case second = <-secondDone:
return true
default:
return false
}
}, testutil.WaitShort, testutil.IntervalFast)
require.NoError(t, second.err)
assertHandlerBody(t, second.handler, "/new/v1/models", "new")
close(factory.releaseFirst)
var first acquireResult
require.Eventually(t, func() bool {
select {
case first = <-firstDone:
return true
default:
return false
}
}, testutil.WaitShort, testutil.IntervalFast)
require.NoError(t, first.err)
third, err := pool.Acquire(t.Context(), req, clientFn, factory)
require.NoError(t, err)
require.Same(t, second.handler, third)
}
func TestPoolReplaceProvidersAfterShutdownIsNoop(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
opts := aibridged.PoolOptions{MaxItems: 1, TTL: time.Minute}
pool, err := aibridged.NewCachedBridgePool(opts, nil, logger, nil, testTracer)
require.NoError(t, err)
require.NoError(t, pool.Shutdown(t.Context()))
require.NotPanics(t, func() {
pool.ReplaceProviders([]aibridge.Provider{
aibridge.NewOpenAIProvider(config.OpenAI{Name: "new", BaseURL: "https://example.com"}),
})
})
_, err = pool.Acquire(t.Context(), aibridged.Request{
SessionKey: "key",
InitiatorID: uuid.New(),
APIKeyID: uuid.New().String(),
}, func() (aibridged.DRPCClient, error) {
return nil, context.Canceled
}, newMockMCPFactory(nil))
require.ErrorContains(t, err, "pool shutting down")
}
func TestPool_Expiry(t *testing.T) {
t.Parallel()
@@ -166,6 +327,21 @@ func TestPool_Expiry(t *testing.T) {
})
}
func assertHandlerBody(t *testing.T, handler http.Handler, path string, body string) {
t.Helper()
req := httptest.NewRequest(http.MethodGet, path, nil)
rw := httptest.NewRecorder()
handler.ServeHTTP(rw, req)
resp := rw.Result()
defer resp.Body.Close()
got, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
require.Equal(t, body, string(got))
}
var _ aibridged.MCPProxyBuilder = &mockMCPFactory{}
type mockMCPFactory struct {
@@ -179,3 +355,42 @@ func newMockMCPFactory(proxy *mcpmock.MockServerProxier) *mockMCPFactory {
func (m *mockMCPFactory) Build(ctx context.Context, req aibridged.Request, tracer trace.Tracer) (mcp.ServerProxier, error) {
return m.proxy, nil
}
type acquireResult struct {
handler http.Handler
err error
}
type blockingMCPFactory struct {
calls atomic.Int32
firstStarted chan struct{}
releaseFirst chan struct{}
}
func newBlockingMCPFactory() *blockingMCPFactory {
return &blockingMCPFactory{
firstStarted: make(chan struct{}),
releaseFirst: make(chan struct{}),
}
}
func (m *blockingMCPFactory) firstBuildStarted() bool {
select {
case <-m.firstStarted:
return true
default:
return false
}
}
func (m *blockingMCPFactory) Build(ctx context.Context, _ aibridged.Request, _ trace.Tracer) (mcp.ServerProxier, error) {
if m.calls.Add(1) == 1 {
close(m.firstStarted)
select {
case <-m.releaseFirst:
case <-ctx.Done():
return nil, ctx.Err()
}
}
return nil, context.Canceled
}
+50
View File
@@ -0,0 +1,50 @@
package aibridged
import (
"context"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub"
"github.com/coder/coder/v2/coderd/pubsub"
)
// ProviderReloader refreshes a component's provider snapshot.
type ProviderReloader interface {
Reload(ctx context.Context) error
}
// SubscribeProviderReload refreshes once, then on AI provider changes.
func SubscribeProviderReload(
ctx context.Context,
ps dbpubsub.Pubsub,
reloader ProviderReloader,
logger slog.Logger,
) (func(), error) {
if ps == nil {
return nil, xerrors.New("pubsub is required")
}
if reloader == nil {
return nil, xerrors.New("reloader is required")
}
unsubscribe, err := ps.SubscribeWithErr(pubsub.AIProvidersChangedChannel, func(cbCtx context.Context, _ []byte, err error) {
if err != nil {
logger.Warn(cbCtx, "ai providers changed event delivered with error", slog.Error(err))
return
}
if err := reloader.Reload(cbCtx); err != nil {
logger.Warn(cbCtx, "reload ai provider snapshot from pubsub event", slog.Error(err))
return
}
logger.Debug(cbCtx, "reloaded ai provider snapshot from pubsub event")
})
if err != nil {
return nil, xerrors.Errorf("subscribe to %s: %w", pubsub.AIProvidersChangedChannel, err)
}
if err := reloader.Reload(ctx); err != nil {
logger.Warn(ctx, "initial ai provider reload", slog.Error(err))
}
return unsubscribe, nil
}
+133
View File
@@ -0,0 +1,133 @@
package aibridged_test
import (
"context"
"sync/atomic"
"testing"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/aibridged"
dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub"
"github.com/coder/coder/v2/coderd/pubsub"
"github.com/coder/coder/v2/testutil"
)
func TestSubscribeProviderReload(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
logger := slogtest.Make(t, nil)
ps := dbpubsub.NewInMemory()
t.Cleanup(func() { _ = ps.Close() })
calls := &recordingReloader{}
unsub, err := aibridged.SubscribeProviderReload(ctx, ps, calls, logger)
require.NoError(t, err)
t.Cleanup(unsub)
require.Equal(t, 1, calls.count())
require.NoError(t, ps.Publish(pubsub.AIProvidersChangedChannel, nil))
require.Eventually(t, func() bool { return calls.count() >= 2 }, testutil.WaitShort, testutil.IntervalFast,
"Reload must fire again after a pubsub notification")
}
func TestSubscribeProviderReloadSurfacesReloadError(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
logger := slogtest.Make(t, nil)
ps := dbpubsub.NewInMemory()
t.Cleanup(func() { _ = ps.Close() })
calls := &recordingReloader{returnErr: true}
unsub, err := aibridged.SubscribeProviderReload(ctx, ps, calls, logger)
require.NoError(t, err)
t.Cleanup(unsub)
require.Equal(t, 1, calls.count())
require.NoError(t, ps.Publish(pubsub.AIProvidersChangedChannel, nil))
require.Eventually(t, func() bool { return calls.count() >= 2 }, testutil.WaitShort, testutil.IntervalFast,
"Reload must keep firing even after a previous Reload returned an error")
}
func TestSubscribeProviderReloadIgnoresEventError(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
logger := slogtest.Make(t, nil)
ps := &errInjectingPubsub{}
calls := &recordingReloader{}
unsub, err := aibridged.SubscribeProviderReload(ctx, ps, calls, logger)
require.NoError(t, err)
t.Cleanup(unsub)
require.Equal(t, 1, calls.count())
ps.listener(ctx, nil, errPubsubDelivery)
require.Equal(t, 1, calls.count())
ps.listener(ctx, nil, nil)
require.Equal(t, 2, calls.count())
}
// recordingReloader is a minimal [aibridged.ProviderReloader] that
// counts calls.
type recordingReloader struct {
n atomic.Int32
returnErr bool
}
func (r *recordingReloader) Reload(_ context.Context) error {
r.n.Add(1)
if r.returnErr {
return errReloadFailed
}
return nil
}
func (r *recordingReloader) count() int {
return int(r.n.Load())
}
var (
errReloadFailed = stubError("reload failed")
errPubsubDelivery = stubError("pubsub delivery failed")
)
type stubError string
func (s stubError) Error() string { return string(s) }
var _ dbpubsub.Pubsub = &errInjectingPubsub{}
type errInjectingPubsub struct {
listener dbpubsub.ListenerWithErr
}
func (*errInjectingPubsub) Subscribe(string, dbpubsub.Listener) (func(), error) {
return nil, xerrors.New("Subscribe not implemented")
}
func (p *errInjectingPubsub) SubscribeWithErr(_ string, listener dbpubsub.ListenerWithErr) (func(), error) {
p.listener = listener
return func() {}, nil
}
func (*errInjectingPubsub) Publish(string, []byte) error {
return xerrors.New("Publish not implemented")
}
func (*errInjectingPubsub) Close() error {
return nil
}
+11
View File
@@ -0,0 +1,11 @@
package pubsub
// AIProvidersChangedChannel is the pubsub channel that carries AI
// provider lifecycle events: provider create / update / soft-delete
// and key insert / delete. Subscribers (aibridged, aibridgeproxyd)
// reload their in-memory provider snapshot on receipt.
//
// The payload is an empty invalidation hint; subscribers refetch the
// authoritative state from the database, so dropped messages only
// delay convergence rather than diverge state.
const AIProvidersChangedChannel = "ai_providers_changed"
+87 -63
View File
@@ -18,6 +18,7 @@ import (
"strconv"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"
@@ -30,6 +31,18 @@ import (
agplaibridge "github.com/coder/coder/v2/coderd/aibridge"
)
// ProviderRoute is the routing entry for a single AI provider: the
// instance name (the routing key) and the upstream base URL (the
// source of the MITM allowlist host).
type ProviderRoute struct {
Name string
BaseURL string
}
// RefreshProvidersFunc returns the live provider set used by Reload to
// rebuild the proxy's routing snapshot.
type RefreshProvidersFunc func(ctx context.Context) ([]ProviderRoute, error)
// Known AI provider hosts.
const (
HostAnthropic = "api.anthropic.com"
@@ -119,14 +132,21 @@ var blockedIPRanges = func() []net.IPNet {
// - decrypting requests using the configured MITM CA certificate
// - forwarding requests to aibridged for processing
type Server struct {
ctx context.Context
logger slog.Logger
proxy *goproxy.ProxyHttpServer
httpServer *http.Server
listener net.Listener
tlsEnabled bool
coderAccessURL *url.URL
aibridgeProviderFromHost func(host string) string
ctx context.Context
logger slog.Logger
proxy *goproxy.ProxyHttpServer
httpServer *http.Server
listener net.Listener
tlsEnabled bool
coderAccessURL *url.URL
// refreshProviders fetches the live provider snapshot on Reload.
// Nil disables hot-reload.
refreshProviders RefreshProvidersFunc
// providerRouter holds the live (mitmHosts, nameByHost) pair.
providerRouter atomic.Pointer[providerRouter]
// allowedPorts is the port allowlist for CONNECT requests. Fixed at
// construction; not reloadable.
allowedPorts []string
// caCert is the PEM-encoded MITM CA certificate loaded during initialization.
// This is served to clients who need to trust the proxy's generated certificates.
caCert []byte
@@ -139,6 +159,21 @@ type Server struct {
metrics *Metrics
}
// providerRouter keeps CONNECT matching and provider lookup in sync.
type providerRouter struct {
mitmHosts []string // host:port allowlist for the goproxy condition.
nameByHost map[string]string // lowercase hostname -> provider name.
}
// emptyProviderRouter is used before the first Reload (or when the
// operator deconfigures every provider) so handlers can safely call
// loadProviderRouter without a nil check.
var emptyProviderRouter = &providerRouter{nameByHost: map[string]string{}}
func (r *providerRouter) providerFromHost(host string) string {
return r.nameByHost[strings.ToLower(host)]
}
// requestContext holds metadata propagated through the proxy request/response chain.
// It is stored in goproxy's ProxyCtx.UserData and enriched as the request progresses
// through the proxy handlers.
@@ -183,13 +218,12 @@ type Options struct {
// CertStore is an optional certificate cache for MITM. If nil, a default
// cache is created. Exposed for testing.
CertStore goproxy.CertStorage
// DomainAllowlist is the list of domains to intercept and route through AI Bridge.
// Only requests to these domains will be MITM'd and forwarded to aibridged.
// Requests to other domains will be tunneled directly without decryption.
// DomainAllowlist seeds the boot-time MITM allowlist. Production
// callers should leave this empty and rely on RefreshProviders;
// tests use it to skip the refresh round-trip.
DomainAllowlist []string
// AIBridgeProviderFromHost maps a hostname to a known aibridge provider
// name. Must be non-nil; the caller derives it from the configured
// provider list.
// AIBridgeProviderFromHost seeds the boot-time host -> provider
// name mapping. Required iff DomainAllowlist is non-empty.
AIBridgeProviderFromHost func(host string) string
// UpstreamProxy is the URL of an upstream HTTP proxy to chain tunneled
// (non-allowlisted) requests through. If empty, tunneled requests connect
@@ -214,6 +248,10 @@ type Options struct {
// Metrics is the prometheus metrics instance for recording proxy metrics.
// If nil, metrics will not be recorded.
Metrics *Metrics
// RefreshProviders, when set, is invoked by Server.Reload to fetch
// the live provider snapshot used to derive the MITM allowlist and
// host -> provider-name routing. Nil disables hot-reload.
RefreshProviders RefreshProvidersFunc
}
func New(ctx context.Context, logger slog.Logger, opts Options) (*Server, error) {
@@ -258,29 +296,12 @@ func New(ctx context.Context, logger slog.Logger, opts Options) (*Server, error)
allowedPorts = []string{"80", "443"}
}
// An empty allowlist is permitted so the server can boot before any
// ai_providers row exists; every intercept attempt is then rejected
// until providers are configured.
// TODO: refresh the allowlist when ai_providers changes so a restart
// is not required after the first provider is configured.
mitmHosts, err := convertDomainsToHosts(opts.DomainAllowlist, allowedPorts)
// Build the boot-time router from DomainAllowlist + the lookup fn.
// Both empty is fine: the server fails closed (no MITM until
// Reload populates the router from the database).
bootRouter, err := buildBootRouter(opts.DomainAllowlist, opts.AIBridgeProviderFromHost, allowedPorts)
if err != nil {
return nil, xerrors.Errorf("invalid domain allowlist: %w", err)
}
if opts.AIBridgeProviderFromHost == nil {
return nil, xerrors.New("AIBridgeProviderFromHost is required")
}
aibridgeProviderFromHost := opts.AIBridgeProviderFromHost
for _, domain := range opts.DomainAllowlist {
domain = strings.TrimSpace(strings.ToLower(domain))
if domain == "" {
continue
}
if aibridgeProviderFromHost(domain) == "" {
return nil, xerrors.Errorf("domain %q is in allowlist but has no provider mapping", domain)
}
return nil, err
}
// Parse configured exceptions to the blocked IP ranges.
@@ -319,17 +340,22 @@ func New(ctx context.Context, logger slog.Logger, opts Options) (*Server, error)
}
srv := &Server{
ctx: ctx,
logger: logger,
proxy: proxy,
tlsEnabled: opts.TLSCertFile != "",
coderAccessURL: coderAccessURL,
aibridgeProviderFromHost: aibridgeProviderFromHost,
caCert: certPEM,
allowedPrivateRanges: allowedPrivateRanges,
newDumper: opts.NewDumper,
metrics: opts.Metrics,
ctx: ctx,
logger: logger,
proxy: proxy,
tlsEnabled: opts.TLSCertFile != "",
coderAccessURL: coderAccessURL,
refreshProviders: opts.RefreshProviders,
allowedPorts: allowedPorts,
caCert: certPEM,
allowedPrivateRanges: allowedPrivateRanges,
newDumper: opts.NewDumper,
metrics: opts.Metrics,
}
// Seed the boot-time router from the constructor inputs so the
// proxy can serve immediately. Reload may swap this snapshot at any
// point after construction.
srv.providerRouter.Store(bootRouter)
// Configure upstream proxy for tunneled (non-allowlisted) CONNECT requests.
// Allowlisted domains are MITM'd and forwarded to aibridge directly,
@@ -417,12 +443,11 @@ func New(ctx context.Context, logger slog.Logger, opts Options) (*Server, error)
// Reject CONNECT requests to non-standard ports.
proxy.OnRequest().HandleConnectFunc(srv.portMiddleware(allowedPorts))
// Apply MITM with authentication only to allowlisted hosts.
proxy.OnRequest(
// Only CONNECT requests to these hosts will be intercepted and decrypted.
// All other requests will be tunneled directly to their destination.
goproxy.ReqHostIs(mitmHosts...),
).HandleConnectFunc(
// Apply MITM with authentication only to allowlisted hosts. The host
// list is loaded from the atomic router on every CONNECT so a
// Reload while inflight requests are in progress takes effect on
// the next CONNECT without touching the already-MITM'd ones.
proxy.OnRequest(srv.mitmHostsCondition()).HandleConnectFunc(
// Extract Coder token from proxy authentication to forward to aibridged.
srv.authMiddleware,
)
@@ -470,7 +495,7 @@ func New(ctx context.Context, logger slog.Logger, opts Options) (*Server, error)
slog.F("listen_addr", listener.Addr().String()),
slog.F("tls_listener_enabled", srv.tlsEnabled),
slog.F("coder_access_url", coderAccessURL.String()),
slog.F("domain_allowlist", mitmHosts),
slog.F("domain_allowlist", bootRouter.mitmHosts),
slog.F("upstream_proxy", opts.UpstreamProxy),
slog.F("allowed_private_cidrs", opts.AllowedPrivateCIDRs),
slog.F("api_dump_enabled", opts.NewDumper != nil),
@@ -651,11 +676,11 @@ func (s *Server) authMiddleware(host string, ctx *goproxy.ProxyCtx) (*goproxy.Co
)
// Determine the provider from the request hostname.
provider := s.aibridgeProviderFromHost(ctx.Req.URL.Hostname())
// This should never happen: startup validation ensures all allowlisted
// domains have known aibridge provider mappings.
provider := s.loadProviderRouter().providerFromHost(ctx.Req.URL.Hostname())
// A concurrent Reload can swap the router between CONNECT matching
// and provider lookup, so treat a missing mapping as a runtime miss.
if provider == "" {
logger.Error(s.ctx, "rejecting CONNECT request with no provider mapping")
logger.Warn(s.ctx, "rejecting CONNECT request with no provider mapping")
return goproxy.RejectConnect, host
}
@@ -922,12 +947,11 @@ func (s *Server) handleRequest(req *http.Request, ctx *goproxy.ProxyCtx) (*http.
}
if reqCtx.Provider == "" {
// This should never happen: startup validation ensures all allowlisted
// domains have known aibridge provider mappings.
// The request is MITM'd (decrypted) but since there is no mapping,
// there is no known route to aibridge.
// Log error and forward to the original destination as a fallback.
s.logger.Error(s.ctx, "decrypted request has no provider mapping, passing through",
// A concurrent Reload can remove the provider after CONNECT
// authentication. The request is MITM'd (decrypted), but without a
// mapping there is no known route to aibridge. Log and forward
// to the original destination as a fallback.
s.logger.Warn(s.ctx, "decrypted request has no provider mapping, passing through",
slog.F("connect_id", reqCtx.ConnectSessionID.String()),
slog.F("host", req.Host),
slog.F("method", req.Method),
@@ -158,6 +158,7 @@ type testProxyConfig struct {
allowedPrivateCIDRs []string
newDumper func(string, string) aibridgeproxyd.RoundTripDumper
metrics *aibridgeproxyd.Metrics
refreshProviders aibridgeproxyd.RefreshProvidersFunc
}
type testProxyOption func(*testProxyConfig)
@@ -250,6 +251,12 @@ func withListenerTLS(certFile, keyFile string) testProxyOption {
}
}
func withRefreshProviders(fn aibridgeproxyd.RefreshProvidersFunc) testProxyOption {
return func(cfg *testProxyConfig) {
cfg.refreshProviders = fn
}
}
// newTestProxy creates a new AI Bridge Proxy server for testing.
// It uses the shared MITM certificate and registers cleanup automatically.
// It waits for the proxy server to be ready before returning.
@@ -289,6 +296,7 @@ func newTestProxy(t *testing.T, opts ...testProxyOption) *aibridgeproxyd.Server
AllowedPrivateCIDRs: cfg.allowedPrivateCIDRs,
NewDumper: cfg.newDumper,
Metrics: cfg.metrics,
RefreshProviders: cfg.refreshProviders,
}
if cfg.certStore != nil {
aibridgeOpts.CertStore = cfg.certStore
+124
View File
@@ -0,0 +1,124 @@
package aibridgeproxyd
import (
"context"
"net/http"
"net/url"
"slices"
"strings"
"github.com/elazarl/goproxy"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
)
// Reload refreshes proxy routing from the configured provider source.
// A refresh failure leaves the previous snapshot in place.
func (s *Server) Reload(ctx context.Context) error {
if s.refreshProviders == nil {
return nil
}
providers, err := s.refreshProviders(ctx)
if err != nil {
return xerrors.Errorf("refresh ai providers for proxy routing: %w", err)
}
router, err := buildProviderRouter(ctx, s.logger, providers, s.allowedPorts)
if err != nil {
return xerrors.Errorf("build provider router (provider_count=%d): %w", len(providers), err)
}
s.providerRouter.Store(router)
s.logger.Debug(s.ctx, "aibridgeproxyd router reloaded",
slog.F("mitm_host_count", len(router.mitmHosts)),
)
return nil
}
func (s *Server) loadProviderRouter() *providerRouter {
if p := s.providerRouter.Load(); p != nil {
return p
}
return emptyProviderRouter
}
// mitmHostsCondition returns a goproxy ReqConditionFunc that reads the
// allowlist from the atomic router on every match. Using a closure
// instead of goproxy.ReqHostIs(...) lets Reload affect every later
// CONNECT without re-registering handlers.
func (s *Server) mitmHostsCondition() goproxy.ReqConditionFunc {
return func(req *http.Request, _ *goproxy.ProxyCtx) bool {
if req == nil {
return false
}
return slices.Contains(s.loadProviderRouter().mitmHosts, strings.ToLower(req.URL.Host))
}
}
// buildProviderRouter constructs a router snapshot from a refreshed
// provider list. First provider wins on duplicate hostnames.
func buildProviderRouter(ctx context.Context, logger slog.Logger, providers []ProviderRoute, allowedPorts []string) (*providerRouter, error) {
nameByHost := make(map[string]string, len(providers))
var domains []string
for _, p := range providers {
if p.BaseURL == "" {
logger.Warn(ctx, "skipping ai provider without base url",
slog.F("provider_name", p.Name),
)
continue
}
u, err := url.Parse(p.BaseURL)
if err != nil {
logger.Warn(ctx, "skipping ai provider with invalid base url",
slog.F("provider_name", p.Name),
slog.F("base_url", p.BaseURL),
slog.Error(err),
)
continue
}
if u.Hostname() == "" {
logger.Warn(ctx, "skipping ai provider base url without hostname",
slog.F("provider_name", p.Name),
slog.F("base_url", p.BaseURL),
)
continue
}
host := strings.ToLower(u.Hostname())
if _, exists := nameByHost[host]; exists {
continue
}
nameByHost[host] = p.Name
domains = append(domains, host)
}
mitmHosts, err := convertDomainsToHosts(domains, allowedPorts)
if err != nil {
return nil, err
}
return &providerRouter{mitmHosts: mitmHosts, nameByHost: nameByHost}, nil
}
// buildBootRouter seeds the providerRouter from the boot-time inputs.
// The lookup function is consulted only for hosts in the allowlist; a
// nil function with an empty allowlist is fine and yields an empty
// router (the proxy fails closed until Reload populates it).
func buildBootRouter(domainAllowlist []string, providerFromHost func(string) string, allowedPorts []string) (*providerRouter, error) {
mitmHosts, err := convertDomainsToHosts(domainAllowlist, allowedPorts)
if err != nil {
return nil, xerrors.Errorf("invalid domain allowlist: %w", err)
}
nameByHost := make(map[string]string, len(domainAllowlist))
for _, domain := range domainAllowlist {
domain = strings.TrimSpace(strings.ToLower(domain))
if domain == "" {
continue
}
var name string
if providerFromHost != nil {
name = providerFromHost(domain)
}
if name == "" {
return nil, xerrors.Errorf("domain %q is in allowlist but has no provider mapping", domain)
}
nameByHost[domain] = name
}
return &providerRouter{mitmHosts: mitmHosts, nameByHost: nameByHost}, nil
}
@@ -0,0 +1,147 @@
package aibridgeproxyd
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/testutil"
)
func TestServerReloadSwapsProviderRouter(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
providers := []ProviderRoute{{Name: "old", BaseURL: "https://old.example.com/"}}
srv := &Server{
ctx: ctx,
logger: slogtest.Make(t, nil),
allowedPorts: []string{"443"},
refreshProviders: func(context.Context) ([]ProviderRoute, error) {
return providers, nil
},
}
srv.providerRouter.Store(emptyProviderRouter)
require.NoError(t, srv.Reload(ctx))
assert.Equal(t, "old", srv.loadProviderRouter().providerFromHost("old.example.com"))
assert.Empty(t, srv.loadProviderRouter().providerFromHost("new.example.com"))
providers = []ProviderRoute{{Name: "new", BaseURL: "https://new.example.com/"}}
require.NoError(t, srv.Reload(ctx))
router := srv.loadProviderRouter()
assert.Empty(t, router.providerFromHost("old.example.com"))
assert.Equal(t, "new", router.providerFromHost("new.example.com"))
assert.Equal(t, []string{"new.example.com:443"}, router.mitmHosts)
}
func TestServerReloadPreservesProviderRouterOnRefreshError(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
refreshErr := xerrors.New("refresh failed")
providers := []ProviderRoute{{Name: "old", BaseURL: "https://old.example.com/"}}
failRefresh := false
srv := &Server{
ctx: ctx,
logger: slogtest.Make(t, nil),
allowedPorts: []string{"443"},
refreshProviders: func(context.Context) ([]ProviderRoute, error) {
if failRefresh {
return nil, refreshErr
}
return providers, nil
},
}
srv.providerRouter.Store(emptyProviderRouter)
require.NoError(t, srv.Reload(ctx))
before := srv.loadProviderRouter()
assert.Equal(t, "old", before.providerFromHost("old.example.com"))
failRefresh = true
require.ErrorIs(t, srv.Reload(ctx), refreshErr)
after := srv.loadProviderRouter()
assert.Same(t, before, after)
assert.Equal(t, "old", after.providerFromHost("old.example.com"))
assert.Equal(t, []string{"old.example.com:443"}, after.mitmHosts)
}
// TestBuildProviderRouter covers the host-and-routing derivation that
// Reload feeds into the providerRouter.
func TestBuildProviderRouter(t *testing.T) {
t.Parallel()
t.Run("ExtractsHostnames", func(t *testing.T) {
t.Parallel()
providers := []ProviderRoute{
{Name: "openai", BaseURL: "https://api.openai.com/v1/"},
{Name: "anthropic", BaseURL: "https://api.anthropic.com/"},
{Name: "custom", BaseURL: "https://custom-llm.example.com:8443/api"},
}
router, err := buildProviderRouter(testutil.Context(t, testutil.WaitShort), slogtest.Make(t, nil), providers, []string{"443"})
require.NoError(t, err)
assert.Equal(t, "openai", router.providerFromHost("api.openai.com"))
assert.Equal(t, "anthropic", router.providerFromHost("api.anthropic.com"))
assert.Equal(t, "custom", router.providerFromHost("custom-llm.example.com"))
assert.Empty(t, router.providerFromHost("unknown.com"))
assert.Contains(t, router.mitmHosts, "api.openai.com:443")
assert.Contains(t, router.mitmHosts, "api.anthropic.com:443")
})
t.Run("DeduplicatesSameHost", func(t *testing.T) {
t.Parallel()
providers := []ProviderRoute{
{Name: "first", BaseURL: "https://api.example.com/v1"},
{Name: "second", BaseURL: "https://api.example.com/v2"},
}
router, err := buildProviderRouter(testutil.Context(t, testutil.WaitShort), slogtest.Make(t, nil), providers, []string{"443"})
require.NoError(t, err)
// First provider wins on duplicate host.
assert.Equal(t, "first", router.providerFromHost("api.example.com"))
})
t.Run("CaseInsensitive", func(t *testing.T) {
t.Parallel()
providers := []ProviderRoute{
{Name: "provider", BaseURL: "https://API.Example.COM/v1"},
}
router, err := buildProviderRouter(testutil.Context(t, testutil.WaitShort), slogtest.Make(t, nil), providers, []string{"443"})
require.NoError(t, err)
assert.Equal(t, "provider", router.providerFromHost("API.Example.COM"))
assert.Equal(t, "provider", router.providerFromHost("api.example.com"))
})
t.Run("SkipsEmptyOrMalformedBaseURL", func(t *testing.T) {
t.Parallel()
providers := []ProviderRoute{
{Name: "no-url"},
{Name: "scheme-only", BaseURL: "https://"},
{Name: "good", BaseURL: "https://api.good.example.com/"},
}
router, err := buildProviderRouter(testutil.Context(t, testutil.WaitShort), slogtest.Make(t, nil), providers, []string{"443"})
require.NoError(t, err)
assert.Equal(t, "good", router.providerFromHost("api.good.example.com"))
assert.Empty(t, router.providerFromHost("scheme-only"))
assert.Equal(t, []string{"api.good.example.com:443"}, router.mitmHosts)
})
}
+379
View File
@@ -0,0 +1,379 @@
package aibridgeproxyd_test
import (
"context"
"io"
"net/http"
"net/http/httptest"
"slices"
"strings"
"sync"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/enterprise/aibridgeproxyd"
"github.com/coder/coder/v2/testutil"
)
// reloadTestHarness wires a real proxy server to a mutable provider
// store and a mock aibridged backend so tests can drive Reload through
// a CRUD-style sequence and observe routing via real proxy requests.
type reloadTestHarness struct {
srv *aibridgeproxyd.Server
store *providerStore
client *http.Client
bridged *httptest.Server
recorder *aibridgedRecorder
}
// aibridgedRecorder captures the path of the last request received by
// the mock aibridged backend. Access is mutex-guarded so the test
// goroutine and the proxy's response goroutine can read/write safely.
type aibridgedRecorder struct {
mu sync.Mutex
path string
}
func (r *aibridgedRecorder) record(path string) {
r.mu.Lock()
defer r.mu.Unlock()
r.path = path
}
func (r *aibridgedRecorder) load() string {
r.mu.Lock()
defer r.mu.Unlock()
return r.path
}
func (r *aibridgedRecorder) reset() {
r.mu.Lock()
defer r.mu.Unlock()
r.path = ""
}
// providerStore is a mutable [aibridgeproxyd.RefreshProvidersFunc]
// backing for integration tests. set / setErr mutate the snapshot
// returned by the next Reload, mimicking CRUD against the database.
type providerStore struct {
mu sync.Mutex
providers []aibridgeproxyd.ProviderRoute
err error
}
func (s *providerStore) set(providers []aibridgeproxyd.ProviderRoute) {
s.mu.Lock()
defer s.mu.Unlock()
s.providers = providers
s.err = nil
}
func (s *providerStore) setErr(err error) {
s.mu.Lock()
defer s.mu.Unlock()
s.err = err
}
func (s *providerStore) refresh(context.Context) ([]aibridgeproxyd.ProviderRoute, error) {
s.mu.Lock()
defer s.mu.Unlock()
if s.err != nil {
return nil, s.err
}
// Return a copy so callers can't mutate our internal snapshot.
return slices.Clone(s.providers), nil
}
// newReloadTestHarness boots a proxy with an empty boot allowlist and a
// store-backed RefreshProviders. Production wiring is identical: the
// daemon constructs the proxy without a static allowlist and lets
// Reload populate the router from the database.
func newReloadTestHarness(t *testing.T) *reloadTestHarness {
t.Helper()
recorder := &aibridgedRecorder{}
bridged := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
recorder.record(r.URL.Path)
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("aibridged"))
}))
t.Cleanup(bridged.Close)
store := &providerStore{}
srv := newTestProxy(t,
withCoderAccessURL(bridged.URL),
withAllowedPorts("443"),
// Empty boot allowlist: the router must be populated by Reload,
// matching the production daemon's behavior.
withDomainAllowlist(),
withAIBridgeProviderFromHost(nil),
withRefreshProviders(store.refresh),
)
certPool := getProxyCertPool(t)
client := newProxyClient(t, srv, makeProxyAuthHeader("coder-token"), certPool, false)
// Disable keep-alives so each request opens a fresh CONNECT through
// the proxy. Per the Reload contract, already-MITM'd tunnels keep
// the provider name they captured at CONNECT time; only new
// connections see the post-Reload snapshot. Tests need a fresh
// CONNECT between phases to assert on the new routing.
client.Transport.(*http.Transport).DisableKeepAlives = true
return &reloadTestHarness{
srv: srv,
store: store,
client: client,
bridged: bridged,
recorder: recorder,
}
}
// requestResult is the outcome of sending a request through the proxy.
// Either err is set (CONNECT failed for a non-MITM'd host whose dial
// fell through to the tunneled path and could not be resolved) or
// status/body carry the MITM'd response from the mock aibridged.
type requestResult struct {
status int
body string
err error
}
// sendRequest issues a single POST through the proxy. It returns rather
// than asserting so callers can branch on whether the host is currently
// routed (MITM'd to aibridged) or not (tunneled, dial of an unresolvable
// host fails).
func (h *reloadTestHarness) sendRequest(t *testing.T, targetURL string) requestResult {
t.Helper()
ctx, cancel := context.WithTimeout(t.Context(), testutil.WaitShort)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodPost, targetURL, strings.NewReader(`{}`))
require.NoError(t, err)
req.Header.Set("Content-Type", "application/json")
resp, err := h.client.Do(req)
if err != nil {
return requestResult{err: err}
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
return requestResult{status: resp.StatusCode, body: string(body)}
}
// expectRoutedTo asserts the proxy MITM'd the request and forwarded it
// to aibridged with the expected /api/v2/aibridge/<name>/<path>.
func (h *reloadTestHarness) expectRoutedTo(t *testing.T, targetURL, expectedPath string) {
t.Helper()
h.recorder.reset()
res := h.sendRequest(t, targetURL)
require.NoError(t, res.err, "request to routed host must succeed")
require.Equal(t, http.StatusOK, res.status)
require.Equal(t, "aibridged", res.body)
require.Equal(t, expectedPath, h.recorder.load(),
"aibridged must observe the rewritten path for %s", targetURL)
}
// expectNotRouted asserts the proxy did not MITM the request for the
// given host. The CONNECT either falls through to the tunneled path
// (where the .invalid hostname fails to dial) or to a 502 from the
// proxy. Either way, aibridged never sees the request.
func (h *reloadTestHarness) expectNotRouted(t *testing.T, targetURL string) {
t.Helper()
h.recorder.reset()
_ = h.sendRequest(t, targetURL)
require.Empty(t, h.recorder.load(),
"aibridged must not be reached for non-routed host %s", targetURL)
}
// TestProxy_HotReloadRoutingCRUD drives the proxy through a CRUD-style
// sequence of provider changes and asserts on routing after each
// Reload via real HTTPS requests. Each sub-test mutates the store and
// validates that:
// - newly created providers are MITM'd to aibridged with the right
// /api/v2/aibridge/<name>/<path>
// - renamed providers route under the new name
// - providers whose BaseURL host changes route the new host and stop
// MITM'ing the old host
// - deleted providers stop being MITM'd; aibridged sees nothing
//
// Hostnames are .invalid (RFC 2606) so a request that escapes the MITM
// path fails fast via DNS rather than reaching a real upstream.
func TestProxy_HotReloadRoutingCRUD(t *testing.T) {
t.Parallel()
h := newReloadTestHarness(t)
// InitialEmptyRouter: no Reload has been called and the boot
// allowlist is empty, so any host falls through to the tunneled
// middleware.
h.expectNotRouted(t, "https://alpha.invalid/v1/messages")
// CreateProvider.
h.store.set([]aibridgeproxyd.ProviderRoute{
{Name: "alpha", BaseURL: "https://alpha.invalid/v1"},
})
require.NoError(t, h.srv.Reload(t.Context()))
h.expectRoutedTo(t, "https://alpha.invalid/v1/messages", "/api/v2/aibridge/alpha/v1/messages")
// UpdateProviderName: the same BaseURL with a new name must route
// under the new name on the next Reload.
h.store.set([]aibridgeproxyd.ProviderRoute{
{Name: "alpha-v2", BaseURL: "https://alpha.invalid/v1"},
})
require.NoError(t, h.srv.Reload(t.Context()))
h.expectRoutedTo(t, "https://alpha.invalid/v1/messages", "/api/v2/aibridge/alpha-v2/v1/messages")
// UpdateProviderBaseURLHost: moving the provider to a new host must
// start MITM'ing the new host and stop MITM'ing the old one.
h.store.set([]aibridgeproxyd.ProviderRoute{
{Name: "alpha-v2", BaseURL: "https://alpha-new.invalid/v1"},
})
require.NoError(t, h.srv.Reload(t.Context()))
h.expectRoutedTo(t, "https://alpha-new.invalid/v1/messages", "/api/v2/aibridge/alpha-v2/v1/messages")
h.expectNotRouted(t, "https://alpha.invalid/v1/messages")
// AddSecondProvider: a second provider added in the same Reload must
// route independently from the first.
h.store.set([]aibridgeproxyd.ProviderRoute{
{Name: "alpha-v2", BaseURL: "https://alpha-new.invalid/v1"},
{Name: "beta", BaseURL: "https://beta.invalid/v1"},
})
require.NoError(t, h.srv.Reload(t.Context()))
h.expectRoutedTo(t, "https://alpha-new.invalid/v1/messages", "/api/v2/aibridge/alpha-v2/v1/messages")
h.expectRoutedTo(t, "https://beta.invalid/v1/chat/completions", "/api/v2/aibridge/beta/v1/chat/completions")
// DeleteOneProvider: removing alpha must keep beta routed and stop
// routing alpha.
h.store.set([]aibridgeproxyd.ProviderRoute{
{Name: "beta", BaseURL: "https://beta.invalid/v1"},
})
require.NoError(t, h.srv.Reload(t.Context()))
h.expectRoutedTo(t, "https://beta.invalid/v1/chat/completions", "/api/v2/aibridge/beta/v1/chat/completions")
h.expectNotRouted(t, "https://alpha-new.invalid/v1/messages")
// DeleteAllProviders: an empty Reload must collapse the router to
// the fail-closed state with no host MITM'd.
h.store.set(nil)
require.NoError(t, h.srv.Reload(t.Context()))
h.expectNotRouted(t, "https://beta.invalid/v1/chat/completions")
h.expectNotRouted(t, "https://alpha-new.invalid/v1/messages")
// RecreateAfterDelete: reintroducing a previously-deleted provider
// must route again without restart, confirming the swap is
// symmetric.
h.store.set([]aibridgeproxyd.ProviderRoute{
{Name: "alpha", BaseURL: "https://alpha.invalid/v1"},
})
require.NoError(t, h.srv.Reload(t.Context()))
h.expectRoutedTo(t, "https://alpha.invalid/v1/messages", "/api/v2/aibridge/alpha/v1/messages")
}
// TestProxy_HotReloadRoutingInvalidProviders covers the resilience
// requirements stated in the [aibridgeproxyd.Server.Reload] contract:
// individual invalid provider entries do not poison the snapshot, and
// a refresh-level error does not collapse the previous snapshot to
// empty.
func TestProxy_HotReloadRoutingInvalidProviders(t *testing.T) {
t.Parallel()
t.Run("EmptyBaseURLSkipped", func(t *testing.T) {
t.Parallel()
h := newReloadTestHarness(t)
// One valid provider and one with an empty BaseURL. The empty
// entry must be silently dropped; the valid one must still
// route.
h.store.set([]aibridgeproxyd.ProviderRoute{
{Name: "no-url"},
{Name: "valid", BaseURL: "https://valid.invalid/v1"},
})
require.NoError(t, h.srv.Reload(t.Context()))
h.expectRoutedTo(t, "https://valid.invalid/v1/messages", "/api/v2/aibridge/valid/v1/messages")
})
t.Run("MalformedBaseURLSkipped", func(t *testing.T) {
t.Parallel()
h := newReloadTestHarness(t)
// A BaseURL that fails url.Parse and one whose Hostname() is
// empty must both be dropped. Mixed with a valid entry, only
// the valid one routes.
h.store.set([]aibridgeproxyd.ProviderRoute{
{Name: "malformed", BaseURL: "://not-a-url"},
{Name: "no-host", BaseURL: "https://"},
{Name: "valid", BaseURL: "https://valid.invalid/v1"},
})
require.NoError(t, h.srv.Reload(t.Context()))
h.expectRoutedTo(t, "https://valid.invalid/v1/messages", "/api/v2/aibridge/valid/v1/messages")
})
t.Run("DuplicateHostFirstWins", func(t *testing.T) {
t.Parallel()
h := newReloadTestHarness(t)
// Two providers with the same BaseURL host: the first one wins,
// matching buildProviderRouter's documented contract.
h.store.set([]aibridgeproxyd.ProviderRoute{
{Name: "first", BaseURL: "https://shared.invalid/v1"},
{Name: "second", BaseURL: "https://shared.invalid/v2"},
})
require.NoError(t, h.srv.Reload(t.Context()))
h.expectRoutedTo(t, "https://shared.invalid/v1/messages", "/api/v2/aibridge/first/v1/messages")
})
t.Run("AllInvalidYieldsEmptyRouter", func(t *testing.T) {
t.Parallel()
h := newReloadTestHarness(t)
// When every provider is invalid, the router contains no
// entries and the proxy fails closed: no host is MITM'd.
h.store.set([]aibridgeproxyd.ProviderRoute{
{Name: "no-url"},
{Name: "malformed", BaseURL: "://not-a-url"},
{Name: "no-host", BaseURL: "https://"},
})
require.NoError(t, h.srv.Reload(t.Context()))
h.expectNotRouted(t, "https://anything.invalid/v1/messages")
})
t.Run("RefreshErrorPreservesPreviousSnapshot", func(t *testing.T) {
t.Parallel()
h := newReloadTestHarness(t)
// Seed a valid snapshot so we have something to preserve.
h.store.set([]aibridgeproxyd.ProviderRoute{
{Name: "alpha", BaseURL: "https://alpha.invalid/v1"},
})
require.NoError(t, h.srv.Reload(t.Context()))
h.expectRoutedTo(t, "https://alpha.invalid/v1/messages", "/api/v2/aibridge/alpha/v1/messages")
// A refresh error must NOT clear the router: dropping the
// allowlist on every transient DB hiccup would amplify the
// fault into a denial of service.
h.store.setErr(xerrors.New("simulated db failure"))
err := h.srv.Reload(t.Context())
require.Error(t, err)
assert.Contains(t, err.Error(), "refresh ai providers for proxy routing")
h.expectRoutedTo(t, "https://alpha.invalid/v1/messages", "/api/v2/aibridge/alpha/v1/messages")
// Recovery: once the store returns providers again, the next
// Reload applies the new snapshot.
h.store.set([]aibridgeproxyd.ProviderRoute{
{Name: "beta", BaseURL: "https://beta.invalid/v1"},
})
require.NoError(t, h.srv.Reload(t.Context()))
h.expectRoutedTo(t, "https://beta.invalid/v1/messages", "/api/v2/aibridge/beta/v1/messages")
h.expectNotRouted(t, "https://alpha.invalid/v1/messages")
})
}
+63 -49
View File
@@ -4,27 +4,45 @@ package cli
import (
"context"
"net/url"
"io"
"path/filepath"
"strings"
"github.com/prometheus/client_golang/prometheus"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/aibridge"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/aibridge/intercept/apidump"
"github.com/coder/coder/v2/coderd/aibridged"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/enterprise/aibridgeproxyd"
"github.com/coder/coder/v2/enterprise/coderd"
)
func newAIBridgeProxyDaemon(coderAPI *coderd.API, providers []aibridge.Provider) (*aibridgeproxyd.Server, error) {
// aiBridgeProxyDaemon bundles the proxy server and its pubsub
// subscription so both are torn down by a single Close call.
type aiBridgeProxyDaemon struct {
server *aibridgeproxyd.Server
unsubscribe func()
}
func (d *aiBridgeProxyDaemon) Close() error {
if d.unsubscribe != nil {
d.unsubscribe()
}
return d.server.Close()
}
// newAIBridgeProxyDaemon starts the enterprise aibridge proxy daemon,
// subscribes to ai_providers changes so the proxy's routing snapshot
// tracks the database, and registers the HTTP handler on the API.
// The returned io.Closer tears down both the subscription and server.
func newAIBridgeProxyDaemon(coderAPI *coderd.API) (io.Closer, error) {
ctx := context.Background()
coderAPI.Logger.Debug(ctx, "starting in-memory aibridgeproxy daemon")
logger := coderAPI.Logger.Named("aibridgeproxyd")
domains, providerFromHost := domainsFromProviders(providers)
reg := prometheus.WrapRegistererWithPrefix("coder_aibridgeproxyd_", coderAPI.PrometheusRegistry)
metrics := aibridgeproxyd.NewMetrics(reg)
@@ -36,55 +54,51 @@ func newAIBridgeProxyDaemon(coderAPI *coderd.API, providers []aibridge.Provider)
}
srv, err := aibridgeproxyd.New(ctx, logger, aibridgeproxyd.Options{
ListenAddr: coderAPI.DeploymentValues.AI.BridgeProxyConfig.ListenAddr.String(),
TLSCertFile: coderAPI.DeploymentValues.AI.BridgeProxyConfig.TLSCertFile.String(),
TLSKeyFile: coderAPI.DeploymentValues.AI.BridgeProxyConfig.TLSKeyFile.String(),
CoderAccessURL: coderAPI.AccessURL.String(),
MITMCertFile: coderAPI.DeploymentValues.AI.BridgeProxyConfig.MITMCertFile.String(),
MITMKeyFile: coderAPI.DeploymentValues.AI.BridgeProxyConfig.MITMKeyFile.String(),
DomainAllowlist: domains,
AIBridgeProviderFromHost: providerFromHost,
UpstreamProxy: coderAPI.DeploymentValues.AI.BridgeProxyConfig.UpstreamProxy.String(),
UpstreamProxyCA: coderAPI.DeploymentValues.AI.BridgeProxyConfig.UpstreamProxyCA.String(),
AllowedPrivateCIDRs: coderAPI.DeploymentValues.AI.BridgeProxyConfig.AllowedPrivateCIDRs.Value(),
NewDumper: newDumper,
Metrics: metrics,
ListenAddr: coderAPI.DeploymentValues.AI.BridgeProxyConfig.ListenAddr.String(),
TLSCertFile: coderAPI.DeploymentValues.AI.BridgeProxyConfig.TLSCertFile.String(),
TLSKeyFile: coderAPI.DeploymentValues.AI.BridgeProxyConfig.TLSKeyFile.String(),
CoderAccessURL: coderAPI.AccessURL.String(),
MITMCertFile: coderAPI.DeploymentValues.AI.BridgeProxyConfig.MITMCertFile.String(),
MITMKeyFile: coderAPI.DeploymentValues.AI.BridgeProxyConfig.MITMKeyFile.String(),
UpstreamProxy: coderAPI.DeploymentValues.AI.BridgeProxyConfig.UpstreamProxy.String(),
UpstreamProxyCA: coderAPI.DeploymentValues.AI.BridgeProxyConfig.UpstreamProxyCA.String(),
AllowedPrivateCIDRs: coderAPI.DeploymentValues.AI.BridgeProxyConfig.AllowedPrivateCIDRs.Value(),
NewDumper: newDumper,
Metrics: metrics,
RefreshProviders: refreshProxyProviders(coderAPI.Database),
})
if err != nil {
return nil, xerrors.Errorf("failed to start in-memory aibridgeproxy daemon: %w", err)
}
return srv, nil
}
// domainsFromProviders extracts distinct hostnames from providers' base
// URLs and builds a host-to-provider-name mapping function. The returned
// domain list is suitable for use as DomainAllowlist and the mapping
// function is suitable for use as AIBridgeProviderFromHost.
func domainsFromProviders(providers []aibridge.Provider) ([]string, func(string) string) {
hostToProvider := make(map[string]string, len(providers))
var domains []string
for _, p := range providers {
raw := p.BaseURL()
if raw == "" {
continue
}
u, err := url.Parse(raw)
if err != nil || u.Hostname() == "" {
continue
}
host := strings.ToLower(u.Hostname())
if _, exists := hostToProvider[host]; exists {
// First provider wins; duplicates are expected when
// multiple providers share a base URL host (e.g. two
// OpenAI providers using the same proxy).
continue
}
hostToProvider[host] = p.Name()
domains = append(domains, host)
unsubscribe, err := aibridged.SubscribeProviderReload(ctx, coderAPI.Pubsub, srv, logger.Named("provider-reload"))
if err != nil {
logger.Warn(ctx, "subscribe aibridgeproxyd to ai providers change channel", slog.Error(err))
unsubscribe = func() {}
}
return domains, func(host string) string {
return hostToProvider[strings.ToLower(host)]
// Register the handler so coderd can serve the proxy endpoints.
coderAPI.RegisterInMemoryAIBridgeProxydHTTPHandler(srv.Handler())
return &aiBridgeProxyDaemon{
server: srv,
unsubscribe: unsubscribe,
}, nil
}
func refreshProxyProviders(db database.Store) aibridgeproxyd.RefreshProvidersFunc {
return func(ctx context.Context) ([]aibridgeproxyd.ProviderRoute, error) {
//nolint:gocritic // AsAIProviderMetadataReader is the correct subject for routing-only access.
rows, err := db.GetAIProviders(dbauthz.AsAIProviderMetadataReader(ctx), database.GetAIProvidersParams{
IncludeDisabled: false,
})
if err != nil {
return nil, xerrors.Errorf("load ai providers: %w", err)
}
out := make([]aibridgeproxyd.ProviderRoute, 0, len(rows))
for _, row := range rows {
out = append(out, aibridgeproxyd.ProviderRoute{Name: row.Name, BaseURL: row.BaseUrl})
}
return out, nil
}
}
@@ -1,72 +0,0 @@
//go:build !slim
package cli
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/coder/coder/v2/aibridge"
)
func TestDomainsFromProviders(t *testing.T) {
t.Parallel()
t.Run("ExtractsHostnames", func(t *testing.T) {
t.Parallel()
providers := []aibridge.Provider{
aibridge.NewOpenAIProvider(aibridge.OpenAIConfig{Name: "openai", BaseURL: "https://api.openai.com/v1/"}),
aibridge.NewAnthropicProvider(aibridge.AnthropicConfig{Name: "anthropic", BaseURL: "https://api.anthropic.com/"}, nil),
aibridge.NewOpenAIProvider(aibridge.OpenAIConfig{Name: "custom", BaseURL: "https://custom-llm.example.com:8443/api"}),
}
domains, mapping := domainsFromProviders(providers)
assert.Contains(t, domains, "api.openai.com")
assert.Contains(t, domains, "api.anthropic.com")
assert.Contains(t, domains, "custom-llm.example.com")
assert.Equal(t, "openai", mapping("api.openai.com"))
assert.Equal(t, "anthropic", mapping("api.anthropic.com"))
assert.Equal(t, "custom", mapping("custom-llm.example.com"))
assert.Empty(t, mapping("unknown.com"))
})
t.Run("DeduplicatesSameHost", func(t *testing.T) {
t.Parallel()
providers := []aibridge.Provider{
aibridge.NewOpenAIProvider(aibridge.OpenAIConfig{Name: "first", BaseURL: "https://api.example.com/v1"}),
aibridge.NewOpenAIProvider(aibridge.OpenAIConfig{Name: "second", BaseURL: "https://api.example.com/v2"}),
}
domains, mapping := domainsFromProviders(providers)
// Count occurrences of api.example.com.
count := 0
for _, d := range domains {
if d == "api.example.com" {
count++
}
}
assert.Equal(t, 1, count)
// First provider wins.
assert.Equal(t, "first", mapping("api.example.com"))
})
t.Run("CaseInsensitive", func(t *testing.T) {
t.Parallel()
providers := []aibridge.Provider{
aibridge.NewOpenAIProvider(aibridge.OpenAIConfig{Name: "provider", BaseURL: "https://API.Example.COM/v1"}),
}
domains, mapping := domainsFromProviders(providers)
assert.Contains(t, domains, "api.example.com")
assert.Equal(t, "provider", mapping("API.Example.COM"))
assert.Equal(t, "provider", mapping("api.example.com"))
})
}
+10 -17
View File
@@ -15,7 +15,6 @@ import (
"tailscale.com/derp"
"tailscale.com/types/key"
agplcli "github.com/coder/coder/v2/cli"
agplcoderd "github.com/coder/coder/v2/coderd"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/cryptorand"
@@ -167,13 +166,14 @@ func (r *RootCmd) Server(_ func()) *serpent.Command {
// in-memory roundtripper regardless of license); only the proxy
// daemon remains enterprise-gated by config.
if options.DeploymentValues.AI.BridgeProxyConfig.Enabled.Value() {
// Seed env-derived providers before reading them back so the
// proxy observes them on first startup. options.Database is
// dbcrypt-wrapped at this point (set by coderd.New above),
// so env-seeded keys are also written encrypted. Detached
// ctx for the same reason as in agplcli below: an early
// return would orphan newAPI's goroutines. Seeding is
// idempotent; the agplcli path seeds again post-newAPI.
// Seed env-derived providers before the proxy daemon's reloader
// reads them back so the proxy observes them on first startup.
// options.Database is dbcrypt-wrapped at this point (set by
// coderd.New above), so env-seeded keys are also written
// encrypted. Detached ctx for the same reason as in agplcli
// below: an early return would orphan newAPI's goroutines.
// Seeding is idempotent; the agplcli path seeds again
// post-newAPI.
//nolint:gocritic // Production timeout, not a test wait.
aibridgeInitCtx, aibridgeInitCancel := context.WithTimeout(context.WithoutCancel(ctx), 30*time.Second)
defer aibridgeInitCancel()
@@ -185,19 +185,12 @@ func (r *RootCmd) Server(_ func()) *serpent.Command {
); err != nil {
return nil, nil, xerrors.Errorf("seed ai providers from env: %w", err)
}
providers, err := agplcli.BuildProviders(aibridgeInitCtx, options.Database, options.DeploymentValues.AI.BridgeConfig, options.Logger.Named("aibridge.providers"))
if err != nil {
return nil, nil, xerrors.Errorf("build AI providers: %w", err)
}
aiBridgeProxyServer, err := newAIBridgeProxyDaemon(api, providers)
aiBridgeProxyCloser, err := newAIBridgeProxyDaemon(api)
if err != nil {
_ = closers.Close()
return nil, nil, xerrors.Errorf("create aibridgeproxyd: %w", err)
}
closers.Add(aiBridgeProxyServer)
// Register the handler so coderd can serve the proxy endpoints.
api.RegisterInMemoryAIBridgeProxydHTTPHandler(aiBridgeProxyServer.Handler())
closers.Add(aiBridgeProxyCloser)
}
return api.AGPL, closers, nil
+233
View File
@@ -0,0 +1,233 @@
package coderd_test
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/cli"
"github.com/coder/coder/v2/coderd"
"github.com/coder/coder/v2/coderd/aibridged"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/enterprise/coderd/coderdenttest"
"github.com/coder/coder/v2/enterprise/coderd/license"
"github.com/coder/coder/v2/testutil"
"github.com/coder/serpent"
)
// mockUpstream is a single httptest server identified by a unique
// marker that it echoes in every response body, so callers can verify
// which upstream a proxied request actually reached. The hit counter
// supports asserting the upstream was touched at all.
type mockUpstream struct {
server *httptest.Server
name string
hits atomic.Int32
}
func newMockUpstream(t *testing.T, name string) *mockUpstream {
t.Helper()
m := &mockUpstream{name: name}
m.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
m.hits.Add(1)
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
assert.NoError(t, json.NewEncoder(w).Encode(map[string]string{"upstream": name}))
}))
t.Cleanup(m.server.Close)
return m
}
// startTestAIBridgeDaemon wires an in-process aibridged daemon onto
// the supplied API and subscribes it to ai_providers change events.
// This mirrors what cli/server.go does in production so /api/v2/aibridge
// requests dispatch through the real pool and reloader.
func startTestAIBridgeDaemon(t *testing.T, api *coderd.API) {
t.Helper()
ctx := context.Background()
logger := slogtest.Make(t, nil).Named("aibridged").Leveled(slog.LevelDebug)
cfg := api.DeploymentValues.AI.BridgeConfig
tracer := otel.Tracer("aibridge-reload-test")
providers, err := cli.BuildProviders(ctx, api.Database, cfg, logger)
require.NoError(t, err)
pool, err := aibridged.NewCachedBridgePool(aibridged.DefaultPoolOptions, providers, logger.Named("pool"), nil, tracer)
require.NoError(t, err)
t.Cleanup(func() { _ = pool.Shutdown(context.Background()) })
reloader := &testPoolReloader{pool: pool, db: api.Database, cfg: cfg, logger: logger.Named("reloader")}
unsubscribe, err := aibridged.SubscribeProviderReload(ctx, api.Pubsub, reloader, logger.Named("subscriber"))
require.NoError(t, err)
t.Cleanup(unsubscribe)
srv, err := aibridged.New(ctx, pool, func(dialCtx context.Context) (aibridged.DRPCClient, error) {
return api.CreateInMemoryAIBridgeServer(dialCtx)
}, logger, tracer)
require.NoError(t, err)
t.Cleanup(func() { _ = srv.Close() })
api.RegisterInMemoryAIBridgedHTTPHandler(srv)
}
type testPoolReloader struct {
pool *aibridged.CachedBridgePool
db database.Store
cfg codersdk.AIBridgeConfig
logger slog.Logger
}
func (r *testPoolReloader) Reload(ctx context.Context) error {
providers, err := cli.BuildProviders(ctx, r.db, r.cfg, r.logger)
if err != nil {
return err
}
r.pool.ReplaceProviders(providers)
return nil
}
// TestAIBridgeProviderHotReload exercises the end-to-end CRUD ->
// reload -> routing path: every provider mutation made through codersdk
// must, within a short window, change the routing observed at
// /api/v2/aibridge/{name}/v1/models. The OpenAI passthrough route
// /v1/models reverse-proxies to BaseURL, so the upstream that responds
// identifies which provider the daemon's mux dispatched to.
func TestAIBridgeProviderHotReload(t *testing.T) {
t.Parallel()
// Two distinct upstreams so an Update that swings the BaseURL is
// observable: which upstream answers tells us which BaseURL the
// freshly-built provider is pointed at.
upstreamA := newMockUpstream(t, "a")
upstreamB := newMockUpstream(t, "b")
dv := coderdtest.DeploymentValues(t)
dv.AI.BridgeConfig.Enabled = serpent.Bool(true)
client, _, api, _ := coderdenttest.NewWithAPI(t, &coderdenttest.Options{
Options: &coderdtest.Options{DeploymentValues: dv},
LicenseOptions: &coderdenttest.LicenseOptions{
Features: license.Features{codersdk.FeatureAIBridge: 1},
},
})
startTestAIBridgeDaemon(t, api.AGPL)
ctx := testutil.Context(t, testutil.WaitLong)
// sendRequest issues GET /api/v2/aibridge/{name}/v1/models and
// returns the status and the upstream marker decoded from the
// JSON body (empty if the body was not the marker JSON).
sendRequest := func(providerName string) (int, string) {
url := client.URL.String() + "/api/v2/aibridge/" + providerName + "/v1/models"
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
require.NoError(t, err)
req.Header.Set("Authorization", "Bearer "+client.SessionToken())
resp, err := http.DefaultClient.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return resp.StatusCode, ""
}
var decoded map[string]string
_ = json.Unmarshal(body, &decoded)
return resp.StatusCode, decoded["upstream"]
}
// requireRoutesTo polls until the routing reflects the expected
// upstream. The pool reloads asynchronously from a pubsub event;
// require.Eventually is the natural fit.
requireRoutesTo := func(t *testing.T, providerName string, upstream *mockUpstream) {
t.Helper()
before := upstream.hits.Load()
require.Eventuallyf(t, func() bool {
status, marker := sendRequest(providerName)
return status == http.StatusOK && marker == upstream.name
}, testutil.WaitShort, testutil.IntervalFast,
"expected provider %q to route to upstream %q", providerName, upstream.name)
require.Greater(t, upstream.hits.Load(), before,
"upstream %q must have observed at least one request", upstream.name)
}
// requireRoutingGone polls until the provider name yields a 404
// from the aibridge mux's catch-all, indicating the provider has
// been removed from the pool snapshot.
requireRoutingGone := func(t *testing.T, providerName string) {
t.Helper()
require.Eventuallyf(t, func() bool {
status, _ := sendRequest(providerName)
return status == http.StatusNotFound
}, testutil.WaitShort, testutil.IntervalFast,
"expected provider %q to stop routing", providerName)
}
// 1. Create: provider points at upstream A.
created, err := client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{
Type: codersdk.AIProviderTypeOpenAI,
Name: "primary",
Enabled: true,
BaseURL: upstreamA.server.URL,
APIKeys: []string{"sk-primary-key"},
})
require.NoError(t, err)
require.Equal(t, "primary", created.Name)
requireRoutesTo(t, "primary", upstreamA)
// 2. Update BaseURL: same name, now points at upstream B.
newBaseURL := upstreamB.server.URL
_, err = client.UpdateAIProvider(ctx, "primary", codersdk.UpdateAIProviderRequest{
BaseURL: &newBaseURL,
})
require.NoError(t, err)
requireRoutesTo(t, "primary", upstreamB)
// 3. Disable: the provider drops out of the snapshot, requests
// stop reaching any upstream.
disabled := false
_, err = client.UpdateAIProvider(ctx, "primary", codersdk.UpdateAIProviderRequest{
Enabled: &disabled,
})
require.NoError(t, err)
requireRoutingGone(t, "primary")
// 4. Re-enable: routing comes back at the most recent BaseURL.
enabled := true
_, err = client.UpdateAIProvider(ctx, "primary", codersdk.UpdateAIProviderRequest{
Enabled: &enabled,
})
require.NoError(t, err)
requireRoutesTo(t, "primary", upstreamB)
// 5. Add a second provider; both names must route independently.
_, err = client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{
Type: codersdk.AIProviderTypeOpenAI,
Name: "secondary",
Enabled: true,
BaseURL: upstreamA.server.URL,
APIKeys: []string{"sk-secondary-key"},
})
require.NoError(t, err)
requireRoutesTo(t, "primary", upstreamB)
requireRoutesTo(t, "secondary", upstreamA)
// 6. Delete primary: only secondary remains routable.
require.NoError(t, client.DeleteAIProvider(ctx, "primary"))
requireRoutingGone(t, "primary")
requireRoutesTo(t, "secondary", upstreamA)
}