mirror of
https://github.com/coder/coder.git
synced 2026-06-05 14:08:20 +00:00
988 lines
27 KiB
Go
988 lines
27 KiB
Go
package chatloop
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"errors"
|
|
"iter"
|
|
"sync"
|
|
"sync/atomic"
|
|
"testing"
|
|
"time"
|
|
"unicode/utf8"
|
|
|
|
"charm.land/fantasy"
|
|
fantasyanthropic "charm.land/fantasy/providers/anthropic"
|
|
"github.com/prometheus/client_golang/prometheus"
|
|
promtestutil "github.com/prometheus/client_golang/prometheus/testutil"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"golang.org/x/xerrors"
|
|
|
|
"cdr.dev/slog/v3"
|
|
"github.com/coder/coder/v2/coderd/x/chatd/chaterror"
|
|
"github.com/coder/coder/v2/coderd/x/chatd/chatsanitize"
|
|
"github.com/coder/coder/v2/coderd/x/chatd/chattest"
|
|
"github.com/coder/coder/v2/codersdk"
|
|
"github.com/coder/coder/v2/testutil"
|
|
"github.com/coder/quartz"
|
|
)
|
|
|
|
func validWebSearchProviderMetadataForTest() fantasy.ProviderMetadata {
|
|
return fantasy.ProviderMetadata{
|
|
fantasyanthropic.Name: &fantasyanthropic.WebSearchResultMetadata{
|
|
Results: []fantasyanthropic.WebSearchResultItem{
|
|
{
|
|
URL: "https://example.com",
|
|
Title: "Example",
|
|
EncryptedContent: "encrypted",
|
|
},
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
func safeToolCallContent(block fantasy.Content) (fantasy.ToolCallContent, bool) {
|
|
var zero fantasy.ToolCallContent
|
|
switch value := block.(type) {
|
|
case fantasy.ToolCallContent:
|
|
return value, true
|
|
case *fantasy.ToolCallContent:
|
|
if value == nil {
|
|
return zero, false
|
|
}
|
|
return *value, true
|
|
default:
|
|
return zero, false
|
|
}
|
|
}
|
|
|
|
func safeToolResultContent(block fantasy.Content) (fantasy.ToolResultContent, bool) {
|
|
var zero fantasy.ToolResultContent
|
|
switch value := block.(type) {
|
|
case fantasy.ToolResultContent:
|
|
return value, true
|
|
case *fantasy.ToolResultContent:
|
|
if value == nil {
|
|
return zero, false
|
|
}
|
|
return *value, true
|
|
default:
|
|
return zero, false
|
|
}
|
|
}
|
|
|
|
func toolCallContentToPart(toolCall fantasy.ToolCallContent) fantasy.ToolCallPart {
|
|
return fantasy.ToolCallPart{
|
|
ToolCallID: toolCall.ToolCallID,
|
|
ToolName: toolCall.ToolName,
|
|
Input: toolCall.Input,
|
|
ProviderExecuted: toolCall.ProviderExecuted,
|
|
ProviderOptions: fantasy.ProviderOptions(toolCall.ProviderMetadata),
|
|
}
|
|
}
|
|
|
|
func toolResultContentToPart(toolResult fantasy.ToolResultContent) fantasy.ToolResultPart {
|
|
return fantasy.ToolResultPart{
|
|
ToolCallID: toolResult.ToolCallID,
|
|
Output: toolResult.Result,
|
|
ProviderExecuted: toolResult.ProviderExecuted,
|
|
ProviderOptions: fantasy.ProviderOptions(toolResult.ProviderMetadata),
|
|
}
|
|
}
|
|
|
|
func TestStreamSilenceGuard_DisarmAndFireRace(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
for range 128 {
|
|
var cancels atomic.Int32
|
|
guard := newStreamSilenceGuard(quartz.NewReal(), time.Hour, func(err error) {
|
|
if errors.Is(err, errStreamSilenceTimeout) {
|
|
cancels.Add(1)
|
|
}
|
|
})
|
|
|
|
start := make(chan struct{})
|
|
var wg sync.WaitGroup
|
|
wg.Add(2)
|
|
|
|
go func() {
|
|
defer wg.Done()
|
|
<-start
|
|
guard.onTimeout()
|
|
}()
|
|
|
|
go func() {
|
|
defer wg.Done()
|
|
<-start
|
|
guard.Disarm()
|
|
}()
|
|
|
|
close(start)
|
|
wg.Wait()
|
|
|
|
guard.onTimeout()
|
|
guard.Disarm()
|
|
|
|
require.LessOrEqual(t, cancels.Load(), int32(1))
|
|
}
|
|
}
|
|
|
|
func TestStreamSilenceGuard_DisarmPreservesPermanentError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
attemptCtx, cancelAttempt := context.WithCancelCause(context.Background())
|
|
defer cancelAttempt(nil)
|
|
|
|
guard := newStreamSilenceGuard(quartz.NewReal(), time.Hour, cancelAttempt)
|
|
guard.Disarm()
|
|
guard.onTimeout()
|
|
|
|
classified := chaterror.Classify(classifyStreamSilenceTimeout(
|
|
attemptCtx,
|
|
"openai",
|
|
xerrors.New("invalid model"),
|
|
))
|
|
require.Equal(t, codersdk.ChatErrorKindConfig, classified.Kind)
|
|
require.False(t, classified.Retryable)
|
|
require.Nil(t, context.Cause(attemptCtx))
|
|
}
|
|
|
|
func TestGenerateAssistant_ProviderContextSurvivesStreamError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
model := &chattest.FakeModel{
|
|
ProviderName: "openai",
|
|
ModelName: "test-model",
|
|
StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
|
return nil, xerrors.New("upstream returned status 400")
|
|
},
|
|
}
|
|
|
|
_, err := GenerateAssistant(context.Background(), GenerateAssistantOptions{
|
|
Model: model,
|
|
Messages: []fantasy.Message{
|
|
textMessage(fantasy.MessageRoleUser, "hello"),
|
|
},
|
|
})
|
|
require.Error(t, err)
|
|
classified := chaterror.Classify(err)
|
|
require.Equal(t, "openai", classified.Provider)
|
|
require.Equal(t, "OpenAI returned an unexpected error.", classified.Message)
|
|
}
|
|
|
|
func TestGenerateAssistant_HTTP2TransportErrorClassifiedAsRetryableTimeout(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
for _, provider := range []string{"anthropic", "openai"} {
|
|
provider := provider
|
|
t.Run(provider, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
model := &chattest.FakeModel{
|
|
ProviderName: provider,
|
|
ModelName: "test-model",
|
|
StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
|
return nil, xerrors.New("http2: client connection force closed via ClientConn.Close")
|
|
},
|
|
}
|
|
|
|
_, err := GenerateAssistant(context.Background(), GenerateAssistantOptions{
|
|
Model: model,
|
|
})
|
|
require.Error(t, err)
|
|
classified := chaterror.Classify(err)
|
|
require.Equal(t, codersdk.ChatErrorKindTimeout, classified.Kind)
|
|
require.Equal(t, provider, classified.Provider)
|
|
require.True(t, classified.Retryable)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestGenerateAssistant_StreamSilenceTimeoutRetryClassification(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("timeout while opening stream", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
const silenceTimeout = 5 * time.Millisecond
|
|
clock := quartz.NewMock(t)
|
|
trap := clock.Trap().AfterFunc(streamSilenceGuardTimerTag)
|
|
defer trap.Close()
|
|
var calls atomic.Int32
|
|
model := &chattest.FakeModel{
|
|
ProviderName: "openai",
|
|
ModelName: "test-model",
|
|
StreamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
|
if calls.Add(1) == 1 {
|
|
<-ctx.Done()
|
|
return nil, ctx.Err()
|
|
}
|
|
return streamFromParts([]fantasy.StreamPart{{
|
|
Type: fantasy.StreamPartTypeFinish,
|
|
FinishReason: fantasy.FinishReasonStop,
|
|
}}), nil
|
|
},
|
|
}
|
|
done := make(chan error, 1)
|
|
go func() {
|
|
_, err := GenerateAssistant(context.Background(), GenerateAssistantOptions{
|
|
Model: model,
|
|
Clock: clock,
|
|
StreamSilenceTimeout: silenceTimeout,
|
|
})
|
|
done <- err
|
|
}()
|
|
|
|
trap.MustWait(ctx).MustRelease(ctx)
|
|
_, waiter := clock.AdvanceNext()
|
|
waiter.MustWait(ctx)
|
|
require.Error(t, <-done)
|
|
require.Equal(t, int32(1), calls.Load())
|
|
})
|
|
|
|
t.Run("timeout before first part", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
const silenceTimeout = 5 * time.Millisecond
|
|
clock := quartz.NewMock(t)
|
|
trap := clock.Trap().AfterFunc(streamSilenceGuardTimerTag)
|
|
defer trap.Close()
|
|
var calls atomic.Int32
|
|
model := &chattest.FakeModel{
|
|
ProviderName: "openai",
|
|
ModelName: "test-model",
|
|
StreamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
|
calls.Add(1)
|
|
return func(yield func(fantasy.StreamPart) bool) {
|
|
<-ctx.Done()
|
|
yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeError, Error: ctx.Err()})
|
|
}, nil
|
|
},
|
|
}
|
|
done := make(chan error, 1)
|
|
go func() {
|
|
_, err := GenerateAssistant(context.Background(), GenerateAssistantOptions{
|
|
Model: model,
|
|
Clock: clock,
|
|
StreamSilenceTimeout: silenceTimeout,
|
|
})
|
|
done <- err
|
|
}()
|
|
|
|
trap.MustWait(ctx).MustRelease(ctx)
|
|
_, waiter := clock.AdvanceNext()
|
|
waiter.MustWait(ctx)
|
|
err := <-done
|
|
require.Error(t, err)
|
|
classified := chaterror.Classify(err)
|
|
require.Equal(t, codersdk.ChatErrorKindStreamSilenceTimeout, classified.Kind)
|
|
require.Equal(t, "openai", classified.Provider)
|
|
require.True(t, classified.Retryable)
|
|
require.Equal(t, int32(1), calls.Load())
|
|
})
|
|
|
|
t.Run("first part disarms timeout", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
const silenceTimeout = 5 * time.Millisecond
|
|
clock := quartz.NewMock(t)
|
|
trap := clock.Trap().AfterFunc(streamSilenceGuardTimerTag)
|
|
defer trap.Close()
|
|
var calls atomic.Int32
|
|
continueStream := make(chan struct{})
|
|
model := &chattest.FakeModel{
|
|
ProviderName: "openai",
|
|
ModelName: "test-model",
|
|
StreamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
|
calls.Add(1)
|
|
return func(yield func(fantasy.StreamPart) bool) {
|
|
if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"}) {
|
|
return
|
|
}
|
|
select {
|
|
case <-continueStream:
|
|
case <-ctx.Done():
|
|
yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeError, Error: ctx.Err()})
|
|
return
|
|
}
|
|
yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "done"})
|
|
yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"})
|
|
yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop})
|
|
}, nil
|
|
},
|
|
}
|
|
done := make(chan error, 1)
|
|
go func() {
|
|
_, err := GenerateAssistant(context.Background(), GenerateAssistantOptions{
|
|
Model: model,
|
|
Clock: clock,
|
|
StreamSilenceTimeout: silenceTimeout,
|
|
})
|
|
done <- err
|
|
}()
|
|
|
|
trap.MustWait(ctx).MustRelease(ctx)
|
|
close(continueStream)
|
|
require.NoError(t, <-done)
|
|
require.Equal(t, int32(1), calls.Load())
|
|
})
|
|
|
|
t.Run("silent stream close after timeout", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := testutil.Context(t, testutil.WaitLong)
|
|
const silenceTimeout = 5 * time.Millisecond
|
|
clock := quartz.NewMock(t)
|
|
trap := clock.Trap().AfterFunc(streamSilenceGuardTimerTag)
|
|
defer trap.Close()
|
|
var calls atomic.Int32
|
|
model := &chattest.FakeModel{
|
|
ProviderName: "openai",
|
|
ModelName: "test-model",
|
|
StreamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
|
calls.Add(1)
|
|
return func(func(fantasy.StreamPart) bool) {
|
|
<-ctx.Done()
|
|
}, nil
|
|
},
|
|
}
|
|
done := make(chan error, 1)
|
|
go func() {
|
|
_, err := GenerateAssistant(context.Background(), GenerateAssistantOptions{
|
|
Model: model,
|
|
Clock: clock,
|
|
StreamSilenceTimeout: silenceTimeout,
|
|
})
|
|
done <- err
|
|
}()
|
|
|
|
trap.MustWait(ctx).MustRelease(ctx)
|
|
_, waiter := clock.AdvanceNext()
|
|
waiter.MustWait(ctx)
|
|
err := <-done
|
|
require.Error(t, err)
|
|
classified := chaterror.Classify(err)
|
|
require.Equal(t, codersdk.ChatErrorKindStreamSilenceTimeout, classified.Kind)
|
|
require.Equal(t, int32(1), calls.Load())
|
|
})
|
|
}
|
|
|
|
func TestGenerateAssistant_PanicInPublishMessagePartReleasesAttempt(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
attemptReleased := make(chan struct{})
|
|
model := &chattest.FakeModel{
|
|
ProviderName: "openai",
|
|
ModelName: "test-model",
|
|
StreamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) {
|
|
go func() {
|
|
<-ctx.Done()
|
|
close(attemptReleased)
|
|
}()
|
|
return streamFromParts([]fantasy.StreamPart{
|
|
{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"},
|
|
{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "boom"},
|
|
}), nil
|
|
},
|
|
}
|
|
|
|
defer func() {
|
|
r := recover()
|
|
require.NotNil(t, r)
|
|
select {
|
|
case <-attemptReleased:
|
|
case <-time.After(time.Second):
|
|
t.Fatal("attempt context was not released after panic")
|
|
}
|
|
}()
|
|
|
|
_, _ = GenerateAssistant(context.Background(), GenerateAssistantOptions{
|
|
Model: model,
|
|
PublishMessagePart: func(codersdk.ChatMessageRole, codersdk.ChatMessagePart) {
|
|
panic("publish panic")
|
|
},
|
|
})
|
|
|
|
t.Fatal("expected GenerateAssistant to panic")
|
|
}
|
|
|
|
func requireToolResultErrorMessage(
|
|
t *testing.T,
|
|
result fantasy.ToolResultContent,
|
|
expected string,
|
|
) {
|
|
t.Helper()
|
|
|
|
output, ok := result.Result.(fantasy.ToolResultOutputContentError)
|
|
require.Truef(t, ok, "expected error tool result, got %T", result.Result)
|
|
require.Error(t, output.Error)
|
|
require.Equal(t, expected, output.Error.Error())
|
|
}
|
|
|
|
func streamFromParts(parts []fantasy.StreamPart) fantasy.StreamResponse {
|
|
return iter.Seq[fantasy.StreamPart](func(yield func(fantasy.StreamPart) bool) {
|
|
for _, part := range parts {
|
|
if !yield(part) {
|
|
return
|
|
}
|
|
}
|
|
})
|
|
}
|
|
|
|
func textMessage(role fantasy.MessageRole, text string) fantasy.Message {
|
|
return fantasy.Message{
|
|
Role: role,
|
|
Content: []fantasy.MessagePart{
|
|
fantasy.TextPart{Text: text},
|
|
},
|
|
}
|
|
}
|
|
|
|
func requireTextContent(t *testing.T, content []fantasy.Content, text string) fantasy.TextContent {
|
|
t.Helper()
|
|
|
|
for _, block := range content {
|
|
textContent, ok := fantasy.AsContentType[fantasy.TextContent](block)
|
|
if ok && textContent.Text == text {
|
|
return textContent
|
|
}
|
|
}
|
|
t.Fatalf("missing text content %q", text)
|
|
return fantasy.TextContent{}
|
|
}
|
|
|
|
func TestExclusiveToolPolicy_MixedBatchErrors(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
results, violated := applyExclusiveToolPolicy(
|
|
[]fantasy.ToolCallContent{
|
|
{ToolCallID: "advisor-1", ToolName: "advisor", Input: `{}`},
|
|
{ToolCallID: "read-1", ToolName: "read_file", Input: `{"path":"main.go"}`},
|
|
},
|
|
map[string]bool{"advisor": true},
|
|
NopMetrics(),
|
|
"fake",
|
|
"",
|
|
)
|
|
|
|
require.True(t, violated)
|
|
require.Len(t, results, 2)
|
|
require.Equal(t, "advisor-1", results[0].ToolCallID)
|
|
require.Equal(t, "read-1", results[1].ToolCallID)
|
|
requireToolResultErrorMessage(
|
|
t,
|
|
results[0],
|
|
"advisor must be called alone, without other tools in the same batch. Retry with only the advisor call.",
|
|
)
|
|
requireToolResultErrorMessage(
|
|
t,
|
|
results[1],
|
|
"this tool was skipped because advisor must run alone in its batch. Retry your tool calls without advisor, or call advisor separately first.",
|
|
)
|
|
}
|
|
|
|
func TestApplyExclusiveToolPolicy_RecordsErrorMetrics(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
reg := prometheus.NewPedanticRegistry()
|
|
m := NewMetrics(reg)
|
|
|
|
_, violated := applyExclusiveToolPolicy(
|
|
[]fantasy.ToolCallContent{
|
|
{ToolCallID: "advisor-1", ToolName: "advisor", Input: `{}`},
|
|
{ToolCallID: "read-1", ToolName: "read_file", Input: `{"path":"main.go"}`},
|
|
},
|
|
map[string]bool{"advisor": true},
|
|
m,
|
|
"fake",
|
|
"claude-test",
|
|
)
|
|
require.True(t, violated)
|
|
|
|
require.Equal(t, 1.0, promtestutil.ToFloat64(
|
|
m.ToolErrorsTotal.WithLabelValues("fake", "claude-test", "advisor"),
|
|
))
|
|
require.Equal(t, 1.0, promtestutil.ToFloat64(
|
|
m.ToolErrorsTotal.WithLabelValues("fake", "claude-test", "read_file"),
|
|
))
|
|
}
|
|
|
|
func TestExclusiveToolPolicy_MultipleExclusive(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
results, violated := applyExclusiveToolPolicy(
|
|
[]fantasy.ToolCallContent{
|
|
{ToolCallID: "advisor-1", ToolName: "advisor", Input: `{}`},
|
|
{ToolCallID: "advisor-2", ToolName: "advisor", Input: `{"mode":"second-opinion"}`},
|
|
},
|
|
map[string]bool{"advisor": true},
|
|
NopMetrics(),
|
|
"fake",
|
|
"",
|
|
)
|
|
|
|
require.True(t, violated)
|
|
require.Len(t, results, 2)
|
|
requireToolResultErrorMessage(
|
|
t,
|
|
results[0],
|
|
"advisor must be called alone, without other tools in the same batch. Retry with only the advisor call.",
|
|
)
|
|
requireToolResultErrorMessage(
|
|
t,
|
|
results[1],
|
|
"advisor must be called alone, without other tools in the same batch. Retry with only the advisor call.",
|
|
)
|
|
}
|
|
|
|
func TestSanitizeAnthropicProviderToolContent(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
providerCall := func(id, name, input string) fantasy.ToolCallContent {
|
|
return fantasy.ToolCallContent{
|
|
ToolCallID: id,
|
|
ToolName: name,
|
|
Input: input,
|
|
ProviderExecuted: true,
|
|
}
|
|
}
|
|
providerResult := func(id, name string) fantasy.ToolResultContent {
|
|
return fantasy.ToolResultContent{
|
|
ToolCallID: id,
|
|
ToolName: name,
|
|
ProviderExecuted: true,
|
|
ProviderMetadata: validWebSearchProviderMetadataForTest(),
|
|
Result: fantasy.ToolResultOutputContentText{Text: "ok"},
|
|
}
|
|
}
|
|
localCall := func(id, name string) fantasy.ToolCallContent {
|
|
return fantasy.ToolCallContent{
|
|
ToolCallID: id,
|
|
ToolName: name,
|
|
Input: `{}`,
|
|
}
|
|
}
|
|
localResult := func(id, name string) fantasy.ToolResultContent {
|
|
return fantasy.ToolResultContent{
|
|
ToolCallID: id,
|
|
ToolName: name,
|
|
Result: fantasy.ToolResultOutputContentText{Text: "ok"},
|
|
}
|
|
}
|
|
type contentSummary struct {
|
|
providerCalls []string
|
|
providerResults []string
|
|
localCalls []string
|
|
localResults []string
|
|
}
|
|
summarizeContent := func(content []fantasy.Content) contentSummary {
|
|
var summary contentSummary
|
|
for _, block := range content {
|
|
if toolCall, ok := safeToolCallContent(block); ok {
|
|
if toolCall.ProviderExecuted {
|
|
summary.providerCalls = append(summary.providerCalls, toolCall.ToolCallID)
|
|
} else {
|
|
summary.localCalls = append(summary.localCalls, toolCall.ToolCallID)
|
|
}
|
|
continue
|
|
}
|
|
if toolResult, ok := safeToolResultContent(block); ok {
|
|
if toolResult.ProviderExecuted {
|
|
summary.providerResults = append(summary.providerResults, toolResult.ToolCallID)
|
|
} else {
|
|
summary.localResults = append(summary.localResults, toolResult.ToolCallID)
|
|
}
|
|
}
|
|
}
|
|
return summary
|
|
}
|
|
assertProviderHistoryValid := func(t *testing.T, content []fantasy.Content) {
|
|
t.Helper()
|
|
|
|
parts := make([]fantasy.MessagePart, 0)
|
|
for _, block := range content {
|
|
if toolCall, ok := safeToolCallContent(block); ok && toolCall.ProviderExecuted {
|
|
parts = append(parts, toolCallContentToPart(toolCall))
|
|
continue
|
|
}
|
|
if toolResult, ok := safeToolResultContent(block); ok && toolResult.ProviderExecuted {
|
|
parts = append(parts, toolResultContentToPart(toolResult))
|
|
}
|
|
}
|
|
if len(parts) == 0 {
|
|
return
|
|
}
|
|
require.Empty(t, chatsanitize.ValidateAnthropicProviderToolHistory([]fantasy.Message{
|
|
{
|
|
Role: fantasy.MessageRoleAssistant,
|
|
Content: parts,
|
|
},
|
|
}))
|
|
}
|
|
|
|
metadataCall := providerCall("ws-meta", "web_search", `{"query":"coder"}`)
|
|
metadataCall.ProviderMetadata = fantasy.ProviderMetadata{fantasyanthropic.Name: nil}
|
|
metadataResult := providerResult("ws-meta", "web_search")
|
|
metadataResult.ProviderMetadata = fantasy.ProviderMetadata{fantasyanthropic.Name: nil}
|
|
pointerCall := providerCall("ws-pointer", "web_search", `{"query":"coder"}`)
|
|
var nilToolCall *fantasy.ToolCallContent
|
|
|
|
testCases := []struct {
|
|
name string
|
|
provider string
|
|
content []fantasy.Content
|
|
wantSummary contentSummary
|
|
wantRemovedCalls int
|
|
wantRemovedResults int
|
|
wantTexts []string
|
|
validateAnthropic bool
|
|
}{
|
|
{
|
|
name: "orphan provider result textified",
|
|
provider: fantasyanthropic.Name,
|
|
content: []fantasy.Content{
|
|
fantasy.TextContent{Text: "keep"},
|
|
providerResult("ws-1", "web_search"),
|
|
},
|
|
wantRemovedResults: 1,
|
|
wantTexts: []string{"keep", "ok"},
|
|
validateAnthropic: true,
|
|
},
|
|
{
|
|
name: "result before call removes both provider blocks",
|
|
provider: fantasyanthropic.Name,
|
|
content: []fantasy.Content{
|
|
providerResult("ws-1", "web_search"),
|
|
providerCall("ws-1", "web_search", `{"query":"coder"}`),
|
|
},
|
|
wantRemovedCalls: 1,
|
|
wantRemovedResults: 1,
|
|
wantTexts: []string{"ok"},
|
|
validateAnthropic: true,
|
|
},
|
|
{
|
|
name: "valid web search pair preserved",
|
|
provider: fantasyanthropic.Name,
|
|
content: []fantasy.Content{
|
|
providerCall("ws-1", "web_search", `{"query":"coder"}`),
|
|
providerResult("ws-1", "web_search"),
|
|
fantasy.TextContent{Text: "search done"},
|
|
},
|
|
wantSummary: contentSummary{
|
|
providerCalls: []string{"ws-1"},
|
|
providerResults: []string{"ws-1"},
|
|
},
|
|
wantTexts: []string{"search done"},
|
|
validateAnthropic: true,
|
|
},
|
|
{
|
|
name: "invalid JSON provider call drops pair",
|
|
provider: fantasyanthropic.Name,
|
|
content: []fantasy.Content{
|
|
providerCall("ws-1", "web_search", `{`),
|
|
providerResult("ws-1", "web_search"),
|
|
},
|
|
wantRemovedCalls: 1,
|
|
wantRemovedResults: 1,
|
|
wantTexts: []string{"ok"},
|
|
validateAnthropic: true,
|
|
},
|
|
{
|
|
name: "empty ID provider call drops pair",
|
|
provider: fantasyanthropic.Name,
|
|
content: []fantasy.Content{
|
|
providerCall("", "web_search", `{"query":"coder"}`),
|
|
providerResult("", "web_search"),
|
|
},
|
|
wantRemovedCalls: 1,
|
|
wantRemovedResults: 1,
|
|
wantTexts: []string{"ok"},
|
|
validateAnthropic: true,
|
|
},
|
|
{
|
|
name: "empty tool name provider call drops pair",
|
|
provider: fantasyanthropic.Name,
|
|
content: []fantasy.Content{
|
|
providerCall("ws-empty", "", `{"query":"coder"}`),
|
|
providerResult("ws-empty", ""),
|
|
},
|
|
wantRemovedCalls: 1,
|
|
wantRemovedResults: 1,
|
|
wantTexts: []string{"ok"},
|
|
validateAnthropic: true,
|
|
},
|
|
{
|
|
name: "non web search provider pair drops through serializable helper",
|
|
provider: fantasyanthropic.Name,
|
|
content: []fantasy.Content{
|
|
providerCall("code-1", "code_execution", `{"code":"print(1)"}`),
|
|
providerResult("code-1", "code_execution"),
|
|
},
|
|
wantRemovedCalls: 1,
|
|
wantRemovedResults: 1,
|
|
wantTexts: []string{"ok"},
|
|
validateAnthropic: true,
|
|
},
|
|
{
|
|
name: "mismatched provider result tool name drops pair",
|
|
provider: fantasyanthropic.Name,
|
|
content: []fantasy.Content{
|
|
providerCall("ws-mismatch", "web_search", `{"query":"coder"}`),
|
|
providerResult("ws-mismatch", "code_execution"),
|
|
},
|
|
wantRemovedCalls: 1,
|
|
wantRemovedResults: 1,
|
|
wantTexts: []string{"ok"},
|
|
validateAnthropic: true,
|
|
},
|
|
{
|
|
name: "duplicate provider IDs drop all provider content for ID",
|
|
provider: fantasyanthropic.Name,
|
|
content: []fantasy.Content{
|
|
providerCall("dup-1", "web_search", `{"query":"coder"}`),
|
|
providerResult("dup-1", "web_search"),
|
|
providerCall("dup-1", "web_search", `{"query":"coder"}`),
|
|
},
|
|
wantRemovedCalls: 2,
|
|
wantRemovedResults: 1,
|
|
wantTexts: []string{"ok"},
|
|
validateAnthropic: true,
|
|
},
|
|
{
|
|
name: "mismatched provider flags remove only provider side",
|
|
provider: fantasyanthropic.Name,
|
|
content: []fantasy.Content{
|
|
providerCall("mix-1", "web_search", `{"query":"coder"}`),
|
|
localResult("mix-1", "web_search"),
|
|
localCall("mix-2", "read_file"),
|
|
providerResult("mix-2", "web_search"),
|
|
},
|
|
wantSummary: contentSummary{
|
|
localCalls: []string{"mix-2"},
|
|
localResults: []string{"mix-1"},
|
|
},
|
|
wantRemovedCalls: 1,
|
|
wantRemovedResults: 1,
|
|
wantTexts: []string{"ok"},
|
|
validateAnthropic: true,
|
|
},
|
|
{
|
|
name: "malformed provider metadata textifies result",
|
|
provider: fantasyanthropic.Name,
|
|
content: []fantasy.Content{
|
|
metadataCall,
|
|
metadataResult,
|
|
},
|
|
wantRemovedCalls: 1,
|
|
wantRemovedResults: 1,
|
|
wantTexts: []string{"ok"},
|
|
validateAnthropic: true,
|
|
},
|
|
{
|
|
name: "pointer and nil pointer variants are handled safely",
|
|
provider: fantasyanthropic.Name,
|
|
content: []fantasy.Content{
|
|
nilToolCall,
|
|
&pointerCall,
|
|
providerResult("ws-pointer", "web_search"),
|
|
},
|
|
wantSummary: contentSummary{
|
|
providerCalls: []string{"ws-pointer"},
|
|
providerResults: []string{"ws-pointer"},
|
|
},
|
|
validateAnthropic: true,
|
|
},
|
|
{
|
|
name: "local tool content is unchanged",
|
|
provider: fantasyanthropic.Name,
|
|
content: []fantasy.Content{
|
|
localCall("tc-1", "read_file"),
|
|
localResult("tc-1", "read_file"),
|
|
},
|
|
wantSummary: contentSummary{
|
|
localCalls: []string{"tc-1"},
|
|
localResults: []string{"tc-1"},
|
|
},
|
|
validateAnthropic: true,
|
|
},
|
|
{
|
|
name: "non Anthropic provider content is unchanged",
|
|
provider: "fake",
|
|
content: []fantasy.Content{
|
|
providerCall("ws-1", "web_search", `{"query":"coder"}`),
|
|
},
|
|
wantSummary: contentSummary{
|
|
providerCalls: []string{"ws-1"},
|
|
},
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
sanitized, stats := chatsanitize.SanitizeAnthropicProviderToolContent(tc.provider, tc.content)
|
|
require.Equal(t, tc.wantRemovedCalls, stats.RemovedToolCalls)
|
|
require.Equal(t, tc.wantRemovedResults, stats.RemovedToolResults)
|
|
require.Zero(t, stats.DroppedMessages)
|
|
|
|
summary := summarizeContent(sanitized)
|
|
assert.ElementsMatch(t, tc.wantSummary.providerCalls, summary.providerCalls)
|
|
assert.ElementsMatch(t, tc.wantSummary.providerResults, summary.providerResults)
|
|
assert.ElementsMatch(t, tc.wantSummary.localCalls, summary.localCalls)
|
|
assert.ElementsMatch(t, tc.wantSummary.localResults, summary.localResults)
|
|
for _, text := range tc.wantTexts {
|
|
requireTextContent(t, sanitized, text)
|
|
}
|
|
if tc.validateAnthropic {
|
|
assertProviderHistoryValid(t, sanitized)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestExecuteSingleTool_MediaBase64Encoding(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
originalBytes := []byte{0xFF, 0xD8, 0xFF, 0xE0, 0x00, 0x10}
|
|
metrics := NewMetrics(prometheus.NewRegistry())
|
|
logger := slog.Make()
|
|
|
|
t.Run("EncodesRawBytesToBase64", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tool := fantasy.NewAgentTool(
|
|
"screenshot",
|
|
"takes a screenshot",
|
|
func(_ context.Context, _ struct{}, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
|
return fantasy.ToolResponse{
|
|
Type: "image",
|
|
Data: originalBytes,
|
|
MediaType: "image/jpeg",
|
|
}, nil
|
|
},
|
|
)
|
|
|
|
toolMap := map[string]fantasy.AgentTool{
|
|
"screenshot": tool,
|
|
}
|
|
tc := fantasy.ToolCallContent{
|
|
ToolCallID: "call-1",
|
|
ToolName: "screenshot",
|
|
Input: "{}",
|
|
}
|
|
|
|
result := executeSingleTool(
|
|
context.Background(),
|
|
toolMap,
|
|
tc,
|
|
metrics,
|
|
logger,
|
|
"fake", "fake-model",
|
|
map[string]bool{},
|
|
[]string{"screenshot"},
|
|
map[string]struct{}{},
|
|
nil,
|
|
)
|
|
|
|
media, ok := result.Result.(fantasy.ToolResultOutputContentMedia)
|
|
require.True(t, ok, "expected ToolResultOutputContentMedia")
|
|
require.Equal(t, "image/jpeg", media.MediaType)
|
|
|
|
decoded, err := base64.StdEncoding.DecodeString(media.Data)
|
|
require.NoError(t, err, "Data should be valid base64")
|
|
require.Equal(t, originalBytes, decoded)
|
|
})
|
|
|
|
t.Run("SanitizesInvalidUTF8InContent", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tool := fantasy.NewAgentTool(
|
|
"screenshot",
|
|
"takes a screenshot",
|
|
func(_ context.Context, _ struct{}, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
|
return fantasy.ToolResponse{
|
|
Type: "image",
|
|
Data: originalBytes,
|
|
MediaType: "image/png",
|
|
Content: "hello\xffworld",
|
|
}, nil
|
|
},
|
|
)
|
|
|
|
toolMap := map[string]fantasy.AgentTool{
|
|
"screenshot": tool,
|
|
}
|
|
tc := fantasy.ToolCallContent{
|
|
ToolCallID: "call-2",
|
|
ToolName: "screenshot",
|
|
Input: "{}",
|
|
}
|
|
|
|
result := executeSingleTool(
|
|
context.Background(),
|
|
toolMap,
|
|
tc,
|
|
metrics,
|
|
logger,
|
|
"fake", "fake-model",
|
|
map[string]bool{},
|
|
[]string{"screenshot"},
|
|
map[string]struct{}{},
|
|
nil,
|
|
)
|
|
|
|
media, ok := result.Result.(fantasy.ToolResultOutputContentMedia)
|
|
require.True(t, ok, "expected ToolResultOutputContentMedia")
|
|
require.True(t, utf8.ValidString(media.Text), "Text should be valid UTF-8")
|
|
require.Contains(t, media.Text, "hello")
|
|
require.Contains(t, media.Text, "world")
|
|
})
|
|
|
|
t.Run("SanitizesInvalidUTF8InTextResult", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
tool := fantasy.NewAgentTool(
|
|
"echo",
|
|
"echoes input",
|
|
func(_ context.Context, _ struct{}, _ fantasy.ToolCall) (fantasy.ToolResponse, error) {
|
|
return fantasy.ToolResponse{
|
|
Content: "hello\xffworld",
|
|
}, nil
|
|
},
|
|
)
|
|
|
|
toolMap := map[string]fantasy.AgentTool{
|
|
"echo": tool,
|
|
}
|
|
tc := fantasy.ToolCallContent{
|
|
ToolCallID: "call-3",
|
|
ToolName: "echo",
|
|
Input: "{}",
|
|
}
|
|
|
|
result := executeSingleTool(
|
|
context.Background(),
|
|
toolMap,
|
|
tc,
|
|
metrics,
|
|
logger,
|
|
"fake", "fake-model",
|
|
map[string]bool{},
|
|
[]string{"echo"},
|
|
map[string]struct{}{},
|
|
nil,
|
|
)
|
|
|
|
textOutput, ok := result.Result.(fantasy.ToolResultOutputContentText)
|
|
require.True(t, ok, "expected ToolResultOutputContentText, got %T", result.Result)
|
|
require.True(t, utf8.ValidString(textOutput.Text), "Text should be valid UTF-8")
|
|
require.Contains(t, textOutput.Text, "hello")
|
|
require.Contains(t, textOutput.Text, "world")
|
|
})
|
|
}
|