Files
coder/enterprise/coderd/aibridge_test.go
T
Danny Kopping 13e3df67d6 feat: track client sessions (#22470)
This change adds support for tracking client session IDs in AI Bridge interceptions to enable better session-based auditing.

Depends on https://github.com/coder/aibridge/pull/198  
Fixes https://github.com/coder/internal/issues/1337

The session ID field is optional and not universally supported by all clients.
2026-03-06 14:43:53 +02:00

862 lines
28 KiB
Go

package coderd_test
import (
"database/sql"
"io"
"net/http"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
aiblib "github.com/coder/aibridge"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/cryptorand"
"github.com/coder/coder/v2/enterprise/coderd/coderdenttest"
"github.com/coder/coder/v2/enterprise/coderd/license"
"github.com/coder/coder/v2/testutil"
"github.com/coder/serpent"
)
func TestAIBridgeListInterceptions(t *testing.T) {
t.Parallel()
t.Run("RequiresLicenseFeature", func(t *testing.T) {
t.Parallel()
dv := coderdtest.DeploymentValues(t)
client, _ := coderdenttest.New(t, &coderdenttest.Options{
Options: &coderdtest.Options{
DeploymentValues: dv,
},
LicenseOptions: &coderdenttest.LicenseOptions{
// No aibridge feature
Features: license.Features{},
},
})
ctx := testutil.Context(t, testutil.WaitLong)
//nolint:gocritic // Owner role is irrelevant here.
_, err := client.AIBridgeListInterceptions(ctx, codersdk.AIBridgeListInterceptionsFilter{})
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusForbidden, sdkErr.StatusCode())
require.Equal(t, "AI Bridge is a Premium feature. Contact sales!", sdkErr.Message)
})
t.Run("EmptyDB", func(t *testing.T) {
t.Parallel()
dv := coderdtest.DeploymentValues(t)
dv.AI.BridgeConfig.Enabled = serpent.Bool(true)
client, _ := coderdenttest.New(t, &coderdenttest.Options{
Options: &coderdtest.Options{
DeploymentValues: dv,
},
LicenseOptions: &coderdenttest.LicenseOptions{
Features: license.Features{
codersdk.FeatureAIBridge: 1,
},
},
})
ctx := testutil.Context(t, testutil.WaitLong)
//nolint:gocritic // Owner role is irrelevant here.
res, err := client.AIBridgeListInterceptions(ctx, codersdk.AIBridgeListInterceptionsFilter{})
require.NoError(t, err)
require.Empty(t, res.Results)
})
t.Run("OK", func(t *testing.T) {
t.Parallel()
dv := coderdtest.DeploymentValues(t)
dv.AI.BridgeConfig.Enabled = serpent.Bool(true)
client, db, firstUser := coderdenttest.NewWithDatabase(t, &coderdenttest.Options{
Options: &coderdtest.Options{
DeploymentValues: dv,
},
LicenseOptions: &coderdenttest.LicenseOptions{
Features: license.Features{
codersdk.FeatureAIBridge: 1,
},
},
})
ctx := testutil.Context(t, testutil.WaitLong)
user1, err := client.User(ctx, codersdk.Me)
require.NoError(t, err)
user1Visible := database.VisibleUser{
ID: user1.ID,
Username: user1.Username,
Name: user1.Name,
AvatarURL: user1.AvatarURL,
}
_, user2 := coderdtest.CreateAnotherUser(t, client, firstUser.OrganizationID)
user2Visible := database.VisibleUser{
ID: user2.ID,
Username: user2.Username,
Name: user2.Name,
AvatarURL: user2.AvatarURL,
}
// Insert a bunch of test data.
now := dbtime.Now()
i1ApiKey := sql.NullString{String: "some-api-key", Valid: true}
i1EndedAt := now.Add(-time.Hour + time.Minute)
i1 := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{
APIKeyID: i1ApiKey,
InitiatorID: user1.ID,
StartedAt: now.Add(-time.Hour),
}, &i1EndedAt)
i1tok1 := dbgen.AIBridgeTokenUsage(t, db, database.InsertAIBridgeTokenUsageParams{
InterceptionID: i1.ID,
CreatedAt: now,
})
i1tok2 := dbgen.AIBridgeTokenUsage(t, db, database.InsertAIBridgeTokenUsageParams{
InterceptionID: i1.ID,
CreatedAt: now.Add(-time.Minute),
})
i1up1 := dbgen.AIBridgeUserPrompt(t, db, database.InsertAIBridgeUserPromptParams{
InterceptionID: i1.ID,
CreatedAt: now,
})
i1up2 := dbgen.AIBridgeUserPrompt(t, db, database.InsertAIBridgeUserPromptParams{
InterceptionID: i1.ID,
CreatedAt: now.Add(-time.Minute),
})
i1tool1 := dbgen.AIBridgeToolUsage(t, db, database.InsertAIBridgeToolUsageParams{
InterceptionID: i1.ID,
CreatedAt: now,
})
i1tool2 := dbgen.AIBridgeToolUsage(t, db, database.InsertAIBridgeToolUsageParams{
InterceptionID: i1.ID,
CreatedAt: now.Add(-time.Minute),
})
i2 := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{
InitiatorID: user2.ID,
StartedAt: now,
}, &now)
// Convert to SDK types for response comparison.
// You may notice that the ordering of the inner arrays are ASC, this is
// intentional.
i1SDK := db2sdk.AIBridgeInterception(i1, user1Visible, []database.AIBridgeTokenUsage{i1tok2, i1tok1}, []database.AIBridgeUserPrompt{i1up2, i1up1}, []database.AIBridgeToolUsage{i1tool2, i1tool1})
i2SDK := db2sdk.AIBridgeInterception(i2, user2Visible, nil, nil, nil)
res, err := client.AIBridgeListInterceptions(ctx, codersdk.AIBridgeListInterceptionsFilter{})
require.NoError(t, err)
require.Len(t, res.Results, 2)
require.Equal(t, i2SDK.ID, res.Results[0].ID)
require.Equal(t, i1SDK.ID, res.Results[1].ID)
require.Equal(t, &i1ApiKey.String, i1SDK.APIKeyID)
require.Nil(t, i2SDK.APIKeyID)
// Normalize timestamps in the response so we can compare the whole
// thing easily.
res.Results[0].StartedAt = i2SDK.StartedAt
res.Results[1].StartedAt = i1SDK.StartedAt
require.Len(t, res.Results[1].TokenUsages, 2)
require.Equal(t, i1SDK.TokenUsages[0].ID, res.Results[1].TokenUsages[0].ID)
require.Equal(t, i1SDK.TokenUsages[1].ID, res.Results[1].TokenUsages[1].ID)
res.Results[1].TokenUsages[0].CreatedAt = i1SDK.TokenUsages[0].CreatedAt
res.Results[1].TokenUsages[1].CreatedAt = i1SDK.TokenUsages[1].CreatedAt
require.Len(t, res.Results[1].UserPrompts, 2)
require.Equal(t, i1SDK.UserPrompts[0].ID, res.Results[1].UserPrompts[0].ID)
require.Equal(t, i1SDK.UserPrompts[1].ID, res.Results[1].UserPrompts[1].ID)
res.Results[1].UserPrompts[0].CreatedAt = i1SDK.UserPrompts[0].CreatedAt
res.Results[1].UserPrompts[1].CreatedAt = i1SDK.UserPrompts[1].CreatedAt
require.Len(t, res.Results[1].ToolUsages, 2)
require.Equal(t, i1SDK.ToolUsages[0].ID, res.Results[1].ToolUsages[0].ID)
require.Equal(t, i1SDK.ToolUsages[1].ID, res.Results[1].ToolUsages[1].ID)
res.Results[1].ToolUsages[0].CreatedAt = i1SDK.ToolUsages[0].CreatedAt
res.Results[1].ToolUsages[1].CreatedAt = i1SDK.ToolUsages[1].CreatedAt
// Time comparison
require.Len(t, res.Results, 2)
require.Equal(t, res.Results[0].ID, i2SDK.ID)
require.NotNil(t, res.Results[0].EndedAt)
require.WithinDuration(t, now, *res.Results[0].EndedAt, 5*time.Second)
res.Results[0].EndedAt = i2SDK.EndedAt
require.NotNil(t, res.Results[1].EndedAt)
res.Results[1].EndedAt = i1SDK.EndedAt
require.Equal(t, []codersdk.AIBridgeInterception{i2SDK, i1SDK}, res.Results)
})
t.Run("Pagination", func(t *testing.T) {
t.Parallel()
dv := coderdtest.DeploymentValues(t)
dv.AI.BridgeConfig.Enabled = serpent.Bool(true)
client, db, firstUser := coderdenttest.NewWithDatabase(t, &coderdenttest.Options{
Options: &coderdtest.Options{
DeploymentValues: dv,
},
LicenseOptions: &coderdenttest.LicenseOptions{
Features: license.Features{
codersdk.FeatureAIBridge: 1,
},
},
})
ctx := testutil.Context(t, testutil.WaitLong)
allInterceptionIDs := make([]uuid.UUID, 0, 20)
// Create 10 interceptions with the same started_at time. The returned
// order for these should still be deterministic.
now := dbtime.Now()
for i := range 10 {
interception := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{
ID: uuid.UUID{byte(i)},
InitiatorID: firstUser.UserID,
StartedAt: now,
}, &now)
allInterceptionIDs = append(allInterceptionIDs, interception.ID)
}
// Create 10 interceptions with a random started_at time.
for i := range 10 {
randomOffset, err := cryptorand.Intn(10000)
require.NoError(t, err)
randomOffsetDur := time.Duration(randomOffset) * time.Second
endedAt := now.Add(randomOffsetDur + time.Minute)
interception := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{
ID: uuid.UUID{byte(i + 10)},
InitiatorID: firstUser.UserID,
StartedAt: now.Add(randomOffsetDur),
}, &endedAt)
allInterceptionIDs = append(allInterceptionIDs, interception.ID)
}
// Try to fetch with an invalid limit.
res, err := client.AIBridgeListInterceptions(ctx, codersdk.AIBridgeListInterceptionsFilter{
Pagination: codersdk.Pagination{
Limit: 1001,
},
})
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
require.Contains(t, sdkErr.Message, "Invalid pagination limit value.")
require.Empty(t, res.Results)
// Try to fetch with both after_id and offset pagination.
res, err = client.AIBridgeListInterceptions(ctx, codersdk.AIBridgeListInterceptionsFilter{
Pagination: codersdk.Pagination{
AfterID: allInterceptionIDs[0],
Offset: 1,
},
})
require.ErrorAs(t, err, &sdkErr)
require.Contains(t, sdkErr.Message, "Query parameters have invalid values")
require.Contains(t, sdkErr.Detail, "Cannot use both after_id and offset pagination in the same request.")
// Iterate over all interceptions using both cursor and offset
// pagination modes.
for _, paginationMode := range []string{"after_id", "offset"} {
t.Run(paginationMode, func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
// Get all interceptions one by one using the given pagination
// mode.
getAllInterceptionsOneByOne := func() []uuid.UUID {
interceptionIDs := []uuid.UUID{}
for {
pagination := codersdk.Pagination{
Limit: 1,
}
if paginationMode == "after_id" {
if len(interceptionIDs) > 0 {
pagination.AfterID = interceptionIDs[len(interceptionIDs)-1]
}
} else {
pagination.Offset = len(interceptionIDs)
}
res, err := client.AIBridgeListInterceptions(ctx, codersdk.AIBridgeListInterceptionsFilter{
Pagination: pagination,
})
require.NoError(t, err)
if len(res.Results) == 0 {
break
}
require.EqualValues(t, len(allInterceptionIDs), res.Count)
require.Len(t, res.Results, 1)
interceptionIDs = append(interceptionIDs, res.Results[0].ID)
}
return interceptionIDs
}
// First attempt: get all interceptions one by one.
gotInterceptionIDs1 := getAllInterceptionsOneByOne()
// We should have all of the interceptions returned:
require.ElementsMatch(t, allInterceptionIDs, gotInterceptionIDs1)
// Second attempt: get all interceptions one by one again.
gotInterceptionIDs2 := getAllInterceptionsOneByOne()
// They should be returned in the exact same order.
require.Equal(t, gotInterceptionIDs1, gotInterceptionIDs2)
})
}
})
t.Run("InflightInterceptions", func(t *testing.T) {
t.Parallel()
dv := coderdtest.DeploymentValues(t)
dv.AI.BridgeConfig.Enabled = serpent.Bool(true)
client, db, firstUser := coderdenttest.NewWithDatabase(t, &coderdenttest.Options{
Options: &coderdtest.Options{
DeploymentValues: dv,
},
LicenseOptions: &coderdenttest.LicenseOptions{
Features: license.Features{
codersdk.FeatureAIBridge: 1,
},
},
})
ctx := testutil.Context(t, testutil.WaitLong)
now := dbtime.Now()
i1EndedAt := now.Add(time.Minute)
i1 := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{
InitiatorID: firstUser.UserID,
StartedAt: now,
}, &i1EndedAt)
dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{
InitiatorID: firstUser.UserID,
StartedAt: now.Add(-time.Hour),
}, nil)
res, err := client.AIBridgeListInterceptions(ctx, codersdk.AIBridgeListInterceptionsFilter{})
require.NoError(t, err)
require.EqualValues(t, 1, res.Count)
require.Len(t, res.Results, 1)
require.Equal(t, i1.ID, res.Results[0].ID)
})
t.Run("Authorized", func(t *testing.T) {
t.Parallel()
dv := coderdtest.DeploymentValues(t)
dv.AI.BridgeConfig.Enabled = serpent.Bool(true)
adminClient, db, firstUser := coderdenttest.NewWithDatabase(t, &coderdenttest.Options{
Options: &coderdtest.Options{
DeploymentValues: dv,
},
LicenseOptions: &coderdenttest.LicenseOptions{
Features: license.Features{
codersdk.FeatureAIBridge: 1,
},
},
})
ctx := testutil.Context(t, testutil.WaitLong)
secondUserClient, secondUser := coderdtest.CreateAnotherUser(t, adminClient, firstUser.OrganizationID)
now := dbtime.Now()
i1EndedAt := now.Add(time.Minute)
i1 := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{
InitiatorID: firstUser.UserID,
StartedAt: now,
}, &i1EndedAt)
i2 := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{
InitiatorID: secondUser.ID,
StartedAt: now.Add(-time.Hour),
}, &now)
// Admin can see all interceptions.
res, err := adminClient.AIBridgeListInterceptions(ctx, codersdk.AIBridgeListInterceptionsFilter{})
require.NoError(t, err)
require.EqualValues(t, 2, res.Count)
require.Len(t, res.Results, 2)
require.Equal(t, i1.ID, res.Results[0].ID)
require.Equal(t, i2.ID, res.Results[1].ID)
// Second user can only see their own interceptions.
res, err = secondUserClient.AIBridgeListInterceptions(ctx, codersdk.AIBridgeListInterceptionsFilter{})
require.NoError(t, err)
require.EqualValues(t, 1, res.Count)
require.Len(t, res.Results, 1)
require.Equal(t, i2.ID, res.Results[0].ID)
})
t.Run("Filter", func(t *testing.T) {
t.Parallel()
dv := coderdtest.DeploymentValues(t)
dv.AI.BridgeConfig.Enabled = serpent.Bool(true)
client, db, firstUser := coderdenttest.NewWithDatabase(t, &coderdenttest.Options{
Options: &coderdtest.Options{
DeploymentValues: dv,
},
LicenseOptions: &coderdenttest.LicenseOptions{
Features: license.Features{
codersdk.FeatureAIBridge: 1,
},
},
})
ctx := testutil.Context(t, testutil.WaitLong)
user1, err := client.User(ctx, codersdk.Me)
require.NoError(t, err)
user1Visible := database.VisibleUser{
ID: user1.ID,
Username: user1.Username,
Name: user1.Name,
AvatarURL: user1.AvatarURL,
}
_, user2 := coderdtest.CreateAnotherUser(t, client, firstUser.OrganizationID)
user2Visible := database.VisibleUser{
ID: user2.ID,
Username: user2.Username,
Name: user2.Name,
AvatarURL: user2.AvatarURL,
}
// Insert a bunch of test data with varying filterable fields.
now := dbtime.Now()
i1EndedAt := now.Add(time.Minute)
i1 := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{
ID: uuid.MustParse("00000000-0000-0000-0000-000000000001"),
InitiatorID: user1.ID,
Provider: "one",
Model: "one",
StartedAt: now,
}, &i1EndedAt)
i2 := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{
ID: uuid.MustParse("00000000-0000-0000-0000-000000000002"),
InitiatorID: user1.ID,
Provider: "two",
Model: "two",
StartedAt: now.Add(-time.Hour),
Client: sql.NullString{String: string(aiblib.ClientCursor), Valid: true},
}, &now)
i3 := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{
ID: uuid.MustParse("00000000-0000-0000-0000-000000000003"),
InitiatorID: user2.ID,
Provider: "three",
Model: "three",
StartedAt: now.Add(-2 * time.Hour),
Client: sql.NullString{String: string(aiblib.ClientClaudeCode), Valid: true},
}, &now)
// Convert to SDK types for response comparison. We don't care about the
// inner arrays for this test.
i1SDK := db2sdk.AIBridgeInterception(i1, user1Visible, nil, nil, nil)
i2SDK := db2sdk.AIBridgeInterception(i2, user1Visible, nil, nil, nil)
i3SDK := db2sdk.AIBridgeInterception(i3, user2Visible, nil, nil, nil)
cases := []struct {
name string
filter codersdk.AIBridgeListInterceptionsFilter
want []codersdk.AIBridgeInterception
}{
{
name: "NoFilter",
filter: codersdk.AIBridgeListInterceptionsFilter{},
want: []codersdk.AIBridgeInterception{i1SDK, i2SDK, i3SDK},
},
{
name: "Initiator/NoMatch",
filter: codersdk.AIBridgeListInterceptionsFilter{Initiator: uuid.New().String()},
want: []codersdk.AIBridgeInterception{},
},
{
name: "Initiator/Me",
filter: codersdk.AIBridgeListInterceptionsFilter{Initiator: codersdk.Me},
want: []codersdk.AIBridgeInterception{i1SDK, i2SDK},
},
{
name: "Initiator/UserID",
filter: codersdk.AIBridgeListInterceptionsFilter{Initiator: user2.ID.String()},
want: []codersdk.AIBridgeInterception{i3SDK},
},
{
name: "Initiator/Username",
filter: codersdk.AIBridgeListInterceptionsFilter{Initiator: user2.Username},
want: []codersdk.AIBridgeInterception{i3SDK},
},
{
name: "Provider/NoMatch",
filter: codersdk.AIBridgeListInterceptionsFilter{Provider: "nonsense"},
want: []codersdk.AIBridgeInterception{},
},
{
name: "Provider/OK",
filter: codersdk.AIBridgeListInterceptionsFilter{Provider: "two"},
want: []codersdk.AIBridgeInterception{i2SDK},
},
{
name: "Model/NoMatch",
filter: codersdk.AIBridgeListInterceptionsFilter{Model: "nonsense"},
want: []codersdk.AIBridgeInterception{},
},
{
name: "Model/OK",
filter: codersdk.AIBridgeListInterceptionsFilter{Model: "three"},
want: []codersdk.AIBridgeInterception{i3SDK},
},
{
name: "Client/Unknown",
filter: codersdk.AIBridgeListInterceptionsFilter{Client: "Unknown"},
want: []codersdk.AIBridgeInterception{i1SDK},
},
{
name: "Client/Match",
filter: codersdk.AIBridgeListInterceptionsFilter{Client: string(aiblib.ClientCursor)},
want: []codersdk.AIBridgeInterception{i2SDK},
},
{
name: "Client/NoMatch",
filter: codersdk.AIBridgeListInterceptionsFilter{Client: "nonsense"},
want: []codersdk.AIBridgeInterception{},
},
{
name: "StartedAfter/NoMatch",
filter: codersdk.AIBridgeListInterceptionsFilter{
StartedAfter: i1.StartedAt.Add(10 * time.Minute),
},
want: []codersdk.AIBridgeInterception{},
},
{
name: "StartedAfter/OK",
filter: codersdk.AIBridgeListInterceptionsFilter{
StartedAfter: i2.StartedAt.Add(-10 * time.Minute),
},
want: []codersdk.AIBridgeInterception{i1SDK, i2SDK},
},
{
name: "StartedBefore/NoMatch",
filter: codersdk.AIBridgeListInterceptionsFilter{
StartedBefore: i3.StartedAt.Add(-10 * time.Minute),
},
want: []codersdk.AIBridgeInterception{},
},
{
name: "StartedBefore/OK",
filter: codersdk.AIBridgeListInterceptionsFilter{
StartedBefore: i3.StartedAt.Add(10 * time.Minute),
},
want: []codersdk.AIBridgeInterception{i3SDK},
},
{
name: "BothBeforeAndAfter/NoMatch",
filter: codersdk.AIBridgeListInterceptionsFilter{
StartedAfter: i1.StartedAt.Add(10 * time.Minute),
StartedBefore: i1.StartedAt.Add(20 * time.Minute),
},
want: []codersdk.AIBridgeInterception{},
},
{
name: "BothBeforeAndAfter/OK",
filter: codersdk.AIBridgeListInterceptionsFilter{
StartedAfter: i2.StartedAt.Add(-10 * time.Minute),
StartedBefore: i2.StartedAt.Add(10 * time.Minute),
},
want: []codersdk.AIBridgeInterception{i2SDK},
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
res, err := client.AIBridgeListInterceptions(ctx, tc.filter)
require.NoError(t, err)
require.EqualValues(t, len(tc.want), res.Count)
// We just compare UUID strings for the sake of this test.
wantIDs := make([]string, len(tc.want))
for i, r := range tc.want {
wantIDs[i] = r.ID.String()
}
gotIDs := make([]string, len(res.Results))
for i, r := range res.Results {
gotIDs[i] = r.ID.String()
}
require.Equal(t, wantIDs, gotIDs)
})
}
})
t.Run("FilterErrors", func(t *testing.T) {
t.Parallel()
dv := coderdtest.DeploymentValues(t)
dv.AI.BridgeConfig.Enabled = serpent.Bool(true)
client, _ := coderdenttest.New(t, &coderdenttest.Options{
Options: &coderdtest.Options{
DeploymentValues: dv,
},
LicenseOptions: &coderdenttest.LicenseOptions{
Features: license.Features{
codersdk.FeatureAIBridge: 1,
},
},
})
// No need to insert any test data, we're just testing the filter
// errors.
cases := []struct {
name string
q string
want []codersdk.ValidationError
}{
{
name: "UnknownUsername",
q: "initiator:unknown",
want: []codersdk.ValidationError{
{
Field: "initiator",
Detail: `Query param "initiator" has invalid value: user "unknown" either does not exist, or you are unauthorized to view them`,
},
},
},
{
name: "InvalidStartedAfter",
q: "started_after:invalid",
want: []codersdk.ValidationError{
{
Field: "started_after",
Detail: `Query param "started_after" must be a valid date format (2006-01-02T15:04:05.999999999Z07:00): parsing time "INVALID" as "2006-01-02T15:04:05.999999999Z07:00": cannot parse "INVALID" as "2006"`,
},
},
},
{
name: "InvalidStartedBefore",
q: "started_before:invalid",
want: []codersdk.ValidationError{
{
Field: "started_before",
Detail: `Query param "started_before" must be a valid date format (2006-01-02T15:04:05.999999999Z07:00): parsing time "INVALID" as "2006-01-02T15:04:05.999999999Z07:00": cannot parse "INVALID" as "2006"`,
},
},
},
{
name: "InvalidBeforeAfterRange",
// Before MUST be after After if both are set
q: `started_after:"2025-01-01T00:00:00Z" started_before:"2024-01-01T00:00:00Z"`,
want: []codersdk.ValidationError{
{
Field: "started_before",
Detail: `Query param "started_before" has invalid value: "started_before" must be after "started_after" if set`,
},
},
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
res, err := client.AIBridgeListInterceptions(ctx, codersdk.AIBridgeListInterceptionsFilter{
FilterQuery: tc.q,
})
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, tc.want, sdkErr.Validations)
require.Empty(t, res.Results)
})
}
})
}
func TestAIBridgeRouting(t *testing.T) {
t.Parallel()
dv := coderdtest.DeploymentValues(t)
dv.AI.BridgeConfig.Enabled = serpent.Bool(true)
client, closer, api, _ := coderdenttest.NewWithAPI(t, &coderdenttest.Options{
Options: &coderdtest.Options{
DeploymentValues: dv,
},
LicenseOptions: &coderdenttest.LicenseOptions{
Features: license.Features{
codersdk.FeatureAIBridge: 1,
},
},
})
t.Cleanup(func() {
_ = closer.Close()
})
// Register a simple test handler that echoes back the request path.
testHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(http.StatusOK)
_, _ = rw.Write([]byte(r.URL.Path))
})
api.RegisterInMemoryAIBridgedHTTPHandler(testHandler)
cases := []struct {
name string
path string
expectedPath string
}{
{
name: "StablePrefix",
path: "/api/v2/aibridge/openai/v1/chat/completions",
expectedPath: "/openai/v1/chat/completions",
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, client.URL.String()+tc.path, nil)
require.NoError(t, err)
req.Header.Set(codersdk.SessionTokenHeader, client.SessionToken())
httpClient := &http.Client{}
resp, err := httpClient.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)
// Verify that the prefix was stripped correctly and the path was forwarded.
body, err := io.ReadAll(resp.Body)
require.NoError(t, err)
require.Equal(t, tc.expectedPath, string(body))
})
}
}
func TestAIBridgeRateLimiting(t *testing.T) {
t.Parallel()
dv := coderdtest.DeploymentValues(t)
dv.AI.BridgeConfig.Enabled = serpent.Bool(true)
// Set a low rate limit for testing.
dv.AI.BridgeConfig.RateLimit = 2
client, closer, api, _ := coderdenttest.NewWithAPI(t, &coderdenttest.Options{
Options: &coderdtest.Options{
DeploymentValues: dv,
},
LicenseOptions: &coderdenttest.LicenseOptions{
Features: license.Features{
codersdk.FeatureAIBridge: 1,
},
},
})
t.Cleanup(func() {
_ = closer.Close()
})
// Register a simple test handler.
testHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
rw.WriteHeader(http.StatusOK)
})
api.RegisterInMemoryAIBridgedHTTPHandler(testHandler)
ctx := testutil.Context(t, testutil.WaitLong)
httpClient := &http.Client{}
url := client.URL.String() + "/api/v2/aibridge/test"
// Make requests up to the limit - should succeed.
for range 2 {
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil)
require.NoError(t, err)
req.Header.Set(codersdk.SessionTokenHeader, client.SessionToken())
resp, err := httpClient.Do(req)
require.NoError(t, err)
_ = resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)
}
// Next request should be rate limited.
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil)
require.NoError(t, err)
req.Header.Set(codersdk.SessionTokenHeader, client.SessionToken())
resp, err := httpClient.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusTooManyRequests, resp.StatusCode)
require.NotEmpty(t, resp.Header.Get("Retry-After"))
}
func TestAIBridgeConcurrencyLimiting(t *testing.T) {
t.Parallel()
dv := coderdtest.DeploymentValues(t)
dv.AI.BridgeConfig.Enabled = serpent.Bool(true)
// Set a low concurrency limit for testing.
dv.AI.BridgeConfig.MaxConcurrency = 1
client, closer, api, _ := coderdenttest.NewWithAPI(t, &coderdenttest.Options{
Options: &coderdtest.Options{
DeploymentValues: dv,
},
LicenseOptions: &coderdenttest.LicenseOptions{
Features: license.Features{
codersdk.FeatureAIBridge: 1,
},
},
})
t.Cleanup(func() {
_ = closer.Close()
})
// Register a handler that blocks until signaled.
started := make(chan struct{})
unblock := make(chan struct{})
testHandler := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
started <- struct{}{}
<-unblock
rw.WriteHeader(http.StatusOK)
})
api.RegisterInMemoryAIBridgedHTTPHandler(testHandler)
ctx := testutil.Context(t, testutil.WaitLong)
httpClient := &http.Client{}
url := client.URL.String() + "/api/v2/aibridge/test"
// Start a request that will block.
done := make(chan struct{})
go func() {
defer close(done)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil)
if err != nil {
return
}
req.Header.Set(codersdk.SessionTokenHeader, client.SessionToken())
resp, err := httpClient.Do(req)
if err == nil {
_ = resp.Body.Close()
}
}()
// Wait for the first request to start processing.
select {
case <-started:
case <-ctx.Done():
t.Fatal("timed out waiting for first request to start")
}
// Second request should be rejected with 503.
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil)
require.NoError(t, err)
req.Header.Set(codersdk.SessionTokenHeader, client.SessionToken())
resp, err := httpClient.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusServiceUnavailable, resp.StatusCode)
// Unblock the first request and wait for it to complete.
close(unblock)
select {
case <-done:
case <-ctx.Done():
t.Fatal("timed out waiting for first request to complete")
}
}