feat: hot-reload aibridged and aibridgeproxyd providers on DB changes (#25673)

Previously the in-process aibridge daemon and the enterprise aibridgeproxy daemon both snapshotted their provider routing once at boot. Any `ai_providers` or `ai_provider_keys` mutation required a restart for either to pick it up.

Add an `ai_providers_changed` pubsub channel that the CRUD handlers publish on after Create / Update / Delete. Both daemons subscribe:

- **aibridged** rebuilds its `[]aibridge.Provider` snapshot via `BuildProviders` and swaps it into the pool atomically. Inflight requests keep serving against the bridge they already acquired; new acquires build against the new snapshot. Per-provider construction errors stay scoped to the offending row.
- **aibridgeproxyd** rebuilds its routing snapshot from `GetAIProviders` and swaps the host→provider map atomically. The MITM listener picks up new providers without restart.

DB read for aibridgeproxyd uses the existing `AsAIProviderMetadataReader` subject for routing-only access.
This commit is contained in:
Danny Kopping
2026-05-27 11:58:43 +02:00
committed by GitHub
parent 6acfe6c835
commit 79e007cf30
19 changed files with 1678 additions and 227 deletions
+19
View File
@@ -21,6 +21,7 @@ import (
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/httpmw"
coderpubsub "github.com/coder/coder/v2/coderd/pubsub"
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/codersdk"
)
@@ -235,6 +236,7 @@ func (api *API) aiProvidersCreate(rw http.ResponseWriter, r *http.Request) {
aReq.New = row
auditAIProviderKeyChanges(ctx, r, *auditor, api.Logger, aiProviderKeyChanges{Added: keys})
api.publishAIProvidersChanged(ctx)
sdk, err := db2sdk.AIProvider(row, keys)
if err != nil {
@@ -400,6 +402,7 @@ func (api *API) aiProvidersUpdate(rw http.ResponseWriter, r *http.Request) {
}
auditAIProviderKeyChanges(ctx, r, *auditor, api.Logger, keyChanges)
api.publishAIProvidersChanged(ctx)
sdk, err := db2sdk.AIProvider(updated, keys)
if err != nil {
@@ -453,9 +456,25 @@ func (api *API) aiProvidersDelete(rw http.ResponseWriter, r *http.Request) {
return
}
api.publishAIProvidersChanged(ctx)
rw.WriteHeader(http.StatusNoContent)
}
// publishAIProvidersChanged notifies subscribers (aibridged,
// aibridgeproxyd) that the live provider set changed and they should
// refetch from the database. Pubsub failures are logged but not
// propagated: subscribers refresh authoritatively from the DB, so a
// dropped notification only delays convergence.
func (api *API) publishAIProvidersChanged(ctx context.Context) {
if api.Pubsub == nil {
return
}
if err := api.Pubsub.Publish(coderpubsub.AIProvidersChangedChannel, nil); err != nil {
api.Logger.Warn(ctx, "publish ai providers changed event", slog.Error(err))
}
}
// errBedrockRejectsAPIKeys is the sentinel returned from inside the
// update transaction when a caller attempts to attach api_keys to a
// Bedrock-typed provider; the outer handler translates it into a 400.
+62
View File
@@ -0,0 +1,62 @@
package coderd_test
import (
"context"
"sync/atomic"
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/coderdtest"
coderpubsub "github.com/coder/coder/v2/coderd/pubsub"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
)
// TestAIProvidersChangedPubsub asserts that the CRUD handlers publish
// on AIProvidersChangedChannel for the operations that affect the
// runtime provider set. Subscribers (aibridged, aibridgeproxyd) depend
// on these notifications to trigger their pool reload.
//
// The handlers publish best-effort and the payload is empty, so we
// assert "at least one event per mutation" via a counter.
func TestAIProvidersChangedPubsub(t *testing.T) {
t.Parallel()
client, _, api := coderdtest.NewWithAPI(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
ctx := testutil.Context(t, testutil.WaitLong)
var count atomic.Int64
unsubscribe, err := api.Pubsub.Subscribe(coderpubsub.AIProvidersChangedChannel, func(_ context.Context, _ []byte) {
count.Add(1)
})
require.NoError(t, err)
t.Cleanup(unsubscribe)
// Create.
req := codersdk.CreateAIProviderRequest{
Type: codersdk.AIProviderTypeOpenAI,
Name: "pubsub-openai",
Enabled: true,
BaseURL: "https://api.openai.com/v1/",
APIKeys: []string{"k1"},
}
//nolint:gocritic // Owner role is the audience for this endpoint.
created, err := client.CreateAIProvider(ctx, req)
require.NoError(t, err)
testutil.Eventually(ctx, t, func(_ context.Context) bool { return count.Load() >= 1 }, testutil.IntervalFast)
// Update.
newKey := "k2"
_, err = client.UpdateAIProvider(ctx, created.ID.String(), codersdk.UpdateAIProviderRequest{
APIKeys: &[]codersdk.AIProviderKeyMutation{{APIKey: &newKey}},
})
require.NoError(t, err)
testutil.Eventually(ctx, t, func(_ context.Context) bool { return count.Load() >= 2 }, testutil.IntervalFast)
// Delete.
err = client.DeleteAIProvider(ctx, created.ID.String())
require.NoError(t, err)
testutil.Eventually(ctx, t, func(_ context.Context) bool { return count.Load() >= 3 }, testutil.IntervalFast)
}
@@ -14,6 +14,7 @@ import (
http "net/http"
reflect "reflect"
aibridge "github.com/coder/coder/v2/aibridge"
aibridged "github.com/coder/coder/v2/coderd/aibridged"
gomock "go.uber.org/mock/gomock"
)
@@ -57,6 +58,18 @@ func (mr *MockPoolerMockRecorder) Acquire(ctx, req, clientFn, mcpBootstrapper an
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Acquire", reflect.TypeOf((*MockPooler)(nil).Acquire), ctx, req, clientFn, mcpBootstrapper)
}
// ReplaceProviders mocks base method.
func (m *MockPooler) ReplaceProviders(providers []aibridge.Provider) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "ReplaceProviders", providers)
}
// ReplaceProviders indicates an expected call of ReplaceProviders.
func (mr *MockPoolerMockRecorder) ReplaceProviders(providers any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReplaceProviders", reflect.TypeOf((*MockPooler)(nil).ReplaceProviders), providers)
}
// Shutdown mocks base method.
func (m *MockPooler) Shutdown(ctx context.Context) error {
m.ctrl.T.Helper()
+71 -21
View File
@@ -3,7 +3,10 @@ package aibridged
import (
"context"
"net/http"
"slices"
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/dgraph-io/ristretto/v2"
@@ -26,6 +29,9 @@ const (
// One [*aibridge.RequestBridge] instance is created per given key.
type Pooler interface {
Acquire(ctx context.Context, req Request, clientFn ClientFunc, mcpBootstrapper MCPProxyBuilder) (http.Handler, error)
// ReplaceProviders swaps the providers used to construct future
// RequestBridge instances and clears the cache.
ReplaceProviders(providers []aibridge.Provider)
Shutdown(ctx context.Context) error
}
@@ -46,10 +52,12 @@ var DefaultPoolOptions = PoolOptions{MaxItems: 5000, TTL: time.Minute * 15}
var _ Pooler = &CachedBridgePool{}
type CachedBridgePool struct {
cache *ristretto.Cache[string, *aibridge.RequestBridge]
providers []aibridge.Provider
logger slog.Logger
options PoolOptions
cache *ristretto.Cache[string, *aibridge.RequestBridge]
// providers is the live provider set used by new RequestBridge instances.
providers atomic.Pointer[[]aibridge.Provider]
providerVersion atomic.Int64
logger slog.Logger
options PoolOptions
singleflight *singleflight.Group[string, *aibridge.RequestBridge]
@@ -71,13 +79,16 @@ func NewCachedBridgePool(options PoolOptions, providers []aibridge.Provider, log
if item == nil || item.Value == nil {
return
}
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), time.Second*5)
defer shutdownCancel()
// Run the eviction in the background since ristretto blocks sets until a free slot is available.
// Capture the value synchronously: ristretto reuses the
// item slot after OnEvict returns, so reading item.Value
// from the goroutine below races with the caller of
// Clear/Set. The shutdown still runs in the background to
// avoid blocking ristretto's eviction loop.
bridge := item.Value
go func() {
_ = item.Value.Shutdown(shutdownCtx)
shutdownCtx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
_ = bridge.Shutdown(shutdownCtx)
}()
},
})
@@ -85,18 +96,53 @@ func NewCachedBridgePool(options PoolOptions, providers []aibridge.Provider, log
return nil, xerrors.Errorf("create cache: %w", err)
}
return &CachedBridgePool{
cache: cache,
providers: providers,
options: options,
metrics: metrics,
tracer: tracer,
logger: logger,
pool := &CachedBridgePool{
cache: cache,
options: options,
metrics: metrics,
tracer: tracer,
logger: logger,
singleflight: &singleflight.Group[string, *aibridge.RequestBridge]{},
shuttingDownCh: make(chan struct{}),
}, nil
}
initial := slices.Clone(providers)
pool.providers.Store(&initial)
return pool, nil
}
// ReplaceProviders swaps the provider snapshot used by future Acquires.
// It is safe to call concurrently with Acquire and is a no-op after
// Shutdown.
func (p *CachedBridgePool) ReplaceProviders(providers []aibridge.Provider) {
select {
case <-p.shuttingDownCh:
return
default:
}
snapshot := slices.Clone(providers)
p.providers.Store(&snapshot)
version := time.Now().UnixNano()
p.providerVersion.Store(version)
// Clear evicts every cached bridge; OnEvict shuts each one down in
// the background. Wait for buffered writes to drain so a replacement
// immediately followed by an Acquire always sees the cleared cache.
p.cache.Clear()
p.cache.Wait()
p.logger.Info(context.Background(), "request bridge pool reloaded",
slog.F("provider_count", len(snapshot)),
slog.F("provider_version", version),
)
}
// loadProviders returns the current providers snapshot. The returned
// slice must not be mutated.
func (p *CachedBridgePool) loadProviders() []aibridge.Provider {
if ptr := p.providers.Load(); ptr != nil {
return *ptr
}
return nil
}
// Acquire retrieves or creates a [*aibridge.RequestBridge] instance per given key.
@@ -140,6 +186,7 @@ func (p *CachedBridgePool) Acquire(ctx context.Context, req Request, clientFn Cl
}
span.AddEvent("cache_miss")
providerVersion := p.providerVersion.Load()
recorder := aibridge.NewRecorder(p.logger.Named("recorder"), p.tracer, func() (aibridge.Recorder, error) {
client, err := clientFn()
if err != nil {
@@ -152,7 +199,8 @@ func (p *CachedBridgePool) Acquire(ctx context.Context, req Request, clientFn Cl
// Slow path.
// Creating an *aibridge.RequestBridge may take some time, so gate all subsequent callers behind the initial request and return the resulting value.
// TODO: track startup time since it adds latency to first request (histogram count will also help us see how often this occurs).
instance, err, _ := p.singleflight.Do(req.InitiatorID.String(), func() (*aibridge.RequestBridge, error) {
singleflightKey := cacheKey + "|" + strconv.FormatInt(providerVersion, 10)
instance, err, _ := p.singleflight.Do(singleflightKey, func() (*aibridge.RequestBridge, error) {
var (
mcpServers mcp.ServerProxier
err error
@@ -171,12 +219,14 @@ func (p *CachedBridgePool) Acquire(ctx context.Context, req Request, clientFn Cl
}
}
bridge, err := aibridge.NewRequestBridge(ctx, p.providers, recorder, mcpServers, p.logger, p.metrics, p.tracer)
bridge, err := aibridge.NewRequestBridge(ctx, p.loadProviders(), recorder, mcpServers, p.logger, p.metrics, p.tracer)
if err != nil {
return nil, xerrors.Errorf("create new request bridge: %w", err)
}
p.cache.SetWithTTL(cacheKey, bridge, cacheCost, p.options.TTL)
if p.providerVersion.Load() == providerVersion {
p.cache.SetWithTTL(cacheKey, bridge, cacheCost, p.options.TTL)
}
return bridge, nil
})
+215
View File
@@ -2,6 +2,10 @@ package aibridged_test
import (
"context"
"io"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
"testing/synctest"
"time"
@@ -12,10 +16,13 @@ import (
"go.uber.org/mock/gomock"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/aibridge"
"github.com/coder/coder/v2/aibridge/config"
"github.com/coder/coder/v2/aibridge/mcp"
"github.com/coder/coder/v2/aibridge/mcpmock"
"github.com/coder/coder/v2/coderd/aibridged"
mock "github.com/coder/coder/v2/coderd/aibridged/aibridgedmock"
"github.com/coder/coder/v2/testutil"
)
// TestPool validates the published behavior of [aibridged.CachedBridgePool].
@@ -107,6 +114,160 @@ func TestPool(t *testing.T) {
require.EqualValues(t, 3, cacheMetrics.Misses())
}
func TestPoolReplaceProvidersClearsCacheAndUsesNewProviders(t *testing.T) {
t.Parallel()
oldUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_, _ = io.WriteString(w, "old")
}))
t.Cleanup(oldUpstream.Close)
newUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_, _ = io.WriteString(w, "new")
}))
t.Cleanup(newUpstream.Close)
logger := slogtest.Make(t, nil)
ctrl := gomock.NewController(t)
client := mock.NewMockDRPCClient(ctrl)
mcpProxy := mcpmock.NewMockServerProxier(ctrl)
mcpProxy.EXPECT().Init(gomock.Any()).AnyTimes().Return(nil)
mcpProxy.EXPECT().Shutdown(gomock.Any()).AnyTimes().Return(nil)
opts := aibridged.PoolOptions{MaxItems: 1, TTL: time.Minute}
pool, err := aibridged.NewCachedBridgePool(opts, []aibridge.Provider{
aibridge.NewOpenAIProvider(config.OpenAI{Name: "old", BaseURL: oldUpstream.URL}),
}, logger, nil, testTracer)
require.NoError(t, err)
t.Cleanup(func() { _ = pool.Shutdown(context.Background()) })
req := aibridged.Request{
SessionKey: "key",
InitiatorID: uuid.New(),
APIKeyID: uuid.New().String(),
}
clientFn := func() (aibridged.DRPCClient, error) {
return client, nil
}
inst, err := pool.Acquire(t.Context(), req, clientFn, newMockMCPFactory(mcpProxy))
require.NoError(t, err)
assertHandlerBody(t, inst, "/old/v1/models", "old")
pool.ReplaceProviders([]aibridge.Provider{
aibridge.NewOpenAIProvider(config.OpenAI{Name: "new", BaseURL: newUpstream.URL}),
})
instAfterReload, err := pool.Acquire(t.Context(), req, clientFn, newMockMCPFactory(mcpProxy))
require.NoError(t, err)
require.NotSame(t, inst, instAfterReload)
assertHandlerBody(t, instAfterReload, "/new/v1/models", "new")
}
func TestPoolReplaceProvidersDoesNotJoinStaleSingleflight(t *testing.T) {
t.Parallel()
oldUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_, _ = io.WriteString(w, "old")
}))
t.Cleanup(oldUpstream.Close)
newUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
_, _ = io.WriteString(w, "new")
}))
t.Cleanup(newUpstream.Close)
logger := slogtest.Make(t, nil)
ctrl := gomock.NewController(t)
client := mock.NewMockDRPCClient(ctrl)
opts := aibridged.PoolOptions{MaxItems: 1, TTL: time.Minute}
pool, err := aibridged.NewCachedBridgePool(opts, []aibridge.Provider{
aibridge.NewOpenAIProvider(config.OpenAI{Name: "old", BaseURL: oldUpstream.URL}),
}, logger, nil, testTracer)
require.NoError(t, err)
t.Cleanup(func() { _ = pool.Shutdown(context.Background()) })
req := aibridged.Request{
SessionKey: "key",
InitiatorID: uuid.New(),
APIKeyID: uuid.New().String(),
}
clientFn := func() (aibridged.DRPCClient, error) {
return client, nil
}
factory := newBlockingMCPFactory()
firstDone := make(chan acquireResult, 1)
go func() {
handler, err := pool.Acquire(t.Context(), req, clientFn, factory)
firstDone <- acquireResult{handler: handler, err: err}
}()
require.Eventually(t, factory.firstBuildStarted, testutil.WaitShort, testutil.IntervalFast)
pool.ReplaceProviders([]aibridge.Provider{
aibridge.NewOpenAIProvider(config.OpenAI{Name: "new", BaseURL: newUpstream.URL}),
})
secondDone := make(chan acquireResult, 1)
go func() {
handler, err := pool.Acquire(t.Context(), req, clientFn, factory)
secondDone <- acquireResult{handler: handler, err: err}
}()
var second acquireResult
require.Eventually(t, func() bool {
select {
case second = <-secondDone:
return true
default:
return false
}
}, testutil.WaitShort, testutil.IntervalFast)
require.NoError(t, second.err)
assertHandlerBody(t, second.handler, "/new/v1/models", "new")
close(factory.releaseFirst)
var first acquireResult
require.Eventually(t, func() bool {
select {
case first = <-firstDone:
return true
default:
return false
}
}, testutil.WaitShort, testutil.IntervalFast)
require.NoError(t, first.err)
third, err := pool.Acquire(t.Context(), req, clientFn, factory)
require.NoError(t, err)
require.Same(t, second.handler, third)
}
func TestPoolReplaceProvidersAfterShutdownIsNoop(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
opts := aibridged.PoolOptions{MaxItems: 1, TTL: time.Minute}
pool, err := aibridged.NewCachedBridgePool(opts, nil, logger, nil, testTracer)
require.NoError(t, err)
require.NoError(t, pool.Shutdown(t.Context()))
require.NotPanics(t, func() {
pool.ReplaceProviders([]aibridge.Provider{
aibridge.NewOpenAIProvider(config.OpenAI{Name: "new", BaseURL: "https://example.com"}),
})
})
_, err = pool.Acquire(t.Context(), aibridged.Request{
SessionKey: "key",
InitiatorID: uuid.New(),
APIKeyID: uuid.New().String(),
}, func() (aibridged.DRPCClient, error) {
return nil, context.Canceled
}, newMockMCPFactory(nil))
require.ErrorContains(t, err, "pool shutting down")
}
func TestPool_Expiry(t *testing.T) {
t.Parallel()
@@ -166,6 +327,21 @@ func TestPool_Expiry(t *testing.T) {
})
}
func assertHandlerBody(t *testing.T, handler http.Handler, path string, body string) {
t.Helper()
req := httptest.NewRequest(http.MethodGet, path, nil)
rw := httptest.NewRecorder()
handler.ServeHTTP(rw, req)
resp := rw.Result()
defer resp.Body.Close()
got, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, http.StatusOK, resp.StatusCode)
require.Equal(t, body, string(got))
}
var _ aibridged.MCPProxyBuilder = &mockMCPFactory{}
type mockMCPFactory struct {
@@ -179,3 +355,42 @@ func newMockMCPFactory(proxy *mcpmock.MockServerProxier) *mockMCPFactory {
func (m *mockMCPFactory) Build(ctx context.Context, req aibridged.Request, tracer trace.Tracer) (mcp.ServerProxier, error) {
return m.proxy, nil
}
type acquireResult struct {
handler http.Handler
err error
}
type blockingMCPFactory struct {
calls atomic.Int32
firstStarted chan struct{}
releaseFirst chan struct{}
}
func newBlockingMCPFactory() *blockingMCPFactory {
return &blockingMCPFactory{
firstStarted: make(chan struct{}),
releaseFirst: make(chan struct{}),
}
}
func (m *blockingMCPFactory) firstBuildStarted() bool {
select {
case <-m.firstStarted:
return true
default:
return false
}
}
func (m *blockingMCPFactory) Build(ctx context.Context, _ aibridged.Request, _ trace.Tracer) (mcp.ServerProxier, error) {
if m.calls.Add(1) == 1 {
close(m.firstStarted)
select {
case <-m.releaseFirst:
case <-ctx.Done():
return nil, ctx.Err()
}
}
return nil, context.Canceled
}
+50
View File
@@ -0,0 +1,50 @@
package aibridged
import (
"context"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub"
"github.com/coder/coder/v2/coderd/pubsub"
)
// ProviderReloader refreshes a component's provider snapshot.
type ProviderReloader interface {
Reload(ctx context.Context) error
}
// SubscribeProviderReload refreshes once, then on AI provider changes.
func SubscribeProviderReload(
ctx context.Context,
ps dbpubsub.Pubsub,
reloader ProviderReloader,
logger slog.Logger,
) (func(), error) {
if ps == nil {
return nil, xerrors.New("pubsub is required")
}
if reloader == nil {
return nil, xerrors.New("reloader is required")
}
unsubscribe, err := ps.SubscribeWithErr(pubsub.AIProvidersChangedChannel, func(cbCtx context.Context, _ []byte, err error) {
if err != nil {
logger.Warn(cbCtx, "ai providers changed event delivered with error", slog.Error(err))
return
}
if err := reloader.Reload(cbCtx); err != nil {
logger.Warn(cbCtx, "reload ai provider snapshot from pubsub event", slog.Error(err))
return
}
logger.Debug(cbCtx, "reloaded ai provider snapshot from pubsub event")
})
if err != nil {
return nil, xerrors.Errorf("subscribe to %s: %w", pubsub.AIProvidersChangedChannel, err)
}
if err := reloader.Reload(ctx); err != nil {
logger.Warn(ctx, "initial ai provider reload", slog.Error(err))
}
return unsubscribe, nil
}
+133
View File
@@ -0,0 +1,133 @@
package aibridged_test
import (
"context"
"sync/atomic"
"testing"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/aibridged"
dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub"
"github.com/coder/coder/v2/coderd/pubsub"
"github.com/coder/coder/v2/testutil"
)
func TestSubscribeProviderReload(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
logger := slogtest.Make(t, nil)
ps := dbpubsub.NewInMemory()
t.Cleanup(func() { _ = ps.Close() })
calls := &recordingReloader{}
unsub, err := aibridged.SubscribeProviderReload(ctx, ps, calls, logger)
require.NoError(t, err)
t.Cleanup(unsub)
require.Equal(t, 1, calls.count())
require.NoError(t, ps.Publish(pubsub.AIProvidersChangedChannel, nil))
require.Eventually(t, func() bool { return calls.count() >= 2 }, testutil.WaitShort, testutil.IntervalFast,
"Reload must fire again after a pubsub notification")
}
func TestSubscribeProviderReloadSurfacesReloadError(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
logger := slogtest.Make(t, nil)
ps := dbpubsub.NewInMemory()
t.Cleanup(func() { _ = ps.Close() })
calls := &recordingReloader{returnErr: true}
unsub, err := aibridged.SubscribeProviderReload(ctx, ps, calls, logger)
require.NoError(t, err)
t.Cleanup(unsub)
require.Equal(t, 1, calls.count())
require.NoError(t, ps.Publish(pubsub.AIProvidersChangedChannel, nil))
require.Eventually(t, func() bool { return calls.count() >= 2 }, testutil.WaitShort, testutil.IntervalFast,
"Reload must keep firing even after a previous Reload returned an error")
}
func TestSubscribeProviderReloadIgnoresEventError(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitMedium)
logger := slogtest.Make(t, nil)
ps := &errInjectingPubsub{}
calls := &recordingReloader{}
unsub, err := aibridged.SubscribeProviderReload(ctx, ps, calls, logger)
require.NoError(t, err)
t.Cleanup(unsub)
require.Equal(t, 1, calls.count())
ps.listener(ctx, nil, errPubsubDelivery)
require.Equal(t, 1, calls.count())
ps.listener(ctx, nil, nil)
require.Equal(t, 2, calls.count())
}
// recordingReloader is a minimal [aibridged.ProviderReloader] that
// counts calls.
type recordingReloader struct {
n atomic.Int32
returnErr bool
}
func (r *recordingReloader) Reload(_ context.Context) error {
r.n.Add(1)
if r.returnErr {
return errReloadFailed
}
return nil
}
func (r *recordingReloader) count() int {
return int(r.n.Load())
}
var (
errReloadFailed = stubError("reload failed")
errPubsubDelivery = stubError("pubsub delivery failed")
)
type stubError string
func (s stubError) Error() string { return string(s) }
var _ dbpubsub.Pubsub = &errInjectingPubsub{}
type errInjectingPubsub struct {
listener dbpubsub.ListenerWithErr
}
func (*errInjectingPubsub) Subscribe(string, dbpubsub.Listener) (func(), error) {
return nil, xerrors.New("Subscribe not implemented")
}
func (p *errInjectingPubsub) SubscribeWithErr(_ string, listener dbpubsub.ListenerWithErr) (func(), error) {
p.listener = listener
return func() {}, nil
}
func (*errInjectingPubsub) Publish(string, []byte) error {
return xerrors.New("Publish not implemented")
}
func (*errInjectingPubsub) Close() error {
return nil
}
+11
View File
@@ -0,0 +1,11 @@
package pubsub
// AIProvidersChangedChannel is the pubsub channel that carries AI
// provider lifecycle events: provider create / update / soft-delete
// and key insert / delete. Subscribers (aibridged, aibridgeproxyd)
// reload their in-memory provider snapshot on receipt.
//
// The payload is an empty invalidation hint; subscribers refetch the
// authoritative state from the database, so dropped messages only
// delay convergence rather than diverge state.
const AIProvidersChangedChannel = "ai_providers_changed"