mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
feat: hot-reload aibridged and aibridgeproxyd providers on DB changes (#25673)
Previously the in-process aibridge daemon and the enterprise aibridgeproxy daemon both snapshotted their provider routing once at boot. Any `ai_providers` or `ai_provider_keys` mutation required a restart for either to pick it up. Add an `ai_providers_changed` pubsub channel that the CRUD handlers publish on after Create / Update / Delete. Both daemons subscribe: - **aibridged** rebuilds its `[]aibridge.Provider` snapshot via `BuildProviders` and swaps it into the pool atomically. Inflight requests keep serving against the bridge they already acquired; new acquires build against the new snapshot. Per-provider construction errors stay scoped to the offending row. - **aibridgeproxyd** rebuilds its routing snapshot from `GetAIProviders` and swaps the host→provider map atomically. The MITM listener picks up new providers without restart. DB read for aibridgeproxyd uses the existing `AsAIProviderMetadataReader` subject for routing-only access.
This commit is contained in:
@@ -21,6 +21,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/database/dbtime"
|
||||
"github.com/coder/coder/v2/coderd/httpapi"
|
||||
"github.com/coder/coder/v2/coderd/httpmw"
|
||||
coderpubsub "github.com/coder/coder/v2/coderd/pubsub"
|
||||
"github.com/coder/coder/v2/coderd/util/ptr"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
@@ -235,6 +236,7 @@ func (api *API) aiProvidersCreate(rw http.ResponseWriter, r *http.Request) {
|
||||
aReq.New = row
|
||||
|
||||
auditAIProviderKeyChanges(ctx, r, *auditor, api.Logger, aiProviderKeyChanges{Added: keys})
|
||||
api.publishAIProvidersChanged(ctx)
|
||||
|
||||
sdk, err := db2sdk.AIProvider(row, keys)
|
||||
if err != nil {
|
||||
@@ -400,6 +402,7 @@ func (api *API) aiProvidersUpdate(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
auditAIProviderKeyChanges(ctx, r, *auditor, api.Logger, keyChanges)
|
||||
api.publishAIProvidersChanged(ctx)
|
||||
|
||||
sdk, err := db2sdk.AIProvider(updated, keys)
|
||||
if err != nil {
|
||||
@@ -453,9 +456,25 @@ func (api *API) aiProvidersDelete(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
api.publishAIProvidersChanged(ctx)
|
||||
|
||||
rw.WriteHeader(http.StatusNoContent)
|
||||
}
|
||||
|
||||
// publishAIProvidersChanged notifies subscribers (aibridged,
|
||||
// aibridgeproxyd) that the live provider set changed and they should
|
||||
// refetch from the database. Pubsub failures are logged but not
|
||||
// propagated: subscribers refresh authoritatively from the DB, so a
|
||||
// dropped notification only delays convergence.
|
||||
func (api *API) publishAIProvidersChanged(ctx context.Context) {
|
||||
if api.Pubsub == nil {
|
||||
return
|
||||
}
|
||||
if err := api.Pubsub.Publish(coderpubsub.AIProvidersChangedChannel, nil); err != nil {
|
||||
api.Logger.Warn(ctx, "publish ai providers changed event", slog.Error(err))
|
||||
}
|
||||
}
|
||||
|
||||
// errBedrockRejectsAPIKeys is the sentinel returned from inside the
|
||||
// update transaction when a caller attempts to attach api_keys to a
|
||||
// Bedrock-typed provider; the outer handler translates it into a 400.
|
||||
|
||||
@@ -0,0 +1,62 @@
|
||||
package coderd_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
coderpubsub "github.com/coder/coder/v2/coderd/pubsub"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
// TestAIProvidersChangedPubsub asserts that the CRUD handlers publish
|
||||
// on AIProvidersChangedChannel for the operations that affect the
|
||||
// runtime provider set. Subscribers (aibridged, aibridgeproxyd) depend
|
||||
// on these notifications to trigger their pool reload.
|
||||
//
|
||||
// The handlers publish best-effort and the payload is empty, so we
|
||||
// assert "at least one event per mutation" via a counter.
|
||||
func TestAIProvidersChangedPubsub(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
client, _, api := coderdtest.NewWithAPI(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
var count atomic.Int64
|
||||
unsubscribe, err := api.Pubsub.Subscribe(coderpubsub.AIProvidersChangedChannel, func(_ context.Context, _ []byte) {
|
||||
count.Add(1)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(unsubscribe)
|
||||
|
||||
// Create.
|
||||
req := codersdk.CreateAIProviderRequest{
|
||||
Type: codersdk.AIProviderTypeOpenAI,
|
||||
Name: "pubsub-openai",
|
||||
Enabled: true,
|
||||
BaseURL: "https://api.openai.com/v1/",
|
||||
APIKeys: []string{"k1"},
|
||||
}
|
||||
//nolint:gocritic // Owner role is the audience for this endpoint.
|
||||
created, err := client.CreateAIProvider(ctx, req)
|
||||
require.NoError(t, err)
|
||||
testutil.Eventually(ctx, t, func(_ context.Context) bool { return count.Load() >= 1 }, testutil.IntervalFast)
|
||||
|
||||
// Update.
|
||||
newKey := "k2"
|
||||
_, err = client.UpdateAIProvider(ctx, created.ID.String(), codersdk.UpdateAIProviderRequest{
|
||||
APIKeys: &[]codersdk.AIProviderKeyMutation{{APIKey: &newKey}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
testutil.Eventually(ctx, t, func(_ context.Context) bool { return count.Load() >= 2 }, testutil.IntervalFast)
|
||||
|
||||
// Delete.
|
||||
err = client.DeleteAIProvider(ctx, created.ID.String())
|
||||
require.NoError(t, err)
|
||||
testutil.Eventually(ctx, t, func(_ context.Context) bool { return count.Load() >= 3 }, testutil.IntervalFast)
|
||||
}
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
http "net/http"
|
||||
reflect "reflect"
|
||||
|
||||
aibridge "github.com/coder/coder/v2/aibridge"
|
||||
aibridged "github.com/coder/coder/v2/coderd/aibridged"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
@@ -57,6 +58,18 @@ func (mr *MockPoolerMockRecorder) Acquire(ctx, req, clientFn, mcpBootstrapper an
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Acquire", reflect.TypeOf((*MockPooler)(nil).Acquire), ctx, req, clientFn, mcpBootstrapper)
|
||||
}
|
||||
|
||||
// ReplaceProviders mocks base method.
|
||||
func (m *MockPooler) ReplaceProviders(providers []aibridge.Provider) {
|
||||
m.ctrl.T.Helper()
|
||||
m.ctrl.Call(m, "ReplaceProviders", providers)
|
||||
}
|
||||
|
||||
// ReplaceProviders indicates an expected call of ReplaceProviders.
|
||||
func (mr *MockPoolerMockRecorder) ReplaceProviders(providers any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReplaceProviders", reflect.TypeOf((*MockPooler)(nil).ReplaceProviders), providers)
|
||||
}
|
||||
|
||||
// Shutdown mocks base method.
|
||||
func (m *MockPooler) Shutdown(ctx context.Context) error {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
+71
-21
@@ -3,7 +3,10 @@ package aibridged
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strconv"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/dgraph-io/ristretto/v2"
|
||||
@@ -26,6 +29,9 @@ const (
|
||||
// One [*aibridge.RequestBridge] instance is created per given key.
|
||||
type Pooler interface {
|
||||
Acquire(ctx context.Context, req Request, clientFn ClientFunc, mcpBootstrapper MCPProxyBuilder) (http.Handler, error)
|
||||
// ReplaceProviders swaps the providers used to construct future
|
||||
// RequestBridge instances and clears the cache.
|
||||
ReplaceProviders(providers []aibridge.Provider)
|
||||
Shutdown(ctx context.Context) error
|
||||
}
|
||||
|
||||
@@ -46,10 +52,12 @@ var DefaultPoolOptions = PoolOptions{MaxItems: 5000, TTL: time.Minute * 15}
|
||||
var _ Pooler = &CachedBridgePool{}
|
||||
|
||||
type CachedBridgePool struct {
|
||||
cache *ristretto.Cache[string, *aibridge.RequestBridge]
|
||||
providers []aibridge.Provider
|
||||
logger slog.Logger
|
||||
options PoolOptions
|
||||
cache *ristretto.Cache[string, *aibridge.RequestBridge]
|
||||
// providers is the live provider set used by new RequestBridge instances.
|
||||
providers atomic.Pointer[[]aibridge.Provider]
|
||||
providerVersion atomic.Int64
|
||||
logger slog.Logger
|
||||
options PoolOptions
|
||||
|
||||
singleflight *singleflight.Group[string, *aibridge.RequestBridge]
|
||||
|
||||
@@ -71,13 +79,16 @@ func NewCachedBridgePool(options PoolOptions, providers []aibridge.Provider, log
|
||||
if item == nil || item.Value == nil {
|
||||
return
|
||||
}
|
||||
|
||||
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
defer shutdownCancel()
|
||||
|
||||
// Run the eviction in the background since ristretto blocks sets until a free slot is available.
|
||||
// Capture the value synchronously: ristretto reuses the
|
||||
// item slot after OnEvict returns, so reading item.Value
|
||||
// from the goroutine below races with the caller of
|
||||
// Clear/Set. The shutdown still runs in the background to
|
||||
// avoid blocking ristretto's eviction loop.
|
||||
bridge := item.Value
|
||||
go func() {
|
||||
_ = item.Value.Shutdown(shutdownCtx)
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
||||
defer cancel()
|
||||
_ = bridge.Shutdown(shutdownCtx)
|
||||
}()
|
||||
},
|
||||
})
|
||||
@@ -85,18 +96,53 @@ func NewCachedBridgePool(options PoolOptions, providers []aibridge.Provider, log
|
||||
return nil, xerrors.Errorf("create cache: %w", err)
|
||||
}
|
||||
|
||||
return &CachedBridgePool{
|
||||
cache: cache,
|
||||
providers: providers,
|
||||
options: options,
|
||||
metrics: metrics,
|
||||
tracer: tracer,
|
||||
logger: logger,
|
||||
pool := &CachedBridgePool{
|
||||
cache: cache,
|
||||
options: options,
|
||||
metrics: metrics,
|
||||
tracer: tracer,
|
||||
logger: logger,
|
||||
|
||||
singleflight: &singleflight.Group[string, *aibridge.RequestBridge]{},
|
||||
|
||||
shuttingDownCh: make(chan struct{}),
|
||||
}, nil
|
||||
}
|
||||
initial := slices.Clone(providers)
|
||||
pool.providers.Store(&initial)
|
||||
return pool, nil
|
||||
}
|
||||
|
||||
// ReplaceProviders swaps the provider snapshot used by future Acquires.
|
||||
// It is safe to call concurrently with Acquire and is a no-op after
|
||||
// Shutdown.
|
||||
func (p *CachedBridgePool) ReplaceProviders(providers []aibridge.Provider) {
|
||||
select {
|
||||
case <-p.shuttingDownCh:
|
||||
return
|
||||
default:
|
||||
}
|
||||
snapshot := slices.Clone(providers)
|
||||
p.providers.Store(&snapshot)
|
||||
version := time.Now().UnixNano()
|
||||
p.providerVersion.Store(version)
|
||||
// Clear evicts every cached bridge; OnEvict shuts each one down in
|
||||
// the background. Wait for buffered writes to drain so a replacement
|
||||
// immediately followed by an Acquire always sees the cleared cache.
|
||||
p.cache.Clear()
|
||||
p.cache.Wait()
|
||||
p.logger.Info(context.Background(), "request bridge pool reloaded",
|
||||
slog.F("provider_count", len(snapshot)),
|
||||
slog.F("provider_version", version),
|
||||
)
|
||||
}
|
||||
|
||||
// loadProviders returns the current providers snapshot. The returned
|
||||
// slice must not be mutated.
|
||||
func (p *CachedBridgePool) loadProviders() []aibridge.Provider {
|
||||
if ptr := p.providers.Load(); ptr != nil {
|
||||
return *ptr
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Acquire retrieves or creates a [*aibridge.RequestBridge] instance per given key.
|
||||
@@ -140,6 +186,7 @@ func (p *CachedBridgePool) Acquire(ctx context.Context, req Request, clientFn Cl
|
||||
}
|
||||
|
||||
span.AddEvent("cache_miss")
|
||||
providerVersion := p.providerVersion.Load()
|
||||
recorder := aibridge.NewRecorder(p.logger.Named("recorder"), p.tracer, func() (aibridge.Recorder, error) {
|
||||
client, err := clientFn()
|
||||
if err != nil {
|
||||
@@ -152,7 +199,8 @@ func (p *CachedBridgePool) Acquire(ctx context.Context, req Request, clientFn Cl
|
||||
// Slow path.
|
||||
// Creating an *aibridge.RequestBridge may take some time, so gate all subsequent callers behind the initial request and return the resulting value.
|
||||
// TODO: track startup time since it adds latency to first request (histogram count will also help us see how often this occurs).
|
||||
instance, err, _ := p.singleflight.Do(req.InitiatorID.String(), func() (*aibridge.RequestBridge, error) {
|
||||
singleflightKey := cacheKey + "|" + strconv.FormatInt(providerVersion, 10)
|
||||
instance, err, _ := p.singleflight.Do(singleflightKey, func() (*aibridge.RequestBridge, error) {
|
||||
var (
|
||||
mcpServers mcp.ServerProxier
|
||||
err error
|
||||
@@ -171,12 +219,14 @@ func (p *CachedBridgePool) Acquire(ctx context.Context, req Request, clientFn Cl
|
||||
}
|
||||
}
|
||||
|
||||
bridge, err := aibridge.NewRequestBridge(ctx, p.providers, recorder, mcpServers, p.logger, p.metrics, p.tracer)
|
||||
bridge, err := aibridge.NewRequestBridge(ctx, p.loadProviders(), recorder, mcpServers, p.logger, p.metrics, p.tracer)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("create new request bridge: %w", err)
|
||||
}
|
||||
|
||||
p.cache.SetWithTTL(cacheKey, bridge, cacheCost, p.options.TTL)
|
||||
if p.providerVersion.Load() == providerVersion {
|
||||
p.cache.SetWithTTL(cacheKey, bridge, cacheCost, p.options.TTL)
|
||||
}
|
||||
|
||||
return bridge, nil
|
||||
})
|
||||
|
||||
@@ -2,6 +2,10 @@ package aibridged_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"testing/synctest"
|
||||
"time"
|
||||
@@ -12,10 +16,13 @@ import (
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/aibridge"
|
||||
"github.com/coder/coder/v2/aibridge/config"
|
||||
"github.com/coder/coder/v2/aibridge/mcp"
|
||||
"github.com/coder/coder/v2/aibridge/mcpmock"
|
||||
"github.com/coder/coder/v2/coderd/aibridged"
|
||||
mock "github.com/coder/coder/v2/coderd/aibridged/aibridgedmock"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
// TestPool validates the published behavior of [aibridged.CachedBridgePool].
|
||||
@@ -107,6 +114,160 @@ func TestPool(t *testing.T) {
|
||||
require.EqualValues(t, 3, cacheMetrics.Misses())
|
||||
}
|
||||
|
||||
func TestPoolReplaceProvidersClearsCacheAndUsesNewProviders(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
oldUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = io.WriteString(w, "old")
|
||||
}))
|
||||
t.Cleanup(oldUpstream.Close)
|
||||
newUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = io.WriteString(w, "new")
|
||||
}))
|
||||
t.Cleanup(newUpstream.Close)
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
ctrl := gomock.NewController(t)
|
||||
client := mock.NewMockDRPCClient(ctrl)
|
||||
mcpProxy := mcpmock.NewMockServerProxier(ctrl)
|
||||
mcpProxy.EXPECT().Init(gomock.Any()).AnyTimes().Return(nil)
|
||||
mcpProxy.EXPECT().Shutdown(gomock.Any()).AnyTimes().Return(nil)
|
||||
|
||||
opts := aibridged.PoolOptions{MaxItems: 1, TTL: time.Minute}
|
||||
pool, err := aibridged.NewCachedBridgePool(opts, []aibridge.Provider{
|
||||
aibridge.NewOpenAIProvider(config.OpenAI{Name: "old", BaseURL: oldUpstream.URL}),
|
||||
}, logger, nil, testTracer)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { _ = pool.Shutdown(context.Background()) })
|
||||
|
||||
req := aibridged.Request{
|
||||
SessionKey: "key",
|
||||
InitiatorID: uuid.New(),
|
||||
APIKeyID: uuid.New().String(),
|
||||
}
|
||||
clientFn := func() (aibridged.DRPCClient, error) {
|
||||
return client, nil
|
||||
}
|
||||
|
||||
inst, err := pool.Acquire(t.Context(), req, clientFn, newMockMCPFactory(mcpProxy))
|
||||
require.NoError(t, err)
|
||||
assertHandlerBody(t, inst, "/old/v1/models", "old")
|
||||
|
||||
pool.ReplaceProviders([]aibridge.Provider{
|
||||
aibridge.NewOpenAIProvider(config.OpenAI{Name: "new", BaseURL: newUpstream.URL}),
|
||||
})
|
||||
|
||||
instAfterReload, err := pool.Acquire(t.Context(), req, clientFn, newMockMCPFactory(mcpProxy))
|
||||
require.NoError(t, err)
|
||||
require.NotSame(t, inst, instAfterReload)
|
||||
assertHandlerBody(t, instAfterReload, "/new/v1/models", "new")
|
||||
}
|
||||
|
||||
func TestPoolReplaceProvidersDoesNotJoinStaleSingleflight(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
oldUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = io.WriteString(w, "old")
|
||||
}))
|
||||
t.Cleanup(oldUpstream.Close)
|
||||
newUpstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
|
||||
_, _ = io.WriteString(w, "new")
|
||||
}))
|
||||
t.Cleanup(newUpstream.Close)
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
ctrl := gomock.NewController(t)
|
||||
client := mock.NewMockDRPCClient(ctrl)
|
||||
|
||||
opts := aibridged.PoolOptions{MaxItems: 1, TTL: time.Minute}
|
||||
pool, err := aibridged.NewCachedBridgePool(opts, []aibridge.Provider{
|
||||
aibridge.NewOpenAIProvider(config.OpenAI{Name: "old", BaseURL: oldUpstream.URL}),
|
||||
}, logger, nil, testTracer)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() { _ = pool.Shutdown(context.Background()) })
|
||||
|
||||
req := aibridged.Request{
|
||||
SessionKey: "key",
|
||||
InitiatorID: uuid.New(),
|
||||
APIKeyID: uuid.New().String(),
|
||||
}
|
||||
clientFn := func() (aibridged.DRPCClient, error) {
|
||||
return client, nil
|
||||
}
|
||||
|
||||
factory := newBlockingMCPFactory()
|
||||
firstDone := make(chan acquireResult, 1)
|
||||
go func() {
|
||||
handler, err := pool.Acquire(t.Context(), req, clientFn, factory)
|
||||
firstDone <- acquireResult{handler: handler, err: err}
|
||||
}()
|
||||
|
||||
require.Eventually(t, factory.firstBuildStarted, testutil.WaitShort, testutil.IntervalFast)
|
||||
|
||||
pool.ReplaceProviders([]aibridge.Provider{
|
||||
aibridge.NewOpenAIProvider(config.OpenAI{Name: "new", BaseURL: newUpstream.URL}),
|
||||
})
|
||||
|
||||
secondDone := make(chan acquireResult, 1)
|
||||
go func() {
|
||||
handler, err := pool.Acquire(t.Context(), req, clientFn, factory)
|
||||
secondDone <- acquireResult{handler: handler, err: err}
|
||||
}()
|
||||
|
||||
var second acquireResult
|
||||
require.Eventually(t, func() bool {
|
||||
select {
|
||||
case second = <-secondDone:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
require.NoError(t, second.err)
|
||||
assertHandlerBody(t, second.handler, "/new/v1/models", "new")
|
||||
|
||||
close(factory.releaseFirst)
|
||||
var first acquireResult
|
||||
require.Eventually(t, func() bool {
|
||||
select {
|
||||
case first = <-firstDone:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}, testutil.WaitShort, testutil.IntervalFast)
|
||||
require.NoError(t, first.err)
|
||||
|
||||
third, err := pool.Acquire(t.Context(), req, clientFn, factory)
|
||||
require.NoError(t, err)
|
||||
require.Same(t, second.handler, third)
|
||||
}
|
||||
|
||||
func TestPoolReplaceProvidersAfterShutdownIsNoop(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
opts := aibridged.PoolOptions{MaxItems: 1, TTL: time.Minute}
|
||||
pool, err := aibridged.NewCachedBridgePool(opts, nil, logger, nil, testTracer)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, pool.Shutdown(t.Context()))
|
||||
require.NotPanics(t, func() {
|
||||
pool.ReplaceProviders([]aibridge.Provider{
|
||||
aibridge.NewOpenAIProvider(config.OpenAI{Name: "new", BaseURL: "https://example.com"}),
|
||||
})
|
||||
})
|
||||
|
||||
_, err = pool.Acquire(t.Context(), aibridged.Request{
|
||||
SessionKey: "key",
|
||||
InitiatorID: uuid.New(),
|
||||
APIKeyID: uuid.New().String(),
|
||||
}, func() (aibridged.DRPCClient, error) {
|
||||
return nil, context.Canceled
|
||||
}, newMockMCPFactory(nil))
|
||||
require.ErrorContains(t, err, "pool shutting down")
|
||||
}
|
||||
|
||||
func TestPool_Expiry(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -166,6 +327,21 @@ func TestPool_Expiry(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func assertHandlerBody(t *testing.T, handler http.Handler, path string, body string) {
|
||||
t.Helper()
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, path, nil)
|
||||
rw := httptest.NewRecorder()
|
||||
handler.ServeHTTP(rw, req)
|
||||
resp := rw.Result()
|
||||
defer resp.Body.Close()
|
||||
|
||||
got, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, http.StatusOK, resp.StatusCode)
|
||||
require.Equal(t, body, string(got))
|
||||
}
|
||||
|
||||
var _ aibridged.MCPProxyBuilder = &mockMCPFactory{}
|
||||
|
||||
type mockMCPFactory struct {
|
||||
@@ -179,3 +355,42 @@ func newMockMCPFactory(proxy *mcpmock.MockServerProxier) *mockMCPFactory {
|
||||
func (m *mockMCPFactory) Build(ctx context.Context, req aibridged.Request, tracer trace.Tracer) (mcp.ServerProxier, error) {
|
||||
return m.proxy, nil
|
||||
}
|
||||
|
||||
type acquireResult struct {
|
||||
handler http.Handler
|
||||
err error
|
||||
}
|
||||
|
||||
type blockingMCPFactory struct {
|
||||
calls atomic.Int32
|
||||
firstStarted chan struct{}
|
||||
releaseFirst chan struct{}
|
||||
}
|
||||
|
||||
func newBlockingMCPFactory() *blockingMCPFactory {
|
||||
return &blockingMCPFactory{
|
||||
firstStarted: make(chan struct{}),
|
||||
releaseFirst: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *blockingMCPFactory) firstBuildStarted() bool {
|
||||
select {
|
||||
case <-m.firstStarted:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func (m *blockingMCPFactory) Build(ctx context.Context, _ aibridged.Request, _ trace.Tracer) (mcp.ServerProxier, error) {
|
||||
if m.calls.Add(1) == 1 {
|
||||
close(m.firstStarted)
|
||||
select {
|
||||
case <-m.releaseFirst:
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
return nil, context.Canceled
|
||||
}
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
package aibridged
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
"github.com/coder/coder/v2/coderd/pubsub"
|
||||
)
|
||||
|
||||
// ProviderReloader refreshes a component's provider snapshot.
|
||||
type ProviderReloader interface {
|
||||
Reload(ctx context.Context) error
|
||||
}
|
||||
|
||||
// SubscribeProviderReload refreshes once, then on AI provider changes.
|
||||
func SubscribeProviderReload(
|
||||
ctx context.Context,
|
||||
ps dbpubsub.Pubsub,
|
||||
reloader ProviderReloader,
|
||||
logger slog.Logger,
|
||||
) (func(), error) {
|
||||
if ps == nil {
|
||||
return nil, xerrors.New("pubsub is required")
|
||||
}
|
||||
if reloader == nil {
|
||||
return nil, xerrors.New("reloader is required")
|
||||
}
|
||||
|
||||
unsubscribe, err := ps.SubscribeWithErr(pubsub.AIProvidersChangedChannel, func(cbCtx context.Context, _ []byte, err error) {
|
||||
if err != nil {
|
||||
logger.Warn(cbCtx, "ai providers changed event delivered with error", slog.Error(err))
|
||||
return
|
||||
}
|
||||
if err := reloader.Reload(cbCtx); err != nil {
|
||||
logger.Warn(cbCtx, "reload ai provider snapshot from pubsub event", slog.Error(err))
|
||||
return
|
||||
}
|
||||
logger.Debug(cbCtx, "reloaded ai provider snapshot from pubsub event")
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("subscribe to %s: %w", pubsub.AIProvidersChangedChannel, err)
|
||||
}
|
||||
if err := reloader.Reload(ctx); err != nil {
|
||||
logger.Warn(ctx, "initial ai provider reload", slog.Error(err))
|
||||
}
|
||||
return unsubscribe, nil
|
||||
}
|
||||
@@ -0,0 +1,133 @@
|
||||
package aibridged_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd/aibridged"
|
||||
dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
"github.com/coder/coder/v2/coderd/pubsub"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestSubscribeProviderReload(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
ps := dbpubsub.NewInMemory()
|
||||
t.Cleanup(func() { _ = ps.Close() })
|
||||
|
||||
calls := &recordingReloader{}
|
||||
|
||||
unsub, err := aibridged.SubscribeProviderReload(ctx, ps, calls, logger)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(unsub)
|
||||
|
||||
require.Equal(t, 1, calls.count())
|
||||
|
||||
require.NoError(t, ps.Publish(pubsub.AIProvidersChangedChannel, nil))
|
||||
|
||||
require.Eventually(t, func() bool { return calls.count() >= 2 }, testutil.WaitShort, testutil.IntervalFast,
|
||||
"Reload must fire again after a pubsub notification")
|
||||
}
|
||||
|
||||
func TestSubscribeProviderReloadSurfacesReloadError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
ps := dbpubsub.NewInMemory()
|
||||
t.Cleanup(func() { _ = ps.Close() })
|
||||
|
||||
calls := &recordingReloader{returnErr: true}
|
||||
|
||||
unsub, err := aibridged.SubscribeProviderReload(ctx, ps, calls, logger)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(unsub)
|
||||
|
||||
require.Equal(t, 1, calls.count())
|
||||
require.NoError(t, ps.Publish(pubsub.AIProvidersChangedChannel, nil))
|
||||
require.Eventually(t, func() bool { return calls.count() >= 2 }, testutil.WaitShort, testutil.IntervalFast,
|
||||
"Reload must keep firing even after a previous Reload returned an error")
|
||||
}
|
||||
|
||||
func TestSubscribeProviderReloadIgnoresEventError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
logger := slogtest.Make(t, nil)
|
||||
ps := &errInjectingPubsub{}
|
||||
|
||||
calls := &recordingReloader{}
|
||||
unsub, err := aibridged.SubscribeProviderReload(ctx, ps, calls, logger)
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(unsub)
|
||||
|
||||
require.Equal(t, 1, calls.count())
|
||||
|
||||
ps.listener(ctx, nil, errPubsubDelivery)
|
||||
require.Equal(t, 1, calls.count())
|
||||
|
||||
ps.listener(ctx, nil, nil)
|
||||
require.Equal(t, 2, calls.count())
|
||||
}
|
||||
|
||||
// recordingReloader is a minimal [aibridged.ProviderReloader] that
|
||||
// counts calls.
|
||||
type recordingReloader struct {
|
||||
n atomic.Int32
|
||||
returnErr bool
|
||||
}
|
||||
|
||||
func (r *recordingReloader) Reload(_ context.Context) error {
|
||||
r.n.Add(1)
|
||||
if r.returnErr {
|
||||
return errReloadFailed
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *recordingReloader) count() int {
|
||||
return int(r.n.Load())
|
||||
}
|
||||
|
||||
var (
|
||||
errReloadFailed = stubError("reload failed")
|
||||
errPubsubDelivery = stubError("pubsub delivery failed")
|
||||
)
|
||||
|
||||
type stubError string
|
||||
|
||||
func (s stubError) Error() string { return string(s) }
|
||||
|
||||
var _ dbpubsub.Pubsub = &errInjectingPubsub{}
|
||||
|
||||
type errInjectingPubsub struct {
|
||||
listener dbpubsub.ListenerWithErr
|
||||
}
|
||||
|
||||
func (*errInjectingPubsub) Subscribe(string, dbpubsub.Listener) (func(), error) {
|
||||
return nil, xerrors.New("Subscribe not implemented")
|
||||
}
|
||||
|
||||
func (p *errInjectingPubsub) SubscribeWithErr(_ string, listener dbpubsub.ListenerWithErr) (func(), error) {
|
||||
p.listener = listener
|
||||
return func() {}, nil
|
||||
}
|
||||
|
||||
func (*errInjectingPubsub) Publish(string, []byte) error {
|
||||
return xerrors.New("Publish not implemented")
|
||||
}
|
||||
|
||||
func (*errInjectingPubsub) Close() error {
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
package pubsub
|
||||
|
||||
// AIProvidersChangedChannel is the pubsub channel that carries AI
|
||||
// provider lifecycle events: provider create / update / soft-delete
|
||||
// and key insert / delete. Subscribers (aibridged, aibridgeproxyd)
|
||||
// reload their in-memory provider snapshot on receipt.
|
||||
//
|
||||
// The payload is an empty invalidation hint; subscribers refetch the
|
||||
// authoritative state from the database, so dropped messages only
|
||||
// delay convergence rather than diverge state.
|
||||
const AIProvidersChangedChannel = "ai_providers_changed"
|
||||
Reference in New Issue
Block a user