mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
feat: session list API (#23202)
<!-- If you have used AI to produce some or all of this PR, please ensure you have read our [AI Contribution guidelines](https://coder.com/docs/about/contributing/AI_CONTRIBUTING) before submitting. --> _Disclaimer:_ _initially_ _produced_ _by_ _Claude_ _Opus_ _4\.6,_ _heavily_ _modified_ _and_ _reviewed_ _by_ _me._ Closes https://github.com/coder/internal/issues/1360 Adds a new `/api/v2/aibridge/sessions` API which returns "sessions". Sessions, as defined in the [RFC](https://www.notion.so/coderhq/AI-Bridge-Sessions-Threads-2ccd579be59280f28021d3baf7472fbe?source=copy_link), are a set of interceptions logically grouped by a session key issued by the client. The API design for this endpoint was done in [this doc](https://github.com/coder/internal/issues/1360). If the client has not provided a session ID, we will revert to the thread root ID, and if that's not present we use the interception's own ID (i.e. a session of a single interception - which is effectively what we show currently in our `/api/v2/aibridge/interceptions` API). The SQL query looks gnarly but it's relatively simple, and seems to perform well (~200ms) even when I import dogfood's `aibridge_*` tables into my workspace. If we need to improve performance on this later we can investigate materialized views, perhaps, but for now I don't think it's warranted. --- _The PR looks large but it's got a lot of generated code; the actual changes aren't huge._
This commit is contained in:
Generated
+123
@@ -163,6 +163,57 @@ const docTemplate = `{
|
||||
]
|
||||
}
|
||||
},
|
||||
"/aibridge/sessions": {
|
||||
"get": {
|
||||
"produces": [
|
||||
"application/json"
|
||||
],
|
||||
"tags": [
|
||||
"AI Bridge"
|
||||
],
|
||||
"summary": "List AI Bridge sessions",
|
||||
"operationId": "list-ai-bridge-sessions",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Search query in the format ` + "`" + `key:value` + "`" + `. Available keys are: initiator, provider, model, client, session_id, started_after, started_before.",
|
||||
"name": "q",
|
||||
"in": "query"
|
||||
},
|
||||
{
|
||||
"type": "integer",
|
||||
"description": "Page limit",
|
||||
"name": "limit",
|
||||
"in": "query"
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Cursor pagination after session ID (cannot be used with offset)",
|
||||
"name": "after_session_id",
|
||||
"in": "query"
|
||||
},
|
||||
{
|
||||
"type": "integer",
|
||||
"description": "Offset pagination (cannot be used with after_session_id)",
|
||||
"name": "offset",
|
||||
"in": "query"
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.AIBridgeListSessionsResponse"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"/appearance": {
|
||||
"get": {
|
||||
"produces": [
|
||||
@@ -12778,6 +12829,20 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeListSessionsResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"count": {
|
||||
"type": "integer"
|
||||
},
|
||||
"sessions": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.AIBridgeSession"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeOpenAIConfig": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -12830,6 +12895,64 @@ const docTemplate = `{
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeSession": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"client": {
|
||||
"type": "string"
|
||||
},
|
||||
"ended_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"id": {
|
||||
"type": "string"
|
||||
},
|
||||
"initiator": {
|
||||
"$ref": "#/definitions/codersdk.MinimalUser"
|
||||
},
|
||||
"last_prompt": {
|
||||
"type": "string"
|
||||
},
|
||||
"metadata": {
|
||||
"type": "object",
|
||||
"additionalProperties": {}
|
||||
},
|
||||
"models": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"providers": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"started_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"threads": {
|
||||
"type": "integer"
|
||||
},
|
||||
"token_usage_summary": {
|
||||
"$ref": "#/definitions/codersdk.AIBridgeSessionTokenUsageSummary"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeSessionTokenUsageSummary": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"input_tokens": {
|
||||
"type": "integer"
|
||||
},
|
||||
"output_tokens": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeTokenUsage": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
||||
Generated
+119
@@ -136,6 +136,53 @@
|
||||
]
|
||||
}
|
||||
},
|
||||
"/aibridge/sessions": {
|
||||
"get": {
|
||||
"produces": ["application/json"],
|
||||
"tags": ["AI Bridge"],
|
||||
"summary": "List AI Bridge sessions",
|
||||
"operationId": "list-ai-bridge-sessions",
|
||||
"parameters": [
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Search query in the format `key:value`. Available keys are: initiator, provider, model, client, session_id, started_after, started_before.",
|
||||
"name": "q",
|
||||
"in": "query"
|
||||
},
|
||||
{
|
||||
"type": "integer",
|
||||
"description": "Page limit",
|
||||
"name": "limit",
|
||||
"in": "query"
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"description": "Cursor pagination after session ID (cannot be used with offset)",
|
||||
"name": "after_session_id",
|
||||
"in": "query"
|
||||
},
|
||||
{
|
||||
"type": "integer",
|
||||
"description": "Offset pagination (cannot be used with after_session_id)",
|
||||
"name": "offset",
|
||||
"in": "query"
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"schema": {
|
||||
"$ref": "#/definitions/codersdk.AIBridgeListSessionsResponse"
|
||||
}
|
||||
}
|
||||
},
|
||||
"security": [
|
||||
{
|
||||
"CoderSessionToken": []
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"/appearance": {
|
||||
"get": {
|
||||
"produces": ["application/json"],
|
||||
@@ -11368,6 +11415,20 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeListSessionsResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"count": {
|
||||
"type": "integer"
|
||||
},
|
||||
"sessions": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"$ref": "#/definitions/codersdk.AIBridgeSession"
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeOpenAIConfig": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -11420,6 +11481,64 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeSession": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"client": {
|
||||
"type": "string"
|
||||
},
|
||||
"ended_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"id": {
|
||||
"type": "string"
|
||||
},
|
||||
"initiator": {
|
||||
"$ref": "#/definitions/codersdk.MinimalUser"
|
||||
},
|
||||
"last_prompt": {
|
||||
"type": "string"
|
||||
},
|
||||
"metadata": {
|
||||
"type": "object",
|
||||
"additionalProperties": {}
|
||||
},
|
||||
"models": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"providers": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"started_at": {
|
||||
"type": "string",
|
||||
"format": "date-time"
|
||||
},
|
||||
"threads": {
|
||||
"type": "integer"
|
||||
},
|
||||
"token_usage_summary": {
|
||||
"$ref": "#/definitions/codersdk.AIBridgeSessionTokenUsageSummary"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeSessionTokenUsageSummary": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"input_tokens": {
|
||||
"type": "integer"
|
||||
},
|
||||
"output_tokens": {
|
||||
"type": "integer"
|
||||
}
|
||||
}
|
||||
},
|
||||
"codersdk.AIBridgeTokenUsage": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
||||
@@ -1021,6 +1021,44 @@ func AIBridgeInterception(interception database.AIBridgeInterception, initiator
|
||||
return intc
|
||||
}
|
||||
|
||||
func AIBridgeSession(row database.ListAIBridgeSessionsRow) codersdk.AIBridgeSession {
|
||||
session := codersdk.AIBridgeSession{
|
||||
ID: row.SessionID,
|
||||
Initiator: MinimalUserFromVisibleUser(database.VisibleUser{
|
||||
ID: row.UserID,
|
||||
Username: row.UserUsername,
|
||||
Name: row.UserName,
|
||||
AvatarURL: row.UserAvatarUrl,
|
||||
}),
|
||||
Providers: row.Providers,
|
||||
Models: row.Models,
|
||||
Metadata: jsonOrEmptyMap(pqtype.NullRawMessage{RawMessage: row.Metadata, Valid: len(row.Metadata) > 0}),
|
||||
StartedAt: row.StartedAt,
|
||||
Threads: row.Threads,
|
||||
TokenUsageSummary: codersdk.AIBridgeSessionTokenUsageSummary{
|
||||
InputTokens: row.InputTokens,
|
||||
OutputTokens: row.OutputTokens,
|
||||
},
|
||||
}
|
||||
// Ensure non-nil slices for JSON serialization.
|
||||
if session.Providers == nil {
|
||||
session.Providers = []string{}
|
||||
}
|
||||
if session.Models == nil {
|
||||
session.Models = []string{}
|
||||
}
|
||||
if row.Client != "" {
|
||||
session.Client = &row.Client
|
||||
}
|
||||
if !row.EndedAt.IsZero() {
|
||||
session.EndedAt = &row.EndedAt
|
||||
}
|
||||
if row.LastPrompt != "" {
|
||||
session.LastPrompt = &row.LastPrompt
|
||||
}
|
||||
return session
|
||||
}
|
||||
|
||||
func AIBridgeTokenUsage(usage database.AIBridgeTokenUsage) codersdk.AIBridgeTokenUsage {
|
||||
return codersdk.AIBridgeTokenUsage{
|
||||
ID: usage.ID,
|
||||
|
||||
@@ -1709,6 +1709,14 @@ func (q *querier) CountAIBridgeInterceptions(ctx context.Context, arg database.C
|
||||
return q.db.CountAuthorizedAIBridgeInterceptions(ctx, arg, prep)
|
||||
}
|
||||
|
||||
func (q *querier) CountAIBridgeSessions(ctx context.Context, arg database.CountAIBridgeSessionsParams) (int64, error) {
|
||||
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAibridgeInterception.Type)
|
||||
if err != nil {
|
||||
return 0, xerrors.Errorf("(dev error) prepare sql filter: %w", err)
|
||||
}
|
||||
return q.db.CountAuthorizedAIBridgeSessions(ctx, arg, prep)
|
||||
}
|
||||
|
||||
func (q *querier) CountAuditLogs(ctx context.Context, arg database.CountAuditLogsParams) (int64, error) {
|
||||
// Shortcut if the user is an owner. The SQL filter is noticeable,
|
||||
// and this is an easy win for owners. Which is the common case.
|
||||
@@ -5317,6 +5325,14 @@ func (q *querier) ListAIBridgeModels(ctx context.Context, arg database.ListAIBri
|
||||
return q.db.ListAuthorizedAIBridgeModels(ctx, arg, prep)
|
||||
}
|
||||
|
||||
func (q *querier) ListAIBridgeSessions(ctx context.Context, arg database.ListAIBridgeSessionsParams) ([]database.ListAIBridgeSessionsRow, error) {
|
||||
prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAibridgeInterception.Type)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("(dev error) prepare sql filter: %w", err)
|
||||
}
|
||||
return q.db.ListAuthorizedAIBridgeSessions(ctx, arg, prep)
|
||||
}
|
||||
|
||||
func (q *querier) ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIDs []uuid.UUID) ([]database.AIBridgeTokenUsage, error) {
|
||||
// This function is a system function until we implement a join for aibridge interceptions.
|
||||
// Matches the behavior of the workspaces listing endpoint.
|
||||
@@ -7128,6 +7144,14 @@ func (q *querier) ListAuthorizedAIBridgeModels(ctx context.Context, arg database
|
||||
return q.ListAIBridgeModels(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) ListAuthorizedAIBridgeSessions(ctx context.Context, arg database.ListAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) ([]database.ListAIBridgeSessionsRow, error) {
|
||||
return q.db.ListAuthorizedAIBridgeSessions(ctx, arg, prepared)
|
||||
}
|
||||
|
||||
func (q *querier) CountAuthorizedAIBridgeSessions(ctx context.Context, arg database.CountAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) (int64, error) {
|
||||
return q.db.CountAuthorizedAIBridgeSessions(ctx, arg, prepared)
|
||||
}
|
||||
|
||||
func (q *querier) GetAuthorizedChats(ctx context.Context, arg database.GetChatsParams, _ rbac.PreparedAuthorized) ([]database.Chat, error) {
|
||||
return q.GetChats(ctx, arg)
|
||||
}
|
||||
|
||||
@@ -5514,6 +5514,34 @@ func (s *MethodTestSuite) TestAIBridge() {
|
||||
check.Args(params, emptyPreparedAuthorized{}).Asserts()
|
||||
}))
|
||||
|
||||
s.Run("ListAIBridgeSessions", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
params := database.ListAIBridgeSessionsParams{}
|
||||
db.EXPECT().ListAuthorizedAIBridgeSessions(gomock.Any(), params, gomock.Any()).Return([]database.ListAIBridgeSessionsRow{}, nil).AnyTimes()
|
||||
// No asserts here because SQLFilter.
|
||||
check.Args(params).Asserts()
|
||||
}))
|
||||
|
||||
s.Run("ListAuthorizedAIBridgeSessions", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
params := database.ListAIBridgeSessionsParams{}
|
||||
db.EXPECT().ListAuthorizedAIBridgeSessions(gomock.Any(), params, gomock.Any()).Return([]database.ListAIBridgeSessionsRow{}, nil).AnyTimes()
|
||||
// No asserts here because SQLFilter.
|
||||
check.Args(params, emptyPreparedAuthorized{}).Asserts()
|
||||
}))
|
||||
|
||||
s.Run("CountAIBridgeSessions", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
params := database.CountAIBridgeSessionsParams{}
|
||||
db.EXPECT().CountAuthorizedAIBridgeSessions(gomock.Any(), params, gomock.Any()).Return(int64(0), nil).AnyTimes()
|
||||
// No asserts here because SQLFilter.
|
||||
check.Args(params).Asserts()
|
||||
}))
|
||||
|
||||
s.Run("CountAuthorizedAIBridgeSessions", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
params := database.CountAIBridgeSessionsParams{}
|
||||
db.EXPECT().CountAuthorizedAIBridgeSessions(gomock.Any(), params, gomock.Any()).Return(int64(0), nil).AnyTimes()
|
||||
// No asserts here because SQLFilter.
|
||||
check.Args(params, emptyPreparedAuthorized{}).Asserts()
|
||||
}))
|
||||
|
||||
s.Run("ListAIBridgeTokenUsagesByInterceptionIDs", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
ids := []uuid.UUID{{1}}
|
||||
db.EXPECT().ListAIBridgeTokenUsagesByInterceptionIDs(gomock.Any(), ids).Return([]database.AIBridgeTokenUsage{}, nil).AnyTimes()
|
||||
|
||||
@@ -280,6 +280,14 @@ func (m queryMetricsStore) CountAIBridgeInterceptions(ctx context.Context, arg d
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) CountAIBridgeSessions(ctx context.Context, arg database.CountAIBridgeSessionsParams) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.CountAIBridgeSessions(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("CountAIBridgeSessions").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "CountAIBridgeSessions").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) CountAuditLogs(ctx context.Context, arg database.CountAuditLogsParams) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.CountAuditLogs(ctx, arg)
|
||||
@@ -3720,6 +3728,14 @@ func (m queryMetricsStore) ListAIBridgeModels(ctx context.Context, arg database.
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListAIBridgeSessions(ctx context.Context, arg database.ListAIBridgeSessionsParams) ([]database.ListAIBridgeSessionsRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListAIBridgeSessions(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("ListAIBridgeSessions").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListAIBridgeSessions").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]database.AIBridgeTokenUsage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListAIBridgeTokenUsagesByInterceptionIDs(ctx, interceptionIds)
|
||||
@@ -5136,6 +5152,22 @@ func (m queryMetricsStore) ListAuthorizedAIBridgeModels(ctx context.Context, arg
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) ListAuthorizedAIBridgeSessions(ctx context.Context, arg database.ListAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) ([]database.ListAIBridgeSessionsRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.ListAuthorizedAIBridgeSessions(ctx, arg, prepared)
|
||||
m.queryLatencies.WithLabelValues("ListAuthorizedAIBridgeSessions").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListAuthorizedAIBridgeSessions").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) CountAuthorizedAIBridgeSessions(ctx context.Context, arg database.CountAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) (int64, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.CountAuthorizedAIBridgeSessions(ctx, arg, prepared)
|
||||
m.queryLatencies.WithLabelValues("CountAuthorizedAIBridgeSessions").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "CountAuthorizedAIBridgeSessions").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetAuthorizedChats(ctx context.Context, arg database.GetChatsParams, prepared rbac.PreparedAuthorized) ([]database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetAuthorizedChats(ctx, arg, prepared)
|
||||
|
||||
@@ -363,6 +363,21 @@ func (mr *MockStoreMockRecorder) CountAIBridgeInterceptions(ctx, arg any) *gomoc
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAIBridgeInterceptions", reflect.TypeOf((*MockStore)(nil).CountAIBridgeInterceptions), ctx, arg)
|
||||
}
|
||||
|
||||
// CountAIBridgeSessions mocks base method.
|
||||
func (m *MockStore) CountAIBridgeSessions(ctx context.Context, arg database.CountAIBridgeSessionsParams) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CountAIBridgeSessions", ctx, arg)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// CountAIBridgeSessions indicates an expected call of CountAIBridgeSessions.
|
||||
func (mr *MockStoreMockRecorder) CountAIBridgeSessions(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAIBridgeSessions", reflect.TypeOf((*MockStore)(nil).CountAIBridgeSessions), ctx, arg)
|
||||
}
|
||||
|
||||
// CountAuditLogs mocks base method.
|
||||
func (m *MockStore) CountAuditLogs(ctx context.Context, arg database.CountAuditLogsParams) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -393,6 +408,21 @@ func (mr *MockStoreMockRecorder) CountAuthorizedAIBridgeInterceptions(ctx, arg,
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAuthorizedAIBridgeInterceptions", reflect.TypeOf((*MockStore)(nil).CountAuthorizedAIBridgeInterceptions), ctx, arg, prepared)
|
||||
}
|
||||
|
||||
// CountAuthorizedAIBridgeSessions mocks base method.
|
||||
func (m *MockStore) CountAuthorizedAIBridgeSessions(ctx context.Context, arg database.CountAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "CountAuthorizedAIBridgeSessions", ctx, arg, prepared)
|
||||
ret0, _ := ret[0].(int64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// CountAuthorizedAIBridgeSessions indicates an expected call of CountAuthorizedAIBridgeSessions.
|
||||
func (mr *MockStoreMockRecorder) CountAuthorizedAIBridgeSessions(ctx, arg, prepared any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CountAuthorizedAIBridgeSessions", reflect.TypeOf((*MockStore)(nil).CountAuthorizedAIBridgeSessions), ctx, arg, prepared)
|
||||
}
|
||||
|
||||
// CountAuthorizedAuditLogs mocks base method.
|
||||
func (m *MockStore) CountAuthorizedAuditLogs(ctx context.Context, arg database.CountAuditLogsParams, prepared rbac.PreparedAuthorized) (int64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -6947,6 +6977,21 @@ func (mr *MockStoreMockRecorder) ListAIBridgeModels(ctx, arg any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAIBridgeModels", reflect.TypeOf((*MockStore)(nil).ListAIBridgeModels), ctx, arg)
|
||||
}
|
||||
|
||||
// ListAIBridgeSessions mocks base method.
|
||||
func (m *MockStore) ListAIBridgeSessions(ctx context.Context, arg database.ListAIBridgeSessionsParams) ([]database.ListAIBridgeSessionsRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ListAIBridgeSessions", ctx, arg)
|
||||
ret0, _ := ret[0].([]database.ListAIBridgeSessionsRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ListAIBridgeSessions indicates an expected call of ListAIBridgeSessions.
|
||||
func (mr *MockStoreMockRecorder) ListAIBridgeSessions(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAIBridgeSessions", reflect.TypeOf((*MockStore)(nil).ListAIBridgeSessions), ctx, arg)
|
||||
}
|
||||
|
||||
// ListAIBridgeTokenUsagesByInterceptionIDs mocks base method.
|
||||
func (m *MockStore) ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]database.AIBridgeTokenUsage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -7022,6 +7067,21 @@ func (mr *MockStoreMockRecorder) ListAuthorizedAIBridgeModels(ctx, arg, prepared
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAuthorizedAIBridgeModels", reflect.TypeOf((*MockStore)(nil).ListAuthorizedAIBridgeModels), ctx, arg, prepared)
|
||||
}
|
||||
|
||||
// ListAuthorizedAIBridgeSessions mocks base method.
|
||||
func (m *MockStore) ListAuthorizedAIBridgeSessions(ctx context.Context, arg database.ListAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) ([]database.ListAIBridgeSessionsRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "ListAuthorizedAIBridgeSessions", ctx, arg, prepared)
|
||||
ret0, _ := ret[0].([]database.ListAIBridgeSessionsRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// ListAuthorizedAIBridgeSessions indicates an expected call of ListAuthorizedAIBridgeSessions.
|
||||
func (mr *MockStoreMockRecorder) ListAuthorizedAIBridgeSessions(ctx, arg, prepared any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAuthorizedAIBridgeSessions", reflect.TypeOf((*MockStore)(nil).ListAuthorizedAIBridgeSessions), ctx, arg, prepared)
|
||||
}
|
||||
|
||||
// ListChatUsageLimitGroupOverrides mocks base method.
|
||||
func (m *MockStore) ListChatUsageLimitGroupOverrides(ctx context.Context) ([]database.ListChatUsageLimitGroupOverridesRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
Generated
+10
-1
@@ -1099,7 +1099,8 @@ CREATE TABLE aibridge_interceptions (
|
||||
client character varying(64) DEFAULT 'Unknown'::character varying,
|
||||
thread_parent_id uuid,
|
||||
thread_root_id uuid,
|
||||
client_session_id character varying(256)
|
||||
client_session_id character varying(256),
|
||||
session_id text GENERATED ALWAYS AS (COALESCE(client_session_id, ((thread_root_id)::text)::character varying, ((id)::text)::character varying)) STORED NOT NULL
|
||||
);
|
||||
|
||||
COMMENT ON TABLE aibridge_interceptions IS 'Audit log of requests intercepted by AI Bridge';
|
||||
@@ -1112,6 +1113,8 @@ COMMENT ON COLUMN aibridge_interceptions.thread_root_id IS 'The root interceptio
|
||||
|
||||
COMMENT ON COLUMN aibridge_interceptions.client_session_id IS 'The session ID supplied by the client (optional and not universally supported).';
|
||||
|
||||
COMMENT ON COLUMN aibridge_interceptions.session_id IS 'Groups related interceptions into a logical session. Determined by a priority chain: (1) client_session_id — an explicit session identifier supplied by the calling client (e.g. Claude Code); (2) thread_root_id — the root of an agentic thread detected by Bridge through tool-call correlation, used when the client does not supply its own session ID; (3) id — the interception''s own ID, used as a last resort so every interception belongs to exactly one session even if it is standalone. This is a generated column stored on disk so it can be indexed and joined without recomputing the COALESCE on every query.';
|
||||
|
||||
CREATE TABLE aibridge_model_thoughts (
|
||||
interception_id uuid NOT NULL,
|
||||
content text NOT NULL,
|
||||
@@ -3654,6 +3657,10 @@ CREATE INDEX idx_aibridge_interceptions_model ON aibridge_interceptions USING bt
|
||||
|
||||
CREATE INDEX idx_aibridge_interceptions_provider ON aibridge_interceptions USING btree (provider);
|
||||
|
||||
CREATE INDEX idx_aibridge_interceptions_session_id ON aibridge_interceptions USING btree (session_id) WHERE (ended_at IS NOT NULL);
|
||||
|
||||
CREATE INDEX idx_aibridge_interceptions_sessions_filter ON aibridge_interceptions USING btree (initiator_id, started_at DESC, id DESC) WHERE (ended_at IS NOT NULL);
|
||||
|
||||
CREATE INDEX idx_aibridge_interceptions_started_id_desc ON aibridge_interceptions USING btree (started_at DESC, id DESC);
|
||||
|
||||
CREATE INDEX idx_aibridge_interceptions_thread_parent_id ON aibridge_interceptions USING btree (thread_parent_id);
|
||||
@@ -3672,6 +3679,8 @@ CREATE INDEX idx_aibridge_tool_usages_provider_tool_call_id ON aibridge_tool_usa
|
||||
|
||||
CREATE INDEX idx_aibridge_tool_usagesprovider_response_id ON aibridge_tool_usages USING btree (provider_response_id);
|
||||
|
||||
CREATE INDEX idx_aibridge_user_prompts_interception_created ON aibridge_user_prompts USING btree (interception_id, created_at DESC, id DESC);
|
||||
|
||||
CREATE INDEX idx_aibridge_user_prompts_interception_id ON aibridge_user_prompts USING btree (interception_id);
|
||||
|
||||
CREATE INDEX idx_aibridge_user_prompts_provider_response_id ON aibridge_user_prompts USING btree (provider_response_id);
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
DROP INDEX IF EXISTS idx_aibridge_interceptions_session_id;
|
||||
DROP INDEX IF EXISTS idx_aibridge_user_prompts_interception_created;
|
||||
DROP INDEX IF EXISTS idx_aibridge_interceptions_sessions_filter;
|
||||
|
||||
ALTER TABLE aibridge_interceptions DROP COLUMN IF EXISTS session_id;
|
||||
@@ -0,0 +1,40 @@
|
||||
-- A "session" groups related interceptions together. See the COMMENT ON
|
||||
-- COLUMN below for the full business-logic description.
|
||||
ALTER TABLE aibridge_interceptions
|
||||
ADD COLUMN session_id TEXT NOT NULL
|
||||
GENERATED ALWAYS AS (
|
||||
COALESCE(
|
||||
client_session_id,
|
||||
thread_root_id::text,
|
||||
id::text
|
||||
)
|
||||
) STORED;
|
||||
|
||||
-- Searching and grouping on the resolved session ID will be common.
|
||||
CREATE INDEX idx_aibridge_interceptions_session_id
|
||||
ON aibridge_interceptions (session_id)
|
||||
WHERE ended_at IS NOT NULL;
|
||||
|
||||
COMMENT ON COLUMN aibridge_interceptions.session_id IS
|
||||
'Groups related interceptions into a logical session. '
|
||||
'Determined by a priority chain: '
|
||||
'(1) client_session_id — an explicit session identifier supplied by the '
|
||||
'calling client (e.g. Claude Code); '
|
||||
'(2) thread_root_id — the root of an agentic thread detected by Bridge '
|
||||
'through tool-call correlation, used when the client does not supply its '
|
||||
'own session ID; '
|
||||
'(3) id — the interception''s own ID, used as a last resort so every '
|
||||
'interception belongs to exactly one session even if it is standalone. '
|
||||
'This is a generated column stored on disk so it can be indexed and '
|
||||
'joined without recomputing the COALESCE on every query.';
|
||||
|
||||
-- Composite index for the most common filter path used by
|
||||
-- ListAIBridgeSessions: initiator_id equality + started_at range,
|
||||
-- with ended_at IS NOT NULL as a partial filter.
|
||||
CREATE INDEX idx_aibridge_interceptions_sessions_filter
|
||||
ON aibridge_interceptions (initiator_id, started_at DESC, id DESC)
|
||||
WHERE ended_at IS NOT NULL;
|
||||
|
||||
-- Supports lateral prompt lookup by interception + recency.
|
||||
CREATE INDEX idx_aibridge_user_prompts_interception_created
|
||||
ON aibridge_user_prompts (interception_id, created_at DESC, id DESC);
|
||||
@@ -806,6 +806,8 @@ type aibridgeQuerier interface {
|
||||
ListAuthorizedAIBridgeInterceptions(ctx context.Context, arg ListAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) ([]ListAIBridgeInterceptionsRow, error)
|
||||
CountAuthorizedAIBridgeInterceptions(ctx context.Context, arg CountAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) (int64, error)
|
||||
ListAuthorizedAIBridgeModels(ctx context.Context, arg ListAIBridgeModelsParams, prepared rbac.PreparedAuthorized) ([]string, error)
|
||||
ListAuthorizedAIBridgeSessions(ctx context.Context, arg ListAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) ([]ListAIBridgeSessionsRow, error)
|
||||
CountAuthorizedAIBridgeSessions(ctx context.Context, arg CountAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) (int64, error)
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) ListAuthorizedAIBridgeInterceptions(ctx context.Context, arg ListAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) ([]ListAIBridgeInterceptionsRow, error) {
|
||||
@@ -852,6 +854,7 @@ func (q *sqlQuerier) ListAuthorizedAIBridgeInterceptions(ctx context.Context, ar
|
||||
&i.AIBridgeInterception.ThreadParentID,
|
||||
&i.AIBridgeInterception.ThreadRootID,
|
||||
&i.AIBridgeInterception.ClientSessionID,
|
||||
&i.AIBridgeInterception.SessionID,
|
||||
&i.VisibleUser.ID,
|
||||
&i.VisibleUser.Username,
|
||||
&i.VisibleUser.Name,
|
||||
@@ -939,6 +942,109 @@ func (q *sqlQuerier) ListAuthorizedAIBridgeModels(ctx context.Context, arg ListA
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) ListAuthorizedAIBridgeSessions(ctx context.Context, arg ListAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) ([]ListAIBridgeSessionsRow, error) {
|
||||
authorizedFilter, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{
|
||||
VariableConverter: regosql.AIBridgeInterceptionConverter(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("compile authorized filter: %w", err)
|
||||
}
|
||||
filtered, err := insertAuthorizedFilter(listAIBridgeSessions, fmt.Sprintf(" AND %s", authorizedFilter))
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("insert authorized filter: %w", err)
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("-- name: ListAuthorizedAIBridgeSessions :many\n%s", filtered)
|
||||
rows, err := q.db.QueryContext(ctx, query,
|
||||
arg.AfterSessionID,
|
||||
arg.Offset,
|
||||
arg.Limit,
|
||||
arg.StartedAfter,
|
||||
arg.StartedBefore,
|
||||
arg.InitiatorID,
|
||||
arg.Provider,
|
||||
arg.Model,
|
||||
arg.Client,
|
||||
arg.SessionID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []ListAIBridgeSessionsRow
|
||||
for rows.Next() {
|
||||
var i ListAIBridgeSessionsRow
|
||||
if err := rows.Scan(
|
||||
&i.SessionID,
|
||||
&i.UserID,
|
||||
&i.UserUsername,
|
||||
&i.UserName,
|
||||
&i.UserAvatarUrl,
|
||||
pq.Array(&i.Providers),
|
||||
pq.Array(&i.Models),
|
||||
&i.Client,
|
||||
&i.Metadata,
|
||||
&i.StartedAt,
|
||||
&i.EndedAt,
|
||||
&i.Threads,
|
||||
&i.InputTokens,
|
||||
&i.OutputTokens,
|
||||
&i.LastPrompt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) CountAuthorizedAIBridgeSessions(ctx context.Context, arg CountAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) (int64, error) {
|
||||
authorizedFilter, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{
|
||||
VariableConverter: regosql.AIBridgeInterceptionConverter(),
|
||||
})
|
||||
if err != nil {
|
||||
return 0, xerrors.Errorf("compile authorized filter: %w", err)
|
||||
}
|
||||
filtered, err := insertAuthorizedFilter(countAIBridgeSessions, fmt.Sprintf(" AND %s", authorizedFilter))
|
||||
if err != nil {
|
||||
return 0, xerrors.Errorf("insert authorized filter: %w", err)
|
||||
}
|
||||
|
||||
query := fmt.Sprintf("-- name: CountAuthorizedAIBridgeSessions :one\n%s", filtered)
|
||||
rows, err := q.db.QueryContext(ctx, query,
|
||||
arg.StartedAfter,
|
||||
arg.StartedBefore,
|
||||
arg.InitiatorID,
|
||||
arg.Provider,
|
||||
arg.Model,
|
||||
arg.Client,
|
||||
arg.SessionID,
|
||||
)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var count int64
|
||||
for rows.Next() {
|
||||
if err := rows.Scan(&count); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func insertAuthorizedFilter(query string, replaceWith string) (string, error) {
|
||||
if !strings.Contains(query, authorizedQueryPlaceholder) {
|
||||
return "", xerrors.Errorf("query does not contain authorized replace string, this is not an authorized query")
|
||||
|
||||
@@ -4036,6 +4036,8 @@ type AIBridgeInterception struct {
|
||||
ThreadRootID uuid.NullUUID `db:"thread_root_id" json:"thread_root_id"`
|
||||
// The session ID supplied by the client (optional and not universally supported).
|
||||
ClientSessionID sql.NullString `db:"client_session_id" json:"client_session_id"`
|
||||
// Groups related interceptions into a logical session. Determined by a priority chain: (1) client_session_id — an explicit session identifier supplied by the calling client (e.g. Claude Code); (2) thread_root_id — the root of an agentic thread detected by Bridge through tool-call correlation, used when the client does not supply its own session ID; (3) id — the interception's own ID, used as a last resort so every interception belongs to exactly one session even if it is standalone. This is a generated column stored on disk so it can be indexed and joined without recomputing the COALESCE on every query.
|
||||
SessionID string `db:"session_id" json:"session_id"`
|
||||
}
|
||||
|
||||
// Audit log of model thinking in intercepted requests in AI Bridge
|
||||
|
||||
@@ -76,6 +76,7 @@ type sqlcQuerier interface {
|
||||
CleanTailnetTunnels(ctx context.Context) error
|
||||
CleanupDeletedMCPServerIDsFromChats(ctx context.Context) error
|
||||
CountAIBridgeInterceptions(ctx context.Context, arg CountAIBridgeInterceptionsParams) (int64, error)
|
||||
CountAIBridgeSessions(ctx context.Context, arg CountAIBridgeSessionsParams) (int64, error)
|
||||
CountAuditLogs(ctx context.Context, arg CountAuditLogsParams) (int64, error)
|
||||
CountConnectionLogs(ctx context.Context, arg CountConnectionLogsParams) (int64, error)
|
||||
// Counts enabled, non-deleted model configs that lack both input and
|
||||
@@ -759,6 +760,10 @@ type sqlcQuerier interface {
|
||||
// (provider, model, client) in the given timeframe for telemetry reporting.
|
||||
ListAIBridgeInterceptionsTelemetrySummaries(ctx context.Context, arg ListAIBridgeInterceptionsTelemetrySummariesParams) ([]ListAIBridgeInterceptionsTelemetrySummariesRow, error)
|
||||
ListAIBridgeModels(ctx context.Context, arg ListAIBridgeModelsParams) ([]string, error)
|
||||
// Returns paginated sessions with aggregated metadata, token counts, and
|
||||
// the most recent user prompt. A "session" is a logical grouping of
|
||||
// interceptions that share the same session_id (set by the client).
|
||||
ListAIBridgeSessions(ctx context.Context, arg ListAIBridgeSessionsParams) ([]ListAIBridgeSessionsRow, error)
|
||||
ListAIBridgeTokenUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeTokenUsage, error)
|
||||
ListAIBridgeToolUsagesByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeToolUsage, error)
|
||||
ListAIBridgeUserPromptsByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeUserPrompt, error)
|
||||
|
||||
@@ -332,6 +332,77 @@ func (q *sqlQuerier) CountAIBridgeInterceptions(ctx context.Context, arg CountAI
|
||||
return count, err
|
||||
}
|
||||
|
||||
const countAIBridgeSessions = `-- name: CountAIBridgeSessions :one
|
||||
SELECT
|
||||
COUNT(DISTINCT (aibridge_interceptions.session_id, aibridge_interceptions.initiator_id))
|
||||
FROM
|
||||
aibridge_interceptions
|
||||
WHERE
|
||||
-- Remove inflight interceptions (ones which lack an ended_at value).
|
||||
aibridge_interceptions.ended_at IS NOT NULL
|
||||
-- Filter by time frame
|
||||
AND CASE
|
||||
WHEN $1::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at >= $1::timestamptz
|
||||
ELSE true
|
||||
END
|
||||
AND CASE
|
||||
WHEN $2::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at <= $2::timestamptz
|
||||
ELSE true
|
||||
END
|
||||
-- Filter initiator_id
|
||||
AND CASE
|
||||
WHEN $3::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN aibridge_interceptions.initiator_id = $3::uuid
|
||||
ELSE true
|
||||
END
|
||||
-- Filter provider
|
||||
AND CASE
|
||||
WHEN $4::text != '' THEN aibridge_interceptions.provider = $4::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter model
|
||||
AND CASE
|
||||
WHEN $5::text != '' THEN aibridge_interceptions.model = $5::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter client
|
||||
AND CASE
|
||||
WHEN $6::text != '' THEN COALESCE(aibridge_interceptions.client, 'Unknown') = $6::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter session_id
|
||||
AND CASE
|
||||
WHEN $7::text != '' THEN aibridge_interceptions.session_id = $7::text
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in CountAuthorizedAIBridgeSessions
|
||||
-- @authorize_filter
|
||||
`
|
||||
|
||||
type CountAIBridgeSessionsParams struct {
|
||||
StartedAfter time.Time `db:"started_after" json:"started_after"`
|
||||
StartedBefore time.Time `db:"started_before" json:"started_before"`
|
||||
InitiatorID uuid.UUID `db:"initiator_id" json:"initiator_id"`
|
||||
Provider string `db:"provider" json:"provider"`
|
||||
Model string `db:"model" json:"model"`
|
||||
Client string `db:"client" json:"client"`
|
||||
SessionID string `db:"session_id" json:"session_id"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) CountAIBridgeSessions(ctx context.Context, arg CountAIBridgeSessionsParams) (int64, error) {
|
||||
row := q.db.QueryRowContext(ctx, countAIBridgeSessions,
|
||||
arg.StartedAfter,
|
||||
arg.StartedBefore,
|
||||
arg.InitiatorID,
|
||||
arg.Provider,
|
||||
arg.Model,
|
||||
arg.Client,
|
||||
arg.SessionID,
|
||||
)
|
||||
var count int64
|
||||
err := row.Scan(&count)
|
||||
return count, err
|
||||
}
|
||||
|
||||
const deleteOldAIBridgeRecords = `-- name: DeleteOldAIBridgeRecords :one
|
||||
WITH
|
||||
-- We don't have FK relationships between the dependent tables and aibridge_interceptions, so we can't rely on DELETE CASCADE.
|
||||
@@ -384,7 +455,7 @@ func (q *sqlQuerier) DeleteOldAIBridgeRecords(ctx context.Context, beforeTime ti
|
||||
|
||||
const getAIBridgeInterceptionByID = `-- name: GetAIBridgeInterceptionByID :one
|
||||
SELECT
|
||||
id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id, client_session_id
|
||||
id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id, client_session_id, session_id
|
||||
FROM
|
||||
aibridge_interceptions
|
||||
WHERE
|
||||
@@ -407,6 +478,7 @@ func (q *sqlQuerier) GetAIBridgeInterceptionByID(ctx context.Context, id uuid.UU
|
||||
&i.ThreadParentID,
|
||||
&i.ThreadRootID,
|
||||
&i.ClientSessionID,
|
||||
&i.SessionID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@@ -441,7 +513,7 @@ func (q *sqlQuerier) GetAIBridgeInterceptionLineageByToolCallID(ctx context.Cont
|
||||
|
||||
const getAIBridgeInterceptions = `-- name: GetAIBridgeInterceptions :many
|
||||
SELECT
|
||||
id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id, client_session_id
|
||||
id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id, client_session_id, session_id
|
||||
FROM
|
||||
aibridge_interceptions
|
||||
`
|
||||
@@ -468,6 +540,7 @@ func (q *sqlQuerier) GetAIBridgeInterceptions(ctx context.Context) ([]AIBridgeIn
|
||||
&i.ThreadParentID,
|
||||
&i.ThreadRootID,
|
||||
&i.ClientSessionID,
|
||||
&i.SessionID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -618,7 +691,7 @@ INSERT INTO aibridge_interceptions (
|
||||
) VALUES (
|
||||
$1, $2, $3, $4, $5, COALESCE($6::jsonb, '{}'::jsonb), $7, $8, $9, $10::uuid, $11::uuid
|
||||
)
|
||||
RETURNING id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id, client_session_id
|
||||
RETURNING id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id, client_session_id, session_id
|
||||
`
|
||||
|
||||
type InsertAIBridgeInterceptionParams struct {
|
||||
@@ -663,6 +736,7 @@ func (q *sqlQuerier) InsertAIBridgeInterception(ctx context.Context, arg InsertA
|
||||
&i.ThreadParentID,
|
||||
&i.ThreadRootID,
|
||||
&i.ClientSessionID,
|
||||
&i.SessionID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@@ -837,7 +911,7 @@ func (q *sqlQuerier) InsertAIBridgeUserPrompt(ctx context.Context, arg InsertAIB
|
||||
|
||||
const listAIBridgeInterceptions = `-- name: ListAIBridgeInterceptions :many
|
||||
SELECT
|
||||
aibridge_interceptions.id, aibridge_interceptions.initiator_id, aibridge_interceptions.provider, aibridge_interceptions.model, aibridge_interceptions.started_at, aibridge_interceptions.metadata, aibridge_interceptions.ended_at, aibridge_interceptions.api_key_id, aibridge_interceptions.client, aibridge_interceptions.thread_parent_id, aibridge_interceptions.thread_root_id, aibridge_interceptions.client_session_id,
|
||||
aibridge_interceptions.id, aibridge_interceptions.initiator_id, aibridge_interceptions.provider, aibridge_interceptions.model, aibridge_interceptions.started_at, aibridge_interceptions.metadata, aibridge_interceptions.ended_at, aibridge_interceptions.api_key_id, aibridge_interceptions.client, aibridge_interceptions.thread_parent_id, aibridge_interceptions.thread_root_id, aibridge_interceptions.client_session_id, aibridge_interceptions.session_id,
|
||||
visible_users.id, visible_users.username, visible_users.name, visible_users.avatar_url
|
||||
FROM
|
||||
aibridge_interceptions
|
||||
@@ -949,6 +1023,7 @@ func (q *sqlQuerier) ListAIBridgeInterceptions(ctx context.Context, arg ListAIBr
|
||||
&i.AIBridgeInterception.ThreadParentID,
|
||||
&i.AIBridgeInterception.ThreadRootID,
|
||||
&i.AIBridgeInterception.ClientSessionID,
|
||||
&i.AIBridgeInterception.SessionID,
|
||||
&i.VisibleUser.ID,
|
||||
&i.VisibleUser.Username,
|
||||
&i.VisibleUser.Name,
|
||||
@@ -1071,6 +1146,229 @@ func (q *sqlQuerier) ListAIBridgeModels(ctx context.Context, arg ListAIBridgeMod
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listAIBridgeSessions = `-- name: ListAIBridgeSessions :many
|
||||
WITH filtered_interceptions AS (
|
||||
SELECT
|
||||
aibridge_interceptions.id, aibridge_interceptions.initiator_id, aibridge_interceptions.provider, aibridge_interceptions.model, aibridge_interceptions.started_at, aibridge_interceptions.metadata, aibridge_interceptions.ended_at, aibridge_interceptions.api_key_id, aibridge_interceptions.client, aibridge_interceptions.thread_parent_id, aibridge_interceptions.thread_root_id, aibridge_interceptions.client_session_id, aibridge_interceptions.session_id
|
||||
FROM
|
||||
aibridge_interceptions
|
||||
WHERE
|
||||
-- Remove inflight interceptions (ones which lack an ended_at value).
|
||||
aibridge_interceptions.ended_at IS NOT NULL
|
||||
-- Filter by time frame
|
||||
AND CASE
|
||||
WHEN $4::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at >= $4::timestamptz
|
||||
ELSE true
|
||||
END
|
||||
AND CASE
|
||||
WHEN $5::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at <= $5::timestamptz
|
||||
ELSE true
|
||||
END
|
||||
-- Filter initiator_id
|
||||
AND CASE
|
||||
WHEN $6::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN aibridge_interceptions.initiator_id = $6::uuid
|
||||
ELSE true
|
||||
END
|
||||
-- Filter provider
|
||||
AND CASE
|
||||
WHEN $7::text != '' THEN aibridge_interceptions.provider = $7::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter model
|
||||
AND CASE
|
||||
WHEN $8::text != '' THEN aibridge_interceptions.model = $8::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter client
|
||||
AND CASE
|
||||
WHEN $9::text != '' THEN COALESCE(aibridge_interceptions.client, 'Unknown') = $9::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter session_id
|
||||
AND CASE
|
||||
WHEN $10::text != '' THEN aibridge_interceptions.session_id = $10::text
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in ListAuthorizedAIBridgeSessions
|
||||
-- @authorize_filter
|
||||
),
|
||||
session_tokens AS (
|
||||
-- Aggregate token usage across all interceptions in each session.
|
||||
-- Group by (session_id, initiator_id) to avoid merging sessions from
|
||||
-- different users who happen to share the same client_session_id.
|
||||
SELECT
|
||||
fi.session_id,
|
||||
fi.initiator_id,
|
||||
COALESCE(SUM(tu.input_tokens), 0)::bigint AS input_tokens,
|
||||
COALESCE(SUM(tu.output_tokens), 0)::bigint AS output_tokens
|
||||
-- TODO: add extra token types once https://github.com/coder/aibridge/issues/150 lands.
|
||||
FROM
|
||||
filtered_interceptions fi
|
||||
LEFT JOIN
|
||||
aibridge_token_usages tu ON fi.id = tu.interception_id
|
||||
GROUP BY
|
||||
fi.session_id, fi.initiator_id
|
||||
),
|
||||
session_root AS (
|
||||
-- Build one summary row per session. Group by (session_id, initiator_id)
|
||||
-- to avoid merging sessions from different users who happen to share the
|
||||
-- same client_session_id. The ARRAY_AGG with ORDER BY picks values from
|
||||
-- the chronologically first interception for fields that should represent
|
||||
-- the session as a whole (client, metadata). Threads are counted as
|
||||
-- distinct root interception IDs: an interception with a NULL
|
||||
-- thread_root_id is itself a thread root.
|
||||
SELECT
|
||||
fi.session_id,
|
||||
fi.initiator_id,
|
||||
(ARRAY_AGG(fi.client ORDER BY fi.started_at, fi.id))[1] AS client,
|
||||
(ARRAY_AGG(fi.metadata ORDER BY fi.started_at, fi.id))[1] AS metadata,
|
||||
ARRAY_AGG(DISTINCT fi.provider ORDER BY fi.provider) AS providers,
|
||||
ARRAY_AGG(DISTINCT fi.model ORDER BY fi.model) AS models,
|
||||
MIN(fi.started_at) AS started_at,
|
||||
MAX(fi.ended_at) AS ended_at,
|
||||
COUNT(DISTINCT COALESCE(fi.thread_root_id, fi.id)) AS threads,
|
||||
-- Collect IDs for lateral prompt lookup.
|
||||
ARRAY_AGG(fi.id) AS interception_ids
|
||||
FROM
|
||||
filtered_interceptions fi
|
||||
GROUP BY
|
||||
fi.session_id, fi.initiator_id
|
||||
)
|
||||
SELECT
|
||||
sr.session_id,
|
||||
visible_users.id AS user_id,
|
||||
visible_users.username AS user_username,
|
||||
visible_users.name AS user_name,
|
||||
visible_users.avatar_url AS user_avatar_url,
|
||||
sr.providers::text[] AS providers,
|
||||
sr.models::text[] AS models,
|
||||
COALESCE(sr.client, '')::varchar(64) AS client,
|
||||
sr.metadata::jsonb AS metadata,
|
||||
sr.started_at::timestamptz AS started_at,
|
||||
sr.ended_at::timestamptz AS ended_at,
|
||||
sr.threads,
|
||||
COALESCE(st.input_tokens, 0)::bigint AS input_tokens,
|
||||
COALESCE(st.output_tokens, 0)::bigint AS output_tokens,
|
||||
COALESCE(slp.prompt, '') AS last_prompt
|
||||
FROM
|
||||
session_root sr
|
||||
JOIN
|
||||
visible_users ON visible_users.id = sr.initiator_id
|
||||
LEFT JOIN
|
||||
session_tokens st ON st.session_id = sr.session_id AND st.initiator_id = sr.initiator_id
|
||||
LEFT JOIN LATERAL (
|
||||
-- Lateral join to efficiently fetch only the most recent user prompt
|
||||
-- across all interceptions in the session, avoiding a full aggregation.
|
||||
SELECT up.prompt
|
||||
FROM aibridge_user_prompts up
|
||||
WHERE up.interception_id = ANY(sr.interception_ids)
|
||||
ORDER BY up.created_at DESC, up.id DESC
|
||||
LIMIT 1
|
||||
) slp ON true
|
||||
WHERE
|
||||
-- Cursor pagination: uses a composite (started_at, session_id) cursor
|
||||
-- to support keyset pagination. The less-than comparison matches the
|
||||
-- DESC sort order so that rows after the cursor come later in results.
|
||||
CASE
|
||||
WHEN $1::text != '' THEN (
|
||||
(sr.started_at, sr.session_id) < (
|
||||
(SELECT started_at FROM session_root WHERE session_id = $1),
|
||||
$1::text
|
||||
)
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
ORDER BY
|
||||
sr.started_at DESC,
|
||||
sr.session_id DESC
|
||||
LIMIT COALESCE(NULLIF($3::integer, 0), 100)
|
||||
OFFSET $2
|
||||
`
|
||||
|
||||
type ListAIBridgeSessionsParams struct {
|
||||
AfterSessionID string `db:"after_session_id" json:"after_session_id"`
|
||||
Offset int32 `db:"offset_" json:"offset_"`
|
||||
Limit int32 `db:"limit_" json:"limit_"`
|
||||
StartedAfter time.Time `db:"started_after" json:"started_after"`
|
||||
StartedBefore time.Time `db:"started_before" json:"started_before"`
|
||||
InitiatorID uuid.UUID `db:"initiator_id" json:"initiator_id"`
|
||||
Provider string `db:"provider" json:"provider"`
|
||||
Model string `db:"model" json:"model"`
|
||||
Client string `db:"client" json:"client"`
|
||||
SessionID string `db:"session_id" json:"session_id"`
|
||||
}
|
||||
|
||||
type ListAIBridgeSessionsRow struct {
|
||||
SessionID string `db:"session_id" json:"session_id"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
UserUsername string `db:"user_username" json:"user_username"`
|
||||
UserName string `db:"user_name" json:"user_name"`
|
||||
UserAvatarUrl string `db:"user_avatar_url" json:"user_avatar_url"`
|
||||
Providers []string `db:"providers" json:"providers"`
|
||||
Models []string `db:"models" json:"models"`
|
||||
Client string `db:"client" json:"client"`
|
||||
Metadata json.RawMessage `db:"metadata" json:"metadata"`
|
||||
StartedAt time.Time `db:"started_at" json:"started_at"`
|
||||
EndedAt time.Time `db:"ended_at" json:"ended_at"`
|
||||
Threads int64 `db:"threads" json:"threads"`
|
||||
InputTokens int64 `db:"input_tokens" json:"input_tokens"`
|
||||
OutputTokens int64 `db:"output_tokens" json:"output_tokens"`
|
||||
LastPrompt string `db:"last_prompt" json:"last_prompt"`
|
||||
}
|
||||
|
||||
// Returns paginated sessions with aggregated metadata, token counts, and
|
||||
// the most recent user prompt. A "session" is a logical grouping of
|
||||
// interceptions that share the same session_id (set by the client).
|
||||
func (q *sqlQuerier) ListAIBridgeSessions(ctx context.Context, arg ListAIBridgeSessionsParams) ([]ListAIBridgeSessionsRow, error) {
|
||||
rows, err := q.db.QueryContext(ctx, listAIBridgeSessions,
|
||||
arg.AfterSessionID,
|
||||
arg.Offset,
|
||||
arg.Limit,
|
||||
arg.StartedAfter,
|
||||
arg.StartedBefore,
|
||||
arg.InitiatorID,
|
||||
arg.Provider,
|
||||
arg.Model,
|
||||
arg.Client,
|
||||
arg.SessionID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []ListAIBridgeSessionsRow
|
||||
for rows.Next() {
|
||||
var i ListAIBridgeSessionsRow
|
||||
if err := rows.Scan(
|
||||
&i.SessionID,
|
||||
&i.UserID,
|
||||
&i.UserUsername,
|
||||
&i.UserName,
|
||||
&i.UserAvatarUrl,
|
||||
pq.Array(&i.Providers),
|
||||
pq.Array(&i.Models),
|
||||
&i.Client,
|
||||
&i.Metadata,
|
||||
&i.StartedAt,
|
||||
&i.EndedAt,
|
||||
&i.Threads,
|
||||
&i.InputTokens,
|
||||
&i.OutputTokens,
|
||||
&i.LastPrompt,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const listAIBridgeTokenUsagesByInterceptionIDs = `-- name: ListAIBridgeTokenUsagesByInterceptionIDs :many
|
||||
SELECT
|
||||
id, interception_id, provider_response_id, input_tokens, output_tokens, metadata, created_at
|
||||
@@ -1209,7 +1507,7 @@ UPDATE aibridge_interceptions
|
||||
WHERE
|
||||
id = $2::uuid
|
||||
AND ended_at IS NULL
|
||||
RETURNING id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id, client_session_id
|
||||
RETURNING id, initiator_id, provider, model, started_at, metadata, ended_at, api_key_id, client, thread_parent_id, thread_root_id, client_session_id, session_id
|
||||
`
|
||||
|
||||
type UpdateAIBridgeInterceptionEndedParams struct {
|
||||
@@ -1233,6 +1531,7 @@ func (q *sqlQuerier) UpdateAIBridgeInterceptionEnded(ctx context.Context, arg Up
|
||||
&i.ThreadParentID,
|
||||
&i.ThreadRootID,
|
||||
&i.ClientSessionID,
|
||||
&i.SessionID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
@@ -404,6 +404,194 @@ SELECT (
|
||||
(SELECT COUNT(*) FROM interceptions)
|
||||
)::bigint as total_deleted;
|
||||
|
||||
-- name: CountAIBridgeSessions :one
|
||||
SELECT
|
||||
COUNT(DISTINCT (aibridge_interceptions.session_id, aibridge_interceptions.initiator_id))
|
||||
FROM
|
||||
aibridge_interceptions
|
||||
WHERE
|
||||
-- Remove inflight interceptions (ones which lack an ended_at value).
|
||||
aibridge_interceptions.ended_at IS NOT NULL
|
||||
-- Filter by time frame
|
||||
AND CASE
|
||||
WHEN @started_after::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at >= @started_after::timestamptz
|
||||
ELSE true
|
||||
END
|
||||
AND CASE
|
||||
WHEN @started_before::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at <= @started_before::timestamptz
|
||||
ELSE true
|
||||
END
|
||||
-- Filter initiator_id
|
||||
AND CASE
|
||||
WHEN @initiator_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN aibridge_interceptions.initiator_id = @initiator_id::uuid
|
||||
ELSE true
|
||||
END
|
||||
-- Filter provider
|
||||
AND CASE
|
||||
WHEN @provider::text != '' THEN aibridge_interceptions.provider = @provider::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter model
|
||||
AND CASE
|
||||
WHEN @model::text != '' THEN aibridge_interceptions.model = @model::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter client
|
||||
AND CASE
|
||||
WHEN @client::text != '' THEN COALESCE(aibridge_interceptions.client, 'Unknown') = @client::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter session_id
|
||||
AND CASE
|
||||
WHEN @session_id::text != '' THEN aibridge_interceptions.session_id = @session_id::text
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in CountAuthorizedAIBridgeSessions
|
||||
-- @authorize_filter
|
||||
;
|
||||
|
||||
-- name: ListAIBridgeSessions :many
|
||||
-- Returns paginated sessions with aggregated metadata, token counts, and
|
||||
-- the most recent user prompt. A "session" is a logical grouping of
|
||||
-- interceptions that share the same session_id (set by the client).
|
||||
WITH filtered_interceptions AS (
|
||||
SELECT
|
||||
aibridge_interceptions.*
|
||||
FROM
|
||||
aibridge_interceptions
|
||||
WHERE
|
||||
-- Remove inflight interceptions (ones which lack an ended_at value).
|
||||
aibridge_interceptions.ended_at IS NOT NULL
|
||||
-- Filter by time frame
|
||||
AND CASE
|
||||
WHEN @started_after::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at >= @started_after::timestamptz
|
||||
ELSE true
|
||||
END
|
||||
AND CASE
|
||||
WHEN @started_before::timestamptz != '0001-01-01 00:00:00+00'::timestamptz THEN aibridge_interceptions.started_at <= @started_before::timestamptz
|
||||
ELSE true
|
||||
END
|
||||
-- Filter initiator_id
|
||||
AND CASE
|
||||
WHEN @initiator_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN aibridge_interceptions.initiator_id = @initiator_id::uuid
|
||||
ELSE true
|
||||
END
|
||||
-- Filter provider
|
||||
AND CASE
|
||||
WHEN @provider::text != '' THEN aibridge_interceptions.provider = @provider::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter model
|
||||
AND CASE
|
||||
WHEN @model::text != '' THEN aibridge_interceptions.model = @model::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter client
|
||||
AND CASE
|
||||
WHEN @client::text != '' THEN COALESCE(aibridge_interceptions.client, 'Unknown') = @client::text
|
||||
ELSE true
|
||||
END
|
||||
-- Filter session_id
|
||||
AND CASE
|
||||
WHEN @session_id::text != '' THEN aibridge_interceptions.session_id = @session_id::text
|
||||
ELSE true
|
||||
END
|
||||
-- Authorize Filter clause will be injected below in ListAuthorizedAIBridgeSessions
|
||||
-- @authorize_filter
|
||||
),
|
||||
session_tokens AS (
|
||||
-- Aggregate token usage across all interceptions in each session.
|
||||
-- Group by (session_id, initiator_id) to avoid merging sessions from
|
||||
-- different users who happen to share the same client_session_id.
|
||||
SELECT
|
||||
fi.session_id,
|
||||
fi.initiator_id,
|
||||
COALESCE(SUM(tu.input_tokens), 0)::bigint AS input_tokens,
|
||||
COALESCE(SUM(tu.output_tokens), 0)::bigint AS output_tokens
|
||||
-- TODO: add extra token types once https://github.com/coder/aibridge/issues/150 lands.
|
||||
FROM
|
||||
filtered_interceptions fi
|
||||
LEFT JOIN
|
||||
aibridge_token_usages tu ON fi.id = tu.interception_id
|
||||
GROUP BY
|
||||
fi.session_id, fi.initiator_id
|
||||
),
|
||||
session_root AS (
|
||||
-- Build one summary row per session. Group by (session_id, initiator_id)
|
||||
-- to avoid merging sessions from different users who happen to share the
|
||||
-- same client_session_id. The ARRAY_AGG with ORDER BY picks values from
|
||||
-- the chronologically first interception for fields that should represent
|
||||
-- the session as a whole (client, metadata). Threads are counted as
|
||||
-- distinct root interception IDs: an interception with a NULL
|
||||
-- thread_root_id is itself a thread root.
|
||||
SELECT
|
||||
fi.session_id,
|
||||
fi.initiator_id,
|
||||
(ARRAY_AGG(fi.client ORDER BY fi.started_at, fi.id))[1] AS client,
|
||||
(ARRAY_AGG(fi.metadata ORDER BY fi.started_at, fi.id))[1] AS metadata,
|
||||
ARRAY_AGG(DISTINCT fi.provider ORDER BY fi.provider) AS providers,
|
||||
ARRAY_AGG(DISTINCT fi.model ORDER BY fi.model) AS models,
|
||||
MIN(fi.started_at) AS started_at,
|
||||
MAX(fi.ended_at) AS ended_at,
|
||||
COUNT(DISTINCT COALESCE(fi.thread_root_id, fi.id)) AS threads,
|
||||
-- Collect IDs for lateral prompt lookup.
|
||||
ARRAY_AGG(fi.id) AS interception_ids
|
||||
FROM
|
||||
filtered_interceptions fi
|
||||
GROUP BY
|
||||
fi.session_id, fi.initiator_id
|
||||
)
|
||||
SELECT
|
||||
sr.session_id,
|
||||
visible_users.id AS user_id,
|
||||
visible_users.username AS user_username,
|
||||
visible_users.name AS user_name,
|
||||
visible_users.avatar_url AS user_avatar_url,
|
||||
sr.providers::text[] AS providers,
|
||||
sr.models::text[] AS models,
|
||||
COALESCE(sr.client, '')::varchar(64) AS client,
|
||||
sr.metadata::jsonb AS metadata,
|
||||
sr.started_at::timestamptz AS started_at,
|
||||
sr.ended_at::timestamptz AS ended_at,
|
||||
sr.threads,
|
||||
COALESCE(st.input_tokens, 0)::bigint AS input_tokens,
|
||||
COALESCE(st.output_tokens, 0)::bigint AS output_tokens,
|
||||
COALESCE(slp.prompt, '') AS last_prompt
|
||||
FROM
|
||||
session_root sr
|
||||
JOIN
|
||||
visible_users ON visible_users.id = sr.initiator_id
|
||||
LEFT JOIN
|
||||
session_tokens st ON st.session_id = sr.session_id AND st.initiator_id = sr.initiator_id
|
||||
LEFT JOIN LATERAL (
|
||||
-- Lateral join to efficiently fetch only the most recent user prompt
|
||||
-- across all interceptions in the session, avoiding a full aggregation.
|
||||
SELECT up.prompt
|
||||
FROM aibridge_user_prompts up
|
||||
WHERE up.interception_id = ANY(sr.interception_ids)
|
||||
ORDER BY up.created_at DESC, up.id DESC
|
||||
LIMIT 1
|
||||
) slp ON true
|
||||
WHERE
|
||||
-- Cursor pagination: uses a composite (started_at, session_id) cursor
|
||||
-- to support keyset pagination. The less-than comparison matches the
|
||||
-- DESC sort order so that rows after the cursor come later in results.
|
||||
CASE
|
||||
WHEN @after_session_id::text != '' THEN (
|
||||
(sr.started_at, sr.session_id) < (
|
||||
(SELECT started_at FROM session_root WHERE session_id = @after_session_id),
|
||||
@after_session_id::text
|
||||
)
|
||||
)
|
||||
ELSE true
|
||||
END
|
||||
ORDER BY
|
||||
sr.started_at DESC,
|
||||
sr.session_id DESC
|
||||
LIMIT COALESCE(NULLIF(@limit_::integer, 0), 100)
|
||||
OFFSET @offset_
|
||||
;
|
||||
|
||||
-- name: ListAIBridgeModels :many
|
||||
SELECT
|
||||
model
|
||||
|
||||
@@ -401,6 +401,49 @@ func AIBridgeInterceptions(ctx context.Context, db database.Store, query string,
|
||||
return filter, parser.Errors
|
||||
}
|
||||
|
||||
func AIBridgeSessions(ctx context.Context, db database.Store, query string, page codersdk.Pagination, actorID uuid.UUID, afterSessionID string) (database.ListAIBridgeSessionsParams, []codersdk.ValidationError) {
|
||||
// nolint:exhaustruct // Empty values just means "don't filter by that field".
|
||||
filter := database.ListAIBridgeSessionsParams{
|
||||
AfterSessionID: afterSessionID,
|
||||
// #nosec G115 - Safe conversion for pagination limit which is expected to be within int32 range
|
||||
Limit: int32(page.Limit),
|
||||
// #nosec G115 - Safe conversion for pagination offset which is expected to be within int32 range
|
||||
Offset: int32(page.Offset),
|
||||
}
|
||||
|
||||
if query == "" {
|
||||
return filter, nil
|
||||
}
|
||||
|
||||
values, errors := searchTerms(query, func(string, url.Values) error {
|
||||
// Do not specify a default search key; let's be explicit to prevent user confusion.
|
||||
return xerrors.New("no search key specified")
|
||||
})
|
||||
if len(errors) > 0 {
|
||||
return filter, errors
|
||||
}
|
||||
|
||||
parser := httpapi.NewQueryParamParser()
|
||||
filter.InitiatorID = parseUser(ctx, db, parser, values, "initiator", actorID)
|
||||
filter.Provider = parser.String(values, "", "provider")
|
||||
filter.Model = parser.String(values, "", "model")
|
||||
filter.Client = parser.String(values, "", "client")
|
||||
filter.SessionID = parser.String(values, "", "session_id")
|
||||
|
||||
// Time must be between started_after and started_before.
|
||||
filter.StartedAfter = parser.Time3339Nano(values, time.Time{}, "started_after")
|
||||
filter.StartedBefore = parser.Time3339Nano(values, time.Time{}, "started_before")
|
||||
if !filter.StartedBefore.IsZero() && !filter.StartedAfter.IsZero() && !filter.StartedBefore.After(filter.StartedAfter) {
|
||||
parser.Errors = append(parser.Errors, codersdk.ValidationError{
|
||||
Field: "started_before",
|
||||
Detail: `Query param "started_before" has invalid value: "started_before" must be after "started_after" if set`,
|
||||
})
|
||||
}
|
||||
|
||||
parser.ErrorExcessParams(values)
|
||||
return filter, parser.Errors
|
||||
}
|
||||
|
||||
func AIBridgeModels(query string, page codersdk.Pagination) (database.ListAIBridgeModelsParams, []codersdk.ValidationError) {
|
||||
// nolint:exhaustruct // Empty values just means "don't filter by that field".
|
||||
filter := database.ListAIBridgeModelsParams{
|
||||
|
||||
@@ -63,6 +63,51 @@ type AIBridgeListInterceptionsResponse struct {
|
||||
Results []AIBridgeInterception `json:"results"`
|
||||
}
|
||||
|
||||
type AIBridgeSession struct {
|
||||
ID string `json:"id"`
|
||||
Initiator MinimalUser `json:"initiator"`
|
||||
Providers []string `json:"providers"`
|
||||
Models []string `json:"models"`
|
||||
Client *string `json:"client"`
|
||||
Metadata map[string]any `json:"metadata"`
|
||||
StartedAt time.Time `json:"started_at" format:"date-time"`
|
||||
EndedAt *time.Time `json:"ended_at,omitempty" format:"date-time"`
|
||||
Threads int64 `json:"threads"`
|
||||
TokenUsageSummary AIBridgeSessionTokenUsageSummary `json:"token_usage_summary"`
|
||||
LastPrompt *string `json:"last_prompt,omitempty"`
|
||||
}
|
||||
|
||||
type AIBridgeSessionTokenUsageSummary struct {
|
||||
InputTokens int64 `json:"input_tokens"`
|
||||
OutputTokens int64 `json:"output_tokens"`
|
||||
}
|
||||
|
||||
type AIBridgeListSessionsResponse struct {
|
||||
Count int64 `json:"count"`
|
||||
Sessions []AIBridgeSession `json:"sessions"`
|
||||
}
|
||||
|
||||
// @typescript-ignore AIBridgeListSessionsFilter
|
||||
type AIBridgeListSessionsFilter struct {
|
||||
// Limit defaults to 100, max is 1000.
|
||||
Pagination Pagination `json:"pagination,omitempty"`
|
||||
|
||||
// Initiator is a user ID, username, or "me".
|
||||
Initiator string `json:"initiator,omitempty"`
|
||||
StartedBefore time.Time `json:"started_before,omitempty" format:"date-time"`
|
||||
StartedAfter time.Time `json:"started_after,omitempty" format:"date-time"`
|
||||
Provider string `json:"provider,omitempty"`
|
||||
Model string `json:"model,omitempty"`
|
||||
Client string `json:"client,omitempty"`
|
||||
SessionID string `json:"session_id,omitempty"`
|
||||
|
||||
// AfterSessionID is a cursor for pagination. It is the session ID of the
|
||||
// last session in the previous page.
|
||||
AfterSessionID string `json:"after_session_id,omitempty"`
|
||||
|
||||
FilterQuery string `json:"q,omitempty"`
|
||||
}
|
||||
|
||||
// @typescript-ignore AIBridgeListInterceptionsFilter
|
||||
type AIBridgeListInterceptionsFilter struct {
|
||||
// Limit defaults to 100, max is 1000.
|
||||
@@ -117,6 +162,44 @@ func (f AIBridgeListInterceptionsFilter) asRequestOption() RequestOption {
|
||||
}
|
||||
}
|
||||
|
||||
// asRequestOption returns a function that can be used in (*Client).Request.
|
||||
func (f AIBridgeListSessionsFilter) asRequestOption() RequestOption {
|
||||
return func(r *http.Request) {
|
||||
var params []string
|
||||
if f.Initiator != "" {
|
||||
params = append(params, fmt.Sprintf("initiator:%q", f.Initiator))
|
||||
}
|
||||
if !f.StartedBefore.IsZero() {
|
||||
params = append(params, fmt.Sprintf("started_before:%q", f.StartedBefore.Format(time.RFC3339Nano)))
|
||||
}
|
||||
if !f.StartedAfter.IsZero() {
|
||||
params = append(params, fmt.Sprintf("started_after:%q", f.StartedAfter.Format(time.RFC3339Nano)))
|
||||
}
|
||||
if f.Provider != "" {
|
||||
params = append(params, fmt.Sprintf("provider:%q", f.Provider))
|
||||
}
|
||||
if f.Model != "" {
|
||||
params = append(params, fmt.Sprintf("model:%q", f.Model))
|
||||
}
|
||||
if f.Client != "" {
|
||||
params = append(params, fmt.Sprintf("client:%q", f.Client))
|
||||
}
|
||||
if f.SessionID != "" {
|
||||
params = append(params, fmt.Sprintf("session_id:%q", f.SessionID))
|
||||
}
|
||||
if f.FilterQuery != "" {
|
||||
params = append(params, f.FilterQuery)
|
||||
}
|
||||
|
||||
q := r.URL.Query()
|
||||
q.Set("q", strings.Join(params, " "))
|
||||
if f.AfterSessionID != "" {
|
||||
q.Set("after_session_id", f.AfterSessionID)
|
||||
}
|
||||
r.URL.RawQuery = q.Encode()
|
||||
}
|
||||
}
|
||||
|
||||
// AIBridgeListInterceptions returns AI Bridge interceptions with the given
|
||||
// filter.
|
||||
func (c *Client) AIBridgeListInterceptions(ctx context.Context, filter AIBridgeListInterceptionsFilter) (AIBridgeListInterceptionsResponse, error) {
|
||||
@@ -131,3 +214,17 @@ func (c *Client) AIBridgeListInterceptions(ctx context.Context, filter AIBridgeL
|
||||
var resp AIBridgeListInterceptionsResponse
|
||||
return resp, json.NewDecoder(res.Body).Decode(&resp)
|
||||
}
|
||||
|
||||
// AIBridgeListSessions returns AI Bridge sessions with the given filter.
|
||||
func (c *Client) AIBridgeListSessions(ctx context.Context, filter AIBridgeListSessionsFilter) (AIBridgeListSessionsResponse, error) {
|
||||
res, err := c.Request(ctx, http.MethodGet, "/api/v2/aibridge/sessions", nil, filter.asRequestOption(), filter.Pagination.asRequestOption())
|
||||
if err != nil {
|
||||
return AIBridgeListSessionsResponse{}, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return AIBridgeListSessionsResponse{}, ReadBodyAsError(res)
|
||||
}
|
||||
var resp AIBridgeListSessionsResponse
|
||||
return resp, json.NewDecoder(res.Body).Decode(&resp)
|
||||
}
|
||||
|
||||
@@ -16,10 +16,10 @@ AI Bridge interception data can be exported for external analysis, compliance re
|
||||
|
||||
### REST API
|
||||
|
||||
You can retrieve AI Bridge interceptions via the Coder API with filtering and pagination support.
|
||||
You can retrieve AI Bridge sessions via the Coder API, with filtering and pagination support.
|
||||
|
||||
```sh
|
||||
curl -X GET "https://coder.example.com/api/v2/aibridge/interceptions?q=initiator:me" \
|
||||
curl -X GET "https://coder.example.com/api/v2/aibridge/sessions" \
|
||||
-H "Coder-Session-Token: $CODER_SESSION_TOKEN"
|
||||
```
|
||||
|
||||
|
||||
Generated
+70
@@ -137,3 +137,73 @@ curl -X GET http://coder-server:8080/api/v2/aibridge/models \
|
||||
<h3 id="list-ai-bridge-models-responseschema">Response Schema</h3>
|
||||
|
||||
To perform this operation, you must be authenticated. [Learn more](authentication.md).
|
||||
|
||||
## List AI Bridge sessions
|
||||
|
||||
### Code samples
|
||||
|
||||
```shell
|
||||
# Example request using curl
|
||||
curl -X GET http://coder-server:8080/api/v2/aibridge/sessions \
|
||||
-H 'Accept: application/json' \
|
||||
-H 'Coder-Session-Token: API_KEY'
|
||||
```
|
||||
|
||||
`GET /aibridge/sessions`
|
||||
|
||||
### Parameters
|
||||
|
||||
| Name | In | Type | Required | Description |
|
||||
|--------------------|-------|---------|----------|--------------------------------------------------------------------------------------------------------------------------------------------|
|
||||
| `q` | query | string | false | Search query in the format `key:value`. Available keys are: initiator, provider, model, client, session_id, started_after, started_before. |
|
||||
| `limit` | query | integer | false | Page limit |
|
||||
| `after_session_id` | query | string | false | Cursor pagination after session ID (cannot be used with offset) |
|
||||
| `offset` | query | integer | false | Offset pagination (cannot be used with after_session_id) |
|
||||
|
||||
### Example responses
|
||||
|
||||
> 200 Response
|
||||
|
||||
```json
|
||||
{
|
||||
"count": 0,
|
||||
"sessions": [
|
||||
{
|
||||
"client": "string",
|
||||
"ended_at": "2019-08-24T14:15:22Z",
|
||||
"id": "string",
|
||||
"initiator": {
|
||||
"avatar_url": "http://example.com",
|
||||
"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08",
|
||||
"name": "string",
|
||||
"username": "string"
|
||||
},
|
||||
"last_prompt": "string",
|
||||
"metadata": {
|
||||
"property1": null,
|
||||
"property2": null
|
||||
},
|
||||
"models": [
|
||||
"string"
|
||||
],
|
||||
"providers": [
|
||||
"string"
|
||||
],
|
||||
"started_at": "2019-08-24T14:15:22Z",
|
||||
"threads": 0,
|
||||
"token_usage_summary": {
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Responses
|
||||
|
||||
| Status | Meaning | Description | Schema |
|
||||
|--------|---------------------------------------------------------|-------------|------------------------------------------------------------------------------------------|
|
||||
| 200 | [OK](https://tools.ietf.org/html/rfc7231#section-6.3.1) | OK | [codersdk.AIBridgeListSessionsResponse](schemas.md#codersdkaibridgelistsessionsresponse) |
|
||||
|
||||
To perform this operation, you must be authenticated. [Learn more](authentication.md).
|
||||
|
||||
Generated
+111
@@ -598,6 +598,51 @@
|
||||
| `count` | integer | false | | |
|
||||
| `results` | array of [codersdk.AIBridgeInterception](#codersdkaibridgeinterception) | false | | |
|
||||
|
||||
## codersdk.AIBridgeListSessionsResponse
|
||||
|
||||
```json
|
||||
{
|
||||
"count": 0,
|
||||
"sessions": [
|
||||
{
|
||||
"client": "string",
|
||||
"ended_at": "2019-08-24T14:15:22Z",
|
||||
"id": "string",
|
||||
"initiator": {
|
||||
"avatar_url": "http://example.com",
|
||||
"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08",
|
||||
"name": "string",
|
||||
"username": "string"
|
||||
},
|
||||
"last_prompt": "string",
|
||||
"metadata": {
|
||||
"property1": null,
|
||||
"property2": null
|
||||
},
|
||||
"models": [
|
||||
"string"
|
||||
],
|
||||
"providers": [
|
||||
"string"
|
||||
],
|
||||
"started_at": "2019-08-24T14:15:22Z",
|
||||
"threads": 0,
|
||||
"token_usage_summary": {
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Properties
|
||||
|
||||
| Name | Type | Required | Restrictions | Description |
|
||||
|------------|---------------------------------------------------------------|----------|--------------|-------------|
|
||||
| `count` | integer | false | | |
|
||||
| `sessions` | array of [codersdk.AIBridgeSession](#codersdkaibridgesession) | false | | |
|
||||
|
||||
## codersdk.AIBridgeOpenAIConfig
|
||||
|
||||
```json
|
||||
@@ -650,6 +695,72 @@
|
||||
| `upstream_proxy` | string | false | | |
|
||||
| `upstream_proxy_ca` | string | false | | |
|
||||
|
||||
## codersdk.AIBridgeSession
|
||||
|
||||
```json
|
||||
{
|
||||
"client": "string",
|
||||
"ended_at": "2019-08-24T14:15:22Z",
|
||||
"id": "string",
|
||||
"initiator": {
|
||||
"avatar_url": "http://example.com",
|
||||
"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08",
|
||||
"name": "string",
|
||||
"username": "string"
|
||||
},
|
||||
"last_prompt": "string",
|
||||
"metadata": {
|
||||
"property1": null,
|
||||
"property2": null
|
||||
},
|
||||
"models": [
|
||||
"string"
|
||||
],
|
||||
"providers": [
|
||||
"string"
|
||||
],
|
||||
"started_at": "2019-08-24T14:15:22Z",
|
||||
"threads": 0,
|
||||
"token_usage_summary": {
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Properties
|
||||
|
||||
| Name | Type | Required | Restrictions | Description |
|
||||
|-----------------------|----------------------------------------------------------------------------------------|----------|--------------|-------------|
|
||||
| `client` | string | false | | |
|
||||
| `ended_at` | string | false | | |
|
||||
| `id` | string | false | | |
|
||||
| `initiator` | [codersdk.MinimalUser](#codersdkminimaluser) | false | | |
|
||||
| `last_prompt` | string | false | | |
|
||||
| `metadata` | object | false | | |
|
||||
| » `[any property]` | any | false | | |
|
||||
| `models` | array of string | false | | |
|
||||
| `providers` | array of string | false | | |
|
||||
| `started_at` | string | false | | |
|
||||
| `threads` | integer | false | | |
|
||||
| `token_usage_summary` | [codersdk.AIBridgeSessionTokenUsageSummary](#codersdkaibridgesessiontokenusagesummary) | false | | |
|
||||
|
||||
## codersdk.AIBridgeSessionTokenUsageSummary
|
||||
|
||||
```json
|
||||
{
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0
|
||||
}
|
||||
```
|
||||
|
||||
### Properties
|
||||
|
||||
| Name | Type | Required | Restrictions | Description |
|
||||
|-----------------|---------|----------|--------------|-------------|
|
||||
| `input_tokens` | integer | false | | |
|
||||
| `output_tokens` | integer | false | | |
|
||||
|
||||
## codersdk.AIBridgeTokenUsage
|
||||
|
||||
```json
|
||||
|
||||
@@ -2,6 +2,7 @@ package coderd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
@@ -10,6 +11,7 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/db2sdk"
|
||||
@@ -22,8 +24,10 @@ import (
|
||||
|
||||
const (
|
||||
maxListInterceptionsLimit = 1000
|
||||
maxListSessionsLimit = 1000
|
||||
maxListModelsLimit = 1000
|
||||
defaultListInterceptionsLimit = 100
|
||||
defaultListSessionsLimit = 100
|
||||
defaultListModelsLimit = 100
|
||||
// aiBridgeRateLimitWindow is the fixed duration for rate limiting AI Bridge
|
||||
// requests. This is hardcoded to keep configuration simple.
|
||||
@@ -43,6 +47,7 @@ func aibridgeHandler(api *API, middlewares ...func(http.Handler) http.Handler) f
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(middlewares...)
|
||||
r.Get("/interceptions", api.aiBridgeListInterceptions)
|
||||
r.Get("/sessions", api.aiBridgeListSessions)
|
||||
r.Get("/models", api.aiBridgeListModels)
|
||||
})
|
||||
|
||||
@@ -176,6 +181,130 @@ func (api *API) aiBridgeListInterceptions(rw http.ResponseWriter, r *http.Reques
|
||||
})
|
||||
}
|
||||
|
||||
// aiBridgeListSessions returns AI Bridge sessions (aggregated interceptions).
|
||||
//
|
||||
// @Summary List AI Bridge sessions
|
||||
// @ID list-ai-bridge-sessions
|
||||
// @Security CoderSessionToken
|
||||
// @Produce json
|
||||
// @Tags AI Bridge
|
||||
// @Param q query string false "Search query in the format `key:value`. Available keys are: initiator, provider, model, client, session_id, started_after, started_before."
|
||||
// @Param limit query int false "Page limit"
|
||||
// @Param after_session_id query string false "Cursor pagination after session ID (cannot be used with offset)"
|
||||
// @Param offset query int false "Offset pagination (cannot be used with after_session_id)"
|
||||
// @Success 200 {object} codersdk.AIBridgeListSessionsResponse
|
||||
// @Router /aibridge/sessions [get]
|
||||
func (api *API) aiBridgeListSessions(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
apiKey := httpmw.APIKey(r)
|
||||
|
||||
page, ok := coderd.ParsePagination(rw, r)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
afterSessionID := r.URL.Query().Get("after_session_id")
|
||||
if afterSessionID != "" && page.Offset != 0 {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Query parameters have invalid values.",
|
||||
Detail: "Cannot use both after_session_id and offset pagination in the same request.",
|
||||
})
|
||||
return
|
||||
}
|
||||
if page.Limit == 0 {
|
||||
page.Limit = defaultListSessionsLimit
|
||||
}
|
||||
if page.Limit > maxListSessionsLimit || page.Limit < 1 {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid pagination limit value.",
|
||||
Detail: fmt.Sprintf("Pagination limit must be in range (0, %d]", maxListSessionsLimit),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
queryStr := r.URL.Query().Get("q")
|
||||
filter, errs := searchquery.AIBridgeSessions(ctx, api.Database, queryStr, page, apiKey.UserID, afterSessionID)
|
||||
if len(errs) > 0 {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid session search query.",
|
||||
Validations: errs,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Validate the cursor session exists before running the main query.
|
||||
if afterSessionID != "" {
|
||||
//nolint:exhaustruct // Only need session_id filter and limit.
|
||||
cursor, err := api.Database.ListAIBridgeSessions(ctx, database.ListAIBridgeSessionsParams{
|
||||
SessionID: afterSessionID,
|
||||
Limit: 1,
|
||||
})
|
||||
if err != nil {
|
||||
api.Logger.Error(ctx, "error validating after_session_id cursor", slog.Error(err))
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error validating after_session_id cursor.",
|
||||
Detail: "", // Don't leak database issue to client.
|
||||
})
|
||||
return
|
||||
}
|
||||
if len(cursor) == 0 {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Query parameter has invalid value.",
|
||||
Detail: fmt.Sprintf("after_session_id: session %q not found", afterSessionID),
|
||||
})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
count int64
|
||||
rows []database.ListAIBridgeSessionsRow
|
||||
)
|
||||
err := api.Database.InTx(func(db database.Store) error {
|
||||
var err error
|
||||
count, err = db.CountAIBridgeSessions(ctx, database.CountAIBridgeSessionsParams{
|
||||
StartedAfter: filter.StartedAfter,
|
||||
StartedBefore: filter.StartedBefore,
|
||||
InitiatorID: filter.InitiatorID,
|
||||
Provider: filter.Provider,
|
||||
Model: filter.Model,
|
||||
Client: filter.Client,
|
||||
SessionID: filter.SessionID,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("count authorized aibridge sessions: %w", err)
|
||||
}
|
||||
|
||||
rows, err = db.ListAIBridgeSessions(ctx, filter)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("list aibridge sessions: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}, &database.TxOptions{
|
||||
Isolation: sql.LevelRepeatableRead, // Consistency across queries tables while writes may be occurring.
|
||||
ReadOnly: true,
|
||||
TxIdentifier: "aibridge_list_sessions",
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error getting AI Bridge sessions.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
sessions := make([]codersdk.AIBridgeSession, len(rows))
|
||||
for i, row := range rows {
|
||||
sessions[i] = db2sdk.AIBridgeSession(row)
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.AIBridgeListSessionsResponse{
|
||||
Count: count,
|
||||
Sessions: sessions,
|
||||
})
|
||||
}
|
||||
|
||||
// aiBridgeListModels returns all AI Bridge models a user can see.
|
||||
//
|
||||
// @Summary List AI Bridge models
|
||||
|
||||
@@ -665,6 +665,560 @@ func TestAIBridgeListInterceptions(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func aibridgeOpts(t *testing.T) *coderdenttest.Options {
|
||||
t.Helper()
|
||||
dv := coderdtest.DeploymentValues(t)
|
||||
dv.AI.BridgeConfig.Enabled = serpent.Bool(true)
|
||||
return &coderdenttest.Options{
|
||||
Options: &coderdtest.Options{
|
||||
DeploymentValues: dv,
|
||||
},
|
||||
LicenseOptions: &coderdenttest.LicenseOptions{
|
||||
Features: license.Features{
|
||||
codersdk.FeatureAIBridge: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func TestAIBridgeListSessions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("EmptyDB", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client, _ := coderdenttest.New(t, aibridgeOpts(t))
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
//nolint:gocritic // Owner role is irrelevant here.
|
||||
res, err := client.AIBridgeListSessions(ctx, codersdk.AIBridgeListSessionsFilter{})
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, res.Sessions)
|
||||
require.EqualValues(t, 0, res.Count)
|
||||
})
|
||||
|
||||
t.Run("OK", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client, db, firstUser := coderdenttest.NewWithDatabase(t, aibridgeOpts(t))
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
now := dbtime.Now()
|
||||
|
||||
// Session 1: Two interceptions sharing client_session_id "session-A".
|
||||
s1i1EndedAt := now.Add(time.Minute)
|
||||
s1i1 := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{
|
||||
InitiatorID: firstUser.UserID,
|
||||
Provider: "anthropic",
|
||||
Model: "claude-4",
|
||||
StartedAt: now,
|
||||
Client: sql.NullString{String: "claude-code", Valid: true},
|
||||
ClientSessionID: sql.NullString{String: "session-A", Valid: true},
|
||||
}, &s1i1EndedAt)
|
||||
s1i2EndedAt := now.Add(2 * time.Minute)
|
||||
dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{
|
||||
InitiatorID: firstUser.UserID,
|
||||
Provider: "anthropic",
|
||||
Model: "claude-4-haiku",
|
||||
StartedAt: now.Add(time.Minute),
|
||||
Client: sql.NullString{String: "claude-code", Valid: true},
|
||||
ClientSessionID: sql.NullString{String: "session-A", Valid: true},
|
||||
ThreadRootInterceptionID: uuid.NullUUID{UUID: s1i1.ID, Valid: true},
|
||||
ThreadParentInterceptionID: uuid.NullUUID{UUID: s1i1.ID, Valid: true},
|
||||
}, &s1i2EndedAt)
|
||||
|
||||
// Add token usages to session 1 interceptions.
|
||||
dbgen.AIBridgeTokenUsage(t, db, database.InsertAIBridgeTokenUsageParams{
|
||||
InterceptionID: s1i1.ID,
|
||||
InputTokens: 100,
|
||||
OutputTokens: 50,
|
||||
CreatedAt: now,
|
||||
})
|
||||
dbgen.AIBridgeTokenUsage(t, db, database.InsertAIBridgeTokenUsageParams{
|
||||
InterceptionID: s1i1.ID,
|
||||
InputTokens: 200,
|
||||
OutputTokens: 75,
|
||||
CreatedAt: now.Add(time.Second),
|
||||
})
|
||||
|
||||
// Add user prompts to session 1.
|
||||
dbgen.AIBridgeUserPrompt(t, db, database.InsertAIBridgeUserPromptParams{
|
||||
InterceptionID: s1i1.ID,
|
||||
Prompt: "first prompt",
|
||||
CreatedAt: now,
|
||||
})
|
||||
dbgen.AIBridgeUserPrompt(t, db, database.InsertAIBridgeUserPromptParams{
|
||||
InterceptionID: s1i1.ID,
|
||||
Prompt: "last prompt in session",
|
||||
CreatedAt: now.Add(time.Minute),
|
||||
})
|
||||
|
||||
// Session 2: Thread-based session (no client_session_id, shared thread_root_id).
|
||||
s2i1EndedAt := now.Add(-time.Hour + time.Minute)
|
||||
s2i1 := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{
|
||||
InitiatorID: firstUser.UserID,
|
||||
Provider: "openai",
|
||||
Model: "gpt-4",
|
||||
StartedAt: now.Add(-time.Hour),
|
||||
}, &s2i1EndedAt)
|
||||
s2i2EndedAt := now.Add(-time.Hour + 2*time.Minute)
|
||||
dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{
|
||||
InitiatorID: firstUser.UserID,
|
||||
Provider: "openai",
|
||||
Model: "gpt-4",
|
||||
StartedAt: now.Add(-time.Hour + time.Minute),
|
||||
ThreadRootInterceptionID: uuid.NullUUID{UUID: s2i1.ID, Valid: true},
|
||||
ThreadParentInterceptionID: uuid.NullUUID{UUID: s2i1.ID, Valid: true},
|
||||
}, &s2i2EndedAt)
|
||||
|
||||
// Session 3: Standalone interception (no client_session_id, no thread_root_id).
|
||||
s3EndedAt := now.Add(-2*time.Hour + time.Minute)
|
||||
s3i1 := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{
|
||||
InitiatorID: firstUser.UserID,
|
||||
Provider: "anthropic",
|
||||
Model: "claude-4",
|
||||
StartedAt: now.Add(-2 * time.Hour),
|
||||
}, &s3EndedAt)
|
||||
|
||||
// Session 4: Two distinct thread roots in one client_session_id.
|
||||
s4i1EndedAt := now.Add(-3*time.Hour + time.Minute)
|
||||
dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{
|
||||
InitiatorID: firstUser.UserID,
|
||||
Provider: "anthropic",
|
||||
Model: "claude-4",
|
||||
StartedAt: now.Add(-3 * time.Hour),
|
||||
ClientSessionID: sql.NullString{String: "session-multi", Valid: true},
|
||||
}, &s4i1EndedAt)
|
||||
s4i2EndedAt := now.Add(-3*time.Hour + 2*time.Minute)
|
||||
dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{
|
||||
InitiatorID: firstUser.UserID,
|
||||
Provider: "openai",
|
||||
Model: "gpt-4",
|
||||
StartedAt: now.Add(-3*time.Hour + time.Minute),
|
||||
ClientSessionID: sql.NullString{String: "session-multi", Valid: true},
|
||||
}, &s4i2EndedAt)
|
||||
|
||||
//nolint:gocritic // Owner role is irrelevant here.
|
||||
res, err := client.AIBridgeListSessions(ctx, codersdk.AIBridgeListSessionsFilter{})
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 4, res.Count)
|
||||
require.Len(t, res.Sessions, 4)
|
||||
|
||||
// Sessions ordered by started_at DESC: session-A (now), then
|
||||
// thread-based (now-1h), then standalone (now-2h), then
|
||||
// multi-thread (now-3h).
|
||||
require.Equal(t, "session-A", res.Sessions[0].ID)
|
||||
require.Equal(t, s2i1.ID.String(), res.Sessions[1].ID)
|
||||
require.Equal(t, s3i1.ID.String(), res.Sessions[2].ID)
|
||||
require.Equal(t, "session-multi", res.Sessions[3].ID)
|
||||
|
||||
// Verify session 1 aggregations.
|
||||
s1 := res.Sessions[0]
|
||||
require.ElementsMatch(t, []string{"anthropic"}, s1.Providers)
|
||||
require.ElementsMatch(t, []string{"claude-4", "claude-4-haiku"}, s1.Models)
|
||||
require.NotNil(t, s1.Client)
|
||||
require.Equal(t, "claude-code", *s1.Client)
|
||||
require.EqualValues(t, 300, s1.TokenUsageSummary.InputTokens)
|
||||
require.EqualValues(t, 125, s1.TokenUsageSummary.OutputTokens)
|
||||
require.NotNil(t, s1.LastPrompt)
|
||||
require.Equal(t, "last prompt in session", *s1.LastPrompt)
|
||||
// Two interceptions in session-A, but they share a thread root,
|
||||
// so thread count is 1.
|
||||
require.EqualValues(t, 1, s1.Threads)
|
||||
|
||||
// Verify session 2 (thread-based).
|
||||
s2 := res.Sessions[1]
|
||||
require.ElementsMatch(t, []string{"openai"}, s2.Providers)
|
||||
// Thread count: the root interception and its child share the same
|
||||
// thread root, so count is 1.
|
||||
require.EqualValues(t, 1, s2.Threads)
|
||||
|
||||
// Verify session 3 (standalone).
|
||||
s3 := res.Sessions[2]
|
||||
require.EqualValues(t, 1, s3.Threads)
|
||||
require.Nil(t, s3.LastPrompt)
|
||||
|
||||
// Verify session 4 (multiple threads). Thread A has a root +
|
||||
// child (1 thread), thread B is a standalone root (1 thread),
|
||||
// so total is 2.
|
||||
s4 := res.Sessions[3]
|
||||
require.EqualValues(t, 2, s4.Threads)
|
||||
require.ElementsMatch(t, []string{"anthropic", "openai"}, s4.Providers)
|
||||
require.ElementsMatch(t, []string{"claude-4", "gpt-4"}, s4.Models)
|
||||
})
|
||||
|
||||
t.Run("Pagination", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client, db, firstUser := coderdenttest.NewWithDatabase(t, aibridgeOpts(t))
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
now := dbtime.Now()
|
||||
// Create 5 standalone sessions with different start times.
|
||||
allSessionIDs := make([]string, 5)
|
||||
for i := range 5 {
|
||||
endedAt := now.Add(-time.Duration(i)*time.Hour + time.Minute)
|
||||
intc := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{
|
||||
InitiatorID: firstUser.UserID,
|
||||
StartedAt: now.Add(-time.Duration(i) * time.Hour),
|
||||
}, &endedAt)
|
||||
// Standalone session: ID = interception UUID string.
|
||||
allSessionIDs[i] = intc.ID.String()
|
||||
}
|
||||
|
||||
// Test offset pagination.
|
||||
//nolint:gocritic // Owner role is irrelevant here.
|
||||
res, err := client.AIBridgeListSessions(ctx, codersdk.AIBridgeListSessionsFilter{
|
||||
Pagination: codersdk.Pagination{Limit: 2},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, res.Sessions, 2)
|
||||
require.EqualValues(t, 5, res.Count)
|
||||
require.Equal(t, allSessionIDs[0], res.Sessions[0].ID)
|
||||
require.Equal(t, allSessionIDs[1], res.Sessions[1].ID)
|
||||
|
||||
// Second page with offset.
|
||||
res, err = client.AIBridgeListSessions(ctx, codersdk.AIBridgeListSessionsFilter{
|
||||
Pagination: codersdk.Pagination{Limit: 2, Offset: 2},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, res.Sessions, 2)
|
||||
require.Equal(t, allSessionIDs[2], res.Sessions[0].ID)
|
||||
require.Equal(t, allSessionIDs[3], res.Sessions[1].ID)
|
||||
|
||||
// Test cursor pagination.
|
||||
res, err = client.AIBridgeListSessions(ctx, codersdk.AIBridgeListSessionsFilter{
|
||||
Pagination: codersdk.Pagination{Limit: 2},
|
||||
AfterSessionID: allSessionIDs[1],
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, res.Sessions, 2)
|
||||
require.Equal(t, allSessionIDs[2], res.Sessions[0].ID)
|
||||
require.Equal(t, allSessionIDs[3], res.Sessions[1].ID)
|
||||
|
||||
// Test mutual exclusion of cursor and offset.
|
||||
_, err = client.AIBridgeListSessions(ctx, codersdk.AIBridgeListSessionsFilter{
|
||||
Pagination: codersdk.Pagination{Limit: 2, Offset: 1},
|
||||
AfterSessionID: allSessionIDs[0],
|
||||
})
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
require.Contains(t, sdkErr.Detail, "Cannot use both after_session_id and offset pagination")
|
||||
})
|
||||
|
||||
t.Run("AfterSessionIDNotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client, _ := coderdenttest.New(t, aibridgeOpts(t))
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
//nolint:gocritic // Owner role is irrelevant here.
|
||||
_, err := client.AIBridgeListSessions(ctx, codersdk.AIBridgeListSessionsFilter{
|
||||
Pagination: codersdk.Pagination{Limit: 10},
|
||||
AfterSessionID: "nonexistent-session-id",
|
||||
})
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode())
|
||||
require.Equal(t, `after_session_id: session "nonexistent-session-id" not found`, sdkErr.Detail)
|
||||
})
|
||||
|
||||
t.Run("Filters", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client, db, firstUser := coderdenttest.NewWithDatabase(t, aibridgeOpts(t))
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
_, user2 := coderdtest.CreateAnotherUser(t, client, firstUser.OrganizationID)
|
||||
|
||||
now := dbtime.Now()
|
||||
|
||||
// Session from user1 with provider "anthropic" and client "claude-code".
|
||||
s1EndedAt := now.Add(time.Minute)
|
||||
s1 := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{
|
||||
InitiatorID: firstUser.UserID,
|
||||
Provider: "anthropic",
|
||||
Model: "claude-4",
|
||||
StartedAt: now,
|
||||
Client: sql.NullString{String: "claude-code", Valid: true},
|
||||
}, &s1EndedAt)
|
||||
|
||||
// Session from user2 with provider "openai".
|
||||
s2EndedAt := now.Add(-time.Hour + time.Minute)
|
||||
s2 := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{
|
||||
InitiatorID: user2.ID,
|
||||
Provider: "openai",
|
||||
Model: "gpt-4",
|
||||
StartedAt: now.Add(-time.Hour),
|
||||
}, &s2EndedAt)
|
||||
|
||||
// Filter by initiator.
|
||||
//nolint:gocritic // Owner role is irrelevant; testing filter behavior.
|
||||
res, err := client.AIBridgeListSessions(ctx, codersdk.AIBridgeListSessionsFilter{
|
||||
Initiator: user2.Username,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 1, res.Count)
|
||||
require.Equal(t, s2.ID.String(), res.Sessions[0].ID)
|
||||
|
||||
// Filter by provider.
|
||||
res, err = client.AIBridgeListSessions(ctx, codersdk.AIBridgeListSessionsFilter{
|
||||
Provider: "anthropic",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 1, res.Count)
|
||||
require.Equal(t, s1.ID.String(), res.Sessions[0].ID)
|
||||
|
||||
// Filter by model.
|
||||
res, err = client.AIBridgeListSessions(ctx, codersdk.AIBridgeListSessionsFilter{
|
||||
Model: "gpt-4",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 1, res.Count)
|
||||
require.Equal(t, s2.ID.String(), res.Sessions[0].ID)
|
||||
|
||||
// Filter by client.
|
||||
res, err = client.AIBridgeListSessions(ctx, codersdk.AIBridgeListSessionsFilter{
|
||||
Client: "claude-code",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 1, res.Count)
|
||||
require.Equal(t, s1.ID.String(), res.Sessions[0].ID)
|
||||
|
||||
// Filter by time range.
|
||||
res, err = client.AIBridgeListSessions(ctx, codersdk.AIBridgeListSessionsFilter{
|
||||
StartedAfter: now.Add(-30 * time.Minute),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 1, res.Count)
|
||||
require.Equal(t, s1.ID.String(), res.Sessions[0].ID)
|
||||
|
||||
// Filter by session_id.
|
||||
res, err = client.AIBridgeListSessions(ctx, codersdk.AIBridgeListSessionsFilter{
|
||||
SessionID: s2.ID.String(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 1, res.Count)
|
||||
require.Len(t, res.Sessions, 1)
|
||||
require.Equal(t, s2.ID.String(), res.Sessions[0].ID)
|
||||
|
||||
// Filter by session_id with no match.
|
||||
res, err = client.AIBridgeListSessions(ctx, codersdk.AIBridgeListSessionsFilter{
|
||||
SessionID: "nonexistent-session-id",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 0, res.Count)
|
||||
require.Empty(t, res.Sessions)
|
||||
})
|
||||
|
||||
t.Run("Authorized", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
adminClient, db, firstUser := coderdenttest.NewWithDatabase(t, aibridgeOpts(t))
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
secondUserClient, secondUser := coderdtest.CreateAnotherUser(t, adminClient, firstUser.OrganizationID)
|
||||
|
||||
now := dbtime.Now()
|
||||
i1EndedAt := now.Add(time.Minute)
|
||||
i1 := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{
|
||||
InitiatorID: firstUser.UserID,
|
||||
StartedAt: now,
|
||||
}, &i1EndedAt)
|
||||
i2 := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{
|
||||
InitiatorID: secondUser.ID,
|
||||
StartedAt: now.Add(-time.Hour),
|
||||
}, &now)
|
||||
|
||||
// Admin can see all sessions.
|
||||
//nolint:gocritic // Intentionally testing admin/owner visibility.
|
||||
res, err := adminClient.AIBridgeListSessions(ctx, codersdk.AIBridgeListSessionsFilter{})
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 2, res.Count)
|
||||
require.Len(t, res.Sessions, 2)
|
||||
require.Equal(t, i1.ID.String(), res.Sessions[0].ID)
|
||||
require.Equal(t, i2.ID.String(), res.Sessions[1].ID)
|
||||
|
||||
// Second user can only see their own sessions.
|
||||
res, err = secondUserClient.AIBridgeListSessions(ctx, codersdk.AIBridgeListSessionsFilter{})
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 1, res.Count)
|
||||
require.Len(t, res.Sessions, 1)
|
||||
require.Equal(t, i2.ID.String(), res.Sessions[0].ID)
|
||||
})
|
||||
|
||||
t.Run("SessionIDCollisionAcrossUsers", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client, db, firstUser := coderdenttest.NewWithDatabase(t, aibridgeOpts(t))
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
_, user2 := coderdtest.CreateAnotherUser(t, client, firstUser.OrganizationID)
|
||||
|
||||
now := dbtime.Now()
|
||||
|
||||
// Two users share the same client_session_id. They must be
|
||||
// treated as distinct sessions.
|
||||
sharedSessionID := "shared-session-id"
|
||||
u1EndedAt := now.Add(time.Minute)
|
||||
u1Interception := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{
|
||||
InitiatorID: firstUser.UserID,
|
||||
Provider: "anthropic",
|
||||
Model: "claude-4",
|
||||
StartedAt: now,
|
||||
Client: sql.NullString{String: "claude-code", Valid: true},
|
||||
ClientSessionID: sql.NullString{String: sharedSessionID, Valid: true},
|
||||
}, &u1EndedAt)
|
||||
dbgen.AIBridgeTokenUsage(t, db, database.InsertAIBridgeTokenUsageParams{
|
||||
InterceptionID: u1Interception.ID,
|
||||
InputTokens: 100,
|
||||
OutputTokens: 50,
|
||||
CreatedAt: now,
|
||||
})
|
||||
|
||||
u2EndedAt := now.Add(-time.Hour + time.Minute)
|
||||
u2Interception := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{
|
||||
InitiatorID: user2.ID,
|
||||
Provider: "openai",
|
||||
Model: "gpt-4",
|
||||
StartedAt: now.Add(-time.Hour),
|
||||
Client: sql.NullString{String: "cursor", Valid: true},
|
||||
ClientSessionID: sql.NullString{String: sharedSessionID, Valid: true},
|
||||
}, &u2EndedAt)
|
||||
dbgen.AIBridgeTokenUsage(t, db, database.InsertAIBridgeTokenUsageParams{
|
||||
InterceptionID: u2Interception.ID,
|
||||
InputTokens: 200,
|
||||
OutputTokens: 75,
|
||||
CreatedAt: now.Add(-time.Hour),
|
||||
})
|
||||
|
||||
// Admin should see two distinct sessions despite the shared
|
||||
// session_id, each with the correct user and token counts.
|
||||
//nolint:gocritic // Owner role is irrelevant; testing collision behavior.
|
||||
res, err := client.AIBridgeListSessions(ctx, codersdk.AIBridgeListSessionsFilter{})
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 2, res.Count)
|
||||
require.Len(t, res.Sessions, 2)
|
||||
|
||||
// Both sessions share the same ID string but belong to
|
||||
// different users.
|
||||
require.Equal(t, sharedSessionID, res.Sessions[0].ID)
|
||||
require.Equal(t, sharedSessionID, res.Sessions[1].ID)
|
||||
require.NotEqual(t, res.Sessions[0].Initiator.ID, res.Sessions[1].Initiator.ID)
|
||||
|
||||
// Verify token counts are not merged across users.
|
||||
for _, s := range res.Sessions {
|
||||
if s.Initiator.ID == firstUser.UserID {
|
||||
require.EqualValues(t, 100, s.TokenUsageSummary.InputTokens)
|
||||
require.EqualValues(t, 50, s.TokenUsageSummary.OutputTokens)
|
||||
} else {
|
||||
require.EqualValues(t, 200, s.TokenUsageSummary.InputTokens)
|
||||
require.EqualValues(t, 75, s.TokenUsageSummary.OutputTokens)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("InflightSessions", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client, db, firstUser := coderdenttest.NewWithDatabase(t, aibridgeOpts(t))
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
now := dbtime.Now()
|
||||
i1EndedAt := now.Add(time.Minute)
|
||||
i1 := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{
|
||||
InitiatorID: firstUser.UserID,
|
||||
StartedAt: now,
|
||||
}, &i1EndedAt)
|
||||
// Inflight interception (no ended_at) should not appear as a session.
|
||||
dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{
|
||||
InitiatorID: firstUser.UserID,
|
||||
StartedAt: now.Add(-time.Hour),
|
||||
}, nil)
|
||||
|
||||
//nolint:gocritic // Owner role is irrelevant; testing inflight filtering.
|
||||
res, err := client.AIBridgeListSessions(ctx, codersdk.AIBridgeListSessionsFilter{})
|
||||
require.NoError(t, err)
|
||||
require.EqualValues(t, 1, res.Count)
|
||||
require.Len(t, res.Sessions, 1)
|
||||
require.Equal(t, i1.ID.String(), res.Sessions[0].ID)
|
||||
})
|
||||
|
||||
t.Run("FilterErrors", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client, _ := coderdenttest.New(t, aibridgeOpts(t))
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
q string
|
||||
want []codersdk.ValidationError
|
||||
}{
|
||||
{
|
||||
name: "UnknownUsername",
|
||||
q: "initiator:unknown",
|
||||
want: []codersdk.ValidationError{
|
||||
{
|
||||
Field: "initiator",
|
||||
Detail: `Query param "initiator" has invalid value: user "unknown" either does not exist, or you are unauthorized to view them`,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "InvalidStartedAfter",
|
||||
q: "started_after:invalid",
|
||||
want: []codersdk.ValidationError{
|
||||
{
|
||||
Field: "started_after",
|
||||
Detail: `Query param "started_after" must be a valid date format (2006-01-02T15:04:05.999999999Z07:00): parsing time "INVALID" as "2006-01-02T15:04:05.999999999Z07:00": cannot parse "INVALID" as "2006"`,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "InvalidStartedBefore",
|
||||
q: "started_before:invalid",
|
||||
want: []codersdk.ValidationError{
|
||||
{
|
||||
Field: "started_before",
|
||||
Detail: `Query param "started_before" must be a valid date format (2006-01-02T15:04:05.999999999Z07:00): parsing time "INVALID" as "2006-01-02T15:04:05.999999999Z07:00": cannot parse "INVALID" as "2006"`,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "InvalidBeforeAfterRange",
|
||||
q: `started_after:"2025-01-01T00:00:00Z" started_before:"2024-01-01T00:00:00Z"`,
|
||||
want: []codersdk.ValidationError{
|
||||
{
|
||||
Field: "started_before",
|
||||
Detail: `Query param "started_before" has invalid value: "started_before" must be after "started_after" if set`,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
res, err := client.AIBridgeListSessions(ctx, codersdk.AIBridgeListSessionsFilter{
|
||||
FilterQuery: tc.q,
|
||||
})
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
require.Equal(t, tc.want, sdkErr.Validations)
|
||||
require.Empty(t, res.Sessions)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("PaginationLimitValidation", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client, _ := coderdenttest.New(t, aibridgeOpts(t))
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
//nolint:gocritic // Owner role is irrelevant; testing pagination validation.
|
||||
res, err := client.AIBridgeListSessions(ctx, codersdk.AIBridgeListSessionsFilter{
|
||||
Pagination: codersdk.Pagination{
|
||||
Limit: 1001,
|
||||
},
|
||||
})
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
require.Contains(t, sdkErr.Message, "Invalid pagination limit value.")
|
||||
require.Empty(t, res.Sessions)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAIBridgeRouting(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
Generated
+28
@@ -75,6 +75,12 @@ export interface AIBridgeListInterceptionsResponse {
|
||||
readonly results: readonly AIBridgeInterception[];
|
||||
}
|
||||
|
||||
// From codersdk/aibridge.go
|
||||
export interface AIBridgeListSessionsResponse {
|
||||
readonly count: number;
|
||||
readonly sessions: readonly AIBridgeSession[];
|
||||
}
|
||||
|
||||
// From codersdk/deployment.go
|
||||
export interface AIBridgeOpenAIConfig {
|
||||
readonly base_url: string;
|
||||
@@ -95,6 +101,28 @@ export interface AIBridgeProxyConfig {
|
||||
readonly allowed_private_cidrs: string;
|
||||
}
|
||||
|
||||
// From codersdk/aibridge.go
|
||||
export interface AIBridgeSession {
|
||||
readonly id: string;
|
||||
readonly initiator: MinimalUser;
|
||||
readonly providers: readonly string[];
|
||||
readonly models: readonly string[];
|
||||
readonly client: string | null;
|
||||
// empty interface{} type, falling back to unknown
|
||||
readonly metadata: Record<string, unknown>;
|
||||
readonly started_at: string;
|
||||
readonly ended_at?: string;
|
||||
readonly threads: number;
|
||||
readonly token_usage_summary: AIBridgeSessionTokenUsageSummary;
|
||||
readonly last_prompt?: string;
|
||||
}
|
||||
|
||||
// From codersdk/aibridge.go
|
||||
export interface AIBridgeSessionTokenUsageSummary {
|
||||
readonly input_tokens: number;
|
||||
readonly output_tokens: number;
|
||||
}
|
||||
|
||||
// From codersdk/aibridge.go
|
||||
export interface AIBridgeTokenUsage {
|
||||
readonly id: string;
|
||||
|
||||
Reference in New Issue
Block a user