feat: replace InsertChatMessage with batch InsertChatMessages (#23220)

Replaces the singular `InsertChatMessage` query with
`InsertChatMessages` that uses PostgreSQL's `unnest()` for batch
inserts. This reduces the number of database round-trips when inserting
multiple messages in a single transaction.

## Changes

- **SQL**: New `InsertChatMessages :many` query using `unnest()` arrays
following the existing codebase pattern (e.g.,
`InsertWorkspaceAgentStats`). Preserves the CTE that updates
`chats.last_model_config_id` using the last non-null model config from
the batch. Uses `NULLIF` for UUID columns to handle NULL foreign keys.
- **Go layers**: Updated `querier.go`, `dbauthz.go`,
`dbmetrics/querymetrics.go`, `dbmock/dbmock.go`, and `queries.sql.go` to
use the new batch signature (`[]ChatMessage` return type, array params).
- **chatd.go**: All call sites converted to batch inserts:
  - **CreateChat**: System prompt + user message batched into one call
- **persistStep**: Assistant message + tool messages batched into one
call
- **persistSummary**: Hidden summary + assistant + tool messages batched
into one call
  - Single-message sites use the same API with single-element arrays
- **Helper**: New `appendChatMessage` function simplifies building batch
params at each call site.
- **Tests**: All test files updated to use the new API.

Builds on top of #23213.
This commit is contained in:
Kyle Carberry
2026-03-18 12:27:07 -04:00
committed by GitHub
parent 1f5f6c9ccb
commit 483adc59fe
11 changed files with 692 additions and 584 deletions
+4 -4
View File
@@ -4603,16 +4603,16 @@ func (q *querier) InsertChatFile(ctx context.Context, arg database.InsertChatFil
return insert(q.log, q.auth, rbac.ResourceChat.WithOwner(arg.OwnerID.String()).InOrg(arg.OrganizationID), q.db.InsertChatFile)(ctx, arg)
}
func (q *querier) InsertChatMessage(ctx context.Context, arg database.InsertChatMessageParams) (database.ChatMessage, error) {
func (q *querier) InsertChatMessages(ctx context.Context, arg database.InsertChatMessagesParams) ([]database.ChatMessage, error) {
// Authorize create on the parent chat (using update permission).
chat, err := q.db.GetChatByID(ctx, arg.ChatID)
if err != nil {
return database.ChatMessage{}, err
return nil, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return database.ChatMessage{}, err
return nil, err
}
return q.db.InsertChatMessage(ctx, arg)
return q.db.InsertChatMessages(ctx, arg)
}
func (q *querier) InsertChatModelConfig(ctx context.Context, arg database.InsertChatModelConfigParams) (database.ChatModelConfig, error) {
+5 -5
View File
@@ -675,13 +675,13 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().InsertChatFile(gomock.Any(), arg).Return(file, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceChat.WithOwner(arg.OwnerID.String()).InOrg(arg.OrganizationID), policy.ActionCreate).Returns(file)
}))
s.Run("InsertChatMessage", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
s.Run("InsertChatMessages", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
arg := testutil.Fake(s.T(), faker, database.InsertChatMessageParams{ChatID: chat.ID})
msg := testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})
arg := testutil.Fake(s.T(), faker, database.InsertChatMessagesParams{ChatID: chat.ID})
msgs := []database.ChatMessage{testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})}
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().InsertChatMessage(gomock.Any(), arg).Return(msg, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(msg)
dbm.EXPECT().InsertChatMessages(gomock.Any(), arg).Return(msgs, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(msgs)
}))
s.Run("InsertChatQueuedMessage", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
+4 -4
View File
@@ -3064,11 +3064,11 @@ func (m queryMetricsStore) InsertChatFile(ctx context.Context, arg database.Inse
return r0, r1
}
func (m queryMetricsStore) InsertChatMessage(ctx context.Context, arg database.InsertChatMessageParams) (database.ChatMessage, error) {
func (m queryMetricsStore) InsertChatMessages(ctx context.Context, arg database.InsertChatMessagesParams) ([]database.ChatMessage, error) {
start := time.Now()
r0, r1 := m.s.InsertChatMessage(ctx, arg)
m.queryLatencies.WithLabelValues("InsertChatMessage").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertChatMessage").Inc()
r0, r1 := m.s.InsertChatMessages(ctx, arg)
m.queryLatencies.WithLabelValues("InsertChatMessages").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "InsertChatMessages").Inc()
return r0, r1
}
+7 -7
View File
@@ -5735,19 +5735,19 @@ func (mr *MockStoreMockRecorder) InsertChatFile(ctx, arg any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChatFile", reflect.TypeOf((*MockStore)(nil).InsertChatFile), ctx, arg)
}
// InsertChatMessage mocks base method.
func (m *MockStore) InsertChatMessage(ctx context.Context, arg database.InsertChatMessageParams) (database.ChatMessage, error) {
// InsertChatMessages mocks base method.
func (m *MockStore) InsertChatMessages(ctx context.Context, arg database.InsertChatMessagesParams) ([]database.ChatMessage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "InsertChatMessage", ctx, arg)
ret0, _ := ret[0].(database.ChatMessage)
ret := m.ctrl.Call(m, "InsertChatMessages", ctx, arg)
ret0, _ := ret[0].([]database.ChatMessage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// InsertChatMessage indicates an expected call of InsertChatMessage.
func (mr *MockStoreMockRecorder) InsertChatMessage(ctx, arg any) *gomock.Call {
// InsertChatMessages indicates an expected call of InsertChatMessages.
func (mr *MockStoreMockRecorder) InsertChatMessages(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChatMessage", reflect.TypeOf((*MockStore)(nil).InsertChatMessage), ctx, arg)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertChatMessages", reflect.TypeOf((*MockStore)(nil).InsertChatMessages), ctx, arg)
}
// InsertChatModelConfig mocks base method.
+1 -1
View File
@@ -637,7 +637,7 @@ type sqlcQuerier interface {
InsertAuditLog(ctx context.Context, arg InsertAuditLogParams) (AuditLog, error)
InsertChat(ctx context.Context, arg InsertChatParams) (Chat, error)
InsertChatFile(ctx context.Context, arg InsertChatFileParams) (InsertChatFileRow, error)
InsertChatMessage(ctx context.Context, arg InsertChatMessageParams) (ChatMessage, error)
InsertChatMessages(ctx context.Context, arg InsertChatMessagesParams) ([]ChatMessage, error)
InsertChatModelConfig(ctx context.Context, arg InsertChatModelConfigParams) (ChatModelConfig, error)
InsertChatProvider(ctx context.Context, arg InsertChatProviderParams) (ChatProvider, error)
InsertChatQueuedMessage(ctx context.Context, arg InsertChatQueuedMessageParams) (ChatQueuedMessage, error)
+38 -23
View File
@@ -9404,7 +9404,7 @@ func TestInsertWorkspaceAgentDevcontainers(t *testing.T) {
}
}
func TestInsertChatMessage(t *testing.T) {
func TestInsertChatMessages(t *testing.T) {
t.Parallel()
insertModelConfig := func(
@@ -9478,17 +9478,24 @@ func TestInsertChatMessage(t *testing.T) {
insertMessage := func(t *testing.T, store database.Store, ctx context.Context, chatID, userID, modelConfigID uuid.UUID, content string) {
t.Helper()
_, err := store.InsertChatMessage(ctx, database.InsertChatMessageParams{
ChatID: chatID,
CreatedBy: uuid.NullUUID{UUID: userID, Valid: true},
ModelConfigID: uuid.NullUUID{UUID: modelConfigID, Valid: true},
Role: database.ChatMessageRoleUser,
ContentVersion: chatprompt.CurrentContentVersion,
Visibility: database.ChatMessageVisibilityBoth,
Content: pqtype.NullRawMessage{
RawMessage: json.RawMessage(fmt.Sprintf("%q", content)),
Valid: true,
},
_, err := store.InsertChatMessages(ctx, database.InsertChatMessagesParams{
ChatID: chatID,
CreatedBy: []uuid.UUID{userID},
ModelConfigID: []uuid.UUID{modelConfigID},
Role: []database.ChatMessageRole{database.ChatMessageRoleUser},
ContentVersion: []int16{chatprompt.CurrentContentVersion},
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth},
Content: []string{fmt.Sprintf("%q", content)},
InputTokens: []int64{0},
OutputTokens: []int64{0},
TotalTokens: []int64{0},
ReasoningTokens: []int64{0},
CacheCreationTokens: []int64{0},
CacheReadTokens: []int64{0},
ContextLimit: []int64{0},
Compressed: []bool{false},
TotalCostMicros: []int64{0},
RuntimeMs: []int64{0},
})
require.NoError(t, err)
}
@@ -9583,19 +9590,27 @@ func TestGetChatMessagesForPromptByChatID(t *testing.T) {
content string,
) database.ChatMessage {
t.Helper()
msg, err := db.InsertChatMessage(ctx, database.InsertChatMessageParams{
ChatID: chatID,
Role: role,
ContentVersion: chatprompt.CurrentContentVersion,
Visibility: vis,
Compressed: sql.NullBool{Bool: compressed, Valid: true},
Content: pqtype.NullRawMessage{
RawMessage: json.RawMessage(`"` + content + `"`),
Valid: true,
},
results, err := db.InsertChatMessages(ctx, database.InsertChatMessagesParams{
ChatID: chatID,
CreatedBy: []uuid.UUID{uuid.Nil},
ModelConfigID: []uuid.UUID{uuid.Nil},
Role: []database.ChatMessageRole{role},
ContentVersion: []int16{chatprompt.CurrentContentVersion},
Visibility: []database.ChatMessageVisibility{vis},
Compressed: []bool{compressed},
Content: []string{`"` + content + `"`},
InputTokens: []int64{0},
OutputTokens: []int64{0},
TotalTokens: []int64{0},
ReasoningTokens: []int64{0},
CacheCreationTokens: []int64{0},
CacheReadTokens: []int64{0},
ContextLimit: []int64{0},
TotalCostMicros: []int64{0},
RuntimeMs: []int64{0},
})
require.NoError(t, err)
return msg
return results[0]
}
msgIDs := func(msgs []database.ChatMessage) []int64 {
+115 -81
View File
@@ -4841,16 +4841,34 @@ func (q *sqlQuerier) InsertChat(ctx context.Context, arg InsertChatParams) (Chat
return i, err
}
const insertChatMessage = `-- name: InsertChatMessage :one
const insertChatMessages = `-- name: InsertChatMessages :many
WITH updated_chat AS (
UPDATE
chats
SET
last_model_config_id = $3::uuid
last_model_config_id = (
SELECT val
FROM unnest($3::uuid[])
WITH ORDINALITY AS t(val, ord)
WHERE val != '00000000-0000-0000-0000-000000000000'::uuid
ORDER BY ord DESC
LIMIT 1
)
WHERE
id = $1::uuid
AND $3::uuid IS NOT NULL
AND chats.last_model_config_id IS DISTINCT FROM $3::uuid
AND EXISTS (
SELECT 1
FROM unnest($3::uuid[])
WHERE unnest != '00000000-0000-0000-0000-000000000000'::uuid
)
AND chats.last_model_config_id IS DISTINCT FROM (
SELECT val
FROM unnest($3::uuid[])
WITH ORDINALITY AS t(val, ord)
WHERE val != '00000000-0000-0000-0000-000000000000'::uuid
ORDER BY ord DESC
LIMIT 1
)
)
INSERT INTO chat_messages (
chat_id,
@@ -4870,92 +4888,108 @@ INSERT INTO chat_messages (
compressed,
total_cost_micros,
runtime_ms
) VALUES (
$1::uuid,
$2::uuid,
$3::uuid,
$4::chat_message_role,
$5::jsonb,
$6::smallint,
$7::chat_message_visibility,
$8::bigint,
$9::bigint,
$10::bigint,
$11::bigint,
$12::bigint,
$13::bigint,
$14::bigint,
COALESCE($15::boolean, FALSE),
$16::bigint,
$17::bigint
)
SELECT
$1::uuid,
NULLIF(unnest($2::uuid[]), '00000000-0000-0000-0000-000000000000'::uuid),
NULLIF(unnest($3::uuid[]), '00000000-0000-0000-0000-000000000000'::uuid),
unnest($4::chat_message_role[]),
unnest($5::text[])::jsonb,
unnest($6::smallint[]),
unnest($7::chat_message_visibility[]),
NULLIF(unnest($8::bigint[]), 0),
NULLIF(unnest($9::bigint[]), 0),
NULLIF(unnest($10::bigint[]), 0),
NULLIF(unnest($11::bigint[]), 0),
NULLIF(unnest($12::bigint[]), 0),
NULLIF(unnest($13::bigint[]), 0),
NULLIF(unnest($14::bigint[]), 0),
unnest($15::boolean[]),
NULLIF(unnest($16::bigint[]), 0),
NULLIF(unnest($17::bigint[]), 0)
RETURNING
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms
`
type InsertChatMessageParams struct {
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
CreatedBy uuid.NullUUID `db:"created_by" json:"created_by"`
ModelConfigID uuid.NullUUID `db:"model_config_id" json:"model_config_id"`
Role ChatMessageRole `db:"role" json:"role"`
Content pqtype.NullRawMessage `db:"content" json:"content"`
ContentVersion int16 `db:"content_version" json:"content_version"`
Visibility ChatMessageVisibility `db:"visibility" json:"visibility"`
InputTokens sql.NullInt64 `db:"input_tokens" json:"input_tokens"`
OutputTokens sql.NullInt64 `db:"output_tokens" json:"output_tokens"`
TotalTokens sql.NullInt64 `db:"total_tokens" json:"total_tokens"`
ReasoningTokens sql.NullInt64 `db:"reasoning_tokens" json:"reasoning_tokens"`
CacheCreationTokens sql.NullInt64 `db:"cache_creation_tokens" json:"cache_creation_tokens"`
CacheReadTokens sql.NullInt64 `db:"cache_read_tokens" json:"cache_read_tokens"`
ContextLimit sql.NullInt64 `db:"context_limit" json:"context_limit"`
Compressed sql.NullBool `db:"compressed" json:"compressed"`
TotalCostMicros sql.NullInt64 `db:"total_cost_micros" json:"total_cost_micros"`
RuntimeMs sql.NullInt64 `db:"runtime_ms" json:"runtime_ms"`
type InsertChatMessagesParams struct {
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
CreatedBy []uuid.UUID `db:"created_by" json:"created_by"`
ModelConfigID []uuid.UUID `db:"model_config_id" json:"model_config_id"`
Role []ChatMessageRole `db:"role" json:"role"`
Content []string `db:"content" json:"content"`
ContentVersion []int16 `db:"content_version" json:"content_version"`
Visibility []ChatMessageVisibility `db:"visibility" json:"visibility"`
InputTokens []int64 `db:"input_tokens" json:"input_tokens"`
OutputTokens []int64 `db:"output_tokens" json:"output_tokens"`
TotalTokens []int64 `db:"total_tokens" json:"total_tokens"`
ReasoningTokens []int64 `db:"reasoning_tokens" json:"reasoning_tokens"`
CacheCreationTokens []int64 `db:"cache_creation_tokens" json:"cache_creation_tokens"`
CacheReadTokens []int64 `db:"cache_read_tokens" json:"cache_read_tokens"`
ContextLimit []int64 `db:"context_limit" json:"context_limit"`
Compressed []bool `db:"compressed" json:"compressed"`
TotalCostMicros []int64 `db:"total_cost_micros" json:"total_cost_micros"`
RuntimeMs []int64 `db:"runtime_ms" json:"runtime_ms"`
}
func (q *sqlQuerier) InsertChatMessage(ctx context.Context, arg InsertChatMessageParams) (ChatMessage, error) {
row := q.db.QueryRowContext(ctx, insertChatMessage,
func (q *sqlQuerier) InsertChatMessages(ctx context.Context, arg InsertChatMessagesParams) ([]ChatMessage, error) {
rows, err := q.db.QueryContext(ctx, insertChatMessages,
arg.ChatID,
arg.CreatedBy,
arg.ModelConfigID,
arg.Role,
arg.Content,
arg.ContentVersion,
arg.Visibility,
arg.InputTokens,
arg.OutputTokens,
arg.TotalTokens,
arg.ReasoningTokens,
arg.CacheCreationTokens,
arg.CacheReadTokens,
arg.ContextLimit,
arg.Compressed,
arg.TotalCostMicros,
arg.RuntimeMs,
pq.Array(arg.CreatedBy),
pq.Array(arg.ModelConfigID),
pq.Array(arg.Role),
pq.Array(arg.Content),
pq.Array(arg.ContentVersion),
pq.Array(arg.Visibility),
pq.Array(arg.InputTokens),
pq.Array(arg.OutputTokens),
pq.Array(arg.TotalTokens),
pq.Array(arg.ReasoningTokens),
pq.Array(arg.CacheCreationTokens),
pq.Array(arg.CacheReadTokens),
pq.Array(arg.ContextLimit),
pq.Array(arg.Compressed),
pq.Array(arg.TotalCostMicros),
pq.Array(arg.RuntimeMs),
)
var i ChatMessage
err := row.Scan(
&i.ID,
&i.ChatID,
&i.ModelConfigID,
&i.CreatedAt,
&i.Role,
&i.Content,
&i.Visibility,
&i.InputTokens,
&i.OutputTokens,
&i.TotalTokens,
&i.ReasoningTokens,
&i.CacheCreationTokens,
&i.CacheReadTokens,
&i.ContextLimit,
&i.Compressed,
&i.CreatedBy,
&i.ContentVersion,
&i.TotalCostMicros,
&i.RuntimeMs,
)
return i, err
if err != nil {
return nil, err
}
defer rows.Close()
var items []ChatMessage
for rows.Next() {
var i ChatMessage
if err := rows.Scan(
&i.ID,
&i.ChatID,
&i.ModelConfigID,
&i.CreatedAt,
&i.Role,
&i.Content,
&i.Visibility,
&i.InputTokens,
&i.OutputTokens,
&i.TotalTokens,
&i.ReasoningTokens,
&i.CacheCreationTokens,
&i.CacheReadTokens,
&i.ContextLimit,
&i.Compressed,
&i.CreatedBy,
&i.ContentVersion,
&i.TotalCostMicros,
&i.RuntimeMs,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const insertChatQueuedMessage = `-- name: InsertChatQueuedMessage :one
+40 -22
View File
@@ -178,16 +178,34 @@ INSERT INTO chats (
RETURNING
*;
-- name: InsertChatMessage :one
-- name: InsertChatMessages :many
WITH updated_chat AS (
UPDATE
chats
SET
last_model_config_id = sqlc.narg('model_config_id')::uuid
last_model_config_id = (
SELECT val
FROM unnest(@model_config_id::uuid[])
WITH ORDINALITY AS t(val, ord)
WHERE val != '00000000-0000-0000-0000-000000000000'::uuid
ORDER BY ord DESC
LIMIT 1
)
WHERE
id = @chat_id::uuid
AND sqlc.narg('model_config_id')::uuid IS NOT NULL
AND chats.last_model_config_id IS DISTINCT FROM sqlc.narg('model_config_id')::uuid
AND EXISTS (
SELECT 1
FROM unnest(@model_config_id::uuid[])
WHERE unnest != '00000000-0000-0000-0000-000000000000'::uuid
)
AND chats.last_model_config_id IS DISTINCT FROM (
SELECT val
FROM unnest(@model_config_id::uuid[])
WITH ORDINALITY AS t(val, ord)
WHERE val != '00000000-0000-0000-0000-000000000000'::uuid
ORDER BY ord DESC
LIMIT 1
)
)
INSERT INTO chat_messages (
chat_id,
@@ -207,25 +225,25 @@ INSERT INTO chat_messages (
compressed,
total_cost_micros,
runtime_ms
) VALUES (
@chat_id::uuid,
sqlc.narg('created_by')::uuid,
sqlc.narg('model_config_id')::uuid,
@role::chat_message_role,
sqlc.narg('content')::jsonb,
@content_version::smallint,
@visibility::chat_message_visibility,
sqlc.narg('input_tokens')::bigint,
sqlc.narg('output_tokens')::bigint,
sqlc.narg('total_tokens')::bigint,
sqlc.narg('reasoning_tokens')::bigint,
sqlc.narg('cache_creation_tokens')::bigint,
sqlc.narg('cache_read_tokens')::bigint,
sqlc.narg('context_limit')::bigint,
COALESCE(sqlc.narg('compressed')::boolean, FALSE),
sqlc.narg('total_cost_micros')::bigint,
sqlc.narg('runtime_ms')::bigint
)
SELECT
@chat_id::uuid,
NULLIF(unnest(@created_by::uuid[]), '00000000-0000-0000-0000-000000000000'::uuid),
NULLIF(unnest(@model_config_id::uuid[]), '00000000-0000-0000-0000-000000000000'::uuid),
unnest(@role::chat_message_role[]),
unnest(@content::text[])::jsonb,
unnest(@content_version::smallint[]),
unnest(@visibility::chat_message_visibility[]),
NULLIF(unnest(@input_tokens::bigint[]), 0),
NULLIF(unnest(@output_tokens::bigint[]), 0),
NULLIF(unnest(@total_tokens::bigint[]), 0),
NULLIF(unnest(@reasoning_tokens::bigint[]), 0),
NULLIF(unnest(@cache_creation_tokens::bigint[]), 0),
NULLIF(unnest(@cache_read_tokens::bigint[]), 0),
NULLIF(unnest(@context_limit::bigint[]), 0),
unnest(@compressed::boolean[]),
NULLIF(unnest(@total_cost_micros::bigint[]), 0),
NULLIF(unnest(@runtime_ms::bigint[]), 0)
RETURNING
*;