mirror of
https://github.com/coder/coder.git
synced 2026-06-03 04:58:23 +00:00
ddec110b0e
In order to allow Coder Agents to use AI Gateway in OSS, we need to rehome the `aibridged`\-related code into the AGPL path. The HTTP API is only registered under enterprise so will still require the AI Governance Add-on to be present in order to use it, whereas Coder Agents uses an in-memory pipe to the same handlers.
182 lines
6.0 KiB
Go
182 lines
6.0 KiB
Go
package aibridged_test
|
|
|
|
import (
|
|
"context"
|
|
"testing"
|
|
"testing/synctest"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/stretchr/testify/require"
|
|
"go.opentelemetry.io/otel/trace"
|
|
"go.uber.org/mock/gomock"
|
|
|
|
"cdr.dev/slog/v3/sloggers/slogtest"
|
|
"github.com/coder/coder/v2/aibridge/mcp"
|
|
"github.com/coder/coder/v2/aibridge/mcpmock"
|
|
"github.com/coder/coder/v2/coderd/aibridged"
|
|
mock "github.com/coder/coder/v2/coderd/aibridged/aibridgedmock"
|
|
)
|
|
|
|
// TestPool validates the published behavior of [aibridged.CachedBridgePool].
|
|
// It is not meant to be an exhaustive test of the internal cache's functionality,
|
|
// since that is already covered by its library.
|
|
func TestPool(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
logger := slogtest.Make(t, nil)
|
|
|
|
ctrl := gomock.NewController(t)
|
|
client := mock.NewMockDRPCClient(ctrl)
|
|
mcpProxy := mcpmock.NewMockServerProxier(ctrl)
|
|
|
|
opts := aibridged.PoolOptions{MaxItems: 1, TTL: time.Second}
|
|
pool, err := aibridged.NewCachedBridgePool(opts, nil, logger, nil, testTracer)
|
|
require.NoError(t, err)
|
|
t.Cleanup(func() { pool.Shutdown(context.Background()) })
|
|
|
|
id, id2, apiKeyID1, apiKeyID2 := uuid.New(), uuid.New(), uuid.New(), uuid.New()
|
|
clientFn := func() (aibridged.DRPCClient, error) {
|
|
return client, nil
|
|
}
|
|
|
|
// Once a pool instance is initialized, it will try setup its MCP proxier(s).
|
|
// This is called exactly once since the instance below is only created once.
|
|
mcpProxy.EXPECT().Init(gomock.Any()).Times(1).Return(nil)
|
|
// This is part of the lifecycle.
|
|
mcpProxy.EXPECT().Shutdown(gomock.Any()).AnyTimes().Return(nil)
|
|
|
|
// Acquiring a pool instance will create one the first time it sees an
|
|
// initiator ID...
|
|
inst, err := pool.Acquire(t.Context(), aibridged.Request{
|
|
SessionKey: "key",
|
|
InitiatorID: id,
|
|
APIKeyID: apiKeyID1.String(),
|
|
}, clientFn, newMockMCPFactory(mcpProxy))
|
|
require.NoError(t, err, "acquire pool instance")
|
|
|
|
// ...and it will return it when acquired again.
|
|
instB, err := pool.Acquire(t.Context(), aibridged.Request{
|
|
SessionKey: "key",
|
|
InitiatorID: id,
|
|
APIKeyID: apiKeyID1.String(),
|
|
}, clientFn, newMockMCPFactory(mcpProxy))
|
|
require.NoError(t, err, "acquire pool instance")
|
|
require.Same(t, inst, instB)
|
|
|
|
cacheMetrics := pool.CacheMetrics()
|
|
require.EqualValues(t, 1, cacheMetrics.KeysAdded())
|
|
require.EqualValues(t, 0, cacheMetrics.KeysEvicted())
|
|
require.EqualValues(t, 1, cacheMetrics.Hits())
|
|
require.EqualValues(t, 1, cacheMetrics.Misses())
|
|
|
|
// This will get called again because a new instance will be created.
|
|
mcpProxy.EXPECT().Init(gomock.Any()).Times(1).Return(nil)
|
|
|
|
// But that key will be evicted when a new initiator is seen (maxItems=1):
|
|
inst2, err := pool.Acquire(t.Context(), aibridged.Request{
|
|
SessionKey: "key",
|
|
InitiatorID: id2,
|
|
APIKeyID: apiKeyID1.String(),
|
|
}, clientFn, newMockMCPFactory(mcpProxy))
|
|
require.NoError(t, err, "acquire pool instance")
|
|
require.NotSame(t, inst, inst2)
|
|
|
|
cacheMetrics = pool.CacheMetrics()
|
|
require.EqualValues(t, 2, cacheMetrics.KeysAdded())
|
|
require.EqualValues(t, 1, cacheMetrics.KeysEvicted())
|
|
require.EqualValues(t, 1, cacheMetrics.Hits())
|
|
require.EqualValues(t, 2, cacheMetrics.Misses())
|
|
|
|
// This will get called again because a new instance will be created.
|
|
mcpProxy.EXPECT().Init(gomock.Any()).Times(1).Return(nil)
|
|
|
|
// New instance is created for different api key id
|
|
inst2B, err := pool.Acquire(t.Context(), aibridged.Request{
|
|
SessionKey: "key",
|
|
InitiatorID: id2,
|
|
APIKeyID: apiKeyID2.String(),
|
|
}, clientFn, newMockMCPFactory(mcpProxy))
|
|
require.NoError(t, err, "acquire pool instance 2B")
|
|
require.NotSame(t, inst2, inst2B)
|
|
|
|
cacheMetrics = pool.CacheMetrics()
|
|
require.EqualValues(t, 3, cacheMetrics.KeysAdded())
|
|
require.EqualValues(t, 2, cacheMetrics.KeysEvicted())
|
|
require.EqualValues(t, 1, cacheMetrics.Hits())
|
|
require.EqualValues(t, 3, cacheMetrics.Misses())
|
|
}
|
|
|
|
func TestPool_Expiry(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
synctest.Test(t, func(t *testing.T) {
|
|
logger := slogtest.Make(t, nil)
|
|
ctrl := gomock.NewController(t)
|
|
client := mock.NewMockDRPCClient(ctrl)
|
|
mcpProxy := mcpmock.NewMockServerProxier(ctrl)
|
|
mcpProxy.EXPECT().Init(gomock.Any()).AnyTimes().Return(nil)
|
|
mcpProxy.EXPECT().Shutdown(gomock.Any()).AnyTimes().Return(nil)
|
|
|
|
const ttl = time.Second
|
|
opts := aibridged.PoolOptions{MaxItems: 1, TTL: ttl}
|
|
pool, err := aibridged.NewCachedBridgePool(opts, nil, logger, nil, testTracer)
|
|
require.NoError(t, err)
|
|
t.Cleanup(func() { pool.Shutdown(context.Background()) })
|
|
|
|
req := aibridged.Request{
|
|
SessionKey: "key",
|
|
InitiatorID: uuid.New(),
|
|
APIKeyID: uuid.New().String(),
|
|
}
|
|
clientFn := func() (aibridged.DRPCClient, error) {
|
|
return client, nil
|
|
}
|
|
|
|
ctx := t.Context()
|
|
|
|
// First acquire is a cache miss.
|
|
_, err = pool.Acquire(ctx, req, clientFn, newMockMCPFactory(mcpProxy))
|
|
require.NoError(t, err)
|
|
|
|
// Second acquire is a cache hit.
|
|
_, err = pool.Acquire(ctx, req, clientFn, newMockMCPFactory(mcpProxy))
|
|
require.NoError(t, err)
|
|
|
|
metrics := pool.CacheMetrics()
|
|
require.EqualValues(t, 1, metrics.Misses())
|
|
require.EqualValues(t, 1, metrics.Hits())
|
|
|
|
// TTL expires
|
|
time.Sleep(ttl + time.Millisecond)
|
|
|
|
// Third acquire is a cache miss because the entry expired.
|
|
_, err = pool.Acquire(ctx, req, clientFn, newMockMCPFactory(mcpProxy))
|
|
require.NoError(t, err)
|
|
|
|
metrics = pool.CacheMetrics()
|
|
require.EqualValues(t, 2, metrics.Misses())
|
|
require.EqualValues(t, 1, metrics.Hits())
|
|
|
|
// Wait for all eviction goroutines to complete before gomock's ctrl.Finish()
|
|
// runs in test cleanup. ristretto's OnEvict callback spawns goroutines that
|
|
// need to finish calling mcpProxy.Shutdown() before ctrl.finish clears the
|
|
// expectations.
|
|
synctest.Wait()
|
|
})
|
|
}
|
|
|
|
var _ aibridged.MCPProxyBuilder = &mockMCPFactory{}
|
|
|
|
type mockMCPFactory struct {
|
|
proxy *mcpmock.MockServerProxier
|
|
}
|
|
|
|
func newMockMCPFactory(proxy *mcpmock.MockServerProxier) *mockMCPFactory {
|
|
return &mockMCPFactory{proxy: proxy}
|
|
}
|
|
|
|
func (m *mockMCPFactory) Build(ctx context.Context, req aibridged.Request, tracer trace.Tracer) (mcp.ServerProxier, error) {
|
|
return m.proxy, nil
|
|
}
|