fix: address review comments on InsertChatMessages (#23239)

Follow-up to #23220, addressing Cian's review comments:

- **SQL casing**: Uppercase `UNNEST` to match `NULLIF`/`COALESCE`
convention in the query.
- **Builder pattern**: `chatMessage` struct now uses unexported fields
with a `newChatMessage` constructor for required fields (role, content,
visibility, modelConfigID, contentVersion) and chainable builder methods
(`withCreatedBy`, `withCompressed`, `withUsage`, `withContextLimit`,
`withTotalCostMicros`, `withRuntimeMs`) for optional/nullable fields.
- **Batch test in chats_test**: Replaced the `for i := 0; i < 2` loop
with a single batch insert of 2 messages to actually exercise the batch
logic.
- **Multi-message querier test**: Added `BatchInsertMultipleMessages`
test verifying 3-message batch insert with role ordering, sequential
IDs, nullable field semantics (NULL for zero UUIDs and zero ints), and
token/cost assertions.

---------

Co-authored-by: Cian Johnston <cian@coder.com>
This commit is contained in:
Kyle Carberry
2026-03-18 13:06:44 -04:00
committed by GitHub
parent c46136ff73
commit d4a072b61e
5 changed files with 294 additions and 194 deletions
+62
View File
@@ -9534,6 +9534,68 @@ func TestInsertChatMessages(t *testing.T) {
require.NoError(t, err)
require.Equal(t, modelConfigA.ID, gotChat.LastModelConfigID)
})
t.Run("BatchInsertMultipleMessages", func(t *testing.T) {
t.Parallel()
store, ctx, user, chat, _, modelConfigA := setupChat(t)
msgs, err := store.InsertChatMessages(ctx, database.InsertChatMessagesParams{
ChatID: chat.ID,
CreatedBy: []uuid.UUID{user.ID, uuid.Nil, uuid.Nil},
ModelConfigID: []uuid.UUID{modelConfigA.ID, modelConfigA.ID, modelConfigA.ID},
Role: []database.ChatMessageRole{database.ChatMessageRoleUser, database.ChatMessageRoleAssistant, database.ChatMessageRoleTool},
ContentVersion: []int16{chatprompt.CurrentContentVersion, chatprompt.CurrentContentVersion, chatprompt.CurrentContentVersion},
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth, database.ChatMessageVisibilityBoth, database.ChatMessageVisibilityBoth},
Content: []string{`"hello"`, `"response"`, `"tool result"`},
InputTokens: []int64{10, 0, 0},
OutputTokens: []int64{0, 20, 0},
TotalTokens: []int64{10, 20, 0},
ReasoningTokens: []int64{0, 5, 0},
CacheCreationTokens: []int64{0, 0, 0},
CacheReadTokens: []int64{0, 0, 0},
ContextLimit: []int64{0, 0, 0},
Compressed: []bool{false, false, false},
TotalCostMicros: []int64{0, 100, 0},
RuntimeMs: []int64{0, 500, 0},
})
require.NoError(t, err)
require.Len(t, msgs, 3)
// Verify ordering and roles.
require.Equal(t, database.ChatMessageRoleUser, msgs[0].Role)
require.Equal(t, database.ChatMessageRoleAssistant, msgs[1].Role)
require.Equal(t, database.ChatMessageRoleTool, msgs[2].Role)
// Verify IDs are sequential.
require.Less(t, msgs[0].ID, msgs[1].ID)
require.Less(t, msgs[1].ID, msgs[2].ID)
// Verify nullable fields: user message has CreatedBy set.
require.True(t, msgs[0].CreatedBy.Valid)
require.Equal(t, user.ID, msgs[0].CreatedBy.UUID)
// Assistant and tool messages have NULL CreatedBy.
require.False(t, msgs[1].CreatedBy.Valid)
require.False(t, msgs[2].CreatedBy.Valid)
// Verify token fields stored as NULL when zero.
require.True(t, msgs[0].InputTokens.Valid)
require.Equal(t, int64(10), msgs[0].InputTokens.Int64)
require.False(t, msgs[0].OutputTokens.Valid) // 0 → NULL
require.True(t, msgs[1].OutputTokens.Valid)
require.Equal(t, int64(20), msgs[1].OutputTokens.Int64)
// Verify cost: assistant has cost, others NULL.
require.True(t, msgs[1].TotalCostMicros.Valid)
require.Equal(t, int64(100), msgs[1].TotalCostMicros.Int64)
require.False(t, msgs[0].TotalCostMicros.Valid)
require.False(t, msgs[2].TotalCostMicros.Valid)
// Verify runtime_ms on assistant message.
require.True(t, msgs[1].RuntimeMs.Valid)
require.Equal(t, int64(500), msgs[1].RuntimeMs.Int64)
require.False(t, msgs[0].RuntimeMs.Valid)
})
}
func TestGetChatMessagesForPromptByChatID(t *testing.T) {
+19 -19
View File
@@ -4848,7 +4848,7 @@ WITH updated_chat AS (
SET
last_model_config_id = (
SELECT val
FROM unnest($3::uuid[])
FROM UNNEST($3::uuid[])
WITH ORDINALITY AS t(val, ord)
WHERE val != '00000000-0000-0000-0000-000000000000'::uuid
ORDER BY ord DESC
@@ -4858,12 +4858,12 @@ WITH updated_chat AS (
id = $1::uuid
AND EXISTS (
SELECT 1
FROM unnest($3::uuid[])
FROM UNNEST($3::uuid[])
WHERE unnest != '00000000-0000-0000-0000-000000000000'::uuid
)
AND chats.last_model_config_id IS DISTINCT FROM (
SELECT val
FROM unnest($3::uuid[])
FROM UNNEST($3::uuid[])
WITH ORDINALITY AS t(val, ord)
WHERE val != '00000000-0000-0000-0000-000000000000'::uuid
ORDER BY ord DESC
@@ -4891,22 +4891,22 @@ INSERT INTO chat_messages (
)
SELECT
$1::uuid,
NULLIF(unnest($2::uuid[]), '00000000-0000-0000-0000-000000000000'::uuid),
NULLIF(unnest($3::uuid[]), '00000000-0000-0000-0000-000000000000'::uuid),
unnest($4::chat_message_role[]),
unnest($5::text[])::jsonb,
unnest($6::smallint[]),
unnest($7::chat_message_visibility[]),
NULLIF(unnest($8::bigint[]), 0),
NULLIF(unnest($9::bigint[]), 0),
NULLIF(unnest($10::bigint[]), 0),
NULLIF(unnest($11::bigint[]), 0),
NULLIF(unnest($12::bigint[]), 0),
NULLIF(unnest($13::bigint[]), 0),
NULLIF(unnest($14::bigint[]), 0),
unnest($15::boolean[]),
NULLIF(unnest($16::bigint[]), 0),
NULLIF(unnest($17::bigint[]), 0)
NULLIF(UNNEST($2::uuid[]), '00000000-0000-0000-0000-000000000000'::uuid),
NULLIF(UNNEST($3::uuid[]), '00000000-0000-0000-0000-000000000000'::uuid),
UNNEST($4::chat_message_role[]),
UNNEST($5::text[])::jsonb,
UNNEST($6::smallint[]),
UNNEST($7::chat_message_visibility[]),
NULLIF(UNNEST($8::bigint[]), 0),
NULLIF(UNNEST($9::bigint[]), 0),
NULLIF(UNNEST($10::bigint[]), 0),
NULLIF(UNNEST($11::bigint[]), 0),
NULLIF(UNNEST($12::bigint[]), 0),
NULLIF(UNNEST($13::bigint[]), 0),
NULLIF(UNNEST($14::bigint[]), 0),
UNNEST($15::boolean[]),
NULLIF(UNNEST($16::bigint[]), 0),
NULLIF(UNNEST($17::bigint[]), 0)
RETURNING
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms
`
+19 -19
View File
@@ -185,7 +185,7 @@ WITH updated_chat AS (
SET
last_model_config_id = (
SELECT val
FROM unnest(@model_config_id::uuid[])
FROM UNNEST(@model_config_id::uuid[])
WITH ORDINALITY AS t(val, ord)
WHERE val != '00000000-0000-0000-0000-000000000000'::uuid
ORDER BY ord DESC
@@ -195,12 +195,12 @@ WITH updated_chat AS (
id = @chat_id::uuid
AND EXISTS (
SELECT 1
FROM unnest(@model_config_id::uuid[])
FROM UNNEST(@model_config_id::uuid[])
WHERE unnest != '00000000-0000-0000-0000-000000000000'::uuid
)
AND chats.last_model_config_id IS DISTINCT FROM (
SELECT val
FROM unnest(@model_config_id::uuid[])
FROM UNNEST(@model_config_id::uuid[])
WITH ORDINALITY AS t(val, ord)
WHERE val != '00000000-0000-0000-0000-000000000000'::uuid
ORDER BY ord DESC
@@ -228,22 +228,22 @@ INSERT INTO chat_messages (
)
SELECT
@chat_id::uuid,
NULLIF(unnest(@created_by::uuid[]), '00000000-0000-0000-0000-000000000000'::uuid),
NULLIF(unnest(@model_config_id::uuid[]), '00000000-0000-0000-0000-000000000000'::uuid),
unnest(@role::chat_message_role[]),
unnest(@content::text[])::jsonb,
unnest(@content_version::smallint[]),
unnest(@visibility::chat_message_visibility[]),
NULLIF(unnest(@input_tokens::bigint[]), 0),
NULLIF(unnest(@output_tokens::bigint[]), 0),
NULLIF(unnest(@total_tokens::bigint[]), 0),
NULLIF(unnest(@reasoning_tokens::bigint[]), 0),
NULLIF(unnest(@cache_creation_tokens::bigint[]), 0),
NULLIF(unnest(@cache_read_tokens::bigint[]), 0),
NULLIF(unnest(@context_limit::bigint[]), 0),
unnest(@compressed::boolean[]),
NULLIF(unnest(@total_cost_micros::bigint[]), 0),
NULLIF(unnest(@runtime_ms::bigint[]), 0)
NULLIF(UNNEST(@created_by::uuid[]), '00000000-0000-0000-0000-000000000000'::uuid),
NULLIF(UNNEST(@model_config_id::uuid[]), '00000000-0000-0000-0000-000000000000'::uuid),
UNNEST(@role::chat_message_role[]),
UNNEST(@content::text[])::jsonb,
UNNEST(@content_version::smallint[]),
UNNEST(@visibility::chat_message_visibility[]),
NULLIF(UNNEST(@input_tokens::bigint[]), 0),
NULLIF(UNNEST(@output_tokens::bigint[]), 0),
NULLIF(UNNEST(@total_tokens::bigint[]), 0),
NULLIF(UNNEST(@reasoning_tokens::bigint[]), 0),
NULLIF(UNNEST(@cache_creation_tokens::bigint[]), 0),
NULLIF(UNNEST(@cache_read_tokens::bigint[]), 0),
NULLIF(UNNEST(@context_limit::bigint[]), 0),
UNNEST(@compressed::boolean[]),
NULLIF(UNNEST(@total_cost_micros::bigint[]), 0),
NULLIF(UNNEST(@runtime_ms::bigint[]), 0)
RETURNING
*;