mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
fix: save media message type to db (#23427)
We had a bug where computer use base64-encoded screenshots would not be interpreted as screenshots anymore once saved to the db, loaded back into memory, and sent to Anthropic. Instead, they would be interpreted as regular text. Once a computer use agent made enough screenshots and stopped, and you tried sending it another message, you'd get an out of context error: <img width="808" height="367" alt="Screenshot 2026-03-23 at 12 02 54" src="https://github.com/user-attachments/assets/f0bf6be2-4863-47ca-a7a9-9e6d9dfceeed" /> This PR fixes that.
This commit is contained in:
@@ -3693,7 +3693,7 @@ func (p *Server) persistChatContextSummary(
|
||||
return xerrors.Errorf("encode summary result payload: %w", err)
|
||||
}
|
||||
toolResult, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageToolResult(toolCallID, "chat_summarized", summaryResult, false),
|
||||
codersdk.ChatMessageToolResult(toolCallID, "chat_summarized", summaryResult, false, false),
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("encode summary tool result: %w", err)
|
||||
|
||||
@@ -160,7 +160,7 @@ func tryCompact(
|
||||
})
|
||||
config.PublishMessagePart(
|
||||
codersdk.ChatMessageRoleTool,
|
||||
codersdk.ChatMessageToolResult(config.ToolCallID, config.ToolName, resultJSON, false),
|
||||
codersdk.ChatMessageToolResult(config.ToolCallID, config.ToolName, resultJSON, false, false),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -178,7 +178,7 @@ func publishCompactionError(config CompactionOptions, msg string) {
|
||||
})
|
||||
config.PublishMessagePart(
|
||||
codersdk.ChatMessageRoleTool,
|
||||
codersdk.ChatMessageToolResult(config.ToolCallID, config.ToolName, errJSON, true),
|
||||
codersdk.ChatMessageToolResult(config.ToolCallID, config.ToolName, errJSON, true, false),
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
@@ -405,7 +405,7 @@ func parseToolRole(raw pqtype.NullRawMessage) ([]codersdk.ChatMessagePart, error
|
||||
}
|
||||
parts = make([]codersdk.ChatMessagePart, 0, len(rows))
|
||||
for _, row := range rows {
|
||||
part := codersdk.ChatMessageToolResult(row.ToolCallID, row.ToolName, row.Result, row.IsError)
|
||||
part := codersdk.ChatMessageToolResult(row.ToolCallID, row.ToolName, row.Result, row.IsError, row.IsMedia)
|
||||
part.ProviderExecuted = row.ProviderExecuted
|
||||
part.ProviderMetadata = row.ProviderMetadata
|
||||
parts = append(parts, part)
|
||||
@@ -528,6 +528,7 @@ type toolResultRaw struct {
|
||||
ToolName string `json:"tool_name"`
|
||||
Result json.RawMessage `json:"result"`
|
||||
IsError bool `json:"is_error,omitempty"`
|
||||
IsMedia bool `json:"is_media,omitempty"`
|
||||
ProviderExecuted bool `json:"provider_executed,omitempty"`
|
||||
ProviderMetadata json.RawMessage `json:"provider_metadata,omitempty"`
|
||||
}
|
||||
@@ -669,8 +670,8 @@ func MarshalContent(blocks []fantasy.Content, fileIDs map[int]uuid.UUID) (pqtype
|
||||
// tool-row format. Retained for test fixtures that create
|
||||
// legacy-format DB rows. Production write paths use MarshalParts.
|
||||
// The stored shape is
|
||||
// [{"tool_call_id":…,"tool_name":…,"result":…,"is_error":…}].
|
||||
func MarshalToolResult(toolCallID, toolName string, result json.RawMessage, isError bool, providerExecuted bool, providerMetadata fantasy.ProviderMetadata) (pqtype.NullRawMessage, error) {
|
||||
// [{"tool_call_id":…,"tool_name":…,"result":…,"is_error":…,"is_media":…}].
|
||||
func MarshalToolResult(toolCallID, toolName string, result json.RawMessage, isError bool, isMedia bool, providerExecuted bool, providerMetadata fantasy.ProviderMetadata) (pqtype.NullRawMessage, error) {
|
||||
var metaJSON json.RawMessage
|
||||
if len(providerMetadata) > 0 {
|
||||
var err error
|
||||
@@ -684,6 +685,7 @@ func MarshalToolResult(toolCallID, toolName string, result json.RawMessage, isEr
|
||||
ToolName: toolName,
|
||||
Result: result,
|
||||
IsError: isError,
|
||||
IsMedia: isMedia,
|
||||
ProviderExecuted: providerExecuted,
|
||||
ProviderMetadata: metaJSON,
|
||||
}
|
||||
@@ -779,18 +781,20 @@ func PartFromContent(block fantasy.Content) codersdk.ChatMessagePart {
|
||||
}
|
||||
}
|
||||
|
||||
// ToolResultToPart converts a tool call ID, raw result, and error
|
||||
// flag into a ChatMessagePart. This is the minimal conversion used
|
||||
// both during streaming and when reading from the database.
|
||||
func ToolResultToPart(toolCallID, toolName string, result json.RawMessage, isError bool) codersdk.ChatMessagePart {
|
||||
return codersdk.ChatMessageToolResult(toolCallID, toolName, result, isError)
|
||||
// ToolResultToPart converts a tool call ID, raw result, error flag,
|
||||
// and media flag into a ChatMessagePart. This is the minimal
|
||||
// conversion used both during streaming and when reading from the
|
||||
// database.
|
||||
func ToolResultToPart(toolCallID, toolName string, result json.RawMessage, isError bool, isMedia bool) codersdk.ChatMessagePart {
|
||||
return codersdk.ChatMessageToolResult(toolCallID, toolName, result, isError, isMedia)
|
||||
}
|
||||
|
||||
// toolResultContentToPart converts a fantasy ToolResultContent
|
||||
// directly into a ChatMessagePart without an intermediate struct.
|
||||
// toolResultContentToPart converts a fantasy ToolResultContent into a
|
||||
// ChatMessagePart.
|
||||
func toolResultContentToPart(content fantasy.ToolResultContent) codersdk.ChatMessagePart {
|
||||
var result json.RawMessage
|
||||
var isError bool
|
||||
var isMedia bool
|
||||
|
||||
switch output := content.Result.(type) {
|
||||
case fantasy.ToolResultOutputContentError:
|
||||
@@ -807,16 +811,17 @@ func toolResultContentToPart(content fantasy.ToolResultContent) codersdk.ChatMes
|
||||
result, _ = json.Marshal(map[string]any{"output": output.Text})
|
||||
}
|
||||
case fantasy.ToolResultOutputContentMedia:
|
||||
result, _ = json.Marshal(map[string]any{
|
||||
"data": output.Data,
|
||||
"mime_type": output.MediaType,
|
||||
"text": output.Text,
|
||||
isMedia = true
|
||||
result, _ = json.Marshal(persistedMediaResult{
|
||||
Data: output.Data,
|
||||
MimeType: output.MediaType,
|
||||
Text: output.Text,
|
||||
})
|
||||
default:
|
||||
result = []byte(`{}`)
|
||||
}
|
||||
|
||||
part := ToolResultToPart(content.ToolCallID, content.ToolName, result, isError)
|
||||
part := ToolResultToPart(content.ToolCallID, content.ToolName, result, isError, isMedia)
|
||||
part.ProviderExecuted = content.ProviderExecuted
|
||||
part.ProviderMetadata = marshalProviderMetadata(content.ProviderMetadata)
|
||||
return part
|
||||
@@ -1213,6 +1218,44 @@ func toolResultPartToMessagePart(logger slog.Logger, part codersdk.ChatMessagePa
|
||||
}
|
||||
}
|
||||
|
||||
// IsError takes precedence and is handled above.
|
||||
// Detect media content flagged by toolResultContentToPart.
|
||||
// Screenshots from the computer use tool are stored as
|
||||
// {"data":"<base64>","mime_type":"image/png","text":"..."}.
|
||||
// Without this detection, the entire base64 payload is sent
|
||||
// as text tokens, which quickly exceeds the context limit
|
||||
// on follow-up messages.
|
||||
if part.IsMedia {
|
||||
var media persistedMediaResult
|
||||
unmarshalErr := json.Unmarshal(part.Result, &media)
|
||||
if unmarshalErr == nil && media.Data != "" && media.MimeType != "" {
|
||||
return fantasy.ToolResultPart{
|
||||
ToolCallID: toolCallID,
|
||||
ProviderExecuted: part.ProviderExecuted,
|
||||
Output: fantasy.ToolResultOutputContentMedia{
|
||||
Data: media.Data,
|
||||
MediaType: media.MimeType,
|
||||
Text: media.Text,
|
||||
},
|
||||
ProviderOptions: opts,
|
||||
}
|
||||
}
|
||||
|
||||
fields := []slog.Field{
|
||||
slog.F("tool_call_id", toolCallID),
|
||||
slog.F("tool_name", part.ToolName),
|
||||
slog.F("has_data", media.Data != ""),
|
||||
slog.F("has_mime_type", media.MimeType != ""),
|
||||
}
|
||||
if unmarshalErr != nil {
|
||||
fields = append(fields, slog.Error(unmarshalErr))
|
||||
}
|
||||
logger.Warn(context.Background(),
|
||||
"media tool result failed reconstruction, falling through to text",
|
||||
fields...,
|
||||
)
|
||||
}
|
||||
|
||||
return fantasy.ToolResultPart{
|
||||
ToolCallID: toolCallID,
|
||||
ProviderExecuted: part.ProviderExecuted,
|
||||
@@ -1223,6 +1266,21 @@ func toolResultPartToMessagePart(logger slog.Logger, part codersdk.ChatMessagePa
|
||||
}
|
||||
}
|
||||
|
||||
// persistedMediaResult is the JSON shape used to store media tool
|
||||
// results (e.g. computer-use screenshots) in the database. Both
|
||||
// the write path (toolResultContentToPart) and the read path
|
||||
// (toolResultPartToMessagePart) use this struct so the two sides
|
||||
// cannot drift.
|
||||
//
|
||||
// The "mime_type" key intentionally diverges from the fantasy
|
||||
// struct tag (json:"media_type"). Do not change it without
|
||||
// updating both paths.
|
||||
type persistedMediaResult struct {
|
||||
Data string `json:"data"`
|
||||
MimeType string `json:"mime_type"`
|
||||
Text string `json:"text"`
|
||||
}
|
||||
|
||||
// partsToMessageParts converts SDK chat message parts into fantasy
|
||||
// message parts for LLM dispatch. It handles file data injection
|
||||
// from resolved files, file-reference to text conversion, and
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
@@ -92,6 +93,7 @@ func TestConvertMessages_NormalizesAssistantToolCallInput(t *testing.T) {
|
||||
json.RawMessage(`{"error":"tool call was interrupted before it produced a result"}`),
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
@@ -297,7 +299,7 @@ func TestInjectMissingToolResults_SkipsProviderExecuted(t *testing.T) {
|
||||
localResult := mustMarshalToolResult(t,
|
||||
"toolu_local", "spawn_agent",
|
||||
json.RawMessage(`{"status":"done"}`),
|
||||
false, false,
|
||||
false, false, false,
|
||||
)
|
||||
|
||||
prompt, err := chatprompt.ConvertMessages([]database.ChatMessage{
|
||||
@@ -361,12 +363,12 @@ func TestInjectMissingToolUses_DropsProviderExecutedOrphans(t *testing.T) {
|
||||
resultA := mustMarshalToolResult(t,
|
||||
"toolu_A", "spawn_agent",
|
||||
json.RawMessage(`{"status":"done"}`),
|
||||
false, false,
|
||||
false, false, false,
|
||||
)
|
||||
resultB := mustMarshalToolResult(t,
|
||||
"toolu_B", "spawn_agent",
|
||||
json.RawMessage(`{"status":"done"}`),
|
||||
false, false,
|
||||
false, false, false,
|
||||
)
|
||||
|
||||
// Step 2: assistant with sources/text + wait_agent x2.
|
||||
@@ -389,17 +391,17 @@ func TestInjectMissingToolUses_DropsProviderExecutedOrphans(t *testing.T) {
|
||||
resultC := mustMarshalToolResult(t,
|
||||
"srvtoolu_C", "web_search",
|
||||
json.RawMessage(`{}`),
|
||||
false, true, // provider_executed = true
|
||||
false, false, true, // provider_executed = true
|
||||
)
|
||||
resultD := mustMarshalToolResult(t,
|
||||
"toolu_D", "wait_agent",
|
||||
json.RawMessage(`{"report":"done"}`),
|
||||
false, false,
|
||||
false, false, false,
|
||||
)
|
||||
resultE := mustMarshalToolResult(t,
|
||||
"toolu_E", "wait_agent",
|
||||
json.RawMessage(`{"report":"done"}`),
|
||||
false, false,
|
||||
false, false, false,
|
||||
)
|
||||
|
||||
prompt, err := chatprompt.ConvertMessages([]database.ChatMessage{
|
||||
@@ -475,7 +477,7 @@ func TestInjectMissingToolUses_DropsOnlyProviderExecutedMessage(t *testing.T) {
|
||||
localResult := mustMarshalToolResult(t,
|
||||
"toolu_local", "execute",
|
||||
json.RawMessage(`{"output":"file.txt"}`),
|
||||
false, false,
|
||||
false, false, false,
|
||||
)
|
||||
|
||||
// Second assistant with only local tool call.
|
||||
@@ -487,7 +489,7 @@ func TestInjectMissingToolUses_DropsOnlyProviderExecutedMessage(t *testing.T) {
|
||||
peResult := mustMarshalToolResult(t,
|
||||
"srvtoolu_orphan", "web_search",
|
||||
json.RawMessage(`{}`),
|
||||
false, true,
|
||||
false, false, true,
|
||||
)
|
||||
|
||||
prompt, err := chatprompt.ConvertMessages([]database.ChatMessage{
|
||||
@@ -600,12 +602,12 @@ func TestProviderExecutedResult_LegacyToolRow(t *testing.T) {
|
||||
peResult := mustMarshalToolResult(t,
|
||||
"srvtoolu_WS", "web_search",
|
||||
json.RawMessage(`{"results":"cached"}`),
|
||||
false, true, // providerExecuted = true
|
||||
false, false, true, // providerExecuted = true
|
||||
)
|
||||
execResult := mustMarshalToolResult(t,
|
||||
"toolu_exec", "execute",
|
||||
json.RawMessage(`{"output":"file.txt"}`),
|
||||
false, false,
|
||||
false, false, false,
|
||||
)
|
||||
|
||||
prompt, err := chatprompt.ConvertMessages([]database.ChatMessage{
|
||||
@@ -1193,7 +1195,7 @@ func TestMixedFormatConversation(t *testing.T) {
|
||||
// 4. Old tool (legacy result rows).
|
||||
oldToolRaw, err := chatprompt.MarshalToolResult(
|
||||
"call_1", "analyze_image",
|
||||
json.RawMessage(`{"description":"a cat"}`), false,
|
||||
json.RawMessage(`{"description":"a cat"}`), false, false,
|
||||
false, nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
@@ -1425,9 +1427,9 @@ func mustMarshalContent(t *testing.T, content []fantasy.Content) pqtype.NullRawM
|
||||
return result
|
||||
}
|
||||
|
||||
func mustMarshalToolResult(t *testing.T, toolCallID, toolName string, result json.RawMessage, isError, providerExecuted bool) pqtype.NullRawMessage {
|
||||
func mustMarshalToolResult(t *testing.T, toolCallID, toolName string, result json.RawMessage, isError, isMedia, providerExecuted bool) pqtype.NullRawMessage {
|
||||
t.Helper()
|
||||
raw, err := chatprompt.MarshalToolResult(toolCallID, toolName, result, isError, providerExecuted, nil)
|
||||
raw, err := chatprompt.MarshalToolResult(toolCallID, toolName, result, isError, isMedia, providerExecuted, nil)
|
||||
require.NoError(t, err)
|
||||
return raw
|
||||
}
|
||||
@@ -1608,7 +1610,7 @@ func TestNulEscapeRoundTrip(t *testing.T) {
|
||||
|
||||
resultJSON := json.RawMessage(`"output:\u0000done"`)
|
||||
parts := []codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageToolResult("call-1", "my_tool", resultJSON, false),
|
||||
codersdk.ChatMessageToolResult("call-1", "my_tool", resultJSON, false, false),
|
||||
}
|
||||
|
||||
encoded, err := chatprompt.MarshalParts(parts)
|
||||
@@ -1676,7 +1678,7 @@ func TestConvertMessagesWithFiles_FiltersEmptyTextAndReasoningParts(t *testing.T
|
||||
codersdk.ChatMessageText(" hello "), // kept with original whitespace
|
||||
codersdk.ChatMessageReasoning("thinking deeply"), // kept
|
||||
codersdk.ChatMessageToolCall("call-1", "my_tool", json.RawMessage(`{"x":1}`)),
|
||||
codersdk.ChatMessageToolResult("call-1", "my_tool", json.RawMessage(`{"ok":true}`), false),
|
||||
codersdk.ChatMessageToolResult("call-1", "my_tool", json.RawMessage(`{"ok":true}`), false, false),
|
||||
}
|
||||
|
||||
prompt, err := chatprompt.ConvertMessagesWithFiles(
|
||||
@@ -1926,3 +1928,400 @@ func convertSingleResolvedFileMessage(t *testing.T, fileID uuid.UUID, fileData c
|
||||
require.NoError(t, err)
|
||||
return prompt
|
||||
}
|
||||
|
||||
func TestMediaToolResultRoundTrip(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Full DB round-trip test: insert messages into PostgreSQL,
|
||||
// load them back via GetChatMessagesForPromptByChatID, and
|
||||
// verify the fantasy message parts are identical after the
|
||||
// round-trip.
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
|
||||
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
|
||||
Provider: "anthropic",
|
||||
DisplayName: "anthropic",
|
||||
APIKey: "test-key",
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
Enabled: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
model, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
|
||||
Provider: "anthropic",
|
||||
Model: "test-model",
|
||||
DisplayName: "Test Model",
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
Enabled: true,
|
||||
IsDefault: true,
|
||||
ContextLimit: 200000,
|
||||
CompressionThreshold: 70,
|
||||
Options: json.RawMessage(`{}`),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Small base64 payload standing in for a real screenshot.
|
||||
const imageData = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAAC0lEQVQI12NgAAIABQAB"
|
||||
|
||||
// insertPair writes an assistant tool-call message and a
|
||||
// tool-result message into the database, returning the chat
|
||||
// they belong to.
|
||||
insertPair := func(
|
||||
t *testing.T,
|
||||
callID, toolName string,
|
||||
resultParts []codersdk.ChatMessagePart,
|
||||
) database.Chat {
|
||||
t.Helper()
|
||||
|
||||
chat, chatErr := db.InsertChat(ctx, database.InsertChatParams{
|
||||
OwnerID: user.ID,
|
||||
LastModelConfigID: model.ID,
|
||||
Title: "media-roundtrip-" + callID,
|
||||
})
|
||||
require.NoError(t, chatErr)
|
||||
|
||||
// Assistant message with the tool call.
|
||||
callPart := codersdk.ChatMessageToolCall(callID, toolName, json.RawMessage(`{}`))
|
||||
assistantEncoded, encErr := chatprompt.MarshalParts([]codersdk.ChatMessagePart{callPart})
|
||||
require.NoError(t, encErr)
|
||||
|
||||
// Tool result message.
|
||||
resultEncoded, encErr := chatprompt.MarshalParts(resultParts)
|
||||
require.NoError(t, encErr)
|
||||
|
||||
_, insertErr := db.InsertChatMessages(ctx, database.InsertChatMessagesParams{
|
||||
ChatID: chat.ID,
|
||||
CreatedBy: []uuid.UUID{user.ID, user.ID},
|
||||
ModelConfigID: []uuid.UUID{model.ID, model.ID},
|
||||
Role: []database.ChatMessageRole{database.ChatMessageRoleAssistant, database.ChatMessageRoleTool},
|
||||
Content: []string{string(assistantEncoded.RawMessage), string(resultEncoded.RawMessage)},
|
||||
ContentVersion: []int16{chatprompt.CurrentContentVersion, chatprompt.CurrentContentVersion},
|
||||
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth, database.ChatMessageVisibilityBoth},
|
||||
InputTokens: []int64{0, 0},
|
||||
OutputTokens: []int64{0, 0},
|
||||
TotalTokens: []int64{0, 0},
|
||||
ReasoningTokens: []int64{0, 0},
|
||||
CacheCreationTokens: []int64{0, 0},
|
||||
CacheReadTokens: []int64{0, 0},
|
||||
ContextLimit: []int64{0, 0},
|
||||
Compressed: []bool{false, false},
|
||||
TotalCostMicros: []int64{0, 0},
|
||||
RuntimeMs: []int64{0, 0},
|
||||
})
|
||||
require.NoError(t, insertErr)
|
||||
return chat
|
||||
}
|
||||
|
||||
// loadPrompt reads messages back from the DB via the same
|
||||
// path used by runChat, and converts them to fantasy messages.
|
||||
loadPrompt := func(t *testing.T, chat database.Chat) []fantasy.Message {
|
||||
t.Helper()
|
||||
dbMsgs, loadErr := db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
|
||||
require.NoError(t, loadErr)
|
||||
prompt, convErr := chatprompt.ConvertMessagesWithFiles(
|
||||
ctx, dbMsgs, nil, slogtest.Make(t, nil),
|
||||
)
|
||||
require.NoError(t, convErr)
|
||||
return prompt
|
||||
}
|
||||
|
||||
t.Run("MediaResultRoundTripsAsMedia", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const callID = "call-screenshot-1"
|
||||
const toolName = "computer"
|
||||
const mimeType = "image/png"
|
||||
|
||||
// Use PartFromContent (the production write path) to
|
||||
// produce the SDK part, rather than hand-crafting JSON.
|
||||
// Computer use is a provider-defined tool, but Coder executes it
|
||||
// locally via chatloop.ProviderTool.Runner, so screenshot results
|
||||
// persist as tool-role messages with ProviderExecuted=false.
|
||||
sdkPart := chatprompt.PartFromContent(fantasy.ToolResultContent{
|
||||
ToolCallID: callID,
|
||||
ToolName: toolName,
|
||||
Result: fantasy.ToolResultOutputContentMedia{
|
||||
Data: imageData,
|
||||
MediaType: mimeType,
|
||||
},
|
||||
})
|
||||
|
||||
chat := insertPair(t, callID, toolName, []codersdk.ChatMessagePart{sdkPart})
|
||||
|
||||
prompt := loadPrompt(t, chat)
|
||||
// assistant + tool
|
||||
require.Len(t, prompt, 2)
|
||||
|
||||
toolMsg := prompt[1]
|
||||
require.Equal(t, fantasy.MessageRoleTool, toolMsg.Role)
|
||||
require.Len(t, toolMsg.Content, 1)
|
||||
|
||||
resultPart, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](toolMsg.Content[0])
|
||||
require.True(t, ok, "expected ToolResultPart")
|
||||
require.Equal(t, callID, resultPart.ToolCallID)
|
||||
require.False(t, resultPart.ProviderExecuted)
|
||||
|
||||
mediaOutput, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](resultPart.Output)
|
||||
require.True(t, ok, "expected ToolResultOutputContentMedia, got %T", resultPart.Output)
|
||||
require.Equal(t, imageData, mediaOutput.Data)
|
||||
require.Equal(t, mimeType, mediaOutput.MediaType)
|
||||
})
|
||||
|
||||
t.Run("MediaResultWithText", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const callID = "call-screenshot-2"
|
||||
const toolName = "computer"
|
||||
const mimeType = "image/png"
|
||||
|
||||
sdkPart := chatprompt.PartFromContent(fantasy.ToolResultContent{
|
||||
ToolCallID: callID,
|
||||
ToolName: toolName,
|
||||
Result: fantasy.ToolResultOutputContentMedia{
|
||||
Data: imageData,
|
||||
MediaType: mimeType,
|
||||
Text: "screenshot after click",
|
||||
},
|
||||
})
|
||||
|
||||
chat := insertPair(t, callID, toolName, []codersdk.ChatMessagePart{sdkPart})
|
||||
|
||||
prompt := loadPrompt(t, chat)
|
||||
require.Len(t, prompt, 2)
|
||||
|
||||
resultPart, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](prompt[1].Content[0])
|
||||
require.True(t, ok)
|
||||
require.False(t, resultPart.ProviderExecuted)
|
||||
|
||||
mediaOutput, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](resultPart.Output)
|
||||
require.True(t, ok, "expected media output")
|
||||
require.Equal(t, imageData, mediaOutput.Data)
|
||||
require.Equal(t, mimeType, mediaOutput.MediaType)
|
||||
require.Equal(t, "screenshot after click", mediaOutput.Text)
|
||||
})
|
||||
|
||||
t.Run("TextResultStaysText", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const callID = "call-text-1"
|
||||
const toolName = "read_file"
|
||||
|
||||
textResult := json.RawMessage(`{"output":"file contents here"}`)
|
||||
|
||||
chat := insertPair(t, callID, toolName, []codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageToolResult(callID, toolName, textResult, false, false),
|
||||
})
|
||||
|
||||
prompt := loadPrompt(t, chat)
|
||||
require.Len(t, prompt, 2)
|
||||
|
||||
resultPart, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](prompt[1].Content[0])
|
||||
require.True(t, ok)
|
||||
|
||||
_, isMedia := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](resultPart.Output)
|
||||
require.False(t, isMedia, "text result should not be detected as media")
|
||||
|
||||
textOutput, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](resultPart.Output)
|
||||
require.True(t, ok, "expected ToolResultOutputContentText")
|
||||
require.JSONEq(t, string(textResult), textOutput.Text)
|
||||
})
|
||||
|
||||
t.Run("MissingMimeTypeStaysText", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const callID = "call-no-mime"
|
||||
const toolName = "computer"
|
||||
|
||||
noMimeJSON := json.RawMessage(`{"data":"some_base64","text":""}`)
|
||||
|
||||
chat := insertPair(t, callID, toolName, []codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageToolResult(callID, toolName, noMimeJSON, false, false),
|
||||
})
|
||||
|
||||
prompt := loadPrompt(t, chat)
|
||||
require.Len(t, prompt, 2)
|
||||
|
||||
resultPart, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](prompt[1].Content[0])
|
||||
require.True(t, ok)
|
||||
|
||||
_, isMedia := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](resultPart.Output)
|
||||
require.False(t, isMedia, "missing mime_type should not produce media")
|
||||
})
|
||||
|
||||
t.Run("MissingDataStaysText", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const callID = "call-no-data"
|
||||
const toolName = "computer"
|
||||
|
||||
noDataJSON := json.RawMessage(`{"mime_type":"image/png","text":""}`)
|
||||
|
||||
chat := insertPair(t, callID, toolName, []codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageToolResult(callID, toolName, noDataJSON, false, false),
|
||||
})
|
||||
|
||||
prompt := loadPrompt(t, chat)
|
||||
require.Len(t, prompt, 2)
|
||||
|
||||
resultPart, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](prompt[1].Content[0])
|
||||
require.True(t, ok)
|
||||
|
||||
_, isMedia := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](resultPart.Output)
|
||||
require.False(t, isMedia, "missing data should not produce media")
|
||||
})
|
||||
|
||||
t.Run("ErrorResultStaysError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const callID = "call-err"
|
||||
const toolName = "computer"
|
||||
|
||||
// Use PartFromContent to go through the production
|
||||
// write path for error results.
|
||||
sdkPart := chatprompt.PartFromContent(fantasy.ToolResultContent{
|
||||
ToolCallID: callID,
|
||||
ToolName: toolName,
|
||||
Result: fantasy.ToolResultOutputContentError{
|
||||
Error: xerrors.New("screenshot failed"),
|
||||
},
|
||||
})
|
||||
|
||||
chat := insertPair(t, callID, toolName, []codersdk.ChatMessagePart{sdkPart})
|
||||
|
||||
prompt := loadPrompt(t, chat)
|
||||
require.Len(t, prompt, 2)
|
||||
|
||||
resultPart, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](prompt[1].Content[0])
|
||||
require.True(t, ok)
|
||||
|
||||
errOutput, isError := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](resultPart.Output)
|
||||
require.True(t, isError, "error result should remain error")
|
||||
require.Contains(t, errOutput.Error.Error(), "screenshot failed")
|
||||
})
|
||||
|
||||
t.Run("NonMediaResultTypeStaysText", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// A text tool result that happens to contain "data" and
|
||||
// "mime_type" fields must NOT be misidentified as media
|
||||
// when IsMedia is false. The protection is entirely the
|
||||
// IsMedia boolean flag on the ChatMessagePart.
|
||||
const callID = "call-not-media"
|
||||
const toolName = "list_files"
|
||||
|
||||
textJSON, jsonErr := json.Marshal(map[string]any{
|
||||
"result_type": "listing",
|
||||
"data": "file1.txt",
|
||||
"mime_type": "text/csv",
|
||||
})
|
||||
require.NoError(t, jsonErr)
|
||||
|
||||
chat := insertPair(t, callID, toolName, []codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageToolResult(callID, toolName, textJSON, false, false),
|
||||
})
|
||||
|
||||
prompt := loadPrompt(t, chat)
|
||||
require.Len(t, prompt, 2)
|
||||
|
||||
resultPart, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](prompt[1].Content[0])
|
||||
require.True(t, ok)
|
||||
|
||||
_, isMedia := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](resultPart.Output)
|
||||
require.False(t, isMedia, "non-media result_type must not be detected as media")
|
||||
|
||||
textOutput, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](resultPart.Output)
|
||||
require.True(t, ok, "expected ToolResultOutputContentText")
|
||||
require.JSONEq(t, string(textJSON), textOutput.Text)
|
||||
})
|
||||
|
||||
t.Run("IsMediaTrueButMissingMimeType", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// IsMedia is true but the JSON payload has no mime_type
|
||||
// field. The media reconstruction guard should fail and
|
||||
// the result should fall through to text.
|
||||
const callID = "call-media-no-mime"
|
||||
const toolName = "computer"
|
||||
|
||||
noMimeJSON := json.RawMessage(`{"data":"some_base64","text":""}`)
|
||||
|
||||
chat := insertPair(t, callID, toolName, []codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageToolResult(callID, toolName, noMimeJSON, false, true),
|
||||
})
|
||||
|
||||
prompt := loadPrompt(t, chat)
|
||||
require.Len(t, prompt, 2)
|
||||
|
||||
resultPart, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](prompt[1].Content[0])
|
||||
require.True(t, ok)
|
||||
|
||||
_, isMedia := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](resultPart.Output)
|
||||
require.False(t, isMedia, "IsMedia=true with missing mime_type should fall through to text")
|
||||
|
||||
_, isText := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](resultPart.Output)
|
||||
require.True(t, isText, "expected ToolResultOutputContentText")
|
||||
})
|
||||
|
||||
t.Run("IsMediaTrueButMissingData", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// IsMedia is true but the JSON payload has no data field.
|
||||
// The media reconstruction guard should fail and the result
|
||||
// should fall through to text.
|
||||
const callID = "call-media-no-data"
|
||||
const toolName = "computer"
|
||||
|
||||
noDataJSON := json.RawMessage(`{"mime_type":"image/png","text":""}`)
|
||||
|
||||
chat := insertPair(t, callID, toolName, []codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageToolResult(callID, toolName, noDataJSON, false, true),
|
||||
})
|
||||
|
||||
prompt := loadPrompt(t, chat)
|
||||
require.Len(t, prompt, 2)
|
||||
|
||||
resultPart, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](prompt[1].Content[0])
|
||||
require.True(t, ok)
|
||||
|
||||
_, isMedia := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](resultPart.Output)
|
||||
require.False(t, isMedia, "IsMedia=true with missing data should fall through to text")
|
||||
|
||||
_, isText := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](resultPart.Output)
|
||||
require.True(t, isText, "expected ToolResultOutputContentText")
|
||||
})
|
||||
|
||||
t.Run("IsMediaTrueButGarbageJSON", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// IsMedia is true but the result is a JSON string, not
|
||||
// an object. Unmarshal into persistedMediaResult fails
|
||||
// and the result should fall through to text. Truly
|
||||
// invalid JSON cannot reach the read path because both
|
||||
// MarshalParts and PostgreSQL jsonb reject it, so a
|
||||
// non-object JSON value is the realistic edge case.
|
||||
const callID = "call-media-garbage"
|
||||
const toolName = "computer"
|
||||
|
||||
garbageJSON := json.RawMessage(`"not a json object"`)
|
||||
|
||||
chat := insertPair(t, callID, toolName, []codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageToolResult(callID, toolName, garbageJSON, false, true),
|
||||
})
|
||||
|
||||
prompt := loadPrompt(t, chat)
|
||||
require.Len(t, prompt, 2)
|
||||
|
||||
resultPart, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](prompt[1].Content[0])
|
||||
require.True(t, ok)
|
||||
|
||||
_, isMedia := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](resultPart.Output)
|
||||
require.False(t, isMedia, "IsMedia=true with garbage JSON should fall through to text")
|
||||
|
||||
_, isText := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](resultPart.Output)
|
||||
require.True(t, isText, "expected ToolResultOutputContentText")
|
||||
})
|
||||
}
|
||||
|
||||
+6
-1
@@ -159,6 +159,7 @@ type ChatMessagePart struct {
|
||||
Result json.RawMessage `json:"result,omitempty" variants:"tool-result?"`
|
||||
ResultDelta string `json:"result_delta,omitempty"`
|
||||
IsError bool `json:"is_error,omitempty" variants:"tool-result?"`
|
||||
IsMedia bool `json:"is_media,omitempty" variants:"tool-result?"`
|
||||
SourceID string `json:"source_id,omitempty" variants:"source?"`
|
||||
URL string `json:"url" variants:"source"`
|
||||
Title string `json:"title,omitempty" variants:"source?"`
|
||||
@@ -241,13 +242,17 @@ func ChatMessageToolCall(toolCallID, toolName string, args json.RawMessage) Chat
|
||||
}
|
||||
|
||||
// ChatMessageToolResult builds a tool-result chat message part.
|
||||
func ChatMessageToolResult(toolCallID, toolName string, result json.RawMessage, isError bool) ChatMessagePart {
|
||||
// The isMedia flag marks the result as carrying binary media content
|
||||
// (e.g. a screenshot) so that round-trip reconstruction preserves
|
||||
// the media type instead of sending raw base64 as text tokens.
|
||||
func ChatMessageToolResult(toolCallID, toolName string, result json.RawMessage, isError bool, isMedia bool) ChatMessagePart {
|
||||
return ChatMessagePart{
|
||||
Type: ChatMessagePartTypeToolResult,
|
||||
ToolCallID: toolCallID,
|
||||
ToolName: toolName,
|
||||
Result: result,
|
||||
IsError: isError,
|
||||
IsMedia: isMedia,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Generated
+1
@@ -1944,6 +1944,7 @@ export interface ChatToolResultPart {
|
||||
readonly mcp_server_config_id?: string;
|
||||
readonly result?: Record<string, string>;
|
||||
readonly is_error?: boolean;
|
||||
readonly is_media?: boolean;
|
||||
/**
|
||||
* ProviderExecuted indicates the tool call was executed by
|
||||
* the provider (e.g. Anthropic computer use).
|
||||
|
||||
Reference in New Issue
Block a user