From c3b6284955e498005a01b8fa1f5adf4c3a5f9f1c Mon Sep 17 00:00:00 2001 From: Michael Suchacz <203725896+ibetitsmike@users.noreply.github.com> Date: Fri, 13 Mar 2026 18:30:49 +0100 Subject: [PATCH] feat: add chat cost analytics backend (#23036) Add cost tracking for LLM chat interactions with microdollar precision. ## Changes - Add `chatcost` package for per-message cost calculation using `shopspring/decimal` for intermediate arithmetic - **Ceil rounding policy**: fractional micros round UP to next whole micro (applied once after summing all components) - Database migration: `total_cost_micros` BIGINT column with historical backfill and `created_at` index - API endpoints: per-user cost summary and admin rollup under `/api/experimental/chats/cost/` - SDK types: `ChatCostSummary`, `ChatCostModelBreakdown`, `ChatCostUserRollup` - Fix `modeloptionsgen` to handle `decimal.Decimal` as opaque numeric type - Update frontend pricing test fixtures for string decimal types ## Design decisions - `NULL` = unpriced (no matching model config), `0` = free - Reasoning tokens included in output tokens (no double-counting) - Integer microdollars (BIGINT) for storage and API responses - Price config uses `decimal.Decimal` for exact parsing; totals use `int64` Frontend: #23037 --- AGENTS.md | 25 ++ coderd/chatd/chatcost/chatcost.go | 71 ++++ coderd/chatd/chatcost/chatcost_test.go | 163 ++++++++ coderd/chatd/chatd.go | 49 +++ coderd/chatd/chatprovider/chatprovider.go | 28 -- .../chatd/chatprovider/chatprovider_test.go | 59 --- coderd/chats.go | 233 ++++++++++- coderd/chats_test.go | 339 +++++++++++++++- coderd/coderd.go | 7 + coderd/database/dbauthz/dbauthz.go | 28 ++ coderd/database/dbauthz/dbauthz_test.go | 75 ++++ coderd/database/dbmetrics/querymetrics.go | 32 ++ coderd/database/dbmock/dbmock.go | 60 +++ coderd/database/dump.sql | 5 +- .../000435_add_cost_to_chat_messages.down.sql | 3 + .../000435_add_cost_to_chat_messages.up.sql | 68 ++++ coderd/database/models.go | 1 + coderd/database/querier.go | 13 + coderd/database/queries.sql.go | 373 +++++++++++++++++- coderd/database/queries/chats.sql | 167 +++++++- coderd/database/sqlc.yaml | 11 + codersdk/chats.go | 161 +++++++- codersdk/chats_test.go | 63 +++ docs/reference/api/chat.md | 1 + go.mod | 1 + go.sum | 2 + scripts/apitypings/main.go | 4 + scripts/modeloptionsgen/main.go | 14 +- site/src/api/typesGenerated.ts | 103 ++++- .../ModelsSection.test.tsx | 99 ----- .../modelConfigFormLogic.test.ts | 16 +- .../modelConfigFormLogic.ts | 2 +- .../ChatModelAdminPanel/pricingFields.test.ts | 8 +- .../ChatModelAdminPanel/pricingFields.ts | 12 +- 34 files changed, 2034 insertions(+), 262 deletions(-) create mode 100644 coderd/chatd/chatcost/chatcost.go create mode 100644 coderd/chatd/chatcost/chatcost_test.go create mode 100644 coderd/database/migrations/000435_add_cost_to_chat_messages.down.sql create mode 100644 coderd/database/migrations/000435_add_cost_to_chat_messages.up.sql create mode 100644 docs/reference/api/chat.md delete mode 100644 site/src/pages/AgentsPage/ChatModelAdminPanel/ModelsSection.test.tsx diff --git a/AGENTS.md b/AGENTS.md index 832f090978..c3abfa662b 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -100,6 +100,31 @@ app, err := api.Database.GetOAuth2ProviderAppByClientID(dbauthz.AsSystemRestrict app, err := api.Database.GetOAuth2ProviderAppByClientID(ctx, clientID) ``` +### API Design + +- Add swagger annotations when introducing new HTTP endpoints. Do this in + the same change as the handler so the docs do not get missed before + release. +- For user-scoped or resource-scoped routes, prefer path parameters over + query parameters when that matches existing route patterns. +- For experimental or unstable API paths, skip public doc generation with + `// @x-apidocgen {"skip": true}` after the `@Router` annotation. This + keeps them out of the published API reference until they stabilize. + +### Database Query Naming + +- Use `ByX` when `X` is the lookup or filter column. +- Use `PerX` or `GroupedByX` when `X` is the aggregation or grouping + dimension. +- Avoid `ByX` names for grouped queries. + +### Database-to-SDK Conversions + +- Extract explicit db-to-SDK conversion helpers instead of inlining large + conversion blocks inside handlers. +- Keep nullable-field handling, type coercion, and response shaping in the + converter so handlers stay focused on request flow and authorization. + ## Quick Reference ### Full workflows available in imported WORKFLOWS.md diff --git a/coderd/chatd/chatcost/chatcost.go b/coderd/chatd/chatcost/chatcost.go new file mode 100644 index 0000000000..a3a04f14a4 --- /dev/null +++ b/coderd/chatd/chatcost/chatcost.go @@ -0,0 +1,71 @@ +package chatcost + +import ( + "github.com/shopspring/decimal" + + "github.com/coder/coder/v2/coderd/util/ptr" + "github.com/coder/coder/v2/codersdk" +) + +// Returns cost in micros -- millionths of a dollar, rounded up to the next +// whole microdollar. +// Returns nil when pricing is not configured or when all priced usage fields +// are nil, allowing callers to distinguish "zero cost" from "unpriced". +func CalculateTotalCostMicros( + usage codersdk.ChatMessageUsage, + cost *codersdk.ModelCostConfig, +) *int64 { + if cost == nil { + return nil + } + + // A cost config with no prices set means pricing is effectively + // unconfigured — return nil (unpriced) rather than zero. + if cost.InputPricePerMillionTokens == nil && + cost.OutputPricePerMillionTokens == nil && + cost.CacheReadPricePerMillionTokens == nil && + cost.CacheWritePricePerMillionTokens == nil { + return nil + } + + if usage.InputTokens == nil && + usage.OutputTokens == nil && + usage.ReasoningTokens == nil && + usage.CacheCreationTokens == nil && + usage.CacheReadTokens == nil { + return nil + } + + // OutputTokens already includes reasoning tokens per provider + // semantics (e.g. OpenAI's completion_tokens encompasses + // reasoning_tokens). Adding ReasoningTokens here would + // double-count. + + // Preserve nil when usage exists only in categories without configured + // pricing, so callers can distinguish "unpriced" from "priced at zero". + hasMatchingPrice := (usage.InputTokens != nil && cost.InputPricePerMillionTokens != nil) || + (usage.OutputTokens != nil && cost.OutputPricePerMillionTokens != nil) || + (usage.CacheReadTokens != nil && cost.CacheReadPricePerMillionTokens != nil) || + (usage.CacheCreationTokens != nil && cost.CacheWritePricePerMillionTokens != nil) + if !hasMatchingPrice { + return nil + } + + inputMicros := calcCost(usage.InputTokens, cost.InputPricePerMillionTokens) + outputMicros := calcCost(usage.OutputTokens, cost.OutputPricePerMillionTokens) + cacheReadMicros := calcCost(usage.CacheReadTokens, cost.CacheReadPricePerMillionTokens) + cacheWriteMicros := calcCost(usage.CacheCreationTokens, cost.CacheWritePricePerMillionTokens) + + total := inputMicros. + Add(outputMicros). + Add(cacheReadMicros). + Add(cacheWriteMicros) + rounded := total.Ceil().IntPart() + return &rounded +} + +// calcCost returns the cost in fractional microdollars (millionths of a USD) +// for the given token count at the specified per-million-token price. +func calcCost(tokens *int64, pricePerMillion *decimal.Decimal) decimal.Decimal { + return decimal.NewFromInt(ptr.NilToEmpty(tokens)).Mul(ptr.NilToEmpty(pricePerMillion)) +} diff --git a/coderd/chatd/chatcost/chatcost_test.go b/coderd/chatd/chatcost/chatcost_test.go new file mode 100644 index 0000000000..0142f4f612 --- /dev/null +++ b/coderd/chatd/chatcost/chatcost_test.go @@ -0,0 +1,163 @@ +package chatcost_test + +import ( + "testing" + + "github.com/shopspring/decimal" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/coderd/chatd/chatcost" + "github.com/coder/coder/v2/coderd/util/ptr" + "github.com/coder/coder/v2/codersdk" +) + +func TestCalculateTotalCostMicros(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + usage codersdk.ChatMessageUsage + cost *codersdk.ModelCostConfig + want *int64 + }{ + { + name: "nil cost returns nil", + usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](1000)}, + cost: nil, + want: nil, + }, + { + name: "all priced usage fields nil returns nil", + usage: codersdk.ChatMessageUsage{ + TotalTokens: ptr.Ref[int64](1234), + ContextLimit: ptr.Ref[int64](8192), + }, + cost: &codersdk.ModelCostConfig{ + InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("3")), + }, + want: nil, + }, + { + name: "sub-micro total rounds up to 1", + usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](1)}, + cost: &codersdk.ModelCostConfig{ + InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("0.01")), + }, + want: ptr.Ref[int64](1), + }, + { + name: "simple input only", + usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](1000)}, + cost: &codersdk.ModelCostConfig{ + InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("3")), + }, + want: ptr.Ref[int64](3000), + }, + { + name: "simple output only", + usage: codersdk.ChatMessageUsage{OutputTokens: ptr.Ref[int64](500)}, + cost: &codersdk.ModelCostConfig{ + OutputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("15")), + }, + want: ptr.Ref[int64](7500), + }, + { + name: "reasoning tokens included in output total", + usage: codersdk.ChatMessageUsage{ + OutputTokens: ptr.Ref[int64](500), + ReasoningTokens: ptr.Ref[int64](200), + }, + cost: &codersdk.ModelCostConfig{ + OutputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("15")), + }, + want: ptr.Ref[int64](7500), + }, + { + name: "cache read tokens", + usage: codersdk.ChatMessageUsage{CacheReadTokens: ptr.Ref[int64](10000)}, + cost: &codersdk.ModelCostConfig{ + CacheReadPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("0.3")), + }, + want: ptr.Ref[int64](3000), + }, + { + name: "cache creation tokens", + usage: codersdk.ChatMessageUsage{CacheCreationTokens: ptr.Ref[int64](5000)}, + cost: &codersdk.ModelCostConfig{ + CacheWritePricePerMillionTokens: ptr.Ref(decimal.RequireFromString("3.75")), + }, + want: ptr.Ref[int64](18750), + }, + { + name: "full mixed usage totals all components exactly", + usage: codersdk.ChatMessageUsage{ + InputTokens: ptr.Ref[int64](101), + OutputTokens: ptr.Ref[int64](201), + ReasoningTokens: ptr.Ref[int64](52), + CacheReadTokens: ptr.Ref[int64](1005), + CacheCreationTokens: ptr.Ref[int64](33), + TotalTokens: ptr.Ref[int64](1391), + ContextLimit: ptr.Ref[int64](4096), + }, + cost: &codersdk.ModelCostConfig{ + InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("1.23")), + OutputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("4.56")), + CacheReadPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("0.7")), + CacheWritePricePerMillionTokens: ptr.Ref(decimal.RequireFromString("7.89")), + }, + want: ptr.Ref[int64](2005), + }, + { + name: "partial pricing only input contributes", + usage: codersdk.ChatMessageUsage{ + InputTokens: ptr.Ref[int64](1234), + OutputTokens: ptr.Ref[int64](999), + ReasoningTokens: ptr.Ref[int64](111), + CacheReadTokens: ptr.Ref[int64](500), + CacheCreationTokens: ptr.Ref[int64](250), + }, + cost: &codersdk.ModelCostConfig{ + InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("2.5")), + }, + want: ptr.Ref[int64](3085), + }, + { + name: "zero tokens with pricing returns zero pointer", + usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](0)}, + cost: &codersdk.ModelCostConfig{ + InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("3")), + }, + want: ptr.Ref[int64](0), + }, + { + name: "usage only in unpriced categories returns nil", + usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](1000)}, + cost: &codersdk.ModelCostConfig{ + OutputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("15")), + }, + want: nil, + }, + { + name: "non nil usage with empty cost config returns nil", + usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](42)}, + cost: &codersdk.ModelCostConfig{}, + want: nil, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got := chatcost.CalculateTotalCostMicros(tt.usage, tt.cost) + + if tt.want == nil { + require.Nil(t, got) + } else { + require.NotNil(t, got) + require.Equal(t, *tt.want, *got) + } + }) + } +} diff --git a/coderd/chatd/chatd.go b/coderd/chatd/chatd.go index fbdb59b267..3093fb1cfe 100644 --- a/coderd/chatd/chatd.go +++ b/coderd/chatd/chatd.go @@ -19,6 +19,7 @@ import ( "golang.org/x/xerrors" "cdr.dev/slog/v3" + "github.com/coder/coder/v2/coderd/chatd/chatcost" "github.com/coder/coder/v2/coderd/chatd/chatloop" "github.com/coder/coder/v2/coderd/chatd/chatprompt" "github.com/coder/coder/v2/coderd/chatd/chatprovider" @@ -293,6 +294,7 @@ func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.C CacheReadTokens: sql.NullInt64{}, ContextLimit: sql.NullInt64{}, Compressed: sql.NullBool{}, + TotalCostMicros: sql.NullInt64{}, }) if err != nil { return xerrors.Errorf("insert system message: %w", err) @@ -321,6 +323,7 @@ func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.C CacheCreationTokens: sql.NullInt64{}, CacheReadTokens: sql.NullInt64{}, ContextLimit: sql.NullInt64{}, + TotalCostMicros: sql.NullInt64{}, Compressed: sql.NullBool{}, }) if err != nil { @@ -910,6 +913,7 @@ func insertUserMessageAndSetPending( CacheCreationTokens: sql.NullInt64{}, CacheReadTokens: sql.NullInt64{}, ContextLimit: sql.NullInt64{}, + TotalCostMicros: sql.NullInt64{}, Compressed: sql.NullBool{}, }) if err != nil { @@ -1969,6 +1973,7 @@ func (p *Server) processChat(ctx context.Context, chat database.Chat) { CacheCreationTokens: sql.NullInt64{}, CacheReadTokens: sql.NullInt64{}, ContextLimit: sql.NullInt64{}, + TotalCostMicros: sql.NullInt64{}, Compressed: sql.NullBool{}, }) if insertErr != nil { @@ -2347,6 +2352,30 @@ func (p *Server) runChat( } hasUsage := step.Usage != (fantasy.Usage{}) + var usageForCost codersdk.ChatMessageUsage + if hasUsage { + // Only populate fields that the provider explicitly + // reported. Nil fields tell the calculator "no data" + // vs zero meaning "reported as zero tokens." + if step.Usage.InputTokens != 0 { + usageForCost.InputTokens = int64Ptr(step.Usage.InputTokens) + } + if step.Usage.OutputTokens != 0 { + usageForCost.OutputTokens = int64Ptr(step.Usage.OutputTokens) + } + if step.Usage.ReasoningTokens != 0 { + usageForCost.ReasoningTokens = int64Ptr(step.Usage.ReasoningTokens) + } + if step.Usage.CacheCreationTokens != 0 { + usageForCost.CacheCreationTokens = int64Ptr(step.Usage.CacheCreationTokens) + } + if step.Usage.CacheReadTokens != 0 { + usageForCost.CacheReadTokens = int64Ptr(step.Usage.CacheReadTokens) + } + } + + totalCostMicros := chatcost.CalculateTotalCostMicros(usageForCost, callConfig.Cost) + assistantMessage, insertErr := tx.InsertChatMessage(persistCtx, database.InsertChatMessageParams{ ChatID: chat.ID, CreatedBy: uuid.NullUUID{}, @@ -2369,6 +2398,11 @@ func (p *Server) runChat( CacheReadTokens: usageNullInt64(step.Usage.CacheReadTokens, hasUsage), ContextLimit: step.ContextLimit, Compressed: sql.NullBool{}, + // TotalCostMicros is nullable: NULL means "unpriced" + // (pricing config was missing or no priced token + // breakdown available), while 0 means "priced at + // zero cost" (e.g., a free model). + TotalCostMicros: usageNullInt64Ptr(totalCostMicros), }) if insertErr != nil { return xerrors.Errorf("insert assistant message: %w", insertErr) @@ -2398,6 +2432,7 @@ func (p *Server) runChat( CacheCreationTokens: sql.NullInt64{}, CacheReadTokens: sql.NullInt64{}, ContextLimit: sql.NullInt64{}, + TotalCostMicros: sql.NullInt64{}, Compressed: sql.NullBool{}, }) if insertErr != nil { @@ -2740,6 +2775,7 @@ func (p *Server) persistChatContextSummary( CacheCreationTokens: sql.NullInt64{}, CacheReadTokens: sql.NullInt64{}, ContextLimit: sql.NullInt64{}, + TotalCostMicros: sql.NullInt64{}, }) if txErr != nil { return xerrors.Errorf("insert hidden summary message: %w", txErr) @@ -2764,6 +2800,7 @@ func (p *Server) persistChatContextSummary( CacheCreationTokens: sql.NullInt64{}, CacheReadTokens: sql.NullInt64{}, ContextLimit: sql.NullInt64{}, + TotalCostMicros: sql.NullInt64{}, }) if txErr != nil { return xerrors.Errorf("insert summary tool call message: %w", txErr) @@ -2789,6 +2826,7 @@ func (p *Server) persistChatContextSummary( CacheCreationTokens: sql.NullInt64{}, CacheReadTokens: sql.NullInt64{}, ContextLimit: sql.NullInt64{}, + TotalCostMicros: sql.NullInt64{}, }) if txErr != nil { return xerrors.Errorf("insert summary tool result message: %w", txErr) @@ -2901,6 +2939,10 @@ func (p *Server) resolveModelConfig( return defaultConfig, nil } +func int64Ptr(value int64) *int64 { + return &value +} + //nolint:revive // Boolean controls SQL NULL validity. func usageNullInt64(value int64, valid bool) sql.NullInt64 { if !valid { @@ -2912,6 +2954,13 @@ func usageNullInt64(value int64, valid bool) sql.NullInt64 { } } +func usageNullInt64Ptr(v *int64) sql.NullInt64 { + if v == nil { + return sql.NullInt64{} + } + return sql.NullInt64{Int64: *v, Valid: true} +} + func refreshChatWorkspaceSnapshot( ctx context.Context, chat database.Chat, diff --git a/coderd/chatd/chatprovider/chatprovider.go b/coderd/chatd/chatprovider/chatprovider.go index 26447cefde..edef337e7b 100644 --- a/coderd/chatd/chatprovider/chatprovider.go +++ b/coderd/chatd/chatprovider/chatprovider.go @@ -553,34 +553,6 @@ func normalizedEnumValue(value string, allowed ...string) *string { return nil } -// MergeMissingCallConfig fills unset call config values from a provider or -// profile default config. -func MergeMissingCallConfig( - dst *codersdk.ChatModelCallConfig, - defaults codersdk.ChatModelCallConfig, -) { - if dst.MaxOutputTokens == nil { - dst.MaxOutputTokens = defaults.MaxOutputTokens - } - if dst.Temperature == nil { - dst.Temperature = defaults.Temperature - } - if dst.TopP == nil { - dst.TopP = defaults.TopP - } - if dst.TopK == nil { - dst.TopK = defaults.TopK - } - if dst.PresencePenalty == nil { - dst.PresencePenalty = defaults.PresencePenalty - } - if dst.FrequencyPenalty == nil { - dst.FrequencyPenalty = defaults.FrequencyPenalty - } - MergeMissingModelCostConfig(&dst.Cost, defaults.Cost) - MergeMissingProviderOptions(&dst.ProviderOptions, defaults.ProviderOptions) -} - // MergeMissingModelCostConfig fills unset pricing metadata from defaults. func MergeMissingModelCostConfig( dst **codersdk.ModelCostConfig, diff --git a/coderd/chatd/chatprovider/chatprovider_test.go b/coderd/chatd/chatprovider/chatprovider_test.go index 920e8ea363..57f5e1708b 100644 --- a/coderd/chatd/chatprovider/chatprovider_test.go +++ b/coderd/chatd/chatprovider/chatprovider_test.go @@ -137,61 +137,6 @@ func TestMergeMissingProviderOptions_OpenRouterNested(t *testing.T) { require.Equal(t, "latency", *options.OpenRouter.Provider.Sort) } -func TestMergeMissingCallConfig_FillsUnsetFields(t *testing.T) { - t.Parallel() - - dst := codersdk.ChatModelCallConfig{ - Temperature: float64Ptr(0.2), - Cost: &codersdk.ModelCostConfig{ - OutputPricePerMillionTokens: float64Ptr(0.7), - }, - ProviderOptions: &codersdk.ChatModelProviderOptions{ - OpenAI: &codersdk.ChatModelOpenAIProviderOptions{ - User: stringPtr("alice"), - }, - }, - } - defaultCallConfig := codersdk.ChatModelCallConfig{ - MaxOutputTokens: int64Ptr(512), - Temperature: float64Ptr(0.9), - TopP: float64Ptr(0.8), - Cost: &codersdk.ModelCostConfig{ - InputPricePerMillionTokens: float64Ptr(0.15), - OutputPricePerMillionTokens: float64Ptr(0.9), - CacheReadPricePerMillionTokens: float64Ptr(0.03), - CacheWritePricePerMillionTokens: float64Ptr(0.3), - }, - ProviderOptions: &codersdk.ChatModelProviderOptions{ - OpenAI: &codersdk.ChatModelOpenAIProviderOptions{ - User: stringPtr("bob"), - ReasoningEffort: stringPtr("medium"), - }, - }, - } - - chatprovider.MergeMissingCallConfig(&dst, defaultCallConfig) - - require.NotNil(t, dst.MaxOutputTokens) - require.EqualValues(t, 512, *dst.MaxOutputTokens) - require.NotNil(t, dst.Temperature) - require.Equal(t, 0.2, *dst.Temperature) - require.NotNil(t, dst.TopP) - require.Equal(t, 0.8, *dst.TopP) - require.NotNil(t, dst.Cost) - require.NotNil(t, dst.Cost.InputPricePerMillionTokens) - require.Equal(t, 0.15, *dst.Cost.InputPricePerMillionTokens) - require.NotNil(t, dst.Cost.OutputPricePerMillionTokens) - require.Equal(t, 0.7, *dst.Cost.OutputPricePerMillionTokens) - require.NotNil(t, dst.Cost.CacheReadPricePerMillionTokens) - require.Equal(t, 0.03, *dst.Cost.CacheReadPricePerMillionTokens) - require.NotNil(t, dst.Cost.CacheWritePricePerMillionTokens) - require.Equal(t, 0.3, *dst.Cost.CacheWritePricePerMillionTokens) - require.NotNil(t, dst.ProviderOptions) - require.NotNil(t, dst.ProviderOptions.OpenAI) - require.Equal(t, "alice", *dst.ProviderOptions.OpenAI.User) - require.Equal(t, "medium", *dst.ProviderOptions.OpenAI.ReasoningEffort) -} - func stringPtr(value string) *string { return &value } @@ -203,7 +148,3 @@ func boolPtr(value bool) *bool { func int64Ptr(value int64) *int64 { return &value } - -func float64Ptr(value float64) *float64 { - return &value -} diff --git a/coderd/chats.go b/coderd/chats.go index 48aea9ca4b..3e12728657 100644 --- a/coderd/chats.go +++ b/coderd/chats.go @@ -9,6 +9,7 @@ import ( "errors" "fmt" "io" + "math" "mime" "net/http" "net/http/httptest" @@ -20,6 +21,7 @@ import ( "github.com/go-chi/chi/v5" "github.com/google/uuid" + "github.com/shopspring/decimal" "golang.org/x/xerrors" "cdr.dev/slog/v3" @@ -354,6 +356,187 @@ func (api *API) listChatModels(rw http.ResponseWriter, r *http.Request) { httpapi.Write(ctx, rw, http.StatusOK, response) } +func (api *API) chatCostSummary(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + apiKey := httpmw.APIKey(r) + + // Default date range: last 30 days. + now := time.Now() + defaultStart := now.AddDate(0, 0, -30) + + qp := r.URL.Query() + p := httpapi.NewQueryParamParser() + startDate := p.Time(qp, defaultStart, "start_date", time.RFC3339) + endDate := p.Time(qp, now, "end_date", time.RFC3339) + p.ErrorExcessParams(qp) + if len(p.Errors) > 0 { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid query parameters.", + Validations: p.Errors, + }) + return + } + + targetUser := httpmw.UserParam(r) + if targetUser.ID != apiKey.UserID && !api.Authorize(r, policy.ActionRead, rbac.ResourceChat.WithOwner(targetUser.ID.String())) { + httpapi.Forbidden(rw) + return + } + + summary, err := api.Database.GetChatCostSummary(ctx, database.GetChatCostSummaryParams{ + OwnerID: targetUser.ID, + StartDate: startDate, + EndDate: endDate, + }) + if err != nil { + httpapi.InternalServerError(rw, err) + return + } + + byModel, err := api.Database.GetChatCostPerModel(ctx, database.GetChatCostPerModelParams{ + OwnerID: targetUser.ID, + StartDate: startDate, + EndDate: endDate, + }) + if err != nil { + httpapi.InternalServerError(rw, err) + return + } + + byChat, err := api.Database.GetChatCostPerChat(ctx, database.GetChatCostPerChatParams{ + OwnerID: targetUser.ID, + StartDate: startDate, + EndDate: endDate, + }) + if err != nil { + httpapi.InternalServerError(rw, err) + return + } + + modelBreakdowns := make([]codersdk.ChatCostModelBreakdown, 0, len(byModel)) + for _, model := range byModel { + modelBreakdowns = append(modelBreakdowns, convertChatCostModelBreakdown(model)) + } + + chatBreakdowns := make([]codersdk.ChatCostChatBreakdown, 0, len(byChat)) + for _, chat := range byChat { + chatBreakdowns = append(chatBreakdowns, convertChatCostChatBreakdown(chat)) + } + + httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatCostSummary{ + StartDate: startDate, + EndDate: endDate, + TotalCostMicros: summary.TotalCostMicros, + PricedMessageCount: summary.PricedMessageCount, + UnpricedMessageCount: summary.UnpricedMessageCount, + TotalInputTokens: summary.TotalInputTokens, + TotalOutputTokens: summary.TotalOutputTokens, + ByModel: modelBreakdowns, + ByChat: chatBreakdowns, + }) +} + +func (api *API) chatCostUsers(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + if !api.Authorize(r, policy.ActionRead, rbac.ResourceChat) { + httpapi.Forbidden(rw) + return + } + + now := time.Now() + defaultStart := now.AddDate(0, 0, -30) + + qp := r.URL.Query() + p := httpapi.NewQueryParamParser() + startDate := p.Time(qp, defaultStart, "start_date", time.RFC3339) + endDate := p.Time(qp, now, "end_date", time.RFC3339) + username := strings.TrimSpace(p.String(qp, "", "username")) + limit := p.Int(qp, 10, "limit") + offset := p.Int(qp, 0, "offset") + p.ErrorExcessParams(qp) + if len(p.Errors) > 0 { + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid query parameters.", + Validations: p.Errors, + }) + return + } + if limit <= 0 { + limit = 10 + } + if offset < 0 || offset > math.MaxInt32 || limit > math.MaxInt32 { + validations := make([]codersdk.ValidationError, 0, 2) + if offset < 0 { + validations = append(validations, codersdk.ValidationError{ + Field: "offset", + Detail: "Must be greater than or equal to 0.", + }) + } + if offset > math.MaxInt32 { + validations = append(validations, codersdk.ValidationError{ + Field: "offset", + Detail: fmt.Sprintf("Must be less than or equal to %d.", math.MaxInt32), + }) + } + if limit > math.MaxInt32 { + validations = append(validations, codersdk.ValidationError{ + Field: "limit", + Detail: fmt.Sprintf("Must be less than or equal to %d.", math.MaxInt32), + }) + } + httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ + Message: "Invalid query parameters.", + Validations: validations, + }) + return + } + + users, err := api.Database.GetChatCostPerUser(ctx, database.GetChatCostPerUserParams{ + StartDate: startDate, + EndDate: endDate, + Username: username, + // #nosec G115 - Pagination limits are validated to fit in int32 above. + PageLimit: int32(limit), + // #nosec G115 - Pagination offsets are validated to fit in int32 above. + PageOffset: int32(offset), + }) + if err != nil { + httpapi.InternalServerError(rw, err) + return + } + + rollups := make([]codersdk.ChatCostUserRollup, 0, len(users)) + count := int64(0) + for _, user := range users { + count = user.TotalCount + rollups = append(rollups, convertChatCostUserRollup(user)) + } + + if len(users) == 0 && offset > 0 { + countUsers, countErr := api.Database.GetChatCostPerUser(ctx, database.GetChatCostPerUserParams{ + StartDate: startDate, + EndDate: endDate, + Username: username, + PageLimit: 1, + PageOffset: 0, + }) + if countErr != nil { + httpapi.InternalServerError(rw, countErr) + return + } + if len(countUsers) > 0 { + count = countUsers[0].TotalCount + } + } + + httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatCostUsersResponse{ + StartDate: startDate, + EndDate: endDate, + Count: count, + Users: rollups, + }) +} + // EXPERIMENTAL: this endpoint is experimental and is subject to change. // //nolint:revive // HTTP handler writes to ResponseWriter. @@ -2172,6 +2355,48 @@ func convertChats(chats []database.Chat, diffStatusesByChatID map[uuid.UUID]data return result } +func convertChatCostModelBreakdown(model database.GetChatCostPerModelRow) codersdk.ChatCostModelBreakdown { + displayName := strings.TrimSpace(model.DisplayName) + if displayName == "" { + displayName = model.Model + } + return codersdk.ChatCostModelBreakdown{ + ModelConfigID: model.ModelConfigID, + DisplayName: displayName, + Provider: model.Provider, + Model: model.Model, + TotalCostMicros: model.TotalCostMicros, + MessageCount: model.MessageCount, + TotalInputTokens: model.TotalInputTokens, + TotalOutputTokens: model.TotalOutputTokens, + } +} + +func convertChatCostChatBreakdown(chat database.GetChatCostPerChatRow) codersdk.ChatCostChatBreakdown { + return codersdk.ChatCostChatBreakdown{ + RootChatID: chat.RootChatID, + ChatTitle: chat.ChatTitle, + TotalCostMicros: chat.TotalCostMicros, + MessageCount: chat.MessageCount, + TotalInputTokens: chat.TotalInputTokens, + TotalOutputTokens: chat.TotalOutputTokens, + } +} + +func convertChatCostUserRollup(user database.GetChatCostPerUserRow) codersdk.ChatCostUserRollup { + return codersdk.ChatCostUserRollup{ + UserID: user.UserID, + Username: user.Username, + Name: user.Name, + AvatarURL: user.AvatarURL, + TotalCostMicros: user.TotalCostMicros, + MessageCount: user.MessageCount, + ChatCount: user.ChatCount, + TotalInputTokens: user.TotalInputTokens, + TotalOutputTokens: user.TotalOutputTokens, + } +} + func convertChatQueuedMessage(m database.ChatQueuedMessage) codersdk.ChatQueuedMessage { return db2sdk.ChatQueuedMessage(m) } @@ -3117,7 +3342,7 @@ func validateChatModelCallConfig(modelConfig *codersdk.ChatModelCallConfig) erro pricingFields := []struct { name string - value *float64 + value *decimal.Decimal }{ {name: "cost.input_price_per_million_tokens", value: costConfig.InputPricePerMillionTokens}, {name: "cost.output_price_per_million_tokens", value: costConfig.OutputPricePerMillionTokens}, @@ -3125,7 +3350,7 @@ func validateChatModelCallConfig(modelConfig *codersdk.ChatModelCallConfig) erro {name: "cost.cache_write_price_per_million_tokens", value: costConfig.CacheWritePricePerMillionTokens}, } for _, field := range pricingFields { - if err := validateNonNegativeFloat64Field(field.name, field.value); err != nil { + if err := validateNonNegativeDecimalField(field.name, field.value); err != nil { return err } } @@ -3133,11 +3358,11 @@ func validateChatModelCallConfig(modelConfig *codersdk.ChatModelCallConfig) erro return nil } -func validateNonNegativeFloat64Field(name string, value *float64) error { +func validateNonNegativeDecimalField(name string, value *decimal.Decimal) error { if value == nil { return nil } - if *value < 0 { + if value.IsNegative() { return xerrors.Errorf("%s must be greater than or equal to zero", name) } return nil diff --git a/coderd/chats_test.go b/coderd/chats_test.go index 17c31202e5..bde9cb3957 100644 --- a/coderd/chats_test.go +++ b/coderd/chats_test.go @@ -13,6 +13,7 @@ import ( "time" "github.com/google/uuid" + "github.com/shopspring/decimal" "github.com/stretchr/testify/require" "github.com/coder/coder/v2/coderd/coderdtest" @@ -22,7 +23,6 @@ import ( "github.com/coder/coder/v2/coderd/database/dbfake" "github.com/coder/coder/v2/coderd/externalauth" coderdpubsub "github.com/coder/coder/v2/coderd/pubsub" - "github.com/coder/coder/v2/coderd/util/ptr" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/testutil" "github.com/coder/websocket" @@ -939,10 +939,10 @@ func TestListChatModelConfigs(t *testing.T) { require.Equal(t, storedConfig.ID, configs[0].ID) requireChatModelPricing(t, configs[0].ModelConfig, &codersdk.ChatModelCallConfig{ Cost: &codersdk.ModelCostConfig{ - InputPricePerMillionTokens: ptr.Ref(0.15), - OutputPricePerMillionTokens: ptr.Ref(0.6), - CacheReadPricePerMillionTokens: ptr.Ref(0.03), - CacheWritePricePerMillionTokens: ptr.Ref(0.3), + InputPricePerMillionTokens: decRef("0.15"), + OutputPricePerMillionTokens: decRef("0.6"), + CacheReadPricePerMillionTokens: decRef("0.03"), + CacheWritePricePerMillionTokens: decRef("0.3"), }, }) }) @@ -993,10 +993,10 @@ func TestCreateChatModelConfig(t *testing.T) { isDefault := true pricing := &codersdk.ChatModelCallConfig{ Cost: &codersdk.ModelCostConfig{ - InputPricePerMillionTokens: ptr.Ref(0.15), - OutputPricePerMillionTokens: ptr.Ref(0.6), - CacheReadPricePerMillionTokens: ptr.Ref(0.03), - CacheWritePricePerMillionTokens: ptr.Ref(0.3), + InputPricePerMillionTokens: decRef("0.15"), + OutputPricePerMillionTokens: decRef("0.6"), + CacheReadPricePerMillionTokens: decRef("0.03"), + CacheWritePricePerMillionTokens: decRef("0.3"), }, } modelConfig, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{ @@ -1040,7 +1040,7 @@ func TestCreateChatModelConfig(t *testing.T) { ContextLimit: &contextLimit, ModelConfig: &codersdk.ChatModelCallConfig{ Cost: &codersdk.ModelCostConfig{ - InputPricePerMillionTokens: ptr.Ref(-0.01), + InputPricePerMillionTokens: decRef("-0.01"), }, }, }) @@ -1123,10 +1123,10 @@ func TestUpdateChatModelConfig(t *testing.T) { contextLimit := int64(8192) pricing := &codersdk.ChatModelCallConfig{ Cost: &codersdk.ModelCostConfig{ - InputPricePerMillionTokens: ptr.Ref(0.2), - OutputPricePerMillionTokens: ptr.Ref(0.8), - CacheReadPricePerMillionTokens: ptr.Ref(0.04), - CacheWritePricePerMillionTokens: ptr.Ref(0.4), + InputPricePerMillionTokens: decRef("0.2"), + OutputPricePerMillionTokens: decRef("0.8"), + CacheReadPricePerMillionTokens: decRef("0.04"), + CacheWritePricePerMillionTokens: decRef("0.4"), }, } updated, err := client.UpdateChatModelConfig(ctx, modelConfig.ID, codersdk.UpdateChatModelConfigRequest{ @@ -1157,7 +1157,7 @@ func TestUpdateChatModelConfig(t *testing.T) { _, err := client.UpdateChatModelConfig(ctx, modelConfig.ID, codersdk.UpdateChatModelConfigRequest{ ModelConfig: &codersdk.ChatModelCallConfig{ Cost: &codersdk.ModelCostConfig{ - OutputPricePerMillionTokens: ptr.Ref(-1.0), + OutputPricePerMillionTokens: decRef("-1.0"), }, }, }) @@ -3479,6 +3479,302 @@ func TestGetChatFile(t *testing.T) { }) } +func TestChatCostSummary(t *testing.T) { + t.Parallel() + + t.Run("BasicSummary", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client, db := newChatClientWithDatabase(t) + firstUser := coderdtest.CreateFirstUser(t, client) + modelConfig := createChatModelConfig(t, client) + + chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ + OwnerID: firstUser.UserID, + LastModelConfigID: modelConfig.ID, + Title: "test chat", + }) + require.NoError(t, err) + + for i := 0; i < 2; i++ { + _, err = db.InsertChatMessage(dbauthz.AsSystemRestricted(ctx), database.InsertChatMessageParams{ + ChatID: chat.ID, + ModelConfigID: uuid.NullUUID{UUID: modelConfig.ID, Valid: true}, + Role: "assistant", + Visibility: database.ChatMessageVisibilityBoth, + InputTokens: sql.NullInt64{Int64: 100, Valid: true}, + OutputTokens: sql.NullInt64{Int64: 50, Valid: true}, + TotalCostMicros: sql.NullInt64{Int64: 500, Valid: true}, + }) + require.NoError(t, err) + } + + summary, err := client.GetChatCostSummary(ctx, "me", codersdk.ChatCostSummaryOptions{}) + require.NoError(t, err) + + require.Equal(t, int64(1000), summary.TotalCostMicros) + require.Equal(t, int64(2), summary.PricedMessageCount) + require.Equal(t, int64(0), summary.UnpricedMessageCount) + require.Equal(t, int64(200), summary.TotalInputTokens) + require.Equal(t, int64(100), summary.TotalOutputTokens) + + require.Len(t, summary.ByModel, 1) + require.Equal(t, modelConfig.ID, summary.ByModel[0].ModelConfigID) + require.Equal(t, int64(1000), summary.ByModel[0].TotalCostMicros) + require.Equal(t, int64(2), summary.ByModel[0].MessageCount) + + require.Len(t, summary.ByChat, 1) + require.Equal(t, chat.ID, summary.ByChat[0].RootChatID) + require.Equal(t, int64(1000), summary.ByChat[0].TotalCostMicros) + require.Equal(t, int64(2), summary.ByChat[0].MessageCount) + }) +} + +func TestChatCostSummary_AdminDrilldown(t *testing.T) { + t.Parallel() + + seedCtx := testutil.Context(t, testutil.WaitLong) + client, db := newChatClientWithDatabase(t) + firstUser := coderdtest.CreateFirstUser(t, client) + memberClient, member := coderdtest.CreateAnotherUser(t, client, firstUser.OrganizationID) + modelConfig := createChatModelConfig(t, client) + + chat, err := db.InsertChat(dbauthz.AsSystemRestricted(seedCtx), database.InsertChatParams{ + OwnerID: member.ID, + LastModelConfigID: modelConfig.ID, + Title: "member chat", + }) + require.NoError(t, err) + + _, err = db.InsertChatMessage(dbauthz.AsSystemRestricted(seedCtx), database.InsertChatMessageParams{ + ChatID: chat.ID, + ModelConfigID: uuid.NullUUID{UUID: modelConfig.ID, Valid: true}, + Role: "assistant", + Visibility: database.ChatMessageVisibilityBoth, + InputTokens: sql.NullInt64{Int64: 200, Valid: true}, + OutputTokens: sql.NullInt64{Int64: 100, Valid: true}, + TotalCostMicros: sql.NullInt64{Int64: 750, Valid: true}, + }) + require.NoError(t, err) + + t.Run("AdminCanDrilldown", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + summary, err := client.GetChatCostSummary(ctx, member.ID.String(), codersdk.ChatCostSummaryOptions{}) + require.NoError(t, err) + require.Equal(t, int64(750), summary.TotalCostMicros) + require.Equal(t, int64(1), summary.PricedMessageCount) + }) + + t.Run("MemberCannotDrilldownOtherUser", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + _, err := memberClient.GetChatCostSummary(ctx, firstUser.UserID.String(), codersdk.ChatCostSummaryOptions{}) + require.Error(t, err) + var sdkErr *codersdk.Error + require.ErrorAs(t, err, &sdkErr) + require.Equal(t, http.StatusNotFound, sdkErr.StatusCode()) + }) +} + +func TestChatCostUsers(t *testing.T) { + t.Parallel() + + seedCtx := testutil.Context(t, testutil.WaitLong) + client, db := newChatClientWithDatabase(t) + firstUser := coderdtest.CreateFirstUser(t, client) + memberClient, member := coderdtest.CreateAnotherUser(t, client, firstUser.OrganizationID) + firstUserRecord, err := db.GetUserByID(dbauthz.AsSystemRestricted(seedCtx), firstUser.UserID) + require.NoError(t, err) + modelConfig := createChatModelConfig(t, client) + + adminChat, err := db.InsertChat(dbauthz.AsSystemRestricted(seedCtx), database.InsertChatParams{ + OwnerID: firstUser.UserID, + LastModelConfigID: modelConfig.ID, + Title: "admin chat", + }) + require.NoError(t, err) + _, err = db.InsertChatMessage(dbauthz.AsSystemRestricted(seedCtx), database.InsertChatMessageParams{ + ChatID: adminChat.ID, + ModelConfigID: uuid.NullUUID{UUID: modelConfig.ID, Valid: true}, + Role: "assistant", + Visibility: database.ChatMessageVisibilityBoth, + InputTokens: sql.NullInt64{Int64: 100, Valid: true}, + OutputTokens: sql.NullInt64{Int64: 50, Valid: true}, + TotalCostMicros: sql.NullInt64{Int64: 300, Valid: true}, + }) + require.NoError(t, err) + + memberChat, err := db.InsertChat(dbauthz.AsSystemRestricted(seedCtx), database.InsertChatParams{ + OwnerID: member.ID, + LastModelConfigID: modelConfig.ID, + Title: "member chat", + }) + require.NoError(t, err) + _, err = db.InsertChatMessage(dbauthz.AsSystemRestricted(seedCtx), database.InsertChatMessageParams{ + ChatID: memberChat.ID, + ModelConfigID: uuid.NullUUID{UUID: modelConfig.ID, Valid: true}, + Role: "assistant", + Visibility: database.ChatMessageVisibilityBoth, + InputTokens: sql.NullInt64{Int64: 200, Valid: true}, + OutputTokens: sql.NullInt64{Int64: 100, Valid: true}, + TotalCostMicros: sql.NullInt64{Int64: 800, Valid: true}, + }) + require.NoError(t, err) + + t.Run("AdminCanListUsers", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + resp, err := client.GetChatCostUsers(ctx, codersdk.ChatCostUsersOptions{}) + require.NoError(t, err) + require.Equal(t, int64(2), resp.Count) + require.Len(t, resp.Users, 2) + require.Equal(t, member.ID, resp.Users[0].UserID) + require.Equal(t, member.Username, resp.Users[0].Username) + require.Equal(t, int64(800), resp.Users[0].TotalCostMicros) + require.Equal(t, int64(1), resp.Users[0].MessageCount) + require.Equal(t, int64(1), resp.Users[0].ChatCount) + require.Equal(t, firstUser.UserID, resp.Users[1].UserID) + require.Equal(t, firstUserRecord.Username, resp.Users[1].Username) + require.Equal(t, int64(300), resp.Users[1].TotalCostMicros) + }) + + t.Run("AdminCanFilterAndPaginateUsers", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + resp, err := client.GetChatCostUsers(ctx, codersdk.ChatCostUsersOptions{ + Username: member.Username, + Pagination: codersdk.Pagination{ + Limit: 1, + Offset: 0, + }, + }) + require.NoError(t, err) + require.Equal(t, int64(1), resp.Count) + require.Len(t, resp.Users, 1) + require.Equal(t, member.ID, resp.Users[0].UserID) + require.Equal(t, member.Username, resp.Users[0].Username) + }) + + t.Run("MemberCannotListUsers", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + _, err := memberClient.GetChatCostUsers(ctx, codersdk.ChatCostUsersOptions{}) + require.Error(t, err) + var sdkErr *codersdk.Error + require.ErrorAs(t, err, &sdkErr) + require.Equal(t, http.StatusForbidden, sdkErr.StatusCode()) + }) +} + +func TestChatCostSummary_DateRange(t *testing.T) { + t.Parallel() + + seedCtx := testutil.Context(t, testutil.WaitLong) + client, db := newChatClientWithDatabase(t) + firstUser := coderdtest.CreateFirstUser(t, client) + modelConfig := createChatModelConfig(t, client) + + chat, err := db.InsertChat(dbauthz.AsSystemRestricted(seedCtx), database.InsertChatParams{ + OwnerID: firstUser.UserID, + LastModelConfigID: modelConfig.ID, + Title: "date range test", + }) + require.NoError(t, err) + + _, err = db.InsertChatMessage(dbauthz.AsSystemRestricted(seedCtx), database.InsertChatMessageParams{ + ChatID: chat.ID, + ModelConfigID: uuid.NullUUID{UUID: modelConfig.ID, Valid: true}, + Role: "assistant", + Visibility: database.ChatMessageVisibilityBoth, + InputTokens: sql.NullInt64{Int64: 100, Valid: true}, + OutputTokens: sql.NullInt64{Int64: 50, Valid: true}, + TotalCostMicros: sql.NullInt64{Int64: 500, Valid: true}, + }) + require.NoError(t, err) + + now := time.Now() + + t.Run("MessageInRange", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + summary, err := client.GetChatCostSummary(ctx, "me", codersdk.ChatCostSummaryOptions{ + StartDate: now.Add(-time.Hour), + EndDate: now.Add(time.Hour), + }) + require.NoError(t, err) + require.Equal(t, int64(500), summary.TotalCostMicros) + require.Equal(t, int64(1), summary.PricedMessageCount) + }) + + t.Run("MessageOutOfRange", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + summary, err := client.GetChatCostSummary(ctx, "me", codersdk.ChatCostSummaryOptions{ + StartDate: now.Add(time.Hour), + EndDate: now.Add(2 * time.Hour), + }) + require.NoError(t, err) + require.Equal(t, int64(0), summary.TotalCostMicros) + require.Equal(t, int64(0), summary.PricedMessageCount) + }) +} + +func TestChatCostSummary_UnpricedMessages(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client, db := newChatClientWithDatabase(t) + firstUser := coderdtest.CreateFirstUser(t, client) + modelConfig := createChatModelConfig(t, client) + + chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ + OwnerID: firstUser.UserID, + LastModelConfigID: modelConfig.ID, + Title: "unpriced test", + }) + require.NoError(t, err) + + _, err = db.InsertChatMessage(dbauthz.AsSystemRestricted(ctx), database.InsertChatMessageParams{ + ChatID: chat.ID, + ModelConfigID: uuid.NullUUID{UUID: modelConfig.ID, Valid: true}, + Role: "assistant", + Visibility: database.ChatMessageVisibilityBoth, + InputTokens: sql.NullInt64{Int64: 100, Valid: true}, + OutputTokens: sql.NullInt64{Int64: 50, Valid: true}, + TotalCostMicros: sql.NullInt64{Int64: 500, Valid: true}, + }) + require.NoError(t, err) + + _, err = db.InsertChatMessage(dbauthz.AsSystemRestricted(ctx), database.InsertChatMessageParams{ + ChatID: chat.ID, + ModelConfigID: uuid.NullUUID{UUID: modelConfig.ID, Valid: true}, + Role: "assistant", + Visibility: database.ChatMessageVisibilityBoth, + InputTokens: sql.NullInt64{Int64: 200, Valid: true}, + OutputTokens: sql.NullInt64{Int64: 75, Valid: true}, + TotalCostMicros: sql.NullInt64{}, + }) + require.NoError(t, err) + + summary, err := client.GetChatCostSummary(ctx, "me", codersdk.ChatCostSummaryOptions{}) + require.NoError(t, err) + + require.Equal(t, int64(500), summary.TotalCostMicros) + require.Equal(t, int64(1), summary.PricedMessageCount) + require.Equal(t, int64(1), summary.UnpricedMessageCount) + require.Equal(t, int64(300), summary.TotalInputTokens) + require.Equal(t, int64(125), summary.TotalOutputTokens) +} + func requireChatModelPricing( t *testing.T, actual *codersdk.ChatModelCallConfig, @@ -3495,10 +3791,15 @@ func requireChatModelPricing( require.NotNil(t, actual.Cost.CacheReadPricePerMillionTokens) require.NotNil(t, actual.Cost.CacheWritePricePerMillionTokens) - require.Equal(t, *expected.Cost.InputPricePerMillionTokens, *actual.Cost.InputPricePerMillionTokens) - require.Equal(t, *expected.Cost.OutputPricePerMillionTokens, *actual.Cost.OutputPricePerMillionTokens) - require.Equal(t, *expected.Cost.CacheReadPricePerMillionTokens, *actual.Cost.CacheReadPricePerMillionTokens) - require.Equal(t, *expected.Cost.CacheWritePricePerMillionTokens, *actual.Cost.CacheWritePricePerMillionTokens) + require.True(t, expected.Cost.InputPricePerMillionTokens.Equal(*actual.Cost.InputPricePerMillionTokens)) + require.True(t, expected.Cost.OutputPricePerMillionTokens.Equal(*actual.Cost.OutputPricePerMillionTokens)) + require.True(t, expected.Cost.CacheReadPricePerMillionTokens.Equal(*actual.Cost.CacheReadPricePerMillionTokens)) + require.True(t, expected.Cost.CacheWritePricePerMillionTokens.Equal(*actual.Cost.CacheWritePricePerMillionTokens)) +} + +func decRef(value string) *decimal.Decimal { + d := decimal.RequireFromString(value) + return &d } func createChatModelConfig(t *testing.T, client *codersdk.Client) codersdk.ChatModelConfig { diff --git a/coderd/coderd.go b/coderd/coderd.go index 19caea0445..18966bd11d 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -1139,6 +1139,13 @@ func New(options *Options) *API { r.Post("/", api.postChats) r.Get("/models", api.listChatModels) r.Get("/watch", api.watchChats) + r.Route("/cost", func(r chi.Router) { + r.Get("/users", api.chatCostUsers) + r.Route("/{user}", func(r chi.Router) { + r.Use(httpmw.ExtractUserParam(options.Database)) + r.Get("/summary", api.chatCostSummary) + }) + }) r.Route("/files", func(r chi.Router) { r.Use(httpmw.RateLimit(options.FilesRateLimit, time.Minute)) r.Post("/", api.postChatFile) diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 3e57a0cf90..2197fd8722 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -2426,6 +2426,34 @@ func (q *querier) GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (datab return fetch(q.log, q.auth, q.db.GetChatByIDForUpdate)(ctx, id) } +func (q *querier) GetChatCostPerChat(ctx context.Context, arg database.GetChatCostPerChatParams) ([]database.GetChatCostPerChatRow, error) { + if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat.WithOwner(arg.OwnerID.String())); err != nil { + return nil, err + } + return q.db.GetChatCostPerChat(ctx, arg) +} + +func (q *querier) GetChatCostPerModel(ctx context.Context, arg database.GetChatCostPerModelParams) ([]database.GetChatCostPerModelRow, error) { + if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat.WithOwner(arg.OwnerID.String())); err != nil { + return nil, err + } + return q.db.GetChatCostPerModel(ctx, arg) +} + +func (q *querier) GetChatCostPerUser(ctx context.Context, arg database.GetChatCostPerUserParams) ([]database.GetChatCostPerUserRow, error) { + if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat); err != nil { + return nil, err + } + return q.db.GetChatCostPerUser(ctx, arg) +} + +func (q *querier) GetChatCostSummary(ctx context.Context, arg database.GetChatCostSummaryParams) (database.GetChatCostSummaryRow, error) { + if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat.WithOwner(arg.OwnerID.String())); err != nil { + return database.GetChatCostSummaryRow{}, err + } + return q.db.GetChatCostSummary(ctx, arg) +} + func (q *querier) GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (database.ChatDiffStatus, error) { // Authorize read on the parent chat. _, err := q.GetChatByID(ctx, chatID) diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 4035f78cd5..778d670900 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -438,6 +438,81 @@ func (s *MethodTestSuite) TestChats() { dbm.EXPECT().GetChatByIDForUpdate(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes() check.Args(chat.ID).Asserts(chat, policy.ActionRead).Returns(chat) })) + s.Run("GetChatCostPerChat", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + arg := database.GetChatCostPerChatParams{ + OwnerID: uuid.New(), + StartDate: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), + EndDate: time.Date(2025, 2, 1, 0, 0, 0, 0, time.UTC), + } + rows := []database.GetChatCostPerChatRow{{ + RootChatID: uuid.New(), + ChatTitle: "chat-cost", + TotalCostMicros: 123, + MessageCount: 4, + TotalInputTokens: 55, + TotalOutputTokens: 89, + }} + dbm.EXPECT().GetChatCostPerChat(gomock.Any(), arg).Return(rows, nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceChat.WithOwner(arg.OwnerID.String()), policy.ActionRead).Returns(rows) + })) + s.Run("GetChatCostPerModel", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + arg := database.GetChatCostPerModelParams{ + OwnerID: uuid.New(), + StartDate: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), + EndDate: time.Date(2025, 2, 1, 0, 0, 0, 0, time.UTC), + } + rows := []database.GetChatCostPerModelRow{{ + ModelConfigID: uuid.New(), + DisplayName: "GPT 4.1", + Provider: "openai", + Model: "gpt-4.1", + TotalCostMicros: 456, + MessageCount: 7, + TotalInputTokens: 144, + TotalOutputTokens: 233, + }} + dbm.EXPECT().GetChatCostPerModel(gomock.Any(), arg).Return(rows, nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceChat.WithOwner(arg.OwnerID.String()), policy.ActionRead).Returns(rows) + })) + s.Run("GetChatCostPerUser", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + arg := database.GetChatCostPerUserParams{ + PageOffset: 0, + PageLimit: 25, + StartDate: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), + EndDate: time.Date(2025, 2, 1, 0, 0, 0, 0, time.UTC), + Username: "cost-user", + } + rows := []database.GetChatCostPerUserRow{{ + UserID: uuid.New(), + Username: "cost-user", + Name: "Cost User", + AvatarURL: "https://example.com/avatar.png", + TotalCostMicros: 789, + MessageCount: 11, + ChatCount: 3, + TotalInputTokens: 377, + TotalOutputTokens: 610, + TotalCount: 1, + }} + dbm.EXPECT().GetChatCostPerUser(gomock.Any(), arg).Return(rows, nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceChat, policy.ActionRead).Returns(rows) + })) + s.Run("GetChatCostSummary", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + arg := database.GetChatCostSummaryParams{ + OwnerID: uuid.New(), + StartDate: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC), + EndDate: time.Date(2025, 2, 1, 0, 0, 0, 0, time.UTC), + } + row := database.GetChatCostSummaryRow{ + TotalCostMicros: 987, + PricedMessageCount: 12, + UnpricedMessageCount: 2, + TotalInputTokens: 400, + TotalOutputTokens: 800, + } + dbm.EXPECT().GetChatCostSummary(gomock.Any(), arg).Return(row, nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceChat.WithOwner(arg.OwnerID.String()), policy.ActionRead).Returns(row) + })) s.Run("GetChatDiffStatusByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { chat := testutil.Fake(s.T(), faker, database.Chat{}) diffStatus := testutil.Fake(s.T(), faker, database.ChatDiffStatus{ChatID: chat.ID}) diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index c2286d0602..93ab9fd354 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -983,6 +983,38 @@ func (m queryMetricsStore) GetChatByIDForUpdate(ctx context.Context, id uuid.UUI return r0, r1 } +func (m queryMetricsStore) GetChatCostPerChat(ctx context.Context, arg database.GetChatCostPerChatParams) ([]database.GetChatCostPerChatRow, error) { + start := time.Now() + r0, r1 := m.s.GetChatCostPerChat(ctx, arg) + m.queryLatencies.WithLabelValues("GetChatCostPerChat").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatCostPerChat").Inc() + return r0, r1 +} + +func (m queryMetricsStore) GetChatCostPerModel(ctx context.Context, arg database.GetChatCostPerModelParams) ([]database.GetChatCostPerModelRow, error) { + start := time.Now() + r0, r1 := m.s.GetChatCostPerModel(ctx, arg) + m.queryLatencies.WithLabelValues("GetChatCostPerModel").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatCostPerModel").Inc() + return r0, r1 +} + +func (m queryMetricsStore) GetChatCostPerUser(ctx context.Context, arg database.GetChatCostPerUserParams) ([]database.GetChatCostPerUserRow, error) { + start := time.Now() + r0, r1 := m.s.GetChatCostPerUser(ctx, arg) + m.queryLatencies.WithLabelValues("GetChatCostPerUser").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatCostPerUser").Inc() + return r0, r1 +} + +func (m queryMetricsStore) GetChatCostSummary(ctx context.Context, arg database.GetChatCostSummaryParams) (database.GetChatCostSummaryRow, error) { + start := time.Now() + r0, r1 := m.s.GetChatCostSummary(ctx, arg) + m.queryLatencies.WithLabelValues("GetChatCostSummary").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatCostSummary").Inc() + return r0, r1 +} + func (m queryMetricsStore) GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (database.ChatDiffStatus, error) { start := time.Now() r0, r1 := m.s.GetChatDiffStatusByChatID(ctx, chatID) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index 89c8493c85..26266f9dfc 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -1778,6 +1778,66 @@ func (mr *MockStoreMockRecorder) GetChatByIDForUpdate(ctx, id any) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatByIDForUpdate", reflect.TypeOf((*MockStore)(nil).GetChatByIDForUpdate), ctx, id) } +// GetChatCostPerChat mocks base method. +func (m *MockStore) GetChatCostPerChat(ctx context.Context, arg database.GetChatCostPerChatParams) ([]database.GetChatCostPerChatRow, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetChatCostPerChat", ctx, arg) + ret0, _ := ret[0].([]database.GetChatCostPerChatRow) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetChatCostPerChat indicates an expected call of GetChatCostPerChat. +func (mr *MockStoreMockRecorder) GetChatCostPerChat(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatCostPerChat", reflect.TypeOf((*MockStore)(nil).GetChatCostPerChat), ctx, arg) +} + +// GetChatCostPerModel mocks base method. +func (m *MockStore) GetChatCostPerModel(ctx context.Context, arg database.GetChatCostPerModelParams) ([]database.GetChatCostPerModelRow, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetChatCostPerModel", ctx, arg) + ret0, _ := ret[0].([]database.GetChatCostPerModelRow) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetChatCostPerModel indicates an expected call of GetChatCostPerModel. +func (mr *MockStoreMockRecorder) GetChatCostPerModel(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatCostPerModel", reflect.TypeOf((*MockStore)(nil).GetChatCostPerModel), ctx, arg) +} + +// GetChatCostPerUser mocks base method. +func (m *MockStore) GetChatCostPerUser(ctx context.Context, arg database.GetChatCostPerUserParams) ([]database.GetChatCostPerUserRow, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetChatCostPerUser", ctx, arg) + ret0, _ := ret[0].([]database.GetChatCostPerUserRow) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetChatCostPerUser indicates an expected call of GetChatCostPerUser. +func (mr *MockStoreMockRecorder) GetChatCostPerUser(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatCostPerUser", reflect.TypeOf((*MockStore)(nil).GetChatCostPerUser), ctx, arg) +} + +// GetChatCostSummary mocks base method. +func (m *MockStore) GetChatCostSummary(ctx context.Context, arg database.GetChatCostSummaryParams) (database.GetChatCostSummaryRow, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetChatCostSummary", ctx, arg) + ret0, _ := ret[0].(database.GetChatCostSummaryRow) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetChatCostSummary indicates an expected call of GetChatCostSummary. +func (mr *MockStoreMockRecorder) GetChatCostSummary(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatCostSummary", reflect.TypeOf((*MockStore)(nil).GetChatCostSummary), ctx, arg) +} + // GetChatDiffStatusByChatID mocks base method. func (m *MockStore) GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (database.ChatDiffStatus, error) { m.ctrl.T.Helper() diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index 784b6628cc..90f0b65b52 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -1226,7 +1226,8 @@ CREATE TABLE chat_messages ( context_limit bigint, compressed boolean DEFAULT false NOT NULL, created_by uuid, - content_version smallint NOT NULL + content_version smallint NOT NULL, + total_cost_micros bigint ); CREATE SEQUENCE chat_messages_id_seq @@ -3534,6 +3535,8 @@ CREATE INDEX idx_chat_messages_chat_created ON chat_messages USING btree (chat_i CREATE INDEX idx_chat_messages_compressed_summary_boundary ON chat_messages USING btree (chat_id, created_at DESC, id DESC) WHERE ((compressed = true) AND (role = 'system'::chat_message_role) AND (visibility = ANY (ARRAY['model'::chat_message_visibility, 'both'::chat_message_visibility]))); +CREATE INDEX idx_chat_messages_created_at ON chat_messages USING btree (created_at); + CREATE INDEX idx_chat_model_configs_enabled ON chat_model_configs USING btree (enabled); CREATE INDEX idx_chat_model_configs_provider ON chat_model_configs USING btree (provider); diff --git a/coderd/database/migrations/000435_add_cost_to_chat_messages.down.sql b/coderd/database/migrations/000435_add_cost_to_chat_messages.down.sql new file mode 100644 index 0000000000..471a9b5452 --- /dev/null +++ b/coderd/database/migrations/000435_add_cost_to_chat_messages.down.sql @@ -0,0 +1,3 @@ +DROP INDEX IF EXISTS idx_chat_messages_created_at; + +ALTER TABLE chat_messages DROP COLUMN total_cost_micros; diff --git a/coderd/database/migrations/000435_add_cost_to_chat_messages.up.sql b/coderd/database/migrations/000435_add_cost_to_chat_messages.up.sql new file mode 100644 index 0000000000..b17e47a882 --- /dev/null +++ b/coderd/database/migrations/000435_add_cost_to_chat_messages.up.sql @@ -0,0 +1,68 @@ +ALTER TABLE chat_messages ADD COLUMN total_cost_micros BIGINT; + +WITH message_costs AS ( + SELECT + msg.id, + ROUND( + COALESCE(msg.input_tokens, 0)::numeric * COALESCE(pricing.input_price, 0) + + COALESCE(msg.output_tokens, 0)::numeric * COALESCE(pricing.output_price, 0) + + COALESCE(msg.cache_read_tokens, 0)::numeric * COALESCE(pricing.cache_read_price, 0) + + COALESCE(msg.cache_creation_tokens, 0)::numeric * COALESCE(pricing.cache_write_price, 0) + )::bigint AS total_cost_micros + FROM + chat_messages AS msg + JOIN + chat_model_configs AS cfg + ON + cfg.id = msg.model_config_id + CROSS JOIN LATERAL ( + SELECT + COALESCE( + (cfg.options -> 'cost' ->> 'input_price_per_million_tokens')::numeric, + (cfg.options ->> 'input_price_per_million_tokens')::numeric + ) AS input_price, + COALESCE( + (cfg.options -> 'cost' ->> 'output_price_per_million_tokens')::numeric, + (cfg.options ->> 'output_price_per_million_tokens')::numeric + ) AS output_price, + COALESCE( + (cfg.options -> 'cost' ->> 'cache_read_price_per_million_tokens')::numeric, + (cfg.options ->> 'cache_read_price_per_million_tokens')::numeric + ) AS cache_read_price, + COALESCE( + (cfg.options -> 'cost' ->> 'cache_write_price_per_million_tokens')::numeric, + (cfg.options ->> 'cache_write_price_per_million_tokens')::numeric + ) AS cache_write_price + ) AS pricing + WHERE + msg.total_cost_micros IS NULL + AND ( + msg.input_tokens IS NOT NULL + OR msg.output_tokens IS NOT NULL + OR msg.reasoning_tokens IS NOT NULL + OR msg.cache_creation_tokens IS NOT NULL + OR msg.cache_read_tokens IS NOT NULL + ) + AND ( + pricing.input_price IS NOT NULL + OR pricing.output_price IS NOT NULL + OR pricing.cache_read_price IS NOT NULL + OR pricing.cache_write_price IS NOT NULL + ) + AND ( + (msg.input_tokens IS NOT NULL AND pricing.input_price IS NOT NULL) + OR (msg.output_tokens IS NOT NULL AND pricing.output_price IS NOT NULL) + OR (msg.cache_read_tokens IS NOT NULL AND pricing.cache_read_price IS NOT NULL) + OR (msg.cache_creation_tokens IS NOT NULL AND pricing.cache_write_price IS NOT NULL) + ) +) +UPDATE + chat_messages AS msg +SET + total_cost_micros = message_costs.total_cost_micros +FROM + message_costs +WHERE + msg.id = message_costs.id; + +CREATE INDEX idx_chat_messages_created_at ON chat_messages (created_at); diff --git a/coderd/database/models.go b/coderd/database/models.go index 948ab02739..633b9d2ec4 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -4020,6 +4020,7 @@ type ChatMessage struct { Compressed bool `db:"compressed" json:"compressed"` CreatedBy uuid.NullUUID `db:"created_by" json:"created_by"` ContentVersion int16 `db:"content_version" json:"content_version"` + TotalCostMicros sql.NullInt64 `db:"total_cost_micros" json:"total_cost_micros"` } type ChatModelConfig struct { diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 8ebd3c001d..e01ae7c66f 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -215,6 +215,19 @@ type sqlcQuerier interface { GetAuthorizationUserRoles(ctx context.Context, userID uuid.UUID) (GetAuthorizationUserRolesRow, error) GetChatByID(ctx context.Context, id uuid.UUID) (Chat, error) GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (Chat, error) + // Per-root-chat cost breakdown for a single user within a date range. + // Groups by root_chat_id so forked chats roll up under their root. + // Only counts assistant-role messages. + GetChatCostPerChat(ctx context.Context, arg GetChatCostPerChatParams) ([]GetChatCostPerChatRow, error) + // Per-model cost breakdown for a single user within a date range. + // Only counts assistant-role messages that have a model_config_id. + GetChatCostPerModel(ctx context.Context, arg GetChatCostPerModelParams) ([]GetChatCostPerModelRow, error) + // Deployment-wide per-user cost rollup within a date range. + // Only counts assistant-role messages. + GetChatCostPerUser(ctx context.Context, arg GetChatCostPerUserParams) ([]GetChatCostPerUserRow, error) + // Aggregate cost summary for a single user within a date range. + // Only counts assistant-role messages. + GetChatCostSummary(ctx context.Context, arg GetChatCostSummaryParams) (GetChatCostSummaryRow, error) GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (ChatDiffStatus, error) GetChatDiffStatusesByChatIDs(ctx context.Context, chatIds []uuid.UUID) ([]ChatDiffStatus, error) GetChatFileByID(ctx context.Context, id uuid.UUID) (ChatFile, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 24de7b15b5..8a67e80e84 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -3229,6 +3229,353 @@ func (q *sqlQuerier) GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (Ch return i, err } +const getChatCostPerChat = `-- name: GetChatCostPerChat :many +WITH chat_costs AS ( + SELECT + COALESCE(c.root_chat_id, c.id) AS root_chat_id, + COALESCE(SUM(cm.total_cost_micros), 0)::bigint AS total_cost_micros, + COUNT(*) FILTER ( + WHERE cm.input_tokens IS NOT NULL + OR cm.output_tokens IS NOT NULL + OR cm.reasoning_tokens IS NOT NULL + OR cm.cache_creation_tokens IS NOT NULL + OR cm.cache_read_tokens IS NOT NULL + )::bigint AS message_count, + COALESCE(SUM(cm.input_tokens), 0)::bigint AS total_input_tokens, + COALESCE(SUM(cm.output_tokens), 0)::bigint AS total_output_tokens + FROM chat_messages cm + JOIN chats c ON c.id = cm.chat_id + WHERE c.owner_id = $1::uuid + AND cm.role = 'assistant' + AND cm.created_at >= $2::timestamptz + AND cm.created_at < $3::timestamptz + GROUP BY COALESCE(c.root_chat_id, c.id) +) +SELECT + cc.root_chat_id, + COALESCE(rc.title, '') AS chat_title, + cc.total_cost_micros, + cc.message_count, + cc.total_input_tokens, + cc.total_output_tokens +FROM chat_costs cc +LEFT JOIN chats rc ON rc.id = cc.root_chat_id +ORDER BY cc.total_cost_micros DESC +` + +type GetChatCostPerChatParams struct { + OwnerID uuid.UUID `db:"owner_id" json:"owner_id"` + StartDate time.Time `db:"start_date" json:"start_date"` + EndDate time.Time `db:"end_date" json:"end_date"` +} + +type GetChatCostPerChatRow struct { + RootChatID uuid.UUID `db:"root_chat_id" json:"root_chat_id"` + ChatTitle string `db:"chat_title" json:"chat_title"` + TotalCostMicros int64 `db:"total_cost_micros" json:"total_cost_micros"` + MessageCount int64 `db:"message_count" json:"message_count"` + TotalInputTokens int64 `db:"total_input_tokens" json:"total_input_tokens"` + TotalOutputTokens int64 `db:"total_output_tokens" json:"total_output_tokens"` +} + +// Per-root-chat cost breakdown for a single user within a date range. +// Groups by root_chat_id so forked chats roll up under their root. +// Only counts assistant-role messages. +func (q *sqlQuerier) GetChatCostPerChat(ctx context.Context, arg GetChatCostPerChatParams) ([]GetChatCostPerChatRow, error) { + rows, err := q.db.QueryContext(ctx, getChatCostPerChat, arg.OwnerID, arg.StartDate, arg.EndDate) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetChatCostPerChatRow + for rows.Next() { + var i GetChatCostPerChatRow + if err := rows.Scan( + &i.RootChatID, + &i.ChatTitle, + &i.TotalCostMicros, + &i.MessageCount, + &i.TotalInputTokens, + &i.TotalOutputTokens, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getChatCostPerModel = `-- name: GetChatCostPerModel :many +SELECT + cmc.id AS model_config_id, + cmc.display_name, + cmc.provider, + cmc.model, + COALESCE(SUM(cm.total_cost_micros), 0)::bigint AS total_cost_micros, + COUNT(*) FILTER ( + WHERE cm.input_tokens IS NOT NULL + OR cm.output_tokens IS NOT NULL + OR cm.reasoning_tokens IS NOT NULL + OR cm.cache_creation_tokens IS NOT NULL + OR cm.cache_read_tokens IS NOT NULL + )::bigint AS message_count, + COALESCE(SUM(cm.input_tokens), 0)::bigint AS total_input_tokens, + COALESCE(SUM(cm.output_tokens), 0)::bigint AS total_output_tokens +FROM + chat_messages cm +JOIN + chats c ON c.id = cm.chat_id +JOIN + chat_model_configs cmc ON cmc.id = cm.model_config_id +WHERE + c.owner_id = $1::uuid + AND cm.role = 'assistant' + AND cm.created_at >= $2::timestamptz + AND cm.created_at < $3::timestamptz +GROUP BY + cmc.id, cmc.display_name, cmc.provider, cmc.model +ORDER BY + total_cost_micros DESC +` + +type GetChatCostPerModelParams struct { + OwnerID uuid.UUID `db:"owner_id" json:"owner_id"` + StartDate time.Time `db:"start_date" json:"start_date"` + EndDate time.Time `db:"end_date" json:"end_date"` +} + +type GetChatCostPerModelRow struct { + ModelConfigID uuid.UUID `db:"model_config_id" json:"model_config_id"` + DisplayName string `db:"display_name" json:"display_name"` + Provider string `db:"provider" json:"provider"` + Model string `db:"model" json:"model"` + TotalCostMicros int64 `db:"total_cost_micros" json:"total_cost_micros"` + MessageCount int64 `db:"message_count" json:"message_count"` + TotalInputTokens int64 `db:"total_input_tokens" json:"total_input_tokens"` + TotalOutputTokens int64 `db:"total_output_tokens" json:"total_output_tokens"` +} + +// Per-model cost breakdown for a single user within a date range. +// Only counts assistant-role messages that have a model_config_id. +func (q *sqlQuerier) GetChatCostPerModel(ctx context.Context, arg GetChatCostPerModelParams) ([]GetChatCostPerModelRow, error) { + rows, err := q.db.QueryContext(ctx, getChatCostPerModel, arg.OwnerID, arg.StartDate, arg.EndDate) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetChatCostPerModelRow + for rows.Next() { + var i GetChatCostPerModelRow + if err := rows.Scan( + &i.ModelConfigID, + &i.DisplayName, + &i.Provider, + &i.Model, + &i.TotalCostMicros, + &i.MessageCount, + &i.TotalInputTokens, + &i.TotalOutputTokens, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getChatCostPerUser = `-- name: GetChatCostPerUser :many +WITH chat_cost_users AS ( + SELECT + c.owner_id AS user_id, + u.username, + u.name, + u.avatar_url, + COALESCE(SUM(cm.total_cost_micros), 0)::bigint AS total_cost_micros, + COUNT(*) FILTER ( + WHERE cm.input_tokens IS NOT NULL + OR cm.output_tokens IS NOT NULL + OR cm.reasoning_tokens IS NOT NULL + OR cm.cache_creation_tokens IS NOT NULL + OR cm.cache_read_tokens IS NOT NULL + )::bigint AS message_count, + COUNT(DISTINCT COALESCE(c.root_chat_id, c.id))::bigint AS chat_count, + COALESCE(SUM(cm.input_tokens), 0)::bigint AS total_input_tokens, + COALESCE(SUM(cm.output_tokens), 0)::bigint AS total_output_tokens + FROM + chat_messages cm + JOIN + chats c ON c.id = cm.chat_id + JOIN + users u ON u.id = c.owner_id + WHERE + cm.role = 'assistant' + AND cm.created_at >= $3::timestamptz + AND cm.created_at < $4::timestamptz + AND ( + $5::text = '' + OR u.username ILIKE '%' || $5::text || '%' + ) + GROUP BY + c.owner_id, + u.username, + u.name, + u.avatar_url +) +SELECT + user_id, + username, + name, + avatar_url, + total_cost_micros, + message_count, + chat_count, + total_input_tokens, + total_output_tokens, + COUNT(*) OVER()::bigint AS total_count +FROM + chat_cost_users +ORDER BY + total_cost_micros DESC, + username ASC +LIMIT + $2::int +OFFSET + $1::int +` + +type GetChatCostPerUserParams struct { + PageOffset int32 `db:"page_offset" json:"page_offset"` + PageLimit int32 `db:"page_limit" json:"page_limit"` + StartDate time.Time `db:"start_date" json:"start_date"` + EndDate time.Time `db:"end_date" json:"end_date"` + Username string `db:"username" json:"username"` +} + +type GetChatCostPerUserRow struct { + UserID uuid.UUID `db:"user_id" json:"user_id"` + Username string `db:"username" json:"username"` + Name string `db:"name" json:"name"` + AvatarURL string `db:"avatar_url" json:"avatar_url"` + TotalCostMicros int64 `db:"total_cost_micros" json:"total_cost_micros"` + MessageCount int64 `db:"message_count" json:"message_count"` + ChatCount int64 `db:"chat_count" json:"chat_count"` + TotalInputTokens int64 `db:"total_input_tokens" json:"total_input_tokens"` + TotalOutputTokens int64 `db:"total_output_tokens" json:"total_output_tokens"` + TotalCount int64 `db:"total_count" json:"total_count"` +} + +// Deployment-wide per-user cost rollup within a date range. +// Only counts assistant-role messages. +func (q *sqlQuerier) GetChatCostPerUser(ctx context.Context, arg GetChatCostPerUserParams) ([]GetChatCostPerUserRow, error) { + rows, err := q.db.QueryContext(ctx, getChatCostPerUser, + arg.PageOffset, + arg.PageLimit, + arg.StartDate, + arg.EndDate, + arg.Username, + ) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetChatCostPerUserRow + for rows.Next() { + var i GetChatCostPerUserRow + if err := rows.Scan( + &i.UserID, + &i.Username, + &i.Name, + &i.AvatarURL, + &i.TotalCostMicros, + &i.MessageCount, + &i.ChatCount, + &i.TotalInputTokens, + &i.TotalOutputTokens, + &i.TotalCount, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getChatCostSummary = `-- name: GetChatCostSummary :one +SELECT + COALESCE(SUM(cm.total_cost_micros), 0)::bigint AS total_cost_micros, + COUNT(*) FILTER ( + WHERE cm.total_cost_micros IS NOT NULL + )::bigint AS priced_message_count, + COUNT(*) FILTER ( + WHERE cm.total_cost_micros IS NULL + AND ( + cm.input_tokens IS NOT NULL + OR cm.output_tokens IS NOT NULL + OR cm.reasoning_tokens IS NOT NULL + OR cm.cache_creation_tokens IS NOT NULL + OR cm.cache_read_tokens IS NOT NULL + ) + )::bigint AS unpriced_message_count, + COALESCE(SUM(cm.input_tokens), 0)::bigint AS total_input_tokens, + COALESCE(SUM(cm.output_tokens), 0)::bigint AS total_output_tokens +FROM + chat_messages cm +JOIN + chats c ON c.id = cm.chat_id +WHERE + c.owner_id = $1::uuid + AND cm.role = 'assistant' + AND cm.created_at >= $2::timestamptz + AND cm.created_at < $3::timestamptz +` + +type GetChatCostSummaryParams struct { + OwnerID uuid.UUID `db:"owner_id" json:"owner_id"` + StartDate time.Time `db:"start_date" json:"start_date"` + EndDate time.Time `db:"end_date" json:"end_date"` +} + +type GetChatCostSummaryRow struct { + TotalCostMicros int64 `db:"total_cost_micros" json:"total_cost_micros"` + PricedMessageCount int64 `db:"priced_message_count" json:"priced_message_count"` + UnpricedMessageCount int64 `db:"unpriced_message_count" json:"unpriced_message_count"` + TotalInputTokens int64 `db:"total_input_tokens" json:"total_input_tokens"` + TotalOutputTokens int64 `db:"total_output_tokens" json:"total_output_tokens"` +} + +// Aggregate cost summary for a single user within a date range. +// Only counts assistant-role messages. +func (q *sqlQuerier) GetChatCostSummary(ctx context.Context, arg GetChatCostSummaryParams) (GetChatCostSummaryRow, error) { + row := q.db.QueryRowContext(ctx, getChatCostSummary, arg.OwnerID, arg.StartDate, arg.EndDate) + var i GetChatCostSummaryRow + err := row.Scan( + &i.TotalCostMicros, + &i.PricedMessageCount, + &i.UnpricedMessageCount, + &i.TotalInputTokens, + &i.TotalOutputTokens, + ) + return i, err +} + const getChatDiffStatusByChatID = `-- name: GetChatDiffStatusByChatID :one SELECT chat_id, url, pull_request_state, changes_requested, additions, deletions, changed_files, refreshed_at, stale_at, created_at, updated_at, git_branch, git_remote_origin, pull_request_title, pull_request_draft @@ -3311,7 +3658,7 @@ func (q *sqlQuerier) GetChatDiffStatusesByChatIDs(ctx context.Context, chatIds [ const getChatMessageByID = `-- name: GetChatMessageByID :one SELECT - id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version + id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros FROM chat_messages WHERE @@ -3339,13 +3686,14 @@ func (q *sqlQuerier) GetChatMessageByID(ctx context.Context, id int64) (ChatMess &i.Compressed, &i.CreatedBy, &i.ContentVersion, + &i.TotalCostMicros, ) return i, err } const getChatMessagesByChatID = `-- name: GetChatMessagesByChatID :many SELECT - id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version + id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros FROM chat_messages WHERE @@ -3388,6 +3736,7 @@ func (q *sqlQuerier) GetChatMessagesByChatID(ctx context.Context, arg GetChatMes &i.Compressed, &i.CreatedBy, &i.ContentVersion, + &i.TotalCostMicros, ); err != nil { return nil, err } @@ -3419,7 +3768,7 @@ WITH latest_compressed_summary AS ( 1 ) SELECT - id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version + id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros FROM chat_messages WHERE @@ -3486,6 +3835,7 @@ func (q *sqlQuerier) GetChatMessagesForPromptByChatID(ctx context.Context, chatI &i.Compressed, &i.CreatedBy, &i.ContentVersion, + &i.TotalCostMicros, ); err != nil { return nil, err } @@ -3629,7 +3979,7 @@ func (q *sqlQuerier) GetChatsByOwnerID(ctx context.Context, arg GetChatsByOwnerI const getLastChatMessageByRole = `-- name: GetLastChatMessageByRole :one SELECT - id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version + id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros FROM chat_messages WHERE @@ -3667,6 +4017,7 @@ func (q *sqlQuerier) GetLastChatMessageByRole(ctx context.Context, arg GetLastCh &i.Compressed, &i.CreatedBy, &i.ContentVersion, + &i.TotalCostMicros, ) return i, err } @@ -3806,7 +4157,8 @@ INSERT INTO chat_messages ( cache_creation_tokens, cache_read_tokens, context_limit, - compressed + compressed, + total_cost_micros ) VALUES ( $1::uuid, $2::uuid, @@ -3822,10 +4174,11 @@ INSERT INTO chat_messages ( $12::bigint, $13::bigint, $14::bigint, - COALESCE($15::boolean, FALSE) + COALESCE($15::boolean, FALSE), + $16::bigint ) RETURNING - id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version + id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros ` type InsertChatMessageParams struct { @@ -3844,6 +4197,7 @@ type InsertChatMessageParams struct { CacheReadTokens sql.NullInt64 `db:"cache_read_tokens" json:"cache_read_tokens"` ContextLimit sql.NullInt64 `db:"context_limit" json:"context_limit"` Compressed sql.NullBool `db:"compressed" json:"compressed"` + TotalCostMicros sql.NullInt64 `db:"total_cost_micros" json:"total_cost_micros"` } func (q *sqlQuerier) InsertChatMessage(ctx context.Context, arg InsertChatMessageParams) (ChatMessage, error) { @@ -3863,6 +4217,7 @@ func (q *sqlQuerier) InsertChatMessage(ctx context.Context, arg InsertChatMessag arg.CacheReadTokens, arg.ContextLimit, arg.Compressed, + arg.TotalCostMicros, ) var i ChatMessage err := row.Scan( @@ -3883,6 +4238,7 @@ func (q *sqlQuerier) InsertChatMessage(ctx context.Context, arg InsertChatMessag &i.Compressed, &i.CreatedBy, &i.ContentVersion, + &i.TotalCostMicros, ) return i, err } @@ -4017,7 +4373,7 @@ SET WHERE id = $3::bigint RETURNING - id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version + id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros ` type UpdateChatMessageByIDParams struct { @@ -4047,6 +4403,7 @@ func (q *sqlQuerier) UpdateChatMessageByID(ctx context.Context, arg UpdateChatMe &i.Compressed, &i.CreatedBy, &i.ContentVersion, + &i.TotalCostMicros, ) return i, err } diff --git a/coderd/database/queries/chats.sql b/coderd/database/queries/chats.sql index d341b9d931..e18f913173 100644 --- a/coderd/database/queries/chats.sql +++ b/coderd/database/queries/chats.sql @@ -179,7 +179,8 @@ INSERT INTO chat_messages ( cache_creation_tokens, cache_read_tokens, context_limit, - compressed + compressed, + total_cost_micros ) VALUES ( @chat_id::uuid, sqlc.narg('created_by')::uuid, @@ -195,7 +196,8 @@ INSERT INTO chat_messages ( sqlc.narg('cache_creation_tokens')::bigint, sqlc.narg('cache_read_tokens')::bigint, sqlc.narg('context_limit')::bigint, - COALESCE(sqlc.narg('compressed')::boolean, FALSE) + COALESCE(sqlc.narg('compressed')::boolean, FALSE), + sqlc.narg('total_cost_micros')::bigint ) RETURNING *; @@ -481,3 +483,164 @@ SET updated_at = NOW() WHERE chat_id = @chat_id::uuid; + +-- name: GetChatCostSummary :one +-- Aggregate cost summary for a single user within a date range. +-- Only counts assistant-role messages. +SELECT + COALESCE(SUM(cm.total_cost_micros), 0)::bigint AS total_cost_micros, + COUNT(*) FILTER ( + WHERE cm.total_cost_micros IS NOT NULL + )::bigint AS priced_message_count, + COUNT(*) FILTER ( + WHERE cm.total_cost_micros IS NULL + AND ( + cm.input_tokens IS NOT NULL + OR cm.output_tokens IS NOT NULL + OR cm.reasoning_tokens IS NOT NULL + OR cm.cache_creation_tokens IS NOT NULL + OR cm.cache_read_tokens IS NOT NULL + ) + )::bigint AS unpriced_message_count, + COALESCE(SUM(cm.input_tokens), 0)::bigint AS total_input_tokens, + COALESCE(SUM(cm.output_tokens), 0)::bigint AS total_output_tokens +FROM + chat_messages cm +JOIN + chats c ON c.id = cm.chat_id +WHERE + c.owner_id = @owner_id::uuid + AND cm.role = 'assistant' + AND cm.created_at >= @start_date::timestamptz + AND cm.created_at < @end_date::timestamptz; + +-- name: GetChatCostPerModel :many +-- Per-model cost breakdown for a single user within a date range. +-- Only counts assistant-role messages that have a model_config_id. +SELECT + cmc.id AS model_config_id, + cmc.display_name, + cmc.provider, + cmc.model, + COALESCE(SUM(cm.total_cost_micros), 0)::bigint AS total_cost_micros, + COUNT(*) FILTER ( + WHERE cm.input_tokens IS NOT NULL + OR cm.output_tokens IS NOT NULL + OR cm.reasoning_tokens IS NOT NULL + OR cm.cache_creation_tokens IS NOT NULL + OR cm.cache_read_tokens IS NOT NULL + )::bigint AS message_count, + COALESCE(SUM(cm.input_tokens), 0)::bigint AS total_input_tokens, + COALESCE(SUM(cm.output_tokens), 0)::bigint AS total_output_tokens +FROM + chat_messages cm +JOIN + chats c ON c.id = cm.chat_id +JOIN + chat_model_configs cmc ON cmc.id = cm.model_config_id +WHERE + c.owner_id = @owner_id::uuid + AND cm.role = 'assistant' + AND cm.created_at >= @start_date::timestamptz + AND cm.created_at < @end_date::timestamptz +GROUP BY + cmc.id, cmc.display_name, cmc.provider, cmc.model +ORDER BY + total_cost_micros DESC; + +-- name: GetChatCostPerChat :many +-- Per-root-chat cost breakdown for a single user within a date range. +-- Groups by root_chat_id so forked chats roll up under their root. +-- Only counts assistant-role messages. +WITH chat_costs AS ( + SELECT + COALESCE(c.root_chat_id, c.id) AS root_chat_id, + COALESCE(SUM(cm.total_cost_micros), 0)::bigint AS total_cost_micros, + COUNT(*) FILTER ( + WHERE cm.input_tokens IS NOT NULL + OR cm.output_tokens IS NOT NULL + OR cm.reasoning_tokens IS NOT NULL + OR cm.cache_creation_tokens IS NOT NULL + OR cm.cache_read_tokens IS NOT NULL + )::bigint AS message_count, + COALESCE(SUM(cm.input_tokens), 0)::bigint AS total_input_tokens, + COALESCE(SUM(cm.output_tokens), 0)::bigint AS total_output_tokens + FROM chat_messages cm + JOIN chats c ON c.id = cm.chat_id + WHERE c.owner_id = @owner_id::uuid + AND cm.role = 'assistant' + AND cm.created_at >= @start_date::timestamptz + AND cm.created_at < @end_date::timestamptz + GROUP BY COALESCE(c.root_chat_id, c.id) +) +SELECT + cc.root_chat_id, + COALESCE(rc.title, '') AS chat_title, + cc.total_cost_micros, + cc.message_count, + cc.total_input_tokens, + cc.total_output_tokens +FROM chat_costs cc +LEFT JOIN chats rc ON rc.id = cc.root_chat_id +ORDER BY cc.total_cost_micros DESC; + +-- name: GetChatCostPerUser :many +-- Deployment-wide per-user cost rollup within a date range. +-- Only counts assistant-role messages. +WITH chat_cost_users AS ( + SELECT + c.owner_id AS user_id, + u.username, + u.name, + u.avatar_url, + COALESCE(SUM(cm.total_cost_micros), 0)::bigint AS total_cost_micros, + COUNT(*) FILTER ( + WHERE cm.input_tokens IS NOT NULL + OR cm.output_tokens IS NOT NULL + OR cm.reasoning_tokens IS NOT NULL + OR cm.cache_creation_tokens IS NOT NULL + OR cm.cache_read_tokens IS NOT NULL + )::bigint AS message_count, + COUNT(DISTINCT COALESCE(c.root_chat_id, c.id))::bigint AS chat_count, + COALESCE(SUM(cm.input_tokens), 0)::bigint AS total_input_tokens, + COALESCE(SUM(cm.output_tokens), 0)::bigint AS total_output_tokens + FROM + chat_messages cm + JOIN + chats c ON c.id = cm.chat_id + JOIN + users u ON u.id = c.owner_id + WHERE + cm.role = 'assistant' + AND cm.created_at >= @start_date::timestamptz + AND cm.created_at < @end_date::timestamptz + AND ( + @username::text = '' + OR u.username ILIKE '%' || @username::text || '%' + ) + GROUP BY + c.owner_id, + u.username, + u.name, + u.avatar_url +) +SELECT + user_id, + username, + name, + avatar_url, + total_cost_micros, + message_count, + chat_count, + total_input_tokens, + total_output_tokens, + COUNT(*) OVER()::bigint AS total_count +FROM + chat_cost_users +ORDER BY + total_cost_micros DESC, + username ASC +LIMIT + sqlc.arg('page_limit')::int +OFFSET + sqlc.arg('page_offset')::int; diff --git a/coderd/database/sqlc.yaml b/coderd/database/sqlc.yaml index bc5e216726..b692606fab 100644 --- a/coderd/database/sqlc.yaml +++ b/coderd/database/sqlc.yaml @@ -148,6 +148,17 @@ sql: go_type: "database/sql.NullTime" - column: "task_event_data.first_status_after_resume_at" go_type: "database/sql.NullTime" + - db_type: "pg_catalog.numeric" + go_type: + import: "github.com/shopspring/decimal" + type: "Decimal" + package: "decimal" + - db_type: "pg_catalog.numeric" + nullable: true + go_type: + import: "github.com/shopspring/decimal" + type: "NullDecimal" + package: "decimal" rename: group_member: GroupMemberTable group_members_expanded: GroupMember diff --git a/codersdk/chats.go b/codersdk/chats.go index 9881b748be..f793661622 100644 --- a/codersdk/chats.go +++ b/codersdk/chats.go @@ -7,10 +7,13 @@ import ( "io" "mime" "net/http" + "net/url" + "strconv" "strings" "time" "github.com/google/uuid" + "github.com/shopspring/decimal" "github.com/coder/websocket" "github.com/coder/websocket/wsjson" @@ -198,11 +201,11 @@ func ChatMessageFileReference(fileName string, startLine, endLine int, content s } // ChatMessageSource builds a source chat message part. -func ChatMessageSource(sourceID, url, title string) ChatMessagePart { +func ChatMessageSource(sourceID, sourceURL, title string) ChatMessagePart { return ChatMessagePart{ Type: ChatMessagePartTypeSource, SourceID: sourceID, - URL: url, + URL: sourceURL, Title: title, } } @@ -519,13 +522,10 @@ type ChatModelVercelProviderOptions struct { // ModelCostConfig stores pricing metadata for a chat model. type ModelCostConfig struct { - // Pricing is stored as configuration metadata and currently only needs to - // round-trip cleanly through the API and admin UI. If we later use these - // values for billing-grade arithmetic, switch to a fixed-point type. - InputPricePerMillionTokens *float64 `json:"input_price_per_million_tokens,omitempty" description:"Input token price in USD per 1M tokens"` - OutputPricePerMillionTokens *float64 `json:"output_price_per_million_tokens,omitempty" description:"Output token price in USD per 1M tokens"` - CacheReadPricePerMillionTokens *float64 `json:"cache_read_price_per_million_tokens,omitempty" description:"Cache read token price in USD per 1M tokens"` - CacheWritePricePerMillionTokens *float64 `json:"cache_write_price_per_million_tokens,omitempty" description:"Cache write or cache creation token price in USD per 1M tokens"` + InputPricePerMillionTokens *decimal.Decimal `json:"input_price_per_million_tokens,omitempty" description:"Input token price in USD per 1M tokens"` + OutputPricePerMillionTokens *decimal.Decimal `json:"output_price_per_million_tokens,omitempty" description:"Output token price in USD per 1M tokens"` + CacheReadPricePerMillionTokens *decimal.Decimal `json:"cache_read_price_per_million_tokens,omitempty" description:"Cache read token price in USD per 1M tokens"` + CacheWritePricePerMillionTokens *decimal.Decimal `json:"cache_write_price_per_million_tokens,omitempty" description:"Cache write or cache creation token price in USD per 1M tokens"` } // ChatModelCallConfig configures per-call model behavior defaults. @@ -546,10 +546,10 @@ func (c *ChatModelCallConfig) UnmarshalJSON(data []byte) error { type chatModelCallConfigAlias ChatModelCallConfig aux := struct { *chatModelCallConfigAlias - InputPricePerMillionTokens *float64 `json:"input_price_per_million_tokens,omitempty"` - OutputPricePerMillionTokens *float64 `json:"output_price_per_million_tokens,omitempty"` - CacheReadPricePerMillionTokens *float64 `json:"cache_read_price_per_million_tokens,omitempty"` - CacheWritePricePerMillionTokens *float64 `json:"cache_write_price_per_million_tokens,omitempty"` + InputPricePerMillionTokens *decimal.Decimal `json:"input_price_per_million_tokens,omitempty"` + OutputPricePerMillionTokens *decimal.Decimal `json:"output_price_per_million_tokens,omitempty"` + CacheReadPricePerMillionTokens *decimal.Decimal `json:"cache_read_price_per_million_tokens,omitempty"` + CacheWritePricePerMillionTokens *decimal.Decimal `json:"cache_write_price_per_million_tokens,omitempty"` }{ chatModelCallConfigAlias: (*chatModelCallConfigAlias)(c), } @@ -710,6 +710,76 @@ type chatStreamEnvelope struct { Data json.RawMessage `json:"data,omitempty"` } +// ChatCostSummaryOptions are optional query parameters for GetChatCostSummary. +type ChatCostSummaryOptions struct { + StartDate time.Time + EndDate time.Time +} + +// ChatCostUsersOptions are optional query parameters for GetChatCostUsers. +type ChatCostUsersOptions struct { + StartDate time.Time + EndDate time.Time + Username string + Pagination +} + +// ChatCostSummary is the response from the chat cost summary endpoint. +type ChatCostSummary struct { + StartDate time.Time `json:"start_date" format:"date-time"` + EndDate time.Time `json:"end_date" format:"date-time"` + TotalCostMicros int64 `json:"total_cost_micros"` + PricedMessageCount int64 `json:"priced_message_count"` + UnpricedMessageCount int64 `json:"unpriced_message_count"` + TotalInputTokens int64 `json:"total_input_tokens"` + TotalOutputTokens int64 `json:"total_output_tokens"` + ByModel []ChatCostModelBreakdown `json:"by_model"` + ByChat []ChatCostChatBreakdown `json:"by_chat"` +} + +// ChatCostModelBreakdown contains per-model cost aggregation. +type ChatCostModelBreakdown struct { + ModelConfigID uuid.UUID `json:"model_config_id" format:"uuid"` + DisplayName string `json:"display_name"` + Provider string `json:"provider"` + Model string `json:"model"` + TotalCostMicros int64 `json:"total_cost_micros"` + MessageCount int64 `json:"message_count"` + TotalInputTokens int64 `json:"total_input_tokens"` + TotalOutputTokens int64 `json:"total_output_tokens"` +} + +// ChatCostChatBreakdown contains per-root-chat cost aggregation. +type ChatCostChatBreakdown struct { + RootChatID uuid.UUID `json:"root_chat_id" format:"uuid"` + ChatTitle string `json:"chat_title"` + TotalCostMicros int64 `json:"total_cost_micros"` + MessageCount int64 `json:"message_count"` + TotalInputTokens int64 `json:"total_input_tokens"` + TotalOutputTokens int64 `json:"total_output_tokens"` +} + +// ChatCostUserRollup contains per-user cost aggregation for admin views. +type ChatCostUserRollup struct { + UserID uuid.UUID `json:"user_id" format:"uuid"` + Username string `json:"username"` + Name string `json:"name"` + AvatarURL string `json:"avatar_url"` + TotalCostMicros int64 `json:"total_cost_micros"` + MessageCount int64 `json:"message_count"` + ChatCount int64 `json:"chat_count"` + TotalInputTokens int64 `json:"total_input_tokens"` + TotalOutputTokens int64 `json:"total_output_tokens"` +} + +// ChatCostUsersResponse is the response from the admin chat cost users endpoint. +type ChatCostUsersResponse struct { + StartDate time.Time `json:"start_date" format:"date-time"` + EndDate time.Time `json:"end_date" format:"date-time"` + Count int64 `json:"count"` + Users []ChatCostUserRollup `json:"users"` +} + // ListChatsOptions are optional parameters for ListChats. type ListChatsOptions struct { Query string @@ -872,6 +942,71 @@ func (c *Client) DeleteChatModelConfig(ctx context.Context, modelConfigID uuid.U return nil } +// GetChatCostSummary returns an aggregate cost summary for the specified +// user. Zero-valued StartDate or EndDate fields are omitted from the +// request, letting the server apply its own defaults (typically the last +// 30 days). +func (c *Client) GetChatCostSummary(ctx context.Context, user string, opts ChatCostSummaryOptions) (ChatCostSummary, error) { + qp := url.Values{} + if !opts.StartDate.IsZero() { + qp.Set("start_date", opts.StartDate.Format(time.RFC3339)) + } + if !opts.EndDate.IsZero() { + qp.Set("end_date", opts.EndDate.Format(time.RFC3339)) + } + reqURL := fmt.Sprintf("/api/experimental/chats/cost/%s/summary", user) + if len(qp) > 0 { + reqURL += "?" + qp.Encode() + } + res, err := c.Request(ctx, http.MethodGet, reqURL, nil) + if err != nil { + return ChatCostSummary{}, err + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return ChatCostSummary{}, ReadBodyAsError(res) + } + var summary ChatCostSummary + return summary, json.NewDecoder(res.Body).Decode(&summary) +} + +// GetChatCostUsers returns a per-user cost rollup for the deployment +// (admin only). Zero-valued StartDate or EndDate fields are omitted from +// the request, letting the server apply its own defaults (typically the +// last 30 days). +func (c *Client) GetChatCostUsers(ctx context.Context, opts ChatCostUsersOptions) (ChatCostUsersResponse, error) { + qp := url.Values{} + if !opts.StartDate.IsZero() { + qp.Set("start_date", opts.StartDate.Format(time.RFC3339)) + } + if !opts.EndDate.IsZero() { + qp.Set("end_date", opts.EndDate.Format(time.RFC3339)) + } + if opts.Username != "" { + qp.Set("username", opts.Username) + } + if opts.Limit > 0 { + qp.Set("limit", strconv.Itoa(opts.Limit)) + } + if opts.Offset > 0 { + qp.Set("offset", strconv.Itoa(opts.Offset)) + } + reqURL := "/api/experimental/chats/cost/users" + if len(qp) > 0 { + reqURL += "?" + qp.Encode() + } + res, err := c.Request(ctx, http.MethodGet, reqURL, nil) + if err != nil { + return ChatCostUsersResponse{}, err + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return ChatCostUsersResponse{}, ReadBodyAsError(res) + } + var resp ChatCostUsersResponse + return resp, json.NewDecoder(res.Body).Decode(&resp) +} + // GetChatSystemPrompt returns the deployment-wide chat system prompt. func (c *Client) GetChatSystemPrompt(ctx context.Context) (ChatSystemPromptResponse, error) { res, err := c.Request(ctx, http.MethodGet, "/api/experimental/chats/config/system-prompt", nil) diff --git a/codersdk/chats_test.go b/codersdk/chats_test.go index 5e61aa81b5..7526ea23ce 100644 --- a/codersdk/chats_test.go +++ b/codersdk/chats_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/google/uuid" + "github.com/shopspring/decimal" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -109,3 +110,65 @@ func TestChatMessagePart_StripInternal(t *testing.T) { assert.Equal(t, codersdk.ChatMessagePartTypeText, part.Type) }) } + +func TestModelCostConfig_LegacyNumericJSON(t *testing.T) { + t.Parallel() + + var decoded codersdk.ModelCostConfig + err := json.Unmarshal([]byte("{\"input_price_per_million_tokens\": 1.5}"), &decoded) + require.NoError(t, err) + require.NotNil(t, decoded.InputPricePerMillionTokens) + require.True(t, decoded.InputPricePerMillionTokens.Equal(decimal.RequireFromString("1.5"))) +} + +func TestModelCostConfig_QuotedDecimalJSON(t *testing.T) { + t.Parallel() + + var decoded codersdk.ModelCostConfig + err := json.Unmarshal([]byte("{\"input_price_per_million_tokens\": \"1.5\"}"), &decoded) + require.NoError(t, err) + require.NotNil(t, decoded.InputPricePerMillionTokens) + require.True(t, decoded.InputPricePerMillionTokens.Equal(decimal.RequireFromString("1.5"))) +} + +func TestModelCostConfig_NilVsZero(t *testing.T) { + t.Parallel() + + zero := decimal.Zero + raw, err := json.Marshal(struct { + Nil codersdk.ModelCostConfig `json:"nil"` + Zero codersdk.ModelCostConfig `json:"zero"` + }{ + Nil: codersdk.ModelCostConfig{}, + Zero: codersdk.ModelCostConfig{InputPricePerMillionTokens: &zero}, + }) + require.NoError(t, err) + require.Contains(t, string(raw), "\"zero\":{\"input_price_per_million_tokens\":\"0\"}") + require.Contains(t, string(raw), "\"nil\":{}") +} + +func TestChatModelCallConfig_UnmarshalLegacyPricing(t *testing.T) { + t.Parallel() + + var decoded codersdk.ChatModelCallConfig + err := json.Unmarshal([]byte("{\"input_price_per_million_tokens\": 1.5}"), &decoded) + require.NoError(t, err) + require.NotNil(t, decoded.Cost) + require.NotNil(t, decoded.Cost.InputPricePerMillionTokens) + require.True(t, decoded.Cost.InputPricePerMillionTokens.Equal(decimal.RequireFromString("1.5"))) +} + +func TestChatCostSummary_JSONRoundTrip(t *testing.T) { + t.Parallel() + + original := codersdk.ChatCostSummary{ + TotalCostMicros: 123, + } + raw, err := json.Marshal(original) + require.NoError(t, err) + + var decoded codersdk.ChatCostSummary + err = json.Unmarshal(raw, &decoded) + require.NoError(t, err) + require.Equal(t, original.TotalCostMicros, decoded.TotalCostMicros) +} diff --git a/docs/reference/api/chat.md b/docs/reference/api/chat.md new file mode 100644 index 0000000000..279df4ad79 --- /dev/null +++ b/docs/reference/api/chat.md @@ -0,0 +1 @@ +# Chat diff --git a/go.mod b/go.mod index a22334d40b..a6f51ca425 100644 --- a/go.mod +++ b/go.mod @@ -492,6 +492,7 @@ require ( github.com/go-git/go-git/v5 v5.17.0 github.com/mark3labs/mcp-go v0.38.0 github.com/openai/openai-go/v3 v3.15.0 + github.com/shopspring/decimal v1.4.0 gonum.org/v1/gonum v0.17.0 ) diff --git a/go.sum b/go.sum index f18b4984f0..30ddc05e1d 100644 --- a/go.sum +++ b/go.sum @@ -1059,6 +1059,8 @@ github.com/sergi/go-diff v1.4.0 h1:n/SP9D5ad1fORl+llWyN+D6qoUETXNZARKjyY2/KVCw= github.com/sergi/go-diff v1.4.0/go.mod h1:A0bzQcvG0E7Rwjx0REVgAGH58e96+X0MeOfepqsbeW4= github.com/shirou/gopsutil/v4 v4.26.1 h1:TOkEyriIXk2HX9d4isZJtbjXbEjf5qyKPAzbzY0JWSo= github.com/shirou/gopsutil/v4 v4.26.1/go.mod h1:medLI9/UNAb0dOI9Q3/7yWSqKkj00u+1tgY8nvv41pc= +github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k= +github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME= github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0= github.com/sirupsen/logrus v1.9.4 h1:TsZE7l11zFCLZnZ+teH4Umoq5BhEIfIzfRDZ1Uzql2w= github.com/sirupsen/logrus v1.9.4/go.mod h1:ftWc9WdOfJ0a92nsE2jF5u5ZwH8Bv2zdeOC42RjbV2g= diff --git a/scripts/apitypings/main.go b/scripts/apitypings/main.go index 65483a34bc..599f2b00bf 100644 --- a/scripts/apitypings/main.go +++ b/scripts/apitypings/main.go @@ -130,6 +130,10 @@ func TypeMappings(gen *guts.GoParser) error { "github.com/coder/serpent.URL": "string", "github.com/coder/serpent.HostPort": "string", "encoding/json.RawMessage": "map[string]string", + // decimal.Decimal preserves exact pricing precision (e.g. $3.50 per + // million tokens) and serializes as a JSON string to avoid + // floating-point loss in transit. + "github.com/shopspring/decimal.Decimal": "string", }) if err != nil { return xerrors.Errorf("include custom: %w", err) diff --git a/scripts/modeloptionsgen/main.go b/scripts/modeloptionsgen/main.go index 7f4d039e18..5f6746bbb5 100644 --- a/scripts/modeloptionsgen/main.go +++ b/scripts/modeloptionsgen/main.go @@ -7,6 +7,8 @@ import ( "reflect" "strings" + "github.com/shopspring/decimal" + "github.com/coder/coder/v2/codersdk" ) @@ -117,11 +119,15 @@ func extractFields(t reflect.Type, prefix string, skip map[string]bool) FieldGro // so that entire sub-objects can be marked hidden. hidden := f.Tag.Get("hidden") == "true" + // decimal.Decimal is an opaque numeric type used for pricing + // precision; do not recurse into its internal struct fields. + isDecimal := ft == reflect.TypeOf(decimal.Decimal{}) + // If the field is a struct (not a map), recurse to flatten // its children using dot-separated names — unless the // entire struct is marked hidden, in which case emit it // as a single opaque field. - if ft.Kind() == reflect.Struct && !hidden { + if ft.Kind() == reflect.Struct && !hidden && !isDecimal { nested := extractFields(ft, fullJSONName, nil) fields = append(fields, nested.Fields...) continue @@ -206,6 +212,12 @@ func goTypeToSchemaType(t reflect.Type) string { t = t.Elem() } + // decimal.Decimal represents a precise numeric value and should + // map to the "number" schema type. + if t == reflect.TypeOf(decimal.Decimal{}) { + return "number" + } + switch t.Kind() { case reflect.String: return "string" diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index bcae027ddf..0a08487faf 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -1071,6 +1071,96 @@ export interface Chat { readonly archived: boolean; } +// From codersdk/chats.go +/** + * ChatCostChatBreakdown contains per-root-chat cost aggregation. + */ +export interface ChatCostChatBreakdown { + readonly root_chat_id: string; + readonly chat_title: string; + readonly total_cost_micros: number; + readonly message_count: number; + readonly total_input_tokens: number; + readonly total_output_tokens: number; +} + +// From codersdk/chats.go +/** + * ChatCostModelBreakdown contains per-model cost aggregation. + */ +export interface ChatCostModelBreakdown { + readonly model_config_id: string; + readonly display_name: string; + readonly provider: string; + readonly model: string; + readonly total_cost_micros: number; + readonly message_count: number; + readonly total_input_tokens: number; + readonly total_output_tokens: number; +} + +// From codersdk/chats.go +/** + * ChatCostSummary is the response from the chat cost summary endpoint. + */ +export interface ChatCostSummary { + readonly start_date: string; + readonly end_date: string; + readonly total_cost_micros: number; + readonly priced_message_count: number; + readonly unpriced_message_count: number; + readonly total_input_tokens: number; + readonly total_output_tokens: number; + readonly by_model: readonly ChatCostModelBreakdown[]; + readonly by_chat: readonly ChatCostChatBreakdown[]; +} + +// From codersdk/chats.go +/** + * ChatCostSummaryOptions are optional query parameters for GetChatCostSummary. + */ +export interface ChatCostSummaryOptions { + readonly StartDate: string; + readonly EndDate: string; +} + +// From codersdk/chats.go +/** + * ChatCostUserRollup contains per-user cost aggregation for admin views. + */ +export interface ChatCostUserRollup { + readonly user_id: string; + readonly username: string; + readonly name: string; + readonly avatar_url: string; + readonly total_cost_micros: number; + readonly message_count: number; + readonly chat_count: number; + readonly total_input_tokens: number; + readonly total_output_tokens: number; +} + +// From codersdk/chats.go +/** + * ChatCostUsersOptions are optional query parameters for GetChatCostUsers. + */ +export interface ChatCostUsersOptions extends Pagination { + readonly StartDate: string; + readonly EndDate: string; + readonly Username: string; +} + +// From codersdk/chats.go +/** + * ChatCostUsersResponse is the response from the admin chat cost users endpoint. + */ +export interface ChatCostUsersResponse { + readonly start_date: string; + readonly end_date: string; + readonly count: number; + readonly users: readonly ChatCostUserRollup[]; +} + // From codersdk/chats.go /** * ChatDiffContents represents the resolved diff text for a chat. @@ -3466,15 +3556,10 @@ export interface MinimalUser { * ModelCostConfig stores pricing metadata for a chat model. */ export interface ModelCostConfig { - /** - * Pricing is stored as configuration metadata and currently only needs to - * round-trip cleanly through the API and admin UI. If we later use these - * values for billing-grade arithmetic, switch to a fixed-point type. - */ - readonly input_price_per_million_tokens?: number; - readonly output_price_per_million_tokens?: number; - readonly cache_read_price_per_million_tokens?: number; - readonly cache_write_price_per_million_tokens?: number; + readonly input_price_per_million_tokens?: string; + readonly output_price_per_million_tokens?: string; + readonly cache_read_price_per_million_tokens?: string; + readonly cache_write_price_per_million_tokens?: string; } // From netcheck/netcheck.go diff --git a/site/src/pages/AgentsPage/ChatModelAdminPanel/ModelsSection.test.tsx b/site/src/pages/AgentsPage/ChatModelAdminPanel/ModelsSection.test.tsx deleted file mode 100644 index 190b5c7720..0000000000 --- a/site/src/pages/AgentsPage/ChatModelAdminPanel/ModelsSection.test.tsx +++ /dev/null @@ -1,99 +0,0 @@ -import { render, screen } from "@testing-library/react"; -import type * as TypesGen from "api/typesGenerated"; -import { TooltipProvider } from "components/Tooltip/Tooltip"; -import { describe, expect, it, vi } from "vitest"; -import type { ProviderState } from "./ChatModelAdminPanel"; -import { ModelsSection } from "./ModelsSection"; - -vi.mock("./ProviderIcon", () => ({ - ProviderIcon: ({ provider }: { provider: string }) => ( -