fix: save media message type to db (#23427)

We had a bug where computer use base64-encoded screenshots would not be
interpreted as screenshots anymore once saved to the db, loaded back
into memory, and sent to Anthropic. Instead, they would be interpreted
as regular text. Once a computer use agent made enough screenshots and
stopped, and you tried sending it another message, you'd get an out of
context error:

<img width="808" height="367" alt="Screenshot 2026-03-23 at 12 02 54"
src="https://github.com/user-attachments/assets/f0bf6be2-4863-47ca-a7a9-9e6d9dfceeed"
/>

This PR fixes that.
This commit is contained in:
Hugo Dutka
2026-03-25 18:11:21 +01:00
committed by GitHub
parent d9fc5a5be1
commit 84740f4619
6 changed files with 497 additions and 34 deletions
+1 -1
View File
@@ -3693,7 +3693,7 @@ func (p *Server) persistChatContextSummary(
return xerrors.Errorf("encode summary result payload: %w", err)
}
toolResult, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
codersdk.ChatMessageToolResult(toolCallID, "chat_summarized", summaryResult, false),
codersdk.ChatMessageToolResult(toolCallID, "chat_summarized", summaryResult, false, false),
})
if err != nil {
return xerrors.Errorf("encode summary tool result: %w", err)
+2 -2
View File
@@ -160,7 +160,7 @@ func tryCompact(
})
config.PublishMessagePart(
codersdk.ChatMessageRoleTool,
codersdk.ChatMessageToolResult(config.ToolCallID, config.ToolName, resultJSON, false),
codersdk.ChatMessageToolResult(config.ToolCallID, config.ToolName, resultJSON, false, false),
)
}
@@ -178,7 +178,7 @@ func publishCompactionError(config CompactionOptions, msg string) {
})
config.PublishMessagePart(
codersdk.ChatMessageRoleTool,
codersdk.ChatMessageToolResult(config.ToolCallID, config.ToolName, errJSON, true),
codersdk.ChatMessageToolResult(config.ToolCallID, config.ToolName, errJSON, true, false),
)
}
+73 -15
View File
@@ -405,7 +405,7 @@ func parseToolRole(raw pqtype.NullRawMessage) ([]codersdk.ChatMessagePart, error
}
parts = make([]codersdk.ChatMessagePart, 0, len(rows))
for _, row := range rows {
part := codersdk.ChatMessageToolResult(row.ToolCallID, row.ToolName, row.Result, row.IsError)
part := codersdk.ChatMessageToolResult(row.ToolCallID, row.ToolName, row.Result, row.IsError, row.IsMedia)
part.ProviderExecuted = row.ProviderExecuted
part.ProviderMetadata = row.ProviderMetadata
parts = append(parts, part)
@@ -528,6 +528,7 @@ type toolResultRaw struct {
ToolName string `json:"tool_name"`
Result json.RawMessage `json:"result"`
IsError bool `json:"is_error,omitempty"`
IsMedia bool `json:"is_media,omitempty"`
ProviderExecuted bool `json:"provider_executed,omitempty"`
ProviderMetadata json.RawMessage `json:"provider_metadata,omitempty"`
}
@@ -669,8 +670,8 @@ func MarshalContent(blocks []fantasy.Content, fileIDs map[int]uuid.UUID) (pqtype
// tool-row format. Retained for test fixtures that create
// legacy-format DB rows. Production write paths use MarshalParts.
// The stored shape is
// [{"tool_call_id":…,"tool_name":…,"result":…,"is_error":…}].
func MarshalToolResult(toolCallID, toolName string, result json.RawMessage, isError bool, providerExecuted bool, providerMetadata fantasy.ProviderMetadata) (pqtype.NullRawMessage, error) {
// [{"tool_call_id":…,"tool_name":…,"result":…,"is_error":…,"is_media":…}].
func MarshalToolResult(toolCallID, toolName string, result json.RawMessage, isError bool, isMedia bool, providerExecuted bool, providerMetadata fantasy.ProviderMetadata) (pqtype.NullRawMessage, error) {
var metaJSON json.RawMessage
if len(providerMetadata) > 0 {
var err error
@@ -684,6 +685,7 @@ func MarshalToolResult(toolCallID, toolName string, result json.RawMessage, isEr
ToolName: toolName,
Result: result,
IsError: isError,
IsMedia: isMedia,
ProviderExecuted: providerExecuted,
ProviderMetadata: metaJSON,
}
@@ -779,18 +781,20 @@ func PartFromContent(block fantasy.Content) codersdk.ChatMessagePart {
}
}
// ToolResultToPart converts a tool call ID, raw result, and error
// flag into a ChatMessagePart. This is the minimal conversion used
// both during streaming and when reading from the database.
func ToolResultToPart(toolCallID, toolName string, result json.RawMessage, isError bool) codersdk.ChatMessagePart {
return codersdk.ChatMessageToolResult(toolCallID, toolName, result, isError)
// ToolResultToPart converts a tool call ID, raw result, error flag,
// and media flag into a ChatMessagePart. This is the minimal
// conversion used both during streaming and when reading from the
// database.
func ToolResultToPart(toolCallID, toolName string, result json.RawMessage, isError bool, isMedia bool) codersdk.ChatMessagePart {
return codersdk.ChatMessageToolResult(toolCallID, toolName, result, isError, isMedia)
}
// toolResultContentToPart converts a fantasy ToolResultContent
// directly into a ChatMessagePart without an intermediate struct.
// toolResultContentToPart converts a fantasy ToolResultContent into a
// ChatMessagePart.
func toolResultContentToPart(content fantasy.ToolResultContent) codersdk.ChatMessagePart {
var result json.RawMessage
var isError bool
var isMedia bool
switch output := content.Result.(type) {
case fantasy.ToolResultOutputContentError:
@@ -807,16 +811,17 @@ func toolResultContentToPart(content fantasy.ToolResultContent) codersdk.ChatMes
result, _ = json.Marshal(map[string]any{"output": output.Text})
}
case fantasy.ToolResultOutputContentMedia:
result, _ = json.Marshal(map[string]any{
"data": output.Data,
"mime_type": output.MediaType,
"text": output.Text,
isMedia = true
result, _ = json.Marshal(persistedMediaResult{
Data: output.Data,
MimeType: output.MediaType,
Text: output.Text,
})
default:
result = []byte(`{}`)
}
part := ToolResultToPart(content.ToolCallID, content.ToolName, result, isError)
part := ToolResultToPart(content.ToolCallID, content.ToolName, result, isError, isMedia)
part.ProviderExecuted = content.ProviderExecuted
part.ProviderMetadata = marshalProviderMetadata(content.ProviderMetadata)
return part
@@ -1213,6 +1218,44 @@ func toolResultPartToMessagePart(logger slog.Logger, part codersdk.ChatMessagePa
}
}
// IsError takes precedence and is handled above.
// Detect media content flagged by toolResultContentToPart.
// Screenshots from the computer use tool are stored as
// {"data":"<base64>","mime_type":"image/png","text":"..."}.
// Without this detection, the entire base64 payload is sent
// as text tokens, which quickly exceeds the context limit
// on follow-up messages.
if part.IsMedia {
var media persistedMediaResult
unmarshalErr := json.Unmarshal(part.Result, &media)
if unmarshalErr == nil && media.Data != "" && media.MimeType != "" {
return fantasy.ToolResultPart{
ToolCallID: toolCallID,
ProviderExecuted: part.ProviderExecuted,
Output: fantasy.ToolResultOutputContentMedia{
Data: media.Data,
MediaType: media.MimeType,
Text: media.Text,
},
ProviderOptions: opts,
}
}
fields := []slog.Field{
slog.F("tool_call_id", toolCallID),
slog.F("tool_name", part.ToolName),
slog.F("has_data", media.Data != ""),
slog.F("has_mime_type", media.MimeType != ""),
}
if unmarshalErr != nil {
fields = append(fields, slog.Error(unmarshalErr))
}
logger.Warn(context.Background(),
"media tool result failed reconstruction, falling through to text",
fields...,
)
}
return fantasy.ToolResultPart{
ToolCallID: toolCallID,
ProviderExecuted: part.ProviderExecuted,
@@ -1223,6 +1266,21 @@ func toolResultPartToMessagePart(logger slog.Logger, part codersdk.ChatMessagePa
}
}
// persistedMediaResult is the JSON shape used to store media tool
// results (e.g. computer-use screenshots) in the database. Both
// the write path (toolResultContentToPart) and the read path
// (toolResultPartToMessagePart) use this struct so the two sides
// cannot drift.
//
// The "mime_type" key intentionally diverges from the fantasy
// struct tag (json:"media_type"). Do not change it without
// updating both paths.
type persistedMediaResult struct {
Data string `json:"data"`
MimeType string `json:"mime_type"`
Text string `json:"text"`
}
// partsToMessageParts converts SDK chat message parts into fantasy
// message parts for LLM dispatch. It handles file data injection
// from resolved files, file-reference to text conversion, and
+414 -15
View File
@@ -13,6 +13,7 @@ import (
"github.com/sqlc-dev/pqtype"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/database"
@@ -92,6 +93,7 @@ func TestConvertMessages_NormalizesAssistantToolCallInput(t *testing.T) {
json.RawMessage(`{"error":"tool call was interrupted before it produced a result"}`),
true,
false,
false,
nil,
)
require.NoError(t, err)
@@ -297,7 +299,7 @@ func TestInjectMissingToolResults_SkipsProviderExecuted(t *testing.T) {
localResult := mustMarshalToolResult(t,
"toolu_local", "spawn_agent",
json.RawMessage(`{"status":"done"}`),
false, false,
false, false, false,
)
prompt, err := chatprompt.ConvertMessages([]database.ChatMessage{
@@ -361,12 +363,12 @@ func TestInjectMissingToolUses_DropsProviderExecutedOrphans(t *testing.T) {
resultA := mustMarshalToolResult(t,
"toolu_A", "spawn_agent",
json.RawMessage(`{"status":"done"}`),
false, false,
false, false, false,
)
resultB := mustMarshalToolResult(t,
"toolu_B", "spawn_agent",
json.RawMessage(`{"status":"done"}`),
false, false,
false, false, false,
)
// Step 2: assistant with sources/text + wait_agent x2.
@@ -389,17 +391,17 @@ func TestInjectMissingToolUses_DropsProviderExecutedOrphans(t *testing.T) {
resultC := mustMarshalToolResult(t,
"srvtoolu_C", "web_search",
json.RawMessage(`{}`),
false, true, // provider_executed = true
false, false, true, // provider_executed = true
)
resultD := mustMarshalToolResult(t,
"toolu_D", "wait_agent",
json.RawMessage(`{"report":"done"}`),
false, false,
false, false, false,
)
resultE := mustMarshalToolResult(t,
"toolu_E", "wait_agent",
json.RawMessage(`{"report":"done"}`),
false, false,
false, false, false,
)
prompt, err := chatprompt.ConvertMessages([]database.ChatMessage{
@@ -475,7 +477,7 @@ func TestInjectMissingToolUses_DropsOnlyProviderExecutedMessage(t *testing.T) {
localResult := mustMarshalToolResult(t,
"toolu_local", "execute",
json.RawMessage(`{"output":"file.txt"}`),
false, false,
false, false, false,
)
// Second assistant with only local tool call.
@@ -487,7 +489,7 @@ func TestInjectMissingToolUses_DropsOnlyProviderExecutedMessage(t *testing.T) {
peResult := mustMarshalToolResult(t,
"srvtoolu_orphan", "web_search",
json.RawMessage(`{}`),
false, true,
false, false, true,
)
prompt, err := chatprompt.ConvertMessages([]database.ChatMessage{
@@ -600,12 +602,12 @@ func TestProviderExecutedResult_LegacyToolRow(t *testing.T) {
peResult := mustMarshalToolResult(t,
"srvtoolu_WS", "web_search",
json.RawMessage(`{"results":"cached"}`),
false, true, // providerExecuted = true
false, false, true, // providerExecuted = true
)
execResult := mustMarshalToolResult(t,
"toolu_exec", "execute",
json.RawMessage(`{"output":"file.txt"}`),
false, false,
false, false, false,
)
prompt, err := chatprompt.ConvertMessages([]database.ChatMessage{
@@ -1193,7 +1195,7 @@ func TestMixedFormatConversation(t *testing.T) {
// 4. Old tool (legacy result rows).
oldToolRaw, err := chatprompt.MarshalToolResult(
"call_1", "analyze_image",
json.RawMessage(`{"description":"a cat"}`), false,
json.RawMessage(`{"description":"a cat"}`), false, false,
false, nil,
)
require.NoError(t, err)
@@ -1425,9 +1427,9 @@ func mustMarshalContent(t *testing.T, content []fantasy.Content) pqtype.NullRawM
return result
}
func mustMarshalToolResult(t *testing.T, toolCallID, toolName string, result json.RawMessage, isError, providerExecuted bool) pqtype.NullRawMessage {
func mustMarshalToolResult(t *testing.T, toolCallID, toolName string, result json.RawMessage, isError, isMedia, providerExecuted bool) pqtype.NullRawMessage {
t.Helper()
raw, err := chatprompt.MarshalToolResult(toolCallID, toolName, result, isError, providerExecuted, nil)
raw, err := chatprompt.MarshalToolResult(toolCallID, toolName, result, isError, isMedia, providerExecuted, nil)
require.NoError(t, err)
return raw
}
@@ -1608,7 +1610,7 @@ func TestNulEscapeRoundTrip(t *testing.T) {
resultJSON := json.RawMessage(`"output:\u0000done"`)
parts := []codersdk.ChatMessagePart{
codersdk.ChatMessageToolResult("call-1", "my_tool", resultJSON, false),
codersdk.ChatMessageToolResult("call-1", "my_tool", resultJSON, false, false),
}
encoded, err := chatprompt.MarshalParts(parts)
@@ -1676,7 +1678,7 @@ func TestConvertMessagesWithFiles_FiltersEmptyTextAndReasoningParts(t *testing.T
codersdk.ChatMessageText(" hello "), // kept with original whitespace
codersdk.ChatMessageReasoning("thinking deeply"), // kept
codersdk.ChatMessageToolCall("call-1", "my_tool", json.RawMessage(`{"x":1}`)),
codersdk.ChatMessageToolResult("call-1", "my_tool", json.RawMessage(`{"ok":true}`), false),
codersdk.ChatMessageToolResult("call-1", "my_tool", json.RawMessage(`{"ok":true}`), false, false),
}
prompt, err := chatprompt.ConvertMessagesWithFiles(
@@ -1926,3 +1928,400 @@ func convertSingleResolvedFileMessage(t *testing.T, fileID uuid.UUID, fileData c
require.NoError(t, err)
return prompt
}
func TestMediaToolResultRoundTrip(t *testing.T) {
t.Parallel()
// Full DB round-trip test: insert messages into PostgreSQL,
// load them back via GetChatMessagesForPromptByChatID, and
// verify the fantasy message parts are identical after the
// round-trip.
db, _ := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitShort)
user := dbgen.User(t, db, database.User{})
_, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{
Provider: "anthropic",
DisplayName: "anthropic",
APIKey: "test-key",
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
Enabled: true,
})
require.NoError(t, err)
model, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{
Provider: "anthropic",
Model: "test-model",
DisplayName: "Test Model",
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
Enabled: true,
IsDefault: true,
ContextLimit: 200000,
CompressionThreshold: 70,
Options: json.RawMessage(`{}`),
})
require.NoError(t, err)
// Small base64 payload standing in for a real screenshot.
const imageData = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAAC0lEQVQI12NgAAIABQAB"
// insertPair writes an assistant tool-call message and a
// tool-result message into the database, returning the chat
// they belong to.
insertPair := func(
t *testing.T,
callID, toolName string,
resultParts []codersdk.ChatMessagePart,
) database.Chat {
t.Helper()
chat, chatErr := db.InsertChat(ctx, database.InsertChatParams{
OwnerID: user.ID,
LastModelConfigID: model.ID,
Title: "media-roundtrip-" + callID,
})
require.NoError(t, chatErr)
// Assistant message with the tool call.
callPart := codersdk.ChatMessageToolCall(callID, toolName, json.RawMessage(`{}`))
assistantEncoded, encErr := chatprompt.MarshalParts([]codersdk.ChatMessagePart{callPart})
require.NoError(t, encErr)
// Tool result message.
resultEncoded, encErr := chatprompt.MarshalParts(resultParts)
require.NoError(t, encErr)
_, insertErr := db.InsertChatMessages(ctx, database.InsertChatMessagesParams{
ChatID: chat.ID,
CreatedBy: []uuid.UUID{user.ID, user.ID},
ModelConfigID: []uuid.UUID{model.ID, model.ID},
Role: []database.ChatMessageRole{database.ChatMessageRoleAssistant, database.ChatMessageRoleTool},
Content: []string{string(assistantEncoded.RawMessage), string(resultEncoded.RawMessage)},
ContentVersion: []int16{chatprompt.CurrentContentVersion, chatprompt.CurrentContentVersion},
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth, database.ChatMessageVisibilityBoth},
InputTokens: []int64{0, 0},
OutputTokens: []int64{0, 0},
TotalTokens: []int64{0, 0},
ReasoningTokens: []int64{0, 0},
CacheCreationTokens: []int64{0, 0},
CacheReadTokens: []int64{0, 0},
ContextLimit: []int64{0, 0},
Compressed: []bool{false, false},
TotalCostMicros: []int64{0, 0},
RuntimeMs: []int64{0, 0},
})
require.NoError(t, insertErr)
return chat
}
// loadPrompt reads messages back from the DB via the same
// path used by runChat, and converts them to fantasy messages.
loadPrompt := func(t *testing.T, chat database.Chat) []fantasy.Message {
t.Helper()
dbMsgs, loadErr := db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
require.NoError(t, loadErr)
prompt, convErr := chatprompt.ConvertMessagesWithFiles(
ctx, dbMsgs, nil, slogtest.Make(t, nil),
)
require.NoError(t, convErr)
return prompt
}
t.Run("MediaResultRoundTripsAsMedia", func(t *testing.T) {
t.Parallel()
const callID = "call-screenshot-1"
const toolName = "computer"
const mimeType = "image/png"
// Use PartFromContent (the production write path) to
// produce the SDK part, rather than hand-crafting JSON.
// Computer use is a provider-defined tool, but Coder executes it
// locally via chatloop.ProviderTool.Runner, so screenshot results
// persist as tool-role messages with ProviderExecuted=false.
sdkPart := chatprompt.PartFromContent(fantasy.ToolResultContent{
ToolCallID: callID,
ToolName: toolName,
Result: fantasy.ToolResultOutputContentMedia{
Data: imageData,
MediaType: mimeType,
},
})
chat := insertPair(t, callID, toolName, []codersdk.ChatMessagePart{sdkPart})
prompt := loadPrompt(t, chat)
// assistant + tool
require.Len(t, prompt, 2)
toolMsg := prompt[1]
require.Equal(t, fantasy.MessageRoleTool, toolMsg.Role)
require.Len(t, toolMsg.Content, 1)
resultPart, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](toolMsg.Content[0])
require.True(t, ok, "expected ToolResultPart")
require.Equal(t, callID, resultPart.ToolCallID)
require.False(t, resultPart.ProviderExecuted)
mediaOutput, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](resultPart.Output)
require.True(t, ok, "expected ToolResultOutputContentMedia, got %T", resultPart.Output)
require.Equal(t, imageData, mediaOutput.Data)
require.Equal(t, mimeType, mediaOutput.MediaType)
})
t.Run("MediaResultWithText", func(t *testing.T) {
t.Parallel()
const callID = "call-screenshot-2"
const toolName = "computer"
const mimeType = "image/png"
sdkPart := chatprompt.PartFromContent(fantasy.ToolResultContent{
ToolCallID: callID,
ToolName: toolName,
Result: fantasy.ToolResultOutputContentMedia{
Data: imageData,
MediaType: mimeType,
Text: "screenshot after click",
},
})
chat := insertPair(t, callID, toolName, []codersdk.ChatMessagePart{sdkPart})
prompt := loadPrompt(t, chat)
require.Len(t, prompt, 2)
resultPart, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](prompt[1].Content[0])
require.True(t, ok)
require.False(t, resultPart.ProviderExecuted)
mediaOutput, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](resultPart.Output)
require.True(t, ok, "expected media output")
require.Equal(t, imageData, mediaOutput.Data)
require.Equal(t, mimeType, mediaOutput.MediaType)
require.Equal(t, "screenshot after click", mediaOutput.Text)
})
t.Run("TextResultStaysText", func(t *testing.T) {
t.Parallel()
const callID = "call-text-1"
const toolName = "read_file"
textResult := json.RawMessage(`{"output":"file contents here"}`)
chat := insertPair(t, callID, toolName, []codersdk.ChatMessagePart{
codersdk.ChatMessageToolResult(callID, toolName, textResult, false, false),
})
prompt := loadPrompt(t, chat)
require.Len(t, prompt, 2)
resultPart, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](prompt[1].Content[0])
require.True(t, ok)
_, isMedia := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](resultPart.Output)
require.False(t, isMedia, "text result should not be detected as media")
textOutput, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](resultPart.Output)
require.True(t, ok, "expected ToolResultOutputContentText")
require.JSONEq(t, string(textResult), textOutput.Text)
})
t.Run("MissingMimeTypeStaysText", func(t *testing.T) {
t.Parallel()
const callID = "call-no-mime"
const toolName = "computer"
noMimeJSON := json.RawMessage(`{"data":"some_base64","text":""}`)
chat := insertPair(t, callID, toolName, []codersdk.ChatMessagePart{
codersdk.ChatMessageToolResult(callID, toolName, noMimeJSON, false, false),
})
prompt := loadPrompt(t, chat)
require.Len(t, prompt, 2)
resultPart, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](prompt[1].Content[0])
require.True(t, ok)
_, isMedia := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](resultPart.Output)
require.False(t, isMedia, "missing mime_type should not produce media")
})
t.Run("MissingDataStaysText", func(t *testing.T) {
t.Parallel()
const callID = "call-no-data"
const toolName = "computer"
noDataJSON := json.RawMessage(`{"mime_type":"image/png","text":""}`)
chat := insertPair(t, callID, toolName, []codersdk.ChatMessagePart{
codersdk.ChatMessageToolResult(callID, toolName, noDataJSON, false, false),
})
prompt := loadPrompt(t, chat)
require.Len(t, prompt, 2)
resultPart, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](prompt[1].Content[0])
require.True(t, ok)
_, isMedia := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](resultPart.Output)
require.False(t, isMedia, "missing data should not produce media")
})
t.Run("ErrorResultStaysError", func(t *testing.T) {
t.Parallel()
const callID = "call-err"
const toolName = "computer"
// Use PartFromContent to go through the production
// write path for error results.
sdkPart := chatprompt.PartFromContent(fantasy.ToolResultContent{
ToolCallID: callID,
ToolName: toolName,
Result: fantasy.ToolResultOutputContentError{
Error: xerrors.New("screenshot failed"),
},
})
chat := insertPair(t, callID, toolName, []codersdk.ChatMessagePart{sdkPart})
prompt := loadPrompt(t, chat)
require.Len(t, prompt, 2)
resultPart, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](prompt[1].Content[0])
require.True(t, ok)
errOutput, isError := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](resultPart.Output)
require.True(t, isError, "error result should remain error")
require.Contains(t, errOutput.Error.Error(), "screenshot failed")
})
t.Run("NonMediaResultTypeStaysText", func(t *testing.T) {
t.Parallel()
// A text tool result that happens to contain "data" and
// "mime_type" fields must NOT be misidentified as media
// when IsMedia is false. The protection is entirely the
// IsMedia boolean flag on the ChatMessagePart.
const callID = "call-not-media"
const toolName = "list_files"
textJSON, jsonErr := json.Marshal(map[string]any{
"result_type": "listing",
"data": "file1.txt",
"mime_type": "text/csv",
})
require.NoError(t, jsonErr)
chat := insertPair(t, callID, toolName, []codersdk.ChatMessagePart{
codersdk.ChatMessageToolResult(callID, toolName, textJSON, false, false),
})
prompt := loadPrompt(t, chat)
require.Len(t, prompt, 2)
resultPart, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](prompt[1].Content[0])
require.True(t, ok)
_, isMedia := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](resultPart.Output)
require.False(t, isMedia, "non-media result_type must not be detected as media")
textOutput, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](resultPart.Output)
require.True(t, ok, "expected ToolResultOutputContentText")
require.JSONEq(t, string(textJSON), textOutput.Text)
})
t.Run("IsMediaTrueButMissingMimeType", func(t *testing.T) {
t.Parallel()
// IsMedia is true but the JSON payload has no mime_type
// field. The media reconstruction guard should fail and
// the result should fall through to text.
const callID = "call-media-no-mime"
const toolName = "computer"
noMimeJSON := json.RawMessage(`{"data":"some_base64","text":""}`)
chat := insertPair(t, callID, toolName, []codersdk.ChatMessagePart{
codersdk.ChatMessageToolResult(callID, toolName, noMimeJSON, false, true),
})
prompt := loadPrompt(t, chat)
require.Len(t, prompt, 2)
resultPart, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](prompt[1].Content[0])
require.True(t, ok)
_, isMedia := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](resultPart.Output)
require.False(t, isMedia, "IsMedia=true with missing mime_type should fall through to text")
_, isText := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](resultPart.Output)
require.True(t, isText, "expected ToolResultOutputContentText")
})
t.Run("IsMediaTrueButMissingData", func(t *testing.T) {
t.Parallel()
// IsMedia is true but the JSON payload has no data field.
// The media reconstruction guard should fail and the result
// should fall through to text.
const callID = "call-media-no-data"
const toolName = "computer"
noDataJSON := json.RawMessage(`{"mime_type":"image/png","text":""}`)
chat := insertPair(t, callID, toolName, []codersdk.ChatMessagePart{
codersdk.ChatMessageToolResult(callID, toolName, noDataJSON, false, true),
})
prompt := loadPrompt(t, chat)
require.Len(t, prompt, 2)
resultPart, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](prompt[1].Content[0])
require.True(t, ok)
_, isMedia := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](resultPart.Output)
require.False(t, isMedia, "IsMedia=true with missing data should fall through to text")
_, isText := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](resultPart.Output)
require.True(t, isText, "expected ToolResultOutputContentText")
})
t.Run("IsMediaTrueButGarbageJSON", func(t *testing.T) {
t.Parallel()
// IsMedia is true but the result is a JSON string, not
// an object. Unmarshal into persistedMediaResult fails
// and the result should fall through to text. Truly
// invalid JSON cannot reach the read path because both
// MarshalParts and PostgreSQL jsonb reject it, so a
// non-object JSON value is the realistic edge case.
const callID = "call-media-garbage"
const toolName = "computer"
garbageJSON := json.RawMessage(`"not a json object"`)
chat := insertPair(t, callID, toolName, []codersdk.ChatMessagePart{
codersdk.ChatMessageToolResult(callID, toolName, garbageJSON, false, true),
})
prompt := loadPrompt(t, chat)
require.Len(t, prompt, 2)
resultPart, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](prompt[1].Content[0])
require.True(t, ok)
_, isMedia := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](resultPart.Output)
require.False(t, isMedia, "IsMedia=true with garbage JSON should fall through to text")
_, isText := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](resultPart.Output)
require.True(t, isText, "expected ToolResultOutputContentText")
})
}
+6 -1
View File
@@ -159,6 +159,7 @@ type ChatMessagePart struct {
Result json.RawMessage `json:"result,omitempty" variants:"tool-result?"`
ResultDelta string `json:"result_delta,omitempty"`
IsError bool `json:"is_error,omitempty" variants:"tool-result?"`
IsMedia bool `json:"is_media,omitempty" variants:"tool-result?"`
SourceID string `json:"source_id,omitempty" variants:"source?"`
URL string `json:"url" variants:"source"`
Title string `json:"title,omitempty" variants:"source?"`
@@ -241,13 +242,17 @@ func ChatMessageToolCall(toolCallID, toolName string, args json.RawMessage) Chat
}
// ChatMessageToolResult builds a tool-result chat message part.
func ChatMessageToolResult(toolCallID, toolName string, result json.RawMessage, isError bool) ChatMessagePart {
// The isMedia flag marks the result as carrying binary media content
// (e.g. a screenshot) so that round-trip reconstruction preserves
// the media type instead of sending raw base64 as text tokens.
func ChatMessageToolResult(toolCallID, toolName string, result json.RawMessage, isError bool, isMedia bool) ChatMessagePart {
return ChatMessagePart{
Type: ChatMessagePartTypeToolResult,
ToolCallID: toolCallID,
ToolName: toolName,
Result: result,
IsError: isError,
IsMedia: isMedia,
}
}
+1
View File
@@ -1944,6 +1944,7 @@ export interface ChatToolResultPart {
readonly mcp_server_config_id?: string;
readonly result?: Record<string, string>;
readonly is_error?: boolean;
readonly is_media?: boolean;
/**
* ProviderExecuted indicates the tool call was executed by
* the provider (e.g. Anthropic computer use).