Files
coder/coderd/aibridged/pool.go
T
Danny Kopping 5b10268827 feat: serve 503 sentinel for disabled providers (#25794)
_Disclosure: created with Coder Agents._

When providers are disabled, we should serve a sentinel error so the
requesting client (Claude Code, Coder Agents, etc) is informed. Coder
Agents can also conditionalize its display to show a helpful error
message.

---------

Signed-off-by: Danny Kopping <danny@coder.com>
Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-29 10:24:16 +02:00

259 lines
8.5 KiB
Go

package aibridged
import (
"context"
"net/http"
"slices"
"strconv"
"sync"
"sync/atomic"
"time"
"github.com/dgraph-io/ristretto/v2"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"golang.org/x/xerrors"
"tailscale.com/util/singleflight"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/aibridge"
"github.com/coder/coder/v2/aibridge/mcp"
"github.com/coder/coder/v2/aibridge/tracing"
)
const (
cacheCost = 1 // We can't know the actual size in bytes of the value (it'll change over time).
)
// Pooler describes a pool of [*aibridge.RequestBridge] instances from which instances can be retrieved.
// One [*aibridge.RequestBridge] instance is created per given key.
type Pooler interface {
Acquire(ctx context.Context, req Request, clientFn ClientFunc, mcpBootstrapper MCPProxyBuilder) (http.Handler, error)
// ReplaceProviders swaps the providers used to construct future
// RequestBridge instances and clears the cache. Disabled providers
// must be included; the bridge serves a 503 sentinel on their
// routes.
ReplaceProviders(providers []aibridge.Provider)
Shutdown(ctx context.Context) error
}
type PoolMetrics interface {
Hits() uint64
Misses() uint64
KeysAdded() uint64
KeysEvicted() uint64
}
type PoolOptions struct {
MaxItems int64
TTL time.Duration
}
var DefaultPoolOptions = PoolOptions{MaxItems: 5000, TTL: time.Minute * 15}
var _ Pooler = &CachedBridgePool{}
type CachedBridgePool struct {
cache *ristretto.Cache[string, *aibridge.RequestBridge]
// providers is the live provider set used by new RequestBridge
// instances. Includes disabled providers.
providers atomic.Pointer[[]aibridge.Provider]
providerVersion atomic.Int64
logger slog.Logger
options PoolOptions
singleflight *singleflight.Group[string, *aibridge.RequestBridge]
metrics *aibridge.Metrics
tracer trace.Tracer
shutDownOnce sync.Once
shuttingDownCh chan struct{}
}
func NewCachedBridgePool(options PoolOptions, providers []aibridge.Provider, logger slog.Logger, metrics *aibridge.Metrics, tracer trace.Tracer) (*CachedBridgePool, error) {
cache, err := ristretto.NewCache(&ristretto.Config[string, *aibridge.RequestBridge]{
NumCounters: options.MaxItems * 10, // Docs suggest setting this 10x number of keys.
MaxCost: options.MaxItems * cacheCost, // Up to n instances.
IgnoreInternalCost: true, // Don't try estimate cost using bytes (ristretto does this naïvely anyway, just using the size of the value struct not the REAL memory usage).
BufferItems: 64, // Sticking with recommendation from docs.
Metrics: true, // Collect metrics (only used in tests, for now).
OnEvict: func(item *ristretto.Item[*aibridge.RequestBridge]) {
if item == nil || item.Value == nil {
return
}
// Capture the value synchronously: ristretto reuses the
// item slot after OnEvict returns, so reading item.Value
// from the goroutine below races with the caller of
// Clear/Set. The shutdown still runs in the background to
// avoid blocking ristretto's eviction loop.
bridge := item.Value
go func() {
shutdownCtx, cancel := context.WithTimeout(context.Background(), time.Second*5)
defer cancel()
_ = bridge.Shutdown(shutdownCtx)
}()
},
})
if err != nil {
return nil, xerrors.Errorf("create cache: %w", err)
}
pool := &CachedBridgePool{
cache: cache,
options: options,
metrics: metrics,
tracer: tracer,
logger: logger,
singleflight: &singleflight.Group[string, *aibridge.RequestBridge]{},
shuttingDownCh: make(chan struct{}),
}
initial := slices.Clone(providers)
pool.providers.Store(&initial)
return pool, nil
}
// ReplaceProviders swaps the provider snapshot used by future Acquires.
// It is safe to call concurrently with Acquire and is a no-op after
// Shutdown.
func (p *CachedBridgePool) ReplaceProviders(providers []aibridge.Provider) {
select {
case <-p.shuttingDownCh:
return
default:
}
snapshot := slices.Clone(providers)
p.providers.Store(&snapshot)
version := time.Now().UnixNano()
p.providerVersion.Store(version)
// Clear evicts every cached bridge; OnEvict shuts each one down in
// the background. Wait for buffered writes to drain so a replacement
// immediately followed by an Acquire always sees the cleared cache.
p.cache.Clear()
p.cache.Wait()
p.logger.Info(context.Background(), "request bridge pool reloaded",
slog.F("provider_count", len(snapshot)),
slog.F("provider_version", version),
)
}
// loadProviders returns the current providers snapshot. The returned
// slice must not be mutated.
func (p *CachedBridgePool) loadProviders() []aibridge.Provider {
if ptr := p.providers.Load(); ptr != nil {
return *ptr
}
return nil
}
// Acquire retrieves or creates a [*aibridge.RequestBridge] instance per given key.
//
// Each returned [*aibridge.RequestBridge] is safe for concurrent use.
// Each [*aibridge.RequestBridge] is stateful because it has MCP clients which maintain sessions to the configured MCP server.
func (p *CachedBridgePool) Acquire(ctx context.Context, req Request, clientFn ClientFunc, mcpProxyFactory MCPProxyBuilder) (_ http.Handler, outErr error) {
spanAttrs := []attribute.KeyValue{
attribute.String(tracing.InitiatorID, req.InitiatorID.String()),
attribute.String(tracing.APIKeyID, req.APIKeyID),
}
ctx, span := p.tracer.Start(ctx, "CachedBridgePool.Acquire", trace.WithAttributes(spanAttrs...))
defer tracing.EndSpanErr(span, &outErr)
ctx = tracing.WithRequestBridgeAttributesInContext(ctx, spanAttrs)
if err := ctx.Err(); err != nil {
return nil, xerrors.Errorf("acquire: %w", err)
}
select {
case <-p.shuttingDownCh:
return nil, xerrors.New("pool shutting down")
default:
}
// Wait for all buffered writes to be applied, otherwise multiple calls in quick succession
// may visit the slow path unnecessarily.
defer p.cache.Wait()
// Fast path.
cacheKey := req.InitiatorID.String() + "|" + req.APIKeyID
bridge, ok := p.cache.Get(cacheKey)
if ok && bridge != nil {
// TODO: future improvement:
// Once we can detect token expiry against an MCP server, we no longer need to let these instances
// expire after the original TTL; we can extend the TTL on each Acquire() call.
// For now, we need to let the instance expiry to keep the MCP connections fresh.
span.AddEvent("cache_hit")
return bridge, nil
}
span.AddEvent("cache_miss")
providerVersion := p.providerVersion.Load()
recorder := aibridge.NewRecorder(p.logger.Named("recorder"), p.tracer, func() (aibridge.Recorder, error) {
client, err := clientFn()
if err != nil {
return nil, xerrors.Errorf("acquire client: %w", err)
}
return &recorderTranslation{apiKeyID: req.APIKeyID, client: client}, nil
})
// Slow path.
// Creating an *aibridge.RequestBridge may take some time, so gate all subsequent callers behind the initial request and return the resulting value.
// TODO: track startup time since it adds latency to first request (histogram count will also help us see how often this occurs).
singleflightKey := cacheKey + "|" + strconv.FormatInt(providerVersion, 10)
instance, err, _ := p.singleflight.Do(singleflightKey, func() (*aibridge.RequestBridge, error) {
var (
mcpServers mcp.ServerProxier
err error
)
mcpServers, err = mcpProxyFactory.Build(ctx, req, p.tracer)
if err != nil {
p.logger.Warn(ctx, "failed to create MCP server proxiers", slog.Error(err))
// Don't fail here; MCP server injection can gracefully degrade.
}
if mcpServers != nil {
// This will block while connections are established with upstream MCP server(s), and tools are listed.
if err := mcpServers.Init(ctx); err != nil {
p.logger.Warn(ctx, "failed to initialize MCP server proxier(s)", slog.Error(err))
}
}
bridge, err := aibridge.NewRequestBridge(ctx, p.loadProviders(), recorder, mcpServers, p.logger, p.metrics, p.tracer)
if err != nil {
return nil, xerrors.Errorf("create new request bridge: %w", err)
}
if p.providerVersion.Load() == providerVersion {
p.cache.SetWithTTL(cacheKey, bridge, cacheCost, p.options.TTL)
}
return bridge, nil
})
return instance, err
}
func (p *CachedBridgePool) CacheMetrics() PoolMetrics {
if p.cache == nil {
return nil
}
return p.cache.Metrics
}
// Shutdown will close the cache which will trigger eviction of all the Bridge entries.
func (p *CachedBridgePool) Shutdown(_ context.Context) error {
p.shutDownOnce.Do(func() {
// Prevent new requests from being served.
close(p.shuttingDownCh)
p.cache.Close()
})
return nil
}