mirror of
https://github.com/coder/coder.git
synced 2026-06-03 21:18:24 +00:00
183 lines
4.3 KiB
Go
183 lines
4.3 KiB
Go
package chatdebug //nolint:testpackage // Uses unexported recorder helpers.
|
|
|
|
import (
|
|
"context"
|
|
"slices"
|
|
"sync"
|
|
"testing"
|
|
|
|
"charm.land/fantasy"
|
|
"github.com/google/uuid"
|
|
"github.com/stretchr/testify/require"
|
|
"go.uber.org/mock/gomock"
|
|
|
|
"github.com/coder/coder/v2/coderd/database/dbmock"
|
|
"github.com/coder/coder/v2/coderd/x/chatd/chattest"
|
|
"github.com/coder/coder/v2/testutil"
|
|
)
|
|
|
|
func TestAttemptSink_ThreadSafe(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
const n = 256
|
|
|
|
sink := &attemptSink{}
|
|
var wg sync.WaitGroup
|
|
|
|
for i := range n {
|
|
wg.Go(func() {
|
|
sink.record(Attempt{Number: i + 1, ResponseStatus: 200 + i})
|
|
})
|
|
}
|
|
|
|
wg.Wait()
|
|
|
|
attempts := sink.snapshot()
|
|
require.Len(t, attempts, n)
|
|
|
|
numbers := make([]int, 0, n)
|
|
statuses := make([]int, 0, n)
|
|
for _, attempt := range attempts {
|
|
numbers = append(numbers, attempt.Number)
|
|
statuses = append(statuses, attempt.ResponseStatus)
|
|
}
|
|
slices.Sort(numbers)
|
|
slices.Sort(statuses)
|
|
|
|
for i := range n {
|
|
require.Equal(t, i+1, numbers[i])
|
|
require.Equal(t, 200+i, statuses[i])
|
|
}
|
|
}
|
|
|
|
func TestAttemptSinkContext(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := context.Background()
|
|
require.Nil(t, attemptSinkFromContext(ctx))
|
|
|
|
sink := &attemptSink{}
|
|
ctx = withAttemptSink(ctx, sink)
|
|
require.Same(t, sink, attemptSinkFromContext(ctx))
|
|
}
|
|
|
|
func TestWrapModel_NilModel(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
require.Panics(t, func() {
|
|
WrapModel(nil, &Service{}, RecorderOptions{})
|
|
})
|
|
}
|
|
|
|
func TestWrapModel_NilService(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
model := &chattest.FakeModel{ProviderName: "provider", ModelName: "model"}
|
|
wrapped := WrapModel(model, nil, RecorderOptions{})
|
|
require.Same(t, model, wrapped)
|
|
}
|
|
|
|
func TestNextStepNumber_Concurrent(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
const n = 256
|
|
|
|
runID := uuid.New()
|
|
t.Cleanup(func() { CleanupStepCounter(runID) })
|
|
|
|
results := make([]int, n)
|
|
var wg sync.WaitGroup
|
|
|
|
for i := range n {
|
|
wg.Go(func() {
|
|
results[i] = int(nextStepNumber(runID))
|
|
})
|
|
}
|
|
|
|
wg.Wait()
|
|
|
|
slices.Sort(results)
|
|
for i := range n {
|
|
require.Equal(t, i+1, results[i])
|
|
}
|
|
}
|
|
|
|
func TestStepFinalizeContext_StripsCancellation(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
baseCtx, cancelBase := context.WithCancel(context.Background())
|
|
cancelBase()
|
|
require.ErrorIs(t, baseCtx.Err(), context.Canceled)
|
|
|
|
finalizeCtx, cancelFinalize := stepFinalizeContext(baseCtx)
|
|
defer cancelFinalize()
|
|
|
|
require.NoError(t, finalizeCtx.Err())
|
|
_, hasDeadline := finalizeCtx.Deadline()
|
|
require.True(t, hasDeadline)
|
|
}
|
|
|
|
func TestSyncStepCounter_AdvancesCounter(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
runID := uuid.New()
|
|
t.Cleanup(func() { CleanupStepCounter(runID) })
|
|
|
|
syncStepCounter(runID, 7)
|
|
require.Equal(t, int32(8), nextStepNumber(runID))
|
|
}
|
|
|
|
func TestStepHandleFinish_NilHandle(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
var handle *stepHandle
|
|
handle.finish(context.Background(), StatusCompleted, nil, nil, nil, nil)
|
|
}
|
|
|
|
func TestBeginStep_NilService(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx := context.Background()
|
|
handle, enriched := beginStep(ctx, nil, RecorderOptions{}, OperationGenerate, nil)
|
|
require.Nil(t, handle)
|
|
require.Nil(t, attemptSinkFromContext(enriched))
|
|
_, ok := StepFromContext(enriched)
|
|
require.False(t, ok)
|
|
}
|
|
|
|
func TestBeginStep_FallsBackToRunChatID(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctrl := gomock.NewController(t)
|
|
db := dbmock.NewMockStore(ctrl)
|
|
runID := uuid.New()
|
|
runChatID := uuid.New()
|
|
ownerID := uuid.New()
|
|
expectDebugLoggingEnabled(t, db, ownerID)
|
|
expectCreateStepNumberWithRequestValidity(t, db, runID, runChatID, 1, OperationGenerate, false)
|
|
|
|
ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: runChatID})
|
|
svc := NewService(db, testutil.Logger(t), nil)
|
|
|
|
handle, enriched := beginStep(ctx, svc, RecorderOptions{OwnerID: ownerID}, OperationGenerate, nil)
|
|
require.NotNil(t, handle)
|
|
require.Equal(t, runChatID, handle.stepCtx.ChatID)
|
|
|
|
stepCtx, ok := StepFromContext(enriched)
|
|
require.True(t, ok)
|
|
require.Equal(t, runChatID, stepCtx.ChatID)
|
|
}
|
|
|
|
func TestWrapModel_ReturnsDebugModel(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
model := &chattest.FakeModel{ProviderName: "provider", ModelName: "model"}
|
|
wrapped := WrapModel(model, &Service{}, RecorderOptions{})
|
|
|
|
require.NotSame(t, model, wrapped)
|
|
require.IsType(t, &debugModel{}, wrapped)
|
|
require.Implements(t, (*fantasy.LanguageModel)(nil), wrapped)
|
|
require.Equal(t, model.Provider(), wrapped.Provider())
|
|
require.Equal(t, model.Model(), wrapped.Model())
|
|
}
|