mirror of
https://github.com/coder/coder.git
synced 2026-06-05 14:08:20 +00:00
059ed7ab5c
## Problem Flaky test: `TestCloseDuringShutdownContextCanceledShouldRetryOnNewReplica` (coder/internal#1371) The test intermittently fails because the chat ends up in `waiting` status instead of `pending` after server shutdown. ## Root Cause There is a race condition in `processChat` where `runChat` completes successfully just as the server context is being canceled during `Close()`. The sequence: 1. Server calls `Close()`, canceling the server context. 2. The LLM HTTP response has already been fully written by the mock server (the stream closes normally before context cancellation propagates to the HTTP client). 3. `runChat` returns `nil` (success) instead of `context.Canceled`. 4. The existing `isShutdownCancellation` check only runs when `runChat` returns an error, so the shutdown is not detected. 5. `processChat`'s deferred cleanup marks the chat as `waiting` instead of `pending`. 6. The test's assertion that the chat is `pending` never becomes true. This race is timing-dependent — it only triggers when the mock server's HTTP response completes in the narrow window between context cancellation being initiated and it propagating through the HTTP transport layer. ## Fix Add a server context check after `runChat` returns successfully. If the server is shutting down (`ctx.Err() != nil`), override the status to `pending` so another replica can pick up the chat. This is the same pattern already used for the error path (`isShutdownCancellation`), extended to cover the success path.
2631 lines
76 KiB
Go
2631 lines
76 KiB
Go
package chatd
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"charm.land/fantasy"
|
|
"github.com/google/uuid"
|
|
"github.com/sqlc-dev/pqtype"
|
|
"golang.org/x/xerrors"
|
|
|
|
"cdr.dev/slog/v3"
|
|
"github.com/coder/coder/v2/coderd/chatd/chatloop"
|
|
"github.com/coder/coder/v2/coderd/chatd/chatprompt"
|
|
"github.com/coder/coder/v2/coderd/chatd/chatprovider"
|
|
"github.com/coder/coder/v2/coderd/chatd/chattool"
|
|
"github.com/coder/coder/v2/coderd/database"
|
|
"github.com/coder/coder/v2/coderd/database/db2sdk"
|
|
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
|
"github.com/coder/coder/v2/coderd/database/pubsub"
|
|
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
|
|
"github.com/coder/coder/v2/coderd/webpush"
|
|
"github.com/coder/coder/v2/codersdk"
|
|
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
|
)
|
|
|
|
const (
|
|
// DefaultPendingChatAcquireInterval is the default time between attempts to
|
|
// acquire pending chats.
|
|
DefaultPendingChatAcquireInterval = time.Second
|
|
// DefaultInFlightChatStaleAfter is the default age after which a running
|
|
// chat is considered stale and should be recovered.
|
|
DefaultInFlightChatStaleAfter = 5 * time.Minute
|
|
|
|
homeInstructionLookupTimeout = 5 * time.Second
|
|
instructionCacheTTL = 5 * time.Minute
|
|
chatHeartbeatInterval = 30 * time.Second
|
|
maxChatSteps = 1200
|
|
|
|
// staleRecoveryIntervalDivisor determines how often the stale
|
|
// recovery loop runs relative to the stale threshold. A value
|
|
// of 5 means recovery runs at 1/5 of the stale-after duration.
|
|
staleRecoveryIntervalDivisor = 5
|
|
|
|
defaultSubagentInstruction = "You are running as a delegated sub-agent chat. Complete the delegated task and provide clear, concise assistant responses for the parent agent."
|
|
)
|
|
|
|
// Server handles background processing of pending chats.
|
|
type Server struct {
|
|
cancel context.CancelFunc
|
|
closed chan struct{}
|
|
inflight sync.WaitGroup
|
|
|
|
db database.Store
|
|
workerID uuid.UUID
|
|
logger slog.Logger
|
|
|
|
remotePartsProvider RemotePartsProvider
|
|
|
|
agentConnFn AgentConnFunc
|
|
createWorkspaceFn chattool.CreateWorkspaceFn
|
|
pubsub pubsub.Pubsub
|
|
webpushDispatcher webpush.Dispatcher
|
|
providerAPIKeys chatprovider.ProviderAPIKeys
|
|
|
|
// streamMu guards chatStreams which tracks in-flight chat
|
|
// stream state for broadcasting ephemeral events.
|
|
streamMu sync.Mutex
|
|
chatStreams map[uuid.UUID]*chatStreamState
|
|
|
|
// instructionCache caches home instruction file contents by
|
|
// workspace agent ID so we don't re-dial on every chat turn.
|
|
instructionCacheMu sync.Mutex
|
|
instructionCache map[uuid.UUID]cachedInstruction
|
|
|
|
// Configuration
|
|
pendingChatAcquireInterval time.Duration
|
|
inFlightChatStaleAfter time.Duration
|
|
}
|
|
|
|
type cachedInstruction struct {
|
|
instruction string
|
|
fetchedAt time.Time
|
|
}
|
|
|
|
// AgentConnFunc provides access to workspace agent connections.
|
|
type AgentConnFunc func(ctx context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error)
|
|
|
|
// ReplicaAddressResolver maps a replica ID to its relay address.
|
|
type ReplicaAddressResolver func(context.Context, uuid.UUID) (string, bool)
|
|
|
|
// RemotePartsProvider returns a snapshot and live stream of message_part
|
|
// events from the replica that is running the chat. Called when the chat
|
|
// is actively running on a different replica. Nil in AGPL single-replica
|
|
// deployments.
|
|
type RemotePartsProvider func(
|
|
ctx context.Context,
|
|
chatID uuid.UUID,
|
|
workerID uuid.UUID,
|
|
requestHeader http.Header,
|
|
) (
|
|
snapshot []codersdk.ChatStreamEvent,
|
|
parts <-chan codersdk.ChatStreamEvent,
|
|
cancel func(),
|
|
err error,
|
|
)
|
|
|
|
type chatStreamState struct {
|
|
buffer []codersdk.ChatStreamEvent
|
|
buffering bool
|
|
subscribers map[uuid.UUID]chan codersdk.ChatStreamEvent
|
|
}
|
|
|
|
// MaxQueueSize is the maximum number of queued user messages per chat.
|
|
const MaxQueueSize = 20
|
|
|
|
var (
|
|
// ErrMessageQueueFull indicates the per-chat queue limit was reached.
|
|
ErrMessageQueueFull = xerrors.New("chat message queue is full")
|
|
// ErrEditedMessageNotFound indicates the edited message does not exist
|
|
// in the target chat.
|
|
ErrEditedMessageNotFound = xerrors.New("edited message not found")
|
|
// ErrEditedMessageNotUser indicates a non-user message edit attempt.
|
|
ErrEditedMessageNotUser = xerrors.New("only user messages can be edited")
|
|
)
|
|
|
|
// CreateOptions controls chat creation in the shared chat mutation path.
|
|
type CreateOptions struct {
|
|
OwnerID uuid.UUID
|
|
WorkspaceID uuid.NullUUID
|
|
ParentChatID uuid.NullUUID
|
|
RootChatID uuid.NullUUID
|
|
Title string
|
|
ModelConfigID uuid.UUID
|
|
SystemPrompt string
|
|
InitialUserContent []fantasy.Content
|
|
}
|
|
|
|
// SendMessageBusyBehavior controls what happens when a chat is already active.
|
|
type SendMessageBusyBehavior string
|
|
|
|
const (
|
|
// SendMessageBusyBehaviorQueue queues user messages while the chat is busy.
|
|
SendMessageBusyBehaviorQueue SendMessageBusyBehavior = "queue"
|
|
// SendMessageBusyBehaviorInterrupt inserts the message immediately and
|
|
// transitions the chat to pending, which interrupts the active run.
|
|
SendMessageBusyBehaviorInterrupt SendMessageBusyBehavior = "interrupt"
|
|
)
|
|
|
|
// SendMessageOptions controls user message insertion with busy-state behavior.
|
|
type SendMessageOptions struct {
|
|
ChatID uuid.UUID
|
|
Content []fantasy.Content
|
|
ModelConfigID *uuid.UUID
|
|
BusyBehavior SendMessageBusyBehavior
|
|
}
|
|
|
|
// SendMessageResult contains the outcome of user message processing.
|
|
type SendMessageResult struct {
|
|
Queued bool
|
|
QueuedMessage *database.ChatQueuedMessage
|
|
Message database.ChatMessage
|
|
Chat database.Chat
|
|
}
|
|
|
|
// EditMessageOptions controls in-place user message edits.
|
|
type EditMessageOptions struct {
|
|
ChatID uuid.UUID
|
|
EditedMessageID int64
|
|
Content []fantasy.Content
|
|
}
|
|
|
|
// EditMessageResult contains the updated user message and chat status.
|
|
type EditMessageResult struct {
|
|
Message database.ChatMessage
|
|
Chat database.Chat
|
|
}
|
|
|
|
// PromoteQueuedOptions controls queued-message promotion.
|
|
type PromoteQueuedOptions struct {
|
|
ChatID uuid.UUID
|
|
QueuedMessageID int64
|
|
ModelConfigID *uuid.UUID
|
|
}
|
|
|
|
// PromoteQueuedResult contains post-promotion message metadata.
|
|
type PromoteQueuedResult struct {
|
|
PromotedMessage database.ChatMessage
|
|
}
|
|
|
|
// CreateChat creates a chat, inserts optional system prompt and initial user
|
|
// message, and moves the chat into pending status.
|
|
func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.Chat, error) {
|
|
if opts.OwnerID == uuid.Nil {
|
|
return database.Chat{}, xerrors.New("owner_id is required")
|
|
}
|
|
if strings.TrimSpace(opts.Title) == "" {
|
|
return database.Chat{}, xerrors.New("title is required")
|
|
}
|
|
if len(opts.InitialUserContent) == 0 {
|
|
return database.Chat{}, xerrors.New("initial user content is required")
|
|
}
|
|
|
|
var chat database.Chat
|
|
txErr := p.db.InTx(func(tx database.Store) error {
|
|
insertedChat, err := tx.InsertChat(ctx, database.InsertChatParams{
|
|
OwnerID: opts.OwnerID,
|
|
WorkspaceID: opts.WorkspaceID,
|
|
ParentChatID: opts.ParentChatID,
|
|
RootChatID: opts.RootChatID,
|
|
LastModelConfigID: opts.ModelConfigID,
|
|
Title: opts.Title,
|
|
})
|
|
if err != nil {
|
|
return xerrors.Errorf("insert chat: %w", err)
|
|
}
|
|
|
|
systemPrompt := strings.TrimSpace(opts.SystemPrompt)
|
|
if systemPrompt != "" {
|
|
systemContent, err := json.Marshal(systemPrompt)
|
|
if err != nil {
|
|
return xerrors.Errorf("marshal system prompt: %w", err)
|
|
}
|
|
_, err = tx.InsertChatMessage(ctx, database.InsertChatMessageParams{
|
|
ChatID: insertedChat.ID,
|
|
ModelConfigID: uuid.NullUUID{
|
|
UUID: opts.ModelConfigID,
|
|
Valid: true,
|
|
},
|
|
Role: "system",
|
|
Content: pqtype.NullRawMessage{
|
|
RawMessage: systemContent,
|
|
Valid: len(systemContent) > 0,
|
|
},
|
|
Visibility: database.ChatMessageVisibilityModel,
|
|
InputTokens: sql.NullInt64{},
|
|
OutputTokens: sql.NullInt64{},
|
|
TotalTokens: sql.NullInt64{},
|
|
ReasoningTokens: sql.NullInt64{},
|
|
CacheCreationTokens: sql.NullInt64{},
|
|
CacheReadTokens: sql.NullInt64{},
|
|
ContextLimit: sql.NullInt64{},
|
|
Compressed: sql.NullBool{},
|
|
})
|
|
if err != nil {
|
|
return xerrors.Errorf("insert system message: %w", err)
|
|
}
|
|
}
|
|
|
|
userContent, err := chatprompt.MarshalContent(opts.InitialUserContent)
|
|
if err != nil {
|
|
return xerrors.Errorf("marshal initial user content: %w", err)
|
|
}
|
|
_, err = insertChatMessageWithStore(ctx, tx, database.InsertChatMessageParams{
|
|
ChatID: insertedChat.ID,
|
|
ModelConfigID: uuid.NullUUID{
|
|
UUID: opts.ModelConfigID,
|
|
Valid: true,
|
|
},
|
|
Role: "user",
|
|
Content: userContent,
|
|
Visibility: database.ChatMessageVisibilityBoth,
|
|
InputTokens: sql.NullInt64{},
|
|
OutputTokens: sql.NullInt64{},
|
|
TotalTokens: sql.NullInt64{},
|
|
ReasoningTokens: sql.NullInt64{},
|
|
CacheCreationTokens: sql.NullInt64{},
|
|
CacheReadTokens: sql.NullInt64{},
|
|
ContextLimit: sql.NullInt64{},
|
|
Compressed: sql.NullBool{},
|
|
})
|
|
if err != nil {
|
|
return xerrors.Errorf("insert initial user message: %w", err)
|
|
}
|
|
|
|
chat, err = setChatPendingWithStore(ctx, tx, insertedChat.ID)
|
|
if err != nil {
|
|
return xerrors.Errorf("set chat pending: %w", err)
|
|
}
|
|
|
|
if !chat.RootChatID.Valid && !chat.ParentChatID.Valid {
|
|
chat.RootChatID = uuid.NullUUID{UUID: chat.ID, Valid: true}
|
|
}
|
|
return nil
|
|
}, nil)
|
|
if txErr != nil {
|
|
return database.Chat{}, txErr
|
|
}
|
|
|
|
p.publishChatPubsubEvent(chat, coderdpubsub.ChatEventKindCreated)
|
|
return chat, nil
|
|
}
|
|
|
|
// SendMessage inserts a user message and optionally queues it while the chat
|
|
// is busy, then publishes stream + pubsub updates.
|
|
func (p *Server) SendMessage(
|
|
ctx context.Context,
|
|
opts SendMessageOptions,
|
|
) (SendMessageResult, error) {
|
|
if opts.ChatID == uuid.Nil {
|
|
return SendMessageResult{}, xerrors.New("chat_id is required")
|
|
}
|
|
if len(opts.Content) == 0 {
|
|
return SendMessageResult{}, xerrors.New("content is required")
|
|
}
|
|
|
|
busyBehavior := opts.BusyBehavior
|
|
if busyBehavior == "" {
|
|
busyBehavior = SendMessageBusyBehaviorQueue
|
|
}
|
|
switch busyBehavior {
|
|
case SendMessageBusyBehaviorQueue, SendMessageBusyBehaviorInterrupt:
|
|
default:
|
|
return SendMessageResult{}, xerrors.Errorf("invalid busy behavior %q", opts.BusyBehavior)
|
|
}
|
|
|
|
content, err := chatprompt.MarshalContent(opts.Content)
|
|
if err != nil {
|
|
return SendMessageResult{}, xerrors.Errorf("marshal message content: %w", err)
|
|
}
|
|
|
|
var (
|
|
result SendMessageResult
|
|
queuedMessagesSDK []codersdk.ChatQueuedMessage
|
|
)
|
|
|
|
txErr := p.db.InTx(func(tx database.Store) error {
|
|
lockedChat, err := tx.GetChatByIDForUpdate(ctx, opts.ChatID)
|
|
if err != nil {
|
|
return xerrors.Errorf("lock chat: %w", err)
|
|
}
|
|
modelConfigID := lockedChat.LastModelConfigID
|
|
if opts.ModelConfigID != nil {
|
|
modelConfigID = *opts.ModelConfigID
|
|
}
|
|
|
|
if busyBehavior == SendMessageBusyBehaviorQueue &&
|
|
shouldQueueUserMessage(lockedChat.Status) {
|
|
existingQueued, err := tx.GetChatQueuedMessages(ctx, opts.ChatID)
|
|
if err != nil {
|
|
return xerrors.Errorf("get queued messages: %w", err)
|
|
}
|
|
if len(existingQueued) >= MaxQueueSize {
|
|
return ErrMessageQueueFull
|
|
}
|
|
|
|
queued, err := tx.InsertChatQueuedMessage(ctx, database.InsertChatQueuedMessageParams{
|
|
ChatID: opts.ChatID,
|
|
Content: content.RawMessage,
|
|
})
|
|
if err != nil {
|
|
return xerrors.Errorf("insert queued message: %w", err)
|
|
}
|
|
|
|
queuedMessages, err := tx.GetChatQueuedMessages(ctx, opts.ChatID)
|
|
if err != nil {
|
|
return xerrors.Errorf("get queued messages: %w", err)
|
|
}
|
|
|
|
result.Queued = true
|
|
result.QueuedMessage = &queued
|
|
result.Chat = lockedChat
|
|
queuedMessagesSDK = db2sdk.ChatQueuedMessages(queuedMessages)
|
|
return nil
|
|
}
|
|
|
|
message, updatedChat, err := insertUserMessageAndSetPending(
|
|
ctx,
|
|
tx,
|
|
lockedChat,
|
|
modelConfigID,
|
|
content,
|
|
)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
result.Message = message
|
|
result.Chat = updatedChat
|
|
|
|
return nil
|
|
}, nil)
|
|
if txErr != nil {
|
|
return SendMessageResult{}, txErr
|
|
}
|
|
|
|
if result.Queued {
|
|
p.publishEvent(opts.ChatID, codersdk.ChatStreamEvent{
|
|
Type: codersdk.ChatStreamEventTypeQueueUpdate,
|
|
ChatID: opts.ChatID,
|
|
QueuedMessages: queuedMessagesSDK,
|
|
})
|
|
p.publishChatStreamNotify(opts.ChatID, coderdpubsub.ChatStreamNotifyMessage{
|
|
QueueUpdate: true,
|
|
})
|
|
return result, nil
|
|
}
|
|
|
|
p.publishMessage(opts.ChatID, result.Message)
|
|
p.publishStatus(opts.ChatID, result.Chat.Status, result.Chat.WorkerID)
|
|
p.publishChatPubsubEvent(result.Chat, coderdpubsub.ChatEventKindStatusChange)
|
|
return result, nil
|
|
}
|
|
|
|
// EditMessage updates a user message in-place, truncates all following messages,
|
|
// clears queued messages, and moves the chat into pending status.
|
|
func (p *Server) EditMessage(
|
|
ctx context.Context,
|
|
opts EditMessageOptions,
|
|
) (EditMessageResult, error) {
|
|
if opts.ChatID == uuid.Nil {
|
|
return EditMessageResult{}, xerrors.New("chat_id is required")
|
|
}
|
|
if opts.EditedMessageID <= 0 {
|
|
return EditMessageResult{}, xerrors.New("edited_message_id is required")
|
|
}
|
|
if len(opts.Content) == 0 {
|
|
return EditMessageResult{}, xerrors.New("content is required")
|
|
}
|
|
|
|
content, err := chatprompt.MarshalContent(opts.Content)
|
|
if err != nil {
|
|
return EditMessageResult{}, xerrors.Errorf("marshal message content: %w", err)
|
|
}
|
|
|
|
var result EditMessageResult
|
|
txErr := p.db.InTx(func(tx database.Store) error {
|
|
_, err := tx.GetChatByIDForUpdate(ctx, opts.ChatID)
|
|
if err != nil {
|
|
return xerrors.Errorf("lock chat: %w", err)
|
|
}
|
|
|
|
existing, err := tx.GetChatMessageByID(ctx, opts.EditedMessageID)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return ErrEditedMessageNotFound
|
|
}
|
|
return xerrors.Errorf("get edited message: %w", err)
|
|
}
|
|
if existing.ChatID != opts.ChatID {
|
|
return ErrEditedMessageNotFound
|
|
}
|
|
if existing.Role != "user" {
|
|
return ErrEditedMessageNotUser
|
|
}
|
|
|
|
updatedMessage, err := tx.UpdateChatMessageByID(ctx, database.UpdateChatMessageByIDParams{
|
|
ModelConfigID: uuid.NullUUID{},
|
|
Content: content,
|
|
ID: opts.EditedMessageID,
|
|
})
|
|
if err != nil {
|
|
return xerrors.Errorf("update chat message: %w", err)
|
|
}
|
|
|
|
err = tx.DeleteChatMessagesAfterID(ctx, database.DeleteChatMessagesAfterIDParams{
|
|
ChatID: opts.ChatID,
|
|
AfterID: opts.EditedMessageID,
|
|
})
|
|
if err != nil {
|
|
return xerrors.Errorf("delete later chat messages: %w", err)
|
|
}
|
|
|
|
err = tx.DeleteAllChatQueuedMessages(ctx, opts.ChatID)
|
|
if err != nil {
|
|
return xerrors.Errorf("delete queued messages: %w", err)
|
|
}
|
|
|
|
updatedChat, err := tx.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
|
ID: opts.ChatID,
|
|
Status: database.ChatStatusPending,
|
|
WorkerID: uuid.NullUUID{},
|
|
StartedAt: sql.NullTime{},
|
|
HeartbeatAt: sql.NullTime{},
|
|
LastError: sql.NullString{},
|
|
})
|
|
if err != nil {
|
|
return xerrors.Errorf("set chat pending: %w", err)
|
|
}
|
|
|
|
result.Message = updatedMessage
|
|
result.Chat = updatedChat
|
|
return nil
|
|
}, nil)
|
|
if txErr != nil {
|
|
return EditMessageResult{}, txErr
|
|
}
|
|
|
|
p.publishMessage(opts.ChatID, result.Message)
|
|
p.publishEvent(opts.ChatID, codersdk.ChatStreamEvent{
|
|
Type: codersdk.ChatStreamEventTypeQueueUpdate,
|
|
QueuedMessages: []codersdk.ChatQueuedMessage{},
|
|
})
|
|
p.publishChatStreamNotify(opts.ChatID, coderdpubsub.ChatStreamNotifyMessage{
|
|
QueueUpdate: true,
|
|
})
|
|
p.publishStatus(opts.ChatID, result.Chat.Status, result.Chat.WorkerID)
|
|
p.publishChatPubsubEvent(result.Chat, coderdpubsub.ChatEventKindStatusChange)
|
|
|
|
return result, nil
|
|
}
|
|
|
|
// ArchiveChat archives a chat and all descendants, then broadcasts a deleted event.
|
|
func (p *Server) ArchiveChat(ctx context.Context, chatID uuid.UUID) error {
|
|
if chatID == uuid.Nil {
|
|
return xerrors.New("chat_id is required")
|
|
}
|
|
|
|
chat, err := p.db.GetChatByID(ctx, chatID)
|
|
if err != nil {
|
|
return xerrors.Errorf("get chat: %w", err)
|
|
}
|
|
|
|
if err := p.db.ArchiveChatByID(ctx, chatID); err != nil {
|
|
return xerrors.Errorf("archive chat: %w", err)
|
|
}
|
|
|
|
p.publishChatPubsubEvent(chat, coderdpubsub.ChatEventKindDeleted)
|
|
return nil
|
|
}
|
|
|
|
// DeleteQueued removes a queued user message and publishes the queue update.
|
|
func (p *Server) DeleteQueued(
|
|
ctx context.Context,
|
|
chatID uuid.UUID,
|
|
queuedMessageID int64,
|
|
) error {
|
|
if chatID == uuid.Nil {
|
|
return xerrors.New("chat_id is required")
|
|
}
|
|
|
|
err := p.db.DeleteChatQueuedMessage(ctx, database.DeleteChatQueuedMessageParams{
|
|
ID: queuedMessageID,
|
|
ChatID: chatID,
|
|
})
|
|
if err != nil {
|
|
return xerrors.Errorf("delete queued message: %w", err)
|
|
}
|
|
|
|
queuedMessages, err := p.db.GetChatQueuedMessages(ctx, chatID)
|
|
if err != nil {
|
|
p.logger.Warn(ctx, "failed to load queued messages after delete",
|
|
slog.F("chat_id", chatID),
|
|
slog.F("queued_message_id", queuedMessageID),
|
|
slog.Error(err),
|
|
)
|
|
return nil
|
|
}
|
|
|
|
p.publishEvent(chatID, codersdk.ChatStreamEvent{
|
|
Type: codersdk.ChatStreamEventTypeQueueUpdate,
|
|
QueuedMessages: db2sdk.ChatQueuedMessages(queuedMessages),
|
|
})
|
|
p.publishChatStreamNotify(chatID, coderdpubsub.ChatStreamNotifyMessage{
|
|
QueueUpdate: true,
|
|
})
|
|
return nil
|
|
}
|
|
|
|
// PromoteQueued promotes a queued message into chat history and marks the chat pending.
|
|
func (p *Server) PromoteQueued(
|
|
ctx context.Context,
|
|
opts PromoteQueuedOptions,
|
|
) (PromoteQueuedResult, error) {
|
|
if opts.ChatID == uuid.Nil {
|
|
return PromoteQueuedResult{}, xerrors.New("chat_id is required")
|
|
}
|
|
|
|
var (
|
|
result PromoteQueuedResult
|
|
promoted database.ChatMessage
|
|
updatedChat database.Chat
|
|
remainingQueue []database.ChatQueuedMessage
|
|
)
|
|
|
|
txErr := p.db.InTx(func(tx database.Store) error {
|
|
lockedChat, err := tx.GetChatByIDForUpdate(ctx, opts.ChatID)
|
|
if err != nil {
|
|
return xerrors.Errorf("lock chat: %w", err)
|
|
}
|
|
modelConfigID := lockedChat.LastModelConfigID
|
|
if opts.ModelConfigID != nil {
|
|
modelConfigID = *opts.ModelConfigID
|
|
}
|
|
|
|
queuedMessages, err := tx.GetChatQueuedMessages(ctx, opts.ChatID)
|
|
if err != nil {
|
|
return xerrors.Errorf("get queued messages: %w", err)
|
|
}
|
|
|
|
var (
|
|
targetContent json.RawMessage
|
|
found bool
|
|
)
|
|
for _, qm := range queuedMessages {
|
|
if qm.ID == opts.QueuedMessageID {
|
|
targetContent = qm.Content
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
if !found {
|
|
return xerrors.New("queued message not found")
|
|
}
|
|
|
|
err = tx.DeleteChatQueuedMessage(ctx, database.DeleteChatQueuedMessageParams{
|
|
ID: opts.QueuedMessageID,
|
|
ChatID: opts.ChatID,
|
|
})
|
|
if err != nil {
|
|
return xerrors.Errorf("delete queued message: %w", err)
|
|
}
|
|
|
|
promoted, updatedChat, err = insertUserMessageAndSetPending(
|
|
ctx,
|
|
tx,
|
|
lockedChat,
|
|
modelConfigID,
|
|
pqtype.NullRawMessage{
|
|
RawMessage: targetContent,
|
|
Valid: len(targetContent) > 0,
|
|
},
|
|
)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
remainingQueue, err = tx.GetChatQueuedMessages(ctx, opts.ChatID)
|
|
if err != nil {
|
|
return xerrors.Errorf("get remaining queue: %w", err)
|
|
}
|
|
result.PromotedMessage = promoted
|
|
|
|
return nil
|
|
}, nil)
|
|
if txErr != nil {
|
|
return PromoteQueuedResult{}, txErr
|
|
}
|
|
|
|
p.publishEvent(opts.ChatID, codersdk.ChatStreamEvent{
|
|
Type: codersdk.ChatStreamEventTypeQueueUpdate,
|
|
QueuedMessages: db2sdk.ChatQueuedMessages(remainingQueue),
|
|
})
|
|
p.publishChatStreamNotify(opts.ChatID, coderdpubsub.ChatStreamNotifyMessage{
|
|
QueueUpdate: true,
|
|
})
|
|
p.publishMessage(opts.ChatID, promoted)
|
|
p.publishStatus(opts.ChatID, updatedChat.Status, updatedChat.WorkerID)
|
|
|
|
return result, nil
|
|
}
|
|
|
|
// InterruptChat interrupts execution, sets waiting status, and broadcasts status updates.
|
|
func (p *Server) InterruptChat(
|
|
ctx context.Context,
|
|
chat database.Chat,
|
|
) database.Chat {
|
|
if chat.ID == uuid.Nil {
|
|
return chat
|
|
}
|
|
|
|
updatedChat, err := p.setChatWaiting(ctx, chat.ID)
|
|
if err != nil {
|
|
p.logger.Error(ctx, "failed to mark chat as waiting",
|
|
slog.F("chat_id", chat.ID),
|
|
slog.Error(err),
|
|
)
|
|
return chat
|
|
}
|
|
return updatedChat
|
|
}
|
|
|
|
// RefreshStatus loads the latest chat status and publishes it to stream subscribers.
|
|
func (p *Server) RefreshStatus(ctx context.Context, chatID uuid.UUID) error {
|
|
if chatID == uuid.Nil {
|
|
return xerrors.New("chat_id is required")
|
|
}
|
|
|
|
chat, err := p.db.GetChatByID(ctx, chatID)
|
|
if err != nil {
|
|
return xerrors.Errorf("get chat: %w", err)
|
|
}
|
|
|
|
p.publishStatus(chat.ID, chat.Status, chat.WorkerID)
|
|
return nil
|
|
}
|
|
|
|
func setChatPendingWithStore(
|
|
ctx context.Context,
|
|
store database.Store,
|
|
chatID uuid.UUID,
|
|
) (database.Chat, error) {
|
|
chat, err := store.GetChatByID(ctx, chatID)
|
|
if err != nil {
|
|
return database.Chat{}, xerrors.Errorf("get chat: %w", err)
|
|
}
|
|
if chat.Status == database.ChatStatusPending {
|
|
return chat, nil
|
|
}
|
|
|
|
updatedChat, err := store.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
|
ID: chat.ID,
|
|
Status: database.ChatStatusPending,
|
|
WorkerID: uuid.NullUUID{},
|
|
StartedAt: sql.NullTime{},
|
|
HeartbeatAt: sql.NullTime{},
|
|
LastError: sql.NullString{},
|
|
})
|
|
if err != nil {
|
|
return database.Chat{}, xerrors.Errorf("set chat pending: %w", err)
|
|
}
|
|
return updatedChat, nil
|
|
}
|
|
|
|
func (p *Server) setChatWaiting(ctx context.Context, chatID uuid.UUID) (database.Chat, error) {
|
|
updatedChat, err := p.db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
|
ID: chatID,
|
|
Status: database.ChatStatusWaiting,
|
|
WorkerID: uuid.NullUUID{},
|
|
StartedAt: sql.NullTime{},
|
|
HeartbeatAt: sql.NullTime{},
|
|
LastError: sql.NullString{},
|
|
})
|
|
if err != nil {
|
|
return database.Chat{}, err
|
|
}
|
|
p.publishStatus(chatID, updatedChat.Status, updatedChat.WorkerID)
|
|
p.publishChatPubsubEvent(updatedChat, coderdpubsub.ChatEventKindStatusChange)
|
|
return updatedChat, nil
|
|
}
|
|
|
|
func insertChatMessageWithStore(
|
|
ctx context.Context,
|
|
store database.Store,
|
|
params database.InsertChatMessageParams,
|
|
) (database.ChatMessage, error) {
|
|
message, err := store.InsertChatMessage(ctx, params)
|
|
if err != nil {
|
|
return database.ChatMessage{}, xerrors.Errorf("insert chat message: %w", err)
|
|
}
|
|
return message, nil
|
|
}
|
|
|
|
func insertUserMessageAndSetPending(
|
|
ctx context.Context,
|
|
store database.Store,
|
|
lockedChat database.Chat,
|
|
modelConfigID uuid.UUID,
|
|
content pqtype.NullRawMessage,
|
|
) (database.ChatMessage, database.Chat, error) {
|
|
message, err := insertChatMessageWithStore(ctx, store, database.InsertChatMessageParams{
|
|
ChatID: lockedChat.ID,
|
|
ModelConfigID: uuid.NullUUID{UUID: modelConfigID, Valid: true},
|
|
Role: "user",
|
|
Content: content,
|
|
Visibility: database.ChatMessageVisibilityBoth,
|
|
InputTokens: sql.NullInt64{},
|
|
OutputTokens: sql.NullInt64{},
|
|
TotalTokens: sql.NullInt64{},
|
|
ReasoningTokens: sql.NullInt64{},
|
|
CacheCreationTokens: sql.NullInt64{},
|
|
CacheReadTokens: sql.NullInt64{},
|
|
ContextLimit: sql.NullInt64{},
|
|
Compressed: sql.NullBool{},
|
|
})
|
|
if err != nil {
|
|
return database.ChatMessage{}, database.Chat{}, err
|
|
}
|
|
|
|
if lockedChat.Status == database.ChatStatusPending {
|
|
return message, lockedChat, nil
|
|
}
|
|
|
|
updatedChat, err := store.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
|
ID: lockedChat.ID,
|
|
Status: database.ChatStatusPending,
|
|
WorkerID: uuid.NullUUID{},
|
|
StartedAt: sql.NullTime{},
|
|
HeartbeatAt: sql.NullTime{},
|
|
LastError: sql.NullString{},
|
|
})
|
|
if err != nil {
|
|
return database.ChatMessage{}, database.Chat{}, xerrors.Errorf("set chat pending: %w", err)
|
|
}
|
|
return message, updatedChat, nil
|
|
}
|
|
|
|
// shouldQueueUserMessage reports whether a user message should be
|
|
// queued while a chat is active.
|
|
func shouldQueueUserMessage(status database.ChatStatus) bool {
|
|
switch status {
|
|
case database.ChatStatusRunning, database.ChatStatusPending:
|
|
return true
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
// Config configures a chat processor.
|
|
type Config struct {
|
|
Logger slog.Logger
|
|
Database database.Store
|
|
ReplicaID uuid.UUID
|
|
RemotePartsProvider RemotePartsProvider
|
|
PendingChatAcquireInterval time.Duration
|
|
InFlightChatStaleAfter time.Duration
|
|
AgentConn AgentConnFunc
|
|
CreateWorkspace chattool.CreateWorkspaceFn
|
|
Pubsub pubsub.Pubsub
|
|
ProviderAPIKeys chatprovider.ProviderAPIKeys
|
|
WebpushDispatcher webpush.Dispatcher
|
|
}
|
|
|
|
// New creates a new chat processor. The processor polls for pending
|
|
// chats and processes them. It is the caller's responsibility to call Close
|
|
// on the returned instance.
|
|
func New(cfg Config) *Server {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
pendingChatAcquireInterval := cfg.PendingChatAcquireInterval
|
|
if pendingChatAcquireInterval == 0 {
|
|
pendingChatAcquireInterval = DefaultPendingChatAcquireInterval
|
|
}
|
|
|
|
inFlightChatStaleAfter := cfg.InFlightChatStaleAfter
|
|
if inFlightChatStaleAfter == 0 {
|
|
inFlightChatStaleAfter = DefaultInFlightChatStaleAfter
|
|
}
|
|
|
|
workerID := cfg.ReplicaID
|
|
if workerID == uuid.Nil {
|
|
workerID = uuid.New()
|
|
}
|
|
|
|
p := &Server{
|
|
cancel: cancel,
|
|
closed: make(chan struct{}),
|
|
db: cfg.Database,
|
|
workerID: workerID,
|
|
logger: cfg.Logger.Named("chat-processor"),
|
|
remotePartsProvider: cfg.RemotePartsProvider,
|
|
agentConnFn: cfg.AgentConn,
|
|
createWorkspaceFn: cfg.CreateWorkspace,
|
|
pubsub: cfg.Pubsub,
|
|
webpushDispatcher: cfg.WebpushDispatcher,
|
|
providerAPIKeys: cfg.ProviderAPIKeys,
|
|
chatStreams: make(map[uuid.UUID]*chatStreamState),
|
|
instructionCache: make(map[uuid.UUID]cachedInstruction),
|
|
pendingChatAcquireInterval: pendingChatAcquireInterval,
|
|
inFlightChatStaleAfter: inFlightChatStaleAfter,
|
|
}
|
|
|
|
//nolint:gocritic // The chat processor uses a scoped chatd context.
|
|
ctx = dbauthz.AsChatd(ctx)
|
|
go p.start(ctx)
|
|
|
|
return p
|
|
}
|
|
|
|
func (p *Server) start(ctx context.Context) {
|
|
defer close(p.closed)
|
|
|
|
// Recover stale chats on startup and periodically thereafter
|
|
// to handle chats orphaned by crashed or redeployed workers.
|
|
p.recoverStaleChats(ctx)
|
|
|
|
acquireTicker := time.NewTicker(p.pendingChatAcquireInterval)
|
|
defer acquireTicker.Stop()
|
|
|
|
staleRecoveryInterval := p.inFlightChatStaleAfter / staleRecoveryIntervalDivisor
|
|
staleTicker := time.NewTicker(staleRecoveryInterval)
|
|
defer staleTicker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-acquireTicker.C:
|
|
p.processOnce(ctx)
|
|
case <-staleTicker.C:
|
|
p.recoverStaleChats(ctx)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (p *Server) processOnce(ctx context.Context) {
|
|
// Try to acquire a pending chat.
|
|
chat, err := p.db.AcquireChat(ctx, database.AcquireChatParams{
|
|
StartedAt: time.Now(),
|
|
WorkerID: p.workerID,
|
|
})
|
|
if err != nil {
|
|
if !xerrors.Is(err, sql.ErrNoRows) {
|
|
p.logger.Error(ctx, "failed to acquire chat", slog.Error(err))
|
|
}
|
|
// No pending chats or error.
|
|
return
|
|
}
|
|
|
|
// Process the chat (don't block the main loop).
|
|
p.inflight.Add(1)
|
|
go func() {
|
|
defer p.inflight.Done()
|
|
p.processChat(ctx, chat)
|
|
}()
|
|
}
|
|
|
|
func (p *Server) publishToStream(chatID uuid.UUID, event codersdk.ChatStreamEvent) {
|
|
p.streamMu.Lock()
|
|
state := p.streamStateLocked(chatID)
|
|
if event.Type == codersdk.ChatStreamEventTypeMessagePart {
|
|
if !state.buffering {
|
|
p.streamMu.Unlock()
|
|
return
|
|
}
|
|
state.buffer = append(state.buffer, event)
|
|
}
|
|
subscribers := make([]chan codersdk.ChatStreamEvent, 0, len(state.subscribers))
|
|
for _, ch := range state.subscribers {
|
|
subscribers = append(subscribers, ch)
|
|
}
|
|
p.streamMu.Unlock()
|
|
|
|
for _, ch := range subscribers {
|
|
select {
|
|
case ch <- event:
|
|
default:
|
|
p.logger.Warn(context.Background(), "dropping chat stream event",
|
|
slog.F("chat_id", chatID), slog.F("type", event.Type))
|
|
}
|
|
}
|
|
}
|
|
|
|
func (p *Server) subscribeToStream(chatID uuid.UUID) (
|
|
[]codersdk.ChatStreamEvent,
|
|
<-chan codersdk.ChatStreamEvent,
|
|
func(),
|
|
) {
|
|
p.streamMu.Lock()
|
|
state := p.streamStateLocked(chatID)
|
|
snapshot := append([]codersdk.ChatStreamEvent(nil), state.buffer...)
|
|
id := uuid.New()
|
|
ch := make(chan codersdk.ChatStreamEvent, 128)
|
|
state.subscribers[id] = ch
|
|
p.streamMu.Unlock()
|
|
|
|
cancel := func() {
|
|
p.streamMu.Lock()
|
|
state, ok := p.chatStreams[chatID]
|
|
if ok {
|
|
if subscriber, exists := state.subscribers[id]; exists {
|
|
delete(state.subscribers, id)
|
|
close(subscriber)
|
|
}
|
|
p.cleanupStreamIfIdleLocked(chatID, state)
|
|
}
|
|
p.streamMu.Unlock()
|
|
}
|
|
|
|
return snapshot, ch, cancel
|
|
}
|
|
|
|
// cleanupStreamIfIdleLocked removes the chat entry when there
|
|
// are no subscribers and the stream is not buffering. The
|
|
// caller must hold p.streamMu.
|
|
func (p *Server) cleanupStreamIfIdleLocked(chatID uuid.UUID, state *chatStreamState) {
|
|
if !state.buffering && len(state.subscribers) == 0 {
|
|
delete(p.chatStreams, chatID)
|
|
}
|
|
}
|
|
|
|
func (p *Server) streamStateLocked(chatID uuid.UUID) *chatStreamState {
|
|
state, ok := p.chatStreams[chatID]
|
|
if !ok {
|
|
state = &chatStreamState{subscribers: make(map[uuid.UUID]chan codersdk.ChatStreamEvent)}
|
|
p.chatStreams[chatID] = state
|
|
}
|
|
return state
|
|
}
|
|
|
|
func (p *Server) Subscribe(
|
|
ctx context.Context,
|
|
chatID uuid.UUID,
|
|
requestHeader http.Header,
|
|
afterMessageID int64,
|
|
) (
|
|
[]codersdk.ChatStreamEvent,
|
|
<-chan codersdk.ChatStreamEvent,
|
|
func(),
|
|
bool,
|
|
) {
|
|
if p == nil {
|
|
return nil, nil, nil, false
|
|
}
|
|
if ctx == nil {
|
|
ctx = context.Background()
|
|
}
|
|
|
|
// Subscribe to local stream for message_parts (ephemeral).
|
|
localSnapshot, localParts, localCancel := p.subscribeToStream(chatID)
|
|
|
|
// Build initial snapshot synchronously
|
|
initialSnapshot := make([]codersdk.ChatStreamEvent, 0)
|
|
// Add local message_parts to snapshot
|
|
for _, event := range localSnapshot {
|
|
if event.Type == codersdk.ChatStreamEventTypeMessagePart {
|
|
initialSnapshot = append(initialSnapshot, event)
|
|
}
|
|
}
|
|
|
|
// Load initial messages from DB. When afterMessageID > 0 the
|
|
// caller already has messages up to that ID (e.g. from the REST
|
|
// endpoint), so we only fetch newer ones to avoid sending
|
|
// duplicate data.
|
|
messages, err := p.db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
|
ChatID: chatID,
|
|
AfterID: afterMessageID,
|
|
})
|
|
if err == nil {
|
|
for _, msg := range messages {
|
|
sdkMsg := db2sdk.ChatMessage(msg)
|
|
initialSnapshot = append(initialSnapshot, codersdk.ChatStreamEvent{
|
|
Type: codersdk.ChatStreamEventTypeMessage,
|
|
ChatID: chatID,
|
|
Message: &sdkMsg,
|
|
})
|
|
}
|
|
}
|
|
|
|
// Load initial queue
|
|
queued, err := p.db.GetChatQueuedMessages(ctx, chatID)
|
|
if err == nil && len(queued) > 0 {
|
|
initialSnapshot = append(initialSnapshot, codersdk.ChatStreamEvent{
|
|
Type: codersdk.ChatStreamEventTypeQueueUpdate,
|
|
ChatID: chatID,
|
|
QueuedMessages: db2sdk.ChatQueuedMessages(queued),
|
|
})
|
|
}
|
|
|
|
// Get initial chat state to determine if we need a relay
|
|
chat, err := p.db.GetChatByID(ctx, chatID)
|
|
var relayCancel func()
|
|
var relayParts <-chan codersdk.ChatStreamEvent
|
|
if err == nil && chat.Status == database.ChatStatusRunning && chat.WorkerID.Valid && chat.WorkerID.UUID != p.workerID && p.remotePartsProvider != nil {
|
|
// Open relay for initial snapshot
|
|
snapshot, parts, cancel, err := p.remotePartsProvider(ctx, chatID, chat.WorkerID.UUID, requestHeader)
|
|
if err == nil {
|
|
relayCancel = cancel
|
|
relayParts = parts
|
|
// Add relay message_parts to snapshot
|
|
for _, event := range snapshot {
|
|
if event.Type == codersdk.ChatStreamEventTypeMessagePart {
|
|
initialSnapshot = append(initialSnapshot, event)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Include the current chat status in the snapshot so the
|
|
// frontend can gate message_part processing correctly from
|
|
// the very first batch, without waiting for a separate REST
|
|
// query.
|
|
if err == nil {
|
|
statusEvent := codersdk.ChatStreamEvent{
|
|
Type: codersdk.ChatStreamEventTypeStatus,
|
|
ChatID: chatID,
|
|
Status: &codersdk.ChatStreamStatus{
|
|
Status: codersdk.ChatStatus(chat.Status),
|
|
},
|
|
}
|
|
// Prepend so the frontend sees the status before any
|
|
// message_part events.
|
|
initialSnapshot = append([]codersdk.ChatStreamEvent{statusEvent}, initialSnapshot...)
|
|
}
|
|
|
|
// Track the last message ID we've seen for DB queries
|
|
var lastMessageID int64
|
|
if len(messages) > 0 {
|
|
lastMessageID = messages[len(messages)-1].ID
|
|
}
|
|
|
|
// Merge all event sources
|
|
mergedCtx, mergedCancel := context.WithCancel(ctx)
|
|
mergedEvents := make(chan codersdk.ChatStreamEvent, 128)
|
|
var allCancels []func()
|
|
allCancels = append(allCancels, localCancel)
|
|
if relayCancel != nil {
|
|
allCancels = append(allCancels, relayCancel)
|
|
}
|
|
|
|
// Channel for async relay establishment.
|
|
type relayResult struct {
|
|
parts <-chan codersdk.ChatStreamEvent
|
|
cancel func()
|
|
}
|
|
relayReadyCh := make(chan relayResult, 1)
|
|
|
|
// Reconnect timer state.
|
|
var reconnectTimer *time.Timer
|
|
var reconnectCh <-chan time.Time
|
|
|
|
// Helper to close relay and stop any pending reconnect timer.
|
|
closeRelay := func() {
|
|
if relayCancel != nil {
|
|
relayCancel()
|
|
relayCancel = nil
|
|
}
|
|
relayParts = nil
|
|
if reconnectTimer != nil {
|
|
reconnectTimer.Stop()
|
|
reconnectTimer = nil
|
|
reconnectCh = nil
|
|
}
|
|
}
|
|
|
|
// openRelayAsync dials the remote replica in a background
|
|
// goroutine and delivers the result on relayReadyCh so the
|
|
// main select loop is never blocked by network I/O.
|
|
openRelayAsync := func(workerID uuid.UUID) {
|
|
if p.remotePartsProvider == nil {
|
|
return
|
|
}
|
|
closeRelay()
|
|
go func() {
|
|
snapshot, parts, cancel, err := p.remotePartsProvider(mergedCtx, chatID, workerID, requestHeader)
|
|
if err != nil {
|
|
p.logger.Warn(mergedCtx, "failed to open relay for message parts",
|
|
slog.F("chat_id", chatID),
|
|
slog.F("worker_id", workerID),
|
|
slog.Error(err),
|
|
)
|
|
return
|
|
}
|
|
// Wrap the relay channel so snapshot parts are
|
|
// delivered through the same channel as live parts.
|
|
wrappedParts := make(chan codersdk.ChatStreamEvent, 128)
|
|
go func() {
|
|
defer close(wrappedParts)
|
|
for _, event := range snapshot {
|
|
if event.Type == codersdk.ChatStreamEventTypeMessagePart {
|
|
select {
|
|
case wrappedParts <- event:
|
|
case <-mergedCtx.Done():
|
|
cancel()
|
|
return
|
|
}
|
|
}
|
|
}
|
|
for event := range parts {
|
|
select {
|
|
case wrappedParts <- event:
|
|
case <-mergedCtx.Done():
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
select {
|
|
case relayReadyCh <- relayResult{parts: wrappedParts, cancel: cancel}:
|
|
case <-mergedCtx.Done():
|
|
cancel()
|
|
}
|
|
}()
|
|
}
|
|
|
|
// scheduleRelayReconnect arms a short timer so the select
|
|
// loop can re-check chat status and reopen the relay without
|
|
// spinning in a tight loop.
|
|
scheduleRelayReconnect := func() {
|
|
if p.remotePartsProvider == nil {
|
|
return
|
|
}
|
|
if reconnectTimer != nil {
|
|
reconnectTimer.Stop()
|
|
}
|
|
reconnectTimer = time.NewTimer(500 * time.Millisecond)
|
|
reconnectCh = reconnectTimer.C
|
|
}
|
|
|
|
//nolint:nestif
|
|
if p.pubsub != nil {
|
|
notifications := make(chan coderdpubsub.ChatStreamNotifyMessage, 10)
|
|
errCh := make(chan error, 1)
|
|
|
|
listener := func(_ context.Context, message []byte, err error) {
|
|
if err != nil {
|
|
select {
|
|
case <-mergedCtx.Done():
|
|
case errCh <- err:
|
|
}
|
|
return
|
|
}
|
|
var notify coderdpubsub.ChatStreamNotifyMessage
|
|
if unmarshalErr := json.Unmarshal(message, ¬ify); unmarshalErr != nil {
|
|
select {
|
|
case <-mergedCtx.Done():
|
|
case errCh <- xerrors.Errorf("unmarshal chat stream notify: %w", unmarshalErr):
|
|
}
|
|
return
|
|
}
|
|
select {
|
|
case <-mergedCtx.Done():
|
|
case notifications <- notify:
|
|
}
|
|
}
|
|
|
|
// Subscribe to pubsub for durable events
|
|
if pubsubCancel, err := p.pubsub.SubscribeWithErr(
|
|
coderdpubsub.ChatStreamNotifyChannel(chatID),
|
|
listener,
|
|
); err == nil {
|
|
allCancels = append(allCancels, pubsubCancel)
|
|
} else {
|
|
p.logger.Warn(mergedCtx, "failed to subscribe to chat stream notifications",
|
|
slog.F("chat_id", chatID),
|
|
slog.Error(err),
|
|
)
|
|
}
|
|
|
|
// Handle pubsub notifications in a goroutine
|
|
go func() {
|
|
defer close(mergedEvents)
|
|
defer closeRelay()
|
|
|
|
for {
|
|
relayPartsCh := relayParts
|
|
select {
|
|
case <-mergedCtx.Done():
|
|
return
|
|
case err := <-errCh:
|
|
p.logger.Error(mergedCtx, "chat stream pubsub error",
|
|
slog.F("chat_id", chatID),
|
|
slog.Error(err),
|
|
)
|
|
mergedEvents <- codersdk.ChatStreamEvent{
|
|
Type: codersdk.ChatStreamEventTypeError,
|
|
ChatID: chatID,
|
|
Error: &codersdk.ChatStreamError{
|
|
Message: err.Error(),
|
|
},
|
|
}
|
|
return
|
|
case result := <-relayReadyCh:
|
|
// An async relay dial completed; swap in the
|
|
// new relay channel.
|
|
closeRelay()
|
|
relayParts = result.parts
|
|
relayCancel = result.cancel
|
|
case <-reconnectCh:
|
|
reconnectCh = nil
|
|
// Re-check whether the chat is still running
|
|
// on a remote worker before reconnecting.
|
|
currentChat, chatErr := p.db.GetChatByID(mergedCtx, chatID)
|
|
if chatErr == nil && currentChat.Status == database.ChatStatusRunning &&
|
|
currentChat.WorkerID.Valid && currentChat.WorkerID.UUID != p.workerID {
|
|
openRelayAsync(currentChat.WorkerID.UUID)
|
|
}
|
|
case notify := <-notifications:
|
|
// Handle different notification types
|
|
if notify.AfterMessageID > 0 {
|
|
// Read only new messages from DB.
|
|
messages, err := p.db.GetChatMessagesByChatID(mergedCtx, database.GetChatMessagesByChatIDParams{
|
|
ChatID: chatID,
|
|
AfterID: lastMessageID,
|
|
})
|
|
if err == nil {
|
|
for _, msg := range messages {
|
|
sdkMsg := db2sdk.ChatMessage(msg)
|
|
select {
|
|
case <-mergedCtx.Done():
|
|
return
|
|
case mergedEvents <- codersdk.ChatStreamEvent{
|
|
Type: codersdk.ChatStreamEventTypeMessage,
|
|
ChatID: chatID,
|
|
Message: &sdkMsg,
|
|
}:
|
|
}
|
|
lastMessageID = msg.ID
|
|
}
|
|
}
|
|
}
|
|
if notify.Status != "" {
|
|
status := database.ChatStatus(notify.Status)
|
|
select {
|
|
case <-mergedCtx.Done():
|
|
return
|
|
case mergedEvents <- codersdk.ChatStreamEvent{
|
|
Type: codersdk.ChatStreamEventTypeStatus,
|
|
ChatID: chatID,
|
|
Status: &codersdk.ChatStreamStatus{Status: codersdk.ChatStatus(status)},
|
|
}:
|
|
}
|
|
// Manage relay lifecycle based on status.
|
|
if status == database.ChatStatusRunning && notify.WorkerID != "" {
|
|
workerID, err := uuid.Parse(notify.WorkerID)
|
|
if err == nil && workerID != p.workerID {
|
|
openRelayAsync(workerID)
|
|
} else if workerID == p.workerID {
|
|
closeRelay()
|
|
}
|
|
} else {
|
|
closeRelay()
|
|
}
|
|
}
|
|
if notify.Error != "" {
|
|
select {
|
|
case <-mergedCtx.Done():
|
|
return
|
|
case mergedEvents <- codersdk.ChatStreamEvent{
|
|
Type: codersdk.ChatStreamEventTypeError,
|
|
ChatID: chatID,
|
|
Error: &codersdk.ChatStreamError{
|
|
Message: notify.Error,
|
|
},
|
|
}:
|
|
}
|
|
}
|
|
if notify.QueueUpdate {
|
|
queued, err := p.db.GetChatQueuedMessages(mergedCtx, chatID)
|
|
if err == nil {
|
|
select {
|
|
case <-mergedCtx.Done():
|
|
return
|
|
case mergedEvents <- codersdk.ChatStreamEvent{
|
|
Type: codersdk.ChatStreamEventTypeQueueUpdate,
|
|
ChatID: chatID,
|
|
QueuedMessages: db2sdk.ChatQueuedMessages(queued),
|
|
}:
|
|
}
|
|
}
|
|
}
|
|
case event, ok := <-localParts:
|
|
if !ok {
|
|
// Local parts channel closed, but continue with pubsub
|
|
continue
|
|
}
|
|
// Only forward message_part events from local (durable events come via pubsub)
|
|
if event.Type == codersdk.ChatStreamEventTypeMessagePart {
|
|
select {
|
|
case <-mergedCtx.Done():
|
|
return
|
|
case mergedEvents <- event:
|
|
}
|
|
}
|
|
case event, ok := <-relayPartsCh:
|
|
if !ok {
|
|
relayParts = nil
|
|
// Schedule reconnection instead of giving up.
|
|
scheduleRelayReconnect()
|
|
continue
|
|
}
|
|
// Only forward message_part events from relay (durable events come via pubsub)
|
|
if event.Type == codersdk.ChatStreamEventTypeMessagePart {
|
|
select {
|
|
case <-mergedCtx.Done():
|
|
return
|
|
case mergedEvents <- event:
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
} else {
|
|
// No pubsub, just merge local parts.
|
|
// localSnapshot was already included in initialSnapshot,
|
|
// so only forward new events here.
|
|
go func() {
|
|
defer close(mergedEvents)
|
|
for event := range localParts {
|
|
select {
|
|
case <-mergedCtx.Done():
|
|
return
|
|
case mergedEvents <- event:
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
cancel := func() {
|
|
mergedCancel()
|
|
for _, cancelFn := range allCancels {
|
|
if cancelFn != nil {
|
|
cancelFn()
|
|
}
|
|
}
|
|
if reconnectTimer != nil {
|
|
reconnectTimer.Stop()
|
|
}
|
|
}
|
|
|
|
return initialSnapshot, mergedEvents, cancel, true
|
|
}
|
|
|
|
func (p *Server) publishEvent(chatID uuid.UUID, event codersdk.ChatStreamEvent) {
|
|
if event.ChatID == uuid.Nil {
|
|
event.ChatID = chatID
|
|
}
|
|
p.publishToStream(chatID, event)
|
|
}
|
|
|
|
func (p *Server) publishStatus(chatID uuid.UUID, status database.ChatStatus, workerID uuid.NullUUID) {
|
|
p.publishEvent(chatID, codersdk.ChatStreamEvent{
|
|
Type: codersdk.ChatStreamEventTypeStatus,
|
|
Status: &codersdk.ChatStreamStatus{Status: codersdk.ChatStatus(status)},
|
|
})
|
|
notify := coderdpubsub.ChatStreamNotifyMessage{
|
|
Status: string(status),
|
|
}
|
|
if workerID.Valid {
|
|
notify.WorkerID = workerID.UUID.String()
|
|
}
|
|
p.publishChatStreamNotify(chatID, notify)
|
|
}
|
|
|
|
// publishChatStreamNotify broadcasts a per-chat stream notification via
|
|
// PostgreSQL pubsub so that all replicas can read updates from the database.
|
|
func (p *Server) publishChatStreamNotify(chatID uuid.UUID, notify coderdpubsub.ChatStreamNotifyMessage) {
|
|
if p.pubsub == nil {
|
|
return
|
|
}
|
|
payload, err := json.Marshal(notify)
|
|
if err != nil {
|
|
p.logger.Error(context.Background(), "failed to marshal chat stream notify",
|
|
slog.F("chat_id", chatID),
|
|
slog.Error(err),
|
|
)
|
|
return
|
|
}
|
|
if err := p.pubsub.Publish(coderdpubsub.ChatStreamNotifyChannel(chatID), payload); err != nil {
|
|
p.logger.Error(context.Background(), "failed to publish chat stream notify",
|
|
slog.F("chat_id", chatID),
|
|
slog.Error(err),
|
|
)
|
|
}
|
|
}
|
|
|
|
// publishChatPubsubEvent broadcasts a chat lifecycle event via PostgreSQL
|
|
// pubsub so that all replicas can push updates to watching clients.
|
|
func (p *Server) publishChatPubsubEvent(chat database.Chat, kind coderdpubsub.ChatEventKind) {
|
|
if p.pubsub == nil {
|
|
return
|
|
}
|
|
sdkChat := codersdk.Chat{
|
|
ID: chat.ID,
|
|
OwnerID: chat.OwnerID,
|
|
Title: chat.Title,
|
|
Status: codersdk.ChatStatus(chat.Status),
|
|
CreatedAt: chat.CreatedAt,
|
|
UpdatedAt: chat.UpdatedAt,
|
|
}
|
|
if chat.ParentChatID.Valid {
|
|
parentChatID := chat.ParentChatID.UUID
|
|
sdkChat.ParentChatID = &parentChatID
|
|
}
|
|
if chat.RootChatID.Valid {
|
|
rootChatID := chat.RootChatID.UUID
|
|
sdkChat.RootChatID = &rootChatID
|
|
} else if !chat.ParentChatID.Valid {
|
|
rootChatID := chat.ID
|
|
sdkChat.RootChatID = &rootChatID
|
|
}
|
|
if chat.WorkspaceID.Valid {
|
|
sdkChat.WorkspaceID = &chat.WorkspaceID.UUID
|
|
}
|
|
event := coderdpubsub.ChatEvent{
|
|
Kind: kind,
|
|
Chat: sdkChat,
|
|
}
|
|
payload, err := json.Marshal(event)
|
|
if err != nil {
|
|
p.logger.Error(context.Background(), "failed to marshal chat pubsub event",
|
|
slog.F("chat_id", chat.ID),
|
|
slog.Error(err),
|
|
)
|
|
return
|
|
}
|
|
if err := p.pubsub.Publish(coderdpubsub.ChatEventChannel(chat.OwnerID), payload); err != nil {
|
|
p.logger.Error(context.Background(), "failed to publish chat pubsub event",
|
|
slog.F("chat_id", chat.ID),
|
|
slog.F("kind", kind),
|
|
slog.Error(err),
|
|
)
|
|
}
|
|
}
|
|
|
|
// PublishDiffStatusChange broadcasts a diff_status_change event for
|
|
// the given chat so that watching clients know to re-fetch the diff
|
|
// status. This is called from the HTTP layer after the diff status
|
|
// is updated in the database.
|
|
func (p *Server) PublishDiffStatusChange(ctx context.Context, chatID uuid.UUID) error {
|
|
if p.pubsub == nil {
|
|
return nil
|
|
}
|
|
|
|
chat, err := p.db.GetChatByID(ctx, chatID)
|
|
if err != nil {
|
|
return xerrors.Errorf("get chat: %w", err)
|
|
}
|
|
|
|
p.publishChatPubsubEvent(chat, coderdpubsub.ChatEventKindDiffStatusChange)
|
|
return nil
|
|
}
|
|
|
|
func (p *Server) publishError(chatID uuid.UUID, message string) {
|
|
message = strings.TrimSpace(message)
|
|
if message == "" {
|
|
return
|
|
}
|
|
p.publishEvent(chatID, codersdk.ChatStreamEvent{
|
|
Type: codersdk.ChatStreamEventTypeError,
|
|
Error: &codersdk.ChatStreamError{Message: message},
|
|
})
|
|
p.publishChatStreamNotify(chatID, coderdpubsub.ChatStreamNotifyMessage{
|
|
Error: message,
|
|
})
|
|
}
|
|
|
|
func processingFailureReason(err error) (string, bool) {
|
|
if err == nil {
|
|
return "", false
|
|
}
|
|
|
|
reason := strings.TrimSpace(err.Error())
|
|
if reason == "" {
|
|
return "", false
|
|
}
|
|
return reason, true
|
|
}
|
|
|
|
func panicFailureReason(recovered any) string {
|
|
var reason string
|
|
switch typed := recovered.(type) {
|
|
case string:
|
|
reason = strings.TrimSpace(typed)
|
|
case error:
|
|
reason = strings.TrimSpace(typed.Error())
|
|
default:
|
|
reason = strings.TrimSpace(fmt.Sprint(typed))
|
|
}
|
|
|
|
if reason == "" || reason == "<nil>" {
|
|
return "chat processing panicked"
|
|
}
|
|
return "chat processing panicked: " + reason
|
|
}
|
|
|
|
func (p *Server) publishMessage(chatID uuid.UUID, message database.ChatMessage) {
|
|
sdkMessage := db2sdk.ChatMessage(message)
|
|
p.publishEvent(chatID, codersdk.ChatStreamEvent{
|
|
Type: codersdk.ChatStreamEventTypeMessage,
|
|
Message: &sdkMessage,
|
|
})
|
|
p.publishChatStreamNotify(chatID, coderdpubsub.ChatStreamNotifyMessage{
|
|
AfterMessageID: message.ID - 1,
|
|
})
|
|
}
|
|
|
|
func (p *Server) publishMessagePart(chatID uuid.UUID, role string, part codersdk.ChatMessagePart) {
|
|
if part.Type == "" {
|
|
return
|
|
}
|
|
p.publishEvent(chatID, codersdk.ChatStreamEvent{
|
|
Type: codersdk.ChatStreamEventTypeMessagePart,
|
|
MessagePart: &codersdk.ChatStreamMessagePart{
|
|
Role: role,
|
|
Part: part,
|
|
},
|
|
})
|
|
}
|
|
|
|
func shouldCancelChatFromControlNotification(
|
|
notify coderdpubsub.ChatStreamNotifyMessage,
|
|
workerID uuid.UUID,
|
|
) bool {
|
|
status := database.ChatStatus(strings.TrimSpace(notify.Status))
|
|
switch status {
|
|
case database.ChatStatusWaiting, database.ChatStatusPending, database.ChatStatusError:
|
|
return true
|
|
case database.ChatStatusRunning:
|
|
worker := strings.TrimSpace(notify.WorkerID)
|
|
if worker == "" {
|
|
return false
|
|
}
|
|
notifyWorkerID, err := uuid.Parse(worker)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
return notifyWorkerID != workerID
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
func (p *Server) subscribeChatControl(
|
|
ctx context.Context,
|
|
chatID uuid.UUID,
|
|
cancel context.CancelCauseFunc,
|
|
logger slog.Logger,
|
|
) func() {
|
|
if p.pubsub == nil {
|
|
return nil
|
|
}
|
|
|
|
listener := func(_ context.Context, message []byte, err error) {
|
|
if err != nil {
|
|
logger.Warn(ctx, "chat control pubsub error", slog.Error(err))
|
|
return
|
|
}
|
|
|
|
var notify coderdpubsub.ChatStreamNotifyMessage
|
|
if unmarshalErr := json.Unmarshal(message, ¬ify); unmarshalErr != nil {
|
|
logger.Warn(ctx, "failed to unmarshal chat control notify", slog.Error(unmarshalErr))
|
|
return
|
|
}
|
|
|
|
if shouldCancelChatFromControlNotification(notify, p.workerID) {
|
|
cancel(chatloop.ErrInterrupted)
|
|
}
|
|
}
|
|
|
|
controlCancel, err := p.pubsub.SubscribeWithErr(
|
|
coderdpubsub.ChatStreamNotifyChannel(chatID),
|
|
listener,
|
|
)
|
|
if err != nil {
|
|
logger.Warn(ctx, "failed to subscribe to chat control notifications", slog.Error(err))
|
|
return nil
|
|
}
|
|
return controlCancel
|
|
}
|
|
|
|
func (p *Server) processChat(ctx context.Context, chat database.Chat) {
|
|
logger := p.logger.With(slog.F("chat_id", chat.ID))
|
|
logger.Info(ctx, "processing chat request")
|
|
|
|
chatCtx, cancel := context.WithCancelCause(ctx)
|
|
defer cancel(nil)
|
|
|
|
controlCancel := p.subscribeChatControl(chatCtx, chat.ID, cancel, logger)
|
|
defer func() {
|
|
if controlCancel != nil {
|
|
controlCancel()
|
|
}
|
|
}()
|
|
|
|
// Periodically update the heartbeat so other replicas know this
|
|
// worker is still alive. The goroutine stops when chatCtx is
|
|
// canceled (either by completion or interruption).
|
|
go func() {
|
|
ticker := time.NewTicker(chatHeartbeatInterval)
|
|
defer ticker.Stop()
|
|
for {
|
|
select {
|
|
case <-chatCtx.Done():
|
|
return
|
|
case <-ticker.C:
|
|
rows, err := p.db.UpdateChatHeartbeat(chatCtx, database.UpdateChatHeartbeatParams{
|
|
ID: chat.ID,
|
|
WorkerID: p.workerID,
|
|
})
|
|
if err != nil {
|
|
logger.Warn(chatCtx, "failed to update chat heartbeat", slog.Error(err))
|
|
continue
|
|
}
|
|
if rows == 0 {
|
|
cancel(chatloop.ErrInterrupted)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
|
|
p.publishStatus(chat.ID, database.ChatStatusRunning, uuid.NullUUID{
|
|
UUID: p.workerID,
|
|
Valid: true,
|
|
})
|
|
|
|
// Determine the final status and last error to set when we're done.
|
|
status := database.ChatStatusWaiting
|
|
lastError := ""
|
|
remainingQueuedMessages := []database.ChatQueuedMessage{}
|
|
shouldPublishQueueUpdate := false
|
|
|
|
defer func() {
|
|
// Use a context that is not canceled by Close() so we can
|
|
// reliably update the chat status in the database during
|
|
// graceful shutdown.
|
|
cleanupCtx := context.WithoutCancel(ctx)
|
|
|
|
// Handle panics gracefully.
|
|
if r := recover(); r != nil {
|
|
logger.Error(cleanupCtx, "panic during chat processing", slog.F("panic", r))
|
|
lastError = panicFailureReason(r)
|
|
p.publishError(chat.ID, lastError)
|
|
status = database.ChatStatusError
|
|
}
|
|
|
|
// Check for queued messages and auto-promote the next one.
|
|
// This must be done atomically with the status update to avoid
|
|
// races with the promote endpoint (which also sets status to
|
|
// pending). We use a transaction with FOR UPDATE to ensure we
|
|
// don't overwrite a status change made by another caller.
|
|
err := p.db.InTx(func(tx database.Store) error {
|
|
// Re-read the chat status under lock — another caller
|
|
// (e.g. promote) may have already set it to pending.
|
|
latestChat, lockErr := tx.GetChatByIDForUpdate(cleanupCtx, chat.ID)
|
|
if lockErr != nil {
|
|
return xerrors.Errorf("lock chat for release: %w", lockErr)
|
|
}
|
|
|
|
// If someone else already set the chat to pending (e.g.
|
|
// the promote endpoint), don't overwrite it — just clear
|
|
// the worker and let the processor pick it back up.
|
|
if latestChat.Status == database.ChatStatusPending && status == database.ChatStatusWaiting {
|
|
status = database.ChatStatusPending
|
|
} else if status == database.ChatStatusWaiting {
|
|
// Try to auto-promote the next queued message.
|
|
nextQueued, popErr := tx.PopNextQueuedMessage(cleanupCtx, chat.ID)
|
|
if popErr == nil {
|
|
msg, insertErr := tx.InsertChatMessage(cleanupCtx, database.InsertChatMessageParams{
|
|
ChatID: chat.ID,
|
|
ModelConfigID: uuid.NullUUID{UUID: latestChat.LastModelConfigID, Valid: true},
|
|
Role: "user",
|
|
Content: pqtype.NullRawMessage{
|
|
RawMessage: nextQueued.Content,
|
|
Valid: len(nextQueued.Content) > 0,
|
|
},
|
|
Visibility: database.ChatMessageVisibilityBoth,
|
|
InputTokens: sql.NullInt64{},
|
|
OutputTokens: sql.NullInt64{},
|
|
TotalTokens: sql.NullInt64{},
|
|
ReasoningTokens: sql.NullInt64{},
|
|
CacheCreationTokens: sql.NullInt64{},
|
|
CacheReadTokens: sql.NullInt64{},
|
|
ContextLimit: sql.NullInt64{},
|
|
Compressed: sql.NullBool{},
|
|
})
|
|
if insertErr != nil {
|
|
logger.Error(cleanupCtx, "failed to promote queued message",
|
|
slog.F("queued_message_id", nextQueued.ID), slog.Error(insertErr))
|
|
} else {
|
|
status = database.ChatStatusPending
|
|
|
|
sdkMsg := db2sdk.ChatMessage(msg)
|
|
p.publishEvent(chat.ID, codersdk.ChatStreamEvent{
|
|
Type: codersdk.ChatStreamEventTypeMessage,
|
|
Message: &sdkMsg,
|
|
})
|
|
|
|
remaining, qErr := tx.GetChatQueuedMessages(cleanupCtx, chat.ID)
|
|
if qErr == nil {
|
|
remainingQueuedMessages = remaining
|
|
shouldPublishQueueUpdate = true
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
_, updateErr := tx.UpdateChatStatus(cleanupCtx, database.UpdateChatStatusParams{
|
|
ID: chat.ID,
|
|
Status: status,
|
|
WorkerID: uuid.NullUUID{},
|
|
StartedAt: sql.NullTime{},
|
|
HeartbeatAt: sql.NullTime{},
|
|
LastError: sql.NullString{String: lastError, Valid: lastError != ""},
|
|
})
|
|
return updateErr
|
|
}, nil)
|
|
if err != nil {
|
|
logger.Error(cleanupCtx, "failed to release chat", slog.Error(err))
|
|
}
|
|
if err == nil && shouldPublishQueueUpdate {
|
|
p.publishEvent(chat.ID, codersdk.ChatStreamEvent{
|
|
Type: codersdk.ChatStreamEventTypeQueueUpdate,
|
|
QueuedMessages: db2sdk.ChatQueuedMessages(remainingQueuedMessages),
|
|
})
|
|
p.publishChatStreamNotify(chat.ID, coderdpubsub.ChatStreamNotifyMessage{
|
|
QueueUpdate: true,
|
|
})
|
|
}
|
|
|
|
p.publishStatus(chat.ID, status, uuid.NullUUID{})
|
|
// Re-read the chat from the database to pick up any title
|
|
// changes made during processing (e.g. AI-generated titles
|
|
// from maybeGenerateChatTitle). The local `chat` variable
|
|
// is a value copy and won't reflect updates made in runChat.
|
|
if freshChat, readErr := p.db.GetChatByID(cleanupCtx, chat.ID); readErr == nil {
|
|
chat = freshChat
|
|
} else {
|
|
logger.Warn(cleanupCtx, "failed to re-read chat for status event",
|
|
slog.F("chat_id", chat.ID), slog.Error(readErr))
|
|
}
|
|
chat.Status = status
|
|
p.publishChatPubsubEvent(chat, coderdpubsub.ChatEventKindStatusChange)
|
|
|
|
// Send a web push notification when the agent finishes
|
|
// processing. We only notify for terminal states (waiting
|
|
// = success, error = failure) and skip sub-agent chats to
|
|
// avoid spamming the user with notifications for internal
|
|
// delegation.
|
|
if p.webpushDispatcher != nil && p.webpushDispatcher.PublicKey() != "" && !chat.ParentChatID.Valid {
|
|
if status == database.ChatStatusWaiting || status == database.ChatStatusError {
|
|
pushMsg := codersdk.WebpushMessage{
|
|
Title: chat.Title,
|
|
Body: "Agent has finished running.",
|
|
Icon: "/favicon.ico",
|
|
}
|
|
if status == database.ChatStatusError {
|
|
pushMsg.Body = "Agent encountered an error."
|
|
if lastError != "" {
|
|
pushMsg.Body = lastError
|
|
}
|
|
}
|
|
if err := p.webpushDispatcher.Dispatch(cleanupCtx, chat.OwnerID, pushMsg); err != nil {
|
|
logger.Warn(cleanupCtx, "failed to send chat completion web push",
|
|
slog.F("chat_id", chat.ID),
|
|
slog.F("status", status),
|
|
slog.Error(err),
|
|
)
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
|
|
if err := p.runChat(chatCtx, chat, logger); err != nil {
|
|
if errors.Is(err, chatloop.ErrInterrupted) || errors.Is(context.Cause(chatCtx), chatloop.ErrInterrupted) {
|
|
logger.Info(ctx, "chat interrupted")
|
|
status = database.ChatStatusWaiting
|
|
return
|
|
}
|
|
if isShutdownCancellation(ctx, chatCtx, err) {
|
|
logger.Info(ctx, "chat canceled during shutdown; returning to pending")
|
|
status = database.ChatStatusPending
|
|
lastError = ""
|
|
return
|
|
}
|
|
logger.Error(ctx, "failed to process chat", slog.Error(err))
|
|
if reason, ok := processingFailureReason(err); ok {
|
|
lastError = reason
|
|
p.publishError(chat.ID, lastError)
|
|
}
|
|
status = database.ChatStatusError
|
|
return
|
|
}
|
|
|
|
// If runChat completed successfully but the server context was
|
|
// canceled (e.g. during Close()), the chat should be returned
|
|
// to pending so another replica can pick it up. There is a
|
|
// race where the LLM stream finishes just as the server is
|
|
// shutting down — the HTTP response completes before context
|
|
// cancellation propagates, so runChat returns nil instead of
|
|
// a context.Canceled error. Without this check the chat would
|
|
// be marked "waiting" and never retried.
|
|
if ctx.Err() != nil {
|
|
logger.Info(ctx, "chat completed during shutdown; returning to pending")
|
|
status = database.ChatStatusPending
|
|
lastError = ""
|
|
return
|
|
}
|
|
}
|
|
|
|
func isShutdownCancellation(
|
|
serverCtx context.Context,
|
|
chatCtx context.Context,
|
|
err error,
|
|
) bool {
|
|
if err == nil {
|
|
return false
|
|
}
|
|
// During Close(), the server context is canceled. In-flight chats should
|
|
// be returned to pending so another replica can retry them.
|
|
if serverCtx.Err() == nil {
|
|
return false
|
|
}
|
|
if errors.Is(err, context.Canceled) {
|
|
return true
|
|
}
|
|
return errors.Is(context.Cause(chatCtx), context.Canceled)
|
|
}
|
|
|
|
func (p *Server) runChat(
|
|
ctx context.Context,
|
|
chat database.Chat,
|
|
logger slog.Logger,
|
|
) error {
|
|
model, modelConfig, err := p.resolveChatModel(ctx, chat)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
var callConfig codersdk.ChatModelCallConfig
|
|
if len(modelConfig.Options) > 0 {
|
|
if err := json.Unmarshal(modelConfig.Options, &callConfig); err != nil {
|
|
return xerrors.Errorf("parse model call config: %w", err)
|
|
}
|
|
}
|
|
|
|
messages, err := p.db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
|
|
if err != nil {
|
|
return xerrors.Errorf("get chat messages: %w", err)
|
|
}
|
|
// Fire title generation asynchronously so it doesn't block the
|
|
// chat response. It uses a detached context so it can finish
|
|
// even after the chat processing context is canceled.
|
|
p.inflight.Add(1)
|
|
go func() {
|
|
defer p.inflight.Done()
|
|
p.maybeGenerateChatTitle(context.WithoutCancel(ctx), chat, messages, model, logger)
|
|
}()
|
|
|
|
prompt, err := chatprompt.ConvertMessages(messages)
|
|
if err != nil {
|
|
return xerrors.Errorf("build chat prompt: %w", err)
|
|
}
|
|
if chat.ParentChatID.Valid {
|
|
prompt = chatprompt.InsertSystem(prompt, defaultSubagentInstruction)
|
|
}
|
|
|
|
// Start buffering stream events for this chat so that new
|
|
// subscribers receive a snapshot of in-flight message parts.
|
|
p.streamMu.Lock()
|
|
startState := p.streamStateLocked(chat.ID)
|
|
startState.buffer = nil
|
|
startState.buffering = true
|
|
p.streamMu.Unlock()
|
|
defer func() {
|
|
p.streamMu.Lock()
|
|
if stopState, ok := p.chatStreams[chat.ID]; ok {
|
|
stopState.buffer = nil
|
|
stopState.buffering = false
|
|
p.cleanupStreamIfIdleLocked(chat.ID, stopState)
|
|
}
|
|
p.streamMu.Unlock()
|
|
}()
|
|
|
|
currentChat := chat
|
|
loadChatSnapshot := func(
|
|
loadCtx context.Context,
|
|
chatID uuid.UUID,
|
|
) (database.Chat, error) {
|
|
return p.db.GetChatByID(loadCtx, chatID)
|
|
}
|
|
var (
|
|
chatStateMu sync.Mutex
|
|
workspaceMu sync.Mutex
|
|
conn workspacesdk.AgentConn
|
|
releaseConn func()
|
|
)
|
|
closeConn := func() {
|
|
if releaseConn != nil {
|
|
releaseConn()
|
|
releaseConn = nil
|
|
}
|
|
}
|
|
defer closeConn()
|
|
|
|
getWorkspaceConn := func(ctx context.Context) (workspacesdk.AgentConn, error) {
|
|
chatStateMu.Lock()
|
|
if conn != nil {
|
|
currentConn := conn
|
|
chatStateMu.Unlock()
|
|
return currentConn, nil
|
|
}
|
|
chatSnapshot := currentChat
|
|
chatStateMu.Unlock()
|
|
|
|
if p.agentConnFn == nil {
|
|
return nil, xerrors.New("workspace agent connector is not configured")
|
|
}
|
|
|
|
if !chatSnapshot.WorkspaceID.Valid {
|
|
refreshedChat, refreshErr := refreshChatWorkspaceSnapshot(
|
|
ctx,
|
|
chatSnapshot,
|
|
loadChatSnapshot,
|
|
)
|
|
if refreshErr != nil {
|
|
return nil, refreshErr
|
|
}
|
|
if refreshedChat.WorkspaceID.Valid {
|
|
chatStateMu.Lock()
|
|
currentChat = refreshedChat
|
|
chatSnapshot = refreshedChat
|
|
chatStateMu.Unlock()
|
|
}
|
|
}
|
|
|
|
if !chatSnapshot.WorkspaceID.Valid {
|
|
return nil, xerrors.New("chat has no workspace")
|
|
}
|
|
|
|
agents, err := p.db.GetWorkspaceAgentsInLatestBuildByWorkspaceID(
|
|
ctx,
|
|
chatSnapshot.WorkspaceID.UUID,
|
|
)
|
|
if err != nil || len(agents) == 0 {
|
|
return nil, xerrors.New("chat has no workspace agent")
|
|
}
|
|
|
|
agentConn, agentRelease, err := p.agentConnFn(ctx, agents[0].ID)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("connect to workspace agent: %w", err)
|
|
}
|
|
|
|
chatStateMu.Lock()
|
|
if conn == nil {
|
|
conn = agentConn
|
|
releaseConn = agentRelease
|
|
chatStateMu.Unlock()
|
|
return agentConn, nil
|
|
}
|
|
currentConn := conn
|
|
chatStateMu.Unlock()
|
|
|
|
agentRelease()
|
|
return currentConn, nil
|
|
}
|
|
|
|
if instruction := p.resolveInstructions(ctx, chat, getWorkspaceConn); instruction != "" {
|
|
prompt = chatprompt.InsertSystem(prompt, instruction)
|
|
}
|
|
|
|
// Use the model config's context_limit as a fallback when the LLM
|
|
// provider doesn't include context_limit in its response metadata
|
|
// (which is the common case).
|
|
modelConfigContextLimit := modelConfig.ContextLimit
|
|
|
|
persistStep := func(persistCtx context.Context, step chatloop.PersistedStep) error {
|
|
// Split the step content into assistant blocks and tool
|
|
// result blocks so they can be stored as separate messages
|
|
// with the appropriate roles.
|
|
var assistantBlocks []fantasy.Content
|
|
var toolResults []fantasy.ToolResultContent
|
|
for _, block := range step.Content {
|
|
if tr, ok := fantasy.AsContentType[fantasy.ToolResultContent](block); ok {
|
|
toolResults = append(toolResults, tr)
|
|
continue
|
|
}
|
|
if trPtr, ok := fantasy.AsContentType[*fantasy.ToolResultContent](block); ok && trPtr != nil {
|
|
toolResults = append(toolResults, *trPtr)
|
|
continue
|
|
}
|
|
assistantBlocks = append(assistantBlocks, block)
|
|
}
|
|
|
|
if len(assistantBlocks) > 0 {
|
|
assistantContent, err := chatprompt.MarshalContent(assistantBlocks)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
hasUsage := step.Usage != (fantasy.Usage{})
|
|
assistantMessage, err := p.db.InsertChatMessage(persistCtx, database.InsertChatMessageParams{
|
|
ChatID: chat.ID,
|
|
ModelConfigID: uuid.NullUUID{UUID: modelConfig.ID, Valid: true},
|
|
Role: string(fantasy.MessageRoleAssistant),
|
|
Content: assistantContent,
|
|
Visibility: database.ChatMessageVisibilityBoth,
|
|
InputTokens: usageNullInt64(step.Usage.InputTokens, hasUsage),
|
|
OutputTokens: usageNullInt64(step.Usage.OutputTokens, hasUsage),
|
|
TotalTokens: usageNullInt64(step.Usage.TotalTokens, hasUsage),
|
|
ReasoningTokens: usageNullInt64(
|
|
step.Usage.ReasoningTokens,
|
|
hasUsage,
|
|
),
|
|
CacheCreationTokens: usageNullInt64(
|
|
step.Usage.CacheCreationTokens,
|
|
hasUsage,
|
|
),
|
|
CacheReadTokens: usageNullInt64(step.Usage.CacheReadTokens, hasUsage),
|
|
ContextLimit: step.ContextLimit,
|
|
Compressed: sql.NullBool{},
|
|
})
|
|
if err != nil {
|
|
return xerrors.Errorf("insert assistant message: %w", err)
|
|
}
|
|
p.publishMessage(chat.ID, assistantMessage)
|
|
}
|
|
|
|
for _, tr := range toolResults {
|
|
resultContent, err := chatprompt.MarshalToolResultContent(tr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
toolMessage, err := p.db.InsertChatMessage(persistCtx, database.InsertChatMessageParams{
|
|
ChatID: chat.ID,
|
|
ModelConfigID: uuid.NullUUID{UUID: modelConfig.ID, Valid: true},
|
|
Role: string(fantasy.MessageRoleTool),
|
|
Content: resultContent,
|
|
Visibility: database.ChatMessageVisibilityBoth,
|
|
InputTokens: sql.NullInt64{},
|
|
OutputTokens: sql.NullInt64{},
|
|
TotalTokens: sql.NullInt64{},
|
|
ReasoningTokens: sql.NullInt64{},
|
|
CacheCreationTokens: sql.NullInt64{},
|
|
CacheReadTokens: sql.NullInt64{},
|
|
ContextLimit: sql.NullInt64{},
|
|
Compressed: sql.NullBool{},
|
|
})
|
|
if err != nil {
|
|
return xerrors.Errorf("insert tool result: %w", err)
|
|
}
|
|
|
|
p.publishMessage(chat.ID, toolMessage)
|
|
}
|
|
|
|
// Clear the stream buffer now that the step is
|
|
// persisted. Late-joining subscribers will load
|
|
// these messages from the database instead.
|
|
p.streamMu.Lock()
|
|
if state, ok := p.chatStreams[chat.ID]; ok {
|
|
state.buffer = nil
|
|
}
|
|
p.streamMu.Unlock()
|
|
|
|
return nil
|
|
}
|
|
|
|
// Apply the default MaxOutputTokens if the model config
|
|
// does not specify one.
|
|
if callConfig.MaxOutputTokens == nil {
|
|
maxOutputTokens := int64(32_000)
|
|
callConfig.MaxOutputTokens = &maxOutputTokens
|
|
}
|
|
|
|
// Generate the tool call ID up front so that the streaming
|
|
// parts and durable messages share the same identifier.
|
|
// Without this the client cannot correlate the
|
|
// "Summarizing..." tool call with the "Summarized" tool
|
|
// result.
|
|
compactionToolCallID := "chat_summarized_" + uuid.NewString()
|
|
compactionOptions := &chatloop.CompactionOptions{
|
|
ThresholdPercent: modelConfig.CompressionThreshold,
|
|
ContextLimit: modelConfig.ContextLimit,
|
|
Persist: func(
|
|
persistCtx context.Context,
|
|
result chatloop.CompactionResult,
|
|
) error {
|
|
if err := p.persistChatContextSummary(
|
|
persistCtx,
|
|
chat.ID,
|
|
modelConfig.ID,
|
|
compactionToolCallID,
|
|
result,
|
|
); err != nil {
|
|
return xerrors.Errorf("persist context summary: %w", err)
|
|
}
|
|
logger.Info(persistCtx, "chat context summarized",
|
|
slog.F("chat_id", chat.ID),
|
|
slog.F("threshold_percent", result.ThresholdPercent),
|
|
slog.F("usage_percent", result.UsagePercent),
|
|
slog.F("context_tokens", result.ContextTokens),
|
|
slog.F("context_limit", result.ContextLimit),
|
|
)
|
|
return nil
|
|
},
|
|
ToolCallID: compactionToolCallID,
|
|
ToolName: "chat_summarized",
|
|
PublishMessagePart: func(role fantasy.MessageRole, part codersdk.ChatMessagePart) {
|
|
p.publishMessagePart(chat.ID, string(role), part)
|
|
},
|
|
OnError: func(err error) {
|
|
logger.Warn(ctx, "failed to compact chat context", slog.Error(err))
|
|
},
|
|
}
|
|
|
|
// Here are all the tools we have for the chat.
|
|
tools := []fantasy.AgentTool{
|
|
chattool.ListTemplates(chattool.ListTemplatesOptions{
|
|
DB: p.db,
|
|
OwnerID: chat.OwnerID,
|
|
}),
|
|
chattool.ReadTemplate(chattool.ReadTemplateOptions{
|
|
DB: p.db,
|
|
OwnerID: chat.OwnerID,
|
|
}),
|
|
chattool.CreateWorkspace(chattool.CreateWorkspaceOptions{
|
|
DB: p.db,
|
|
OwnerID: chat.OwnerID,
|
|
ChatID: chat.ID,
|
|
CreateFn: p.createWorkspaceFn,
|
|
AgentConnFn: chattool.AgentConnFunc(p.agentConnFn),
|
|
WorkspaceMu: &workspaceMu,
|
|
}),
|
|
chattool.ReadFile(chattool.ReadFileOptions{
|
|
GetWorkspaceConn: getWorkspaceConn,
|
|
}),
|
|
chattool.WriteFile(chattool.WriteFileOptions{
|
|
GetWorkspaceConn: getWorkspaceConn,
|
|
}),
|
|
chattool.EditFiles(chattool.EditFilesOptions{
|
|
GetWorkspaceConn: getWorkspaceConn,
|
|
}),
|
|
chattool.Execute(chattool.ExecuteOptions{
|
|
GetWorkspaceConn: getWorkspaceConn,
|
|
ChatID: chat.ID.String(),
|
|
}),
|
|
chattool.ProcessOutput(chattool.ProcessToolOptions{
|
|
GetWorkspaceConn: getWorkspaceConn,
|
|
}),
|
|
chattool.ProcessList(chattool.ProcessToolOptions{
|
|
GetWorkspaceConn: getWorkspaceConn,
|
|
}),
|
|
chattool.ProcessSignal(chattool.ProcessToolOptions{
|
|
GetWorkspaceConn: getWorkspaceConn,
|
|
}),
|
|
}
|
|
// Only root chats (not delegated subagents) get subagent tools.
|
|
// Child agents must not spawn further subagents — they should
|
|
// focus on completing their delegated task.
|
|
if !chat.ParentChatID.Valid {
|
|
tools = append(tools, p.subagentTools(func() database.Chat {
|
|
return chat
|
|
})...)
|
|
}
|
|
|
|
err = chatloop.Run(ctx, chatloop.RunOptions{
|
|
Model: model,
|
|
Messages: prompt,
|
|
Tools: tools,
|
|
MaxSteps: maxChatSteps,
|
|
|
|
ModelConfig: callConfig,
|
|
ProviderOptions: chatprovider.ProviderOptionsFromChatModelConfig(model, callConfig.ProviderOptions),
|
|
|
|
ContextLimitFallback: modelConfigContextLimit,
|
|
|
|
PersistStep: persistStep,
|
|
PublishMessagePart: func(
|
|
role fantasy.MessageRole,
|
|
part codersdk.ChatMessagePart,
|
|
) {
|
|
p.publishMessagePart(chat.ID, string(role), part)
|
|
},
|
|
Compaction: compactionOptions,
|
|
|
|
OnRetry: func(attempt int, retryErr error, delay time.Duration) {
|
|
logger.Warn(ctx, "retrying LLM stream",
|
|
slog.F("attempt", attempt),
|
|
slog.F("delay", delay.String()),
|
|
slog.Error(retryErr),
|
|
)
|
|
p.publishEvent(chat.ID, codersdk.ChatStreamEvent{
|
|
Type: codersdk.ChatStreamEventTypeRetry,
|
|
ChatID: chat.ID,
|
|
Retry: &codersdk.ChatStreamRetry{
|
|
Attempt: attempt,
|
|
DelayMs: delay.Milliseconds(),
|
|
Error: retryErr.Error(),
|
|
RetryingAt: time.Now().Add(delay),
|
|
},
|
|
})
|
|
},
|
|
|
|
OnInterruptedPersistError: func(err error) {
|
|
p.logger.Warn(ctx, "failed to persist interrupted chat step", slog.Error(err))
|
|
},
|
|
})
|
|
return err
|
|
}
|
|
|
|
// persistChatContextSummary persists a chat context summary to the database.
|
|
// This is invoked via the chat loop's compaction callback.
|
|
func (p *Server) persistChatContextSummary(
|
|
ctx context.Context,
|
|
chatID uuid.UUID,
|
|
modelConfigID uuid.UUID,
|
|
toolCallID string,
|
|
result chatloop.CompactionResult,
|
|
) error {
|
|
if strings.TrimSpace(result.SystemSummary) == "" ||
|
|
strings.TrimSpace(result.SummaryReport) == "" {
|
|
return nil
|
|
}
|
|
|
|
systemContent, err := json.Marshal(result.SystemSummary)
|
|
if err != nil {
|
|
return xerrors.Errorf("encode system summary: %w", err)
|
|
}
|
|
|
|
_, err = p.db.InsertChatMessage(ctx, database.InsertChatMessageParams{
|
|
ChatID: chatID,
|
|
ModelConfigID: uuid.NullUUID{UUID: modelConfigID, Valid: true},
|
|
Role: string(fantasy.MessageRoleSystem),
|
|
Content: pqtype.NullRawMessage{
|
|
RawMessage: systemContent,
|
|
Valid: len(systemContent) > 0,
|
|
},
|
|
Visibility: database.ChatMessageVisibilityModel,
|
|
Compressed: sql.NullBool{Bool: true, Valid: true},
|
|
InputTokens: sql.NullInt64{},
|
|
OutputTokens: sql.NullInt64{},
|
|
TotalTokens: sql.NullInt64{},
|
|
ReasoningTokens: sql.NullInt64{},
|
|
CacheCreationTokens: sql.NullInt64{},
|
|
CacheReadTokens: sql.NullInt64{},
|
|
ContextLimit: sql.NullInt64{},
|
|
})
|
|
if err != nil {
|
|
return xerrors.Errorf("insert hidden summary message: %w", err)
|
|
}
|
|
|
|
args, err := json.Marshal(map[string]any{
|
|
"source": "automatic",
|
|
"threshold_percent": result.ThresholdPercent,
|
|
})
|
|
if err != nil {
|
|
return xerrors.Errorf("encode summary tool args: %w", err)
|
|
}
|
|
|
|
assistantContent, err := chatprompt.MarshalContent([]fantasy.Content{
|
|
fantasy.ToolCallContent{
|
|
ToolCallID: toolCallID,
|
|
ToolName: "chat_summarized",
|
|
Input: string(args),
|
|
},
|
|
})
|
|
if err != nil {
|
|
return xerrors.Errorf("encode summary tool call: %w", err)
|
|
}
|
|
|
|
assistantMessage, err := p.db.InsertChatMessage(ctx, database.InsertChatMessageParams{
|
|
ChatID: chatID,
|
|
ModelConfigID: uuid.NullUUID{UUID: modelConfigID, Valid: true},
|
|
Role: string(fantasy.MessageRoleAssistant),
|
|
Content: assistantContent,
|
|
Visibility: database.ChatMessageVisibilityUser,
|
|
Compressed: sql.NullBool{
|
|
Bool: true,
|
|
Valid: true,
|
|
},
|
|
InputTokens: sql.NullInt64{},
|
|
OutputTokens: sql.NullInt64{},
|
|
TotalTokens: sql.NullInt64{},
|
|
ReasoningTokens: sql.NullInt64{},
|
|
CacheCreationTokens: sql.NullInt64{},
|
|
CacheReadTokens: sql.NullInt64{},
|
|
ContextLimit: sql.NullInt64{},
|
|
})
|
|
if err != nil {
|
|
return xerrors.Errorf("insert summary tool call message: %w", err)
|
|
}
|
|
|
|
summaryResult, marshalErr := json.Marshal(map[string]any{
|
|
"summary": result.SummaryReport,
|
|
"source": "automatic",
|
|
"threshold_percent": result.ThresholdPercent,
|
|
"usage_percent": result.UsagePercent,
|
|
"context_tokens": result.ContextTokens,
|
|
"context_limit_tokens": result.ContextLimit,
|
|
})
|
|
if marshalErr != nil {
|
|
return xerrors.Errorf("encode summary result payload: %w", marshalErr)
|
|
}
|
|
toolResult, err := chatprompt.MarshalToolResult(
|
|
toolCallID,
|
|
"chat_summarized",
|
|
summaryResult,
|
|
false,
|
|
)
|
|
if err != nil {
|
|
return xerrors.Errorf("encode summary tool result: %w", err)
|
|
}
|
|
|
|
toolMessage, err := p.db.InsertChatMessage(ctx, database.InsertChatMessageParams{
|
|
ChatID: chatID,
|
|
ModelConfigID: uuid.NullUUID{UUID: modelConfigID, Valid: true},
|
|
Role: string(fantasy.MessageRoleTool),
|
|
Content: toolResult,
|
|
Visibility: database.ChatMessageVisibilityBoth,
|
|
Compressed: sql.NullBool{
|
|
Bool: true,
|
|
Valid: true,
|
|
},
|
|
InputTokens: sql.NullInt64{},
|
|
OutputTokens: sql.NullInt64{},
|
|
TotalTokens: sql.NullInt64{},
|
|
ReasoningTokens: sql.NullInt64{},
|
|
CacheCreationTokens: sql.NullInt64{},
|
|
CacheReadTokens: sql.NullInt64{},
|
|
ContextLimit: sql.NullInt64{},
|
|
})
|
|
if err != nil {
|
|
return xerrors.Errorf("insert summary tool result message: %w", err)
|
|
}
|
|
|
|
p.publishMessage(chatID, assistantMessage)
|
|
p.publishMessage(chatID, toolMessage)
|
|
return nil
|
|
}
|
|
|
|
func (p *Server) resolveChatModel(
|
|
ctx context.Context,
|
|
chat database.Chat,
|
|
) (fantasy.LanguageModel, database.ChatModelConfig, error) {
|
|
dbConfig, err := p.resolveModelConfig(ctx, chat)
|
|
if err != nil {
|
|
return nil, database.ChatModelConfig{}, xerrors.Errorf(
|
|
"resolve model config: %w", err,
|
|
)
|
|
}
|
|
|
|
providers, err := p.db.GetEnabledChatProviders(ctx)
|
|
if err != nil {
|
|
return nil, database.ChatModelConfig{}, xerrors.Errorf(
|
|
"get enabled chat providers: %w", err,
|
|
)
|
|
}
|
|
dbProviders := make(
|
|
[]chatprovider.ConfiguredProvider, 0, len(providers),
|
|
)
|
|
for _, provider := range providers {
|
|
dbProviders = append(dbProviders, chatprovider.ConfiguredProvider{
|
|
Provider: provider.Provider,
|
|
APIKey: provider.APIKey,
|
|
BaseURL: provider.BaseUrl,
|
|
})
|
|
}
|
|
keys := chatprovider.MergeProviderAPIKeys(
|
|
p.providerAPIKeys, dbProviders,
|
|
)
|
|
|
|
model, err := chatprovider.ModelFromConfig(
|
|
dbConfig.Provider, dbConfig.Model, keys,
|
|
)
|
|
if err != nil {
|
|
return nil, database.ChatModelConfig{}, xerrors.Errorf(
|
|
"create model: %w", err,
|
|
)
|
|
}
|
|
return model, dbConfig, nil
|
|
}
|
|
|
|
// resolveModelConfig looks up the chat's model config by its
|
|
// LastModelConfigID. If the referenced config no longer exists
|
|
// (e.g. it was deleted), it falls back to the default model
|
|
// config. Returns an error when no usable config is available.
|
|
func (p *Server) resolveModelConfig(
|
|
ctx context.Context,
|
|
chat database.Chat,
|
|
) (database.ChatModelConfig, error) {
|
|
if chat.LastModelConfigID != uuid.Nil {
|
|
modelConfig, err := p.db.GetChatModelConfigByID(
|
|
ctx, chat.LastModelConfigID,
|
|
)
|
|
if err == nil {
|
|
return modelConfig, nil
|
|
}
|
|
if !xerrors.Is(err, sql.ErrNoRows) {
|
|
return database.ChatModelConfig{}, xerrors.Errorf(
|
|
"get chat model config %s: %w",
|
|
chat.LastModelConfigID, err,
|
|
)
|
|
}
|
|
// Model config was deleted, fall through to default.
|
|
}
|
|
|
|
defaultConfig, err := p.db.GetDefaultChatModelConfig(ctx)
|
|
if err != nil {
|
|
if xerrors.Is(err, sql.ErrNoRows) {
|
|
return database.ChatModelConfig{}, xerrors.New(
|
|
"no default chat model config is available",
|
|
)
|
|
}
|
|
return database.ChatModelConfig{}, xerrors.Errorf(
|
|
"get default chat model config: %w", err,
|
|
)
|
|
}
|
|
return defaultConfig, nil
|
|
}
|
|
|
|
//nolint:revive // Boolean controls SQL NULL validity.
|
|
func usageNullInt64(value int64, valid bool) sql.NullInt64 {
|
|
if !valid {
|
|
return sql.NullInt64{}
|
|
}
|
|
return sql.NullInt64{
|
|
Int64: value,
|
|
Valid: valid,
|
|
}
|
|
}
|
|
|
|
func refreshChatWorkspaceSnapshot(
|
|
ctx context.Context,
|
|
chat database.Chat,
|
|
loadChat func(context.Context, uuid.UUID) (database.Chat, error),
|
|
) (database.Chat, error) {
|
|
if chat.WorkspaceID.Valid || loadChat == nil {
|
|
return chat, nil
|
|
}
|
|
|
|
refreshedChat, err := loadChat(ctx, chat.ID)
|
|
if err != nil {
|
|
return chat, xerrors.Errorf("reload chat workspace state: %w", err)
|
|
}
|
|
|
|
return refreshedChat, nil
|
|
}
|
|
|
|
// resolveInstructions returns the combined system instructions for the
|
|
// workspace agent. It reads the home-level (~/.coder/AGENTS.md) and
|
|
// working-directory-level (<pwd>/AGENTS.md) instruction files, combines
|
|
// them with agent metadata (OS, directory), and caches the result.
|
|
func (p *Server) resolveInstructions(
|
|
ctx context.Context,
|
|
chat database.Chat,
|
|
getWorkspaceConn func(context.Context) (workspacesdk.AgentConn, error),
|
|
) string {
|
|
if !chat.WorkspaceID.Valid {
|
|
return ""
|
|
}
|
|
|
|
agents, agentsErr := p.db.GetWorkspaceAgentsInLatestBuildByWorkspaceID(
|
|
ctx,
|
|
chat.WorkspaceID.UUID,
|
|
)
|
|
if agentsErr != nil || len(agents) == 0 {
|
|
return ""
|
|
}
|
|
agentID := agents[0].ID
|
|
|
|
p.instructionCacheMu.Lock()
|
|
cached, ok := p.instructionCache[agentID]
|
|
p.instructionCacheMu.Unlock()
|
|
|
|
if ok && time.Since(cached.fetchedAt) < instructionCacheTTL {
|
|
return cached.instruction
|
|
}
|
|
|
|
// Look up the agent's OS and working directory.
|
|
agent, err := p.db.GetWorkspaceAgentByID(ctx, agentID)
|
|
if err != nil {
|
|
p.logger.Debug(ctx, "failed to look up workspace agent for instruction context",
|
|
slog.F("agent_id", agentID),
|
|
slog.Error(err),
|
|
)
|
|
}
|
|
directory := agent.ExpandedDirectory
|
|
if directory == "" {
|
|
directory = agent.Directory
|
|
}
|
|
|
|
// Read instruction files from the workspace agent.
|
|
var sections []instructionFileSection
|
|
if getWorkspaceConn != nil {
|
|
instructionCtx, cancel := context.WithTimeout(ctx, homeInstructionLookupTimeout)
|
|
defer cancel()
|
|
|
|
conn, connErr := getWorkspaceConn(instructionCtx)
|
|
if connErr != nil {
|
|
p.logger.Debug(ctx, "failed to resolve workspace connection for instruction files",
|
|
slog.F("chat_id", chat.ID),
|
|
slog.Error(connErr),
|
|
)
|
|
} else {
|
|
// ~/.coder/AGENTS.md
|
|
if content, source, truncated, err := readHomeInstructionFile(instructionCtx, conn); err != nil {
|
|
p.logger.Debug(ctx, "failed to load home instruction file",
|
|
slog.F("chat_id", chat.ID), slog.Error(err))
|
|
} else if content != "" {
|
|
sections = append(sections, instructionFileSection{content, source, truncated})
|
|
}
|
|
|
|
// <pwd>/AGENTS.md
|
|
if pwdPath := pwdInstructionFilePath(directory); pwdPath != "" {
|
|
if content, source, truncated, err := readInstructionFile(instructionCtx, conn, pwdPath); err != nil {
|
|
p.logger.Debug(ctx, "failed to load working directory instruction file",
|
|
slog.F("chat_id", chat.ID), slog.F("directory", directory), slog.Error(err))
|
|
} else if content != "" {
|
|
sections = append(sections, instructionFileSection{content, source, truncated})
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
instruction := formatSystemInstructions(agent.OperatingSystem, directory, sections)
|
|
|
|
p.instructionCacheMu.Lock()
|
|
p.instructionCache[agentID] = cachedInstruction{
|
|
instruction: instruction,
|
|
fetchedAt: time.Now(),
|
|
}
|
|
p.instructionCacheMu.Unlock()
|
|
|
|
return instruction
|
|
}
|
|
|
|
func (p *Server) recoverStaleChats(ctx context.Context) {
|
|
staleAfter := time.Now().Add(-p.inFlightChatStaleAfter)
|
|
staleChats, err := p.db.GetStaleChats(ctx, staleAfter)
|
|
if err != nil {
|
|
p.logger.Error(ctx, "failed to get stale chats", slog.Error(err))
|
|
return
|
|
}
|
|
|
|
for _, chat := range staleChats {
|
|
p.logger.Info(ctx, "recovering stale chat", slog.F("chat_id", chat.ID))
|
|
|
|
// Reset to pending so any replica can pick it up.
|
|
_, err := p.db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{
|
|
ID: chat.ID,
|
|
Status: database.ChatStatusPending,
|
|
WorkerID: uuid.NullUUID{},
|
|
StartedAt: sql.NullTime{},
|
|
HeartbeatAt: sql.NullTime{},
|
|
LastError: sql.NullString{},
|
|
})
|
|
if err != nil {
|
|
p.logger.Error(ctx, "failed to recover stale chat",
|
|
slog.F("chat_id", chat.ID), slog.Error(err))
|
|
}
|
|
}
|
|
|
|
if len(staleChats) > 0 {
|
|
p.logger.Info(ctx, "recovered stale chats", slog.F("count", len(staleChats)))
|
|
}
|
|
}
|
|
|
|
// Close stops the processor and waits for it to finish.
|
|
func (p *Server) Close() error {
|
|
p.cancel()
|
|
<-p.closed
|
|
p.inflight.Wait()
|
|
return nil
|
|
}
|