mirror of
https://github.com/coder/coder.git
synced 2026-06-06 06:28:20 +00:00
316 lines
10 KiB
Go
316 lines
10 KiB
Go
package chatd //nolint:testpackage // Tests unexported chat worker internals.
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"errors"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/coder/coder/v2/coderd/database"
|
|
"github.com/coder/coder/v2/testutil"
|
|
"github.com/coder/quartz"
|
|
)
|
|
|
|
func TestWorker_NewRequiresTaskStarterOrMessagePartBuffer(t *testing.T) {
|
|
t.Parallel()
|
|
f := newWorkerTestFixture(t)
|
|
_, err := newChatWorker(nil, chatWorkerOptions{WorkerID: uuid.New(), Store: f.db, Pubsub: f.pubsub})
|
|
require.ErrorContains(t, err, "task starter or message part buffer is required")
|
|
}
|
|
|
|
func TestWorker_NewRequiresWorkerID(t *testing.T) {
|
|
t.Parallel()
|
|
f := newWorkerTestFixture(t)
|
|
opts := testOptions(t, f, newRecordingTaskStarter())
|
|
opts.WorkerID = uuid.Nil
|
|
_, err := newChatWorker(nil, opts)
|
|
require.ErrorContains(t, err, "worker ID is required")
|
|
}
|
|
|
|
func TestWorker_UsesConfiguredWorkerID(t *testing.T) {
|
|
t.Parallel()
|
|
f := newWorkerTestFixture(t)
|
|
starter := newRecordingTaskStarter()
|
|
opts := testOptions(t, f, starter)
|
|
workerID := opts.WorkerID
|
|
worker, err := newChatWorker(nil, opts)
|
|
require.NoError(t, err)
|
|
require.Equal(t, workerID, worker.chatWorkerID())
|
|
require.NoError(t, worker.Start(context.Background()))
|
|
require.Equal(t, workerID, worker.chatWorkerID())
|
|
require.NoError(t, worker.Close())
|
|
}
|
|
|
|
func TestWorker_AcquiresRunnableChatFromOwnershipHint(t *testing.T) {
|
|
t.Parallel()
|
|
f := newWorkerTestFixture(t)
|
|
chat := f.createRunningChat(t)
|
|
starter := newRecordingTaskStarter()
|
|
worker := startWorker(t, testOptions(t, f, starter))
|
|
|
|
call := starter.waitCall(t, taskKindGeneration, chat.ID)
|
|
require.Equal(t, worker.chatWorkerID(), call.input.WorkerID)
|
|
require.Equal(t, database.ChatStatusRunning, call.input.Status)
|
|
require.NotEqual(t, uuid.Nil, call.input.RunnerID)
|
|
|
|
latest, err := f.db.GetChatByID(testutil.Context(t, testutil.WaitShort), chat.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, worker.chatWorkerID(), latest.WorkerID.UUID)
|
|
require.Equal(t, call.input.RunnerID, latest.RunnerID.UUID)
|
|
_, err = f.db.GetChatHeartbeat(testutil.Context(t, testutil.WaitShort), database.GetChatHeartbeatParams{
|
|
ChatID: chat.ID,
|
|
RunnerID: call.input.RunnerID,
|
|
})
|
|
require.NoError(t, err)
|
|
}
|
|
|
|
func TestWorker_AcquiresRequiresActionChatFromOwnershipHint(t *testing.T) {
|
|
t.Parallel()
|
|
f := newWorkerTestFixture(t)
|
|
chat := f.createRequiresActionChat(t)
|
|
starter := newRecordingTaskStarter()
|
|
startWorker(t, testOptions(t, f, starter))
|
|
|
|
call := starter.waitCall(t, taskKindRequiresActionTimeout, chat.ID)
|
|
require.Equal(t, database.ChatStatusRequiresAction, call.input.Status)
|
|
require.True(t, call.input.RequiresActionDeadlineAt.Valid)
|
|
}
|
|
|
|
func TestWorker_SkipsFreshlyOwnedChat(t *testing.T) {
|
|
t.Parallel()
|
|
f := newWorkerTestFixture(t)
|
|
chat := f.createRunningChat(t)
|
|
otherWorker := uuid.New()
|
|
otherRunner := uuid.New()
|
|
acquireChat(t, f, chat.ID, otherWorker, otherRunner)
|
|
starter := newRecordingTaskStarter()
|
|
worker := startWorker(t, testOptions(t, f, starter))
|
|
worker.Wake()
|
|
|
|
starter.assertNoCall(t)
|
|
latest, err := f.db.GetChatByID(testutil.Context(t, testutil.WaitShort), chat.ID)
|
|
require.NoError(t, err)
|
|
require.Equal(t, otherWorker, latest.WorkerID.UUID)
|
|
require.Equal(t, otherRunner, latest.RunnerID.UUID)
|
|
}
|
|
|
|
func TestWorker_TwoWorkersRaceSingleOwner(t *testing.T) {
|
|
t.Parallel()
|
|
f := newWorkerTestFixture(t)
|
|
chat := f.createRunningChat(t)
|
|
firstStarter := newRecordingTaskStarter()
|
|
secondStarter := newRecordingTaskStarter()
|
|
first := startWorker(t, testOptions(t, f, firstStarter))
|
|
second := startWorker(t, testOptions(t, f, secondStarter))
|
|
|
|
call := waitAnyTaskCall(t, firstStarter, secondStarter, taskKindGeneration, chat.ID)
|
|
require.Contains(t, []uuid.UUID{first.chatWorkerID(), second.chatWorkerID()}, call.input.WorkerID)
|
|
firstStarter.assertNoCall(t)
|
|
secondStarter.assertNoCall(t)
|
|
|
|
latest, err := f.db.GetChatByID(testutil.Context(t, testutil.WaitShort), chat.ID)
|
|
require.NoError(t, err)
|
|
require.True(t, latest.WorkerID.Valid)
|
|
require.True(t, latest.RunnerID.Valid)
|
|
require.Equal(t, call.input.WorkerID, latest.WorkerID.UUID)
|
|
require.Equal(t, call.input.RunnerID, latest.RunnerID.UUID)
|
|
}
|
|
|
|
func TestWorker_DrainsMultipleRunnableChatsOnWake(t *testing.T) {
|
|
t.Parallel()
|
|
f := newWorkerTestFixture(t)
|
|
first := f.createRunningChat(t)
|
|
second := f.createRunningChat(t)
|
|
third := f.createRunningChat(t)
|
|
starter := newRecordingTaskStarter()
|
|
opts := testOptions(t, f, starter)
|
|
opts.AcquisitionBatchSize = 1
|
|
startWorker(t, opts)
|
|
|
|
want := map[uuid.UUID]bool{first.ID: true, second.ID: true, third.ID: true}
|
|
for range 3 {
|
|
call := starter.waitCall(t, taskKindGeneration, uuid.Nil)
|
|
delete(want, call.input.ChatID)
|
|
}
|
|
require.Empty(t, want)
|
|
}
|
|
|
|
func TestWorker_DoesNotAcquireIdleOrArchivedChats(t *testing.T) {
|
|
t.Parallel()
|
|
f := newWorkerTestFixture(t)
|
|
waiting := f.createRunningChat(t)
|
|
finishTurn(t, f, waiting.ID)
|
|
errorChat := f.createRunningChat(t)
|
|
forceExecutionStateAndPublish(t, f, errorChat.ID, database.ChatStatusError, false)
|
|
archived := f.createRunningChat(t)
|
|
forceExecutionStateAndPublish(t, f, archived.ID, database.ChatStatusRunning, true)
|
|
starter := newRecordingTaskStarter()
|
|
worker := startWorker(t, testOptions(t, f, starter))
|
|
worker.Wake()
|
|
|
|
starter.assertNoCall(t)
|
|
}
|
|
|
|
func TestWorker_HeartbeatLoopRefreshesActiveRunnerHeartbeat(t *testing.T) {
|
|
t.Parallel()
|
|
f := newWorkerTestFixture(t)
|
|
chat := f.createRunningChat(t)
|
|
clock := quartz.NewMock(t)
|
|
heartbeatTrap := clock.Trap().NewTicker("chatworker", "heartbeat")
|
|
defer heartbeatTrap.Close()
|
|
starter := newBlockingTaskStarter(false)
|
|
opts := testOptions(t, f, starter)
|
|
opts.Clock = clock
|
|
opts.HeartbeatInterval = time.Minute
|
|
startWorker(t, opts)
|
|
heartbeatTrap.MustWait(testutil.Context(t, testutil.WaitLong)).MustRelease(testutil.Context(t, testutil.WaitLong))
|
|
call := starter.waitCall(t, taskKindGeneration, chat.ID)
|
|
oldHeartbeat := makeHeartbeatStale(t, f, chat.ID, call.input.RunnerID)
|
|
|
|
clock.Advance(time.Minute).MustWait(testutil.Context(t, testutil.WaitLong))
|
|
testutil.Eventually(testutil.Context(t, testutil.WaitLong), t, func(ctx context.Context) bool {
|
|
heartbeat, err := f.db.GetChatHeartbeat(ctx, database.GetChatHeartbeatParams{
|
|
ChatID: chat.ID,
|
|
RunnerID: call.input.RunnerID,
|
|
})
|
|
return err == nil && heartbeat.HeartbeatAt.After(oldHeartbeat)
|
|
}, testutil.IntervalFast, "heartbeat should be refreshed")
|
|
}
|
|
|
|
func TestWorker_HeartbeatCleanupDeletesStaleRows(t *testing.T) {
|
|
t.Parallel()
|
|
f := newWorkerTestFixture(t)
|
|
chat := f.createRunningChat(t)
|
|
finishTurn(t, f, chat.ID)
|
|
runnerID := uuid.New()
|
|
require.NoError(t, f.db.UpsertChatHeartbeat(testutil.Context(t, testutil.WaitShort), database.UpsertChatHeartbeatParams{
|
|
ChatID: chat.ID,
|
|
RunnerID: runnerID,
|
|
}))
|
|
makeHeartbeatStale(t, f, chat.ID, runnerID)
|
|
clock := quartz.NewMock(t)
|
|
cleanupTrap := clock.Trap().NewTicker("chatworker", "heartbeat-cleanup")
|
|
defer cleanupTrap.Close()
|
|
starter := newRecordingTaskStarter()
|
|
opts := testOptions(t, f, starter)
|
|
opts.Clock = clock
|
|
opts.HeartbeatCleanupInterval = time.Minute
|
|
startWorker(t, opts)
|
|
cleanupTrap.MustWait(testutil.Context(t, testutil.WaitLong)).MustRelease(testutil.Context(t, testutil.WaitLong))
|
|
|
|
clock.Advance(time.Minute).MustWait(testutil.Context(t, testutil.WaitLong))
|
|
testutil.Eventually(testutil.Context(t, testutil.WaitLong), t, func(ctx context.Context) bool {
|
|
_, err := f.db.GetChatHeartbeat(ctx, database.GetChatHeartbeatParams{
|
|
ChatID: chat.ID,
|
|
RunnerID: runnerID,
|
|
})
|
|
return errors.Is(err, sql.ErrNoRows)
|
|
}, testutil.IntervalFast)
|
|
}
|
|
|
|
func TestWorker_CloseDeletesOwnedHeartbeatsAndPublishesOwnershipHints(t *testing.T) {
|
|
t.Parallel()
|
|
f := newWorkerTestFixture(t)
|
|
first := f.createRunningChat(t)
|
|
second := f.createRunningChat(t)
|
|
starter := newBlockingTaskStarter(false)
|
|
pubsub := newRecordingPubsub(f.pubsub)
|
|
opts := testOptions(t, f, starter)
|
|
opts.Pubsub = pubsub
|
|
worker := startWorker(t, opts)
|
|
callsByChat := make(map[uuid.UUID]taskCall)
|
|
for range 2 {
|
|
call := starter.waitCall(t, taskKindGeneration, uuid.Nil)
|
|
callsByChat[call.input.ChatID] = call
|
|
}
|
|
require.Contains(t, callsByChat, first.ID)
|
|
require.Contains(t, callsByChat, second.ID)
|
|
|
|
require.NoError(t, worker.Close())
|
|
for _, call := range callsByChat {
|
|
_, err := f.db.GetChatHeartbeat(testutil.Context(t, testutil.WaitShort), database.GetChatHeartbeatParams{
|
|
ChatID: call.input.ChatID,
|
|
RunnerID: call.input.RunnerID,
|
|
})
|
|
require.ErrorIs(t, err, sql.ErrNoRows)
|
|
}
|
|
|
|
messages := pubsub.ownershipMessages(t)
|
|
seen := make(map[uuid.UUID]bool)
|
|
for _, msg := range messages {
|
|
seen[msg.ChatID] = true
|
|
require.NotZero(t, msg.SnapshotVersion)
|
|
}
|
|
require.True(t, seen[first.ID], "expected ownership hint for first runner")
|
|
require.True(t, seen[second.ID], "expected ownership hint for second runner")
|
|
}
|
|
|
|
func TestWorker_CloseIsIdempotentAndDoesNotBlock(t *testing.T) {
|
|
t.Parallel()
|
|
f := newWorkerTestFixture(t)
|
|
chat := f.createRunningChat(t)
|
|
starter := newBlockingTaskStarter(false)
|
|
worker := startWorker(t, testOptions(t, f, starter))
|
|
call := starter.waitCall(t, taskKindGeneration, chat.ID)
|
|
|
|
closed := make(chan error, 1)
|
|
go func() {
|
|
if err := worker.Close(); err != nil {
|
|
closed <- err
|
|
return
|
|
}
|
|
closed <- worker.Close()
|
|
}()
|
|
select {
|
|
case err := <-closed:
|
|
require.NoError(t, err)
|
|
case <-time.After(testutil.WaitLong):
|
|
t.Fatal("worker close did not return")
|
|
}
|
|
select {
|
|
case <-call.ctx.Done():
|
|
case <-time.After(testutil.WaitLong):
|
|
t.Fatal("active task was not canceled")
|
|
}
|
|
}
|
|
|
|
func waitAnyTaskCall(
|
|
t *testing.T,
|
|
first *recordingTaskStarter,
|
|
second *recordingTaskStarter,
|
|
kind taskKind,
|
|
chatID uuid.UUID,
|
|
) taskCall {
|
|
t.Helper()
|
|
deadline := time.After(testutil.WaitLong)
|
|
for {
|
|
select {
|
|
case call := <-first.callCh:
|
|
if call.kind == kind && call.input.ChatID == chatID {
|
|
return call
|
|
}
|
|
case call := <-second.callCh:
|
|
if call.kind == kind && call.input.ChatID == chatID {
|
|
return call
|
|
}
|
|
case <-deadline:
|
|
t.Fatal("timed out waiting for either worker to start task")
|
|
return taskCall{}
|
|
}
|
|
}
|
|
}
|
|
|
|
func requireTaskCanceled(t *testing.T, call taskCall) {
|
|
t.Helper()
|
|
select {
|
|
case <-call.ctx.Done():
|
|
require.True(t, errors.Is(call.ctx.Err(), context.Canceled))
|
|
case <-time.After(testutil.WaitLong):
|
|
t.Fatal("task context was not canceled")
|
|
}
|
|
}
|