diff --git a/coderd/database/modelqueries.go b/coderd/database/modelqueries.go index 2b92947a14..356d262600 100644 --- a/coderd/database/modelqueries.go +++ b/coderd/database/modelqueries.go @@ -996,8 +996,6 @@ func (q *sqlQuerier) ListAuthorizedAIBridgeSessions(ctx context.Context, arg Lis query := fmt.Sprintf("-- name: ListAuthorizedAIBridgeSessions :many\n%s", filtered) rows, err := q.db.QueryContext(ctx, query, arg.AfterSessionID, - arg.Offset, - arg.Limit, arg.StartedAfter, arg.StartedBefore, arg.InitiatorID, @@ -1005,6 +1003,8 @@ func (q *sqlQuerier) ListAuthorizedAIBridgeSessions(ctx context.Context, arg Lis arg.Model, arg.Client, arg.SessionID, + arg.Offset, + arg.Limit, ) if err != nil { return nil, err diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 186603a66c..f73dcee469 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -788,6 +788,10 @@ type sqlcQuerier interface { // Returns paginated sessions with aggregated metadata, token counts, and // the most recent user prompt. A "session" is a logical grouping of // interceptions that share the same session_id (set by the client). + // + // Pagination-first strategy: identify the page of sessions cheaply via a + // single GROUP BY scan, then do expensive lateral joins (tokens, prompts, + // first-interception metadata) only for the ~page-size result set. ListAIBridgeSessions(ctx context.Context, arg ListAIBridgeSessionsParams) ([]ListAIBridgeSessionsRow, error) ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeTokenUsage, error) ListAIBridgeToolUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeToolUsage, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 5287e13814..f1c76fb7fe 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -1347,95 +1347,87 @@ func (q *sqlQuerier) ListAIBridgeSessionThreads(ctx context.Context, arg ListAIB } const listAIBridgeSessions = `-- name: ListAIBridgeSessions :many -WITH filtered_interceptions AS ( +WITH cursor_pos AS ( + -- Resolve the cursor's started_at once, outside the HAVING clause, + -- so the planner cannot accidentally re-evaluate it per group. + SELECT MIN(aibridge_interceptions.started_at) AS started_at + FROM aibridge_interceptions + WHERE aibridge_interceptions.session_id = $1 AND aibridge_interceptions.ended_at IS NOT NULL +), +session_page AS ( + -- Paginate at the session level first; only cheap aggregates here. SELECT - aibridge_interceptions.id, aibridge_interceptions.initiator_id, aibridge_interceptions.provider, aibridge_interceptions.model, aibridge_interceptions.started_at, aibridge_interceptions.metadata, aibridge_interceptions.ended_at, aibridge_interceptions.api_key_id, aibridge_interceptions.client, aibridge_interceptions.thread_parent_id, aibridge_interceptions.thread_root_id, aibridge_interceptions.client_session_id, aibridge_interceptions.session_id + ai.session_id, + ai.initiator_id, + MIN(ai.started_at) AS started_at, + MAX(ai.ended_at) AS ended_at, + COUNT(*) FILTER (WHERE ai.thread_root_id IS NULL) AS threads FROM - aibridge_interceptions + aibridge_interceptions ai WHERE -- Remove inflight interceptions (ones which lack an ended_at value). - aibridge_interceptions.ended_at IS NOT NULL + ai.ended_at IS NOT NULL -- Filter by time frame AND CASE - WHEN $4::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at >= $4::timestamptz + WHEN $2::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN ai.started_at >= $2::timestamptz ELSE true END AND CASE - WHEN $5::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at <= $5::timestamptz + WHEN $3::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN ai.started_at <= $3::timestamptz ELSE true END -- Filter initiator_id AND CASE - WHEN $6::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN aibridge_interceptions.initiator_id = $6::uuid + WHEN $4::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN ai.initiator_id = $4::uuid ELSE true END -- Filter provider AND CASE - WHEN $7::text != '' THEN aibridge_interceptions.provider = $7::text + WHEN $5::text != '' THEN ai.provider = $5::text ELSE true END -- Filter model AND CASE - WHEN $8::text != '' THEN aibridge_interceptions.model = $8::text + WHEN $6::text != '' THEN ai.model = $6::text ELSE true END -- Filter client AND CASE - WHEN $9::text != '' THEN COALESCE(aibridge_interceptions.client, 'Unknown') = $9::text + WHEN $7::text != '' THEN COALESCE(ai.client, 'Unknown') = $7::text ELSE true END -- Filter session_id AND CASE - WHEN $10::text != '' THEN aibridge_interceptions.session_id = $10::text + WHEN $8::text != '' THEN ai.session_id = $8::text ELSE true END -- Authorize Filter clause will be injected below in ListAuthorizedAIBridgeSessions -- @authorize_filter -), -session_tokens AS ( - -- Aggregate token usage across all interceptions in each session. - -- Group by (session_id, initiator_id) to avoid merging sessions from - -- different users who happen to share the same client_session_id. - SELECT - fi.session_id, - fi.initiator_id, - COALESCE(SUM(tu.input_tokens), 0)::bigint AS input_tokens, - COALESCE(SUM(tu.output_tokens), 0)::bigint AS output_tokens - -- TODO: add extra token types once https://github.com/coder/aibridge/issues/150 lands. - FROM - filtered_interceptions fi - LEFT JOIN - aibridge_token_usages tu ON fi.id = tu.interception_id GROUP BY - fi.session_id, fi.initiator_id -), -session_root AS ( - -- Build one summary row per session. Group by (session_id, initiator_id) - -- to avoid merging sessions from different users who happen to share the - -- same client_session_id. The ARRAY_AGG with ORDER BY picks values from - -- the chronologically first interception for fields that should represent - -- the session as a whole (client, metadata). Threads are counted as - -- distinct root interception IDs: an interception with a NULL - -- thread_root_id is itself a thread root. - SELECT - fi.session_id, - fi.initiator_id, - (ARRAY_AGG(fi.client ORDER BY fi.started_at, fi.id))[1] AS client, - (ARRAY_AGG(fi.metadata ORDER BY fi.started_at, fi.id))[1] AS metadata, - ARRAY_AGG(DISTINCT fi.provider ORDER BY fi.provider) AS providers, - ARRAY_AGG(DISTINCT fi.model ORDER BY fi.model) AS models, - MIN(fi.started_at) AS started_at, - MAX(fi.ended_at) AS ended_at, - COUNT(DISTINCT COALESCE(fi.thread_root_id, fi.id)) AS threads, - -- Collect IDs for lateral prompt lookup. - ARRAY_AGG(fi.id) AS interception_ids - FROM - filtered_interceptions fi - GROUP BY - fi.session_id, fi.initiator_id + ai.session_id, ai.initiator_id + HAVING + -- Cursor pagination: uses a composite (started_at, session_id) + -- cursor to support keyset pagination. The less-than comparison + -- matches the DESC sort order so rows after the cursor come + -- later in results. The cursor value comes from cursor_pos to + -- guarantee single evaluation. + CASE + WHEN $1::text != '' THEN ( + (MIN(ai.started_at), ai.session_id) < ( + (SELECT started_at FROM cursor_pos), + $1::text + ) + ) + ELSE true + END + ORDER BY + MIN(ai.started_at) DESC, + ai.session_id DESC + LIMIT COALESCE(NULLIF($10::integer, 0), 100) + OFFSET $9 ) SELECT - sr.session_id, + sp.session_id, visible_users.id AS user_id, visible_users.username AS user_username, visible_users.name AS user_name, @@ -1444,51 +1436,52 @@ SELECT sr.models::text[] AS models, COALESCE(sr.client, '')::varchar(64) AS client, sr.metadata::jsonb AS metadata, - sr.started_at::timestamptz AS started_at, - sr.ended_at::timestamptz AS ended_at, - sr.threads, + sp.started_at::timestamptz AS started_at, + sp.ended_at::timestamptz AS ended_at, + sp.threads, COALESCE(st.input_tokens, 0)::bigint AS input_tokens, COALESCE(st.output_tokens, 0)::bigint AS output_tokens, COALESCE(slp.prompt, '') AS last_prompt FROM - session_root sr + session_page sp JOIN - visible_users ON visible_users.id = sr.initiator_id -LEFT JOIN - session_tokens st ON st.session_id = sr.session_id AND st.initiator_id = sr.initiator_id + visible_users ON visible_users.id = sp.initiator_id LEFT JOIN LATERAL ( - -- Lateral join to efficiently fetch only the most recent user prompt - -- across all interceptions in the session, avoiding a full aggregation. + SELECT + (ARRAY_AGG(ai.client ORDER BY ai.started_at, ai.id))[1] AS client, + (ARRAY_AGG(ai.metadata ORDER BY ai.started_at, ai.id))[1] AS metadata, + ARRAY_AGG(DISTINCT ai.provider ORDER BY ai.provider) AS providers, + ARRAY_AGG(DISTINCT ai.model ORDER BY ai.model) AS models, + ARRAY_AGG(ai.id) AS interception_ids + FROM aibridge_interceptions ai + WHERE ai.session_id = sp.session_id + AND ai.initiator_id = sp.initiator_id + AND ai.ended_at IS NOT NULL +) sr ON true +LEFT JOIN LATERAL ( + -- Aggregate tokens only for this session's interceptions. + SELECT + COALESCE(SUM(tu.input_tokens), 0)::bigint AS input_tokens, + COALESCE(SUM(tu.output_tokens), 0)::bigint AS output_tokens + FROM aibridge_token_usages tu + WHERE tu.interception_id = ANY(sr.interception_ids) +) st ON true +LEFT JOIN LATERAL ( + -- Fetch only the most recent user prompt across all interceptions + -- in the session. SELECT up.prompt FROM aibridge_user_prompts up WHERE up.interception_id = ANY(sr.interception_ids) ORDER BY up.created_at DESC, up.id DESC LIMIT 1 ) slp ON true -WHERE - -- Cursor pagination: uses a composite (started_at, session_id) cursor - -- to support keyset pagination. The less-than comparison matches the - -- DESC sort order so that rows after the cursor come later in results. - CASE - WHEN $1::text != '' THEN ( - (sr.started_at, sr.session_id) < ( - (SELECT started_at FROM session_root WHERE session_id = $1), - $1::text - ) - ) - ELSE true - END ORDER BY - sr.started_at DESC, - sr.session_id DESC -LIMIT COALESCE(NULLIF($3::integer, 0), 100) -OFFSET $2 + sp.started_at DESC, + sp.session_id DESC ` type ListAIBridgeSessionsParams struct { AfterSessionID string `db:"after_session_id" json:"after_session_id"` - Offset int32 `db:"offset_" json:"offset_"` - Limit int32 `db:"limit_" json:"limit_"` StartedAfter time.Time `db:"started_after" json:"started_after"` StartedBefore time.Time `db:"started_before" json:"started_before"` InitiatorID uuid.UUID `db:"initiator_id" json:"initiator_id"` @@ -1496,6 +1489,8 @@ type ListAIBridgeSessionsParams struct { Model string `db:"model" json:"model"` Client string `db:"client" json:"client"` SessionID string `db:"session_id" json:"session_id"` + Offset int32 `db:"offset_" json:"offset_"` + Limit int32 `db:"limit_" json:"limit_"` } type ListAIBridgeSessionsRow struct { @@ -1519,11 +1514,13 @@ type ListAIBridgeSessionsRow struct { // Returns paginated sessions with aggregated metadata, token counts, and // the most recent user prompt. A "session" is a logical grouping of // interceptions that share the same session_id (set by the client). +// +// Pagination-first strategy: identify the page of sessions cheaply via a +// single GROUP BY scan, then do expensive lateral joins (tokens, prompts, +// first-interception metadata) only for the ~page-size result set. func (q *sqlQuerier) ListAIBridgeSessions(ctx context.Context, arg ListAIBridgeSessionsParams) ([]ListAIBridgeSessionsRow, error) { rows, err := q.db.QueryContext(ctx, listAIBridgeSessions, arg.AfterSessionID, - arg.Offset, - arg.Limit, arg.StartedAfter, arg.StartedBefore, arg.InitiatorID, @@ -1531,6 +1528,8 @@ func (q *sqlQuerier) ListAIBridgeSessions(ctx context.Context, arg ListAIBridgeS arg.Model, arg.Client, arg.SessionID, + arg.Offset, + arg.Limit, ) if err != nil { return nil, err diff --git a/coderd/database/queries/aibridge.sql b/coderd/database/queries/aibridge.sql index f4f03ff1cc..005e6d01f7 100644 --- a/coderd/database/queries/aibridge.sql +++ b/coderd/database/queries/aibridge.sql @@ -454,95 +454,91 @@ WHERE -- Returns paginated sessions with aggregated metadata, token counts, and -- the most recent user prompt. A "session" is a logical grouping of -- interceptions that share the same session_id (set by the client). -WITH filtered_interceptions AS ( +-- +-- Pagination-first strategy: identify the page of sessions cheaply via a +-- single GROUP BY scan, then do expensive lateral joins (tokens, prompts, +-- first-interception metadata) only for the ~page-size result set. +WITH cursor_pos AS ( + -- Resolve the cursor's started_at once, outside the HAVING clause, + -- so the planner cannot accidentally re-evaluate it per group. + SELECT MIN(aibridge_interceptions.started_at) AS started_at + FROM aibridge_interceptions + WHERE aibridge_interceptions.session_id = @after_session_id AND aibridge_interceptions.ended_at IS NOT NULL +), +session_page AS ( + -- Paginate at the session level first; only cheap aggregates here. SELECT - aibridge_interceptions.* + ai.session_id, + ai.initiator_id, + MIN(ai.started_at) AS started_at, + MAX(ai.ended_at) AS ended_at, + COUNT(*) FILTER (WHERE ai.thread_root_id IS NULL) AS threads FROM - aibridge_interceptions + aibridge_interceptions ai WHERE -- Remove inflight interceptions (ones which lack an ended_at value). - aibridge_interceptions.ended_at IS NOT NULL + ai.ended_at IS NOT NULL -- Filter by time frame AND CASE - WHEN @started_after::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at >= @started_after::timestamptz + WHEN @started_after::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN ai.started_at >= @started_after::timestamptz ELSE true END AND CASE - WHEN @started_before::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at <= @started_before::timestamptz + WHEN @started_before::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN ai.started_at <= @started_before::timestamptz ELSE true END -- Filter initiator_id AND CASE - WHEN @initiator_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN aibridge_interceptions.initiator_id = @initiator_id::uuid + WHEN @initiator_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN ai.initiator_id = @initiator_id::uuid ELSE true END -- Filter provider AND CASE - WHEN @provider::text != '' THEN aibridge_interceptions.provider = @provider::text + WHEN @provider::text != '' THEN ai.provider = @provider::text ELSE true END -- Filter model AND CASE - WHEN @model::text != '' THEN aibridge_interceptions.model = @model::text + WHEN @model::text != '' THEN ai.model = @model::text ELSE true END -- Filter client AND CASE - WHEN @client::text != '' THEN COALESCE(aibridge_interceptions.client, 'Unknown') = @client::text + WHEN @client::text != '' THEN COALESCE(ai.client, 'Unknown') = @client::text ELSE true END -- Filter session_id AND CASE - WHEN @session_id::text != '' THEN aibridge_interceptions.session_id = @session_id::text + WHEN @session_id::text != '' THEN ai.session_id = @session_id::text ELSE true END -- Authorize Filter clause will be injected below in ListAuthorizedAIBridgeSessions -- @authorize_filter -), -session_tokens AS ( - -- Aggregate token usage across all interceptions in each session. - -- Group by (session_id, initiator_id) to avoid merging sessions from - -- different users who happen to share the same client_session_id. - SELECT - fi.session_id, - fi.initiator_id, - COALESCE(SUM(tu.input_tokens), 0)::bigint AS input_tokens, - COALESCE(SUM(tu.output_tokens), 0)::bigint AS output_tokens - -- TODO: add extra token types once https://github.com/coder/aibridge/issues/150 lands. - FROM - filtered_interceptions fi - LEFT JOIN - aibridge_token_usages tu ON fi.id = tu.interception_id GROUP BY - fi.session_id, fi.initiator_id -), -session_root AS ( - -- Build one summary row per session. Group by (session_id, initiator_id) - -- to avoid merging sessions from different users who happen to share the - -- same client_session_id. The ARRAY_AGG with ORDER BY picks values from - -- the chronologically first interception for fields that should represent - -- the session as a whole (client, metadata). Threads are counted as - -- distinct root interception IDs: an interception with a NULL - -- thread_root_id is itself a thread root. - SELECT - fi.session_id, - fi.initiator_id, - (ARRAY_AGG(fi.client ORDER BY fi.started_at, fi.id))[1] AS client, - (ARRAY_AGG(fi.metadata ORDER BY fi.started_at, fi.id))[1] AS metadata, - ARRAY_AGG(DISTINCT fi.provider ORDER BY fi.provider) AS providers, - ARRAY_AGG(DISTINCT fi.model ORDER BY fi.model) AS models, - MIN(fi.started_at) AS started_at, - MAX(fi.ended_at) AS ended_at, - COUNT(DISTINCT COALESCE(fi.thread_root_id, fi.id)) AS threads, - -- Collect IDs for lateral prompt lookup. - ARRAY_AGG(fi.id) AS interception_ids - FROM - filtered_interceptions fi - GROUP BY - fi.session_id, fi.initiator_id + ai.session_id, ai.initiator_id + HAVING + -- Cursor pagination: uses a composite (started_at, session_id) + -- cursor to support keyset pagination. The less-than comparison + -- matches the DESC sort order so rows after the cursor come + -- later in results. The cursor value comes from cursor_pos to + -- guarantee single evaluation. + CASE + WHEN @after_session_id::text != '' THEN ( + (MIN(ai.started_at), ai.session_id) < ( + (SELECT started_at FROM cursor_pos), + @after_session_id::text + ) + ) + ELSE true + END + ORDER BY + MIN(ai.started_at) DESC, + ai.session_id DESC + LIMIT COALESCE(NULLIF(@limit_::integer, 0), 100) + OFFSET @offset_ ) SELECT - sr.session_id, + sp.session_id, visible_users.id AS user_id, visible_users.username AS user_username, visible_users.name AS user_name, @@ -551,45 +547,48 @@ SELECT sr.models::text[] AS models, COALESCE(sr.client, '')::varchar(64) AS client, sr.metadata::jsonb AS metadata, - sr.started_at::timestamptz AS started_at, - sr.ended_at::timestamptz AS ended_at, - sr.threads, + sp.started_at::timestamptz AS started_at, + sp.ended_at::timestamptz AS ended_at, + sp.threads, COALESCE(st.input_tokens, 0)::bigint AS input_tokens, COALESCE(st.output_tokens, 0)::bigint AS output_tokens, COALESCE(slp.prompt, '') AS last_prompt FROM - session_root sr + session_page sp JOIN - visible_users ON visible_users.id = sr.initiator_id -LEFT JOIN - session_tokens st ON st.session_id = sr.session_id AND st.initiator_id = sr.initiator_id + visible_users ON visible_users.id = sp.initiator_id LEFT JOIN LATERAL ( - -- Lateral join to efficiently fetch only the most recent user prompt - -- across all interceptions in the session, avoiding a full aggregation. + SELECT + (ARRAY_AGG(ai.client ORDER BY ai.started_at, ai.id))[1] AS client, + (ARRAY_AGG(ai.metadata ORDER BY ai.started_at, ai.id))[1] AS metadata, + ARRAY_AGG(DISTINCT ai.provider ORDER BY ai.provider) AS providers, + ARRAY_AGG(DISTINCT ai.model ORDER BY ai.model) AS models, + ARRAY_AGG(ai.id) AS interception_ids + FROM aibridge_interceptions ai + WHERE ai.session_id = sp.session_id + AND ai.initiator_id = sp.initiator_id + AND ai.ended_at IS NOT NULL +) sr ON true +LEFT JOIN LATERAL ( + -- Aggregate tokens only for this session's interceptions. + SELECT + COALESCE(SUM(tu.input_tokens), 0)::bigint AS input_tokens, + COALESCE(SUM(tu.output_tokens), 0)::bigint AS output_tokens + FROM aibridge_token_usages tu + WHERE tu.interception_id = ANY(sr.interception_ids) +) st ON true +LEFT JOIN LATERAL ( + -- Fetch only the most recent user prompt across all interceptions + -- in the session. SELECT up.prompt FROM aibridge_user_prompts up WHERE up.interception_id = ANY(sr.interception_ids) ORDER BY up.created_at DESC, up.id DESC LIMIT 1 ) slp ON true -WHERE - -- Cursor pagination: uses a composite (started_at, session_id) cursor - -- to support keyset pagination. The less-than comparison matches the - -- DESC sort order so that rows after the cursor come later in results. - CASE - WHEN @after_session_id::text != '' THEN ( - (sr.started_at, sr.session_id) < ( - (SELECT started_at FROM session_root WHERE session_id = @after_session_id), - @after_session_id::text - ) - ) - ELSE true - END ORDER BY - sr.started_at DESC, - sr.session_id DESC -LIMIT COALESCE(NULLIF(@limit_::integer, 0), 100) -OFFSET @offset_ + sp.started_at DESC, + sp.session_id DESC ; -- name: ListAIBridgeSessionThreads :many diff --git a/enterprise/coderd/aibridge_test.go b/enterprise/coderd/aibridge_test.go index 6f2f80258b..8f088c3689 100644 --- a/enterprise/coderd/aibridge_test.go +++ b/enterprise/coderd/aibridge_test.go @@ -440,7 +440,7 @@ func TestAIBridgeListInterceptions(t *testing.T) { }, { name: "Client/Unknown", - filter: codersdk.AIBridgeListInterceptionsFilter{Client: "Unknown"}, + filter: codersdk.AIBridgeListInterceptionsFilter{Client: string(aiblib.ClientUnknown)}, want: []codersdk.AIBridgeInterception{i1SDK}, }, { @@ -1213,6 +1213,302 @@ func TestAIBridgeListSessions(t *testing.T) { require.Contains(t, sdkErr.Message, "Invalid pagination limit value.") require.Empty(t, res.Sessions) }) + + t.Run("StartedBeforeFilter", func(t *testing.T) { + t.Parallel() + client, db, firstUser := coderdenttest.NewWithDatabase(t, aibridgeOpts(t)) + ctx := testutil.Context(t, testutil.WaitLong) + + now := dbtime.Now() + + // Session started recently. + recentEndedAt := now.Add(time.Minute) + dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ + InitiatorID: firstUser.UserID, + StartedAt: now, + }, &recentEndedAt) + + // Session started 2 hours ago. + oldEndedAt := now.Add(-2*time.Hour + time.Minute) + old := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ + InitiatorID: firstUser.UserID, + StartedAt: now.Add(-2 * time.Hour), + }, &oldEndedAt) + + // Only the old session should be returned when started_before + // is set to 1 hour ago. + //nolint:gocritic // Owner role is irrelevant; testing filter. + res, err := client.AIBridgeListSessions(ctx, codersdk.AIBridgeListSessionsFilter{ + StartedBefore: now.Add(-time.Hour), + }) + require.NoError(t, err) + require.EqualValues(t, 1, res.Count) + require.Len(t, res.Sessions, 1) + require.Equal(t, old.ID.String(), res.Sessions[0].ID) + }) + + t.Run("NullClientCoalescesToUnknown", func(t *testing.T) { + t.Parallel() + client, db, firstUser := coderdenttest.NewWithDatabase(t, aibridgeOpts(t)) + ctx := testutil.Context(t, testutil.WaitLong) + + now := dbtime.Now() + + // Session with explicit client. + withClientEndedAt := now.Add(time.Minute) + dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ + InitiatorID: firstUser.UserID, + StartedAt: now, + Client: sql.NullString{String: "claude-code", Valid: true}, + }, &withClientEndedAt) + + // Session with NULL client (should COALESCE to ClientUnknown). + nullClientEndedAt := now.Add(-time.Hour + time.Minute) + nullClient := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ + InitiatorID: firstUser.UserID, + StartedAt: now.Add(-time.Hour), + // Client field deliberately omitted (NULL). + }, &nullClientEndedAt) + + // Filtering by ClientUnknown should return only the NULL-client + // session. + //nolint:gocritic // Owner role is irrelevant; testing COALESCE. + res, err := client.AIBridgeListSessions(ctx, codersdk.AIBridgeListSessionsFilter{ + Client: string(aiblib.ClientUnknown), + }) + require.NoError(t, err) + require.EqualValues(t, 1, res.Count) + require.Len(t, res.Sessions, 1) + require.Equal(t, nullClient.ID.String(), res.Sessions[0].ID) + }) + + t.Run("MetadataFromFirstInterception", func(t *testing.T) { + t.Parallel() + client, db, firstUser := coderdenttest.NewWithDatabase(t, aibridgeOpts(t)) + ctx := testutil.Context(t, testutil.WaitLong) + + now := dbtime.Now() + + // First interception (chronologically) carries the expected + // metadata for the session. + i1EndedAt := now.Add(time.Minute) + dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ + InitiatorID: firstUser.UserID, + StartedAt: now, + Metadata: json.RawMessage(`{"editor":"vscode"}`), + Client: sql.NullString{String: "claude-code", Valid: true}, + ClientSessionID: sql.NullString{String: "meta-session", Valid: true}, + }, &i1EndedAt) + + // Second interception has different metadata. + i2EndedAt := now.Add(2 * time.Minute) + dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ + InitiatorID: firstUser.UserID, + StartedAt: now.Add(time.Minute), + Metadata: json.RawMessage(`{"editor":"jetbrains"}`), + Client: sql.NullString{String: "claude-code", Valid: true}, + ClientSessionID: sql.NullString{String: "meta-session", Valid: true}, + }, &i2EndedAt) + + //nolint:gocritic // Owner role is irrelevant; testing metadata. + res, err := client.AIBridgeListSessions(ctx, codersdk.AIBridgeListSessionsFilter{}) + require.NoError(t, err) + require.Len(t, res.Sessions, 1) + // Metadata should come from the first interception. + require.Equal(t, "vscode", res.Sessions[0].Metadata["editor"]) + }) + + t.Run("SessionTimestamps", func(t *testing.T) { + t.Parallel() + client, db, firstUser := coderdenttest.NewWithDatabase(t, aibridgeOpts(t)) + ctx := testutil.Context(t, testutil.WaitLong) + + now := dbtime.Now() + + // Two interceptions in the same session with different + // started_at and ended_at values. The session should report + // MIN(started_at) and MAX(ended_at). + i1StartedAt := now + i1EndedAt := now.Add(time.Minute) + dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ + InitiatorID: firstUser.UserID, + StartedAt: i1StartedAt, + ClientSessionID: sql.NullString{String: "ts-session", Valid: true}, + }, &i1EndedAt) + + i2StartedAt := now.Add(2 * time.Minute) + i2EndedAt := now.Add(5 * time.Minute) + dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ + InitiatorID: firstUser.UserID, + StartedAt: i2StartedAt, + ClientSessionID: sql.NullString{String: "ts-session", Valid: true}, + }, &i2EndedAt) + + //nolint:gocritic // Owner role is irrelevant; testing timestamps. + res, err := client.AIBridgeListSessions(ctx, codersdk.AIBridgeListSessionsFilter{}) + require.NoError(t, err) + require.Len(t, res.Sessions, 1) + s := res.Sessions[0] + require.WithinDuration(t, i1StartedAt, s.StartedAt, time.Millisecond, + "session started_at should be MIN of interception started_at values") + require.NotNil(t, s.EndedAt) + require.WithinDuration(t, i2EndedAt, *s.EndedAt, time.Millisecond, + "session ended_at should be MAX of interception ended_at values") + }) + + t.Run("LastPromptAcrossInterceptions", func(t *testing.T) { + t.Parallel() + client, db, firstUser := coderdenttest.NewWithDatabase(t, aibridgeOpts(t)) + ctx := testutil.Context(t, testutil.WaitLong) + + now := dbtime.Now() + + // Two interceptions in the same session. + i1EndedAt := now.Add(time.Minute) + i1 := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ + InitiatorID: firstUser.UserID, + StartedAt: now, + ClientSessionID: sql.NullString{String: "prompt-session", Valid: true}, + }, &i1EndedAt) + i2EndedAt := now.Add(3 * time.Minute) + i2 := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ + InitiatorID: firstUser.UserID, + StartedAt: now.Add(2 * time.Minute), + ClientSessionID: sql.NullString{String: "prompt-session", Valid: true}, + }, &i2EndedAt) + + // Add prompts to both interceptions. The most recent prompt + // overall belongs to the second interception. + dbgen.AIBridgeUserPrompt(t, db, database.InsertAIBridgeUserPromptParams{ + InterceptionID: i1.ID, + Prompt: "early prompt from i1", + CreatedAt: now, + }) + dbgen.AIBridgeUserPrompt(t, db, database.InsertAIBridgeUserPromptParams{ + InterceptionID: i2.ID, + Prompt: "latest prompt from i2", + CreatedAt: now.Add(2 * time.Minute), + }) + + //nolint:gocritic // Owner role is irrelevant; testing lateral join. + res, err := client.AIBridgeListSessions(ctx, codersdk.AIBridgeListSessionsFilter{}) + require.NoError(t, err) + require.Len(t, res.Sessions, 1) + require.NotNil(t, res.Sessions[0].LastPrompt) + require.Equal(t, "latest prompt from i2", *res.Sessions[0].LastPrompt, + "last_prompt should be the most recent prompt across all interceptions in the session") + }) + + t.Run("CombinedFilters", func(t *testing.T) { + t.Parallel() + client, db, firstUser := coderdenttest.NewWithDatabase(t, aibridgeOpts(t)) + ctx := testutil.Context(t, testutil.WaitLong) + + _, user2 := coderdtest.CreateAnotherUser(t, client, firstUser.OrganizationID) + + now := dbtime.Now() + + // Session A: user1, anthropic, claude-4, started now. + aEndedAt := now.Add(time.Minute) + a := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ + InitiatorID: firstUser.UserID, + Provider: "anthropic", + Model: "claude-4", + StartedAt: now, + }, &aEndedAt) + + // Session B: user1, anthropic, gpt-4, started 2h ago. + bEndedAt := now.Add(-2*time.Hour + time.Minute) + dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ + InitiatorID: firstUser.UserID, + Provider: "anthropic", + Model: "gpt-4", + StartedAt: now.Add(-2 * time.Hour), + }, &bEndedAt) + + // Session C: user2, anthropic, claude-4, started 1h ago. + cEndedAt := now.Add(-time.Hour + time.Minute) + dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ + InitiatorID: user2.ID, + Provider: "anthropic", + Model: "claude-4", + StartedAt: now.Add(-time.Hour), + }, &cEndedAt) + + // Combining provider + model + started_after should return + // only session A (user1, anthropic, claude-4, recent). + //nolint:gocritic // Owner role is irrelevant; testing combined filters. + res, err := client.AIBridgeListSessions(ctx, codersdk.AIBridgeListSessionsFilter{ + Provider: "anthropic", + Model: "claude-4", + StartedAfter: now.Add(-30 * time.Minute), + }) + require.NoError(t, err) + require.EqualValues(t, 1, res.Count) + require.Len(t, res.Sessions, 1) + require.Equal(t, a.ID.String(), res.Sessions[0].ID) + }) + + t.Run("CursorPaginationWithTiedStartedAt", func(t *testing.T) { + t.Parallel() + client, db, firstUser := coderdenttest.NewWithDatabase(t, aibridgeOpts(t)) + ctx := testutil.Context(t, testutil.WaitLong) + + now := dbtime.Now() + + // Create 3 standalone sessions all starting at the same time. + // The tie-breaker is session_id DESC. + for range 3 { + endedAt := now.Add(time.Minute) + dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ + InitiatorID: firstUser.UserID, + StartedAt: now, + }, &endedAt) + } + + // Fetch all to learn the sort order (started_at DESC, + // session_id DESC). + //nolint:gocritic // Owner role is irrelevant; testing cursor. + all, err := client.AIBridgeListSessions(ctx, codersdk.AIBridgeListSessionsFilter{}) + require.NoError(t, err) + require.Len(t, all.Sessions, 3) + + // Use the first result as cursor. The remaining 2 should be + // returned. + afterID := all.Sessions[0].ID + page, err := client.AIBridgeListSessions(ctx, codersdk.AIBridgeListSessionsFilter{ + Pagination: codersdk.Pagination{Limit: 10}, + AfterSessionID: afterID, + }) + require.NoError(t, err) + require.Len(t, page.Sessions, 2) + require.Equal(t, all.Sessions[1].ID, page.Sessions[0].ID) + require.Equal(t, all.Sessions[2].ID, page.Sessions[1].ID) + }) + + t.Run("DefaultLimit", func(t *testing.T) { + t.Parallel() + client, db, firstUser := coderdenttest.NewWithDatabase(t, aibridgeOpts(t)) + ctx := testutil.Context(t, testutil.WaitLong) + + now := dbtime.Now() + // Create 3 sessions. Without an explicit limit the default of + // 100 should apply and return all 3. + for i := range 3 { + endedAt := now.Add(-time.Duration(i)*time.Hour + time.Minute) + dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ + InitiatorID: firstUser.UserID, + StartedAt: now.Add(-time.Duration(i) * time.Hour), + }, &endedAt) + } + + // No Pagination.Limit set. + //nolint:gocritic // Owner role is irrelevant; testing default limit. + res, err := client.AIBridgeListSessions(ctx, codersdk.AIBridgeListSessionsFilter{}) + require.NoError(t, err) + require.Len(t, res.Sessions, 3) + require.EqualValues(t, 3, res.Count) + }) } func TestAIBridgeListClients(t *testing.T) {