Files
coder/coderd/x/chatd/tasks_test.go
T
Hugo Dutka 658a04d28f pr 3
2026-06-04 18:51:22 +00:00

1126 lines
40 KiB
Go

//nolint:testpackage // These tests exercise package-private task seams.
package chatd
import (
"context"
"database/sql"
"encoding/json"
"sync"
"testing"
"time"
"github.com/google/uuid"
"github.com/sqlc-dev/pqtype"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub"
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
"github.com/coder/coder/v2/coderd/x/chatd/chaterror"
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
"github.com/coder/coder/v2/coderd/x/chatd/chatretry"
"github.com/coder/coder/v2/coderd/x/chatd/chatstate"
"github.com/coder/coder/v2/coderd/x/chatd/messagepartbuffer"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil"
"github.com/coder/quartz"
)
func TestRetryWrapper_ExpectedExitsDoNotRetry(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
calls := 0
err := runTaskWithRetry(ctx, retryWrapperOptions{
clock: quartz.NewMock(t),
initialDelay: time.Second,
maxDelay: time.Second,
}, taskKindInterrupt, func(context.Context) error {
calls++
return errTaskExpectedExit
})
require.NoError(t, err)
require.Equal(t, 1, calls)
}
func TestRetryWrapper_UnexpectedErrorsRetry(t *testing.T) {
t.Parallel()
clock := quartz.NewMock(t)
trap := clock.Trap().NewTimer("chatworker", "task-retry-requires_action_timeout")
defer trap.Close()
ctx := testutil.Context(t, testutil.WaitLong)
calls := 0
done := make(chan error, 1)
go func() {
done <- runTaskWithRetry(ctx, retryWrapperOptions{
clock: clock,
initialDelay: time.Minute,
maxDelay: time.Minute,
}, taskKindRequiresActionTimeout, func(context.Context) error {
calls++
if calls == 1 {
return xerrors.New("database unavailable")
}
return nil
})
}()
trap.MustWait(ctx).MustRelease(ctx)
clock.Advance(time.Minute).MustWait(ctx)
require.NoError(t, <-done)
require.Equal(t, 2, calls)
}
func TestRetryWrapper_PanicsRetry(t *testing.T) {
t.Parallel()
clock := quartz.NewMock(t)
trap := clock.Trap().NewTimer("chatworker", "task-retry-generation")
defer trap.Close()
ctx := testutil.Context(t, testutil.WaitLong)
calls := 0
done := make(chan error, 1)
go func() {
done <- runTaskWithRetry(ctx, retryWrapperOptions{
clock: clock,
initialDelay: time.Minute,
maxDelay: time.Minute,
}, taskKindGeneration, func(context.Context) error {
calls++
if calls == 1 {
panic("database unavailable")
}
return nil
})
}()
trap.MustWait(ctx).MustRelease(ctx)
clock.Advance(time.Minute).MustWait(ctx)
require.NoError(t, <-done)
require.Equal(t, 2, calls)
}
func TestInterruptTask_FinishInterruptionOnly(t *testing.T) {
t.Parallel()
f := newTaskTestFixture(t)
chat := f.createRunningChat(t)
workerID := uuid.New()
runnerID := uuid.New()
acquired := f.acquireChat(t, chat.ID, workerID, runnerID)
buffer := messagepartbuffer.New(messagepartbuffer.Options{})
key := messagepartbuffer.Key{
ChatID: chat.ID,
HistoryVersion: acquired.HistoryVersion,
GenerationAttempt: acquired.GenerationAttempt,
}
require.NoError(t, buffer.CreateEpisode(key))
require.NoError(t, buffer.AddPart(key, codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageText("partial answer")))
interrupting := f.interruptChat(t, chat.ID)
require.Equal(t, database.ChatStatusInterrupting, interrupting.Status)
recorder := newTaskSideEffectRecorder()
starter := newTestTaskStarter(t, f, buffer, recorder)
err := starter.StartInterrupt(testutil.Context(t, testutil.WaitLong), chatWorkerTaskStartInput{
ChatID: chat.ID,
WorkerID: workerID,
RunnerID: runnerID,
HistoryVersion: interrupting.HistoryVersion,
GenerationAttempt: interrupting.GenerationAttempt,
Status: database.ChatStatusInterrupting,
})
require.NoError(t, err)
latest, err := f.db.GetChatByID(testutil.Context(t, testutil.WaitShort), chat.ID)
require.NoError(t, err)
require.Equal(t, database.ChatStatusRunning, latest.Status)
recorder.requireStateHint(t, chat.ID, latest.SnapshotVersion, database.ChatStatusRunning)
recorder.requireInterruptionOutcome(t, chat.ID, database.ChatStatusRunning)
recorder.requireCleanupCount(t, 0)
f.requireWatchEvent(t, chat.ID, codersdk.ChatWatchEventKindStatusChange)
messages, err := f.db.GetChatMessagesByChatID(testutil.Context(t, testutil.WaitShort), database.GetChatMessagesByChatIDParams{ChatID: chat.ID})
require.NoError(t, err)
require.GreaterOrEqual(t, len(messages), 3)
parts, err := chatprompt.ParseContent(messages[len(messages)-2])
require.NoError(t, err)
require.Equal(t, []codersdk.ChatMessagePart{codersdk.ChatMessageText("partial answer")}, parts)
require.Equal(t, database.ChatMessageRoleUser, messages[len(messages)-1].Role)
}
func TestInterruptTask_StaleFenceExits(t *testing.T) {
t.Parallel()
f := newTaskTestFixture(t)
chat := f.createRunningChat(t)
workerID := uuid.New()
runnerID := uuid.New()
f.acquireChat(t, chat.ID, workerID, runnerID)
interrupting := f.interruptChat(t, chat.ID)
otherWorkerID := uuid.New()
otherRunnerID := uuid.New()
f.acquireChat(t, chat.ID, otherWorkerID, otherRunnerID)
recorder := newTaskSideEffectRecorder()
starter := newTestTaskStarter(t, f, messagepartbuffer.New(messagepartbuffer.Options{}), recorder)
err := starter.StartInterrupt(testutil.Context(t, testutil.WaitLong), chatWorkerTaskStartInput{
ChatID: chat.ID,
WorkerID: workerID,
RunnerID: runnerID,
HistoryVersion: interrupting.HistoryVersion,
GenerationAttempt: interrupting.GenerationAttempt,
Status: database.ChatStatusInterrupting,
})
require.ErrorIs(t, err, errTaskExpectedExit)
latest, err := f.db.GetChatByID(testutil.Context(t, testutil.WaitShort), chat.ID)
require.NoError(t, err)
require.Equal(t, database.ChatStatusInterrupting, latest.Status)
require.Equal(t, otherWorkerID, latest.WorkerID.UUID)
require.Equal(t, otherRunnerID, latest.RunnerID.UUID)
recorder.requireStateHintCount(t, 0)
f.requireNoWatchEvents(t)
}
func TestInterruptTask_MissingEpisodePersistsNilPartials(t *testing.T) {
t.Parallel()
f := newTaskTestFixture(t)
chat := f.createRunningChat(t)
workerID := uuid.New()
runnerID := uuid.New()
f.acquireChat(t, chat.ID, workerID, runnerID)
interrupting := f.forceExecutionState(t, chat.ID, database.ChatStatusInterrupting, false, sql.NullTime{})
recorder := newTaskSideEffectRecorder()
starter := newTestTaskStarter(t, f, messagepartbuffer.New(messagepartbuffer.Options{}), recorder)
err := starter.StartInterrupt(testutil.Context(t, testutil.WaitLong), chatWorkerTaskStartInput{
ChatID: chat.ID,
WorkerID: workerID,
RunnerID: runnerID,
HistoryVersion: interrupting.HistoryVersion,
GenerationAttempt: interrupting.GenerationAttempt,
Status: database.ChatStatusInterrupting,
})
require.NoError(t, err)
latest, err := f.db.GetChatByID(testutil.Context(t, testutil.WaitShort), chat.ID)
require.NoError(t, err)
require.Equal(t, database.ChatStatusWaiting, latest.Status)
recorder.requireInterruptionOutcome(t, chat.ID, database.ChatStatusWaiting)
messages, err := f.db.GetChatMessagesByChatID(testutil.Context(t, testutil.WaitShort), database.GetChatMessagesByChatIDParams{ChatID: chat.ID})
require.NoError(t, err)
require.Len(t, messages, 1)
recorder.requireStateHint(t, chat.ID, latest.SnapshotVersion, database.ChatStatusWaiting)
}
func TestInterruptTask_BufferedPartsBecomePartialMessages(t *testing.T) {
t.Parallel()
f := newTaskTestFixture(t)
chat := f.createRunningChat(t)
workerID := uuid.New()
runnerID := uuid.New()
acquired := f.acquireChat(t, chat.ID, workerID, runnerID)
buffer := messagepartbuffer.New(messagepartbuffer.Options{})
key := messagepartbuffer.Key{ChatID: chat.ID, HistoryVersion: acquired.HistoryVersion, GenerationAttempt: acquired.GenerationAttempt}
require.NoError(t, buffer.CreateEpisode(key))
callID := "call_" + uuid.NewString()
require.NoError(t, buffer.AddPart(key, codersdk.ChatMessageRoleAssistant, codersdk.ChatMessagePart{
Type: codersdk.ChatMessagePartTypeToolCall,
ToolCallID: callID,
ToolName: "local_tool",
Args: json.RawMessage(`{"value":1}`),
}))
interrupting := f.interruptChat(t, chat.ID)
recorder := newTaskSideEffectRecorder()
starter := newTestTaskStarter(t, f, buffer, recorder)
err := starter.StartInterrupt(testutil.Context(t, testutil.WaitLong), chatWorkerTaskStartInput{
ChatID: chat.ID,
WorkerID: workerID,
RunnerID: runnerID,
HistoryVersion: interrupting.HistoryVersion,
GenerationAttempt: interrupting.GenerationAttempt,
Status: database.ChatStatusInterrupting,
})
require.NoError(t, err)
messages, err := f.db.GetChatMessagesByChatID(testutil.Context(t, testutil.WaitShort), database.GetChatMessagesByChatIDParams{ChatID: chat.ID})
require.NoError(t, err)
require.GreaterOrEqual(t, len(messages), 4)
assistant := messages[len(messages)-3]
tool := messages[len(messages)-2]
require.Equal(t, database.ChatMessageRoleAssistant, assistant.Role)
require.Equal(t, database.ChatMessageRoleTool, tool.Role)
toolParts, err := chatprompt.ParseContent(tool)
require.NoError(t, err)
require.Len(t, toolParts, 1)
require.Equal(t, codersdk.ChatMessagePartTypeToolResult, toolParts[0].Type)
require.Equal(t, callID, toolParts[0].ToolCallID)
require.True(t, toolParts[0].IsError)
}
func TestRequiresActionTimeout_ExpiredCancelsOnly(t *testing.T) {
t.Parallel()
f := newTaskTestFixture(t)
chat := f.createRequiresActionChat(t)
workerID := uuid.New()
runnerID := uuid.New()
acquired := f.acquireChat(t, chat.ID, workerID, runnerID)
expired := f.setRequiresActionDeadline(t, chat.ID, sql.NullTime{Time: time.Now().Add(-time.Minute), Valid: true})
recorder := newTaskSideEffectRecorder()
starter := newTestTaskStarter(t, f, messagepartbuffer.New(messagepartbuffer.Options{}), recorder)
err := starter.StartRequiresActionTimeout(testutil.Context(t, testutil.WaitLong), chatWorkerTaskStartInput{
ChatID: chat.ID,
WorkerID: workerID,
RunnerID: runnerID,
HistoryVersion: acquired.HistoryVersion,
Status: database.ChatStatusRequiresAction,
RequiresActionDeadlineAt: expired.RequiresActionDeadlineAt,
})
require.NoError(t, err)
latest, err := f.db.GetChatByID(testutil.Context(t, testutil.WaitShort), chat.ID)
require.NoError(t, err)
require.Equal(t, database.ChatStatusRunning, latest.Status)
require.False(t, latest.RequiresActionDeadlineAt.Valid)
recorder.requireStateHint(t, chat.ID, latest.SnapshotVersion, database.ChatStatusRunning)
f.requireWatchEvent(t, chat.ID, codersdk.ChatWatchEventKindStatusChange)
}
func TestRequiresActionTimeout_NullDeadlineCancelsImmediately(t *testing.T) {
t.Parallel()
f := newTaskTestFixture(t)
chat := f.createRequiresActionChat(t)
workerID := uuid.New()
runnerID := uuid.New()
acquired := f.acquireChat(t, chat.ID, workerID, runnerID)
nullDeadline := f.setRequiresActionDeadline(t, chat.ID, sql.NullTime{})
recorder := newTaskSideEffectRecorder()
starter := newTestTaskStarter(t, f, messagepartbuffer.New(messagepartbuffer.Options{}), recorder)
err := starter.StartRequiresActionTimeout(testutil.Context(t, testutil.WaitLong), chatWorkerTaskStartInput{
ChatID: chat.ID,
WorkerID: workerID,
RunnerID: runnerID,
HistoryVersion: acquired.HistoryVersion,
Status: database.ChatStatusRequiresAction,
RequiresActionDeadlineAt: nullDeadline.RequiresActionDeadlineAt,
})
require.NoError(t, err)
latest, err := f.db.GetChatByID(testutil.Context(t, testutil.WaitShort), chat.ID)
require.NoError(t, err)
require.Equal(t, database.ChatStatusRunning, latest.Status)
recorder.requireStateHint(t, chat.ID, latest.SnapshotVersion, database.ChatStatusRunning)
}
func TestRequiresActionTimeout_StaleFenceExitsAfterToolResult(t *testing.T) {
t.Parallel()
f := newTaskTestFixture(t)
chat := f.createRequiresActionChat(t)
workerID := uuid.New()
runnerID := uuid.New()
acquired := f.acquireChat(t, chat.ID, workerID, runnerID)
expired := f.setRequiresActionDeadline(t, chat.ID, sql.NullTime{Time: time.Now().Add(-time.Minute), Valid: true})
f.forceExecutionState(t, chat.ID, database.ChatStatusRunning, false, sql.NullTime{})
recorder := newTaskSideEffectRecorder()
starter := newTestTaskStarter(t, f, messagepartbuffer.New(messagepartbuffer.Options{}), recorder)
err := starter.StartRequiresActionTimeout(testutil.Context(t, testutil.WaitLong), chatWorkerTaskStartInput{
ChatID: chat.ID,
WorkerID: workerID,
RunnerID: runnerID,
HistoryVersion: acquired.HistoryVersion,
Status: database.ChatStatusRequiresAction,
RequiresActionDeadlineAt: expired.RequiresActionDeadlineAt,
})
require.ErrorIs(t, err, errTaskExpectedExit)
latest, err := f.db.GetChatByID(testutil.Context(t, testutil.WaitShort), chat.ID)
require.NoError(t, err)
require.Equal(t, database.ChatStatusRunning, latest.Status)
recorder.requireStateHintCount(t, 0)
f.requireNoWatchEvents(t)
}
func TestAbandonTask_AbandonOnly(t *testing.T) {
t.Parallel()
f := newTaskTestFixture(t)
chat := f.createRunningChat(t)
workerID := uuid.New()
runnerID := uuid.New()
acquired := f.acquireChat(t, chat.ID, workerID, runnerID)
recorder := newTaskSideEffectRecorder()
starter := newTestTaskStarter(t, f, messagepartbuffer.New(messagepartbuffer.Options{}), recorder)
err := starter.StartAbandon(testutil.Context(t, testutil.WaitLong), chatWorkerTaskStartInput{
ChatID: chat.ID,
WorkerID: workerID,
RunnerID: runnerID,
HistoryVersion: acquired.HistoryVersion,
Status: database.ChatStatusRunning,
})
require.NoError(t, err)
latest, err := f.db.GetChatByID(testutil.Context(t, testutil.WaitShort), chat.ID)
require.NoError(t, err)
require.False(t, latest.WorkerID.Valid)
require.False(t, latest.RunnerID.Valid)
recorder.requireCleanup(t, chat.ID, runnerID)
recorder.requireStateHintCount(t, 0)
f.requireNoWatchEvents(t)
}
func TestAbandonTask_OwnershipMismatchRequestsCleanup(t *testing.T) {
t.Parallel()
f := newTaskTestFixture(t)
chat := f.createRunningChat(t)
workerID := uuid.New()
runnerID := uuid.New()
f.acquireChat(t, chat.ID, workerID, runnerID)
otherWorkerID := uuid.New()
otherRunnerID := uuid.New()
latestOwner := f.acquireChat(t, chat.ID, otherWorkerID, otherRunnerID)
recorder := newTaskSideEffectRecorder()
starter := newTestTaskStarter(t, f, messagepartbuffer.New(messagepartbuffer.Options{}), recorder)
err := starter.StartAbandon(testutil.Context(t, testutil.WaitLong), chatWorkerTaskStartInput{
ChatID: chat.ID,
WorkerID: workerID,
RunnerID: runnerID,
HistoryVersion: latestOwner.HistoryVersion,
Status: database.ChatStatusRunning,
})
require.NoError(t, err)
latest, err := f.db.GetChatByID(testutil.Context(t, testutil.WaitShort), chat.ID)
require.NoError(t, err)
require.Equal(t, otherWorkerID, latest.WorkerID.UUID)
require.Equal(t, otherRunnerID, latest.RunnerID.UUID)
recorder.requireCleanup(t, chat.ID, runnerID)
}
func TestAbandonTask_StaleStatusFenceExits(t *testing.T) {
t.Parallel()
f := newTaskTestFixture(t)
chat := f.createRunningChat(t)
workerID := uuid.New()
runnerID := uuid.New()
acquired := f.acquireChat(t, chat.ID, workerID, runnerID)
f.forceExecutionState(t, chat.ID, database.ChatStatusInterrupting, false, sql.NullTime{})
recorder := newTaskSideEffectRecorder()
starter := newTestTaskStarter(t, f, messagepartbuffer.New(messagepartbuffer.Options{}), recorder)
err := starter.StartAbandon(testutil.Context(t, testutil.WaitLong), chatWorkerTaskStartInput{
ChatID: chat.ID,
WorkerID: workerID,
RunnerID: runnerID,
HistoryVersion: acquired.HistoryVersion,
Status: database.ChatStatusWaiting,
})
require.ErrorIs(t, err, errTaskExpectedExit)
latest, err := f.db.GetChatByID(testutil.Context(t, testutil.WaitShort), chat.ID)
require.NoError(t, err)
require.True(t, latest.WorkerID.Valid)
require.True(t, latest.RunnerID.Valid)
require.Equal(t, database.ChatStatusInterrupting, latest.Status)
recorder.requireCleanupCount(t, 0)
}
func TestGenerationTask_RecordRetryState(t *testing.T) {
t.Parallel()
f := newTaskTestFixture(t)
chat := f.createRunningChat(t)
workerID := uuid.New()
runnerID := uuid.New()
acquired := f.acquireChat(t, chat.ID, workerID, runnerID)
recorder := newTaskSideEffectRecorder()
starter := newTestTaskStarter(t, f, messagepartbuffer.New(messagepartbuffer.Options{}), recorder)
attempt, _, _, closeEpisode, err := starter.beginGenerationAttempt(
testutil.Context(t, testutil.WaitLong),
chatstate.NewChatMachine(f.db, f.pubsub, chat.ID, chatstate.Options{}),
chatWorkerTaskStartInput{
ChatID: chat.ID,
WorkerID: workerID,
RunnerID: runnerID,
HistoryVersion: acquired.HistoryVersion,
Status: database.ChatStatusRunning,
},
)
require.NoError(t, err)
closeEpisode()
require.Equal(t, int64(1), attempt)
before, err := f.db.GetChatByID(testutil.Context(t, testutil.WaitShort), chat.ID)
require.NoError(t, err)
require.False(t, before.RetryState.Valid)
decision, err := starter.recordGenerationRetry(
testutil.Context(t, testutil.WaitLong),
chatstate.NewChatMachine(f.db, f.pubsub, chat.ID, chatstate.Options{}),
chatWorkerTaskStartInput{
ChatID: chat.ID,
WorkerID: workerID,
RunnerID: runnerID,
HistoryVersion: acquired.HistoryVersion,
Status: database.ChatStatusRunning,
},
chaterror.ClassifiedError{
Message: "OpenAI is rate limiting requests.",
Kind: codersdk.ChatErrorKindRateLimit,
Provider: "openai",
Retryable: true,
StatusCode: 429,
},
)
require.NoError(t, err)
require.True(t, decision.retry)
require.Equal(t, int64(1), decision.generationAttempt)
require.Equal(t, chatretry.Delay(0), decision.delay)
latest, err := f.db.GetChatByID(testutil.Context(t, testutil.WaitShort), chat.ID)
require.NoError(t, err)
require.True(t, latest.RetryState.Valid)
require.Equal(t, latest.SnapshotVersion, latest.RetryStateVersion)
require.Greater(t, latest.RetryStateVersion, before.RetryStateVersion)
require.Equal(t, before.GenerationAttempt, latest.GenerationAttempt)
recorder.requireStateHintCount(t, 0)
var retryPayload codersdk.ChatStreamRetry
require.NoError(t, json.Unmarshal(latest.RetryState.RawMessage, &retryPayload))
require.Equal(t, 1, retryPayload.Attempt)
require.Equal(t, chatretry.Delay(0).Milliseconds(), retryPayload.DelayMs)
require.Equal(t, "OpenAI is rate limiting requests.", retryPayload.Error)
require.Equal(t, codersdk.ChatErrorKindRateLimit, retryPayload.Kind)
require.Equal(t, "openai", retryPayload.Provider)
require.Equal(t, 429, retryPayload.StatusCode)
require.False(t, retryPayload.RetryingAt.IsZero())
}
func TestGenerationTask_RecordRetryStateUsesDurableGenerationAttempt(t *testing.T) {
t.Parallel()
f := newTaskTestFixture(t)
chat := f.createRunningChat(t)
workerID := uuid.New()
runnerID := uuid.New()
acquired := f.acquireChat(t, chat.ID, workerID, runnerID)
starter := newTestTaskStarter(t, f, messagepartbuffer.New(messagepartbuffer.Options{}), newTaskSideEffectRecorder())
machine := chatstate.NewChatMachine(f.db, f.pubsub, chat.ID, chatstate.Options{})
for range 3 {
attempt, _, _, closeEpisode, err := starter.beginGenerationAttempt(
testutil.Context(t, testutil.WaitLong),
machine,
chatWorkerTaskStartInput{
ChatID: chat.ID,
WorkerID: workerID,
RunnerID: runnerID,
HistoryVersion: acquired.HistoryVersion,
Status: database.ChatStatusRunning,
},
)
require.NoError(t, err)
closeEpisode()
require.Positive(t, attempt)
}
decision, err := starter.recordGenerationRetry(
testutil.Context(t, testutil.WaitLong),
machine,
chatWorkerTaskStartInput{
ChatID: chat.ID,
WorkerID: workerID,
RunnerID: runnerID,
HistoryVersion: acquired.HistoryVersion,
Status: database.ChatStatusRunning,
},
chaterror.ClassifiedError{
Message: "OpenAI is temporarily unavailable.",
Kind: codersdk.ChatErrorKindTimeout,
Provider: "openai",
Retryable: true,
},
)
require.NoError(t, err)
require.True(t, decision.retry)
require.Equal(t, int64(3), decision.generationAttempt)
require.Equal(t, chatretry.Delay(2), decision.delay)
latest, err := f.db.GetChatByID(testutil.Context(t, testutil.WaitShort), chat.ID)
require.NoError(t, err)
var retryPayload codersdk.ChatStreamRetry
require.NoError(t, json.Unmarshal(latest.RetryState.RawMessage, &retryPayload))
require.Equal(t, 3, retryPayload.Attempt)
require.Equal(t, chatretry.Delay(2).Milliseconds(), retryPayload.DelayMs)
}
func TestGenerationTask_RecordRetryStateClearedByNextAttempt(t *testing.T) {
t.Parallel()
f := newTaskTestFixture(t)
chat := f.createRunningChat(t)
workerID := uuid.New()
runnerID := uuid.New()
acquired := f.acquireChat(t, chat.ID, workerID, runnerID)
starter := newTestTaskStarter(t, f, messagepartbuffer.New(messagepartbuffer.Options{}), newTaskSideEffectRecorder())
machine := chatstate.NewChatMachine(f.db, f.pubsub, chat.ID, chatstate.Options{})
input := chatWorkerTaskStartInput{
ChatID: chat.ID,
WorkerID: workerID,
RunnerID: runnerID,
HistoryVersion: acquired.HistoryVersion,
Status: database.ChatStatusRunning,
}
attempt, _, _, closeEpisode, err := starter.beginGenerationAttempt(testutil.Context(t, testutil.WaitLong), machine, input)
require.NoError(t, err)
closeEpisode()
require.Equal(t, int64(1), attempt)
_, err = starter.recordGenerationRetry(
testutil.Context(t, testutil.WaitLong),
machine,
input,
chaterror.ClassifiedError{
Message: "OpenAI is temporarily unavailable.",
Kind: codersdk.ChatErrorKindTimeout,
Provider: "openai",
Retryable: true,
},
)
require.NoError(t, err)
withRetry, err := f.db.GetChatByID(testutil.Context(t, testutil.WaitShort), chat.ID)
require.NoError(t, err)
require.True(t, withRetry.RetryState.Valid)
attempt, _, _, closeEpisode, err = starter.beginGenerationAttempt(testutil.Context(t, testutil.WaitLong), machine, input)
require.NoError(t, err)
closeEpisode()
require.Equal(t, int64(2), attempt)
after, err := f.db.GetChatByID(testutil.Context(t, testutil.WaitShort), chat.ID)
require.NoError(t, err)
require.False(t, after.RetryState.Valid)
require.Equal(t, after.SnapshotVersion, after.RetryStateVersion)
require.Greater(t, after.RetryStateVersion, withRetry.RetryStateVersion)
}
func TestGenerationTask_RecordRetryStateStaleFenceExits(t *testing.T) {
t.Parallel()
f := newTaskTestFixture(t)
chat := f.createRunningChat(t)
workerID := uuid.New()
runnerID := uuid.New()
acquired := f.acquireChat(t, chat.ID, workerID, runnerID)
starter := newTestTaskStarter(t, f, messagepartbuffer.New(messagepartbuffer.Options{}), newTaskSideEffectRecorder())
machine := chatstate.NewChatMachine(f.db, f.pubsub, chat.ID, chatstate.Options{})
attempt, _, _, closeEpisode, err := starter.beginGenerationAttempt(
testutil.Context(t, testutil.WaitLong),
machine,
chatWorkerTaskStartInput{
ChatID: chat.ID,
WorkerID: workerID,
RunnerID: runnerID,
HistoryVersion: acquired.HistoryVersion,
Status: database.ChatStatusRunning,
},
)
require.NoError(t, err)
closeEpisode()
require.Equal(t, int64(1), attempt)
otherWorkerID := uuid.New()
otherRunnerID := uuid.New()
f.acquireChat(t, chat.ID, otherWorkerID, otherRunnerID)
_, err = starter.recordGenerationRetry(
testutil.Context(t, testutil.WaitLong),
machine,
chatWorkerTaskStartInput{
ChatID: chat.ID,
WorkerID: workerID,
RunnerID: runnerID,
HistoryVersion: acquired.HistoryVersion,
Status: database.ChatStatusRunning,
},
chaterror.ClassifiedError{
Message: "OpenAI is temporarily unavailable.",
Kind: codersdk.ChatErrorKindTimeout,
Provider: "openai",
Retryable: true,
},
)
require.ErrorIs(t, err, errTaskExpectedExit)
latest, err := f.db.GetChatByID(testutil.Context(t, testutil.WaitShort), chat.ID)
require.NoError(t, err)
require.False(t, latest.RetryState.Valid)
require.Equal(t, otherWorkerID, latest.WorkerID.UUID)
require.Equal(t, otherRunnerID, latest.RunnerID.UUID)
}
func TestRunner_StartsRealInterruptTask(t *testing.T) {
t.Parallel()
f := newTaskTestFixture(t)
chat := f.createRunningChat(t)
worker := startRealTaskWorker(t, f, messagepartbuffer.New(messagepartbuffer.Options{}))
waitOwnedChat(t, f, chat.ID, worker.chatWorkerID())
interrupting := f.interruptChat(t, chat.ID)
require.Equal(t, database.ChatStatusInterrupting, interrupting.Status)
testutil.Eventually(testutil.Context(t, testutil.WaitLong), t, func(ctx context.Context) bool {
latest, err := f.db.GetChatByID(ctx, chat.ID)
return err == nil && latest.Status == database.ChatStatusRunning
}, testutil.IntervalFast)
latest, err := f.db.GetChatByID(testutil.Context(t, testutil.WaitShort), chat.ID)
require.NoError(t, err)
require.Equal(t, worker.chatWorkerID(), latest.WorkerID.UUID)
f.requireWatchEvent(t, chat.ID, codersdk.ChatWatchEventKindStatusChange)
}
func TestRunner_StartsRealRequiresActionTimeoutTask(t *testing.T) {
t.Parallel()
f := newTaskTestFixture(t)
chat := f.createRequiresActionChat(t)
f.setRequiresActionDeadline(t, chat.ID, sql.NullTime{Time: time.Now().Add(-time.Minute), Valid: true})
worker := startRealTaskWorker(t, f, messagepartbuffer.New(messagepartbuffer.Options{}))
testutil.Eventually(testutil.Context(t, testutil.WaitLong), t, func(ctx context.Context) bool {
latest, err := f.db.GetChatByID(ctx, chat.ID)
return err == nil && latest.Status == database.ChatStatusRunning && latest.WorkerID.Valid && latest.WorkerID.UUID == worker.chatWorkerID()
}, testutil.IntervalFast)
latest, err := f.db.GetChatByID(testutil.Context(t, testutil.WaitShort), chat.ID)
require.NoError(t, err)
require.True(t, latest.RunnerID.Valid)
f.requireWatchEvent(t, chat.ID, codersdk.ChatWatchEventKindStatusChange)
}
func TestRunner_StartsRealAbandonTask(t *testing.T) {
t.Parallel()
f := newTaskTestFixture(t)
chat := f.createRunningChat(t)
worker := startRealTaskWorker(t, f, messagepartbuffer.New(messagepartbuffer.Options{}))
waitOwnedChat(t, f, chat.ID, worker.chatWorkerID())
updated := f.forceExecutionState(t, chat.ID, database.ChatStatusError, false, sql.NullTime{})
f.publishChatUpdate(t, updated)
testutil.Eventually(testutil.Context(t, testutil.WaitLong), t, func(ctx context.Context) bool {
latest, err := f.db.GetChatByID(ctx, chat.ID)
return err == nil && !latest.WorkerID.Valid && !latest.RunnerID.Valid
}, testutil.IntervalFast)
}
type taskTestFixture struct {
db database.Store
pubsub *taskRecordingPubsub
sqlDB *sql.DB
user database.User
org database.Organization
model database.ChatModelConfig
}
func newTaskTestFixture(t *testing.T) *taskTestFixture {
t.Helper()
db, ps, sqlDB := dbtestutil.NewDBWithSQLDB(t)
user := dbgen.User(t, db, database.User{})
org := dbgen.Organization(t, db, database.Organization{})
dbgen.OrganizationMember(t, db, database.OrganizationMember{UserID: user.ID, OrganizationID: org.ID})
dbgen.ChatProvider(t, db, database.ChatProvider{
Provider: "openai",
DisplayName: "openai",
BaseUrl: "http://example.invalid",
})
model := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{Provider: "openai", IsDefault: true})
return &taskTestFixture{db: db, pubsub: newTaskRecordingPubsub(ps), sqlDB: sqlDB, user: user, org: org, model: model}
}
func (f *taskTestFixture) createRunningChat(t *testing.T) database.Chat {
t.Helper()
res, err := chatstate.CreateChat(testutil.Context(t, testutil.WaitShort), f.db, f.pubsub, chatstate.CreateChatInput{
OrganizationID: f.org.ID,
OwnerID: f.user.ID,
LastModelConfigID: f.model.ID,
Title: "test",
ClientType: database.ChatClientTypeApi,
InitialMessages: []chatstate.Message{taskUserTextMessage(t, "hello", f.user.ID, f.model.ID)},
})
require.NoError(t, err)
f.pubsub.clear()
return res.Chat
}
func (f *taskTestFixture) createRequiresActionChat(t *testing.T) database.Chat {
t.Helper()
toolName := "dynamic_" + uuid.NewString()
dynamicTools, err := json.Marshal([]codersdk.DynamicTool{{
Name: toolName,
Description: "test tool",
InputSchema: json.RawMessage(`{"type":"object"}`),
}})
require.NoError(t, err)
res, err := chatstate.CreateChat(testutil.Context(t, testutil.WaitShort), f.db, f.pubsub, chatstate.CreateChatInput{
OrganizationID: f.org.ID,
OwnerID: f.user.ID,
LastModelConfigID: f.model.ID,
Title: "test",
ClientType: database.ChatClientTypeApi,
DynamicTools: pqtype.NullRawMessage{RawMessage: dynamicTools, Valid: true},
InitialMessages: []chatstate.Message{taskUserTextMessage(t, "hello", f.user.ID, f.model.ID)},
})
require.NoError(t, err)
machine := chatstate.NewChatMachine(f.db, f.pubsub, res.Chat.ID, chatstate.Options{})
require.NoError(t, machine.Update(testutil.Context(t, testutil.WaitShort), func(tx *chatstate.Tx) error {
_, err := tx.CommitStep(chatstate.CommitStepInput{Messages: []chatstate.Message{taskAssistantToolCallMessage(t, f.model.ID, toolName)}})
return err
}))
require.NoError(t, machine.Update(testutil.Context(t, testutil.WaitShort), func(tx *chatstate.Tx) error {
_, err := tx.EnterRequiresAction(chatstate.EnterRequiresActionInput{})
return err
}))
chat, err := f.db.GetChatByID(testutil.Context(t, testutil.WaitShort), res.Chat.ID)
require.NoError(t, err)
f.pubsub.clear()
return chat
}
func (f *taskTestFixture) acquireChat(t *testing.T, chatID uuid.UUID, workerID uuid.UUID, runnerID uuid.UUID) database.Chat {
t.Helper()
machine := chatstate.NewChatMachine(f.db, f.pubsub, chatID, chatstate.Options{})
require.NoError(t, machine.Update(testutil.Context(t, testutil.WaitShort), func(tx *chatstate.Tx) error {
_, err := tx.Acquire(chatstate.AcquireInput{WorkerID: workerID, RunnerID: runnerID})
return err
}))
chat, err := f.db.GetChatByID(testutil.Context(t, testutil.WaitShort), chatID)
require.NoError(t, err)
f.pubsub.clear()
return chat
}
func (f *taskTestFixture) interruptChat(t *testing.T, chatID uuid.UUID) database.Chat {
t.Helper()
machine := chatstate.NewChatMachine(f.db, f.pubsub, chatID, chatstate.Options{})
require.NoError(t, machine.Update(testutil.Context(t, testutil.WaitShort), func(tx *chatstate.Tx) error {
_, err := tx.SendMessage(chatstate.SendMessageInput{
Message: taskUserTextMessage(t, "interrupt", f.user.ID, f.model.ID),
BusyBehavior: chatstate.BusyBehaviorInterrupt,
})
return err
}))
chat, err := f.db.GetChatByID(testutil.Context(t, testutil.WaitShort), chatID)
require.NoError(t, err)
f.pubsub.clear()
return chat
}
func (f *taskTestFixture) forceExecutionState(t *testing.T, chatID uuid.UUID, status database.ChatStatus, archived bool, deadline sql.NullTime) database.Chat {
t.Helper()
var updated database.Chat
require.NoError(t, f.db.InTx(func(store database.Store) error {
if _, err := store.LockChatAndBumpSnapshotVersion(testutil.Context(t, testutil.WaitShort), chatID); err != nil {
return err
}
chat, err := store.GetChatByID(testutil.Context(t, testutil.WaitShort), chatID)
if err != nil {
return err
}
updated, err = store.UpdateChatExecutionState(testutil.Context(t, testutil.WaitShort), database.UpdateChatExecutionStateParams{
ID: chat.ID,
Status: status,
Archived: archived,
WorkerID: chat.WorkerID,
RunnerID: chat.RunnerID,
LastError: chat.LastError,
RequiresActionDeadlineAt: deadline,
})
return err
}, nil))
f.pubsub.clear()
return updated
}
func (f *taskTestFixture) setRequiresActionDeadline(t *testing.T, chatID uuid.UUID, deadline sql.NullTime) database.Chat {
t.Helper()
chat, err := f.db.GetChatByID(testutil.Context(t, testutil.WaitShort), chatID)
require.NoError(t, err)
return f.forceExecutionState(t, chatID, chat.Status, chat.Archived, deadline)
}
func (f *taskTestFixture) publishChatUpdate(t *testing.T, chat database.Chat) {
t.Helper()
msg := coderdpubsub.ChatStateUpdateMessage{
SnapshotVersion: chat.SnapshotVersion,
HistoryVersion: chat.HistoryVersion,
QueueVersion: chat.QueueVersion,
RetryStateVersion: chat.RetryStateVersion,
GenerationAttempt: chat.GenerationAttempt,
Status: string(chat.Status),
Archived: chat.Archived,
}
if chat.WorkerID.Valid {
id := chat.WorkerID.UUID
msg.WorkerID = &id
}
if chat.RunnerID.Valid {
id := chat.RunnerID.UUID
msg.RunnerID = &id
}
payload, err := json.Marshal(msg)
require.NoError(t, err)
require.NoError(t, f.pubsub.Publish(coderdpubsub.ChatStateUpdateChannel(chat.ID), payload))
}
func (f *taskTestFixture) requireWatchEvent(t *testing.T, chatID uuid.UUID, kind codersdk.ChatWatchEventKind) {
t.Helper()
events := f.pubsub.watchEvents(t)
for _, event := range events {
if event.Kind == kind && event.Chat.ID == chatID {
return
}
}
t.Fatalf("missing watch event kind=%s chat_id=%s events=%v", kind, chatID, events)
}
func (f *taskTestFixture) requireNoWatchEvents(t *testing.T) {
t.Helper()
require.Empty(t, f.pubsub.watchEvents(t))
}
func taskUserTextMessage(t *testing.T, text string, createdBy uuid.UUID, modelConfigID uuid.UUID) chatstate.Message {
t.Helper()
raw, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{codersdk.ChatMessageText(text)})
require.NoError(t, err)
return chatstate.Message{
Role: database.ChatMessageRoleUser,
Content: raw,
Visibility: database.ChatMessageVisibilityBoth,
ContentVersion: chatprompt.CurrentContentVersion,
CreatedBy: uuid.NullUUID{UUID: createdBy, Valid: true},
ModelConfigID: uuid.NullUUID{UUID: modelConfigID, Valid: true},
}
}
func taskAssistantToolCallMessage(t *testing.T, modelConfigID uuid.UUID, toolName string) chatstate.Message {
t.Helper()
raw, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeToolCall,
ToolCallID: "call_" + uuid.NewString(),
ToolName: toolName,
Args: json.RawMessage(`{}`),
}})
require.NoError(t, err)
return chatstate.Message{
Role: database.ChatMessageRoleAssistant,
Content: raw,
Visibility: database.ChatMessageVisibilityBoth,
ContentVersion: chatprompt.CurrentContentVersion,
ModelConfigID: uuid.NullUUID{UUID: modelConfigID, Valid: true},
}
}
type taskPublishedEvent struct {
channel string
payload []byte
}
type taskRecordingPubsub struct {
inner dbpubsub.Pubsub
mu sync.Mutex
sent []taskPublishedEvent
}
func newTaskRecordingPubsub(inner dbpubsub.Pubsub) *taskRecordingPubsub {
return &taskRecordingPubsub{inner: inner}
}
func (p *taskRecordingPubsub) Publish(channel string, payload []byte) error {
p.mu.Lock()
p.sent = append(p.sent, taskPublishedEvent{channel: channel, payload: append([]byte(nil), payload...)})
p.mu.Unlock()
return p.inner.Publish(channel, payload)
}
func (p *taskRecordingPubsub) SubscribeWithErr(channel string, listener dbpubsub.ListenerWithErr) (func(), error) {
return p.inner.SubscribeWithErr(channel, listener)
}
func (p *taskRecordingPubsub) clear() {
p.mu.Lock()
p.sent = nil
p.mu.Unlock()
}
func (p *taskRecordingPubsub) events() []taskPublishedEvent {
p.mu.Lock()
defer p.mu.Unlock()
return append([]taskPublishedEvent(nil), p.sent...)
}
func (p *taskRecordingPubsub) watchEvents(t *testing.T) []codersdk.ChatWatchEvent {
t.Helper()
events := p.events()
out := make([]codersdk.ChatWatchEvent, 0)
for _, event := range events {
var payload codersdk.ChatWatchEvent
if err := json.Unmarshal(event.payload, &payload); err != nil {
continue
}
if event.channel != coderdpubsub.ChatWatchEventChannel(payload.Chat.OwnerID) {
continue
}
out = append(out, payload)
}
return out
}
func startRealTaskWorker(t *testing.T, f *taskTestFixture, buffer *messagepartbuffer.Buffer) *chatWorker {
t.Helper()
worker, err := newChatWorker(nil, chatWorkerOptions{
WorkerID: uuid.New(),
Store: f.db,
Pubsub: f.pubsub,
Logger: slog.Make(),
MessagePartBuffer: buffer,
AcquisitionInterval: time.Hour,
AcquisitionBatchSize: 10,
RunnerSyncInterval: time.Hour,
HeartbeatInterval: time.Hour,
HeartbeatCleanupInterval: time.Hour,
HeartbeatStaleSeconds: 30,
StateChannelSize: 16,
RunnerManagerChannelSize: 16,
AcquisitionWakeChannelSize: 1,
TaskRetryInitialBackoff: time.Millisecond,
TaskRetryMaxBackoff: time.Millisecond,
})
require.NoError(t, err)
require.NoError(t, worker.Start(context.Background()))
t.Cleanup(func() { require.NoError(t, worker.Close()) })
return worker
}
func waitOwnedChat(t *testing.T, f *taskTestFixture, chatID uuid.UUID, workerID uuid.UUID) database.Chat {
t.Helper()
var latest database.Chat
testutil.Eventually(testutil.Context(t, testutil.WaitLong), t, func(ctx context.Context) bool {
chat, err := f.db.GetChatByID(ctx, chatID)
if err != nil {
return false
}
latest = chat
return chat.WorkerID.Valid && chat.WorkerID.UUID == workerID && chat.RunnerID.Valid
}, testutil.IntervalFast)
return latest
}
type taskSideEffectRecorder struct {
mu sync.Mutex
hints []runnerStateUpdate
cleanups []runnerKey
interrupts []interruptionOutcome
}
func newTaskSideEffectRecorder() *taskSideEffectRecorder {
return &taskSideEffectRecorder{}
}
func (r *taskSideEffectRecorder) routeStateHint(_ context.Context, state runnerStateUpdate) {
r.mu.Lock()
r.hints = append(r.hints, state)
r.mu.Unlock()
}
func (r *taskSideEffectRecorder) requestCleanup(_ context.Context, key runnerKey) {
r.mu.Lock()
r.cleanups = append(r.cleanups, key)
r.mu.Unlock()
}
func (r *taskSideEffectRecorder) afterInterruptionOutcome(_ context.Context, outcome interruptionOutcome) error {
r.mu.Lock()
r.interrupts = append(r.interrupts, outcome)
r.mu.Unlock()
return nil
}
func (r *taskSideEffectRecorder) requireStateHint(t *testing.T, chatID uuid.UUID, snapshot int64, status database.ChatStatus) {
t.Helper()
r.mu.Lock()
defer r.mu.Unlock()
for _, hint := range r.hints {
if hint.ChatID == chatID && hint.SnapshotVersion == snapshot && hint.Status == status {
return
}
}
t.Fatalf("missing state hint chat_id=%s snapshot=%d status=%s hints=%v", chatID, snapshot, status, r.hints)
}
func (r *taskSideEffectRecorder) requireStateHintCount(t *testing.T, count int) {
t.Helper()
r.mu.Lock()
defer r.mu.Unlock()
require.Len(t, r.hints, count)
}
func (r *taskSideEffectRecorder) requireCleanup(t *testing.T, chatID uuid.UUID, runnerID uuid.UUID) {
t.Helper()
r.mu.Lock()
defer r.mu.Unlock()
for _, cleanup := range r.cleanups {
if cleanup.ChatID == chatID && cleanup.RunnerID == runnerID {
return
}
}
t.Fatalf("missing cleanup chat_id=%s runner_id=%s cleanups=%v", chatID, runnerID, r.cleanups)
}
func (r *taskSideEffectRecorder) requireCleanupCount(t *testing.T, count int) {
t.Helper()
r.mu.Lock()
defer r.mu.Unlock()
require.Len(t, r.cleanups, count)
}
func (r *taskSideEffectRecorder) requireInterruptionOutcome(t *testing.T, chatID uuid.UUID, status database.ChatStatus) {
t.Helper()
r.mu.Lock()
defer r.mu.Unlock()
for _, outcome := range r.interrupts {
if outcome.Chat.ID == chatID && outcome.Chat.Status == status {
return
}
}
t.Fatalf("missing interruption outcome chat_id=%s status=%s outcomes=%v", chatID, status, r.interrupts)
}
func newTestTaskStarter(t *testing.T, f *taskTestFixture, buffer *messagepartbuffer.Buffer, recorder *taskSideEffectRecorder) *taskStarter {
t.Helper()
starter, err := newTaskStarter(nil, chatWorkerOptions{
Store: f.db,
Pubsub: f.pubsub,
Logger: slog.Make(),
Clock: quartz.NewReal(),
MessagePartBuffer: buffer,
TaskRetryInitialBackoff: time.Millisecond,
TaskRetryMaxBackoff: time.Millisecond,
}, recorder.routeStateHint, recorder.requestCleanup)
require.NoError(t, err)
starter.afterInterruptionOutcome = recorder.afterInterruptionOutcome
return starter
}