mirror of
https://github.com/coder/coder.git
synced 2026-06-05 05:58:20 +00:00
322 lines
7.8 KiB
Go
322 lines
7.8 KiB
Go
package chatd
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"sync"
|
|
|
|
"github.com/google/uuid"
|
|
"golang.org/x/xerrors"
|
|
|
|
"github.com/coder/coder/v2/coderd/database"
|
|
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
|
|
)
|
|
|
|
type taskKind string
|
|
|
|
const (
|
|
taskKindGeneration taskKind = "generation"
|
|
taskKindInterrupt taskKind = "interrupt"
|
|
taskKindRequiresActionTimeout taskKind = "requires_action_timeout"
|
|
taskKindAbandon taskKind = "abandon"
|
|
)
|
|
|
|
type taskInstanceID uuid.UUID
|
|
|
|
type localWorkKey struct {
|
|
historyVersion int64
|
|
status database.ChatStatus
|
|
}
|
|
|
|
type taskIndexKey struct {
|
|
kind taskKind
|
|
key localWorkKey
|
|
}
|
|
|
|
type taskRecord struct {
|
|
id taskInstanceID
|
|
kind taskKind
|
|
localKey localWorkKey
|
|
cancel context.CancelFunc
|
|
done <-chan struct{}
|
|
}
|
|
|
|
type runner struct {
|
|
ctx context.Context
|
|
mgr *runnerManager
|
|
rec *runnerRecord
|
|
opts chatWorkerOptions
|
|
|
|
lastSnapshotVersion int64
|
|
hasAcceptedState bool
|
|
latestState runnerStateUpdate
|
|
|
|
activeTaskID taskInstanceID
|
|
activeTaskSet bool
|
|
tasks map[taskInstanceID]*taskRecord
|
|
tasksByIndex map[taskIndexKey]taskInstanceID
|
|
localLocks *localLockSet
|
|
}
|
|
|
|
func newRunner(ctx context.Context, mgr *runnerManager, rec *runnerRecord, opts chatWorkerOptions) *runner {
|
|
return &runner{
|
|
ctx: ctx,
|
|
mgr: mgr,
|
|
rec: rec,
|
|
opts: opts,
|
|
tasks: make(map[taskInstanceID]*taskRecord),
|
|
tasksByIndex: make(map[taskIndexKey]taskInstanceID),
|
|
localLocks: newLocalLockSet(),
|
|
}
|
|
}
|
|
|
|
func (r *runner) run() {
|
|
if !r.bootstrap() {
|
|
return
|
|
}
|
|
for {
|
|
select {
|
|
case state := <-r.rec.stateCh:
|
|
r.processState(state)
|
|
case <-r.ctx.Done():
|
|
r.cancelActiveTask()
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (r *runner) bootstrap() bool {
|
|
channel := coderdpubsub.ChatStateUpdateChannel(r.rec.key.ChatID)
|
|
unsubscribe, err := r.opts.Pubsub.SubscribeWithErr(channel, coderdpubsub.HandleChatStateUpdate(
|
|
func(ctx context.Context, payload coderdpubsub.ChatStateUpdateMessage, err error) {
|
|
if err != nil {
|
|
r.opts.Logger.Warn(ctx, "chatworker state update decode failed", slogError(err))
|
|
return
|
|
}
|
|
r.mgr.RouteStateHint(ctx, stateUpdateFromPubsub(r.rec.key.ChatID, payload))
|
|
},
|
|
))
|
|
if err != nil {
|
|
r.mgr.requestCleanup(r.ctx, r.rec.key)
|
|
return false
|
|
}
|
|
if !r.rec.setUnsubscribe(unsubscribe) {
|
|
return false
|
|
}
|
|
chat, err := r.opts.Store.GetChatByID(r.ctx, r.rec.key.ChatID)
|
|
if err != nil {
|
|
r.opts.Logger.Warn(r.ctx, "chatworker runner bootstrap failed", slogError(err))
|
|
r.mgr.requestCleanup(r.ctx, r.rec.key)
|
|
return false
|
|
}
|
|
r.mgr.RouteStateHint(r.ctx, stateUpdateFromChat(chat))
|
|
return true
|
|
}
|
|
|
|
func stateUpdateFromPubsub(chatID uuid.UUID, payload coderdpubsub.ChatStateUpdateMessage) runnerStateUpdate {
|
|
return runnerStateUpdate{
|
|
ChatID: chatID,
|
|
WorkerID: payload.WorkerID,
|
|
RunnerID: payload.RunnerID,
|
|
SnapshotVersion: payload.SnapshotVersion,
|
|
HistoryVersion: payload.HistoryVersion,
|
|
QueueVersion: payload.QueueVersion,
|
|
GenerationAttempt: payload.GenerationAttempt,
|
|
Status: database.ChatStatus(payload.Status),
|
|
Archived: payload.Archived,
|
|
}
|
|
}
|
|
|
|
func (r *runner) processState(state runnerStateUpdate) {
|
|
if state.SnapshotVersion <= r.lastSnapshotVersion {
|
|
return
|
|
}
|
|
|
|
r.removeFinishedTasks()
|
|
|
|
if !uuidPtrEqual(state.WorkerID, r.rec.workerID) || !uuidPtrEqual(state.RunnerID, r.rec.key.RunnerID) {
|
|
r.acceptState(state)
|
|
r.mgr.requestCleanup(r.ctx, r.rec.key)
|
|
return
|
|
}
|
|
|
|
changed := !r.hasAcceptedState ||
|
|
r.latestState.HistoryVersion != state.HistoryVersion ||
|
|
r.latestState.Status != state.Status ||
|
|
r.latestState.Archived != state.Archived
|
|
if !changed {
|
|
r.acceptState(state)
|
|
return
|
|
}
|
|
if r.hasAcceptedState && r.activeTaskSet {
|
|
r.cancelActiveTask()
|
|
}
|
|
|
|
r.spawnForState(state)
|
|
r.acceptState(state)
|
|
}
|
|
|
|
func (r *runner) acceptState(state runnerStateUpdate) {
|
|
r.hasAcceptedState = true
|
|
r.latestState = state
|
|
r.lastSnapshotVersion = state.SnapshotVersion
|
|
}
|
|
|
|
func (r *runner) spawnForState(state runnerStateUpdate) {
|
|
if state.Archived {
|
|
r.spawnTaskIfNeeded(taskKindAbandon, state)
|
|
return
|
|
}
|
|
switch state.Status {
|
|
case database.ChatStatusRunning:
|
|
r.spawnTaskIfNeeded(taskKindGeneration, state)
|
|
case database.ChatStatusInterrupting:
|
|
r.spawnTaskIfNeeded(taskKindInterrupt, state)
|
|
case database.ChatStatusRequiresAction:
|
|
r.spawnTaskIfNeeded(taskKindRequiresActionTimeout, state)
|
|
case database.ChatStatusWaiting, database.ChatStatusError:
|
|
r.spawnTaskIfNeeded(taskKindAbandon, state)
|
|
default:
|
|
r.spawnTaskIfNeeded(taskKindAbandon, state)
|
|
}
|
|
}
|
|
|
|
func (r *runner) spawnTaskIfNeeded(kind taskKind, state runnerStateUpdate) {
|
|
key := localWorkKey{historyVersion: state.HistoryVersion, status: state.Status}
|
|
idx := taskIndexKey{kind: kind, key: key}
|
|
if r.activeTaskSet && r.tasksByIndex[idx] == r.activeTaskID {
|
|
return
|
|
}
|
|
|
|
id := taskInstanceID(uuid.New())
|
|
taskCtx, cancel := context.WithCancel(r.ctx)
|
|
done := make(chan struct{})
|
|
record := &taskRecord{
|
|
id: id,
|
|
kind: kind,
|
|
localKey: key,
|
|
cancel: cancel,
|
|
done: done,
|
|
}
|
|
r.tasks[id] = record
|
|
r.tasksByIndex[idx] = id
|
|
r.activeTaskID = id
|
|
r.activeTaskSet = true
|
|
|
|
input := chatWorkerTaskStartInput{
|
|
TaskID: uuid.UUID(id),
|
|
ChatID: r.rec.key.ChatID,
|
|
WorkerID: r.rec.workerID,
|
|
RunnerID: r.rec.key.RunnerID,
|
|
HistoryVersion: state.HistoryVersion,
|
|
GenerationAttempt: state.GenerationAttempt,
|
|
Status: state.Status,
|
|
RequiresActionDeadlineAt: state.RequiresActionDeadlineAt,
|
|
}
|
|
go r.runTask(taskCtx, kind, key, input, done)
|
|
}
|
|
|
|
func (r *runner) runTask(
|
|
ctx context.Context,
|
|
kind taskKind,
|
|
key localWorkKey,
|
|
input chatWorkerTaskStartInput,
|
|
done chan<- struct{},
|
|
) {
|
|
defer close(done)
|
|
err := runTaskWithRetry(ctx, r.opts.retryOptions(), kind, func(ctx context.Context) error {
|
|
unlock, ok := r.localLocks.acquire(ctx, key)
|
|
if !ok {
|
|
return errTaskExpectedExit
|
|
}
|
|
defer unlock()
|
|
if ctx.Err() != nil {
|
|
return errTaskExpectedExit
|
|
}
|
|
|
|
switch kind {
|
|
case taskKindGeneration:
|
|
return r.opts.TaskStarter.StartGeneration(ctx, input)
|
|
case taskKindInterrupt:
|
|
return r.opts.TaskStarter.StartInterrupt(ctx, input)
|
|
case taskKindRequiresActionTimeout:
|
|
return r.opts.TaskStarter.StartRequiresActionTimeout(ctx, input)
|
|
case taskKindAbandon:
|
|
return r.opts.TaskStarter.StartAbandon(ctx, input)
|
|
default:
|
|
return errors.Join(errTaskExpectedExit, xerrors.Errorf("unknown task kind %q", kind))
|
|
}
|
|
})
|
|
if err != nil && ctx.Err() == nil {
|
|
r.opts.Logger.Warn(ctx, "chatworker task failed", slogError(err))
|
|
}
|
|
}
|
|
|
|
func (r *runner) cancelActiveTask() {
|
|
if !r.activeTaskSet {
|
|
return
|
|
}
|
|
id := r.activeTaskID
|
|
r.activeTaskSet = false
|
|
if record := r.tasks[id]; record != nil {
|
|
record.cancel()
|
|
}
|
|
}
|
|
|
|
func (r *runner) removeFinishedTasks() {
|
|
for id, record := range r.tasks {
|
|
select {
|
|
case <-record.done:
|
|
delete(r.tasks, id)
|
|
idx := taskIndexKey{kind: record.kind, key: record.localKey}
|
|
if r.tasksByIndex[idx] == id {
|
|
delete(r.tasksByIndex, idx)
|
|
}
|
|
if r.activeTaskSet && r.activeTaskID == id {
|
|
r.activeTaskSet = false
|
|
}
|
|
default:
|
|
}
|
|
}
|
|
}
|
|
|
|
func uuidPtrEqual(got *uuid.UUID, want uuid.UUID) bool {
|
|
return got != nil && *got == want
|
|
}
|
|
|
|
type localLockSet struct {
|
|
mu sync.Mutex
|
|
locked map[localWorkKey]chan struct{}
|
|
}
|
|
|
|
func newLocalLockSet() *localLockSet {
|
|
return &localLockSet{locked: make(map[localWorkKey]chan struct{})}
|
|
}
|
|
|
|
func (l *localLockSet) acquire(ctx context.Context, key localWorkKey) (func(), bool) {
|
|
for {
|
|
l.mu.Lock()
|
|
wait, ok := l.locked[key]
|
|
if !ok {
|
|
released := make(chan struct{})
|
|
l.locked[key] = released
|
|
l.mu.Unlock()
|
|
return func() {
|
|
l.mu.Lock()
|
|
if l.locked[key] == released {
|
|
delete(l.locked, key)
|
|
close(released)
|
|
}
|
|
l.mu.Unlock()
|
|
}, true
|
|
}
|
|
l.mu.Unlock()
|
|
|
|
select {
|
|
case <-wait:
|
|
case <-ctx.Done():
|
|
return nil, false
|
|
}
|
|
}
|
|
}
|