mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
feat: add chat cost analytics backend (#23036)
Add cost tracking for LLM chat interactions with microdollar precision. ## Changes - Add `chatcost` package for per-message cost calculation using `shopspring/decimal` for intermediate arithmetic - **Ceil rounding policy**: fractional micros round UP to next whole micro (applied once after summing all components) - Database migration: `total_cost_micros` BIGINT column with historical backfill and `created_at` index - API endpoints: per-user cost summary and admin rollup under `/api/experimental/chats/cost/` - SDK types: `ChatCostSummary`, `ChatCostModelBreakdown`, `ChatCostUserRollup` - Fix `modeloptionsgen` to handle `decimal.Decimal` as opaque numeric type - Update frontend pricing test fixtures for string decimal types ## Design decisions - `NULL` = unpriced (no matching model config), `0` = free - Reasoning tokens included in output tokens (no double-counting) - Integer microdollars (BIGINT) for storage and API responses - Price config uses `decimal.Decimal` for exact parsing; totals use `int64` Frontend: #23037
This commit is contained in:
@@ -100,6 +100,31 @@ app, err := api.Database.GetOAuth2ProviderAppByClientID(dbauthz.AsSystemRestrict
|
||||
app, err := api.Database.GetOAuth2ProviderAppByClientID(ctx, clientID)
|
||||
```
|
||||
|
||||
### API Design
|
||||
|
||||
- Add swagger annotations when introducing new HTTP endpoints. Do this in
|
||||
the same change as the handler so the docs do not get missed before
|
||||
release.
|
||||
- For user-scoped or resource-scoped routes, prefer path parameters over
|
||||
query parameters when that matches existing route patterns.
|
||||
- For experimental or unstable API paths, skip public doc generation with
|
||||
`// @x-apidocgen {"skip": true}` after the `@Router` annotation. This
|
||||
keeps them out of the published API reference until they stabilize.
|
||||
|
||||
### Database Query Naming
|
||||
|
||||
- Use `ByX` when `X` is the lookup or filter column.
|
||||
- Use `PerX` or `GroupedByX` when `X` is the aggregation or grouping
|
||||
dimension.
|
||||
- Avoid `ByX` names for grouped queries.
|
||||
|
||||
### Database-to-SDK Conversions
|
||||
|
||||
- Extract explicit db-to-SDK conversion helpers instead of inlining large
|
||||
conversion blocks inside handlers.
|
||||
- Keep nullable-field handling, type coercion, and response shaping in the
|
||||
converter so handlers stay focused on request flow and authorization.
|
||||
|
||||
## Quick Reference
|
||||
|
||||
### Full workflows available in imported WORKFLOWS.md
|
||||
|
||||
@@ -0,0 +1,71 @@
|
||||
package chatcost
|
||||
|
||||
import (
|
||||
"github.com/shopspring/decimal"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/util/ptr"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
// Returns cost in micros -- millionths of a dollar, rounded up to the next
|
||||
// whole microdollar.
|
||||
// Returns nil when pricing is not configured or when all priced usage fields
|
||||
// are nil, allowing callers to distinguish "zero cost" from "unpriced".
|
||||
func CalculateTotalCostMicros(
|
||||
usage codersdk.ChatMessageUsage,
|
||||
cost *codersdk.ModelCostConfig,
|
||||
) *int64 {
|
||||
if cost == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// A cost config with no prices set means pricing is effectively
|
||||
// unconfigured — return nil (unpriced) rather than zero.
|
||||
if cost.InputPricePerMillionTokens == nil &&
|
||||
cost.OutputPricePerMillionTokens == nil &&
|
||||
cost.CacheReadPricePerMillionTokens == nil &&
|
||||
cost.CacheWritePricePerMillionTokens == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if usage.InputTokens == nil &&
|
||||
usage.OutputTokens == nil &&
|
||||
usage.ReasoningTokens == nil &&
|
||||
usage.CacheCreationTokens == nil &&
|
||||
usage.CacheReadTokens == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// OutputTokens already includes reasoning tokens per provider
|
||||
// semantics (e.g. OpenAI's completion_tokens encompasses
|
||||
// reasoning_tokens). Adding ReasoningTokens here would
|
||||
// double-count.
|
||||
|
||||
// Preserve nil when usage exists only in categories without configured
|
||||
// pricing, so callers can distinguish "unpriced" from "priced at zero".
|
||||
hasMatchingPrice := (usage.InputTokens != nil && cost.InputPricePerMillionTokens != nil) ||
|
||||
(usage.OutputTokens != nil && cost.OutputPricePerMillionTokens != nil) ||
|
||||
(usage.CacheReadTokens != nil && cost.CacheReadPricePerMillionTokens != nil) ||
|
||||
(usage.CacheCreationTokens != nil && cost.CacheWritePricePerMillionTokens != nil)
|
||||
if !hasMatchingPrice {
|
||||
return nil
|
||||
}
|
||||
|
||||
inputMicros := calcCost(usage.InputTokens, cost.InputPricePerMillionTokens)
|
||||
outputMicros := calcCost(usage.OutputTokens, cost.OutputPricePerMillionTokens)
|
||||
cacheReadMicros := calcCost(usage.CacheReadTokens, cost.CacheReadPricePerMillionTokens)
|
||||
cacheWriteMicros := calcCost(usage.CacheCreationTokens, cost.CacheWritePricePerMillionTokens)
|
||||
|
||||
total := inputMicros.
|
||||
Add(outputMicros).
|
||||
Add(cacheReadMicros).
|
||||
Add(cacheWriteMicros)
|
||||
rounded := total.Ceil().IntPart()
|
||||
return &rounded
|
||||
}
|
||||
|
||||
// calcCost returns the cost in fractional microdollars (millionths of a USD)
|
||||
// for the given token count at the specified per-million-token price.
|
||||
func calcCost(tokens *int64, pricePerMillion *decimal.Decimal) decimal.Decimal {
|
||||
return decimal.NewFromInt(ptr.NilToEmpty(tokens)).Mul(ptr.NilToEmpty(pricePerMillion))
|
||||
}
|
||||
@@ -0,0 +1,163 @@
|
||||
package chatcost_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatcost"
|
||||
"github.com/coder/coder/v2/coderd/util/ptr"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
func TestCalculateTotalCostMicros(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
usage codersdk.ChatMessageUsage
|
||||
cost *codersdk.ModelCostConfig
|
||||
want *int64
|
||||
}{
|
||||
{
|
||||
name: "nil cost returns nil",
|
||||
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](1000)},
|
||||
cost: nil,
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "all priced usage fields nil returns nil",
|
||||
usage: codersdk.ChatMessageUsage{
|
||||
TotalTokens: ptr.Ref[int64](1234),
|
||||
ContextLimit: ptr.Ref[int64](8192),
|
||||
},
|
||||
cost: &codersdk.ModelCostConfig{
|
||||
InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("3")),
|
||||
},
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "sub-micro total rounds up to 1",
|
||||
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](1)},
|
||||
cost: &codersdk.ModelCostConfig{
|
||||
InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("0.01")),
|
||||
},
|
||||
want: ptr.Ref[int64](1),
|
||||
},
|
||||
{
|
||||
name: "simple input only",
|
||||
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](1000)},
|
||||
cost: &codersdk.ModelCostConfig{
|
||||
InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("3")),
|
||||
},
|
||||
want: ptr.Ref[int64](3000),
|
||||
},
|
||||
{
|
||||
name: "simple output only",
|
||||
usage: codersdk.ChatMessageUsage{OutputTokens: ptr.Ref[int64](500)},
|
||||
cost: &codersdk.ModelCostConfig{
|
||||
OutputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("15")),
|
||||
},
|
||||
want: ptr.Ref[int64](7500),
|
||||
},
|
||||
{
|
||||
name: "reasoning tokens included in output total",
|
||||
usage: codersdk.ChatMessageUsage{
|
||||
OutputTokens: ptr.Ref[int64](500),
|
||||
ReasoningTokens: ptr.Ref[int64](200),
|
||||
},
|
||||
cost: &codersdk.ModelCostConfig{
|
||||
OutputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("15")),
|
||||
},
|
||||
want: ptr.Ref[int64](7500),
|
||||
},
|
||||
{
|
||||
name: "cache read tokens",
|
||||
usage: codersdk.ChatMessageUsage{CacheReadTokens: ptr.Ref[int64](10000)},
|
||||
cost: &codersdk.ModelCostConfig{
|
||||
CacheReadPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("0.3")),
|
||||
},
|
||||
want: ptr.Ref[int64](3000),
|
||||
},
|
||||
{
|
||||
name: "cache creation tokens",
|
||||
usage: codersdk.ChatMessageUsage{CacheCreationTokens: ptr.Ref[int64](5000)},
|
||||
cost: &codersdk.ModelCostConfig{
|
||||
CacheWritePricePerMillionTokens: ptr.Ref(decimal.RequireFromString("3.75")),
|
||||
},
|
||||
want: ptr.Ref[int64](18750),
|
||||
},
|
||||
{
|
||||
name: "full mixed usage totals all components exactly",
|
||||
usage: codersdk.ChatMessageUsage{
|
||||
InputTokens: ptr.Ref[int64](101),
|
||||
OutputTokens: ptr.Ref[int64](201),
|
||||
ReasoningTokens: ptr.Ref[int64](52),
|
||||
CacheReadTokens: ptr.Ref[int64](1005),
|
||||
CacheCreationTokens: ptr.Ref[int64](33),
|
||||
TotalTokens: ptr.Ref[int64](1391),
|
||||
ContextLimit: ptr.Ref[int64](4096),
|
||||
},
|
||||
cost: &codersdk.ModelCostConfig{
|
||||
InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("1.23")),
|
||||
OutputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("4.56")),
|
||||
CacheReadPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("0.7")),
|
||||
CacheWritePricePerMillionTokens: ptr.Ref(decimal.RequireFromString("7.89")),
|
||||
},
|
||||
want: ptr.Ref[int64](2005),
|
||||
},
|
||||
{
|
||||
name: "partial pricing only input contributes",
|
||||
usage: codersdk.ChatMessageUsage{
|
||||
InputTokens: ptr.Ref[int64](1234),
|
||||
OutputTokens: ptr.Ref[int64](999),
|
||||
ReasoningTokens: ptr.Ref[int64](111),
|
||||
CacheReadTokens: ptr.Ref[int64](500),
|
||||
CacheCreationTokens: ptr.Ref[int64](250),
|
||||
},
|
||||
cost: &codersdk.ModelCostConfig{
|
||||
InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("2.5")),
|
||||
},
|
||||
want: ptr.Ref[int64](3085),
|
||||
},
|
||||
{
|
||||
name: "zero tokens with pricing returns zero pointer",
|
||||
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](0)},
|
||||
cost: &codersdk.ModelCostConfig{
|
||||
InputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("3")),
|
||||
},
|
||||
want: ptr.Ref[int64](0),
|
||||
},
|
||||
{
|
||||
name: "usage only in unpriced categories returns nil",
|
||||
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](1000)},
|
||||
cost: &codersdk.ModelCostConfig{
|
||||
OutputPricePerMillionTokens: ptr.Ref(decimal.RequireFromString("15")),
|
||||
},
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "non nil usage with empty cost config returns nil",
|
||||
usage: codersdk.ChatMessageUsage{InputTokens: ptr.Ref[int64](42)},
|
||||
cost: &codersdk.ModelCostConfig{},
|
||||
want: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := chatcost.CalculateTotalCostMicros(tt.usage, tt.cost)
|
||||
|
||||
if tt.want == nil {
|
||||
require.Nil(t, got)
|
||||
} else {
|
||||
require.NotNil(t, got)
|
||||
require.Equal(t, *tt.want, *got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatcost"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatloop"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
|
||||
"github.com/coder/coder/v2/coderd/chatd/chatprovider"
|
||||
@@ -293,6 +294,7 @@ func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.C
|
||||
CacheReadTokens: sql.NullInt64{},
|
||||
ContextLimit: sql.NullInt64{},
|
||||
Compressed: sql.NullBool{},
|
||||
TotalCostMicros: sql.NullInt64{},
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("insert system message: %w", err)
|
||||
@@ -321,6 +323,7 @@ func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.C
|
||||
CacheCreationTokens: sql.NullInt64{},
|
||||
CacheReadTokens: sql.NullInt64{},
|
||||
ContextLimit: sql.NullInt64{},
|
||||
TotalCostMicros: sql.NullInt64{},
|
||||
Compressed: sql.NullBool{},
|
||||
})
|
||||
if err != nil {
|
||||
@@ -910,6 +913,7 @@ func insertUserMessageAndSetPending(
|
||||
CacheCreationTokens: sql.NullInt64{},
|
||||
CacheReadTokens: sql.NullInt64{},
|
||||
ContextLimit: sql.NullInt64{},
|
||||
TotalCostMicros: sql.NullInt64{},
|
||||
Compressed: sql.NullBool{},
|
||||
})
|
||||
if err != nil {
|
||||
@@ -1969,6 +1973,7 @@ func (p *Server) processChat(ctx context.Context, chat database.Chat) {
|
||||
CacheCreationTokens: sql.NullInt64{},
|
||||
CacheReadTokens: sql.NullInt64{},
|
||||
ContextLimit: sql.NullInt64{},
|
||||
TotalCostMicros: sql.NullInt64{},
|
||||
Compressed: sql.NullBool{},
|
||||
})
|
||||
if insertErr != nil {
|
||||
@@ -2347,6 +2352,30 @@ func (p *Server) runChat(
|
||||
}
|
||||
|
||||
hasUsage := step.Usage != (fantasy.Usage{})
|
||||
var usageForCost codersdk.ChatMessageUsage
|
||||
if hasUsage {
|
||||
// Only populate fields that the provider explicitly
|
||||
// reported. Nil fields tell the calculator "no data"
|
||||
// vs zero meaning "reported as zero tokens."
|
||||
if step.Usage.InputTokens != 0 {
|
||||
usageForCost.InputTokens = int64Ptr(step.Usage.InputTokens)
|
||||
}
|
||||
if step.Usage.OutputTokens != 0 {
|
||||
usageForCost.OutputTokens = int64Ptr(step.Usage.OutputTokens)
|
||||
}
|
||||
if step.Usage.ReasoningTokens != 0 {
|
||||
usageForCost.ReasoningTokens = int64Ptr(step.Usage.ReasoningTokens)
|
||||
}
|
||||
if step.Usage.CacheCreationTokens != 0 {
|
||||
usageForCost.CacheCreationTokens = int64Ptr(step.Usage.CacheCreationTokens)
|
||||
}
|
||||
if step.Usage.CacheReadTokens != 0 {
|
||||
usageForCost.CacheReadTokens = int64Ptr(step.Usage.CacheReadTokens)
|
||||
}
|
||||
}
|
||||
|
||||
totalCostMicros := chatcost.CalculateTotalCostMicros(usageForCost, callConfig.Cost)
|
||||
|
||||
assistantMessage, insertErr := tx.InsertChatMessage(persistCtx, database.InsertChatMessageParams{
|
||||
ChatID: chat.ID,
|
||||
CreatedBy: uuid.NullUUID{},
|
||||
@@ -2369,6 +2398,11 @@ func (p *Server) runChat(
|
||||
CacheReadTokens: usageNullInt64(step.Usage.CacheReadTokens, hasUsage),
|
||||
ContextLimit: step.ContextLimit,
|
||||
Compressed: sql.NullBool{},
|
||||
// TotalCostMicros is nullable: NULL means "unpriced"
|
||||
// (pricing config was missing or no priced token
|
||||
// breakdown available), while 0 means "priced at
|
||||
// zero cost" (e.g., a free model).
|
||||
TotalCostMicros: usageNullInt64Ptr(totalCostMicros),
|
||||
})
|
||||
if insertErr != nil {
|
||||
return xerrors.Errorf("insert assistant message: %w", insertErr)
|
||||
@@ -2398,6 +2432,7 @@ func (p *Server) runChat(
|
||||
CacheCreationTokens: sql.NullInt64{},
|
||||
CacheReadTokens: sql.NullInt64{},
|
||||
ContextLimit: sql.NullInt64{},
|
||||
TotalCostMicros: sql.NullInt64{},
|
||||
Compressed: sql.NullBool{},
|
||||
})
|
||||
if insertErr != nil {
|
||||
@@ -2740,6 +2775,7 @@ func (p *Server) persistChatContextSummary(
|
||||
CacheCreationTokens: sql.NullInt64{},
|
||||
CacheReadTokens: sql.NullInt64{},
|
||||
ContextLimit: sql.NullInt64{},
|
||||
TotalCostMicros: sql.NullInt64{},
|
||||
})
|
||||
if txErr != nil {
|
||||
return xerrors.Errorf("insert hidden summary message: %w", txErr)
|
||||
@@ -2764,6 +2800,7 @@ func (p *Server) persistChatContextSummary(
|
||||
CacheCreationTokens: sql.NullInt64{},
|
||||
CacheReadTokens: sql.NullInt64{},
|
||||
ContextLimit: sql.NullInt64{},
|
||||
TotalCostMicros: sql.NullInt64{},
|
||||
})
|
||||
if txErr != nil {
|
||||
return xerrors.Errorf("insert summary tool call message: %w", txErr)
|
||||
@@ -2789,6 +2826,7 @@ func (p *Server) persistChatContextSummary(
|
||||
CacheCreationTokens: sql.NullInt64{},
|
||||
CacheReadTokens: sql.NullInt64{},
|
||||
ContextLimit: sql.NullInt64{},
|
||||
TotalCostMicros: sql.NullInt64{},
|
||||
})
|
||||
if txErr != nil {
|
||||
return xerrors.Errorf("insert summary tool result message: %w", txErr)
|
||||
@@ -2901,6 +2939,10 @@ func (p *Server) resolveModelConfig(
|
||||
return defaultConfig, nil
|
||||
}
|
||||
|
||||
func int64Ptr(value int64) *int64 {
|
||||
return &value
|
||||
}
|
||||
|
||||
//nolint:revive // Boolean controls SQL NULL validity.
|
||||
func usageNullInt64(value int64, valid bool) sql.NullInt64 {
|
||||
if !valid {
|
||||
@@ -2912,6 +2954,13 @@ func usageNullInt64(value int64, valid bool) sql.NullInt64 {
|
||||
}
|
||||
}
|
||||
|
||||
func usageNullInt64Ptr(v *int64) sql.NullInt64 {
|
||||
if v == nil {
|
||||
return sql.NullInt64{}
|
||||
}
|
||||
return sql.NullInt64{Int64: *v, Valid: true}
|
||||
}
|
||||
|
||||
func refreshChatWorkspaceSnapshot(
|
||||
ctx context.Context,
|
||||
chat database.Chat,
|
||||
|
||||
@@ -553,34 +553,6 @@ func normalizedEnumValue(value string, allowed ...string) *string {
|
||||
return nil
|
||||
}
|
||||
|
||||
// MergeMissingCallConfig fills unset call config values from a provider or
|
||||
// profile default config.
|
||||
func MergeMissingCallConfig(
|
||||
dst *codersdk.ChatModelCallConfig,
|
||||
defaults codersdk.ChatModelCallConfig,
|
||||
) {
|
||||
if dst.MaxOutputTokens == nil {
|
||||
dst.MaxOutputTokens = defaults.MaxOutputTokens
|
||||
}
|
||||
if dst.Temperature == nil {
|
||||
dst.Temperature = defaults.Temperature
|
||||
}
|
||||
if dst.TopP == nil {
|
||||
dst.TopP = defaults.TopP
|
||||
}
|
||||
if dst.TopK == nil {
|
||||
dst.TopK = defaults.TopK
|
||||
}
|
||||
if dst.PresencePenalty == nil {
|
||||
dst.PresencePenalty = defaults.PresencePenalty
|
||||
}
|
||||
if dst.FrequencyPenalty == nil {
|
||||
dst.FrequencyPenalty = defaults.FrequencyPenalty
|
||||
}
|
||||
MergeMissingModelCostConfig(&dst.Cost, defaults.Cost)
|
||||
MergeMissingProviderOptions(&dst.ProviderOptions, defaults.ProviderOptions)
|
||||
}
|
||||
|
||||
// MergeMissingModelCostConfig fills unset pricing metadata from defaults.
|
||||
func MergeMissingModelCostConfig(
|
||||
dst **codersdk.ModelCostConfig,
|
||||
|
||||
@@ -137,61 +137,6 @@ func TestMergeMissingProviderOptions_OpenRouterNested(t *testing.T) {
|
||||
require.Equal(t, "latency", *options.OpenRouter.Provider.Sort)
|
||||
}
|
||||
|
||||
func TestMergeMissingCallConfig_FillsUnsetFields(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dst := codersdk.ChatModelCallConfig{
|
||||
Temperature: float64Ptr(0.2),
|
||||
Cost: &codersdk.ModelCostConfig{
|
||||
OutputPricePerMillionTokens: float64Ptr(0.7),
|
||||
},
|
||||
ProviderOptions: &codersdk.ChatModelProviderOptions{
|
||||
OpenAI: &codersdk.ChatModelOpenAIProviderOptions{
|
||||
User: stringPtr("alice"),
|
||||
},
|
||||
},
|
||||
}
|
||||
defaultCallConfig := codersdk.ChatModelCallConfig{
|
||||
MaxOutputTokens: int64Ptr(512),
|
||||
Temperature: float64Ptr(0.9),
|
||||
TopP: float64Ptr(0.8),
|
||||
Cost: &codersdk.ModelCostConfig{
|
||||
InputPricePerMillionTokens: float64Ptr(0.15),
|
||||
OutputPricePerMillionTokens: float64Ptr(0.9),
|
||||
CacheReadPricePerMillionTokens: float64Ptr(0.03),
|
||||
CacheWritePricePerMillionTokens: float64Ptr(0.3),
|
||||
},
|
||||
ProviderOptions: &codersdk.ChatModelProviderOptions{
|
||||
OpenAI: &codersdk.ChatModelOpenAIProviderOptions{
|
||||
User: stringPtr("bob"),
|
||||
ReasoningEffort: stringPtr("medium"),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
chatprovider.MergeMissingCallConfig(&dst, defaultCallConfig)
|
||||
|
||||
require.NotNil(t, dst.MaxOutputTokens)
|
||||
require.EqualValues(t, 512, *dst.MaxOutputTokens)
|
||||
require.NotNil(t, dst.Temperature)
|
||||
require.Equal(t, 0.2, *dst.Temperature)
|
||||
require.NotNil(t, dst.TopP)
|
||||
require.Equal(t, 0.8, *dst.TopP)
|
||||
require.NotNil(t, dst.Cost)
|
||||
require.NotNil(t, dst.Cost.InputPricePerMillionTokens)
|
||||
require.Equal(t, 0.15, *dst.Cost.InputPricePerMillionTokens)
|
||||
require.NotNil(t, dst.Cost.OutputPricePerMillionTokens)
|
||||
require.Equal(t, 0.7, *dst.Cost.OutputPricePerMillionTokens)
|
||||
require.NotNil(t, dst.Cost.CacheReadPricePerMillionTokens)
|
||||
require.Equal(t, 0.03, *dst.Cost.CacheReadPricePerMillionTokens)
|
||||
require.NotNil(t, dst.Cost.CacheWritePricePerMillionTokens)
|
||||
require.Equal(t, 0.3, *dst.Cost.CacheWritePricePerMillionTokens)
|
||||
require.NotNil(t, dst.ProviderOptions)
|
||||
require.NotNil(t, dst.ProviderOptions.OpenAI)
|
||||
require.Equal(t, "alice", *dst.ProviderOptions.OpenAI.User)
|
||||
require.Equal(t, "medium", *dst.ProviderOptions.OpenAI.ReasoningEffort)
|
||||
}
|
||||
|
||||
func stringPtr(value string) *string {
|
||||
return &value
|
||||
}
|
||||
@@ -203,7 +148,3 @@ func boolPtr(value bool) *bool {
|
||||
func int64Ptr(value int64) *int64 {
|
||||
return &value
|
||||
}
|
||||
|
||||
func float64Ptr(value float64) *float64 {
|
||||
return &value
|
||||
}
|
||||
|
||||
+229
-4
@@ -9,6 +9,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"math"
|
||||
"mime"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
@@ -20,6 +21,7 @@ import (
|
||||
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/google/uuid"
|
||||
"github.com/shopspring/decimal"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
@@ -354,6 +356,187 @@ func (api *API) listChatModels(rw http.ResponseWriter, r *http.Request) {
|
||||
httpapi.Write(ctx, rw, http.StatusOK, response)
|
||||
}
|
||||
|
||||
func (api *API) chatCostSummary(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
apiKey := httpmw.APIKey(r)
|
||||
|
||||
// Default date range: last 30 days.
|
||||
now := time.Now()
|
||||
defaultStart := now.AddDate(0, 0, -30)
|
||||
|
||||
qp := r.URL.Query()
|
||||
p := httpapi.NewQueryParamParser()
|
||||
startDate := p.Time(qp, defaultStart, "start_date", time.RFC3339)
|
||||
endDate := p.Time(qp, now, "end_date", time.RFC3339)
|
||||
p.ErrorExcessParams(qp)
|
||||
if len(p.Errors) > 0 {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid query parameters.",
|
||||
Validations: p.Errors,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
targetUser := httpmw.UserParam(r)
|
||||
if targetUser.ID != apiKey.UserID && !api.Authorize(r, policy.ActionRead, rbac.ResourceChat.WithOwner(targetUser.ID.String())) {
|
||||
httpapi.Forbidden(rw)
|
||||
return
|
||||
}
|
||||
|
||||
summary, err := api.Database.GetChatCostSummary(ctx, database.GetChatCostSummaryParams{
|
||||
OwnerID: targetUser.ID,
|
||||
StartDate: startDate,
|
||||
EndDate: endDate,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.InternalServerError(rw, err)
|
||||
return
|
||||
}
|
||||
|
||||
byModel, err := api.Database.GetChatCostPerModel(ctx, database.GetChatCostPerModelParams{
|
||||
OwnerID: targetUser.ID,
|
||||
StartDate: startDate,
|
||||
EndDate: endDate,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.InternalServerError(rw, err)
|
||||
return
|
||||
}
|
||||
|
||||
byChat, err := api.Database.GetChatCostPerChat(ctx, database.GetChatCostPerChatParams{
|
||||
OwnerID: targetUser.ID,
|
||||
StartDate: startDate,
|
||||
EndDate: endDate,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.InternalServerError(rw, err)
|
||||
return
|
||||
}
|
||||
|
||||
modelBreakdowns := make([]codersdk.ChatCostModelBreakdown, 0, len(byModel))
|
||||
for _, model := range byModel {
|
||||
modelBreakdowns = append(modelBreakdowns, convertChatCostModelBreakdown(model))
|
||||
}
|
||||
|
||||
chatBreakdowns := make([]codersdk.ChatCostChatBreakdown, 0, len(byChat))
|
||||
for _, chat := range byChat {
|
||||
chatBreakdowns = append(chatBreakdowns, convertChatCostChatBreakdown(chat))
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatCostSummary{
|
||||
StartDate: startDate,
|
||||
EndDate: endDate,
|
||||
TotalCostMicros: summary.TotalCostMicros,
|
||||
PricedMessageCount: summary.PricedMessageCount,
|
||||
UnpricedMessageCount: summary.UnpricedMessageCount,
|
||||
TotalInputTokens: summary.TotalInputTokens,
|
||||
TotalOutputTokens: summary.TotalOutputTokens,
|
||||
ByModel: modelBreakdowns,
|
||||
ByChat: chatBreakdowns,
|
||||
})
|
||||
}
|
||||
|
||||
func (api *API) chatCostUsers(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if !api.Authorize(r, policy.ActionRead, rbac.ResourceChat) {
|
||||
httpapi.Forbidden(rw)
|
||||
return
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
defaultStart := now.AddDate(0, 0, -30)
|
||||
|
||||
qp := r.URL.Query()
|
||||
p := httpapi.NewQueryParamParser()
|
||||
startDate := p.Time(qp, defaultStart, "start_date", time.RFC3339)
|
||||
endDate := p.Time(qp, now, "end_date", time.RFC3339)
|
||||
username := strings.TrimSpace(p.String(qp, "", "username"))
|
||||
limit := p.Int(qp, 10, "limit")
|
||||
offset := p.Int(qp, 0, "offset")
|
||||
p.ErrorExcessParams(qp)
|
||||
if len(p.Errors) > 0 {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid query parameters.",
|
||||
Validations: p.Errors,
|
||||
})
|
||||
return
|
||||
}
|
||||
if limit <= 0 {
|
||||
limit = 10
|
||||
}
|
||||
if offset < 0 || offset > math.MaxInt32 || limit > math.MaxInt32 {
|
||||
validations := make([]codersdk.ValidationError, 0, 2)
|
||||
if offset < 0 {
|
||||
validations = append(validations, codersdk.ValidationError{
|
||||
Field: "offset",
|
||||
Detail: "Must be greater than or equal to 0.",
|
||||
})
|
||||
}
|
||||
if offset > math.MaxInt32 {
|
||||
validations = append(validations, codersdk.ValidationError{
|
||||
Field: "offset",
|
||||
Detail: fmt.Sprintf("Must be less than or equal to %d.", math.MaxInt32),
|
||||
})
|
||||
}
|
||||
if limit > math.MaxInt32 {
|
||||
validations = append(validations, codersdk.ValidationError{
|
||||
Field: "limit",
|
||||
Detail: fmt.Sprintf("Must be less than or equal to %d.", math.MaxInt32),
|
||||
})
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid query parameters.",
|
||||
Validations: validations,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
users, err := api.Database.GetChatCostPerUser(ctx, database.GetChatCostPerUserParams{
|
||||
StartDate: startDate,
|
||||
EndDate: endDate,
|
||||
Username: username,
|
||||
// #nosec G115 - Pagination limits are validated to fit in int32 above.
|
||||
PageLimit: int32(limit),
|
||||
// #nosec G115 - Pagination offsets are validated to fit in int32 above.
|
||||
PageOffset: int32(offset),
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.InternalServerError(rw, err)
|
||||
return
|
||||
}
|
||||
|
||||
rollups := make([]codersdk.ChatCostUserRollup, 0, len(users))
|
||||
count := int64(0)
|
||||
for _, user := range users {
|
||||
count = user.TotalCount
|
||||
rollups = append(rollups, convertChatCostUserRollup(user))
|
||||
}
|
||||
|
||||
if len(users) == 0 && offset > 0 {
|
||||
countUsers, countErr := api.Database.GetChatCostPerUser(ctx, database.GetChatCostPerUserParams{
|
||||
StartDate: startDate,
|
||||
EndDate: endDate,
|
||||
Username: username,
|
||||
PageLimit: 1,
|
||||
PageOffset: 0,
|
||||
})
|
||||
if countErr != nil {
|
||||
httpapi.InternalServerError(rw, countErr)
|
||||
return
|
||||
}
|
||||
if len(countUsers) > 0 {
|
||||
count = countUsers[0].TotalCount
|
||||
}
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.ChatCostUsersResponse{
|
||||
StartDate: startDate,
|
||||
EndDate: endDate,
|
||||
Count: count,
|
||||
Users: rollups,
|
||||
})
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
//
|
||||
//nolint:revive // HTTP handler writes to ResponseWriter.
|
||||
@@ -2172,6 +2355,48 @@ func convertChats(chats []database.Chat, diffStatusesByChatID map[uuid.UUID]data
|
||||
return result
|
||||
}
|
||||
|
||||
func convertChatCostModelBreakdown(model database.GetChatCostPerModelRow) codersdk.ChatCostModelBreakdown {
|
||||
displayName := strings.TrimSpace(model.DisplayName)
|
||||
if displayName == "" {
|
||||
displayName = model.Model
|
||||
}
|
||||
return codersdk.ChatCostModelBreakdown{
|
||||
ModelConfigID: model.ModelConfigID,
|
||||
DisplayName: displayName,
|
||||
Provider: model.Provider,
|
||||
Model: model.Model,
|
||||
TotalCostMicros: model.TotalCostMicros,
|
||||
MessageCount: model.MessageCount,
|
||||
TotalInputTokens: model.TotalInputTokens,
|
||||
TotalOutputTokens: model.TotalOutputTokens,
|
||||
}
|
||||
}
|
||||
|
||||
func convertChatCostChatBreakdown(chat database.GetChatCostPerChatRow) codersdk.ChatCostChatBreakdown {
|
||||
return codersdk.ChatCostChatBreakdown{
|
||||
RootChatID: chat.RootChatID,
|
||||
ChatTitle: chat.ChatTitle,
|
||||
TotalCostMicros: chat.TotalCostMicros,
|
||||
MessageCount: chat.MessageCount,
|
||||
TotalInputTokens: chat.TotalInputTokens,
|
||||
TotalOutputTokens: chat.TotalOutputTokens,
|
||||
}
|
||||
}
|
||||
|
||||
func convertChatCostUserRollup(user database.GetChatCostPerUserRow) codersdk.ChatCostUserRollup {
|
||||
return codersdk.ChatCostUserRollup{
|
||||
UserID: user.UserID,
|
||||
Username: user.Username,
|
||||
Name: user.Name,
|
||||
AvatarURL: user.AvatarURL,
|
||||
TotalCostMicros: user.TotalCostMicros,
|
||||
MessageCount: user.MessageCount,
|
||||
ChatCount: user.ChatCount,
|
||||
TotalInputTokens: user.TotalInputTokens,
|
||||
TotalOutputTokens: user.TotalOutputTokens,
|
||||
}
|
||||
}
|
||||
|
||||
func convertChatQueuedMessage(m database.ChatQueuedMessage) codersdk.ChatQueuedMessage {
|
||||
return db2sdk.ChatQueuedMessage(m)
|
||||
}
|
||||
@@ -3117,7 +3342,7 @@ func validateChatModelCallConfig(modelConfig *codersdk.ChatModelCallConfig) erro
|
||||
|
||||
pricingFields := []struct {
|
||||
name string
|
||||
value *float64
|
||||
value *decimal.Decimal
|
||||
}{
|
||||
{name: "cost.input_price_per_million_tokens", value: costConfig.InputPricePerMillionTokens},
|
||||
{name: "cost.output_price_per_million_tokens", value: costConfig.OutputPricePerMillionTokens},
|
||||
@@ -3125,7 +3350,7 @@ func validateChatModelCallConfig(modelConfig *codersdk.ChatModelCallConfig) erro
|
||||
{name: "cost.cache_write_price_per_million_tokens", value: costConfig.CacheWritePricePerMillionTokens},
|
||||
}
|
||||
for _, field := range pricingFields {
|
||||
if err := validateNonNegativeFloat64Field(field.name, field.value); err != nil {
|
||||
if err := validateNonNegativeDecimalField(field.name, field.value); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -3133,11 +3358,11 @@ func validateChatModelCallConfig(modelConfig *codersdk.ChatModelCallConfig) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateNonNegativeFloat64Field(name string, value *float64) error {
|
||||
func validateNonNegativeDecimalField(name string, value *decimal.Decimal) error {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
if *value < 0 {
|
||||
if value.IsNegative() {
|
||||
return xerrors.Errorf("%s must be greater than or equal to zero", name)
|
||||
}
|
||||
return nil
|
||||
|
||||
+320
-19
@@ -13,6 +13,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/shopspring/decimal"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
@@ -22,7 +23,6 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/database/dbfake"
|
||||
"github.com/coder/coder/v2/coderd/externalauth"
|
||||
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
|
||||
"github.com/coder/coder/v2/coderd/util/ptr"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
"github.com/coder/websocket"
|
||||
@@ -939,10 +939,10 @@ func TestListChatModelConfigs(t *testing.T) {
|
||||
require.Equal(t, storedConfig.ID, configs[0].ID)
|
||||
requireChatModelPricing(t, configs[0].ModelConfig, &codersdk.ChatModelCallConfig{
|
||||
Cost: &codersdk.ModelCostConfig{
|
||||
InputPricePerMillionTokens: ptr.Ref(0.15),
|
||||
OutputPricePerMillionTokens: ptr.Ref(0.6),
|
||||
CacheReadPricePerMillionTokens: ptr.Ref(0.03),
|
||||
CacheWritePricePerMillionTokens: ptr.Ref(0.3),
|
||||
InputPricePerMillionTokens: decRef("0.15"),
|
||||
OutputPricePerMillionTokens: decRef("0.6"),
|
||||
CacheReadPricePerMillionTokens: decRef("0.03"),
|
||||
CacheWritePricePerMillionTokens: decRef("0.3"),
|
||||
},
|
||||
})
|
||||
})
|
||||
@@ -993,10 +993,10 @@ func TestCreateChatModelConfig(t *testing.T) {
|
||||
isDefault := true
|
||||
pricing := &codersdk.ChatModelCallConfig{
|
||||
Cost: &codersdk.ModelCostConfig{
|
||||
InputPricePerMillionTokens: ptr.Ref(0.15),
|
||||
OutputPricePerMillionTokens: ptr.Ref(0.6),
|
||||
CacheReadPricePerMillionTokens: ptr.Ref(0.03),
|
||||
CacheWritePricePerMillionTokens: ptr.Ref(0.3),
|
||||
InputPricePerMillionTokens: decRef("0.15"),
|
||||
OutputPricePerMillionTokens: decRef("0.6"),
|
||||
CacheReadPricePerMillionTokens: decRef("0.03"),
|
||||
CacheWritePricePerMillionTokens: decRef("0.3"),
|
||||
},
|
||||
}
|
||||
modelConfig, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
@@ -1040,7 +1040,7 @@ func TestCreateChatModelConfig(t *testing.T) {
|
||||
ContextLimit: &contextLimit,
|
||||
ModelConfig: &codersdk.ChatModelCallConfig{
|
||||
Cost: &codersdk.ModelCostConfig{
|
||||
InputPricePerMillionTokens: ptr.Ref(-0.01),
|
||||
InputPricePerMillionTokens: decRef("-0.01"),
|
||||
},
|
||||
},
|
||||
})
|
||||
@@ -1123,10 +1123,10 @@ func TestUpdateChatModelConfig(t *testing.T) {
|
||||
contextLimit := int64(8192)
|
||||
pricing := &codersdk.ChatModelCallConfig{
|
||||
Cost: &codersdk.ModelCostConfig{
|
||||
InputPricePerMillionTokens: ptr.Ref(0.2),
|
||||
OutputPricePerMillionTokens: ptr.Ref(0.8),
|
||||
CacheReadPricePerMillionTokens: ptr.Ref(0.04),
|
||||
CacheWritePricePerMillionTokens: ptr.Ref(0.4),
|
||||
InputPricePerMillionTokens: decRef("0.2"),
|
||||
OutputPricePerMillionTokens: decRef("0.8"),
|
||||
CacheReadPricePerMillionTokens: decRef("0.04"),
|
||||
CacheWritePricePerMillionTokens: decRef("0.4"),
|
||||
},
|
||||
}
|
||||
updated, err := client.UpdateChatModelConfig(ctx, modelConfig.ID, codersdk.UpdateChatModelConfigRequest{
|
||||
@@ -1157,7 +1157,7 @@ func TestUpdateChatModelConfig(t *testing.T) {
|
||||
_, err := client.UpdateChatModelConfig(ctx, modelConfig.ID, codersdk.UpdateChatModelConfigRequest{
|
||||
ModelConfig: &codersdk.ChatModelCallConfig{
|
||||
Cost: &codersdk.ModelCostConfig{
|
||||
OutputPricePerMillionTokens: ptr.Ref(-1.0),
|
||||
OutputPricePerMillionTokens: decRef("-1.0"),
|
||||
},
|
||||
},
|
||||
})
|
||||
@@ -3479,6 +3479,302 @@ func TestGetChatFile(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestChatCostSummary(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("BasicSummary", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, client)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
|
||||
chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
|
||||
OwnerID: firstUser.UserID,
|
||||
LastModelConfigID: modelConfig.ID,
|
||||
Title: "test chat",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
_, err = db.InsertChatMessage(dbauthz.AsSystemRestricted(ctx), database.InsertChatMessageParams{
|
||||
ChatID: chat.ID,
|
||||
ModelConfigID: uuid.NullUUID{UUID: modelConfig.ID, Valid: true},
|
||||
Role: "assistant",
|
||||
Visibility: database.ChatMessageVisibilityBoth,
|
||||
InputTokens: sql.NullInt64{Int64: 100, Valid: true},
|
||||
OutputTokens: sql.NullInt64{Int64: 50, Valid: true},
|
||||
TotalCostMicros: sql.NullInt64{Int64: 500, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
summary, err := client.GetChatCostSummary(ctx, "me", codersdk.ChatCostSummaryOptions{})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, int64(1000), summary.TotalCostMicros)
|
||||
require.Equal(t, int64(2), summary.PricedMessageCount)
|
||||
require.Equal(t, int64(0), summary.UnpricedMessageCount)
|
||||
require.Equal(t, int64(200), summary.TotalInputTokens)
|
||||
require.Equal(t, int64(100), summary.TotalOutputTokens)
|
||||
|
||||
require.Len(t, summary.ByModel, 1)
|
||||
require.Equal(t, modelConfig.ID, summary.ByModel[0].ModelConfigID)
|
||||
require.Equal(t, int64(1000), summary.ByModel[0].TotalCostMicros)
|
||||
require.Equal(t, int64(2), summary.ByModel[0].MessageCount)
|
||||
|
||||
require.Len(t, summary.ByChat, 1)
|
||||
require.Equal(t, chat.ID, summary.ByChat[0].RootChatID)
|
||||
require.Equal(t, int64(1000), summary.ByChat[0].TotalCostMicros)
|
||||
require.Equal(t, int64(2), summary.ByChat[0].MessageCount)
|
||||
})
|
||||
}
|
||||
|
||||
func TestChatCostSummary_AdminDrilldown(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
seedCtx := testutil.Context(t, testutil.WaitLong)
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, client)
|
||||
memberClient, member := coderdtest.CreateAnotherUser(t, client, firstUser.OrganizationID)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
|
||||
chat, err := db.InsertChat(dbauthz.AsSystemRestricted(seedCtx), database.InsertChatParams{
|
||||
OwnerID: member.ID,
|
||||
LastModelConfigID: modelConfig.ID,
|
||||
Title: "member chat",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.InsertChatMessage(dbauthz.AsSystemRestricted(seedCtx), database.InsertChatMessageParams{
|
||||
ChatID: chat.ID,
|
||||
ModelConfigID: uuid.NullUUID{UUID: modelConfig.ID, Valid: true},
|
||||
Role: "assistant",
|
||||
Visibility: database.ChatMessageVisibilityBoth,
|
||||
InputTokens: sql.NullInt64{Int64: 200, Valid: true},
|
||||
OutputTokens: sql.NullInt64{Int64: 100, Valid: true},
|
||||
TotalCostMicros: sql.NullInt64{Int64: 750, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("AdminCanDrilldown", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
summary, err := client.GetChatCostSummary(ctx, member.ID.String(), codersdk.ChatCostSummaryOptions{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(750), summary.TotalCostMicros)
|
||||
require.Equal(t, int64(1), summary.PricedMessageCount)
|
||||
})
|
||||
|
||||
t.Run("MemberCannotDrilldownOtherUser", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
_, err := memberClient.GetChatCostSummary(ctx, firstUser.UserID.String(), codersdk.ChatCostSummaryOptions{})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
require.Equal(t, http.StatusNotFound, sdkErr.StatusCode())
|
||||
})
|
||||
}
|
||||
|
||||
func TestChatCostUsers(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
seedCtx := testutil.Context(t, testutil.WaitLong)
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, client)
|
||||
memberClient, member := coderdtest.CreateAnotherUser(t, client, firstUser.OrganizationID)
|
||||
firstUserRecord, err := db.GetUserByID(dbauthz.AsSystemRestricted(seedCtx), firstUser.UserID)
|
||||
require.NoError(t, err)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
|
||||
adminChat, err := db.InsertChat(dbauthz.AsSystemRestricted(seedCtx), database.InsertChatParams{
|
||||
OwnerID: firstUser.UserID,
|
||||
LastModelConfigID: modelConfig.ID,
|
||||
Title: "admin chat",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = db.InsertChatMessage(dbauthz.AsSystemRestricted(seedCtx), database.InsertChatMessageParams{
|
||||
ChatID: adminChat.ID,
|
||||
ModelConfigID: uuid.NullUUID{UUID: modelConfig.ID, Valid: true},
|
||||
Role: "assistant",
|
||||
Visibility: database.ChatMessageVisibilityBoth,
|
||||
InputTokens: sql.NullInt64{Int64: 100, Valid: true},
|
||||
OutputTokens: sql.NullInt64{Int64: 50, Valid: true},
|
||||
TotalCostMicros: sql.NullInt64{Int64: 300, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
memberChat, err := db.InsertChat(dbauthz.AsSystemRestricted(seedCtx), database.InsertChatParams{
|
||||
OwnerID: member.ID,
|
||||
LastModelConfigID: modelConfig.ID,
|
||||
Title: "member chat",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = db.InsertChatMessage(dbauthz.AsSystemRestricted(seedCtx), database.InsertChatMessageParams{
|
||||
ChatID: memberChat.ID,
|
||||
ModelConfigID: uuid.NullUUID{UUID: modelConfig.ID, Valid: true},
|
||||
Role: "assistant",
|
||||
Visibility: database.ChatMessageVisibilityBoth,
|
||||
InputTokens: sql.NullInt64{Int64: 200, Valid: true},
|
||||
OutputTokens: sql.NullInt64{Int64: 100, Valid: true},
|
||||
TotalCostMicros: sql.NullInt64{Int64: 800, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("AdminCanListUsers", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
resp, err := client.GetChatCostUsers(ctx, codersdk.ChatCostUsersOptions{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(2), resp.Count)
|
||||
require.Len(t, resp.Users, 2)
|
||||
require.Equal(t, member.ID, resp.Users[0].UserID)
|
||||
require.Equal(t, member.Username, resp.Users[0].Username)
|
||||
require.Equal(t, int64(800), resp.Users[0].TotalCostMicros)
|
||||
require.Equal(t, int64(1), resp.Users[0].MessageCount)
|
||||
require.Equal(t, int64(1), resp.Users[0].ChatCount)
|
||||
require.Equal(t, firstUser.UserID, resp.Users[1].UserID)
|
||||
require.Equal(t, firstUserRecord.Username, resp.Users[1].Username)
|
||||
require.Equal(t, int64(300), resp.Users[1].TotalCostMicros)
|
||||
})
|
||||
|
||||
t.Run("AdminCanFilterAndPaginateUsers", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
resp, err := client.GetChatCostUsers(ctx, codersdk.ChatCostUsersOptions{
|
||||
Username: member.Username,
|
||||
Pagination: codersdk.Pagination{
|
||||
Limit: 1,
|
||||
Offset: 0,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(1), resp.Count)
|
||||
require.Len(t, resp.Users, 1)
|
||||
require.Equal(t, member.ID, resp.Users[0].UserID)
|
||||
require.Equal(t, member.Username, resp.Users[0].Username)
|
||||
})
|
||||
|
||||
t.Run("MemberCannotListUsers", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
_, err := memberClient.GetChatCostUsers(ctx, codersdk.ChatCostUsersOptions{})
|
||||
require.Error(t, err)
|
||||
var sdkErr *codersdk.Error
|
||||
require.ErrorAs(t, err, &sdkErr)
|
||||
require.Equal(t, http.StatusForbidden, sdkErr.StatusCode())
|
||||
})
|
||||
}
|
||||
|
||||
func TestChatCostSummary_DateRange(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
seedCtx := testutil.Context(t, testutil.WaitLong)
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, client)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
|
||||
chat, err := db.InsertChat(dbauthz.AsSystemRestricted(seedCtx), database.InsertChatParams{
|
||||
OwnerID: firstUser.UserID,
|
||||
LastModelConfigID: modelConfig.ID,
|
||||
Title: "date range test",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.InsertChatMessage(dbauthz.AsSystemRestricted(seedCtx), database.InsertChatMessageParams{
|
||||
ChatID: chat.ID,
|
||||
ModelConfigID: uuid.NullUUID{UUID: modelConfig.ID, Valid: true},
|
||||
Role: "assistant",
|
||||
Visibility: database.ChatMessageVisibilityBoth,
|
||||
InputTokens: sql.NullInt64{Int64: 100, Valid: true},
|
||||
OutputTokens: sql.NullInt64{Int64: 50, Valid: true},
|
||||
TotalCostMicros: sql.NullInt64{Int64: 500, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
now := time.Now()
|
||||
|
||||
t.Run("MessageInRange", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
summary, err := client.GetChatCostSummary(ctx, "me", codersdk.ChatCostSummaryOptions{
|
||||
StartDate: now.Add(-time.Hour),
|
||||
EndDate: now.Add(time.Hour),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(500), summary.TotalCostMicros)
|
||||
require.Equal(t, int64(1), summary.PricedMessageCount)
|
||||
})
|
||||
|
||||
t.Run("MessageOutOfRange", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
summary, err := client.GetChatCostSummary(ctx, "me", codersdk.ChatCostSummaryOptions{
|
||||
StartDate: now.Add(time.Hour),
|
||||
EndDate: now.Add(2 * time.Hour),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, int64(0), summary.TotalCostMicros)
|
||||
require.Equal(t, int64(0), summary.PricedMessageCount)
|
||||
})
|
||||
}
|
||||
|
||||
func TestChatCostSummary_UnpricedMessages(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, client)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
|
||||
chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
|
||||
OwnerID: firstUser.UserID,
|
||||
LastModelConfigID: modelConfig.ID,
|
||||
Title: "unpriced test",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.InsertChatMessage(dbauthz.AsSystemRestricted(ctx), database.InsertChatMessageParams{
|
||||
ChatID: chat.ID,
|
||||
ModelConfigID: uuid.NullUUID{UUID: modelConfig.ID, Valid: true},
|
||||
Role: "assistant",
|
||||
Visibility: database.ChatMessageVisibilityBoth,
|
||||
InputTokens: sql.NullInt64{Int64: 100, Valid: true},
|
||||
OutputTokens: sql.NullInt64{Int64: 50, Valid: true},
|
||||
TotalCostMicros: sql.NullInt64{Int64: 500, Valid: true},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.InsertChatMessage(dbauthz.AsSystemRestricted(ctx), database.InsertChatMessageParams{
|
||||
ChatID: chat.ID,
|
||||
ModelConfigID: uuid.NullUUID{UUID: modelConfig.ID, Valid: true},
|
||||
Role: "assistant",
|
||||
Visibility: database.ChatMessageVisibilityBoth,
|
||||
InputTokens: sql.NullInt64{Int64: 200, Valid: true},
|
||||
OutputTokens: sql.NullInt64{Int64: 75, Valid: true},
|
||||
TotalCostMicros: sql.NullInt64{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
summary, err := client.GetChatCostSummary(ctx, "me", codersdk.ChatCostSummaryOptions{})
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, int64(500), summary.TotalCostMicros)
|
||||
require.Equal(t, int64(1), summary.PricedMessageCount)
|
||||
require.Equal(t, int64(1), summary.UnpricedMessageCount)
|
||||
require.Equal(t, int64(300), summary.TotalInputTokens)
|
||||
require.Equal(t, int64(125), summary.TotalOutputTokens)
|
||||
}
|
||||
|
||||
func requireChatModelPricing(
|
||||
t *testing.T,
|
||||
actual *codersdk.ChatModelCallConfig,
|
||||
@@ -3495,10 +3791,15 @@ func requireChatModelPricing(
|
||||
require.NotNil(t, actual.Cost.CacheReadPricePerMillionTokens)
|
||||
require.NotNil(t, actual.Cost.CacheWritePricePerMillionTokens)
|
||||
|
||||
require.Equal(t, *expected.Cost.InputPricePerMillionTokens, *actual.Cost.InputPricePerMillionTokens)
|
||||
require.Equal(t, *expected.Cost.OutputPricePerMillionTokens, *actual.Cost.OutputPricePerMillionTokens)
|
||||
require.Equal(t, *expected.Cost.CacheReadPricePerMillionTokens, *actual.Cost.CacheReadPricePerMillionTokens)
|
||||
require.Equal(t, *expected.Cost.CacheWritePricePerMillionTokens, *actual.Cost.CacheWritePricePerMillionTokens)
|
||||
require.True(t, expected.Cost.InputPricePerMillionTokens.Equal(*actual.Cost.InputPricePerMillionTokens))
|
||||
require.True(t, expected.Cost.OutputPricePerMillionTokens.Equal(*actual.Cost.OutputPricePerMillionTokens))
|
||||
require.True(t, expected.Cost.CacheReadPricePerMillionTokens.Equal(*actual.Cost.CacheReadPricePerMillionTokens))
|
||||
require.True(t, expected.Cost.CacheWritePricePerMillionTokens.Equal(*actual.Cost.CacheWritePricePerMillionTokens))
|
||||
}
|
||||
|
||||
func decRef(value string) *decimal.Decimal {
|
||||
d := decimal.RequireFromString(value)
|
||||
return &d
|
||||
}
|
||||
|
||||
func createChatModelConfig(t *testing.T, client *codersdk.Client) codersdk.ChatModelConfig {
|
||||
|
||||
@@ -1139,6 +1139,13 @@ func New(options *Options) *API {
|
||||
r.Post("/", api.postChats)
|
||||
r.Get("/models", api.listChatModels)
|
||||
r.Get("/watch", api.watchChats)
|
||||
r.Route("/cost", func(r chi.Router) {
|
||||
r.Get("/users", api.chatCostUsers)
|
||||
r.Route("/{user}", func(r chi.Router) {
|
||||
r.Use(httpmw.ExtractUserParam(options.Database))
|
||||
r.Get("/summary", api.chatCostSummary)
|
||||
})
|
||||
})
|
||||
r.Route("/files", func(r chi.Router) {
|
||||
r.Use(httpmw.RateLimit(options.FilesRateLimit, time.Minute))
|
||||
r.Post("/", api.postChatFile)
|
||||
|
||||
@@ -2426,6 +2426,34 @@ func (q *querier) GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (datab
|
||||
return fetch(q.log, q.auth, q.db.GetChatByIDForUpdate)(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatCostPerChat(ctx context.Context, arg database.GetChatCostPerChatParams) ([]database.GetChatCostPerChatRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat.WithOwner(arg.OwnerID.String())); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetChatCostPerChat(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatCostPerModel(ctx context.Context, arg database.GetChatCostPerModelParams) ([]database.GetChatCostPerModelRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat.WithOwner(arg.OwnerID.String())); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetChatCostPerModel(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatCostPerUser(ctx context.Context, arg database.GetChatCostPerUserParams) ([]database.GetChatCostPerUserRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetChatCostPerUser(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatCostSummary(ctx context.Context, arg database.GetChatCostSummaryParams) (database.GetChatCostSummaryRow, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceChat.WithOwner(arg.OwnerID.String())); err != nil {
|
||||
return database.GetChatCostSummaryRow{}, err
|
||||
}
|
||||
return q.db.GetChatCostSummary(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (database.ChatDiffStatus, error) {
|
||||
// Authorize read on the parent chat.
|
||||
_, err := q.GetChatByID(ctx, chatID)
|
||||
|
||||
@@ -438,6 +438,81 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().GetChatByIDForUpdate(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
check.Args(chat.ID).Asserts(chat, policy.ActionRead).Returns(chat)
|
||||
}))
|
||||
s.Run("GetChatCostPerChat", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.GetChatCostPerChatParams{
|
||||
OwnerID: uuid.New(),
|
||||
StartDate: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
EndDate: time.Date(2025, 2, 1, 0, 0, 0, 0, time.UTC),
|
||||
}
|
||||
rows := []database.GetChatCostPerChatRow{{
|
||||
RootChatID: uuid.New(),
|
||||
ChatTitle: "chat-cost",
|
||||
TotalCostMicros: 123,
|
||||
MessageCount: 4,
|
||||
TotalInputTokens: 55,
|
||||
TotalOutputTokens: 89,
|
||||
}}
|
||||
dbm.EXPECT().GetChatCostPerChat(gomock.Any(), arg).Return(rows, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceChat.WithOwner(arg.OwnerID.String()), policy.ActionRead).Returns(rows)
|
||||
}))
|
||||
s.Run("GetChatCostPerModel", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.GetChatCostPerModelParams{
|
||||
OwnerID: uuid.New(),
|
||||
StartDate: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
EndDate: time.Date(2025, 2, 1, 0, 0, 0, 0, time.UTC),
|
||||
}
|
||||
rows := []database.GetChatCostPerModelRow{{
|
||||
ModelConfigID: uuid.New(),
|
||||
DisplayName: "GPT 4.1",
|
||||
Provider: "openai",
|
||||
Model: "gpt-4.1",
|
||||
TotalCostMicros: 456,
|
||||
MessageCount: 7,
|
||||
TotalInputTokens: 144,
|
||||
TotalOutputTokens: 233,
|
||||
}}
|
||||
dbm.EXPECT().GetChatCostPerModel(gomock.Any(), arg).Return(rows, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceChat.WithOwner(arg.OwnerID.String()), policy.ActionRead).Returns(rows)
|
||||
}))
|
||||
s.Run("GetChatCostPerUser", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.GetChatCostPerUserParams{
|
||||
PageOffset: 0,
|
||||
PageLimit: 25,
|
||||
StartDate: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
EndDate: time.Date(2025, 2, 1, 0, 0, 0, 0, time.UTC),
|
||||
Username: "cost-user",
|
||||
}
|
||||
rows := []database.GetChatCostPerUserRow{{
|
||||
UserID: uuid.New(),
|
||||
Username: "cost-user",
|
||||
Name: "Cost User",
|
||||
AvatarURL: "https://example.com/avatar.png",
|
||||
TotalCostMicros: 789,
|
||||
MessageCount: 11,
|
||||
ChatCount: 3,
|
||||
TotalInputTokens: 377,
|
||||
TotalOutputTokens: 610,
|
||||
TotalCount: 1,
|
||||
}}
|
||||
dbm.EXPECT().GetChatCostPerUser(gomock.Any(), arg).Return(rows, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceChat, policy.ActionRead).Returns(rows)
|
||||
}))
|
||||
s.Run("GetChatCostSummary", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) {
|
||||
arg := database.GetChatCostSummaryParams{
|
||||
OwnerID: uuid.New(),
|
||||
StartDate: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
EndDate: time.Date(2025, 2, 1, 0, 0, 0, 0, time.UTC),
|
||||
}
|
||||
row := database.GetChatCostSummaryRow{
|
||||
TotalCostMicros: 987,
|
||||
PricedMessageCount: 12,
|
||||
UnpricedMessageCount: 2,
|
||||
TotalInputTokens: 400,
|
||||
TotalOutputTokens: 800,
|
||||
}
|
||||
dbm.EXPECT().GetChatCostSummary(gomock.Any(), arg).Return(row, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceChat.WithOwner(arg.OwnerID.String()), policy.ActionRead).Returns(row)
|
||||
}))
|
||||
s.Run("GetChatDiffStatusByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
diffStatus := testutil.Fake(s.T(), faker, database.ChatDiffStatus{ChatID: chat.ID})
|
||||
|
||||
@@ -983,6 +983,38 @@ func (m queryMetricsStore) GetChatByIDForUpdate(ctx context.Context, id uuid.UUI
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatCostPerChat(ctx context.Context, arg database.GetChatCostPerChatParams) ([]database.GetChatCostPerChatRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatCostPerChat(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("GetChatCostPerChat").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatCostPerChat").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatCostPerModel(ctx context.Context, arg database.GetChatCostPerModelParams) ([]database.GetChatCostPerModelRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatCostPerModel(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("GetChatCostPerModel").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatCostPerModel").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatCostPerUser(ctx context.Context, arg database.GetChatCostPerUserParams) ([]database.GetChatCostPerUserRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatCostPerUser(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("GetChatCostPerUser").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatCostPerUser").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatCostSummary(ctx context.Context, arg database.GetChatCostSummaryParams) (database.GetChatCostSummaryRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatCostSummary(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("GetChatCostSummary").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatCostSummary").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (database.ChatDiffStatus, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatDiffStatusByChatID(ctx, chatID)
|
||||
|
||||
@@ -1778,6 +1778,66 @@ func (mr *MockStoreMockRecorder) GetChatByIDForUpdate(ctx, id any) *gomock.Call
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatByIDForUpdate", reflect.TypeOf((*MockStore)(nil).GetChatByIDForUpdate), ctx, id)
|
||||
}
|
||||
|
||||
// GetChatCostPerChat mocks base method.
|
||||
func (m *MockStore) GetChatCostPerChat(ctx context.Context, arg database.GetChatCostPerChatParams) ([]database.GetChatCostPerChatRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatCostPerChat", ctx, arg)
|
||||
ret0, _ := ret[0].([]database.GetChatCostPerChatRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatCostPerChat indicates an expected call of GetChatCostPerChat.
|
||||
func (mr *MockStoreMockRecorder) GetChatCostPerChat(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatCostPerChat", reflect.TypeOf((*MockStore)(nil).GetChatCostPerChat), ctx, arg)
|
||||
}
|
||||
|
||||
// GetChatCostPerModel mocks base method.
|
||||
func (m *MockStore) GetChatCostPerModel(ctx context.Context, arg database.GetChatCostPerModelParams) ([]database.GetChatCostPerModelRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatCostPerModel", ctx, arg)
|
||||
ret0, _ := ret[0].([]database.GetChatCostPerModelRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatCostPerModel indicates an expected call of GetChatCostPerModel.
|
||||
func (mr *MockStoreMockRecorder) GetChatCostPerModel(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatCostPerModel", reflect.TypeOf((*MockStore)(nil).GetChatCostPerModel), ctx, arg)
|
||||
}
|
||||
|
||||
// GetChatCostPerUser mocks base method.
|
||||
func (m *MockStore) GetChatCostPerUser(ctx context.Context, arg database.GetChatCostPerUserParams) ([]database.GetChatCostPerUserRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatCostPerUser", ctx, arg)
|
||||
ret0, _ := ret[0].([]database.GetChatCostPerUserRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatCostPerUser indicates an expected call of GetChatCostPerUser.
|
||||
func (mr *MockStoreMockRecorder) GetChatCostPerUser(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatCostPerUser", reflect.TypeOf((*MockStore)(nil).GetChatCostPerUser), ctx, arg)
|
||||
}
|
||||
|
||||
// GetChatCostSummary mocks base method.
|
||||
func (m *MockStore) GetChatCostSummary(ctx context.Context, arg database.GetChatCostSummaryParams) (database.GetChatCostSummaryRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatCostSummary", ctx, arg)
|
||||
ret0, _ := ret[0].(database.GetChatCostSummaryRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatCostSummary indicates an expected call of GetChatCostSummary.
|
||||
func (mr *MockStoreMockRecorder) GetChatCostSummary(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatCostSummary", reflect.TypeOf((*MockStore)(nil).GetChatCostSummary), ctx, arg)
|
||||
}
|
||||
|
||||
// GetChatDiffStatusByChatID mocks base method.
|
||||
func (m *MockStore) GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (database.ChatDiffStatus, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
Generated
+4
-1
@@ -1226,7 +1226,8 @@ CREATE TABLE chat_messages (
|
||||
context_limit bigint,
|
||||
compressed boolean DEFAULT false NOT NULL,
|
||||
created_by uuid,
|
||||
content_version smallint NOT NULL
|
||||
content_version smallint NOT NULL,
|
||||
total_cost_micros bigint
|
||||
);
|
||||
|
||||
CREATE SEQUENCE chat_messages_id_seq
|
||||
@@ -3534,6 +3535,8 @@ CREATE INDEX idx_chat_messages_chat_created ON chat_messages USING btree (chat_i
|
||||
|
||||
CREATE INDEX idx_chat_messages_compressed_summary_boundary ON chat_messages USING btree (chat_id, created_at DESC, id DESC) WHERE ((compressed = true) AND (role = 'system'::chat_message_role) AND (visibility = ANY (ARRAY['model'::chat_message_visibility, 'both'::chat_message_visibility])));
|
||||
|
||||
CREATE INDEX idx_chat_messages_created_at ON chat_messages USING btree (created_at);
|
||||
|
||||
CREATE INDEX idx_chat_model_configs_enabled ON chat_model_configs USING btree (enabled);
|
||||
|
||||
CREATE INDEX idx_chat_model_configs_provider ON chat_model_configs USING btree (provider);
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
DROP INDEX IF EXISTS idx_chat_messages_created_at;
|
||||
|
||||
ALTER TABLE chat_messages DROP COLUMN total_cost_micros;
|
||||
@@ -0,0 +1,68 @@
|
||||
ALTER TABLE chat_messages ADD COLUMN total_cost_micros BIGINT;
|
||||
|
||||
WITH message_costs AS (
|
||||
SELECT
|
||||
msg.id,
|
||||
ROUND(
|
||||
COALESCE(msg.input_tokens, 0)::numeric * COALESCE(pricing.input_price, 0)
|
||||
+ COALESCE(msg.output_tokens, 0)::numeric * COALESCE(pricing.output_price, 0)
|
||||
+ COALESCE(msg.cache_read_tokens, 0)::numeric * COALESCE(pricing.cache_read_price, 0)
|
||||
+ COALESCE(msg.cache_creation_tokens, 0)::numeric * COALESCE(pricing.cache_write_price, 0)
|
||||
)::bigint AS total_cost_micros
|
||||
FROM
|
||||
chat_messages AS msg
|
||||
JOIN
|
||||
chat_model_configs AS cfg
|
||||
ON
|
||||
cfg.id = msg.model_config_id
|
||||
CROSS JOIN LATERAL (
|
||||
SELECT
|
||||
COALESCE(
|
||||
(cfg.options -> 'cost' ->> 'input_price_per_million_tokens')::numeric,
|
||||
(cfg.options ->> 'input_price_per_million_tokens')::numeric
|
||||
) AS input_price,
|
||||
COALESCE(
|
||||
(cfg.options -> 'cost' ->> 'output_price_per_million_tokens')::numeric,
|
||||
(cfg.options ->> 'output_price_per_million_tokens')::numeric
|
||||
) AS output_price,
|
||||
COALESCE(
|
||||
(cfg.options -> 'cost' ->> 'cache_read_price_per_million_tokens')::numeric,
|
||||
(cfg.options ->> 'cache_read_price_per_million_tokens')::numeric
|
||||
) AS cache_read_price,
|
||||
COALESCE(
|
||||
(cfg.options -> 'cost' ->> 'cache_write_price_per_million_tokens')::numeric,
|
||||
(cfg.options ->> 'cache_write_price_per_million_tokens')::numeric
|
||||
) AS cache_write_price
|
||||
) AS pricing
|
||||
WHERE
|
||||
msg.total_cost_micros IS NULL
|
||||
AND (
|
||||
msg.input_tokens IS NOT NULL
|
||||
OR msg.output_tokens IS NOT NULL
|
||||
OR msg.reasoning_tokens IS NOT NULL
|
||||
OR msg.cache_creation_tokens IS NOT NULL
|
||||
OR msg.cache_read_tokens IS NOT NULL
|
||||
)
|
||||
AND (
|
||||
pricing.input_price IS NOT NULL
|
||||
OR pricing.output_price IS NOT NULL
|
||||
OR pricing.cache_read_price IS NOT NULL
|
||||
OR pricing.cache_write_price IS NOT NULL
|
||||
)
|
||||
AND (
|
||||
(msg.input_tokens IS NOT NULL AND pricing.input_price IS NOT NULL)
|
||||
OR (msg.output_tokens IS NOT NULL AND pricing.output_price IS NOT NULL)
|
||||
OR (msg.cache_read_tokens IS NOT NULL AND pricing.cache_read_price IS NOT NULL)
|
||||
OR (msg.cache_creation_tokens IS NOT NULL AND pricing.cache_write_price IS NOT NULL)
|
||||
)
|
||||
)
|
||||
UPDATE
|
||||
chat_messages AS msg
|
||||
SET
|
||||
total_cost_micros = message_costs.total_cost_micros
|
||||
FROM
|
||||
message_costs
|
||||
WHERE
|
||||
msg.id = message_costs.id;
|
||||
|
||||
CREATE INDEX idx_chat_messages_created_at ON chat_messages (created_at);
|
||||
@@ -4020,6 +4020,7 @@ type ChatMessage struct {
|
||||
Compressed bool `db:"compressed" json:"compressed"`
|
||||
CreatedBy uuid.NullUUID `db:"created_by" json:"created_by"`
|
||||
ContentVersion int16 `db:"content_version" json:"content_version"`
|
||||
TotalCostMicros sql.NullInt64 `db:"total_cost_micros" json:"total_cost_micros"`
|
||||
}
|
||||
|
||||
type ChatModelConfig struct {
|
||||
|
||||
@@ -215,6 +215,19 @@ type sqlcQuerier interface {
|
||||
GetAuthorizationUserRoles(ctx context.Context, userID uuid.UUID) (GetAuthorizationUserRolesRow, error)
|
||||
GetChatByID(ctx context.Context, id uuid.UUID) (Chat, error)
|
||||
GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (Chat, error)
|
||||
// Per-root-chat cost breakdown for a single user within a date range.
|
||||
// Groups by root_chat_id so forked chats roll up under their root.
|
||||
// Only counts assistant-role messages.
|
||||
GetChatCostPerChat(ctx context.Context, arg GetChatCostPerChatParams) ([]GetChatCostPerChatRow, error)
|
||||
// Per-model cost breakdown for a single user within a date range.
|
||||
// Only counts assistant-role messages that have a model_config_id.
|
||||
GetChatCostPerModel(ctx context.Context, arg GetChatCostPerModelParams) ([]GetChatCostPerModelRow, error)
|
||||
// Deployment-wide per-user cost rollup within a date range.
|
||||
// Only counts assistant-role messages.
|
||||
GetChatCostPerUser(ctx context.Context, arg GetChatCostPerUserParams) ([]GetChatCostPerUserRow, error)
|
||||
// Aggregate cost summary for a single user within a date range.
|
||||
// Only counts assistant-role messages.
|
||||
GetChatCostSummary(ctx context.Context, arg GetChatCostSummaryParams) (GetChatCostSummaryRow, error)
|
||||
GetChatDiffStatusByChatID(ctx context.Context, chatID uuid.UUID) (ChatDiffStatus, error)
|
||||
GetChatDiffStatusesByChatIDs(ctx context.Context, chatIds []uuid.UUID) ([]ChatDiffStatus, error)
|
||||
GetChatFileByID(ctx context.Context, id uuid.UUID) (ChatFile, error)
|
||||
|
||||
@@ -3229,6 +3229,353 @@ func (q *sqlQuerier) GetChatByIDForUpdate(ctx context.Context, id uuid.UUID) (Ch
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getChatCostPerChat = `-- name: GetChatCostPerChat :many
|
||||
WITH chat_costs AS (
|
||||
SELECT
|
||||
COALESCE(c.root_chat_id, c.id) AS root_chat_id,
|
||||
COALESCE(SUM(cm.total_cost_micros), 0)::bigint AS total_cost_micros,
|
||||
COUNT(*) FILTER (
|
||||
WHERE cm.input_tokens IS NOT NULL
|
||||
OR cm.output_tokens IS NOT NULL
|
||||
OR cm.reasoning_tokens IS NOT NULL
|
||||
OR cm.cache_creation_tokens IS NOT NULL
|
||||
OR cm.cache_read_tokens IS NOT NULL
|
||||
)::bigint AS message_count,
|
||||
COALESCE(SUM(cm.input_tokens), 0)::bigint AS total_input_tokens,
|
||||
COALESCE(SUM(cm.output_tokens), 0)::bigint AS total_output_tokens
|
||||
FROM chat_messages cm
|
||||
JOIN chats c ON c.id = cm.chat_id
|
||||
WHERE c.owner_id = $1::uuid
|
||||
AND cm.role = 'assistant'
|
||||
AND cm.created_at >= $2::timestamptz
|
||||
AND cm.created_at < $3::timestamptz
|
||||
GROUP BY COALESCE(c.root_chat_id, c.id)
|
||||
)
|
||||
SELECT
|
||||
cc.root_chat_id,
|
||||
COALESCE(rc.title, '') AS chat_title,
|
||||
cc.total_cost_micros,
|
||||
cc.message_count,
|
||||
cc.total_input_tokens,
|
||||
cc.total_output_tokens
|
||||
FROM chat_costs cc
|
||||
LEFT JOIN chats rc ON rc.id = cc.root_chat_id
|
||||
ORDER BY cc.total_cost_micros DESC
|
||||
`
|
||||
|
||||
type GetChatCostPerChatParams struct {
|
||||
OwnerID uuid.UUID `db:"owner_id" json:"owner_id"`
|
||||
StartDate time.Time `db:"start_date" json:"start_date"`
|
||||
EndDate time.Time `db:"end_date" json:"end_date"`
|
||||
}
|
||||
|
||||
type GetChatCostPerChatRow struct {
|
||||
RootChatID uuid.UUID `db:"root_chat_id" json:"root_chat_id"`
|
||||
ChatTitle string `db:"chat_title" json:"chat_title"`
|
||||
TotalCostMicros int64 `db:"total_cost_micros" json:"total_cost_micros"`
|
||||
MessageCount int64 `db:"message_count" json:"message_count"`
|
||||
TotalInputTokens int64 `db:"total_input_tokens" json:"total_input_tokens"`
|
||||
TotalOutputTokens int64 `db:"total_output_tokens" json:"total_output_tokens"`
|
||||
}
|
||||
|
||||
// Per-root-chat cost breakdown for a single user within a date range.
|
||||
// Groups by root_chat_id so forked chats roll up under their root.
|
||||
// Only counts assistant-role messages.
|
||||
func (q *sqlQuerier) GetChatCostPerChat(ctx context.Context, arg GetChatCostPerChatParams) ([]GetChatCostPerChatRow, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getChatCostPerChat, arg.OwnerID, arg.StartDate, arg.EndDate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []GetChatCostPerChatRow
|
||||
for rows.Next() {
|
||||
var i GetChatCostPerChatRow
|
||||
if err := rows.Scan(
|
||||
&i.RootChatID,
|
||||
&i.ChatTitle,
|
||||
&i.TotalCostMicros,
|
||||
&i.MessageCount,
|
||||
&i.TotalInputTokens,
|
||||
&i.TotalOutputTokens,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const getChatCostPerModel = `-- name: GetChatCostPerModel :many
|
||||
SELECT
|
||||
cmc.id AS model_config_id,
|
||||
cmc.display_name,
|
||||
cmc.provider,
|
||||
cmc.model,
|
||||
COALESCE(SUM(cm.total_cost_micros), 0)::bigint AS total_cost_micros,
|
||||
COUNT(*) FILTER (
|
||||
WHERE cm.input_tokens IS NOT NULL
|
||||
OR cm.output_tokens IS NOT NULL
|
||||
OR cm.reasoning_tokens IS NOT NULL
|
||||
OR cm.cache_creation_tokens IS NOT NULL
|
||||
OR cm.cache_read_tokens IS NOT NULL
|
||||
)::bigint AS message_count,
|
||||
COALESCE(SUM(cm.input_tokens), 0)::bigint AS total_input_tokens,
|
||||
COALESCE(SUM(cm.output_tokens), 0)::bigint AS total_output_tokens
|
||||
FROM
|
||||
chat_messages cm
|
||||
JOIN
|
||||
chats c ON c.id = cm.chat_id
|
||||
JOIN
|
||||
chat_model_configs cmc ON cmc.id = cm.model_config_id
|
||||
WHERE
|
||||
c.owner_id = $1::uuid
|
||||
AND cm.role = 'assistant'
|
||||
AND cm.created_at >= $2::timestamptz
|
||||
AND cm.created_at < $3::timestamptz
|
||||
GROUP BY
|
||||
cmc.id, cmc.display_name, cmc.provider, cmc.model
|
||||
ORDER BY
|
||||
total_cost_micros DESC
|
||||
`
|
||||
|
||||
type GetChatCostPerModelParams struct {
|
||||
OwnerID uuid.UUID `db:"owner_id" json:"owner_id"`
|
||||
StartDate time.Time `db:"start_date" json:"start_date"`
|
||||
EndDate time.Time `db:"end_date" json:"end_date"`
|
||||
}
|
||||
|
||||
type GetChatCostPerModelRow struct {
|
||||
ModelConfigID uuid.UUID `db:"model_config_id" json:"model_config_id"`
|
||||
DisplayName string `db:"display_name" json:"display_name"`
|
||||
Provider string `db:"provider" json:"provider"`
|
||||
Model string `db:"model" json:"model"`
|
||||
TotalCostMicros int64 `db:"total_cost_micros" json:"total_cost_micros"`
|
||||
MessageCount int64 `db:"message_count" json:"message_count"`
|
||||
TotalInputTokens int64 `db:"total_input_tokens" json:"total_input_tokens"`
|
||||
TotalOutputTokens int64 `db:"total_output_tokens" json:"total_output_tokens"`
|
||||
}
|
||||
|
||||
// Per-model cost breakdown for a single user within a date range.
|
||||
// Only counts assistant-role messages that have a model_config_id.
|
||||
func (q *sqlQuerier) GetChatCostPerModel(ctx context.Context, arg GetChatCostPerModelParams) ([]GetChatCostPerModelRow, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getChatCostPerModel, arg.OwnerID, arg.StartDate, arg.EndDate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []GetChatCostPerModelRow
|
||||
for rows.Next() {
|
||||
var i GetChatCostPerModelRow
|
||||
if err := rows.Scan(
|
||||
&i.ModelConfigID,
|
||||
&i.DisplayName,
|
||||
&i.Provider,
|
||||
&i.Model,
|
||||
&i.TotalCostMicros,
|
||||
&i.MessageCount,
|
||||
&i.TotalInputTokens,
|
||||
&i.TotalOutputTokens,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const getChatCostPerUser = `-- name: GetChatCostPerUser :many
|
||||
WITH chat_cost_users AS (
|
||||
SELECT
|
||||
c.owner_id AS user_id,
|
||||
u.username,
|
||||
u.name,
|
||||
u.avatar_url,
|
||||
COALESCE(SUM(cm.total_cost_micros), 0)::bigint AS total_cost_micros,
|
||||
COUNT(*) FILTER (
|
||||
WHERE cm.input_tokens IS NOT NULL
|
||||
OR cm.output_tokens IS NOT NULL
|
||||
OR cm.reasoning_tokens IS NOT NULL
|
||||
OR cm.cache_creation_tokens IS NOT NULL
|
||||
OR cm.cache_read_tokens IS NOT NULL
|
||||
)::bigint AS message_count,
|
||||
COUNT(DISTINCT COALESCE(c.root_chat_id, c.id))::bigint AS chat_count,
|
||||
COALESCE(SUM(cm.input_tokens), 0)::bigint AS total_input_tokens,
|
||||
COALESCE(SUM(cm.output_tokens), 0)::bigint AS total_output_tokens
|
||||
FROM
|
||||
chat_messages cm
|
||||
JOIN
|
||||
chats c ON c.id = cm.chat_id
|
||||
JOIN
|
||||
users u ON u.id = c.owner_id
|
||||
WHERE
|
||||
cm.role = 'assistant'
|
||||
AND cm.created_at >= $3::timestamptz
|
||||
AND cm.created_at < $4::timestamptz
|
||||
AND (
|
||||
$5::text = ''
|
||||
OR u.username ILIKE '%' || $5::text || '%'
|
||||
)
|
||||
GROUP BY
|
||||
c.owner_id,
|
||||
u.username,
|
||||
u.name,
|
||||
u.avatar_url
|
||||
)
|
||||
SELECT
|
||||
user_id,
|
||||
username,
|
||||
name,
|
||||
avatar_url,
|
||||
total_cost_micros,
|
||||
message_count,
|
||||
chat_count,
|
||||
total_input_tokens,
|
||||
total_output_tokens,
|
||||
COUNT(*) OVER()::bigint AS total_count
|
||||
FROM
|
||||
chat_cost_users
|
||||
ORDER BY
|
||||
total_cost_micros DESC,
|
||||
username ASC
|
||||
LIMIT
|
||||
$2::int
|
||||
OFFSET
|
||||
$1::int
|
||||
`
|
||||
|
||||
type GetChatCostPerUserParams struct {
|
||||
PageOffset int32 `db:"page_offset" json:"page_offset"`
|
||||
PageLimit int32 `db:"page_limit" json:"page_limit"`
|
||||
StartDate time.Time `db:"start_date" json:"start_date"`
|
||||
EndDate time.Time `db:"end_date" json:"end_date"`
|
||||
Username string `db:"username" json:"username"`
|
||||
}
|
||||
|
||||
type GetChatCostPerUserRow struct {
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
Username string `db:"username" json:"username"`
|
||||
Name string `db:"name" json:"name"`
|
||||
AvatarURL string `db:"avatar_url" json:"avatar_url"`
|
||||
TotalCostMicros int64 `db:"total_cost_micros" json:"total_cost_micros"`
|
||||
MessageCount int64 `db:"message_count" json:"message_count"`
|
||||
ChatCount int64 `db:"chat_count" json:"chat_count"`
|
||||
TotalInputTokens int64 `db:"total_input_tokens" json:"total_input_tokens"`
|
||||
TotalOutputTokens int64 `db:"total_output_tokens" json:"total_output_tokens"`
|
||||
TotalCount int64 `db:"total_count" json:"total_count"`
|
||||
}
|
||||
|
||||
// Deployment-wide per-user cost rollup within a date range.
|
||||
// Only counts assistant-role messages.
|
||||
func (q *sqlQuerier) GetChatCostPerUser(ctx context.Context, arg GetChatCostPerUserParams) ([]GetChatCostPerUserRow, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getChatCostPerUser,
|
||||
arg.PageOffset,
|
||||
arg.PageLimit,
|
||||
arg.StartDate,
|
||||
arg.EndDate,
|
||||
arg.Username,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []GetChatCostPerUserRow
|
||||
for rows.Next() {
|
||||
var i GetChatCostPerUserRow
|
||||
if err := rows.Scan(
|
||||
&i.UserID,
|
||||
&i.Username,
|
||||
&i.Name,
|
||||
&i.AvatarURL,
|
||||
&i.TotalCostMicros,
|
||||
&i.MessageCount,
|
||||
&i.ChatCount,
|
||||
&i.TotalInputTokens,
|
||||
&i.TotalOutputTokens,
|
||||
&i.TotalCount,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const getChatCostSummary = `-- name: GetChatCostSummary :one
|
||||
SELECT
|
||||
COALESCE(SUM(cm.total_cost_micros), 0)::bigint AS total_cost_micros,
|
||||
COUNT(*) FILTER (
|
||||
WHERE cm.total_cost_micros IS NOT NULL
|
||||
)::bigint AS priced_message_count,
|
||||
COUNT(*) FILTER (
|
||||
WHERE cm.total_cost_micros IS NULL
|
||||
AND (
|
||||
cm.input_tokens IS NOT NULL
|
||||
OR cm.output_tokens IS NOT NULL
|
||||
OR cm.reasoning_tokens IS NOT NULL
|
||||
OR cm.cache_creation_tokens IS NOT NULL
|
||||
OR cm.cache_read_tokens IS NOT NULL
|
||||
)
|
||||
)::bigint AS unpriced_message_count,
|
||||
COALESCE(SUM(cm.input_tokens), 0)::bigint AS total_input_tokens,
|
||||
COALESCE(SUM(cm.output_tokens), 0)::bigint AS total_output_tokens
|
||||
FROM
|
||||
chat_messages cm
|
||||
JOIN
|
||||
chats c ON c.id = cm.chat_id
|
||||
WHERE
|
||||
c.owner_id = $1::uuid
|
||||
AND cm.role = 'assistant'
|
||||
AND cm.created_at >= $2::timestamptz
|
||||
AND cm.created_at < $3::timestamptz
|
||||
`
|
||||
|
||||
type GetChatCostSummaryParams struct {
|
||||
OwnerID uuid.UUID `db:"owner_id" json:"owner_id"`
|
||||
StartDate time.Time `db:"start_date" json:"start_date"`
|
||||
EndDate time.Time `db:"end_date" json:"end_date"`
|
||||
}
|
||||
|
||||
type GetChatCostSummaryRow struct {
|
||||
TotalCostMicros int64 `db:"total_cost_micros" json:"total_cost_micros"`
|
||||
PricedMessageCount int64 `db:"priced_message_count" json:"priced_message_count"`
|
||||
UnpricedMessageCount int64 `db:"unpriced_message_count" json:"unpriced_message_count"`
|
||||
TotalInputTokens int64 `db:"total_input_tokens" json:"total_input_tokens"`
|
||||
TotalOutputTokens int64 `db:"total_output_tokens" json:"total_output_tokens"`
|
||||
}
|
||||
|
||||
// Aggregate cost summary for a single user within a date range.
|
||||
// Only counts assistant-role messages.
|
||||
func (q *sqlQuerier) GetChatCostSummary(ctx context.Context, arg GetChatCostSummaryParams) (GetChatCostSummaryRow, error) {
|
||||
row := q.db.QueryRowContext(ctx, getChatCostSummary, arg.OwnerID, arg.StartDate, arg.EndDate)
|
||||
var i GetChatCostSummaryRow
|
||||
err := row.Scan(
|
||||
&i.TotalCostMicros,
|
||||
&i.PricedMessageCount,
|
||||
&i.UnpricedMessageCount,
|
||||
&i.TotalInputTokens,
|
||||
&i.TotalOutputTokens,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getChatDiffStatusByChatID = `-- name: GetChatDiffStatusByChatID :one
|
||||
SELECT
|
||||
chat_id, url, pull_request_state, changes_requested, additions, deletions, changed_files, refreshed_at, stale_at, created_at, updated_at, git_branch, git_remote_origin, pull_request_title, pull_request_draft
|
||||
@@ -3311,7 +3658,7 @@ func (q *sqlQuerier) GetChatDiffStatusesByChatIDs(ctx context.Context, chatIds [
|
||||
|
||||
const getChatMessageByID = `-- name: GetChatMessageByID :one
|
||||
SELECT
|
||||
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version
|
||||
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros
|
||||
FROM
|
||||
chat_messages
|
||||
WHERE
|
||||
@@ -3339,13 +3686,14 @@ func (q *sqlQuerier) GetChatMessageByID(ctx context.Context, id int64) (ChatMess
|
||||
&i.Compressed,
|
||||
&i.CreatedBy,
|
||||
&i.ContentVersion,
|
||||
&i.TotalCostMicros,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getChatMessagesByChatID = `-- name: GetChatMessagesByChatID :many
|
||||
SELECT
|
||||
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version
|
||||
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros
|
||||
FROM
|
||||
chat_messages
|
||||
WHERE
|
||||
@@ -3388,6 +3736,7 @@ func (q *sqlQuerier) GetChatMessagesByChatID(ctx context.Context, arg GetChatMes
|
||||
&i.Compressed,
|
||||
&i.CreatedBy,
|
||||
&i.ContentVersion,
|
||||
&i.TotalCostMicros,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -3419,7 +3768,7 @@ WITH latest_compressed_summary AS (
|
||||
1
|
||||
)
|
||||
SELECT
|
||||
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version
|
||||
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros
|
||||
FROM
|
||||
chat_messages
|
||||
WHERE
|
||||
@@ -3486,6 +3835,7 @@ func (q *sqlQuerier) GetChatMessagesForPromptByChatID(ctx context.Context, chatI
|
||||
&i.Compressed,
|
||||
&i.CreatedBy,
|
||||
&i.ContentVersion,
|
||||
&i.TotalCostMicros,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -3629,7 +3979,7 @@ func (q *sqlQuerier) GetChatsByOwnerID(ctx context.Context, arg GetChatsByOwnerI
|
||||
|
||||
const getLastChatMessageByRole = `-- name: GetLastChatMessageByRole :one
|
||||
SELECT
|
||||
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version
|
||||
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros
|
||||
FROM
|
||||
chat_messages
|
||||
WHERE
|
||||
@@ -3667,6 +4017,7 @@ func (q *sqlQuerier) GetLastChatMessageByRole(ctx context.Context, arg GetLastCh
|
||||
&i.Compressed,
|
||||
&i.CreatedBy,
|
||||
&i.ContentVersion,
|
||||
&i.TotalCostMicros,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@@ -3806,7 +4157,8 @@ INSERT INTO chat_messages (
|
||||
cache_creation_tokens,
|
||||
cache_read_tokens,
|
||||
context_limit,
|
||||
compressed
|
||||
compressed,
|
||||
total_cost_micros
|
||||
) VALUES (
|
||||
$1::uuid,
|
||||
$2::uuid,
|
||||
@@ -3822,10 +4174,11 @@ INSERT INTO chat_messages (
|
||||
$12::bigint,
|
||||
$13::bigint,
|
||||
$14::bigint,
|
||||
COALESCE($15::boolean, FALSE)
|
||||
COALESCE($15::boolean, FALSE),
|
||||
$16::bigint
|
||||
)
|
||||
RETURNING
|
||||
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version
|
||||
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros
|
||||
`
|
||||
|
||||
type InsertChatMessageParams struct {
|
||||
@@ -3844,6 +4197,7 @@ type InsertChatMessageParams struct {
|
||||
CacheReadTokens sql.NullInt64 `db:"cache_read_tokens" json:"cache_read_tokens"`
|
||||
ContextLimit sql.NullInt64 `db:"context_limit" json:"context_limit"`
|
||||
Compressed sql.NullBool `db:"compressed" json:"compressed"`
|
||||
TotalCostMicros sql.NullInt64 `db:"total_cost_micros" json:"total_cost_micros"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) InsertChatMessage(ctx context.Context, arg InsertChatMessageParams) (ChatMessage, error) {
|
||||
@@ -3863,6 +4217,7 @@ func (q *sqlQuerier) InsertChatMessage(ctx context.Context, arg InsertChatMessag
|
||||
arg.CacheReadTokens,
|
||||
arg.ContextLimit,
|
||||
arg.Compressed,
|
||||
arg.TotalCostMicros,
|
||||
)
|
||||
var i ChatMessage
|
||||
err := row.Scan(
|
||||
@@ -3883,6 +4238,7 @@ func (q *sqlQuerier) InsertChatMessage(ctx context.Context, arg InsertChatMessag
|
||||
&i.Compressed,
|
||||
&i.CreatedBy,
|
||||
&i.ContentVersion,
|
||||
&i.TotalCostMicros,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@@ -4017,7 +4373,7 @@ SET
|
||||
WHERE
|
||||
id = $3::bigint
|
||||
RETURNING
|
||||
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version
|
||||
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros
|
||||
`
|
||||
|
||||
type UpdateChatMessageByIDParams struct {
|
||||
@@ -4047,6 +4403,7 @@ func (q *sqlQuerier) UpdateChatMessageByID(ctx context.Context, arg UpdateChatMe
|
||||
&i.Compressed,
|
||||
&i.CreatedBy,
|
||||
&i.ContentVersion,
|
||||
&i.TotalCostMicros,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
@@ -179,7 +179,8 @@ INSERT INTO chat_messages (
|
||||
cache_creation_tokens,
|
||||
cache_read_tokens,
|
||||
context_limit,
|
||||
compressed
|
||||
compressed,
|
||||
total_cost_micros
|
||||
) VALUES (
|
||||
@chat_id::uuid,
|
||||
sqlc.narg('created_by')::uuid,
|
||||
@@ -195,7 +196,8 @@ INSERT INTO chat_messages (
|
||||
sqlc.narg('cache_creation_tokens')::bigint,
|
||||
sqlc.narg('cache_read_tokens')::bigint,
|
||||
sqlc.narg('context_limit')::bigint,
|
||||
COALESCE(sqlc.narg('compressed')::boolean, FALSE)
|
||||
COALESCE(sqlc.narg('compressed')::boolean, FALSE),
|
||||
sqlc.narg('total_cost_micros')::bigint
|
||||
)
|
||||
RETURNING
|
||||
*;
|
||||
@@ -481,3 +483,164 @@ SET
|
||||
updated_at = NOW()
|
||||
WHERE
|
||||
chat_id = @chat_id::uuid;
|
||||
|
||||
-- name: GetChatCostSummary :one
|
||||
-- Aggregate cost summary for a single user within a date range.
|
||||
-- Only counts assistant-role messages.
|
||||
SELECT
|
||||
COALESCE(SUM(cm.total_cost_micros), 0)::bigint AS total_cost_micros,
|
||||
COUNT(*) FILTER (
|
||||
WHERE cm.total_cost_micros IS NOT NULL
|
||||
)::bigint AS priced_message_count,
|
||||
COUNT(*) FILTER (
|
||||
WHERE cm.total_cost_micros IS NULL
|
||||
AND (
|
||||
cm.input_tokens IS NOT NULL
|
||||
OR cm.output_tokens IS NOT NULL
|
||||
OR cm.reasoning_tokens IS NOT NULL
|
||||
OR cm.cache_creation_tokens IS NOT NULL
|
||||
OR cm.cache_read_tokens IS NOT NULL
|
||||
)
|
||||
)::bigint AS unpriced_message_count,
|
||||
COALESCE(SUM(cm.input_tokens), 0)::bigint AS total_input_tokens,
|
||||
COALESCE(SUM(cm.output_tokens), 0)::bigint AS total_output_tokens
|
||||
FROM
|
||||
chat_messages cm
|
||||
JOIN
|
||||
chats c ON c.id = cm.chat_id
|
||||
WHERE
|
||||
c.owner_id = @owner_id::uuid
|
||||
AND cm.role = 'assistant'
|
||||
AND cm.created_at >= @start_date::timestamptz
|
||||
AND cm.created_at < @end_date::timestamptz;
|
||||
|
||||
-- name: GetChatCostPerModel :many
|
||||
-- Per-model cost breakdown for a single user within a date range.
|
||||
-- Only counts assistant-role messages that have a model_config_id.
|
||||
SELECT
|
||||
cmc.id AS model_config_id,
|
||||
cmc.display_name,
|
||||
cmc.provider,
|
||||
cmc.model,
|
||||
COALESCE(SUM(cm.total_cost_micros), 0)::bigint AS total_cost_micros,
|
||||
COUNT(*) FILTER (
|
||||
WHERE cm.input_tokens IS NOT NULL
|
||||
OR cm.output_tokens IS NOT NULL
|
||||
OR cm.reasoning_tokens IS NOT NULL
|
||||
OR cm.cache_creation_tokens IS NOT NULL
|
||||
OR cm.cache_read_tokens IS NOT NULL
|
||||
)::bigint AS message_count,
|
||||
COALESCE(SUM(cm.input_tokens), 0)::bigint AS total_input_tokens,
|
||||
COALESCE(SUM(cm.output_tokens), 0)::bigint AS total_output_tokens
|
||||
FROM
|
||||
chat_messages cm
|
||||
JOIN
|
||||
chats c ON c.id = cm.chat_id
|
||||
JOIN
|
||||
chat_model_configs cmc ON cmc.id = cm.model_config_id
|
||||
WHERE
|
||||
c.owner_id = @owner_id::uuid
|
||||
AND cm.role = 'assistant'
|
||||
AND cm.created_at >= @start_date::timestamptz
|
||||
AND cm.created_at < @end_date::timestamptz
|
||||
GROUP BY
|
||||
cmc.id, cmc.display_name, cmc.provider, cmc.model
|
||||
ORDER BY
|
||||
total_cost_micros DESC;
|
||||
|
||||
-- name: GetChatCostPerChat :many
|
||||
-- Per-root-chat cost breakdown for a single user within a date range.
|
||||
-- Groups by root_chat_id so forked chats roll up under their root.
|
||||
-- Only counts assistant-role messages.
|
||||
WITH chat_costs AS (
|
||||
SELECT
|
||||
COALESCE(c.root_chat_id, c.id) AS root_chat_id,
|
||||
COALESCE(SUM(cm.total_cost_micros), 0)::bigint AS total_cost_micros,
|
||||
COUNT(*) FILTER (
|
||||
WHERE cm.input_tokens IS NOT NULL
|
||||
OR cm.output_tokens IS NOT NULL
|
||||
OR cm.reasoning_tokens IS NOT NULL
|
||||
OR cm.cache_creation_tokens IS NOT NULL
|
||||
OR cm.cache_read_tokens IS NOT NULL
|
||||
)::bigint AS message_count,
|
||||
COALESCE(SUM(cm.input_tokens), 0)::bigint AS total_input_tokens,
|
||||
COALESCE(SUM(cm.output_tokens), 0)::bigint AS total_output_tokens
|
||||
FROM chat_messages cm
|
||||
JOIN chats c ON c.id = cm.chat_id
|
||||
WHERE c.owner_id = @owner_id::uuid
|
||||
AND cm.role = 'assistant'
|
||||
AND cm.created_at >= @start_date::timestamptz
|
||||
AND cm.created_at < @end_date::timestamptz
|
||||
GROUP BY COALESCE(c.root_chat_id, c.id)
|
||||
)
|
||||
SELECT
|
||||
cc.root_chat_id,
|
||||
COALESCE(rc.title, '') AS chat_title,
|
||||
cc.total_cost_micros,
|
||||
cc.message_count,
|
||||
cc.total_input_tokens,
|
||||
cc.total_output_tokens
|
||||
FROM chat_costs cc
|
||||
LEFT JOIN chats rc ON rc.id = cc.root_chat_id
|
||||
ORDER BY cc.total_cost_micros DESC;
|
||||
|
||||
-- name: GetChatCostPerUser :many
|
||||
-- Deployment-wide per-user cost rollup within a date range.
|
||||
-- Only counts assistant-role messages.
|
||||
WITH chat_cost_users AS (
|
||||
SELECT
|
||||
c.owner_id AS user_id,
|
||||
u.username,
|
||||
u.name,
|
||||
u.avatar_url,
|
||||
COALESCE(SUM(cm.total_cost_micros), 0)::bigint AS total_cost_micros,
|
||||
COUNT(*) FILTER (
|
||||
WHERE cm.input_tokens IS NOT NULL
|
||||
OR cm.output_tokens IS NOT NULL
|
||||
OR cm.reasoning_tokens IS NOT NULL
|
||||
OR cm.cache_creation_tokens IS NOT NULL
|
||||
OR cm.cache_read_tokens IS NOT NULL
|
||||
)::bigint AS message_count,
|
||||
COUNT(DISTINCT COALESCE(c.root_chat_id, c.id))::bigint AS chat_count,
|
||||
COALESCE(SUM(cm.input_tokens), 0)::bigint AS total_input_tokens,
|
||||
COALESCE(SUM(cm.output_tokens), 0)::bigint AS total_output_tokens
|
||||
FROM
|
||||
chat_messages cm
|
||||
JOIN
|
||||
chats c ON c.id = cm.chat_id
|
||||
JOIN
|
||||
users u ON u.id = c.owner_id
|
||||
WHERE
|
||||
cm.role = 'assistant'
|
||||
AND cm.created_at >= @start_date::timestamptz
|
||||
AND cm.created_at < @end_date::timestamptz
|
||||
AND (
|
||||
@username::text = ''
|
||||
OR u.username ILIKE '%' || @username::text || '%'
|
||||
)
|
||||
GROUP BY
|
||||
c.owner_id,
|
||||
u.username,
|
||||
u.name,
|
||||
u.avatar_url
|
||||
)
|
||||
SELECT
|
||||
user_id,
|
||||
username,
|
||||
name,
|
||||
avatar_url,
|
||||
total_cost_micros,
|
||||
message_count,
|
||||
chat_count,
|
||||
total_input_tokens,
|
||||
total_output_tokens,
|
||||
COUNT(*) OVER()::bigint AS total_count
|
||||
FROM
|
||||
chat_cost_users
|
||||
ORDER BY
|
||||
total_cost_micros DESC,
|
||||
username ASC
|
||||
LIMIT
|
||||
sqlc.arg('page_limit')::int
|
||||
OFFSET
|
||||
sqlc.arg('page_offset')::int;
|
||||
|
||||
@@ -148,6 +148,17 @@ sql:
|
||||
go_type: "database/sql.NullTime"
|
||||
- column: "task_event_data.first_status_after_resume_at"
|
||||
go_type: "database/sql.NullTime"
|
||||
- db_type: "pg_catalog.numeric"
|
||||
go_type:
|
||||
import: "github.com/shopspring/decimal"
|
||||
type: "Decimal"
|
||||
package: "decimal"
|
||||
- db_type: "pg_catalog.numeric"
|
||||
nullable: true
|
||||
go_type:
|
||||
import: "github.com/shopspring/decimal"
|
||||
type: "NullDecimal"
|
||||
package: "decimal"
|
||||
rename:
|
||||
group_member: GroupMemberTable
|
||||
group_members_expanded: GroupMember
|
||||
|
||||
+148
-13
@@ -7,10 +7,13 @@ import (
|
||||
"io"
|
||||
"mime"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/shopspring/decimal"
|
||||
|
||||
"github.com/coder/websocket"
|
||||
"github.com/coder/websocket/wsjson"
|
||||
@@ -198,11 +201,11 @@ func ChatMessageFileReference(fileName string, startLine, endLine int, content s
|
||||
}
|
||||
|
||||
// ChatMessageSource builds a source chat message part.
|
||||
func ChatMessageSource(sourceID, url, title string) ChatMessagePart {
|
||||
func ChatMessageSource(sourceID, sourceURL, title string) ChatMessagePart {
|
||||
return ChatMessagePart{
|
||||
Type: ChatMessagePartTypeSource,
|
||||
SourceID: sourceID,
|
||||
URL: url,
|
||||
URL: sourceURL,
|
||||
Title: title,
|
||||
}
|
||||
}
|
||||
@@ -519,13 +522,10 @@ type ChatModelVercelProviderOptions struct {
|
||||
|
||||
// ModelCostConfig stores pricing metadata for a chat model.
|
||||
type ModelCostConfig struct {
|
||||
// Pricing is stored as configuration metadata and currently only needs to
|
||||
// round-trip cleanly through the API and admin UI. If we later use these
|
||||
// values for billing-grade arithmetic, switch to a fixed-point type.
|
||||
InputPricePerMillionTokens *float64 `json:"input_price_per_million_tokens,omitempty" description:"Input token price in USD per 1M tokens"`
|
||||
OutputPricePerMillionTokens *float64 `json:"output_price_per_million_tokens,omitempty" description:"Output token price in USD per 1M tokens"`
|
||||
CacheReadPricePerMillionTokens *float64 `json:"cache_read_price_per_million_tokens,omitempty" description:"Cache read token price in USD per 1M tokens"`
|
||||
CacheWritePricePerMillionTokens *float64 `json:"cache_write_price_per_million_tokens,omitempty" description:"Cache write or cache creation token price in USD per 1M tokens"`
|
||||
InputPricePerMillionTokens *decimal.Decimal `json:"input_price_per_million_tokens,omitempty" description:"Input token price in USD per 1M tokens"`
|
||||
OutputPricePerMillionTokens *decimal.Decimal `json:"output_price_per_million_tokens,omitempty" description:"Output token price in USD per 1M tokens"`
|
||||
CacheReadPricePerMillionTokens *decimal.Decimal `json:"cache_read_price_per_million_tokens,omitempty" description:"Cache read token price in USD per 1M tokens"`
|
||||
CacheWritePricePerMillionTokens *decimal.Decimal `json:"cache_write_price_per_million_tokens,omitempty" description:"Cache write or cache creation token price in USD per 1M tokens"`
|
||||
}
|
||||
|
||||
// ChatModelCallConfig configures per-call model behavior defaults.
|
||||
@@ -546,10 +546,10 @@ func (c *ChatModelCallConfig) UnmarshalJSON(data []byte) error {
|
||||
type chatModelCallConfigAlias ChatModelCallConfig
|
||||
aux := struct {
|
||||
*chatModelCallConfigAlias
|
||||
InputPricePerMillionTokens *float64 `json:"input_price_per_million_tokens,omitempty"`
|
||||
OutputPricePerMillionTokens *float64 `json:"output_price_per_million_tokens,omitempty"`
|
||||
CacheReadPricePerMillionTokens *float64 `json:"cache_read_price_per_million_tokens,omitempty"`
|
||||
CacheWritePricePerMillionTokens *float64 `json:"cache_write_price_per_million_tokens,omitempty"`
|
||||
InputPricePerMillionTokens *decimal.Decimal `json:"input_price_per_million_tokens,omitempty"`
|
||||
OutputPricePerMillionTokens *decimal.Decimal `json:"output_price_per_million_tokens,omitempty"`
|
||||
CacheReadPricePerMillionTokens *decimal.Decimal `json:"cache_read_price_per_million_tokens,omitempty"`
|
||||
CacheWritePricePerMillionTokens *decimal.Decimal `json:"cache_write_price_per_million_tokens,omitempty"`
|
||||
}{
|
||||
chatModelCallConfigAlias: (*chatModelCallConfigAlias)(c),
|
||||
}
|
||||
@@ -710,6 +710,76 @@ type chatStreamEnvelope struct {
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// ChatCostSummaryOptions are optional query parameters for GetChatCostSummary.
|
||||
type ChatCostSummaryOptions struct {
|
||||
StartDate time.Time
|
||||
EndDate time.Time
|
||||
}
|
||||
|
||||
// ChatCostUsersOptions are optional query parameters for GetChatCostUsers.
|
||||
type ChatCostUsersOptions struct {
|
||||
StartDate time.Time
|
||||
EndDate time.Time
|
||||
Username string
|
||||
Pagination
|
||||
}
|
||||
|
||||
// ChatCostSummary is the response from the chat cost summary endpoint.
|
||||
type ChatCostSummary struct {
|
||||
StartDate time.Time `json:"start_date" format:"date-time"`
|
||||
EndDate time.Time `json:"end_date" format:"date-time"`
|
||||
TotalCostMicros int64 `json:"total_cost_micros"`
|
||||
PricedMessageCount int64 `json:"priced_message_count"`
|
||||
UnpricedMessageCount int64 `json:"unpriced_message_count"`
|
||||
TotalInputTokens int64 `json:"total_input_tokens"`
|
||||
TotalOutputTokens int64 `json:"total_output_tokens"`
|
||||
ByModel []ChatCostModelBreakdown `json:"by_model"`
|
||||
ByChat []ChatCostChatBreakdown `json:"by_chat"`
|
||||
}
|
||||
|
||||
// ChatCostModelBreakdown contains per-model cost aggregation.
|
||||
type ChatCostModelBreakdown struct {
|
||||
ModelConfigID uuid.UUID `json:"model_config_id" format:"uuid"`
|
||||
DisplayName string `json:"display_name"`
|
||||
Provider string `json:"provider"`
|
||||
Model string `json:"model"`
|
||||
TotalCostMicros int64 `json:"total_cost_micros"`
|
||||
MessageCount int64 `json:"message_count"`
|
||||
TotalInputTokens int64 `json:"total_input_tokens"`
|
||||
TotalOutputTokens int64 `json:"total_output_tokens"`
|
||||
}
|
||||
|
||||
// ChatCostChatBreakdown contains per-root-chat cost aggregation.
|
||||
type ChatCostChatBreakdown struct {
|
||||
RootChatID uuid.UUID `json:"root_chat_id" format:"uuid"`
|
||||
ChatTitle string `json:"chat_title"`
|
||||
TotalCostMicros int64 `json:"total_cost_micros"`
|
||||
MessageCount int64 `json:"message_count"`
|
||||
TotalInputTokens int64 `json:"total_input_tokens"`
|
||||
TotalOutputTokens int64 `json:"total_output_tokens"`
|
||||
}
|
||||
|
||||
// ChatCostUserRollup contains per-user cost aggregation for admin views.
|
||||
type ChatCostUserRollup struct {
|
||||
UserID uuid.UUID `json:"user_id" format:"uuid"`
|
||||
Username string `json:"username"`
|
||||
Name string `json:"name"`
|
||||
AvatarURL string `json:"avatar_url"`
|
||||
TotalCostMicros int64 `json:"total_cost_micros"`
|
||||
MessageCount int64 `json:"message_count"`
|
||||
ChatCount int64 `json:"chat_count"`
|
||||
TotalInputTokens int64 `json:"total_input_tokens"`
|
||||
TotalOutputTokens int64 `json:"total_output_tokens"`
|
||||
}
|
||||
|
||||
// ChatCostUsersResponse is the response from the admin chat cost users endpoint.
|
||||
type ChatCostUsersResponse struct {
|
||||
StartDate time.Time `json:"start_date" format:"date-time"`
|
||||
EndDate time.Time `json:"end_date" format:"date-time"`
|
||||
Count int64 `json:"count"`
|
||||
Users []ChatCostUserRollup `json:"users"`
|
||||
}
|
||||
|
||||
// ListChatsOptions are optional parameters for ListChats.
|
||||
type ListChatsOptions struct {
|
||||
Query string
|
||||
@@ -872,6 +942,71 @@ func (c *Client) DeleteChatModelConfig(ctx context.Context, modelConfigID uuid.U
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetChatCostSummary returns an aggregate cost summary for the specified
|
||||
// user. Zero-valued StartDate or EndDate fields are omitted from the
|
||||
// request, letting the server apply its own defaults (typically the last
|
||||
// 30 days).
|
||||
func (c *Client) GetChatCostSummary(ctx context.Context, user string, opts ChatCostSummaryOptions) (ChatCostSummary, error) {
|
||||
qp := url.Values{}
|
||||
if !opts.StartDate.IsZero() {
|
||||
qp.Set("start_date", opts.StartDate.Format(time.RFC3339))
|
||||
}
|
||||
if !opts.EndDate.IsZero() {
|
||||
qp.Set("end_date", opts.EndDate.Format(time.RFC3339))
|
||||
}
|
||||
reqURL := fmt.Sprintf("/api/experimental/chats/cost/%s/summary", user)
|
||||
if len(qp) > 0 {
|
||||
reqURL += "?" + qp.Encode()
|
||||
}
|
||||
res, err := c.Request(ctx, http.MethodGet, reqURL, nil)
|
||||
if err != nil {
|
||||
return ChatCostSummary{}, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return ChatCostSummary{}, ReadBodyAsError(res)
|
||||
}
|
||||
var summary ChatCostSummary
|
||||
return summary, json.NewDecoder(res.Body).Decode(&summary)
|
||||
}
|
||||
|
||||
// GetChatCostUsers returns a per-user cost rollup for the deployment
|
||||
// (admin only). Zero-valued StartDate or EndDate fields are omitted from
|
||||
// the request, letting the server apply its own defaults (typically the
|
||||
// last 30 days).
|
||||
func (c *Client) GetChatCostUsers(ctx context.Context, opts ChatCostUsersOptions) (ChatCostUsersResponse, error) {
|
||||
qp := url.Values{}
|
||||
if !opts.StartDate.IsZero() {
|
||||
qp.Set("start_date", opts.StartDate.Format(time.RFC3339))
|
||||
}
|
||||
if !opts.EndDate.IsZero() {
|
||||
qp.Set("end_date", opts.EndDate.Format(time.RFC3339))
|
||||
}
|
||||
if opts.Username != "" {
|
||||
qp.Set("username", opts.Username)
|
||||
}
|
||||
if opts.Limit > 0 {
|
||||
qp.Set("limit", strconv.Itoa(opts.Limit))
|
||||
}
|
||||
if opts.Offset > 0 {
|
||||
qp.Set("offset", strconv.Itoa(opts.Offset))
|
||||
}
|
||||
reqURL := "/api/experimental/chats/cost/users"
|
||||
if len(qp) > 0 {
|
||||
reqURL += "?" + qp.Encode()
|
||||
}
|
||||
res, err := c.Request(ctx, http.MethodGet, reqURL, nil)
|
||||
if err != nil {
|
||||
return ChatCostUsersResponse{}, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return ChatCostUsersResponse{}, ReadBodyAsError(res)
|
||||
}
|
||||
var resp ChatCostUsersResponse
|
||||
return resp, json.NewDecoder(res.Body).Decode(&resp)
|
||||
}
|
||||
|
||||
// GetChatSystemPrompt returns the deployment-wide chat system prompt.
|
||||
func (c *Client) GetChatSystemPrompt(ctx context.Context) (ChatSystemPromptResponse, error) {
|
||||
res, err := c.Request(ctx, http.MethodGet, "/api/experimental/chats/config/system-prompt", nil)
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/shopspring/decimal"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
@@ -109,3 +110,65 @@ func TestChatMessagePart_StripInternal(t *testing.T) {
|
||||
assert.Equal(t, codersdk.ChatMessagePartTypeText, part.Type)
|
||||
})
|
||||
}
|
||||
|
||||
func TestModelCostConfig_LegacyNumericJSON(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var decoded codersdk.ModelCostConfig
|
||||
err := json.Unmarshal([]byte("{\"input_price_per_million_tokens\": 1.5}"), &decoded)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, decoded.InputPricePerMillionTokens)
|
||||
require.True(t, decoded.InputPricePerMillionTokens.Equal(decimal.RequireFromString("1.5")))
|
||||
}
|
||||
|
||||
func TestModelCostConfig_QuotedDecimalJSON(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var decoded codersdk.ModelCostConfig
|
||||
err := json.Unmarshal([]byte("{\"input_price_per_million_tokens\": \"1.5\"}"), &decoded)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, decoded.InputPricePerMillionTokens)
|
||||
require.True(t, decoded.InputPricePerMillionTokens.Equal(decimal.RequireFromString("1.5")))
|
||||
}
|
||||
|
||||
func TestModelCostConfig_NilVsZero(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
zero := decimal.Zero
|
||||
raw, err := json.Marshal(struct {
|
||||
Nil codersdk.ModelCostConfig `json:"nil"`
|
||||
Zero codersdk.ModelCostConfig `json:"zero"`
|
||||
}{
|
||||
Nil: codersdk.ModelCostConfig{},
|
||||
Zero: codersdk.ModelCostConfig{InputPricePerMillionTokens: &zero},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, string(raw), "\"zero\":{\"input_price_per_million_tokens\":\"0\"}")
|
||||
require.Contains(t, string(raw), "\"nil\":{}")
|
||||
}
|
||||
|
||||
func TestChatModelCallConfig_UnmarshalLegacyPricing(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var decoded codersdk.ChatModelCallConfig
|
||||
err := json.Unmarshal([]byte("{\"input_price_per_million_tokens\": 1.5}"), &decoded)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, decoded.Cost)
|
||||
require.NotNil(t, decoded.Cost.InputPricePerMillionTokens)
|
||||
require.True(t, decoded.Cost.InputPricePerMillionTokens.Equal(decimal.RequireFromString("1.5")))
|
||||
}
|
||||
|
||||
func TestChatCostSummary_JSONRoundTrip(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
original := codersdk.ChatCostSummary{
|
||||
TotalCostMicros: 123,
|
||||
}
|
||||
raw, err := json.Marshal(original)
|
||||
require.NoError(t, err)
|
||||
|
||||
var decoded codersdk.ChatCostSummary
|
||||
err = json.Unmarshal(raw, &decoded)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, original.TotalCostMicros, decoded.TotalCostMicros)
|
||||
}
|
||||
|
||||
Generated
+1
@@ -0,0 +1 @@
|
||||
# Chat
|
||||
@@ -492,6 +492,7 @@ require (
|
||||
github.com/go-git/go-git/v5 v5.17.0
|
||||
github.com/mark3labs/mcp-go v0.38.0
|
||||
github.com/openai/openai-go/v3 v3.15.0
|
||||
github.com/shopspring/decimal v1.4.0
|
||||
gonum.org/v1/gonum v0.17.0
|
||||
)
|
||||
|
||||
|
||||
@@ -1059,6 +1059,8 @@ github.com/sergi/go-diff v1.4.0 h1:n/SP9D5ad1fORl+llWyN+D6qoUETXNZARKjyY2/KVCw=
|
||||
github.com/sergi/go-diff v1.4.0/go.mod h1:A0bzQcvG0E7Rwjx0REVgAGH58e96+X0MeOfepqsbeW4=
|
||||
github.com/shirou/gopsutil/v4 v4.26.1 h1:TOkEyriIXk2HX9d4isZJtbjXbEjf5qyKPAzbzY0JWSo=
|
||||
github.com/shirou/gopsutil/v4 v4.26.1/go.mod h1:medLI9/UNAb0dOI9Q3/7yWSqKkj00u+1tgY8nvv41pc=
|
||||
github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k=
|
||||
github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME=
|
||||
github.com/sirupsen/logrus v1.7.0/go.mod h1:yWOB1SBYBC5VeMP7gHvWumXLIWorT60ONWic61uBYv0=
|
||||
github.com/sirupsen/logrus v1.9.4 h1:TsZE7l11zFCLZnZ+teH4Umoq5BhEIfIzfRDZ1Uzql2w=
|
||||
github.com/sirupsen/logrus v1.9.4/go.mod h1:ftWc9WdOfJ0a92nsE2jF5u5ZwH8Bv2zdeOC42RjbV2g=
|
||||
|
||||
@@ -130,6 +130,10 @@ func TypeMappings(gen *guts.GoParser) error {
|
||||
"github.com/coder/serpent.URL": "string",
|
||||
"github.com/coder/serpent.HostPort": "string",
|
||||
"encoding/json.RawMessage": "map[string]string",
|
||||
// decimal.Decimal preserves exact pricing precision (e.g. $3.50 per
|
||||
// million tokens) and serializes as a JSON string to avoid
|
||||
// floating-point loss in transit.
|
||||
"github.com/shopspring/decimal.Decimal": "string",
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("include custom: %w", err)
|
||||
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
@@ -117,11 +119,15 @@ func extractFields(t reflect.Type, prefix string, skip map[string]bool) FieldGro
|
||||
// so that entire sub-objects can be marked hidden.
|
||||
hidden := f.Tag.Get("hidden") == "true"
|
||||
|
||||
// decimal.Decimal is an opaque numeric type used for pricing
|
||||
// precision; do not recurse into its internal struct fields.
|
||||
isDecimal := ft == reflect.TypeOf(decimal.Decimal{})
|
||||
|
||||
// If the field is a struct (not a map), recurse to flatten
|
||||
// its children using dot-separated names — unless the
|
||||
// entire struct is marked hidden, in which case emit it
|
||||
// as a single opaque field.
|
||||
if ft.Kind() == reflect.Struct && !hidden {
|
||||
if ft.Kind() == reflect.Struct && !hidden && !isDecimal {
|
||||
nested := extractFields(ft, fullJSONName, nil)
|
||||
fields = append(fields, nested.Fields...)
|
||||
continue
|
||||
@@ -206,6 +212,12 @@ func goTypeToSchemaType(t reflect.Type) string {
|
||||
t = t.Elem()
|
||||
}
|
||||
|
||||
// decimal.Decimal represents a precise numeric value and should
|
||||
// map to the "number" schema type.
|
||||
if t == reflect.TypeOf(decimal.Decimal{}) {
|
||||
return "number"
|
||||
}
|
||||
|
||||
switch t.Kind() {
|
||||
case reflect.String:
|
||||
return "string"
|
||||
|
||||
Generated
+94
-9
@@ -1071,6 +1071,96 @@ export interface Chat {
|
||||
readonly archived: boolean;
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
/**
|
||||
* ChatCostChatBreakdown contains per-root-chat cost aggregation.
|
||||
*/
|
||||
export interface ChatCostChatBreakdown {
|
||||
readonly root_chat_id: string;
|
||||
readonly chat_title: string;
|
||||
readonly total_cost_micros: number;
|
||||
readonly message_count: number;
|
||||
readonly total_input_tokens: number;
|
||||
readonly total_output_tokens: number;
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
/**
|
||||
* ChatCostModelBreakdown contains per-model cost aggregation.
|
||||
*/
|
||||
export interface ChatCostModelBreakdown {
|
||||
readonly model_config_id: string;
|
||||
readonly display_name: string;
|
||||
readonly provider: string;
|
||||
readonly model: string;
|
||||
readonly total_cost_micros: number;
|
||||
readonly message_count: number;
|
||||
readonly total_input_tokens: number;
|
||||
readonly total_output_tokens: number;
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
/**
|
||||
* ChatCostSummary is the response from the chat cost summary endpoint.
|
||||
*/
|
||||
export interface ChatCostSummary {
|
||||
readonly start_date: string;
|
||||
readonly end_date: string;
|
||||
readonly total_cost_micros: number;
|
||||
readonly priced_message_count: number;
|
||||
readonly unpriced_message_count: number;
|
||||
readonly total_input_tokens: number;
|
||||
readonly total_output_tokens: number;
|
||||
readonly by_model: readonly ChatCostModelBreakdown[];
|
||||
readonly by_chat: readonly ChatCostChatBreakdown[];
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
/**
|
||||
* ChatCostSummaryOptions are optional query parameters for GetChatCostSummary.
|
||||
*/
|
||||
export interface ChatCostSummaryOptions {
|
||||
readonly StartDate: string;
|
||||
readonly EndDate: string;
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
/**
|
||||
* ChatCostUserRollup contains per-user cost aggregation for admin views.
|
||||
*/
|
||||
export interface ChatCostUserRollup {
|
||||
readonly user_id: string;
|
||||
readonly username: string;
|
||||
readonly name: string;
|
||||
readonly avatar_url: string;
|
||||
readonly total_cost_micros: number;
|
||||
readonly message_count: number;
|
||||
readonly chat_count: number;
|
||||
readonly total_input_tokens: number;
|
||||
readonly total_output_tokens: number;
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
/**
|
||||
* ChatCostUsersOptions are optional query parameters for GetChatCostUsers.
|
||||
*/
|
||||
export interface ChatCostUsersOptions extends Pagination {
|
||||
readonly StartDate: string;
|
||||
readonly EndDate: string;
|
||||
readonly Username: string;
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
/**
|
||||
* ChatCostUsersResponse is the response from the admin chat cost users endpoint.
|
||||
*/
|
||||
export interface ChatCostUsersResponse {
|
||||
readonly start_date: string;
|
||||
readonly end_date: string;
|
||||
readonly count: number;
|
||||
readonly users: readonly ChatCostUserRollup[];
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
/**
|
||||
* ChatDiffContents represents the resolved diff text for a chat.
|
||||
@@ -3466,15 +3556,10 @@ export interface MinimalUser {
|
||||
* ModelCostConfig stores pricing metadata for a chat model.
|
||||
*/
|
||||
export interface ModelCostConfig {
|
||||
/**
|
||||
* Pricing is stored as configuration metadata and currently only needs to
|
||||
* round-trip cleanly through the API and admin UI. If we later use these
|
||||
* values for billing-grade arithmetic, switch to a fixed-point type.
|
||||
*/
|
||||
readonly input_price_per_million_tokens?: number;
|
||||
readonly output_price_per_million_tokens?: number;
|
||||
readonly cache_read_price_per_million_tokens?: number;
|
||||
readonly cache_write_price_per_million_tokens?: number;
|
||||
readonly input_price_per_million_tokens?: string;
|
||||
readonly output_price_per_million_tokens?: string;
|
||||
readonly cache_read_price_per_million_tokens?: string;
|
||||
readonly cache_write_price_per_million_tokens?: string;
|
||||
}
|
||||
|
||||
// From netcheck/netcheck.go
|
||||
|
||||
@@ -1,99 +0,0 @@
|
||||
import { render, screen } from "@testing-library/react";
|
||||
import type * as TypesGen from "api/typesGenerated";
|
||||
import { TooltipProvider } from "components/Tooltip/Tooltip";
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
import type { ProviderState } from "./ChatModelAdminPanel";
|
||||
import { ModelsSection } from "./ModelsSection";
|
||||
|
||||
vi.mock("./ProviderIcon", () => ({
|
||||
ProviderIcon: ({ provider }: { provider: string }) => (
|
||||
<div data-testid="provider-icon">{provider}</div>
|
||||
),
|
||||
}));
|
||||
|
||||
const providerState: ProviderState = {
|
||||
provider: "openai",
|
||||
label: "OpenAI",
|
||||
providerConfig: {
|
||||
id: "provider-config-id",
|
||||
provider: "openai",
|
||||
display_name: "OpenAI",
|
||||
enabled: true,
|
||||
has_api_key: true,
|
||||
base_url: undefined,
|
||||
source: "database",
|
||||
created_at: "2025-01-01T00:00:00Z",
|
||||
updated_at: "2025-01-01T00:00:00Z",
|
||||
},
|
||||
modelConfigs: [],
|
||||
catalogModelCount: 0,
|
||||
hasManagedAPIKey: true,
|
||||
hasCatalogAPIKey: true,
|
||||
hasEffectiveAPIKey: true,
|
||||
isEnvPreset: false,
|
||||
baseURL: "",
|
||||
};
|
||||
|
||||
const baseModelConfig: TypesGen.ChatModelConfig = {
|
||||
id: "model-config-id",
|
||||
provider: "openai",
|
||||
model: "gpt-4.1",
|
||||
display_name: "GPT-4.1",
|
||||
enabled: true,
|
||||
is_default: false,
|
||||
context_limit: 128000,
|
||||
compression_threshold: 80,
|
||||
created_at: "2025-01-01T00:00:00Z",
|
||||
updated_at: "2025-01-01T00:00:00Z",
|
||||
};
|
||||
|
||||
const renderModelsSection = (
|
||||
modelConfigs: readonly TypesGen.ChatModelConfig[],
|
||||
) => {
|
||||
return render(
|
||||
<TooltipProvider>
|
||||
<ModelsSection
|
||||
sectionLabel="Models"
|
||||
providerStates={[providerState]}
|
||||
selectedProvider="openai"
|
||||
selectedProviderState={providerState}
|
||||
onSelectedProviderChange={vi.fn()}
|
||||
modelConfigs={modelConfigs}
|
||||
modelConfigsUnavailable={false}
|
||||
isCreating={false}
|
||||
isUpdating={false}
|
||||
isDeleting={false}
|
||||
onCreateModel={vi.fn()}
|
||||
onUpdateModel={vi.fn()}
|
||||
onDeleteModel={vi.fn()}
|
||||
/>
|
||||
</TooltipProvider>,
|
||||
);
|
||||
};
|
||||
|
||||
describe("ModelsSection", () => {
|
||||
it("shows a warning when a model has no custom pricing configured", () => {
|
||||
renderModelsSection([baseModelConfig]);
|
||||
|
||||
expect(
|
||||
screen.getByText("Model pricing is not defined"),
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
it("hides the warning when a model has explicit zero pricing", () => {
|
||||
renderModelsSection([
|
||||
{
|
||||
...baseModelConfig,
|
||||
model_config: {
|
||||
cost: {
|
||||
output_price_per_million_tokens: 0,
|
||||
},
|
||||
},
|
||||
},
|
||||
]);
|
||||
|
||||
expect(
|
||||
screen.queryByText("Model pricing is not defined"),
|
||||
).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
@@ -210,10 +210,10 @@ describe("extractModelConfigFormState", () => {
|
||||
...baseChatModelConfig,
|
||||
model_config: {
|
||||
cost: {
|
||||
input_price_per_million_tokens: 0.15,
|
||||
output_price_per_million_tokens: 0.6,
|
||||
cache_read_price_per_million_tokens: 0.03,
|
||||
cache_write_price_per_million_tokens: 0.3,
|
||||
input_price_per_million_tokens: "0.15",
|
||||
output_price_per_million_tokens: "0.6",
|
||||
cache_read_price_per_million_tokens: "0.03",
|
||||
cache_write_price_per_million_tokens: "0.3",
|
||||
},
|
||||
},
|
||||
};
|
||||
@@ -553,10 +553,10 @@ describe("buildModelConfigFromForm", () => {
|
||||
expect(result.fieldErrors).toEqual({});
|
||||
expect(result.modelConfig).toMatchObject({
|
||||
cost: {
|
||||
input_price_per_million_tokens: 0.15,
|
||||
output_price_per_million_tokens: 0.6,
|
||||
cache_read_price_per_million_tokens: 0.03,
|
||||
cache_write_price_per_million_tokens: 0.3,
|
||||
input_price_per_million_tokens: "0.15",
|
||||
output_price_per_million_tokens: "0.6",
|
||||
cache_read_price_per_million_tokens: "0.03",
|
||||
cache_write_price_per_million_tokens: "0.3",
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
@@ -110,7 +110,7 @@ function convertFormValue(value: string, field: FieldSchema): unknown {
|
||||
case "integer":
|
||||
return Number.parseInt(trimmed, 10);
|
||||
case "number":
|
||||
return Number(trimmed);
|
||||
return isNonNegativePricingField(field) ? trimmed : Number(trimmed);
|
||||
case "boolean":
|
||||
return trimmed === "true";
|
||||
case "array":
|
||||
|
||||
@@ -10,7 +10,7 @@ import {
|
||||
describe("pricingFields", () => {
|
||||
it("uses $0 defaults for every pricing field", () => {
|
||||
for (const fieldName of pricingFieldNameList) {
|
||||
expect(getDefaultPricingForField(fieldName)).toBe(0);
|
||||
expect(getDefaultPricingForField(fieldName)).toBe("0");
|
||||
expect(getPricingPlaceholderForField(fieldName)).toBe("0");
|
||||
}
|
||||
});
|
||||
@@ -23,8 +23,8 @@ describe("pricingFields", () => {
|
||||
expect(
|
||||
hasCustomPricing({
|
||||
cost: {
|
||||
input_price_per_million_tokens: 0,
|
||||
output_price_per_million_tokens: 0,
|
||||
input_price_per_million_tokens: "0",
|
||||
output_price_per_million_tokens: "0",
|
||||
},
|
||||
} satisfies TypesGen.ChatModelCallConfig),
|
||||
).toBe(true);
|
||||
@@ -34,7 +34,7 @@ describe("pricingFields", () => {
|
||||
expect(
|
||||
hasCustomPricing({
|
||||
cost: {
|
||||
cache_write_price_per_million_tokens: 0.25,
|
||||
cache_write_price_per_million_tokens: "0.25",
|
||||
},
|
||||
} satisfies TypesGen.ChatModelCallConfig),
|
||||
).toBe(true);
|
||||
|
||||
@@ -14,11 +14,11 @@ export const pricingFieldNames = new Set<string>(pricingFieldNameList);
|
||||
type PricingFieldName = (typeof pricingFieldNameList)[number];
|
||||
|
||||
export const defaultPricingByFieldName = {
|
||||
"cost.input_price_per_million_tokens": 0,
|
||||
"cost.output_price_per_million_tokens": 0,
|
||||
"cost.cache_read_price_per_million_tokens": 0,
|
||||
"cost.cache_write_price_per_million_tokens": 0,
|
||||
} as const satisfies Record<PricingFieldName, number>;
|
||||
"cost.input_price_per_million_tokens": "0",
|
||||
"cost.output_price_per_million_tokens": "0",
|
||||
"cost.cache_read_price_per_million_tokens": "0",
|
||||
"cost.cache_write_price_per_million_tokens": "0",
|
||||
} as const satisfies Record<PricingFieldName, string>;
|
||||
|
||||
export const pricingPlaceholderByFieldName = {
|
||||
"cost.input_price_per_million_tokens": "0",
|
||||
@@ -29,7 +29,7 @@ export const pricingPlaceholderByFieldName = {
|
||||
|
||||
export const getDefaultPricingForField = (
|
||||
fieldName: string,
|
||||
): number | undefined =>
|
||||
): string | undefined =>
|
||||
defaultPricingByFieldName[
|
||||
fieldName as keyof typeof defaultPricingByFieldName
|
||||
];
|
||||
|
||||
Reference in New Issue
Block a user