diff --git a/coderd/apidoc/docs.go b/coderd/apidoc/docs.go index 4790ce5b20..7ef90f7e8a 100644 --- a/coderd/apidoc/docs.go +++ b/coderd/apidoc/docs.go @@ -214,6 +214,58 @@ const docTemplate = `{ ] } }, + "/aibridge/sessions/{session_id}": { + "get": { + "produces": [ + "application/json" + ], + "tags": [ + "AI Bridge" + ], + "summary": "Get AI Bridge session threads", + "operationId": "get-ai-bridge-session-threads", + "parameters": [ + { + "type": "string", + "description": "Session ID (client_session_id or interception UUID)", + "name": "session_id", + "in": "path", + "required": true + }, + { + "type": "string", + "description": "Thread pagination cursor (forward/older)", + "name": "after_id", + "in": "query" + }, + { + "type": "string", + "description": "Thread pagination cursor (backward/newer)", + "name": "before_id", + "in": "query" + }, + { + "type": "integer", + "description": "Number of threads per page (default 50)", + "name": "limit", + "in": "query" + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/codersdk.AIBridgeSessionThreadsResponse" + } + } + }, + "security": [ + { + "CoderSessionToken": [] + } + ] + } + }, "/appearance": { "get": { "produces": [ @@ -12675,6 +12727,29 @@ const docTemplate = `{ } } }, + "codersdk.AIBridgeAgenticAction": { + "type": "object", + "properties": { + "model": { + "type": "string" + }, + "thinking": { + "type": "array", + "items": { + "$ref": "#/definitions/codersdk.AIBridgeModelThought" + } + }, + "token_usage": { + "$ref": "#/definitions/codersdk.AIBridgeSessionThreadsTokenUsage" + }, + "tool_calls": { + "type": "array", + "items": { + "$ref": "#/definitions/codersdk.AIBridgeToolCall" + } + } + } + }, "codersdk.AIBridgeAnthropicConfig": { "type": "object", "properties": { @@ -12843,6 +12918,14 @@ const docTemplate = `{ } } }, + "codersdk.AIBridgeModelThought": { + "type": "object", + "properties": { + "text": { + "type": "string" + } + } + }, "codersdk.AIBridgeOpenAIConfig": { "type": "object", "properties": { @@ -12942,6 +13025,76 @@ const docTemplate = `{ } } }, + "codersdk.AIBridgeSessionThreadsResponse": { + "type": "object", + "properties": { + "client": { + "type": "string" + }, + "ended_at": { + "type": "string", + "format": "date-time" + }, + "id": { + "type": "string" + }, + "initiator": { + "$ref": "#/definitions/codersdk.MinimalUser" + }, + "metadata": { + "type": "object", + "additionalProperties": {} + }, + "models": { + "type": "array", + "items": { + "type": "string" + } + }, + "page_ended_at": { + "type": "string", + "format": "date-time" + }, + "page_started_at": { + "type": "string", + "format": "date-time" + }, + "providers": { + "type": "array", + "items": { + "type": "string" + } + }, + "started_at": { + "type": "string", + "format": "date-time" + }, + "threads": { + "type": "array", + "items": { + "$ref": "#/definitions/codersdk.AIBridgeThread" + } + }, + "token_usage_summary": { + "$ref": "#/definitions/codersdk.AIBridgeSessionThreadsTokenUsage" + } + } + }, + "codersdk.AIBridgeSessionThreadsTokenUsage": { + "type": "object", + "properties": { + "input_tokens": { + "type": "integer" + }, + "metadata": { + "type": "object", + "additionalProperties": {} + }, + "output_tokens": { + "type": "integer" + } + } + }, "codersdk.AIBridgeSessionTokenUsageSummary": { "type": "object", "properties": { @@ -12953,6 +13106,41 @@ const docTemplate = `{ } } }, + "codersdk.AIBridgeThread": { + "type": "object", + "properties": { + "agentic_actions": { + "type": "array", + "items": { + "$ref": "#/definitions/codersdk.AIBridgeAgenticAction" + } + }, + "ended_at": { + "type": "string", + "format": "date-time" + }, + "id": { + "type": "string", + "format": "uuid" + }, + "model": { + "type": "string" + }, + "prompt": { + "type": "string" + }, + "provider": { + "type": "string" + }, + "started_at": { + "type": "string", + "format": "date-time" + }, + "token_usage": { + "$ref": "#/definitions/codersdk.AIBridgeSessionThreadsTokenUsage" + } + } + }, "codersdk.AIBridgeTokenUsage": { "type": "object", "properties": { @@ -12983,6 +13171,42 @@ const docTemplate = `{ } } }, + "codersdk.AIBridgeToolCall": { + "type": "object", + "properties": { + "created_at": { + "type": "string", + "format": "date-time" + }, + "id": { + "type": "string", + "format": "uuid" + }, + "injected": { + "type": "boolean" + }, + "input": { + "type": "string" + }, + "interception_id": { + "type": "string", + "format": "uuid" + }, + "metadata": { + "type": "object", + "additionalProperties": {} + }, + "provider_response_id": { + "type": "string" + }, + "server_url": { + "type": "string" + }, + "tool": { + "type": "string" + } + } + }, "codersdk.AIBridgeToolUsage": { "type": "object", "properties": { diff --git a/coderd/apidoc/swagger.json b/coderd/apidoc/swagger.json index 9d2cd58d9f..0d220380ed 100644 --- a/coderd/apidoc/swagger.json +++ b/coderd/apidoc/swagger.json @@ -183,6 +183,54 @@ ] } }, + "/aibridge/sessions/{session_id}": { + "get": { + "produces": ["application/json"], + "tags": ["AI Bridge"], + "summary": "Get AI Bridge session threads", + "operationId": "get-ai-bridge-session-threads", + "parameters": [ + { + "type": "string", + "description": "Session ID (client_session_id or interception UUID)", + "name": "session_id", + "in": "path", + "required": true + }, + { + "type": "string", + "description": "Thread pagination cursor (forward/older)", + "name": "after_id", + "in": "query" + }, + { + "type": "string", + "description": "Thread pagination cursor (backward/newer)", + "name": "before_id", + "in": "query" + }, + { + "type": "integer", + "description": "Number of threads per page (default 50)", + "name": "limit", + "in": "query" + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/codersdk.AIBridgeSessionThreadsResponse" + } + } + }, + "security": [ + { + "CoderSessionToken": [] + } + ] + } + }, "/appearance": { "get": { "produces": ["application/json"], @@ -11261,6 +11309,29 @@ } } }, + "codersdk.AIBridgeAgenticAction": { + "type": "object", + "properties": { + "model": { + "type": "string" + }, + "thinking": { + "type": "array", + "items": { + "$ref": "#/definitions/codersdk.AIBridgeModelThought" + } + }, + "token_usage": { + "$ref": "#/definitions/codersdk.AIBridgeSessionThreadsTokenUsage" + }, + "tool_calls": { + "type": "array", + "items": { + "$ref": "#/definitions/codersdk.AIBridgeToolCall" + } + } + } + }, "codersdk.AIBridgeAnthropicConfig": { "type": "object", "properties": { @@ -11429,6 +11500,14 @@ } } }, + "codersdk.AIBridgeModelThought": { + "type": "object", + "properties": { + "text": { + "type": "string" + } + } + }, "codersdk.AIBridgeOpenAIConfig": { "type": "object", "properties": { @@ -11528,6 +11607,76 @@ } } }, + "codersdk.AIBridgeSessionThreadsResponse": { + "type": "object", + "properties": { + "client": { + "type": "string" + }, + "ended_at": { + "type": "string", + "format": "date-time" + }, + "id": { + "type": "string" + }, + "initiator": { + "$ref": "#/definitions/codersdk.MinimalUser" + }, + "metadata": { + "type": "object", + "additionalProperties": {} + }, + "models": { + "type": "array", + "items": { + "type": "string" + } + }, + "page_ended_at": { + "type": "string", + "format": "date-time" + }, + "page_started_at": { + "type": "string", + "format": "date-time" + }, + "providers": { + "type": "array", + "items": { + "type": "string" + } + }, + "started_at": { + "type": "string", + "format": "date-time" + }, + "threads": { + "type": "array", + "items": { + "$ref": "#/definitions/codersdk.AIBridgeThread" + } + }, + "token_usage_summary": { + "$ref": "#/definitions/codersdk.AIBridgeSessionThreadsTokenUsage" + } + } + }, + "codersdk.AIBridgeSessionThreadsTokenUsage": { + "type": "object", + "properties": { + "input_tokens": { + "type": "integer" + }, + "metadata": { + "type": "object", + "additionalProperties": {} + }, + "output_tokens": { + "type": "integer" + } + } + }, "codersdk.AIBridgeSessionTokenUsageSummary": { "type": "object", "properties": { @@ -11539,6 +11688,41 @@ } } }, + "codersdk.AIBridgeThread": { + "type": "object", + "properties": { + "agentic_actions": { + "type": "array", + "items": { + "$ref": "#/definitions/codersdk.AIBridgeAgenticAction" + } + }, + "ended_at": { + "type": "string", + "format": "date-time" + }, + "id": { + "type": "string", + "format": "uuid" + }, + "model": { + "type": "string" + }, + "prompt": { + "type": "string" + }, + "provider": { + "type": "string" + }, + "started_at": { + "type": "string", + "format": "date-time" + }, + "token_usage": { + "$ref": "#/definitions/codersdk.AIBridgeSessionThreadsTokenUsage" + } + } + }, "codersdk.AIBridgeTokenUsage": { "type": "object", "properties": { @@ -11569,6 +11753,42 @@ } } }, + "codersdk.AIBridgeToolCall": { + "type": "object", + "properties": { + "created_at": { + "type": "string", + "format": "date-time" + }, + "id": { + "type": "string", + "format": "uuid" + }, + "injected": { + "type": "boolean" + }, + "input": { + "type": "string" + }, + "interception_id": { + "type": "string", + "format": "uuid" + }, + "metadata": { + "type": "object", + "additionalProperties": {} + }, + "provider_response_id": { + "type": "string" + }, + "server_url": { + "type": "string" + }, + "tool": { + "type": "string" + } + } + }, "codersdk.AIBridgeToolUsage": { "type": "object", "properties": { diff --git a/coderd/database/db2sdk/db2sdk.go b/coderd/database/db2sdk/db2sdk.go index d79d35b88b..8f87e219d2 100644 --- a/coderd/database/db2sdk/db2sdk.go +++ b/coderd/database/db2sdk/db2sdk.go @@ -1097,6 +1097,287 @@ func AIBridgeToolUsage(usage database.AIBridgeToolUsage) codersdk.AIBridgeToolUs } } +// AIBridgeSessionThreads converts session metadata and thread interceptions +// into the threads response. It groups interceptions into threads, builds +// agentic actions from tool usages and model thoughts, and aggregates +// token usage with metadata. +func AIBridgeSessionThreads( + session database.ListAIBridgeSessionsRow, + interceptions []database.ListAIBridgeSessionThreadsRow, + tokenUsages []database.AIBridgeTokenUsage, + toolUsages []database.AIBridgeToolUsage, + userPrompts []database.AIBridgeUserPrompt, + modelThoughts []database.AIBridgeModelThought, +) codersdk.AIBridgeSessionThreadsResponse { + // Index subresources by interception ID. + tokensByInterception := make(map[uuid.UUID][]database.AIBridgeTokenUsage, len(interceptions)) + for _, tu := range tokenUsages { + tokensByInterception[tu.InterceptionID] = append(tokensByInterception[tu.InterceptionID], tu) + } + toolsByInterception := make(map[uuid.UUID][]database.AIBridgeToolUsage, len(interceptions)) + for _, tu := range toolUsages { + toolsByInterception[tu.InterceptionID] = append(toolsByInterception[tu.InterceptionID], tu) + } + promptsByInterception := make(map[uuid.UUID][]database.AIBridgeUserPrompt, len(interceptions)) + for _, up := range userPrompts { + promptsByInterception[up.InterceptionID] = append(promptsByInterception[up.InterceptionID], up) + } + thoughtsByInterception := make(map[uuid.UUID][]database.AIBridgeModelThought, len(interceptions)) + for _, mt := range modelThoughts { + thoughtsByInterception[mt.InterceptionID] = append(thoughtsByInterception[mt.InterceptionID], mt) + } + + // Group interceptions by thread_id, preserving the order returned by the + // SQL query. + interceptionsByThread := make(map[uuid.UUID][]database.AIBridgeInterception, len(interceptions)) + var threadIDs []uuid.UUID + for _, row := range interceptions { + if _, ok := interceptionsByThread[row.ThreadID]; !ok { + threadIDs = append(threadIDs, row.ThreadID) + } + interceptionsByThread[row.ThreadID] = append(interceptionsByThread[row.ThreadID], row.AIBridgeInterception) + } + + // Build threads and track page time bounds. + threads := make([]codersdk.AIBridgeThread, 0, len(threadIDs)) + var pageStartedAt, pageEndedAt *time.Time + for _, threadID := range threadIDs { + intcs := interceptionsByThread[threadID] + thread := buildAIBridgeThread(threadID, intcs, tokensByInterception, toolsByInterception, promptsByInterception, thoughtsByInterception) + for _, intc := range intcs { + if pageStartedAt == nil || intc.StartedAt.Before(*pageStartedAt) { + t := intc.StartedAt + pageStartedAt = &t + } + if intc.EndedAt.Valid { + if pageEndedAt == nil || intc.EndedAt.Time.After(*pageEndedAt) { + t := intc.EndedAt.Time + pageEndedAt = &t + } + } + } + threads = append(threads, thread) + } + + // Aggregate session-level token usage metadata from all token + // usages in the session (not just the page). + sessionTokenMeta := aggregateTokenMetadata(tokenUsages) + + resp := codersdk.AIBridgeSessionThreadsResponse{ + ID: session.SessionID, + Initiator: MinimalUserFromVisibleUser(database.VisibleUser{ + ID: session.UserID, + Username: session.UserUsername, + Name: session.UserName, + AvatarURL: session.UserAvatarUrl, + }), + Providers: session.Providers, + Models: session.Models, + Metadata: jsonOrEmptyMap(pqtype.NullRawMessage{RawMessage: session.Metadata, Valid: len(session.Metadata) > 0}), + StartedAt: session.StartedAt, + PageStartedAt: pageStartedAt, + PageEndedAt: pageEndedAt, + TokenUsageSummary: codersdk.AIBridgeSessionThreadsTokenUsage{ + InputTokens: session.InputTokens, + OutputTokens: session.OutputTokens, + Metadata: sessionTokenMeta, + }, + Threads: threads, + } + if resp.Providers == nil { + resp.Providers = []string{} + } + if resp.Models == nil { + resp.Models = []string{} + } + if session.Client != "" { + resp.Client = &session.Client + } + if !session.EndedAt.IsZero() { + resp.EndedAt = &session.EndedAt + } + return resp +} + +func buildAIBridgeThread( + threadID uuid.UUID, + interceptions []database.AIBridgeInterception, + tokensByInterception map[uuid.UUID][]database.AIBridgeTokenUsage, + toolsByInterception map[uuid.UUID][]database.AIBridgeToolUsage, + promptsByInterception map[uuid.UUID][]database.AIBridgeUserPrompt, + thoughtsByInterception map[uuid.UUID][]database.AIBridgeModelThought, +) codersdk.AIBridgeThread { + // Find the root interception (where id == threadID) to get the + // thread prompt and model. + var rootIntc *database.AIBridgeInterception + for i := range interceptions { + if interceptions[i].ID == threadID { + rootIntc = &interceptions[i] + break + } + } + // Fallback to first interception if root not found. + if rootIntc == nil && len(interceptions) > 0 { + rootIntc = &interceptions[0] + } + + thread := codersdk.AIBridgeThread{ + ID: threadID, + } + if rootIntc != nil { + thread.Model = rootIntc.Model + thread.Provider = rootIntc.Provider + // Get first user prompt from root interception. + // A thread can only have one prompt, by definition, since we currently + // only store the last prompt observed in an interception. + if prompts := promptsByInterception[rootIntc.ID]; len(prompts) > 0 { + thread.Prompt = &prompts[0].Prompt + } + } + + // Compute thread time bounds from interceptions. + for _, intc := range interceptions { + if thread.StartedAt.IsZero() || intc.StartedAt.Before(thread.StartedAt) { + thread.StartedAt = intc.StartedAt + } + if intc.EndedAt.Valid { + if thread.EndedAt == nil || intc.EndedAt.Time.After(*thread.EndedAt) { + t := intc.EndedAt.Time + thread.EndedAt = &t + } + } + } + + // Build agentic actions grouped by interception. Each interception that + // has tool calls produces one action with all its tool calls, thinking + // blocks, and token usage. + var actions []codersdk.AIBridgeAgenticAction + for _, intc := range interceptions { + tools := toolsByInterception[intc.ID] + if len(tools) == 0 { + continue + } + + // Thinking blocks for this interception. + thoughts := thoughtsByInterception[intc.ID] + thinking := make([]codersdk.AIBridgeModelThought, 0, len(thoughts)) + for _, mt := range thoughts { + thinking = append(thinking, codersdk.AIBridgeModelThought{ + Text: mt.Content, + }) + } + + // Token usage for the interception. + actionTokenUsage := aggregateTokenUsage(tokensByInterception[intc.ID]) + + // Build tool call list. + toolCalls := make([]codersdk.AIBridgeToolCall, 0, len(tools)) + for _, tu := range tools { + toolCalls = append(toolCalls, codersdk.AIBridgeToolCall{ + ID: tu.ID, + InterceptionID: tu.InterceptionID, + ProviderResponseID: tu.ProviderResponseID, + ServerURL: tu.ServerUrl.String, + Tool: tu.Tool, + Injected: tu.Injected, + Input: tu.Input, + Metadata: jsonOrEmptyMap(tu.Metadata), + CreatedAt: tu.CreatedAt, + }) + } + + actions = append(actions, codersdk.AIBridgeAgenticAction{ + Model: intc.Model, + TokenUsage: actionTokenUsage, + Thinking: thinking, + ToolCalls: toolCalls, + }) + } + + if actions == nil { + // Make an empty slice so we don't serialize `null`. + actions = make([]codersdk.AIBridgeAgenticAction, 0) + } + + thread.AgenticActions = actions + + // Aggregate thread-level token usage. + var threadTokens []database.AIBridgeTokenUsage + for _, intc := range interceptions { + threadTokens = append(threadTokens, tokensByInterception[intc.ID]...) + } + thread.TokenUsage = aggregateTokenUsage(threadTokens) + + return thread +} + +// aggregateTokenUsage sums token usage rows and aggregates metadata. +func aggregateTokenUsage(tokens []database.AIBridgeTokenUsage) codersdk.AIBridgeSessionThreadsTokenUsage { + var inputTokens, outputTokens int64 + for _, tu := range tokens { + inputTokens += tu.InputTokens + outputTokens += tu.OutputTokens + // TODO: once https://github.com/coder/aibridge/issues/150 lands we + // should aggregate the other token types. + } + return codersdk.AIBridgeSessionThreadsTokenUsage{ + InputTokens: inputTokens, + OutputTokens: outputTokens, + Metadata: aggregateTokenMetadata(tokens), + } +} + +// aggregateTokenMetadata sums all numeric values from the metadata +// JSONB across the given token usage rows by key. Nested objects are +// flattened using dot-notation (e.g. {"cache": {"read_tokens": 10}} +// becomes "cache.read_tokens"). Non-numeric leaves (strings, +// booleans, arrays, nulls) are silently skipped. +func aggregateTokenMetadata(tokens []database.AIBridgeTokenUsage) map[string]any { + sums := make(map[string]int64) + for _, tu := range tokens { + if !tu.Metadata.Valid || len(tu.Metadata.RawMessage) == 0 { + continue + } + var m map[string]json.RawMessage + if err := json.Unmarshal(tu.Metadata.RawMessage, &m); err != nil { + continue + } + flattenAndSum(sums, "", m) + } + result := make(map[string]any, len(sums)) + for k, v := range sums { + result[k] = v + } + return result +} + +// flattenAndSum recursively walks a JSON object and sums all numeric +// leaf values into sums, using dot-separated keys for nested objects. +func flattenAndSum(sums map[string]int64, prefix string, m map[string]json.RawMessage) { + for k, raw := range m { + key := k + if prefix != "" { + key = prefix + "." + k + } + + // Try as a number first. + var n json.Number + if err := json.Unmarshal(raw, &n); err == nil { + if v, err := n.Int64(); err == nil { + sums[key] += v + } + continue + } + + // Try as a nested object. + var nested map[string]json.RawMessage + if err := json.Unmarshal(raw, &nested); err == nil { + flattenAndSum(sums, key, nested) + } + // Arrays, strings, booleans, nulls are skipped. + } +} + func InvalidatedPresets(invalidatedPresets []database.UpdatePresetsLastInvalidatedAtRow) []codersdk.InvalidatedPreset { var presets []codersdk.InvalidatedPreset for _, p := range invalidatedPresets { diff --git a/coderd/database/db2sdk/db2sdk_internal_test.go b/coderd/database/db2sdk/db2sdk_internal_test.go new file mode 100644 index 0000000000..2222238c90 --- /dev/null +++ b/coderd/database/db2sdk/db2sdk_internal_test.go @@ -0,0 +1,308 @@ +package db2sdk + +import ( + "encoding/json" + "testing" + + "github.com/google/uuid" + "github.com/sqlc-dev/pqtype" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/coderd/database" +) + +func TestAggregateTokenMetadata(t *testing.T) { + t.Parallel() + + t.Run("empty_input", func(t *testing.T) { + t.Parallel() + result := aggregateTokenMetadata(nil) + require.Empty(t, result) + }) + + t.Run("sums_across_rows", func(t *testing.T) { + t.Parallel() + tokens := []database.AIBridgeTokenUsage{ + { + ID: uuid.New(), + Metadata: pqtype.NullRawMessage{ + RawMessage: json.RawMessage(`{"cache_read_tokens":100,"reasoning_tokens":50}`), + Valid: true, + }, + }, + { + ID: uuid.New(), + Metadata: pqtype.NullRawMessage{ + RawMessage: json.RawMessage(`{"cache_read_tokens":200,"reasoning_tokens":75}`), + Valid: true, + }, + }, + } + + result := aggregateTokenMetadata(tokens) + require.Equal(t, int64(300), result["cache_read_tokens"]) + require.Equal(t, int64(125), result["reasoning_tokens"]) + require.Len(t, result, 2) + }) + + t.Run("skips_null_and_invalid_metadata", func(t *testing.T) { + t.Parallel() + tokens := []database.AIBridgeTokenUsage{ + { + ID: uuid.New(), + Metadata: pqtype.NullRawMessage{Valid: false}, + }, + { + ID: uuid.New(), + Metadata: pqtype.NullRawMessage{ + RawMessage: nil, + Valid: true, + }, + }, + { + ID: uuid.New(), + Metadata: pqtype.NullRawMessage{ + RawMessage: json.RawMessage(`{"tokens":42}`), + Valid: true, + }, + }, + } + + result := aggregateTokenMetadata(tokens) + require.Equal(t, int64(42), result["tokens"]) + require.Len(t, result, 1) + }) + + t.Run("skips_non_integer_values", func(t *testing.T) { + t.Parallel() + tokens := []database.AIBridgeTokenUsage{ + { + ID: uuid.New(), + Metadata: pqtype.NullRawMessage{ + // Float values fail json.Number.Int64(), so they + // are silently dropped. + RawMessage: json.RawMessage(`{"good":10,"fractional":1.5}`), + Valid: true, + }, + }, + } + + result := aggregateTokenMetadata(tokens) + require.Equal(t, int64(10), result["good"]) + _, hasFractional := result["fractional"] + require.False(t, hasFractional) + }) + + t.Run("skips_malformed_json", func(t *testing.T) { + t.Parallel() + tokens := []database.AIBridgeTokenUsage{ + { + ID: uuid.New(), + Metadata: pqtype.NullRawMessage{ + RawMessage: json.RawMessage(`not json`), + Valid: true, + }, + }, + { + ID: uuid.New(), + Metadata: pqtype.NullRawMessage{ + RawMessage: json.RawMessage(`{"tokens":5}`), + Valid: true, + }, + }, + } + + result := aggregateTokenMetadata(tokens) + // The malformed row is skipped, the valid one is counted. + require.Equal(t, int64(5), result["tokens"]) + require.Len(t, result, 1) + }) + + t.Run("flattens_nested_objects", func(t *testing.T) { + t.Parallel() + tokens := []database.AIBridgeTokenUsage{ + { + ID: uuid.New(), + Metadata: pqtype.NullRawMessage{ + RawMessage: json.RawMessage(`{ + "cache_read_tokens": 100, + "cache": {"creation_tokens": 40, "read_tokens": 60}, + "reasoning_tokens": 50, + "tags": ["a", "b"] + }`), + Valid: true, + }, + }, + { + ID: uuid.New(), + Metadata: pqtype.NullRawMessage{ + RawMessage: json.RawMessage(`{ + "cache_read_tokens": 200, + "cache": {"creation_tokens": 10} + }`), + Valid: true, + }, + }, + } + + result := aggregateTokenMetadata(tokens) + require.Equal(t, int64(300), result["cache_read_tokens"]) + require.Equal(t, int64(50), result["reasoning_tokens"]) + require.Equal(t, int64(50), result["cache.creation_tokens"]) + require.Equal(t, int64(60), result["cache.read_tokens"]) + // Arrays are skipped. + _, hasTags := result["tags"] + require.False(t, hasTags) + require.Len(t, result, 4) + }) + + t.Run("flattens_deeply_nested_objects", func(t *testing.T) { + t.Parallel() + tokens := []database.AIBridgeTokenUsage{ + { + ID: uuid.New(), + Metadata: pqtype.NullRawMessage{ + RawMessage: json.RawMessage(`{ + "provider": { + "anthropic": {"cache_creation_tokens": 100, "cache_read_tokens": 200}, + "openai": {"reasoning_tokens": 50} + }, + "total": 500 + }`), + Valid: true, + }, + }, + } + + result := aggregateTokenMetadata(tokens) + require.Equal(t, int64(100), result["provider.anthropic.cache_creation_tokens"]) + require.Equal(t, int64(200), result["provider.anthropic.cache_read_tokens"]) + require.Equal(t, int64(50), result["provider.openai.reasoning_tokens"]) + require.Equal(t, int64(500), result["total"]) + require.Len(t, result, 4) + }) + + // Real-world provider metadata shapes from + // https://github.com/coder/aibridge/issues/150. + t.Run("aggregates_real_provider_metadata", func(t *testing.T) { + t.Parallel() + tokens := []database.AIBridgeTokenUsage{ + { + // Anthropic-style: cache fields are top-level. + ID: uuid.New(), + Metadata: pqtype.NullRawMessage{ + RawMessage: json.RawMessage(`{ + "cache_creation_input_tokens": 0, + "cache_read_input_tokens": 23490 + }`), + Valid: true, + }, + }, + { + // OpenAI-style: cache fields are nested inside + // input_tokens_details. + ID: uuid.New(), + Metadata: pqtype.NullRawMessage{ + RawMessage: json.RawMessage(`{ + "input_tokens_details": {"cached_tokens": 11904} + }`), + Valid: true, + }, + }, + { + // Second Anthropic row to verify summing. + ID: uuid.New(), + Metadata: pqtype.NullRawMessage{ + RawMessage: json.RawMessage(`{ + "cache_creation_input_tokens": 500, + "cache_read_input_tokens": 10000 + }`), + Valid: true, + }, + }, + } + + result := aggregateTokenMetadata(tokens) + // Anthropic fields are summed across two rows. + require.Equal(t, int64(500), result["cache_creation_input_tokens"]) + require.Equal(t, int64(33490), result["cache_read_input_tokens"]) + // OpenAI nested field is flattened with dot notation. + require.Equal(t, int64(11904), result["input_tokens_details.cached_tokens"]) + require.Len(t, result, 3) + }) + + t.Run("skips_string_boolean_null_values", func(t *testing.T) { + t.Parallel() + tokens := []database.AIBridgeTokenUsage{ + { + ID: uuid.New(), + Metadata: pqtype.NullRawMessage{ + RawMessage: json.RawMessage(`{"tokens":10,"name":"test","enabled":true,"nothing":null}`), + Valid: true, + }, + }, + } + + result := aggregateTokenMetadata(tokens) + require.Equal(t, int64(10), result["tokens"]) + require.Len(t, result, 1) + }) +} + +func TestAggregateTokenUsage(t *testing.T) { + t.Parallel() + + t.Run("empty_input", func(t *testing.T) { + t.Parallel() + result := aggregateTokenUsage(nil) + require.Equal(t, int64(0), result.InputTokens) + require.Equal(t, int64(0), result.OutputTokens) + require.Empty(t, result.Metadata) + }) + + t.Run("sums_tokens_and_metadata", func(t *testing.T) { + t.Parallel() + tokens := []database.AIBridgeTokenUsage{ + { + ID: uuid.New(), + InputTokens: 100, + OutputTokens: 50, + Metadata: pqtype.NullRawMessage{ + RawMessage: json.RawMessage(`{"reasoning_tokens":20}`), + Valid: true, + }, + }, + { + ID: uuid.New(), + InputTokens: 200, + OutputTokens: 75, + Metadata: pqtype.NullRawMessage{ + RawMessage: json.RawMessage(`{"reasoning_tokens":30}`), + Valid: true, + }, + }, + } + + result := aggregateTokenUsage(tokens) + require.Equal(t, int64(300), result.InputTokens) + require.Equal(t, int64(125), result.OutputTokens) + require.Equal(t, int64(50), result.Metadata["reasoning_tokens"]) + }) + + t.Run("handles_rows_without_metadata", func(t *testing.T) { + t.Parallel() + tokens := []database.AIBridgeTokenUsage{ + { + ID: uuid.New(), + InputTokens: 500, + OutputTokens: 200, + Metadata: pqtype.NullRawMessage{Valid: false}, + }, + } + + result := aggregateTokenUsage(tokens) + require.Equal(t, int64(500), result.InputTokens) + require.Equal(t, int64(200), result.OutputTokens) + require.Empty(t, result.Metadata) + }) +} diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 17b6bad1d6..51e32e0ba4 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -5363,6 +5363,13 @@ func (q *querier) ListAIBridgeInterceptionsTelemetrySummaries(ctx context.Contex return q.db.ListAIBridgeInterceptionsTelemetrySummaries(ctx, arg) } +func (q *querier) ListAIBridgeModelThoughtsByInterceptionIDs(ctx context.Context, interceptionIDs []uuid.UUID) ([]database.AIBridgeModelThought, error) { + if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceAibridgeInterception); err != nil { + return nil, err + } + return q.db.ListAIBridgeModelThoughtsByInterceptionIDs(ctx, interceptionIDs) +} + func (q *querier) ListAIBridgeModels(ctx context.Context, arg database.ListAIBridgeModelsParams) ([]string, error) { prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAibridgeInterception.Type) if err != nil { @@ -5371,6 +5378,14 @@ func (q *querier) ListAIBridgeModels(ctx context.Context, arg database.ListAIBri return q.db.ListAuthorizedAIBridgeModels(ctx, arg, prep) } +func (q *querier) ListAIBridgeSessionThreads(ctx context.Context, arg database.ListAIBridgeSessionThreadsParams) ([]database.ListAIBridgeSessionThreadsRow, error) { + prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAibridgeInterception.Type) + if err != nil { + return nil, xerrors.Errorf("(dev error) prepare sql filter: %w", err) + } + return q.db.ListAuthorizedAIBridgeSessionThreads(ctx, arg, prep) +} + func (q *querier) ListAIBridgeSessions(ctx context.Context, arg database.ListAIBridgeSessionsParams) ([]database.ListAIBridgeSessionsRow, error) { prep, err := prepareSQLFilter(ctx, q.auth, policy.ActionRead, rbac.ResourceAibridgeInterception.Type) if err != nil { @@ -7221,6 +7236,10 @@ func (q *querier) CountAuthorizedAIBridgeSessions(ctx context.Context, arg datab return q.db.CountAuthorizedAIBridgeSessions(ctx, arg, prepared) } +func (q *querier) ListAuthorizedAIBridgeSessionThreads(ctx context.Context, arg database.ListAIBridgeSessionThreadsParams, prepared rbac.PreparedAuthorized) ([]database.ListAIBridgeSessionThreadsRow, error) { + return q.db.ListAuthorizedAIBridgeSessionThreads(ctx, arg, prepared) +} + func (q *querier) GetAuthorizedChats(ctx context.Context, arg database.GetChatsParams, _ rbac.PreparedAuthorized) ([]database.Chat, error) { return q.GetChats(ctx, arg) } diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index befc38a551..76dd630590 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -5528,6 +5528,26 @@ func (s *MethodTestSuite) TestAIBridge() { check.Args(ids).Asserts(rbac.ResourceAibridgeInterception, policy.ActionRead).Returns([]database.AIBridgeToolUsage{}) })) + s.Run("ListAIBridgeModelThoughtsByInterceptionIDs", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + ids := []uuid.UUID{{1}} + db.EXPECT().ListAIBridgeModelThoughtsByInterceptionIDs(gomock.Any(), ids).Return([]database.AIBridgeModelThought{}, nil).AnyTimes() + check.Args(ids).Asserts(rbac.ResourceAibridgeInterception, policy.ActionRead).Returns([]database.AIBridgeModelThought{}) + })) + + s.Run("ListAIBridgeSessionThreads", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + params := database.ListAIBridgeSessionThreadsParams{} + db.EXPECT().ListAuthorizedAIBridgeSessionThreads(gomock.Any(), params, gomock.Any()).Return([]database.ListAIBridgeSessionThreadsRow{}, nil).AnyTimes() + // No asserts here because SQLFilter. + check.Args(params).Asserts() + })) + + s.Run("ListAuthorizedAIBridgeSessionThreads", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + params := database.ListAIBridgeSessionThreadsParams{} + db.EXPECT().ListAuthorizedAIBridgeSessionThreads(gomock.Any(), params, gomock.Any()).Return([]database.ListAIBridgeSessionThreadsRow{}, nil).AnyTimes() + // No asserts here because SQLFilter. + check.Args(params, emptyPreparedAuthorized{}).Asserts() + })) + s.Run("UpdateAIBridgeInterceptionEnded", s.Mocked(func(db *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { intcID := uuid.UUID{1} params := database.UpdateAIBridgeInterceptionEndedParams{ID: intcID} diff --git a/coderd/database/dbgen/dbgen.go b/coderd/database/dbgen/dbgen.go index ac30be56c5..5881f6e157 100644 --- a/coderd/database/dbgen/dbgen.go +++ b/coderd/database/dbgen/dbgen.go @@ -1663,6 +1663,17 @@ func AIBridgeToolUsage(t testing.TB, db database.Store, seed database.InsertAIBr return toolUsage } +func AIBridgeModelThought(t testing.TB, db database.Store, seed database.InsertAIBridgeModelThoughtParams) database.AIBridgeModelThought { + thought, err := db.InsertAIBridgeModelThought(genCtx, database.InsertAIBridgeModelThoughtParams{ + InterceptionID: takeFirst(seed.InterceptionID, uuid.New()), + Content: takeFirst(seed.Content, ""), + Metadata: takeFirstSlice(seed.Metadata, json.RawMessage("{}")), + CreatedAt: takeFirst(seed.CreatedAt, dbtime.Now()), + }) + require.NoError(t, err, "insert aibridge model thought") + return thought +} + func Task(t testing.TB, db database.Store, orig database.TaskTable) database.Task { t.Helper() diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index 25c2f936c8..0373ea802a 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -3760,6 +3760,14 @@ func (m queryMetricsStore) ListAIBridgeInterceptionsTelemetrySummaries(ctx conte return r0, r1 } +func (m queryMetricsStore) ListAIBridgeModelThoughtsByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]database.AIBridgeModelThought, error) { + start := time.Now() + r0, r1 := m.s.ListAIBridgeModelThoughtsByInterceptionIDs(ctx, interceptionIds) + m.queryLatencies.WithLabelValues("ListAIBridgeModelThoughtsByInterceptionIDs").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListAIBridgeModelThoughtsByInterceptionIDs").Inc() + return r0, r1 +} + func (m queryMetricsStore) ListAIBridgeModels(ctx context.Context, arg database.ListAIBridgeModelsParams) ([]string, error) { start := time.Now() r0, r1 := m.s.ListAIBridgeModels(ctx, arg) @@ -3768,6 +3776,14 @@ func (m queryMetricsStore) ListAIBridgeModels(ctx context.Context, arg database. return r0, r1 } +func (m queryMetricsStore) ListAIBridgeSessionThreads(ctx context.Context, arg database.ListAIBridgeSessionThreadsParams) ([]database.ListAIBridgeSessionThreadsRow, error) { + start := time.Now() + r0, r1 := m.s.ListAIBridgeSessionThreads(ctx, arg) + m.queryLatencies.WithLabelValues("ListAIBridgeSessionThreads").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListAIBridgeSessionThreads").Inc() + return r0, r1 +} + func (m queryMetricsStore) ListAIBridgeSessions(ctx context.Context, arg database.ListAIBridgeSessionsParams) ([]database.ListAIBridgeSessionsRow, error) { start := time.Now() r0, r1 := m.s.ListAIBridgeSessions(ctx, arg) @@ -5240,6 +5256,14 @@ func (m queryMetricsStore) CountAuthorizedAIBridgeSessions(ctx context.Context, return r0, r1 } +func (m queryMetricsStore) ListAuthorizedAIBridgeSessionThreads(ctx context.Context, arg database.ListAIBridgeSessionThreadsParams, prepared rbac.PreparedAuthorized) ([]database.ListAIBridgeSessionThreadsRow, error) { + start := time.Now() + r0, r1 := m.s.ListAuthorizedAIBridgeSessionThreads(ctx, arg, prepared) + m.queryLatencies.WithLabelValues("ListAuthorizedAIBridgeSessionThreads").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "ListAuthorizedAIBridgeSessionThreads").Inc() + return r0, r1 +} + func (m queryMetricsStore) GetAuthorizedChats(ctx context.Context, arg database.GetChatsParams, prepared rbac.PreparedAuthorized) ([]database.Chat, error) { start := time.Now() r0, r1 := m.s.GetAuthorizedChats(ctx, arg, prepared) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index 3a99e3d29e..4b51306e8f 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -7037,6 +7037,21 @@ func (mr *MockStoreMockRecorder) ListAIBridgeInterceptionsTelemetrySummaries(ctx return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAIBridgeInterceptionsTelemetrySummaries", reflect.TypeOf((*MockStore)(nil).ListAIBridgeInterceptionsTelemetrySummaries), ctx, arg) } +// ListAIBridgeModelThoughtsByInterceptionIDs mocks base method. +func (m *MockStore) ListAIBridgeModelThoughtsByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]database.AIBridgeModelThought, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListAIBridgeModelThoughtsByInterceptionIDs", ctx, interceptionIds) + ret0, _ := ret[0].([]database.AIBridgeModelThought) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListAIBridgeModelThoughtsByInterceptionIDs indicates an expected call of ListAIBridgeModelThoughtsByInterceptionIDs. +func (mr *MockStoreMockRecorder) ListAIBridgeModelThoughtsByInterceptionIDs(ctx, interceptionIds any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAIBridgeModelThoughtsByInterceptionIDs", reflect.TypeOf((*MockStore)(nil).ListAIBridgeModelThoughtsByInterceptionIDs), ctx, interceptionIds) +} + // ListAIBridgeModels mocks base method. func (m *MockStore) ListAIBridgeModels(ctx context.Context, arg database.ListAIBridgeModelsParams) ([]string, error) { m.ctrl.T.Helper() @@ -7052,6 +7067,21 @@ func (mr *MockStoreMockRecorder) ListAIBridgeModels(ctx, arg any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAIBridgeModels", reflect.TypeOf((*MockStore)(nil).ListAIBridgeModels), ctx, arg) } +// ListAIBridgeSessionThreads mocks base method. +func (m *MockStore) ListAIBridgeSessionThreads(ctx context.Context, arg database.ListAIBridgeSessionThreadsParams) ([]database.ListAIBridgeSessionThreadsRow, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListAIBridgeSessionThreads", ctx, arg) + ret0, _ := ret[0].([]database.ListAIBridgeSessionThreadsRow) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListAIBridgeSessionThreads indicates an expected call of ListAIBridgeSessionThreads. +func (mr *MockStoreMockRecorder) ListAIBridgeSessionThreads(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAIBridgeSessionThreads", reflect.TypeOf((*MockStore)(nil).ListAIBridgeSessionThreads), ctx, arg) +} + // ListAIBridgeSessions mocks base method. func (m *MockStore) ListAIBridgeSessions(ctx context.Context, arg database.ListAIBridgeSessionsParams) ([]database.ListAIBridgeSessionsRow, error) { m.ctrl.T.Helper() @@ -7142,6 +7172,21 @@ func (mr *MockStoreMockRecorder) ListAuthorizedAIBridgeModels(ctx, arg, prepared return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAuthorizedAIBridgeModels", reflect.TypeOf((*MockStore)(nil).ListAuthorizedAIBridgeModels), ctx, arg, prepared) } +// ListAuthorizedAIBridgeSessionThreads mocks base method. +func (m *MockStore) ListAuthorizedAIBridgeSessionThreads(ctx context.Context, arg database.ListAIBridgeSessionThreadsParams, prepared rbac.PreparedAuthorized) ([]database.ListAIBridgeSessionThreadsRow, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ListAuthorizedAIBridgeSessionThreads", ctx, arg, prepared) + ret0, _ := ret[0].([]database.ListAIBridgeSessionThreadsRow) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ListAuthorizedAIBridgeSessionThreads indicates an expected call of ListAuthorizedAIBridgeSessionThreads. +func (mr *MockStoreMockRecorder) ListAuthorizedAIBridgeSessionThreads(ctx, arg, prepared any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListAuthorizedAIBridgeSessionThreads", reflect.TypeOf((*MockStore)(nil).ListAuthorizedAIBridgeSessionThreads), ctx, arg, prepared) +} + // ListAuthorizedAIBridgeSessions mocks base method. func (m *MockStore) ListAuthorizedAIBridgeSessions(ctx context.Context, arg database.ListAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) ([]database.ListAIBridgeSessionsRow, error) { m.ctrl.T.Helper() diff --git a/coderd/database/modelqueries.go b/coderd/database/modelqueries.go index 33d181053d..af1fd954c7 100644 --- a/coderd/database/modelqueries.go +++ b/coderd/database/modelqueries.go @@ -813,6 +813,7 @@ type aibridgeQuerier interface { ListAuthorizedAIBridgeModels(ctx context.Context, arg ListAIBridgeModelsParams, prepared rbac.PreparedAuthorized) ([]string, error) ListAuthorizedAIBridgeSessions(ctx context.Context, arg ListAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) ([]ListAIBridgeSessionsRow, error) CountAuthorizedAIBridgeSessions(ctx context.Context, arg CountAIBridgeSessionsParams, prepared rbac.PreparedAuthorized) (int64, error) + ListAuthorizedAIBridgeSessionThreads(ctx context.Context, arg ListAIBridgeSessionThreadsParams, prepared rbac.PreparedAuthorized) ([]ListAIBridgeSessionThreadsRow, error) } func (q *sqlQuerier) ListAuthorizedAIBridgeInterceptions(ctx context.Context, arg ListAIBridgeInterceptionsParams, prepared rbac.PreparedAuthorized) ([]ListAIBridgeInterceptionsRow, error) { @@ -1050,11 +1051,66 @@ func (q *sqlQuerier) CountAuthorizedAIBridgeSessions(ctx context.Context, arg Co return count, nil } +func (q *sqlQuerier) ListAuthorizedAIBridgeSessionThreads(ctx context.Context, arg ListAIBridgeSessionThreadsParams, prepared rbac.PreparedAuthorized) ([]ListAIBridgeSessionThreadsRow, error) { + authorizedFilter, err := prepared.CompileToSQL(ctx, regosql.ConvertConfig{ + VariableConverter: regosql.AIBridgeInterceptionConverter(), + }) + if err != nil { + return nil, xerrors.Errorf("compile authorized filter: %w", err) + } + filtered, err := insertAuthorizedFilter(listAIBridgeSessionThreads, fmt.Sprintf(" AND %s", authorizedFilter)) + if err != nil { + return nil, xerrors.Errorf("insert authorized filter: %w", err) + } + + query := fmt.Sprintf("-- name: ListAuthorizedAIBridgeSessionThreads :many\n%s", filtered) + rows, err := q.db.QueryContext(ctx, query, + arg.SessionID, + arg.AfterID, + arg.BeforeID, + arg.Limit, + ) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ListAIBridgeSessionThreadsRow + for rows.Next() { + var i ListAIBridgeSessionThreadsRow + if err := rows.Scan( + &i.ThreadID, + &i.AIBridgeInterception.ID, + &i.AIBridgeInterception.InitiatorID, + &i.AIBridgeInterception.Provider, + &i.AIBridgeInterception.Model, + &i.AIBridgeInterception.StartedAt, + &i.AIBridgeInterception.Metadata, + &i.AIBridgeInterception.EndedAt, + &i.AIBridgeInterception.APIKeyID, + &i.AIBridgeInterception.Client, + &i.AIBridgeInterception.ThreadParentID, + &i.AIBridgeInterception.ThreadRootID, + &i.AIBridgeInterception.ClientSessionID, + &i.AIBridgeInterception.SessionID, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + func insertAuthorizedFilter(query string, replaceWith string) (string, error) { if !strings.Contains(query, authorizedQueryPlaceholder) { return "", xerrors.Errorf("query does not contain authorized replace string, this is not an authorized query") } - filtered := strings.Replace(query, authorizedQueryPlaceholder, replaceWith, 1) + filtered := strings.ReplaceAll(query, authorizedQueryPlaceholder, replaceWith) return filtered, nil } diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 6c74b523d7..d26febbe28 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -778,7 +778,11 @@ type sqlcQuerier interface { // Finds all unique AI Bridge interception telemetry summaries combinations // (provider, model, client) in the given timeframe for telemetry reporting. ListAIBridgeInterceptionsTelemetrySummaries(ctx context.Context, arg ListAIBridgeInterceptionsTelemetrySummariesParams) ([]ListAIBridgeInterceptionsTelemetrySummariesRow, error) + ListAIBridgeModelThoughtsByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeModelThought, error) ListAIBridgeModels(ctx context.Context, arg ListAIBridgeModelsParams) ([]string, error) + // Returns all interceptions belonging to paginated threads within a session. + // Threads are paginated by (started_at, thread_id) cursor. + ListAIBridgeSessionThreads(ctx context.Context, arg ListAIBridgeSessionThreadsParams) ([]ListAIBridgeSessionThreadsRow, error) // Returns paginated sessions with aggregated metadata, token counts, and // the most recent user prompt. A "session" is a logical grouping of // interceptions that share the same session_id (set by the client). diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 1a97955cab..81c69a6ae9 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -1092,6 +1092,45 @@ func (q *sqlQuerier) ListAIBridgeInterceptionsTelemetrySummaries(ctx context.Con return items, nil } +const listAIBridgeModelThoughtsByInterceptionIDs = `-- name: ListAIBridgeModelThoughtsByInterceptionIDs :many +SELECT + interception_id, content, metadata, created_at +FROM + aibridge_model_thoughts +WHERE + interception_id = ANY($1::uuid[]) +ORDER BY + created_at ASC +` + +func (q *sqlQuerier) ListAIBridgeModelThoughtsByInterceptionIDs(ctx context.Context, interceptionIds []uuid.UUID) ([]AIBridgeModelThought, error) { + rows, err := q.db.QueryContext(ctx, listAIBridgeModelThoughtsByInterceptionIDs, pq.Array(interceptionIds)) + if err != nil { + return nil, err + } + defer rows.Close() + var items []AIBridgeModelThought + for rows.Next() { + var i AIBridgeModelThought + if err := rows.Scan( + &i.InterceptionID, + &i.Content, + &i.Metadata, + &i.CreatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const listAIBridgeModels = `-- name: ListAIBridgeModels :many SELECT model @@ -1146,6 +1185,115 @@ func (q *sqlQuerier) ListAIBridgeModels(ctx context.Context, arg ListAIBridgeMod return items, nil } +const listAIBridgeSessionThreads = `-- name: ListAIBridgeSessionThreads :many +WITH paginated_threads AS ( + SELECT + -- Find thread root interceptions (thread_root_id IS NULL), apply cursor + -- pagination, and return the page. + aibridge_interceptions.id AS thread_id, + aibridge_interceptions.started_at + FROM + aibridge_interceptions + WHERE + aibridge_interceptions.session_id = $1::text + AND aibridge_interceptions.ended_at IS NOT NULL + AND aibridge_interceptions.thread_root_id IS NULL + -- Pagination cursor. + AND ($2::uuid = '00000000-0000-0000-0000-000000000000'::uuid OR + (aibridge_interceptions.started_at, aibridge_interceptions.id) > ( + (SELECT started_at FROM aibridge_interceptions ai2 WHERE ai2.id = $2), + $2::uuid + ) + ) + AND ($3::uuid = '00000000-0000-0000-0000-000000000000'::uuid OR + (aibridge_interceptions.started_at, aibridge_interceptions.id) < ( + (SELECT started_at FROM aibridge_interceptions ai2 WHERE ai2.id = $3), + $3::uuid + ) + ) + -- @authorize_filter + ORDER BY + aibridge_interceptions.started_at ASC, + aibridge_interceptions.id ASC + LIMIT COALESCE(NULLIF($4::integer, 0), 50) +) +SELECT + COALESCE(aibridge_interceptions.thread_root_id, aibridge_interceptions.id) AS thread_id, + aibridge_interceptions.id, aibridge_interceptions.initiator_id, aibridge_interceptions.provider, aibridge_interceptions.model, aibridge_interceptions.started_at, aibridge_interceptions.metadata, aibridge_interceptions.ended_at, aibridge_interceptions.api_key_id, aibridge_interceptions.client, aibridge_interceptions.thread_parent_id, aibridge_interceptions.thread_root_id, aibridge_interceptions.client_session_id, aibridge_interceptions.session_id +FROM + aibridge_interceptions +JOIN + paginated_threads pt + ON pt.thread_id = COALESCE(aibridge_interceptions.thread_root_id, aibridge_interceptions.id) +WHERE + aibridge_interceptions.session_id = $1::text + AND aibridge_interceptions.ended_at IS NOT NULL + -- @authorize_filter +ORDER BY + -- Ensure threads and their associated interceptions (agentic loops) are sorted chronologically. + pt.started_at ASC, + pt.thread_id ASC, + aibridge_interceptions.started_at ASC, + aibridge_interceptions.id ASC +` + +type ListAIBridgeSessionThreadsParams struct { + SessionID string `db:"session_id" json:"session_id"` + AfterID uuid.UUID `db:"after_id" json:"after_id"` + BeforeID uuid.UUID `db:"before_id" json:"before_id"` + Limit int32 `db:"limit_" json:"limit_"` +} + +type ListAIBridgeSessionThreadsRow struct { + ThreadID uuid.UUID `db:"thread_id" json:"thread_id"` + AIBridgeInterception AIBridgeInterception `db:"aibridge_interception" json:"aibridge_interception"` +} + +// Returns all interceptions belonging to paginated threads within a session. +// Threads are paginated by (started_at, thread_id) cursor. +func (q *sqlQuerier) ListAIBridgeSessionThreads(ctx context.Context, arg ListAIBridgeSessionThreadsParams) ([]ListAIBridgeSessionThreadsRow, error) { + rows, err := q.db.QueryContext(ctx, listAIBridgeSessionThreads, + arg.SessionID, + arg.AfterID, + arg.BeforeID, + arg.Limit, + ) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ListAIBridgeSessionThreadsRow + for rows.Next() { + var i ListAIBridgeSessionThreadsRow + if err := rows.Scan( + &i.ThreadID, + &i.AIBridgeInterception.ID, + &i.AIBridgeInterception.InitiatorID, + &i.AIBridgeInterception.Provider, + &i.AIBridgeInterception.Model, + &i.AIBridgeInterception.StartedAt, + &i.AIBridgeInterception.Metadata, + &i.AIBridgeInterception.EndedAt, + &i.AIBridgeInterception.APIKeyID, + &i.AIBridgeInterception.Client, + &i.AIBridgeInterception.ThreadParentID, + &i.AIBridgeInterception.ThreadRootID, + &i.AIBridgeInterception.ClientSessionID, + &i.AIBridgeInterception.SessionID, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const listAIBridgeSessions = `-- name: ListAIBridgeSessions :many WITH filtered_interceptions AS ( SELECT diff --git a/coderd/database/queries/aibridge.sql b/coderd/database/queries/aibridge.sql index 5804fda09a..0d62ddb301 100644 --- a/coderd/database/queries/aibridge.sql +++ b/coderd/database/queries/aibridge.sql @@ -592,6 +592,70 @@ LIMIT COALESCE(NULLIF(@limit_::integer, 0), 100) OFFSET @offset_ ; +-- name: ListAIBridgeSessionThreads :many +-- Returns all interceptions belonging to paginated threads within a session. +-- Threads are paginated by (started_at, thread_id) cursor. +WITH paginated_threads AS ( + SELECT + -- Find thread root interceptions (thread_root_id IS NULL), apply cursor + -- pagination, and return the page. + aibridge_interceptions.id AS thread_id, + aibridge_interceptions.started_at + FROM + aibridge_interceptions + WHERE + aibridge_interceptions.session_id = @session_id::text + AND aibridge_interceptions.ended_at IS NOT NULL + AND aibridge_interceptions.thread_root_id IS NULL + -- Pagination cursor. + AND (@after_id::uuid = '00000000-0000-0000-0000-000000000000'::uuid OR + (aibridge_interceptions.started_at, aibridge_interceptions.id) > ( + (SELECT started_at FROM aibridge_interceptions ai2 WHERE ai2.id = @after_id), + @after_id::uuid + ) + ) + AND (@before_id::uuid = '00000000-0000-0000-0000-000000000000'::uuid OR + (aibridge_interceptions.started_at, aibridge_interceptions.id) < ( + (SELECT started_at FROM aibridge_interceptions ai2 WHERE ai2.id = @before_id), + @before_id::uuid + ) + ) + -- @authorize_filter + ORDER BY + aibridge_interceptions.started_at ASC, + aibridge_interceptions.id ASC + LIMIT COALESCE(NULLIF(@limit_::integer, 0), 50) +) +SELECT + COALESCE(aibridge_interceptions.thread_root_id, aibridge_interceptions.id) AS thread_id, + sqlc.embed(aibridge_interceptions) +FROM + aibridge_interceptions +JOIN + paginated_threads pt + ON pt.thread_id = COALESCE(aibridge_interceptions.thread_root_id, aibridge_interceptions.id) +WHERE + aibridge_interceptions.session_id = @session_id::text + AND aibridge_interceptions.ended_at IS NOT NULL + -- @authorize_filter +ORDER BY + -- Ensure threads and their associated interceptions (agentic loops) are sorted chronologically. + pt.started_at ASC, + pt.thread_id ASC, + aibridge_interceptions.started_at ASC, + aibridge_interceptions.id ASC +; + +-- name: ListAIBridgeModelThoughtsByInterceptionIDs :many +SELECT + * +FROM + aibridge_model_thoughts +WHERE + interception_id = ANY(@interception_ids::uuid[]) +ORDER BY + created_at ASC; + -- name: ListAIBridgeModels :many SELECT model diff --git a/codersdk/aibridge.go b/codersdk/aibridge.go index 56b2260bfe..1b80b73488 100644 --- a/codersdk/aibridge.go +++ b/codersdk/aibridge.go @@ -87,6 +87,75 @@ type AIBridgeListSessionsResponse struct { Sessions []AIBridgeSession `json:"sessions"` } +// AIBridgeSessionThreadsResponse is the response for GET +// /api/v2/aibridge/sessions/{session_id} which returns a single +// session with fully expanded threads. +type AIBridgeSessionThreadsResponse struct { + ID string `json:"id"` + Initiator MinimalUser `json:"initiator"` + Providers []string `json:"providers"` + Models []string `json:"models"` + Client *string `json:"client,omitempty"` + Metadata map[string]any `json:"metadata"` + PageStartedAt *time.Time `json:"page_started_at,omitempty" format:"date-time"` + PageEndedAt *time.Time `json:"page_ended_at,omitempty" format:"date-time"` + StartedAt time.Time `json:"started_at" format:"date-time"` + EndedAt *time.Time `json:"ended_at,omitempty" format:"date-time"` + TokenUsageSummary AIBridgeSessionThreadsTokenUsage `json:"token_usage_summary"` + Threads []AIBridgeThread `json:"threads"` +} + +// AIBridgeSessionThreadsTokenUsage represents aggregated token usage +// with metadata containing provider-specific fields like +// cache_creation_input, cache_read_input, etc. +type AIBridgeSessionThreadsTokenUsage struct { + InputTokens int64 `json:"input_tokens"` + OutputTokens int64 `json:"output_tokens"` + Metadata map[string]any `json:"metadata"` +} + +// AIBridgeThread represents a single thread within a session. +// A thread groups interceptions by their thread_root_id. +type AIBridgeThread struct { + ID uuid.UUID `json:"id" format:"uuid"` + Prompt *string `json:"prompt,omitempty"` + Model string `json:"model"` + Provider string `json:"provider"` + StartedAt time.Time `json:"started_at" format:"date-time"` + EndedAt *time.Time `json:"ended_at,omitempty" format:"date-time"` + TokenUsage AIBridgeSessionThreadsTokenUsage `json:"token_usage"` + AgenticActions []AIBridgeAgenticAction `json:"agentic_actions"` +} + +// AIBridgeAgenticAction represents a tool call with associated +// thinking blocks and token usage from one or more interceptions. +type AIBridgeAgenticAction struct { + Model string `json:"model"` + TokenUsage AIBridgeSessionThreadsTokenUsage `json:"token_usage"` + Thinking []AIBridgeModelThought `json:"thinking"` + ToolCalls []AIBridgeToolCall `json:"tool_calls"` +} + +// AIBridgeModelThought represents a single thinking block from +// the model. +type AIBridgeModelThought struct { + Text string `json:"text"` +} + +// AIBridgeToolCall represents a tool call recorded during an +// interception. +type AIBridgeToolCall struct { + ID uuid.UUID `json:"id" format:"uuid"` + InterceptionID uuid.UUID `json:"interception_id" format:"uuid"` + ProviderResponseID string `json:"provider_response_id"` + ServerURL string `json:"server_url"` + Tool string `json:"tool"` + Injected bool `json:"injected"` + Input string `json:"input"` + Metadata map[string]any `json:"metadata"` + CreatedAt time.Time `json:"created_at" format:"date-time"` +} + // @typescript-ignore AIBridgeListSessionsFilter type AIBridgeListSessionsFilter struct { // Limit defaults to 100, max is 1000. @@ -228,3 +297,30 @@ func (c *Client) AIBridgeListSessions(ctx context.Context, filter AIBridgeListSe var resp AIBridgeListSessionsResponse return resp, json.NewDecoder(res.Body).Decode(&resp) } + +// AIBridgeGetSessionThreads returns a single session with expanded +// thread details including agentic actions and thinking blocks. +func (c *Client) AIBridgeGetSessionThreads(ctx context.Context, sessionID string, afterID, beforeID uuid.UUID, limit int32) (AIBridgeSessionThreadsResponse, error) { + res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/v2/aibridge/sessions/%s", sessionID), nil, func(r *http.Request) { + q := r.URL.Query() + if afterID != uuid.Nil { + q.Set("after_id", afterID.String()) + } + if beforeID != uuid.Nil { + q.Set("before_id", beforeID.String()) + } + if limit > 0 { + q.Set("limit", fmt.Sprintf("%d", limit)) + } + r.URL.RawQuery = q.Encode() + }) + if err != nil { + return AIBridgeSessionThreadsResponse{}, err + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return AIBridgeSessionThreadsResponse{}, ReadBodyAsError(res) + } + var resp AIBridgeSessionThreadsResponse + return resp, json.NewDecoder(res.Body).Decode(&resp) +} diff --git a/docs/reference/api/aibridge.md b/docs/reference/api/aibridge.md index 28479b8991..6852e42d91 100644 --- a/docs/reference/api/aibridge.md +++ b/docs/reference/api/aibridge.md @@ -207,3 +207,124 @@ curl -X GET http://coder-server:8080/api/v2/aibridge/sessions \ | 200 | [OK](https://tools.ietf.org/html/rfc7231#section-6.3.1) | OK | [codersdk.AIBridgeListSessionsResponse](schemas.md#codersdkaibridgelistsessionsresponse) | To perform this operation, you must be authenticated. [Learn more](authentication.md). + +## Get AI Bridge session threads + +### Code samples + +```shell +# Example request using curl +curl -X GET http://coder-server:8080/api/v2/aibridge/sessions/{session_id} \ + -H 'Accept: application/json' \ + -H 'Coder-Session-Token: API_KEY' +``` + +`GET /aibridge/sessions/{session_id}` + +### Parameters + +| Name | In | Type | Required | Description | +|--------------|-------|---------|----------|-----------------------------------------------------| +| `session_id` | path | string | true | Session ID (client_session_id or interception UUID) | +| `after_id` | query | string | false | Thread pagination cursor (forward/older) | +| `before_id` | query | string | false | Thread pagination cursor (backward/newer) | +| `limit` | query | integer | false | Number of threads per page (default 50) | + +### Example responses + +> 200 Response + +```json +{ + "client": "string", + "ended_at": "2019-08-24T14:15:22Z", + "id": "string", + "initiator": { + "avatar_url": "http://example.com", + "id": "497f6eca-6276-4993-bfeb-53cbbbba6f08", + "name": "string", + "username": "string" + }, + "metadata": { + "property1": null, + "property2": null + }, + "models": [ + "string" + ], + "page_ended_at": "2019-08-24T14:15:22Z", + "page_started_at": "2019-08-24T14:15:22Z", + "providers": [ + "string" + ], + "started_at": "2019-08-24T14:15:22Z", + "threads": [ + { + "agentic_actions": [ + { + "model": "string", + "thinking": [ + { + "text": "string" + } + ], + "token_usage": { + "input_tokens": 0, + "metadata": { + "property1": null, + "property2": null + }, + "output_tokens": 0 + }, + "tool_calls": [ + { + "created_at": "2019-08-24T14:15:22Z", + "id": "497f6eca-6276-4993-bfeb-53cbbbba6f08", + "injected": true, + "input": "string", + "interception_id": "34d9b688-63ad-46f4-88b5-665c1e7f7824", + "metadata": { + "property1": null, + "property2": null + }, + "provider_response_id": "string", + "server_url": "string", + "tool": "string" + } + ] + } + ], + "ended_at": "2019-08-24T14:15:22Z", + "id": "497f6eca-6276-4993-bfeb-53cbbbba6f08", + "model": "string", + "prompt": "string", + "provider": "string", + "started_at": "2019-08-24T14:15:22Z", + "token_usage": { + "input_tokens": 0, + "metadata": { + "property1": null, + "property2": null + }, + "output_tokens": 0 + } + } + ], + "token_usage_summary": { + "input_tokens": 0, + "metadata": { + "property1": null, + "property2": null + }, + "output_tokens": 0 + } +} +``` + +### Responses + +| Status | Meaning | Description | Schema | +|--------|---------------------------------------------------------|-------------|----------------------------------------------------------------------------------------------| +| 200 | [OK](https://tools.ietf.org/html/rfc7231#section-6.3.1) | OK | [codersdk.AIBridgeSessionThreadsResponse](schemas.md#codersdkaibridgesessionthreadsresponse) | + +To perform this operation, you must be authenticated. [Learn more](authentication.md). diff --git a/docs/reference/api/schemas.md b/docs/reference/api/schemas.md index 96ea900ffc..1b9d686c2a 100644 --- a/docs/reference/api/schemas.md +++ b/docs/reference/api/schemas.md @@ -337,6 +337,52 @@ | `groups` | array of [codersdk.Group](#codersdkgroup) | false | | | | `users` | array of [codersdk.ReducedUser](#codersdkreduceduser) | false | | | +## codersdk.AIBridgeAgenticAction + +```json +{ + "model": "string", + "thinking": [ + { + "text": "string" + } + ], + "token_usage": { + "input_tokens": 0, + "metadata": { + "property1": null, + "property2": null + }, + "output_tokens": 0 + }, + "tool_calls": [ + { + "created_at": "2019-08-24T14:15:22Z", + "id": "497f6eca-6276-4993-bfeb-53cbbbba6f08", + "injected": true, + "input": "string", + "interception_id": "34d9b688-63ad-46f4-88b5-665c1e7f7824", + "metadata": { + "property1": null, + "property2": null + }, + "provider_response_id": "string", + "server_url": "string", + "tool": "string" + } + ] +} +``` + +### Properties + +| Name | Type | Required | Restrictions | Description | +|---------------|----------------------------------------------------------------------------------------|----------|--------------|-------------| +| `model` | string | false | | | +| `thinking` | array of [codersdk.AIBridgeModelThought](#codersdkaibridgemodelthought) | false | | | +| `token_usage` | [codersdk.AIBridgeSessionThreadsTokenUsage](#codersdkaibridgesessionthreadstokenusage) | false | | | +| `tool_calls` | array of [codersdk.AIBridgeToolCall](#codersdkaibridgetoolcall) | false | | | + ## codersdk.AIBridgeAnthropicConfig ```json @@ -643,6 +689,20 @@ | `count` | integer | false | | | | `sessions` | array of [codersdk.AIBridgeSession](#codersdkaibridgesession) | false | | | +## codersdk.AIBridgeModelThought + +```json +{ + "text": "string" +} +``` + +### Properties + +| Name | Type | Required | Restrictions | Description | +|--------|--------|----------|--------------|-------------| +| `text` | string | false | | | + ## codersdk.AIBridgeOpenAIConfig ```json @@ -745,6 +805,135 @@ | `threads` | integer | false | | | | `token_usage_summary` | [codersdk.AIBridgeSessionTokenUsageSummary](#codersdkaibridgesessiontokenusagesummary) | false | | | +## codersdk.AIBridgeSessionThreadsResponse + +```json +{ + "client": "string", + "ended_at": "2019-08-24T14:15:22Z", + "id": "string", + "initiator": { + "avatar_url": "http://example.com", + "id": "497f6eca-6276-4993-bfeb-53cbbbba6f08", + "name": "string", + "username": "string" + }, + "metadata": { + "property1": null, + "property2": null + }, + "models": [ + "string" + ], + "page_ended_at": "2019-08-24T14:15:22Z", + "page_started_at": "2019-08-24T14:15:22Z", + "providers": [ + "string" + ], + "started_at": "2019-08-24T14:15:22Z", + "threads": [ + { + "agentic_actions": [ + { + "model": "string", + "thinking": [ + { + "text": "string" + } + ], + "token_usage": { + "input_tokens": 0, + "metadata": { + "property1": null, + "property2": null + }, + "output_tokens": 0 + }, + "tool_calls": [ + { + "created_at": "2019-08-24T14:15:22Z", + "id": "497f6eca-6276-4993-bfeb-53cbbbba6f08", + "injected": true, + "input": "string", + "interception_id": "34d9b688-63ad-46f4-88b5-665c1e7f7824", + "metadata": { + "property1": null, + "property2": null + }, + "provider_response_id": "string", + "server_url": "string", + "tool": "string" + } + ] + } + ], + "ended_at": "2019-08-24T14:15:22Z", + "id": "497f6eca-6276-4993-bfeb-53cbbbba6f08", + "model": "string", + "prompt": "string", + "provider": "string", + "started_at": "2019-08-24T14:15:22Z", + "token_usage": { + "input_tokens": 0, + "metadata": { + "property1": null, + "property2": null + }, + "output_tokens": 0 + } + } + ], + "token_usage_summary": { + "input_tokens": 0, + "metadata": { + "property1": null, + "property2": null + }, + "output_tokens": 0 + } +} +``` + +### Properties + +| Name | Type | Required | Restrictions | Description | +|-----------------------|----------------------------------------------------------------------------------------|----------|--------------|-------------| +| `client` | string | false | | | +| `ended_at` | string | false | | | +| `id` | string | false | | | +| `initiator` | [codersdk.MinimalUser](#codersdkminimaluser) | false | | | +| `metadata` | object | false | | | +| » `[any property]` | any | false | | | +| `models` | array of string | false | | | +| `page_ended_at` | string | false | | | +| `page_started_at` | string | false | | | +| `providers` | array of string | false | | | +| `started_at` | string | false | | | +| `threads` | array of [codersdk.AIBridgeThread](#codersdkaibridgethread) | false | | | +| `token_usage_summary` | [codersdk.AIBridgeSessionThreadsTokenUsage](#codersdkaibridgesessionthreadstokenusage) | false | | | + +## codersdk.AIBridgeSessionThreadsTokenUsage + +```json +{ + "input_tokens": 0, + "metadata": { + "property1": null, + "property2": null + }, + "output_tokens": 0 +} +``` + +### Properties + +| Name | Type | Required | Restrictions | Description | +|--------------------|---------|----------|--------------|-------------| +| `input_tokens` | integer | false | | | +| `metadata` | object | false | | | +| » `[any property]` | any | false | | | +| `output_tokens` | integer | false | | | + ## codersdk.AIBridgeSessionTokenUsageSummary ```json @@ -761,6 +950,74 @@ | `input_tokens` | integer | false | | | | `output_tokens` | integer | false | | | +## codersdk.AIBridgeThread + +```json +{ + "agentic_actions": [ + { + "model": "string", + "thinking": [ + { + "text": "string" + } + ], + "token_usage": { + "input_tokens": 0, + "metadata": { + "property1": null, + "property2": null + }, + "output_tokens": 0 + }, + "tool_calls": [ + { + "created_at": "2019-08-24T14:15:22Z", + "id": "497f6eca-6276-4993-bfeb-53cbbbba6f08", + "injected": true, + "input": "string", + "interception_id": "34d9b688-63ad-46f4-88b5-665c1e7f7824", + "metadata": { + "property1": null, + "property2": null + }, + "provider_response_id": "string", + "server_url": "string", + "tool": "string" + } + ] + } + ], + "ended_at": "2019-08-24T14:15:22Z", + "id": "497f6eca-6276-4993-bfeb-53cbbbba6f08", + "model": "string", + "prompt": "string", + "provider": "string", + "started_at": "2019-08-24T14:15:22Z", + "token_usage": { + "input_tokens": 0, + "metadata": { + "property1": null, + "property2": null + }, + "output_tokens": 0 + } +} +``` + +### Properties + +| Name | Type | Required | Restrictions | Description | +|-------------------|----------------------------------------------------------------------------------------|----------|--------------|-------------| +| `agentic_actions` | array of [codersdk.AIBridgeAgenticAction](#codersdkaibridgeagenticaction) | false | | | +| `ended_at` | string | false | | | +| `id` | string | false | | | +| `model` | string | false | | | +| `prompt` | string | false | | | +| `provider` | string | false | | | +| `started_at` | string | false | | | +| `token_usage` | [codersdk.AIBridgeSessionThreadsTokenUsage](#codersdkaibridgesessionthreadstokenusage) | false | | | + ## codersdk.AIBridgeTokenUsage ```json @@ -791,6 +1048,40 @@ | `output_tokens` | integer | false | | | | `provider_response_id` | string | false | | | +## codersdk.AIBridgeToolCall + +```json +{ + "created_at": "2019-08-24T14:15:22Z", + "id": "497f6eca-6276-4993-bfeb-53cbbbba6f08", + "injected": true, + "input": "string", + "interception_id": "34d9b688-63ad-46f4-88b5-665c1e7f7824", + "metadata": { + "property1": null, + "property2": null + }, + "provider_response_id": "string", + "server_url": "string", + "tool": "string" +} +``` + +### Properties + +| Name | Type | Required | Restrictions | Description | +|------------------------|---------|----------|--------------|-------------| +| `created_at` | string | false | | | +| `id` | string | false | | | +| `injected` | boolean | false | | | +| `input` | string | false | | | +| `interception_id` | string | false | | | +| `metadata` | object | false | | | +| » `[any property]` | any | false | | | +| `provider_response_id` | string | false | | | +| `server_url` | string | false | | | +| `tool` | string | false | | | + ## codersdk.AIBridgeToolUsage ```json diff --git a/enterprise/coderd/aibridge.go b/enterprise/coderd/aibridge.go index 77b6363d07..d1766ed4b2 100644 --- a/enterprise/coderd/aibridge.go +++ b/enterprise/coderd/aibridge.go @@ -3,8 +3,10 @@ package coderd import ( "context" "database/sql" + "errors" "fmt" "net/http" + "strconv" "time" "github.com/go-chi/chi/v5" @@ -33,6 +35,10 @@ const ( aiBridgeRateLimitWindow = time.Second ) +// errInvalidCursor is returned when a pagination cursor does not +// reference a valid resource in the expected scope. +var errInvalidCursor = xerrors.New("invalid pagination cursor") + // aibridgeHandler handles all aibridged-related endpoints. func aibridgeHandler(api *API, middlewares ...func(http.Handler) http.Handler) func(r chi.Router) { // Build the overload protection middleware chain for the aibridged handler. @@ -47,6 +53,7 @@ func aibridgeHandler(api *API, middlewares ...func(http.Handler) http.Handler) f r.Use(middlewares...) r.Get("/interceptions", api.aiBridgeListInterceptions) r.Get("/sessions", api.aiBridgeListSessions) + r.Get("/sessions/{session_id}", api.aiBridgeGetSessionThreads) r.Get("/models", api.aiBridgeListModels) }) @@ -125,12 +132,9 @@ func (api *API) aiBridgeListInterceptions(rw http.ResponseWriter, r *http.Reques rows []database.ListAIBridgeInterceptionsRow ) err := api.Database.InTx(func(db database.Store) error { - // Ensure the after_id interception exists and is visible to the user. - if page.AfterID != uuid.Nil { - _, err := db.GetAIBridgeInterceptionByID(ctx, page.AfterID) - if err != nil { - return xerrors.Errorf("get aibridge interception by id %s for cursor pagination: %w", page.AfterID, err) - } + // Validate the cursor interception exists and is visible. + if err := validateInterceptionCursor(ctx, db, page.AfterID, "after_id", ""); err != nil { + return err } var err error @@ -157,6 +161,13 @@ func (api *API) aiBridgeListInterceptions(rw http.ResponseWriter, r *http.Reques return nil }, nil) if err != nil { + if errors.Is(err, errInvalidCursor) { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid pagination cursor.", + Detail: err.Error(), + }) + return + } httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ Message: "Internal error getting AI Bridge interceptions.", Detail: err.Error(), @@ -304,6 +315,198 @@ func (api *API) aiBridgeListSessions(rw http.ResponseWriter, r *http.Request) { }) } +// aiBridgeGetSessionThreads returns a single session with fully expanded +// threads including agentic actions and thinking blocks. +// +// @Summary Get AI Bridge session threads +// @ID get-ai-bridge-session-threads +// @Security CoderSessionToken +// @Produce json +// @Tags AI Bridge +// @Param session_id path string true "Session ID (client_session_id or interception UUID)" +// @Param after_id query string false "Thread pagination cursor (forward/older)" +// @Param before_id query string false "Thread pagination cursor (backward/newer)" +// @Param limit query int false "Number of threads per page (default 50)" +// @Success 200 {object} codersdk.AIBridgeSessionThreadsResponse +// @Router /aibridge/sessions/{session_id} [get] +func (api *API) aiBridgeGetSessionThreads(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + sessionIDParam := chi.URLParam(r, "session_id") + if sessionIDParam == "" { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Missing session_id path parameter.", + }) + return + } + + // Parse optional pagination cursors. + var afterID, beforeID uuid.UUID + if v := r.URL.Query().Get("after_id"); v != "" { + var err error + afterID, err = uuid.Parse(v) + if err != nil { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid after_id query parameter.", + Detail: err.Error(), + }) + return + } + } + if v := r.URL.Query().Get("before_id"); v != "" { + var err error + beforeID, err = uuid.Parse(v) + if err != nil { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid before_id query parameter.", + Detail: err.Error(), + }) + return + } + } + if afterID != uuid.Nil && beforeID != uuid.Nil { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Cannot use both after_id and before_id in the same request.", + }) + return + } + + var limit int32 = 50 + if v := r.URL.Query().Get("limit"); v != "" { + parsed, err := strconv.ParseInt(v, 10, 32) + if err != nil || parsed < 1 || parsed > 200 { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid limit query parameter.", + Detail: "Limit must be between 1 and 200.", + }) + return + } + limit = int32(parsed) + } + + // Fetch session metadata by reusing the sessions list query + // with a session_id filter. + //nolint:exhaustruct // Let's keep things concise. + sessions, err := api.Database.ListAIBridgeSessions(ctx, database.ListAIBridgeSessionsParams{ + Limit: 1, + SessionID: sessionIDParam, + }) + if err != nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error fetching session.", + Detail: err.Error(), + }) + return + } + if len(sessions) == 0 { + httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{ + Message: "Session not found.", + }) + return + } + session := sessions[0] + + // Fetch paginated session threads and their sub-resources inside + // a repeatable-read transaction so the data is consistent. + var ( + allRows []database.ListAIBridgeSessionThreadsRow + threadRows []database.ListAIBridgeSessionThreadsRow + tokenUsages []database.AIBridgeTokenUsage + toolUsages []database.AIBridgeToolUsage + userPrompts []database.AIBridgeUserPrompt + modelThoughts []database.AIBridgeModelThought + ) + err = api.Database.InTx(func(db database.Store) error { + // Validate cursor IDs before querying threads. The SQL + // subquery returns NULL for unknown cursors, which silently + // filters out all rows instead of surfacing an error. + if err := validateInterceptionCursor(ctx, db, afterID, "after_id", sessionIDParam); err != nil { + return err + } + if err := validateInterceptionCursor(ctx, db, beforeID, "before_id", sessionIDParam); err != nil { + return err + } + + var err error + + // Fetch all interceptions (unpaginated) so we can aggregate + // session-level token metadata across every thread. + //nolint:exhaustruct // Let's be concise. + allRows, err = db.ListAIBridgeSessionThreads(ctx, database.ListAIBridgeSessionThreadsParams{ + SessionID: sessionIDParam, + }) + if err != nil { + return xerrors.Errorf("list all session threads: %w", err) + } + + threadRows, err = db.ListAIBridgeSessionThreads(ctx, database.ListAIBridgeSessionThreadsParams{ + SessionID: sessionIDParam, + AfterID: afterID, + BeforeID: beforeID, + Limit: limit, + }) + if err != nil { + return xerrors.Errorf("list session threads: %w", err) + } + + // Use all interception IDs for token usage (session-level + // metadata aggregation needs every thread). Use only the + // page's IDs for other sub-resources. + allIDs := make([]uuid.UUID, len(allRows)) + for i, row := range allRows { + allIDs[i] = row.AIBridgeInterception.ID + } + ids := make([]uuid.UUID, len(threadRows)) + for i, row := range threadRows { + ids[i] = row.AIBridgeInterception.ID + } + + tokenUsages, err = db.ListAIBridgeTokenUsagesByInterceptionIDs(ctx, allIDs) + if err != nil { + return xerrors.Errorf("list token usages: %w", err) + } + + toolUsages, err = db.ListAIBridgeToolUsagesByInterceptionIDs(ctx, ids) + if err != nil { + return xerrors.Errorf("list tool usages: %w", err) + } + + userPrompts, err = db.ListAIBridgeUserPromptsByInterceptionIDs(ctx, ids) + if err != nil { + return xerrors.Errorf("list user prompts: %w", err) + } + + modelThoughts, err = db.ListAIBridgeModelThoughtsByInterceptionIDs(ctx, ids) + if err != nil { + return xerrors.Errorf("list model thoughts: %w", err) + } + + return nil + }, &database.TxOptions{ + Isolation: sql.LevelRepeatableRead, + ReadOnly: true, + TxIdentifier: "aibridge_get_session_threads", + }) + if err != nil { + if errors.Is(err, errInvalidCursor) { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid pagination cursor.", + Detail: err.Error(), + }) + return + } + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Internal error fetching session threads.", + Detail: err.Error(), + }) + return + } + + resp := db2sdk.AIBridgeSessionThreads(session, threadRows, tokenUsages, toolUsages, userPrompts, modelThoughts) + + httpapi.Write(ctx, rw, http.StatusOK, resp) +} + // aiBridgeListModels returns all AI Bridge models a user can see. // // @Summary List AI Bridge models @@ -356,6 +559,24 @@ func (api *API) aiBridgeListModels(rw http.ResponseWriter, r *http.Request) { httpapi.Write(ctx, rw, http.StatusOK, models) } +// validateInterceptionCursor checks that a pagination cursor refers to an +// existing interception. When sessionID is non-empty the interception must +// also belong to that session. Returns errInvalidCursor on failure so +// callers can distinguish bad cursors from internal errors. +func validateInterceptionCursor(ctx context.Context, db database.Store, cursorID uuid.UUID, cursorName, sessionID string) error { + if cursorID == uuid.Nil { + return nil + } + interception, err := db.GetAIBridgeInterceptionByID(ctx, cursorID) + if err != nil { + return xerrors.Errorf("%s: interception %s not found: %w", cursorName, cursorID, errInvalidCursor) + } + if sessionID != "" && interception.SessionID != sessionID { + return xerrors.Errorf("%s: interception %s does not belong to session %s: %w", cursorName, cursorID, sessionID, errInvalidCursor) + } + return nil +} + func populatedAndConvertAIBridgeInterceptions(ctx context.Context, db database.Store, dbInterceptions []database.ListAIBridgeInterceptionsRow) ([]codersdk.AIBridgeInterception, error) { if len(dbInterceptions) == 0 { return []codersdk.AIBridgeInterception{}, nil diff --git a/enterprise/coderd/aibridge_test.go b/enterprise/coderd/aibridge_test.go index 92686f9f61..9808916230 100644 --- a/enterprise/coderd/aibridge_test.go +++ b/enterprise/coderd/aibridge_test.go @@ -2,6 +2,7 @@ package coderd_test import ( "database/sql" + "encoding/json" "io" "net/http" "testing" @@ -53,18 +54,7 @@ func TestAIBridgeListInterceptions(t *testing.T) { t.Run("EmptyDB", func(t *testing.T) { t.Parallel() - dv := coderdtest.DeploymentValues(t) - dv.AI.BridgeConfig.Enabled = serpent.Bool(true) - client, _ := coderdenttest.New(t, &coderdenttest.Options{ - Options: &coderdtest.Options{ - DeploymentValues: dv, - }, - LicenseOptions: &coderdenttest.LicenseOptions{ - Features: license.Features{ - codersdk.FeatureAIBridge: 1, - }, - }, - }) + client, _ := coderdenttest.New(t, aibridgeOpts(t)) ctx := testutil.Context(t, testutil.WaitLong) //nolint:gocritic // Owner role is irrelevant here. res, err := client.AIBridgeListInterceptions(ctx, codersdk.AIBridgeListInterceptionsFilter{}) @@ -74,18 +64,7 @@ func TestAIBridgeListInterceptions(t *testing.T) { t.Run("OK", func(t *testing.T) { t.Parallel() - dv := coderdtest.DeploymentValues(t) - dv.AI.BridgeConfig.Enabled = serpent.Bool(true) - client, db, firstUser := coderdenttest.NewWithDatabase(t, &coderdenttest.Options{ - Options: &coderdtest.Options{ - DeploymentValues: dv, - }, - LicenseOptions: &coderdenttest.LicenseOptions{ - Features: license.Features{ - codersdk.FeatureAIBridge: 1, - }, - }, - }) + client, db, firstUser := coderdenttest.NewWithDatabase(t, aibridgeOpts(t)) ctx := testutil.Context(t, testutil.WaitLong) user1, err := client.User(ctx, codersdk.Me) @@ -193,18 +172,7 @@ func TestAIBridgeListInterceptions(t *testing.T) { t.Run("Pagination", func(t *testing.T) { t.Parallel() - dv := coderdtest.DeploymentValues(t) - dv.AI.BridgeConfig.Enabled = serpent.Bool(true) - client, db, firstUser := coderdenttest.NewWithDatabase(t, &coderdenttest.Options{ - Options: &coderdtest.Options{ - DeploymentValues: dv, - }, - LicenseOptions: &coderdenttest.LicenseOptions{ - Features: license.Features{ - codersdk.FeatureAIBridge: 1, - }, - }, - }) + client, db, firstUser := coderdenttest.NewWithDatabase(t, aibridgeOpts(t)) ctx := testutil.Context(t, testutil.WaitLong) allInterceptionIDs := make([]uuid.UUID, 0, 20) @@ -309,18 +277,7 @@ func TestAIBridgeListInterceptions(t *testing.T) { t.Run("InflightInterceptions", func(t *testing.T) { t.Parallel() - dv := coderdtest.DeploymentValues(t) - dv.AI.BridgeConfig.Enabled = serpent.Bool(true) - client, db, firstUser := coderdenttest.NewWithDatabase(t, &coderdenttest.Options{ - Options: &coderdtest.Options{ - DeploymentValues: dv, - }, - LicenseOptions: &coderdenttest.LicenseOptions{ - Features: license.Features{ - codersdk.FeatureAIBridge: 1, - }, - }, - }) + client, db, firstUser := coderdenttest.NewWithDatabase(t, aibridgeOpts(t)) ctx := testutil.Context(t, testutil.WaitLong) now := dbtime.Now() @@ -343,18 +300,7 @@ func TestAIBridgeListInterceptions(t *testing.T) { t.Run("Authorized", func(t *testing.T) { t.Parallel() - dv := coderdtest.DeploymentValues(t) - dv.AI.BridgeConfig.Enabled = serpent.Bool(true) - adminClient, db, firstUser := coderdenttest.NewWithDatabase(t, &coderdenttest.Options{ - Options: &coderdtest.Options{ - DeploymentValues: dv, - }, - LicenseOptions: &coderdenttest.LicenseOptions{ - Features: license.Features{ - codersdk.FeatureAIBridge: 1, - }, - }, - }) + adminClient, db, firstUser := coderdenttest.NewWithDatabase(t, aibridgeOpts(t)) ctx := testutil.Context(t, testutil.WaitLong) secondUserClient, secondUser := coderdtest.CreateAnotherUser(t, adminClient, firstUser.OrganizationID) @@ -389,18 +335,7 @@ func TestAIBridgeListInterceptions(t *testing.T) { t.Run("Filter", func(t *testing.T) { t.Parallel() - dv := coderdtest.DeploymentValues(t) - dv.AI.BridgeConfig.Enabled = serpent.Bool(true) - client, db, firstUser := coderdenttest.NewWithDatabase(t, &coderdenttest.Options{ - Options: &coderdtest.Options{ - DeploymentValues: dv, - }, - LicenseOptions: &coderdenttest.LicenseOptions{ - Features: license.Features{ - codersdk.FeatureAIBridge: 1, - }, - }, - }) + client, db, firstUser := coderdenttest.NewWithDatabase(t, aibridgeOpts(t)) ctx := testutil.Context(t, testutil.WaitLong) user1, err := client.User(ctx, codersdk.Me) @@ -622,18 +557,7 @@ func TestAIBridgeListInterceptions(t *testing.T) { t.Run("FilterErrors", func(t *testing.T) { t.Parallel() - dv := coderdtest.DeploymentValues(t) - dv.AI.BridgeConfig.Enabled = serpent.Bool(true) - client, _ := coderdenttest.New(t, &coderdenttest.Options{ - Options: &coderdtest.Options{ - DeploymentValues: dv, - }, - LicenseOptions: &coderdenttest.LicenseOptions{ - Features: license.Features{ - codersdk.FeatureAIBridge: 1, - }, - }, - }) + client, _ := coderdenttest.New(t, aibridgeOpts(t)) // No need to insert any test data, we're just testing the filter // errors. @@ -700,6 +624,25 @@ func TestAIBridgeListInterceptions(t *testing.T) { }) } }) + + t.Run("InvalidCursor", func(t *testing.T) { + t.Parallel() + client, _ := coderdenttest.New(t, aibridgeOpts(t)) + ctx := testutil.Context(t, testutil.WaitLong) + + // Using a nonexistent UUID as after_id should return 400, + // not silently return an empty page. + //nolint:gocritic // Owner role is irrelevant here. + _, err := client.AIBridgeListInterceptions(ctx, codersdk.AIBridgeListInterceptionsFilter{ + Pagination: codersdk.Pagination{ + AfterID: uuid.New(), + }, + }) + var sdkErr *codersdk.Error + require.ErrorAs(t, err, &sdkErr) + require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode()) + require.Contains(t, sdkErr.Message, "Invalid pagination cursor") + }) } func aibridgeOpts(t *testing.T) *coderdenttest.Options { @@ -1466,3 +1409,485 @@ func TestAIBridgeConcurrencyLimiting(t *testing.T) { t.Fatal("timed out waiting for first request to complete") } } + +func TestAIBridgeGetSessionThreads(t *testing.T) { + t.Parallel() + + t.Run("NotFound", func(t *testing.T) { + t.Parallel() + ownerClient, firstUser := coderdenttest.New(t, aibridgeOpts(t)) + memberClient, _ := coderdtest.CreateAnotherUser(t, ownerClient, firstUser.OrganizationID) + + ctx := testutil.Context(t, testutil.WaitLong) + _, err := memberClient.AIBridgeGetSessionThreads(ctx, "nonexistent-session-id", uuid.Nil, uuid.Nil, 0) + var sdkErr *codersdk.Error + require.ErrorAs(t, err, &sdkErr) + require.Equal(t, http.StatusNotFound, sdkErr.StatusCode()) + }) + + t.Run("LookupByClientSessionID", func(t *testing.T) { + t.Parallel() + client, db, firstUser := coderdenttest.NewWithDatabase(t, aibridgeOpts(t)) + ctx := testutil.Context(t, testutil.WaitLong) + + now := dbtime.Now() + endedAt := now.Add(time.Minute) + dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ + InitiatorID: firstUser.UserID, + Provider: "anthropic", + Model: "claude-4", + StartedAt: now, + ClientSessionID: sql.NullString{String: "my-session", Valid: true}, + }, &endedAt) + + res, err := client.AIBridgeGetSessionThreads(ctx, "my-session", uuid.Nil, uuid.Nil, 0) + require.NoError(t, err) + require.Equal(t, "my-session", res.ID) + require.Len(t, res.Threads, 1) + require.Equal(t, "claude-4", res.Threads[0].Model) + require.Equal(t, "anthropic", res.Threads[0].Provider) + }) + + t.Run("LookupByInterceptionUUID", func(t *testing.T) { + t.Parallel() + client, db, firstUser := coderdenttest.NewWithDatabase(t, aibridgeOpts(t)) + ctx := testutil.Context(t, testutil.WaitLong) + + now := dbtime.Now() + endedAt := now.Add(time.Minute) + i1 := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ + InitiatorID: firstUser.UserID, + Provider: "openai", + Model: "gpt-4", + StartedAt: now, + }, &endedAt) + + // When no client session ID is set, the interception ID becomes the session identifier. + res, err := client.AIBridgeGetSessionThreads(ctx, i1.ID.String(), uuid.Nil, uuid.Nil, 0) + require.NoError(t, err) + require.Equal(t, i1.ID.String(), res.ID) + require.Len(t, res.Threads, 1) + }) + + t.Run("ThreadsWithAgenticActions", func(t *testing.T) { + t.Parallel() + client, db, firstUser := coderdenttest.NewWithDatabase(t, aibridgeOpts(t)) + ctx := testutil.Context(t, testutil.WaitLong) + + now := dbtime.Now() + + // Create a session with one thread. Root interception + child + // interception sharing thread_root_id. + rootEndedAt := now.Add(time.Minute) + root := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ + InitiatorID: firstUser.UserID, + Provider: "anthropic", + Model: "claude-4", + StartedAt: now, + ClientSessionID: sql.NullString{String: "thread-session", Valid: true}, + }, &rootEndedAt) + + childEndedAt := now.Add(2 * time.Minute) + child := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ + InitiatorID: firstUser.UserID, + Provider: "anthropic", + Model: "claude-4", + StartedAt: now.Add(time.Minute), + ClientSessionID: sql.NullString{String: "thread-session", Valid: true}, + ThreadRootInterceptionID: uuid.NullUUID{UUID: root.ID, Valid: true}, + ThreadParentInterceptionID: uuid.NullUUID{UUID: root.ID, Valid: true}, + }, &childEndedAt) + + // Add a user prompt on the root. + dbgen.AIBridgeUserPrompt(t, db, database.InsertAIBridgeUserPromptParams{ + InterceptionID: root.ID, + Prompt: "implement login feature", + CreatedAt: now, + }) + + // Add token usage on root with metadata. + providerRespID := "resp-1" + dbgen.AIBridgeTokenUsage(t, db, database.InsertAIBridgeTokenUsageParams{ + InterceptionID: root.ID, + ProviderResponseID: providerRespID, + InputTokens: 100, + OutputTokens: 50, + Metadata: json.RawMessage(`{"cache_read_input": 20, "cache_creation_input": 10}`), + CreatedAt: now, + }) + + // Add two tool usages on root (demonstrates multiple tools per action). + dbgen.AIBridgeToolUsage(t, db, database.InsertAIBridgeToolUsageParams{ + InterceptionID: root.ID, + ProviderResponseID: providerRespID, + Tool: "read_file", + Input: `{"path": "/main.go"}`, + CreatedAt: now.Add(time.Second), + }) + dbgen.AIBridgeToolUsage(t, db, database.InsertAIBridgeToolUsageParams{ + InterceptionID: root.ID, + ProviderResponseID: providerRespID, + Tool: "list_dir", + Input: `{"path": "/"}`, + CreatedAt: now.Add(2 * time.Second), + }) + + // Add model thought for the root interception. + dbgen.AIBridgeModelThought(t, db, database.InsertAIBridgeModelThoughtParams{ + InterceptionID: root.ID, + Content: "Let me read the main file first.", + CreatedAt: now.Add(time.Second), + }) + + // Add token usage on child. + dbgen.AIBridgeTokenUsage(t, db, database.InsertAIBridgeTokenUsageParams{ + InterceptionID: child.ID, + ProviderResponseID: "resp-2", + InputTokens: 200, + OutputTokens: 100, + Metadata: json.RawMessage(`{"cache_read_input": 30}`), + CreatedAt: now.Add(time.Minute), + }) + + // Add another tool usage on child. + dbgen.AIBridgeToolUsage(t, db, database.InsertAIBridgeToolUsageParams{ + InterceptionID: child.ID, + ProviderResponseID: "resp-2", + Tool: "write_file", + Input: `{"path": "/login.go"}`, + CreatedAt: now.Add(time.Minute + time.Second), + }) + + res, err := client.AIBridgeGetSessionThreads(ctx, "thread-session", uuid.Nil, uuid.Nil, 0) + require.NoError(t, err) + require.Equal(t, "thread-session", res.ID) + require.Len(t, res.Threads, 1) + + // PageStartedAt/PageEndedAt bracket the visible threads. + require.NotNil(t, res.PageStartedAt) + require.NotNil(t, res.PageEndedAt) + require.True(t, res.PageStartedAt.Equal(now), "PageStartedAt should equal root started_at") + require.True(t, res.PageEndedAt.Equal(childEndedAt), "PageEndedAt should equal child ended_at") + + thread := res.Threads[0] + require.Equal(t, root.ID, thread.ID) + require.NotNil(t, thread.Prompt) + require.Equal(t, "implement login feature", *thread.Prompt) + require.Equal(t, "claude-4", thread.Model) + require.Equal(t, "anthropic", thread.Provider) + + // Thread-level token aggregation. + require.EqualValues(t, 300, thread.TokenUsage.InputTokens) + require.EqualValues(t, 150, thread.TokenUsage.OutputTokens) + require.NotEmpty(t, thread.TokenUsage.Metadata) + require.EqualValues(t, int64(50), thread.TokenUsage.Metadata["cache_read_input"]) + require.EqualValues(t, int64(10), thread.TokenUsage.Metadata["cache_creation_input"]) + + // Two agentic actions (one per interception with tool calls). + require.Len(t, thread.AgenticActions, 2) + + action1 := thread.AgenticActions[0] + // Root interception has two tool calls. + require.Len(t, action1.ToolCalls, 2) + require.Equal(t, "read_file", action1.ToolCalls[0].Tool) + require.Equal(t, "list_dir", action1.ToolCalls[1].Tool) + require.Len(t, action1.Thinking, 1) + require.Equal(t, "Let me read the main file first.", action1.Thinking[0].Text) + // Token usage for root interception. + require.EqualValues(t, 100, action1.TokenUsage.InputTokens) + require.EqualValues(t, 50, action1.TokenUsage.OutputTokens) + + action2 := thread.AgenticActions[1] + require.Len(t, action2.ToolCalls, 1) + require.Equal(t, "write_file", action2.ToolCalls[0].Tool) + require.Empty(t, action2.Thinking) + + // Session-level token aggregation. + require.EqualValues(t, 300, res.TokenUsageSummary.InputTokens) + require.EqualValues(t, 150, res.TokenUsageSummary.OutputTokens) + }) + + t.Run("MultiThreadPagination", func(t *testing.T) { + t.Parallel() + client, db, firstUser := coderdenttest.NewWithDatabase(t, aibridgeOpts(t)) + ctx := testutil.Context(t, testutil.WaitLong) + + now := dbtime.Now() + + // Create a session with 3 threads. Each thread is a standalone + // interception sharing client_session_id. + startedAt := func(i int) time.Time { return now.Add(time.Duration(i) * time.Hour) } + endedAt := func(i int) time.Time { return now.Add(time.Duration(i)*time.Hour + time.Minute) } + threadIDs := make([]uuid.UUID, 3) + for i := range 3 { + ea := endedAt(i) + intc := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ + InitiatorID: firstUser.UserID, + Provider: "anthropic", + Model: "claude-4", + StartedAt: startedAt(i), + ClientSessionID: sql.NullString{String: "multi-thread-session", Valid: true}, + }, &ea) + threadIDs[i] = intc.ID + } + + // Get all threads (no pagination). + res, err := client.AIBridgeGetSessionThreads(ctx, "multi-thread-session", uuid.Nil, uuid.Nil, 0) + require.NoError(t, err) + require.Len(t, res.Threads, 3) + + // Threads are ordered by started_at ASC (chronological). + require.Equal(t, threadIDs[0], res.Threads[0].ID) + require.Equal(t, threadIDs[1], res.Threads[1].ID) + require.Equal(t, threadIDs[2], res.Threads[2].ID) + + // Page bounds span all 3 threads. + require.NotNil(t, res.PageStartedAt) + require.NotNil(t, res.PageEndedAt) + require.True(t, res.PageStartedAt.Equal(startedAt(0)), "all threads: PageStartedAt = thread 0 started_at") + require.True(t, res.PageEndedAt.Equal(endedAt(2)), "all threads: PageEndedAt = thread 2 ended_at") + + // Page with limit 1: should get only the oldest thread. + res, err = client.AIBridgeGetSessionThreads(ctx, "multi-thread-session", uuid.Nil, uuid.Nil, 1) + require.NoError(t, err) + require.Len(t, res.Threads, 1) + require.Equal(t, threadIDs[0], res.Threads[0].ID) + require.NotNil(t, res.PageStartedAt) + require.NotNil(t, res.PageEndedAt) + require.True(t, res.PageStartedAt.Equal(startedAt(0)), "page 1: PageStartedAt = thread 0 started_at") + require.True(t, res.PageEndedAt.Equal(endedAt(0)), "page 1: PageEndedAt = thread 0 ended_at") + + // Page forward using after_id: get next thread. + res, err = client.AIBridgeGetSessionThreads(ctx, "multi-thread-session", threadIDs[0], uuid.Nil, 1) + require.NoError(t, err) + require.Len(t, res.Threads, 1) + require.Equal(t, threadIDs[1], res.Threads[0].ID) + require.NotNil(t, res.PageStartedAt) + require.NotNil(t, res.PageEndedAt) + require.True(t, res.PageStartedAt.Equal(startedAt(1)), "page 2: PageStartedAt = thread 1 started_at") + require.True(t, res.PageEndedAt.Equal(endedAt(1)), "page 2: PageEndedAt = thread 1 ended_at") + + // Page forward again. + res, err = client.AIBridgeGetSessionThreads(ctx, "multi-thread-session", threadIDs[1], uuid.Nil, 1) + require.NoError(t, err) + require.Len(t, res.Threads, 1) + require.Equal(t, threadIDs[2], res.Threads[0].ID) + require.NotNil(t, res.PageStartedAt) + require.NotNil(t, res.PageEndedAt) + require.True(t, res.PageStartedAt.Equal(startedAt(2)), "page 3: PageStartedAt = thread 2 started_at") + require.True(t, res.PageEndedAt.Equal(endedAt(2)), "page 3: PageEndedAt = thread 2 ended_at") + + // No more threads. + res, err = client.AIBridgeGetSessionThreads(ctx, "multi-thread-session", threadIDs[2], uuid.Nil, 1) + require.NoError(t, err) + require.Empty(t, res.Threads) + require.Nil(t, res.PageStartedAt, "empty page: PageStartedAt is nil") + require.Nil(t, res.PageEndedAt, "empty page: PageEndedAt is nil") + + // before_id filters to threads older than the given ID. + // before_id=newest → returns both older threads, ASC. + res, err = client.AIBridgeGetSessionThreads(ctx, "multi-thread-session", uuid.Nil, threadIDs[2], 0) + require.NoError(t, err) + require.Len(t, res.Threads, 2) + require.Equal(t, threadIDs[0], res.Threads[0].ID) + require.Equal(t, threadIDs[1], res.Threads[1].ID) + require.NotNil(t, res.PageStartedAt) + require.NotNil(t, res.PageEndedAt) + require.True(t, res.PageStartedAt.Equal(startedAt(0)), "before_id=newest: PageStartedAt = thread 0 started_at") + require.True(t, res.PageEndedAt.Equal(endedAt(1)), "before_id=newest: PageEndedAt = thread 1 ended_at") + + // before_id=middle → returns only the oldest thread. + res, err = client.AIBridgeGetSessionThreads(ctx, "multi-thread-session", uuid.Nil, threadIDs[1], 0) + require.NoError(t, err) + require.Len(t, res.Threads, 1) + require.Equal(t, threadIDs[0], res.Threads[0].ID) + require.NotNil(t, res.PageStartedAt) + require.NotNil(t, res.PageEndedAt) + require.True(t, res.PageStartedAt.Equal(startedAt(0)), "before_id=middle: PageStartedAt = thread 0 started_at") + require.True(t, res.PageEndedAt.Equal(endedAt(0)), "before_id=middle: PageEndedAt = thread 0 ended_at") + + // before_id=oldest → no older threads exist. + res, err = client.AIBridgeGetSessionThreads(ctx, "multi-thread-session", uuid.Nil, threadIDs[0], 0) + require.NoError(t, err) + require.Empty(t, res.Threads) + + // Combining after_id and before_id is rejected. + _, err = client.AIBridgeGetSessionThreads(ctx, "multi-thread-session", threadIDs[2], threadIDs[0], 0) + var sdkErr *codersdk.Error + require.ErrorAs(t, err, &sdkErr) + require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode()) + }) + + // Verify that session-level token metadata aggregates tokens from ALL + // threads, not just the ones visible in the current page. + t.Run("SessionTokenAggregationAcrossPages", func(t *testing.T) { + t.Parallel() + client, db, firstUser := coderdenttest.NewWithDatabase(t, aibridgeOpts(t)) + ctx := testutil.Context(t, testutil.WaitLong) + + now := dbtime.Now() + + // Create 3 threads, each with token usage on both root and child + // interceptions to ensure child tokens are counted too. + var firstThreadID uuid.UUID + for i := range 3 { + offset := time.Duration(i) * time.Hour + rootEndedAt := now.Add(offset + 30*time.Minute) + root := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ + InitiatorID: firstUser.UserID, + Provider: "anthropic", + Model: "claude-4", + StartedAt: now.Add(offset), + ClientSessionID: sql.NullString{String: "token-agg-session", Valid: true}, + }, &rootEndedAt) + if i == 0 { + firstThreadID = root.ID + } + + // Token usage on root: 100 input, 50 output, with cache metadata. + dbgen.AIBridgeTokenUsage(t, db, database.InsertAIBridgeTokenUsageParams{ + InterceptionID: root.ID, + ProviderResponseID: "resp-root", + InputTokens: 100, + OutputTokens: 50, + Metadata: json.RawMessage(`{"cache_read_input": 20, "cache_creation_input": 5}`), + CreatedAt: now.Add(offset), + }) + + // Add a child interception with its own token usage. + childEndedAt := now.Add(offset + 45*time.Minute) + child := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ + InitiatorID: firstUser.UserID, + Provider: "anthropic", + Model: "claude-4", + StartedAt: now.Add(offset + 15*time.Minute), + ClientSessionID: sql.NullString{String: "token-agg-session", Valid: true}, + ThreadRootInterceptionID: uuid.NullUUID{UUID: root.ID, Valid: true}, + ThreadParentInterceptionID: uuid.NullUUID{UUID: root.ID, Valid: true}, + }, &childEndedAt) + + // Token usage on child: 200 input, 100 output, with cache metadata. + dbgen.AIBridgeTokenUsage(t, db, database.InsertAIBridgeTokenUsageParams{ + InterceptionID: child.ID, + ProviderResponseID: "resp-child", + InputTokens: 200, + OutputTokens: 100, + Metadata: json.RawMessage(`{"cache_read_input": 30}`), + CreatedAt: now.Add(offset + 15*time.Minute), + }) + } + + // Request only the first thread (limit=1). The session-level + // token summary must still reflect ALL 3 threads. + res, err := client.AIBridgeGetSessionThreads(ctx, "token-agg-session", uuid.Nil, uuid.Nil, 1) + require.NoError(t, err) + require.Len(t, res.Threads, 1) + require.Equal(t, firstThreadID, res.Threads[0].ID) + + // Per-thread token usage: root(100) + child(200) = 300 input. + require.EqualValues(t, 300, res.Threads[0].TokenUsage.InputTokens) + require.EqualValues(t, 150, res.Threads[0].TokenUsage.OutputTokens) + + // Session-level summary must include tokens from all 3 threads + // (3 * 300 input, 3 * 150 output), not just the single page. + require.EqualValues(t, 900, res.TokenUsageSummary.InputTokens) + require.EqualValues(t, 450, res.TokenUsageSummary.OutputTokens) + + // Session-level metadata must aggregate across all 3 threads: + // cache_read_input: 3 * (root 20 + child 30) = 150 + // cache_creation_input: 3 * (root 5) = 15 + require.NotEmpty(t, res.TokenUsageSummary.Metadata) + require.EqualValues(t, int64(150), res.TokenUsageSummary.Metadata["cache_read_input"]) + require.EqualValues(t, int64(15), res.TokenUsageSummary.Metadata["cache_creation_input"]) + }) + + t.Run("InvalidCursor", func(t *testing.T) { + t.Parallel() + client, db, firstUser := coderdenttest.NewWithDatabase(t, aibridgeOpts(t)) + ctx := testutil.Context(t, testutil.WaitLong) + + now := dbtime.Now() + endedAt := now.Add(time.Minute) + dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ + InitiatorID: firstUser.UserID, + Provider: "anthropic", + Model: "claude-4", + StartedAt: now, + ClientSessionID: sql.NullString{String: "cursor-test-session", Valid: true}, + }, &endedAt) + + // A completely nonexistent UUID as after_id should return 400. + _, err := client.AIBridgeGetSessionThreads(ctx, "cursor-test-session", uuid.New(), uuid.Nil, 0) + var sdkErr *codersdk.Error + require.ErrorAs(t, err, &sdkErr) + require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode()) + require.Contains(t, sdkErr.Message, "Invalid pagination cursor") + + // A nonexistent UUID as before_id should also return 400. + _, err = client.AIBridgeGetSessionThreads(ctx, "cursor-test-session", uuid.Nil, uuid.New(), 0) + require.ErrorAs(t, err, &sdkErr) + require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode()) + require.Contains(t, sdkErr.Message, "Invalid pagination cursor") + + // An interception from a different session should also return 400. + otherEndedAt := now.Add(time.Minute) + otherInterception := dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ + InitiatorID: firstUser.UserID, + Provider: "anthropic", + Model: "claude-4", + StartedAt: now, + ClientSessionID: sql.NullString{String: "other-session", Valid: true}, + }, &otherEndedAt) + + _, err = client.AIBridgeGetSessionThreads(ctx, "cursor-test-session", otherInterception.ID, uuid.Nil, 0) + require.ErrorAs(t, err, &sdkErr) + require.Equal(t, http.StatusBadRequest, sdkErr.StatusCode()) + require.Contains(t, sdkErr.Message, "Invalid pagination cursor") + require.Contains(t, sdkErr.Detail, "does not belong to session") + }) + + t.Run("Authorization", func(t *testing.T) { + t.Parallel() + ownerClient, db, firstUser := coderdenttest.NewWithDatabase(t, aibridgeOpts(t)) + ctx := testutil.Context(t, testutil.WaitLong) + + memberClient, member := coderdtest.CreateAnotherUser(t, ownerClient, firstUser.OrganizationID) + + now := dbtime.Now() + endedAt := now.Add(time.Minute) + + // Create a session owned by the owner. + dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ + InitiatorID: firstUser.UserID, + Provider: "anthropic", + Model: "claude-4", + StartedAt: now, + ClientSessionID: sql.NullString{String: "owner-session", Valid: true}, + }, &endedAt) + + // Owner can see their own session. + res, err := ownerClient.AIBridgeGetSessionThreads(ctx, "owner-session", uuid.Nil, uuid.Nil, 0) + require.NoError(t, err) + require.Equal(t, "owner-session", res.ID) + + // Member cannot see the owner's session. + _, err = memberClient.AIBridgeGetSessionThreads(ctx, "owner-session", uuid.Nil, uuid.Nil, 0) + var sdkErr *codersdk.Error + require.ErrorAs(t, err, &sdkErr) + require.Equal(t, http.StatusNotFound, sdkErr.StatusCode()) + + // Create a session owned by the member. + dbgen.AIBridgeInterception(t, db, database.InsertAIBridgeInterceptionParams{ + InitiatorID: member.ID, + Provider: "anthropic", + Model: "claude-4", + StartedAt: now, + ClientSessionID: sql.NullString{String: "member-session", Valid: true}, + }, &endedAt) + + // Member cannot see their own session either (no read permission). + _, err = memberClient.AIBridgeGetSessionThreads(ctx, "member-session", uuid.Nil, uuid.Nil, 0) + require.ErrorAs(t, err, &sdkErr) + require.Equal(t, http.StatusNotFound, sdkErr.StatusCode()) + }) +} diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index 70548639be..1025682917 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -10,6 +10,18 @@ export interface ACLAvailable { readonly groups: readonly Group[]; } +// From codersdk/aibridge.go +/** + * AIBridgeAgenticAction represents a tool call with associated + * thinking blocks and token usage from one or more interceptions. + */ +export interface AIBridgeAgenticAction { + readonly model: string; + readonly token_usage: AIBridgeSessionThreadsTokenUsage; + readonly thinking: readonly AIBridgeModelThought[]; + readonly tool_calls: readonly AIBridgeToolCall[]; +} + // From codersdk/deployment.go export interface AIBridgeAnthropicConfig { readonly base_url: string; @@ -81,6 +93,15 @@ export interface AIBridgeListSessionsResponse { readonly sessions: readonly AIBridgeSession[]; } +// From codersdk/aibridge.go +/** + * AIBridgeModelThought represents a single thinking block from + * the model. + */ +export interface AIBridgeModelThought { + readonly text: string; +} + // From codersdk/deployment.go export interface AIBridgeOpenAIConfig { readonly base_url: string; @@ -117,12 +138,63 @@ export interface AIBridgeSession { readonly last_prompt?: string; } +// From codersdk/aibridge.go +/** + * AIBridgeSessionThreadsResponse is the response for GET + * /api/v2/aibridge/sessions/{session_id} which returns a single + * session with fully expanded threads. + */ +export interface AIBridgeSessionThreadsResponse { + readonly id: string; + readonly initiator: MinimalUser; + readonly providers: readonly string[]; + readonly models: readonly string[]; + readonly client?: string; + // empty interface{} type, falling back to unknown + readonly metadata: Record; + readonly page_started_at?: string; + readonly page_ended_at?: string; + readonly started_at: string; + readonly ended_at?: string; + readonly token_usage_summary: AIBridgeSessionThreadsTokenUsage; + readonly threads: readonly AIBridgeThread[]; +} + +// From codersdk/aibridge.go +/** + * AIBridgeSessionThreadsTokenUsage represents aggregated token usage + * with metadata containing provider-specific fields like + * cache_creation_input, cache_read_input, etc. + */ +export interface AIBridgeSessionThreadsTokenUsage { + readonly input_tokens: number; + readonly output_tokens: number; + // empty interface{} type, falling back to unknown + readonly metadata: Record; +} + // From codersdk/aibridge.go export interface AIBridgeSessionTokenUsageSummary { readonly input_tokens: number; readonly output_tokens: number; } +// From codersdk/aibridge.go +/** + * AIBridgeThread represents a single thread within a session. + * A thread groups interceptions by their thread_root_id. + */ +export interface AIBridgeThread { + readonly id: string; + readonly prompt?: string; + readonly model: string; + readonly provider: string; + readonly started_at: string; + readonly ended_at?: string; + readonly token_usage: AIBridgeSessionThreadsTokenUsage; + readonly agentic_actions: readonly AIBridgeAgenticAction[]; +} + // From codersdk/aibridge.go export interface AIBridgeTokenUsage { readonly id: string; @@ -135,6 +207,24 @@ export interface AIBridgeTokenUsage { readonly created_at: string; } +// From codersdk/aibridge.go +/** + * AIBridgeToolCall represents a tool call recorded during an + * interception. + */ +export interface AIBridgeToolCall { + readonly id: string; + readonly interception_id: string; + readonly provider_response_id: string; + readonly server_url: string; + readonly tool: string; + readonly injected: boolean; + readonly input: string; + // empty interface{} type, falling back to unknown + readonly metadata: Record; + readonly created_at: string; +} + // From codersdk/aibridge.go export interface AIBridgeToolUsage { readonly id: string;