mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
5b10268827
_Disclosure: created with Coder Agents._ When providers are disabled, we should serve a sentinel error so the requesting client (Claude Code, Coder Agents, etc) is informed. Coder Agents can also conditionalize its display to show a helpful error message. --------- Signed-off-by: Danny Kopping <danny@coder.com> Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
259 lines
8.5 KiB
Go
259 lines
8.5 KiB
Go
package aibridged
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"slices"
|
|
"strconv"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/dgraph-io/ristretto/v2"
|
|
"go.opentelemetry.io/otel/attribute"
|
|
"go.opentelemetry.io/otel/trace"
|
|
"golang.org/x/xerrors"
|
|
"tailscale.com/util/singleflight"
|
|
|
|
"cdr.dev/slog/v3"
|
|
"github.com/coder/coder/v2/aibridge"
|
|
"github.com/coder/coder/v2/aibridge/mcp"
|
|
"github.com/coder/coder/v2/aibridge/tracing"
|
|
)
|
|
|
|
const (
|
|
cacheCost = 1 // We can't know the actual size in bytes of the value (it'll change over time).
|
|
)
|
|
|
|
// Pooler describes a pool of [*aibridge.RequestBridge] instances from which instances can be retrieved.
|
|
// One [*aibridge.RequestBridge] instance is created per given key.
|
|
type Pooler interface {
|
|
Acquire(ctx context.Context, req Request, clientFn ClientFunc, mcpBootstrapper MCPProxyBuilder) (http.Handler, error)
|
|
// ReplaceProviders swaps the providers used to construct future
|
|
// RequestBridge instances and clears the cache. Disabled providers
|
|
// must be included; the bridge serves a 503 sentinel on their
|
|
// routes.
|
|
ReplaceProviders(providers []aibridge.Provider)
|
|
Shutdown(ctx context.Context) error
|
|
}
|
|
|
|
type PoolMetrics interface {
|
|
Hits() uint64
|
|
Misses() uint64
|
|
KeysAdded() uint64
|
|
KeysEvicted() uint64
|
|
}
|
|
|
|
type PoolOptions struct {
|
|
MaxItems int64
|
|
TTL time.Duration
|
|
}
|
|
|
|
var DefaultPoolOptions = PoolOptions{MaxItems: 5000, TTL: time.Minute * 15}
|
|
|
|
var _ Pooler = &CachedBridgePool{}
|
|
|
|
type CachedBridgePool struct {
|
|
cache *ristretto.Cache[string, *aibridge.RequestBridge]
|
|
// providers is the live provider set used by new RequestBridge
|
|
// instances. Includes disabled providers.
|
|
providers atomic.Pointer[[]aibridge.Provider]
|
|
providerVersion atomic.Int64
|
|
logger slog.Logger
|
|
options PoolOptions
|
|
|
|
singleflight *singleflight.Group[string, *aibridge.RequestBridge]
|
|
|
|
metrics *aibridge.Metrics
|
|
tracer trace.Tracer
|
|
|
|
shutDownOnce sync.Once
|
|
shuttingDownCh chan struct{}
|
|
}
|
|
|
|
func NewCachedBridgePool(options PoolOptions, providers []aibridge.Provider, logger slog.Logger, metrics *aibridge.Metrics, tracer trace.Tracer) (*CachedBridgePool, error) {
|
|
cache, err := ristretto.NewCache(&ristretto.Config[string, *aibridge.RequestBridge]{
|
|
NumCounters: options.MaxItems * 10, // Docs suggest setting this 10x number of keys.
|
|
MaxCost: options.MaxItems * cacheCost, // Up to n instances.
|
|
IgnoreInternalCost: true, // Don't try estimate cost using bytes (ristretto does this naïvely anyway, just using the size of the value struct not the REAL memory usage).
|
|
BufferItems: 64, // Sticking with recommendation from docs.
|
|
Metrics: true, // Collect metrics (only used in tests, for now).
|
|
OnEvict: func(item *ristretto.Item[*aibridge.RequestBridge]) {
|
|
if item == nil || item.Value == nil {
|
|
return
|
|
}
|
|
// Capture the value synchronously: ristretto reuses the
|
|
// item slot after OnEvict returns, so reading item.Value
|
|
// from the goroutine below races with the caller of
|
|
// Clear/Set. The shutdown still runs in the background to
|
|
// avoid blocking ristretto's eviction loop.
|
|
bridge := item.Value
|
|
go func() {
|
|
shutdownCtx, cancel := context.WithTimeout(context.Background(), time.Second*5)
|
|
defer cancel()
|
|
_ = bridge.Shutdown(shutdownCtx)
|
|
}()
|
|
},
|
|
})
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("create cache: %w", err)
|
|
}
|
|
|
|
pool := &CachedBridgePool{
|
|
cache: cache,
|
|
options: options,
|
|
metrics: metrics,
|
|
tracer: tracer,
|
|
logger: logger,
|
|
|
|
singleflight: &singleflight.Group[string, *aibridge.RequestBridge]{},
|
|
|
|
shuttingDownCh: make(chan struct{}),
|
|
}
|
|
initial := slices.Clone(providers)
|
|
pool.providers.Store(&initial)
|
|
return pool, nil
|
|
}
|
|
|
|
// ReplaceProviders swaps the provider snapshot used by future Acquires.
|
|
// It is safe to call concurrently with Acquire and is a no-op after
|
|
// Shutdown.
|
|
func (p *CachedBridgePool) ReplaceProviders(providers []aibridge.Provider) {
|
|
select {
|
|
case <-p.shuttingDownCh:
|
|
return
|
|
default:
|
|
}
|
|
snapshot := slices.Clone(providers)
|
|
p.providers.Store(&snapshot)
|
|
version := time.Now().UnixNano()
|
|
p.providerVersion.Store(version)
|
|
// Clear evicts every cached bridge; OnEvict shuts each one down in
|
|
// the background. Wait for buffered writes to drain so a replacement
|
|
// immediately followed by an Acquire always sees the cleared cache.
|
|
p.cache.Clear()
|
|
p.cache.Wait()
|
|
p.logger.Info(context.Background(), "request bridge pool reloaded",
|
|
slog.F("provider_count", len(snapshot)),
|
|
slog.F("provider_version", version),
|
|
)
|
|
}
|
|
|
|
// loadProviders returns the current providers snapshot. The returned
|
|
// slice must not be mutated.
|
|
func (p *CachedBridgePool) loadProviders() []aibridge.Provider {
|
|
if ptr := p.providers.Load(); ptr != nil {
|
|
return *ptr
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Acquire retrieves or creates a [*aibridge.RequestBridge] instance per given key.
|
|
//
|
|
// Each returned [*aibridge.RequestBridge] is safe for concurrent use.
|
|
// Each [*aibridge.RequestBridge] is stateful because it has MCP clients which maintain sessions to the configured MCP server.
|
|
func (p *CachedBridgePool) Acquire(ctx context.Context, req Request, clientFn ClientFunc, mcpProxyFactory MCPProxyBuilder) (_ http.Handler, outErr error) {
|
|
spanAttrs := []attribute.KeyValue{
|
|
attribute.String(tracing.InitiatorID, req.InitiatorID.String()),
|
|
attribute.String(tracing.APIKeyID, req.APIKeyID),
|
|
}
|
|
ctx, span := p.tracer.Start(ctx, "CachedBridgePool.Acquire", trace.WithAttributes(spanAttrs...))
|
|
defer tracing.EndSpanErr(span, &outErr)
|
|
ctx = tracing.WithRequestBridgeAttributesInContext(ctx, spanAttrs)
|
|
|
|
if err := ctx.Err(); err != nil {
|
|
return nil, xerrors.Errorf("acquire: %w", err)
|
|
}
|
|
|
|
select {
|
|
case <-p.shuttingDownCh:
|
|
return nil, xerrors.New("pool shutting down")
|
|
default:
|
|
}
|
|
|
|
// Wait for all buffered writes to be applied, otherwise multiple calls in quick succession
|
|
// may visit the slow path unnecessarily.
|
|
defer p.cache.Wait()
|
|
|
|
// Fast path.
|
|
cacheKey := req.InitiatorID.String() + "|" + req.APIKeyID
|
|
bridge, ok := p.cache.Get(cacheKey)
|
|
if ok && bridge != nil {
|
|
// TODO: future improvement:
|
|
// Once we can detect token expiry against an MCP server, we no longer need to let these instances
|
|
// expire after the original TTL; we can extend the TTL on each Acquire() call.
|
|
// For now, we need to let the instance expiry to keep the MCP connections fresh.
|
|
|
|
span.AddEvent("cache_hit")
|
|
return bridge, nil
|
|
}
|
|
|
|
span.AddEvent("cache_miss")
|
|
providerVersion := p.providerVersion.Load()
|
|
recorder := aibridge.NewRecorder(p.logger.Named("recorder"), p.tracer, func() (aibridge.Recorder, error) {
|
|
client, err := clientFn()
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("acquire client: %w", err)
|
|
}
|
|
|
|
return &recorderTranslation{apiKeyID: req.APIKeyID, client: client}, nil
|
|
})
|
|
|
|
// Slow path.
|
|
// Creating an *aibridge.RequestBridge may take some time, so gate all subsequent callers behind the initial request and return the resulting value.
|
|
// TODO: track startup time since it adds latency to first request (histogram count will also help us see how often this occurs).
|
|
singleflightKey := cacheKey + "|" + strconv.FormatInt(providerVersion, 10)
|
|
instance, err, _ := p.singleflight.Do(singleflightKey, func() (*aibridge.RequestBridge, error) {
|
|
var (
|
|
mcpServers mcp.ServerProxier
|
|
err error
|
|
)
|
|
|
|
mcpServers, err = mcpProxyFactory.Build(ctx, req, p.tracer)
|
|
if err != nil {
|
|
p.logger.Warn(ctx, "failed to create MCP server proxiers", slog.Error(err))
|
|
// Don't fail here; MCP server injection can gracefully degrade.
|
|
}
|
|
|
|
if mcpServers != nil {
|
|
// This will block while connections are established with upstream MCP server(s), and tools are listed.
|
|
if err := mcpServers.Init(ctx); err != nil {
|
|
p.logger.Warn(ctx, "failed to initialize MCP server proxier(s)", slog.Error(err))
|
|
}
|
|
}
|
|
|
|
bridge, err := aibridge.NewRequestBridge(ctx, p.loadProviders(), recorder, mcpServers, p.logger, p.metrics, p.tracer)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("create new request bridge: %w", err)
|
|
}
|
|
|
|
if p.providerVersion.Load() == providerVersion {
|
|
p.cache.SetWithTTL(cacheKey, bridge, cacheCost, p.options.TTL)
|
|
}
|
|
|
|
return bridge, nil
|
|
})
|
|
|
|
return instance, err
|
|
}
|
|
|
|
func (p *CachedBridgePool) CacheMetrics() PoolMetrics {
|
|
if p.cache == nil {
|
|
return nil
|
|
}
|
|
|
|
return p.cache.Metrics
|
|
}
|
|
|
|
// Shutdown will close the cache which will trigger eviction of all the Bridge entries.
|
|
func (p *CachedBridgePool) Shutdown(_ context.Context) error {
|
|
p.shutDownOnce.Do(func() {
|
|
// Prevent new requests from being served.
|
|
close(p.shuttingDownCh)
|
|
|
|
p.cache.Close()
|
|
})
|
|
|
|
return nil
|
|
}
|