Files
coder/coderd/aibridged/pool_test.go
T
Danny Kopping ddec110b0e refactor: move aibridged out of enterprise to AGPL (#25570)
In order to allow Coder Agents to use AI Gateway in OSS, we need to rehome the `aibridged`\-related code into the AGPL path.

The HTTP API is only registered under enterprise so will still require the AI Governance Add-on to be present in order to use it, whereas Coder Agents uses an in-memory pipe to the same handlers.
2026-05-22 09:11:37 +02:00

182 lines
6.0 KiB
Go

package aibridged_test
import (
"context"
"testing"
"testing/synctest"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel/trace"
"go.uber.org/mock/gomock"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/aibridge/mcp"
"github.com/coder/coder/v2/aibridge/mcpmock"
"github.com/coder/coder/v2/coderd/aibridged"
mock "github.com/coder/coder/v2/coderd/aibridged/aibridgedmock"
)
// TestPool validates the published behavior of [aibridged.CachedBridgePool].
// It is not meant to be an exhaustive test of the internal cache's functionality,
// since that is already covered by its library.
func TestPool(t *testing.T) {
t.Parallel()
logger := slogtest.Make(t, nil)
ctrl := gomock.NewController(t)
client := mock.NewMockDRPCClient(ctrl)
mcpProxy := mcpmock.NewMockServerProxier(ctrl)
opts := aibridged.PoolOptions{MaxItems: 1, TTL: time.Second}
pool, err := aibridged.NewCachedBridgePool(opts, nil, logger, nil, testTracer)
require.NoError(t, err)
t.Cleanup(func() { pool.Shutdown(context.Background()) })
id, id2, apiKeyID1, apiKeyID2 := uuid.New(), uuid.New(), uuid.New(), uuid.New()
clientFn := func() (aibridged.DRPCClient, error) {
return client, nil
}
// Once a pool instance is initialized, it will try setup its MCP proxier(s).
// This is called exactly once since the instance below is only created once.
mcpProxy.EXPECT().Init(gomock.Any()).Times(1).Return(nil)
// This is part of the lifecycle.
mcpProxy.EXPECT().Shutdown(gomock.Any()).AnyTimes().Return(nil)
// Acquiring a pool instance will create one the first time it sees an
// initiator ID...
inst, err := pool.Acquire(t.Context(), aibridged.Request{
SessionKey: "key",
InitiatorID: id,
APIKeyID: apiKeyID1.String(),
}, clientFn, newMockMCPFactory(mcpProxy))
require.NoError(t, err, "acquire pool instance")
// ...and it will return it when acquired again.
instB, err := pool.Acquire(t.Context(), aibridged.Request{
SessionKey: "key",
InitiatorID: id,
APIKeyID: apiKeyID1.String(),
}, clientFn, newMockMCPFactory(mcpProxy))
require.NoError(t, err, "acquire pool instance")
require.Same(t, inst, instB)
cacheMetrics := pool.CacheMetrics()
require.EqualValues(t, 1, cacheMetrics.KeysAdded())
require.EqualValues(t, 0, cacheMetrics.KeysEvicted())
require.EqualValues(t, 1, cacheMetrics.Hits())
require.EqualValues(t, 1, cacheMetrics.Misses())
// This will get called again because a new instance will be created.
mcpProxy.EXPECT().Init(gomock.Any()).Times(1).Return(nil)
// But that key will be evicted when a new initiator is seen (maxItems=1):
inst2, err := pool.Acquire(t.Context(), aibridged.Request{
SessionKey: "key",
InitiatorID: id2,
APIKeyID: apiKeyID1.String(),
}, clientFn, newMockMCPFactory(mcpProxy))
require.NoError(t, err, "acquire pool instance")
require.NotSame(t, inst, inst2)
cacheMetrics = pool.CacheMetrics()
require.EqualValues(t, 2, cacheMetrics.KeysAdded())
require.EqualValues(t, 1, cacheMetrics.KeysEvicted())
require.EqualValues(t, 1, cacheMetrics.Hits())
require.EqualValues(t, 2, cacheMetrics.Misses())
// This will get called again because a new instance will be created.
mcpProxy.EXPECT().Init(gomock.Any()).Times(1).Return(nil)
// New instance is created for different api key id
inst2B, err := pool.Acquire(t.Context(), aibridged.Request{
SessionKey: "key",
InitiatorID: id2,
APIKeyID: apiKeyID2.String(),
}, clientFn, newMockMCPFactory(mcpProxy))
require.NoError(t, err, "acquire pool instance 2B")
require.NotSame(t, inst2, inst2B)
cacheMetrics = pool.CacheMetrics()
require.EqualValues(t, 3, cacheMetrics.KeysAdded())
require.EqualValues(t, 2, cacheMetrics.KeysEvicted())
require.EqualValues(t, 1, cacheMetrics.Hits())
require.EqualValues(t, 3, cacheMetrics.Misses())
}
func TestPool_Expiry(t *testing.T) {
t.Parallel()
synctest.Test(t, func(t *testing.T) {
logger := slogtest.Make(t, nil)
ctrl := gomock.NewController(t)
client := mock.NewMockDRPCClient(ctrl)
mcpProxy := mcpmock.NewMockServerProxier(ctrl)
mcpProxy.EXPECT().Init(gomock.Any()).AnyTimes().Return(nil)
mcpProxy.EXPECT().Shutdown(gomock.Any()).AnyTimes().Return(nil)
const ttl = time.Second
opts := aibridged.PoolOptions{MaxItems: 1, TTL: ttl}
pool, err := aibridged.NewCachedBridgePool(opts, nil, logger, nil, testTracer)
require.NoError(t, err)
t.Cleanup(func() { pool.Shutdown(context.Background()) })
req := aibridged.Request{
SessionKey: "key",
InitiatorID: uuid.New(),
APIKeyID: uuid.New().String(),
}
clientFn := func() (aibridged.DRPCClient, error) {
return client, nil
}
ctx := t.Context()
// First acquire is a cache miss.
_, err = pool.Acquire(ctx, req, clientFn, newMockMCPFactory(mcpProxy))
require.NoError(t, err)
// Second acquire is a cache hit.
_, err = pool.Acquire(ctx, req, clientFn, newMockMCPFactory(mcpProxy))
require.NoError(t, err)
metrics := pool.CacheMetrics()
require.EqualValues(t, 1, metrics.Misses())
require.EqualValues(t, 1, metrics.Hits())
// TTL expires
time.Sleep(ttl + time.Millisecond)
// Third acquire is a cache miss because the entry expired.
_, err = pool.Acquire(ctx, req, clientFn, newMockMCPFactory(mcpProxy))
require.NoError(t, err)
metrics = pool.CacheMetrics()
require.EqualValues(t, 2, metrics.Misses())
require.EqualValues(t, 1, metrics.Hits())
// Wait for all eviction goroutines to complete before gomock's ctrl.Finish()
// runs in test cleanup. ristretto's OnEvict callback spawns goroutines that
// need to finish calling mcpProxy.Shutdown() before ctrl.finish clears the
// expectations.
synctest.Wait()
})
}
var _ aibridged.MCPProxyBuilder = &mockMCPFactory{}
type mockMCPFactory struct {
proxy *mcpmock.MockServerProxier
}
func newMockMCPFactory(proxy *mcpmock.MockServerProxier) *mockMCPFactory {
return &mockMCPFactory{proxy: proxy}
}
func (m *mockMCPFactory) Build(ctx context.Context, req aibridged.Request, tracer trace.Tracer) (mcp.ServerProxier, error) {
return m.proxy, nil
}