Files
coder/coderd/x/chatd/chattool/mcpworkspace.go
T
Cian Johnston a02339c66a fix(coderd/x/chatd): prevent invalid tool results from poisoning chat history (#24663)
- **computeruse.go**: Decode base64 screenshot data before storing in
`ToolResponse.Data` (was casting base64 string to bytes without
decoding)
- **chatloop.go**: Re-encode `ToolResponse.Data` to base64 via
`base64.StdEncoding.EncodeToString` instead of `string()` cast
- **mcpclient.go**: UTF-8 validate all text from MCP responses in
`convertCallResult()` using `strings.ToValidUTF8`
- **chatprompt.go (persist)**: Defense-in-depth UTF-8 sanitization of
text and media Text fields before database storage
- **chatprompt.go (replay)**: Antivenom layer that validates base64 and
UTF-8 at read time, auto-healing already-poisoned chats without
requiring a migration
- `TestToolResultAntivenom`: 4 subtests covering poisoned text, poisoned
media, valid media round-trip, and media with invalid UTF-8 text
-  Adds `TestConvertCallResult_UTF8Sanitization`: 4 subtests covering invalid
UTF-8 in TextContent, EmbeddedResource, valid passthrough, and
multi-part
- Adds `TestComputerUseTool_Run_ScreenshotDataIsDecodedBinary`: Verifies no
double-encode in the computer-use path
- Updated existing computer-use tests for the new decoded-binary
contract

> 🤖
2026-04-23 19:58:38 +01:00

152 lines
3.5 KiB
Go

package chattool
import (
"context"
"encoding/base64"
"encoding/json"
"strings"
"charm.land/fantasy"
"github.com/coder/coder/v2/codersdk/workspacesdk"
)
// WorkspaceMCPTool wraps a single MCP tool discovered in a
// workspace, proxying calls through the workspace agent
// connection. It implements fantasy.AgentTool so it can be
// registered alongside built-in chat tools.
type WorkspaceMCPTool struct {
info fantasy.ToolInfo
getConn func(context.Context) (workspacesdk.AgentConn, error)
providerOpts fantasy.ProviderOptions
}
// NewWorkspaceMCPTool creates a tool wrapper from an MCPToolInfo
// discovered on a workspace agent. Each tool proxies calls back
// through the agent connection.
func NewWorkspaceMCPTool(
tool workspacesdk.MCPToolInfo,
getConn func(context.Context) (workspacesdk.AgentConn, error),
) *WorkspaceMCPTool {
required := tool.Required
if required == nil {
required = []string{}
}
return &WorkspaceMCPTool{
info: fantasy.ToolInfo{
Name: tool.Name,
Description: tool.Description,
Parameters: tool.Schema,
Required: required,
Parallel: true,
},
getConn: getConn,
}
}
func (t *WorkspaceMCPTool) Info() fantasy.ToolInfo {
return t.info
}
func (t *WorkspaceMCPTool) Run(
ctx context.Context,
params fantasy.ToolCall,
) (fantasy.ToolResponse, error) {
conn, err := t.getConn(ctx)
if err != nil {
return fantasy.NewTextErrorResponse(
"workspace connection failed: " + err.Error(),
), nil
}
var args map[string]any
if params.Input != "" {
if err := json.Unmarshal(
[]byte(params.Input), &args,
); err != nil {
return fantasy.NewTextErrorResponse(
"invalid JSON input: " + err.Error(),
), nil
}
}
resp, err := conn.CallMCPTool(ctx, workspacesdk.CallMCPToolRequest{
ToolName: t.info.Name,
Arguments: args,
})
if err != nil {
return fantasy.NewTextErrorResponse(err.Error()), nil
}
return convertMCPToolResponse(resp), nil
}
func (t *WorkspaceMCPTool) ProviderOptions() fantasy.ProviderOptions {
return t.providerOpts
}
func (t *WorkspaceMCPTool) SetProviderOptions(
opts fantasy.ProviderOptions,
) {
t.providerOpts = opts
}
// convertMCPToolResponse translates a workspace agent MCP tool
// response into a fantasy.ToolResponse. Text content blocks are
// collected and joined; binary content (image/media) is returned
// only when no text is available, matching the mcpclient
// conversion strategy.
func convertMCPToolResponse(
resp workspacesdk.CallMCPToolResponse,
) fantasy.ToolResponse {
var (
textParts []string
binaryResult *fantasy.ToolResponse
)
for _, c := range resp.Content {
switch c.Type {
case "text":
textParts = append(textParts, strings.ToValidUTF8(c.Text, "\uFFFD"))
case "image", "audio":
if c.Data == "" {
continue
}
data, err := base64.StdEncoding.DecodeString(c.Data)
if err != nil {
textParts = append(textParts,
"[binary decode error: "+err.Error()+"]",
)
continue
}
if binaryResult == nil {
r := fantasy.ToolResponse{
Type: c.Type,
Data: data,
MediaType: c.MediaType,
IsError: resp.IsError,
}
binaryResult = &r
}
default:
textParts = append(textParts, strings.ToValidUTF8(c.Text, "\uFFFD"))
}
}
// Prefer text content. Only fall back to binary when no
// text was collected.
if len(textParts) > 0 {
r := fantasy.NewTextResponse(
strings.Join(textParts, "\n"),
)
r.IsError = resp.IsError
return r
}
if binaryResult != nil {
return *binaryResult
}
r := fantasy.NewTextResponse("")
r.IsError = resp.IsError
return r
}