mirror of
https://github.com/coder/coder.git
synced 2026-06-05 05:58:20 +00:00
533 lines
24 KiB
Go
533 lines
24 KiB
Go
package chatd //nolint:testpackage // Uses unexported chatworker helpers.
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
"charm.land/fantasy"
|
|
"github.com/google/uuid"
|
|
"github.com/shopspring/decimal"
|
|
"github.com/sqlc-dev/pqtype"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"cdr.dev/slog/v3"
|
|
"cdr.dev/slog/v3/sloggers/slogtest"
|
|
"github.com/coder/coder/v2/coderd/database"
|
|
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
|
|
"github.com/coder/coder/v2/coderd/x/chatd/messagepartbuffer"
|
|
"github.com/coder/coder/v2/codersdk"
|
|
)
|
|
|
|
func TestBuildCommitStepMessages_AssistantTextAndReasoning(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
modelConfigID := uuid.New()
|
|
startedAt := time.Date(2026, 1, 2, 3, 4, 5, 0, time.UTC)
|
|
completedAt := startedAt.Add(2 * time.Second)
|
|
got, err := buildCommitStepMessages(buildCommitStepMessagesInput{
|
|
modelConfigID: modelConfigID,
|
|
contentVersion: chatprompt.CurrentContentVersion,
|
|
logger: slog.Make(),
|
|
step: stepData{
|
|
Content: []fantasy.Content{
|
|
fantasy.ReasoningContent{Text: "thinking"},
|
|
fantasy.TextContent{Text: "hello"},
|
|
},
|
|
ReasoningStartedAt: []time.Time{startedAt},
|
|
ReasoningCompletedAt: []time.Time{completedAt},
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, got.Messages, 1)
|
|
require.Equal(t, []int{0}, got.VisibleIndexes)
|
|
|
|
msg := got.Messages[0]
|
|
require.Equal(t, database.ChatMessageRoleAssistant, msg.Role)
|
|
require.Equal(t, database.ChatMessageVisibilityBoth, msg.Visibility)
|
|
require.Equal(t, uuid.NullUUID{UUID: modelConfigID, Valid: true}, msg.ModelConfigID)
|
|
require.Equal(t, chatprompt.CurrentContentVersion, msg.ContentVersion)
|
|
parts := parseMessageParts(t, msg.Role, msg.Content)
|
|
require.Len(t, parts, 2)
|
|
require.Equal(t, codersdk.ChatMessagePartTypeReasoning, parts[0].Type)
|
|
require.Equal(t, "thinking", parts[0].Text)
|
|
require.Equal(t, startedAt, requireNotNilTime(t, parts[0].CreatedAt))
|
|
require.Equal(t, completedAt, requireNotNilTime(t, parts[0].CompletedAt))
|
|
require.Equal(t, codersdk.ChatMessagePartTypeText, parts[1].Type)
|
|
require.Equal(t, "hello", parts[1].Text)
|
|
}
|
|
|
|
func TestBuildCommitStepMessages_LocalToolResultsBecomeToolMessages(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
modelConfigID := uuid.New()
|
|
got, err := buildCommitStepMessages(buildCommitStepMessagesInput{
|
|
modelConfigID: modelConfigID,
|
|
contentVersion: chatprompt.CurrentContentVersion,
|
|
logger: slog.Make(),
|
|
step: stepData{Content: []fantasy.Content{
|
|
fantasy.ToolCallContent{ToolCallID: "call-1", ToolName: "execute", Input: `{"cmd":"pwd"}`},
|
|
fantasy.ToolResultContent{
|
|
ToolCallID: "call-1",
|
|
ToolName: "execute",
|
|
Result: fantasy.ToolResultOutputContentText{Text: `{"stdout":"/tmp"}`},
|
|
},
|
|
}},
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, got.Messages, 2)
|
|
require.Equal(t, []int{0, 1}, got.VisibleIndexes)
|
|
|
|
assistantParts := parseMessageParts(t, got.Messages[0].Role, got.Messages[0].Content)
|
|
require.Len(t, assistantParts, 1)
|
|
require.Equal(t, codersdk.ChatMessagePartTypeToolCall, assistantParts[0].Type)
|
|
require.Equal(t, "call-1", assistantParts[0].ToolCallID)
|
|
require.Equal(t, "execute", assistantParts[0].ToolName)
|
|
|
|
toolParts := parseMessageParts(t, got.Messages[1].Role, got.Messages[1].Content)
|
|
require.Len(t, toolParts, 1)
|
|
require.Equal(t, codersdk.ChatMessagePartTypeToolResult, toolParts[0].Type)
|
|
require.Equal(t, "call-1", toolParts[0].ToolCallID)
|
|
require.Equal(t, "execute", toolParts[0].ToolName)
|
|
require.JSONEq(t, `{"stdout":"/tmp"}`, string(toolParts[0].Result))
|
|
}
|
|
|
|
func TestBuildCommitStepMessages_ProviderExecutedResultsStayAssistantContent(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
got, err := buildCommitStepMessages(buildCommitStepMessagesInput{
|
|
modelConfigID: uuid.New(),
|
|
contentVersion: chatprompt.CurrentContentVersion,
|
|
logger: slog.Make(),
|
|
step: stepData{Content: []fantasy.Content{
|
|
fantasy.ToolCallContent{
|
|
ToolCallID: "web-1",
|
|
ToolName: "web_search",
|
|
ProviderExecuted: true,
|
|
},
|
|
fantasy.ToolResultContent{
|
|
ToolCallID: "web-1",
|
|
ToolName: "web_search",
|
|
ProviderExecuted: true,
|
|
Result: fantasy.ToolResultOutputContentText{Text: `{"ok":true}`},
|
|
},
|
|
}},
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, got.Messages, 1)
|
|
parts := parseMessageParts(t, got.Messages[0].Role, got.Messages[0].Content)
|
|
require.Len(t, parts, 2)
|
|
require.Equal(t, codersdk.ChatMessagePartTypeToolCall, parts[0].Type)
|
|
require.True(t, parts[0].ProviderExecuted)
|
|
require.Equal(t, codersdk.ChatMessagePartTypeToolResult, parts[1].Type)
|
|
require.True(t, parts[1].ProviderExecuted)
|
|
}
|
|
|
|
func TestBuildCommitStepMessages_UsageCostRuntimeProviderResponseID(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
inputPrice := decimal.NewFromFloat(2.5)
|
|
outputPrice := decimal.NewFromFloat(7.5)
|
|
got, err := buildCommitStepMessages(buildCommitStepMessagesInput{
|
|
modelConfigID: uuid.New(),
|
|
contentVersion: chatprompt.CurrentContentVersion,
|
|
logger: slog.Make(),
|
|
modelCallConfig: codersdk.ChatModelCallConfig{
|
|
Cost: &codersdk.ModelCostConfig{
|
|
InputPricePerMillionTokens: &inputPrice,
|
|
OutputPricePerMillionTokens: &outputPrice,
|
|
},
|
|
},
|
|
step: stepData{
|
|
Content: []fantasy.Content{fantasy.TextContent{Text: "usage"}},
|
|
Usage: fantasy.Usage{InputTokens: 100, OutputTokens: 20, TotalTokens: 120, ReasoningTokens: 3, CacheCreationTokens: 4, CacheReadTokens: 5},
|
|
ContextLimit: sql.NullInt64{Int64: 4096, Valid: true},
|
|
ProviderResponseID: "resp-123",
|
|
Runtime: 1500 * time.Millisecond,
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, got.Messages, 1)
|
|
msg := got.Messages[0]
|
|
require.Equal(t, sql.NullInt64{Int64: 100, Valid: true}, msg.InputTokens)
|
|
require.Equal(t, sql.NullInt64{Int64: 20, Valid: true}, msg.OutputTokens)
|
|
require.Equal(t, sql.NullInt64{Int64: 120, Valid: true}, msg.TotalTokens)
|
|
require.Equal(t, sql.NullInt64{Int64: 3, Valid: true}, msg.ReasoningTokens)
|
|
require.Equal(t, sql.NullInt64{Int64: 4, Valid: true}, msg.CacheCreationTokens)
|
|
require.Equal(t, sql.NullInt64{Int64: 5, Valid: true}, msg.CacheReadTokens)
|
|
require.Equal(t, sql.NullInt64{Int64: 4096, Valid: true}, msg.ContextLimit)
|
|
require.Equal(t, sql.NullInt64{Int64: 1500, Valid: true}, msg.RuntimeMs)
|
|
require.Equal(t, sql.NullString{String: "resp-123", Valid: true}, msg.ProviderResponseID)
|
|
require.True(t, msg.TotalCostMicros.Valid)
|
|
require.Greater(t, msg.TotalCostMicros.Int64, int64(0))
|
|
}
|
|
|
|
func TestBuildCommitStepMessages_ToolTimestampsAndMCPConfigIDs(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
callAt := time.Date(2026, 2, 3, 4, 5, 6, 0, time.UTC)
|
|
resultAt := callAt.Add(3 * time.Second)
|
|
configID := uuid.New()
|
|
got, err := buildCommitStepMessages(buildCommitStepMessagesInput{
|
|
modelConfigID: uuid.New(),
|
|
contentVersion: chatprompt.CurrentContentVersion,
|
|
logger: slog.Make(),
|
|
toolNameToConfigID: map[string]uuid.UUID{
|
|
"mcp_tool": configID,
|
|
},
|
|
step: stepData{Content: []fantasy.Content{
|
|
fantasy.ToolCallContent{ToolCallID: "call-1", ToolName: "mcp_tool", Input: `{}`},
|
|
fantasy.ToolResultContent{ToolCallID: "call-1", ToolName: "mcp_tool", Result: fantasy.ToolResultOutputContentText{Text: `{"ok":true}`}},
|
|
}, ToolCallCreatedAt: map[string]time.Time{
|
|
"call-1": callAt,
|
|
}, ToolResultCreatedAt: map[string]time.Time{
|
|
"call-1": resultAt,
|
|
}},
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, got.Messages, 2)
|
|
callPart := parseMessageParts(t, got.Messages[0].Role, got.Messages[0].Content)[0]
|
|
resultPart := parseMessageParts(t, got.Messages[1].Role, got.Messages[1].Content)[0]
|
|
require.Equal(t, uuid.NullUUID{UUID: configID, Valid: true}, callPart.MCPServerConfigID)
|
|
require.Equal(t, callAt, requireNotNilTime(t, callPart.CreatedAt))
|
|
require.Equal(t, uuid.NullUUID{UUID: configID, Valid: true}, resultPart.MCPServerConfigID)
|
|
require.Equal(t, resultAt, requireNotNilTime(t, resultPart.CreatedAt))
|
|
}
|
|
|
|
func TestBuildCompactionMessages_CompressedSummaryToolCallAndResult(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
modelConfigID := uuid.New()
|
|
got, err := buildCompactionMessages(buildCompactionMessagesInput{
|
|
modelConfigID: modelConfigID,
|
|
contentVersion: chatprompt.CurrentContentVersion,
|
|
toolCallID: "summary-1",
|
|
toolName: "chat_summarized",
|
|
compaction: compactionOutcome{
|
|
SystemSummary: "system summary",
|
|
SummaryReport: "user report",
|
|
ThresholdPercent: 70,
|
|
UsagePercent: 81.5,
|
|
ContextTokens: 815,
|
|
ContextLimit: 1000,
|
|
},
|
|
})
|
|
require.NoError(t, err)
|
|
require.Equal(t, 1, got.HiddenCount)
|
|
require.Len(t, got.Messages, 3)
|
|
|
|
require.Equal(t, database.ChatMessageRoleUser, got.Messages[0].Role)
|
|
require.Equal(t, database.ChatMessageVisibilityModel, got.Messages[0].Visibility)
|
|
require.True(t, got.Messages[0].Compressed)
|
|
require.Equal(t, uuid.NullUUID{UUID: modelConfigID, Valid: true}, got.Messages[0].ModelConfigID)
|
|
require.Equal(t, "system summary", parseMessageParts(t, got.Messages[0].Role, got.Messages[0].Content)[0].Text)
|
|
|
|
require.Equal(t, database.ChatMessageRoleAssistant, got.Messages[1].Role)
|
|
require.Equal(t, database.ChatMessageVisibilityUser, got.Messages[1].Visibility)
|
|
require.True(t, got.Messages[1].Compressed)
|
|
callPart := parseMessageParts(t, got.Messages[1].Role, got.Messages[1].Content)[0]
|
|
require.Equal(t, codersdk.ChatMessagePartTypeToolCall, callPart.Type)
|
|
require.Equal(t, "summary-1", callPart.ToolCallID)
|
|
require.JSONEq(t, `{"source":"automatic","threshold_percent":70}`, string(callPart.Args))
|
|
|
|
require.Equal(t, database.ChatMessageRoleTool, got.Messages[2].Role)
|
|
require.Equal(t, database.ChatMessageVisibilityBoth, got.Messages[2].Visibility)
|
|
require.True(t, got.Messages[2].Compressed)
|
|
resultPart := parseMessageParts(t, got.Messages[2].Role, got.Messages[2].Content)[0]
|
|
require.Equal(t, codersdk.ChatMessagePartTypeToolResult, resultPart.Type)
|
|
require.Equal(t, "summary-1", resultPart.ToolCallID)
|
|
require.JSONEq(t, `{"summary":"user report","source":"automatic","threshold_percent":70,"usage_percent":81.5,"context_tokens":815,"context_limit_tokens":1000}`, string(resultPart.Result))
|
|
}
|
|
|
|
func TestCurrentTurnStepCount_ExcludesCompressedCompactionMessages(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
messages := []database.ChatMessage{
|
|
dbMessage(t, 1, database.ChatMessageRoleUser, false, codersdk.ChatMessageText("start")),
|
|
dbMessage(t, 2, database.ChatMessageRoleAssistant, false, codersdk.ChatMessageText("first")),
|
|
dbMessage(t, 3, database.ChatMessageRoleUser, true, codersdk.ChatMessageText("compressed summary")),
|
|
dbMessage(t, 4, database.ChatMessageRoleAssistant, true, codersdk.ChatMessageToolCall("summary", "chat_summarized", nil)),
|
|
dbMessage(t, 5, database.ChatMessageRoleTool, true, codersdk.ChatMessageToolResult("summary", "chat_summarized", json.RawMessage(`{}`), false, false)),
|
|
dbMessage(t, 6, database.ChatMessageRoleAssistant, false, codersdk.ChatMessageText("second")),
|
|
}
|
|
got := currentTurnStepCount(messages)
|
|
require.Equal(t, 2, got)
|
|
}
|
|
|
|
func TestCurrentTurnStepCount_CountsAssistantMessagesAfterLatestUser(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
messages := []database.ChatMessage{
|
|
dbMessage(t, 1, database.ChatMessageRoleUser, false, codersdk.ChatMessageText("old")),
|
|
dbMessage(t, 2, database.ChatMessageRoleAssistant, false, codersdk.ChatMessageText("old answer")),
|
|
dbMessage(t, 3, database.ChatMessageRoleUser, false, codersdk.ChatMessageText("new")),
|
|
dbMessage(t, 4, database.ChatMessageRoleAssistant, false, codersdk.ChatMessageText("one")),
|
|
dbMessage(t, 5, database.ChatMessageRoleTool, false, codersdk.ChatMessageToolResult("call", "tool", json.RawMessage(`{}`), false, false)),
|
|
dbMessage(t, 6, database.ChatMessageRoleAssistant, false, codersdk.ChatMessageText("two")),
|
|
}
|
|
got := currentTurnStepCount(messages)
|
|
require.Equal(t, 2, got)
|
|
}
|
|
|
|
func TestDecisionDetectsStopAfterToolFromCommittedHistory(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
messages := []database.ChatMessage{
|
|
dbMessage(t, 1, database.ChatMessageRoleUser, false, codersdk.ChatMessageText("plan")),
|
|
dbMessage(t, 2, database.ChatMessageRoleAssistant, false, codersdk.ChatMessageToolCall("plan-1", "propose_plan", json.RawMessage(`{}`))),
|
|
dbMessage(t, 3, database.ChatMessageRoleTool, false, codersdk.ChatMessageToolResult("plan-1", "propose_plan", json.RawMessage(`{"ok":true}`), false, false)),
|
|
}
|
|
got, err := historyHasStopAfterToolResult(messages, map[string]struct{}{"propose_plan": {}})
|
|
require.NoError(t, err)
|
|
require.True(t, got)
|
|
|
|
messages[2] = dbMessage(t, 3, database.ChatMessageRoleTool, false, codersdk.ChatMessageToolResult("plan-1", "propose_plan", json.RawMessage(`{"error":"no"}`), true, false))
|
|
got, err = historyHasStopAfterToolResult(messages, map[string]struct{}{"propose_plan": {}})
|
|
require.NoError(t, err)
|
|
require.False(t, got)
|
|
}
|
|
|
|
func TestDecisionDetectsCurrentHistoryCompletion(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
complete, err := currentHistoryComplete([]database.ChatMessage{
|
|
dbMessage(t, 1, database.ChatMessageRoleUser, false, codersdk.ChatMessageText("hello")),
|
|
dbMessage(t, 2, database.ChatMessageRoleAssistant, false, codersdk.ChatMessageText("done")),
|
|
})
|
|
require.NoError(t, err)
|
|
require.True(t, complete)
|
|
|
|
complete, err = currentHistoryComplete([]database.ChatMessage{
|
|
dbMessage(t, 1, database.ChatMessageRoleUser, false, codersdk.ChatMessageText("hello")),
|
|
dbMessage(t, 2, database.ChatMessageRoleAssistant, false, codersdk.ChatMessageToolCall("call-1", "execute", json.RawMessage(`{}`))),
|
|
})
|
|
require.NoError(t, err)
|
|
require.False(t, complete)
|
|
|
|
complete, err = currentHistoryComplete([]database.ChatMessage{
|
|
dbMessage(t, 1, database.ChatMessageRoleUser, false, codersdk.ChatMessageText("hello")),
|
|
dbMessage(t, 2, database.ChatMessageRoleAssistant, false, codersdk.ChatMessageToolCall("call-1", "execute", json.RawMessage(`{}`))),
|
|
dbMessage(t, 3, database.ChatMessageRoleTool, false, codersdk.ChatMessageToolResult("call-1", "execute", json.RawMessage(`{"ok":true}`), false, false)),
|
|
})
|
|
require.NoError(t, err)
|
|
require.False(t, complete)
|
|
}
|
|
|
|
func TestBufferedPartsToPartialMessages_NormalizesToolCallDeltasBeforeFinal(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
createdAt := time.Date(2026, 3, 4, 5, 6, 7, 0, time.UTC)
|
|
parts := []messagepartbuffer.Part{
|
|
{Seq: 1, Role: codersdk.ChatMessageRoleAssistant, MessagePart: codersdk.ChatMessageText("partial ")},
|
|
{Seq: 2, Role: codersdk.ChatMessageRoleAssistant, MessagePart: codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeToolCall, ToolCallID: "call-1", ToolName: "execute", ArgsDelta: `{"cmd":`}},
|
|
{Seq: 3, Role: codersdk.ChatMessageRoleAssistant, MessagePart: codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeToolCall, ToolCallID: "call-1", ToolName: "execute", ArgsDelta: `"ignored"}`}},
|
|
{Seq: 4, Role: codersdk.ChatMessageRoleAssistant, MessagePart: codersdk.ChatMessageToolCall("call-1", "execute", json.RawMessage(`{"cmd":"pwd"}`))},
|
|
}
|
|
got, err := bufferedPartsToPartialMessages(bufferedPartsToPartialMessagesInput{
|
|
parts: parts,
|
|
modelConfigID: uuid.New(),
|
|
contentVersion: chatprompt.CurrentContentVersion,
|
|
logger: slog.Make(),
|
|
interruptedAt: createdAt,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, got, 2)
|
|
assistantParts := parseMessageParts(t, got[0].Role, got[0].Content)
|
|
require.Len(t, assistantParts, 2)
|
|
require.Equal(t, codersdk.ChatMessagePartTypeText, assistantParts[0].Type)
|
|
call := assistantParts[1]
|
|
require.Equal(t, codersdk.ChatMessagePartTypeToolCall, call.Type)
|
|
require.Equal(t, "call-1", call.ToolCallID)
|
|
require.Empty(t, call.ArgsDelta)
|
|
require.JSONEq(t, `{"cmd":"pwd"}`, string(call.Args))
|
|
syntheticParts := parseMessageParts(t, got[1].Role, got[1].Content)
|
|
require.Len(t, syntheticParts, 1)
|
|
require.Equal(t, "call-1", syntheticParts[0].ToolCallID)
|
|
}
|
|
|
|
func TestBufferedPartsToPartialMessages_MergesToolCallDeltasWithoutFinal(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
parts := []messagepartbuffer.Part{
|
|
{Seq: 1, Role: codersdk.ChatMessageRoleAssistant, MessagePart: codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeToolCall, ToolCallID: "call-1", ToolName: "execute", ArgsDelta: `{"cmd":`}},
|
|
{Seq: 2, Role: codersdk.ChatMessageRoleAssistant, MessagePart: codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeToolCall, ToolCallID: "call-1", ToolName: "execute", ArgsDelta: `"pwd"}`}},
|
|
}
|
|
got, err := bufferedPartsToPartialMessages(bufferedPartsToPartialMessagesInput{
|
|
parts: parts,
|
|
modelConfigID: uuid.New(),
|
|
contentVersion: chatprompt.CurrentContentVersion,
|
|
logger: slog.Make(),
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, got, 2)
|
|
assistantParts := parseMessageParts(t, got[0].Role, got[0].Content)
|
|
require.Len(t, assistantParts, 1)
|
|
require.Empty(t, assistantParts[0].ArgsDelta)
|
|
require.JSONEq(t, `{"cmd":"pwd"}`, string(assistantParts[0].Args))
|
|
syntheticParts := parseMessageParts(t, got[1].Role, got[1].Content)
|
|
require.Len(t, syntheticParts, 1)
|
|
require.Equal(t, "call-1", syntheticParts[0].ToolCallID)
|
|
}
|
|
|
|
func TestBufferedPartsToPartialMessages_DeltaOnlyToolResultDoesNotAnswer(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
logSink := &partialConversionLogSink{}
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).AppendSinks(logSink)
|
|
parts := []messagepartbuffer.Part{
|
|
{Seq: 1, Role: codersdk.ChatMessageRoleAssistant, MessagePart: codersdk.ChatMessageToolCall("call-1", "advisor", json.RawMessage(`{}`))},
|
|
{Seq: 2, Role: codersdk.ChatMessageRoleTool, MessagePart: codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeToolResult, ToolCallID: "call-1", ToolName: "advisor", ResultDelta: `{"type":"advice"}`}},
|
|
}
|
|
got, err := bufferedPartsToPartialMessages(bufferedPartsToPartialMessagesInput{
|
|
parts: parts,
|
|
modelConfigID: uuid.New(),
|
|
contentVersion: chatprompt.CurrentContentVersion,
|
|
logger: logger,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, got, 2)
|
|
toolParts := parseMessageParts(t, got[1].Role, got[1].Content)
|
|
require.Len(t, toolParts, 1)
|
|
require.Equal(t, "call-1", toolParts[0].ToolCallID)
|
|
require.True(t, toolParts[0].IsError)
|
|
require.Empty(t, toolParts[0].ResultDelta)
|
|
require.JSONEq(t, `{"error":"tool call was interrupted before it produced a result"}`, string(toolParts[0].Result))
|
|
require.NotEmpty(t, logSink.entriesAtLevelWithMessage(slog.LevelWarn, "skipping buffered chat message part"))
|
|
}
|
|
|
|
func TestBufferedPartsToPartialMessages_LogsMalformedSkippedParts(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
logSink := &partialConversionLogSink{}
|
|
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).AppendSinks(logSink)
|
|
parts := []messagepartbuffer.Part{
|
|
{Seq: 1, Role: codersdk.ChatMessageRoleSystem, MessagePart: codersdk.ChatMessageText("bad role")},
|
|
{Seq: 2, Role: codersdk.ChatMessageRoleAssistant, MessagePart: codersdk.ChatMessagePart{}},
|
|
{Seq: 3, Role: codersdk.ChatMessageRoleTool, MessagePart: codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeToolResult, ToolName: "execute", Result: json.RawMessage(`{"ok":true}`)}},
|
|
{Seq: 4, Role: codersdk.ChatMessageRoleAssistant, MessagePart: codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeToolCall, ToolCallID: "bad-args", ToolName: "execute", ArgsDelta: `{"cmd":`}},
|
|
}
|
|
got, err := bufferedPartsToPartialMessages(bufferedPartsToPartialMessagesInput{
|
|
parts: parts,
|
|
modelConfigID: uuid.New(),
|
|
contentVersion: chatprompt.CurrentContentVersion,
|
|
logger: logger,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Empty(t, got)
|
|
require.GreaterOrEqual(t, len(logSink.entriesAtLevelWithMessage(slog.LevelWarn, "skipping buffered chat message part")), 4)
|
|
}
|
|
|
|
func TestBufferedPartsToPartialMessages_SynthesizesMissingToolResults(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
modelConfigID := uuid.New()
|
|
createdAt := time.Date(2026, 3, 4, 5, 6, 7, 0, time.UTC)
|
|
reasoningStartedAt := createdAt.Add(-2 * time.Second)
|
|
reasoningPart := codersdk.ChatMessageReasoning("partial thought")
|
|
reasoningPart.CreatedAt = &reasoningStartedAt
|
|
parts := []messagepartbuffer.Part{
|
|
{Seq: 1, Role: codersdk.ChatMessageRoleAssistant, MessagePart: codersdk.ChatMessageText("partial ")},
|
|
{Seq: 2, Role: codersdk.ChatMessageRoleAssistant, MessagePart: reasoningPart},
|
|
{Seq: 3, Role: codersdk.ChatMessageRoleAssistant, MessagePart: codersdk.ChatMessageToolCall("call-1", "execute", json.RawMessage(`{}`))},
|
|
{Seq: 4, Role: codersdk.ChatMessageRoleAssistant, MessagePart: codersdk.ChatMessageToolCall("call-2", "read_file", json.RawMessage(`{}`))},
|
|
{Seq: 5, Role: codersdk.ChatMessageRoleTool, MessagePart: withCreatedAt(codersdk.ChatMessageToolResult("call-2", "read_file", json.RawMessage(`{"ok":true}`), false, false), createdAt)},
|
|
}
|
|
got, err := bufferedPartsToPartialMessages(bufferedPartsToPartialMessagesInput{
|
|
parts: parts,
|
|
modelConfigID: modelConfigID,
|
|
contentVersion: chatprompt.CurrentContentVersion,
|
|
logger: slog.Make(),
|
|
interruptedAt: createdAt,
|
|
})
|
|
require.NoError(t, err)
|
|
require.Len(t, got, 3)
|
|
require.Equal(t, database.ChatMessageRoleAssistant, got[0].Role)
|
|
assistantParts := parseMessageParts(t, got[0].Role, got[0].Content)
|
|
require.Len(t, assistantParts, 4)
|
|
require.Equal(t, codersdk.ChatMessagePartTypeReasoning, assistantParts[1].Type)
|
|
require.Equal(t, "partial thought", assistantParts[1].Text)
|
|
require.Equal(t, reasoningStartedAt, requireNotNilTime(t, assistantParts[1].CreatedAt))
|
|
require.Equal(t, createdAt, requireNotNilTime(t, assistantParts[1].CompletedAt))
|
|
require.Equal(t, codersdk.ChatMessagePartTypeToolCall, assistantParts[2].Type)
|
|
require.Equal(t, codersdk.ChatMessagePartTypeToolCall, assistantParts[3].Type)
|
|
|
|
require.Equal(t, database.ChatMessageRoleTool, got[1].Role)
|
|
toolParts := parseMessageParts(t, got[1].Role, got[1].Content)
|
|
require.Equal(t, "call-2", toolParts[0].ToolCallID)
|
|
require.Equal(t, createdAt, requireNotNilTime(t, toolParts[0].CreatedAt))
|
|
|
|
require.Equal(t, database.ChatMessageRoleTool, got[2].Role)
|
|
syntheticParts := parseMessageParts(t, got[2].Role, got[2].Content)
|
|
require.Len(t, syntheticParts, 1)
|
|
require.Equal(t, "call-1", syntheticParts[0].ToolCallID)
|
|
require.Equal(t, "execute", syntheticParts[0].ToolName)
|
|
require.True(t, syntheticParts[0].IsError)
|
|
require.JSONEq(t, `{"error":"tool call was interrupted before it produced a result"}`, string(syntheticParts[0].Result))
|
|
require.Equal(t, createdAt, requireNotNilTime(t, syntheticParts[0].CreatedAt))
|
|
require.Equal(t, uuid.NullUUID{UUID: modelConfigID, Valid: true}, got[2].ModelConfigID)
|
|
}
|
|
|
|
func parseMessageParts(t *testing.T, role database.ChatMessageRole, raw pqtype.NullRawMessage) []codersdk.ChatMessagePart {
|
|
t.Helper()
|
|
parts, err := chatprompt.ParseContent(database.ChatMessage{
|
|
Role: role,
|
|
Content: raw,
|
|
})
|
|
require.NoError(t, err)
|
|
return parts
|
|
}
|
|
|
|
func dbMessage(t *testing.T, id int64, role database.ChatMessageRole, compressed bool, parts ...codersdk.ChatMessagePart) database.ChatMessage {
|
|
t.Helper()
|
|
raw, err := chatprompt.MarshalParts(parts)
|
|
require.NoError(t, err)
|
|
return database.ChatMessage{
|
|
ID: id,
|
|
Role: role,
|
|
Content: raw,
|
|
ContentVersion: chatprompt.CurrentContentVersion,
|
|
Visibility: database.ChatMessageVisibilityBoth,
|
|
Compressed: compressed,
|
|
}
|
|
}
|
|
|
|
func requireNotNilTime(t *testing.T, value *time.Time) time.Time {
|
|
t.Helper()
|
|
require.NotNil(t, value)
|
|
return *value
|
|
}
|
|
|
|
func withCreatedAt(part codersdk.ChatMessagePart, createdAt time.Time) codersdk.ChatMessagePart {
|
|
part.CreatedAt = &createdAt
|
|
return part
|
|
}
|
|
|
|
type partialConversionLogSink struct {
|
|
mu sync.Mutex
|
|
entries []slog.SinkEntry
|
|
}
|
|
|
|
func (s *partialConversionLogSink) LogEntry(_ context.Context, entry slog.SinkEntry) {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
s.entries = append(s.entries, entry)
|
|
}
|
|
|
|
func (*partialConversionLogSink) Sync() {}
|
|
|
|
func (s *partialConversionLogSink) entriesAtLevelWithMessage(level slog.Level, message string) []slog.SinkEntry {
|
|
s.mu.Lock()
|
|
defer s.mu.Unlock()
|
|
|
|
entries := make([]slog.SinkEntry, 0, len(s.entries))
|
|
for _, entry := range s.entries {
|
|
if entry.Level == level && entry.Message == message {
|
|
entries = append(entries, entry)
|
|
}
|
|
}
|
|
return entries
|
|
}
|