mirror of
https://github.com/coder/coder.git
synced 2026-06-04 13:38:21 +00:00
724 lines
22 KiB
Go
724 lines
22 KiB
Go
package chatdebug
|
|
|
|
import (
|
|
"context"
|
|
"io"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"strings"
|
|
"testing"
|
|
|
|
"charm.land/fantasy"
|
|
"github.com/google/uuid"
|
|
"github.com/stretchr/testify/require"
|
|
"go.uber.org/mock/gomock"
|
|
"golang.org/x/xerrors"
|
|
|
|
"github.com/coder/coder/v2/coderd/database/dbmock"
|
|
"github.com/coder/coder/v2/coderd/x/chatd/chattest"
|
|
"github.com/coder/coder/v2/testutil"
|
|
)
|
|
|
|
type testError struct{ message string }
|
|
|
|
func (e *testError) Error() string { return e.message }
|
|
|
|
func TestDebugModel_Provider(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
inner := &chattest.FakeModel{ProviderName: "provider-a", ModelName: "model-a"}
|
|
model := &debugModel{inner: inner}
|
|
|
|
require.Equal(t, inner.Provider(), model.Provider())
|
|
}
|
|
|
|
func TestDebugModel_Model(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
inner := &chattest.FakeModel{ProviderName: "provider-a", ModelName: "model-a"}
|
|
model := &debugModel{inner: inner}
|
|
|
|
require.Equal(t, inner.Model(), model.Model())
|
|
}
|
|
|
|
func TestDebugModel_Disabled(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
chatID := uuid.New()
|
|
ownerID := uuid.New()
|
|
svc := NewService(db, testutil.Logger(t), nil)
|
|
respWant := &fantasy.Response{FinishReason: fantasy.FinishReasonStop}
|
|
inner := &chattest.FakeModel{
|
|
GenerateFn: func(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
|
|
_, ok := StepFromContext(ctx)
|
|
require.False(t, ok)
|
|
require.Nil(t, attemptSinkFromContext(ctx))
|
|
return respWant, nil
|
|
},
|
|
}
|
|
|
|
model := &debugModel{
|
|
inner: inner,
|
|
svc: svc,
|
|
opts: RecorderOptions{
|
|
ChatID: chatID,
|
|
OwnerID: ownerID,
|
|
},
|
|
}
|
|
|
|
resp, err := model.Generate(context.Background(), fantasy.Call{})
|
|
require.NoError(t, err)
|
|
require.Same(t, respWant, resp)
|
|
}
|
|
|
|
func TestDebugModel_Generate(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
chatID := uuid.New()
|
|
ownerID := uuid.New()
|
|
runID := uuid.New()
|
|
call := fantasy.Call{
|
|
Prompt: fantasy.Prompt{fantasy.NewUserMessage("hello")},
|
|
MaxOutputTokens: int64Ptr(128),
|
|
Temperature: float64Ptr(0.25),
|
|
}
|
|
respWant := &fantasy.Response{
|
|
Content: fantasy.ResponseContent{
|
|
fantasy.TextContent{Text: "hello"},
|
|
fantasy.ToolCallContent{ToolCallID: "tool-1", ToolName: "tool", Input: `{}`},
|
|
fantasy.SourceContent{ID: "source-1", Title: "docs", URL: "https://example.com"},
|
|
},
|
|
FinishReason: fantasy.FinishReasonStop,
|
|
Usage: fantasy.Usage{InputTokens: 10, OutputTokens: 4, TotalTokens: 14},
|
|
Warnings: []fantasy.CallWarning{{Message: "warning"}},
|
|
}
|
|
|
|
svc := NewService(db, testutil.Logger(t), nil)
|
|
inner := &chattest.FakeModel{
|
|
GenerateFn: func(ctx context.Context, got fantasy.Call) (*fantasy.Response, error) {
|
|
require.Equal(t, call, got)
|
|
stepCtx, ok := StepFromContext(ctx)
|
|
require.True(t, ok)
|
|
require.Equal(t, runID, stepCtx.RunID)
|
|
require.Equal(t, chatID, stepCtx.ChatID)
|
|
require.Equal(t, int32(1), stepCtx.StepNumber)
|
|
require.Equal(t, OperationGenerate, stepCtx.Operation)
|
|
require.NotEqual(t, uuid.Nil, stepCtx.StepID)
|
|
require.NotNil(t, attemptSinkFromContext(ctx))
|
|
return respWant, nil
|
|
},
|
|
}
|
|
|
|
model := &debugModel{
|
|
inner: inner,
|
|
svc: svc,
|
|
opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID},
|
|
}
|
|
t.Cleanup(func() { CleanupStepCounter(runID) })
|
|
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
|
|
|
|
resp, err := model.Generate(ctx, call)
|
|
require.NoError(t, err)
|
|
require.Same(t, respWant, resp)
|
|
}
|
|
|
|
func TestDebugModel_GeneratePersistsAttemptsWithoutResponseClose(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
chatID := uuid.New()
|
|
ownerID := uuid.New()
|
|
runID := uuid.New()
|
|
|
|
server := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
|
|
body, err := io.ReadAll(req.Body)
|
|
require.NoError(t, err)
|
|
require.JSONEq(t, `{"message":"hello","api_key":"super-secret"}`,
|
|
string(body))
|
|
require.Equal(t, "Bearer top-secret", req.Header.Get("Authorization"))
|
|
|
|
rw.Header().Set("Content-Type", "application/json")
|
|
rw.Header().Set("X-API-Key", "response-secret")
|
|
rw.WriteHeader(http.StatusCreated)
|
|
_, _ = rw.Write([]byte(`{"token":"response-secret","safe":"ok"}`))
|
|
}))
|
|
defer server.Close()
|
|
|
|
svc := NewService(db, testutil.Logger(t), nil)
|
|
inner := &chattest.FakeModel{
|
|
GenerateFn: func(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
|
|
client := &http.Client{Transport: &RecordingTransport{Base: server.Client().Transport}}
|
|
req, err := http.NewRequestWithContext(
|
|
ctx,
|
|
http.MethodPost,
|
|
server.URL,
|
|
strings.NewReader(`{"message":"hello","api_key":"super-secret"}`),
|
|
)
|
|
require.NoError(t, err)
|
|
req.Header.Set("Authorization", "Bearer top-secret")
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
resp, err := client.Do(req)
|
|
require.NoError(t, err)
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
require.NoError(t, err)
|
|
require.JSONEq(t, `{"token":"response-secret","safe":"ok"}`, string(body))
|
|
require.NoError(t, resp.Body.Close())
|
|
return &fantasy.Response{FinishReason: fantasy.FinishReasonStop}, nil
|
|
},
|
|
}
|
|
|
|
model := &debugModel{
|
|
inner: inner,
|
|
svc: svc,
|
|
opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID},
|
|
}
|
|
t.Cleanup(func() { CleanupStepCounter(runID) })
|
|
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
|
|
|
|
resp, err := model.Generate(ctx, fantasy.Call{})
|
|
require.NoError(t, err)
|
|
require.NotNil(t, resp)
|
|
}
|
|
|
|
func TestDebugModel_GenerateError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
chatID := uuid.New()
|
|
ownerID := uuid.New()
|
|
runID := uuid.New()
|
|
wantErr := &testError{message: "boom"}
|
|
|
|
svc := NewService(db, testutil.Logger(t), nil)
|
|
model := &debugModel{
|
|
inner: &chattest.FakeModel{
|
|
GenerateFn: func(context.Context, fantasy.Call) (*fantasy.Response, error) {
|
|
return nil, wantErr
|
|
},
|
|
},
|
|
svc: svc,
|
|
opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID},
|
|
}
|
|
t.Cleanup(func() { CleanupStepCounter(runID) })
|
|
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
|
|
|
|
resp, err := model.Generate(ctx, fantasy.Call{})
|
|
require.Nil(t, resp)
|
|
require.ErrorIs(t, err, wantErr)
|
|
}
|
|
|
|
func TestStepStatusForError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("Canceled", func(t *testing.T) {
|
|
t.Parallel()
|
|
require.Equal(t, StatusInterrupted, stepStatusForError(context.Canceled))
|
|
})
|
|
|
|
t.Run("DeadlineExceeded", func(t *testing.T) {
|
|
t.Parallel()
|
|
require.Equal(t, StatusInterrupted, stepStatusForError(context.DeadlineExceeded))
|
|
})
|
|
|
|
t.Run("OtherError", func(t *testing.T) {
|
|
t.Parallel()
|
|
require.Equal(t, StatusError, stepStatusForError(xerrors.New("boom")))
|
|
})
|
|
}
|
|
|
|
func TestDebugModel_Stream(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
chatID := uuid.New()
|
|
ownerID := uuid.New()
|
|
runID := uuid.New()
|
|
errPart := xerrors.New("chunk failed")
|
|
parts := []fantasy.StreamPart{
|
|
{Type: fantasy.StreamPartTypeTextDelta, Delta: "hel"},
|
|
{Type: fantasy.StreamPartTypeToolCall, ID: "tool-call-1", ToolCallName: "tool"},
|
|
{Type: fantasy.StreamPartTypeSource, ID: "source-1", URL: "https://example.com", Title: "docs"},
|
|
{Type: fantasy.StreamPartTypeWarnings, Warnings: []fantasy.CallWarning{{Message: "w1"}, {Message: "w2"}}},
|
|
{Type: fantasy.StreamPartTypeError, Error: errPart},
|
|
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop, Usage: fantasy.Usage{InputTokens: 8, OutputTokens: 3, TotalTokens: 11}},
|
|
}
|
|
|
|
svc := NewService(db, testutil.Logger(t), nil)
|
|
model := &debugModel{
|
|
inner: &chattest.FakeModel{
|
|
StreamFn: func(ctx context.Context, call fantasy.Call) (fantasy.StreamResponse, error) {
|
|
stepCtx, ok := StepFromContext(ctx)
|
|
require.True(t, ok)
|
|
require.Equal(t, runID, stepCtx.RunID)
|
|
require.Equal(t, chatID, stepCtx.ChatID)
|
|
require.Equal(t, int32(1), stepCtx.StepNumber)
|
|
require.Equal(t, OperationStream, stepCtx.Operation)
|
|
require.NotEqual(t, uuid.Nil, stepCtx.StepID)
|
|
require.NotNil(t, attemptSinkFromContext(ctx))
|
|
return partsToSeq(parts), nil
|
|
},
|
|
},
|
|
svc: svc,
|
|
opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID},
|
|
}
|
|
t.Cleanup(func() { CleanupStepCounter(runID) })
|
|
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
|
|
|
|
seq, err := model.Stream(ctx, fantasy.Call{})
|
|
require.NoError(t, err)
|
|
|
|
got := make([]fantasy.StreamPart, 0, len(parts))
|
|
for part := range seq {
|
|
got = append(got, part)
|
|
}
|
|
|
|
require.Equal(t, parts, got)
|
|
}
|
|
|
|
func TestDebugModel_StreamObject(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
chatID := uuid.New()
|
|
ownerID := uuid.New()
|
|
runID := uuid.New()
|
|
parts := []fantasy.ObjectStreamPart{
|
|
{Type: fantasy.ObjectStreamPartTypeTextDelta, Delta: "ob"},
|
|
{Type: fantasy.ObjectStreamPartTypeTextDelta, Delta: "ject"},
|
|
{Type: fantasy.ObjectStreamPartTypeObject, Object: map[string]any{"value": "object"}},
|
|
{Type: fantasy.ObjectStreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop, Usage: fantasy.Usage{InputTokens: 5, OutputTokens: 2, TotalTokens: 7}},
|
|
}
|
|
|
|
svc := NewService(db, testutil.Logger(t), nil)
|
|
model := &debugModel{
|
|
inner: &chattest.FakeModel{
|
|
StreamObjectFn: func(ctx context.Context, call fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
|
|
stepCtx, ok := StepFromContext(ctx)
|
|
require.True(t, ok)
|
|
require.Equal(t, runID, stepCtx.RunID)
|
|
require.Equal(t, chatID, stepCtx.ChatID)
|
|
require.Equal(t, int32(1), stepCtx.StepNumber)
|
|
require.Equal(t, OperationStream, stepCtx.Operation)
|
|
require.NotEqual(t, uuid.Nil, stepCtx.StepID)
|
|
require.NotNil(t, attemptSinkFromContext(ctx))
|
|
return objectPartsToSeq(parts), nil
|
|
},
|
|
},
|
|
svc: svc,
|
|
opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID},
|
|
}
|
|
t.Cleanup(func() { CleanupStepCounter(runID) })
|
|
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
|
|
|
|
seq, err := model.StreamObject(ctx, fantasy.ObjectCall{})
|
|
require.NoError(t, err)
|
|
|
|
got := make([]fantasy.ObjectStreamPart, 0, len(parts))
|
|
for part := range seq {
|
|
got = append(got, part)
|
|
}
|
|
|
|
require.Equal(t, parts, got)
|
|
}
|
|
|
|
// TestDebugModel_StreamCompletedAfterFinish verifies that when a consumer
|
|
// stops iteration after receiving a finish part, the step is marked as
|
|
// completed rather than interrupted.
|
|
func TestDebugModel_StreamCompletedAfterFinish(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
handle := &stepHandle{
|
|
stepCtx: &StepContext{StepID: uuid.New(), RunID: uuid.New(), ChatID: uuid.New()},
|
|
sink: &attemptSink{},
|
|
}
|
|
|
|
parts := []fantasy.StreamPart{
|
|
{Type: fantasy.StreamPartTypeTextDelta, Delta: "hello"},
|
|
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop, Usage: fantasy.Usage{InputTokens: 5, OutputTokens: 1, TotalTokens: 6}},
|
|
}
|
|
seq := wrapStreamSeq(context.Background(), handle, partsToSeq(parts))
|
|
|
|
// Consumer reads through the finish part then breaks. The wrapper
|
|
// should finalize as completed, not interrupted.
|
|
for part := range seq {
|
|
if part.Type == fantasy.StreamPartTypeFinish {
|
|
break
|
|
}
|
|
}
|
|
|
|
handle.mu.Lock()
|
|
status := handle.status
|
|
handle.mu.Unlock()
|
|
require.Equal(t, StatusCompleted, status)
|
|
}
|
|
|
|
// TestDebugModel_StreamInterruptedBeforeFinish verifies that when a consumer
|
|
// stops iteration before receiving a finish part, the step is marked as
|
|
// interrupted.
|
|
func TestDebugModel_StreamInterruptedBeforeFinish(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
handle := &stepHandle{
|
|
stepCtx: &StepContext{StepID: uuid.New(), RunID: uuid.New(), ChatID: uuid.New()},
|
|
sink: &attemptSink{},
|
|
}
|
|
|
|
parts := []fantasy.StreamPart{
|
|
{Type: fantasy.StreamPartTypeTextDelta, Delta: "hello"},
|
|
{Type: fantasy.StreamPartTypeTextDelta, Delta: " world"},
|
|
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop},
|
|
}
|
|
seq := wrapStreamSeq(context.Background(), handle, partsToSeq(parts))
|
|
|
|
// Consumer reads the first delta then breaks before finish.
|
|
count := 0
|
|
for range seq {
|
|
count++
|
|
if count == 1 {
|
|
break
|
|
}
|
|
}
|
|
require.Equal(t, 1, count)
|
|
|
|
handle.mu.Lock()
|
|
status := handle.status
|
|
handle.mu.Unlock()
|
|
require.Equal(t, StatusInterrupted, status)
|
|
}
|
|
|
|
func TestDebugModel_StreamRejectsNilSequence(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
chatID := uuid.New()
|
|
runID := uuid.New()
|
|
svc := NewService(db, testutil.Logger(t), nil)
|
|
model := &debugModel{
|
|
inner: &chattest.FakeModel{
|
|
StreamFn: func(context.Context, fantasy.Call) (fantasy.StreamResponse, error) {
|
|
var nilStream fantasy.StreamResponse
|
|
return nilStream, nil
|
|
},
|
|
},
|
|
svc: svc,
|
|
opts: RecorderOptions{ChatID: chatID, OwnerID: uuid.New()},
|
|
}
|
|
t.Cleanup(func() { CleanupStepCounter(runID) })
|
|
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
|
|
|
|
seq, err := model.Stream(ctx, fantasy.Call{})
|
|
require.Nil(t, seq)
|
|
require.ErrorIs(t, err, ErrNilModelResult)
|
|
}
|
|
|
|
func TestDebugModel_StreamObjectRejectsNilSequence(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
chatID := uuid.New()
|
|
runID := uuid.New()
|
|
svc := NewService(db, testutil.Logger(t), nil)
|
|
model := &debugModel{
|
|
inner: &chattest.FakeModel{
|
|
StreamObjectFn: func(context.Context, fantasy.ObjectCall) (fantasy.ObjectStreamResponse, error) {
|
|
var nilStream fantasy.ObjectStreamResponse
|
|
return nilStream, nil
|
|
},
|
|
},
|
|
svc: svc,
|
|
opts: RecorderOptions{ChatID: chatID, OwnerID: uuid.New()},
|
|
}
|
|
t.Cleanup(func() { CleanupStepCounter(runID) })
|
|
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
|
|
|
|
seq, err := model.StreamObject(ctx, fantasy.ObjectCall{})
|
|
require.Nil(t, seq)
|
|
require.ErrorIs(t, err, ErrNilModelResult)
|
|
}
|
|
|
|
func TestDebugModel_StreamEarlyStop(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
chatID := uuid.New()
|
|
ownerID := uuid.New()
|
|
runID := uuid.New()
|
|
parts := []fantasy.StreamPart{
|
|
{Type: fantasy.StreamPartTypeTextDelta, Delta: "first"},
|
|
{Type: fantasy.StreamPartTypeTextDelta, Delta: "second"},
|
|
}
|
|
|
|
svc := NewService(db, testutil.Logger(t), nil)
|
|
model := &debugModel{
|
|
inner: &chattest.FakeModel{
|
|
StreamFn: func(context.Context, fantasy.Call) (fantasy.StreamResponse, error) {
|
|
return partsToSeq(parts), nil
|
|
},
|
|
},
|
|
svc: svc,
|
|
opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID},
|
|
}
|
|
t.Cleanup(func() { CleanupStepCounter(runID) })
|
|
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
|
|
|
|
seq, err := model.Stream(ctx, fantasy.Call{})
|
|
require.NoError(t, err)
|
|
|
|
count := 0
|
|
for part := range seq {
|
|
require.Equal(t, parts[0], part)
|
|
count++
|
|
break
|
|
}
|
|
require.Equal(t, 1, count)
|
|
}
|
|
|
|
func TestStreamErrorStatus(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("CancellationBecomesInterrupted", func(t *testing.T) {
|
|
t.Parallel()
|
|
require.Equal(t, StatusInterrupted, streamErrorStatus(StatusCompleted, context.Canceled))
|
|
})
|
|
|
|
t.Run("DeadlineExceededBecomesInterrupted", func(t *testing.T) {
|
|
t.Parallel()
|
|
require.Equal(t, StatusInterrupted, streamErrorStatus(StatusCompleted, context.DeadlineExceeded))
|
|
})
|
|
|
|
t.Run("NilErrorBecomesError", func(t *testing.T) {
|
|
t.Parallel()
|
|
require.Equal(t, StatusError, streamErrorStatus(StatusCompleted, nil))
|
|
})
|
|
|
|
t.Run("ExistingErrorWins", func(t *testing.T) {
|
|
t.Parallel()
|
|
require.Equal(t, StatusError, streamErrorStatus(StatusError, context.Canceled))
|
|
})
|
|
}
|
|
|
|
func objectPartsToSeq(parts []fantasy.ObjectStreamPart) fantasy.ObjectStreamResponse {
|
|
return func(yield func(fantasy.ObjectStreamPart) bool) {
|
|
for _, part := range parts {
|
|
if !yield(part) {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func partsToSeq(parts []fantasy.StreamPart) fantasy.StreamResponse {
|
|
return func(yield func(fantasy.StreamPart) bool) {
|
|
for _, part := range parts {
|
|
if !yield(part) {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestDebugModel_GenerateObject(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
chatID := uuid.New()
|
|
ownerID := uuid.New()
|
|
runID := uuid.New()
|
|
call := fantasy.ObjectCall{
|
|
Prompt: fantasy.Prompt{fantasy.NewUserMessage("summarize")},
|
|
SchemaName: "Summary",
|
|
MaxOutputTokens: int64Ptr(256),
|
|
}
|
|
respWant := &fantasy.ObjectResponse{
|
|
RawText: `{"title":"test"}`,
|
|
FinishReason: fantasy.FinishReasonStop,
|
|
Usage: fantasy.Usage{InputTokens: 5, OutputTokens: 3, TotalTokens: 8},
|
|
}
|
|
|
|
svc := NewService(db, testutil.Logger(t), nil)
|
|
inner := &chattest.FakeModel{
|
|
GenerateObjectFn: func(ctx context.Context, got fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
|
|
require.Equal(t, call, got)
|
|
stepCtx, ok := StepFromContext(ctx)
|
|
require.True(t, ok)
|
|
require.Equal(t, runID, stepCtx.RunID)
|
|
require.Equal(t, chatID, stepCtx.ChatID)
|
|
require.Equal(t, OperationGenerate, stepCtx.Operation)
|
|
require.NotEqual(t, uuid.Nil, stepCtx.StepID)
|
|
require.NotNil(t, attemptSinkFromContext(ctx))
|
|
return respWant, nil
|
|
},
|
|
}
|
|
|
|
model := &debugModel{
|
|
inner: inner,
|
|
svc: svc,
|
|
opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID},
|
|
}
|
|
t.Cleanup(func() { CleanupStepCounter(runID) })
|
|
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
|
|
|
|
resp, err := model.GenerateObject(ctx, call)
|
|
require.NoError(t, err)
|
|
require.Same(t, respWant, resp)
|
|
}
|
|
|
|
func TestDebugModel_GenerateObjectError(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
chatID := uuid.New()
|
|
runID := uuid.New()
|
|
wantErr := &testError{message: "object boom"}
|
|
|
|
svc := NewService(db, testutil.Logger(t), nil)
|
|
model := &debugModel{
|
|
inner: &chattest.FakeModel{
|
|
GenerateObjectFn: func(context.Context, fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
|
|
return nil, wantErr
|
|
},
|
|
},
|
|
svc: svc,
|
|
opts: RecorderOptions{ChatID: chatID, OwnerID: uuid.New()},
|
|
}
|
|
t.Cleanup(func() { CleanupStepCounter(runID) })
|
|
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
|
|
|
|
resp, err := model.GenerateObject(ctx, fantasy.ObjectCall{})
|
|
require.Nil(t, resp)
|
|
require.ErrorIs(t, err, wantErr)
|
|
}
|
|
|
|
func TestDebugModel_GenerateObjectRejectsNilResponse(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
chatID := uuid.New()
|
|
runID := uuid.New()
|
|
|
|
svc := NewService(db, testutil.Logger(t), nil)
|
|
model := &debugModel{
|
|
inner: &chattest.FakeModel{
|
|
GenerateObjectFn: func(context.Context, fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
|
|
return nil, nil //nolint:nilnil // Intentionally testing nil response handling.
|
|
},
|
|
},
|
|
svc: svc,
|
|
opts: RecorderOptions{ChatID: chatID, OwnerID: uuid.New()},
|
|
}
|
|
t.Cleanup(func() { CleanupStepCounter(runID) })
|
|
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})
|
|
|
|
resp, err := model.GenerateObject(ctx, fantasy.ObjectCall{})
|
|
require.Nil(t, resp)
|
|
require.ErrorIs(t, err, ErrNilModelResult)
|
|
}
|
|
|
|
func TestWrapStreamSeq_CompletedNotDowngradedByCtxCancel(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
handle := &stepHandle{
|
|
stepCtx: &StepContext{StepID: uuid.New(), RunID: uuid.New(), ChatID: uuid.New()},
|
|
sink: &attemptSink{},
|
|
}
|
|
|
|
// Create a context that we cancel after the stream finishes.
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
parts := []fantasy.StreamPart{
|
|
{Type: fantasy.StreamPartTypeTextDelta, Delta: "hello"},
|
|
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop, Usage: fantasy.Usage{InputTokens: 5, OutputTokens: 1, TotalTokens: 6}},
|
|
}
|
|
seq := wrapStreamSeq(ctx, handle, partsToSeq(parts))
|
|
|
|
//nolint:revive // Intentionally consuming iterator to trigger side-effects.
|
|
for range seq {
|
|
}
|
|
|
|
// Cancel the context after the stream has been fully consumed
|
|
// and finalized. The status should remain completed.
|
|
cancel()
|
|
|
|
handle.mu.Lock()
|
|
status := handle.status
|
|
handle.mu.Unlock()
|
|
require.Equal(t, StatusCompleted, status)
|
|
}
|
|
|
|
func TestWrapObjectStreamSeq_CompletedNotDowngradedByCtxCancel(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
handle := &stepHandle{
|
|
stepCtx: &StepContext{StepID: uuid.New(), RunID: uuid.New(), ChatID: uuid.New()},
|
|
sink: &attemptSink{},
|
|
}
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
parts := []fantasy.ObjectStreamPart{
|
|
{Type: fantasy.ObjectStreamPartTypeTextDelta, Delta: "obj"},
|
|
{Type: fantasy.ObjectStreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop, Usage: fantasy.Usage{InputTokens: 3, OutputTokens: 1, TotalTokens: 4}},
|
|
}
|
|
seq := wrapObjectStreamSeq(ctx, handle, objectPartsToSeq(parts))
|
|
|
|
//nolint:revive // Intentionally consuming iterator to trigger side-effects.
|
|
for range seq {
|
|
}
|
|
|
|
cancel()
|
|
|
|
handle.mu.Lock()
|
|
status := handle.status
|
|
handle.mu.Unlock()
|
|
require.Equal(t, StatusCompleted, status)
|
|
}
|
|
|
|
func TestWrapStreamSeq_DroppedStreamFinalizedOnCtxCancel(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
handle := &stepHandle{
|
|
stepCtx: &StepContext{StepID: uuid.New(), RunID: uuid.New(), ChatID: uuid.New()},
|
|
sink: &attemptSink{},
|
|
}
|
|
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
parts := []fantasy.StreamPart{
|
|
{Type: fantasy.StreamPartTypeTextDelta, Delta: "hello"},
|
|
{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop},
|
|
}
|
|
|
|
// Create the wrapped stream but never iterate it.
|
|
_ = wrapStreamSeq(ctx, handle, partsToSeq(parts))
|
|
|
|
// Cancel the context; the AfterFunc safety net should finalize
|
|
// the step as interrupted.
|
|
cancel()
|
|
|
|
// AfterFunc fires asynchronously; give it a moment.
|
|
require.Eventually(t, func() bool {
|
|
handle.mu.Lock()
|
|
defer handle.mu.Unlock()
|
|
return handle.status == StatusInterrupted
|
|
}, testutil.WaitShort, testutil.IntervalFast)
|
|
}
|
|
|
|
func int64Ptr(v int64) *int64 { return &v }
|
|
|
|
func float64Ptr(v float64) *float64 { return &v }
|