From 43a1af3cd667cf8f14bf6272d765fc07d09b5d31 Mon Sep 17 00:00:00 2001 From: Danny Kopping Date: Tue, 24 Mar 2026 08:58:47 +0200 Subject: [PATCH] feat: session list API (#23202) _Disclaimer:_ _initially_ _produced_ _by_ _Claude_ _Opus_ _4\.6,_ _heavily_ _modified_ _and_ _reviewed_ _by_ _me._ Closes https://github.com/coder/internal/issues/1360 Adds a new `/api/v2/aibridge/sessions` API which returns "sessions". Sessions, as defined in the [RFC](https://www.notion.so/coderhq/AI-Bridge-Sessions-Threads-2ccd579be59280f28021d3baf7472fbe?source=copy_link), are a set of interceptions logically grouped by a session key issued by the client. The API design for this endpoint was done in [this doc](https://github.com/coder/internal/issues/1360). If the client has not provided a session ID, we will revert to the thread root ID, and if that's not present we use the interception's own ID (i.e. a session of a single interception - which is effectively what we show currently in our `/api/v2/aibridge/interceptions` API). The SQL query looks gnarly but it's relatively simple, and seems to perform well (~200ms) even when I import dogfood's `aibridge_*` tables into my workspace. If we need to improve performance on this later we can investigate materialized views, perhaps, but for now I don't think it's warranted. --- _The PR looks large but it's got a lot of generated code; the actual changes aren't huge._ --- coderd/apidoc/docs.go | 123 ++++ coderd/apidoc/swagger.json | 119 ++++ coderd/database/db2sdk/db2sdk.go | 38 ++ coderd/database/dbauthz/dbauthz.go | 24 + coderd/database/dbauthz/dbauthz_test.go | 28 + coderd/database/dbmetrics/querymetrics.go | 32 + coderd/database/dbmock/dbmock.go | 60 ++ coderd/database/dump.sql | 11 +- .../000449_aibridge_session_indexes.down.sql | 5 + .../000449_aibridge_session_indexes.up.sql | 40 ++ coderd/database/modelqueries.go | 106 ++++ coderd/database/models.go | 2 + coderd/database/querier.go | 5 + coderd/database/queries.sql.go | 309 +++++++++- coderd/database/queries/aibridge.sql | 188 ++++++ coderd/searchquery/search.go | 43 ++ codersdk/aibridge.go | 97 +++ docs/ai-coder/ai-bridge/monitoring.md | 4 +- docs/reference/api/aibridge.md | 70 +++ docs/reference/api/schemas.md | 111 ++++ enterprise/coderd/aibridge.go | 129 ++++ enterprise/coderd/aibridge_test.go | 554 ++++++++++++++++++ site/src/api/typesGenerated.ts | 28 + 23 files changed, 2118 insertions(+), 8 deletions(-) create mode 100644 coderd/database/migrations/000449_aibridge_session_indexes.down.sql create mode 100644 coderd/database/migrations/000449_aibridge_session_indexes.up.sql diff --git a/coderd/apidoc/docs.go b/coderd/apidoc/docs.go index 683fb1a4e1..2d18a69576 100644 --- a/coderd/apidoc/docs.go +++ b/coderd/apidoc/docs.go @@ -163,6 +163,57 @@ const docTemplate = `{ ] } }, + "/aibridge/sessions": { + "get": { + "produces": [ + "application/json" + ], + "tags": [ + "AI Bridge" + ], + "summary": "List AI Bridge sessions", + "operationId": "list-ai-bridge-sessions", + "parameters": [ + { + "type": "string", + "description": "Search query in the format ` + "`" + `key:value` + "`" + `. Available keys are: initiator, provider, model, client, session_id, started_after, started_before.", + "name": "q", + "in": "query" + }, + { + "type": "integer", + "description": "Page limit", + "name": "limit", + "in": "query" + }, + { + "type": "string", + "description": "Cursor pagination after session ID (cannot be used with offset)", + "name": "after_session_id", + "in": "query" + }, + { + "type": "integer", + "description": "Offset pagination (cannot be used with after_session_id)", + "name": "offset", + "in": "query" + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/codersdk.AIBridgeListSessionsResponse" + } + } + }, + "security": [ + { + "CoderSessionToken": [] + } + ] + } + }, "/appearance": { "get": { "produces": [ @@ -12778,6 +12829,20 @@ const docTemplate = `{ } } }, + "codersdk.AIBridgeListSessionsResponse": { + "type": "object", + "properties": { + "count": { + "type": "integer" + }, + "sessions": { + "type": "array", + "items": { + "$ref": "#/definitions/codersdk.AIBridgeSession" + } + } + } + }, "codersdk.AIBridgeOpenAIConfig": { "type": "object", "properties": { @@ -12830,6 +12895,64 @@ const docTemplate = `{ } } }, + "codersdk.AIBridgeSession": { + "type": "object", + "properties": { + "client": { + "type": "string" + }, + "ended_at": { + "type": "string", + "format": "date-time" + }, + "id": { + "type": "string" + }, + "initiator": { + "$ref": "#/definitions/codersdk.MinimalUser" + }, + "last_prompt": { + "type": "string" + }, + "metadata": { + "type": "object", + "additionalProperties": {} + }, + "models": { + "type": "array", + "items": { + "type": "string" + } + }, + "providers": { + "type": "array", + "items": { + "type": "string" + } + }, + "started_at": { + "type": "string", + "format": "date-time" + }, + "threads": { + "type": "integer" + }, + "token_usage_summary": { + "$ref": "#/definitions/codersdk.AIBridgeSessionTokenUsageSummary" + } + } + }, + "codersdk.AIBridgeSessionTokenUsageSummary": { + "type": "object", + "properties": { + "input_tokens": { + "type": "integer" + }, + "output_tokens": { + "type": "integer" + } + } + }, "codersdk.AIBridgeTokenUsage": { "type": "object", "properties": { diff --git a/coderd/apidoc/swagger.json b/coderd/apidoc/swagger.json index d5a74adb09..075536ae96 100644 --- a/coderd/apidoc/swagger.json +++ b/coderd/apidoc/swagger.json @@ -136,6 +136,53 @@ ] } }, + "/aibridge/sessions": { + "get": { + "produces": ["application/json"], + "tags": ["AI Bridge"], + "summary": "List AI Bridge sessions", + "operationId": "list-ai-bridge-sessions", + "parameters": [ + { + "type": "string", + "description": "Search query in the format `key:value`. Available keys are: initiator, provider, model, client, session_id, started_after, started_before.", + "name": "q", + "in": "query" + }, + { + "type": "integer", + "description": "Page limit", + "name": "limit", + "in": "query" + }, + { + "type": "string", + "description": "Cursor pagination after session ID (cannot be used with offset)", + "name": "after_session_id", + "in": "query" + }, + { + "type": "integer", + "description": "Offset pagination (cannot be used with after_session_id)", + "name": "offset", + "in": "query" + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/codersdk.AIBridgeListSessionsResponse" + } + } + }, + "security": [ + { + "CoderSessionToken": [] + } + ] + } + }, "/appearance": { "get": { "produces": ["application/json"], @@ -11368,6 +11415,20 @@ } } }, + "codersdk.AIBridgeListSessionsResponse": { + "type": "object", + "properties": { + "count": { + "type": "integer" + }, + "sessions": { + "type": "array", + "items": { + "$ref": "#/definitions/codersdk.AIBridgeSession" + } + } + } + }, "codersdk.AIBridgeOpenAIConfig": { "type": "object", "properties": { @@ -11420,6 +11481,64 @@ } } }, + "codersdk.AIBridgeSession": { + "type": "object", + "properties": { + "client": { + "type": "string" + }, + "ended_at": { + "type": "string", + "format": "date-time" + }, + "id": { + "type": "string" + }, + "initiator": { + "$ref": "#/definitions/codersdk.MinimalUser" + }, + "last_prompt": { + "type": "string" + }, + "metadata": { + "type": "object", + "additionalProperties": {} + }, + "models": { + "type": "array", + "items": { + "type": "string" + } + }, + "providers": { + "type": "array", + "items": { + "type": "string" + } + }, + "started_at": { + "type": "string", + "format": "date-time" + }, + "threads": { + "type": "integer" + }, + "token_usage_summary": { + "$ref": "#/definitions/codersdk.AIBridgeSessionTokenUsageSummary" + } + } + }, + "codersdk.AIBridgeSessionTokenUsageSummary": { + "type": "object", + "properties": { + "input_tokens": { + "type": "integer" + }, + "output_tokens": { + "type": "integer" + } + } + }, "codersdk.AIBridgeTokenUsage": { "type": "object", "properties": { diff --git a/coderd/database/db2sdk/db2sdk.go b/coderd/database/db2sdk/db2sdk.go index 77b280ff7b..d79d35b88b 100644 --- a/coderd/database/db2sdk/db2sdk.go +++ b/coderd/database/db2sdk/db2sdk.go @@ -1021,6 +1021,44 @@ func AIBridgeInterception(interception database.AIBridgeInterception, initiator return intc } +func AIBridgeSession(row database.ListAIBridgeSessionsRow) codersdk.AIBridgeSession { + session := codersdk.AIBridgeSession{ + ID: row.SessionID, + Initiator: MinimalUserFromVisibleUser(database.VisibleUser{ + ID: row.UserID, + Username: row.UserUsername, + Name: row.UserName, + AvatarURL: row.UserAvatarUrl, + }), + Providers: row.Providers, + Models: row.Models, + Metadata: jsonOrEmptyMap(pqtype.NullRawMessage{RawMessage: row.Metadata, Valid: len(row.Metadata) > 0}), + StartedAt: row.StartedAt, + Threads: row.Threads, + TokenUsageSummary: codersdk.AIBridgeSessionTokenUsageSummary{ + InputTokens: row.InputTokens, + OutputTokens: row.OutputTokens, + }, + } + // Ensure non-nil slices for JSON serialization. + if session.Providers == nil { + session.Providers = []string{} + } + if session.Models == nil { + session.Models = []string{} + } + if row.Client != "" { + session.Client = &row.Client + } + if !row.EndedAt.IsZero() { + session.EndedAt = &row.EndedAt + } + if row.LastPrompt != "" { + session.LastPrompt = &row.LastPrompt + } + return session +} + func AIBridgeTokenUsage(usage database.AIBridgeTokenUsage) codersdk.AIBridgeTokenUsage { return codersdk.AIBridgeTokenUsage{ ID: usage.ID, diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 310d5de494..321f325ef2 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -1709,6 +1709,14 @@ func (q *querier) CountAIBridgeInterceptions(ctx context.Context, arg database.C return q.db.CountAuthorizedAIBridgeInterceptions(ctx, arg, prep) } +func (q *querier) CountAIBridgeSessions(ctx context.Context, arg database.CountAIBridgeSessionsParams) (int64, error) { + prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAibridgeInterception.Type) + if err != nil { + return 0, xerrors.Errorf("(dev error) prepare sql filter: %w", err) + } + return q.db.CountAuthorizedAIBridgeSessions(ctx, arg, prep) +} + func (q *querier) CountAuditLogs(ctx context.Context, arg database.CountAuditLogsParams) (int64, error) { // Shortcut if the user is an owner. The SQL filter is noticeable, // and this is an easy win for owners. Which is the common case. @@ -5317,6 +5325,14 @@ func (q *querier) ListAIBridgeModels(ctx context.Context, arg database.ListAIBri return q.db.ListAuthorizedAIBridgeModels(ctx, arg, prep) } +func (q *querier) ListAIBridgeSessions(ctx context.Context, arg database.ListAIBridgeSessionsParams) ([]database.ListAIBridgeSessionsRow, error) { + prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAibridgeInterception.Type) + if err != nil { + return nil, xerrors.Errorf("(dev error) prepare sql filter: %w", err) + } + return q.db.ListAuthorizedAIBridgeSessions(ctx, arg, prep) +} + func (q *querier) ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIDs []uuid.UUID) ([]database.AIBridgeTokenUsage, error) { // This function is a system function until we implement a join for aibridge interceptions. // Matches the behavior of the workspaces listing endpoint. @@ -7128,6 +7144,14 @@ func (q *querier) ListAuthorizedAIBridgeModels(ctx context.Context, arg database return q.ListAIBridgeModels(ctx, arg) } +func (q *querier) ListAuthorizedAIBridgeSessions(ctx context.Context, arg database.ListAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) ([]database.ListAIBridgeSessionsRow, error) { + return q.db.ListAuthorizedAIBridgeSessions(ctx, arg, prepared) +} + +func (q *querier) CountAuthorizedAIBridgeSessions(ctx context.Context, arg database.CountAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) (int64, error) { + return q.db.CountAuthorizedAIBridgeSessions(ctx, arg, prepared) +} + func (q *querier) GetAuthorizedChats(ctx context.Context, arg database.GetChatsParams, _ rbac.PreparedAuthorized) ([]database.Chat, error) { return q.GetChats(ctx, arg) } diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 8a9e4988fd..b38ff83ffd 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -5514,6 +5514,34 @@ func (s *MethodTestSuite) TestAIBridge() { check.Args(params, emptyPreparedAuthorized{}).Asserts() })) + s.Run("ListAIBridgeSessions", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + params := database.ListAIBridgeSessionsParams{} + db.EXPECT().ListAuthorizedAIBridgeSessions(gomock.Any(), params, gomock.Any()).Return([]database.ListAIBridgeSessionsRow{}, nil).AnyTimes() + // No asserts here because SQLFilter. + check.Args(params).Asserts() + })) + + s.Run("ListAuthorizedAIBridgeSessions", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + params := database.ListAIBridgeSessionsParams{} + db.EXPECT().ListAuthorizedAIBridgeSessions(gomock.Any(), params, gomock.Any()).Return([]database.ListAIBridgeSessionsRow{}, nil).AnyTimes() + // No asserts here because SQLFilter. + check.Args(params, emptyPreparedAuthorized{}).Asserts() + })) + + s.Run("CountAIBridgeSessions", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + params := database.CountAIBridgeSessionsParams{} + db.EXPECT().CountAuthorizedAIBridgeSessions(gomock.Any(), params, gomock.Any()).Return(int64(0), nil).AnyTimes() + // No asserts here because SQLFilter. + check.Args(params).Asserts() + })) + + s.Run("CountAuthorizedAIBridgeSessions", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + params := database.CountAIBridgeSessionsParams{} + db.EXPECT().CountAuthorizedAIBridgeSessions(gomock.Any(), params, gomock.Any()).Return(int64(0), nil).AnyTimes() + // No asserts here because SQLFilter. + check.Args(params, emptyPreparedAuthorized{}).Asserts() + })) + s.Run("ListAIBridgeTokenUsagesByInterceptionIDs", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { ids := []uuid.UUID{{1}} db.EXPECT().ListAIBridgeTokenUsagesByInterceptionIDs(gomock.Any(), ids).Return([]database.AIBridgeTokenUsage{}, nil).AnyTimes() diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index 31f798cf34..c41e2a4647 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -280,6 +280,14 @@ func (m queryMetricsStore) CountAIBridgeInterceptions(ctx context.Context, arg d return r0, r1 } +func (m queryMetricsStore) CountAIBridgeSessions(ctx context.Context, arg database.CountAIBridgeSessionsParams) (int64, error) { + start := time.Now() + r0, r1 := m.s.CountAIBridgeSessions(ctx, arg) + m.queryLatencies.WithLabelValues("CountAIBridgeSessions").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "CountAIBridgeSessions").Inc() + return r0, r1 +} + func (m queryMetricsStore) CountAuditLogs(ctx context.Context, arg database.CountAuditLogsParams) (int64, error) { start := time.Now() r0, r1 := m.s.CountAuditLogs(ctx, arg) @@ -3720,6 +3728,14 @@ func (m queryMetricsStore) ListAIBridgeModels(ctx context.Context, arg database. return r0, r1 } +func (m queryMetricsStore) ListAIBridgeSessions(ctx context.Context, arg database.ListAIBridgeSessionsParams) ([]database.ListAIBridgeSessionsRow, error) { + start := time.Now() + r0, r1 := m.s.ListAIBridgeSessions(ctx, arg) + m.queryLatencies.WithLabelValues("ListAIBridgeSessions").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListAIBridgeSessions").Inc() + return r0, r1 +} + func (m queryMetricsStore) ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]database.AIBridgeTokenUsage, error) { start := time.Now() r0, r1 := m.s.ListAIBridgeTokenUsagesByInterceptionIDs(ctx, interceptionIds) @@ -5136,6 +5152,22 @@ func (m queryMetricsStore) ListAuthorizedAIBridgeModels(ctx context.Context, arg return r0, r1 } +func (m queryMetricsStore) ListAuthorizedAIBridgeSessions(ctx context.Context, arg database.ListAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) ([]database.ListAIBridgeSessionsRow, error) { + start := time.Now() + r0, r1 := m.s.ListAuthorizedAIBridgeSessions(ctx, arg, prepared) + m.queryLatencies.WithLabelValues("ListAuthorizedAIBridgeSessions").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListAuthorizedAIBridgeSessions").Inc() + return r0, r1 +} + +func (m queryMetricsStore) CountAuthorizedAIBridgeSessions(ctx context.Context, arg database.CountAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) (int64, error) { + start := time.Now() + r0, r1 := m.s.CountAuthorizedAIBridgeSessions(ctx, arg, prepared) + m.queryLatencies.WithLabelValues("CountAuthorizedAIBridgeSessions").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "CountAuthorizedAIBridgeSessions").Inc() + return r0, r1 +} + func (m queryMetricsStore) GetAuthorizedChats(ctx context.Context, arg database.GetChatsParams, prepared rbac.PreparedAuthorized) ([]database.Chat, error) { start := time.Now() r0, r1 := m.s.GetAuthorizedChats(ctx, arg, prepared) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index 82cc6c4145..a061d6ebab 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -363,6 +363,21 @@ func (mr *MockStoreMockRecorder) CountAIBridgeInterceptions(ctx, arg any) *gomoc return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAIBridgeInterceptions", reflect.TypeOf((*MockStore)(nil).CountAIBridgeInterceptions), ctx, arg) } +// CountAIBridgeSessions mocks base method. +func (m *MockStore) CountAIBridgeSessions(ctx context.Context, arg database.CountAIBridgeSessionsParams) (int64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CountAIBridgeSessions", ctx, arg) + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CountAIBridgeSessions indicates an expected call of CountAIBridgeSessions. +func (mr *MockStoreMockRecorder) CountAIBridgeSessions(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAIBridgeSessions", reflect.TypeOf((*MockStore)(nil).CountAIBridgeSessions), ctx, arg) +} + // CountAuditLogs mocks base method. func (m *MockStore) CountAuditLogs(ctx context.Context, arg database.CountAuditLogsParams) (int64, error) { m.ctrl.T.Helper() @@ -393,6 +408,21 @@ func (mr *MockStoreMockRecorder) CountAuthorizedAIBridgeInterceptions(ctx, arg, return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAuthorizedAIBridgeInterceptions", reflect.TypeOf((*MockStore)(nil).CountAuthorizedAIBridgeInterceptions), ctx, arg, prepared) } +// CountAuthorizedAIBridgeSessions mocks base method. +func (m *MockStore) CountAuthorizedAIBridgeSessions(ctx context.Context, arg database.CountAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) (int64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CountAuthorizedAIBridgeSessions", ctx, arg, prepared) + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CountAuthorizedAIBridgeSessions indicates an expected call of CountAuthorizedAIBridgeSessions. +func (mr *MockStoreMockRecorder) CountAuthorizedAIBridgeSessions(ctx, arg, prepared any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAuthorizedAIBridgeSessions", reflect.TypeOf((*MockStore)(nil).CountAuthorizedAIBridgeSessions), ctx, arg, prepared) +} + // CountAuthorizedAuditLogs mocks base method. func (m *MockStore) CountAuthorizedAuditLogs(ctx context.Context, arg database.CountAuditLogsParams, prepared rbac.PreparedAuthorized) (int64, error) { m.ctrl.T.Helper() @@ -6947,6 +6977,21 @@ func (mr *MockStoreMockRecorder) ListAIBridgeModels(ctx, arg any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAIBridgeModels", reflect.TypeOf((*MockStore)(nil).ListAIBridgeModels), ctx, arg) } +// ListAIBridgeSessions mocks base method. +func (m *MockStore) ListAIBridgeSessions(ctx context.Context, arg database.ListAIBridgeSessionsParams) ([]database.ListAIBridgeSessionsRow, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListAIBridgeSessions", ctx, arg) + ret0, _ := ret[0].([]database.ListAIBridgeSessionsRow) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListAIBridgeSessions indicates an expected call of ListAIBridgeSessions. +func (mr *MockStoreMockRecorder) ListAIBridgeSessions(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAIBridgeSessions", reflect.TypeOf((*MockStore)(nil).ListAIBridgeSessions), ctx, arg) +} + // ListAIBridgeTokenUsagesByInterceptionIDs mocks base method. func (m *MockStore) ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]database.AIBridgeTokenUsage, error) { m.ctrl.T.Helper() @@ -7022,6 +7067,21 @@ func (mr *MockStoreMockRecorder) ListAuthorizedAIBridgeModels(ctx, arg, prepared return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAuthorizedAIBridgeModels", reflect.TypeOf((*MockStore)(nil).ListAuthorizedAIBridgeModels), ctx, arg, prepared) } +// ListAuthorizedAIBridgeSessions mocks base method. +func (m *MockStore) ListAuthorizedAIBridgeSessions(ctx context.Context, arg database.ListAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) ([]database.ListAIBridgeSessionsRow, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListAuthorizedAIBridgeSessions", ctx, arg, prepared) + ret0, _ := ret[0].([]database.ListAIBridgeSessionsRow) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListAuthorizedAIBridgeSessions indicates an expected call of ListAuthorizedAIBridgeSessions. +func (mr *MockStoreMockRecorder) ListAuthorizedAIBridgeSessions(ctx, arg, prepared any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAuthorizedAIBridgeSessions", reflect.TypeOf((*MockStore)(nil).ListAuthorizedAIBridgeSessions), ctx, arg, prepared) +} + // ListChatUsageLimitGroupOverrides mocks base method. func (m *MockStore) ListChatUsageLimitGroupOverrides(ctx context.Context) ([]database.ListChatUsageLimitGroupOverridesRow, error) { m.ctrl.T.Helper() diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index 3a77928e1f..f44556cd72 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -1099,7 +1099,8 @@ CREATE TABLE aibridge_interceptions ( client character varying(64) DEFAULT 'Unknown'::character varying, thread_parent_id uuid, thread_root_id uuid, - client_session_id character varying(256) + client_session_id character varying(256), + session_id text GENERATED ALWAYS AS (COALESCE(client_session_id, ((thread_root_id)::text)::character varying, ((id)::text)::character varying)) STORED NOT NULL ); COMMENT ON TABLE aibridge_interceptions IS 'Audit log of requests intercepted by AI Bridge'; @@ -1112,6 +1113,8 @@ COMMENT ON COLUMN aibridge_interceptions.thread_root_id IS 'The root interceptio COMMENT ON COLUMN aibridge_interceptions.client_session_id IS 'The session ID supplied by the client (optional and not universally supported).'; +COMMENT ON COLUMN aibridge_interceptions.session_id IS 'Groups related interceptions into a logical session. Determined by a priority chain: (1) client_session_id — an explicit session identifier supplied by the calling client (e.g. Claude Code); (2) thread_root_id — the root of an agentic thread detected by Bridge through tool-call correlation, used when the client does not supply its own session ID; (3) id — the interception''s own ID, used as a last resort so every interception belongs to exactly one session even if it is standalone. This is a generated column stored on disk so it can be indexed and joined without recomputing the COALESCE on every query.'; + CREATE TABLE aibridge_model_thoughts ( interception_id uuid NOT NULL, content text NOT NULL, @@ -3654,6 +3657,10 @@ CREATE INDEX idx_aibridge_interceptions_model ON aibridge_interceptions USING bt CREATE INDEX idx_aibridge_interceptions_provider ON aibridge_interceptions USING btree (provider); +CREATE INDEX idx_aibridge_interceptions_session_id ON aibridge_interceptions USING btree (session_id) WHERE (ended_at IS NOT NULL); + +CREATE INDEX idx_aibridge_interceptions_sessions_filter ON aibridge_interceptions USING btree (initiator_id, started_at DESC, id DESC) WHERE (ended_at IS NOT NULL); + CREATE INDEX idx_aibridge_interceptions_started_id_desc ON aibridge_interceptions USING btree (started_at DESC, id DESC); CREATE INDEX idx_aibridge_interceptions_thread_parent_id ON aibridge_interceptions USING btree (thread_parent_id); @@ -3672,6 +3679,8 @@ CREATE INDEX idx_aibridge_tool_usages_provider_tool_call_id ON aibridge_tool_usa CREATE INDEX idx_aibridge_tool_usagesprovider_response_id ON aibridge_tool_usages USING btree (provider_response_id); +CREATE INDEX idx_aibridge_user_prompts_interception_created ON aibridge_user_prompts USING btree (interception_id, created_at DESC, id DESC); + CREATE INDEX idx_aibridge_user_prompts_interception_id ON aibridge_user_prompts USING btree (interception_id); CREATE INDEX idx_aibridge_user_prompts_provider_response_id ON aibridge_user_prompts USING btree (provider_response_id); diff --git a/coderd/database/migrations/000449_aibridge_session_indexes.down.sql b/coderd/database/migrations/000449_aibridge_session_indexes.down.sql new file mode 100644 index 0000000000..7f510a7cc5 --- /dev/null +++ b/coderd/database/migrations/000449_aibridge_session_indexes.down.sql @@ -0,0 +1,5 @@ +DROP INDEX IF EXISTS idx_aibridge_interceptions_session_id; +DROP INDEX IF EXISTS idx_aibridge_user_prompts_interception_created; +DROP INDEX IF EXISTS idx_aibridge_interceptions_sessions_filter; + +ALTER TABLE aibridge_interceptions DROP COLUMN IF EXISTS session_id; diff --git a/coderd/database/migrations/000449_aibridge_session_indexes.up.sql b/coderd/database/migrations/000449_aibridge_session_indexes.up.sql new file mode 100644 index 0000000000..3927f9c1ba --- /dev/null +++ b/coderd/database/migrations/000449_aibridge_session_indexes.up.sql @@ -0,0 +1,40 @@ +-- A "session" groups related interceptions together. See the COMMENT ON +-- COLUMN below for the full business-logic description. +ALTER TABLE aibridge_interceptions + ADD COLUMN session_id TEXT NOT NULL + GENERATED ALWAYS AS ( + COALESCE( + client_session_id, + thread_root_id::text, + id::text + ) + ) STORED; + +-- Searching and grouping on the resolved session ID will be common. +CREATE INDEX idx_aibridge_interceptions_session_id + ON aibridge_interceptions (session_id) + WHERE ended_at IS NOT NULL; + +COMMENT ON COLUMN aibridge_interceptions.session_id IS + 'Groups related interceptions into a logical session. ' + 'Determined by a priority chain: ' + '(1) client_session_id — an explicit session identifier supplied by the ' + 'calling client (e.g. Claude Code); ' + '(2) thread_root_id — the root of an agentic thread detected by Bridge ' + 'through tool-call correlation, used when the client does not supply its ' + 'own session ID; ' + '(3) id — the interception''s own ID, used as a last resort so every ' + 'interception belongs to exactly one session even if it is standalone. ' + 'This is a generated column stored on disk so it can be indexed and ' + 'joined without recomputing the COALESCE on every query.'; + +-- Composite index for the most common filter path used by +-- ListAIBridgeSessions: initiator_id equality + started_at range, +-- with ended_at IS NOT NULL as a partial filter. +CREATE INDEX idx_aibridge_interceptions_sessions_filter + ON aibridge_interceptions (initiator_id, started_at DESC, id DESC) + WHERE ended_at IS NOT NULL; + +-- Supports lateral prompt lookup by interception + recency. +CREATE INDEX idx_aibridge_user_prompts_interception_created + ON aibridge_user_prompts (interception_id, created_at DESC, id DESC); diff --git a/coderd/database/modelqueries.go b/coderd/database/modelqueries.go index c39b06202b..6a8aed8108 100644 --- a/coderd/database/modelqueries.go +++ b/coderd/database/modelqueries.go @@ -806,6 +806,8 @@ type aibridgeQuerier interface { ListAuthorizedAIBridgeInterceptions(ctx context.Context, arg ListAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) ([]ListAIBridgeInterceptionsRow, error) CountAuthorizedAIBridgeInterceptions(ctx context.Context, arg CountAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) (int64, error) ListAuthorizedAIBridgeModels(ctx context.Context, arg ListAIBridgeModelsParams, prepared rbac.PreparedAuthorized) ([]string, error) + ListAuthorizedAIBridgeSessions(ctx context.Context, arg ListAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) ([]ListAIBridgeSessionsRow, error) + CountAuthorizedAIBridgeSessions(ctx context.Context, arg CountAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) (int64, error) } func (q *sqlQuerier) ListAuthorizedAIBridgeInterceptions(ctx context.Context, arg ListAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) ([]ListAIBridgeInterceptionsRow, error) { @@ -852,6 +854,7 @@ func (q *sqlQuerier) ListAuthorizedAIBridgeInterceptions(ctx context.Context, ar &i.AIBridgeInterception.ThreadParentID, &i.AIBridgeInterception.ThreadRootID, &i.AIBridgeInterception.ClientSessionID, + &i.AIBridgeInterception.SessionID, &i.VisibleUser.ID, &i.VisibleUser.Username, &i.VisibleUser.Name, @@ -939,6 +942,109 @@ func (q *sqlQuerier) ListAuthorizedAIBridgeModels(ctx context.Context, arg ListA return items, nil } +func (q *sqlQuerier) ListAuthorizedAIBridgeSessions(ctx context.Context, arg ListAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) ([]ListAIBridgeSessionsRow, error) { + authorizedFilter, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{ + VariableConverter: regosql.AIBridgeInterceptionConverter(), + }) + if err != nil { + return nil, xerrors.Errorf("compile authorized filter: %w", err) + } + filtered, err := insertAuthorizedFilter(listAIBridgeSessions, fmt.Sprintf(" AND %s", authorizedFilter)) + if err != nil { + return nil, xerrors.Errorf("insert authorized filter: %w", err) + } + + query := fmt.Sprintf("-- name: ListAuthorizedAIBridgeSessions :many\n%s", filtered) + rows, err := q.db.QueryContext(ctx, query, + arg.AfterSessionID, + arg.Offset, + arg.Limit, + arg.StartedAfter, + arg.StartedBefore, + arg.InitiatorID, + arg.Provider, + arg.Model, + arg.Client, + arg.SessionID, + ) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ListAIBridgeSessionsRow + for rows.Next() { + var i ListAIBridgeSessionsRow + if err := rows.Scan( + &i.SessionID, + &i.UserID, + &i.UserUsername, + &i.UserName, + &i.UserAvatarUrl, + pq.Array(&i.Providers), + pq.Array(&i.Models), + &i.Client, + &i.Metadata, + &i.StartedAt, + &i.EndedAt, + &i.Threads, + &i.InputTokens, + &i.OutputTokens, + &i.LastPrompt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +func (q *sqlQuerier) CountAuthorizedAIBridgeSessions(ctx context.Context, arg CountAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) (int64, error) { + authorizedFilter, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{ + VariableConverter: regosql.AIBridgeInterceptionConverter(), + }) + if err != nil { + return 0, xerrors.Errorf("compile authorized filter: %w", err) + } + filtered, err := insertAuthorizedFilter(countAIBridgeSessions, fmt.Sprintf(" AND %s", authorizedFilter)) + if err != nil { + return 0, xerrors.Errorf("insert authorized filter: %w", err) + } + + query := fmt.Sprintf("-- name: CountAuthorizedAIBridgeSessions :one\n%s", filtered) + rows, err := q.db.QueryContext(ctx, query, + arg.StartedAfter, + arg.StartedBefore, + arg.InitiatorID, + arg.Provider, + arg.Model, + arg.Client, + arg.SessionID, + ) + if err != nil { + return 0, err + } + defer rows.Close() + var count int64 + for rows.Next() { + if err := rows.Scan(&count); err != nil { + return 0, err + } + } + if err := rows.Close(); err != nil { + return 0, err + } + if err := rows.Err(); err != nil { + return 0, err + } + return count, nil +} + func insertAuthorizedFilter(query string, replaceWith string) (string, error) { if !strings.Contains(query, authorizedQueryPlaceholder) { return "", xerrors.Errorf("query does not contain authorized replace string, this is not an authorized query") diff --git a/coderd/database/models.go b/coderd/database/models.go index 80972cf7d2..4b7feb6b02 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -4036,6 +4036,8 @@ type AIBridgeInterception struct { ThreadRootID uuid.NullUUID `db:"thread_root_id" json:"thread_root_id"` // The session ID supplied by the client (optional and not universally supported). ClientSessionID sql.NullString `db:"client_session_id" json:"client_session_id"` + // Groups related interceptions into a logical session. Determined by a priority chain: (1) client_session_id — an explicit session identifier supplied by the calling client (e.g. Claude Code); (2) thread_root_id — the root of an agentic thread detected by Bridge through tool-call correlation, used when the client does not supply its own session ID; (3) id — the interception's own ID, used as a last resort so every interception belongs to exactly one session even if it is standalone. This is a generated column stored on disk so it can be indexed and joined without recomputing the COALESCE on every query. + SessionID string `db:"session_id" json:"session_id"` } // Audit log of model thinking in intercepted requests in AI Bridge diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 9bd876b683..78cad95f85 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -76,6 +76,7 @@ type sqlcQuerier interface { CleanTailnetTunnels(ctx context.Context) error CleanupDeletedMCPServerIDsFromChats(ctx context.Context) error CountAIBridgeInterceptions(ctx context.Context, arg CountAIBridgeInterceptionsParams) (int64, error) + CountAIBridgeSessions(ctx context.Context, arg CountAIBridgeSessionsParams) (int64, error) CountAuditLogs(ctx context.Context, arg CountAuditLogsParams) (int64, error) CountConnectionLogs(ctx context.Context, arg CountConnectionLogsParams) (int64, error) // Counts enabled, non-deleted model configs that lack both input and @@ -759,6 +760,10 @@ type sqlcQuerier interface { // (provider, model, client) in the given timeframe for telemetry reporting. ListAIBridgeInterceptionsTelemetrySummaries(ctx context.Context, arg ListAIBridgeInterceptionsTelemetrySummariesParams) ([]ListAIBridgeInterceptionsTelemetrySummariesRow, error) ListAIBridgeModels(ctx context.Context, arg ListAIBridgeModelsParams) ([]string, error) + // Returns paginated sessions with aggregated metadata, token counts, and + // the most recent user prompt. A "session" is a logical grouping of + // interceptions that share the same session_id (set by the client). + ListAIBridgeSessions(ctx context.Context, arg ListAIBridgeSessionsParams) ([]ListAIBridgeSessionsRow, error) ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeTokenUsage, error) ListAIBridgeToolUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeToolUsage, error) ListAIBridgeUserPromptsByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeUserPrompt, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 2a74a99f22..845333d399 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -332,6 +332,77 @@ func (q *sqlQuerier) CountAIBridgeInterceptions(ctx context.Context, arg CountAI return count, err } +const countAIBridgeSessions = `-- name: CountAIBridgeSessions :one +SELECT + COUNT(DISTINCT (aibridge_interceptions.session_id, aibridge_interceptions.initiator_id)) +FROM + aibridge_interceptions +WHERE + -- Remove inflight interceptions (ones which lack an ended_at value). + aibridge_interceptions.ended_at IS NOT NULL + -- Filter by time frame + AND CASE + WHEN $1::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at >= $1::timestamptz + ELSE true + END + AND CASE + WHEN $2::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at <= $2::timestamptz + ELSE true + END + -- Filter initiator_id + AND CASE + WHEN $3::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN aibridge_interceptions.initiator_id = $3::uuid + ELSE true + END + -- Filter provider + AND CASE + WHEN $4::text != '' THEN aibridge_interceptions.provider = $4::text + ELSE true + END + -- Filter model + AND CASE + WHEN $5::text != '' THEN aibridge_interceptions.model = $5::text + ELSE true + END + -- Filter client + AND CASE + WHEN $6::text != '' THEN COALESCE(aibridge_interceptions.client, 'Unknown') = $6::text + ELSE true + END + -- Filter session_id + AND CASE + WHEN $7::text != '' THEN aibridge_interceptions.session_id = $7::text + ELSE true + END + -- Authorize Filter clause will be injected below in CountAuthorizedAIBridgeSessions + -- @authorize_filter +` + +type CountAIBridgeSessionsParams struct { + StartedAfter time.Time `db:"started_after" json:"started_after"` + StartedBefore time.Time `db:"started_before" json:"started_before"` + InitiatorID uuid.UUID `db:"initiator_id" json:"initiator_id"` + Provider string `db:"provider" json:"provider"` + Model string `db:"model" json:"model"` + Client string `db:"client" json:"client"` + SessionID string `db:"session_id" json:"session_id"` +} + +func (q *sqlQuerier) CountAIBridgeSessions(ctx context.Context, arg CountAIBridgeSessionsParams) (int64, error) { + row := q.db.QueryRowContext(ctx, countAIBridgeSessions, + arg.StartedAfter, + arg.StartedBefore, + arg.InitiatorID, + arg.Provider, + arg.Model, + arg.Client, + arg.SessionID, + ) + var count int64 + err := row.Scan(&count) + return count, err +} + const deleteOldAIBridgeRecords = `-- name: DeleteOldAIBridgeRecords :one WITH -- We don't have FK relationships between the dependent tables and aibridge_interceptions, so we can't rely on DELETE CASCADE. @@ -384,7 +455,7 @@ func (q *sqlQuerier) DeleteOldAIBridgeRecords(ctx context.Context, beforeTime ti const getAIBridgeInterceptionByID = `-- name: GetAIBridgeInterceptionByID :one SELECT - id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id, client_session_id + id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id, client_session_id, session_id FROM aibridge_interceptions WHERE @@ -407,6 +478,7 @@ func (q *sqlQuerier) GetAIBridgeInterceptionByID(ctx context.Context, id uuid.UU &i.ThreadParentID, &i.ThreadRootID, &i.ClientSessionID, + &i.SessionID, ) return i, err } @@ -441,7 +513,7 @@ func (q *sqlQuerier) GetAIBridgeInterceptionLineageByToolCallID(ctx context.Cont const getAIBridgeInterceptions = `-- name: GetAIBridgeInterceptions :many SELECT - id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id, client_session_id + id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id, client_session_id, session_id FROM aibridge_interceptions ` @@ -468,6 +540,7 @@ func (q *sqlQuerier) GetAIBridgeInterceptions(ctx context.Context) ([]AIBridgeIn &i.ThreadParentID, &i.ThreadRootID, &i.ClientSessionID, + &i.SessionID, ); err != nil { return nil, err } @@ -618,7 +691,7 @@ INSERT INTO aibridge_interceptions ( ) VALUES ( $1, $2, $3, $4, $5, COALESCE($6::jsonb, '{}'::jsonb), $7, $8, $9, $10::uuid, $11::uuid ) -RETURNING id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id, client_session_id +RETURNING id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id, client_session_id, session_id ` type InsertAIBridgeInterceptionParams struct { @@ -663,6 +736,7 @@ func (q *sqlQuerier) InsertAIBridgeInterception(ctx context.Context, arg InsertA &i.ThreadParentID, &i.ThreadRootID, &i.ClientSessionID, + &i.SessionID, ) return i, err } @@ -837,7 +911,7 @@ func (q *sqlQuerier) InsertAIBridgeUserPrompt(ctx context.Context, arg InsertAIB const listAIBridgeInterceptions = `-- name: ListAIBridgeInterceptions :many SELECT - aibridge_interceptions.id, aibridge_interceptions.initiator_id, aibridge_interceptions.provider, aibridge_interceptions.model, aibridge_interceptions.started_at, aibridge_interceptions.metadata, aibridge_interceptions.ended_at, aibridge_interceptions.api_key_id, aibridge_interceptions.client, aibridge_interceptions.thread_parent_id, aibridge_interceptions.thread_root_id, aibridge_interceptions.client_session_id, + aibridge_interceptions.id, aibridge_interceptions.initiator_id, aibridge_interceptions.provider, aibridge_interceptions.model, aibridge_interceptions.started_at, aibridge_interceptions.metadata, aibridge_interceptions.ended_at, aibridge_interceptions.api_key_id, aibridge_interceptions.client, aibridge_interceptions.thread_parent_id, aibridge_interceptions.thread_root_id, aibridge_interceptions.client_session_id, aibridge_interceptions.session_id, visible_users.id, visible_users.username, visible_users.name, visible_users.avatar_url FROM aibridge_interceptions @@ -949,6 +1023,7 @@ func (q *sqlQuerier) ListAIBridgeInterceptions(ctx context.Context, arg ListAIBr &i.AIBridgeInterception.ThreadParentID, &i.AIBridgeInterception.ThreadRootID, &i.AIBridgeInterception.ClientSessionID, + &i.AIBridgeInterception.SessionID, &i.VisibleUser.ID, &i.VisibleUser.Username, &i.VisibleUser.Name, @@ -1071,6 +1146,229 @@ func (q *sqlQuerier) ListAIBridgeModels(ctx context.Context, arg ListAIBridgeMod return items, nil } +const listAIBridgeSessions = `-- name: ListAIBridgeSessions :many +WITH filtered_interceptions AS ( + SELECT + aibridge_interceptions.id, aibridge_interceptions.initiator_id, aibridge_interceptions.provider, aibridge_interceptions.model, aibridge_interceptions.started_at, aibridge_interceptions.metadata, aibridge_interceptions.ended_at, aibridge_interceptions.api_key_id, aibridge_interceptions.client, aibridge_interceptions.thread_parent_id, aibridge_interceptions.thread_root_id, aibridge_interceptions.client_session_id, aibridge_interceptions.session_id + FROM + aibridge_interceptions + WHERE + -- Remove inflight interceptions (ones which lack an ended_at value). + aibridge_interceptions.ended_at IS NOT NULL + -- Filter by time frame + AND CASE + WHEN $4::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at >= $4::timestamptz + ELSE true + END + AND CASE + WHEN $5::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at <= $5::timestamptz + ELSE true + END + -- Filter initiator_id + AND CASE + WHEN $6::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN aibridge_interceptions.initiator_id = $6::uuid + ELSE true + END + -- Filter provider + AND CASE + WHEN $7::text != '' THEN aibridge_interceptions.provider = $7::text + ELSE true + END + -- Filter model + AND CASE + WHEN $8::text != '' THEN aibridge_interceptions.model = $8::text + ELSE true + END + -- Filter client + AND CASE + WHEN $9::text != '' THEN COALESCE(aibridge_interceptions.client, 'Unknown') = $9::text + ELSE true + END + -- Filter session_id + AND CASE + WHEN $10::text != '' THEN aibridge_interceptions.session_id = $10::text + ELSE true + END + -- Authorize Filter clause will be injected below in ListAuthorizedAIBridgeSessions + -- @authorize_filter +), +session_tokens AS ( + -- Aggregate token usage across all interceptions in each session. + -- Group by (session_id, initiator_id) to avoid merging sessions from + -- different users who happen to share the same client_session_id. + SELECT + fi.session_id, + fi.initiator_id, + COALESCE(SUM(tu.input_tokens), 0)::bigint AS input_tokens, + COALESCE(SUM(tu.output_tokens), 0)::bigint AS output_tokens + -- TODO: add extra token types once https://github.com/coder/aibridge/issues/150 lands. + FROM + filtered_interceptions fi + LEFT JOIN + aibridge_token_usages tu ON fi.id = tu.interception_id + GROUP BY + fi.session_id, fi.initiator_id +), +session_root AS ( + -- Build one summary row per session. Group by (session_id, initiator_id) + -- to avoid merging sessions from different users who happen to share the + -- same client_session_id. The ARRAY_AGG with ORDER BY picks values from + -- the chronologically first interception for fields that should represent + -- the session as a whole (client, metadata). Threads are counted as + -- distinct root interception IDs: an interception with a NULL + -- thread_root_id is itself a thread root. + SELECT + fi.session_id, + fi.initiator_id, + (ARRAY_AGG(fi.client ORDER BY fi.started_at, fi.id))[1] AS client, + (ARRAY_AGG(fi.metadata ORDER BY fi.started_at, fi.id))[1] AS metadata, + ARRAY_AGG(DISTINCT fi.provider ORDER BY fi.provider) AS providers, + ARRAY_AGG(DISTINCT fi.model ORDER BY fi.model) AS models, + MIN(fi.started_at) AS started_at, + MAX(fi.ended_at) AS ended_at, + COUNT(DISTINCT COALESCE(fi.thread_root_id, fi.id)) AS threads, + -- Collect IDs for lateral prompt lookup. + ARRAY_AGG(fi.id) AS interception_ids + FROM + filtered_interceptions fi + GROUP BY + fi.session_id, fi.initiator_id +) +SELECT + sr.session_id, + visible_users.id AS user_id, + visible_users.username AS user_username, + visible_users.name AS user_name, + visible_users.avatar_url AS user_avatar_url, + sr.providers::text[] AS providers, + sr.models::text[] AS models, + COALESCE(sr.client, '')::varchar(64) AS client, + sr.metadata::jsonb AS metadata, + sr.started_at::timestamptz AS started_at, + sr.ended_at::timestamptz AS ended_at, + sr.threads, + COALESCE(st.input_tokens, 0)::bigint AS input_tokens, + COALESCE(st.output_tokens, 0)::bigint AS output_tokens, + COALESCE(slp.prompt, '') AS last_prompt +FROM + session_root sr +JOIN + visible_users ON visible_users.id = sr.initiator_id +LEFT JOIN + session_tokens st ON st.session_id = sr.session_id AND st.initiator_id = sr.initiator_id +LEFT JOIN LATERAL ( + -- Lateral join to efficiently fetch only the most recent user prompt + -- across all interceptions in the session, avoiding a full aggregation. + SELECT up.prompt + FROM aibridge_user_prompts up + WHERE up.interception_id = ANY(sr.interception_ids) + ORDER BY up.created_at DESC, up.id DESC + LIMIT 1 +) slp ON true +WHERE + -- Cursor pagination: uses a composite (started_at, session_id) cursor + -- to support keyset pagination. The less-than comparison matches the + -- DESC sort order so that rows after the cursor come later in results. + CASE + WHEN $1::text != '' THEN ( + (sr.started_at, sr.session_id) < ( + (SELECT started_at FROM session_root WHERE session_id = $1), + $1::text + ) + ) + ELSE true + END +ORDER BY + sr.started_at DESC, + sr.session_id DESC +LIMIT COALESCE(NULLIF($3::integer, 0), 100) +OFFSET $2 +` + +type ListAIBridgeSessionsParams struct { + AfterSessionID string `db:"after_session_id" json:"after_session_id"` + Offset int32 `db:"offset_" json:"offset_"` + Limit int32 `db:"limit_" json:"limit_"` + StartedAfter time.Time `db:"started_after" json:"started_after"` + StartedBefore time.Time `db:"started_before" json:"started_before"` + InitiatorID uuid.UUID `db:"initiator_id" json:"initiator_id"` + Provider string `db:"provider" json:"provider"` + Model string `db:"model" json:"model"` + Client string `db:"client" json:"client"` + SessionID string `db:"session_id" json:"session_id"` +} + +type ListAIBridgeSessionsRow struct { + SessionID string `db:"session_id" json:"session_id"` + UserID uuid.UUID `db:"user_id" json:"user_id"` + UserUsername string `db:"user_username" json:"user_username"` + UserName string `db:"user_name" json:"user_name"` + UserAvatarUrl string `db:"user_avatar_url" json:"user_avatar_url"` + Providers []string `db:"providers" json:"providers"` + Models []string `db:"models" json:"models"` + Client string `db:"client" json:"client"` + Metadata json.RawMessage `db:"metadata" json:"metadata"` + StartedAt time.Time `db:"started_at" json:"started_at"` + EndedAt time.Time `db:"ended_at" json:"ended_at"` + Threads int64 `db:"threads" json:"threads"` + InputTokens int64 `db:"input_tokens" json:"input_tokens"` + OutputTokens int64 `db:"output_tokens" json:"output_tokens"` + LastPrompt string `db:"last_prompt" json:"last_prompt"` +} + +// Returns paginated sessions with aggregated metadata, token counts, and +// the most recent user prompt. A "session" is a logical grouping of +// interceptions that share the same session_id (set by the client). +func (q *sqlQuerier) ListAIBridgeSessions(ctx context.Context, arg ListAIBridgeSessionsParams) ([]ListAIBridgeSessionsRow, error) { + rows, err := q.db.QueryContext(ctx, listAIBridgeSessions, + arg.AfterSessionID, + arg.Offset, + arg.Limit, + arg.StartedAfter, + arg.StartedBefore, + arg.InitiatorID, + arg.Provider, + arg.Model, + arg.Client, + arg.SessionID, + ) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ListAIBridgeSessionsRow + for rows.Next() { + var i ListAIBridgeSessionsRow + if err := rows.Scan( + &i.SessionID, + &i.UserID, + &i.UserUsername, + &i.UserName, + &i.UserAvatarUrl, + pq.Array(&i.Providers), + pq.Array(&i.Models), + &i.Client, + &i.Metadata, + &i.StartedAt, + &i.EndedAt, + &i.Threads, + &i.InputTokens, + &i.OutputTokens, + &i.LastPrompt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const listAIBridgeTokenUsagesByInterceptionIDs = `-- name: ListAIBridgeTokenUsagesByInterceptionIDs :many SELECT id, interception_id, provider_response_id, input_tokens, output_tokens, metadata, created_at @@ -1209,7 +1507,7 @@ UPDATE aibridge_interceptions WHERE id = $2::uuid AND ended_at IS NULL -RETURNING id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id, client_session_id +RETURNING id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id, client_session_id, session_id ` type UpdateAIBridgeInterceptionEndedParams struct { @@ -1233,6 +1531,7 @@ func (q *sqlQuerier) UpdateAIBridgeInterceptionEnded(ctx context.Context, arg Up &i.ThreadParentID, &i.ThreadRootID, &i.ClientSessionID, + &i.SessionID, ) return i, err } diff --git a/coderd/database/queries/aibridge.sql b/coderd/database/queries/aibridge.sql index 2115ffebe7..5804fda09a 100644 --- a/coderd/database/queries/aibridge.sql +++ b/coderd/database/queries/aibridge.sql @@ -404,6 +404,194 @@ SELECT ( (SELECT COUNT(*) FROM interceptions) )::bigint as total_deleted; +-- name: CountAIBridgeSessions :one +SELECT + COUNT(DISTINCT (aibridge_interceptions.session_id, aibridge_interceptions.initiator_id)) +FROM + aibridge_interceptions +WHERE + -- Remove inflight interceptions (ones which lack an ended_at value). + aibridge_interceptions.ended_at IS NOT NULL + -- Filter by time frame + AND CASE + WHEN @started_after::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at >= @started_after::timestamptz + ELSE true + END + AND CASE + WHEN @started_before::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at <= @started_before::timestamptz + ELSE true + END + -- Filter initiator_id + AND CASE + WHEN @initiator_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN aibridge_interceptions.initiator_id = @initiator_id::uuid + ELSE true + END + -- Filter provider + AND CASE + WHEN @provider::text != '' THEN aibridge_interceptions.provider = @provider::text + ELSE true + END + -- Filter model + AND CASE + WHEN @model::text != '' THEN aibridge_interceptions.model = @model::text + ELSE true + END + -- Filter client + AND CASE + WHEN @client::text != '' THEN COALESCE(aibridge_interceptions.client, 'Unknown') = @client::text + ELSE true + END + -- Filter session_id + AND CASE + WHEN @session_id::text != '' THEN aibridge_interceptions.session_id = @session_id::text + ELSE true + END + -- Authorize Filter clause will be injected below in CountAuthorizedAIBridgeSessions + -- @authorize_filter +; + +-- name: ListAIBridgeSessions :many +-- Returns paginated sessions with aggregated metadata, token counts, and +-- the most recent user prompt. A "session" is a logical grouping of +-- interceptions that share the same session_id (set by the client). +WITH filtered_interceptions AS ( + SELECT + aibridge_interceptions.* + FROM + aibridge_interceptions + WHERE + -- Remove inflight interceptions (ones which lack an ended_at value). + aibridge_interceptions.ended_at IS NOT NULL + -- Filter by time frame + AND CASE + WHEN @started_after::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at >= @started_after::timestamptz + ELSE true + END + AND CASE + WHEN @started_before::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at <= @started_before::timestamptz + ELSE true + END + -- Filter initiator_id + AND CASE + WHEN @initiator_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN aibridge_interceptions.initiator_id = @initiator_id::uuid + ELSE true + END + -- Filter provider + AND CASE + WHEN @provider::text != '' THEN aibridge_interceptions.provider = @provider::text + ELSE true + END + -- Filter model + AND CASE + WHEN @model::text != '' THEN aibridge_interceptions.model = @model::text + ELSE true + END + -- Filter client + AND CASE + WHEN @client::text != '' THEN COALESCE(aibridge_interceptions.client, 'Unknown') = @client::text + ELSE true + END + -- Filter session_id + AND CASE + WHEN @session_id::text != '' THEN aibridge_interceptions.session_id = @session_id::text + ELSE true + END + -- Authorize Filter clause will be injected below in ListAuthorizedAIBridgeSessions + -- @authorize_filter +), +session_tokens AS ( + -- Aggregate token usage across all interceptions in each session. + -- Group by (session_id, initiator_id) to avoid merging sessions from + -- different users who happen to share the same client_session_id. + SELECT + fi.session_id, + fi.initiator_id, + COALESCE(SUM(tu.input_tokens), 0)::bigint AS input_tokens, + COALESCE(SUM(tu.output_tokens), 0)::bigint AS output_tokens + -- TODO: add extra token types once https://github.com/coder/aibridge/issues/150 lands. + FROM + filtered_interceptions fi + LEFT JOIN + aibridge_token_usages tu ON fi.id = tu.interception_id + GROUP BY + fi.session_id, fi.initiator_id +), +session_root AS ( + -- Build one summary row per session. Group by (session_id, initiator_id) + -- to avoid merging sessions from different users who happen to share the + -- same client_session_id. The ARRAY_AGG with ORDER BY picks values from + -- the chronologically first interception for fields that should represent + -- the session as a whole (client, metadata). Threads are counted as + -- distinct root interception IDs: an interception with a NULL + -- thread_root_id is itself a thread root. + SELECT + fi.session_id, + fi.initiator_id, + (ARRAY_AGG(fi.client ORDER BY fi.started_at, fi.id))[1] AS client, + (ARRAY_AGG(fi.metadata ORDER BY fi.started_at, fi.id))[1] AS metadata, + ARRAY_AGG(DISTINCT fi.provider ORDER BY fi.provider) AS providers, + ARRAY_AGG(DISTINCT fi.model ORDER BY fi.model) AS models, + MIN(fi.started_at) AS started_at, + MAX(fi.ended_at) AS ended_at, + COUNT(DISTINCT COALESCE(fi.thread_root_id, fi.id)) AS threads, + -- Collect IDs for lateral prompt lookup. + ARRAY_AGG(fi.id) AS interception_ids + FROM + filtered_interceptions fi + GROUP BY + fi.session_id, fi.initiator_id +) +SELECT + sr.session_id, + visible_users.id AS user_id, + visible_users.username AS user_username, + visible_users.name AS user_name, + visible_users.avatar_url AS user_avatar_url, + sr.providers::text[] AS providers, + sr.models::text[] AS models, + COALESCE(sr.client, '')::varchar(64) AS client, + sr.metadata::jsonb AS metadata, + sr.started_at::timestamptz AS started_at, + sr.ended_at::timestamptz AS ended_at, + sr.threads, + COALESCE(st.input_tokens, 0)::bigint AS input_tokens, + COALESCE(st.output_tokens, 0)::bigint AS output_tokens, + COALESCE(slp.prompt, '') AS last_prompt +FROM + session_root sr +JOIN + visible_users ON visible_users.id = sr.initiator_id +LEFT JOIN + session_tokens st ON st.session_id = sr.session_id AND st.initiator_id = sr.initiator_id +LEFT JOIN LATERAL ( + -- Lateral join to efficiently fetch only the most recent user prompt + -- across all interceptions in the session, avoiding a full aggregation. + SELECT up.prompt + FROM aibridge_user_prompts up + WHERE up.interception_id = ANY(sr.interception_ids) + ORDER BY up.created_at DESC, up.id DESC + LIMIT 1 +) slp ON true +WHERE + -- Cursor pagination: uses a composite (started_at, session_id) cursor + -- to support keyset pagination. The less-than comparison matches the + -- DESC sort order so that rows after the cursor come later in results. + CASE + WHEN @after_session_id::text != '' THEN ( + (sr.started_at, sr.session_id) < ( + (SELECT started_at FROM session_root WHERE session_id = @after_session_id), + @after_session_id::text + ) + ) + ELSE true + END +ORDER BY + sr.started_at DESC, + sr.session_id DESC +LIMIT COALESCE(NULLIF(@limit_::integer, 0), 100) +OFFSET @offset_ +; + -- name: ListAIBridgeModels :many SELECT model diff --git a/coderd/searchquery/search.go b/coderd/searchquery/search.go index 7d8f517d08..462562ae28 100644 --- a/coderd/searchquery/search.go +++ b/coderd/searchquery/search.go @@ -401,6 +401,49 @@ func AIBridgeInterceptions(ctx context.Context, db database.Store, query string, return filter, parser.Errors } +func AIBridgeSessions(ctx context.Context, db database.Store, query string, page codersdk.Pagination, actorID uuid.UUID, afterSessionID string) (database.ListAIBridgeSessionsParams, []codersdk.ValidationError) { + // nolint:exhaustruct // Empty values just means "don't filter by that field". + filter := database.ListAIBridgeSessionsParams{ + AfterSessionID: afterSessionID, + // #nosec G115 - Safe conversion for pagination limit which is expected to be within int32 range + Limit: int32(page.Limit), + // #nosec G115 - Safe conversion for pagination offset which is expected to be within int32 range + Offset: int32(page.Offset), + } + + if query == "" { + return filter, nil + } + + values, errors := searchTerms(query, func(string, url.Values) error { + // Do not specify a default search key; let's be explicit to prevent user confusion. + return xerrors.New("no search key specified") + }) + if len(errors) > 0 { + return filter, errors + } + + parser := httpapi.NewQueryParamParser() + filter.InitiatorID = parseUser(ctx, db, parser, values, "initiator", actorID) + filter.Provider = parser.String(values, "", "provider") + filter.Model = parser.String(values, "", "model") + filter.Client = parser.String(values, "", "client") + filter.SessionID = parser.String(values, "", "session_id") + + // Time must be between started_after and started_before. + filter.StartedAfter = parser.Time3339Nano(values, time.Time{}, "started_after") + filter.StartedBefore = parser.Time3339Nano(values, time.Time{}, "started_before") + if !filter.StartedBefore.IsZero() && !filter.StartedAfter.IsZero() && !filter.StartedBefore.After(filter.StartedAfter) { + parser.Errors = append(parser.Errors, codersdk.ValidationError{ + Field: "started_before", + Detail: `Query param "started_before" has invalid value: "started_before" must be after "started_after" if set`, + }) + } + + parser.ErrorExcessParams(values) + return filter, parser.Errors +} + func AIBridgeModels(query string, page codersdk.Pagination) (database.ListAIBridgeModelsParams, []codersdk.ValidationError) { // nolint:exhaustruct // Empty values just means "don't filter by that field". filter := database.ListAIBridgeModelsParams{ diff --git a/codersdk/aibridge.go b/codersdk/aibridge.go index 2d994558f5..56b2260bfe 100644 --- a/codersdk/aibridge.go +++ b/codersdk/aibridge.go @@ -63,6 +63,51 @@ type AIBridgeListInterceptionsResponse struct { Results []AIBridgeInterception `json:"results"` } +type AIBridgeSession struct { + ID string `json:"id"` + Initiator MinimalUser `json:"initiator"` + Providers []string `json:"providers"` + Models []string `json:"models"` + Client *string `json:"client"` + Metadata map[string]any `json:"metadata"` + StartedAt time.Time `json:"started_at" format:"date-time"` + EndedAt *time.Time `json:"ended_at,omitempty" format:"date-time"` + Threads int64 `json:"threads"` + TokenUsageSummary AIBridgeSessionTokenUsageSummary `json:"token_usage_summary"` + LastPrompt *string `json:"last_prompt,omitempty"` +} + +type AIBridgeSessionTokenUsageSummary struct { + InputTokens int64 `json:"input_tokens"` + OutputTokens int64 `json:"output_tokens"` +} + +type AIBridgeListSessionsResponse struct { + Count int64 `json:"count"` + Sessions []AIBridgeSession `json:"sessions"` +} + +// @typescript-ignore AIBridgeListSessionsFilter +type AIBridgeListSessionsFilter struct { + // Limit defaults to 100, max is 1000. + Pagination Pagination `json:"pagination,omitempty"` + + // Initiator is a user ID, username, or "me". + Initiator string `json:"initiator,omitempty"` + StartedBefore time.Time `json:"started_before,omitempty" format:"date-time"` + StartedAfter time.Time `json:"started_after,omitempty" format:"date-time"` + Provider string `json:"provider,omitempty"` + Model string `json:"model,omitempty"` + Client string `json:"client,omitempty"` + SessionID string `json:"session_id,omitempty"` + + // AfterSessionID is a cursor for pagination. It is the session ID of the + // last session in the previous page. + AfterSessionID string `json:"after_session_id,omitempty"` + + FilterQuery string `json:"q,omitempty"` +} + // @typescript-ignore AIBridgeListInterceptionsFilter type AIBridgeListInterceptionsFilter struct { // Limit defaults to 100, max is 1000. @@ -117,6 +162,44 @@ func (f AIBridgeListInterceptionsFilter) asRequestOption() RequestOption { } } +// asRequestOption returns a function that can be used in (*Client).Request. +func (f AIBridgeListSessionsFilter) asRequestOption() RequestOption { + return func(r *http.Request) { + var params []string + if f.Initiator != "" { + params = append(params, fmt.Sprintf("initiator:%q", f.Initiator)) + } + if !f.StartedBefore.IsZero() { + params = append(params, fmt.Sprintf("started_before:%q", f.StartedBefore.Format(time.RFC3339Nano))) + } + if !f.StartedAfter.IsZero() { + params = append(params, fmt.Sprintf("started_after:%q", f.StartedAfter.Format(time.RFC3339Nano))) + } + if f.Provider != "" { + params = append(params, fmt.Sprintf("provider:%q", f.Provider)) + } + if f.Model != "" { + params = append(params, fmt.Sprintf("model:%q", f.Model)) + } + if f.Client != "" { + params = append(params, fmt.Sprintf("client:%q", f.Client)) + } + if f.SessionID != "" { + params = append(params, fmt.Sprintf("session_id:%q", f.SessionID)) + } + if f.FilterQuery != "" { + params = append(params, f.FilterQuery) + } + + q := r.URL.Query() + q.Set("q", strings.Join(params, " ")) + if f.AfterSessionID != "" { + q.Set("after_session_id", f.AfterSessionID) + } + r.URL.RawQuery = q.Encode() + } +} + // AIBridgeListInterceptions returns AI Bridge interceptions with the given // filter. func (c *Client) AIBridgeListInterceptions(ctx context.Context, filter AIBridgeListInterceptionsFilter) (AIBridgeListInterceptionsResponse, error) { @@ -131,3 +214,17 @@ func (c *Client) AIBridgeListInterceptions(ctx context.Context, filter AIBridgeL var resp AIBridgeListInterceptionsResponse return resp, json.NewDecoder(res.Body).Decode(&resp) } + +// AIBridgeListSessions returns AI Bridge sessions with the given filter. +func (c *Client) AIBridgeListSessions(ctx context.Context, filter AIBridgeListSessionsFilter) (AIBridgeListSessionsResponse, error) { + res, err := c.Request(ctx, http.MethodGet, "/api/v2/aibridge/sessions", nil, filter.asRequestOption(), filter.Pagination.asRequestOption()) + if err != nil { + return AIBridgeListSessionsResponse{}, err + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return AIBridgeListSessionsResponse{}, ReadBodyAsError(res) + } + var resp AIBridgeListSessionsResponse + return resp, json.NewDecoder(res.Body).Decode(&resp) +} diff --git a/docs/ai-coder/ai-bridge/monitoring.md b/docs/ai-coder/ai-bridge/monitoring.md index f214eeb8a0..d339df3ee8 100644 --- a/docs/ai-coder/ai-bridge/monitoring.md +++ b/docs/ai-coder/ai-bridge/monitoring.md @@ -16,10 +16,10 @@ AI Bridge interception data can be exported for external analysis, compliance re ### REST API -You can retrieve AI Bridge interceptions via the Coder API with filtering and pagination support. +You can retrieve AI Bridge sessions via the Coder API, with filtering and pagination support. ```sh -curl -X GET "https://coder.example.com/api/v2/aibridge/interceptions?q=initiator:me" \ +curl -X GET "https://coder.example.com/api/v2/aibridge/sessions" \ -H "Coder-Session-Token: $CODER_SESSION_TOKEN" ``` diff --git a/docs/reference/api/aibridge.md b/docs/reference/api/aibridge.md index d5ca02bd5b..28479b8991 100644 --- a/docs/reference/api/aibridge.md +++ b/docs/reference/api/aibridge.md @@ -137,3 +137,73 @@ curl -X GET http://coder-server:8080/api/v2/aibridge/models \

Response Schema

To perform this operation, you must be authenticated. [Learn more](authentication.md). + +## List AI Bridge sessions + +### Code samples + +```shell +# Example request using curl +curl -X GET http://coder-server:8080/api/v2/aibridge/sessions \ + -H 'Accept: application/json' \ + -H 'Coder-Session-Token: API_KEY' +``` + +`GET /aibridge/sessions` + +### Parameters + +| Name | In | Type | Required | Description | +|--------------------|-------|---------|----------|--------------------------------------------------------------------------------------------------------------------------------------------| +| `q` | query | string | false | Search query in the format `key:value`. Available keys are: initiator, provider, model, client, session_id, started_after, started_before. | +| `limit` | query | integer | false | Page limit | +| `after_session_id` | query | string | false | Cursor pagination after session ID (cannot be used with offset) | +| `offset` | query | integer | false | Offset pagination (cannot be used with after_session_id) | + +### Example responses + +> 200 Response + +```json +{ + "count": 0, + "sessions": [ + { + "client": "string", + "ended_at": "2019-08-24T14:15:22Z", + "id": "string", + "initiator": { + "avatar_url": "http://example.com", + "id": "497f6eca-6276-4993-bfeb-53cbbbba6f08", + "name": "string", + "username": "string" + }, + "last_prompt": "string", + "metadata": { + "property1": null, + "property2": null + }, + "models": [ + "string" + ], + "providers": [ + "string" + ], + "started_at": "2019-08-24T14:15:22Z", + "threads": 0, + "token_usage_summary": { + "input_tokens": 0, + "output_tokens": 0 + } + } + ] +} +``` + +### Responses + +| Status | Meaning | Description | Schema | +|--------|---------------------------------------------------------|-------------|------------------------------------------------------------------------------------------| +| 200 | [OK](https://tools.ietf.org/html/rfc7231#section-6.3.1) | OK | [codersdk.AIBridgeListSessionsResponse](schemas.md#codersdkaibridgelistsessionsresponse) | + +To perform this operation, you must be authenticated. [Learn more](authentication.md). diff --git a/docs/reference/api/schemas.md b/docs/reference/api/schemas.md index 2956c2e7f8..2bea8f2b23 100644 --- a/docs/reference/api/schemas.md +++ b/docs/reference/api/schemas.md @@ -598,6 +598,51 @@ | `count` | integer | false | | | | `results` | array of [codersdk.AIBridgeInterception](#codersdkaibridgeinterception) | false | | | +## codersdk.AIBridgeListSessionsResponse + +```json +{ + "count": 0, + "sessions": [ + { + "client": "string", + "ended_at": "2019-08-24T14:15:22Z", + "id": "string", + "initiator": { + "avatar_url": "http://example.com", + "id": "497f6eca-6276-4993-bfeb-53cbbbba6f08", + "name": "string", + "username": "string" + }, + "last_prompt": "string", + "metadata": { + "property1": null, + "property2": null + }, + "models": [ + "string" + ], + "providers": [ + "string" + ], + "started_at": "2019-08-24T14:15:22Z", + "threads": 0, + "token_usage_summary": { + "input_tokens": 0, + "output_tokens": 0 + } + } + ] +} +``` + +### Properties + +| Name | Type | Required | Restrictions | Description | +|------------|---------------------------------------------------------------|----------|--------------|-------------| +| `count` | integer | false | | | +| `sessions` | array of [codersdk.AIBridgeSession](#codersdkaibridgesession) | false | | | + ## codersdk.AIBridgeOpenAIConfig ```json @@ -650,6 +695,72 @@ | `upstream_proxy` | string | false | | | | `upstream_proxy_ca` | string | false | | | +## codersdk.AIBridgeSession + +```json +{ + "client": "string", + "ended_at": "2019-08-24T14:15:22Z", + "id": "string", + "initiator": { + "avatar_url": "http://example.com", + "id": "497f6eca-6276-4993-bfeb-53cbbbba6f08", + "name": "string", + "username": "string" + }, + "last_prompt": "string", + "metadata": { + "property1": null, + "property2": null + }, + "models": [ + "string" + ], + "providers": [ + "string" + ], + "started_at": "2019-08-24T14:15:22Z", + "threads": 0, + "token_usage_summary": { + "input_tokens": 0, + "output_tokens": 0 + } +} +``` + +### Properties + +| Name | Type | Required | Restrictions | Description | +|-----------------------|----------------------------------------------------------------------------------------|----------|--------------|-------------| +| `client` | string | false | | | +| `ended_at` | string | false | | | +| `id` | string | false | | | +| `initiator` | [codersdk.MinimalUser](#codersdkminimaluser) | false | | | +| `last_prompt` | string | false | | | +| `metadata` | object | false | | | +| » `[any property]` | any | false | | | +| `models` | array of string | false | | | +| `providers` | array of string | false | | | +| `started_at` | string | false | | | +| `threads` | integer | false | | | +| `token_usage_summary` | [codersdk.AIBridgeSessionTokenUsageSummary](#codersdkaibridgesessiontokenusagesummary) | false | | | + +## codersdk.AIBridgeSessionTokenUsageSummary + +```json +{ + "input_tokens": 0, + "output_tokens": 0 +} +``` + +### Properties + +| Name | Type | Required | Restrictions | Description | +|-----------------|---------|----------|--------------|-------------| +| `input_tokens` | integer | false | | | +| `output_tokens` | integer | false | | | + ## codersdk.AIBridgeTokenUsage ```json diff --git a/enterprise/coderd/aibridge.go b/enterprise/coderd/aibridge.go index ce988006d3..11d972531e 100644 --- a/enterprise/coderd/aibridge.go +++ b/enterprise/coderd/aibridge.go @@ -2,6 +2,7 @@ package coderd import ( "context" + "database/sql" "fmt" "net/http" "time" @@ -10,6 +11,7 @@ import ( "github.com/google/uuid" "golang.org/x/xerrors" + "cdr.dev/slog/v3" "github.com/coder/coder/v2/coderd" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/db2sdk" @@ -22,8 +24,10 @@ import ( const ( maxListInterceptionsLimit = 1000 + maxListSessionsLimit = 1000 maxListModelsLimit = 1000 defaultListInterceptionsLimit = 100 + defaultListSessionsLimit = 100 defaultListModelsLimit = 100 // aiBridgeRateLimitWindow is the fixed duration for rate limiting AI Bridge // requests. This is hardcoded to keep configuration simple. @@ -43,6 +47,7 @@ func aibridgeHandler(api *API, middlewares ...func(http.Handler) http.Handler) f r.Group(func(r chi.Router) { r.Use(middlewares...) r.Get("/interceptions", api.aiBridgeListInterceptions) + r.Get("/sessions", api.aiBridgeListSessions) r.Get("/models", api.aiBridgeListModels) }) @@ -176,6 +181,130 @@ func (api *API) aiBridgeListInterceptions(rw http.ResponseWriter, r *http.Reques }) } +// aiBridgeListSessions returns AI Bridge sessions (aggregated interceptions). +// +// @Summary List AI Bridge sessions +// @ID list-ai-bridge-sessions +// @Security CoderSessionToken +// @Produce json +// @Tags AI Bridge +// @Param q query string false "Search query in the format `key:value`. Available keys are: initiator, provider, model, client, session_id, started_after, started_before." +// @Param limit query int false "Page limit" +// @Param after_session_id query string false "Cursor pagination after session ID (cannot be used with offset)" +// @Param offset query int false "Offset pagination (cannot be used with after_session_id)" +// @Success 200 {object} codersdk.AIBridgeListSessionsResponse +// @Router /aibridge/sessions [get] +func (api *API) aiBridgeListSessions(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + apiKey := httpmw.APIKey(r) + + page, ok := coderd.ParsePagination(rw, r) + if !ok { + return + } + + afterSessionID := r.URL.Query().Get("after_session_id") + if afterSessionID != "" && page.Offset != 0 { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Query parameters have invalid values.", + Detail: "Cannot use both after_session_id and offset pagination in the same request.", + }) + return + } + if page.Limit == 0 { + page.Limit = defaultListSessionsLimit + } + if page.Limit > maxListSessionsLimit || page.Limit < 1 { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid pagination limit value.", + Detail: fmt.Sprintf("Pagination limit must be in range (0, %d]", maxListSessionsLimit), + }) + return + } + + queryStr := r.URL.Query().Get("q") + filter, errs := searchquery.AIBridgeSessions(ctx, api.Database, queryStr, page, apiKey.UserID, afterSessionID) + if len(errs) > 0 { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid session search query.", + Validations: errs, + }) + return + } + + // Validate the cursor session exists before running the main query. + if afterSessionID != "" { + //nolint:exhaustruct // Only need session_id filter and limit. + cursor, err := api.Database.ListAIBridgeSessions(ctx, database.ListAIBridgeSessionsParams{ + SessionID: afterSessionID, + Limit: 1, + }) + if err != nil { + api.Logger.Error(ctx, "error validating after_session_id cursor", slog.Error(err)) + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error validating after_session_id cursor.", + Detail: "", // Don't leak database issue to client. + }) + return + } + if len(cursor) == 0 { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Query parameter has invalid value.", + Detail: fmt.Sprintf("after_session_id: session %q not found", afterSessionID), + }) + return + } + } + + var ( + count int64 + rows []database.ListAIBridgeSessionsRow + ) + err := api.Database.InTx(func(db database.Store) error { + var err error + count, err = db.CountAIBridgeSessions(ctx, database.CountAIBridgeSessionsParams{ + StartedAfter: filter.StartedAfter, + StartedBefore: filter.StartedBefore, + InitiatorID: filter.InitiatorID, + Provider: filter.Provider, + Model: filter.Model, + Client: filter.Client, + SessionID: filter.SessionID, + }) + if err != nil { + return xerrors.Errorf("count authorized aibridge sessions: %w", err) + } + + rows, err = db.ListAIBridgeSessions(ctx, filter) + if err != nil { + return xerrors.Errorf("list aibridge sessions: %w", err) + } + + return nil + }, &database.TxOptions{ + Isolation: sql.LevelRepeatableRead, // Consistency across queries tables while writes may be occurring. + ReadOnly: true, + TxIdentifier: "aibridge_list_sessions", + }) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error getting AI Bridge sessions.", + Detail: err.Error(), + }) + return + } + + sessions := make([]codersdk.AIBridgeSession, len(rows)) + for i, row := range rows { + sessions[i] = db2sdk.AIBridgeSession(row) + } + + httpapi.Write(ctx, rw, http.StatusOK, codersdk.AIBridgeListSessionsResponse{ + Count: count, + Sessions: sessions, + }) +} + // aiBridgeListModels returns all AI Bridge models a user can see. // // @Summary List AI Bridge models diff --git a/enterprise/coderd/aibridge_test.go b/enterprise/coderd/aibridge_test.go index c2a57e25ba..197f1763cd 100644 --- a/enterprise/coderd/aibridge_test.go +++ b/enterprise/coderd/aibridge_test.go @@ -665,6 +665,560 @@ func TestAIBridgeListInterceptions(t *testing.T) { }) } +func aibridgeOpts(t *testing.T) *coderdenttest.Options { + t.Helper() + dv := coderdtest.DeploymentValues(t) + dv.AI.BridgeConfig.Enabled = serpent.Bool(true) + return &coderdenttest.Options{ + Options: &coderdtest.Options{ + DeploymentValues: dv, + }, + LicenseOptions: &coderdenttest.LicenseOptions{ + Features: license.Features{ + codersdk.FeatureAIBridge: 1, + }, + }, + } +} + +func TestAIBridgeListSessions(t *testing.T) { + t.Parallel() + + t.Run("EmptyDB", func(t *testing.T) { + t.Parallel() + client, _ := coderdenttest.New(t, aibridgeOpts(t)) + ctx := testutil.Context(t, testutil.WaitLong) + //nolint:gocritic // Owner role is irrelevant here. + res, err := client.AIBridgeListSessions(ctx, codersdk.AIBridgeListSessionsFilter{}) + require.NoError(t, err) + require.Empty(t, res.Sessions) + require.EqualValues(t, 0, res.Count) + }) + + t.Run("OK", func(t *testing.T) { + t.Parallel() + client, db, firstUser := coderdenttest.NewWithDatabase(t, aibridgeOpts(t)) + ctx := testutil.Context(t, testutil.WaitLong) + + now := dbtime.Now() + + // Session 1: Two interceptions sharing client_session_id "session-A". + s1i1EndedAt := now.Add(time.Minute) + s1i1 := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ + InitiatorID: firstUser.UserID, + Provider: "anthropic", + Model: "claude-4", + StartedAt: now, + Client: sql.NullString{String: "claude-code", Valid: true}, + ClientSessionID: sql.NullString{String: "session-A", Valid: true}, + }, &s1i1EndedAt) + s1i2EndedAt := now.Add(2 * time.Minute) + dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ + InitiatorID: firstUser.UserID, + Provider: "anthropic", + Model: "claude-4-haiku", + StartedAt: now.Add(time.Minute), + Client: sql.NullString{String: "claude-code", Valid: true}, + ClientSessionID: sql.NullString{String: "session-A", Valid: true}, + ThreadRootInterceptionID: uuid.NullUUID{UUID: s1i1.ID, Valid: true}, + ThreadParentInterceptionID: uuid.NullUUID{UUID: s1i1.ID, Valid: true}, + }, &s1i2EndedAt) + + // Add token usages to session 1 interceptions. + dbgen.AIBridgeTokenUsage(t, db, database.InsertAIBridgeTokenUsageParams{ + InterceptionID: s1i1.ID, + InputTokens: 100, + OutputTokens: 50, + CreatedAt: now, + }) + dbgen.AIBridgeTokenUsage(t, db, database.InsertAIBridgeTokenUsageParams{ + InterceptionID: s1i1.ID, + InputTokens: 200, + OutputTokens: 75, + CreatedAt: now.Add(time.Second), + }) + + // Add user prompts to session 1. + dbgen.AIBridgeUserPrompt(t, db, database.InsertAIBridgeUserPromptParams{ + InterceptionID: s1i1.ID, + Prompt: "first prompt", + CreatedAt: now, + }) + dbgen.AIBridgeUserPrompt(t, db, database.InsertAIBridgeUserPromptParams{ + InterceptionID: s1i1.ID, + Prompt: "last prompt in session", + CreatedAt: now.Add(time.Minute), + }) + + // Session 2: Thread-based session (no client_session_id, shared thread_root_id). + s2i1EndedAt := now.Add(-time.Hour + time.Minute) + s2i1 := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ + InitiatorID: firstUser.UserID, + Provider: "openai", + Model: "gpt-4", + StartedAt: now.Add(-time.Hour), + }, &s2i1EndedAt) + s2i2EndedAt := now.Add(-time.Hour + 2*time.Minute) + dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ + InitiatorID: firstUser.UserID, + Provider: "openai", + Model: "gpt-4", + StartedAt: now.Add(-time.Hour + time.Minute), + ThreadRootInterceptionID: uuid.NullUUID{UUID: s2i1.ID, Valid: true}, + ThreadParentInterceptionID: uuid.NullUUID{UUID: s2i1.ID, Valid: true}, + }, &s2i2EndedAt) + + // Session 3: Standalone interception (no client_session_id, no thread_root_id). + s3EndedAt := now.Add(-2*time.Hour + time.Minute) + s3i1 := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ + InitiatorID: firstUser.UserID, + Provider: "anthropic", + Model: "claude-4", + StartedAt: now.Add(-2 * time.Hour), + }, &s3EndedAt) + + // Session 4: Two distinct thread roots in one client_session_id. + s4i1EndedAt := now.Add(-3*time.Hour + time.Minute) + dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ + InitiatorID: firstUser.UserID, + Provider: "anthropic", + Model: "claude-4", + StartedAt: now.Add(-3 * time.Hour), + ClientSessionID: sql.NullString{String: "session-multi", Valid: true}, + }, &s4i1EndedAt) + s4i2EndedAt := now.Add(-3*time.Hour + 2*time.Minute) + dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ + InitiatorID: firstUser.UserID, + Provider: "openai", + Model: "gpt-4", + StartedAt: now.Add(-3*time.Hour + time.Minute), + ClientSessionID: sql.NullString{String: "session-multi", Valid: true}, + }, &s4i2EndedAt) + + //nolint:gocritic // Owner role is irrelevant here. + res, err := client.AIBridgeListSessions(ctx, codersdk.AIBridgeListSessionsFilter{}) + require.NoError(t, err) + require.EqualValues(t, 4, res.Count) + require.Len(t, res.Sessions, 4) + + // Sessions ordered by started_at DESC: session-A (now), then + // thread-based (now-1h), then standalone (now-2h), then + // multi-thread (now-3h). + require.Equal(t, "session-A", res.Sessions[0].ID) + require.Equal(t, s2i1.ID.String(), res.Sessions[1].ID) + require.Equal(t, s3i1.ID.String(), res.Sessions[2].ID) + require.Equal(t, "session-multi", res.Sessions[3].ID) + + // Verify session 1 aggregations. + s1 := res.Sessions[0] + require.ElementsMatch(t, []string{"anthropic"}, s1.Providers) + require.ElementsMatch(t, []string{"claude-4", "claude-4-haiku"}, s1.Models) + require.NotNil(t, s1.Client) + require.Equal(t, "claude-code", *s1.Client) + require.EqualValues(t, 300, s1.TokenUsageSummary.InputTokens) + require.EqualValues(t, 125, s1.TokenUsageSummary.OutputTokens) + require.NotNil(t, s1.LastPrompt) + require.Equal(t, "last prompt in session", *s1.LastPrompt) + // Two interceptions in session-A, but they share a thread root, + // so thread count is 1. + require.EqualValues(t, 1, s1.Threads) + + // Verify session 2 (thread-based). + s2 := res.Sessions[1] + require.ElementsMatch(t, []string{"openai"}, s2.Providers) + // Thread count: the root interception and its child share the same + // thread root, so count is 1. + require.EqualValues(t, 1, s2.Threads) + + // Verify session 3 (standalone). + s3 := res.Sessions[2] + require.EqualValues(t, 1, s3.Threads) + require.Nil(t, s3.LastPrompt) + + // Verify session 4 (multiple threads). Thread A has a root + + // child (1 thread), thread B is a standalone root (1 thread), + // so total is 2. + s4 := res.Sessions[3] + require.EqualValues(t, 2, s4.Threads) + require.ElementsMatch(t, []string{"anthropic", "openai"}, s4.Providers) + require.ElementsMatch(t, []string{"claude-4", "gpt-4"}, s4.Models) + }) + + t.Run("Pagination", func(t *testing.T) { + t.Parallel() + client, db, firstUser := coderdenttest.NewWithDatabase(t, aibridgeOpts(t)) + ctx := testutil.Context(t, testutil.WaitLong) + + now := dbtime.Now() + // Create 5 standalone sessions with different start times. + allSessionIDs := make([]string, 5) + for i := range 5 { + endedAt := now.Add(-time.Duration(i)*time.Hour + time.Minute) + intc := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ + InitiatorID: firstUser.UserID, + StartedAt: now.Add(-time.Duration(i) * time.Hour), + }, &endedAt) + // Standalone session: ID = interception UUID string. + allSessionIDs[i] = intc.ID.String() + } + + // Test offset pagination. + //nolint:gocritic // Owner role is irrelevant here. + res, err := client.AIBridgeListSessions(ctx, codersdk.AIBridgeListSessionsFilter{ + Pagination: codersdk.Pagination{Limit: 2}, + }) + require.NoError(t, err) + require.Len(t, res.Sessions, 2) + require.EqualValues(t, 5, res.Count) + require.Equal(t, allSessionIDs[0], res.Sessions[0].ID) + require.Equal(t, allSessionIDs[1], res.Sessions[1].ID) + + // Second page with offset. + res, err = client.AIBridgeListSessions(ctx, codersdk.AIBridgeListSessionsFilter{ + Pagination: codersdk.Pagination{Limit: 2, Offset: 2}, + }) + require.NoError(t, err) + require.Len(t, res.Sessions, 2) + require.Equal(t, allSessionIDs[2], res.Sessions[0].ID) + require.Equal(t, allSessionIDs[3], res.Sessions[1].ID) + + // Test cursor pagination. + res, err = client.AIBridgeListSessions(ctx, codersdk.AIBridgeListSessionsFilter{ + Pagination: codersdk.Pagination{Limit: 2}, + AfterSessionID: allSessionIDs[1], + }) + require.NoError(t, err) + require.Len(t, res.Sessions, 2) + require.Equal(t, allSessionIDs[2], res.Sessions[0].ID) + require.Equal(t, allSessionIDs[3], res.Sessions[1].ID) + + // Test mutual exclusion of cursor and offset. + _, err = client.AIBridgeListSessions(ctx, codersdk.AIBridgeListSessionsFilter{ + Pagination: codersdk.Pagination{Limit: 2, Offset: 1}, + AfterSessionID: allSessionIDs[0], + }) + var sdkErr *codersdk.Error + require.ErrorAs(t, err, &sdkErr) + require.Contains(t, sdkErr.Detail, "Cannot use both after_session_id and offset pagination") + }) + + t.Run("AfterSessionIDNotFound", func(t *testing.T) { + t.Parallel() + client, _ := coderdenttest.New(t, aibridgeOpts(t)) + ctx := testutil.Context(t, testutil.WaitLong) + + //nolint:gocritic // Owner role is irrelevant here. + _, err := client.AIBridgeListSessions(ctx, codersdk.AIBridgeListSessionsFilter{ + Pagination: codersdk.Pagination{Limit: 10}, + AfterSessionID: "nonexistent-session-id", + }) + var sdkErr *codersdk.Error + require.ErrorAs(t, err, &sdkErr) + require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode()) + require.Equal(t, `after_session_id: session "nonexistent-session-id" not found`, sdkErr.Detail) + }) + + t.Run("Filters", func(t *testing.T) { + t.Parallel() + client, db, firstUser := coderdenttest.NewWithDatabase(t, aibridgeOpts(t)) + ctx := testutil.Context(t, testutil.WaitLong) + + _, user2 := coderdtest.CreateAnotherUser(t, client, firstUser.OrganizationID) + + now := dbtime.Now() + + // Session from user1 with provider "anthropic" and client "claude-code". + s1EndedAt := now.Add(time.Minute) + s1 := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ + InitiatorID: firstUser.UserID, + Provider: "anthropic", + Model: "claude-4", + StartedAt: now, + Client: sql.NullString{String: "claude-code", Valid: true}, + }, &s1EndedAt) + + // Session from user2 with provider "openai". + s2EndedAt := now.Add(-time.Hour + time.Minute) + s2 := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ + InitiatorID: user2.ID, + Provider: "openai", + Model: "gpt-4", + StartedAt: now.Add(-time.Hour), + }, &s2EndedAt) + + // Filter by initiator. + //nolint:gocritic // Owner role is irrelevant; testing filter behavior. + res, err := client.AIBridgeListSessions(ctx, codersdk.AIBridgeListSessionsFilter{ + Initiator: user2.Username, + }) + require.NoError(t, err) + require.EqualValues(t, 1, res.Count) + require.Equal(t, s2.ID.String(), res.Sessions[0].ID) + + // Filter by provider. + res, err = client.AIBridgeListSessions(ctx, codersdk.AIBridgeListSessionsFilter{ + Provider: "anthropic", + }) + require.NoError(t, err) + require.EqualValues(t, 1, res.Count) + require.Equal(t, s1.ID.String(), res.Sessions[0].ID) + + // Filter by model. + res, err = client.AIBridgeListSessions(ctx, codersdk.AIBridgeListSessionsFilter{ + Model: "gpt-4", + }) + require.NoError(t, err) + require.EqualValues(t, 1, res.Count) + require.Equal(t, s2.ID.String(), res.Sessions[0].ID) + + // Filter by client. + res, err = client.AIBridgeListSessions(ctx, codersdk.AIBridgeListSessionsFilter{ + Client: "claude-code", + }) + require.NoError(t, err) + require.EqualValues(t, 1, res.Count) + require.Equal(t, s1.ID.String(), res.Sessions[0].ID) + + // Filter by time range. + res, err = client.AIBridgeListSessions(ctx, codersdk.AIBridgeListSessionsFilter{ + StartedAfter: now.Add(-30 * time.Minute), + }) + require.NoError(t, err) + require.EqualValues(t, 1, res.Count) + require.Equal(t, s1.ID.String(), res.Sessions[0].ID) + + // Filter by session_id. + res, err = client.AIBridgeListSessions(ctx, codersdk.AIBridgeListSessionsFilter{ + SessionID: s2.ID.String(), + }) + require.NoError(t, err) + require.EqualValues(t, 1, res.Count) + require.Len(t, res.Sessions, 1) + require.Equal(t, s2.ID.String(), res.Sessions[0].ID) + + // Filter by session_id with no match. + res, err = client.AIBridgeListSessions(ctx, codersdk.AIBridgeListSessionsFilter{ + SessionID: "nonexistent-session-id", + }) + require.NoError(t, err) + require.EqualValues(t, 0, res.Count) + require.Empty(t, res.Sessions) + }) + + t.Run("Authorized", func(t *testing.T) { + t.Parallel() + adminClient, db, firstUser := coderdenttest.NewWithDatabase(t, aibridgeOpts(t)) + ctx := testutil.Context(t, testutil.WaitLong) + + secondUserClient, secondUser := coderdtest.CreateAnotherUser(t, adminClient, firstUser.OrganizationID) + + now := dbtime.Now() + i1EndedAt := now.Add(time.Minute) + i1 := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ + InitiatorID: firstUser.UserID, + StartedAt: now, + }, &i1EndedAt) + i2 := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ + InitiatorID: secondUser.ID, + StartedAt: now.Add(-time.Hour), + }, &now) + + // Admin can see all sessions. + //nolint:gocritic // Intentionally testing admin/owner visibility. + res, err := adminClient.AIBridgeListSessions(ctx, codersdk.AIBridgeListSessionsFilter{}) + require.NoError(t, err) + require.EqualValues(t, 2, res.Count) + require.Len(t, res.Sessions, 2) + require.Equal(t, i1.ID.String(), res.Sessions[0].ID) + require.Equal(t, i2.ID.String(), res.Sessions[1].ID) + + // Second user can only see their own sessions. + res, err = secondUserClient.AIBridgeListSessions(ctx, codersdk.AIBridgeListSessionsFilter{}) + require.NoError(t, err) + require.EqualValues(t, 1, res.Count) + require.Len(t, res.Sessions, 1) + require.Equal(t, i2.ID.String(), res.Sessions[0].ID) + }) + + t.Run("SessionIDCollisionAcrossUsers", func(t *testing.T) { + t.Parallel() + client, db, firstUser := coderdenttest.NewWithDatabase(t, aibridgeOpts(t)) + ctx := testutil.Context(t, testutil.WaitLong) + + _, user2 := coderdtest.CreateAnotherUser(t, client, firstUser.OrganizationID) + + now := dbtime.Now() + + // Two users share the same client_session_id. They must be + // treated as distinct sessions. + sharedSessionID := "shared-session-id" + u1EndedAt := now.Add(time.Minute) + u1Interception := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ + InitiatorID: firstUser.UserID, + Provider: "anthropic", + Model: "claude-4", + StartedAt: now, + Client: sql.NullString{String: "claude-code", Valid: true}, + ClientSessionID: sql.NullString{String: sharedSessionID, Valid: true}, + }, &u1EndedAt) + dbgen.AIBridgeTokenUsage(t, db, database.InsertAIBridgeTokenUsageParams{ + InterceptionID: u1Interception.ID, + InputTokens: 100, + OutputTokens: 50, + CreatedAt: now, + }) + + u2EndedAt := now.Add(-time.Hour + time.Minute) + u2Interception := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ + InitiatorID: user2.ID, + Provider: "openai", + Model: "gpt-4", + StartedAt: now.Add(-time.Hour), + Client: sql.NullString{String: "cursor", Valid: true}, + ClientSessionID: sql.NullString{String: sharedSessionID, Valid: true}, + }, &u2EndedAt) + dbgen.AIBridgeTokenUsage(t, db, database.InsertAIBridgeTokenUsageParams{ + InterceptionID: u2Interception.ID, + InputTokens: 200, + OutputTokens: 75, + CreatedAt: now.Add(-time.Hour), + }) + + // Admin should see two distinct sessions despite the shared + // session_id, each with the correct user and token counts. + //nolint:gocritic // Owner role is irrelevant; testing collision behavior. + res, err := client.AIBridgeListSessions(ctx, codersdk.AIBridgeListSessionsFilter{}) + require.NoError(t, err) + require.EqualValues(t, 2, res.Count) + require.Len(t, res.Sessions, 2) + + // Both sessions share the same ID string but belong to + // different users. + require.Equal(t, sharedSessionID, res.Sessions[0].ID) + require.Equal(t, sharedSessionID, res.Sessions[1].ID) + require.NotEqual(t, res.Sessions[0].Initiator.ID, res.Sessions[1].Initiator.ID) + + // Verify token counts are not merged across users. + for _, s := range res.Sessions { + if s.Initiator.ID == firstUser.UserID { + require.EqualValues(t, 100, s.TokenUsageSummary.InputTokens) + require.EqualValues(t, 50, s.TokenUsageSummary.OutputTokens) + } else { + require.EqualValues(t, 200, s.TokenUsageSummary.InputTokens) + require.EqualValues(t, 75, s.TokenUsageSummary.OutputTokens) + } + } + }) + + t.Run("InflightSessions", func(t *testing.T) { + t.Parallel() + client, db, firstUser := coderdenttest.NewWithDatabase(t, aibridgeOpts(t)) + ctx := testutil.Context(t, testutil.WaitLong) + + now := dbtime.Now() + i1EndedAt := now.Add(time.Minute) + i1 := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ + InitiatorID: firstUser.UserID, + StartedAt: now, + }, &i1EndedAt) + // Inflight interception (no ended_at) should not appear as a session. + dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ + InitiatorID: firstUser.UserID, + StartedAt: now.Add(-time.Hour), + }, nil) + + //nolint:gocritic // Owner role is irrelevant; testing inflight filtering. + res, err := client.AIBridgeListSessions(ctx, codersdk.AIBridgeListSessionsFilter{}) + require.NoError(t, err) + require.EqualValues(t, 1, res.Count) + require.Len(t, res.Sessions, 1) + require.Equal(t, i1.ID.String(), res.Sessions[0].ID) + }) + + t.Run("FilterErrors", func(t *testing.T) { + t.Parallel() + client, _ := coderdenttest.New(t, aibridgeOpts(t)) + + cases := []struct { + name string + q string + want []codersdk.ValidationError + }{ + { + name: "UnknownUsername", + q: "initiator:unknown", + want: []codersdk.ValidationError{ + { + Field: "initiator", + Detail: `Query param "initiator" has invalid value: user "unknown" either does not exist, or you are unauthorized to view them`, + }, + }, + }, + { + name: "InvalidStartedAfter", + q: "started_after:invalid", + want: []codersdk.ValidationError{ + { + Field: "started_after", + Detail: `Query param "started_after" must be a valid date format (2006-01-02T15:04:05.999999999Z07:00): parsing time "INVALID" as "2006-01-02T15:04:05.999999999Z07:00": cannot parse "INVALID" as "2006"`, + }, + }, + }, + { + name: "InvalidStartedBefore", + q: "started_before:invalid", + want: []codersdk.ValidationError{ + { + Field: "started_before", + Detail: `Query param "started_before" must be a valid date format (2006-01-02T15:04:05.999999999Z07:00): parsing time "INVALID" as "2006-01-02T15:04:05.999999999Z07:00": cannot parse "INVALID" as "2006"`, + }, + }, + }, + { + name: "InvalidBeforeAfterRange", + q: `started_after:"2025-01-01T00:00:00Z" started_before:"2024-01-01T00:00:00Z"`, + want: []codersdk.ValidationError{ + { + Field: "started_before", + Detail: `Query param "started_before" has invalid value: "started_before" must be after "started_after" if set`, + }, + }, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitLong) + res, err := client.AIBridgeListSessions(ctx, codersdk.AIBridgeListSessionsFilter{ + FilterQuery: tc.q, + }) + var sdkErr *codersdk.Error + require.ErrorAs(t, err, &sdkErr) + require.Equal(t, tc.want, sdkErr.Validations) + require.Empty(t, res.Sessions) + }) + } + }) + + t.Run("PaginationLimitValidation", func(t *testing.T) { + t.Parallel() + client, _ := coderdenttest.New(t, aibridgeOpts(t)) + ctx := testutil.Context(t, testutil.WaitLong) + + //nolint:gocritic // Owner role is irrelevant; testing pagination validation. + res, err := client.AIBridgeListSessions(ctx, codersdk.AIBridgeListSessionsFilter{ + Pagination: codersdk.Pagination{ + Limit: 1001, + }, + }) + var sdkErr *codersdk.Error + require.ErrorAs(t, err, &sdkErr) + require.Contains(t, sdkErr.Message, "Invalid pagination limit value.") + require.Empty(t, res.Sessions) + }) +} + func TestAIBridgeRouting(t *testing.T) { t.Parallel() diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index c6d13d3f83..4cb3cf5087 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -75,6 +75,12 @@ export interface AIBridgeListInterceptionsResponse { readonly results: readonly AIBridgeInterception[]; } +// From codersdk/aibridge.go +export interface AIBridgeListSessionsResponse { + readonly count: number; + readonly sessions: readonly AIBridgeSession[]; +} + // From codersdk/deployment.go export interface AIBridgeOpenAIConfig { readonly base_url: string; @@ -95,6 +101,28 @@ export interface AIBridgeProxyConfig { readonly allowed_private_cidrs: string; } +// From codersdk/aibridge.go +export interface AIBridgeSession { + readonly id: string; + readonly initiator: MinimalUser; + readonly providers: readonly string[]; + readonly models: readonly string[]; + readonly client: string | null; + // empty interface{} type, falling back to unknown + readonly metadata: Record; + readonly started_at: string; + readonly ended_at?: string; + readonly threads: number; + readonly token_usage_summary: AIBridgeSessionTokenUsageSummary; + readonly last_prompt?: string; +} + +// From codersdk/aibridge.go +export interface AIBridgeSessionTokenUsageSummary { + readonly input_tokens: number; + readonly output_tokens: number; +} + // From codersdk/aibridge.go export interface AIBridgeTokenUsage { readonly id: string;