diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 0c0dd9fd83..0dca4d1eff 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -275,8 +275,10 @@ SELECT FROM aibridge_interceptions WHERE + -- Remove inflight interceptions (ones which lack an ended_at value). + aibridge_interceptions.ended_at IS NOT NULL -- Filter by time frame - CASE + AND CASE WHEN $1::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at >= $1::timestamptz ELSE true END @@ -744,8 +746,10 @@ FROM JOIN visible_users ON visible_users.id = aibridge_interceptions.initiator_id WHERE + -- Remove inflight interceptions (ones which lack an ended_at value). + aibridge_interceptions.ended_at IS NOT NULL -- Filter by time frame - CASE + AND CASE WHEN $1::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at >= $1::timestamptz ELSE true END diff --git a/coderd/database/queries/aibridge.sql b/coderd/database/queries/aibridge.sql index 4a1e346c86..cf87598115 100644 --- a/coderd/database/queries/aibridge.sql +++ b/coderd/database/queries/aibridge.sql @@ -89,8 +89,10 @@ SELECT FROM aibridge_interceptions WHERE + -- Remove inflight interceptions (ones which lack an ended_at value). + aibridge_interceptions.ended_at IS NOT NULL -- Filter by time frame - CASE + AND CASE WHEN @started_after::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at >= @started_after::timestamptz ELSE true END @@ -126,8 +128,10 @@ FROM JOIN visible_users ON visible_users.id = aibridge_interceptions.initiator_id WHERE + -- Remove inflight interceptions (ones which lack an ended_at value). + aibridge_interceptions.ended_at IS NOT NULL -- Filter by time frame - CASE + AND CASE WHEN @started_after::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at >= @started_after::timestamptz ELSE true END diff --git a/enterprise/cli/aibridge_test.go b/enterprise/cli/aibridge_test.go index a5b48a14e1..666dc69858 100644 --- a/enterprise/cli/aibridge_test.go +++ b/enterprise/cli/aibridge_test.go @@ -43,10 +43,11 @@ func TestAIBridgeListInterceptions(t *testing.T) { InitiatorID: member.ID, StartedAt: now.Add(-time.Hour), }, &now) + interception2EndedAt := now.Add(time.Minute) interception2 := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ InitiatorID: member.ID, StartedAt: now, - }, nil) + }, &interception2EndedAt) // Should not be returned because the user can't see it. _ = dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ InitiatorID: owner.UserID, @@ -91,12 +92,13 @@ func TestAIBridgeListInterceptions(t *testing.T) { now := dbtime.Now() // This interception should be returned since it matches all filters. + goodInterceptionEndedAt := now.Add(time.Minute) goodInterception := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ InitiatorID: member.ID, Provider: "real-provider", Model: "real-model", StartedAt: now, - }, nil) + }, &goodInterceptionEndedAt) // These interceptions should not be returned since they don't match the // filters. @@ -173,10 +175,11 @@ func TestAIBridgeListInterceptions(t *testing.T) { memberClient, member := coderdtest.CreateAnotherUser(t, client, owner.OrganizationID) now := dbtime.Now() + firstInterceptionEndedAt := now.Add(time.Minute) firstInterception := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ InitiatorID: member.ID, StartedAt: now, - }, nil) + }, &firstInterceptionEndedAt) returnedInterception := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ InitiatorID: member.ID, StartedAt: now.Add(-time.Hour), diff --git a/enterprise/coderd/aibridge_test.go b/enterprise/coderd/aibridge_test.go index e95f1d99e2..2913fe516a 100644 --- a/enterprise/coderd/aibridge_test.go +++ b/enterprise/coderd/aibridge_test.go @@ -103,11 +103,12 @@ func TestAIBridgeListInterceptions(t *testing.T) { // Insert a bunch of test data. now := dbtime.Now() i1ApiKey := sql.NullString{String: "some-api-key", Valid: true} + i1EndedAt := now.Add(-time.Hour + time.Minute) i1 := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ APIKeyID: i1ApiKey, InitiatorID: user1.ID, StartedAt: now.Add(-time.Hour), - }, nil) + }, &i1EndedAt) i1tok1 := dbgen.AIBridgeTokenUsage(t, db, database.InsertAIBridgeTokenUsageParams{ InterceptionID: i1.ID, CreatedAt: now, @@ -175,9 +176,11 @@ func TestAIBridgeListInterceptions(t *testing.T) { // Time comparison require.Len(t, res.Results, 2) require.Equal(t, res.Results[0].ID, i2SDK.ID) - require.NotNil(t, now, res.Results[0].EndedAt) + require.NotNil(t, res.Results[0].EndedAt) require.WithinDuration(t, now, *res.Results[0].EndedAt, 5*time.Second) res.Results[0].EndedAt = i2SDK.EndedAt + require.NotNil(t, res.Results[1].EndedAt) + res.Results[1].EndedAt = i1SDK.EndedAt require.Equal(t, []codersdk.AIBridgeInterception{i2SDK, i1SDK}, res.Results) }) @@ -217,11 +220,12 @@ func TestAIBridgeListInterceptions(t *testing.T) { randomOffset, err := cryptorand.Intn(10000) require.NoError(t, err) randomOffsetDur := time.Duration(randomOffset) * time.Second + endedAt := now.Add(randomOffsetDur + time.Minute) interception := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ ID: uuid.UUID{byte(i + 10)}, InitiatorID: firstUser.UserID, StartedAt: now.Add(randomOffsetDur), - }, nil) + }, &endedAt) allInterceptionIDs = append(allInterceptionIDs, interception.ID) } @@ -297,6 +301,39 @@ func TestAIBridgeListInterceptions(t *testing.T) { } }) + t.Run("InflightInterceptions", func(t *testing.T) { + t.Parallel() + dv := coderdtest.DeploymentValues(t) + client, db, firstUser := coderdenttest.NewWithDatabase(t, &coderdenttest.Options{ + Options: &coderdtest.Options{ + DeploymentValues: dv, + }, + LicenseOptions: &coderdenttest.LicenseOptions{ + Features: license.Features{ + codersdk.FeatureAIBridge: 1, + }, + }, + }) + ctx := testutil.Context(t, testutil.WaitLong) + + now := dbtime.Now() + i1EndedAt := now.Add(time.Minute) + i1 := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ + InitiatorID: firstUser.UserID, + StartedAt: now, + }, &i1EndedAt) + dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ + InitiatorID: firstUser.UserID, + StartedAt: now.Add(-time.Hour), + }, nil) + + res, err := client.AIBridgeListInterceptions(ctx, codersdk.AIBridgeListInterceptionsFilter{}) + require.NoError(t, err) + require.EqualValues(t, 1, res.Count) + require.Len(t, res.Results, 1) + require.Equal(t, i1.ID, res.Results[0].ID) + }) + t.Run("Authorized", func(t *testing.T) { t.Parallel() dv := coderdtest.DeploymentValues(t) @@ -315,10 +352,11 @@ func TestAIBridgeListInterceptions(t *testing.T) { secondUserClient, secondUser := coderdtest.CreateAnotherUser(t, adminClient, firstUser.OrganizationID) now := dbtime.Now() + i1EndedAt := now.Add(time.Minute) i1 := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ InitiatorID: firstUser.UserID, StartedAt: now, - }, nil) + }, &i1EndedAt) i2 := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ InitiatorID: secondUser.ID, StartedAt: now.Add(-time.Hour), @@ -374,13 +412,14 @@ func TestAIBridgeListInterceptions(t *testing.T) { // Insert a bunch of test data with varying filterable fields. now := dbtime.Now() + i1EndedAt := now.Add(time.Minute) i1 := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ ID: uuid.MustParse("00000000-0000-0000-0000-000000000001"), InitiatorID: user1.ID, Provider: "one", Model: "one", StartedAt: now, - }, nil) + }, &i1EndedAt) i2 := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ ID: uuid.MustParse("00000000-0000-0000-0000-000000000002"), InitiatorID: user1.ID,