mirror of
https://github.com/coder/coder.git
synced 2026-06-03 13:08:25 +00:00
c3b6284955
Add cost tracking for LLM chat interactions with microdollar precision. ## Changes - Add `chatcost` package for per-message cost calculation using `shopspring/decimal` for intermediate arithmetic - **Ceil rounding policy**: fractional micros round UP to next whole micro (applied once after summing all components) - Database migration: `total_cost_micros` BIGINT column with historical backfill and `created_at` index - API endpoints: per-user cost summary and admin rollup under `/api/experimental/chats/cost/` - SDK types: `ChatCostSummary`, `ChatCostModelBreakdown`, `ChatCostUserRollup` - Fix `modeloptionsgen` to handle `decimal.Decimal` as opaque numeric type - Update frontend pricing test fixtures for string decimal types ## Design decisions - `NULL` = unpriced (no matching model config), `0` = free - Reasoning tokens included in output tokens (no double-counting) - Integer microdollars (BIGINT) for storage and API responses - Price config uses `decimal.Decimal` for exact parsing; totals use `int64` Frontend: #23037
175 lines
5.0 KiB
Go
175 lines
5.0 KiB
Go
package codersdk_test
|
|
|
|
import (
|
|
"encoding/json"
|
|
"testing"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/shopspring/decimal"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/coder/coder/v2/codersdk"
|
|
)
|
|
|
|
func TestChatModelProviderOptions_MarshalJSON_UsesPlainProviderPayload(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
sendReasoning := true
|
|
effort := "high"
|
|
|
|
raw, err := json.Marshal(codersdk.ChatModelProviderOptions{
|
|
Anthropic: &codersdk.ChatModelAnthropicProviderOptions{
|
|
SendReasoning: &sendReasoning,
|
|
Effort: &effort,
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
require.NotContains(t, string(raw), `"type":"anthropic.options"`)
|
|
require.NotContains(t, string(raw), `"data":`)
|
|
require.Contains(t, string(raw), `"send_reasoning":true`)
|
|
require.Contains(t, string(raw), `"effort":"high"`)
|
|
}
|
|
|
|
func TestChatModelProviderOptions_UnmarshalJSON_ParsesPlainProviderPayloads(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
raw := []byte(`{
|
|
"anthropic": {
|
|
"send_reasoning": true,
|
|
"effort": "high"
|
|
}
|
|
}`)
|
|
|
|
var decoded codersdk.ChatModelProviderOptions
|
|
err := json.Unmarshal(raw, &decoded)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, decoded.Anthropic)
|
|
require.NotNil(t, decoded.Anthropic.SendReasoning)
|
|
require.True(t, *decoded.Anthropic.SendReasoning)
|
|
require.NotNil(t, decoded.Anthropic.Effort)
|
|
require.Equal(
|
|
t,
|
|
"high",
|
|
*decoded.Anthropic.Effort,
|
|
)
|
|
}
|
|
|
|
func TestChatMessagePart_StripInternal(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("StripsProviderMetadata", func(t *testing.T) {
|
|
t.Parallel()
|
|
part := codersdk.ChatMessagePart{
|
|
Type: codersdk.ChatMessagePartTypeToolCall,
|
|
ToolCallID: "call-1",
|
|
ToolName: "some_tool",
|
|
Args: json.RawMessage(`{"key":"value"}`),
|
|
ProviderMetadata: json.RawMessage(`{"type":"ephemeral"}`),
|
|
}
|
|
part.StripInternal()
|
|
assert.Nil(t, part.ProviderMetadata)
|
|
// Public fields preserved.
|
|
assert.Equal(t, codersdk.ChatMessagePartTypeToolCall, part.Type)
|
|
assert.Equal(t, "call-1", part.ToolCallID)
|
|
assert.Equal(t, "some_tool", part.ToolName)
|
|
assert.JSONEq(t, `{"key":"value"}`, string(part.Args))
|
|
})
|
|
|
|
t.Run("StripsFileDataWhenFileIDSet", func(t *testing.T) {
|
|
t.Parallel()
|
|
id := uuid.New()
|
|
part := codersdk.ChatMessagePart{
|
|
Type: codersdk.ChatMessagePartTypeFile,
|
|
FileID: uuid.NullUUID{UUID: id, Valid: true},
|
|
MediaType: "image/png",
|
|
Data: []byte("binary-payload"),
|
|
}
|
|
part.StripInternal()
|
|
assert.Nil(t, part.Data)
|
|
assert.Equal(t, id, part.FileID.UUID)
|
|
assert.Equal(t, "image/png", part.MediaType)
|
|
})
|
|
|
|
t.Run("PreservesDataWhenNoFileID", func(t *testing.T) {
|
|
t.Parallel()
|
|
part := codersdk.ChatMessagePart{
|
|
Type: codersdk.ChatMessagePartTypeFile,
|
|
MediaType: "image/png",
|
|
Data: []byte("inline-data"),
|
|
}
|
|
part.StripInternal()
|
|
assert.Equal(t, []byte("inline-data"), part.Data)
|
|
})
|
|
|
|
t.Run("NoopOnCleanPart", func(t *testing.T) {
|
|
t.Parallel()
|
|
part := codersdk.ChatMessageText("hello")
|
|
part.StripInternal()
|
|
assert.Equal(t, "hello", part.Text)
|
|
assert.Equal(t, codersdk.ChatMessagePartTypeText, part.Type)
|
|
})
|
|
}
|
|
|
|
func TestModelCostConfig_LegacyNumericJSON(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var decoded codersdk.ModelCostConfig
|
|
err := json.Unmarshal([]byte("{\"input_price_per_million_tokens\": 1.5}"), &decoded)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, decoded.InputPricePerMillionTokens)
|
|
require.True(t, decoded.InputPricePerMillionTokens.Equal(decimal.RequireFromString("1.5")))
|
|
}
|
|
|
|
func TestModelCostConfig_QuotedDecimalJSON(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var decoded codersdk.ModelCostConfig
|
|
err := json.Unmarshal([]byte("{\"input_price_per_million_tokens\": \"1.5\"}"), &decoded)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, decoded.InputPricePerMillionTokens)
|
|
require.True(t, decoded.InputPricePerMillionTokens.Equal(decimal.RequireFromString("1.5")))
|
|
}
|
|
|
|
func TestModelCostConfig_NilVsZero(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
zero := decimal.Zero
|
|
raw, err := json.Marshal(struct {
|
|
Nil codersdk.ModelCostConfig `json:"nil"`
|
|
Zero codersdk.ModelCostConfig `json:"zero"`
|
|
}{
|
|
Nil: codersdk.ModelCostConfig{},
|
|
Zero: codersdk.ModelCostConfig{InputPricePerMillionTokens: &zero},
|
|
})
|
|
require.NoError(t, err)
|
|
require.Contains(t, string(raw), "\"zero\":{\"input_price_per_million_tokens\":\"0\"}")
|
|
require.Contains(t, string(raw), "\"nil\":{}")
|
|
}
|
|
|
|
func TestChatModelCallConfig_UnmarshalLegacyPricing(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var decoded codersdk.ChatModelCallConfig
|
|
err := json.Unmarshal([]byte("{\"input_price_per_million_tokens\": 1.5}"), &decoded)
|
|
require.NoError(t, err)
|
|
require.NotNil(t, decoded.Cost)
|
|
require.NotNil(t, decoded.Cost.InputPricePerMillionTokens)
|
|
require.True(t, decoded.Cost.InputPricePerMillionTokens.Equal(decimal.RequireFromString("1.5")))
|
|
}
|
|
|
|
func TestChatCostSummary_JSONRoundTrip(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
original := codersdk.ChatCostSummary{
|
|
TotalCostMicros: 123,
|
|
}
|
|
raw, err := json.Marshal(original)
|
|
require.NoError(t, err)
|
|
|
|
var decoded codersdk.ChatCostSummary
|
|
err = json.Unmarshal(raw, &decoded)
|
|
require.NoError(t, err)
|
|
require.Equal(t, original.TotalCostMicros, decoded.TotalCostMicros)
|
|
}
|