Files
coder/coderd/x/chatd/chatdebug/recorder_test.go
T

183 lines
4.3 KiB
Go

package chatdebug //nolint:testpackage // Uses unexported recorder helpers.
import (
"context"
"slices"
"sync"
"testing"
"charm.land/fantasy"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/coderd/x/chatd/chattest"
"github.com/coder/coder/v2/testutil"
)
func TestAttemptSink_ThreadSafe(t *testing.T) {
t.Parallel()
const n = 256
sink := &attemptSink{}
var wg sync.WaitGroup
for i := range n {
wg.Go(func() {
sink.record(Attempt{Number: i + 1, ResponseStatus: 200 + i})
})
}
wg.Wait()
attempts := sink.snapshot()
require.Len(t, attempts, n)
numbers := make([]int, 0, n)
statuses := make([]int, 0, n)
for _, attempt := range attempts {
numbers = append(numbers, attempt.Number)
statuses = append(statuses, attempt.ResponseStatus)
}
slices.Sort(numbers)
slices.Sort(statuses)
for i := range n {
require.Equal(t, i+1, numbers[i])
require.Equal(t, 200+i, statuses[i])
}
}
func TestAttemptSinkContext(t *testing.T) {
t.Parallel()
ctx := context.Background()
require.Nil(t, attemptSinkFromContext(ctx))
sink := &attemptSink{}
ctx = withAttemptSink(ctx, sink)
require.Same(t, sink, attemptSinkFromContext(ctx))
}
func TestWrapModel_NilModel(t *testing.T) {
t.Parallel()
require.Panics(t, func() {
WrapModel(nil, &Service{}, RecorderOptions{})
})
}
func TestWrapModel_NilService(t *testing.T) {
t.Parallel()
model := &chattest.FakeModel{ProviderName: "provider", ModelName: "model"}
wrapped := WrapModel(model, nil, RecorderOptions{})
require.Same(t, model, wrapped)
}
func TestNextStepNumber_Concurrent(t *testing.T) {
t.Parallel()
const n = 256
runID := uuid.New()
t.Cleanup(func() { CleanupStepCounter(runID) })
results := make([]int, n)
var wg sync.WaitGroup
for i := range n {
wg.Go(func() {
results[i] = int(nextStepNumber(runID))
})
}
wg.Wait()
slices.Sort(results)
for i := range n {
require.Equal(t, i+1, results[i])
}
}
func TestStepFinalizeContext_StripsCancellation(t *testing.T) {
t.Parallel()
baseCtx, cancelBase := context.WithCancel(context.Background())
cancelBase()
require.ErrorIs(t, baseCtx.Err(), context.Canceled)
finalizeCtx, cancelFinalize := stepFinalizeContext(baseCtx)
defer cancelFinalize()
require.NoError(t, finalizeCtx.Err())
_, hasDeadline := finalizeCtx.Deadline()
require.True(t, hasDeadline)
}
func TestSyncStepCounter_AdvancesCounter(t *testing.T) {
t.Parallel()
runID := uuid.New()
t.Cleanup(func() { CleanupStepCounter(runID) })
syncStepCounter(runID, 7)
require.Equal(t, int32(8), nextStepNumber(runID))
}
func TestStepHandleFinish_NilHandle(t *testing.T) {
t.Parallel()
var handle *stepHandle
handle.finish(context.Background(), StatusCompleted, nil, nil, nil, nil)
}
func TestBeginStep_NilService(t *testing.T) {
t.Parallel()
ctx := context.Background()
handle, enriched := beginStep(ctx, nil, RecorderOptions{}, OperationGenerate, nil)
require.Nil(t, handle)
require.Nil(t, attemptSinkFromContext(enriched))
_, ok := StepFromContext(enriched)
require.False(t, ok)
}
func TestBeginStep_FallsBackToRunChatID(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
runID := uuid.New()
runChatID := uuid.New()
ownerID := uuid.New()
expectDebugLoggingEnabled(t, db, ownerID)
expectCreateStepNumberWithRequestValidity(t, db, runID, runChatID, 1, OperationGenerate, false)
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: runChatID})
svc := NewService(db, testutil.Logger(t), nil)
handle, enriched := beginStep(ctx, svc, RecorderOptions{OwnerID: ownerID}, OperationGenerate, nil)
require.NotNil(t, handle)
require.Equal(t, runChatID, handle.stepCtx.ChatID)
stepCtx, ok := StepFromContext(enriched)
require.True(t, ok)
require.Equal(t, runChatID, stepCtx.ChatID)
}
func TestWrapModel_ReturnsDebugModel(t *testing.T) {
t.Parallel()
model := &chattest.FakeModel{ProviderName: "provider", ModelName: "model"}
wrapped := WrapModel(model, &Service{}, RecorderOptions{})
require.NotSame(t, model, wrapped)
require.IsType(t, &debugModel{}, wrapped)
require.Implements(t, (*fantasy.LanguageModel)(nil), wrapped)
require.Equal(t, model.Provider(), wrapped.Provider())
require.Equal(t, model.Model(), wrapped.Model())
}