feat: add chat cost analytics backend (#23036)

Add cost tracking for LLM chat interactions with microdollar precision.

## Changes
- Add `chatcost` package for per-message cost calculation using
`shopspring/decimal` for intermediate arithmetic
- **Ceil rounding policy**: fractional micros round UP to next whole
micro (applied once after summing all components)
- Database migration: `total_cost_micros` BIGINT column with historical
backfill and `created_at` index
- API endpoints: per-user cost summary and admin rollup under
`/api/experimental/chats/cost/`
- SDK types: `ChatCostSummary`, `ChatCostModelBreakdown`,
`ChatCostUserRollup`
- Fix `modeloptionsgen` to handle `decimal.Decimal` as opaque numeric
type
- Update frontend pricing test fixtures for string decimal types

## Design decisions
- `NULL` = unpriced (no matching model config), `0` = free
- Reasoning tokens included in output tokens (no double-counting)
- Integer microdollars (BIGINT) for storage and API responses
- Price config uses `decimal.Decimal` for exact parsing; totals use
`int64`

Frontend: #23037
This commit is contained in:
Michael Suchacz
2026-03-13 18:30:49 +01:00
committed by GitHub
parent 1152b61ebb
commit c3b6284955
34 changed files with 2034 additions and 262 deletions
+25
View File
@@ -100,6 +100,31 @@ app, err := api.Database.GetOAuth2ProviderAppByClientID(dbauthz.AsSystemRestrict
app, err := api.Database.GetOAuth2ProviderAppByClientID(ctx, clientID)
```
### API Design
- Add swagger annotations when introducing new HTTP endpoints. Do this in
the same change as the handler so the docs do not get missed before
release.
- For user-scoped or resource-scoped routes, prefer path parameters over
query parameters when that matches existing route patterns.
- For experimental or unstable API paths, skip public doc generation with
`// @x-apidocgen {"skip": true}` after the `@Router` annotation. This
keeps them out of the published API reference until they stabilize.
### Database Query Naming
- Use `ByX` when `X` is the lookup or filter column.
- Use `PerX` or `GroupedByX` when `X` is the aggregation or grouping
dimension.
- Avoid `ByX` names for grouped queries.
### Database-to-SDK Conversions
- Extract explicit db-to-SDK conversion helpers instead of inlining large
conversion blocks inside handlers.
- Keep nullable-field handling, type coercion, and response shaping in the
converter so handlers stay focused on request flow and authorization.
## Quick Reference
### Full workflows available in imported WORKFLOWS.md
+71
View File
@@ -0,0 +1,71 @@
package chatcost
import (
"github.com/shopspring/decimal"
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/codersdk"
)
// Returns cost in micros -- millionths of a dollar, rounded up to the next
// whole microdollar.
// Returns nil when pricing is not configured or when all priced usage fields
// are nil, allowing callers to distinguish "zero cost" from "unpriced".
func CalculateTotalCostMicros(
usage codersdk.ChatMessageUsage,
cost *codersdk.ModelCostConfig,
) *int64 {
if cost == nil {
return nil
}
// A cost config with no prices set means pricing is effectively
// unconfigured — return nil (unpriced) rather than zero.
if cost.InputPricePerMillionTokens == nil &&
cost.OutputPricePerMillionTokens == nil &&
cost.CacheReadPricePerMillionTokens == nil &&
cost.CacheWritePricePerMillionTokens == nil {
return nil
}
if usage.InputTokens == nil &&
usage.OutputTokens == nil &&
usage.ReasoningTokens == nil &&
usage.CacheCreationTokens == nil &&
usage.CacheReadTokens == nil {
return nil
}
// OutputTokens already includes reasoning tokens per provider
// semantics (e.g. OpenAI's completion_tokens encompasses
// reasoning_tokens). Adding ReasoningTokens here would
// double-count.
// Preserve nil when usage exists only in categories without configured
// pricing, so callers can distinguish "unpriced" from "priced at zero".
hasMatchingPrice := (usage.InputTokens != nil && cost.InputPricePerMillionTokens != nil) ||
(usage.OutputTokens != nil && cost.OutputPricePerMillionTokens != nil) ||
(usage.CacheReadTokens != nil && cost.CacheReadPricePerMillionTokens != nil) ||
(usage.CacheCreationTokens != nil && cost.CacheWritePricePerMillionTokens != nil)
if !hasMatchingPrice {
return nil
}
inputMicros := calcCost(usage.InputTokens, cost.InputPricePerMillionTokens)
outputMicros := calcCost(usage.OutputTokens, cost.OutputPricePerMillionTokens)
cacheReadMicros := calcCost(usage.CacheReadTokens, cost.CacheReadPricePerMillionTokens)
cacheWriteMicros := calcCost(usage.CacheCreationTokens, cost.CacheWritePricePerMillionTokens)
total := inputMicros.
Add(outputMicros).
Add(cacheReadMicros).
Add(cacheWriteMicros)
rounded := total.Ceil().IntPart()
return &rounded
}
// calcCost returns the cost in fractional microdollars (millionths of a USD)
// for the given token count at the specified per-million-token price.
func calcCost(tokens *int64, pricePerMillion *decimal.Decimal) decimal.Decimal {
return decimal.NewFromInt(ptr.NilToEmpty(tokens)).Mul(ptr.NilToEmpty(pricePerMillion))
}
+163
View File
@@ -0,0 +1,163 @@
package chatcost_test
import (
"testing"
"github.com/shopspring/decimal"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/chatd/chatcost"
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/codersdk"
)
func TestCalculateTotalCostMicros(t *testing.T) {
t.Parallel()
tests := []struct {
name string
usage codersdk.ChatMessageUsage
cost *codersdk.ModelCostConfig
want *int64
}{
{
name: "nil cost returns nil",
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](1000)},
cost: nil,
want: nil,
},
{
name: "all priced usage fields nil returns nil",
usage: codersdk.ChatMessageUsage{
TotalTokens: ptr.Ref[int64](1234),
ContextLimit: ptr.Ref[int64](8192),
},
cost: &codersdk.ModelCostConfig{
InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("3")),
},
want: nil,
},
{
name: "sub-micro total rounds up to 1",
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](1)},
cost: &codersdk.ModelCostConfig{
InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("0.01")),
},
want: ptr.Ref[int64](1),
},
{
name: "simple input only",
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](1000)},
cost: &codersdk.ModelCostConfig{
InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("3")),
},
want: ptr.Ref[int64](3000),
},
{
name: "simple output only",
usage: codersdk.ChatMessageUsage{OutputTokens: ptr.Ref[int64](500)},
cost: &codersdk.ModelCostConfig{
OutputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("15")),
},
want: ptr.Ref[int64](7500),
},
{
name: "reasoning tokens included in output total",
usage: codersdk.ChatMessageUsage{
OutputTokens: ptr.Ref[int64](500),
ReasoningTokens: ptr.Ref[int64](200),
},
cost: &codersdk.ModelCostConfig{
OutputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("15")),
},
want: ptr.Ref[int64](7500),
},
{
name: "cache read tokens",
usage: codersdk.ChatMessageUsage{CacheReadTokens: ptr.Ref[int64](10000)},
cost: &codersdk.ModelCostConfig{
CacheReadPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("0.3")),
},
want: ptr.Ref[int64](3000),
},
{
name: "cache creation tokens",
usage: codersdk.ChatMessageUsage{CacheCreationTokens: ptr.Ref[int64](5000)},
cost: &codersdk.ModelCostConfig{
CacheWritePricePerMillionTokens: ptr.Ref(decimal.RequireFromString("3.75")),
},
want: ptr.Ref[int64](18750),
},
{
name: "full mixed usage totals all components exactly",
usage: codersdk.ChatMessageUsage{
InputTokens: ptr.Ref[int64](101),
OutputTokens: ptr.Ref[int64](201),
ReasoningTokens: ptr.Ref[int64](52),
CacheReadTokens: ptr.Ref[int64](1005),
CacheCreationTokens: ptr.Ref[int64](33),
TotalTokens: ptr.Ref[int64](1391),
ContextLimit: ptr.Ref[int64](4096),
},
cost: &codersdk.ModelCostConfig{
InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("1.23")),
OutputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("4.56")),
CacheReadPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("0.7")),
CacheWritePricePerMillionTokens: ptr.Ref(decimal.RequireFromString("7.89")),
},
want: ptr.Ref[int64](2005),
},
{
name: "partial pricing only input contributes",
usage: codersdk.ChatMessageUsage{
InputTokens: ptr.Ref[int64](1234),
OutputTokens: ptr.Ref[int64](999),
ReasoningTokens: ptr.Ref[int64](111),
CacheReadTokens: ptr.Ref[int64](500),
CacheCreationTokens: ptr.Ref[int64](250),
},
cost: &codersdk.ModelCostConfig{
InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("2.5")),
},
want: ptr.Ref[int64](3085),
},
{
name: "zero tokens with pricing returns zero pointer",
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](0)},
cost: &codersdk.ModelCostConfig{
InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("3")),
},
want: ptr.Ref[int64](0),
},
{
name: "usage only in unpriced categories returns nil",
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](1000)},
cost: &codersdk.ModelCostConfig{
OutputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("15")),
},
want: nil,
},
{
name: "non nil usage with empty cost config returns nil",
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](42)},
cost: &codersdk.ModelCostConfig{},
want: nil,
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := chatcost.CalculateTotalCostMicros(tt.usage, tt.cost)
if tt.want == nil {
require.Nil(t, got)
} else {
require.NotNil(t, got)
require.Equal(t, *tt.want, *got)
}
})
}
}
+49
View File
@@ -19,6 +19,7 @@ import (
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/chatd/chatcost"
"github.com/coder/coder/v2/coderd/chatd/chatloop"
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
"github.com/coder/coder/v2/coderd/chatd/chatprovider"
@@ -293,6 +294,7 @@ func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.C
CacheReadTokens: sql.NullInt64{},
ContextLimit: sql.NullInt64{},
Compressed: sql.NullBool{},
TotalCostMicros: sql.NullInt64{},
})
if err != nil {
return xerrors.Errorf("insert system message: %w", err)
@@ -321,6 +323,7 @@ func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.C
CacheCreationTokens: sql.NullInt64{},
CacheReadTokens: sql.NullInt64{},
ContextLimit: sql.NullInt64{},
TotalCostMicros: sql.NullInt64{},
Compressed: sql.NullBool{},
})
if err != nil {
@@ -910,6 +913,7 @@ func insertUserMessageAndSetPending(
CacheCreationTokens: sql.NullInt64{},
CacheReadTokens: sql.NullInt64{},
ContextLimit: sql.NullInt64{},
TotalCostMicros: sql.NullInt64{},
Compressed: sql.NullBool{},
})
if err != nil {
@@ -1969,6 +1973,7 @@ func (p *Server) processChat(ctx context.Context, chat database.Chat) {
CacheCreationTokens: sql.NullInt64{},
CacheReadTokens: sql.NullInt64{},
ContextLimit: sql.NullInt64{},
TotalCostMicros: sql.NullInt64{},
Compressed: sql.NullBool{},
})
if insertErr != nil {
@@ -2347,6 +2352,30 @@ func (p *Server) runChat(
}
hasUsage := step.Usage != (fantasy.Usage{})
var usageForCost codersdk.ChatMessageUsage
if hasUsage {
// Only populate fields that the provider explicitly
// reported. Nil fields tell the calculator "no data"
// vs zero meaning "reported as zero tokens."
if step.Usage.InputTokens != 0 {
usageForCost.InputTokens = int64Ptr(step.Usage.InputTokens)
}
if step.Usage.OutputTokens != 0 {
usageForCost.OutputTokens = int64Ptr(step.Usage.OutputTokens)
}
if step.Usage.ReasoningTokens != 0 {
usageForCost.ReasoningTokens = int64Ptr(step.Usage.ReasoningTokens)
}
if step.Usage.CacheCreationTokens != 0 {
usageForCost.CacheCreationTokens = int64Ptr(step.Usage.CacheCreationTokens)
}
if step.Usage.CacheReadTokens != 0 {
usageForCost.CacheReadTokens = int64Ptr(step.Usage.CacheReadTokens)
}
}
totalCostMicros := chatcost.CalculateTotalCostMicros(usageForCost, callConfig.Cost)
assistantMessage, insertErr := tx.InsertChatMessage(persistCtx, database.InsertChatMessageParams{
ChatID: chat.ID,
CreatedBy: uuid.NullUUID{},
@@ -2369,6 +2398,11 @@ func (p *Server) runChat(
CacheReadTokens: usageNullInt64(step.Usage.CacheReadTokens, hasUsage),
ContextLimit: step.ContextLimit,
Compressed: sql.NullBool{},
// TotalCostMicros is nullable: NULL means "unpriced"
// (pricing config was missing or no priced token
// breakdown available), while 0 means "priced at
// zero cost" (e.g., a free model).
TotalCostMicros: usageNullInt64Ptr(totalCostMicros),
})
if insertErr != nil {
return xerrors.Errorf("insert assistant message: %w", insertErr)
@@ -2398,6 +2432,7 @@ func (p *Server) runChat(
CacheCreationTokens: sql.NullInt64{},
CacheReadTokens: sql.NullInt64{},
ContextLimit: sql.NullInt64{},
TotalCostMicros: sql.NullInt64{},
Compressed: sql.NullBool{},
})
if insertErr != nil {
@@ -2740,6 +2775,7 @@ func (p *Server) persistChatContextSummary(
CacheCreationTokens: sql.NullInt64{},
CacheReadTokens: sql.NullInt64{},
ContextLimit: sql.NullInt64{},
TotalCostMicros: sql.NullInt64{},
})
if txErr != nil {
return xerrors.Errorf("insert hidden summary message: %w", txErr)
@@ -2764,6 +2800,7 @@ func (p *Server) persistChatContextSummary(
CacheCreationTokens: sql.NullInt64{},
CacheReadTokens: sql.NullInt64{},
ContextLimit: sql.NullInt64{},
TotalCostMicros: sql.NullInt64{},
})
if txErr != nil {
return xerrors.Errorf("insert summary tool call message: %w", txErr)
@@ -2789,6 +2826,7 @@ func (p *Server) persistChatContextSummary(
CacheCreationTokens: sql.NullInt64{},
CacheReadTokens: sql.NullInt64{},
ContextLimit: sql.NullInt64{},
TotalCostMicros: sql.NullInt64{},
})
if txErr != nil {
return xerrors.Errorf("insert summary tool result message: %w", txErr)
@@ -2901,6 +2939,10 @@ func (p *Server) resolveModelConfig(
return defaultConfig, nil
}
func int64Ptr(value int64) *int64 {
return &value
}
//nolint:revive // Boolean controls SQL NULL validity.
func usageNullInt64(value int64, valid bool) sql.NullInt64 {
if !valid {
@@ -2912,6 +2954,13 @@ func usageNullInt64(value int64, valid bool) sql.NullInt64 {
}
}
func usageNullInt64Ptr(v *int64) sql.NullInt64 {
if v == nil {
return sql.NullInt64{}
}
return sql.NullInt64{Int64: *v, Valid: true}
}
func refreshChatWorkspaceSnapshot(
ctx context.Context,
chat database.Chat,
-28
View File
@@ -553,34 +553,6 @@ func normalizedEnumValue(value string, allowed ...string) *string {
return nil
}
// MergeMissingCallConfig fills unset call config values from a provider or
// profile default config.
func MergeMissingCallConfig(
dst *codersdk.ChatModelCallConfig,
defaults codersdk.ChatModelCallConfig,
) {
if dst.MaxOutputTokens == nil {
dst.MaxOutputTokens = defaults.MaxOutputTokens
}
if dst.Temperature == nil {
dst.Temperature = defaults.Temperature
}
if dst.TopP == nil {
dst.TopP = defaults.TopP
}
if dst.TopK == nil {
dst.TopK = defaults.TopK
}
if dst.PresencePenalty == nil {
dst.PresencePenalty = defaults.PresencePenalty
}
if dst.FrequencyPenalty == nil {
dst.FrequencyPenalty = defaults.FrequencyPenalty
}
MergeMissingModelCostConfig(&dst.Cost, defaults.Cost)
MergeMissingProviderOptions(&dst.ProviderOptions, defaults.ProviderOptions)
}
// MergeMissingModelCostConfig fills unset pricing metadata from defaults.
func MergeMissingModelCostConfig(
dst **codersdk.ModelCostConfig,
@@ -137,61 +137,6 @@ func TestMergeMissingProviderOptions_OpenRouterNested(t *testing.T) {
require.Equal(t, "latency", *options.OpenRouter.Provider.Sort)
}
func TestMergeMissingCallConfig_FillsUnsetFields(t *testing.T) {
t.Parallel()
dst := codersdk.ChatModelCallConfig{
Temperature: float64Ptr(0.2),
Cost: &codersdk.ModelCostConfig{
OutputPricePerMillionTokens: float64Ptr(0.7),
},
ProviderOptions: &codersdk.ChatModelProviderOptions{
OpenAI: &codersdk.ChatModelOpenAIProviderOptions{
User: stringPtr("alice"),
},
},
}
defaultCallConfig := codersdk.ChatModelCallConfig{
MaxOutputTokens: int64Ptr(512),
Temperature: float64Ptr(0.9),
TopP: float64Ptr(0.8),
Cost: &codersdk.ModelCostConfig{
InputPricePerMillionTokens: float64Ptr(0.15),
OutputPricePerMillionTokens: float64Ptr(0.9),
CacheReadPricePerMillionTokens: float64Ptr(0.03),
CacheWritePricePerMillionTokens: float64Ptr(0.3),
},
ProviderOptions: &codersdk.ChatModelProviderOptions{
OpenAI: &codersdk.ChatModelOpenAIProviderOptions{
User: stringPtr("bob"),
ReasoningEffort: stringPtr("medium"),
},
},
}
chatprovider.MergeMissingCallConfig(&dst, defaultCallConfig)
require.NotNil(t, dst.MaxOutputTokens)
require.EqualValues(t, 512, *dst.MaxOutputTokens)
require.NotNil(t, dst.Temperature)
require.Equal(t, 0.2, *dst.Temperature)
require.NotNil(t, dst.TopP)
require.Equal(t, 0.8, *dst.TopP)
require.NotNil(t, dst.Cost)
require.NotNil(t, dst.Cost.InputPricePerMillionTokens)
require.Equal(t, 0.15, *dst.Cost.InputPricePerMillionTokens)
require.NotNil(t, dst.Cost.OutputPricePerMillionTokens)
require.Equal(t, 0.7, *dst.Cost.OutputPricePerMillionTokens)
require.NotNil(t, dst.Cost.CacheReadPricePerMillionTokens)
require.Equal(t, 0.03, *dst.Cost.CacheReadPricePerMillionTokens)
require.NotNil(t, dst.Cost.CacheWritePricePerMillionTokens)
require.Equal(t, 0.3, *dst.Cost.CacheWritePricePerMillionTokens)
require.NotNil(t, dst.ProviderOptions)
require.NotNil(t, dst.ProviderOptions.OpenAI)
require.Equal(t, "alice", *dst.ProviderOptions.OpenAI.User)
require.Equal(t, "medium", *dst.ProviderOptions.OpenAI.ReasoningEffort)
}
func stringPtr(value string) *string {
return &value
}
@@ -203,7 +148,3 @@ func boolPtr(value bool) *bool {
func int64Ptr(value int64) *int64 {
return &value
}
func float64Ptr(value float64) *float64 {
return &value
}
+229 -4
View File
@@ -9,6 +9,7 @@ import (
"errors"
"fmt"
"io"
"math"
"mime"
"net/http"
"net/http/httptest"
@@ -20,6 +21,7 @@ import (
"github.com/go-chi/chi/v5"
"github.com/google/uuid"
"github.com/shopspring/decimal"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
@@ -354,6 +356,187 @@ func (api *API) listChatModels(rw http.ResponseWriter, r *http.Request) {
httpapi.Write(ctx, rw, http.StatusOK, response)
}
func (api *API) chatCostSummary(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
apiKey := httpmw.APIKey(r)
// Default date range: last 30 days.
now := time.Now()
defaultStart := now.AddDate(0, 0, -30)
qp := r.URL.Query()
p := httpapi.NewQueryParamParser()
startDate := p.Time(qp, defaultStart, "start_date", time.RFC3339)
endDate := p.Time(qp, now, "end_date", time.RFC3339)
p.ErrorExcessParams(qp)
if len(p.Errors) > 0 {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid query parameters.",
Validations: p.Errors,
})
return
}
targetUser := httpmw.UserParam(r)
if targetUser.ID != apiKey.UserID && !api.Authorize(r, policy.ActionRead, rbac.ResourceChat.WithOwner(targetUser.ID.String())) {
httpapi.Forbidden(rw)
return
}
summary, err := api.Database.GetChatCostSummary(ctx, database.GetChatCostSummaryParams{
OwnerID: targetUser.ID,
StartDate: startDate,
EndDate: endDate,
})
if err != nil {
httpapi.InternalServerError(rw, err)
return
}
byModel, err := api.Database.GetChatCostPerModel(ctx, database.GetChatCostPerModelParams{
OwnerID: targetUser.ID,
StartDate: startDate,
EndDate: endDate,
})
if err != nil {
httpapi.InternalServerError(rw, err)
return
}
byChat, err := api.Database.GetChatCostPerChat(ctx, database.GetChatCostPerChatParams{
OwnerID: targetUser.ID,
StartDate: startDate,
EndDate: endDate,
})
if err != nil {
httpapi.InternalServerError(rw, err)
return
}
modelBreakdowns := make([]codersdk.ChatCostModelBreakdown, 0, len(byModel))
for _, model := range byModel {
modelBreakdowns = append(modelBreakdowns, convertChatCostModelBreakdown(model))
}
chatBreakdowns := make([]codersdk.ChatCostChatBreakdown, 0, len(byChat))
for _, chat := range byChat {
chatBreakdowns = append(chatBreakdowns, convertChatCostChatBreakdown(chat))
}
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatCostSummary{
StartDate: startDate,
EndDate: endDate,
TotalCostMicros: summary.TotalCostMicros,
PricedMessageCount: summary.PricedMessageCount,
UnpricedMessageCount: summary.UnpricedMessageCount,
TotalInputTokens: summary.TotalInputTokens,
TotalOutputTokens: summary.TotalOutputTokens,
ByModel: modelBreakdowns,
ByChat: chatBreakdowns,
})
}
func (api *API) chatCostUsers(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if !api.Authorize(r, policy.ActionRead, rbac.ResourceChat) {
httpapi.Forbidden(rw)
return
}
now := time.Now()
defaultStart := now.AddDate(0, 0, -30)
qp := r.URL.Query()
p := httpapi.NewQueryParamParser()
startDate := p.Time(qp, defaultStart, "start_date", time.RFC3339)
endDate := p.Time(qp, now, "end_date", time.RFC3339)
username := strings.TrimSpace(p.String(qp, "", "username"))
limit := p.Int(qp, 10, "limit")
offset := p.Int(qp, 0, "offset")
p.ErrorExcessParams(qp)
if len(p.Errors) > 0 {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid query parameters.",
Validations: p.Errors,
})
return
}
if limit <= 0 {
limit = 10
}
if offset < 0 || offset > math.MaxInt32 || limit > math.MaxInt32 {
validations := make([]codersdk.ValidationError, 0, 2)
if offset < 0 {
validations = append(validations, codersdk.ValidationError{
Field: "offset",
Detail: "Must be greater than or equal to 0.",
})
}
if offset > math.MaxInt32 {
validations = append(validations, codersdk.ValidationError{
Field: "offset",
Detail: fmt.Sprintf("Must be less than or equal to %d.", math.MaxInt32),
})
}
if limit > math.MaxInt32 {
validations = append(validations, codersdk.ValidationError{
Field: "limit",
Detail: fmt.Sprintf("Must be less than or equal to %d.", math.MaxInt32),
})
}
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid query parameters.",
Validations: validations,
})
return
}
users, err := api.Database.GetChatCostPerUser(ctx, database.GetChatCostPerUserParams{
StartDate: startDate,
EndDate: endDate,
Username: username,
// #nosec G115 - Pagination limits are validated to fit in int32 above.
PageLimit: int32(limit),
// #nosec G115 - Pagination offsets are validated to fit in int32 above.
PageOffset: int32(offset),
})
if err != nil {
httpapi.InternalServerError(rw, err)
return
}
rollups := make([]codersdk.ChatCostUserRollup, 0, len(users))
count := int64(0)
for _, user := range users {
count = user.TotalCount
rollups = append(rollups, convertChatCostUserRollup(user))
}
if len(users) == 0 && offset > 0 {
countUsers, countErr := api.Database.GetChatCostPerUser(ctx, database.GetChatCostPerUserParams{
StartDate: startDate,
EndDate: endDate,
Username: username,
PageLimit: 1,
PageOffset: 0,
})
if countErr != nil {
httpapi.InternalServerError(rw, countErr)
return
}
if len(countUsers) > 0 {
count = countUsers[0].TotalCount
}
}
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatCostUsersResponse{
StartDate: startDate,
EndDate: endDate,
Count: count,
Users: rollups,
})
}
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
//
//nolint:revive // HTTP handler writes to ResponseWriter.
@@ -2172,6 +2355,48 @@ func convertChats(chats []database.Chat, diffStatusesByChatID map[uuid.UUID]data
return result
}
func convertChatCostModelBreakdown(model database.GetChatCostPerModelRow) codersdk.ChatCostModelBreakdown {
displayName := strings.TrimSpace(model.DisplayName)
if displayName == "" {
displayName = model.Model
}
return codersdk.ChatCostModelBreakdown{
ModelConfigID: model.ModelConfigID,
DisplayName: displayName,
Provider: model.Provider,
Model: model.Model,
TotalCostMicros: model.TotalCostMicros,
MessageCount: model.MessageCount,
TotalInputTokens: model.TotalInputTokens,
TotalOutputTokens: model.TotalOutputTokens,
}
}
func convertChatCostChatBreakdown(chat database.GetChatCostPerChatRow) codersdk.ChatCostChatBreakdown {
return codersdk.ChatCostChatBreakdown{
RootChatID: chat.RootChatID,
ChatTitle: chat.ChatTitle,
TotalCostMicros: chat.TotalCostMicros,
MessageCount: chat.MessageCount,
TotalInputTokens: chat.TotalInputTokens,
TotalOutputTokens: chat.TotalOutputTokens,
}
}
func convertChatCostUserRollup(user database.GetChatCostPerUserRow) codersdk.ChatCostUserRollup {
return codersdk.ChatCostUserRollup{
UserID: user.UserID,
Username: user.Username,
Name: user.Name,
AvatarURL: user.AvatarURL,
TotalCostMicros: user.TotalCostMicros,
MessageCount: user.MessageCount,
ChatCount: user.ChatCount,
TotalInputTokens: user.TotalInputTokens,
TotalOutputTokens: user.TotalOutputTokens,
}
}
func convertChatQueuedMessage(m database.ChatQueuedMessage) codersdk.ChatQueuedMessage {
return db2sdk.ChatQueuedMessage(m)
}
@@ -3117,7 +3342,7 @@ func validateChatModelCallConfig(modelConfig *codersdk.ChatModelCallConfig) erro
pricingFields := []struct {
name string
value *float64
value *decimal.Decimal
}{
{name: "cost.input_price_per_million_tokens", value: costConfig.InputPricePerMillionTokens},
{name: "cost.output_price_per_million_tokens", value: costConfig.OutputPricePerMillionTokens},
@@ -3125,7 +3350,7 @@ func validateChatModelCallConfig(modelConfig *codersdk.ChatModelCallConfig) erro
{name: "cost.cache_write_price_per_million_tokens", value: costConfig.CacheWritePricePerMillionTokens},
}
for _, field := range pricingFields {
if err := validateNonNegativeFloat64Field(field.name, field.value); err != nil {
if err := validateNonNegativeDecimalField(field.name, field.value); err != nil {
return err
}
}
@@ -3133,11 +3358,11 @@ func validateChatModelCallConfig(modelConfig *codersdk.ChatModelCallConfig) erro
return nil
}
func validateNonNegativeFloat64Field(name string, value *float64) error {
func validateNonNegativeDecimalField(name string, value *decimal.Decimal) error {
if value == nil {
return nil
}
if *value < 0 {
if value.IsNegative() {
return xerrors.Errorf("%s must be greater than or equal to zero", name)
}
return nil
+320 -19
View File
@@ -13,6 +13,7 @@ import (
"time"
"github.com/google/uuid"
"github.com/shopspring/decimal"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/coderdtest"
@@ -22,7 +23,6 @@ import (
"github.com/coder/coder/v2/coderd/database/dbfake"
"github.com/coder/coder/v2/coderd/externalauth"
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
"github.com/coder/websocket"
@@ -939,10 +939,10 @@ func TestListChatModelConfigs(t *testing.T) {
require.Equal(t, storedConfig.ID, configs[0].ID)
requireChatModelPricing(t, configs[0].ModelConfig, &codersdk.ChatModelCallConfig{
Cost: &codersdk.ModelCostConfig{
InputPricePerMillionTokens: ptr.Ref(0.15),
OutputPricePerMillionTokens: ptr.Ref(0.6),
CacheReadPricePerMillionTokens: ptr.Ref(0.03),
CacheWritePricePerMillionTokens: ptr.Ref(0.3),
InputPricePerMillionTokens: decRef("0.15"),
OutputPricePerMillionTokens: decRef("0.6"),
CacheReadPricePerMillionTokens: decRef("0.03"),
CacheWritePricePerMillionTokens: decRef("0.3"),
},
})
})
@@ -993,10 +993,10 @@ func TestCreateChatModelConfig(t *testing.T) {
isDefault := true
pricing := &codersdk.ChatModelCallConfig{
Cost: &codersdk.ModelCostConfig{
InputPricePerMillionTokens: ptr.Ref(0.15),
OutputPricePerMillionTokens: ptr.Ref(0.6),
CacheReadPricePerMillionTokens: ptr.Ref(0.03),
CacheWritePricePerMillionTokens: ptr.Ref(0.3),
InputPricePerMillionTokens: decRef("0.15"),
OutputPricePerMillionTokens: decRef("0.6"),
CacheReadPricePerMillionTokens: decRef("0.03"),
CacheWritePricePerMillionTokens: decRef("0.3"),
},
}
modelConfig, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
@@ -1040,7 +1040,7 @@ func TestCreateChatModelConfig(t *testing.T) {
ContextLimit: &contextLimit,
ModelConfig: &codersdk.ChatModelCallConfig{
Cost: &codersdk.ModelCostConfig{
InputPricePerMillionTokens: ptr.Ref(-0.01),
InputPricePerMillionTokens: decRef("-0.01"),
},
},
})
@@ -1123,10 +1123,10 @@ func TestUpdateChatModelConfig(t *testing.T) {
contextLimit := int64(8192)
pricing := &codersdk.ChatModelCallConfig{
Cost: &codersdk.ModelCostConfig{
InputPricePerMillionTokens: ptr.Ref(0.2),
OutputPricePerMillionTokens: ptr.Ref(0.8),
CacheReadPricePerMillionTokens: ptr.Ref(0.04),
CacheWritePricePerMillionTokens: ptr.Ref(0.4),
InputPricePerMillionTokens: decRef("0.2"),
OutputPricePerMillionTokens: decRef("0.8"),
CacheReadPricePerMillionTokens: decRef("0.04"),
CacheWritePricePerMillionTokens: decRef("0.4"),
},
}
updated, err := client.UpdateChatModelConfig(ctx, modelConfig.ID, codersdk.UpdateChatModelConfigRequest{
@@ -1157,7 +1157,7 @@ func TestUpdateChatModelConfig(t *testing.T) {
_, err := client.UpdateChatModelConfig(ctx, modelConfig.ID, codersdk.UpdateChatModelConfigRequest{
ModelConfig: &codersdk.ChatModelCallConfig{
Cost: &codersdk.ModelCostConfig{
OutputPricePerMillionTokens: ptr.Ref(-1.0),
OutputPricePerMillionTokens: decRef("-1.0"),
},
},
})
@@ -3479,6 +3479,302 @@ func TestGetChatFile(t *testing.T) {
})
}
func TestChatCostSummary(t *testing.T) {
t.Parallel()
t.Run("BasicSummary", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, db := newChatClientWithDatabase(t)
firstUser := coderdtest.CreateFirstUser(t, client)
modelConfig := createChatModelConfig(t, client)
chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
OwnerID: firstUser.UserID,
LastModelConfigID: modelConfig.ID,
Title: "test chat",
})
require.NoError(t, err)
for i := 0; i < 2; i++ {
_, err = db.InsertChatMessage(dbauthz.AsSystemRestricted(ctx), database.InsertChatMessageParams{
ChatID: chat.ID,
ModelConfigID: uuid.NullUUID{UUID: modelConfig.ID, Valid: true},
Role: "assistant",
Visibility: database.ChatMessageVisibilityBoth,
InputTokens: sql.NullInt64{Int64: 100, Valid: true},
OutputTokens: sql.NullInt64{Int64: 50, Valid: true},
TotalCostMicros: sql.NullInt64{Int64: 500, Valid: true},
})
require.NoError(t, err)
}
summary, err := client.GetChatCostSummary(ctx, "me", codersdk.ChatCostSummaryOptions{})
require.NoError(t, err)
require.Equal(t, int64(1000), summary.TotalCostMicros)
require.Equal(t, int64(2), summary.PricedMessageCount)
require.Equal(t, int64(0), summary.UnpricedMessageCount)
require.Equal(t, int64(200), summary.TotalInputTokens)
require.Equal(t, int64(100), summary.TotalOutputTokens)
require.Len(t, summary.ByModel, 1)
require.Equal(t, modelConfig.ID, summary.ByModel[0].ModelConfigID)
require.Equal(t, int64(1000), summary.ByModel[0].TotalCostMicros)
require.Equal(t, int64(2), summary.ByModel[0].MessageCount)
require.Len(t, summary.ByChat, 1)
require.Equal(t, chat.ID, summary.ByChat[0].RootChatID)
require.Equal(t, int64(1000), summary.ByChat[0].TotalCostMicros)
require.Equal(t, int64(2), summary.ByChat[0].MessageCount)
})
}
func TestChatCostSummary_AdminDrilldown(t *testing.T) {
t.Parallel()
seedCtx := testutil.Context(t, testutil.WaitLong)
client, db := newChatClientWithDatabase(t)
firstUser := coderdtest.CreateFirstUser(t, client)
memberClient, member := coderdtest.CreateAnotherUser(t, client, firstUser.OrganizationID)
modelConfig := createChatModelConfig(t, client)
chat, err := db.InsertChat(dbauthz.AsSystemRestricted(seedCtx), database.InsertChatParams{
OwnerID: member.ID,
LastModelConfigID: modelConfig.ID,
Title: "member chat",
})
require.NoError(t, err)
_, err = db.InsertChatMessage(dbauthz.AsSystemRestricted(seedCtx), database.InsertChatMessageParams{
ChatID: chat.ID,
ModelConfigID: uuid.NullUUID{UUID: modelConfig.ID, Valid: true},
Role: "assistant",
Visibility: database.ChatMessageVisibilityBoth,
InputTokens: sql.NullInt64{Int64: 200, Valid: true},
OutputTokens: sql.NullInt64{Int64: 100, Valid: true},
TotalCostMicros: sql.NullInt64{Int64: 750, Valid: true},
})
require.NoError(t, err)
t.Run("AdminCanDrilldown", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
summary, err := client.GetChatCostSummary(ctx, member.ID.String(), codersdk.ChatCostSummaryOptions{})
require.NoError(t, err)
require.Equal(t, int64(750), summary.TotalCostMicros)
require.Equal(t, int64(1), summary.PricedMessageCount)
})
t.Run("MemberCannotDrilldownOtherUser", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
_, err := memberClient.GetChatCostSummary(ctx, firstUser.UserID.String(), codersdk.ChatCostSummaryOptions{})
require.Error(t, err)
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusNotFound, sdkErr.StatusCode())
})
}
func TestChatCostUsers(t *testing.T) {
t.Parallel()
seedCtx := testutil.Context(t, testutil.WaitLong)
client, db := newChatClientWithDatabase(t)
firstUser := coderdtest.CreateFirstUser(t, client)
memberClient, member := coderdtest.CreateAnotherUser(t, client, firstUser.OrganizationID)
firstUserRecord, err := db.GetUserByID(dbauthz.AsSystemRestricted(seedCtx), firstUser.UserID)
require.NoError(t, err)
modelConfig := createChatModelConfig(t, client)
adminChat, err := db.InsertChat(dbauthz.AsSystemRestricted(seedCtx), database.InsertChatParams{
OwnerID: firstUser.UserID,
LastModelConfigID: modelConfig.ID,
Title: "admin chat",
})
require.NoError(t, err)
_, err = db.InsertChatMessage(dbauthz.AsSystemRestricted(seedCtx), database.InsertChatMessageParams{
ChatID: adminChat.ID,
ModelConfigID: uuid.NullUUID{UUID: modelConfig.ID, Valid: true},
Role: "assistant",
Visibility: database.ChatMessageVisibilityBoth,
InputTokens: sql.NullInt64{Int64: 100, Valid: true},
OutputTokens: sql.NullInt64{Int64: 50, Valid: true},
TotalCostMicros: sql.NullInt64{Int64: 300, Valid: true},
})
require.NoError(t, err)
memberChat, err := db.InsertChat(dbauthz.AsSystemRestricted(seedCtx), database.InsertChatParams{
OwnerID: member.ID,
LastModelConfigID: modelConfig.ID,
Title: "member chat",
})
require.NoError(t, err)
_, err = db.InsertChatMessage(dbauthz.AsSystemRestricted(seedCtx), database.InsertChatMessageParams{
ChatID: memberChat.ID,
ModelConfigID: uuid.NullUUID{UUID: modelConfig.ID, Valid: true},
Role: "assistant",
Visibility: database.ChatMessageVisibilityBoth,
InputTokens: sql.NullInt64{Int64: 200, Valid: true},
OutputTokens: sql.NullInt64{Int64: 100, Valid: true},
TotalCostMicros: sql.NullInt64{Int64: 800, Valid: true},
})
require.NoError(t, err)
t.Run("AdminCanListUsers", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
resp, err := client.GetChatCostUsers(ctx, codersdk.ChatCostUsersOptions{})
require.NoError(t, err)
require.Equal(t, int64(2), resp.Count)
require.Len(t, resp.Users, 2)
require.Equal(t, member.ID, resp.Users[0].UserID)
require.Equal(t, member.Username, resp.Users[0].Username)
require.Equal(t, int64(800), resp.Users[0].TotalCostMicros)
require.Equal(t, int64(1), resp.Users[0].MessageCount)
require.Equal(t, int64(1), resp.Users[0].ChatCount)
require.Equal(t, firstUser.UserID, resp.Users[1].UserID)
require.Equal(t, firstUserRecord.Username, resp.Users[1].Username)
require.Equal(t, int64(300), resp.Users[1].TotalCostMicros)
})
t.Run("AdminCanFilterAndPaginateUsers", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
resp, err := client.GetChatCostUsers(ctx, codersdk.ChatCostUsersOptions{
Username: member.Username,
Pagination: codersdk.Pagination{
Limit: 1,
Offset: 0,
},
})
require.NoError(t, err)
require.Equal(t, int64(1), resp.Count)
require.Len(t, resp.Users, 1)
require.Equal(t, member.ID, resp.Users[0].UserID)
require.Equal(t, member.Username, resp.Users[0].Username)
})
t.Run("MemberCannotListUsers", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
_, err := memberClient.GetChatCostUsers(ctx, codersdk.ChatCostUsersOptions{})
require.Error(t, err)
var sdkErr *codersdk.Error
require.ErrorAs(t, err, &sdkErr)
require.Equal(t, http.StatusForbidden, sdkErr.StatusCode())
})
}
func TestChatCostSummary_DateRange(t *testing.T) {
t.Parallel()
seedCtx := testutil.Context(t, testutil.WaitLong)
client, db := newChatClientWithDatabase(t)
firstUser := coderdtest.CreateFirstUser(t, client)
modelConfig := createChatModelConfig(t, client)
chat, err := db.InsertChat(dbauthz.AsSystemRestricted(seedCtx), database.InsertChatParams{
OwnerID: firstUser.UserID,
LastModelConfigID: modelConfig.ID,
Title: "date range test",
})
require.NoError(t, err)
_, err = db.InsertChatMessage(dbauthz.AsSystemRestricted(seedCtx), database.InsertChatMessageParams{
ChatID: chat.ID,
ModelConfigID: uuid.NullUUID{UUID: modelConfig.ID, Valid: true},
Role: "assistant",
Visibility: database.ChatMessageVisibilityBoth,
InputTokens: sql.NullInt64{Int64: 100, Valid: true},
OutputTokens: sql.NullInt64{Int64: 50, Valid: true},
TotalCostMicros: sql.NullInt64{Int64: 500, Valid: true},
})
require.NoError(t, err)
now := time.Now()
t.Run("MessageInRange", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
summary, err := client.GetChatCostSummary(ctx, "me", codersdk.ChatCostSummaryOptions{
StartDate: now.Add(-time.Hour),
EndDate: now.Add(time.Hour),
})
require.NoError(t, err)
require.Equal(t, int64(500), summary.TotalCostMicros)
require.Equal(t, int64(1), summary.PricedMessageCount)
})
t.Run("MessageOutOfRange", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
summary, err := client.GetChatCostSummary(ctx, "me", codersdk.ChatCostSummaryOptions{
StartDate: now.Add(time.Hour),
EndDate: now.Add(2 * time.Hour),
})
require.NoError(t, err)
require.Equal(t, int64(0), summary.TotalCostMicros)
require.Equal(t, int64(0), summary.PricedMessageCount)
})
}
func TestChatCostSummary_UnpricedMessages(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, db := newChatClientWithDatabase(t)
firstUser := coderdtest.CreateFirstUser(t, client)
modelConfig := createChatModelConfig(t, client)
chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
OwnerID: firstUser.UserID,
LastModelConfigID: modelConfig.ID,
Title: "unpriced test",
})
require.NoError(t, err)
_, err = db.InsertChatMessage(dbauthz.AsSystemRestricted(ctx), database.InsertChatMessageParams{
ChatID: chat.ID,
ModelConfigID: uuid.NullUUID{UUID: modelConfig.ID, Valid: true},
Role: "assistant",
Visibility: database.ChatMessageVisibilityBoth,
InputTokens: sql.NullInt64{Int64: 100, Valid: true},
OutputTokens: sql.NullInt64{Int64: 50, Valid: true},
TotalCostMicros: sql.NullInt64{Int64: 500, Valid: true},
})
require.NoError(t, err)
_, err = db.InsertChatMessage(dbauthz.AsSystemRestricted(ctx), database.InsertChatMessageParams{
ChatID: chat.ID,
ModelConfigID: uuid.NullUUID{UUID: modelConfig.ID, Valid: true},
Role: "assistant",
Visibility: database.ChatMessageVisibilityBoth,
InputTokens: sql.NullInt64{Int64: 200, Valid: true},
OutputTokens: sql.NullInt64{Int64: 75, Valid: true},
TotalCostMicros: sql.NullInt64{},
})
require.NoError(t, err)
summary, err := client.GetChatCostSummary(ctx, "me", codersdk.ChatCostSummaryOptions{})
require.NoError(t, err)
require.Equal(t, int64(500), summary.TotalCostMicros)
require.Equal(t, int64(1), summary.PricedMessageCount)
require.Equal(t, int64(1), summary.UnpricedMessageCount)
require.Equal(t, int64(300), summary.TotalInputTokens)
require.Equal(t, int64(125), summary.TotalOutputTokens)
}
func requireChatModelPricing(
t *testing.T,
actual *codersdk.ChatModelCallConfig,
@@ -3495,10 +3791,15 @@ func requireChatModelPricing(
require.NotNil(t, actual.Cost.CacheReadPricePerMillionTokens)
require.NotNil(t, actual.Cost.CacheWritePricePerMillionTokens)
require.Equal(t, *expected.Cost.InputPricePerMillionTokens, *actual.Cost.InputPricePerMillionTokens)
require.Equal(t, *expected.Cost.OutputPricePerMillionTokens, *actual.Cost.OutputPricePerMillionTokens)
require.Equal(t, *expected.Cost.CacheReadPricePerMillionTokens, *actual.Cost.CacheReadPricePerMillionTokens)
require.Equal(t, *expected.Cost.CacheWritePricePerMillionTokens, *actual.Cost.CacheWritePricePerMillionTokens)
require.True(t, expected.Cost.InputPricePerMillionTokens.Equal(*actual.Cost.InputPricePerMillionTokens))
require.True(t, expected.Cost.OutputPricePerMillionTokens.Equal(*actual.Cost.OutputPricePerMillionTokens))
require.True(t, expected.Cost.CacheReadPricePerMillionTokens.Equal(*actual.Cost.CacheReadPricePerMillionTokens))
require.True(t, expected.Cost.CacheWritePricePerMillionTokens.Equal(*actual.Cost.CacheWritePricePerMillionTokens))
}
func decRef(value string) *decimal.Decimal {
d := decimal.RequireFromString(value)
return &d
}
func createChatModelConfig(t *testing.T, client *codersdk.Client) codersdk.ChatModelConfig {
+7
View File
@@ -1139,6 +1139,13 @@ func New(options *Options) *API {
r.Post("/", api.postChats)
r.Get("/models", api.listChatModels)
r.Get("/watch", api.watchChats)
r.Route("/cost", func(r chi.Router) {
r.Get("/users", api.chatCostUsers)
r.Route("/{user}", func(r chi.Router) {
r.Use(httpmw.ExtractUserParam(options.Database))
r.Get("/summary", api.chatCostSummary)
})
})
r.Route("/files", func(r chi.Router) {
r.Use(httpmw.RateLimit(options.FilesRateLimit, time.Minute))
r.Post("/", api.postChatFile)
+28
View File
@@ -2426,6 +2426,34 @@ func (q *querier) GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (datab
return fetch(q.log, q.auth, q.db.GetChatByIDForUpdate)(ctx, id)
}
func (q *querier) GetChatCostPerChat(ctx context.Context, arg database.GetChatCostPerChatParams) ([]database.GetChatCostPerChatRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat.WithOwner(arg.OwnerID.String())); err != nil {
return nil, err
}
return q.db.GetChatCostPerChat(ctx, arg)
}
func (q *querier) GetChatCostPerModel(ctx context.Context, arg database.GetChatCostPerModelParams) ([]database.GetChatCostPerModelRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat.WithOwner(arg.OwnerID.String())); err != nil {
return nil, err
}
return q.db.GetChatCostPerModel(ctx, arg)
}
func (q *querier) GetChatCostPerUser(ctx context.Context, arg database.GetChatCostPerUserParams) ([]database.GetChatCostPerUserRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat); err != nil {
return nil, err
}
return q.db.GetChatCostPerUser(ctx, arg)
}
func (q *querier) GetChatCostSummary(ctx context.Context, arg database.GetChatCostSummaryParams) (database.GetChatCostSummaryRow, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat.WithOwner(arg.OwnerID.String())); err != nil {
return database.GetChatCostSummaryRow{}, err
}
return q.db.GetChatCostSummary(ctx, arg)
}
func (q *querier) GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (database.ChatDiffStatus, error) {
// Authorize read on the parent chat.
_, err := q.GetChatByID(ctx, chatID)
+75
View File
@@ -438,6 +438,81 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().GetChatByIDForUpdate(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
check.Args(chat.ID).Asserts(chat, policy.ActionRead).Returns(chat)
}))
s.Run("GetChatCostPerChat", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.GetChatCostPerChatParams{
OwnerID: uuid.New(),
StartDate: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
EndDate: time.Date(2025, 2, 1, 0, 0, 0, 0, time.UTC),
}
rows := []database.GetChatCostPerChatRow{{
RootChatID: uuid.New(),
ChatTitle: "chat-cost",
TotalCostMicros: 123,
MessageCount: 4,
TotalInputTokens: 55,
TotalOutputTokens: 89,
}}
dbm.EXPECT().GetChatCostPerChat(gomock.Any(), arg).Return(rows, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceChat.WithOwner(arg.OwnerID.String()), policy.ActionRead).Returns(rows)
}))
s.Run("GetChatCostPerModel", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.GetChatCostPerModelParams{
OwnerID: uuid.New(),
StartDate: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
EndDate: time.Date(2025, 2, 1, 0, 0, 0, 0, time.UTC),
}
rows := []database.GetChatCostPerModelRow{{
ModelConfigID: uuid.New(),
DisplayName: "GPT 4.1",
Provider: "openai",
Model: "gpt-4.1",
TotalCostMicros: 456,
MessageCount: 7,
TotalInputTokens: 144,
TotalOutputTokens: 233,
}}
dbm.EXPECT().GetChatCostPerModel(gomock.Any(), arg).Return(rows, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceChat.WithOwner(arg.OwnerID.String()), policy.ActionRead).Returns(rows)
}))
s.Run("GetChatCostPerUser", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.GetChatCostPerUserParams{
PageOffset: 0,
PageLimit: 25,
StartDate: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
EndDate: time.Date(2025, 2, 1, 0, 0, 0, 0, time.UTC),
Username: "cost-user",
}
rows := []database.GetChatCostPerUserRow{{
UserID: uuid.New(),
Username: "cost-user",
Name: "Cost User",
AvatarURL: "https://example.com/avatar.png",
TotalCostMicros: 789,
MessageCount: 11,
ChatCount: 3,
TotalInputTokens: 377,
TotalOutputTokens: 610,
TotalCount: 1,
}}
dbm.EXPECT().GetChatCostPerUser(gomock.Any(), arg).Return(rows, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceChat, policy.ActionRead).Returns(rows)
}))
s.Run("GetChatCostSummary", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
arg := database.GetChatCostSummaryParams{
OwnerID: uuid.New(),
StartDate: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
EndDate: time.Date(2025, 2, 1, 0, 0, 0, 0, time.UTC),
}
row := database.GetChatCostSummaryRow{
TotalCostMicros: 987,
PricedMessageCount: 12,
UnpricedMessageCount: 2,
TotalInputTokens: 400,
TotalOutputTokens: 800,
}
dbm.EXPECT().GetChatCostSummary(gomock.Any(), arg).Return(row, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceChat.WithOwner(arg.OwnerID.String()), policy.ActionRead).Returns(row)
}))
s.Run("GetChatDiffStatusByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
diffStatus := testutil.Fake(s.T(), faker, database.ChatDiffStatus{ChatID: chat.ID})
+32
View File
@@ -983,6 +983,38 @@ func (m queryMetricsStore) GetChatByIDForUpdate(ctx context.Context, id uuid.UUI
return r0, r1
}
func (m queryMetricsStore) GetChatCostPerChat(ctx context.Context, arg database.GetChatCostPerChatParams) ([]database.GetChatCostPerChatRow, error) {
start := time.Now()
r0, r1 := m.s.GetChatCostPerChat(ctx, arg)
m.queryLatencies.WithLabelValues("GetChatCostPerChat").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatCostPerChat").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatCostPerModel(ctx context.Context, arg database.GetChatCostPerModelParams) ([]database.GetChatCostPerModelRow, error) {
start := time.Now()
r0, r1 := m.s.GetChatCostPerModel(ctx, arg)
m.queryLatencies.WithLabelValues("GetChatCostPerModel").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatCostPerModel").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatCostPerUser(ctx context.Context, arg database.GetChatCostPerUserParams) ([]database.GetChatCostPerUserRow, error) {
start := time.Now()
r0, r1 := m.s.GetChatCostPerUser(ctx, arg)
m.queryLatencies.WithLabelValues("GetChatCostPerUser").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatCostPerUser").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatCostSummary(ctx context.Context, arg database.GetChatCostSummaryParams) (database.GetChatCostSummaryRow, error) {
start := time.Now()
r0, r1 := m.s.GetChatCostSummary(ctx, arg)
m.queryLatencies.WithLabelValues("GetChatCostSummary").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatCostSummary").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (database.ChatDiffStatus, error) {
start := time.Now()
r0, r1 := m.s.GetChatDiffStatusByChatID(ctx, chatID)
+60
View File
@@ -1778,6 +1778,66 @@ func (mr *MockStoreMockRecorder) GetChatByIDForUpdate(ctx, id any) *gomock.Call
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatByIDForUpdate", reflect.TypeOf((*MockStore)(nil).GetChatByIDForUpdate), ctx, id)
}
// GetChatCostPerChat mocks base method.
func (m *MockStore) GetChatCostPerChat(ctx context.Context, arg database.GetChatCostPerChatParams) ([]database.GetChatCostPerChatRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatCostPerChat", ctx, arg)
ret0, _ := ret[0].([]database.GetChatCostPerChatRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatCostPerChat indicates an expected call of GetChatCostPerChat.
func (mr *MockStoreMockRecorder) GetChatCostPerChat(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatCostPerChat", reflect.TypeOf((*MockStore)(nil).GetChatCostPerChat), ctx, arg)
}
// GetChatCostPerModel mocks base method.
func (m *MockStore) GetChatCostPerModel(ctx context.Context, arg database.GetChatCostPerModelParams) ([]database.GetChatCostPerModelRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatCostPerModel", ctx, arg)
ret0, _ := ret[0].([]database.GetChatCostPerModelRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatCostPerModel indicates an expected call of GetChatCostPerModel.
func (mr *MockStoreMockRecorder) GetChatCostPerModel(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatCostPerModel", reflect.TypeOf((*MockStore)(nil).GetChatCostPerModel), ctx, arg)
}
// GetChatCostPerUser mocks base method.
func (m *MockStore) GetChatCostPerUser(ctx context.Context, arg database.GetChatCostPerUserParams) ([]database.GetChatCostPerUserRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatCostPerUser", ctx, arg)
ret0, _ := ret[0].([]database.GetChatCostPerUserRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatCostPerUser indicates an expected call of GetChatCostPerUser.
func (mr *MockStoreMockRecorder) GetChatCostPerUser(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatCostPerUser", reflect.TypeOf((*MockStore)(nil).GetChatCostPerUser), ctx, arg)
}
// GetChatCostSummary mocks base method.
func (m *MockStore) GetChatCostSummary(ctx context.Context, arg database.GetChatCostSummaryParams) (database.GetChatCostSummaryRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatCostSummary", ctx, arg)
ret0, _ := ret[0].(database.GetChatCostSummaryRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatCostSummary indicates an expected call of GetChatCostSummary.
func (mr *MockStoreMockRecorder) GetChatCostSummary(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatCostSummary", reflect.TypeOf((*MockStore)(nil).GetChatCostSummary), ctx, arg)
}
// GetChatDiffStatusByChatID mocks base method.
func (m *MockStore) GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (database.ChatDiffStatus, error) {
m.ctrl.T.Helper()
+4 -1
View File
@@ -1226,7 +1226,8 @@ CREATE TABLE chat_messages (
context_limit bigint,
compressed boolean DEFAULT false NOT NULL,
created_by uuid,
content_version smallint NOT NULL
content_version smallint NOT NULL,
total_cost_micros bigint
);
CREATE SEQUENCE chat_messages_id_seq
@@ -3534,6 +3535,8 @@ CREATE INDEX idx_chat_messages_chat_created ON chat_messages USING btree (chat_i
CREATE INDEX idx_chat_messages_compressed_summary_boundary ON chat_messages USING btree (chat_id, created_at DESC, id DESC) WHERE ((compressed = true) AND (role = 'system'::chat_message_role) AND (visibility = ANY (ARRAY['model'::chat_message_visibility, 'both'::chat_message_visibility])));
CREATE INDEX idx_chat_messages_created_at ON chat_messages USING btree (created_at);
CREATE INDEX idx_chat_model_configs_enabled ON chat_model_configs USING btree (enabled);
CREATE INDEX idx_chat_model_configs_provider ON chat_model_configs USING btree (provider);
@@ -0,0 +1,3 @@
DROP INDEX IF EXISTS idx_chat_messages_created_at;
ALTER TABLE chat_messages DROP COLUMN total_cost_micros;
@@ -0,0 +1,68 @@
ALTER TABLE chat_messages ADD COLUMN total_cost_micros BIGINT;
WITH message_costs AS (
SELECT
msg.id,
ROUND(
COALESCE(msg.input_tokens, 0)::numeric * COALESCE(pricing.input_price, 0)
+ COALESCE(msg.output_tokens, 0)::numeric * COALESCE(pricing.output_price, 0)
+ COALESCE(msg.cache_read_tokens, 0)::numeric * COALESCE(pricing.cache_read_price, 0)
+ COALESCE(msg.cache_creation_tokens, 0)::numeric * COALESCE(pricing.cache_write_price, 0)
)::bigint AS total_cost_micros
FROM
chat_messages AS msg
JOIN
chat_model_configs AS cfg
ON
cfg.id = msg.model_config_id
CROSS JOIN LATERAL (
SELECT
COALESCE(
(cfg.options -> 'cost' ->> 'input_price_per_million_tokens')::numeric,
(cfg.options ->> 'input_price_per_million_tokens')::numeric
) AS input_price,
COALESCE(
(cfg.options -> 'cost' ->> 'output_price_per_million_tokens')::numeric,
(cfg.options ->> 'output_price_per_million_tokens')::numeric
) AS output_price,
COALESCE(
(cfg.options -> 'cost' ->> 'cache_read_price_per_million_tokens')::numeric,
(cfg.options ->> 'cache_read_price_per_million_tokens')::numeric
) AS cache_read_price,
COALESCE(
(cfg.options -> 'cost' ->> 'cache_write_price_per_million_tokens')::numeric,
(cfg.options ->> 'cache_write_price_per_million_tokens')::numeric
) AS cache_write_price
) AS pricing
WHERE
msg.total_cost_micros IS NULL
AND (
msg.input_tokens IS NOT NULL
OR msg.output_tokens IS NOT NULL
OR msg.reasoning_tokens IS NOT NULL
OR msg.cache_creation_tokens IS NOT NULL
OR msg.cache_read_tokens IS NOT NULL
)
AND (
pricing.input_price IS NOT NULL
OR pricing.output_price IS NOT NULL
OR pricing.cache_read_price IS NOT NULL
OR pricing.cache_write_price IS NOT NULL
)
AND (
(msg.input_tokens IS NOT NULL AND pricing.input_price IS NOT NULL)
OR (msg.output_tokens IS NOT NULL AND pricing.output_price IS NOT NULL)
OR (msg.cache_read_tokens IS NOT NULL AND pricing.cache_read_price IS NOT NULL)
OR (msg.cache_creation_tokens IS NOT NULL AND pricing.cache_write_price IS NOT NULL)
)
)
UPDATE
chat_messages AS msg
SET
total_cost_micros = message_costs.total_cost_micros
FROM
message_costs
WHERE
msg.id = message_costs.id;
CREATE INDEX idx_chat_messages_created_at ON chat_messages (created_at);
+1
View File
@@ -4020,6 +4020,7 @@ type ChatMessage struct {
Compressed bool `db:"compressed" json:"compressed"`
CreatedBy uuid.NullUUID `db:"created_by" json:"created_by"`
ContentVersion int16 `db:"content_version" json:"content_version"`
TotalCostMicros sql.NullInt64 `db:"total_cost_micros" json:"total_cost_micros"`
}
type ChatModelConfig struct {
+13
View File
@@ -215,6 +215,19 @@ type sqlcQuerier interface {
GetAuthorizationUserRoles(ctx context.Context, userID uuid.UUID) (GetAuthorizationUserRolesRow, error)
GetChatByID(ctx context.Context, id uuid.UUID) (Chat, error)
GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (Chat, error)
// Per-root-chat cost breakdown for a single user within a date range.
// Groups by root_chat_id so forked chats roll up under their root.
// Only counts assistant-role messages.
GetChatCostPerChat(ctx context.Context, arg GetChatCostPerChatParams) ([]GetChatCostPerChatRow, error)
// Per-model cost breakdown for a single user within a date range.
// Only counts assistant-role messages that have a model_config_id.
GetChatCostPerModel(ctx context.Context, arg GetChatCostPerModelParams) ([]GetChatCostPerModelRow, error)
// Deployment-wide per-user cost rollup within a date range.
// Only counts assistant-role messages.
GetChatCostPerUser(ctx context.Context, arg GetChatCostPerUserParams) ([]GetChatCostPerUserRow, error)
// Aggregate cost summary for a single user within a date range.
// Only counts assistant-role messages.
GetChatCostSummary(ctx context.Context, arg GetChatCostSummaryParams) (GetChatCostSummaryRow, error)
GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (ChatDiffStatus, error)
GetChatDiffStatusesByChatIDs(ctx context.Context, chatIds []uuid.UUID) ([]ChatDiffStatus, error)
GetChatFileByID(ctx context.Context, id uuid.UUID) (ChatFile, error)
+365 -8
View File
@@ -3229,6 +3229,353 @@ func (q *sqlQuerier) GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (Ch
return i, err
}
const getChatCostPerChat = `-- name: GetChatCostPerChat :many
WITH chat_costs AS (
SELECT
COALESCE(c.root_chat_id, c.id) AS root_chat_id,
COALESCE(SUM(cm.total_cost_micros), 0)::bigint AS total_cost_micros,
COUNT(*) FILTER (
WHERE cm.input_tokens IS NOT NULL
OR cm.output_tokens IS NOT NULL
OR cm.reasoning_tokens IS NOT NULL
OR cm.cache_creation_tokens IS NOT NULL
OR cm.cache_read_tokens IS NOT NULL
)::bigint AS message_count,
COALESCE(SUM(cm.input_tokens), 0)::bigint AS total_input_tokens,
COALESCE(SUM(cm.output_tokens), 0)::bigint AS total_output_tokens
FROM chat_messages cm
JOIN chats c ON c.id = cm.chat_id
WHERE c.owner_id = $1::uuid
AND cm.role = 'assistant'
AND cm.created_at >= $2::timestamptz
AND cm.created_at < $3::timestamptz
GROUP BY COALESCE(c.root_chat_id, c.id)
)
SELECT
cc.root_chat_id,
COALESCE(rc.title, '') AS chat_title,
cc.total_cost_micros,
cc.message_count,
cc.total_input_tokens,
cc.total_output_tokens
FROM chat_costs cc
LEFT JOIN chats rc ON rc.id = cc.root_chat_id
ORDER BY cc.total_cost_micros DESC
`
type GetChatCostPerChatParams struct {
OwnerID uuid.UUID `db:"owner_id" json:"owner_id"`
StartDate time.Time `db:"start_date" json:"start_date"`
EndDate time.Time `db:"end_date" json:"end_date"`
}
type GetChatCostPerChatRow struct {
RootChatID uuid.UUID `db:"root_chat_id" json:"root_chat_id"`
ChatTitle string `db:"chat_title" json:"chat_title"`
TotalCostMicros int64 `db:"total_cost_micros" json:"total_cost_micros"`
MessageCount int64 `db:"message_count" json:"message_count"`
TotalInputTokens int64 `db:"total_input_tokens" json:"total_input_tokens"`
TotalOutputTokens int64 `db:"total_output_tokens" json:"total_output_tokens"`
}
// Per-root-chat cost breakdown for a single user within a date range.
// Groups by root_chat_id so forked chats roll up under their root.
// Only counts assistant-role messages.
func (q *sqlQuerier) GetChatCostPerChat(ctx context.Context, arg GetChatCostPerChatParams) ([]GetChatCostPerChatRow, error) {
rows, err := q.db.QueryContext(ctx, getChatCostPerChat, arg.OwnerID, arg.StartDate, arg.EndDate)
if err != nil {
return nil, err
}
defer rows.Close()
var items []GetChatCostPerChatRow
for rows.Next() {
var i GetChatCostPerChatRow
if err := rows.Scan(
&i.RootChatID,
&i.ChatTitle,
&i.TotalCostMicros,
&i.MessageCount,
&i.TotalInputTokens,
&i.TotalOutputTokens,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getChatCostPerModel = `-- name: GetChatCostPerModel :many
SELECT
cmc.id AS model_config_id,
cmc.display_name,
cmc.provider,
cmc.model,
COALESCE(SUM(cm.total_cost_micros), 0)::bigint AS total_cost_micros,
COUNT(*) FILTER (
WHERE cm.input_tokens IS NOT NULL
OR cm.output_tokens IS NOT NULL
OR cm.reasoning_tokens IS NOT NULL
OR cm.cache_creation_tokens IS NOT NULL
OR cm.cache_read_tokens IS NOT NULL
)::bigint AS message_count,
COALESCE(SUM(cm.input_tokens), 0)::bigint AS total_input_tokens,
COALESCE(SUM(cm.output_tokens), 0)::bigint AS total_output_tokens
FROM
chat_messages cm
JOIN
chats c ON c.id = cm.chat_id
JOIN
chat_model_configs cmc ON cmc.id = cm.model_config_id
WHERE
c.owner_id = $1::uuid
AND cm.role = 'assistant'
AND cm.created_at >= $2::timestamptz
AND cm.created_at < $3::timestamptz
GROUP BY
cmc.id, cmc.display_name, cmc.provider, cmc.model
ORDER BY
total_cost_micros DESC
`
type GetChatCostPerModelParams struct {
OwnerID uuid.UUID `db:"owner_id" json:"owner_id"`
StartDate time.Time `db:"start_date" json:"start_date"`
EndDate time.Time `db:"end_date" json:"end_date"`
}
type GetChatCostPerModelRow struct {
ModelConfigID uuid.UUID `db:"model_config_id" json:"model_config_id"`
DisplayName string `db:"display_name" json:"display_name"`
Provider string `db:"provider" json:"provider"`
Model string `db:"model" json:"model"`
TotalCostMicros int64 `db:"total_cost_micros" json:"total_cost_micros"`
MessageCount int64 `db:"message_count" json:"message_count"`
TotalInputTokens int64 `db:"total_input_tokens" json:"total_input_tokens"`
TotalOutputTokens int64 `db:"total_output_tokens" json:"total_output_tokens"`
}
// Per-model cost breakdown for a single user within a date range.
// Only counts assistant-role messages that have a model_config_id.
func (q *sqlQuerier) GetChatCostPerModel(ctx context.Context, arg GetChatCostPerModelParams) ([]GetChatCostPerModelRow, error) {
rows, err := q.db.QueryContext(ctx, getChatCostPerModel, arg.OwnerID, arg.StartDate, arg.EndDate)
if err != nil {
return nil, err
}
defer rows.Close()
var items []GetChatCostPerModelRow
for rows.Next() {
var i GetChatCostPerModelRow
if err := rows.Scan(
&i.ModelConfigID,
&i.DisplayName,
&i.Provider,
&i.Model,
&i.TotalCostMicros,
&i.MessageCount,
&i.TotalInputTokens,
&i.TotalOutputTokens,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getChatCostPerUser = `-- name: GetChatCostPerUser :many
WITH chat_cost_users AS (
SELECT
c.owner_id AS user_id,
u.username,
u.name,
u.avatar_url,
COALESCE(SUM(cm.total_cost_micros), 0)::bigint AS total_cost_micros,
COUNT(*) FILTER (
WHERE cm.input_tokens IS NOT NULL
OR cm.output_tokens IS NOT NULL
OR cm.reasoning_tokens IS NOT NULL
OR cm.cache_creation_tokens IS NOT NULL
OR cm.cache_read_tokens IS NOT NULL
)::bigint AS message_count,
COUNT(DISTINCT COALESCE(c.root_chat_id, c.id))::bigint AS chat_count,
COALESCE(SUM(cm.input_tokens), 0)::bigint AS total_input_tokens,
COALESCE(SUM(cm.output_tokens), 0)::bigint AS total_output_tokens
FROM
chat_messages cm
JOIN
chats c ON c.id = cm.chat_id
JOIN
users u ON u.id = c.owner_id
WHERE
cm.role = 'assistant'
AND cm.created_at >= $3::timestamptz
AND cm.created_at < $4::timestamptz
AND (
$5::text = ''
OR u.username ILIKE '%' || $5::text || '%'
)
GROUP BY
c.owner_id,
u.username,
u.name,
u.avatar_url
)
SELECT
user_id,
username,
name,
avatar_url,
total_cost_micros,
message_count,
chat_count,
total_input_tokens,
total_output_tokens,
COUNT(*) OVER()::bigint AS total_count
FROM
chat_cost_users
ORDER BY
total_cost_micros DESC,
username ASC
LIMIT
$2::int
OFFSET
$1::int
`
type GetChatCostPerUserParams struct {
PageOffset int32 `db:"page_offset" json:"page_offset"`
PageLimit int32 `db:"page_limit" json:"page_limit"`
StartDate time.Time `db:"start_date" json:"start_date"`
EndDate time.Time `db:"end_date" json:"end_date"`
Username string `db:"username" json:"username"`
}
type GetChatCostPerUserRow struct {
UserID uuid.UUID `db:"user_id" json:"user_id"`
Username string `db:"username" json:"username"`
Name string `db:"name" json:"name"`
AvatarURL string `db:"avatar_url" json:"avatar_url"`
TotalCostMicros int64 `db:"total_cost_micros" json:"total_cost_micros"`
MessageCount int64 `db:"message_count" json:"message_count"`
ChatCount int64 `db:"chat_count" json:"chat_count"`
TotalInputTokens int64 `db:"total_input_tokens" json:"total_input_tokens"`
TotalOutputTokens int64 `db:"total_output_tokens" json:"total_output_tokens"`
TotalCount int64 `db:"total_count" json:"total_count"`
}
// Deployment-wide per-user cost rollup within a date range.
// Only counts assistant-role messages.
func (q *sqlQuerier) GetChatCostPerUser(ctx context.Context, arg GetChatCostPerUserParams) ([]GetChatCostPerUserRow, error) {
rows, err := q.db.QueryContext(ctx, getChatCostPerUser,
arg.PageOffset,
arg.PageLimit,
arg.StartDate,
arg.EndDate,
arg.Username,
)
if err != nil {
return nil, err
}
defer rows.Close()
var items []GetChatCostPerUserRow
for rows.Next() {
var i GetChatCostPerUserRow
if err := rows.Scan(
&i.UserID,
&i.Username,
&i.Name,
&i.AvatarURL,
&i.TotalCostMicros,
&i.MessageCount,
&i.ChatCount,
&i.TotalInputTokens,
&i.TotalOutputTokens,
&i.TotalCount,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getChatCostSummary = `-- name: GetChatCostSummary :one
SELECT
COALESCE(SUM(cm.total_cost_micros), 0)::bigint AS total_cost_micros,
COUNT(*) FILTER (
WHERE cm.total_cost_micros IS NOT NULL
)::bigint AS priced_message_count,
COUNT(*) FILTER (
WHERE cm.total_cost_micros IS NULL
AND (
cm.input_tokens IS NOT NULL
OR cm.output_tokens IS NOT NULL
OR cm.reasoning_tokens IS NOT NULL
OR cm.cache_creation_tokens IS NOT NULL
OR cm.cache_read_tokens IS NOT NULL
)
)::bigint AS unpriced_message_count,
COALESCE(SUM(cm.input_tokens), 0)::bigint AS total_input_tokens,
COALESCE(SUM(cm.output_tokens), 0)::bigint AS total_output_tokens
FROM
chat_messages cm
JOIN
chats c ON c.id = cm.chat_id
WHERE
c.owner_id = $1::uuid
AND cm.role = 'assistant'
AND cm.created_at >= $2::timestamptz
AND cm.created_at < $3::timestamptz
`
type GetChatCostSummaryParams struct {
OwnerID uuid.UUID `db:"owner_id" json:"owner_id"`
StartDate time.Time `db:"start_date" json:"start_date"`
EndDate time.Time `db:"end_date" json:"end_date"`
}
type GetChatCostSummaryRow struct {
TotalCostMicros int64 `db:"total_cost_micros" json:"total_cost_micros"`
PricedMessageCount int64 `db:"priced_message_count" json:"priced_message_count"`
UnpricedMessageCount int64 `db:"unpriced_message_count" json:"unpriced_message_count"`
TotalInputTokens int64 `db:"total_input_tokens" json:"total_input_tokens"`
TotalOutputTokens int64 `db:"total_output_tokens" json:"total_output_tokens"`
}
// Aggregate cost summary for a single user within a date range.
// Only counts assistant-role messages.
func (q *sqlQuerier) GetChatCostSummary(ctx context.Context, arg GetChatCostSummaryParams) (GetChatCostSummaryRow, error) {
row := q.db.QueryRowContext(ctx, getChatCostSummary, arg.OwnerID, arg.StartDate, arg.EndDate)
var i GetChatCostSummaryRow
err := row.Scan(
&i.TotalCostMicros,
&i.PricedMessageCount,
&i.UnpricedMessageCount,
&i.TotalInputTokens,
&i.TotalOutputTokens,
)
return i, err
}
const getChatDiffStatusByChatID = `-- name: GetChatDiffStatusByChatID :one
SELECT
chat_id, url, pull_request_state, changes_requested, additions, deletions, changed_files, refreshed_at, stale_at, created_at, updated_at, git_branch, git_remote_origin, pull_request_title, pull_request_draft
@@ -3311,7 +3658,7 @@ func (q *sqlQuerier) GetChatDiffStatusesByChatIDs(ctx context.Context, chatIds [
const getChatMessageByID = `-- name: GetChatMessageByID :one
SELECT
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros
FROM
chat_messages
WHERE
@@ -3339,13 +3686,14 @@ func (q *sqlQuerier) GetChatMessageByID(ctx context.Context, id int64) (ChatMess
&i.Compressed,
&i.CreatedBy,
&i.ContentVersion,
&i.TotalCostMicros,
)
return i, err
}
const getChatMessagesByChatID = `-- name: GetChatMessagesByChatID :many
SELECT
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros
FROM
chat_messages
WHERE
@@ -3388,6 +3736,7 @@ func (q *sqlQuerier) GetChatMessagesByChatID(ctx context.Context, arg GetChatMes
&i.Compressed,
&i.CreatedBy,
&i.ContentVersion,
&i.TotalCostMicros,
); err != nil {
return nil, err
}
@@ -3419,7 +3768,7 @@ WITH latest_compressed_summary AS (
1
)
SELECT
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros
FROM
chat_messages
WHERE
@@ -3486,6 +3835,7 @@ func (q *sqlQuerier) GetChatMessagesForPromptByChatID(ctx context.Context, chatI
&i.Compressed,
&i.CreatedBy,
&i.ContentVersion,
&i.TotalCostMicros,
); err != nil {
return nil, err
}
@@ -3629,7 +3979,7 @@ func (q *sqlQuerier) GetChatsByOwnerID(ctx context.Context, arg GetChatsByOwnerI
const getLastChatMessageByRole = `-- name: GetLastChatMessageByRole :one
SELECT
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros
FROM
chat_messages
WHERE
@@ -3667,6 +4017,7 @@ func (q *sqlQuerier) GetLastChatMessageByRole(ctx context.Context, arg GetLastCh
&i.Compressed,
&i.CreatedBy,
&i.ContentVersion,
&i.TotalCostMicros,
)
return i, err
}
@@ -3806,7 +4157,8 @@ INSERT INTO chat_messages (
cache_creation_tokens,
cache_read_tokens,
context_limit,
compressed
compressed,
total_cost_micros
) VALUES (
$1::uuid,
$2::uuid,
@@ -3822,10 +4174,11 @@ INSERT INTO chat_messages (
$12::bigint,
$13::bigint,
$14::bigint,
COALESCE($15::boolean, FALSE)
COALESCE($15::boolean, FALSE),
$16::bigint
)
RETURNING
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros
`
type InsertChatMessageParams struct {
@@ -3844,6 +4197,7 @@ type InsertChatMessageParams struct {
CacheReadTokens sql.NullInt64 `db:"cache_read_tokens" json:"cache_read_tokens"`
ContextLimit sql.NullInt64 `db:"context_limit" json:"context_limit"`
Compressed sql.NullBool `db:"compressed" json:"compressed"`
TotalCostMicros sql.NullInt64 `db:"total_cost_micros" json:"total_cost_micros"`
}
func (q *sqlQuerier) InsertChatMessage(ctx context.Context, arg InsertChatMessageParams) (ChatMessage, error) {
@@ -3863,6 +4217,7 @@ func (q *sqlQuerier) InsertChatMessage(ctx context.Context, arg InsertChatMessag
arg.CacheReadTokens,
arg.ContextLimit,
arg.Compressed,
arg.TotalCostMicros,
)
var i ChatMessage
err := row.Scan(
@@ -3883,6 +4238,7 @@ func (q *sqlQuerier) InsertChatMessage(ctx context.Context, arg InsertChatMessag
&i.Compressed,
&i.CreatedBy,
&i.ContentVersion,
&i.TotalCostMicros,
)
return i, err
}
@@ -4017,7 +4373,7 @@ SET
WHERE
id = $3::bigint
RETURNING
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros
`
type UpdateChatMessageByIDParams struct {
@@ -4047,6 +4403,7 @@ func (q *sqlQuerier) UpdateChatMessageByID(ctx context.Context, arg UpdateChatMe
&i.Compressed,
&i.CreatedBy,
&i.ContentVersion,
&i.TotalCostMicros,
)
return i, err
}
+165 -2
View File
@@ -179,7 +179,8 @@ INSERT INTO chat_messages (
cache_creation_tokens,
cache_read_tokens,
context_limit,
compressed
compressed,
total_cost_micros
) VALUES (
@chat_id::uuid,
sqlc.narg('created_by')::uuid,
@@ -195,7 +196,8 @@ INSERT INTO chat_messages (
sqlc.narg('cache_creation_tokens')::bigint,
sqlc.narg('cache_read_tokens')::bigint,
sqlc.narg('context_limit')::bigint,
COALESCE(sqlc.narg('compressed')::boolean, FALSE)
COALESCE(sqlc.narg('compressed')::boolean, FALSE),
sqlc.narg('total_cost_micros')::bigint
)
RETURNING
*;
@@ -481,3 +483,164 @@ SET
updated_at = NOW()
WHERE
chat_id = @chat_id::uuid;
-- name: GetChatCostSummary :one
-- Aggregate cost summary for a single user within a date range.
-- Only counts assistant-role messages.
SELECT
COALESCE(SUM(cm.total_cost_micros), 0)::bigint AS total_cost_micros,
COUNT(*) FILTER (
WHERE cm.total_cost_micros IS NOT NULL
)::bigint AS priced_message_count,
COUNT(*) FILTER (
WHERE cm.total_cost_micros IS NULL
AND (
cm.input_tokens IS NOT NULL
OR cm.output_tokens IS NOT NULL
OR cm.reasoning_tokens IS NOT NULL
OR cm.cache_creation_tokens IS NOT NULL
OR cm.cache_read_tokens IS NOT NULL
)
)::bigint AS unpriced_message_count,
COALESCE(SUM(cm.input_tokens), 0)::bigint AS total_input_tokens,
COALESCE(SUM(cm.output_tokens), 0)::bigint AS total_output_tokens
FROM
chat_messages cm
JOIN
chats c ON c.id = cm.chat_id
WHERE
c.owner_id = @owner_id::uuid
AND cm.role = 'assistant'
AND cm.created_at >= @start_date::timestamptz
AND cm.created_at < @end_date::timestamptz;
-- name: GetChatCostPerModel :many
-- Per-model cost breakdown for a single user within a date range.
-- Only counts assistant-role messages that have a model_config_id.
SELECT
cmc.id AS model_config_id,
cmc.display_name,
cmc.provider,
cmc.model,
COALESCE(SUM(cm.total_cost_micros), 0)::bigint AS total_cost_micros,
COUNT(*) FILTER (
WHERE cm.input_tokens IS NOT NULL
OR cm.output_tokens IS NOT NULL
OR cm.reasoning_tokens IS NOT NULL
OR cm.cache_creation_tokens IS NOT NULL
OR cm.cache_read_tokens IS NOT NULL
)::bigint AS message_count,
COALESCE(SUM(cm.input_tokens), 0)::bigint AS total_input_tokens,
COALESCE(SUM(cm.output_tokens), 0)::bigint AS total_output_tokens
FROM
chat_messages cm
JOIN
chats c ON c.id = cm.chat_id
JOIN
chat_model_configs cmc ON cmc.id = cm.model_config_id
WHERE
c.owner_id = @owner_id::uuid
AND cm.role = 'assistant'
AND cm.created_at >= @start_date::timestamptz
AND cm.created_at < @end_date::timestamptz
GROUP BY
cmc.id, cmc.display_name, cmc.provider, cmc.model
ORDER BY
total_cost_micros DESC;
-- name: GetChatCostPerChat :many
-- Per-root-chat cost breakdown for a single user within a date range.
-- Groups by root_chat_id so forked chats roll up under their root.
-- Only counts assistant-role messages.
WITH chat_costs AS (
SELECT
COALESCE(c.root_chat_id, c.id) AS root_chat_id,
COALESCE(SUM(cm.total_cost_micros), 0)::bigint AS total_cost_micros,
COUNT(*) FILTER (
WHERE cm.input_tokens IS NOT NULL
OR cm.output_tokens IS NOT NULL
OR cm.reasoning_tokens IS NOT NULL
OR cm.cache_creation_tokens IS NOT NULL
OR cm.cache_read_tokens IS NOT NULL
)::bigint AS message_count,
COALESCE(SUM(cm.input_tokens), 0)::bigint AS total_input_tokens,
COALESCE(SUM(cm.output_tokens), 0)::bigint AS total_output_tokens
FROM chat_messages cm
JOIN chats c ON c.id = cm.chat_id
WHERE c.owner_id = @owner_id::uuid
AND cm.role = 'assistant'
AND cm.created_at >= @start_date::timestamptz
AND cm.created_at < @end_date::timestamptz
GROUP BY COALESCE(c.root_chat_id, c.id)
)
SELECT
cc.root_chat_id,
COALESCE(rc.title, '') AS chat_title,
cc.total_cost_micros,
cc.message_count,
cc.total_input_tokens,
cc.total_output_tokens
FROM chat_costs cc
LEFT JOIN chats rc ON rc.id = cc.root_chat_id
ORDER BY cc.total_cost_micros DESC;
-- name: GetChatCostPerUser :many
-- Deployment-wide per-user cost rollup within a date range.
-- Only counts assistant-role messages.
WITH chat_cost_users AS (
SELECT
c.owner_id AS user_id,
u.username,
u.name,
u.avatar_url,
COALESCE(SUM(cm.total_cost_micros), 0)::bigint AS total_cost_micros,
COUNT(*) FILTER (
WHERE cm.input_tokens IS NOT NULL
OR cm.output_tokens IS NOT NULL
OR cm.reasoning_tokens IS NOT NULL
OR cm.cache_creation_tokens IS NOT NULL
OR cm.cache_read_tokens IS NOT NULL
)::bigint AS message_count,
COUNT(DISTINCT COALESCE(c.root_chat_id, c.id))::bigint AS chat_count,
COALESCE(SUM(cm.input_tokens), 0)::bigint AS total_input_tokens,
COALESCE(SUM(cm.output_tokens), 0)::bigint AS total_output_tokens
FROM
chat_messages cm
JOIN
chats c ON c.id = cm.chat_id
JOIN
users u ON u.id = c.owner_id
WHERE
cm.role = 'assistant'
AND cm.created_at >= @start_date::timestamptz
AND cm.created_at < @end_date::timestamptz
AND (
@username::text = ''
OR u.username ILIKE '%' || @username::text || '%'
)
GROUP BY
c.owner_id,
u.username,
u.name,
u.avatar_url
)
SELECT
user_id,
username,
name,
avatar_url,
total_cost_micros,
message_count,
chat_count,
total_input_tokens,
total_output_tokens,
COUNT(*) OVER()::bigint AS total_count
FROM
chat_cost_users
ORDER BY
total_cost_micros DESC,
username ASC
LIMIT
sqlc.arg('page_limit')::int
OFFSET
sqlc.arg('page_offset')::int;
+11
View File
@@ -148,6 +148,17 @@ sql:
go_type: "database/sql.NullTime"
- column: "task_event_data.first_status_after_resume_at"
go_type: "database/sql.NullTime"
- db_type: "pg_catalog.numeric"
go_type:
import: "github.com/shopspring/decimal"
type: "Decimal"
package: "decimal"
- db_type: "pg_catalog.numeric"
nullable: true
go_type:
import: "github.com/shopspring/decimal"
type: "NullDecimal"
package: "decimal"
rename:
group_member: GroupMemberTable
group_members_expanded: GroupMember
+148 -13
View File
@@ -7,10 +7,13 @@ import (
"io"
"mime"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/google/uuid"
"github.com/shopspring/decimal"
"github.com/coder/websocket"
"github.com/coder/websocket/wsjson"
@@ -198,11 +201,11 @@ func ChatMessageFileReference(fileName string, startLine, endLine int, content s
}
// ChatMessageSource builds a source chat message part.
func ChatMessageSource(sourceID, url, title string) ChatMessagePart {
func ChatMessageSource(sourceID, sourceURL, title string) ChatMessagePart {
return ChatMessagePart{
Type: ChatMessagePartTypeSource,
SourceID: sourceID,
URL: url,
URL: sourceURL,
Title: title,
}
}
@@ -519,13 +522,10 @@ type ChatModelVercelProviderOptions struct {
// ModelCostConfig stores pricing metadata for a chat model.
type ModelCostConfig struct {
// Pricing is stored as configuration metadata and currently only needs to
// round-trip cleanly through the API and admin UI. If we later use these
// values for billing-grade arithmetic, switch to a fixed-point type.
InputPricePerMillionTokens *float64 `json:"input_price_per_million_tokens,omitempty" description:"Input token price in USD per 1M tokens"`
OutputPricePerMillionTokens *float64 `json:"output_price_per_million_tokens,omitempty" description:"Output token price in USD per 1M tokens"`
CacheReadPricePerMillionTokens *float64 `json:"cache_read_price_per_million_tokens,omitempty" description:"Cache read token price in USD per 1M tokens"`
CacheWritePricePerMillionTokens *float64 `json:"cache_write_price_per_million_tokens,omitempty" description:"Cache write or cache creation token price in USD per 1M tokens"`
InputPricePerMillionTokens *decimal.Decimal `json:"input_price_per_million_tokens,omitempty" description:"Input token price in USD per 1M tokens"`
OutputPricePerMillionTokens *decimal.Decimal `json:"output_price_per_million_tokens,omitempty" description:"Output token price in USD per 1M tokens"`
CacheReadPricePerMillionTokens *decimal.Decimal `json:"cache_read_price_per_million_tokens,omitempty" description:"Cache read token price in USD per 1M tokens"`
CacheWritePricePerMillionTokens *decimal.Decimal `json:"cache_write_price_per_million_tokens,omitempty" description:"Cache write or cache creation token price in USD per 1M tokens"`
}
// ChatModelCallConfig configures per-call model behavior defaults.
@@ -546,10 +546,10 @@ func (c *ChatModelCallConfig) UnmarshalJSON(data []byte) error {
type chatModelCallConfigAlias ChatModelCallConfig
aux := struct {
*chatModelCallConfigAlias
InputPricePerMillionTokens *float64 `json:"input_price_per_million_tokens,omitempty"`
OutputPricePerMillionTokens *float64 `json:"output_price_per_million_tokens,omitempty"`
CacheReadPricePerMillionTokens *float64 `json:"cache_read_price_per_million_tokens,omitempty"`
CacheWritePricePerMillionTokens *float64 `json:"cache_write_price_per_million_tokens,omitempty"`
InputPricePerMillionTokens *decimal.Decimal `json:"input_price_per_million_tokens,omitempty"`
OutputPricePerMillionTokens *decimal.Decimal `json:"output_price_per_million_tokens,omitempty"`
CacheReadPricePerMillionTokens *decimal.Decimal `json:"cache_read_price_per_million_tokens,omitempty"`
CacheWritePricePerMillionTokens *decimal.Decimal `json:"cache_write_price_per_million_tokens,omitempty"`
}{
chatModelCallConfigAlias: (*chatModelCallConfigAlias)(c),
}
@@ -710,6 +710,76 @@ type chatStreamEnvelope struct {
Data json.RawMessage `json:"data,omitempty"`
}
// ChatCostSummaryOptions are optional query parameters for GetChatCostSummary.
type ChatCostSummaryOptions struct {
StartDate time.Time
EndDate time.Time
}
// ChatCostUsersOptions are optional query parameters for GetChatCostUsers.
type ChatCostUsersOptions struct {
StartDate time.Time
EndDate time.Time
Username string
Pagination
}
// ChatCostSummary is the response from the chat cost summary endpoint.
type ChatCostSummary struct {
StartDate time.Time `json:"start_date" format:"date-time"`
EndDate time.Time `json:"end_date" format:"date-time"`
TotalCostMicros int64 `json:"total_cost_micros"`
PricedMessageCount int64 `json:"priced_message_count"`
UnpricedMessageCount int64 `json:"unpriced_message_count"`
TotalInputTokens int64 `json:"total_input_tokens"`
TotalOutputTokens int64 `json:"total_output_tokens"`
ByModel []ChatCostModelBreakdown `json:"by_model"`
ByChat []ChatCostChatBreakdown `json:"by_chat"`
}
// ChatCostModelBreakdown contains per-model cost aggregation.
type ChatCostModelBreakdown struct {
ModelConfigID uuid.UUID `json:"model_config_id" format:"uuid"`
DisplayName string `json:"display_name"`
Provider string `json:"provider"`
Model string `json:"model"`
TotalCostMicros int64 `json:"total_cost_micros"`
MessageCount int64 `json:"message_count"`
TotalInputTokens int64 `json:"total_input_tokens"`
TotalOutputTokens int64 `json:"total_output_tokens"`
}
// ChatCostChatBreakdown contains per-root-chat cost aggregation.
type ChatCostChatBreakdown struct {
RootChatID uuid.UUID `json:"root_chat_id" format:"uuid"`
ChatTitle string `json:"chat_title"`
TotalCostMicros int64 `json:"total_cost_micros"`
MessageCount int64 `json:"message_count"`
TotalInputTokens int64 `json:"total_input_tokens"`
TotalOutputTokens int64 `json:"total_output_tokens"`
}
// ChatCostUserRollup contains per-user cost aggregation for admin views.
type ChatCostUserRollup struct {
UserID uuid.UUID `json:"user_id" format:"uuid"`
Username string `json:"username"`
Name string `json:"name"`
AvatarURL string `json:"avatar_url"`
TotalCostMicros int64 `json:"total_cost_micros"`
MessageCount int64 `json:"message_count"`
ChatCount int64 `json:"chat_count"`
TotalInputTokens int64 `json:"total_input_tokens"`
TotalOutputTokens int64 `json:"total_output_tokens"`
}
// ChatCostUsersResponse is the response from the admin chat cost users endpoint.
type ChatCostUsersResponse struct {
StartDate time.Time `json:"start_date" format:"date-time"`
EndDate time.Time `json:"end_date" format:"date-time"`
Count int64 `json:"count"`
Users []ChatCostUserRollup `json:"users"`
}
// ListChatsOptions are optional parameters for ListChats.
type ListChatsOptions struct {
Query string
@@ -872,6 +942,71 @@ func (c *Client) DeleteChatModelConfig(ctx context.Context, modelConfigID uuid.U
return nil
}
// GetChatCostSummary returns an aggregate cost summary for the specified
// user. Zero-valued StartDate or EndDate fields are omitted from the
// request, letting the server apply its own defaults (typically the last
// 30 days).
func (c *Client) GetChatCostSummary(ctx context.Context, user string, opts ChatCostSummaryOptions) (ChatCostSummary, error) {
qp := url.Values{}
if !opts.StartDate.IsZero() {
qp.Set("start_date", opts.StartDate.Format(time.RFC3339))
}
if !opts.EndDate.IsZero() {
qp.Set("end_date", opts.EndDate.Format(time.RFC3339))
}
reqURL := fmt.Sprintf("/api/experimental/chats/cost/%s/summary", user)
if len(qp) > 0 {
reqURL += "?" + qp.Encode()
}
res, err := c.Request(ctx, http.MethodGet, reqURL, nil)
if err != nil {
return ChatCostSummary{}, err
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return ChatCostSummary{}, ReadBodyAsError(res)
}
var summary ChatCostSummary
return summary, json.NewDecoder(res.Body).Decode(&summary)
}
// GetChatCostUsers returns a per-user cost rollup for the deployment
// (admin only). Zero-valued StartDate or EndDate fields are omitted from
// the request, letting the server apply its own defaults (typically the
// last 30 days).
func (c *Client) GetChatCostUsers(ctx context.Context, opts ChatCostUsersOptions) (ChatCostUsersResponse, error) {
qp := url.Values{}
if !opts.StartDate.IsZero() {
qp.Set("start_date", opts.StartDate.Format(time.RFC3339))
}
if !opts.EndDate.IsZero() {
qp.Set("end_date", opts.EndDate.Format(time.RFC3339))
}
if opts.Username != "" {
qp.Set("username", opts.Username)
}
if opts.Limit > 0 {
qp.Set("limit", strconv.Itoa(opts.Limit))
}
if opts.Offset > 0 {
qp.Set("offset", strconv.Itoa(opts.Offset))
}
reqURL := "/api/experimental/chats/cost/users"
if len(qp) > 0 {
reqURL += "?" + qp.Encode()
}
res, err := c.Request(ctx, http.MethodGet, reqURL, nil)
if err != nil {
return ChatCostUsersResponse{}, err
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return ChatCostUsersResponse{}, ReadBodyAsError(res)
}
var resp ChatCostUsersResponse
return resp, json.NewDecoder(res.Body).Decode(&resp)
}
// GetChatSystemPrompt returns the deployment-wide chat system prompt.
func (c *Client) GetChatSystemPrompt(ctx context.Context) (ChatSystemPromptResponse, error) {
res, err := c.Request(ctx, http.MethodGet, "/api/experimental/chats/config/system-prompt", nil)
+63
View File
@@ -5,6 +5,7 @@ import (
"testing"
"github.com/google/uuid"
"github.com/shopspring/decimal"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -109,3 +110,65 @@ func TestChatMessagePart_StripInternal(t *testing.T) {
assert.Equal(t, codersdk.ChatMessagePartTypeText, part.Type)
})
}
func TestModelCostConfig_LegacyNumericJSON(t *testing.T) {
t.Parallel()
var decoded codersdk.ModelCostConfig
err := json.Unmarshal([]byte("{\"input_price_per_million_tokens\": 1.5}"), &decoded)
require.NoError(t, err)
require.NotNil(t, decoded.InputPricePerMillionTokens)
require.True(t, decoded.InputPricePerMillionTokens.Equal(decimal.RequireFromString("1.5")))
}
func TestModelCostConfig_QuotedDecimalJSON(t *testing.T) {
t.Parallel()
var decoded codersdk.ModelCostConfig
err := json.Unmarshal([]byte("{\"input_price_per_million_tokens\": \"1.5\"}"), &decoded)
require.NoError(t, err)
require.NotNil(t, decoded.InputPricePerMillionTokens)
require.True(t, decoded.InputPricePerMillionTokens.Equal(decimal.RequireFromString("1.5")))
}
func TestModelCostConfig_NilVsZero(t *testing.T) {
t.Parallel()
zero := decimal.Zero
raw, err := json.Marshal(struct {
Nil codersdk.ModelCostConfig `json:"nil"`
Zero codersdk.ModelCostConfig `json:"zero"`
}{
Nil: codersdk.ModelCostConfig{},
Zero: codersdk.ModelCostConfig{InputPricePerMillionTokens: &zero},
})
require.NoError(t, err)
require.Contains(t, string(raw), "\"zero\":{\"input_price_per_million_tokens\":\"0\"}")
require.Contains(t, string(raw), "\"nil\":{}")
}
func TestChatModelCallConfig_UnmarshalLegacyPricing(t *testing.T) {
t.Parallel()
var decoded codersdk.ChatModelCallConfig
err := json.Unmarshal([]byte("{\"input_price_per_million_tokens\": 1.5}"), &decoded)
require.NoError(t, err)
require.NotNil(t, decoded.Cost)
require.NotNil(t, decoded.Cost.InputPricePerMillionTokens)
require.True(t, decoded.Cost.InputPricePerMillionTokens.Equal(decimal.RequireFromString("1.5")))
}
func TestChatCostSummary_JSONRoundTrip(t *testing.T) {
t.Parallel()
original := codersdk.ChatCostSummary{
TotalCostMicros: 123,
}
raw, err := json.Marshal(original)
require.NoError(t, err)
var decoded codersdk.ChatCostSummary
err = json.Unmarshal(raw, &decoded)
require.NoError(t, err)
require.Equal(t, original.TotalCostMicros, decoded.TotalCostMicros)
}
+1
View File
@@ -0,0 +1 @@
# Chat
+1
View File
@@ -492,6 +492,7 @@ require (
github.com/go-git/go-git/v5 v5.17.0
github.com/mark3labs/mcp-go v0.38.0
github.com/openai/openai-go/v3 v3.15.0
github.com/shopspring/decimal v1.4.0
gonum.org/v1/gonum v0.17.0
)
+2
View File
@@ -1059,6 +1059,8 @@ github.com/sergi/go-diff v1.4.0 h1:n/SP9D5ad1fORl+llWyN+D6qoUETXNZARKjyY2/KVCw=
github.com/sergi/go-diff v1.4.0/go.mod h1:A0bzQcvG0E7Rwjx0REVgAGH58e96+X0MeOfepqsbeW4=
github.com/shirou/gopsutil/v4 v4.26.1 h1:TOkEyriIXk2HX9d4isZJtbjXbEjf5qyKPAzbzY0JWSo=
github.com/shirou/gopsutil/v4 v4.26.1/go.mod h1:medLI9/UNAb0dOI9Q3/7yWSqKkj00u+1tgY8nvv41pc=
github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k=
github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME=
github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
github.com/sirupsen/logrus v1.9.4 h1:TsZE7l11zFCLZnZ+teH4Umoq5BhEIfIzfRDZ1Uzql2w=
github.com/sirupsen/logrus v1.9.4/go.mod h1:ftWc9WdOfJ0a92nsE2jF5u5ZwH8Bv2zdeOC42RjbV2g=
+4
View File
@@ -130,6 +130,10 @@ func TypeMappings(gen *guts.GoParser) error {
"github.com/coder/serpent.URL": "string",
"github.com/coder/serpent.HostPort": "string",
"encoding/json.RawMessage": "map[string]string",
// decimal.Decimal preserves exact pricing precision (e.g. $3.50 per
// million tokens) and serializes as a JSON string to avoid
// floating-point loss in transit.
"github.com/shopspring/decimal.Decimal": "string",
})
if err != nil {
return xerrors.Errorf("include custom: %w", err)
+13 -1
View File
@@ -7,6 +7,8 @@ import (
"reflect"
"strings"
"github.com/shopspring/decimal"
"github.com/coder/coder/v2/codersdk"
)
@@ -117,11 +119,15 @@ func extractFields(t reflect.Type, prefix string, skip map[string]bool) FieldGro
// so that entire sub-objects can be marked hidden.
hidden := f.Tag.Get("hidden") == "true"
// decimal.Decimal is an opaque numeric type used for pricing
// precision; do not recurse into its internal struct fields.
isDecimal := ft == reflect.TypeOf(decimal.Decimal{})
// If the field is a struct (not a map), recurse to flatten
// its children using dot-separated names — unless the
// entire struct is marked hidden, in which case emit it
// as a single opaque field.
if ft.Kind() == reflect.Struct && !hidden {
if ft.Kind() == reflect.Struct && !hidden && !isDecimal {
nested := extractFields(ft, fullJSONName, nil)
fields = append(fields, nested.Fields...)
continue
@@ -206,6 +212,12 @@ func goTypeToSchemaType(t reflect.Type) string {
t = t.Elem()
}
// decimal.Decimal represents a precise numeric value and should
// map to the "number" schema type.
if t == reflect.TypeOf(decimal.Decimal{}) {
return "number"
}
switch t.Kind() {
case reflect.String:
return "string"
+94 -9
View File
@@ -1071,6 +1071,96 @@ export interface Chat {
readonly archived: boolean;
}
// From codersdk/chats.go
/**
* ChatCostChatBreakdown contains per-root-chat cost aggregation.
*/
export interface ChatCostChatBreakdown {
readonly root_chat_id: string;
readonly chat_title: string;
readonly total_cost_micros: number;
readonly message_count: number;
readonly total_input_tokens: number;
readonly total_output_tokens: number;
}
// From codersdk/chats.go
/**
* ChatCostModelBreakdown contains per-model cost aggregation.
*/
export interface ChatCostModelBreakdown {
readonly model_config_id: string;
readonly display_name: string;
readonly provider: string;
readonly model: string;
readonly total_cost_micros: number;
readonly message_count: number;
readonly total_input_tokens: number;
readonly total_output_tokens: number;
}
// From codersdk/chats.go
/**
* ChatCostSummary is the response from the chat cost summary endpoint.
*/
export interface ChatCostSummary {
readonly start_date: string;
readonly end_date: string;
readonly total_cost_micros: number;
readonly priced_message_count: number;
readonly unpriced_message_count: number;
readonly total_input_tokens: number;
readonly total_output_tokens: number;
readonly by_model: readonly ChatCostModelBreakdown[];
readonly by_chat: readonly ChatCostChatBreakdown[];
}
// From codersdk/chats.go
/**
* ChatCostSummaryOptions are optional query parameters for GetChatCostSummary.
*/
export interface ChatCostSummaryOptions {
readonly StartDate: string;
readonly EndDate: string;
}
// From codersdk/chats.go
/**
* ChatCostUserRollup contains per-user cost aggregation for admin views.
*/
export interface ChatCostUserRollup {
readonly user_id: string;
readonly username: string;
readonly name: string;
readonly avatar_url: string;
readonly total_cost_micros: number;
readonly message_count: number;
readonly chat_count: number;
readonly total_input_tokens: number;
readonly total_output_tokens: number;
}
// From codersdk/chats.go
/**
* ChatCostUsersOptions are optional query parameters for GetChatCostUsers.
*/
export interface ChatCostUsersOptions extends Pagination {
readonly StartDate: string;
readonly EndDate: string;
readonly Username: string;
}
// From codersdk/chats.go
/**
* ChatCostUsersResponse is the response from the admin chat cost users endpoint.
*/
export interface ChatCostUsersResponse {
readonly start_date: string;
readonly end_date: string;
readonly count: number;
readonly users: readonly ChatCostUserRollup[];
}
// From codersdk/chats.go
/**
* ChatDiffContents represents the resolved diff text for a chat.
@@ -3466,15 +3556,10 @@ export interface MinimalUser {
* ModelCostConfig stores pricing metadata for a chat model.
*/
export interface ModelCostConfig {
/**
* Pricing is stored as configuration metadata and currently only needs to
* round-trip cleanly through the API and admin UI. If we later use these
* values for billing-grade arithmetic, switch to a fixed-point type.
*/
readonly input_price_per_million_tokens?: number;
readonly output_price_per_million_tokens?: number;
readonly cache_read_price_per_million_tokens?: number;
readonly cache_write_price_per_million_tokens?: number;
readonly input_price_per_million_tokens?: string;
readonly output_price_per_million_tokens?: string;
readonly cache_read_price_per_million_tokens?: string;
readonly cache_write_price_per_million_tokens?: string;
}
// From netcheck/netcheck.go
@@ -1,99 +0,0 @@
import { render, screen } from "@testing-library/react";
import type * as TypesGen from "api/typesGenerated";
import { TooltipProvider } from "components/Tooltip/Tooltip";
import { describe, expect, it, vi } from "vitest";
import type { ProviderState } from "./ChatModelAdminPanel";
import { ModelsSection } from "./ModelsSection";
vi.mock("./ProviderIcon", () => ({
ProviderIcon: ({ provider }: { provider: string }) => (
<div data-testid="provider-icon">{provider}</div>
),
}));
const providerState: ProviderState = {
provider: "openai",
label: "OpenAI",
providerConfig: {
id: "provider-config-id",
provider: "openai",
display_name: "OpenAI",
enabled: true,
has_api_key: true,
base_url: undefined,
source: "database",
created_at: "2025-01-01T00:00:00Z",
updated_at: "2025-01-01T00:00:00Z",
},
modelConfigs: [],
catalogModelCount: 0,
hasManagedAPIKey: true,
hasCatalogAPIKey: true,
hasEffectiveAPIKey: true,
isEnvPreset: false,
baseURL: "",
};
const baseModelConfig: TypesGen.ChatModelConfig = {
id: "model-config-id",
provider: "openai",
model: "gpt-4.1",
display_name: "GPT-4.1",
enabled: true,
is_default: false,
context_limit: 128000,
compression_threshold: 80,
created_at: "2025-01-01T00:00:00Z",
updated_at: "2025-01-01T00:00:00Z",
};
const renderModelsSection = (
modelConfigs: readonly TypesGen.ChatModelConfig[],
) => {
return render(
<TooltipProvider>
<ModelsSection
sectionLabel="Models"
providerStates={[providerState]}
selectedProvider="openai"
selectedProviderState={providerState}
onSelectedProviderChange={vi.fn()}
modelConfigs={modelConfigs}
modelConfigsUnavailable={false}
isCreating={false}
isUpdating={false}
isDeleting={false}
onCreateModel={vi.fn()}
onUpdateModel={vi.fn()}
onDeleteModel={vi.fn()}
/>
</TooltipProvider>,
);
};
describe("ModelsSection", () => {
it("shows a warning when a model has no custom pricing configured", () => {
renderModelsSection([baseModelConfig]);
expect(
screen.getByText("Model pricing is not defined"),
).toBeInTheDocument();
});
it("hides the warning when a model has explicit zero pricing", () => {
renderModelsSection([
{
...baseModelConfig,
model_config: {
cost: {
output_price_per_million_tokens: 0,
},
},
},
]);
expect(
screen.queryByText("Model pricing is not defined"),
).not.toBeInTheDocument();
});
});
@@ -210,10 +210,10 @@ describe("extractModelConfigFormState", () => {
...baseChatModelConfig,
model_config: {
cost: {
input_price_per_million_tokens: 0.15,
output_price_per_million_tokens: 0.6,
cache_read_price_per_million_tokens: 0.03,
cache_write_price_per_million_tokens: 0.3,
input_price_per_million_tokens: "0.15",
output_price_per_million_tokens: "0.6",
cache_read_price_per_million_tokens: "0.03",
cache_write_price_per_million_tokens: "0.3",
},
},
};
@@ -553,10 +553,10 @@ describe("buildModelConfigFromForm", () => {
expect(result.fieldErrors).toEqual({});
expect(result.modelConfig).toMatchObject({
cost: {
input_price_per_million_tokens: 0.15,
output_price_per_million_tokens: 0.6,
cache_read_price_per_million_tokens: 0.03,
cache_write_price_per_million_tokens: 0.3,
input_price_per_million_tokens: "0.15",
output_price_per_million_tokens: "0.6",
cache_read_price_per_million_tokens: "0.03",
cache_write_price_per_million_tokens: "0.3",
},
});
});
@@ -110,7 +110,7 @@ function convertFormValue(value: string, field: FieldSchema): unknown {
case "integer":
return Number.parseInt(trimmed, 10);
case "number":
return Number(trimmed);
return isNonNegativePricingField(field) ? trimmed : Number(trimmed);
case "boolean":
return trimmed === "true";
case "array":
@@ -10,7 +10,7 @@ import {
describe("pricingFields", () => {
it("uses $0 defaults for every pricing field", () => {
for (const fieldName of pricingFieldNameList) {
expect(getDefaultPricingForField(fieldName)).toBe(0);
expect(getDefaultPricingForField(fieldName)).toBe("0");
expect(getPricingPlaceholderForField(fieldName)).toBe("0");
}
});
@@ -23,8 +23,8 @@ describe("pricingFields", () => {
expect(
hasCustomPricing({
cost: {
input_price_per_million_tokens: 0,
output_price_per_million_tokens: 0,
input_price_per_million_tokens: "0",
output_price_per_million_tokens: "0",
},
} satisfies TypesGen.ChatModelCallConfig),
).toBe(true);
@@ -34,7 +34,7 @@ describe("pricingFields", () => {
expect(
hasCustomPricing({
cost: {
cache_write_price_per_million_tokens: 0.25,
cache_write_price_per_million_tokens: "0.25",
},
} satisfies TypesGen.ChatModelCallConfig),
).toBe(true);
@@ -14,11 +14,11 @@ export const pricingFieldNames = new Set<string>(pricingFieldNameList);
type PricingFieldName = (typeof pricingFieldNameList)[number];
export const defaultPricingByFieldName = {
"cost.input_price_per_million_tokens": 0,
"cost.output_price_per_million_tokens": 0,
"cost.cache_read_price_per_million_tokens": 0,
"cost.cache_write_price_per_million_tokens": 0,
} as const satisfies Record<PricingFieldName, number>;
"cost.input_price_per_million_tokens": "0",
"cost.output_price_per_million_tokens": "0",
"cost.cache_read_price_per_million_tokens": "0",
"cost.cache_write_price_per_million_tokens": "0",
} as const satisfies Record<PricingFieldName, string>;
export const pricingPlaceholderByFieldName = {
"cost.input_price_per_million_tokens": "0",
@@ -29,7 +29,7 @@ export const pricingPlaceholderByFieldName = {
export const getDefaultPricingForField = (
fieldName: string,
): number | undefined =>
): string | undefined =>
defaultPricingByFieldName[
fieldName as keyof typeof defaultPricingByFieldName
];