package aibridged_test import ( "context" "testing" "testing/synctest" "time" "github.com/google/uuid" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel/trace" "go.uber.org/mock/gomock" "cdr.dev/slog/v3/sloggers/slogtest" "github.com/coder/coder/v2/aibridge/mcp" "github.com/coder/coder/v2/aibridge/mcpmock" "github.com/coder/coder/v2/coderd/aibridged" mock "github.com/coder/coder/v2/coderd/aibridged/aibridgedmock" ) // TestPool validates the published behavior of [aibridged.CachedBridgePool]. // It is not meant to be an exhaustive test of the internal cache's functionality, // since that is already covered by its library. func TestPool(t *testing.T) { t.Parallel() logger := slogtest.Make(t, nil) ctrl := gomock.NewController(t) client := mock.NewMockDRPCClient(ctrl) mcpProxy := mcpmock.NewMockServerProxier(ctrl) opts := aibridged.PoolOptions{MaxItems: 1, TTL: time.Second} pool, err := aibridged.NewCachedBridgePool(opts, nil, logger, nil, testTracer) require.NoError(t, err) t.Cleanup(func() { pool.Shutdown(context.Background()) }) id, id2, apiKeyID1, apiKeyID2 := uuid.New(), uuid.New(), uuid.New(), uuid.New() clientFn := func() (aibridged.DRPCClient, error) { return client, nil } // Once a pool instance is initialized, it will try setup its MCP proxier(s). // This is called exactly once since the instance below is only created once. mcpProxy.EXPECT().Init(gomock.Any()).Times(1).Return(nil) // This is part of the lifecycle. mcpProxy.EXPECT().Shutdown(gomock.Any()).AnyTimes().Return(nil) // Acquiring a pool instance will create one the first time it sees an // initiator ID... inst, err := pool.Acquire(t.Context(), aibridged.Request{ SessionKey: "key", InitiatorID: id, APIKeyID: apiKeyID1.String(), }, clientFn, newMockMCPFactory(mcpProxy)) require.NoError(t, err, "acquire pool instance") // ...and it will return it when acquired again. instB, err := pool.Acquire(t.Context(), aibridged.Request{ SessionKey: "key", InitiatorID: id, APIKeyID: apiKeyID1.String(), }, clientFn, newMockMCPFactory(mcpProxy)) require.NoError(t, err, "acquire pool instance") require.Same(t, inst, instB) cacheMetrics := pool.CacheMetrics() require.EqualValues(t, 1, cacheMetrics.KeysAdded()) require.EqualValues(t, 0, cacheMetrics.KeysEvicted()) require.EqualValues(t, 1, cacheMetrics.Hits()) require.EqualValues(t, 1, cacheMetrics.Misses()) // This will get called again because a new instance will be created. mcpProxy.EXPECT().Init(gomock.Any()).Times(1).Return(nil) // But that key will be evicted when a new initiator is seen (maxItems=1): inst2, err := pool.Acquire(t.Context(), aibridged.Request{ SessionKey: "key", InitiatorID: id2, APIKeyID: apiKeyID1.String(), }, clientFn, newMockMCPFactory(mcpProxy)) require.NoError(t, err, "acquire pool instance") require.NotSame(t, inst, inst2) cacheMetrics = pool.CacheMetrics() require.EqualValues(t, 2, cacheMetrics.KeysAdded()) require.EqualValues(t, 1, cacheMetrics.KeysEvicted()) require.EqualValues(t, 1, cacheMetrics.Hits()) require.EqualValues(t, 2, cacheMetrics.Misses()) // This will get called again because a new instance will be created. mcpProxy.EXPECT().Init(gomock.Any()).Times(1).Return(nil) // New instance is created for different api key id inst2B, err := pool.Acquire(t.Context(), aibridged.Request{ SessionKey: "key", InitiatorID: id2, APIKeyID: apiKeyID2.String(), }, clientFn, newMockMCPFactory(mcpProxy)) require.NoError(t, err, "acquire pool instance 2B") require.NotSame(t, inst2, inst2B) cacheMetrics = pool.CacheMetrics() require.EqualValues(t, 3, cacheMetrics.KeysAdded()) require.EqualValues(t, 2, cacheMetrics.KeysEvicted()) require.EqualValues(t, 1, cacheMetrics.Hits()) require.EqualValues(t, 3, cacheMetrics.Misses()) } func TestPool_Expiry(t *testing.T) { t.Parallel() synctest.Test(t, func(t *testing.T) { logger := slogtest.Make(t, nil) ctrl := gomock.NewController(t) client := mock.NewMockDRPCClient(ctrl) mcpProxy := mcpmock.NewMockServerProxier(ctrl) mcpProxy.EXPECT().Init(gomock.Any()).AnyTimes().Return(nil) mcpProxy.EXPECT().Shutdown(gomock.Any()).AnyTimes().Return(nil) const ttl = time.Second opts := aibridged.PoolOptions{MaxItems: 1, TTL: ttl} pool, err := aibridged.NewCachedBridgePool(opts, nil, logger, nil, testTracer) require.NoError(t, err) t.Cleanup(func() { pool.Shutdown(context.Background()) }) req := aibridged.Request{ SessionKey: "key", InitiatorID: uuid.New(), APIKeyID: uuid.New().String(), } clientFn := func() (aibridged.DRPCClient, error) { return client, nil } ctx := t.Context() // First acquire is a cache miss. _, err = pool.Acquire(ctx, req, clientFn, newMockMCPFactory(mcpProxy)) require.NoError(t, err) // Second acquire is a cache hit. _, err = pool.Acquire(ctx, req, clientFn, newMockMCPFactory(mcpProxy)) require.NoError(t, err) metrics := pool.CacheMetrics() require.EqualValues(t, 1, metrics.Misses()) require.EqualValues(t, 1, metrics.Hits()) // TTL expires time.Sleep(ttl + time.Millisecond) // Third acquire is a cache miss because the entry expired. _, err = pool.Acquire(ctx, req, clientFn, newMockMCPFactory(mcpProxy)) require.NoError(t, err) metrics = pool.CacheMetrics() require.EqualValues(t, 2, metrics.Misses()) require.EqualValues(t, 1, metrics.Hits()) // Wait for all eviction goroutines to complete before gomock's ctrl.Finish() // runs in test cleanup. ristretto's OnEvict callback spawns goroutines that // need to finish calling mcpProxy.Shutdown() before ctrl.finish clears the // expectations. synctest.Wait() }) } var _ aibridged.MCPProxyBuilder = &mockMCPFactory{} type mockMCPFactory struct { proxy *mcpmock.MockServerProxier } func newMockMCPFactory(proxy *mcpmock.MockServerProxier) *mockMCPFactory { return &mockMCPFactory{proxy: proxy} } func (m *mockMCPFactory) Build(ctx context.Context, req aibridged.Request, tracer trace.Tracer) (mcp.ServerProxier, error) { return m.proxy, nil }