mirror of
https://github.com/coder/coder.git
synced 2026-06-03 13:08:25 +00:00
c3b6284955
Add cost tracking for LLM chat interactions with microdollar precision. ## Changes - Add `chatcost` package for per-message cost calculation using `shopspring/decimal` for intermediate arithmetic - **Ceil rounding policy**: fractional micros round UP to next whole micro (applied once after summing all components) - Database migration: `total_cost_micros` BIGINT column with historical backfill and `created_at` index - API endpoints: per-user cost summary and admin rollup under `/api/experimental/chats/cost/` - SDK types: `ChatCostSummary`, `ChatCostModelBreakdown`, `ChatCostUserRollup` - Fix `modeloptionsgen` to handle `decimal.Decimal` as opaque numeric type - Update frontend pricing test fixtures for string decimal types ## Design decisions - `NULL` = unpriced (no matching model config), `0` = free - Reasoning tokens included in output tokens (no double-counting) - Integer microdollars (BIGINT) for storage and API responses - Price config uses `decimal.Decimal` for exact parsing; totals use `int64` Frontend: #23037
164 lines
4.8 KiB
Go
164 lines
4.8 KiB
Go
package chatcost_test
|
|
|
|
import (
|
|
"testing"
|
|
|
|
"github.com/shopspring/decimal"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/coder/coder/v2/coderd/chatd/chatcost"
|
|
"github.com/coder/coder/v2/coderd/util/ptr"
|
|
"github.com/coder/coder/v2/codersdk"
|
|
)
|
|
|
|
func TestCalculateTotalCostMicros(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tests := []struct {
|
|
name string
|
|
usage codersdk.ChatMessageUsage
|
|
cost *codersdk.ModelCostConfig
|
|
want *int64
|
|
}{
|
|
{
|
|
name: "nil cost returns nil",
|
|
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](1000)},
|
|
cost: nil,
|
|
want: nil,
|
|
},
|
|
{
|
|
name: "all priced usage fields nil returns nil",
|
|
usage: codersdk.ChatMessageUsage{
|
|
TotalTokens: ptr.Ref[int64](1234),
|
|
ContextLimit: ptr.Ref[int64](8192),
|
|
},
|
|
cost: &codersdk.ModelCostConfig{
|
|
InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("3")),
|
|
},
|
|
want: nil,
|
|
},
|
|
{
|
|
name: "sub-micro total rounds up to 1",
|
|
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](1)},
|
|
cost: &codersdk.ModelCostConfig{
|
|
InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("0.01")),
|
|
},
|
|
want: ptr.Ref[int64](1),
|
|
},
|
|
{
|
|
name: "simple input only",
|
|
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](1000)},
|
|
cost: &codersdk.ModelCostConfig{
|
|
InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("3")),
|
|
},
|
|
want: ptr.Ref[int64](3000),
|
|
},
|
|
{
|
|
name: "simple output only",
|
|
usage: codersdk.ChatMessageUsage{OutputTokens: ptr.Ref[int64](500)},
|
|
cost: &codersdk.ModelCostConfig{
|
|
OutputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("15")),
|
|
},
|
|
want: ptr.Ref[int64](7500),
|
|
},
|
|
{
|
|
name: "reasoning tokens included in output total",
|
|
usage: codersdk.ChatMessageUsage{
|
|
OutputTokens: ptr.Ref[int64](500),
|
|
ReasoningTokens: ptr.Ref[int64](200),
|
|
},
|
|
cost: &codersdk.ModelCostConfig{
|
|
OutputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("15")),
|
|
},
|
|
want: ptr.Ref[int64](7500),
|
|
},
|
|
{
|
|
name: "cache read tokens",
|
|
usage: codersdk.ChatMessageUsage{CacheReadTokens: ptr.Ref[int64](10000)},
|
|
cost: &codersdk.ModelCostConfig{
|
|
CacheReadPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("0.3")),
|
|
},
|
|
want: ptr.Ref[int64](3000),
|
|
},
|
|
{
|
|
name: "cache creation tokens",
|
|
usage: codersdk.ChatMessageUsage{CacheCreationTokens: ptr.Ref[int64](5000)},
|
|
cost: &codersdk.ModelCostConfig{
|
|
CacheWritePricePerMillionTokens: ptr.Ref(decimal.RequireFromString("3.75")),
|
|
},
|
|
want: ptr.Ref[int64](18750),
|
|
},
|
|
{
|
|
name: "full mixed usage totals all components exactly",
|
|
usage: codersdk.ChatMessageUsage{
|
|
InputTokens: ptr.Ref[int64](101),
|
|
OutputTokens: ptr.Ref[int64](201),
|
|
ReasoningTokens: ptr.Ref[int64](52),
|
|
CacheReadTokens: ptr.Ref[int64](1005),
|
|
CacheCreationTokens: ptr.Ref[int64](33),
|
|
TotalTokens: ptr.Ref[int64](1391),
|
|
ContextLimit: ptr.Ref[int64](4096),
|
|
},
|
|
cost: &codersdk.ModelCostConfig{
|
|
InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("1.23")),
|
|
OutputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("4.56")),
|
|
CacheReadPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("0.7")),
|
|
CacheWritePricePerMillionTokens: ptr.Ref(decimal.RequireFromString("7.89")),
|
|
},
|
|
want: ptr.Ref[int64](2005),
|
|
},
|
|
{
|
|
name: "partial pricing only input contributes",
|
|
usage: codersdk.ChatMessageUsage{
|
|
InputTokens: ptr.Ref[int64](1234),
|
|
OutputTokens: ptr.Ref[int64](999),
|
|
ReasoningTokens: ptr.Ref[int64](111),
|
|
CacheReadTokens: ptr.Ref[int64](500),
|
|
CacheCreationTokens: ptr.Ref[int64](250),
|
|
},
|
|
cost: &codersdk.ModelCostConfig{
|
|
InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("2.5")),
|
|
},
|
|
want: ptr.Ref[int64](3085),
|
|
},
|
|
{
|
|
name: "zero tokens with pricing returns zero pointer",
|
|
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](0)},
|
|
cost: &codersdk.ModelCostConfig{
|
|
InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("3")),
|
|
},
|
|
want: ptr.Ref[int64](0),
|
|
},
|
|
{
|
|
name: "usage only in unpriced categories returns nil",
|
|
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](1000)},
|
|
cost: &codersdk.ModelCostConfig{
|
|
OutputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("15")),
|
|
},
|
|
want: nil,
|
|
},
|
|
{
|
|
name: "non nil usage with empty cost config returns nil",
|
|
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](42)},
|
|
cost: &codersdk.ModelCostConfig{},
|
|
want: nil,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
tt := tt
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
got := chatcost.CalculateTotalCostMicros(tt.usage, tt.cost)
|
|
|
|
if tt.want == nil {
|
|
require.Nil(t, got)
|
|
} else {
|
|
require.NotNil(t, got)
|
|
require.Equal(t, *tt.want, *got)
|
|
}
|
|
})
|
|
}
|
|
}
|