Files
coder/coderd/chatd/chatloop/chatloop_test.go
T
Kyle Carberry 42c12176a0 fix(chatd): persist interrupted tool call steps instead of losing them (#23011)
## Problem

When a chat is interrupted while tools are executing, the step content
(text, reasoning, tool calls, and partial tool results) was being lost.
Two gaps existed:

1. **During tool execution**: `executeTools` returns with error results
for interrupted tools, but the subsequent `PersistStep(ctx, ...)` fails
on the canceled context and returns `ErrInterrupted` without persisting
anything.

2. **PersistStep race**: If the context is canceled between the
post-tool interrupt check and the `PersistStep` call, the same loss
occurs.

This is inconsistent with how we handle stream interruptions (which
properly flush and persist partial content via `persistInterruptedStep`)
and how [coder/blink](https://github.com/coder/blink) handles
interruptions (always inserting the response message regardless of
execution phase).

## Fix

Two changes in `chatloop.go`:

- **Post-tool-execution interrupt check**: After `executeTools` returns,
check if the context was interrupted and route through
`persistInterruptedStep` (which uses `context.WithoutCancel` internally)
to save the accumulated content.

- **PersistStep fallback**: If `PersistStep` returns `ErrInterrupted`,
retry via `persistInterruptedStep` so partial content is not lost.

## Tests

- `TestRun_InterruptedDuringToolExecutionPersistsStep`: Verifies that
when a tool is blocked and the chat is interrupted, the step (text +
reasoning + tool call + tool error result) is persisted via the
interrupt-safe path.

- `TestRun_PersistStepInterruptedFallback`: Verifies that when
`PersistStep` itself returns `ErrInterrupted`, the step is retried via
the fallback path and content is saved.
2026-03-12 16:59:16 -04:00

767 lines
24 KiB
Go

package chatloop //nolint:testpackage // Uses internal symbols.
import (
"context"
"errors"
"iter"
"strings"
"sync"
"testing"
"charm.land/fantasy"
fantasyanthropic "charm.land/fantasy/providers/anthropic"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
)
const activeToolName = "read_file"
func TestRun_ActiveToolsPrepareBehavior(t *testing.T) {
t.Parallel()
var capturedCall fantasy.Call
model := &loopTestModel{
provider: fantasyanthropic.Name,
streamFn: func(_ context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
capturedCall = call
return streamFromParts([]fantasy.StreamPart{
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "done"},
{Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"},
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop},
}), nil
},
}
persistStepCalls := 0
var persistedStep PersistedStep
err := Run(context.Background(), RunOptions{
Model: model,
Messages: []fantasy.Message{
textMessage(fantasy.MessageRoleSystem, "sys-1"),
textMessage(fantasy.MessageRoleSystem, "sys-2"),
textMessage(fantasy.MessageRoleUser, "hello"),
textMessage(fantasy.MessageRoleAssistant, "working"),
textMessage(fantasy.MessageRoleUser, "continue"),
},
Tools: []fantasy.AgentTool{
newNoopTool(activeToolName),
newNoopTool("write_file"),
},
MaxSteps: 3,
ActiveTools: []string{activeToolName},
ContextLimitFallback: 4096,
PersistStep: func(_ context.Context, step PersistedStep) error {
persistStepCalls++
persistedStep = step
return nil
},
})
require.NoError(t, err)
require.Equal(t, 1, persistStepCalls)
require.True(t, persistedStep.ContextLimit.Valid)
require.Equal(t, int64(4096), persistedStep.ContextLimit.Int64)
require.NotEmpty(t, capturedCall.Prompt)
require.False(t, containsPromptSentinel(capturedCall.Prompt))
require.Len(t, capturedCall.Tools, 1)
require.Equal(t, activeToolName, capturedCall.Tools[0].GetName())
require.Len(t, capturedCall.Prompt, 5)
require.False(t, hasAnthropicEphemeralCacheControl(capturedCall.Prompt[0]))
require.True(t, hasAnthropicEphemeralCacheControl(capturedCall.Prompt[1]))
require.False(t, hasAnthropicEphemeralCacheControl(capturedCall.Prompt[2]))
require.True(t, hasAnthropicEphemeralCacheControl(capturedCall.Prompt[3]))
require.True(t, hasAnthropicEphemeralCacheControl(capturedCall.Prompt[4]))
}
func TestRun_InterruptedStepPersistsSyntheticToolResult(t *testing.T) {
t.Parallel()
started := make(chan struct{})
model := &loopTestModel{
provider: "fake",
streamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
return iter.Seq[fantasy.StreamPart](func(yield func(fantasy.StreamPart) bool) {
parts := []fantasy.StreamPart{
{
Type: fantasy.StreamPartTypeToolInputStart,
ID: "interrupt-tool-1",
ToolCallName: "read_file",
},
{
Type: fantasy.StreamPartTypeToolInputDelta,
ID: "interrupt-tool-1",
ToolCallName: "read_file",
Delta: `{"path":"main.go"`,
},
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "partial assistant output"},
}
for _, part := range parts {
if !yield(part) {
return
}
}
select {
case <-started:
default:
close(started)
}
<-ctx.Done()
_ = yield(fantasy.StreamPart{
Type: fantasy.StreamPartTypeError,
Error: ctx.Err(),
})
}), nil
},
}
ctx, cancel := context.WithCancelCause(context.Background())
defer cancel(nil)
go func() {
<-started
cancel(ErrInterrupted)
}()
persistedAssistantCtxErr := xerrors.New("unset")
var persistedContent []fantasy.Content
err := Run(ctx, RunOptions{
Model: model,
Messages: []fantasy.Message{
textMessage(fantasy.MessageRoleUser, "hello"),
},
Tools: []fantasy.AgentTool{
newNoopTool("read_file"),
},
MaxSteps: 3,
PersistStep: func(persistCtx context.Context, step PersistedStep) error {
persistedAssistantCtxErr = persistCtx.Err()
persistedContent = append([]fantasy.Content(nil), step.Content...)
return nil
},
})
require.ErrorIs(t, err, ErrInterrupted)
require.NoError(t, persistedAssistantCtxErr)
require.NotEmpty(t, persistedContent)
var (
foundText bool
foundToolCall bool
foundToolResult bool
)
for _, block := range persistedContent {
if text, ok := fantasy.AsContentType[fantasy.TextContent](block); ok {
if strings.Contains(text.Text, "partial assistant output") {
foundText = true
}
continue
}
if toolCall, ok := fantasy.AsContentType[fantasy.ToolCallContent](block); ok {
if toolCall.ToolCallID == "interrupt-tool-1" &&
toolCall.ToolName == "read_file" &&
strings.Contains(toolCall.Input, `"path":"main.go"`) {
foundToolCall = true
}
continue
}
if toolResult, ok := fantasy.AsContentType[fantasy.ToolResultContent](block); ok {
if toolResult.ToolCallID == "interrupt-tool-1" &&
toolResult.ToolName == "read_file" {
_, isErr := toolResult.Result.(fantasy.ToolResultOutputContentError)
require.True(t, isErr, "interrupted tool result should be an error")
foundToolResult = true
}
}
}
require.True(t, foundText)
require.True(t, foundToolCall)
require.True(t, foundToolResult)
}
type loopTestModel struct {
provider string
model string
generateFn func(context.Context, fantasy.Call) (*fantasy.Response, error)
streamFn func(context.Context, fantasy.Call) (fantasy.StreamResponse, error)
}
func (m *loopTestModel) Provider() string {
if m.provider != "" {
return m.provider
}
return "fake"
}
func (m *loopTestModel) Model() string {
if m.model != "" {
return m.model
}
return "fake"
}
func (m *loopTestModel) Generate(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
if m.generateFn != nil {
return m.generateFn(ctx, call)
}
return &fantasy.Response{}, nil
}
func (m *loopTestModel) Stream(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
if m.streamFn != nil {
return m.streamFn(ctx, call)
}
return streamFromParts([]fantasy.StreamPart{{
Type: fantasy.StreamPartTypeFinish,
FinishReason: fantasy.FinishReasonStop,
}}), nil
}
func (*loopTestModel) GenerateObject(context.Context, fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
return nil, xerrors.New("not implemented")
}
func (*loopTestModel) StreamObject(context.Context, fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
return nil, xerrors.New("not implemented")
}
func streamFromParts(parts []fantasy.StreamPart) fantasy.StreamResponse {
return iter.Seq[fantasy.StreamPart](func(yield func(fantasy.StreamPart) bool) {
for _, part := range parts {
if !yield(part) {
return
}
}
})
}
func newNoopTool(name string) fantasy.AgentTool {
return fantasy.NewAgentTool(
name,
"test noop tool",
func(context.Context, struct{}, fantasy.ToolCall) (fantasy.ToolResponse, error) {
return fantasy.ToolResponse{}, nil
},
)
}
func textMessage(role fantasy.MessageRole, text string) fantasy.Message {
return fantasy.Message{
Role: role,
Content: []fantasy.MessagePart{
fantasy.TextPart{Text: text},
},
}
}
func containsPromptSentinel(prompt []fantasy.Message) bool {
for _, message := range prompt {
if message.Role != fantasy.MessageRoleUser || len(message.Content) != 1 {
continue
}
textPart, ok := fantasy.AsMessagePart[fantasy.TextPart](message.Content[0])
if !ok {
continue
}
if strings.HasPrefix(textPart.Text, "__chatd_agent_prompt_sentinel_") {
return true
}
}
return false
}
func TestRun_MultiStepToolExecution(t *testing.T) {
t.Parallel()
var mu sync.Mutex
var streamCalls int
var secondCallPrompt []fantasy.Message
model := &loopTestModel{
provider: "fake",
streamFn: func(_ context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
mu.Lock()
step := streamCalls
streamCalls++
mu.Unlock()
switch step {
case 0:
// Step 0: produce a tool call.
return streamFromParts([]fantasy.StreamPart{
{Type: fantasy.StreamPartTypeToolInputStart, ID: "tc-1", ToolCallName: "read_file"},
{Type: fantasy.StreamPartTypeToolInputDelta, ID: "tc-1", Delta: `{"path":"main.go"}`},
{Type: fantasy.StreamPartTypeToolInputEnd, ID: "tc-1"},
{
Type: fantasy.StreamPartTypeToolCall,
ID: "tc-1",
ToolCallName: "read_file",
ToolCallInput: `{"path":"main.go"}`,
},
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonToolCalls},
}), nil
default:
// Step 1: capture the prompt the loop sent us,
// then return plain text.
mu.Lock()
secondCallPrompt = append([]fantasy.Message(nil), call.Prompt...)
mu.Unlock()
return streamFromParts([]fantasy.StreamPart{
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "all done"},
{Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"},
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop},
}), nil
}
},
}
var persistStepCalls int
err := Run(context.Background(), RunOptions{
Model: model,
Messages: []fantasy.Message{
textMessage(fantasy.MessageRoleUser, "please read main.go"),
},
Tools: []fantasy.AgentTool{
newNoopTool("read_file"),
},
MaxSteps: 5,
PersistStep: func(_ context.Context, _ PersistedStep) error {
persistStepCalls++
return nil
},
})
require.NoError(t, err)
// Stream was called twice: once for the tool-call step,
// once for the follow-up text step.
require.Equal(t, 2, streamCalls)
// PersistStep is called once per step.
require.Equal(t, 2, persistStepCalls)
// The second call's prompt must contain the assistant message
// from step 0 (with the tool call) and a tool-result message.
require.NotEmpty(t, secondCallPrompt)
var foundAssistantToolCall bool
var foundToolResult bool
for _, msg := range secondCallPrompt {
if msg.Role == fantasy.MessageRoleAssistant {
for _, part := range msg.Content {
if tc, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](part); ok {
if tc.ToolCallID == "tc-1" && tc.ToolName == "read_file" {
foundAssistantToolCall = true
}
}
}
}
if msg.Role == fantasy.MessageRoleTool {
for _, part := range msg.Content {
if tr, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part); ok {
if tr.ToolCallID == "tc-1" {
foundToolResult = true
}
}
}
}
}
require.True(t, foundAssistantToolCall, "second call prompt should contain assistant tool call from step 0")
require.True(t, foundToolResult, "second call prompt should contain tool result message")
}
func TestRun_PersistStepErrorPropagates(t *testing.T) {
t.Parallel()
model := &loopTestModel{
provider: "fake",
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
return streamFromParts([]fantasy.StreamPart{
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "hello"},
{Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"},
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop},
}), nil
},
}
persistErr := xerrors.New("database write failed")
err := Run(context.Background(), RunOptions{
Model: model,
Messages: []fantasy.Message{
textMessage(fantasy.MessageRoleUser, "hello"),
},
MaxSteps: 1,
PersistStep: func(_ context.Context, _ PersistedStep) error {
return persistErr
},
})
require.Error(t, err)
require.ErrorContains(t, err, "database write failed")
}
// TestRun_ShutdownDuringToolExecutionReturnsContextCanceled verifies that
// when the parent context is canceled (simulating server shutdown) while
// a tool is blocked, Run returns context.Canceled — not ErrInterrupted.
// This matters because the caller uses the error type to decide whether
// to set chat status to "pending" (retryable on another worker) vs
// "waiting" (stuck forever).
func TestRun_ShutdownDuringToolExecutionReturnsContextCanceled(t *testing.T) {
t.Parallel()
toolStarted := make(chan struct{})
// Model returns a single tool call, then finishes.
model := &loopTestModel{
provider: "fake",
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
return streamFromParts([]fantasy.StreamPart{
{Type: fantasy.StreamPartTypeToolInputStart, ID: "tc-block", ToolCallName: "blocking_tool"},
{Type: fantasy.StreamPartTypeToolInputDelta, ID: "tc-block", Delta: `{}`},
{Type: fantasy.StreamPartTypeToolInputEnd, ID: "tc-block"},
{
Type: fantasy.StreamPartTypeToolCall,
ID: "tc-block",
ToolCallName: "blocking_tool",
ToolCallInput: `{}`,
},
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonToolCalls},
}), nil
},
}
// Tool that blocks until its context is canceled, simulating
// a long-running operation like wait_agent.
blockingTool := fantasy.NewAgentTool(
"blocking_tool",
"blocks until context canceled",
func(ctx context.Context, _ struct{}, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
close(toolStarted)
<-ctx.Done()
return fantasy.ToolResponse{}, ctx.Err()
},
)
// Simulate the server context (parent) and chat context
// (child). Canceling the parent simulates graceful shutdown.
serverCtx, serverCancel := context.WithCancel(context.Background())
defer serverCancel()
serverCancelDone := make(chan struct{})
go func() {
defer close(serverCancelDone)
<-toolStarted
t.Logf("tool started, canceling server context to simulate shutdown")
serverCancel()
}()
// persistStep mirrors the FIXED chatd.go code: it only returns
// ErrInterrupted when the context was actually canceled due to
// an interruption (cause is ErrInterrupted). For shutdown
// (plain context.Canceled), it returns the original error so
// callers can distinguish the two.
persistStep := func(persistCtx context.Context, _ PersistedStep) error {
if persistCtx.Err() != nil {
if errors.Is(context.Cause(persistCtx), ErrInterrupted) {
return ErrInterrupted
}
return persistCtx.Err()
}
return nil
}
err := Run(serverCtx, RunOptions{
Model: model,
Messages: []fantasy.Message{
textMessage(fantasy.MessageRoleUser, "run the blocking tool"),
},
Tools: []fantasy.AgentTool{blockingTool},
MaxSteps: 3,
PersistStep: persistStep,
})
// Wait for the cancel goroutine to finish to aid flake
// diagnosis if the test ever hangs.
<-serverCancelDone
require.Error(t, err)
// The error must NOT be ErrInterrupted — it should propagate
// as context.Canceled so the caller can distinguish shutdown
// from user interruption. Use assert (not require) so both
// checks are evaluated even if the first fails.
assert.NotErrorIs(t, err, ErrInterrupted, "shutdown cancellation must not be converted to ErrInterrupted")
assert.ErrorIs(t, err, context.Canceled, "shutdown should propagate as context.Canceled")
}
func TestToResponseMessages_ProviderExecutedToolResultInAssistantMessage(t *testing.T) {
t.Parallel()
sr := stepResult{
content: []fantasy.Content{
// Provider-executed tool call (e.g. web_search).
fantasy.ToolCallContent{
ToolCallID: "provider-tc-1",
ToolName: "web_search",
Input: `{"query":"coder"}`,
ProviderExecuted: true,
},
// Provider-executed tool result — must stay in
// assistant message.
fantasy.ToolResultContent{
ToolCallID: "provider-tc-1",
ToolName: "web_search",
ProviderExecuted: true,
ProviderMetadata: fantasy.ProviderMetadata{"anthropic": nil},
},
// Local tool call (e.g. read_file).
fantasy.ToolCallContent{
ToolCallID: "local-tc-1",
ToolName: "read_file",
Input: `{"path":"main.go"}`,
ProviderExecuted: false,
},
// Local tool result — should go into tool message.
fantasy.ToolResultContent{
ToolCallID: "local-tc-1",
ToolName: "read_file",
Result: fantasy.ToolResultOutputContentText{Text: "some result"},
ProviderExecuted: false,
},
},
}
msgs := sr.toResponseMessages()
require.Len(t, msgs, 2, "expected assistant + tool messages")
// First message: assistant role.
assistantMsg := msgs[0]
assert.Equal(t, fantasy.MessageRoleAssistant, assistantMsg.Role)
require.Len(t, assistantMsg.Content, 3,
"assistant message should have provider ToolCallPart, provider ToolResultPart, and local ToolCallPart")
// Part 0: provider tool call.
providerTC, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](assistantMsg.Content[0])
require.True(t, ok, "part 0 should be ToolCallPart")
assert.Equal(t, "provider-tc-1", providerTC.ToolCallID)
assert.True(t, providerTC.ProviderExecuted)
// Part 1: provider tool result (inline in assistant turn).
providerTR, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](assistantMsg.Content[1])
require.True(t, ok, "part 1 should be ToolResultPart")
assert.Equal(t, "provider-tc-1", providerTR.ToolCallID)
assert.True(t, providerTR.ProviderExecuted)
// Part 2: local tool call.
localTC, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](assistantMsg.Content[2])
require.True(t, ok, "part 2 should be ToolCallPart")
assert.Equal(t, "local-tc-1", localTC.ToolCallID)
assert.False(t, localTC.ProviderExecuted)
// Second message: tool role.
toolMsg := msgs[1]
assert.Equal(t, fantasy.MessageRoleTool, toolMsg.Role)
require.Len(t, toolMsg.Content, 1,
"tool message should have only the local ToolResultPart")
localTR, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](toolMsg.Content[0])
require.True(t, ok, "tool part should be ToolResultPart")
assert.Equal(t, "local-tc-1", localTR.ToolCallID)
assert.False(t, localTR.ProviderExecuted)
}
func hasAnthropicEphemeralCacheControl(message fantasy.Message) bool {
if len(message.ProviderOptions) == 0 {
return false
}
options, ok := message.ProviderOptions[fantasyanthropic.Name]
if !ok {
return false
}
cacheOptions, ok := options.(*fantasyanthropic.ProviderCacheControlOptions)
return ok && cacheOptions.CacheControl.Type == "ephemeral"
}
// TestRun_InterruptedDuringToolExecutionPersistsStep verifies that when
// tools are executing and the chat is interrupted, the accumulated step
// content (assistant blocks + tool results) is persisted via the
// interrupt-safe path rather than being lost.
func TestRun_InterruptedDuringToolExecutionPersistsStep(t *testing.T) {
t.Parallel()
toolStarted := make(chan struct{})
// Model returns a completed tool call in the stream.
model := &loopTestModel{
provider: "fake",
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
return streamFromParts([]fantasy.StreamPart{
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "calling tool"},
{Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"},
{Type: fantasy.StreamPartTypeReasoningStart, ID: "reason-1"},
{Type: fantasy.StreamPartTypeReasoningDelta, ID: "reason-1", Delta: "let me think"},
{Type: fantasy.StreamPartTypeReasoningEnd, ID: "reason-1"},
{Type: fantasy.StreamPartTypeToolInputStart, ID: "tc-1", ToolCallName: "slow_tool"},
{Type: fantasy.StreamPartTypeToolInputDelta, ID: "tc-1", Delta: `{"key":"value"}`},
{Type: fantasy.StreamPartTypeToolInputEnd, ID: "tc-1"},
{
Type: fantasy.StreamPartTypeToolCall,
ID: "tc-1",
ToolCallName: "slow_tool",
ToolCallInput: `{"key":"value"}`,
},
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonToolCalls},
}), nil
},
}
// Tool that blocks until context is canceled, simulating
// a long-running operation interrupted by the user.
slowTool := fantasy.NewAgentTool(
"slow_tool",
"blocks until canceled",
func(ctx context.Context, _ struct{}, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
close(toolStarted)
<-ctx.Done()
return fantasy.ToolResponse{}, ctx.Err()
},
)
ctx, cancel := context.WithCancelCause(context.Background())
defer cancel(nil)
go func() {
<-toolStarted
cancel(ErrInterrupted)
}()
var persistedContent []fantasy.Content
persistedCtxErr := xerrors.New("unset")
err := Run(ctx, RunOptions{
Model: model,
Messages: []fantasy.Message{
textMessage(fantasy.MessageRoleUser, "run the slow tool"),
},
Tools: []fantasy.AgentTool{slowTool},
MaxSteps: 3,
PersistStep: func(persistCtx context.Context, step PersistedStep) error {
persistedCtxErr = persistCtx.Err()
persistedContent = append([]fantasy.Content(nil), step.Content...)
return nil
},
})
require.ErrorIs(t, err, ErrInterrupted)
// persistInterruptedStep uses context.WithoutCancel, so the
// persist callback should see a non-canceled context.
require.NoError(t, persistedCtxErr)
require.NotEmpty(t, persistedContent)
var (
foundText bool
foundReasoning bool
foundToolCall bool
foundToolResult bool
)
for _, block := range persistedContent {
if text, ok := fantasy.AsContentType[fantasy.TextContent](block); ok {
if strings.Contains(text.Text, "calling tool") {
foundText = true
}
continue
}
if reasoning, ok := fantasy.AsContentType[fantasy.ReasoningContent](block); ok {
if strings.Contains(reasoning.Text, "let me think") {
foundReasoning = true
}
continue
}
if toolCall, ok := fantasy.AsContentType[fantasy.ToolCallContent](block); ok {
if toolCall.ToolCallID == "tc-1" && toolCall.ToolName == "slow_tool" {
foundToolCall = true
}
continue
}
if toolResult, ok := fantasy.AsContentType[fantasy.ToolResultContent](block); ok {
if toolResult.ToolCallID == "tc-1" {
foundToolResult = true
}
}
}
require.True(t, foundText, "persisted content should include text from the stream")
require.True(t, foundReasoning, "persisted content should include reasoning from the stream")
require.True(t, foundToolCall, "persisted content should include the tool call")
require.True(t, foundToolResult, "persisted content should include the tool result (error from cancellation)")
}
// TestRun_PersistStepInterruptedFallback verifies that when the normal
// PersistStep call returns ErrInterrupted (e.g., context canceled in a
// race), the step is retried via the interrupt-safe path.
func TestRun_PersistStepInterruptedFallback(t *testing.T) {
t.Parallel()
model := &loopTestModel{
provider: "fake",
streamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
return streamFromParts([]fantasy.StreamPart{
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "hello world"},
{Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"},
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop},
}), nil
},
}
var (
mu sync.Mutex
persistCalls int
savedContent []fantasy.Content
)
err := Run(context.Background(), RunOptions{
Model: model,
Messages: []fantasy.Message{
textMessage(fantasy.MessageRoleUser, "hello"),
},
MaxSteps: 1,
PersistStep: func(_ context.Context, step PersistedStep) error {
mu.Lock()
defer mu.Unlock()
persistCalls++
if persistCalls == 1 {
// First call: simulate an interrupt race by
// returning ErrInterrupted without persisting.
return ErrInterrupted
}
// Second call (from persistInterruptedStep fallback):
// accept the content.
savedContent = append([]fantasy.Content(nil), step.Content...)
return nil
},
})
require.ErrorIs(t, err, ErrInterrupted)
mu.Lock()
defer mu.Unlock()
require.Equal(t, 2, persistCalls, "PersistStep should be called twice: once normally (failing), once via fallback")
require.NotEmpty(t, savedContent)
var foundText bool
for _, block := range savedContent {
if text, ok := fantasy.AsContentType[fantasy.TextContent](block); ok {
if strings.Contains(text.Text, "hello world") {
foundText = true
}
}
}
require.True(t, foundText, "fallback should persist the text content")
}