From ca560d36ce4f005df564a20d3116a60151ec0135 Mon Sep 17 00:00:00 2001 From: Jake Howell Date: Tue, 25 Nov 2025 10:23:39 +1100 Subject: [PATCH] fix: remove inflight interceptions from aibridge returned values (#20852) Addresses [`aibridge#54`](https://github.com/coder/aibridge/issues/54) When querying against the values in the database for `/api/experimental/aibridge/interceptions` we found strange behaviour wherein there was interceptions that lacked prompting and other various fields we want. Generally this was as a result of the data not actually existing for these values (as they were inflight). The simple solution to this was to hide them if they didn't exist. This PR addresses that. --------- Co-authored-by: Danny Kopping --- coderd/database/queries.sql.go | 8 +++-- coderd/database/queries/aibridge.sql | 8 +++-- enterprise/cli/aibridge_test.go | 9 +++-- enterprise/coderd/aibridge_test.go | 49 +++++++++++++++++++++++++--- 4 files changed, 62 insertions(+), 12 deletions(-) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 0c0dd9fd83..0dca4d1eff 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -275,8 +275,10 @@ SELECT FROM aibridge_interceptions WHERE + -- Remove inflight interceptions (ones which lack an ended_at value). + aibridge_interceptions.ended_at IS NOT NULL -- Filter by time frame - CASE + AND CASE WHEN $1::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at >= $1::timestamptz ELSE true END @@ -744,8 +746,10 @@ FROM JOIN visible_users ON visible_users.id = aibridge_interceptions.initiator_id WHERE + -- Remove inflight interceptions (ones which lack an ended_at value). + aibridge_interceptions.ended_at IS NOT NULL -- Filter by time frame - CASE + AND CASE WHEN $1::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at >= $1::timestamptz ELSE true END diff --git a/coderd/database/queries/aibridge.sql b/coderd/database/queries/aibridge.sql index 4a1e346c86..cf87598115 100644 --- a/coderd/database/queries/aibridge.sql +++ b/coderd/database/queries/aibridge.sql @@ -89,8 +89,10 @@ SELECT FROM aibridge_interceptions WHERE + -- Remove inflight interceptions (ones which lack an ended_at value). + aibridge_interceptions.ended_at IS NOT NULL -- Filter by time frame - CASE + AND CASE WHEN @started_after::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at >= @started_after::timestamptz ELSE true END @@ -126,8 +128,10 @@ FROM JOIN visible_users ON visible_users.id = aibridge_interceptions.initiator_id WHERE + -- Remove inflight interceptions (ones which lack an ended_at value). + aibridge_interceptions.ended_at IS NOT NULL -- Filter by time frame - CASE + AND CASE WHEN @started_after::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at >= @started_after::timestamptz ELSE true END diff --git a/enterprise/cli/aibridge_test.go b/enterprise/cli/aibridge_test.go index a5b48a14e1..666dc69858 100644 --- a/enterprise/cli/aibridge_test.go +++ b/enterprise/cli/aibridge_test.go @@ -43,10 +43,11 @@ func TestAIBridgeListInterceptions(t *testing.T) { InitiatorID: member.ID, StartedAt: now.Add(-time.Hour), }, &now) + interception2EndedAt := now.Add(time.Minute) interception2 := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ InitiatorID: member.ID, StartedAt: now, - }, nil) + }, &interception2EndedAt) // Should not be returned because the user can't see it. _ = dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ InitiatorID: owner.UserID, @@ -91,12 +92,13 @@ func TestAIBridgeListInterceptions(t *testing.T) { now := dbtime.Now() // This interception should be returned since it matches all filters. + goodInterceptionEndedAt := now.Add(time.Minute) goodInterception := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ InitiatorID: member.ID, Provider: "real-provider", Model: "real-model", StartedAt: now, - }, nil) + }, &goodInterceptionEndedAt) // These interceptions should not be returned since they don't match the // filters. @@ -173,10 +175,11 @@ func TestAIBridgeListInterceptions(t *testing.T) { memberClient, member := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID) now := dbtime.Now() + firstInterceptionEndedAt := now.Add(time.Minute) firstInterception := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ InitiatorID: member.ID, StartedAt: now, - }, nil) + }, &firstInterceptionEndedAt) returnedInterception := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ InitiatorID: member.ID, StartedAt: now.Add(-time.Hour), diff --git a/enterprise/coderd/aibridge_test.go b/enterprise/coderd/aibridge_test.go index e95f1d99e2..2913fe516a 100644 --- a/enterprise/coderd/aibridge_test.go +++ b/enterprise/coderd/aibridge_test.go @@ -103,11 +103,12 @@ func TestAIBridgeListInterceptions(t *testing.T) { // Insert a bunch of test data. now := dbtime.Now() i1ApiKey := sql.NullString{String: "some-api-key", Valid: true} + i1EndedAt := now.Add(-time.Hour + time.Minute) i1 := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ APIKeyID: i1ApiKey, InitiatorID: user1.ID, StartedAt: now.Add(-time.Hour), - }, nil) + }, &i1EndedAt) i1tok1 := dbgen.AIBridgeTokenUsage(t, db, database.InsertAIBridgeTokenUsageParams{ InterceptionID: i1.ID, CreatedAt: now, @@ -175,9 +176,11 @@ func TestAIBridgeListInterceptions(t *testing.T) { // Time comparison require.Len(t, res.Results, 2) require.Equal(t, res.Results[0].ID, i2SDK.ID) - require.NotNil(t, now, res.Results[0].EndedAt) + require.NotNil(t, res.Results[0].EndedAt) require.WithinDuration(t, now, *res.Results[0].EndedAt, 5*time.Second) res.Results[0].EndedAt = i2SDK.EndedAt + require.NotNil(t, res.Results[1].EndedAt) + res.Results[1].EndedAt = i1SDK.EndedAt require.Equal(t, []codersdk.AIBridgeInterception{i2SDK, i1SDK}, res.Results) }) @@ -217,11 +220,12 @@ func TestAIBridgeListInterceptions(t *testing.T) { randomOffset, err := cryptorand.Intn(10000) require.NoError(t, err) randomOffsetDur := time.Duration(randomOffset) * time.Second + endedAt := now.Add(randomOffsetDur + time.Minute) interception := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ ID: uuid.UUID{byte(i + 10)}, InitiatorID: firstUser.UserID, StartedAt: now.Add(randomOffsetDur), - }, nil) + }, &endedAt) allInterceptionIDs = append(allInterceptionIDs, interception.ID) } @@ -297,6 +301,39 @@ func TestAIBridgeListInterceptions(t *testing.T) { } }) + t.Run("InflightInterceptions", func(t *testing.T) { + t.Parallel() + dv := coderdtest.DeploymentValues(t) + client, db, firstUser := coderdenttest.NewWithDatabase(t, &coderdenttest.Options{ + Options: &coderdtest.Options{ + DeploymentValues: dv, + }, + LicenseOptions: &coderdenttest.LicenseOptions{ + Features: license.Features{ + codersdk.FeatureAIBridge: 1, + }, + }, + }) + ctx := testutil.Context(t, testutil.WaitLong) + + now := dbtime.Now() + i1EndedAt := now.Add(time.Minute) + i1 := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ + InitiatorID: firstUser.UserID, + StartedAt: now, + }, &i1EndedAt) + dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ + InitiatorID: firstUser.UserID, + StartedAt: now.Add(-time.Hour), + }, nil) + + res, err := client.AIBridgeListInterceptions(ctx, codersdk.AIBridgeListInterceptionsFilter{}) + require.NoError(t, err) + require.EqualValues(t, 1, res.Count) + require.Len(t, res.Results, 1) + require.Equal(t, i1.ID, res.Results[0].ID) + }) + t.Run("Authorized", func(t *testing.T) { t.Parallel() dv := coderdtest.DeploymentValues(t) @@ -315,10 +352,11 @@ func TestAIBridgeListInterceptions(t *testing.T) { secondUserClient, secondUser := coderdtest.CreateAnotherUser(t, adminClient, firstUser.OrganizationID) now := dbtime.Now() + i1EndedAt := now.Add(time.Minute) i1 := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ InitiatorID: firstUser.UserID, StartedAt: now, - }, nil) + }, &i1EndedAt) i2 := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ InitiatorID: secondUser.ID, StartedAt: now.Add(-time.Hour), @@ -374,13 +412,14 @@ func TestAIBridgeListInterceptions(t *testing.T) { // Insert a bunch of test data with varying filterable fields. now := dbtime.Now() + i1EndedAt := now.Add(time.Minute) i1 := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ ID: uuid.MustParse("00000000-0000-0000-0000-000000000001"), InitiatorID: user1.ID, Provider: "one", Model: "one", StartedAt: now, - }, nil) + }, &i1EndedAt) i2 := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ ID: uuid.MustParse("00000000-0000-0000-0000-000000000002"), InitiatorID: user1.ID,