mirror of
https://github.com/coder/coder.git
synced 2026-06-03 04:58:23 +00:00
ddfe630757
## Summary Replaces fantasy's `Agent` abstraction with a direct step loop calling `LanguageModel.Stream()`. Fantasy is retained as the provider abstraction layer (streaming parsers, types, tool schema) but we no longer use `fantasy.Agent`, `AgentStreamCall`, `AgentResult`, or `StepResult`. ## Problems solved | Problem | Before | After | |---|---|---| | **Sentinel prompt hack** | fantasy.Agent requires non-empty Prompt → UUID sentinel generated and stripped in PrepareStep | Messages passed directly to `model.Stream()` | | **Discarded PersistStep errors** | `_ = opts.OnStepFinish(result)` silently swallows errors | Errors propagate directly from `PersistStep()` | | **Shadow draft state** | ~160 LOC tracking content in parallel because fantasy doesn't expose in-progress content on interruption | `stepResult` owns content directly; `flushActiveState()` is trivial | | **Nested retry layers** | fantasy's 2-attempt retry nested inside chatretry's indefinite retry | Single `chatretry.Retry` layer | | **Callback-mediated compaction** | Mutex + boolean flag + coordination between OnStepFinish/PrepareStep callbacks | Inline `if` statement between steps | | **Duplicate compaction paths** | `compactStep()` + `maybeCompact()` sharing ~80% logic | Single `tryCompact()` function | ## Changes ### `coderd/chatd/chatloop/chatloop.go` — Rewritten - **Removed**: `fantasy.NewAgent()`, `AgentStreamCall`, sentinel prompt, shadow draft state (~160 LOC of closures), `compactedMu`/`compacted` flag, `PrepareStepResult` - **Added**: `stepResult` struct, `processStepStream()` (stream consumer), `executeTools()` (sequential tool execution), `flushActiveState()` (interrupt handling), `buildToolDefinitions()`, `toResponseMessages()` - **Changed**: `Run()` return type from `(*fantasy.AgentResult, error)` to `error` (callers already discarded the result) - **Preserved**: Anthropic prompt caching, reasoning title extraction, `extractContextLimit()`, `ErrInterrupted` semantics ### `coderd/chatd/chatloop/compaction.go` — Simplified - Merged `compactStep()` + `maybeCompact()` → single `tryCompact()` - Removed `[]StepResult` parameter from `generateCompactionSummary()` (caller provides complete message list) - Kept helper functions: `normalizedCompactionConfig`, `contextTokensFromUsage`, `resolveContextLimit`, `shouldCompact` ### `coderd/chatd/chatd.go` — Caller updates - Removed `AgentStreamCall` construction - Changed `_, err = chatloop.Run(...)` to `err = chatloop.Run(...)` - Model parameters moved from `AgentStreamCall` fields to `RunOptions` fields ### Tests — 4 new tests - `MidLoopCompactionReloadsMessages` — compaction fires mid-loop, messages reloaded - `PostRunCompactionSkippedAfterMidLoop` — no double compaction - `MultiStepToolExecution` — tools execute between steps, results feed next step - `PersistStepErrorPropagates` — persistence errors propagate (was silently discarded)
956 lines
27 KiB
Go
956 lines
27 KiB
Go
package chatloop
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"errors"
|
|
"slices"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"charm.land/fantasy"
|
|
fantasyanthropic "charm.land/fantasy/providers/anthropic"
|
|
"charm.land/fantasy/schema"
|
|
"golang.org/x/xerrors"
|
|
|
|
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
|
|
"github.com/coder/coder/v2/coderd/chatd/chatretry"
|
|
"github.com/coder/coder/v2/codersdk"
|
|
)
|
|
|
|
const (
|
|
interruptedToolResultErrorMessage = "tool call was interrupted before it produced a result"
|
|
)
|
|
|
|
var ErrInterrupted = xerrors.New("chat interrupted")
|
|
|
|
// PersistedStep contains the full content of a completed or
|
|
// interrupted agent step. Content includes both assistant blocks
|
|
// (text, reasoning, tool calls) and tool result blocks. The
|
|
// persistence layer is responsible for splitting these into
|
|
// separate database messages by role.
|
|
type PersistedStep struct {
|
|
Content []fantasy.Content
|
|
Usage fantasy.Usage
|
|
ContextLimit sql.NullInt64
|
|
}
|
|
|
|
// RunOptions configures a single streaming chat loop run.
|
|
type RunOptions struct {
|
|
Model fantasy.LanguageModel
|
|
Messages []fantasy.Message
|
|
Tools []fantasy.AgentTool
|
|
MaxSteps int
|
|
|
|
ActiveTools []string
|
|
ContextLimitFallback int64
|
|
|
|
// ModelConfig holds per-call LLM parameters (temperature,
|
|
// max tokens, etc.) read from the chat model configuration.
|
|
ModelConfig codersdk.ChatModelCallConfig
|
|
// ProviderOptions are provider-specific call options
|
|
// converted from ModelConfig.ProviderOptions. This is a
|
|
// separate field because the conversion requires knowledge
|
|
// of the provider, which lives in chatd, not chatloop.
|
|
ProviderOptions fantasy.ProviderOptions
|
|
|
|
PersistStep func(context.Context, PersistedStep) error
|
|
PublishMessagePart func(
|
|
role fantasy.MessageRole,
|
|
part codersdk.ChatMessagePart,
|
|
)
|
|
Compaction *CompactionOptions
|
|
ReloadMessages func(context.Context) ([]fantasy.Message, error)
|
|
|
|
// OnRetry is called before each retry attempt when the LLM
|
|
// stream fails with a retryable error. It provides the attempt
|
|
// number, error, and backoff delay so callers can publish status
|
|
// events to connected clients.
|
|
OnRetry chatretry.OnRetryFn
|
|
|
|
OnInterruptedPersistError func(error)
|
|
}
|
|
|
|
// stepResult holds the accumulated output of a single streaming
|
|
// step. Since we own the stream consumer, all content is tracked
|
|
// directly here — no shadow draft state needed.
|
|
type stepResult struct {
|
|
content []fantasy.Content
|
|
usage fantasy.Usage
|
|
providerMetadata fantasy.ProviderMetadata
|
|
finishReason fantasy.FinishReason
|
|
toolCalls []fantasy.ToolCallContent
|
|
shouldContinue bool
|
|
}
|
|
|
|
// toResponseMessages converts step content into messages suitable
|
|
// for appending to the conversation. Mirrors fantasy's
|
|
// toResponseMessages logic.
|
|
func (r stepResult) toResponseMessages() []fantasy.Message {
|
|
var assistantParts []fantasy.MessagePart
|
|
var toolParts []fantasy.MessagePart
|
|
|
|
for _, c := range r.content {
|
|
switch c.GetType() {
|
|
case fantasy.ContentTypeText:
|
|
text, ok := fantasy.AsContentType[fantasy.TextContent](c)
|
|
if !ok {
|
|
continue
|
|
}
|
|
assistantParts = append(assistantParts, fantasy.TextPart{
|
|
Text: text.Text,
|
|
ProviderOptions: fantasy.ProviderOptions(text.ProviderMetadata),
|
|
})
|
|
case fantasy.ContentTypeReasoning:
|
|
reasoning, ok := fantasy.AsContentType[fantasy.ReasoningContent](c)
|
|
if !ok {
|
|
continue
|
|
}
|
|
assistantParts = append(assistantParts, fantasy.ReasoningPart{
|
|
Text: reasoning.Text,
|
|
ProviderOptions: fantasy.ProviderOptions(reasoning.ProviderMetadata),
|
|
})
|
|
case fantasy.ContentTypeToolCall:
|
|
toolCall, ok := fantasy.AsContentType[fantasy.ToolCallContent](c)
|
|
if !ok {
|
|
continue
|
|
}
|
|
assistantParts = append(assistantParts, fantasy.ToolCallPart{
|
|
ToolCallID: toolCall.ToolCallID,
|
|
ToolName: toolCall.ToolName,
|
|
Input: toolCall.Input,
|
|
ProviderExecuted: toolCall.ProviderExecuted,
|
|
ProviderOptions: fantasy.ProviderOptions(toolCall.ProviderMetadata),
|
|
})
|
|
case fantasy.ContentTypeFile:
|
|
file, ok := fantasy.AsContentType[fantasy.FileContent](c)
|
|
if !ok {
|
|
continue
|
|
}
|
|
assistantParts = append(assistantParts, fantasy.FilePart{
|
|
Data: file.Data,
|
|
MediaType: file.MediaType,
|
|
ProviderOptions: fantasy.ProviderOptions(file.ProviderMetadata),
|
|
})
|
|
case fantasy.ContentTypeSource:
|
|
// Sources are metadata about references; they don't
|
|
// need to be included in conversation messages.
|
|
continue
|
|
case fantasy.ContentTypeToolResult:
|
|
result, ok := fantasy.AsContentType[fantasy.ToolResultContent](c)
|
|
if !ok {
|
|
continue
|
|
}
|
|
toolParts = append(toolParts, fantasy.ToolResultPart{
|
|
ToolCallID: result.ToolCallID,
|
|
Output: result.Result,
|
|
ProviderOptions: fantasy.ProviderOptions(result.ProviderMetadata),
|
|
})
|
|
default:
|
|
continue
|
|
}
|
|
}
|
|
|
|
var messages []fantasy.Message
|
|
if len(assistantParts) > 0 {
|
|
messages = append(messages, fantasy.Message{
|
|
Role: fantasy.MessageRoleAssistant,
|
|
Content: assistantParts,
|
|
})
|
|
}
|
|
if len(toolParts) > 0 {
|
|
messages = append(messages, fantasy.Message{
|
|
Role: fantasy.MessageRoleTool,
|
|
Content: toolParts,
|
|
})
|
|
}
|
|
return messages
|
|
}
|
|
|
|
// reasoningState accumulates reasoning content and provider
|
|
// metadata while the stream is in flight.
|
|
type reasoningState struct {
|
|
text string
|
|
options fantasy.ProviderMetadata
|
|
}
|
|
|
|
// Run executes the chat step-stream loop and delegates
|
|
// persistence/publishing to callbacks.
|
|
func Run(ctx context.Context, opts RunOptions) error {
|
|
if opts.Model == nil {
|
|
return xerrors.New("chat model is required")
|
|
}
|
|
if opts.PersistStep == nil {
|
|
return xerrors.New("persist step callback is required")
|
|
}
|
|
if opts.MaxSteps <= 0 {
|
|
opts.MaxSteps = 1
|
|
}
|
|
|
|
publishMessagePart := func(role fantasy.MessageRole, part codersdk.ChatMessagePart) {
|
|
if opts.PublishMessagePart == nil {
|
|
return
|
|
}
|
|
opts.PublishMessagePart(role, part)
|
|
}
|
|
|
|
tools := buildToolDefinitions(opts.Tools, opts.ActiveTools)
|
|
applyAnthropicCaching := shouldApplyAnthropicPromptCaching(opts.Model)
|
|
|
|
messages := opts.Messages
|
|
alreadyCompacted := false
|
|
var lastUsage fantasy.Usage
|
|
var lastProviderMetadata fantasy.ProviderMetadata
|
|
|
|
for step := 0; step < opts.MaxSteps; step++ {
|
|
// Copy messages so that provider-specific caching
|
|
// mutations don't leak back to the caller's slice.
|
|
// copy copies Message structs by value, so field
|
|
// reassignments in addAnthropicPromptCaching only
|
|
// affect the prepared slice.
|
|
prepared := make([]fantasy.Message, len(messages))
|
|
copy(prepared, messages)
|
|
if applyAnthropicCaching {
|
|
addAnthropicPromptCaching(prepared)
|
|
}
|
|
|
|
call := fantasy.Call{
|
|
Prompt: prepared,
|
|
Tools: tools,
|
|
MaxOutputTokens: opts.ModelConfig.MaxOutputTokens,
|
|
Temperature: opts.ModelConfig.Temperature,
|
|
TopP: opts.ModelConfig.TopP,
|
|
TopK: opts.ModelConfig.TopK,
|
|
PresencePenalty: opts.ModelConfig.PresencePenalty,
|
|
FrequencyPenalty: opts.ModelConfig.FrequencyPenalty,
|
|
ProviderOptions: opts.ProviderOptions,
|
|
}
|
|
|
|
var result stepResult
|
|
err := chatretry.Retry(ctx, func(retryCtx context.Context) error {
|
|
stream, streamErr := opts.Model.Stream(retryCtx, call)
|
|
if streamErr != nil {
|
|
return streamErr
|
|
}
|
|
var processErr error
|
|
result, processErr = processStepStream(retryCtx, stream, publishMessagePart)
|
|
return processErr
|
|
}, func(attempt int, retryErr error, delay time.Duration) {
|
|
// Reset result from the failed attempt so the next
|
|
// attempt starts clean.
|
|
result = stepResult{}
|
|
if opts.OnRetry != nil {
|
|
opts.OnRetry(attempt, retryErr, delay)
|
|
}
|
|
})
|
|
if err != nil {
|
|
if errors.Is(err, ErrInterrupted) {
|
|
persistInterruptedStep(ctx, opts, &result)
|
|
return ErrInterrupted
|
|
}
|
|
return xerrors.Errorf("stream response: %w", err)
|
|
}
|
|
|
|
// Execute tools before persisting so that tool results
|
|
// are included in the persisted step content. The
|
|
// persistence layer splits assistant and tool-result
|
|
// blocks into separate database messages by role.
|
|
var toolResults []fantasy.ToolResultContent
|
|
if result.shouldContinue {
|
|
// Check for context cancellation before starting
|
|
// tool execution. If the chat was interrupted
|
|
// between stream completion and here, persist
|
|
// what we have and bail out.
|
|
if ctx.Err() != nil {
|
|
if errors.Is(context.Cause(ctx), ErrInterrupted) {
|
|
persistInterruptedStep(ctx, opts, &result)
|
|
return ErrInterrupted
|
|
}
|
|
return ctx.Err()
|
|
}
|
|
|
|
toolResults = executeTools(ctx, opts.Tools, result.toolCalls, func(tr fantasy.ToolResultContent) {
|
|
publishMessagePart(
|
|
fantasy.MessageRoleTool,
|
|
chatprompt.PartFromContent(tr),
|
|
)
|
|
})
|
|
for _, tr := range toolResults {
|
|
result.content = append(result.content, tr)
|
|
}
|
|
}
|
|
|
|
// Extract context limit from provider metadata.
|
|
contextLimit := extractContextLimit(result.providerMetadata)
|
|
if !contextLimit.Valid && opts.ContextLimitFallback > 0 {
|
|
contextLimit = sql.NullInt64{
|
|
Int64: opts.ContextLimitFallback,
|
|
Valid: true,
|
|
}
|
|
}
|
|
|
|
// Persist the step — errors propagate directly.
|
|
if err := opts.PersistStep(ctx, PersistedStep{
|
|
Content: result.content,
|
|
Usage: result.usage,
|
|
ContextLimit: contextLimit,
|
|
}); err != nil {
|
|
return xerrors.Errorf("persist step: %w", err)
|
|
}
|
|
|
|
lastUsage = result.usage
|
|
lastProviderMetadata = result.providerMetadata
|
|
|
|
// Inline compaction.
|
|
if opts.Compaction != nil && opts.ReloadMessages != nil {
|
|
did, compactErr := tryCompact(
|
|
ctx,
|
|
opts.Model,
|
|
opts.Compaction,
|
|
opts.ContextLimitFallback,
|
|
result.usage,
|
|
result.providerMetadata,
|
|
messages,
|
|
)
|
|
if compactErr != nil && opts.Compaction.OnError != nil {
|
|
opts.Compaction.OnError(compactErr)
|
|
}
|
|
if did {
|
|
alreadyCompacted = true
|
|
reloaded, reloadErr := opts.ReloadMessages(ctx)
|
|
if reloadErr != nil {
|
|
return xerrors.Errorf("reload messages after compaction: %w", reloadErr)
|
|
}
|
|
messages = reloaded
|
|
}
|
|
}
|
|
|
|
if !result.shouldContinue {
|
|
break
|
|
}
|
|
|
|
// Build messages from the step for the next iteration.
|
|
// toResponseMessages produces assistant-role content
|
|
// (text, reasoning, tool calls) and tool-result content.
|
|
stepMessages := result.toResponseMessages()
|
|
messages = append(messages, stepMessages...)
|
|
}
|
|
|
|
// Post-run compaction safety net: if we never compacted
|
|
// during the loop, try once at the end.
|
|
if !alreadyCompacted && opts.Compaction != nil {
|
|
if _, err := tryCompact(
|
|
ctx,
|
|
opts.Model,
|
|
opts.Compaction,
|
|
opts.ContextLimitFallback,
|
|
lastUsage,
|
|
lastProviderMetadata,
|
|
messages,
|
|
); err != nil {
|
|
if opts.Compaction.OnError != nil {
|
|
opts.Compaction.OnError(err)
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// processStepStream consumes a fantasy StreamResponse and
|
|
// accumulates all content into a stepResult. Callbacks fire
|
|
// inline and their errors propagate directly.
|
|
func processStepStream(
|
|
ctx context.Context,
|
|
stream fantasy.StreamResponse,
|
|
publishMessagePart func(fantasy.MessageRole, codersdk.ChatMessagePart),
|
|
) (stepResult, error) {
|
|
var result stepResult
|
|
|
|
activeToolCalls := make(map[string]*fantasy.ToolCallContent)
|
|
activeTextContent := make(map[string]string)
|
|
activeReasoningContent := make(map[string]reasoningState)
|
|
// Track tool names by ID for input delta publishing.
|
|
toolNames := make(map[string]string)
|
|
// Track reasoning text/titles for title extraction.
|
|
reasoningTitles := make(map[string]string)
|
|
reasoningText := make(map[string]string)
|
|
|
|
setReasoningTitleFromText := func(id string, text string) {
|
|
if id == "" || strings.TrimSpace(text) == "" {
|
|
return
|
|
}
|
|
if reasoningTitles[id] != "" {
|
|
return
|
|
}
|
|
reasoningText[id] += text
|
|
if !strings.ContainsAny(reasoningText[id], "\r\n") {
|
|
return
|
|
}
|
|
title := chatprompt.ReasoningTitleFromFirstLine(reasoningText[id])
|
|
if title == "" {
|
|
return
|
|
}
|
|
reasoningTitles[id] = title
|
|
}
|
|
|
|
for part := range stream {
|
|
switch part.Type {
|
|
case fantasy.StreamPartTypeTextStart:
|
|
activeTextContent[part.ID] = ""
|
|
|
|
case fantasy.StreamPartTypeTextDelta:
|
|
if _, exists := activeTextContent[part.ID]; exists {
|
|
activeTextContent[part.ID] += part.Delta
|
|
}
|
|
publishMessagePart(fantasy.MessageRoleAssistant, codersdk.ChatMessagePart{
|
|
Type: codersdk.ChatMessagePartTypeText,
|
|
Text: part.Delta,
|
|
})
|
|
|
|
case fantasy.StreamPartTypeTextEnd:
|
|
if text, exists := activeTextContent[part.ID]; exists {
|
|
result.content = append(result.content, fantasy.TextContent{
|
|
Text: text,
|
|
ProviderMetadata: part.ProviderMetadata,
|
|
})
|
|
delete(activeTextContent, part.ID)
|
|
}
|
|
|
|
case fantasy.StreamPartTypeReasoningStart:
|
|
activeReasoningContent[part.ID] = reasoningState{
|
|
text: part.Delta,
|
|
options: part.ProviderMetadata,
|
|
}
|
|
|
|
case fantasy.StreamPartTypeReasoningDelta:
|
|
if active, exists := activeReasoningContent[part.ID]; exists {
|
|
active.text += part.Delta
|
|
active.options = part.ProviderMetadata
|
|
activeReasoningContent[part.ID] = active
|
|
}
|
|
setReasoningTitleFromText(part.ID, part.Delta)
|
|
title := reasoningTitles[part.ID]
|
|
publishMessagePart(fantasy.MessageRoleAssistant, codersdk.ChatMessagePart{
|
|
Type: codersdk.ChatMessagePartTypeReasoning,
|
|
Text: part.Delta,
|
|
Title: title,
|
|
})
|
|
|
|
case fantasy.StreamPartTypeReasoningEnd:
|
|
if active, exists := activeReasoningContent[part.ID]; exists {
|
|
if part.ProviderMetadata != nil {
|
|
active.options = part.ProviderMetadata
|
|
}
|
|
content := fantasy.ReasoningContent{
|
|
Text: active.text,
|
|
ProviderMetadata: active.options,
|
|
}
|
|
result.content = append(result.content, content)
|
|
delete(activeReasoningContent, part.ID)
|
|
|
|
// Derive reasoning title at end of reasoning
|
|
// block if we haven't yet.
|
|
if reasoningTitles[part.ID] == "" {
|
|
reasoningTitles[part.ID] = chatprompt.ReasoningTitleFromFirstLine(
|
|
reasoningText[part.ID],
|
|
)
|
|
}
|
|
title := reasoningTitles[part.ID]
|
|
if title != "" {
|
|
publishMessagePart(fantasy.MessageRoleAssistant, codersdk.ChatMessagePart{
|
|
Type: codersdk.ChatMessagePartTypeReasoning,
|
|
Title: title,
|
|
})
|
|
}
|
|
}
|
|
|
|
case fantasy.StreamPartTypeToolInputStart:
|
|
activeToolCalls[part.ID] = &fantasy.ToolCallContent{
|
|
ToolCallID: part.ID,
|
|
ToolName: part.ToolCallName,
|
|
Input: "",
|
|
ProviderExecuted: part.ProviderExecuted,
|
|
}
|
|
if strings.TrimSpace(part.ToolCallName) != "" {
|
|
toolNames[part.ID] = part.ToolCallName
|
|
}
|
|
|
|
case fantasy.StreamPartTypeToolInputDelta:
|
|
if toolCall, exists := activeToolCalls[part.ID]; exists {
|
|
toolCall.Input += part.Delta
|
|
}
|
|
toolName := toolNames[part.ID]
|
|
publishMessagePart(fantasy.MessageRoleAssistant, codersdk.ChatMessagePart{
|
|
Type: codersdk.ChatMessagePartTypeToolCall,
|
|
ToolCallID: part.ID,
|
|
ToolName: toolName,
|
|
ArgsDelta: part.Delta,
|
|
})
|
|
|
|
case fantasy.StreamPartTypeToolInputEnd:
|
|
// No callback needed; the full tool call arrives in
|
|
// StreamPartTypeToolCall.
|
|
|
|
case fantasy.StreamPartTypeToolCall:
|
|
tc := fantasy.ToolCallContent{
|
|
ToolCallID: part.ID,
|
|
ToolName: part.ToolCallName,
|
|
Input: part.ToolCallInput,
|
|
ProviderExecuted: part.ProviderExecuted,
|
|
ProviderMetadata: part.ProviderMetadata,
|
|
}
|
|
result.toolCalls = append(result.toolCalls, tc)
|
|
result.content = append(result.content, tc)
|
|
if strings.TrimSpace(part.ToolCallName) != "" {
|
|
toolNames[part.ID] = part.ToolCallName
|
|
}
|
|
// Clean up active tool call tracking.
|
|
delete(activeToolCalls, part.ID)
|
|
|
|
publishMessagePart(
|
|
fantasy.MessageRoleAssistant,
|
|
chatprompt.PartFromContent(tc),
|
|
)
|
|
|
|
case fantasy.StreamPartTypeSource:
|
|
sourceContent := fantasy.SourceContent{
|
|
SourceType: part.SourceType,
|
|
ID: part.ID,
|
|
URL: part.URL,
|
|
Title: part.Title,
|
|
ProviderMetadata: part.ProviderMetadata,
|
|
}
|
|
result.content = append(result.content, sourceContent)
|
|
publishMessagePart(
|
|
fantasy.MessageRoleAssistant,
|
|
chatprompt.PartFromContent(sourceContent),
|
|
)
|
|
|
|
case fantasy.StreamPartTypeFinish:
|
|
result.usage = part.Usage
|
|
result.finishReason = part.FinishReason
|
|
result.providerMetadata = part.ProviderMetadata
|
|
|
|
case fantasy.StreamPartTypeError:
|
|
// Detect interruption: context canceled with
|
|
// ErrInterrupted as the cause.
|
|
if errors.Is(part.Error, context.Canceled) &&
|
|
errors.Is(context.Cause(ctx), ErrInterrupted) {
|
|
// Flush in-progress content so that
|
|
// persistInterruptedStep has access to partial
|
|
// text, reasoning, and tool calls that were
|
|
// still streaming when the interrupt arrived.
|
|
flushActiveState(
|
|
&result,
|
|
activeTextContent,
|
|
activeReasoningContent,
|
|
activeToolCalls,
|
|
toolNames,
|
|
)
|
|
return result, ErrInterrupted
|
|
}
|
|
return result, part.Error
|
|
}
|
|
}
|
|
|
|
result.shouldContinue = len(result.toolCalls) > 0 &&
|
|
result.finishReason == fantasy.FinishReasonToolCalls
|
|
return result, nil
|
|
}
|
|
|
|
// executeTools runs each tool call sequentially after the stream
|
|
// completes. Results are published via onResult as each tool
|
|
// finishes.
|
|
func executeTools(
|
|
ctx context.Context,
|
|
allTools []fantasy.AgentTool,
|
|
toolCalls []fantasy.ToolCallContent,
|
|
onResult func(fantasy.ToolResultContent),
|
|
) []fantasy.ToolResultContent {
|
|
if len(toolCalls) == 0 {
|
|
return nil
|
|
}
|
|
|
|
toolMap := make(map[string]fantasy.AgentTool, len(allTools))
|
|
for _, t := range allTools {
|
|
toolMap[t.Info().Name] = t
|
|
}
|
|
|
|
results := make([]fantasy.ToolResultContent, 0, len(toolCalls))
|
|
for _, tc := range toolCalls {
|
|
tr := executeSingleTool(ctx, toolMap, tc)
|
|
results = append(results, tr)
|
|
if onResult != nil {
|
|
onResult(tr)
|
|
}
|
|
}
|
|
return results
|
|
}
|
|
|
|
// executeSingleTool executes one tool call and converts the
|
|
// response into a ToolResultContent.
|
|
func executeSingleTool(
|
|
ctx context.Context,
|
|
toolMap map[string]fantasy.AgentTool,
|
|
tc fantasy.ToolCallContent,
|
|
) fantasy.ToolResultContent {
|
|
result := fantasy.ToolResultContent{
|
|
ToolCallID: tc.ToolCallID,
|
|
ToolName: tc.ToolName,
|
|
ProviderExecuted: false,
|
|
}
|
|
|
|
tool, exists := toolMap[tc.ToolName]
|
|
if !exists {
|
|
result.Result = fantasy.ToolResultOutputContentError{
|
|
Error: xerrors.New("Tool not found: " + tc.ToolName),
|
|
}
|
|
return result
|
|
}
|
|
|
|
resp, err := tool.Run(ctx, fantasy.ToolCall{
|
|
ID: tc.ToolCallID,
|
|
Name: tc.ToolName,
|
|
Input: tc.Input,
|
|
})
|
|
if err != nil {
|
|
result.Result = fantasy.ToolResultOutputContentError{
|
|
Error: err,
|
|
}
|
|
result.ClientMetadata = resp.Metadata
|
|
return result
|
|
}
|
|
|
|
result.ClientMetadata = resp.Metadata
|
|
switch {
|
|
case resp.IsError:
|
|
result.Result = fantasy.ToolResultOutputContentError{
|
|
Error: xerrors.New(resp.Content),
|
|
}
|
|
case resp.Type == "image" || resp.Type == "media":
|
|
result.Result = fantasy.ToolResultOutputContentMedia{
|
|
Data: string(resp.Data),
|
|
MediaType: resp.MediaType,
|
|
Text: resp.Content,
|
|
}
|
|
default:
|
|
result.Result = fantasy.ToolResultOutputContentText{
|
|
Text: resp.Content,
|
|
}
|
|
}
|
|
return result
|
|
}
|
|
|
|
// flushActiveState moves any in-progress text, reasoning, and
|
|
// tool calls from the active tracking maps into result.content
|
|
// and result.toolCalls. This is called on interruption so that
|
|
// partial content from an incomplete stream is available for
|
|
// persistence.
|
|
func flushActiveState(
|
|
result *stepResult,
|
|
activeText map[string]string,
|
|
activeReasoning map[string]reasoningState,
|
|
activeToolCalls map[string]*fantasy.ToolCallContent,
|
|
toolNames map[string]string,
|
|
) {
|
|
// Flush partial text content.
|
|
for _, text := range activeText {
|
|
if text != "" {
|
|
result.content = append(result.content, fantasy.TextContent{Text: text})
|
|
}
|
|
}
|
|
|
|
// Flush partial reasoning content.
|
|
for _, rs := range activeReasoning {
|
|
if rs.text != "" {
|
|
result.content = append(result.content, fantasy.ReasoningContent{
|
|
Text: rs.text,
|
|
ProviderMetadata: rs.options,
|
|
})
|
|
}
|
|
}
|
|
|
|
// Flush in-progress tool calls. These haven't received a
|
|
// StreamPartTypeToolCall yet, so they only exist in
|
|
// activeToolCalls. We add them to both content and toolCalls
|
|
// so persistInterruptedStep can generate synthetic error
|
|
// results for them.
|
|
for id, tc := range activeToolCalls {
|
|
if tc == nil {
|
|
continue
|
|
}
|
|
// Prefer the tool name from the toolNames map since
|
|
// ToolInputStart may provide a cleaner name.
|
|
toolName := tc.ToolName
|
|
if name, ok := toolNames[id]; ok && strings.TrimSpace(name) != "" {
|
|
toolName = name
|
|
}
|
|
flushed := fantasy.ToolCallContent{
|
|
ToolCallID: tc.ToolCallID,
|
|
ToolName: toolName,
|
|
Input: tc.Input,
|
|
ProviderExecuted: tc.ProviderExecuted,
|
|
}
|
|
result.content = append(result.content, flushed)
|
|
result.toolCalls = append(result.toolCalls, flushed)
|
|
}
|
|
}
|
|
|
|
// persistInterruptedStep saves all accumulated content from a
|
|
// partial stream. Since we own the stepResult directly, no shadow
|
|
// state is needed.
|
|
func persistInterruptedStep(
|
|
ctx context.Context,
|
|
opts RunOptions,
|
|
result *stepResult,
|
|
) {
|
|
if result == nil || (len(result.content) == 0 && len(result.toolCalls) == 0) {
|
|
return
|
|
}
|
|
|
|
// Track which tool calls already have results in the content.
|
|
answeredToolCalls := make(map[string]struct{})
|
|
for _, c := range result.content {
|
|
tr, ok := fantasy.AsContentType[fantasy.ToolResultContent](c)
|
|
if ok && tr.ToolCallID != "" {
|
|
answeredToolCalls[tr.ToolCallID] = struct{}{}
|
|
}
|
|
}
|
|
|
|
// Build combined content: all accumulated content + synthetic
|
|
// interrupted results for any unanswered tool calls.
|
|
content := make([]fantasy.Content, 0, len(result.content))
|
|
content = append(content, result.content...)
|
|
|
|
for _, tc := range result.toolCalls {
|
|
if tc.ToolCallID == "" {
|
|
continue
|
|
}
|
|
if _, exists := answeredToolCalls[tc.ToolCallID]; exists {
|
|
continue
|
|
}
|
|
content = append(content, fantasy.ToolResultContent{
|
|
ToolCallID: tc.ToolCallID,
|
|
ToolName: tc.ToolName,
|
|
Result: fantasy.ToolResultOutputContentError{
|
|
Error: xerrors.New(interruptedToolResultErrorMessage),
|
|
},
|
|
})
|
|
answeredToolCalls[tc.ToolCallID] = struct{}{}
|
|
}
|
|
|
|
persistCtx := context.WithoutCancel(ctx)
|
|
if err := opts.PersistStep(persistCtx, PersistedStep{
|
|
Content: content,
|
|
}); err != nil {
|
|
if opts.OnInterruptedPersistError != nil {
|
|
opts.OnInterruptedPersistError(err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// buildToolDefinitions converts AgentTool definitions into the
|
|
// fantasy.Tool slice expected by fantasy.Call. When activeTools
|
|
// is non-empty, only tools whose name appears in the list are
|
|
// included. This mirrors fantasy's agent.prepareTools filtering.
|
|
func buildToolDefinitions(tools []fantasy.AgentTool, activeTools []string) []fantasy.Tool {
|
|
prepared := make([]fantasy.Tool, 0, len(tools))
|
|
for _, tool := range tools {
|
|
info := tool.Info()
|
|
if len(activeTools) > 0 && !slices.Contains(activeTools, info.Name) {
|
|
continue
|
|
}
|
|
inputSchema := map[string]any{
|
|
"type": "object",
|
|
"properties": info.Parameters,
|
|
"required": info.Required,
|
|
}
|
|
schema.Normalize(inputSchema)
|
|
prepared = append(prepared, fantasy.FunctionTool{
|
|
Name: info.Name,
|
|
Description: info.Description,
|
|
InputSchema: inputSchema,
|
|
ProviderOptions: tool.ProviderOptions(),
|
|
})
|
|
}
|
|
return prepared
|
|
}
|
|
|
|
func shouldApplyAnthropicPromptCaching(model fantasy.LanguageModel) bool {
|
|
if model == nil {
|
|
return false
|
|
}
|
|
return model.Provider() == fantasyanthropic.Name
|
|
}
|
|
|
|
// addAnthropicPromptCaching mutates messages in-place, setting
|
|
// ProviderOptions for Anthropic prompt caching on the last system
|
|
// message and the final two messages.
|
|
func addAnthropicPromptCaching(messages []fantasy.Message) {
|
|
for i := range messages {
|
|
messages[i].ProviderOptions = nil
|
|
}
|
|
|
|
providerOption := fantasy.ProviderOptions{
|
|
fantasyanthropic.Name: &fantasyanthropic.ProviderCacheControlOptions{
|
|
CacheControl: fantasyanthropic.CacheControl{Type: "ephemeral"},
|
|
},
|
|
}
|
|
|
|
lastSystemRoleIdx := -1
|
|
systemMessageUpdated := false
|
|
for i, msg := range messages {
|
|
if msg.Role == fantasy.MessageRoleSystem {
|
|
lastSystemRoleIdx = i
|
|
} else if !systemMessageUpdated && lastSystemRoleIdx >= 0 {
|
|
messages[lastSystemRoleIdx].ProviderOptions = providerOption
|
|
systemMessageUpdated = true
|
|
}
|
|
if i > len(messages)-3 {
|
|
messages[i].ProviderOptions = providerOption
|
|
}
|
|
}
|
|
}
|
|
|
|
func extractContextLimit(metadata fantasy.ProviderMetadata) sql.NullInt64 {
|
|
if len(metadata) == 0 {
|
|
return sql.NullInt64{}
|
|
}
|
|
|
|
encoded, err := json.Marshal(metadata)
|
|
if err != nil || len(encoded) == 0 {
|
|
return sql.NullInt64{}
|
|
}
|
|
|
|
var payload any
|
|
if err := json.Unmarshal(encoded, &payload); err != nil {
|
|
return sql.NullInt64{}
|
|
}
|
|
|
|
limit, ok := findContextLimitValue(payload)
|
|
if !ok {
|
|
return sql.NullInt64{}
|
|
}
|
|
|
|
return sql.NullInt64{
|
|
Int64: limit,
|
|
Valid: true,
|
|
}
|
|
}
|
|
|
|
func findContextLimitValue(value any) (int64, bool) {
|
|
var (
|
|
limit int64
|
|
found bool
|
|
)
|
|
|
|
collectContextLimitValues(value, func(candidate int64) {
|
|
if !found || candidate > limit {
|
|
limit = candidate
|
|
found = true
|
|
}
|
|
})
|
|
|
|
return limit, found
|
|
}
|
|
|
|
func collectContextLimitValues(value any, onValue func(int64)) {
|
|
switch typed := value.(type) {
|
|
case map[string]any:
|
|
for key, child := range typed {
|
|
if isContextLimitKey(key) {
|
|
if numeric, ok := numericContextLimitValue(child); ok {
|
|
onValue(numeric)
|
|
}
|
|
}
|
|
collectContextLimitValues(child, onValue)
|
|
}
|
|
case []any:
|
|
for _, child := range typed {
|
|
collectContextLimitValues(child, onValue)
|
|
}
|
|
}
|
|
}
|
|
|
|
func isContextLimitKey(key string) bool {
|
|
normalized := normalizeMetadataKey(key)
|
|
if normalized == "" {
|
|
return false
|
|
}
|
|
|
|
switch normalized {
|
|
case
|
|
"contextlimit",
|
|
"contextwindow",
|
|
"contextlength",
|
|
"maxcontext",
|
|
"maxcontexttokens",
|
|
"maxinputtokens",
|
|
"maxinputtoken",
|
|
"inputtokenlimit":
|
|
return true
|
|
}
|
|
|
|
return strings.Contains(normalized, "context") &&
|
|
(strings.Contains(normalized, "limit") ||
|
|
strings.Contains(normalized, "window") ||
|
|
strings.Contains(normalized, "length") ||
|
|
strings.HasPrefix(normalized, "max"))
|
|
}
|
|
|
|
func normalizeMetadataKey(key string) string {
|
|
var b strings.Builder
|
|
b.Grow(len(key))
|
|
|
|
for _, r := range key {
|
|
switch {
|
|
case r >= 'a' && r <= 'z':
|
|
_, _ = b.WriteRune(r)
|
|
case r >= 'A' && r <= 'Z':
|
|
_, _ = b.WriteRune(r + ('a' - 'A'))
|
|
case r >= '0' && r <= '9':
|
|
_, _ = b.WriteRune(r)
|
|
}
|
|
}
|
|
|
|
return b.String()
|
|
}
|
|
|
|
func numericContextLimitValue(value any) (int64, bool) {
|
|
switch typed := value.(type) {
|
|
case int64:
|
|
return positiveInt64(typed)
|
|
case int32:
|
|
return positiveInt64(int64(typed))
|
|
case int:
|
|
return positiveInt64(int64(typed))
|
|
case float64:
|
|
casted := int64(typed)
|
|
if typed > 0 && float64(casted) == typed {
|
|
return casted, true
|
|
}
|
|
case string:
|
|
parsed, err := strconv.ParseInt(strings.TrimSpace(typed), 10, 64)
|
|
if err == nil {
|
|
return positiveInt64(parsed)
|
|
}
|
|
case json.Number:
|
|
parsed, err := typed.Int64()
|
|
if err == nil {
|
|
return positiveInt64(parsed)
|
|
}
|
|
}
|
|
|
|
return 0, false
|
|
}
|
|
|
|
func positiveInt64(value int64) (int64, bool) {
|
|
if value <= 0 {
|
|
return 0, false
|
|
}
|
|
return value, true
|
|
}
|