Files
coder/coderd/aibridgedserver/aibridgedserver_test.go
T
Danny Kopping eddd4a8c2f feat(coderd): accept delegated API key ID from in-process aibridge callers (#25625)
Allows an `api_key_id` to be passed from a trusted in-memory transport
(currently: `chatd`) to `aibridged` for use in authenticating LLM
requests.

This value can _only_ be passed via context, and all users of the
in-memory transport _must_ provide it.

It can be used in conjunction with BYOK headers.

---------

Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-25 11:08:07 +02:00

1801 lines
64 KiB
Go

package aibridgedserver_test
import (
"bufio"
"bytes"
"context"
"database/sql"
"encoding/json"
"fmt"
"net"
"net/url"
"testing"
"time"
"github.com/google/uuid"
"github.com/sqlc-dev/pqtype"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
protobufproto "google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/structpb"
"google.golang.org/protobuf/types/known/timestamppb"
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogjson"
"github.com/coder/coder/v2/coderd/aibridged"
"github.com/coder/coder/v2/coderd/aibridged/proto"
"github.com/coder/coder/v2/coderd/aibridgedserver"
agplaiseats "github.com/coder/coder/v2/coderd/aiseats"
"github.com/coder/coder/v2/coderd/apikey"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/externalauth"
codermcp "github.com/coder/coder/v2/coderd/mcp"
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/cryptorand"
"github.com/coder/coder/v2/testutil"
"github.com/coder/serpent"
)
var requiredExperiments = []codersdk.Experiment{
codersdk.ExperimentMCPServerHTTP, codersdk.ExperimentOAuth2,
}
// TestAuthorization validates the authorization logic.
// No other tests are explicitly defined in this package because aibridgedserver is
// tested via integration tests in the aibridged package (see aibridged/aibridged_integration_test.go).
func TestAuthorization(t *testing.T) {
t.Parallel()
cases := []struct {
name string
// Key will be set to the same key passed to mocksFn if unset.
key string
// mocksFn is called with a valid API key and user. If the test needs
// invalid values, it should just mutate them directly.
mocksFn func(db *dbmock.MockStore, apiKey database.APIKey, user database.User)
expectedErr error
}{
{
name: "invalid key format",
key: "foo",
expectedErr: aibridgedserver.ErrInvalidKey,
},
{
name: "unknown key",
expectedErr: aibridgedserver.ErrUnknownKey,
mocksFn: func(db *dbmock.MockStore, apiKey database.APIKey, user database.User) {
db.EXPECT().GetAPIKeyByID(gomock.Any(), apiKey.ID).Times(1).Return(database.APIKey{}, sql.ErrNoRows)
},
},
{
name: "expired",
expectedErr: aibridgedserver.ErrExpired,
mocksFn: func(db *dbmock.MockStore, apiKey database.APIKey, user database.User) {
apiKey.ExpiresAt = dbtime.Now().Add(-time.Hour)
db.EXPECT().GetAPIKeyByID(gomock.Any(), apiKey.ID).Times(1).Return(apiKey, nil)
},
},
{
name: "invalid key secret",
expectedErr: aibridgedserver.ErrInvalidKey,
mocksFn: func(db *dbmock.MockStore, apiKey database.APIKey, user database.User) {
apiKey.HashedSecret = []byte("differentsecret")
db.EXPECT().GetAPIKeyByID(gomock.Any(), apiKey.ID).Times(1).Return(apiKey, nil)
},
},
{
name: "unknown user",
expectedErr: aibridgedserver.ErrUnknownUser,
mocksFn: func(db *dbmock.MockStore, apiKey database.APIKey, user database.User) {
db.EXPECT().GetAPIKeyByID(gomock.Any(), apiKey.ID).Times(1).Return(apiKey, nil)
db.EXPECT().GetUserByID(gomock.Any(), user.ID).Times(1).Return(database.User{}, sql.ErrNoRows)
},
},
{
name: "deleted user",
expectedErr: aibridgedserver.ErrDeletedUser,
mocksFn: func(db *dbmock.MockStore, apiKey database.APIKey, user database.User) {
db.EXPECT().GetAPIKeyByID(gomock.Any(), apiKey.ID).Times(1).Return(apiKey, nil)
db.EXPECT().GetUserByID(gomock.Any(), user.ID).Times(1).Return(database.User{ID: user.ID, Deleted: true}, nil)
},
},
{
name: "system user",
expectedErr: aibridgedserver.ErrSystemUser,
mocksFn: func(db *dbmock.MockStore, apiKey database.APIKey, user database.User) {
db.EXPECT().GetAPIKeyByID(gomock.Any(), apiKey.ID).Times(1).Return(apiKey, nil)
db.EXPECT().GetUserByID(gomock.Any(), user.ID).Times(1).Return(database.User{ID: user.ID, IsSystem: true}, nil)
},
},
{
name: "valid",
mocksFn: func(db *dbmock.MockStore, apiKey database.APIKey, user database.User) {
db.EXPECT().GetAPIKeyByID(gomock.Any(), apiKey.ID).Times(1).Return(apiKey, nil)
db.EXPECT().GetUserByID(gomock.Any(), user.ID).Times(1).Return(user, nil)
},
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
logger := testutil.Logger(t)
// Make a fake user and an API key for the mock calls.
now := dbtime.Now()
user := database.User{
ID: uuid.New(),
Email: "test@coder.com",
Username: "test",
Name: "Test User",
CreatedAt: now,
UpdatedAt: now,
RBACRoles: []string{},
LoginType: database.LoginTypePassword,
Status: database.UserStatusActive,
LastSeenAt: now,
}
keyID, _ := cryptorand.String(10)
keySecret, keySecretHashed, _ := apikey.GenerateSecret(22)
token := fmt.Sprintf("%s-%s", keyID, keySecret)
apiKey := database.APIKey{
ID: keyID,
LifetimeSeconds: 86400, // default in db
HashedSecret: keySecretHashed,
IPAddress: pqtype.Inet{
IPNet: net.IPNet{
IP: net.IPv4(127, 0, 0, 1),
Mask: net.IPv4Mask(255, 255, 255, 255),
},
Valid: true,
},
UserID: user.ID,
LastUsed: now,
ExpiresAt: now.Add(time.Hour),
CreatedAt: now,
UpdatedAt: now,
LoginType: database.LoginTypePassword,
Scopes: []database.APIKeyScope{database.ApiKeyScopeCoderAll},
TokenName: "",
}
if tc.key == "" {
tc.key = token
}
// Define any case-specific mocks.
if tc.mocksFn != nil {
tc.mocksFn(db, apiKey, user)
}
srv, err := aibridgedserver.NewServer(t.Context(), db, logger, "/", codersdk.AIBridgeConfig{}, nil, requiredExperiments, agplaiseats.Noop{})
require.NoError(t, err)
require.NotNil(t, srv)
resp, err := srv.IsAuthorized(t.Context(), &proto.IsAuthorizedRequest{Key: tc.key})
if tc.expectedErr != nil {
require.Error(t, err)
require.ErrorIs(t, err, tc.expectedErr)
} else {
expected := proto.IsAuthorizedResponse{
OwnerId: user.ID.String(),
ApiKeyId: keyID,
Username: user.Username,
}
require.NoError(t, err)
require.Equal(t, &expected, resp)
}
})
}
}
// When IsAuthorizedRequest carries KeyId instead of Key, the server skips
// the secret check and validates only that the key exists, is unexpired, and
// belongs to a non-deleted non-system user. This is the path used by
// in-process delegated callers (e.g., chatd) that hold only the key ID.
func TestAuthorization_Delegated(t *testing.T) {
t.Parallel()
cases := []struct {
name string
mocksFn func(db *dbmock.MockStore, apiKey database.APIKey, user database.User)
bothFields bool
expectedErr error
}{
{
name: "valid",
mocksFn: func(db *dbmock.MockStore, apiKey database.APIKey, user database.User) {
db.EXPECT().GetAPIKeyByID(gomock.Any(), apiKey.ID).Times(1).Return(apiKey, nil)
db.EXPECT().GetUserByID(gomock.Any(), user.ID).Times(1).Return(user, nil)
},
},
{
name: "unknown key",
expectedErr: aibridgedserver.ErrUnknownKey,
mocksFn: func(db *dbmock.MockStore, apiKey database.APIKey, _ database.User) {
db.EXPECT().GetAPIKeyByID(gomock.Any(), apiKey.ID).Times(1).Return(database.APIKey{}, sql.ErrNoRows)
},
},
{
name: "expired",
expectedErr: aibridgedserver.ErrExpired,
mocksFn: func(db *dbmock.MockStore, apiKey database.APIKey, _ database.User) {
apiKey.ExpiresAt = dbtime.Now().Add(-time.Hour)
db.EXPECT().GetAPIKeyByID(gomock.Any(), apiKey.ID).Times(1).Return(apiKey, nil)
},
},
{
// Sending both Key and KeyId is an API misuse and must be
// rejected to avoid ambiguity about which path was taken.
name: "both fields set",
bothFields: true,
expectedErr: aibridgedserver.ErrAmbiguousAuth,
},
{
// A bogus secret has no effect on the delegated path because
// the secret is never checked. This is the load-bearing
// security property: trust is established out-of-band, not in
// this RPC.
name: "secret hash mismatch is ignored",
mocksFn: func(db *dbmock.MockStore, apiKey database.APIKey, user database.User) {
apiKey.HashedSecret = []byte("not-the-real-hash")
db.EXPECT().GetAPIKeyByID(gomock.Any(), apiKey.ID).Times(1).Return(apiKey, nil)
db.EXPECT().GetUserByID(gomock.Any(), user.ID).Times(1).Return(user, nil)
},
},
{
// The delegated path must still reject keys whose owner has
// been deleted; trust at the transport boundary does not
// extend to bypassing user-status checks.
name: "deleted user",
expectedErr: aibridgedserver.ErrDeletedUser,
mocksFn: func(db *dbmock.MockStore, apiKey database.APIKey, user database.User) {
db.EXPECT().GetAPIKeyByID(gomock.Any(), apiKey.ID).Times(1).Return(apiKey, nil)
db.EXPECT().GetUserByID(gomock.Any(), user.ID).Times(1).Return(database.User{ID: user.ID, Deleted: true}, nil)
},
},
{
// Likewise, a system user must never be authenticated through
// the delegated path.
name: "system user",
expectedErr: aibridgedserver.ErrSystemUser,
mocksFn: func(db *dbmock.MockStore, apiKey database.APIKey, user database.User) {
db.EXPECT().GetAPIKeyByID(gomock.Any(), apiKey.ID).Times(1).Return(apiKey, nil)
db.EXPECT().GetUserByID(gomock.Any(), user.ID).Times(1).Return(database.User{ID: user.ID, IsSystem: true}, nil)
},
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
logger := testutil.Logger(t)
now := dbtime.Now()
user := database.User{
ID: uuid.New(),
Email: "test@coder.com",
Username: "test",
Name: "Test User",
CreatedAt: now,
UpdatedAt: now,
RBACRoles: []string{},
LoginType: database.LoginTypePassword,
Status: database.UserStatusActive,
LastSeenAt: now,
}
keyID, _ := cryptorand.String(10)
_, keySecretHashed, _ := apikey.GenerateSecret(22)
apiKey := database.APIKey{
ID: keyID,
LifetimeSeconds: 86400,
HashedSecret: keySecretHashed,
UserID: user.ID,
LastUsed: now,
ExpiresAt: now.Add(time.Hour),
CreatedAt: now,
UpdatedAt: now,
LoginType: database.LoginTypePassword,
Scopes: []database.APIKeyScope{database.ApiKeyScopeCoderAll},
}
if tc.mocksFn != nil {
tc.mocksFn(db, apiKey, user)
}
srv, err := aibridgedserver.NewServer(t.Context(), db, logger, "/", codersdk.AIBridgeConfig{}, nil, requiredExperiments, agplaiseats.Noop{})
require.NoError(t, err)
require.NotNil(t, srv)
req := &proto.IsAuthorizedRequest{KeyId: keyID}
if tc.bothFields {
req.Key = "anything-anything"
}
resp, err := srv.IsAuthorized(t.Context(), req)
if tc.expectedErr != nil {
require.Error(t, err)
require.ErrorIs(t, err, tc.expectedErr)
return
}
require.NoError(t, err)
require.Equal(t, &proto.IsAuthorizedResponse{
OwnerId: user.ID.String(),
ApiKeyId: keyID,
Username: user.Username,
}, resp)
})
}
}
func TestGetMCPServerConfigs(t *testing.T) {
t.Parallel()
externalAuthCfgs := []*externalauth.Config{
{
ID: "1",
MCPURL: "1.com/mcp",
},
{
ID: "2", // Will not be eligible for inclusion since MCPURL is not defined.
},
}
cases := []struct {
name string
disableCoderMCPInjection bool
experiments codersdk.Experiments
externalAuthConfigs []*externalauth.Config
expectCoderMCP bool
expectedExternalMCP bool
}{
{
name: "experiments not enabled",
experiments: codersdk.Experiments{},
},
{
name: "MCP experiment enabled, not OAuth2",
experiments: codersdk.Experiments{codersdk.ExperimentMCPServerHTTP},
},
{
name: "OAuth2 experiment enabled, not MCP",
experiments: codersdk.Experiments{codersdk.ExperimentOAuth2},
},
{
name: "only internal MCP",
experiments: requiredExperiments,
expectCoderMCP: true,
},
{
name: "only external MCP",
externalAuthConfigs: externalAuthCfgs,
expectedExternalMCP: true,
},
{
name: "both internal & external MCP",
experiments: requiredExperiments,
externalAuthConfigs: externalAuthCfgs,
expectCoderMCP: true,
expectedExternalMCP: true,
},
{
name: "both internal & external MCP, but coder MCP tools not injected",
disableCoderMCPInjection: true,
experiments: requiredExperiments,
externalAuthConfigs: externalAuthCfgs,
expectCoderMCP: false,
expectedExternalMCP: true,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
logger := testutil.Logger(t)
accessURL := "https://my-cool-deployment.com"
srv, err := aibridgedserver.NewServer(t.Context(), db, logger, accessURL, codersdk.AIBridgeConfig{
InjectCoderMCPTools: serpent.Bool(!tc.disableCoderMCPInjection),
}, tc.externalAuthConfigs, tc.experiments, agplaiseats.Noop{})
require.NoError(t, err)
require.NotNil(t, srv)
resp, err := srv.GetMCPServerConfigs(t.Context(), &proto.GetMCPServerConfigsRequest{})
require.NoError(t, err)
require.NotNil(t, resp)
if tc.expectCoderMCP {
coderConfig := resp.CoderMcpConfig
require.NotNil(t, coderConfig)
require.Equal(t, aibridged.InternalMCPServerID, coderConfig.GetId())
expectedURL, err := url.JoinPath(accessURL, codermcp.MCPEndpoint)
require.NoError(t, err)
require.Equal(t, expectedURL, coderConfig.GetUrl())
require.Empty(t, coderConfig.GetToolAllowRegex())
require.Empty(t, coderConfig.GetToolDenyRegex())
} else {
require.Empty(t, resp.GetCoderMcpConfig())
}
if tc.expectedExternalMCP {
require.Len(t, resp.GetExternalAuthMcpConfigs(), 1)
} else {
require.Empty(t, resp.GetExternalAuthMcpConfigs())
}
})
}
}
func TestGetMCPServerAccessTokensBatch(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
logger := testutil.Logger(t)
// Given: 2 external auth configured with MCP and 1 without.
srv, err := aibridgedserver.NewServer(t.Context(), db, logger, "/", codersdk.AIBridgeConfig{}, []*externalauth.Config{
{
ID: "1",
MCPURL: "1.com/mcp",
},
{
ID: "2",
MCPURL: "2.com/mcp",
},
{
ID: "3",
},
}, requiredExperiments, agplaiseats.Noop{})
require.NoError(t, err)
require.NotNil(t, srv)
// When: requesting all external auth links, return all.
db.EXPECT().GetExternalAuthLinksByUserID(gomock.Any(), gomock.Any()).MinTimes(1).DoAndReturn(func(ctx context.Context, userID uuid.UUID) ([]database.ExternalAuthLink, error) {
return []database.ExternalAuthLink{
{
UserID: userID,
ProviderID: "1",
OAuthAccessToken: "1-token",
},
{
UserID: userID,
ProviderID: "2",
OAuthAccessToken: "2-token",
OAuthExpiry: dbtime.Now().Add(-time.Minute), // This token is expired and should not be returned.
},
{
UserID: userID,
ProviderID: "3",
OAuthAccessToken: "3-token",
},
}, nil
})
// When: accessing the MCP server access tokens, only the 2 with MCP configured should be returned, and the 1 without should
// not fail the request but rather have an error returned specifically for that server.
resp, err := srv.GetMCPServerAccessTokensBatch(t.Context(), &proto.GetMCPServerAccessTokensBatchRequest{
UserId: uuid.NewString(),
McpServerConfigIds: []string{"1", "1", "2", "3"}, // Duplicates must be tolerated.
})
require.NoError(t, err)
// Then: 2 MCP servers are eligible but only 1 will return a valid token as the other expired.
require.Len(t, resp.GetAccessTokens(), 1)
require.Equal(t, "1-token", resp.GetAccessTokens()["1"])
require.Len(t, resp.GetErrors(), 2)
require.Contains(t, resp.GetErrors()["2"], aibridgedserver.ErrExpiredOrInvalidOAuthToken.Error())
require.Contains(t, resp.GetErrors()["3"], aibridgedserver.ErrNoMCPConfigFound.Error())
}
func TestRecordInterception(t *testing.T) {
t.Parallel()
var (
metadataProto = map[string]*anypb.Any{
"key": mustMarshalAny(t, &structpb.Value{Kind: &structpb.Value_StringValue{StringValue: "value"}}),
}
metadataJSON = `{"key":"value"}`
)
testRecordMethod(t,
func(srv *aibridgedserver.Server, ctx context.Context, req *proto.RecordInterceptionRequest) (*proto.RecordInterceptionResponse, error) {
return srv.RecordInterception(ctx, req)
},
[]testRecordMethodCase[*proto.RecordInterceptionRequest]{
{
name: "valid interception",
request: &proto.RecordInterceptionRequest{
Id: uuid.NewString(),
ApiKeyId: uuid.NewString(),
InitiatorId: uuid.NewString(),
Provider: "anthropic",
ProviderName: "anthropic",
Model: "claude-4-opus",
Metadata: metadataProto,
StartedAt: timestamppb.Now(),
CredentialKind: "byok",
CredentialHint: "sk-a...efgh",
},
setupMocks: func(t *testing.T, db *dbmock.MockStore, req *proto.RecordInterceptionRequest) {
interceptionID, err := uuid.Parse(req.GetId())
assert.NoError(t, err, "parse interception UUID")
initiatorID, err := uuid.Parse(req.GetInitiatorId())
assert.NoError(t, err, "parse interception initiator UUID")
db.EXPECT().InsertAIBridgeInterception(gomock.Any(), database.InsertAIBridgeInterceptionParams{
ID: interceptionID,
APIKeyID: sql.NullString{String: req.ApiKeyId, Valid: true},
InitiatorID: initiatorID,
Provider: req.GetProvider(),
ProviderName: req.GetProviderName(),
Model: req.GetModel(),
Metadata: json.RawMessage(metadataJSON),
StartedAt: req.StartedAt.AsTime().UTC(),
CredentialKind: database.CredentialKindByok,
CredentialHint: "sk-a...efgh",
}).Return(database.AIBridgeInterception{
ID: interceptionID,
APIKeyID: sql.NullString{String: req.ApiKeyId, Valid: true},
InitiatorID: initiatorID,
Provider: req.GetProvider(),
ProviderName: req.GetProviderName(),
Model: req.GetModel(),
StartedAt: req.StartedAt.AsTime().UTC(),
CredentialKind: database.CredentialKindByok,
CredentialHint: "sk-a...efgh",
}, nil)
},
},
{
name: "valid interception with client session ID",
request: &proto.RecordInterceptionRequest{
Id: uuid.NewString(),
ApiKeyId: uuid.NewString(),
InitiatorId: uuid.NewString(),
Provider: "anthropic",
Model: "claude-4-opus",
Metadata: metadataProto,
StartedAt: timestamppb.Now(),
ClientSessionId: ptr.Ref("session-abc-123"),
},
setupMocks: func(t *testing.T, db *dbmock.MockStore, req *proto.RecordInterceptionRequest) {
interceptionID, err := uuid.Parse(req.GetId())
assert.NoError(t, err, "parse interception UUID")
initiatorID, err := uuid.Parse(req.GetInitiatorId())
assert.NoError(t, err, "parse interception initiator UUID")
db.EXPECT().InsertAIBridgeInterception(gomock.Any(), database.InsertAIBridgeInterceptionParams{
ID: interceptionID,
APIKeyID: sql.NullString{String: req.ApiKeyId, Valid: true},
InitiatorID: initiatorID,
Provider: req.GetProvider(),
ProviderName: req.GetProvider(),
Model: req.GetModel(),
Metadata: json.RawMessage(metadataJSON),
StartedAt: req.StartedAt.AsTime().UTC(),
ClientSessionID: sql.NullString{String: "session-abc-123", Valid: true},
CredentialKind: database.CredentialKindCentralized,
}).Return(database.AIBridgeInterception{
ID: interceptionID,
APIKeyID: sql.NullString{String: req.ApiKeyId, Valid: true},
InitiatorID: initiatorID,
Provider: req.GetProvider(),
ProviderName: req.GetProvider(),
Model: req.GetModel(),
StartedAt: req.StartedAt.AsTime().UTC(),
ClientSessionID: sql.NullString{String: "session-abc-123", Valid: true},
}, nil)
},
},
{
name: "empty client session ID treated as null",
request: &proto.RecordInterceptionRequest{
Id: uuid.NewString(),
ApiKeyId: uuid.NewString(),
InitiatorId: uuid.NewString(),
Provider: "anthropic",
Model: "claude-4-opus",
Metadata: metadataProto,
StartedAt: timestamppb.Now(),
ClientSessionId: ptr.Ref(""),
},
setupMocks: func(t *testing.T, db *dbmock.MockStore, req *proto.RecordInterceptionRequest) {
interceptionID, err := uuid.Parse(req.GetId())
assert.NoError(t, err, "parse interception UUID")
initiatorID, err := uuid.Parse(req.GetInitiatorId())
assert.NoError(t, err, "parse interception initiator UUID")
db.EXPECT().InsertAIBridgeInterception(gomock.Any(), database.InsertAIBridgeInterceptionParams{
ID: interceptionID,
APIKeyID: sql.NullString{String: req.ApiKeyId, Valid: true},
InitiatorID: initiatorID,
Provider: req.GetProvider(),
ProviderName: req.GetProvider(),
Model: req.GetModel(),
Metadata: json.RawMessage(metadataJSON),
StartedAt: req.StartedAt.AsTime().UTC(),
ClientSessionID: sql.NullString{},
CredentialKind: database.CredentialKindCentralized,
}).Return(database.AIBridgeInterception{
ID: interceptionID,
APIKeyID: sql.NullString{String: req.ApiKeyId, Valid: true},
InitiatorID: initiatorID,
Provider: req.GetProvider(),
ProviderName: req.GetProvider(),
Model: req.GetModel(),
StartedAt: req.StartedAt.AsTime().UTC(),
}, nil)
},
},
{
name: "invalid interception ID",
request: &proto.RecordInterceptionRequest{
Id: "not-a-uuid",
InitiatorId: uuid.NewString(),
ApiKeyId: uuid.NewString(),
Provider: "anthropic",
Model: "claude-4-opus",
StartedAt: timestamppb.Now(),
},
expectedErr: "invalid interception ID",
},
{
name: "invalid initiator ID",
request: &proto.RecordInterceptionRequest{
Id: uuid.NewString(),
ApiKeyId: uuid.NewString(),
InitiatorId: "not-a-uuid",
Provider: "anthropic",
Model: "claude-4-opus",
StartedAt: timestamppb.Now(),
},
expectedErr: "invalid initiator ID",
},
{
name: "invalid interception no api key set",
request: &proto.RecordInterceptionRequest{
Id: uuid.NewString(),
InitiatorId: uuid.NewString(),
Provider: "anthropic",
Model: "claude-4-opus",
Metadata: metadataProto,
StartedAt: timestamppb.Now(),
},
expectedErr: "empty API key ID",
},
{
name: "provider name differs from provider type",
request: &proto.RecordInterceptionRequest{
Id: uuid.NewString(),
ApiKeyId: uuid.NewString(),
InitiatorId: uuid.NewString(),
Provider: "copilot",
ProviderName: "copilot-business",
Model: "gpt-4o",
StartedAt: timestamppb.Now(),
},
setupMocks: func(t *testing.T, db *dbmock.MockStore, req *proto.RecordInterceptionRequest) {
interceptionID, err := uuid.Parse(req.GetId())
assert.NoError(t, err, "parse interception UUID")
initiatorID, err := uuid.Parse(req.GetInitiatorId())
assert.NoError(t, err, "parse interception initiator UUID")
db.EXPECT().InsertAIBridgeInterception(gomock.Any(), database.InsertAIBridgeInterceptionParams{
ID: interceptionID,
APIKeyID: sql.NullString{String: req.ApiKeyId, Valid: true},
InitiatorID: initiatorID,
Provider: "copilot",
ProviderName: "copilot-business",
Model: req.GetModel(),
Metadata: json.RawMessage("{}"),
StartedAt: req.StartedAt.AsTime().UTC(),
CredentialKind: database.CredentialKindCentralized,
}).Return(database.AIBridgeInterception{
ID: interceptionID,
InitiatorID: initiatorID,
Provider: "copilot",
ProviderName: "copilot-business",
Model: req.GetModel(),
StartedAt: req.StartedAt.AsTime().UTC(),
}, nil)
},
},
{
name: "empty provider name defaults to provider",
request: &proto.RecordInterceptionRequest{
Id: uuid.NewString(),
ApiKeyId: uuid.NewString(),
InitiatorId: uuid.NewString(),
Provider: "copilot",
Model: "gpt-4o",
StartedAt: timestamppb.Now(),
},
setupMocks: func(t *testing.T, db *dbmock.MockStore, req *proto.RecordInterceptionRequest) {
interceptionID, err := uuid.Parse(req.GetId())
assert.NoError(t, err, "parse interception UUID")
initiatorID, err := uuid.Parse(req.GetInitiatorId())
assert.NoError(t, err, "parse interception initiator UUID")
db.EXPECT().InsertAIBridgeInterception(gomock.Any(), database.InsertAIBridgeInterceptionParams{
ID: interceptionID,
APIKeyID: sql.NullString{String: req.ApiKeyId, Valid: true},
InitiatorID: initiatorID,
Provider: "copilot",
ProviderName: "copilot",
Model: req.GetModel(),
Metadata: json.RawMessage("{}"),
StartedAt: req.StartedAt.AsTime().UTC(),
CredentialKind: database.CredentialKindCentralized,
}).Return(database.AIBridgeInterception{
ID: interceptionID,
InitiatorID: initiatorID,
Provider: "copilot",
ProviderName: "copilot",
Model: req.GetModel(),
StartedAt: req.StartedAt.AsTime().UTC(),
}, nil)
},
},
{
name: "whitespace provider name defaults to provider",
request: &proto.RecordInterceptionRequest{
Id: uuid.NewString(),
ApiKeyId: uuid.NewString(),
InitiatorId: uuid.NewString(),
Provider: "copilot",
ProviderName: " ",
Model: "gpt-4o",
StartedAt: timestamppb.Now(),
},
setupMocks: func(t *testing.T, db *dbmock.MockStore, req *proto.RecordInterceptionRequest) {
interceptionID, err := uuid.Parse(req.GetId())
assert.NoError(t, err, "parse interception UUID")
initiatorID, err := uuid.Parse(req.GetInitiatorId())
assert.NoError(t, err, "parse interception initiator UUID")
db.EXPECT().InsertAIBridgeInterception(gomock.Any(), database.InsertAIBridgeInterceptionParams{
ID: interceptionID,
APIKeyID: sql.NullString{String: req.ApiKeyId, Valid: true},
InitiatorID: initiatorID,
Provider: "copilot",
ProviderName: "copilot",
Model: req.GetModel(),
Metadata: json.RawMessage("{}"),
StartedAt: req.StartedAt.AsTime().UTC(),
CredentialKind: database.CredentialKindCentralized,
}).Return(database.AIBridgeInterception{
ID: interceptionID,
InitiatorID: initiatorID,
Provider: "copilot",
ProviderName: "copilot",
Model: req.GetModel(),
StartedAt: req.StartedAt.AsTime().UTC(),
}, nil)
},
},
{
name: "database error",
request: &proto.RecordInterceptionRequest{
Id: uuid.NewString(),
ApiKeyId: uuid.NewString(),
InitiatorId: uuid.NewString(),
Provider: "anthropic",
Model: "claude-4-opus",
StartedAt: timestamppb.Now(),
},
setupMocks: func(t *testing.T, db *dbmock.MockStore, req *proto.RecordInterceptionRequest) {
db.EXPECT().InsertAIBridgeInterception(gomock.Any(), gomock.Any()).Return(database.AIBridgeInterception{}, sql.ErrConnDone)
},
expectedErr: "start interception",
},
{
name: "ok with parent correlation",
request: &proto.RecordInterceptionRequest{
Id: uuid.UUID{3}.String(),
ApiKeyId: uuid.NewString(),
InitiatorId: uuid.NewString(),
Provider: "anthropic",
Model: "claude-4-opus",
StartedAt: timestamppb.Now(),
CorrelatingToolCallId: ptr.Ref("call_abc"),
},
setupMocks: func(t *testing.T, db *dbmock.MockStore, req *proto.RecordInterceptionRequest) {
selfID, err := uuid.Parse(req.GetId())
assert.NoError(t, err, "parse self UUID")
parentID := uuid.UUID{4}
rootID := uuid.UUID{5}
db.EXPECT().GetAIBridgeInterceptionLineageByToolCallID(
gomock.Any(),
"call_abc",
).Return(database.GetAIBridgeInterceptionLineageByToolCallIDRow{
ThreadParentID: parentID,
ThreadRootID: rootID,
}, nil)
db.EXPECT().InsertAIBridgeInterception(gomock.Any(), gomock.Cond(func(p database.InsertAIBridgeInterceptionParams) bool {
return assert.Equal(t, selfID, p.ID, "ID") &&
assert.Equal(t, uuid.NullUUID{UUID: parentID, Valid: true}, p.ThreadParentInterceptionID, "thread parent interception ID") &&
assert.Equal(t, uuid.NullUUID{UUID: rootID, Valid: true}, p.ThreadRootInterceptionID, "thread root interception ID")
})).Return(database.AIBridgeInterception{
ID: selfID,
}, nil)
},
},
{
name: "no lineage",
request: &proto.RecordInterceptionRequest{
Id: uuid.UUID{3}.String(),
ApiKeyId: uuid.NewString(),
InitiatorId: uuid.NewString(),
Provider: "anthropic",
Model: "claude-4-opus",
StartedAt: timestamppb.Now(),
CorrelatingToolCallId: ptr.Ref("call_abc"),
},
setupMocks: func(t *testing.T, db *dbmock.MockStore, req *proto.RecordInterceptionRequest) {
selfID, err := uuid.Parse(req.GetId())
assert.NoError(t, err, "parse self UUID")
db.EXPECT().GetAIBridgeInterceptionLineageByToolCallID(
gomock.Any(),
"call_abc",
).Return(database.GetAIBridgeInterceptionLineageByToolCallIDRow{}, sql.ErrNoRows)
db.EXPECT().InsertAIBridgeInterception(gomock.Any(), gomock.Cond(func(p database.InsertAIBridgeInterceptionParams) bool {
return assert.Equal(t, selfID, p.ID, "ID") &&
assert.Equal(t, uuid.NullUUID{}, p.ThreadParentInterceptionID, "thread parent interception ID") &&
assert.Equal(t, uuid.NullUUID{}, p.ThreadRootInterceptionID, "thread root interception ID")
})).Return(database.AIBridgeInterception{
ID: selfID,
}, nil)
},
},
{
name: "parent without root", // This should never happen since GetAIBridgeInterceptionLineageByToolCallID always returns both, but still...
request: &proto.RecordInterceptionRequest{
Id: uuid.UUID{3}.String(),
ApiKeyId: uuid.NewString(),
InitiatorId: uuid.NewString(),
Provider: "anthropic",
Model: "claude-4-opus",
StartedAt: timestamppb.Now(),
CorrelatingToolCallId: ptr.Ref("call_abc"),
},
setupMocks: func(t *testing.T, db *dbmock.MockStore, req *proto.RecordInterceptionRequest) {
selfID, err := uuid.Parse(req.GetId())
assert.NoError(t, err, "parse self UUID")
parentID := uuid.UUID{4}
db.EXPECT().GetAIBridgeInterceptionLineageByToolCallID(
gomock.Any(),
"call_abc",
).Return(database.GetAIBridgeInterceptionLineageByToolCallIDRow{
ThreadParentID: parentID,
}, nil)
db.EXPECT().InsertAIBridgeInterception(gomock.Any(), gomock.Cond(func(p database.InsertAIBridgeInterceptionParams) bool {
return assert.Equal(t, selfID, p.ID, "ID") &&
assert.Equal(t, uuid.NullUUID{UUID: parentID, Valid: true}, p.ThreadParentInterceptionID, "thread parent interception ID") &&
assert.Equal(t, uuid.NullUUID{}, p.ThreadRootInterceptionID, "thread root interception ID not expected")
})).Return(database.AIBridgeInterception{
ID: selfID,
}, nil)
},
},
{
name: "ok no parent found",
request: &proto.RecordInterceptionRequest{
Id: uuid.UUID{5}.String(),
ApiKeyId: uuid.NewString(),
InitiatorId: uuid.NewString(),
Provider: "anthropic",
Model: "claude-4-opus",
StartedAt: timestamppb.Now(),
CorrelatingToolCallId: ptr.Ref("call_orphan"),
},
setupMocks: func(t *testing.T, db *dbmock.MockStore, req *proto.RecordInterceptionRequest) {
selfID, err := uuid.Parse(req.GetId())
assert.NoError(t, err, "parse self UUID")
db.EXPECT().GetAIBridgeInterceptionLineageByToolCallID(
gomock.Any(),
"call_orphan",
).Return(database.GetAIBridgeInterceptionLineageByToolCallIDRow{}, sql.ErrNoRows)
db.EXPECT().InsertAIBridgeInterception(gomock.Any(), gomock.Cond(func(p database.InsertAIBridgeInterceptionParams) bool {
return assert.Equal(t, selfID, p.ID, "ID") &&
assert.Equal(t, uuid.NullUUID{}, p.ThreadParentInterceptionID, "thread parent interception ID") &&
assert.Equal(t, uuid.NullUUID{}, p.ThreadRootInterceptionID, "thread root interception ID")
})).Return(database.AIBridgeInterception{
ID: selfID,
}, nil)
},
},
},
)
}
func TestRecordInterceptionEnded(t *testing.T) {
t.Parallel()
testRecordMethod(t,
func(srv *aibridgedserver.Server, ctx context.Context, req *proto.RecordInterceptionEndedRequest) (*proto.RecordInterceptionEndedResponse, error) {
return srv.RecordInterceptionEnded(ctx, req)
},
[]testRecordMethodCase[*proto.RecordInterceptionEndedRequest]{
{
name: "ok",
request: &proto.RecordInterceptionEndedRequest{
Id: uuid.UUID{1}.String(),
EndedAt: timestamppb.Now(),
},
setupMocks: func(t *testing.T, db *dbmock.MockStore, req *proto.RecordInterceptionEndedRequest) {
interceptionID, err := uuid.Parse(req.GetId())
assert.NoError(t, err, "parse interception UUID")
db.EXPECT().UpdateAIBridgeInterceptionEnded(gomock.Any(), database.UpdateAIBridgeInterceptionEndedParams{
ID: interceptionID,
EndedAt: req.EndedAt.AsTime(),
}).Return(database.AIBridgeInterception{
ID: interceptionID,
InitiatorID: uuid.UUID{2},
Provider: "prov",
Model: "mod",
StartedAt: time.Now(),
EndedAt: sql.NullTime{Time: req.EndedAt.AsTime(), Valid: true},
}, nil)
},
},
{
name: "bad_uuid_error",
request: &proto.RecordInterceptionEndedRequest{
Id: "this-is-not-uuid",
},
setupMocks: func(t *testing.T, db *dbmock.MockStore, req *proto.RecordInterceptionEndedRequest) {},
expectedErr: "invalid interception ID",
},
{
name: "database_error",
request: &proto.RecordInterceptionEndedRequest{
Id: uuid.UUID{1}.String(),
EndedAt: timestamppb.Now(),
},
setupMocks: func(t *testing.T, db *dbmock.MockStore, req *proto.RecordInterceptionEndedRequest) {
db.EXPECT().UpdateAIBridgeInterceptionEnded(gomock.Any(), gomock.Any()).Return(database.AIBridgeInterception{}, sql.ErrConnDone)
},
expectedErr: "end interception: " + sql.ErrConnDone.Error(),
},
},
)
}
func TestRecordTokenUsage(t *testing.T) {
t.Parallel()
var (
metadataProto = map[string]*anypb.Any{
"key": mustMarshalAny(t, &structpb.Value{Kind: &structpb.Value_StringValue{StringValue: "value"}}),
}
metadataJSON = `{"key":"value"}`
)
testRecordMethod(t,
func(srv *aibridgedserver.Server, ctx context.Context, req *proto.RecordTokenUsageRequest) (*proto.RecordTokenUsageResponse, error) {
return srv.RecordTokenUsage(ctx, req)
},
[]testRecordMethodCase[*proto.RecordTokenUsageRequest]{
{
name: "valid token usage",
request: &proto.RecordTokenUsageRequest{
InterceptionId: uuid.NewString(),
MsgId: "msg_123",
InputTokens: 100,
OutputTokens: 200,
CacheReadInputTokens: 50,
CacheWriteInputTokens: 10,
Metadata: metadataProto,
CreatedAt: timestamppb.Now(),
},
setupMocks: func(t *testing.T, db *dbmock.MockStore, req *proto.RecordTokenUsageRequest) {
interceptionID, err := uuid.Parse(req.GetInterceptionId())
assert.NoError(t, err, "parse interception UUID")
db.EXPECT().InsertAIBridgeTokenUsage(gomock.Any(), gomock.Cond(func(p database.InsertAIBridgeTokenUsageParams) bool {
if !assert.NotEqual(t, uuid.Nil, p.ID, "ID") ||
!assert.Equal(t, interceptionID, p.InterceptionID, "interception ID") ||
!assert.Equal(t, req.GetMsgId(), p.ProviderResponseID, "provider response ID") ||
!assert.Equal(t, req.GetInputTokens(), p.InputTokens, "input tokens") ||
!assert.Equal(t, req.GetOutputTokens(), p.OutputTokens, "output tokens") ||
!assert.Equal(t, req.GetCacheReadInputTokens(), p.CacheReadInputTokens, "cache read input tokens") ||
!assert.Equal(t, req.GetCacheWriteInputTokens(), p.CacheWriteInputTokens, "cache write input tokens") ||
!assert.JSONEq(t, metadataJSON, string(p.Metadata), "metadata") ||
!assert.WithinDuration(t, req.GetCreatedAt().AsTime(), p.CreatedAt, time.Second, "created at") {
return false
}
return true
})).Return(database.AIBridgeTokenUsage{
ID: uuid.New(),
InterceptionID: interceptionID,
ProviderResponseID: req.GetMsgId(),
InputTokens: req.GetInputTokens(),
OutputTokens: req.GetOutputTokens(),
CacheReadInputTokens: req.GetCacheReadInputTokens(),
CacheWriteInputTokens: req.GetCacheWriteInputTokens(),
Metadata: pqtype.NullRawMessage{
RawMessage: json.RawMessage(metadataJSON),
Valid: true,
},
CreatedAt: req.GetCreatedAt().AsTime(),
}, nil)
},
},
{
name: "invalid interception ID",
request: &proto.RecordTokenUsageRequest{
InterceptionId: "not-a-uuid",
MsgId: "msg_123",
InputTokens: 100,
OutputTokens: 200,
CreatedAt: timestamppb.Now(),
},
expectedErr: "failed to parse interception_id",
},
{
name: "database error",
request: &proto.RecordTokenUsageRequest{
InterceptionId: uuid.NewString(),
MsgId: "msg_123",
InputTokens: 100,
OutputTokens: 200,
CreatedAt: timestamppb.Now(),
},
setupMocks: func(t *testing.T, db *dbmock.MockStore, req *proto.RecordTokenUsageRequest) {
db.EXPECT().InsertAIBridgeTokenUsage(gomock.Any(), gomock.Any()).Return(database.AIBridgeTokenUsage{}, sql.ErrConnDone)
},
expectedErr: "insert token usage",
},
},
)
}
func TestRecordPromptUsage(t *testing.T) {
t.Parallel()
var (
metadataProto = map[string]*anypb.Any{
"key": mustMarshalAny(t, &structpb.Value{Kind: &structpb.Value_StringValue{StringValue: "value"}}),
}
metadataJSON = `{"key":"value"}`
)
testRecordMethod(t,
func(srv *aibridgedserver.Server, ctx context.Context, req *proto.RecordPromptUsageRequest) (*proto.RecordPromptUsageResponse, error) {
return srv.RecordPromptUsage(ctx, req)
},
[]testRecordMethodCase[*proto.RecordPromptUsageRequest]{
{
name: "valid prompt usage",
request: &proto.RecordPromptUsageRequest{
InterceptionId: uuid.NewString(),
MsgId: "msg_123",
Prompt: "yo",
Metadata: metadataProto,
CreatedAt: timestamppb.Now(),
},
setupMocks: func(t *testing.T, db *dbmock.MockStore, req *proto.RecordPromptUsageRequest) {
interceptionID, err := uuid.Parse(req.GetInterceptionId())
assert.NoError(t, err, "parse interception UUID")
db.EXPECT().InsertAIBridgeUserPrompt(gomock.Any(), gomock.Cond(func(p database.InsertAIBridgeUserPromptParams) bool {
if !assert.NotEqual(t, uuid.Nil, p.ID, "ID") ||
!assert.Equal(t, interceptionID, p.InterceptionID, "interception ID") ||
!assert.Equal(t, req.GetMsgId(), p.ProviderResponseID, "provider response ID") ||
!assert.Equal(t, req.GetPrompt(), p.Prompt, "prompt") ||
!assert.JSONEq(t, metadataJSON, string(p.Metadata), "metadata") ||
!assert.WithinDuration(t, req.GetCreatedAt().AsTime(), p.CreatedAt, time.Second, "created at") {
return false
}
return true
})).Return(database.AIBridgeUserPrompt{
ID: uuid.New(),
InterceptionID: interceptionID,
ProviderResponseID: req.GetMsgId(),
Prompt: req.GetPrompt(),
Metadata: pqtype.NullRawMessage{
RawMessage: json.RawMessage(metadataJSON),
Valid: true,
},
CreatedAt: req.GetCreatedAt().AsTime(),
}, nil)
},
},
{
name: "invalid interception ID",
request: &proto.RecordPromptUsageRequest{
InterceptionId: "not-a-uuid",
MsgId: "msg_123",
Prompt: "yo",
CreatedAt: timestamppb.Now(),
},
expectedErr: "failed to parse interception_id",
},
{
name: "database error",
request: &proto.RecordPromptUsageRequest{
InterceptionId: uuid.NewString(),
MsgId: "msg_123",
Prompt: "yo",
CreatedAt: timestamppb.Now(),
},
setupMocks: func(t *testing.T, db *dbmock.MockStore, req *proto.RecordPromptUsageRequest) {
db.EXPECT().InsertAIBridgeUserPrompt(gomock.Any(), gomock.Any()).Return(database.AIBridgeUserPrompt{}, sql.ErrConnDone)
},
expectedErr: "insert user prompt",
},
},
)
}
func TestRecordToolUsage(t *testing.T) {
t.Parallel()
var (
metadataProto = map[string]*anypb.Any{
"key": mustMarshalAny(t, &structpb.Value{Kind: &structpb.Value_NumberValue{NumberValue: 123.45}}),
}
metadataJSON = `{"key":123.45}`
)
testRecordMethod(t,
func(srv *aibridgedserver.Server, ctx context.Context, req *proto.RecordToolUsageRequest) (*proto.RecordToolUsageResponse, error) {
return srv.RecordToolUsage(ctx, req)
},
[]testRecordMethodCase[*proto.RecordToolUsageRequest]{
{
name: "valid tool usage with all fields",
request: &proto.RecordToolUsageRequest{
InterceptionId: uuid.NewString(),
MsgId: "msg_123",
ToolCallId: "call_xyz",
ServerUrl: ptr.Ref("https://api.example.com"),
Tool: "read_file",
Input: `{"path": "/etc/hosts"}`,
Injected: false,
InvocationError: ptr.Ref("permission denied"),
Metadata: metadataProto,
CreatedAt: timestamppb.Now(),
},
setupMocks: func(t *testing.T, db *dbmock.MockStore, req *proto.RecordToolUsageRequest) {
interceptionID, err := uuid.Parse(req.GetInterceptionId())
assert.NoError(t, err, "parse interception UUID")
dbServerURL := sql.NullString{}
if req.ServerUrl != nil {
dbServerURL.String = *req.ServerUrl
dbServerURL.Valid = true
}
dbInvocationError := sql.NullString{}
if req.InvocationError != nil {
dbInvocationError.String = *req.InvocationError
dbInvocationError.Valid = true
}
db.EXPECT().InsertAIBridgeToolUsage(gomock.Any(), gomock.Cond(func(p database.InsertAIBridgeToolUsageParams) bool {
if !assert.NotEqual(t, uuid.Nil, p.ID, "ID") ||
!assert.Equal(t, interceptionID, p.InterceptionID, "interception ID") ||
!assert.Equal(t, req.GetMsgId(), p.ProviderResponseID, "provider response ID") ||
!assert.Equal(t, sql.NullString{String: "call_xyz", Valid: true}, p.ProviderToolCallID, "provider tool call ID") ||
!assert.Equal(t, req.GetTool(), p.Tool, "tool") ||
!assert.Equal(t, dbServerURL, p.ServerUrl, "server URL") ||
!assert.Equal(t, req.GetInput(), p.Input, "input") ||
!assert.Equal(t, req.GetInjected(), p.Injected, "injected") ||
!assert.Equal(t, dbInvocationError, p.InvocationError, "invocation error") ||
!assert.JSONEq(t, metadataJSON, string(p.Metadata), "metadata") ||
!assert.WithinDuration(t, req.GetCreatedAt().AsTime(), p.CreatedAt, time.Second, "created at") {
return false
}
return true
})).Return(database.AIBridgeToolUsage{
ID: uuid.New(),
InterceptionID: interceptionID,
ProviderResponseID: req.GetMsgId(),
Tool: req.GetTool(),
ServerUrl: dbServerURL,
Input: req.GetInput(),
Injected: req.GetInjected(),
InvocationError: dbInvocationError,
Metadata: pqtype.NullRawMessage{
RawMessage: json.RawMessage(metadataJSON),
Valid: true,
},
CreatedAt: req.GetCreatedAt().AsTime(),
}, nil)
},
},
{
name: "invalid interception ID",
request: &proto.RecordToolUsageRequest{
InterceptionId: "not-a-uuid",
MsgId: "msg_123",
Tool: "read_file",
Input: `{"path": "/etc/hosts"}`,
CreatedAt: timestamppb.Now(),
},
expectedErr: "failed to parse interception_id",
},
{
name: "database error",
request: &proto.RecordToolUsageRequest{
InterceptionId: uuid.NewString(),
MsgId: "msg_123",
Tool: "read_file",
Input: `{"path": "/etc/hosts"}`,
CreatedAt: timestamppb.Now(),
},
setupMocks: func(t *testing.T, db *dbmock.MockStore, req *proto.RecordToolUsageRequest) {
db.EXPECT().InsertAIBridgeToolUsage(gomock.Any(), gomock.Any()).Return(database.AIBridgeToolUsage{}, sql.ErrConnDone)
},
expectedErr: "insert tool usage",
},
},
)
}
func TestRecordModelThought(t *testing.T) {
t.Parallel()
var (
metadataProto = map[string]*anypb.Any{
"key": mustMarshalAny(t, &structpb.Value{Kind: &structpb.Value_StringValue{StringValue: "value"}}),
}
metadataJSON = `{"key":"value"}`
)
testRecordMethod(t,
func(srv *aibridgedserver.Server, ctx context.Context, req *proto.RecordModelThoughtRequest) (*proto.RecordModelThoughtResponse, error) {
return srv.RecordModelThought(ctx, req)
},
[]testRecordMethodCase[*proto.RecordModelThoughtRequest]{
{
name: "valid model thought",
request: &proto.RecordModelThoughtRequest{
InterceptionId: uuid.NewString(),
Content: "I should list the files.",
Metadata: metadataProto,
CreatedAt: timestamppb.Now(),
},
setupMocks: func(t *testing.T, db *dbmock.MockStore, req *proto.RecordModelThoughtRequest) {
interceptionID, err := uuid.Parse(req.GetInterceptionId())
assert.NoError(t, err, "parse interception UUID")
db.EXPECT().InsertAIBridgeModelThought(gomock.Any(), gomock.Cond(func(p database.InsertAIBridgeModelThoughtParams) bool {
if !assert.Equal(t, interceptionID, p.InterceptionID, "interception ID") ||
!assert.Equal(t, "I should list the files.", p.Content, "content") ||
!assert.JSONEq(t, metadataJSON, string(p.Metadata), "metadata") {
return false
}
return true
})).Return(database.AIBridgeModelThought{
InterceptionID: interceptionID,
Content: "I should list the files.",
Metadata: pqtype.NullRawMessage{
RawMessage: json.RawMessage(metadataJSON),
Valid: true,
},
}, nil)
},
},
{
name: "invalid interception ID",
request: &proto.RecordModelThoughtRequest{
InterceptionId: "not-a-uuid",
Content: "thinking...",
CreatedAt: timestamppb.Now(),
},
expectedErr: "failed to parse interception_id",
},
{
name: "database error",
request: &proto.RecordModelThoughtRequest{
InterceptionId: uuid.NewString(),
Content: "thinking...",
CreatedAt: timestamppb.Now(),
},
setupMocks: func(t *testing.T, db *dbmock.MockStore, req *proto.RecordModelThoughtRequest) {
db.EXPECT().InsertAIBridgeModelThought(gomock.Any(), gomock.Any()).Return(database.AIBridgeModelThought{}, sql.ErrConnDone)
},
expectedErr: "insert model thought",
},
},
)
}
type testRecordMethodCase[Req any] struct {
name string
request Req
// setupMocks is called with the mock store and the above request.
setupMocks func(t *testing.T, db *dbmock.MockStore, req Req)
expectedErr string
}
// testRecordMethod is a helper that abstracts the common testing pattern for all Record* methods.
func testRecordMethod[Req any, Resp any](
t *testing.T,
callMethod func(srv *aibridgedserver.Server, ctx context.Context, req Req) (Resp, error),
cases []testRecordMethodCase[Req],
) {
t.Helper()
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
logger := testutil.Logger(t)
if tc.setupMocks != nil {
tc.setupMocks(t, db, tc.request)
}
ctx := testutil.Context(t, testutil.WaitLong)
srv, err := aibridgedserver.NewServer(ctx, db, logger, "/", codersdk.AIBridgeConfig{}, nil, requiredExperiments, agplaiseats.Noop{})
require.NoError(t, err)
resp, err := callMethod(srv, ctx, tc.request)
if tc.expectedErr != "" {
require.Error(t, err, "Expected error for test case: %s", tc.name)
require.Contains(t, err.Error(), tc.expectedErr)
} else {
require.NoError(t, err, "Unexpected error for test case: %s", tc.name)
require.NotNil(t, resp)
}
})
}
}
// Helper functions.
func mustMarshalAny(t *testing.T, msg protobufproto.Message) *anypb.Any {
t.Helper()
v, err := anypb.New(msg)
require.NoError(t, err)
return v
}
// logLine represents a parsed JSON log entry.
type logLine struct {
Msg string `json:"msg"`
Level string `json:"level"`
Fields map[string]any `json:"fields"`
}
// parseLogLines parses JSON log lines from a buffer.
func parseLogLines(buf *bytes.Buffer) []logLine {
var lines []logLine
scanner := bufio.NewScanner(buf)
for scanner.Scan() {
var line logLine
if err := json.Unmarshal(scanner.Bytes(), &line); err == nil {
lines = append(lines, line)
}
}
return lines
}
// getLogLinesWithMessage returns all log lines with the given message.
func getLogLinesWithMessage(lines []logLine, msg string) []logLine {
var result []logLine
for _, line := range lines {
if line.Msg == msg {
result = append(result, line)
}
}
return result
}
func TestStructuredLogging(t *testing.T) {
t.Parallel()
metadataProto := map[string]*anypb.Any{
"key": mustMarshalAny(t, &structpb.Value{Kind: &structpb.Value_StringValue{StringValue: "value"}}),
}
type testCase struct {
name string
structuredLogging bool
expectedErr error
setupMocks func(db *dbmock.MockStore, interceptionID uuid.UUID)
recordFn func(srv *aibridgedserver.Server, ctx context.Context, interceptionID uuid.UUID) error
expectedFields map[string]any
}
interceptionID := uuid.UUID{1}
initiatorID := uuid.UUID{2}
threadParentID := uuid.UUID{3}
threadRootID := uuid.UUID{4}
toolCallID := "my-tool-call"
sessionID := "some-session-id"
cases := []testCase{
{
name: "RecordInterception_logs_when_enabled",
structuredLogging: true,
setupMocks: func(db *dbmock.MockStore, intcID uuid.UUID) {
db.EXPECT().GetAIBridgeInterceptionLineageByToolCallID(gomock.Any(), toolCallID).Return(database.GetAIBridgeInterceptionLineageByToolCallIDRow{
ThreadParentID: threadParentID,
ThreadRootID: threadRootID,
}, nil)
db.EXPECT().InsertAIBridgeInterception(gomock.Any(), gomock.Any()).Return(database.AIBridgeInterception{
ID: intcID,
InitiatorID: initiatorID,
ThreadParentID: uuid.NullUUID{UUID: threadParentID, Valid: true},
ThreadRootID: uuid.NullUUID{UUID: threadRootID, Valid: true},
}, nil)
},
recordFn: func(srv *aibridgedserver.Server, ctx context.Context, intcID uuid.UUID) error {
_, err := srv.RecordInterception(ctx, &proto.RecordInterceptionRequest{
Id: intcID.String(),
ApiKeyId: "api-key-123",
InitiatorId: initiatorID.String(),
Provider: "anthropic",
Model: "claude-4-opus",
Metadata: metadataProto,
StartedAt: timestamppb.Now(),
CorrelatingToolCallId: ptr.Ref(toolCallID),
ClientSessionId: ptr.Ref(sessionID),
})
return err
},
expectedFields: map[string]any{
"record_type": "interception_start",
"interception_id": interceptionID.String(),
"initiator_id": initiatorID.String(),
"provider": "anthropic",
"model": "claude-4-opus",
"correlating_tool_call_id": toolCallID,
"thread_parent_id": threadParentID.String(),
"thread_root_id": threadRootID.String(),
"client_session_id": sessionID,
},
},
{
name: "RecordInterception_does_not_log_when_disabled",
structuredLogging: false,
setupMocks: func(db *dbmock.MockStore, intcID uuid.UUID) {
db.EXPECT().InsertAIBridgeInterception(gomock.Any(), gomock.Any()).Return(database.AIBridgeInterception{
ID: intcID,
InitiatorID: initiatorID,
}, nil)
},
recordFn: func(srv *aibridgedserver.Server, ctx context.Context, intcID uuid.UUID) error {
_, err := srv.RecordInterception(ctx, &proto.RecordInterceptionRequest{
Id: intcID.String(),
ApiKeyId: "api-key-123",
InitiatorId: initiatorID.String(),
Provider: "anthropic",
Model: "claude-4-opus",
StartedAt: timestamppb.Now(),
})
return err
},
expectedFields: nil, // No log expected.
},
{
name: "RecordInterception_log_on_db_error",
structuredLogging: true,
expectedErr: sql.ErrConnDone,
setupMocks: func(db *dbmock.MockStore, intcID uuid.UUID) {
db.EXPECT().InsertAIBridgeInterception(gomock.Any(), gomock.Any()).Return(database.AIBridgeInterception{}, sql.ErrConnDone)
},
recordFn: func(srv *aibridgedserver.Server, ctx context.Context, intcID uuid.UUID) error {
_, err := srv.RecordInterception(ctx, &proto.RecordInterceptionRequest{
Id: intcID.String(),
ApiKeyId: "api-key-123",
InitiatorId: initiatorID.String(),
Provider: "anthropic",
Model: "claude-4-opus",
StartedAt: timestamppb.Now(),
})
return err
},
// Even though the database call errored, we must still write the logs.
expectedFields: map[string]any{
"record_type": "interception_start",
"interception_id": interceptionID.String(),
"initiator_id": initiatorID.String(),
"provider": "anthropic",
"model": "claude-4-opus",
},
},
{
name: "RecordInterceptionEnded_logs_when_enabled",
structuredLogging: true,
setupMocks: func(db *dbmock.MockStore, intcID uuid.UUID) {
db.EXPECT().UpdateAIBridgeInterceptionEnded(gomock.Any(), gomock.Any()).Return(database.AIBridgeInterception{
ID: intcID,
}, nil)
},
recordFn: func(srv *aibridgedserver.Server, ctx context.Context, intcID uuid.UUID) error {
_, err := srv.RecordInterceptionEnded(ctx, &proto.RecordInterceptionEndedRequest{
Id: intcID.String(),
EndedAt: timestamppb.Now(),
})
return err
},
expectedFields: map[string]any{
"record_type": "interception_end",
"interception_id": interceptionID.String(),
},
},
{
name: "RecordTokenUsage_logs_when_enabled",
structuredLogging: true,
setupMocks: func(db *dbmock.MockStore, intcID uuid.UUID) {
db.EXPECT().InsertAIBridgeTokenUsage(gomock.Any(), gomock.Any()).Return(database.AIBridgeTokenUsage{
ID: uuid.New(),
InterceptionID: intcID,
}, nil)
},
recordFn: func(srv *aibridgedserver.Server, ctx context.Context, intcID uuid.UUID) error {
_, err := srv.RecordTokenUsage(ctx, &proto.RecordTokenUsageRequest{
InterceptionId: intcID.String(),
MsgId: "msg_123",
InputTokens: 100,
OutputTokens: 200,
CacheReadInputTokens: 50,
CacheWriteInputTokens: 10,
Metadata: metadataProto,
CreatedAt: timestamppb.Now(),
})
return err
},
expectedFields: map[string]any{
"record_type": "token_usage",
"interception_id": interceptionID.String(),
"input_tokens": float64(100), // JSON numbers are float64.
"output_tokens": float64(200),
"cache_read_input_tokens": float64(50),
"cache_write_input_tokens": float64(10),
},
},
{
name: "RecordPromptUsage_logs_when_enabled",
structuredLogging: true,
setupMocks: func(db *dbmock.MockStore, intcID uuid.UUID) {
db.EXPECT().InsertAIBridgeUserPrompt(gomock.Any(), gomock.Any()).Return(database.AIBridgeUserPrompt{
ID: uuid.New(),
InterceptionID: intcID,
}, nil)
},
recordFn: func(srv *aibridgedserver.Server, ctx context.Context, intcID uuid.UUID) error {
_, err := srv.RecordPromptUsage(ctx, &proto.RecordPromptUsageRequest{
InterceptionId: intcID.String(),
MsgId: "msg_123",
Prompt: "Hello, Claude!",
Metadata: metadataProto,
CreatedAt: timestamppb.Now(),
})
return err
},
expectedFields: map[string]any{
"record_type": "prompt_usage",
"interception_id": interceptionID.String(),
"prompt": "Hello, Claude!",
},
},
{
name: "RecordToolUsage_logs_when_enabled",
structuredLogging: true,
setupMocks: func(db *dbmock.MockStore, intcID uuid.UUID) {
db.EXPECT().InsertAIBridgeToolUsage(gomock.Any(), gomock.Any()).Return(database.AIBridgeToolUsage{
ID: uuid.New(),
InterceptionID: intcID,
}, nil)
},
recordFn: func(srv *aibridgedserver.Server, ctx context.Context, intcID uuid.UUID) error {
_, err := srv.RecordToolUsage(ctx, &proto.RecordToolUsageRequest{
InterceptionId: intcID.String(),
MsgId: "msg_123",
ServerUrl: ptr.Ref("https://api.example.com"),
Tool: "read_file",
Input: `{"path": "/etc/hosts"}`,
Injected: true,
InvocationError: ptr.Ref("permission denied"),
Metadata: metadataProto,
CreatedAt: timestamppb.Now(),
})
return err
},
expectedFields: map[string]any{
"record_type": "tool_usage",
"interception_id": interceptionID.String(),
"tool": "read_file",
"input": `{"path": "/etc/hosts"}`,
"injected": true,
"invocation_error": "permission denied",
},
},
{
name: "RecordModelThought_logs_when_enabled",
structuredLogging: true,
setupMocks: func(db *dbmock.MockStore, intcID uuid.UUID) {
db.EXPECT().InsertAIBridgeModelThought(gomock.Any(), gomock.Any()).Return(database.AIBridgeModelThought{
InterceptionID: intcID,
}, nil)
},
recordFn: func(srv *aibridgedserver.Server, ctx context.Context, intcID uuid.UUID) error {
_, err := srv.RecordModelThought(ctx, &proto.RecordModelThoughtRequest{
InterceptionId: intcID.String(),
Content: "I need to list the files.",
Metadata: metadataProto,
CreatedAt: timestamppb.Now(),
})
return err
},
expectedFields: map[string]any{
"record_type": "model_thought",
"interception_id": interceptionID.String(),
"content": "I need to list the files.",
},
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
buf := &bytes.Buffer{}
logger := slog.Make(slogjson.Sink(buf)).Leveled(slog.LevelDebug)
tc.setupMocks(db, interceptionID)
ctx := testutil.Context(t, testutil.WaitLong)
srv, err := aibridgedserver.NewServer(ctx, db, logger, "/", codersdk.AIBridgeConfig{
StructuredLogging: serpent.Bool(tc.structuredLogging),
}, nil, requiredExperiments, agplaiseats.Noop{})
require.NoError(t, err)
err = tc.recordFn(srv, ctx, interceptionID)
if tc.expectedErr != nil {
require.Error(t, err)
} else {
require.NoError(t, err)
}
lines := parseLogLines(buf)
if tc.expectedFields == nil {
// No log expected (disabled or error case).
require.Empty(t, lines)
} else {
matchedLines := getLogLinesWithMessage(lines, aibridgedserver.InterceptionLogMarker)
require.GreaterOrEqual(t, len(matchedLines), 1, "expected at least 1 log line(s) with message %q", aibridgedserver.InterceptionLogMarker)
fields := matchedLines[0].Fields
for key, expected := range tc.expectedFields {
require.Equal(t, expected, fields[key], "field %q mismatch", key)
}
}
})
}
}
// TestInferredThreadsByToolCalls verifies that a chain of interceptions linked via
// tool call IDs correctly propagates thread_parent_id and thread_root_id.
//
// The chain is: A → B → C
// - A is the root (no parent, no root)
// - B correlates via a tool call recorded by A (parent=A, root=A)
// - C correlates via a tool call recorded by B (parent=B, root=A)
func TestInferredThreadsByToolCalls(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitLong)
logger := testutil.Logger(t)
user := dbgen.User(t, db, database.User{})
srv, err := aibridgedserver.NewServer(ctx, db, logger, "/", codersdk.AIBridgeConfig{}, nil, requiredExperiments, agplaiseats.Noop{})
require.NoError(t, err)
aID := uuid.New()
bID := uuid.New()
cID := uuid.New()
// Record interception A (root of the chain, no correlation).
_, err = srv.RecordInterception(ctx, &proto.RecordInterceptionRequest{
Id: aID.String(),
ApiKeyId: uuid.NewString(),
InitiatorId: user.ID.String(),
Provider: "anthropic",
Model: "claude-4-opus",
StartedAt: timestamppb.Now(),
})
require.NoError(t, err)
// No thread association yet.
intcA, err := db.GetAIBridgeInterceptionByID(ctx, aID)
require.NoError(t, err)
require.Equal(t, uuid.NullUUID{}, intcA.ThreadParentID)
require.Equal(t, uuid.NullUUID{}, intcA.ThreadRootID)
// Record tool usage on A with a known tool call ID.
_, err = srv.RecordToolUsage(ctx, &proto.RecordToolUsageRequest{
InterceptionId: aID.String(),
MsgId: "resp_a",
ToolCallId: "call_a",
Tool: "bash",
Input: "{}",
CreatedAt: timestamppb.Now(),
})
require.NoError(t, err)
// Record interception B correlating to A's tool call.
_, err = srv.RecordInterception(ctx, &proto.RecordInterceptionRequest{
Id: bID.String(),
ApiKeyId: uuid.NewString(),
InitiatorId: user.ID.String(),
Provider: "anthropic",
Model: "claude-4-opus",
StartedAt: timestamppb.Now(),
CorrelatingToolCallId: ptr.Ref("call_a"),
})
require.NoError(t, err)
intcB, err := db.GetAIBridgeInterceptionByID(ctx, bID)
require.NoError(t, err)
require.Equal(t, uuid.NullUUID{UUID: aID, Valid: true}, intcB.ThreadParentID)
require.Equal(t, uuid.NullUUID{UUID: aID, Valid: true}, intcB.ThreadRootID)
// Record tool usage on B.
_, err = srv.RecordToolUsage(ctx, &proto.RecordToolUsageRequest{
InterceptionId: bID.String(),
MsgId: "resp_b",
ToolCallId: "call_b",
Tool: "bash",
Input: "{}",
CreatedAt: timestamppb.Now(),
})
require.NoError(t, err)
// Record interception C correlating to B's tool call.
_, err = srv.RecordInterception(ctx, &proto.RecordInterceptionRequest{
Id: cID.String(),
ApiKeyId: uuid.NewString(),
InitiatorId: user.ID.String(),
Provider: "anthropic",
Model: "claude-4-opus",
StartedAt: timestamppb.Now(),
CorrelatingToolCallId: ptr.Ref("call_b"),
})
require.NoError(t, err)
intcC, err := db.GetAIBridgeInterceptionByID(ctx, cID)
require.NoError(t, err)
require.Equal(t, uuid.NullUUID{UUID: bID, Valid: true}, intcC.ThreadParentID)
require.Equal(t, uuid.NullUUID{UUID: aID, Valid: true}, intcC.ThreadRootID)
}