diff --git a/coderd/apidoc/docs.go b/coderd/apidoc/docs.go index 683fb1a4e1..2d18a69576 100644 --- a/coderd/apidoc/docs.go +++ b/coderd/apidoc/docs.go @@ -163,6 +163,57 @@ const docTemplate = `{ ] } }, + "/aibridge/sessions": { + "get": { + "produces": [ + "application/json" + ], + "tags": [ + "AI Bridge" + ], + "summary": "List AI Bridge sessions", + "operationId": "list-ai-bridge-sessions", + "parameters": [ + { + "type": "string", + "description": "Search query in the format ` + "`" + `key:value` + "`" + `. Available keys are: initiator, provider, model, client, session_id, started_after, started_before.", + "name": "q", + "in": "query" + }, + { + "type": "integer", + "description": "Page limit", + "name": "limit", + "in": "query" + }, + { + "type": "string", + "description": "Cursor pagination after session ID (cannot be used with offset)", + "name": "after_session_id", + "in": "query" + }, + { + "type": "integer", + "description": "Offset pagination (cannot be used with after_session_id)", + "name": "offset", + "in": "query" + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/codersdk.AIBridgeListSessionsResponse" + } + } + }, + "security": [ + { + "CoderSessionToken": [] + } + ] + } + }, "/appearance": { "get": { "produces": [ @@ -12778,6 +12829,20 @@ const docTemplate = `{ } } }, + "codersdk.AIBridgeListSessionsResponse": { + "type": "object", + "properties": { + "count": { + "type": "integer" + }, + "sessions": { + "type": "array", + "items": { + "$ref": "#/definitions/codersdk.AIBridgeSession" + } + } + } + }, "codersdk.AIBridgeOpenAIConfig": { "type": "object", "properties": { @@ -12830,6 +12895,64 @@ const docTemplate = `{ } } }, + "codersdk.AIBridgeSession": { + "type": "object", + "properties": { + "client": { + "type": "string" + }, + "ended_at": { + "type": "string", + "format": "date-time" + }, + "id": { + "type": "string" + }, + "initiator": { + "$ref": "#/definitions/codersdk.MinimalUser" + }, + "last_prompt": { + "type": "string" + }, + "metadata": { + "type": "object", + "additionalProperties": {} + }, + "models": { + "type": "array", + "items": { + "type": "string" + } + }, + "providers": { + "type": "array", + "items": { + "type": "string" + } + }, + "started_at": { + "type": "string", + "format": "date-time" + }, + "threads": { + "type": "integer" + }, + "token_usage_summary": { + "$ref": "#/definitions/codersdk.AIBridgeSessionTokenUsageSummary" + } + } + }, + "codersdk.AIBridgeSessionTokenUsageSummary": { + "type": "object", + "properties": { + "input_tokens": { + "type": "integer" + }, + "output_tokens": { + "type": "integer" + } + } + }, "codersdk.AIBridgeTokenUsage": { "type": "object", "properties": { diff --git a/coderd/apidoc/swagger.json b/coderd/apidoc/swagger.json index d5a74adb09..075536ae96 100644 --- a/coderd/apidoc/swagger.json +++ b/coderd/apidoc/swagger.json @@ -136,6 +136,53 @@ ] } }, + "/aibridge/sessions": { + "get": { + "produces": ["application/json"], + "tags": ["AI Bridge"], + "summary": "List AI Bridge sessions", + "operationId": "list-ai-bridge-sessions", + "parameters": [ + { + "type": "string", + "description": "Search query in the format `key:value`. Available keys are: initiator, provider, model, client, session_id, started_after, started_before.", + "name": "q", + "in": "query" + }, + { + "type": "integer", + "description": "Page limit", + "name": "limit", + "in": "query" + }, + { + "type": "string", + "description": "Cursor pagination after session ID (cannot be used with offset)", + "name": "after_session_id", + "in": "query" + }, + { + "type": "integer", + "description": "Offset pagination (cannot be used with after_session_id)", + "name": "offset", + "in": "query" + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/codersdk.AIBridgeListSessionsResponse" + } + } + }, + "security": [ + { + "CoderSessionToken": [] + } + ] + } + }, "/appearance": { "get": { "produces": ["application/json"], @@ -11368,6 +11415,20 @@ } } }, + "codersdk.AIBridgeListSessionsResponse": { + "type": "object", + "properties": { + "count": { + "type": "integer" + }, + "sessions": { + "type": "array", + "items": { + "$ref": "#/definitions/codersdk.AIBridgeSession" + } + } + } + }, "codersdk.AIBridgeOpenAIConfig": { "type": "object", "properties": { @@ -11420,6 +11481,64 @@ } } }, + "codersdk.AIBridgeSession": { + "type": "object", + "properties": { + "client": { + "type": "string" + }, + "ended_at": { + "type": "string", + "format": "date-time" + }, + "id": { + "type": "string" + }, + "initiator": { + "$ref": "#/definitions/codersdk.MinimalUser" + }, + "last_prompt": { + "type": "string" + }, + "metadata": { + "type": "object", + "additionalProperties": {} + }, + "models": { + "type": "array", + "items": { + "type": "string" + } + }, + "providers": { + "type": "array", + "items": { + "type": "string" + } + }, + "started_at": { + "type": "string", + "format": "date-time" + }, + "threads": { + "type": "integer" + }, + "token_usage_summary": { + "$ref": "#/definitions/codersdk.AIBridgeSessionTokenUsageSummary" + } + } + }, + "codersdk.AIBridgeSessionTokenUsageSummary": { + "type": "object", + "properties": { + "input_tokens": { + "type": "integer" + }, + "output_tokens": { + "type": "integer" + } + } + }, "codersdk.AIBridgeTokenUsage": { "type": "object", "properties": { diff --git a/coderd/database/db2sdk/db2sdk.go b/coderd/database/db2sdk/db2sdk.go index 77b280ff7b..d79d35b88b 100644 --- a/coderd/database/db2sdk/db2sdk.go +++ b/coderd/database/db2sdk/db2sdk.go @@ -1021,6 +1021,44 @@ func AIBridgeInterception(interception database.AIBridgeInterception, initiator return intc } +func AIBridgeSession(row database.ListAIBridgeSessionsRow) codersdk.AIBridgeSession { + session := codersdk.AIBridgeSession{ + ID: row.SessionID, + Initiator: MinimalUserFromVisibleUser(database.VisibleUser{ + ID: row.UserID, + Username: row.UserUsername, + Name: row.UserName, + AvatarURL: row.UserAvatarUrl, + }), + Providers: row.Providers, + Models: row.Models, + Metadata: jsonOrEmptyMap(pqtype.NullRawMessage{RawMessage: row.Metadata, Valid: len(row.Metadata) > 0}), + StartedAt: row.StartedAt, + Threads: row.Threads, + TokenUsageSummary: codersdk.AIBridgeSessionTokenUsageSummary{ + InputTokens: row.InputTokens, + OutputTokens: row.OutputTokens, + }, + } + // Ensure non-nil slices for JSON serialization. + if session.Providers == nil { + session.Providers = []string{} + } + if session.Models == nil { + session.Models = []string{} + } + if row.Client != "" { + session.Client = &row.Client + } + if !row.EndedAt.IsZero() { + session.EndedAt = &row.EndedAt + } + if row.LastPrompt != "" { + session.LastPrompt = &row.LastPrompt + } + return session +} + func AIBridgeTokenUsage(usage database.AIBridgeTokenUsage) codersdk.AIBridgeTokenUsage { return codersdk.AIBridgeTokenUsage{ ID: usage.ID, diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 310d5de494..321f325ef2 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -1709,6 +1709,14 @@ func (q *querier) CountAIBridgeInterceptions(ctx context.Context, arg database.C return q.db.CountAuthorizedAIBridgeInterceptions(ctx, arg, prep) } +func (q *querier) CountAIBridgeSessions(ctx context.Context, arg database.CountAIBridgeSessionsParams) (int64, error) { + prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAibridgeInterception.Type) + if err != nil { + return 0, xerrors.Errorf("(dev error) prepare sql filter: %w", err) + } + return q.db.CountAuthorizedAIBridgeSessions(ctx, arg, prep) +} + func (q *querier) CountAuditLogs(ctx context.Context, arg database.CountAuditLogsParams) (int64, error) { // Shortcut if the user is an owner. The SQL filter is noticeable, // and this is an easy win for owners. Which is the common case. @@ -5317,6 +5325,14 @@ func (q *querier) ListAIBridgeModels(ctx context.Context, arg database.ListAIBri return q.db.ListAuthorizedAIBridgeModels(ctx, arg, prep) } +func (q *querier) ListAIBridgeSessions(ctx context.Context, arg database.ListAIBridgeSessionsParams) ([]database.ListAIBridgeSessionsRow, error) { + prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAibridgeInterception.Type) + if err != nil { + return nil, xerrors.Errorf("(dev error) prepare sql filter: %w", err) + } + return q.db.ListAuthorizedAIBridgeSessions(ctx, arg, prep) +} + func (q *querier) ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIDs []uuid.UUID) ([]database.AIBridgeTokenUsage, error) { // This function is a system function until we implement a join for aibridge interceptions. // Matches the behavior of the workspaces listing endpoint. @@ -7128,6 +7144,14 @@ func (q *querier) ListAuthorizedAIBridgeModels(ctx context.Context, arg database return q.ListAIBridgeModels(ctx, arg) } +func (q *querier) ListAuthorizedAIBridgeSessions(ctx context.Context, arg database.ListAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) ([]database.ListAIBridgeSessionsRow, error) { + return q.db.ListAuthorizedAIBridgeSessions(ctx, arg, prepared) +} + +func (q *querier) CountAuthorizedAIBridgeSessions(ctx context.Context, arg database.CountAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) (int64, error) { + return q.db.CountAuthorizedAIBridgeSessions(ctx, arg, prepared) +} + func (q *querier) GetAuthorizedChats(ctx context.Context, arg database.GetChatsParams, _ rbac.PreparedAuthorized) ([]database.Chat, error) { return q.GetChats(ctx, arg) } diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 8a9e4988fd..b38ff83ffd 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -5514,6 +5514,34 @@ func (s *MethodTestSuite) TestAIBridge() { check.Args(params, emptyPreparedAuthorized{}).Asserts() })) + s.Run("ListAIBridgeSessions", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + params := database.ListAIBridgeSessionsParams{} + db.EXPECT().ListAuthorizedAIBridgeSessions(gomock.Any(), params, gomock.Any()).Return([]database.ListAIBridgeSessionsRow{}, nil).AnyTimes() + // No asserts here because SQLFilter. + check.Args(params).Asserts() + })) + + s.Run("ListAuthorizedAIBridgeSessions", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + params := database.ListAIBridgeSessionsParams{} + db.EXPECT().ListAuthorizedAIBridgeSessions(gomock.Any(), params, gomock.Any()).Return([]database.ListAIBridgeSessionsRow{}, nil).AnyTimes() + // No asserts here because SQLFilter. + check.Args(params, emptyPreparedAuthorized{}).Asserts() + })) + + s.Run("CountAIBridgeSessions", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + params := database.CountAIBridgeSessionsParams{} + db.EXPECT().CountAuthorizedAIBridgeSessions(gomock.Any(), params, gomock.Any()).Return(int64(0), nil).AnyTimes() + // No asserts here because SQLFilter. + check.Args(params).Asserts() + })) + + s.Run("CountAuthorizedAIBridgeSessions", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + params := database.CountAIBridgeSessionsParams{} + db.EXPECT().CountAuthorizedAIBridgeSessions(gomock.Any(), params, gomock.Any()).Return(int64(0), nil).AnyTimes() + // No asserts here because SQLFilter. + check.Args(params, emptyPreparedAuthorized{}).Asserts() + })) + s.Run("ListAIBridgeTokenUsagesByInterceptionIDs", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { ids := []uuid.UUID{{1}} db.EXPECT().ListAIBridgeTokenUsagesByInterceptionIDs(gomock.Any(), ids).Return([]database.AIBridgeTokenUsage{}, nil).AnyTimes() diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index 31f798cf34..c41e2a4647 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -280,6 +280,14 @@ func (m queryMetricsStore) CountAIBridgeInterceptions(ctx context.Context, arg d return r0, r1 } +func (m queryMetricsStore) CountAIBridgeSessions(ctx context.Context, arg database.CountAIBridgeSessionsParams) (int64, error) { + start := time.Now() + r0, r1 := m.s.CountAIBridgeSessions(ctx, arg) + m.queryLatencies.WithLabelValues("CountAIBridgeSessions").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "CountAIBridgeSessions").Inc() + return r0, r1 +} + func (m queryMetricsStore) CountAuditLogs(ctx context.Context, arg database.CountAuditLogsParams) (int64, error) { start := time.Now() r0, r1 := m.s.CountAuditLogs(ctx, arg) @@ -3720,6 +3728,14 @@ func (m queryMetricsStore) ListAIBridgeModels(ctx context.Context, arg database. return r0, r1 } +func (m queryMetricsStore) ListAIBridgeSessions(ctx context.Context, arg database.ListAIBridgeSessionsParams) ([]database.ListAIBridgeSessionsRow, error) { + start := time.Now() + r0, r1 := m.s.ListAIBridgeSessions(ctx, arg) + m.queryLatencies.WithLabelValues("ListAIBridgeSessions").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListAIBridgeSessions").Inc() + return r0, r1 +} + func (m queryMetricsStore) ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]database.AIBridgeTokenUsage, error) { start := time.Now() r0, r1 := m.s.ListAIBridgeTokenUsagesByInterceptionIDs(ctx, interceptionIds) @@ -5136,6 +5152,22 @@ func (m queryMetricsStore) ListAuthorizedAIBridgeModels(ctx context.Context, arg return r0, r1 } +func (m queryMetricsStore) ListAuthorizedAIBridgeSessions(ctx context.Context, arg database.ListAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) ([]database.ListAIBridgeSessionsRow, error) { + start := time.Now() + r0, r1 := m.s.ListAuthorizedAIBridgeSessions(ctx, arg, prepared) + m.queryLatencies.WithLabelValues("ListAuthorizedAIBridgeSessions").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListAuthorizedAIBridgeSessions").Inc() + return r0, r1 +} + +func (m queryMetricsStore) CountAuthorizedAIBridgeSessions(ctx context.Context, arg database.CountAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) (int64, error) { + start := time.Now() + r0, r1 := m.s.CountAuthorizedAIBridgeSessions(ctx, arg, prepared) + m.queryLatencies.WithLabelValues("CountAuthorizedAIBridgeSessions").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "CountAuthorizedAIBridgeSessions").Inc() + return r0, r1 +} + func (m queryMetricsStore) GetAuthorizedChats(ctx context.Context, arg database.GetChatsParams, prepared rbac.PreparedAuthorized) ([]database.Chat, error) { start := time.Now() r0, r1 := m.s.GetAuthorizedChats(ctx, arg, prepared) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index 82cc6c4145..a061d6ebab 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -363,6 +363,21 @@ func (mr *MockStoreMockRecorder) CountAIBridgeInterceptions(ctx, arg any) *gomoc return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAIBridgeInterceptions", reflect.TypeOf((*MockStore)(nil).CountAIBridgeInterceptions), ctx, arg) } +// CountAIBridgeSessions mocks base method. +func (m *MockStore) CountAIBridgeSessions(ctx context.Context, arg database.CountAIBridgeSessionsParams) (int64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CountAIBridgeSessions", ctx, arg) + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CountAIBridgeSessions indicates an expected call of CountAIBridgeSessions. +func (mr *MockStoreMockRecorder) CountAIBridgeSessions(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAIBridgeSessions", reflect.TypeOf((*MockStore)(nil).CountAIBridgeSessions), ctx, arg) +} + // CountAuditLogs mocks base method. func (m *MockStore) CountAuditLogs(ctx context.Context, arg database.CountAuditLogsParams) (int64, error) { m.ctrl.T.Helper() @@ -393,6 +408,21 @@ func (mr *MockStoreMockRecorder) CountAuthorizedAIBridgeInterceptions(ctx, arg, return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAuthorizedAIBridgeInterceptions", reflect.TypeOf((*MockStore)(nil).CountAuthorizedAIBridgeInterceptions), ctx, arg, prepared) } +// CountAuthorizedAIBridgeSessions mocks base method. +func (m *MockStore) CountAuthorizedAIBridgeSessions(ctx context.Context, arg database.CountAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) (int64, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "CountAuthorizedAIBridgeSessions", ctx, arg, prepared) + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CountAuthorizedAIBridgeSessions indicates an expected call of CountAuthorizedAIBridgeSessions. +func (mr *MockStoreMockRecorder) CountAuthorizedAIBridgeSessions(ctx, arg, prepared any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAuthorizedAIBridgeSessions", reflect.TypeOf((*MockStore)(nil).CountAuthorizedAIBridgeSessions), ctx, arg, prepared) +} + // CountAuthorizedAuditLogs mocks base method. func (m *MockStore) CountAuthorizedAuditLogs(ctx context.Context, arg database.CountAuditLogsParams, prepared rbac.PreparedAuthorized) (int64, error) { m.ctrl.T.Helper() @@ -6947,6 +6977,21 @@ func (mr *MockStoreMockRecorder) ListAIBridgeModels(ctx, arg any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAIBridgeModels", reflect.TypeOf((*MockStore)(nil).ListAIBridgeModels), ctx, arg) } +// ListAIBridgeSessions mocks base method. +func (m *MockStore) ListAIBridgeSessions(ctx context.Context, arg database.ListAIBridgeSessionsParams) ([]database.ListAIBridgeSessionsRow, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListAIBridgeSessions", ctx, arg) + ret0, _ := ret[0].([]database.ListAIBridgeSessionsRow) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListAIBridgeSessions indicates an expected call of ListAIBridgeSessions. +func (mr *MockStoreMockRecorder) ListAIBridgeSessions(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAIBridgeSessions", reflect.TypeOf((*MockStore)(nil).ListAIBridgeSessions), ctx, arg) +} + // ListAIBridgeTokenUsagesByInterceptionIDs mocks base method. func (m *MockStore) ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]database.AIBridgeTokenUsage, error) { m.ctrl.T.Helper() @@ -7022,6 +7067,21 @@ func (mr *MockStoreMockRecorder) ListAuthorizedAIBridgeModels(ctx, arg, prepared return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAuthorizedAIBridgeModels", reflect.TypeOf((*MockStore)(nil).ListAuthorizedAIBridgeModels), ctx, arg, prepared) } +// ListAuthorizedAIBridgeSessions mocks base method. +func (m *MockStore) ListAuthorizedAIBridgeSessions(ctx context.Context, arg database.ListAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) ([]database.ListAIBridgeSessionsRow, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListAuthorizedAIBridgeSessions", ctx, arg, prepared) + ret0, _ := ret[0].([]database.ListAIBridgeSessionsRow) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListAuthorizedAIBridgeSessions indicates an expected call of ListAuthorizedAIBridgeSessions. +func (mr *MockStoreMockRecorder) ListAuthorizedAIBridgeSessions(ctx, arg, prepared any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAuthorizedAIBridgeSessions", reflect.TypeOf((*MockStore)(nil).ListAuthorizedAIBridgeSessions), ctx, arg, prepared) +} + // ListChatUsageLimitGroupOverrides mocks base method. func (m *MockStore) ListChatUsageLimitGroupOverrides(ctx context.Context) ([]database.ListChatUsageLimitGroupOverridesRow, error) { m.ctrl.T.Helper() diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index 3a77928e1f..f44556cd72 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -1099,7 +1099,8 @@ CREATE TABLE aibridge_interceptions ( client character varying(64) DEFAULT 'Unknown'::character varying, thread_parent_id uuid, thread_root_id uuid, - client_session_id character varying(256) + client_session_id character varying(256), + session_id text GENERATED ALWAYS AS (COALESCE(client_session_id, ((thread_root_id)::text)::character varying, ((id)::text)::character varying)) STORED NOT NULL ); COMMENT ON TABLE aibridge_interceptions IS 'Audit log of requests intercepted by AI Bridge'; @@ -1112,6 +1113,8 @@ COMMENT ON COLUMN aibridge_interceptions.thread_root_id IS 'The root interceptio COMMENT ON COLUMN aibridge_interceptions.client_session_id IS 'The session ID supplied by the client (optional and not universally supported).'; +COMMENT ON COLUMN aibridge_interceptions.session_id IS 'Groups related interceptions into a logical session. Determined by a priority chain: (1) client_session_id — an explicit session identifier supplied by the calling client (e.g. Claude Code); (2) thread_root_id — the root of an agentic thread detected by Bridge through tool-call correlation, used when the client does not supply its own session ID; (3) id — the interception''s own ID, used as a last resort so every interception belongs to exactly one session even if it is standalone. This is a generated column stored on disk so it can be indexed and joined without recomputing the COALESCE on every query.'; + CREATE TABLE aibridge_model_thoughts ( interception_id uuid NOT NULL, content text NOT NULL, @@ -3654,6 +3657,10 @@ CREATE INDEX idx_aibridge_interceptions_model ON aibridge_interceptions USING bt CREATE INDEX idx_aibridge_interceptions_provider ON aibridge_interceptions USING btree (provider); +CREATE INDEX idx_aibridge_interceptions_session_id ON aibridge_interceptions USING btree (session_id) WHERE (ended_at IS NOT NULL); + +CREATE INDEX idx_aibridge_interceptions_sessions_filter ON aibridge_interceptions USING btree (initiator_id, started_at DESC, id DESC) WHERE (ended_at IS NOT NULL); + CREATE INDEX idx_aibridge_interceptions_started_id_desc ON aibridge_interceptions USING btree (started_at DESC, id DESC); CREATE INDEX idx_aibridge_interceptions_thread_parent_id ON aibridge_interceptions USING btree (thread_parent_id); @@ -3672,6 +3679,8 @@ CREATE INDEX idx_aibridge_tool_usages_provider_tool_call_id ON aibridge_tool_usa CREATE INDEX idx_aibridge_tool_usagesprovider_response_id ON aibridge_tool_usages USING btree (provider_response_id); +CREATE INDEX idx_aibridge_user_prompts_interception_created ON aibridge_user_prompts USING btree (interception_id, created_at DESC, id DESC); + CREATE INDEX idx_aibridge_user_prompts_interception_id ON aibridge_user_prompts USING btree (interception_id); CREATE INDEX idx_aibridge_user_prompts_provider_response_id ON aibridge_user_prompts USING btree (provider_response_id); diff --git a/coderd/database/migrations/000449_aibridge_session_indexes.down.sql b/coderd/database/migrations/000449_aibridge_session_indexes.down.sql new file mode 100644 index 0000000000..7f510a7cc5 --- /dev/null +++ b/coderd/database/migrations/000449_aibridge_session_indexes.down.sql @@ -0,0 +1,5 @@ +DROP INDEX IF EXISTS idx_aibridge_interceptions_session_id; +DROP INDEX IF EXISTS idx_aibridge_user_prompts_interception_created; +DROP INDEX IF EXISTS idx_aibridge_interceptions_sessions_filter; + +ALTER TABLE aibridge_interceptions DROP COLUMN IF EXISTS session_id; diff --git a/coderd/database/migrations/000449_aibridge_session_indexes.up.sql b/coderd/database/migrations/000449_aibridge_session_indexes.up.sql new file mode 100644 index 0000000000..3927f9c1ba --- /dev/null +++ b/coderd/database/migrations/000449_aibridge_session_indexes.up.sql @@ -0,0 +1,40 @@ +-- A "session" groups related interceptions together. See the COMMENT ON +-- COLUMN below for the full business-logic description. +ALTER TABLE aibridge_interceptions + ADD COLUMN session_id TEXT NOT NULL + GENERATED ALWAYS AS ( + COALESCE( + client_session_id, + thread_root_id::text, + id::text + ) + ) STORED; + +-- Searching and grouping on the resolved session ID will be common. +CREATE INDEX idx_aibridge_interceptions_session_id + ON aibridge_interceptions (session_id) + WHERE ended_at IS NOT NULL; + +COMMENT ON COLUMN aibridge_interceptions.session_id IS + 'Groups related interceptions into a logical session. ' + 'Determined by a priority chain: ' + '(1) client_session_id — an explicit session identifier supplied by the ' + 'calling client (e.g. Claude Code); ' + '(2) thread_root_id — the root of an agentic thread detected by Bridge ' + 'through tool-call correlation, used when the client does not supply its ' + 'own session ID; ' + '(3) id — the interception''s own ID, used as a last resort so every ' + 'interception belongs to exactly one session even if it is standalone. ' + 'This is a generated column stored on disk so it can be indexed and ' + 'joined without recomputing the COALESCE on every query.'; + +-- Composite index for the most common filter path used by +-- ListAIBridgeSessions: initiator_id equality + started_at range, +-- with ended_at IS NOT NULL as a partial filter. +CREATE INDEX idx_aibridge_interceptions_sessions_filter + ON aibridge_interceptions (initiator_id, started_at DESC, id DESC) + WHERE ended_at IS NOT NULL; + +-- Supports lateral prompt lookup by interception + recency. +CREATE INDEX idx_aibridge_user_prompts_interception_created + ON aibridge_user_prompts (interception_id, created_at DESC, id DESC); diff --git a/coderd/database/modelqueries.go b/coderd/database/modelqueries.go index c39b06202b..6a8aed8108 100644 --- a/coderd/database/modelqueries.go +++ b/coderd/database/modelqueries.go @@ -806,6 +806,8 @@ type aibridgeQuerier interface { ListAuthorizedAIBridgeInterceptions(ctx context.Context, arg ListAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) ([]ListAIBridgeInterceptionsRow, error) CountAuthorizedAIBridgeInterceptions(ctx context.Context, arg CountAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) (int64, error) ListAuthorizedAIBridgeModels(ctx context.Context, arg ListAIBridgeModelsParams, prepared rbac.PreparedAuthorized) ([]string, error) + ListAuthorizedAIBridgeSessions(ctx context.Context, arg ListAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) ([]ListAIBridgeSessionsRow, error) + CountAuthorizedAIBridgeSessions(ctx context.Context, arg CountAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) (int64, error) } func (q *sqlQuerier) ListAuthorizedAIBridgeInterceptions(ctx context.Context, arg ListAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) ([]ListAIBridgeInterceptionsRow, error) { @@ -852,6 +854,7 @@ func (q *sqlQuerier) ListAuthorizedAIBridgeInterceptions(ctx context.Context, ar &i.AIBridgeInterception.ThreadParentID, &i.AIBridgeInterception.ThreadRootID, &i.AIBridgeInterception.ClientSessionID, + &i.AIBridgeInterception.SessionID, &i.VisibleUser.ID, &i.VisibleUser.Username, &i.VisibleUser.Name, @@ -939,6 +942,109 @@ func (q *sqlQuerier) ListAuthorizedAIBridgeModels(ctx context.Context, arg ListA return items, nil } +func (q *sqlQuerier) ListAuthorizedAIBridgeSessions(ctx context.Context, arg ListAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) ([]ListAIBridgeSessionsRow, error) { + authorizedFilter, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{ + VariableConverter: regosql.AIBridgeInterceptionConverter(), + }) + if err != nil { + return nil, xerrors.Errorf("compile authorized filter: %w", err) + } + filtered, err := insertAuthorizedFilter(listAIBridgeSessions, fmt.Sprintf(" AND %s", authorizedFilter)) + if err != nil { + return nil, xerrors.Errorf("insert authorized filter: %w", err) + } + + query := fmt.Sprintf("-- name: ListAuthorizedAIBridgeSessions :many\n%s", filtered) + rows, err := q.db.QueryContext(ctx, query, + arg.AfterSessionID, + arg.Offset, + arg.Limit, + arg.StartedAfter, + arg.StartedBefore, + arg.InitiatorID, + arg.Provider, + arg.Model, + arg.Client, + arg.SessionID, + ) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ListAIBridgeSessionsRow + for rows.Next() { + var i ListAIBridgeSessionsRow + if err := rows.Scan( + &i.SessionID, + &i.UserID, + &i.UserUsername, + &i.UserName, + &i.UserAvatarUrl, + pq.Array(&i.Providers), + pq.Array(&i.Models), + &i.Client, + &i.Metadata, + &i.StartedAt, + &i.EndedAt, + &i.Threads, + &i.InputTokens, + &i.OutputTokens, + &i.LastPrompt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +func (q *sqlQuerier) CountAuthorizedAIBridgeSessions(ctx context.Context, arg CountAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) (int64, error) { + authorizedFilter, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{ + VariableConverter: regosql.AIBridgeInterceptionConverter(), + }) + if err != nil { + return 0, xerrors.Errorf("compile authorized filter: %w", err) + } + filtered, err := insertAuthorizedFilter(countAIBridgeSessions, fmt.Sprintf(" AND %s", authorizedFilter)) + if err != nil { + return 0, xerrors.Errorf("insert authorized filter: %w", err) + } + + query := fmt.Sprintf("-- name: CountAuthorizedAIBridgeSessions :one\n%s", filtered) + rows, err := q.db.QueryContext(ctx, query, + arg.StartedAfter, + arg.StartedBefore, + arg.InitiatorID, + arg.Provider, + arg.Model, + arg.Client, + arg.SessionID, + ) + if err != nil { + return 0, err + } + defer rows.Close() + var count int64 + for rows.Next() { + if err := rows.Scan(&count); err != nil { + return 0, err + } + } + if err := rows.Close(); err != nil { + return 0, err + } + if err := rows.Err(); err != nil { + return 0, err + } + return count, nil +} + func insertAuthorizedFilter(query string, replaceWith string) (string, error) { if !strings.Contains(query, authorizedQueryPlaceholder) { return "", xerrors.Errorf("query does not contain authorized replace string, this is not an authorized query") diff --git a/coderd/database/models.go b/coderd/database/models.go index 80972cf7d2..4b7feb6b02 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -4036,6 +4036,8 @@ type AIBridgeInterception struct { ThreadRootID uuid.NullUUID `db:"thread_root_id" json:"thread_root_id"` // The session ID supplied by the client (optional and not universally supported). ClientSessionID sql.NullString `db:"client_session_id" json:"client_session_id"` + // Groups related interceptions into a logical session. Determined by a priority chain: (1) client_session_id — an explicit session identifier supplied by the calling client (e.g. Claude Code); (2) thread_root_id — the root of an agentic thread detected by Bridge through tool-call correlation, used when the client does not supply its own session ID; (3) id — the interception's own ID, used as a last resort so every interception belongs to exactly one session even if it is standalone. This is a generated column stored on disk so it can be indexed and joined without recomputing the COALESCE on every query. + SessionID string `db:"session_id" json:"session_id"` } // Audit log of model thinking in intercepted requests in AI Bridge diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 9bd876b683..78cad95f85 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -76,6 +76,7 @@ type sqlcQuerier interface { CleanTailnetTunnels(ctx context.Context) error CleanupDeletedMCPServerIDsFromChats(ctx context.Context) error CountAIBridgeInterceptions(ctx context.Context, arg CountAIBridgeInterceptionsParams) (int64, error) + CountAIBridgeSessions(ctx context.Context, arg CountAIBridgeSessionsParams) (int64, error) CountAuditLogs(ctx context.Context, arg CountAuditLogsParams) (int64, error) CountConnectionLogs(ctx context.Context, arg CountConnectionLogsParams) (int64, error) // Counts enabled, non-deleted model configs that lack both input and @@ -759,6 +760,10 @@ type sqlcQuerier interface { // (provider, model, client) in the given timeframe for telemetry reporting. ListAIBridgeInterceptionsTelemetrySummaries(ctx context.Context, arg ListAIBridgeInterceptionsTelemetrySummariesParams) ([]ListAIBridgeInterceptionsTelemetrySummariesRow, error) ListAIBridgeModels(ctx context.Context, arg ListAIBridgeModelsParams) ([]string, error) + // Returns paginated sessions with aggregated metadata, token counts, and + // the most recent user prompt. A "session" is a logical grouping of + // interceptions that share the same session_id (set by the client). + ListAIBridgeSessions(ctx context.Context, arg ListAIBridgeSessionsParams) ([]ListAIBridgeSessionsRow, error) ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeTokenUsage, error) ListAIBridgeToolUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeToolUsage, error) ListAIBridgeUserPromptsByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeUserPrompt, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 2a74a99f22..845333d399 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -332,6 +332,77 @@ func (q *sqlQuerier) CountAIBridgeInterceptions(ctx context.Context, arg CountAI return count, err } +const countAIBridgeSessions = `-- name: CountAIBridgeSessions :one +SELECT + COUNT(DISTINCT (aibridge_interceptions.session_id, aibridge_interceptions.initiator_id)) +FROM + aibridge_interceptions +WHERE + -- Remove inflight interceptions (ones which lack an ended_at value). + aibridge_interceptions.ended_at IS NOT NULL + -- Filter by time frame + AND CASE + WHEN $1::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at >= $1::timestamptz + ELSE true + END + AND CASE + WHEN $2::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at <= $2::timestamptz + ELSE true + END + -- Filter initiator_id + AND CASE + WHEN $3::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN aibridge_interceptions.initiator_id = $3::uuid + ELSE true + END + -- Filter provider + AND CASE + WHEN $4::text != '' THEN aibridge_interceptions.provider = $4::text + ELSE true + END + -- Filter model + AND CASE + WHEN $5::text != '' THEN aibridge_interceptions.model = $5::text + ELSE true + END + -- Filter client + AND CASE + WHEN $6::text != '' THEN COALESCE(aibridge_interceptions.client, 'Unknown') = $6::text + ELSE true + END + -- Filter session_id + AND CASE + WHEN $7::text != '' THEN aibridge_interceptions.session_id = $7::text + ELSE true + END + -- Authorize Filter clause will be injected below in CountAuthorizedAIBridgeSessions + -- @authorize_filter +` + +type CountAIBridgeSessionsParams struct { + StartedAfter time.Time `db:"started_after" json:"started_after"` + StartedBefore time.Time `db:"started_before" json:"started_before"` + InitiatorID uuid.UUID `db:"initiator_id" json:"initiator_id"` + Provider string `db:"provider" json:"provider"` + Model string `db:"model" json:"model"` + Client string `db:"client" json:"client"` + SessionID string `db:"session_id" json:"session_id"` +} + +func (q *sqlQuerier) CountAIBridgeSessions(ctx context.Context, arg CountAIBridgeSessionsParams) (int64, error) { + row := q.db.QueryRowContext(ctx, countAIBridgeSessions, + arg.StartedAfter, + arg.StartedBefore, + arg.InitiatorID, + arg.Provider, + arg.Model, + arg.Client, + arg.SessionID, + ) + var count int64 + err := row.Scan(&count) + return count, err +} + const deleteOldAIBridgeRecords = `-- name: DeleteOldAIBridgeRecords :one WITH -- We don't have FK relationships between the dependent tables and aibridge_interceptions, so we can't rely on DELETE CASCADE. @@ -384,7 +455,7 @@ func (q *sqlQuerier) DeleteOldAIBridgeRecords(ctx context.Context, beforeTime ti const getAIBridgeInterceptionByID = `-- name: GetAIBridgeInterceptionByID :one SELECT - id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id, client_session_id + id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id, client_session_id, session_id FROM aibridge_interceptions WHERE @@ -407,6 +478,7 @@ func (q *sqlQuerier) GetAIBridgeInterceptionByID(ctx context.Context, id uuid.UU &i.ThreadParentID, &i.ThreadRootID, &i.ClientSessionID, + &i.SessionID, ) return i, err } @@ -441,7 +513,7 @@ func (q *sqlQuerier) GetAIBridgeInterceptionLineageByToolCallID(ctx context.Cont const getAIBridgeInterceptions = `-- name: GetAIBridgeInterceptions :many SELECT - id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id, client_session_id + id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id, client_session_id, session_id FROM aibridge_interceptions ` @@ -468,6 +540,7 @@ func (q *sqlQuerier) GetAIBridgeInterceptions(ctx context.Context) ([]AIBridgeIn &i.ThreadParentID, &i.ThreadRootID, &i.ClientSessionID, + &i.SessionID, ); err != nil { return nil, err } @@ -618,7 +691,7 @@ INSERT INTO aibridge_interceptions ( ) VALUES ( $1, $2, $3, $4, $5, COALESCE($6::jsonb, '{}'::jsonb), $7, $8, $9, $10::uuid, $11::uuid ) -RETURNING id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id, client_session_id +RETURNING id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id, client_session_id, session_id ` type InsertAIBridgeInterceptionParams struct { @@ -663,6 +736,7 @@ func (q *sqlQuerier) InsertAIBridgeInterception(ctx context.Context, arg InsertA &i.ThreadParentID, &i.ThreadRootID, &i.ClientSessionID, + &i.SessionID, ) return i, err } @@ -837,7 +911,7 @@ func (q *sqlQuerier) InsertAIBridgeUserPrompt(ctx context.Context, arg InsertAIB const listAIBridgeInterceptions = `-- name: ListAIBridgeInterceptions :many SELECT - aibridge_interceptions.id, aibridge_interceptions.initiator_id, aibridge_interceptions.provider, aibridge_interceptions.model, aibridge_interceptions.started_at, aibridge_interceptions.metadata, aibridge_interceptions.ended_at, aibridge_interceptions.api_key_id, aibridge_interceptions.client, aibridge_interceptions.thread_parent_id, aibridge_interceptions.thread_root_id, aibridge_interceptions.client_session_id, + aibridge_interceptions.id, aibridge_interceptions.initiator_id, aibridge_interceptions.provider, aibridge_interceptions.model, aibridge_interceptions.started_at, aibridge_interceptions.metadata, aibridge_interceptions.ended_at, aibridge_interceptions.api_key_id, aibridge_interceptions.client, aibridge_interceptions.thread_parent_id, aibridge_interceptions.thread_root_id, aibridge_interceptions.client_session_id, aibridge_interceptions.session_id, visible_users.id, visible_users.username, visible_users.name, visible_users.avatar_url FROM aibridge_interceptions @@ -949,6 +1023,7 @@ func (q *sqlQuerier) ListAIBridgeInterceptions(ctx context.Context, arg ListAIBr &i.AIBridgeInterception.ThreadParentID, &i.AIBridgeInterception.ThreadRootID, &i.AIBridgeInterception.ClientSessionID, + &i.AIBridgeInterception.SessionID, &i.VisibleUser.ID, &i.VisibleUser.Username, &i.VisibleUser.Name, @@ -1071,6 +1146,229 @@ func (q *sqlQuerier) ListAIBridgeModels(ctx context.Context, arg ListAIBridgeMod return items, nil } +const listAIBridgeSessions = `-- name: ListAIBridgeSessions :many +WITH filtered_interceptions AS ( + SELECT + aibridge_interceptions.id, aibridge_interceptions.initiator_id, aibridge_interceptions.provider, aibridge_interceptions.model, aibridge_interceptions.started_at, aibridge_interceptions.metadata, aibridge_interceptions.ended_at, aibridge_interceptions.api_key_id, aibridge_interceptions.client, aibridge_interceptions.thread_parent_id, aibridge_interceptions.thread_root_id, aibridge_interceptions.client_session_id, aibridge_interceptions.session_id + FROM + aibridge_interceptions + WHERE + -- Remove inflight interceptions (ones which lack an ended_at value). + aibridge_interceptions.ended_at IS NOT NULL + -- Filter by time frame + AND CASE + WHEN $4::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at >= $4::timestamptz + ELSE true + END + AND CASE + WHEN $5::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at <= $5::timestamptz + ELSE true + END + -- Filter initiator_id + AND CASE + WHEN $6::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN aibridge_interceptions.initiator_id = $6::uuid + ELSE true + END + -- Filter provider + AND CASE + WHEN $7::text != '' THEN aibridge_interceptions.provider = $7::text + ELSE true + END + -- Filter model + AND CASE + WHEN $8::text != '' THEN aibridge_interceptions.model = $8::text + ELSE true + END + -- Filter client + AND CASE + WHEN $9::text != '' THEN COALESCE(aibridge_interceptions.client, 'Unknown') = $9::text + ELSE true + END + -- Filter session_id + AND CASE + WHEN $10::text != '' THEN aibridge_interceptions.session_id = $10::text + ELSE true + END + -- Authorize Filter clause will be injected below in ListAuthorizedAIBridgeSessions + -- @authorize_filter +), +session_tokens AS ( + -- Aggregate token usage across all interceptions in each session. + -- Group by (session_id, initiator_id) to avoid merging sessions from + -- different users who happen to share the same client_session_id. + SELECT + fi.session_id, + fi.initiator_id, + COALESCE(SUM(tu.input_tokens), 0)::bigint AS input_tokens, + COALESCE(SUM(tu.output_tokens), 0)::bigint AS output_tokens + -- TODO: add extra token types once https://github.com/coder/aibridge/issues/150 lands. + FROM + filtered_interceptions fi + LEFT JOIN + aibridge_token_usages tu ON fi.id = tu.interception_id + GROUP BY + fi.session_id, fi.initiator_id +), +session_root AS ( + -- Build one summary row per session. Group by (session_id, initiator_id) + -- to avoid merging sessions from different users who happen to share the + -- same client_session_id. The ARRAY_AGG with ORDER BY picks values from + -- the chronologically first interception for fields that should represent + -- the session as a whole (client, metadata). Threads are counted as + -- distinct root interception IDs: an interception with a NULL + -- thread_root_id is itself a thread root. + SELECT + fi.session_id, + fi.initiator_id, + (ARRAY_AGG(fi.client ORDER BY fi.started_at, fi.id))[1] AS client, + (ARRAY_AGG(fi.metadata ORDER BY fi.started_at, fi.id))[1] AS metadata, + ARRAY_AGG(DISTINCT fi.provider ORDER BY fi.provider) AS providers, + ARRAY_AGG(DISTINCT fi.model ORDER BY fi.model) AS models, + MIN(fi.started_at) AS started_at, + MAX(fi.ended_at) AS ended_at, + COUNT(DISTINCT COALESCE(fi.thread_root_id, fi.id)) AS threads, + -- Collect IDs for lateral prompt lookup. + ARRAY_AGG(fi.id) AS interception_ids + FROM + filtered_interceptions fi + GROUP BY + fi.session_id, fi.initiator_id +) +SELECT + sr.session_id, + visible_users.id AS user_id, + visible_users.username AS user_username, + visible_users.name AS user_name, + visible_users.avatar_url AS user_avatar_url, + sr.providers::text[] AS providers, + sr.models::text[] AS models, + COALESCE(sr.client, '')::varchar(64) AS client, + sr.metadata::jsonb AS metadata, + sr.started_at::timestamptz AS started_at, + sr.ended_at::timestamptz AS ended_at, + sr.threads, + COALESCE(st.input_tokens, 0)::bigint AS input_tokens, + COALESCE(st.output_tokens, 0)::bigint AS output_tokens, + COALESCE(slp.prompt, '') AS last_prompt +FROM + session_root sr +JOIN + visible_users ON visible_users.id = sr.initiator_id +LEFT JOIN + session_tokens st ON st.session_id = sr.session_id AND st.initiator_id = sr.initiator_id +LEFT JOIN LATERAL ( + -- Lateral join to efficiently fetch only the most recent user prompt + -- across all interceptions in the session, avoiding a full aggregation. + SELECT up.prompt + FROM aibridge_user_prompts up + WHERE up.interception_id = ANY(sr.interception_ids) + ORDER BY up.created_at DESC, up.id DESC + LIMIT 1 +) slp ON true +WHERE + -- Cursor pagination: uses a composite (started_at, session_id) cursor + -- to support keyset pagination. The less-than comparison matches the + -- DESC sort order so that rows after the cursor come later in results. + CASE + WHEN $1::text != '' THEN ( + (sr.started_at, sr.session_id) < ( + (SELECT started_at FROM session_root WHERE session_id = $1), + $1::text + ) + ) + ELSE true + END +ORDER BY + sr.started_at DESC, + sr.session_id DESC +LIMIT COALESCE(NULLIF($3::integer, 0), 100) +OFFSET $2 +` + +type ListAIBridgeSessionsParams struct { + AfterSessionID string `db:"after_session_id" json:"after_session_id"` + Offset int32 `db:"offset_" json:"offset_"` + Limit int32 `db:"limit_" json:"limit_"` + StartedAfter time.Time `db:"started_after" json:"started_after"` + StartedBefore time.Time `db:"started_before" json:"started_before"` + InitiatorID uuid.UUID `db:"initiator_id" json:"initiator_id"` + Provider string `db:"provider" json:"provider"` + Model string `db:"model" json:"model"` + Client string `db:"client" json:"client"` + SessionID string `db:"session_id" json:"session_id"` +} + +type ListAIBridgeSessionsRow struct { + SessionID string `db:"session_id" json:"session_id"` + UserID uuid.UUID `db:"user_id" json:"user_id"` + UserUsername string `db:"user_username" json:"user_username"` + UserName string `db:"user_name" json:"user_name"` + UserAvatarUrl string `db:"user_avatar_url" json:"user_avatar_url"` + Providers []string `db:"providers" json:"providers"` + Models []string `db:"models" json:"models"` + Client string `db:"client" json:"client"` + Metadata json.RawMessage `db:"metadata" json:"metadata"` + StartedAt time.Time `db:"started_at" json:"started_at"` + EndedAt time.Time `db:"ended_at" json:"ended_at"` + Threads int64 `db:"threads" json:"threads"` + InputTokens int64 `db:"input_tokens" json:"input_tokens"` + OutputTokens int64 `db:"output_tokens" json:"output_tokens"` + LastPrompt string `db:"last_prompt" json:"last_prompt"` +} + +// Returns paginated sessions with aggregated metadata, token counts, and +// the most recent user prompt. A "session" is a logical grouping of +// interceptions that share the same session_id (set by the client). +func (q *sqlQuerier) ListAIBridgeSessions(ctx context.Context, arg ListAIBridgeSessionsParams) ([]ListAIBridgeSessionsRow, error) { + rows, err := q.db.QueryContext(ctx, listAIBridgeSessions, + arg.AfterSessionID, + arg.Offset, + arg.Limit, + arg.StartedAfter, + arg.StartedBefore, + arg.InitiatorID, + arg.Provider, + arg.Model, + arg.Client, + arg.SessionID, + ) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ListAIBridgeSessionsRow + for rows.Next() { + var i ListAIBridgeSessionsRow + if err := rows.Scan( + &i.SessionID, + &i.UserID, + &i.UserUsername, + &i.UserName, + &i.UserAvatarUrl, + pq.Array(&i.Providers), + pq.Array(&i.Models), + &i.Client, + &i.Metadata, + &i.StartedAt, + &i.EndedAt, + &i.Threads, + &i.InputTokens, + &i.OutputTokens, + &i.LastPrompt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const listAIBridgeTokenUsagesByInterceptionIDs = `-- name: ListAIBridgeTokenUsagesByInterceptionIDs :many SELECT id, interception_id, provider_response_id, input_tokens, output_tokens, metadata, created_at @@ -1209,7 +1507,7 @@ UPDATE aibridge_interceptions WHERE id = $2::uuid AND ended_at IS NULL -RETURNING id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id, client_session_id +RETURNING id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id, client_session_id, session_id ` type UpdateAIBridgeInterceptionEndedParams struct { @@ -1233,6 +1531,7 @@ func (q *sqlQuerier) UpdateAIBridgeInterceptionEnded(ctx context.Context, arg Up &i.ThreadParentID, &i.ThreadRootID, &i.ClientSessionID, + &i.SessionID, ) return i, err } diff --git a/coderd/database/queries/aibridge.sql b/coderd/database/queries/aibridge.sql index 2115ffebe7..5804fda09a 100644 --- a/coderd/database/queries/aibridge.sql +++ b/coderd/database/queries/aibridge.sql @@ -404,6 +404,194 @@ SELECT ( (SELECT COUNT(*) FROM interceptions) )::bigint as total_deleted; +-- name: CountAIBridgeSessions :one +SELECT + COUNT(DISTINCT (aibridge_interceptions.session_id, aibridge_interceptions.initiator_id)) +FROM + aibridge_interceptions +WHERE + -- Remove inflight interceptions (ones which lack an ended_at value). + aibridge_interceptions.ended_at IS NOT NULL + -- Filter by time frame + AND CASE + WHEN @started_after::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at >= @started_after::timestamptz + ELSE true + END + AND CASE + WHEN @started_before::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at <= @started_before::timestamptz + ELSE true + END + -- Filter initiator_id + AND CASE + WHEN @initiator_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN aibridge_interceptions.initiator_id = @initiator_id::uuid + ELSE true + END + -- Filter provider + AND CASE + WHEN @provider::text != '' THEN aibridge_interceptions.provider = @provider::text + ELSE true + END + -- Filter model + AND CASE + WHEN @model::text != '' THEN aibridge_interceptions.model = @model::text + ELSE true + END + -- Filter client + AND CASE + WHEN @client::text != '' THEN COALESCE(aibridge_interceptions.client, 'Unknown') = @client::text + ELSE true + END + -- Filter session_id + AND CASE + WHEN @session_id::text != '' THEN aibridge_interceptions.session_id = @session_id::text + ELSE true + END + -- Authorize Filter clause will be injected below in CountAuthorizedAIBridgeSessions + -- @authorize_filter +; + +-- name: ListAIBridgeSessions :many +-- Returns paginated sessions with aggregated metadata, token counts, and +-- the most recent user prompt. A "session" is a logical grouping of +-- interceptions that share the same session_id (set by the client). +WITH filtered_interceptions AS ( + SELECT + aibridge_interceptions.* + FROM + aibridge_interceptions + WHERE + -- Remove inflight interceptions (ones which lack an ended_at value). + aibridge_interceptions.ended_at IS NOT NULL + -- Filter by time frame + AND CASE + WHEN @started_after::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at >= @started_after::timestamptz + ELSE true + END + AND CASE + WHEN @started_before::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at <= @started_before::timestamptz + ELSE true + END + -- Filter initiator_id + AND CASE + WHEN @initiator_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN aibridge_interceptions.initiator_id = @initiator_id::uuid + ELSE true + END + -- Filter provider + AND CASE + WHEN @provider::text != '' THEN aibridge_interceptions.provider = @provider::text + ELSE true + END + -- Filter model + AND CASE + WHEN @model::text != '' THEN aibridge_interceptions.model = @model::text + ELSE true + END + -- Filter client + AND CASE + WHEN @client::text != '' THEN COALESCE(aibridge_interceptions.client, 'Unknown') = @client::text + ELSE true + END + -- Filter session_id + AND CASE + WHEN @session_id::text != '' THEN aibridge_interceptions.session_id = @session_id::text + ELSE true + END + -- Authorize Filter clause will be injected below in ListAuthorizedAIBridgeSessions + -- @authorize_filter +), +session_tokens AS ( + -- Aggregate token usage across all interceptions in each session. + -- Group by (session_id, initiator_id) to avoid merging sessions from + -- different users who happen to share the same client_session_id. + SELECT + fi.session_id, + fi.initiator_id, + COALESCE(SUM(tu.input_tokens), 0)::bigint AS input_tokens, + COALESCE(SUM(tu.output_tokens), 0)::bigint AS output_tokens + -- TODO: add extra token types once https://github.com/coder/aibridge/issues/150 lands. + FROM + filtered_interceptions fi + LEFT JOIN + aibridge_token_usages tu ON fi.id = tu.interception_id + GROUP BY + fi.session_id, fi.initiator_id +), +session_root AS ( + -- Build one summary row per session. Group by (session_id, initiator_id) + -- to avoid merging sessions from different users who happen to share the + -- same client_session_id. The ARRAY_AGG with ORDER BY picks values from + -- the chronologically first interception for fields that should represent + -- the session as a whole (client, metadata). Threads are counted as + -- distinct root interception IDs: an interception with a NULL + -- thread_root_id is itself a thread root. + SELECT + fi.session_id, + fi.initiator_id, + (ARRAY_AGG(fi.client ORDER BY fi.started_at, fi.id))[1] AS client, + (ARRAY_AGG(fi.metadata ORDER BY fi.started_at, fi.id))[1] AS metadata, + ARRAY_AGG(DISTINCT fi.provider ORDER BY fi.provider) AS providers, + ARRAY_AGG(DISTINCT fi.model ORDER BY fi.model) AS models, + MIN(fi.started_at) AS started_at, + MAX(fi.ended_at) AS ended_at, + COUNT(DISTINCT COALESCE(fi.thread_root_id, fi.id)) AS threads, + -- Collect IDs for lateral prompt lookup. + ARRAY_AGG(fi.id) AS interception_ids + FROM + filtered_interceptions fi + GROUP BY + fi.session_id, fi.initiator_id +) +SELECT + sr.session_id, + visible_users.id AS user_id, + visible_users.username AS user_username, + visible_users.name AS user_name, + visible_users.avatar_url AS user_avatar_url, + sr.providers::text[] AS providers, + sr.models::text[] AS models, + COALESCE(sr.client, '')::varchar(64) AS client, + sr.metadata::jsonb AS metadata, + sr.started_at::timestamptz AS started_at, + sr.ended_at::timestamptz AS ended_at, + sr.threads, + COALESCE(st.input_tokens, 0)::bigint AS input_tokens, + COALESCE(st.output_tokens, 0)::bigint AS output_tokens, + COALESCE(slp.prompt, '') AS last_prompt +FROM + session_root sr +JOIN + visible_users ON visible_users.id = sr.initiator_id +LEFT JOIN + session_tokens st ON st.session_id = sr.session_id AND st.initiator_id = sr.initiator_id +LEFT JOIN LATERAL ( + -- Lateral join to efficiently fetch only the most recent user prompt + -- across all interceptions in the session, avoiding a full aggregation. + SELECT up.prompt + FROM aibridge_user_prompts up + WHERE up.interception_id = ANY(sr.interception_ids) + ORDER BY up.created_at DESC, up.id DESC + LIMIT 1 +) slp ON true +WHERE + -- Cursor pagination: uses a composite (started_at, session_id) cursor + -- to support keyset pagination. The less-than comparison matches the + -- DESC sort order so that rows after the cursor come later in results. + CASE + WHEN @after_session_id::text != '' THEN ( + (sr.started_at, sr.session_id) < ( + (SELECT started_at FROM session_root WHERE session_id = @after_session_id), + @after_session_id::text + ) + ) + ELSE true + END +ORDER BY + sr.started_at DESC, + sr.session_id DESC +LIMIT COALESCE(NULLIF(@limit_::integer, 0), 100) +OFFSET @offset_ +; + -- name: ListAIBridgeModels :many SELECT model diff --git a/coderd/searchquery/search.go b/coderd/searchquery/search.go index 7d8f517d08..462562ae28 100644 --- a/coderd/searchquery/search.go +++ b/coderd/searchquery/search.go @@ -401,6 +401,49 @@ func AIBridgeInterceptions(ctx context.Context, db database.Store, query string, return filter, parser.Errors } +func AIBridgeSessions(ctx context.Context, db database.Store, query string, page codersdk.Pagination, actorID uuid.UUID, afterSessionID string) (database.ListAIBridgeSessionsParams, []codersdk.ValidationError) { + // nolint:exhaustruct // Empty values just means "don't filter by that field". + filter := database.ListAIBridgeSessionsParams{ + AfterSessionID: afterSessionID, + // #nosec G115 - Safe conversion for pagination limit which is expected to be within int32 range + Limit: int32(page.Limit), + // #nosec G115 - Safe conversion for pagination offset which is expected to be within int32 range + Offset: int32(page.Offset), + } + + if query == "" { + return filter, nil + } + + values, errors := searchTerms(query, func(string, url.Values) error { + // Do not specify a default search key; let's be explicit to prevent user confusion. + return xerrors.New("no search key specified") + }) + if len(errors) > 0 { + return filter, errors + } + + parser := httpapi.NewQueryParamParser() + filter.InitiatorID = parseUser(ctx, db, parser, values, "initiator", actorID) + filter.Provider = parser.String(values, "", "provider") + filter.Model = parser.String(values, "", "model") + filter.Client = parser.String(values, "", "client") + filter.SessionID = parser.String(values, "", "session_id") + + // Time must be between started_after and started_before. + filter.StartedAfter = parser.Time3339Nano(values, time.Time{}, "started_after") + filter.StartedBefore = parser.Time3339Nano(values, time.Time{}, "started_before") + if !filter.StartedBefore.IsZero() && !filter.StartedAfter.IsZero() && !filter.StartedBefore.After(filter.StartedAfter) { + parser.Errors = append(parser.Errors, codersdk.ValidationError{ + Field: "started_before", + Detail: `Query param "started_before" has invalid value: "started_before" must be after "started_after" if set`, + }) + } + + parser.ErrorExcessParams(values) + return filter, parser.Errors +} + func AIBridgeModels(query string, page codersdk.Pagination) (database.ListAIBridgeModelsParams, []codersdk.ValidationError) { // nolint:exhaustruct // Empty values just means "don't filter by that field". filter := database.ListAIBridgeModelsParams{ diff --git a/codersdk/aibridge.go b/codersdk/aibridge.go index 2d994558f5..56b2260bfe 100644 --- a/codersdk/aibridge.go +++ b/codersdk/aibridge.go @@ -63,6 +63,51 @@ type AIBridgeListInterceptionsResponse struct { Results []AIBridgeInterception `json:"results"` } +type AIBridgeSession struct { + ID string `json:"id"` + Initiator MinimalUser `json:"initiator"` + Providers []string `json:"providers"` + Models []string `json:"models"` + Client *string `json:"client"` + Metadata map[string]any `json:"metadata"` + StartedAt time.Time `json:"started_at" format:"date-time"` + EndedAt *time.Time `json:"ended_at,omitempty" format:"date-time"` + Threads int64 `json:"threads"` + TokenUsageSummary AIBridgeSessionTokenUsageSummary `json:"token_usage_summary"` + LastPrompt *string `json:"last_prompt,omitempty"` +} + +type AIBridgeSessionTokenUsageSummary struct { + InputTokens int64 `json:"input_tokens"` + OutputTokens int64 `json:"output_tokens"` +} + +type AIBridgeListSessionsResponse struct { + Count int64 `json:"count"` + Sessions []AIBridgeSession `json:"sessions"` +} + +// @typescript-ignore AIBridgeListSessionsFilter +type AIBridgeListSessionsFilter struct { + // Limit defaults to 100, max is 1000. + Pagination Pagination `json:"pagination,omitempty"` + + // Initiator is a user ID, username, or "me". + Initiator string `json:"initiator,omitempty"` + StartedBefore time.Time `json:"started_before,omitempty" format:"date-time"` + StartedAfter time.Time `json:"started_after,omitempty" format:"date-time"` + Provider string `json:"provider,omitempty"` + Model string `json:"model,omitempty"` + Client string `json:"client,omitempty"` + SessionID string `json:"session_id,omitempty"` + + // AfterSessionID is a cursor for pagination. It is the session ID of the + // last session in the previous page. + AfterSessionID string `json:"after_session_id,omitempty"` + + FilterQuery string `json:"q,omitempty"` +} + // @typescript-ignore AIBridgeListInterceptionsFilter type AIBridgeListInterceptionsFilter struct { // Limit defaults to 100, max is 1000. @@ -117,6 +162,44 @@ func (f AIBridgeListInterceptionsFilter) asRequestOption() RequestOption { } } +// asRequestOption returns a function that can be used in (*Client).Request. +func (f AIBridgeListSessionsFilter) asRequestOption() RequestOption { + return func(r *http.Request) { + var params []string + if f.Initiator != "" { + params = append(params, fmt.Sprintf("initiator:%q", f.Initiator)) + } + if !f.StartedBefore.IsZero() { + params = append(params, fmt.Sprintf("started_before:%q", f.StartedBefore.Format(time.RFC3339Nano))) + } + if !f.StartedAfter.IsZero() { + params = append(params, fmt.Sprintf("started_after:%q", f.StartedAfter.Format(time.RFC3339Nano))) + } + if f.Provider != "" { + params = append(params, fmt.Sprintf("provider:%q", f.Provider)) + } + if f.Model != "" { + params = append(params, fmt.Sprintf("model:%q", f.Model)) + } + if f.Client != "" { + params = append(params, fmt.Sprintf("client:%q", f.Client)) + } + if f.SessionID != "" { + params = append(params, fmt.Sprintf("session_id:%q", f.SessionID)) + } + if f.FilterQuery != "" { + params = append(params, f.FilterQuery) + } + + q := r.URL.Query() + q.Set("q", strings.Join(params, " ")) + if f.AfterSessionID != "" { + q.Set("after_session_id", f.AfterSessionID) + } + r.URL.RawQuery = q.Encode() + } +} + // AIBridgeListInterceptions returns AI Bridge interceptions with the given // filter. func (c *Client) AIBridgeListInterceptions(ctx context.Context, filter AIBridgeListInterceptionsFilter) (AIBridgeListInterceptionsResponse, error) { @@ -131,3 +214,17 @@ func (c *Client) AIBridgeListInterceptions(ctx context.Context, filter AIBridgeL var resp AIBridgeListInterceptionsResponse return resp, json.NewDecoder(res.Body).Decode(&resp) } + +// AIBridgeListSessions returns AI Bridge sessions with the given filter. +func (c *Client) AIBridgeListSessions(ctx context.Context, filter AIBridgeListSessionsFilter) (AIBridgeListSessionsResponse, error) { + res, err := c.Request(ctx, http.MethodGet, "/api/v2/aibridge/sessions", nil, filter.asRequestOption(), filter.Pagination.asRequestOption()) + if err != nil { + return AIBridgeListSessionsResponse{}, err + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return AIBridgeListSessionsResponse{}, ReadBodyAsError(res) + } + var resp AIBridgeListSessionsResponse + return resp, json.NewDecoder(res.Body).Decode(&resp) +} diff --git a/docs/ai-coder/ai-bridge/monitoring.md b/docs/ai-coder/ai-bridge/monitoring.md index f214eeb8a0..d339df3ee8 100644 --- a/docs/ai-coder/ai-bridge/monitoring.md +++ b/docs/ai-coder/ai-bridge/monitoring.md @@ -16,10 +16,10 @@ AI Bridge interception data can be exported for external analysis, compliance re ### REST API -You can retrieve AI Bridge interceptions via the Coder API with filtering and pagination support. +You can retrieve AI Bridge sessions via the Coder API, with filtering and pagination support. ```sh -curl -X GET "https://coder.example.com/api/v2/aibridge/interceptions?q=initiator:me" \ +curl -X GET "https://coder.example.com/api/v2/aibridge/sessions" \ -H "Coder-Session-Token: $CODER_SESSION_TOKEN" ``` diff --git a/docs/reference/api/aibridge.md b/docs/reference/api/aibridge.md index d5ca02bd5b..28479b8991 100644 --- a/docs/reference/api/aibridge.md +++ b/docs/reference/api/aibridge.md @@ -137,3 +137,73 @@ curl -X GET http://coder-server:8080/api/v2/aibridge/models \