Files
coder/coderd/x/chatd/chatd.go
T
Hugo Dutka 658a04d28f pr 3
2026-06-04 18:51:22 +00:00

6585 lines
208 KiB
Go

package chatd
import (
"bytes"
"cmp"
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"math"
"net/http"
"slices"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"charm.land/fantasy"
"charm.land/fantasy/providers/anthropic"
"github.com/dustin/go-humanize"
"github.com/google/uuid"
"github.com/prometheus/client_golang/prometheus"
"github.com/shopspring/decimal"
"github.com/sqlc-dev/pqtype"
"golang.org/x/sync/errgroup"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/aibridge"
"github.com/coder/coder/v2/coderd/audit"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/pubsub"
"github.com/coder/coder/v2/coderd/notifications"
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/coderd/util/xjson"
"github.com/coder/coder/v2/coderd/webpush"
"github.com/coder/coder/v2/coderd/workspacestats"
"github.com/coder/coder/v2/coderd/x/chatd/chatadvisor"
"github.com/coder/coder/v2/coderd/x/chatd/chatcost"
"github.com/coder/coder/v2/coderd/x/chatd/chatdebug"
"github.com/coder/coder/v2/coderd/x/chatd/chaterror"
"github.com/coder/coder/v2/coderd/x/chatd/chatloop"
"github.com/coder/coder/v2/coderd/x/chatd/chatopenai"
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
"github.com/coder/coder/v2/coderd/x/chatd/chatstate"
"github.com/coder/coder/v2/coderd/x/chatd/chattool"
"github.com/coder/coder/v2/coderd/x/chatd/internal/agentselect"
"github.com/coder/coder/v2/coderd/x/chatd/mcpclient"
"github.com/coder/coder/v2/coderd/x/chatd/messagepartbuffer"
skillspkg "github.com/coder/coder/v2/coderd/x/skills"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/quartz"
)
const (
// DefaultPendingChatAcquireInterval is the default time between attempts to
// acquire pending chats.
DefaultPendingChatAcquireInterval = time.Second
// DefaultInFlightChatStaleAfter is the default age after which a running
// chat is considered stale and should be recovered.
DefaultInFlightChatStaleAfter = 5 * time.Minute
homeInstructionLookupTimeout = 5 * time.Second
planPathLookupTimeout = 5 * time.Second
workspaceDialValidationDelay = 5 * time.Second
// Must exceed agent/x/agentmcp.connectTimeout (30s) so a
// cold-start agent's first MCP reload can settle before
// chatd gives up.
workspaceMCPDiscoveryTimeout = 35 * time.Second
// workspaceMCPPrimeMaxWait bounds the deadline used by the
// create_workspace / start_workspace post-ready cache primer
// loop. The primer checks the deadline only after each
// discoverWorkspaceMCPTools call returns, so total wall-clock
// time can exceed this by one such call (dialTimeout +
// workspaceMCPDiscoveryTimeout in the worst case). The constant
// caps when new retries can start, not when an in-flight call
// must finish. Empty results usually mean the agent's MCP
// Connect is still racing with agent startup. The agent-side
// budget is agent/x/agentmcp.connectTimeout (30s).
workspaceMCPPrimeMaxWait = 30 * time.Second
// workspaceMCPPrimeRetryInterval is the short backoff between
// re-attempts inside the primer when ListMCPTools returns an
// empty list without error.
workspaceMCPPrimeRetryInterval = 2 * time.Second
turnStatusLabelWriteTimeout = 5 * time.Second
// defaultDialTimeout matches the timeout used by ~8 other
// server-side AgentConn callers.
defaultDialTimeout = 30 * time.Second
// DefaultChatHeartbeatInterval is the default time between chat
// heartbeat updates while a chat is being processed.
DefaultChatHeartbeatInterval = 30 * time.Second
maxChatSteps = 1200
// RelaySentinelAfterID is the after_id sentinel used by cross-replica
// relay subscribers. It instructs the peer to skip the durable DB
// snapshot and only deliver buffered message_part events. The
// buffer itself filters committed parts out (see snapshotBufferLocked),
// so the sentinel resolves to "send me any in-progress streaming
// parts you have; I will receive durable messages through pubsub."
RelaySentinelAfterID = math.MaxInt64
// maxConcurrentRecordingUploads caps the number of recording
// stop-and-store operations that can run concurrently. Each
// slot buffers up to MaxRecordingSize + MaxThumbnailSize
// (110 MB) in memory, so this value implicitly bounds memory
// to roughly maxConcurrentRecordingUploads * 110 MB.
maxConcurrentRecordingUploads = 25
// bufferRetainGracePeriod is how long the per-chat stream
// state is kept after processing completes. The retained
// state lets late-connecting cross-replica relay subscribers
// register against the live stream before the next worker
// run starts, preventing a race between cleanupStreamIfIdle
// and subscriber registration. The buffer itself is no
// longer useful at this point: every part has been claimed
// by its durable assistant message and is filtered out of
// the subscriber snapshot.
bufferRetainGracePeriod = 5 * time.Second
// chatStreamControlFetchTimeout bounds subscriber-owned
// control-path DB reads when the caller has no deadline.
chatStreamControlFetchTimeout = 5 * time.Second
// streamJanitorInterval is how often sweepIdleStreams runs.
// Worst-case retention is bufferRetainGracePeriod +
// streamJanitorInterval.
streamJanitorInterval = 30 * time.Second
// agentDisconnectedRecoveryThreshold is how long the latest
// workspace agent must be disconnected before chatd suggests
// destructive stop/start recovery. This is intentionally longer
// than the inactive-disconnect timeout so short heartbeat gaps do
// not prompt a workspace restart.
agentDisconnectedRecoveryThreshold = 90 * time.Second
// DefaultMaxChatsPerAcquire is the maximum number of chats to
// acquire in a single processOnce call. Batching avoids
// waiting a full polling interval between acquisitions
// when many chats are pending.
DefaultMaxChatsPerAcquire int32 = 10
defaultSubagentInstruction = "You are running as a delegated sub-agent chat. Complete the delegated task and provide clear, concise assistant responses for the parent agent."
// defaultAdvisorMaxOutputTokens caps the nested advisor response
// when the admin config omits the field (or sets it to <= 0).
// It is intentionally generous relative to the advisor's concise
// guidance remit so short plans are not truncated mid-reasoning.
defaultAdvisorMaxOutputTokens = 16384
)
var (
errChatHasNoWorkspaceAgent = xerrors.New("workspace has no running agent: the workspace is likely stopped. Use the start_workspace tool to start it")
errChatAgentDisconnected = xerrors.New(
"workspace agent has been disconnected for at least 90 seconds " +
"and cannot execute tools. To recover, call stop_workspace " +
"to stop the workspace, then start_workspace to start it " +
"again",
)
errChatDialTimeout = xerrors.New(
"connection to the workspace agent timed out. " +
"The agent may still be reachable on the next attempt.",
)
errChatExternalAgentUnavailable = xerrors.New("external workspace agent unavailable")
)
type chatExternalAgentUnavailableError struct {
message string
}
func (e chatExternalAgentUnavailableError) Error() string {
return e.message
}
func (chatExternalAgentUnavailableError) Is(target error) bool {
return target == errChatExternalAgentUnavailable
}
func newChatExternalAgentUnavailableError(agent database.WorkspaceAgent) error {
return chatExternalAgentUnavailableError{
message: chattool.ExternalAgentUnavailableMessage(agent),
}
}
// Server handles background processing of pending chats.
type Server struct {
cancel context.CancelFunc
ctx context.Context
wg sync.WaitGroup
inflight sync.WaitGroup
inflightMu sync.Mutex
db database.Store
workerID uuid.UUID
logger slog.Logger
subscribeFn SubscribeFn
agentConnFn AgentConnFunc
agentInactiveDisconnectTimeout time.Duration
dialTimeout time.Duration
instructionLookupTimeout time.Duration
createWorkspaceFn chattool.CreateWorkspaceFn
startWorkspaceFn chattool.StartWorkspaceFn
stopWorkspaceFn chattool.StopWorkspaceFn
pubsub pubsub.Pubsub
webpushDispatcher webpush.Dispatcher
providerAPIKeys chatprovider.ProviderAPIKeys
allowBYOK bool
oidcTokenSource mcpclient.UserOIDCTokenSource
debugSvc *chatdebug.Service
debugSvcFactory func() *chatdebug.Service
debugSvcReady atomic.Bool
debugSvcInit sync.Once
configCache *chatConfigCache
configCacheUnsubscribe func()
// chatStreams stores per-chat stream state. Using sync.Map
// gives each chat independent locking — concurrent chats
// never contend with each other.
chatStreams sync.Map // uuid.UUID -> *chatStreamState
// workspaceMCPToolsCache caches workspace MCP tool definitions
// per chat to avoid re-fetching on every turn. The cache is
// keyed by chat ID and invalidated when the agent changes.
workspaceMCPToolsCache sync.Map // uuid.UUID -> *cachedWorkspaceMCPTools
usageTracker *workspacestats.UsageTracker
clock quartz.Clock
metrics *chatloop.Metrics
chatWorker *chatWorker
messagePartBuffer *messagepartbuffer.Buffer
recordingSem chan struct{}
aibridgeTransportFactory *atomic.Pointer[aibridge.TransportFactory]
aiGatewayRoutingEnabled bool
// Configuration
pendingChatAcquireInterval time.Duration
maxChatsPerAcquire int32
inFlightChatStaleAfter time.Duration
chatHeartbeatInterval time.Duration
}
// chatTemplateAllowlist returns the deployment-wide template
// allowlist as a set of permitted template IDs. The callback
// signature matches what the chat tools expect. When the
// allowlist is empty or cannot be loaded the function returns
// nil, which the tools interpret as "all templates allowed".
func (p *Server) chatTemplateAllowlist() map[uuid.UUID]bool {
//nolint:gocritic // AsChatd provides narrowly-scoped daemon
// access for reading deployment config.
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
//nolint:gocritic // AsChatd provides narrowly-scoped read
// access to deployment config (the template allowlist).
ctx = dbauthz.AsChatd(ctx)
raw, err := p.db.GetChatTemplateAllowlist(ctx)
if err != nil {
p.logger.Warn(ctx, "failed to load chat template allowlist", slog.Error(err))
return nil
}
ids, err := xjson.ParseUUIDList(raw)
if err != nil {
p.logger.Warn(ctx, "failed to parse chat template allowlist", slog.Error(err))
return nil
}
m := make(map[uuid.UUID]bool, len(ids))
for _, id := range ids {
m[id] = true
}
return m
}
func (p *Server) loadAdvisorConfig(ctx context.Context, logger slog.Logger) codersdk.AdvisorConfig {
cfg, err := p.configCache.AdvisorConfig(ctx)
if err != nil {
logger.Warn(ctx, "failed to load advisor config", slog.Error(err))
return codersdk.AdvisorConfig{}
}
return cfg
}
// stripAdvisorGuidanceBlock removes any system message whose text content
// matches chatadvisor.ParentGuidanceBlock after whitespace normalization.
// The block is meant for the parent agent (it advertises the advisor tool)
// and would waste context tokens if forwarded to the advisor's nested run.
func stripAdvisorGuidanceBlock(msgs []fantasy.Message) []fantasy.Message {
filtered := msgs[:0]
for _, msg := range msgs {
if msg.Role == fantasy.MessageRoleSystem && isAdvisorGuidanceMessage(msg) {
continue
}
filtered = append(filtered, msg)
}
return filtered
}
func isAdvisorGuidanceMessage(msg fantasy.Message) bool {
if len(msg.Content) != 1 {
return false
}
text, ok := msg.Content[0].(fantasy.TextPart)
if !ok {
return false
}
return strings.TrimSpace(text.Text) == strings.TrimSpace(chatadvisor.ParentGuidanceBlock)
}
func (p *Server) resolveAdvisorModelOverride(
ctx context.Context,
chat database.Chat,
advisorCfg codersdk.AdvisorConfig,
fallbackModel fantasy.LanguageModel,
fallbackCallConfig codersdk.ChatModelCallConfig,
providerKeys chatprovider.ProviderAPIKeys,
modelOpts modelBuildOptions,
logger slog.Logger,
) (fantasy.LanguageModel, codersdk.ChatModelCallConfig, error) {
if advisorCfg.ModelConfigID == uuid.Nil {
return fallbackModel, fallbackCallConfig, nil
}
// Re-read the override instead of using the cache so disabled models
// or providers stop routing advisor prompts immediately.
overrideConfig, err := p.db.GetEnabledChatModelConfigByID(
ctx,
advisorCfg.ModelConfigID,
)
if err != nil {
if xerrors.Is(err, sql.ErrNoRows) {
logger.Warn(
ctx,
"advisor model config is disabled or unavailable, continuing with chat model",
slog.F("model_config_id", advisorCfg.ModelConfigID),
)
return fallbackModel, fallbackCallConfig, nil
}
logger.Warn(
ctx,
"failed to resolve advisor model config, continuing with chat model",
slog.F("model_config_id", advisorCfg.ModelConfigID),
slog.Error(err),
)
return fallbackModel, fallbackCallConfig, nil
}
overrideCallConfig := codersdk.ChatModelCallConfig{}
if len(overrideConfig.Options) > 0 {
if err := json.Unmarshal(overrideConfig.Options, &overrideCallConfig); err != nil {
logger.Warn(
ctx,
"failed to parse advisor model config, continuing with chat model",
slog.F("model_config_id", advisorCfg.ModelConfigID),
slog.Error(err),
)
return fallbackModel, fallbackCallConfig, nil
}
}
route, err := p.resolveModelRouteForConfig(
ctx,
chat.OwnerID,
overrideConfig,
providerKeys,
)
if err != nil {
if p.shouldUseAIGatewayRouting() && overrideConfig.AIProviderID.Valid {
return nil, codersdk.ChatModelCallConfig{}, xerrors.Errorf("resolve advisor override route: %w", err)
}
logger.Warn(
ctx,
"failed to resolve advisor override route, continuing with chat model",
slog.F("model_config_id", advisorCfg.ModelConfigID),
slog.Error(err),
)
return fallbackModel, fallbackCallConfig, nil
}
overrideModel, err := p.newModel(ctx, modelClientRequest{
Chat: chat,
ModelName: overrideConfig.Model,
UserAgent: chatprovider.UserAgent(),
ExtraHeaders: chatprovider.CoderHeaders(chat),
}, route, modelOpts)
if err != nil {
if p.shouldUseAIGatewayRouting() && overrideConfig.AIProviderID.Valid {
return nil, codersdk.ChatModelCallConfig{}, xerrors.Errorf("create advisor override model: %w", err)
}
logger.Warn(
ctx,
"failed to create advisor override model, continuing with chat model",
slog.F("model_config_id", advisorCfg.ModelConfigID),
slog.Error(err),
)
return fallbackModel, fallbackCallConfig, nil
}
return overrideModel, overrideCallConfig, nil
}
func (p *Server) newAdvisorRuntime(
ctx context.Context,
chat database.Chat,
advisorCfg codersdk.AdvisorConfig,
fallbackModel fantasy.LanguageModel,
fallbackCallConfig codersdk.ChatModelCallConfig,
providerKeys chatprovider.ProviderAPIKeys,
modelOpts modelBuildOptions,
logger slog.Logger,
) (*chatadvisor.Runtime, error) {
advisorModel, advisorCallConfig, err := p.resolveAdvisorModelOverride(
ctx,
chat,
advisorCfg,
fallbackModel,
fallbackCallConfig,
providerKeys,
modelOpts,
logger,
)
if err != nil {
return nil, err
}
maxUsesPerRun := advisorCfg.MaxUsesPerRun
switch {
case maxUsesPerRun == 0:
// Advisor config treats 0 as unlimited, but the runtime
// requires a positive bound. maxChatSteps is the
// effective upper bound because advisor can run at most
// once per loop step.
maxUsesPerRun = maxChatSteps
case maxUsesPerRun < 0:
logger.Warn(
ctx,
"invalid advisor max uses per run, continuing without advisor",
slog.F("max_uses_per_run", maxUsesPerRun),
)
return nil, nil //nolint:nilnil // Nil runtime with nil error means advisor is skipped for this turn.
}
maxOutputTokens := advisorCfg.MaxOutputTokens
if maxOutputTokens <= 0 {
maxOutputTokens = defaultAdvisorMaxOutputTokens
}
advisorCallConfig.MaxOutputTokens = ptr.Ref(maxOutputTokens)
providerOptions := chatprovider.ProviderOptionsFromChatModelConfig(
advisorModel,
advisorCallConfig.ProviderOptions,
)
rt, err := chatadvisor.NewRuntime(chatadvisor.RuntimeConfig{
Model: advisorModel,
ModelConfig: advisorCallConfig,
ProviderOptions: providerOptions,
MaxUsesPerRun: maxUsesPerRun,
MaxOutputTokens: maxOutputTokens,
})
if err != nil {
logger.Warn(
ctx,
"failed to create advisor runtime, continuing without advisor",
slog.Error(err),
)
return nil, nil //nolint:nilnil // Nil runtime with nil error means advisor is skipped for this turn.
}
return rt, nil
}
// cachedWorkspaceMCPTools stores workspace MCP tools discovered
// from a workspace agent, keyed by the agent ID that provided them.
type cachedWorkspaceMCPTools struct {
agentID uuid.UUID
tools []workspacesdk.MCPToolInfo
}
// loadCachedWorkspaceContext checks the MCP tools cache for the
// given chat and agent. Returns non-nil tools when the cache hits,
// which signals the caller to skip the slow MCP discovery path.
func (p *Server) loadCachedWorkspaceContext(
chatID uuid.UUID,
agent database.WorkspaceAgent,
getConn func(context.Context) (workspacesdk.AgentConn, error),
) []fantasy.AgentTool {
cached, ok := p.workspaceMCPToolsCache.Load(chatID)
if !ok {
return nil
}
entry, ok := cached.(*cachedWorkspaceMCPTools)
if !ok || entry.agentID != agent.ID {
return nil
}
var tools []fantasy.AgentTool
invalidate := func() { p.workspaceMCPToolsCache.Delete(chatID) }
for _, t := range entry.tools {
tools = append(tools, chattool.NewWorkspaceMCPTool(t, getConn, invalidate))
}
return tools
}
// discoverWorkspaceMCPTools resolves the chat's workspace agent and
// lists the workspace MCP tools advertised by that agent. Results are
// cached per chat keyed on the agent ID so subsequent calls hit the
// cache. Returns nil (and never an error) on every failure mode so the
// caller can continue without MCP tools.
//
// This helper is shared between the initial discovery path and the
// mid-turn workspace binding path triggered after create_workspace or
// start_workspace bind a workspace to a chat that started without one.
func (p *Server) discoverWorkspaceMCPTools(
ctx context.Context,
logger slog.Logger,
chatID uuid.UUID,
workspaceCtx *turnWorkspaceContext,
) []fantasy.AgentTool {
// Fast path: check cache using the in-memory cached agent
// (ensureWorkspaceAgent is free when already loaded). This
// avoids a per-turn latest-build DB query on the common
// subsequent-turn path.
if agent, agentErr := workspaceCtx.getWorkspaceAgent(ctx); agentErr == nil {
if tools := p.loadCachedWorkspaceContext(
chatID, agent, workspaceCtx.getWorkspaceConn,
); tools != nil {
return tools
}
} // Cache miss, agent changed, or no cache: validate
// that the workspace still has a live agent before
// attempting a dial.
_, _, agentErr := workspaceCtx.workspaceAgentIDForConn(ctx)
if agentErr != nil {
if xerrors.Is(agentErr, errChatHasNoWorkspaceAgent) {
p.workspaceMCPToolsCache.Delete(chatID)
return nil
}
logger.Warn(ctx, "failed to resolve workspace agent for MCP tools",
slog.Error(agentErr))
return nil
}
// List workspace MCP tools via the agent conn.
conn, connErr := workspaceCtx.getWorkspaceConn(ctx)
if connErr != nil {
logger.Warn(ctx, "failed to get workspace conn for MCP tools",
slog.Error(connErr))
return nil
}
listCtx, cancel := context.WithTimeout(ctx, workspaceMCPDiscoveryTimeout)
defer cancel()
toolsResp, listErr := conn.ListMCPTools(listCtx)
if listErr != nil {
logger.Warn(ctx, "failed to list workspace MCP tools",
slog.Error(listErr))
return nil
}
// Cache the result for subsequent turns. Skip caching when
// the list is empty because the agent's MCP Connect may not
// have finished yet; caching an empty list would hide tools
// permanently.
if len(toolsResp.Tools) > 0 {
if agent, agentErr := workspaceCtx.getWorkspaceAgent(ctx); agentErr == nil {
p.workspaceMCPToolsCache.Store(chatID, &cachedWorkspaceMCPTools{
agentID: agent.ID,
tools: toolsResp.Tools,
})
}
}
invalidate := func() { p.workspaceMCPToolsCache.Delete(chatID) }
tools := make([]fantasy.AgentTool, 0, len(toolsResp.Tools))
for _, t := range toolsResp.Tools {
tools = append(tools, chattool.NewWorkspaceMCPTool(t, workspaceCtx.getWorkspaceConn, invalidate))
}
return tools
}
// primeWorkspaceMCPCache populates workspaceMCPToolsCache after the
// create_workspace or start_workspace tool finishes waiting for the
// workspace agent to become reachable. By the time it runs the agent
// is already Ready, so a single ListMCPTools call usually succeeds.
// When the agent's MCP server is still racing with agent startup,
// ListMCPTools may return an empty list (no error) on the first call;
// the primer retries with a short backoff up to
// workspaceMCPPrimeMaxWait so the generation action that follows the
// tool call sees the workspace MCP tools in the cache and does not need
// to dial again.
//
// Returns silently on every failure mode. The chat continues without
// workspace MCP tools when the agent does not advertise any within
// the budget. The next user turn re-runs top-of-turn discovery from
// scratch.
func (p *Server) primeWorkspaceMCPCache(
ctx context.Context,
logger slog.Logger,
chatID uuid.UUID,
workspaceCtx *turnWorkspaceContext,
) {
deadline := p.clock.Now().Add(workspaceMCPPrimeMaxWait)
attempt := 0
for {
attempt++
tools := p.discoverWorkspaceMCPTools(ctx, logger, chatID, workspaceCtx)
if len(tools) > 0 {
logger.Debug(ctx, "primed workspace MCP cache",
slog.F("chat_id", chatID),
slog.F("tool_count", len(tools)),
slog.F("attempts", attempt),
)
return
}
if ctx.Err() != nil {
return
}
if !p.clock.Now().Before(deadline) {
logger.Debug(ctx,
"workspace MCP cache primer gave up waiting for tools",
slog.F("chat_id", chatID),
slog.F("attempts", attempt),
)
return
}
timer := p.clock.NewTimer(workspaceMCPPrimeRetryInterval, "chatd", "workspace-mcp-prime")
select {
case <-timer.C:
case <-ctx.Done():
timer.Stop()
return
}
}
}
type turnWorkspaceContext struct {
server *Server
chatStateMu *sync.Mutex
currentChat *database.Chat
loadChatSnapshot func(context.Context, uuid.UUID) (database.Chat, error)
mu sync.Mutex
agent database.WorkspaceAgent
agentLoaded bool
conn workspacesdk.AgentConn
releaseConn func()
cachedWorkspaceID uuid.NullUUID
}
func (c *turnWorkspaceContext) close() {
c.clearCachedWorkspaceState()
}
func (c *turnWorkspaceContext) clearCachedWorkspaceState() {
c.mu.Lock()
releaseConn := c.releaseConn
c.agent = database.WorkspaceAgent{}
c.agentLoaded = false
c.conn = nil
c.releaseConn = nil
c.cachedWorkspaceID = uuid.NullUUID{}
c.mu.Unlock()
if releaseConn != nil {
releaseConn()
}
}
func (c *turnWorkspaceContext) setCurrentChat(chat database.Chat) {
c.chatStateMu.Lock()
*c.currentChat = chat
c.chatStateMu.Unlock()
}
func (c *turnWorkspaceContext) currentChatSnapshot() database.Chat {
c.chatStateMu.Lock()
chatSnapshot := *c.currentChat
c.chatStateMu.Unlock()
return chatSnapshot
}
func (c *turnWorkspaceContext) selectWorkspace(chat database.Chat) {
c.setCurrentChat(chat)
c.clearCachedWorkspaceState()
}
func (c *turnWorkspaceContext) currentWorkspaceMatches(expected uuid.NullUUID) (database.Chat, bool) {
chatSnapshot := c.currentChatSnapshot()
return chatSnapshot, nullUUIDEqual(chatSnapshot.WorkspaceID, expected)
}
func (c *turnWorkspaceContext) trackWorkspaceUsage(ctx context.Context, chatSnapshot database.Chat) {
if c.server == nil || !chatSnapshot.WorkspaceID.Valid {
return
}
logger := c.server.logger.With(
slog.F("chat_id", chatSnapshot.ID),
slog.F("owner_id", chatSnapshot.OwnerID),
)
c.server.trackWorkspaceUsage(ctx, chatSnapshot.ID, chatSnapshot.WorkspaceID, logger)
}
func nullUUIDEqual(left, right uuid.NullUUID) bool {
if left.Valid != right.Valid {
return false
}
if !left.Valid {
return true
}
return left.UUID == right.UUID
}
func (c *turnWorkspaceContext) persistBuildAgentBinding(
ctx context.Context,
chatSnapshot database.Chat,
buildID uuid.UUID,
agentID uuid.UUID,
) (database.Chat, error) {
updatedChat, err := c.server.db.UpdateChatBuildAgentBinding(
ctx,
database.UpdateChatBuildAgentBindingParams{
ID: chatSnapshot.ID,
BuildID: uuid.NullUUID{
UUID: buildID,
Valid: true,
},
AgentID: uuid.NullUUID{
UUID: agentID,
Valid: true,
},
},
)
if err != nil {
return chatSnapshot, xerrors.Errorf(
"update chat build/agent binding: %w", err,
)
}
c.setCurrentChat(updatedChat)
return updatedChat, nil
}
func (c *turnWorkspaceContext) getWorkspaceAgent(ctx context.Context) (database.WorkspaceAgent, error) {
_, agent, err := c.ensureWorkspaceAgent(ctx)
return agent, err
}
func (c *turnWorkspaceContext) ensureWorkspaceAgent(
ctx context.Context,
) (database.Chat, database.WorkspaceAgent, error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.agentLoaded {
chatSnapshot := c.currentChatSnapshot()
if nullUUIDEqual(c.cachedWorkspaceID, chatSnapshot.WorkspaceID) {
return chatSnapshot, c.agent, nil
}
c.agent = database.WorkspaceAgent{}
c.agentLoaded = false
}
return c.loadWorkspaceAgentLocked(ctx)
}
func (c *turnWorkspaceContext) loadWorkspaceAgentLocked(
ctx context.Context,
) (database.Chat, database.WorkspaceAgent, error) {
chatSnapshot := c.currentChatSnapshot()
for attempt := 0; attempt < 2; attempt++ {
if !chatSnapshot.WorkspaceID.Valid {
refreshedChat, refreshErr := refreshChatWorkspaceSnapshot(
ctx,
chatSnapshot,
c.loadChatSnapshot,
)
if refreshErr != nil {
return chatSnapshot, database.WorkspaceAgent{}, refreshErr
}
if refreshedChat.WorkspaceID.Valid {
c.setCurrentChat(refreshedChat)
chatSnapshot = refreshedChat
}
}
if !chatSnapshot.WorkspaceID.Valid {
return chatSnapshot, database.WorkspaceAgent{}, xerrors.New("no workspace is associated with this chat. Use the create_workspace tool to create one")
}
if chatSnapshot.AgentID.Valid {
agent, err := c.server.db.GetWorkspaceAgentByID(ctx, chatSnapshot.AgentID.UUID)
if err == nil {
latestChat, workspaceMatches := c.currentWorkspaceMatches(chatSnapshot.WorkspaceID)
if !workspaceMatches {
chatSnapshot = latestChat
continue
}
c.agent = agent
c.agentLoaded = true
c.cachedWorkspaceID = chatSnapshot.WorkspaceID
return chatSnapshot, c.agent, nil
}
if !xerrors.Is(err, sql.ErrNoRows) {
c.server.logger.Warn(ctx, "agent binding lookup failed, re-resolving",
slog.F("agent_id", chatSnapshot.AgentID.UUID),
slog.Error(err),
)
}
}
agents, err := c.server.db.GetWorkspaceAgentsInLatestBuildByWorkspaceID(
ctx,
chatSnapshot.WorkspaceID.UUID,
)
if err != nil {
return chatSnapshot, database.WorkspaceAgent{}, xerrors.Errorf(
"get workspace agents in latest build: %w",
err,
)
}
if len(agents) == 0 {
return chatSnapshot, database.WorkspaceAgent{}, errChatHasNoWorkspaceAgent
}
selected, err := agentselect.FindChatAgent(agents)
if err != nil {
return chatSnapshot, database.WorkspaceAgent{}, xerrors.Errorf(
"find chat agent: %w",
err,
)
}
build, err := c.server.db.GetLatestWorkspaceBuildByWorkspaceID(ctx, chatSnapshot.WorkspaceID.UUID)
if err != nil {
return chatSnapshot, database.WorkspaceAgent{}, xerrors.Errorf("get latest workspace build: %w", err)
}
updatedChat, err := c.persistBuildAgentBinding(
ctx,
chatSnapshot,
build.ID,
selected.ID,
)
if err != nil {
return chatSnapshot, database.WorkspaceAgent{}, err
}
chatSnapshot = updatedChat
latestChat, workspaceMatches := c.currentWorkspaceMatches(chatSnapshot.WorkspaceID)
if !workspaceMatches {
chatSnapshot = latestChat
continue
}
c.agent = selected
c.agentLoaded = true
c.cachedWorkspaceID = chatSnapshot.WorkspaceID
return chatSnapshot, c.agent, nil
}
return chatSnapshot, database.WorkspaceAgent{}, xerrors.New(
"chat workspace changed while resolving agent",
)
}
func (c *turnWorkspaceContext) latestWorkspaceAgentID(
ctx context.Context,
workspaceID uuid.UUID,
) (uuid.UUID, error) {
agents, err := c.server.db.GetWorkspaceAgentsInLatestBuildByWorkspaceID(
ctx,
workspaceID,
)
if err != nil {
return uuid.Nil, xerrors.Errorf(
"get workspace agents in latest build: %w",
err,
)
}
if len(agents) == 0 {
return uuid.Nil, errChatHasNoWorkspaceAgent
}
selected, err := agentselect.FindChatAgent(agents)
if err != nil {
return uuid.Nil, xerrors.Errorf(
"find chat agent: %w",
err,
)
}
return selected.ID, nil
}
func (c *turnWorkspaceContext) workspaceAgentIDForConn(
ctx context.Context,
) (database.Chat, uuid.UUID, error) {
for attempt := 0; attempt < 2; attempt++ {
chatSnapshot := c.currentChatSnapshot()
if !chatSnapshot.WorkspaceID.Valid || !chatSnapshot.AgentID.Valid {
updatedChat, agent, err := c.ensureWorkspaceAgent(ctx)
if err != nil {
return updatedChat, uuid.Nil, err
}
return updatedChat, agent.ID, nil
}
currentAgentID, err := c.latestWorkspaceAgentID(
ctx,
chatSnapshot.WorkspaceID.UUID,
)
if err != nil {
if xerrors.Is(err, errChatHasNoWorkspaceAgent) {
c.clearCachedWorkspaceState()
}
return chatSnapshot, uuid.Nil, err
}
latestChat, workspaceMatches := c.currentWorkspaceMatches(
chatSnapshot.WorkspaceID,
)
if !workspaceMatches {
continue
}
return latestChat, currentAgentID, nil
}
chatSnapshot := c.currentChatSnapshot()
return chatSnapshot, uuid.Nil, xerrors.New(
"chat workspace changed while resolving agent",
)
}
// getWorkspaceConnLocked returns the cached connection when it still matches
// the current workspace. When the workspace changed, it clears the stale
// cached state and returns the release func for the caller to run after
// unlocking.
func (c *turnWorkspaceContext) getWorkspaceConnLocked() (workspacesdk.AgentConn, func()) {
if c.conn == nil {
return nil, nil
}
chatSnapshot := c.currentChatSnapshot()
if nullUUIDEqual(c.cachedWorkspaceID, chatSnapshot.WorkspaceID) {
return c.conn, nil
}
agentRelease := c.releaseConn
c.agent = database.WorkspaceAgent{}
c.agentLoaded = false
c.conn = nil
c.releaseConn = nil
c.cachedWorkspaceID = uuid.NullUUID{}
return nil, agentRelease
}
// isAgentUnreachable reports whether the given agent row's
// status is disconnected or timed out. It uses timestamp
// arithmetic on the row. The "connecting" state is allowed
// through because it is normal after a fresh workspace build.
func isAgentUnreachable(now time.Time, agent database.WorkspaceAgent, inactiveTimeout time.Duration) bool {
status := agent.Status(now, inactiveTimeout)
return status.Status == database.WorkspaceAgentStatusDisconnected ||
status.Status == database.WorkspaceAgentStatusTimeout
}
func agentDisconnectedFor(now time.Time, agent database.WorkspaceAgent, inactiveTimeout time.Duration) (time.Duration, bool) {
status := agent.Status(now, inactiveTimeout)
if status.Status != database.WorkspaceAgentStatusDisconnected || status.DisconnectedAt == nil {
return 0, false
}
disconnectedFor := now.Sub(*status.DisconnectedAt)
if disconnectedFor < 0 {
disconnectedFor = 0
}
return disconnectedFor, true
}
func (c *turnWorkspaceContext) latestWorkspaceAgentNeedsRestart(
ctx context.Context,
workspaceID uuid.UUID,
) (bool, error) {
agentID, err := c.latestWorkspaceAgentID(ctx, workspaceID)
if err != nil {
if xerrors.Is(err, errChatHasNoWorkspaceAgent) {
return false, err
}
c.server.logger.Warn(ctx, "failed to resolve latest agent for timeout classification", slog.Error(err))
return false, nil
}
agent, err := c.server.db.GetWorkspaceAgentByID(ctx, agentID)
if err != nil {
c.server.logger.Warn(ctx, "failed to load latest agent for timeout classification",
slog.F("agent_id", agentID),
slog.Error(err),
)
return false, nil
}
disconnectedFor, disconnected := agentDisconnectedFor(c.server.clock.Now(), agent, c.server.agentInactiveDisconnectTimeout)
return disconnected && disconnectedFor >= agentDisconnectedRecoveryThreshold, nil
}
func (c *turnWorkspaceContext) externalAgentError(
ctx context.Context,
agent database.WorkspaceAgent,
fallback error,
) error {
isExternal, err := chattool.IsExternalWorkspaceAgent(ctx, c.server.db, agent)
if err != nil || !isExternal {
return fallback
}
return newChatExternalAgentUnavailableError(agent)
}
func (c *turnWorkspaceContext) externalAgentPreflightError(
ctx context.Context,
chatSnapshot database.Chat,
agent database.WorkspaceAgent,
) error {
// Mirror the cache-hit gate: only short-circuit on clearly offline
// states (Disconnected/Timeout). Connecting is allowed through so
// an external agent the user just started can still connect inside
// the normal dial window.
if !isAgentUnreachable(c.server.clock.Now(), agent, c.server.agentInactiveDisconnectTimeout) {
return nil
}
isExternal, err := chattool.IsExternalWorkspaceAgent(ctx, c.server.db, agent)
if err != nil || !isExternal || !chatSnapshot.WorkspaceID.Valid {
return nil
}
// Stale agent bindings rely on dialWithLazyValidation to discover
// replacement agents, so only skip the dial when this agent is still
// the latest selected chat agent for the workspace.
latestAgentID, err := c.latestWorkspaceAgentID(ctx, chatSnapshot.WorkspaceID.UUID)
if err != nil || latestAgentID != agent.ID {
return nil
}
return newChatExternalAgentUnavailableError(agent)
}
func (c *turnWorkspaceContext) getWorkspaceConn(ctx context.Context) (workspacesdk.AgentConn, error) {
if c.server.agentConnFn == nil {
return nil, xerrors.New("workspace agent connector is not configured")
}
for attempt := 0; attempt < 2; attempt++ {
c.mu.Lock()
currentConn, staleRelease := c.getWorkspaceConnLocked()
// Capture agentID in the same lock section as
// currentConn to prevent a TOCTOU race with
// concurrent clearCachedWorkspaceState calls.
agentID := c.agent.ID
c.mu.Unlock()
// Status check on cache hit: re-fetch the agent
// row so we see the latest heartbeat rather than
// a potentially stale cached copy.
if currentConn != nil {
chatSnapshot := c.currentChatSnapshot()
if agentID != uuid.Nil {
freshAgent, err := c.server.db.GetWorkspaceAgentByID(ctx, agentID)
if err != nil {
c.server.logger.Warn(ctx, "failed to re-fetch agent for status check",
slog.F("agent_id", agentID),
slog.Error(err),
)
// On DB error the check re-runs on the
// next tool call.
} else if _, disconnected := agentDisconnectedFor(
c.server.clock.Now(),
freshAgent,
c.server.agentInactiveDisconnectTimeout,
); disconnected {
c.clearCachedWorkspaceState()
continue
}
}
c.trackWorkspaceUsage(ctx, chatSnapshot)
return currentConn, nil
}
if staleRelease != nil {
staleRelease()
}
chatSnapshot, agent, err := c.ensureWorkspaceAgent(ctx)
if err != nil {
return nil, err
}
if err := c.externalAgentPreflightError(ctx, chatSnapshot, agent); err != nil {
return nil, err
}
// Wrap the dial in a timeout to bound the time spent
// waiting for an unreachable agent. The timeout scopes
// only dialWithLazyValidation, not ensureWorkspaceAgent
// or the post-dial binding steps.
dialCtx, dialCancel := context.WithTimeoutCause(ctx, c.server.dialTimeout, errChatDialTimeout)
dialResult, err := dialWithLazyValidation(
dialCtx,
agent.ID,
chatSnapshot.WorkspaceID.UUID,
DialFunc(c.server.agentConnFn),
func(ctx context.Context, workspaceID uuid.UUID) (uuid.UUID, error) {
return c.latestWorkspaceAgentID(ctx, workspaceID)
},
workspaceDialValidationDelay,
)
dialCancel()
if err != nil {
if xerrors.Is(err, errChatHasNoWorkspaceAgent) {
c.clearCachedWorkspaceState()
return nil, err
}
// Surface the dial timeout sentinel only when the
// parent context is still alive. If the parent was
// canceled (e.g. ErrInterrupted), its error must
// propagate unchanged so the chatloop can detect it.
if ctx.Err() == nil && errors.Is(context.Cause(dialCtx), errChatDialTimeout) {
c.clearCachedWorkspaceState()
needsRestart, statusErr := c.latestWorkspaceAgentNeedsRestart(ctx, chatSnapshot.WorkspaceID.UUID)
if statusErr != nil {
return nil, statusErr
}
if needsRestart {
return nil, c.externalAgentError(ctx, agent, errChatAgentDisconnected)
}
return nil, c.externalAgentError(ctx, agent, errChatDialTimeout)
}
return nil, err
}
agentConn := dialResult.Conn
agentRelease := dialResult.Release
if dialResult.WasSwitched {
build, err := c.server.db.GetLatestWorkspaceBuildByWorkspaceID(ctx, chatSnapshot.WorkspaceID.UUID)
if err != nil {
if agentRelease != nil {
agentRelease()
}
return nil, xerrors.Errorf("get latest workspace build: %w", err)
}
switchedAgent, err := c.server.db.GetWorkspaceAgentByID(ctx, dialResult.AgentID)
if err != nil {
if agentRelease != nil {
agentRelease()
}
return nil, xerrors.Errorf("get workspace agent by id: %w", err)
}
updatedChat, err := c.persistBuildAgentBinding(
ctx,
chatSnapshot,
build.ID,
switchedAgent.ID,
)
if err != nil {
if agentRelease != nil {
agentRelease()
}
return nil, err
}
chatSnapshot = updatedChat
c.mu.Lock()
c.agent = switchedAgent
c.agentLoaded = true
c.cachedWorkspaceID = chatSnapshot.WorkspaceID
c.mu.Unlock()
}
if _, workspaceMatches := c.currentWorkspaceMatches(chatSnapshot.WorkspaceID); !workspaceMatches {
if agentRelease != nil {
agentRelease()
}
c.clearCachedWorkspaceState()
continue
}
c.mu.Lock()
if c.conn == nil {
c.conn = agentConn
c.releaseConn = agentRelease
c.cachedWorkspaceID = chatSnapshot.WorkspaceID
var ancestorIDs []string
if chatSnapshot.ParentChatID.Valid {
ancestorIDs = append(ancestorIDs, chatSnapshot.ParentChatID.UUID.String())
}
ancestorJSON, marshalErr := json.Marshal(ancestorIDs)
if marshalErr != nil {
ancestorJSON = []byte("[]")
}
agentConn.SetExtraHeaders(http.Header{
workspacesdk.CoderChatIDHeader: {chatSnapshot.ID.String()},
workspacesdk.CoderAncestorChatIDsHeader: {string(ancestorJSON)},
})
c.mu.Unlock()
c.server.logger.Debug(ctx, "set chat headers on agent conn",
slog.F("chat_id", chatSnapshot.ID),
slog.F("ancestor_chat_ids", ancestorIDs),
slog.F("workspace_id", chatSnapshot.WorkspaceID.UUID),
slog.F("agent_id", dialResult.AgentID),
)
c.trackWorkspaceUsage(ctx, chatSnapshot)
return agentConn, nil
}
currentConn = c.conn
c.mu.Unlock()
if agentRelease != nil {
agentRelease()
}
c.trackWorkspaceUsage(ctx, chatSnapshot)
return currentConn, nil
}
return nil, xerrors.New("chat workspace changed while connecting")
}
// AgentConnFunc provides access to workspace agent connections.
type AgentConnFunc func(ctx context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error)
// SubscribeFn replaces the default local-only subscription with a
// multi-replica-aware implementation that merges pubsub notifications,
// remote relay streams, and local parts into a single event channel.
// When set, Subscribe delegates the event-merge goroutine to this
// function instead of using simple local forwarding.
//
// Parameters:
// - ctx: subscription lifetime context (canceled on unsubscribe).
// - params: all state needed to build the merged stream.
//
// Returns the merged event channel. Cleanup is driven by ctx
// cancellation — the merge goroutine tears down all relay state
// in its defer when ctx is done.
// Set by enterprise for HA deployments. Nil in AGPL single-replica.
type SubscribeFn func(
ctx context.Context,
params SubscribeFnParams,
) <-chan codersdk.ChatStreamEvent
// StatusNotification informs the enterprise relay manager of chat
// status changes so it can open or close relay connections.
type StatusNotification struct {
Status database.ChatStatus
WorkerID uuid.UUID
}
// SubscribeFnParams carries the state that the enterprise
// SubscribeFn implementation needs from the OSS Subscribe preamble.
type SubscribeFnParams struct {
ChatID uuid.UUID
Chat database.Chat
WorkerID uuid.UUID
StatusNotifications <-chan StatusNotification
RequestHeader http.Header
DB database.Store
Logger slog.Logger
}
// bufferedStreamPart is a buffered message_part event with its
// committed-message linkage. Parts that have not yet been claimed by
// a durable assistant message carry committedMessageID == 0 and are
// considered "in progress"; when an assistant message is published
// every still-in-progress part is claimed by that durable message
// ID, marking the part as redundant for any subscriber that will
// receive the durable message via REST or pubsub.
type bufferedStreamPart struct {
event codersdk.ChatStreamEvent
// committedMessageID is the durable assistant message ID that
// claimed this part, or 0 while the part belongs to the
// in-progress turn. snapshotBufferLocked drops parts with
// committedMessageID != 0 because the subscriber will receive
// the durable message through a different channel (REST snapshot,
// initial DB query in SubscribeAuthorized, or pubsub).
committedMessageID int64
}
type chatStreamState struct {
mu sync.Mutex
buffer []bufferedStreamPart
buffering bool
durableMessages []codersdk.ChatStreamEvent
durableEvictedBefore int64 // highest message ID evicted from durable cache
subscribers map[uuid.UUID]chan codersdk.ChatStreamEvent
bufferDropCount int64
bufferLastWarnAt time.Time
subscriberDropCount int64
subscriberLastWarnAt time.Time
// currentRetry records the current retry phase for late-joining
// same-replica subscribers. Nil when the stream is not waiting
// to retry.
currentRetry *codersdk.ChatStreamRetry
// bufferRetainedAt records when processing completed and
// the per-chat stream state entered the post-completion
// grace window. Zero while buffering is active. When
// non-zero, cleanupStreamIfIdle skips GC until the grace
// period expires so cross-replica relay subscribers can
// register without racing state deletion. The buffer
// itself does not deliver content here: every part is
// claimed by a durable assistant message before
// bufferRetainedAt is set, so snapshotBufferLocked
// returns no parts during the grace window.
bufferRetainedAt time.Time
}
// streamStateCollector exposes scrape-time gauges derived from
// p.chatStreams. Scrape cost is O(n) with a brief per-state mutex
// held for two len() reads; acceptable at typical scrape cadences.
type streamStateCollector struct {
server *Server
}
var (
streamsActiveDesc = prometheus.NewDesc(
"coderd_chatd_streams_active",
"Current number of chat stream state entries (in-flight plus retained).",
nil, nil,
)
streamBufferSizeMaxDesc = prometheus.NewDesc(
"coderd_chatd_stream_buffer_size_max",
"Maximum current buffer length across all chat streams.",
nil, nil,
)
streamBufferEventsDesc = prometheus.NewDesc(
"coderd_chatd_stream_buffer_events",
"Sum of current buffer lengths across all chat streams.",
nil, nil,
)
streamSubscribersDesc = prometheus.NewDesc(
"coderd_chatd_stream_subscribers",
"Current number of chat stream subscribers across all chat streams.",
nil, nil,
)
)
func (*streamStateCollector) Describe(ch chan<- *prometheus.Desc) {
ch <- streamsActiveDesc
ch <- streamBufferSizeMaxDesc
ch <- streamBufferEventsDesc
ch <- streamSubscribersDesc
}
func (c *streamStateCollector) Collect(ch chan<- prometheus.Metric) {
var active, totalEvents, maxBufLen, totalSubs int
c.server.chatStreams.Range(func(_, v any) bool {
state, ok := v.(*chatStreamState)
if !ok {
return true
}
active++
state.mu.Lock()
bufLen := len(state.buffer)
subs := len(state.subscribers)
state.mu.Unlock()
totalEvents += bufLen
totalSubs += subs
maxBufLen = max(maxBufLen, bufLen)
return true
})
ch <- prometheus.MustNewConstMetric(streamsActiveDesc, prometheus.GaugeValue, float64(active))
ch <- prometheus.MustNewConstMetric(streamBufferSizeMaxDesc, prometheus.GaugeValue, float64(maxBufLen))
ch <- prometheus.MustNewConstMetric(streamBufferEventsDesc, prometheus.GaugeValue, float64(totalEvents))
ch <- prometheus.MustNewConstMetric(streamSubscribersDesc, prometheus.GaugeValue, float64(totalSubs))
}
var (
// ErrInvalidModelConfigID indicates the requested model config does not exist.
ErrInvalidModelConfigID = xerrors.New("invalid model config ID")
// ErrEditedMessageNotFound indicates the edited message does not exist
// in the target chat.
ErrEditedMessageNotFound = xerrors.New("edited message not found")
// ErrEditedMessageNotUser indicates a non-user message edit attempt.
ErrEditedMessageNotUser = xerrors.New("only user messages can be edited")
// ErrChatArchived indicates the chat is archived and cannot
// accept modifications (messages, edits, promotions, or
// tool-result submissions).
ErrChatArchived = xerrors.New("chat is archived")
)
// UsageLimitExceededError indicates the user has exceeded their chat spend
// limit.
type UsageLimitExceededError struct {
LimitMicros int64
ConsumedMicros int64
PeriodEnd time.Time
}
func formatMicrosAsDollars(micros int64) string {
return "$" + decimal.NewFromInt(micros).Shift(-6).StringFixed(2)
}
func (e *UsageLimitExceededError) Error() string {
return fmt.Sprintf(
"usage limit exceeded: spent %s of %s limit, resets at %s",
formatMicrosAsDollars(e.ConsumedMicros),
formatMicrosAsDollars(e.LimitMicros),
e.PeriodEnd.Format(time.RFC3339),
)
}
// CreateOptions controls chat creation in the shared chat mutation path.
type CreateOptions struct {
OrganizationID uuid.UUID
OwnerID uuid.UUID
WorkspaceID uuid.NullUUID
BuildID uuid.NullUUID
AgentID uuid.NullUUID
ParentChatID uuid.NullUUID
RootChatID uuid.NullUUID
Title string
ModelConfigID uuid.UUID
ChatMode database.NullChatMode
PlanMode database.NullChatPlanMode
ClientType database.ChatClientType
SystemPrompt string
InitialUserContent []codersdk.ChatMessagePart
APIKeyID string
MCPServerIDs []uuid.UUID
Labels database.StringMap
DynamicTools json.RawMessage
}
// SendMessageBusyBehavior controls what happens when a chat is already active.
type SendMessageBusyBehavior string
const (
// SendMessageBusyBehaviorQueue queues user messages while the chat is busy.
SendMessageBusyBehaviorQueue SendMessageBusyBehavior = "queue"
// SendMessageBusyBehaviorInterrupt queues the message and
// interrupts the active run. The queued message is
// auto-promoted after the interrupted assistant response is
// persisted, ensuring correct message ordering.
SendMessageBusyBehaviorInterrupt SendMessageBusyBehavior = "interrupt"
)
// SendMessageOptions controls user message insertion with busy-state behavior.
type SendMessageOptions struct {
ChatID uuid.UUID
CreatedBy uuid.UUID
Content []codersdk.ChatMessagePart
ModelConfigID uuid.UUID
APIKeyID string
BusyBehavior SendMessageBusyBehavior
PlanMode *database.NullChatPlanMode
MCPServerIDs *[]uuid.UUID
}
// SendMessageResult contains the outcome of user message processing.
type SendMessageResult struct {
Queued bool
QueuedMessage *database.ChatQueuedMessage
Message database.ChatMessage
Chat database.Chat
}
// EditMessageOptions controls user message edits via soft-delete and re-insert.
type EditMessageOptions struct {
ChatID uuid.UUID
CreatedBy uuid.UUID
EditedMessageID int64
Content []codersdk.ChatMessagePart
APIKeyID string
// ModelConfigID, when non-zero, overrides the model used for
// the replacement user message. When set to uuid.Nil the
// original message's model is preserved.
ModelConfigID uuid.UUID
}
// EditMessageResult contains the replacement user message and chat status.
type EditMessageResult struct {
Message database.ChatMessage
Chat database.Chat
}
// PromoteQueuedOptions controls queued-message promotion.
type PromoteQueuedOptions struct {
ChatID uuid.UUID
CreatedBy uuid.UUID
QueuedMessageID int64
}
// PromoteQueuedResult contains post-promotion message metadata.
type PromoteQueuedResult struct {
// PromotedMessage is the inserted user message. For a chat that
// was running at promote time, the insertion is deferred to the
// worker's auto-promote and PromotedMessage is the zero value.
PromotedMessage database.ChatMessage
}
// CreateChat creates a chat with its initial history through
// chatstate.CreateChat. The new chat starts in `running` status per
// the RFC. Ownership hints wake chat workers.
func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.Chat, error) {
if opts.OrganizationID == uuid.Nil {
return database.Chat{}, xerrors.New("organization_id is required")
}
if opts.OwnerID == uuid.Nil {
return database.Chat{}, xerrors.New("owner_id is required")
}
if strings.TrimSpace(opts.Title) == "" {
return database.Chat{}, xerrors.New("title is required")
}
if len(opts.InitialUserContent) == 0 {
return database.Chat{}, xerrors.New("initial user content is required")
}
// Ensure MCPServerIDs is non-nil so pq.Array produces '{}'
// instead of SQL NULL, which violates the NOT NULL column
// constraint.
if opts.MCPServerIDs == nil {
opts.MCPServerIDs = []uuid.UUID{}
}
if opts.Labels == nil {
opts.Labels = database.StringMap{}
}
opts.ClientType = cmp.Or(opts.ClientType, database.ChatClientTypeApi)
if !opts.ClientType.Valid() {
return database.Chat{}, xerrors.Errorf("invalid client_type: %q", opts.ClientType)
}
// Resolve the deployment prompt before opening the transaction so
// chat creation does not hold one DB connection while waiting for
// another pool checkout.
deploymentPrompt := p.resolveDeploymentSystemPrompt(ctx)
// Usage limits gate the create before we touch the state machine.
if limitErr := p.checkUsageLimit(ctx, p.db, opts.OwnerID, uuid.NullUUID{UUID: opts.OrganizationID, Valid: true}); limitErr != nil {
return database.Chat{}, limitErr
}
labelsJSON, err := json.Marshal(opts.Labels)
if err != nil {
return database.Chat{}, xerrors.Errorf("marshal labels: %w", err)
}
userPrompt := SanitizePromptText(opts.SystemPrompt)
var workspaceAwareness string
if opts.WorkspaceID.Valid {
workspaceAwareness = workspaceAttachedAwareness
} else {
workspaceAwareness = workspaceDetachedAwareness
}
workspaceAwarenessContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
codersdk.ChatMessageText(workspaceAwareness),
})
if err != nil {
return database.Chat{}, xerrors.Errorf("marshal workspace awareness: %w", err)
}
userContent, err := chatprompt.MarshalParts(opts.InitialUserContent)
if err != nil {
return database.Chat{}, xerrors.Errorf("marshal initial user content: %w", err)
}
initialMessages := make([]chatstate.Message, 0, 4)
if deploymentPrompt != "" {
deploymentContent, marshalErr := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
codersdk.ChatMessageText(deploymentPrompt),
})
if marshalErr != nil {
return database.Chat{}, xerrors.Errorf("marshal deployment system prompt: %w", marshalErr)
}
initialMessages = append(initialMessages, systemMessage(deploymentContent, opts.ModelConfigID))
}
if userPrompt != "" {
userPromptContent, marshalErr := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
codersdk.ChatMessageText(userPrompt),
})
if marshalErr != nil {
return database.Chat{}, xerrors.Errorf("marshal user system prompt: %w", marshalErr)
}
initialMessages = append(initialMessages, systemMessage(userPromptContent, opts.ModelConfigID))
}
initialMessages = append(initialMessages, systemMessage(workspaceAwarenessContent, opts.ModelConfigID))
initialMessages = append(initialMessages, userMessageWithAPIKeyID(userContent, opts.ModelConfigID, opts.OwnerID, opts.APIKeyID))
result, err := chatstate.CreateChat(ctx, p.db, p.pubsub, chatstate.CreateChatInput{
OrganizationID: opts.OrganizationID,
OwnerID: opts.OwnerID,
WorkspaceID: opts.WorkspaceID,
BuildID: opts.BuildID,
AgentID: opts.AgentID,
ParentChatID: opts.ParentChatID,
RootChatID: opts.RootChatID,
LastModelConfigID: opts.ModelConfigID,
Title: opts.Title,
Mode: opts.ChatMode,
PlanMode: opts.PlanMode,
MCPServerIDs: opts.MCPServerIDs,
Labels: pqtype.NullRawMessage{
RawMessage: labelsJSON,
Valid: true,
},
DynamicTools: pqtype.NullRawMessage{
RawMessage: opts.DynamicTools,
Valid: len(opts.DynamicTools) > 0,
},
ClientType: opts.ClientType,
InitialMessages: initialMessages,
})
if err != nil {
return database.Chat{}, err
}
chat := result.Chat
if !chat.RootChatID.Valid && !chat.ParentChatID.Valid {
chat.RootChatID = uuid.NullUUID{UUID: chat.ID, Valid: true}
}
// Publish the sidebar watch event explicitly after chatstate has
// committed and emitted its own state-machine notifications. The
// watch endpoint is intentionally outside the RFC refactor scope.
p.publishChatPubsubEvent(chat, codersdk.ChatWatchEventKindCreated, nil)
return chat, nil
}
// SendMessage admits a user message through the chatstate.SendMessage
// transition. Pre-transition admission policy (usage limit, plan-mode
// metadata update, MCP server ID update, model-config resolution, queue
// cap) runs inside the same chatstate transaction via tx.Store() so
// everything commits or rolls back together.
func (p *Server) SendMessage(
ctx context.Context,
opts SendMessageOptions,
) (SendMessageResult, error) {
if opts.ChatID == uuid.Nil {
return SendMessageResult{}, xerrors.New("chat_id is required")
}
if len(opts.Content) == 0 {
return SendMessageResult{}, xerrors.New("content is required")
}
busyBehavior := opts.BusyBehavior
if busyBehavior == "" {
busyBehavior = SendMessageBusyBehaviorQueue
}
switch busyBehavior {
case SendMessageBusyBehaviorQueue, SendMessageBusyBehaviorInterrupt:
default:
return SendMessageResult{}, xerrors.Errorf("invalid busy behavior %q", opts.BusyBehavior)
}
content, err := chatprompt.MarshalParts(opts.Content)
if err != nil {
return SendMessageResult{}, xerrors.Errorf("marshal message content: %w", err)
}
requestedPlanMode := opts.PlanMode
requestedMCPServerIDs := opts.MCPServerIDs
var result SendMessageResult
machine := p.newChatMachine(opts.ChatID)
updateErr := machine.Update(ctx, func(tx *chatstate.Tx) error {
store := tx.Store()
lockedChat, err := store.GetChatByID(ctx, opts.ChatID)
if err != nil {
return xerrors.Errorf("load chat: %w", err)
}
if lockedChat.Archived {
return ErrChatArchived
}
// Enforce usage limits before any state-machine work.
if limitErr := p.checkUsageLimit(ctx, store, lockedChat.OwnerID, uuid.NullUUID{UUID: lockedChat.OrganizationID, Valid: true}); limitErr != nil {
return limitErr
}
if requestedPlanMode != nil {
lockedChat, err = store.UpdateChatPlanModeByID(ctx, database.UpdateChatPlanModeByIDParams{
PlanMode: *requestedPlanMode,
ID: opts.ChatID,
})
if err != nil {
return xerrors.Errorf("update chat plan mode: %w", err)
}
}
modelConfigID, err := resolveSendMessageModelConfigID(
ctx,
store,
lockedChat,
opts.ModelConfigID,
)
if err != nil {
return err
}
// Update MCP server IDs on the chat when explicitly provided.
// Explore child chats keep the spawn-time snapshot immutable.
if requestedMCPServerIDs != nil {
if isExploreSubagentMode(lockedChat.Mode) {
p.logger.Warn(ctx,
"ignoring explore subagent mcp server ids update, snapshot is immutable after spawn",
slog.F("chat_id", opts.ChatID),
)
} else {
lockedChat, err = store.UpdateChatMCPServerIDs(ctx, database.UpdateChatMCPServerIDsParams{
ID: opts.ChatID,
MCPServerIDs: *requestedMCPServerIDs,
})
if err != nil {
return xerrors.Errorf("update chat mcp server ids: %w", err)
}
}
}
// Queue capacity is enforced inside tx.SendMessage; this
// wrapper only propagates the typed error.
sendResult, err := tx.SendMessage(chatstate.SendMessageInput{
Message: userMessageWithAPIKeyID(content, modelConfigID, opts.CreatedBy, opts.APIKeyID),
BusyBehavior: busyBehaviorToChatState(busyBehavior),
})
if err != nil {
return err
}
if sendResult.QueuedMessage != nil {
result.Queued = true
result.QueuedMessage = sendResult.QueuedMessage
} else if len(sendResult.InsertedMessages) > 0 {
// The state machine prepends synthetic tool-result
// cancellation messages; the user message is always
// last in the inserted slice.
result.Message = sendResult.InsertedMessages[len(sendResult.InsertedMessages)-1]
}
// Capture the post-transition chat inside the same
// transaction so the returned chat and the watch event
// reflect the snapshot bump and status change produced by
// the transition itself.
refreshed, err := store.GetChatByID(ctx, opts.ChatID)
if err != nil {
return xerrors.Errorf("reload chat after send: %w", err)
}
result.Chat = refreshed
return nil
})
if updateErr != nil {
return SendMessageResult{}, updateErr
}
// Sidebar watch event keeps the chat list in sync. Stream side
// effects are handled by chat:update consumers.
p.publishChatPubsubEvent(result.Chat, codersdk.ChatWatchEventKindStatusChange, nil)
return result, nil
}
func (p *Server) checkUsageLimit(ctx context.Context, store database.Store, ownerID uuid.UUID, organizationID uuid.NullUUID) error {
status, err := ResolveUsageLimitStatus(ctx, store, ownerID, organizationID, time.Now())
if err != nil {
// Fail open: never block chat due to a limit-resolution failure.
p.logger.Warn(ctx, "usage limit check failed, allowing message",
slog.F("owner_id", ownerID),
slog.Error(err),
)
return nil
}
if status == nil {
return nil
}
// Block when current spend reaches or exceeds limit (>= ensures
// the user cannot start new conversations once the limit is hit).
if status.SpendLimitMicros != nil && status.CurrentSpend >= *status.SpendLimitMicros {
return &UsageLimitExceededError{
LimitMicros: *status.SpendLimitMicros,
ConsumedMicros: status.CurrentSpend,
PeriodEnd: status.PeriodEnd,
}
}
return nil
}
func chatdModelConfigLookupContext(ctx context.Context) context.Context {
//nolint:gocritic // Chat message admission needs daemon-scoped
// deployment-config reads for model config validation.
return dbauthz.AsChatd(ctx)
}
func resolveSendMessageModelConfigID(
ctx context.Context,
store database.Store,
chat database.Chat,
requested uuid.UUID,
) (uuid.UUID, error) {
if requested == uuid.Nil {
return resolveFallbackModelConfigID(ctx, store, chat.LastModelConfigID)
}
chatdCtx := chatdModelConfigLookupContext(ctx)
if _, err := store.GetChatModelConfigByID(chatdCtx, requested); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return uuid.Nil, xerrors.Errorf(
"%w: %s",
ErrInvalidModelConfigID,
requested,
)
}
return uuid.Nil, xerrors.Errorf(
"get requested model config %s: %w",
requested,
err,
)
}
return requested, nil
}
func resolveFallbackModelConfigID(
ctx context.Context,
store database.Store,
modelConfigID uuid.UUID,
) (uuid.UUID, error) {
chatdCtx := chatdModelConfigLookupContext(ctx)
if modelConfigID != uuid.Nil {
if _, err := store.GetChatModelConfigByID(chatdCtx, modelConfigID); err == nil {
return modelConfigID, nil
} else if !errors.Is(err, sql.ErrNoRows) {
return uuid.Nil, xerrors.Errorf(
"get chat model config %s: %w",
modelConfigID,
err,
)
}
}
defaultConfig, err := store.GetDefaultChatModelConfig(chatdCtx)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return uuid.Nil, xerrors.New("no default chat model config is available")
}
return uuid.Nil, xerrors.Errorf("get default chat model config: %w", err)
}
return defaultConfig.ID, nil
}
// EditMessage replaces an earlier user message and discards the
// active-history suffix through chatstate.EditMessage. Model-config
// override validation and usage-limit admission run in the same
// transaction as the state-machine transition.
func (p *Server) EditMessage(
ctx context.Context,
opts EditMessageOptions,
) (EditMessageResult, error) {
if opts.ChatID == uuid.Nil {
return EditMessageResult{}, xerrors.New("chat_id is required")
}
if opts.EditedMessageID <= 0 {
return EditMessageResult{}, xerrors.New("edited_message_id is required")
}
if len(opts.Content) == 0 {
return EditMessageResult{}, xerrors.New("content is required")
}
content, err := chatprompt.MarshalParts(opts.Content)
if err != nil {
return EditMessageResult{}, xerrors.Errorf("marshal message content: %w", err)
}
var (
result EditMessageResult
editedMsg database.ChatMessage
editedCutoffT time.Time
)
machine := p.newChatMachine(opts.ChatID)
updateErr := machine.Update(ctx, func(tx *chatstate.Tx) error {
store := tx.Store()
lockedChat, err := store.GetChatByID(ctx, opts.ChatID)
if err != nil {
return xerrors.Errorf("load chat: %w", err)
}
if lockedChat.Archived {
return ErrChatArchived
}
if limitErr := p.checkUsageLimit(ctx, store, lockedChat.OwnerID, uuid.NullUUID{UUID: lockedChat.OrganizationID, Valid: true}); limitErr != nil {
return limitErr
}
// Capture the target message for the post-commit debug
// cleanup hook below. The transition itself revalidates
// chat ownership and user-message constraints.
target, err := store.GetChatMessageByID(ctx, opts.EditedMessageID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return ErrEditedMessageNotFound
}
return xerrors.Errorf("get edited message: %w", err)
}
if target.ChatID != opts.ChatID {
return ErrEditedMessageNotFound
}
editedMsg = target
// Validate the optional model-config override up front so
// the user sees ErrInvalidModelConfigID instead of a
// foreign-key error from the message-insert path.
var modelOverride uuid.NullUUID
if opts.ModelConfigID != uuid.Nil {
if _, err := store.GetChatModelConfigByID(
chatdModelConfigLookupContext(ctx),
opts.ModelConfigID,
); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return xerrors.Errorf(
"%w: %s",
ErrInvalidModelConfigID,
opts.ModelConfigID,
)
}
return xerrors.Errorf(
"get requested model config %s: %w",
opts.ModelConfigID,
err,
)
}
modelOverride = uuid.NullUUID{UUID: opts.ModelConfigID, Valid: true}
}
editResult, err := tx.EditMessage(chatstate.EditMessageInput{
MessageID: opts.EditedMessageID,
CreatedBy: opts.CreatedBy,
Content: content,
ModelConfigIDOverride: modelOverride,
})
if err != nil {
if errors.Is(err, chatstate.ErrEditedMessageNotUser) {
return ErrEditedMessageNotUser
}
return err
}
result.Message = editResult.ReplacementMessage
// Capture the post-edit chat inside the same transaction so
// the returned chat and the debug-cleanup cutoff use the
// snapshot bump and updated_at stamped by the transition.
refreshed, err := store.GetChatByID(ctx, opts.ChatID)
if err != nil {
return xerrors.Errorf("reload chat after edit: %w", err)
}
result.Chat = refreshed
editedCutoffT = refreshed.UpdatedAt
return nil
})
if updateErr != nil {
return EditMessageResult{}, updateErr
}
// Sidebar watch event keeps the chat list responsive. Stream
// side effects are handled by chat:update consumers.
p.publishChatPubsubEvent(result.Chat, codersdk.ChatWatchEventKindStatusChange, nil)
// Editing can race with an interrupted worker still flushing its
// final debug writes. Run a short bounded retry loop so we converge
// quickly without relying on the much longer stale-finalization
// sweep. Source editCutoff from the DB-stamped updated_at returned
// by the post-edit chat row so the filter uses the same clock that
// stamps replacement-turn debug rows; subtract
// debugCleanupClockSkew so replica clock drift cannot let the retry
// delete a replacement turn's debug rows.
editCutoff := editedCutoffT.Add(-debugCleanupClockSkew)
p.scheduleDebugCleanup(
ctx,
"failed to delete chat debug rows after edit",
[]slog.Field{
slog.F("chat_id", opts.ChatID),
slog.F("edited_message_id", editedMsg.ID),
},
func(cleanupCtx context.Context, debugSvc *chatdebug.Service) error {
_, err := debugSvc.DeleteAfterMessageID(cleanupCtx, opts.ChatID, editedMsg.ID-1, editCutoff)
return err
},
)
return result, nil
}
// ErrArchiveRequiresRootChat is returned by [Server.ArchiveChat] and
// [Server.UnarchiveChat] when the supplied chat is a child chat.
// Archive state changes must always target the root chat so the
// whole family flips together.
var ErrArchiveRequiresRootChat = xerrors.New(
"chat archive state can only be changed on the root chat",
)
// ArchiveChat archives a root chat and every child in its family
// through the chatstate state machine. The transition is atomic over
// the whole family: either every member is archived or none is. The
// state machine only permits archive from the idle / error execution
// states (W, E0, E1); active members cause a state conflict that the
// HTTP handler maps to a client error.
//
// Child chats must not be archived independently. ArchiveChat
// rejects them with [ErrArchiveRequiresRootChat] so callers cannot
// silently break the parent-implies-child archive invariant.
func (p *Server) ArchiveChat(ctx context.Context, chat database.Chat) error {
if chat.ID == uuid.Nil {
return xerrors.New("chat_id is required")
}
if chat.ParentChatID.Valid {
return ErrArchiveRequiresRootChat
}
return p.setChatFamilyArchived(ctx, chat, true, codersdk.ChatWatchEventKindDeleted)
}
// UnarchiveChat unarchives a root chat and every child in its family
// through the chatstate state machine. Like ArchiveChat the cascade
// is atomic; ChildChat unarchive attempts are rejected with
// [ErrArchiveRequiresRootChat].
func (p *Server) UnarchiveChat(ctx context.Context, chat database.Chat) error {
if chat.ID == uuid.Nil {
return xerrors.New("chat_id is required")
}
if chat.ParentChatID.Valid {
return ErrArchiveRequiresRootChat
}
return p.setChatFamilyArchived(ctx, chat, false, codersdk.ChatWatchEventKindCreated)
}
// setChatFamilyArchived applies SetArchived(archived) to every chat
// in chat's family through chatstate. The transaction-captured
// family rows feed the post-commit debug cleanup and sidebar watch
// events. Callers must only invoke this for root chats.
//
//nolint:revive // Existing API takes the target archive state as a boolean.
func (p *Server) setChatFamilyArchived(
ctx context.Context,
chat database.Chat,
archived bool,
watchKind codersdk.ChatWatchEventKind,
) error {
if chat.ID == uuid.Nil {
return xerrors.New("chat_id is required")
}
if chat.ParentChatID.Valid {
return ErrArchiveRequiresRootChat
}
familyChats, err := chatstate.SetFamilyArchived(
ctx,
p.db,
p.pubsub,
chatstate.SetFamilyArchivedInput{
RootID: chat.ID,
Archived: archived,
},
)
if err != nil {
return err
}
if archived {
p.scheduleArchiveDebugCleanup(ctx, familyChats)
}
p.publishChatPubsubEvents(familyChats, watchKind)
return nil
}
// DeleteQueued removes a queued user message through the chatstate
// state machine. Stream side effects are handled by chat:update
// consumers.
func (p *Server) DeleteQueued(
ctx context.Context,
chatID uuid.UUID,
queuedMessageID int64,
) error {
if chatID == uuid.Nil {
return xerrors.New("chat_id is required")
}
machine := p.newChatMachine(chatID)
err := machine.Update(ctx, func(tx *chatstate.Tx) error {
_, err := tx.DeleteQueuedMessage(chatstate.DeleteQueuedMessageInput{
QueuedMessageID: queuedMessageID,
})
return err
})
return err
}
// PromoteQueued promotes a queued message through the chatstate state
// machine. From running / interrupting states the state machine
// transitions the chat to `interrupting` so the worker can drain the
// in-flight generation before promoting; from idle / error / requires
// action states it inserts the user message into history
// synchronously.
func (p *Server) PromoteQueued(
ctx context.Context,
opts PromoteQueuedOptions,
) (PromoteQueuedResult, error) {
if opts.ChatID == uuid.Nil {
return PromoteQueuedResult{}, xerrors.New("chat_id is required")
}
var (
result PromoteQueuedResult
refreshChat database.Chat
refreshedOK bool
)
machine := p.newChatMachine(opts.ChatID)
updateErr := machine.Update(ctx, func(tx *chatstate.Tx) error {
store := tx.Store()
lockedChat, err := store.GetChatByID(ctx, opts.ChatID)
if err != nil {
return xerrors.Errorf("load chat: %w", err)
}
if lockedChat.Archived {
return ErrChatArchived
}
promoteResult, err := tx.PromoteQueuedMessage(chatstate.PromoteQueuedMessageInput{
QueuedMessageID: opts.QueuedMessageID,
})
if err != nil {
return err
}
if promoteResult.InsertedMessage != nil {
result.PromotedMessage = *promoteResult.InsertedMessage
}
// Capture the chat inside the transaction so the watch event
// published below uses the snapshot bump and status change
// produced by the transition itself.
refreshed, err := store.GetChatByID(ctx, opts.ChatID)
if err != nil {
return xerrors.Errorf("reload chat after promote: %w", err)
}
refreshChat = refreshed
refreshedOK = true
return nil
})
if updateErr != nil {
return PromoteQueuedResult{}, updateErr
}
if refreshedOK {
p.publishChatPubsubEvent(refreshChat, codersdk.ChatWatchEventKindStatusChange, nil)
}
return result, nil
}
// SubmitToolResultsOptions controls tool result submission.
type SubmitToolResultsOptions struct {
ChatID uuid.UUID
UserID uuid.UUID
ModelConfigID uuid.UUID
Results []codersdk.ToolResult
DynamicTools json.RawMessage
}
// ToolResultValidationError indicates the submitted tool results
// failed validation (e.g. missing, duplicate, or unexpected IDs,
// or invalid JSON output).
type ToolResultValidationError struct {
Message string
Detail string
}
func (e *ToolResultValidationError) Error() string {
if e.Detail != "" {
return e.Message + ": " + e.Detail
}
return e.Message
}
// ToolResultStatusConflictError indicates the chat is not in the
// requires_action state expected for tool result submission.
type ToolResultStatusConflictError struct {
ActualStatus database.ChatStatus
}
func (e *ToolResultStatusConflictError) Error() string {
return fmt.Sprintf(
"chat status is %q, expected %q",
e.ActualStatus, database.ChatStatusRequiresAction,
)
}
// SubmitToolResults validates and persists client-provided tool
// results, returning the chat to running through the chatstate state
// machine. Validation runs inside the same transaction as the
// transition so the assistant message and pending tool calls cannot
// drift between reads.
func (p *Server) SubmitToolResults(
ctx context.Context,
opts SubmitToolResultsOptions,
) error {
var (
statusConflict *ToolResultStatusConflictError
refreshChat database.Chat
refreshedOK bool
)
machine := p.newChatMachine(opts.ChatID)
updateErr := machine.Update(ctx, func(tx *chatstate.Tx) error {
store := tx.Store()
locked, err := store.GetChatByID(ctx, opts.ChatID)
if err != nil {
return xerrors.Errorf("load chat: %w", err)
}
if locked.Archived {
return ErrChatArchived
}
toolResults := make([]chatstate.ToolResultInput, 0, len(opts.Results))
for _, r := range opts.Results {
toolResults = append(toolResults, chatstate.ToolResultInput{
ToolCallID: r.ToolCallID,
Output: r.Output,
IsError: r.IsError,
})
}
modelConfigID := opts.ModelConfigID
if modelConfigID == uuid.Nil {
modelConfigID = locked.LastModelConfigID
}
if _, err := tx.CompleteRequiresAction(chatstate.CompleteRequiresActionInput{
CreatedBy: opts.UserID,
ModelConfigID: modelConfigID,
Results: toolResults,
}); err != nil {
if !errors.Is(err, chatstate.ErrInvalidState) &&
locked.Status != database.ChatStatusRequiresAction &&
errors.Is(err, chatstate.ErrTransitionNotAllowed) {
statusConflict = &ToolResultStatusConflictError{
ActualStatus: locked.Status,
}
return statusConflict
}
return err
}
// Capture the chat inside the transaction so the watch event
// uses the snapshot bump and status change produced by the
// transition itself.
refreshed, err := store.GetChatByID(ctx, opts.ChatID)
if err != nil {
return xerrors.Errorf("reload chat after tool results: %w", err)
}
refreshChat = refreshed
refreshedOK = true
return nil
})
if updateErr != nil {
if statusConflict != nil {
return statusConflict
}
return translateToolResultValidationError(updateErr)
}
if refreshedOK {
p.publishChatPubsubEvent(refreshChat, codersdk.ChatWatchEventKindStatusChange, nil)
}
return nil
}
// translateToolResultValidationError converts a chatstate tool-result
// validation error into the legacy chatd.ToolResultValidationError
// shape so HTTP handlers preserve their existing response detail. If
// err is not a tool-result validation error, it is returned
// unchanged.
func translateToolResultValidationError(err error) error {
var v *chatstate.ToolResultValidationError
if !errors.As(err, &v) {
return err
}
switch {
case xerrors.Is(v, chatstate.ErrToolResultDuplicate):
return &ToolResultValidationError{
Message: "Duplicate tool_call_id in results.",
Detail: fmt.Sprintf("Duplicate tool call ID %q.", v.ToolCallID),
}
case xerrors.Is(v, chatstate.ErrToolResultMissing):
return &ToolResultValidationError{
Message: "Missing tool result.",
Detail: fmt.Sprintf("Missing result for tool call %q.", v.ToolCallID),
}
case xerrors.Is(v, chatstate.ErrToolResultUnexpected):
return &ToolResultValidationError{
Message: "Unexpected tool result.",
Detail: fmt.Sprintf("No pending tool call with ID %q.", v.ToolCallID),
}
case xerrors.Is(v, chatstate.ErrToolResultInvalidJSON):
return &ToolResultValidationError{
Message: "Tool result output must be valid JSON.",
Detail: fmt.Sprintf("Output for tool call %q is not valid JSON.", v.ToolCallID),
}
}
return err
}
// InterruptChat interrupts execution through the chatstate.Interrupt
// transition. Active runs land in `interrupting`; requires-action
// chats synthesize cancellation messages and return to running.
//
// Returns the post-transition chat and an error so callers can map
// state conflicts deliberately. Idle chats return a
// chatstate.ErrTransitionNotAllowed wrapper.
func (p *Server) InterruptChat(
ctx context.Context,
chat database.Chat,
) (database.Chat, error) {
if chat.ID == uuid.Nil {
return chat, xerrors.New("chat_id is required")
}
var refreshed database.Chat
machine := p.newChatMachine(chat.ID)
err := machine.Update(ctx, func(tx *chatstate.Tx) error {
if _, err := tx.Interrupt(chatstate.InterruptInput{
Reason: "Tool execution interrupted by user",
}); err != nil {
return err
}
// Capture the post-interrupt chat inside the transaction so
// the returned chat and the watch event reflect the snapshot
// bump and status change produced by the transition itself.
latest, err := tx.Store().GetChatByID(ctx, chat.ID)
if err != nil {
return xerrors.Errorf("reload chat after interrupt: %w", err)
}
refreshed = latest
return nil
})
if err != nil {
return chat, err
}
p.publishChatPubsubEvent(refreshed, codersdk.ChatWatchEventKindStatusChange, nil)
return refreshed, nil
}
// ReconcileInvalidStateChat recovers a chat stuck in an invalid
// execution-state combination by running the
// chatstate.ReconcileInvalidState transition. The chat lands in an
// error state (E0/E1); queued messages are preserved and pending
// dynamic-tool calls are closed with synthetic cancellations.
//
// Returns the post-transition chat. When the chat is not actually in an
// invalid state the transition returns a wrapped
// chatstate.ErrTransitionNotAllowed; a missing chat returns
// chatstate.ErrChatNotFound. Callers map these to deliberate HTTP
// responses.
func (p *Server) ReconcileInvalidStateChat(
ctx context.Context,
chat database.Chat,
) (database.Chat, error) {
if chat.ID == uuid.Nil {
return chat, xerrors.New("chat_id is required")
}
var refreshed database.Chat
machine := p.newChatMachine(chat.ID)
err := machine.Update(ctx, func(tx *chatstate.Tx) error {
if _, err := tx.ReconcileInvalidState(chatstate.ReconcileInvalidStateInput{}); err != nil {
return err
}
// Capture the post-reconcile chat inside the transaction so
// the returned chat and the watch event reflect the snapshot
// bump and status change produced by the transition itself.
latest, err := tx.Store().GetChatByID(ctx, chat.ID)
if err != nil {
return xerrors.Errorf("reload chat after reconcile: %w", err)
}
refreshed = latest
return nil
})
if err != nil {
return chat, err
}
p.publishChatPubsubEvent(refreshed, codersdk.ChatWatchEventKindStatusChange, nil)
return refreshed, nil
}
const manualTitleMessageWindowLimit = 50
var ErrManualTitleRegenerationInProgress = xerrors.New(
"manual title regeneration already in progress",
)
type manualTitleCandidateResult struct {
title string
modelConfig database.ChatModelConfig
usage fantasy.Usage
activeAPIKeyID string
hasMessages bool
}
type manualTitleGenerationError struct {
cause error
modelConfig database.ChatModelConfig
usage fantasy.Usage
activeAPIKeyID string
}
// generatedChatTitle carries the title produced by the detached
// automatic title-generation goroutine. maybeGenerateChatTitle stores
// the generated title here so tests can observe it without a database
// read; the title_change pubsub event it publishes remains the source of
// truth for clients.
type generatedChatTitle struct {
mu sync.RWMutex
title string
}
func (t *generatedChatTitle) Store(title string) {
if t == nil || title == "" {
return
}
t.mu.Lock()
t.title = title
t.mu.Unlock()
}
func (t *generatedChatTitle) Load() (string, bool) {
if t == nil {
return "", false
}
t.mu.RLock()
defer t.mu.RUnlock()
if t.title == "" {
return "", false
}
return t.title, true
}
func (e *manualTitleGenerationError) Error() string {
return e.cause.Error()
}
func (e *manualTitleGenerationError) Unwrap() error {
return e.cause
}
var manualTitleLockWorkerID = uuid.MustParse(
"00000000-0000-0000-0000-000000000001",
)
const manualTitleLockStaleAfter = time.Minute
func isFreshManualTitleLock(chat database.Chat, now time.Time) bool {
if !chat.WorkerID.Valid || chat.WorkerID.UUID != manualTitleLockWorkerID {
return false
}
leaseAt := chat.HeartbeatAt
if !leaseAt.Valid {
leaseAt = chat.StartedAt
}
return leaseAt.Valid && leaseAt.Time.After(now.Add(-manualTitleLockStaleAfter))
}
// updateChatStatusPreserveUpdatedAt applies internal lock transitions without
// changing chat recency, because chat list ordering uses updated_at.
func updateChatStatusPreserveUpdatedAt(
ctx context.Context,
store database.Store,
chat database.Chat,
workerID uuid.NullUUID,
startedAt sql.NullTime,
heartbeatAt sql.NullTime,
) (database.Chat, error) {
return store.UpdateChatStatusPreserveUpdatedAt(
ctx,
database.UpdateChatStatusPreserveUpdatedAtParams{
ID: chat.ID,
Status: chat.Status,
WorkerID: workerID,
StartedAt: startedAt,
HeartbeatAt: heartbeatAt,
LastError: chat.LastError,
UpdatedAt: chat.UpdatedAt,
},
)
}
func (p *Server) acquireManualTitleLock(ctx context.Context, chatID uuid.UUID) error {
now := time.Now()
return p.db.InTx(func(tx database.Store) error {
lockedChat, err := tx.GetChatByIDForUpdate(ctx, chatID)
if err != nil {
return xerrors.Errorf("lock chat for manual title regeneration: %w", err)
}
// Only a fresh manual lock or a chat without a real worker should
// block title regeneration. Running chats with a real worker may
// regenerate their title concurrently, and last write wins.
hasRealWorker := lockedChat.Status == database.ChatStatusRunning &&
lockedChat.WorkerID.Valid &&
lockedChat.WorkerID.UUID != manualTitleLockWorkerID
if lockedChat.Status == database.ChatStatusPending ||
(lockedChat.Status == database.ChatStatusRunning && !hasRealWorker) ||
isFreshManualTitleLock(lockedChat, now) {
return ErrManualTitleRegenerationInProgress
}
if hasRealWorker {
return nil
}
_, err = updateChatStatusPreserveUpdatedAt(
ctx,
tx,
lockedChat,
uuid.NullUUID{UUID: manualTitleLockWorkerID, Valid: true},
sql.NullTime{Time: now, Valid: true},
sql.NullTime{},
)
if err != nil {
return xerrors.Errorf("mark chat for manual title regeneration: %w", err)
}
return nil
}, database.DefaultTXOptions().WithID("chat_title_regenerate_lock"))
}
func (p *Server) releaseManualTitleLock(ctx context.Context, chatID uuid.UUID) {
cleanupCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second)
defer cancel()
err := p.db.InTx(func(tx database.Store) error {
lockedChat, err := tx.GetChatByIDForUpdate(cleanupCtx, chatID)
if err != nil {
return xerrors.Errorf("lock chat to release manual title regeneration: %w", err)
}
if !lockedChat.WorkerID.Valid || lockedChat.WorkerID.UUID != manualTitleLockWorkerID {
return nil
}
_, err = updateChatStatusPreserveUpdatedAt(
cleanupCtx,
tx,
lockedChat,
uuid.NullUUID{},
sql.NullTime{},
sql.NullTime{},
)
if err != nil {
return xerrors.Errorf("clear manual title regeneration marker: %w", err)
}
return nil
}, database.DefaultTXOptions().WithID("chat_title_regenerate_unlock"))
if err != nil {
p.logger.Warn(cleanupCtx, "failed to release manual title regeneration marker",
slog.F("chat_id", chatID),
slog.Error(err),
)
}
}
// RegenerateChatTitle regenerates a chat title from the chat's visible
// messages, persists it when it changes, and broadcasts the update.
func (p *Server) RegenerateChatTitle(
ctx context.Context,
chat database.Chat,
) (database.Chat, error) {
// Reuse chatd's scoped auth context for deployment-config lookups while
// keeping chat ownership authorization at the HTTP layer.
//nolint:gocritic // Non-admin users need chatd-scoped config reads here.
chatdCtx := dbauthz.AsChatd(ctx)
keys, err := p.resolveUserProviderAPIKeys(chatdCtx, chat.OwnerID, uuid.Nil)
if err != nil {
keys = chatprovider.ProviderAPIKeys{}
}
if err := p.acquireManualTitleLock(ctx, chat.ID); err != nil {
return database.Chat{}, err
}
defer p.releaseManualTitleLock(chatdCtx, chat.ID)
updatedChat, err := p.regenerateChatTitleWithStore(
chatdCtx,
p.db,
chat,
keys,
)
if err != nil {
return database.Chat{}, p.recordManualTitleGenerationFailure(ctx, chat, err)
}
return updatedChat, nil
}
// RenameChatTitle persists a user-supplied chat title.
func (p *Server) RenameChatTitle(
ctx context.Context,
chat database.Chat,
newTitle string,
) (updated database.Chat, wrote bool, err error) {
//nolint:gocritic // Lock release needs chatd-scoped writes.
chatdCtx := dbauthz.AsChatd(ctx)
if err := p.acquireManualTitleLock(ctx, chat.ID); err != nil {
return database.Chat{}, false, err
}
defer p.releaseManualTitleLock(chatdCtx, chat.ID)
currentChat, err := p.db.GetChatByID(ctx, chat.ID)
if err != nil {
return database.Chat{}, false, xerrors.Errorf("get chat for rename: %w", err)
}
if newTitle == currentChat.Title {
return currentChat, false, nil
}
updatedChat, err := p.db.UpdateChatTitleByID(ctx, database.UpdateChatTitleByIDParams{
ID: chat.ID,
Title: newTitle,
})
if err != nil {
return database.Chat{}, false, xerrors.Errorf("update chat title: %w", err)
}
return updatedChat, true, nil
}
// PublishTitleChange broadcasts a title_change event for the given chat.
func (p *Server) PublishTitleChange(chat database.Chat) {
p.publishChatPubsubEvent(chat, codersdk.ChatWatchEventKindTitleChange, nil)
}
// ProposeChatTitle generates a title suggestion from the chat's visible messages without persisting it.
func (p *Server) ProposeChatTitle(
ctx context.Context,
chat database.Chat,
) (string, error) {
//nolint:gocritic // Non-admin users need chatd-scoped config reads here.
chatdCtx := dbauthz.AsChatd(ctx)
keys, err := p.resolveUserProviderAPIKeys(chatdCtx, chat.OwnerID, uuid.Nil)
if err != nil {
keys = chatprovider.ProviderAPIKeys{}
}
if err := p.acquireManualTitleLock(ctx, chat.ID); err != nil {
return "", err
}
defer p.releaseManualTitleLock(chatdCtx, chat.ID)
title, err := p.proposeChatTitleWithStore(chatdCtx, p.db, chat, keys)
if err != nil {
return "", p.recordManualTitleGenerationFailure(ctx, chat, err)
}
return title, nil
}
func (p *Server) recordManualTitleGenerationFailure(
ctx context.Context,
chat database.Chat,
err error,
) error {
var generationErr *manualTitleGenerationError
if !errors.As(err, &generationErr) {
return err
}
//nolint:gocritic // Failure accounting still needs chatd-scoped config reads.
recordCtx, recordCancel := context.WithTimeout(
dbauthz.AsChatd(context.WithoutCancel(ctx)),
5*time.Second,
)
defer recordCancel()
if _, recordErr := recordManualTitleUsage(
recordCtx,
p.db,
chat,
generationErr.modelConfig,
generationErr.usage,
generationErr.activeAPIKeyID,
"",
); recordErr != nil {
return errors.Join(
generationErr,
xerrors.Errorf("record manual title usage: %w", recordErr),
)
}
return generationErr
}
// generateManualTitleCandidate performs only model generation and returns the
// candidate plus accounting metadata. Endpoint-specific commit paths are
// responsible for recording usage and deciding whether to persist the title.
// The context may carry the caller's delegated API key for manual title routes.
func (p *Server) generateManualTitleCandidate(
ctx context.Context,
store database.Store,
chat database.Chat,
keys chatprovider.ProviderAPIKeys,
) (manualTitleCandidateResult, error) {
if limitErr := p.checkUsageLimit(ctx, store, chat.OwnerID, uuid.NullUUID{UUID: chat.OrganizationID, Valid: true}); limitErr != nil {
return manualTitleCandidateResult{}, limitErr
}
headMessages, err := store.GetChatMessagesByChatIDAscPaginated(
ctx,
database.GetChatMessagesByChatIDAscPaginatedParams{
ChatID: chat.ID,
AfterID: 0,
LimitVal: manualTitleMessageWindowLimit,
},
)
if err != nil {
return manualTitleCandidateResult{}, xerrors.Errorf("get head chat messages: %w", err)
}
tailMessages, err := store.GetChatMessagesByChatIDDescPaginated(
ctx,
database.GetChatMessagesByChatIDDescPaginatedParams{
ChatID: chat.ID,
BeforeID: 0,
LimitVal: manualTitleMessageWindowLimit,
},
)
if err != nil {
return manualTitleCandidateResult{}, xerrors.Errorf("get tail chat messages: %w", err)
}
messages := mergeManualTitleMessages(headMessages, tailMessages)
if len(messages) == 0 {
return manualTitleCandidateResult{}, nil
}
modelOpts := modelBuildOptionsFromMessages(messages)
// Manual title routes can run over messages that lack API key attribution.
// Fall back to the authenticated caller's delegated key for AI Gateway routing.
if modelOpts.ActiveAPIKeyID == "" {
if apiKeyID, ok := aibridge.DelegatedAPIKeyIDFromContext(ctx); ok {
modelOpts.ActiveAPIKeyID = apiKeyID
}
}
model, modelConfig, modelKeys, err := p.resolveManualTitleModel(ctx, store, chat, keys, modelOpts)
result := manualTitleCandidateResult{
modelConfig: modelConfig,
activeAPIKeyID: modelOpts.ActiveAPIKeyID,
hasMessages: true,
}
if err != nil {
return result, err
}
titleCtx := ctx
titleModel := model
finishDebugRun := func(error) {}
if debugSvc := p.debugService(); debugSvc != nil && debugSvc.IsEnabled(ctx, chat.ID, chat.OwnerID) {
titleCtx, titleModel, finishDebugRun = p.prepareManualTitleDebugRun(
ctx,
debugSvc,
chat,
modelConfig,
modelKeys,
modelOpts,
messages,
model,
)
}
title, usage, err := generateManualTitle(titleCtx, messages, titleModel)
finishDebugRun(err)
result.title = title
result.usage = usage
if err != nil {
wrappedErr := xerrors.Errorf("generate manual title: %w", err)
if usage == (fantasy.Usage{}) {
return result, wrappedErr
}
return result, &manualTitleGenerationError{
cause: wrappedErr,
modelConfig: modelConfig,
usage: usage,
activeAPIKeyID: modelOpts.ActiveAPIKeyID,
}
}
return result, nil
}
func (p *Server) proposeChatTitleWithStore(
ctx context.Context,
store database.Store,
chat database.Chat,
keys chatprovider.ProviderAPIKeys,
) (string, error) {
result, err := p.generateManualTitleCandidate(ctx, store, chat, keys)
if err != nil {
return "", err
}
if !result.hasMessages {
return "", nil
}
recordCtx, recordCancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second)
defer recordCancel()
if _, recordErr := recordManualTitleUsage(
recordCtx,
store,
chat,
result.modelConfig,
result.usage,
result.activeAPIKeyID,
"",
); recordErr != nil {
return "", xerrors.Errorf("record manual title usage: %w", recordErr)
}
return result.title, nil
}
func (p *Server) regenerateChatTitleWithStore(
ctx context.Context,
store database.Store,
chat database.Chat,
keys chatprovider.ProviderAPIKeys,
) (database.Chat, error) {
result, err := p.generateManualTitleCandidate(ctx, store, chat, keys)
if err != nil {
return database.Chat{}, err
}
if !result.hasMessages {
return chat, nil
}
recordCtx, recordCancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second)
defer recordCancel()
updatedChat, recordErr := recordManualTitleUsage(
recordCtx,
store,
chat,
result.modelConfig,
result.usage,
result.activeAPIKeyID,
result.title,
)
if recordErr != nil {
if result.title != "" {
return database.Chat{}, xerrors.Errorf("record manual title usage and update chat title: %w", recordErr)
}
return database.Chat{}, xerrors.Errorf("record manual title usage: %w", recordErr)
}
if updatedChat.Title == chat.Title {
return updatedChat, nil
}
p.publishChatPubsubEvent(updatedChat, codersdk.ChatWatchEventKindTitleChange, nil)
return updatedChat, nil
}
func (p *Server) prepareManualTitleDebugRun(
ctx context.Context,
debugSvc *chatdebug.Service,
chat database.Chat,
modelConfig database.ChatModelConfig,
keys chatprovider.ProviderAPIKeys,
modelOpts modelBuildOptions,
messages []database.ChatMessage,
fallbackModel fantasy.LanguageModel,
) (context.Context, fantasy.LanguageModel, func(error)) {
titleCtx := ctx
titleModel := fallbackModel
finishDebugRun := func(error) {}
route, routeErr := p.resolveModelRouteForConfig(ctx, chat.OwnerID, modelConfig, keys)
debugOpts := modelOpts
debugOpts.RecordHTTP = true
var debugModelErr error
var debugModel fantasy.LanguageModel
if routeErr != nil {
debugModelErr = routeErr
} else {
debugModel, debugModelErr = p.newModel(ctx, modelClientRequest{
Chat: chat,
ModelName: modelConfig.Model,
UserAgent: chatprovider.UserAgent(),
ExtraHeaders: chatprovider.CoderHeaders(chat),
}, route, debugOpts)
}
switch {
case debugModelErr != nil:
p.logger.Warn(ctx, "failed to create debug-aware manual title model",
slog.F("chat_id", chat.ID),
slog.F("provider", modelConfig.Provider),
slog.F("model", modelConfig.Model),
slog.Error(debugModelErr),
)
case debugModel == nil:
p.logger.Warn(ctx, "manual title debug model creation returned nil",
slog.F("chat_id", chat.ID),
slog.F("provider", modelConfig.Provider),
slog.F("model", modelConfig.Model),
)
default:
titleModel = chatdebug.WrapModel(debugModel, debugSvc, chatdebug.RecorderOptions{
ChatID: chat.ID,
OwnerID: chat.OwnerID,
Provider: modelConfig.Provider,
Model: modelConfig.Model,
})
}
var historyTipMessageID int64
if len(messages) > 0 {
historyTipMessageID = messages[len(messages)-1].ID
}
// Derive a first_message label from the first user message.
var firstUserLabel string
for _, msg := range messages {
if msg.Role == database.ChatMessageRoleUser {
if parts, parseErr := chatprompt.ParseContent(msg); parseErr == nil {
firstUserLabel = contentBlocksToText(parts)
}
break
}
}
if firstUserLabel == "" {
firstUserLabel = "Title generation"
}
seedSummary := chatdebug.SeedSummary(
chatdebug.TruncateLabel(firstUserLabel, chatdebug.MaxLabelLength),
)
createRunCtx, createRunCancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second)
debugRun, createRunErr := debugSvc.CreateRun(createRunCtx, chatdebug.CreateRunParams{
ChatID: chat.ID,
ModelConfigID: modelConfig.ID,
Provider: modelConfig.Provider,
Model: modelConfig.Model,
Kind: chatdebug.KindTitleGeneration,
Status: chatdebug.StatusInProgress,
HistoryTipMessageID: historyTipMessageID,
TriggerMessageID: 0,
Summary: seedSummary,
})
createRunCancel()
if createRunErr != nil {
p.logger.Warn(ctx, "failed to create manual title debug run",
slog.F("chat_id", chat.ID),
slog.F("provider", modelConfig.Provider),
slog.F("model", modelConfig.Model),
slog.Error(createRunErr),
)
return titleCtx, titleModel, finishDebugRun
}
runContext := chatdebugRunContext(debugRun)
titleCtx = chatdebug.ContextWithRun(titleCtx, &runContext)
finishDebugRun = func(generateErr error) {
if finalizeErr := debugSvc.FinalizeRun(ctx, chatdebug.FinalizeRunParams{
RunID: debugRun.ID,
ChatID: debugRun.ChatID,
Status: chatdebug.ClassifyError(generateErr),
SeedSummary: seedSummary,
}); finalizeErr != nil {
p.logger.Warn(ctx, "failed to finalize manual title debug run",
slog.F("chat_id", chat.ID),
slog.F("run_id", debugRun.ID),
slog.Error(finalizeErr),
)
}
}
return titleCtx, titleModel, finishDebugRun
}
func chatdebugRunContext(run database.ChatDebugRun) chatdebug.RunContext {
runContext := chatdebug.RunContext{
RunID: run.ID,
ChatID: run.ChatID,
Kind: chatdebug.RunKind(run.Kind),
}
if run.RootChatID.Valid {
runContext.RootChatID = run.RootChatID.UUID
}
if run.ParentChatID.Valid {
runContext.ParentChatID = run.ParentChatID.UUID
}
if run.ModelConfigID.Valid {
runContext.ModelConfigID = run.ModelConfigID.UUID
}
if run.TriggerMessageID.Valid {
runContext.TriggerMessageID = run.TriggerMessageID.Int64
}
if run.HistoryTipMessageID.Valid {
runContext.HistoryTipMessageID = run.HistoryTipMessageID.Int64
}
if run.Provider.Valid {
runContext.Provider = run.Provider.String
}
if run.Model.Valid {
runContext.Model = run.Model.String
}
return runContext
}
func deriveChatDebugSeed(messages []database.ChatMessage) (
triggerMessageID int64,
historyTipMessageID int64,
triggerLabel string,
) {
for i := len(messages) - 1; i >= 0; i-- {
if messages[i].Role != database.ChatMessageRoleUser {
continue
}
triggerMessageID = messages[i].ID
if parts, parseErr := chatprompt.ParseContent(messages[i]); parseErr == nil {
triggerLabel = contentBlocksToText(parts)
}
break
}
if len(messages) > 0 {
historyTipMessageID = messages[len(messages)-1].ID
}
return triggerMessageID, historyTipMessageID, triggerLabel
}
func prepareChatTurnDebugRun(
ctx context.Context,
logger slog.Logger,
chat database.Chat,
modelConfig database.ChatModelConfig,
debugSvc *chatdebug.Service,
debugProvider string,
debugModel string,
triggerMessageID int64,
historyTipMessageID int64,
triggerLabel string,
) (context.Context, func(error, any)) {
finishDebugRun := func(error, any) {}
if debugSvc == nil {
return ctx, finishDebugRun
}
seedSummary := chatdebug.SeedSummary(
chatdebug.TruncateLabel(triggerLabel, chatdebug.MaxLabelLength),
)
rootChatID := uuid.Nil
if chat.RootChatID.Valid {
rootChatID = chat.RootChatID.UUID
}
parentChatID := uuid.Nil
if chat.ParentChatID.Valid {
parentChatID = chat.ParentChatID.UUID
}
// Debug instrumentation must never block the user turn. Detach
// from the chat-processing context and bound the insert so a slow
// or locked DB makes debug logging degrade silently rather than
// stalling chat processing. Matches the pattern used by
// prepareManualTitleDebugRun.
createRunCtx, createRunCancel := context.WithTimeout(
context.WithoutCancel(ctx), debugCreateRunTimeout,
)
run, createRunErr := debugSvc.CreateRun(createRunCtx, chatdebug.CreateRunParams{
ChatID: chat.ID,
RootChatID: rootChatID,
ParentChatID: parentChatID,
ModelConfigID: modelConfig.ID,
TriggerMessageID: triggerMessageID,
HistoryTipMessageID: historyTipMessageID,
Kind: chatdebug.KindChatTurn,
Status: chatdebug.StatusInProgress,
Provider: debugProvider,
Model: debugModel,
Summary: seedSummary,
})
createRunCancel()
if createRunErr != nil {
logger.Warn(ctx, "failed to create chat debug run",
slog.F("chat_id", chat.ID),
slog.Error(createRunErr),
)
return ctx, finishDebugRun
}
runCtx := chatdebug.ContextWithRun(ctx, &chatdebug.RunContext{
RunID: run.ID,
ChatID: chat.ID,
RootChatID: rootChatID,
ParentChatID: parentChatID,
ModelConfigID: modelConfig.ID,
TriggerMessageID: triggerMessageID,
HistoryTipMessageID: historyTipMessageID,
Kind: chatdebug.KindChatTurn,
Provider: debugProvider,
Model: debugModel,
})
finishDebugRun = func(loopErr error, panicValue any) {
status := chatdebug.ClassifyError(loopErr)
switch {
case panicValue != nil:
status = chatdebug.StatusError
case errors.Is(loopErr, chatloop.ErrInterrupted):
status = chatdebug.StatusInterrupted
case errors.Is(loopErr, chatloop.ErrDynamicToolCall):
// Dynamic tool calls are a successful pause; the run completed
// its model round-trip.
status = chatdebug.StatusCompleted
}
if finalizeErr := debugSvc.FinalizeRun(runCtx, chatdebug.FinalizeRunParams{
RunID: run.ID,
ChatID: chat.ID,
Status: status,
SeedSummary: seedSummary,
}); finalizeErr != nil {
logger.Warn(ctx, "failed to finalize chat debug run",
slog.F("chat_id", chat.ID),
slog.F("run_id", run.ID),
slog.Error(finalizeErr),
)
}
}
return runCtx, finishDebugRun
}
func (p *Server) resolveManualTitleModel(
ctx context.Context,
store database.Store,
chat database.Chat,
keys chatprovider.ProviderAPIKeys,
modelOpts modelBuildOptions,
) (fantasy.LanguageModel, database.ChatModelConfig, chatprovider.ProviderAPIKeys, error) {
overrideConfig, overrideModel, overrideKeys, _, overrideSet, overrideErr := p.resolveTitleGenerationModelOverride(
ctx,
chat,
keys,
modelOpts,
)
if overrideErr != nil {
if overrideSet {
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, xerrors.Errorf(
"resolve manual title generation model override: %w",
overrideErr,
)
}
p.logger.Debug(ctx, "failed to resolve title generation model override for manual title",
slog.F("chat_id", chat.ID),
slog.Error(overrideErr),
)
} else if overrideSet {
return overrideModel, overrideConfig, overrideKeys, nil
}
configs, err := store.GetEnabledChatModelConfigs(ctx)
if err != nil {
p.logger.Debug(ctx, "failed to list manual title model configs",
slog.F("chat_id", chat.ID),
slog.Error(err),
)
return p.resolveFallbackManualTitleModel(ctx, chat, keys, modelOpts)
}
config, ok := selectPreferredConfiguredShortTextModelConfig(configs)
if !ok {
return p.resolveFallbackManualTitleModel(ctx, chat, keys, modelOpts)
}
route, err := p.resolveModelRouteForConfig(ctx, chat.OwnerID, config, keys)
if err != nil {
p.logger.Debug(ctx, "manual title preferred model unavailable",
slog.F("chat_id", chat.ID),
slog.F("provider", config.Provider),
slog.F("model", config.Model),
slog.Error(err),
)
return p.resolveFallbackManualTitleModel(ctx, chat, keys, modelOpts)
}
model, err := p.newModel(ctx, modelClientRequest{
Chat: chat,
ModelName: config.Model,
UserAgent: chatprovider.UserAgent(),
ExtraHeaders: chatprovider.CoderHeaders(chat),
}, route, modelOpts)
if err != nil {
p.logger.Debug(ctx, "manual title preferred model unavailable",
slog.F("chat_id", chat.ID),
slog.F("provider", config.Provider),
slog.F("model", config.Model),
slog.Error(err),
)
return p.resolveFallbackManualTitleModel(ctx, chat, keys, modelOpts)
}
return model, config, route.directProviderKeys(), nil
}
func (p *Server) resolveFallbackManualTitleModel(
ctx context.Context,
chat database.Chat,
keys chatprovider.ProviderAPIKeys,
modelOpts modelBuildOptions,
) (fantasy.LanguageModel, database.ChatModelConfig, chatprovider.ProviderAPIKeys, error) {
config, err := p.resolveModelConfig(ctx, chat)
if err != nil {
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, xerrors.Errorf(
"resolve fallback manual title model config: %w",
err,
)
}
route, err := p.resolveModelRouteForConfig(ctx, chat.OwnerID, config, keys)
if err != nil {
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, err
}
model, err := p.newModel(ctx, modelClientRequest{
Chat: chat,
ModelName: config.Model,
UserAgent: chatprovider.UserAgent(),
ExtraHeaders: chatprovider.CoderHeaders(chat),
}, route, modelOpts)
if err != nil {
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, xerrors.Errorf(
"create fallback manual title model: %w",
err,
)
}
return model, config, route.directProviderKeys(), nil
}
func mergeManualTitleMessages(
headMessages []database.ChatMessage,
tailMessagesDesc []database.ChatMessage,
) []database.ChatMessage {
merged := make([]database.ChatMessage, 0, len(headMessages)+len(tailMessagesDesc))
seen := make(map[int64]struct{}, len(headMessages)+len(tailMessagesDesc))
appendUnique := func(message database.ChatMessage) {
if _, ok := seen[message.ID]; ok {
return
}
seen[message.ID] = struct{}{}
merged = append(merged, message)
}
for _, message := range headMessages {
appendUnique(message)
}
for i := len(tailMessagesDesc) - 1; i >= 0; i-- {
appendUnique(tailMessagesDesc[i])
}
return merged
}
func fantasyUsageToChatMessageUsage(usage fantasy.Usage) codersdk.ChatMessageUsage {
var chatUsage codersdk.ChatMessageUsage
if usage.InputTokens != 0 {
chatUsage.InputTokens = ptr.Ref(usage.InputTokens)
}
if usage.OutputTokens != 0 {
chatUsage.OutputTokens = ptr.Ref(usage.OutputTokens)
}
if usage.ReasoningTokens != 0 {
chatUsage.ReasoningTokens = ptr.Ref(usage.ReasoningTokens)
}
if usage.CacheCreationTokens != 0 {
chatUsage.CacheCreationTokens = ptr.Ref(usage.CacheCreationTokens)
}
if usage.CacheReadTokens != 0 {
chatUsage.CacheReadTokens = ptr.Ref(usage.CacheReadTokens)
}
return chatUsage
}
func recordManualTitleUsage(
ctx context.Context,
store database.Store,
chat database.Chat,
modelConfig database.ChatModelConfig,
usage fantasy.Usage,
activeAPIKeyID string,
newTitle string,
) (database.Chat, error) {
hasUsage := usage != (fantasy.Usage{})
if !hasUsage && newTitle == "" {
return chat, nil
}
var totalCostMicros *int64
if hasUsage {
callConfig := codersdk.ChatModelCallConfig{}
if len(modelConfig.Options) > 0 {
if err := json.Unmarshal(modelConfig.Options, &callConfig); err != nil {
return database.Chat{}, xerrors.Errorf("parse model call config: %w", err)
}
}
totalCostMicros = chatcost.CalculateTotalCostMicros(
fantasyUsageToChatMessageUsage(usage),
callConfig.Cost,
)
}
// Use a valid empty JSON array for the content column.
// MarshalParts returns a null NullRawMessage for empty
// slices, which becomes an empty string that PostgreSQL
// rejects as invalid JSON.
content := "[]"
updatedChat := chat
err := store.InTx(func(tx database.Store) error {
lockedChat, err := tx.GetChatByIDForUpdate(ctx, chat.ID)
if err != nil {
return xerrors.Errorf("lock chat for manual title usage: %w", err)
}
updatedChat = lockedChat
if hasUsage {
messages, err := tx.InsertChatMessages(ctx, database.InsertChatMessagesParams{
ChatID: chat.ID,
CreatedBy: []uuid.UUID{chat.OwnerID},
APIKeyID: []string{activeAPIKeyID},
ModelConfigID: []uuid.UUID{modelConfig.ID},
Role: []database.ChatMessageRole{database.ChatMessageRoleAssistant},
Content: []string{content},
ContentVersion: []int16{chatprompt.CurrentContentVersion},
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityModel},
InputTokens: []int64{usage.InputTokens},
OutputTokens: []int64{usage.OutputTokens},
TotalTokens: []int64{usage.TotalTokens},
ReasoningTokens: []int64{usage.ReasoningTokens},
CacheCreationTokens: []int64{usage.CacheCreationTokens},
CacheReadTokens: []int64{usage.CacheReadTokens},
ContextLimit: []int64{modelConfig.ContextLimit},
Compressed: []bool{false},
TotalCostMicros: []int64{ptr.NilToDefault(totalCostMicros, 0)},
RuntimeMs: []int64{0},
ProviderResponseID: []string{""},
})
if err != nil {
return xerrors.Errorf("insert manual title usage message: %w", err)
}
if len(messages) != 1 {
return xerrors.Errorf("expected 1 manual title usage message, got %d", len(messages))
}
if err := tx.SoftDeleteChatMessageByID(ctx, messages[0].ID); err != nil {
return xerrors.Errorf("soft delete manual title usage message: %w", err)
}
if lockedChat.LastModelConfigID != modelConfig.ID {
if _, err := tx.UpdateChatLastModelConfigByID(ctx, database.UpdateChatLastModelConfigByIDParams{
ID: chat.ID,
LastModelConfigID: lockedChat.LastModelConfigID,
}); err != nil {
return xerrors.Errorf("restore chat model config after manual title usage: %w", err)
}
}
}
if newTitle != "" && lockedChat.Title == chat.Title && newTitle != lockedChat.Title {
updatedChat, err = tx.UpdateChatByID(ctx, database.UpdateChatByIDParams{
ID: chat.ID,
Title: newTitle,
})
if err != nil {
return xerrors.Errorf("update chat title: %w", err)
}
}
return nil
}, nil)
if err != nil {
return database.Chat{}, err
}
return updatedChat, nil
}
type chatMessage struct {
role database.ChatMessageRole
content pqtype.NullRawMessage
visibility database.ChatMessageVisibility
modelConfigID uuid.UUID
createdBy uuid.UUID
contentVersion int16
compressed bool
inputTokens int64
outputTokens int64
totalTokens int64
reasoningTokens int64
cacheCreationTokens int64
cacheReadTokens int64
contextLimit int64
totalCostMicros int64
runtimeMs int64
providerResponseID string
}
type userChatMessage struct {
chatMessage
apiKeyID string
}
func (m userChatMessage) withCreatedBy(id uuid.UUID) userChatMessage {
m.chatMessage = m.chatMessage.withCreatedBy(id)
return m
}
func (m userChatMessage) withCompressed() userChatMessage {
m.chatMessage = m.chatMessage.withCompressed()
return m
}
func newChatMessage(
role database.ChatMessageRole,
content pqtype.NullRawMessage,
visibility database.ChatMessageVisibility,
modelConfigID uuid.UUID,
contentVersion int16,
) chatMessage {
return chatMessage{
role: role,
content: content,
visibility: visibility,
modelConfigID: modelConfigID,
contentVersion: contentVersion,
}
}
func newUserChatMessage(
apiKeyID string,
content pqtype.NullRawMessage,
visibility database.ChatMessageVisibility,
modelConfigID uuid.UUID,
contentVersion int16,
) userChatMessage {
return userChatMessage{
chatMessage: newChatMessage(
database.ChatMessageRoleUser,
content,
visibility,
modelConfigID,
contentVersion,
),
apiKeyID: apiKeyID,
}
}
func (m chatMessage) withCreatedBy(id uuid.UUID) chatMessage {
m.createdBy = id
return m
}
func (m chatMessage) withCompressed() chatMessage {
m.compressed = true
return m
}
func (m chatMessage) withUsage(
inputTokens, outputTokens, totalTokens, reasoningTokens,
cacheCreationTokens, cacheReadTokens int64,
) chatMessage {
m.inputTokens = inputTokens
m.outputTokens = outputTokens
m.totalTokens = totalTokens
m.reasoningTokens = reasoningTokens
m.cacheCreationTokens = cacheCreationTokens
m.cacheReadTokens = cacheReadTokens
return m
}
func (m chatMessage) withContextLimit(limit int64) chatMessage {
m.contextLimit = limit
return m
}
func (m chatMessage) withTotalCostMicros(cost int64) chatMessage {
m.totalCostMicros = cost
return m
}
func (m chatMessage) withRuntimeMs(ms int64) chatMessage {
m.runtimeMs = ms
return m
}
func (m chatMessage) withProviderResponseID(id string) chatMessage {
m.providerResponseID = id
return m
}
func appendMessageFields(
params *database.InsertChatMessagesParams,
msg chatMessage,
apiKeyID string,
) {
params.CreatedBy = append(params.CreatedBy, msg.createdBy)
params.APIKeyID = append(params.APIKeyID, apiKeyID)
params.ModelConfigID = append(params.ModelConfigID, msg.modelConfigID)
params.Role = append(params.Role, msg.role)
params.Content = append(params.Content, string(msg.content.RawMessage))
params.ContentVersion = append(params.ContentVersion, msg.contentVersion)
params.Visibility = append(params.Visibility, msg.visibility)
params.InputTokens = append(params.InputTokens, msg.inputTokens)
params.OutputTokens = append(params.OutputTokens, msg.outputTokens)
params.TotalTokens = append(params.TotalTokens, msg.totalTokens)
params.ReasoningTokens = append(params.ReasoningTokens, msg.reasoningTokens)
params.CacheCreationTokens = append(params.CacheCreationTokens, msg.cacheCreationTokens)
params.CacheReadTokens = append(params.CacheReadTokens, msg.cacheReadTokens)
params.ContextLimit = append(params.ContextLimit, msg.contextLimit)
params.Compressed = append(params.Compressed, msg.compressed)
params.TotalCostMicros = append(params.TotalCostMicros, msg.totalCostMicros)
params.RuntimeMs = append(params.RuntimeMs, msg.runtimeMs)
params.ProviderResponseID = append(params.ProviderResponseID, msg.providerResponseID)
}
func appendChatMessage(params *database.InsertChatMessagesParams, msg chatMessage) {
if msg.role == database.ChatMessageRoleUser {
panic("developer error: use appendUserChatMessage for user-role messages")
}
appendMessageFields(params, msg, "")
}
func appendUserChatMessage(params *database.InsertChatMessagesParams, msg userChatMessage) {
appendMessageFields(params, msg.chatMessage, msg.apiKeyID)
}
// BuildSingleUserChatMessageInsertParams creates batch insert params for
// one user message, requiring an apiKeyID for AI Gateway attribution.
// BuildSingleChatMessageInsertParams creates batch insert params for one
// non-user message using the shared chat message builder.
func BuildSingleChatMessageInsertParams(
chatID uuid.UUID,
role database.ChatMessageRole,
content pqtype.NullRawMessage,
visibility database.ChatMessageVisibility,
modelConfigID uuid.UUID,
contentVersion int16,
createdBy uuid.UUID,
) database.InsertChatMessagesParams {
params := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendChatMessage.
ChatID: chatID,
}
msg := newChatMessage(role, content, visibility, modelConfigID, contentVersion)
if createdBy != uuid.Nil {
msg = msg.withCreatedBy(createdBy)
}
if role == database.ChatMessageRoleUser {
appendMessageFields(&params, msg, "")
} else {
appendChatMessage(&params, msg)
}
return params
}
func BuildSingleUserChatMessageInsertParams(
chatID uuid.UUID,
apiKeyID string,
content pqtype.NullRawMessage,
visibility database.ChatMessageVisibility,
modelConfigID uuid.UUID,
contentVersion int16,
createdBy uuid.UUID,
) database.InsertChatMessagesParams {
params := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendUserChatMessage.
ChatID: chatID,
}
msg := newUserChatMessage(apiKeyID, content, visibility, modelConfigID, contentVersion)
if createdBy != uuid.Nil {
msg = msg.withCreatedBy(createdBy)
}
appendUserChatMessage(&params, msg)
return params
}
// Config configures a chat processor.
type Config struct {
Logger slog.Logger
Database database.Store
ReplicaID uuid.UUID
SubscribeFn SubscribeFn
PendingChatAcquireInterval time.Duration
MaxChatsPerAcquire int32
InFlightChatStaleAfter time.Duration
ChatHeartbeatInterval time.Duration
AgentConn AgentConnFunc
AgentInactiveDisconnectTimeout time.Duration
InstructionLookupTimeout time.Duration
CreateWorkspace chattool.CreateWorkspaceFn
StartWorkspace chattool.StartWorkspaceFn
StopWorkspace chattool.StopWorkspaceFn
Pubsub pubsub.Pubsub
ProviderAPIKeys chatprovider.ProviderAPIKeys
AllowBYOK bool
AllowBYOKSet bool
AlwaysEnableDebugLogs bool
WebpushDispatcher webpush.Dispatcher
UsageTracker *workspacestats.UsageTracker
Clock quartz.Clock
AIBridgeTransportFactory *atomic.Pointer[aibridge.TransportFactory]
AIGatewayRoutingEnabled bool
PrometheusRegistry prometheus.Registerer
// OIDCTokenSource resolves the calling user's OIDC access
// token for MCP servers configured with auth_type=user_oidc.
// May be nil if the deployment has no OIDC provider; servers
// using user_oidc will then send no Authorization header.
OIDCTokenSource mcpclient.UserOIDCTokenSource
NotificationsEnqueuer notifications.Enqueuer
Auditor *atomic.Pointer[audit.Auditor]
}
// New creates a new chat processor. The processor polls for pending
// chats and processes them. It is the caller's responsibility to call Close
// on the returned instance.
func New(cfg Config) *Server {
ctx, cancel := context.WithCancel(context.Background())
pendingChatAcquireInterval := cfg.PendingChatAcquireInterval
if pendingChatAcquireInterval == 0 {
pendingChatAcquireInterval = DefaultPendingChatAcquireInterval
}
inFlightChatStaleAfter := cfg.InFlightChatStaleAfter
if inFlightChatStaleAfter == 0 {
inFlightChatStaleAfter = DefaultInFlightChatStaleAfter
}
maxChatsPerAcquire := cfg.MaxChatsPerAcquire
if maxChatsPerAcquire <= 0 {
maxChatsPerAcquire = DefaultMaxChatsPerAcquire
}
chatHeartbeatInterval := cfg.ChatHeartbeatInterval
if chatHeartbeatInterval == 0 {
chatHeartbeatInterval = DefaultChatHeartbeatInterval
}
clk := cfg.Clock
if clk == nil {
clk = quartz.NewReal()
}
if cfg.Pubsub == nil {
panic("chatd: Pubsub is nil")
}
ps := cfg.Pubsub
notificationsEnqueuer := cfg.NotificationsEnqueuer
if notificationsEnqueuer == nil {
notificationsEnqueuer = notifications.NewNoopEnqueuer()
}
instructionLookupTimeout := cfg.InstructionLookupTimeout
if instructionLookupTimeout == 0 {
instructionLookupTimeout = homeInstructionLookupTimeout
}
workerID := cfg.ReplicaID
if workerID == uuid.Nil {
workerID = uuid.New()
}
allowBYOK := true
if cfg.AllowBYOKSet {
allowBYOK = cfg.AllowBYOK
}
p := &Server{
cancel: cancel,
db: cfg.Database,
workerID: workerID,
logger: cfg.Logger.Named("processor"),
subscribeFn: cfg.SubscribeFn,
agentConnFn: cfg.AgentConn,
agentInactiveDisconnectTimeout: cfg.AgentInactiveDisconnectTimeout,
dialTimeout: defaultDialTimeout,
instructionLookupTimeout: instructionLookupTimeout,
createWorkspaceFn: cfg.CreateWorkspace,
startWorkspaceFn: cfg.StartWorkspace,
stopWorkspaceFn: cfg.StopWorkspace,
pubsub: ps,
webpushDispatcher: cfg.WebpushDispatcher,
providerAPIKeys: cfg.ProviderAPIKeys,
allowBYOK: allowBYOK,
oidcTokenSource: cfg.OIDCTokenSource,
debugSvcFactory: func() *chatdebug.Service {
debugSvc := chatdebug.NewService(
cfg.Database,
cfg.Logger.Named("chatdebug"),
ps,
chatdebug.WithAlwaysEnable(cfg.AlwaysEnableDebugLogs),
)
// Debug runs do not heartbeat during model streams; their
// updated_at is only touched on step/run completion. Use a
// longer stale window so long-running turns are not falsely
// finalized as stale while still executing.
debugSvc.SetStaleAfter(inFlightChatStaleAfter * 3)
return debugSvc
},
aibridgeTransportFactory: cfg.AIBridgeTransportFactory,
aiGatewayRoutingEnabled: cfg.AIGatewayRoutingEnabled,
pendingChatAcquireInterval: pendingChatAcquireInterval,
maxChatsPerAcquire: maxChatsPerAcquire,
inFlightChatStaleAfter: inFlightChatStaleAfter,
chatHeartbeatInterval: chatHeartbeatInterval,
usageTracker: cfg.UsageTracker,
clock: clk,
recordingSem: make(chan struct{}, maxConcurrentRecordingUploads),
}
var chatAutoArchiveRecords prometheus.Counter
if cfg.PrometheusRegistry != nil {
p.metrics = chatloop.NewMetrics(cfg.PrometheusRegistry)
cfg.PrometheusRegistry.MustRegister(&streamStateCollector{server: p})
chatAutoArchiveRecords = prometheus.NewCounter(prometheus.CounterOpts{
Namespace: "coderd",
Subsystem: "chat_auto_archive",
Name: "records_archived_total",
Help: "Total number of chats archived by the auto-archive job (counting both roots and cascaded children).",
})
cfg.PrometheusRegistry.MustRegister(chatAutoArchiveRecords)
} else {
p.metrics = chatloop.NopMetrics()
}
p.messagePartBuffer = messagepartbuffer.New(messagepartbuffer.Options{Clock: clk})
chatWorker, err := newChatWorker(p, chatWorkerOptions{
WorkerID: workerID,
Store: cfg.Database,
Pubsub: ps,
Logger: cfg.Logger.Named("chatworker"),
Clock: clk,
MessagePartBuffer: p.messagePartBuffer,
AcquisitionInterval: pendingChatAcquireInterval,
AcquisitionBatchSize: maxChatsPerAcquire,
HeartbeatInterval: chatHeartbeatInterval,
HeartbeatStaleSeconds: int32(inFlightChatStaleAfter.Seconds()),
NotificationsEnqueuer: notificationsEnqueuer,
Auditor: cfg.Auditor,
AutoArchiveRecords: chatAutoArchiveRecords,
})
if err != nil {
panic("chatd: create chat worker: " + err.Error())
}
p.chatWorker = chatWorker
//nolint:gocritic // The chat processor uses a scoped chatd context.
ctx = dbauthz.AsChatd(ctx)
p.configCache = newChatConfigCache(ctx, cfg.Database, clk)
cancelConfigSub, err := p.pubsub.SubscribeWithErr(
coderdpubsub.ChatConfigEventChannel,
coderdpubsub.HandleChatConfigEvent(func(ctx context.Context, ev coderdpubsub.ChatConfigEvent, err error) {
if err != nil {
p.logger.Warn(ctx, "chat config event error", slog.Error(err))
return
}
switch ev.Kind {
case coderdpubsub.ChatConfigEventProviders:
p.configCache.InvalidateProviders()
case coderdpubsub.ChatConfigEventModelConfig:
p.configCache.InvalidateModelConfig(ev.EntityID)
case coderdpubsub.ChatConfigEventUserPrompt:
p.configCache.InvalidateUserPrompt(ev.EntityID)
case coderdpubsub.ChatConfigEventAdvisorConfig:
p.configCache.InvalidateAdvisorConfig()
}
}),
)
if err != nil {
p.logger.Error(ctx, "subscribe to chat config events", slog.Error(err))
} else {
p.configCacheUnsubscribe = cancelConfigSub
}
p.ctx = ctx
// Spawn background goroutines that all servers need.
p.wg.Go(func() { p.streamJanitorLoop(ctx) })
return p
}
// Start runs the background acquire/wake loop that picks up
// pending chats and processes them. Callers that want a passive
// server (e.g. tests) can skip this call; heartbeat, stream
// janitor, and stale recovery still run.
func (p *Server) Start() *Server {
if p.chatWorker != nil {
if err := p.chatWorker.Start(p.ctx); err != nil {
p.logger.Error(p.ctx, "failed to start chat worker", slog.Error(err))
}
}
return p
}
// getCachedDurableMessages returns cached durable messages with IDs
// greater than afterID. Returns nil when the cache has no relevant
// entries.
func (p *Server) getCachedDurableMessages(
chatID uuid.UUID,
afterID int64,
) []codersdk.ChatStreamEvent {
state := p.getOrCreateStreamState(chatID)
state.mu.Lock()
defer state.mu.Unlock()
if afterID < state.durableEvictedBefore {
return nil
}
var result []codersdk.ChatStreamEvent
for _, event := range state.durableMessages {
if event.Message != nil && event.Message.ID > afterID {
result = append(result, event)
}
}
return result
}
// snapshotBufferLocked returns the buffered message_part events that
// the caller should receive in their initial snapshot.
//
// Parts whose committedMessageID != 0 are dropped: those parts were
// claimed by a durable assistant message that the subscriber will
// receive through a different channel (REST snapshot, the initial DB
// query in SubscribeAuthorized, or pubsub catch-up). Delivering them
// here would render the same content twice on the client, once in the
// streaming UI and once as a durable message.
//
// Every caller receives the same view: in-progress parts are always
// delivered and committed parts are always dropped, regardless of
// cursor or relay sentinel. This keeps the buffer free of duplicate
// work for every subscriber, including cross-replica relay
// subscribers whose user-facing peers receive the durable message
// via pubsub.
//
// The caller must hold the per-chat stream state lock.
func snapshotBufferLocked(buffer []bufferedStreamPart) []codersdk.ChatStreamEvent {
if len(buffer) == 0 {
return nil
}
snapshot := make([]codersdk.ChatStreamEvent, 0, len(buffer))
for _, part := range buffer {
if part.committedMessageID != 0 {
continue
}
snapshot = append(snapshot, part.event)
}
return snapshot
}
// subscribeToStream registers a subscriber to the per-chat in-memory
// stream and returns a snapshot of currently in-progress message_part
// events plus the current retry phase, the live subscriber channel,
// and a cancel func.
//
// Parts that were claimed by a committed durable assistant message
// (committedMessageID != 0) are excluded from the snapshot. The
// subscriber will receive those durable messages through the REST
// snapshot, the initial DB query in SubscribeAuthorized, or pubsub,
// so re-delivering their constituent parts here would render the
// same content twice.
func (p *Server) subscribeToStream(chatID uuid.UUID) (
[]codersdk.ChatStreamEvent,
*codersdk.ChatStreamRetry,
<-chan codersdk.ChatStreamEvent,
func(),
) {
state := p.getOrCreateStreamState(chatID)
state.mu.Lock()
snapshot := snapshotBufferLocked(state.buffer)
var currentRetry *codersdk.ChatStreamRetry
if state.currentRetry != nil {
retryCopy := *state.currentRetry
currentRetry = &retryCopy
}
id := uuid.New()
ch := make(chan codersdk.ChatStreamEvent, 128)
state.subscribers[id] = ch
state.mu.Unlock()
cancel := func() {
state.mu.Lock()
// Remove the subscriber but do not close the channel.
// publishToStream copies subscriber references under
// the per-chat lock then sends outside; closing here
// races with that send and can panic. The channel
// becomes unreachable once removed and will be GC'd.
delete(state.subscribers, id)
p.cleanupStreamIfIdle(chatID, state)
state.mu.Unlock()
}
return snapshot, currentRetry, ch, cancel
}
// getOrCreateStreamState returns the per-chat stream state,
// creating one atomically if it doesn't exist. The returned
// state has its own mutex — callers must lock state.mu for
// access.
func (p *Server) getOrCreateStreamState(chatID uuid.UUID) *chatStreamState {
if val, ok := p.chatStreams.Load(chatID); ok {
state, _ := val.(*chatStreamState)
return state
}
val, _ := p.chatStreams.LoadOrStore(chatID, &chatStreamState{
subscribers: make(map[uuid.UUID]chan codersdk.ChatStreamEvent),
})
state, _ := val.(*chatStreamState)
return state
}
// cleanupStreamIfIdle removes the chat entry from the sync.Map when
// there are no subscribers, the stream is not buffering, and any
// grace period for late-connecting relay subscribers has elapsed. If
// the grace window is still open it returns without rescheduling.
// streamJanitorLoop is the backstop that re-checks on a timer.
//
// The caller must hold state.mu. The state pointer may have been
// captured outside this lock (sync.Map.Load or Range); we use
// CompareAndDelete so a stale pointer cannot evict a fresh entry
// installed by a racing getOrCreateStreamState. Returns true
// if the state was deleted, false otherwise.
func (p *Server) cleanupStreamIfIdle(chatID uuid.UUID, state *chatStreamState) bool {
if state.buffering || len(state.subscribers) > 0 {
return false
}
// Keep stream state alive during the grace period so
// late-connecting cross-replica relay subscribers can
// register against this chat before GC.
if !state.bufferRetainedAt.IsZero() &&
p.clock.Now().Before(state.bufferRetainedAt.Add(bufferRetainGracePeriod)) {
return false
}
if !p.chatStreams.CompareAndDelete(chatID, state) {
return false
}
p.workspaceMCPToolsCache.Delete(chatID)
return true
}
// streamJanitorLoop periodically reaps idle chat stream states whose
// grace period has expired. It is the backstop for the grace-window
// early-return in cleanupStreamIfIdle; without it, a subscriber that
// detaches inside grace (the common enterprise relay-drain case,
// relayDrainTimeout = 200ms vs. 5s grace) pins the state forever.
func (p *Server) streamJanitorLoop(ctx context.Context) {
ticker := p.clock.NewTicker(streamJanitorInterval, "chatd", "stream-janitor")
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
p.safeSweepIdleStreams(ctx)
}
}
}
// safeSweepIdleStreams runs sweepIdleStreams under a panic recovery
// so an unexpected panic in the sweep cannot kill the janitor
// goroutine and silently reintroduce the very leak it exists to
// prevent. The next tick retries.
func (p *Server) safeSweepIdleStreams(ctx context.Context) {
defer func() {
if r := recover(); r != nil {
p.logger.Error(ctx, "stream janitor sweep panicked, will retry next tick",
slog.F("panic", r))
}
}()
p.sweepIdleStreams()
}
// sweepIdleStreams iterates chatStreams once and delegates each entry
// to cleanupStreamIfIdle. Range may skip entries that become reapable
// concurrently. Any such entry is reaped on the next tick.
func (p *Server) sweepIdleStreams() {
var reaped atomic.Int64
defer func() {
if count := reaped.Load(); count > 0 {
p.logger.Info(context.Background(), "reaped idle chat streams", slog.F("count", count))
}
}()
p.chatStreams.Range(func(key, value any) bool {
chatID, ok := key.(uuid.UUID)
if !ok {
return true
}
state, ok := value.(*chatStreamState)
if !ok {
return true
}
// guard against any panic from cleanupStreamIfIdle locking state.mu for all time
func() {
state.mu.Lock()
defer state.mu.Unlock()
if p.cleanupStreamIfIdle(chatID, state) {
reaped.Add(1)
}
}()
return true
})
}
// streamSubscriberControlFetchContext keeps a control-path lookup tied to the
// requesting subscriber while applying a fallback timeout when the caller has
// no deadline.
func streamSubscriberControlFetchContext(ctx context.Context) (context.Context, context.CancelFunc) {
if _, ok := ctx.Deadline(); ok {
return ctx, func() {}
}
return context.WithTimeout(ctx, chatStreamControlFetchTimeout)
}
func subscribeWithInitialError(chatID uuid.UUID, message string) (
[]codersdk.ChatStreamEvent,
<-chan codersdk.ChatStreamEvent,
func(),
bool,
) {
events := make(chan codersdk.ChatStreamEvent)
close(events)
return []codersdk.ChatStreamEvent{{
Type: codersdk.ChatStreamEventTypeError,
ChatID: chatID,
Error: &codersdk.ChatError{Message: message},
}}, events, func() {}, true
}
func (p *Server) Subscribe(
ctx context.Context,
chatID uuid.UUID,
requestHeader http.Header,
afterMessageID int64,
) (
[]codersdk.ChatStreamEvent,
<-chan codersdk.ChatStreamEvent,
func(),
bool,
) {
if p == nil {
return nil, nil, nil, false
}
chat, err := p.db.GetChatByID(ctx, chatID)
if err != nil {
if dbauthz.IsNotAuthorizedError(err) {
return nil, nil, nil, false
}
p.logger.Warn(ctx, "failed to load chat for stream subscription",
slog.F("chat_id", chatID),
slog.Error(err),
)
return subscribeWithInitialError(chatID, "failed to load initial snapshot")
}
return p.SubscribeAuthorized(ctx, chat, requestHeader, afterMessageID)
}
// SubscribeAuthorized subscribes an already-authorized chat to merged stream
// updates. The passed chat row proves authorization, but SubscribeAuthorized
// still reloads the chat after the stream subscriptions are armed so the
// initial status and relay setup use fresh state.
func (p *Server) SubscribeAuthorized(
ctx context.Context,
chat database.Chat,
requestHeader http.Header,
afterMessageID int64,
) (
[]codersdk.ChatStreamEvent,
<-chan codersdk.ChatStreamEvent,
func(),
bool,
) {
if p == nil {
return nil, nil, nil, false
}
chatID := chat.ID
// Subscribe to the local stream for message_parts and same-replica
// persisted messages. Capture the current retry phase under the same
// lock so the transient snapshot and subscriber registration reflect
// a single moment in time.
localSnapshot, localRetry, localParts, localCancel := p.subscribeToStream(chatID)
// Merge all event sources.
mergedCtx, mergedCancel := context.WithCancel(ctx)
mergedEvents := make(chan codersdk.ChatStreamEvent, 128)
var allCancels []func()
allCancels = append(allCancels, localCancel)
// Subscribe to pubsub for durable and structured control
// events (status, messages, queue updates, retry, errors).
// If the subscription cannot be established, deliver all local
// events.
//
// This MUST happen before the DB queries below so that any
// notification published between the query and the subscription
// is not lost (subscribe-first-then-query pattern).
notifyCh := make(chan coderdpubsub.ChatStreamNotifyMessage, 10)
errCh := make(chan error, 1)
listener := func(_ context.Context, message []byte, listenErr error) {
if listenErr != nil {
select {
case <-mergedCtx.Done():
case errCh <- listenErr:
}
return
}
var notify coderdpubsub.ChatStreamNotifyMessage
if unmarshalErr := json.Unmarshal(message, &notify); unmarshalErr != nil {
select {
case <-mergedCtx.Done():
case errCh <- xerrors.Errorf("unmarshal chat stream notify: %w", unmarshalErr):
}
return
}
select {
case <-mergedCtx.Done():
case notifyCh <- notify:
}
}
if pubsubCancel, pubsubErr := p.pubsub.SubscribeWithErr(
coderdpubsub.ChatStreamNotifyChannel(chatID),
listener,
); pubsubErr == nil {
allCancels = append(allCancels, pubsubCancel)
} else {
p.logger.Warn(ctx, "failed to subscribe to chat stream notifications",
slog.F("chat_id", chatID),
slog.Error(pubsubErr),
)
}
cancel := func() {
mergedCancel()
for _, cancelFn := range allCancels {
if cancelFn != nil {
cancelFn()
}
}
}
// Re-read the chat after the local/pubsub subscriptions are active so
// the initial status event and any enterprise relay setup use fresh
// state instead of the middleware-loaded row.
refreshCtx, refreshCancel := streamSubscriberControlFetchContext(ctx)
snapshotChat, err := func() (database.Chat, error) {
defer refreshCancel()
//nolint:gocritic // SubscribeAuthorized already validated the
// caller; this refresh only loads the latest status/worker for
// the already-authorized stream subscription.
return p.db.GetChatByID(dbauthz.AsChatd(refreshCtx), chatID)
}()
if err != nil {
p.logger.Warn(ctx, "failed to refresh chat for stream subscription; using stale state",
slog.F("chat_id", chatID),
slog.Error(err),
)
snapshotChat = chat
}
// Build initial snapshot synchronously. The pubsub subscription
// is already active so no notifications can be lost during this
// window.
initialSnapshot := make([]codersdk.ChatStreamEvent, 0)
delivered := map[int64]struct{}{}
// Add local same-replica message_parts to the snapshot. Retry comes
// from state.currentRetry, not the event buffer, so late joiners see
// only the latest phase rather than a stale buffered retry event.
for _, event := range localSnapshot {
if event.Type == codersdk.ChatStreamEventTypeMessagePart {
initialSnapshot = append(initialSnapshot, event)
}
}
var retryEvent *codersdk.ChatStreamEvent
if localRetry != nil {
retryEvent = &codersdk.ChatStreamEvent{
Type: codersdk.ChatStreamEventTypeRetry,
ChatID: chatID,
Retry: localRetry,
}
}
// Load initial messages from DB. When afterMessageID > 0 the
// caller already has messages up to that ID (e.g. from the REST
// endpoint), so we only fetch newer ones to avoid sending
// duplicate data.
messages, err := p.db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: afterMessageID,
})
if err != nil {
p.logger.Error(ctx, "failed to load initial chat messages",
slog.Error(err),
slog.F("chat_id", chatID),
)
initialSnapshot = append(initialSnapshot, codersdk.ChatStreamEvent{
Type: codersdk.ChatStreamEventTypeError,
ChatID: chatID,
Error: &codersdk.ChatError{Message: "failed to load initial snapshot"},
})
} else {
for _, msg := range messages {
sdkMsg := db2sdk.ChatMessage(msg)
initialSnapshot = append(initialSnapshot, codersdk.ChatStreamEvent{
Type: codersdk.ChatStreamEventTypeMessage,
ChatID: chatID,
Message: &sdkMsg,
})
delivered[msg.ID] = struct{}{}
}
}
// Load initial queue. Queue snapshots are intentionally not
// singleflighted because a chat-scoped key cannot distinguish the
// pre- and post-notification queue state.
queueCtx, queueCancel := streamSubscriberControlFetchContext(ctx)
queued, err := p.db.GetChatQueuedMessages(queueCtx, chatID)
queueCancel()
if err != nil {
p.logger.Error(ctx, "failed to load initial queued messages",
slog.Error(err),
slog.F("chat_id", chatID),
)
initialSnapshot = append(initialSnapshot, codersdk.ChatStreamEvent{
Type: codersdk.ChatStreamEventTypeError,
ChatID: chatID,
Error: &codersdk.ChatError{Message: "failed to load initial snapshot"},
})
} else if len(queued) > 0 {
initialSnapshot = append(initialSnapshot, codersdk.ChatStreamEvent{
Type: codersdk.ChatStreamEventTypeQueueUpdate,
ChatID: chatID,
QueuedMessages: db2sdk.ChatQueuedMessages(queued),
})
}
// Include the current chat status in the snapshot so the
// frontend can gate message_part processing correctly from
// the very first batch, without waiting for a separate REST
// query.
statusEvent := codersdk.ChatStreamEvent{
Type: codersdk.ChatStreamEventTypeStatus,
ChatID: chatID,
Status: &codersdk.ChatStreamStatus{
Status: codersdk.ChatStatus(snapshotChat.Status),
},
}
// Prepend so the frontend sees the current stream phases
// before any message_part events.
prefix := []codersdk.ChatStreamEvent{statusEvent}
if retryEvent != nil {
prefix = append(prefix, *retryEvent)
}
initialSnapshot = append(prefix, initialSnapshot...)
// Track the highest durable message ID delivered to this subscriber,
// whether it came from the initial DB snapshot, the same-replica local
// stream, or a later DB/cache catch-up.
lastMessageID := afterMessageID
if len(messages) > 0 {
lastMessageID = messages[len(messages)-1].ID
}
// When an enterprise SubscribeFn is provided, call it to get relay events
// (message_parts from remote replicas). OSS owns pubsub subscription,
// message catch-up, queue updates, and status forwarding; enterprise only
// manages relay dialing.
var relayEvents <-chan codersdk.ChatStreamEvent
var statusNotifications chan StatusNotification
if p.subscribeFn != nil {
statusNotifications = make(chan StatusNotification, 10)
relayEvents = p.subscribeFn(mergedCtx, SubscribeFnParams{
ChatID: chatID,
Chat: snapshotChat,
WorkerID: p.workerID,
StatusNotifications: statusNotifications,
RequestHeader: requestHeader,
DB: p.db,
Logger: p.logger,
})
}
// hasPubsubSubscription is only true when we actually subscribed
// successfully above (allCancels will contain the pubsub
// cancel func in that case).
hasPubsubSubscription := len(allCancels) > 1
//nolint:nestif
go func() {
defer close(mergedEvents)
if statusNotifications != nil {
defer close(statusNotifications)
}
for {
select {
case <-mergedCtx.Done():
return
case psErr := <-errCh:
p.logger.Error(mergedCtx, "chat stream pubsub error",
slog.F("chat_id", chatID),
slog.Error(psErr),
)
select {
case mergedEvents <- codersdk.ChatStreamEvent{
Type: codersdk.ChatStreamEventTypeError,
ChatID: chatID,
Error: &codersdk.ChatError{
Message: psErr.Error(),
},
}:
case <-mergedCtx.Done():
}
return
case notify := <-notifyCh:
// Marker for ENG-2645: subscriber received pubsub notify.
p.logger.Debug(mergedCtx, "stream subscriber received notify",
slog.F("chat_id", chatID),
slog.F("after_message_id", notify.AfterMessageID),
slog.F("status", notify.Status),
slog.F("queue_update", notify.QueueUpdate),
slog.F("last_message_id", lastMessageID),
)
if notify.AfterMessageID > 0 || notify.FullRefresh {
if notify.FullRefresh {
lastMessageID = 0
clear(delivered)
}
var (
deliveredCount int
source string
)
// Notifies can arrive out of order. Rescan from
// min(AfterMessageID, lastMessageID) to cover the gap,
// floored at afterMessageID to respect the subscription
// boundary. The delivered set deduplicates.
lookupAfter := lastMessageID
if !notify.FullRefresh {
lookupAfter = max(afterMessageID, min(notify.AfterMessageID, lastMessageID))
}
cached := p.getCachedDurableMessages(chatID, lookupAfter)
if !notify.FullRefresh && len(cached) > 0 {
for _, event := range cached {
if event.Message == nil {
continue
}
if _, ok := delivered[event.Message.ID]; ok {
continue
}
select {
case <-mergedCtx.Done():
return
case mergedEvents <- event:
}
delivered[event.Message.ID] = struct{}{}
if event.Message.ID > lastMessageID {
lastMessageID = event.Message.ID
}
deliveredCount++
source = "cache"
}
}
// DB pass picks up cross-replica messages the local cache
// cannot have. Delivered set dedupes against the cache pass.
newMessages, msgErr := p.db.GetChatMessagesByChatID(mergedCtx, database.GetChatMessagesByChatIDParams{
ChatID: chatID,
AfterID: lookupAfter,
})
if msgErr != nil {
p.logger.Warn(mergedCtx, "failed to get chat messages after pubsub notification",
slog.F("chat_id", chatID),
slog.Error(msgErr),
)
} else {
for _, msg := range newMessages {
if msg.ID <= lookupAfter {
continue
}
if _, ok := delivered[msg.ID]; ok {
continue
}
sdkMsg := db2sdk.ChatMessage(msg)
select {
case <-mergedCtx.Done():
return
case mergedEvents <- codersdk.ChatStreamEvent{
Type: codersdk.ChatStreamEventTypeMessage,
ChatID: chatID,
Message: &sdkMsg,
}:
}
delivered[msg.ID] = struct{}{}
if msg.ID > lastMessageID {
lastMessageID = msg.ID
}
deliveredCount++
switch source {
case "":
source = "db"
case "cache":
source = "cache+db"
}
}
}
// Marker for ENG-2645: subscriber delivered durable messages.
p.logger.Debug(mergedCtx, "stream subscriber delivered messages",
slog.F("chat_id", chatID),
slog.F("after_message_id", notify.AfterMessageID),
slog.F("lookup_after", lookupAfter),
slog.F("source", source),
slog.F("delivered_count", deliveredCount),
slog.F("last_message_id", lastMessageID),
)
}
if notify.Status != "" {
status := database.ChatStatus(notify.Status)
select {
case <-mergedCtx.Done():
return
case mergedEvents <- codersdk.ChatStreamEvent{
Type: codersdk.ChatStreamEventTypeStatus,
ChatID: chatID,
Status: &codersdk.ChatStreamStatus{Status: codersdk.ChatStatus(status)},
}:
}
// Notify enterprise relay manager if present.
if statusNotifications != nil {
workerID := uuid.Nil
if notify.WorkerID != "" {
if parsed, parseErr := uuid.Parse(notify.WorkerID); parseErr == nil {
workerID = parsed
}
}
select {
case statusNotifications <- StatusNotification{Status: status, WorkerID: workerID}:
case <-mergedCtx.Done():
return
}
}
}
if notify.Retry != nil {
select {
case <-mergedCtx.Done():
return
case mergedEvents <- codersdk.ChatStreamEvent{
Type: codersdk.ChatStreamEventTypeRetry,
ChatID: chatID,
Retry: notify.Retry,
}:
}
}
if notify.ErrorPayload != nil {
select {
case <-mergedCtx.Done():
return
case mergedEvents <- codersdk.ChatStreamEvent{
Type: codersdk.ChatStreamEventTypeError,
ChatID: chatID,
Error: notify.ErrorPayload,
}:
}
} else if notify.Error != "" {
select {
case <-mergedCtx.Done():
return
case mergedEvents <- codersdk.ChatStreamEvent{
Type: codersdk.ChatStreamEventTypeError,
ChatID: chatID,
Error: &codersdk.ChatError{
Message: notify.Error,
},
}:
}
}
if notify.QueueUpdate {
queueCtx, queueCancel := streamSubscriberControlFetchContext(mergedCtx)
queuedMsgs, queueErr := p.db.GetChatQueuedMessages(queueCtx, chatID)
queueCancel()
if queueErr != nil {
p.logger.Warn(mergedCtx, "failed to get queued messages after pubsub notification",
slog.F("chat_id", chatID),
slog.Error(queueErr),
)
} else {
select {
case <-mergedCtx.Done():
return
case mergedEvents <- codersdk.ChatStreamEvent{
Type: codersdk.ChatStreamEventTypeQueueUpdate,
ChatID: chatID,
QueuedMessages: db2sdk.ChatQueuedMessages(queuedMsgs),
}:
}
}
}
case event, ok := <-localParts:
if !ok {
localParts = nil
// Local parts channel closed. If pubsub is
// active we continue with pubsub-driven events.
// Otherwise terminate.
if !hasPubsubSubscription {
return
}
continue
}
if hasPubsubSubscription {
// Forward transient events from local.
// Durable events (messages, queue updates)
// come via pubsub + cache. Status is
// included alongside message_part because
// both travel through the same ordered
// channel: publishStatus is called before
// the first message_part, so FIFO delivery
// guarantees the frontend sees
// status=running before any content.
// Pubsub will deliver a duplicate status
// later; the frontend deduplicates it
// (setChatStatus is idempotent).
// action_required is also transient and
// only published on the local stream, so
// it must be forwarded here.
if event.Type == codersdk.ChatStreamEventTypeMessagePart ||
event.Type == codersdk.ChatStreamEventTypeStatus ||
event.Type == codersdk.ChatStreamEventTypeActionRequired {
select {
case <-mergedCtx.Done():
return
case mergedEvents <- event:
}
}
} else {
// No pubsub subscription: forward all event types.
select {
case <-mergedCtx.Done():
return
case mergedEvents <- event:
}
}
case event, ok := <-relayEvents:
if !ok {
relayEvents = nil
continue
}
select {
case <-mergedCtx.Done():
return
case mergedEvents <- event:
}
}
}
}()
return initialSnapshot, mergedEvents, cancel, true
}
// publishChatPubsubEvents broadcasts a lifecycle event for each affected chat.
func (p *Server) publishChatPubsubEvents(chats []database.Chat, kind codersdk.ChatWatchEventKind) {
for _, chat := range chats {
p.publishChatPubsubEvent(chat, kind, nil)
}
}
// publishChatPubsubEvent broadcasts a chat lifecycle event via PostgreSQL
// pubsub so that all replicas can push updates to watching clients.
func (p *Server) publishChatPubsubEvent(chat database.Chat, kind codersdk.ChatWatchEventKind, diffStatus *codersdk.ChatDiffStatus) {
if p.pubsub == nil {
return
}
// diffStatus is applied below. File metadata is intentionally
// omitted from pubsub events to avoid an extra DB query per
// publish. Clients must merge pubsub updates, not replace
// cached file metadata.
sdkChat := db2sdk.Chat(chat, nil, nil)
if diffStatus != nil {
sdkChat.DiffStatus = diffStatus
}
event := codersdk.ChatWatchEvent{
Kind: kind,
Chat: sdkChat,
}
payload, err := json.Marshal(event)
if err != nil {
p.logger.Error(context.Background(), "failed to marshal chat pubsub event",
slog.F("chat_id", chat.ID),
slog.Error(err),
)
return
}
if err := p.pubsub.Publish(coderdpubsub.ChatWatchEventChannel(chat.OwnerID), payload); err != nil {
p.logger.Error(context.Background(), "failed to publish chat pubsub event",
slog.F("chat_id", chat.ID),
slog.F("kind", kind),
slog.Error(err),
)
}
}
// PublishDiffStatusChange broadcasts a diff_status_change event for
// the given chat so that watching clients know to re-fetch the diff
// status. This is called from the HTTP layer after the diff status
// is updated in the database.
func (p *Server) PublishDiffStatusChange(ctx context.Context, chatID uuid.UUID) error {
chat, err := p.db.GetChatByID(ctx, chatID)
if err != nil {
return xerrors.Errorf("get chat: %w", err)
}
dbStatus, err := p.db.GetChatDiffStatusByChatID(ctx, chatID)
if err != nil {
return xerrors.Errorf("get chat diff status: %w", err)
}
sdkStatus := db2sdk.ChatDiffStatus(chatID, &dbStatus)
p.publishChatPubsubEvent(chat, codersdk.ChatWatchEventKindDiffStatusChange, &sdkStatus)
return nil
}
// Rejects oversize images on capped providers before any upstream
// request is issued.
//
// Gotcha: a historical oversize image bricks the chat on a capped
// provider until the user switches providers back, starts a new
// chat, or edits a message above the offending one (which truncates
// the prompt forward). A future change should skip the file with a
// user-facing warning, but that requires altering the FileResolver
// contract.
func (p *Server) chatFileResolver(provider string) chatprompt.FileResolver {
return func(ctx context.Context, ids []uuid.UUID) (map[uuid.UUID]chatprompt.FileData, error) {
files, err := p.db.GetChatFilesByIDs(ctx, ids)
if err != nil {
return nil, err
}
imageCap, hasImageCap := chatprovider.InlineImageCapBytes(provider)
normalizedProvider := chatprovider.NormalizeProvider(provider)
result := make(map[uuid.UUID]chatprompt.FileData, len(files))
for _, f := range files {
if hasImageCap &&
strings.HasPrefix(f.Mimetype, "image/") &&
len(f.Data) >= imageCap {
err := xerrors.Errorf(
"image attachment %q is %d bytes; %s inline image limit is %d bytes",
f.Name, len(f.Data),
chatprovider.ProviderDisplayName(normalizedProvider),
imageCap,
)
// User-facing message stays client-agnostic since
// older web clients and direct API callers don't
// auto-resize; the wrapped error above keeps the
// exact byte count for operator logs.
return nil, chaterror.WithClassification(err, chaterror.ClassifiedError{
Kind: codersdk.ChatErrorKindConfig,
Provider: normalizedProvider,
Message: fmt.Sprintf(
"Image attachment exceeds %s's %s inline image limit. Replace it with a smaller image.",
chatprovider.ProviderDisplayName(normalizedProvider),
//nolint:gosec // imageCap is a small positive constant defined in chatprovider.
humanize.IBytes(uint64(imageCap)),
),
Retryable: false,
})
}
result[f.ID] = chatprompt.FileData{
Name: f.Name,
Data: f.Data,
MediaType: f.Mimetype,
}
}
return result, nil
}
}
// trackWorkspaceUsage bumps the workspace's last_used_at via the
// usage tracker and extends the workspace's autostop deadline. If
// wsID is not yet valid, it re-reads the chat from the DB to pick
// up late associations (e.g. create_workspace linking a workspace
// mid-conversation). The caller should store the returned value so
// that subsequent calls skip the DB lookup once a workspace has
// been found.
func (p *Server) trackWorkspaceUsage(
ctx context.Context,
chatID uuid.UUID,
wsID uuid.NullUUID,
logger slog.Logger,
) uuid.NullUUID {
if p.usageTracker == nil {
return wsID
}
if !wsID.Valid {
latest, err := p.db.GetChatByID(ctx, chatID)
if err != nil {
logger.Warn(ctx, "failed to re-read chat for workspace association", slog.Error(err))
return wsID
}
wsID = latest.WorkspaceID
}
if wsID.Valid {
p.usageTracker.Add(wsID.UUID)
// Bump the workspace autostop deadline. We pass time.Time{}
// for nextAutostart since we don't have access to
// TemplateScheduleStore here. The activity bump logic
// defaults to the template's activity_bump duration
// (typically 1 hour). Chat workspaces are never prebuilds,
// so no prebuild guard is needed (unlike reporter.go).
//
// This fires every heartbeat (~30s) but the SQL only
// writes when 5% of the deadline has elapsed — most calls
// perform a read-only CTE lookup with no UPDATE.
//
// Scaling note: for 10,000 active chats, this could lead to
// approx. 333 CTE queries/second. A cheap fix for this could
// be to heartbeat every Nth query. Leaving as potential future
// low-hanging fruit if needed.
workspacestats.ActivityBumpWorkspace(ctx, logger.Named("activity_bump"), p.db, wsID.UUID, time.Time{}, workspacestats.ActivityBumpReasonChatHeartbeat)
}
return wsID
}
type runChatResult struct {
FinalAssistantText string
StatusLabelModel fantasy.LanguageModel
ProviderKeys chatprovider.ProviderAPIKeys
FallbackProvider string
FallbackRoute resolvedModelRoute
FallbackModel string
ModelBuildOptions modelBuildOptions
TriggerMessageID int64
HistoryTipMessageID int64
}
func activeTurnAPIKeyIDFromMessages(messages []database.ChatMessage) (string, bool) {
for i := len(messages) - 1; i >= 0; i-- {
message := messages[i]
if message.Role != database.ChatMessageRoleUser {
continue
}
if !isUserVisibleChatMessage(message) &&
!(message.Visibility == database.ChatMessageVisibilityModel && message.Compressed) {
continue
}
if !message.APIKeyID.Valid || message.APIKeyID.String == "" {
return "", false
}
return message.APIKeyID.String, true
}
return "", false
}
func isUserVisibleChatMessage(message database.ChatMessage) bool {
return message.Visibility == database.ChatMessageVisibilityBoth ||
message.Visibility == database.ChatMessageVisibilityUser
}
func allToolNames(allTools []fantasy.AgentTool) []string {
toolNames := make([]string, 0, len(allTools))
for _, tool := range allTools {
toolNames = append(toolNames, tool.Info().Name)
}
return toolNames
}
func isExploreSubagentMode(mode database.NullChatMode) bool {
return mode.Valid && mode.ChatMode == database.ChatModeExplore
}
// filterExternalMCPConfigsForTurn returns the external MCP server configs
// visible on the current turn. Explore children snapshot this filtered set at
// spawn time so later model overrides cannot widen the external-tool boundary.
func filterExternalMCPConfigsForTurn(
configs []database.MCPServerConfig,
mode database.NullChatPlanMode,
parentChatID uuid.NullUUID,
) ([]database.MCPServerConfig, map[uuid.UUID]struct{}) {
if !mode.Valid || mode.ChatPlanMode != database.ChatPlanModePlan {
return configs, nil
}
if parentChatID.Valid {
// Plan-mode subagents do not receive external MCP tools because
// their trust boundary is narrower than the root chat's.
return nil, map[uuid.UUID]struct{}{}
}
filtered := make([]database.MCPServerConfig, 0, len(configs))
approvedIDs := make(map[uuid.UUID]struct{})
for _, cfg := range configs {
if !cfg.AllowInPlanMode {
continue
}
filtered = append(filtered, cfg)
approvedIDs[cfg.ID] = struct{}{}
}
return filtered, approvedIDs
}
func builtinPlanToolAllowed(name string, isRootChat bool) bool {
switch name {
case "read_file", "execute", "process_output", "read_skill", "read_skill_file":
return true
case "write_file", "edit_files", "list_templates", "read_template",
"create_workspace", "start_workspace", "stop_workspace", "propose_plan", "spawn_agent",
"spawn_explore_agent", "wait_agent", "ask_user_question", "attach_file":
return isRootChat
case "process_list", "process_signal", "message_agent", "close_agent",
"spawn_computer_use_agent":
return false
default:
return false
}
}
func toolAllowedForTurn(
tool fantasy.AgentTool,
mode database.NullChatPlanMode,
parentChatID uuid.NullUUID,
approvedMCPConfigIDs map[uuid.UUID]struct{},
) bool {
if !mode.Valid || mode.ChatPlanMode != database.ChatPlanModePlan {
return true
}
if builtinPlanToolAllowed(tool.Info().Name, !parentChatID.Valid) {
return true
}
mcpTool, ok := tool.(mcpclient.MCPToolIdentifier)
if !ok {
return false
}
_, approved := approvedMCPConfigIDs[mcpTool.MCPServerConfigID()]
return approved
}
func filterToolsForTurn(
allTools []fantasy.AgentTool,
mode database.NullChatPlanMode,
parentChatID uuid.NullUUID,
approvedMCPConfigIDs map[uuid.UUID]struct{},
) []fantasy.AgentTool {
if !mode.Valid || mode.ChatPlanMode != database.ChatPlanModePlan {
return allTools
}
filtered := make([]fantasy.AgentTool, 0, len(allTools))
for _, tool := range allTools {
if toolAllowedForTurn(tool, mode, parentChatID, approvedMCPConfigIDs) {
filtered = append(filtered, tool)
}
}
return filtered
}
// activeToolNamesForTurn extends the built-in plan allowlist with approved
// external MCP tools for root plan-mode chats.
func activeToolNamesForTurn(
allTools []fantasy.AgentTool,
mode database.NullChatPlanMode,
parentChatID uuid.NullUUID,
approvedMCPConfigIDs map[uuid.UUID]struct{},
) []string {
toolNames := make([]string, 0, len(allTools))
for _, tool := range allTools {
if toolAllowedForTurn(tool, mode, parentChatID, approvedMCPConfigIDs) {
toolNames = append(toolNames, tool.Info().Name)
}
}
return toolNames
}
func allowedExploreToolNames(allTools []fantasy.AgentTool) []string {
builtinExplorePolicy := map[string]bool{
"read_file": true,
"write_file": false,
"edit_files": false,
"execute": true,
"process_output": true,
"process_list": false,
"process_signal": false,
"list_templates": false,
"read_template": false,
"create_workspace": false,
"start_workspace": false,
"stop_workspace": false,
"propose_plan": false,
"spawn_agent": false,
"wait_agent": false,
"message_agent": false,
"close_agent": false,
"read_skill": true,
"read_skill_file": true,
"ask_user_question": false,
}
toolNames := make([]string, 0, len(allTools))
for _, tool := range allTools {
name := tool.Info().Name
if builtinExplorePolicy[name] {
toolNames = append(toolNames, name)
continue
}
// External MCP tools pass through here. They were snapshot-filtered
// at spawn time on chat.MCPServerIDs. WorkspaceMCPTool does not
// implement MCPToolIdentifier, so workspace tools are excluded
// here too, in addition to the structural exclusion in runChat
// tool assembly.
if _, ok := tool.(mcpclient.MCPToolIdentifier); ok {
toolNames = append(toolNames, name)
}
}
return toolNames
}
// allowedBehaviorToolNames runs only on non-plan turns because
// appendDynamicTools returns early for plan mode. Within that boundary,
// Explore mode wins over the default behavior that allows all tools.
func allowedBehaviorToolNames(
allTools []fantasy.AgentTool,
chatMode database.NullChatMode,
) []string {
if isExploreSubagentMode(chatMode) {
return allowedExploreToolNames(allTools)
}
return allToolNames(allTools)
}
func stopAfterPlanTools(
planMode database.NullChatPlanMode,
parentChatID uuid.NullUUID,
) map[string]struct{} {
if !planMode.Valid || planMode.ChatPlanMode != database.ChatPlanModePlan {
return nil
}
stopTools := map[string]struct{}{
"propose_plan": {},
}
if !parentChatID.Valid {
stopTools["ask_user_question"] = struct{}{}
}
return stopTools
}
func stopAfterBehaviorTools(
planMode database.NullChatPlanMode,
chatMode database.NullChatMode,
parentChatID uuid.NullUUID,
) map[string]struct{} {
if isExploreSubagentMode(chatMode) {
return nil
}
return stopAfterPlanTools(planMode, parentChatID)
}
type systemPromptBehaviorContext struct {
planMode database.NullChatPlanMode
chatMode database.NullChatMode
planModeInstructions string
isRootChat bool
}
func workspaceSkillsForResolution(workspaceSkills []chattool.SkillMeta) []skillspkg.Skill {
if len(workspaceSkills) == 0 {
return nil
}
resolved := make([]skillspkg.Skill, 0, len(workspaceSkills))
for _, skill := range workspaceSkills {
resolved = append(resolved, skillspkg.Skill{
Name: skill.Name,
Description: skill.Description,
Source: skillspkg.SourceWorkspace,
})
}
return resolved
}
func mergeTurnSkills(
personalSkills []skillspkg.Skill,
workspaceSkills []chattool.SkillMeta,
) []skillspkg.ResolvedSkill {
return skillspkg.MergeSkills(
personalSkills,
workspaceSkillsForResolution(workspaceSkills),
)
}
// buildSystemPrompt applies system-level prompt injections in the
// canonical order. It is used by both the initial prompt assembly
// and the ReloadMessages callback to keep them in sync.
func buildSystemPrompt(
prompt []fantasy.Message,
subagentInstruction string,
instruction string,
resolvedSkills []skillspkg.ResolvedSkill,
userPrompt string,
behaviorContext systemPromptBehaviorContext,
) []fantasy.Message {
if subagentInstruction != "" {
prompt = chatprompt.InsertSystem(prompt, subagentInstruction)
}
if instruction != "" {
prompt = chatprompt.InsertSystem(prompt, instruction)
}
if skillIndex := chattool.FormatResolvedSkillIndex(resolvedSkills); skillIndex != "" {
prompt = chatprompt.InsertSystem(prompt, skillIndex)
}
if userPrompt != "" {
prompt = chatprompt.InsertSystem(prompt, userPrompt)
}
if isExploreSubagentMode(behaviorContext.chatMode) {
prompt = chatprompt.InsertSystem(prompt, ExploreSubagentOverlayPrompt)
return prompt
}
isPlanModeTurn := behaviorContext.planMode.Valid && behaviorContext.planMode.ChatPlanMode == database.ChatPlanModePlan
if isPlanModeTurn {
if behaviorContext.isRootChat {
prompt = chatprompt.InsertSystem(prompt, PlanningOverlayPrompt())
if behaviorContext.planModeInstructions != "" {
prompt = chatprompt.InsertSystem(prompt, behaviorContext.planModeInstructions)
}
} else {
prompt = chatprompt.InsertSystem(prompt, PlanningSubagentOverlayPrompt)
}
}
return prompt
}
func removeSkillIndexMessages(prompt []fantasy.Message) []fantasy.Message {
out := make([]fantasy.Message, 0, len(prompt))
removed := false
for _, message := range prompt {
if isSkillIndexMessage(message) {
removed = true
continue
}
out = append(out, message)
}
if !removed {
return prompt
}
return out
}
func isSkillIndexMessage(message fantasy.Message) bool {
if message.Role != fantasy.MessageRoleSystem || len(message.Content) != 1 {
return false
}
textPart, ok := fantasy.AsMessagePart[fantasy.TextPart](message.Content[0])
if !ok {
return false
}
text := strings.TrimSpace(textPart.Text)
return strings.HasPrefix(text, chattool.AvailableSkillsOpenTag+"\n") && strings.HasSuffix(text, chattool.AvailableSkillsCloseTag)
}
type rootChatToolsOptions struct {
chat database.Chat
modelConfigID uuid.UUID
workspaceCtx *turnWorkspaceContext
workspaceMu *sync.Mutex
resolvePlanPath func(context.Context) (string, string, error)
storeFile chattool.StoreFileFunc
isPlanModeTurn bool
// primerCtx scopes the workspace MCP cache primer goroutines
// that onChatUpdated launches. runChat cancels it before
// workspaceCtx.close() so an in-flight primer cannot dial a
// fresh conn after the cached one was released.
primerCtx context.Context
}
func (p *Server) loadPlanModeInstructions(
ctx context.Context,
mode database.NullChatPlanMode,
logger slog.Logger,
) string {
if !mode.Valid || mode.ChatPlanMode != database.ChatPlanModePlan {
return ""
}
// Plan-mode instructions live in deployment config, but chat workers do
// not carry a deployment-config actor during background execution.
//nolint:gocritic // Required to read deployment config during background chat processing.
systemCtx := dbauthz.AsSystemRestricted(ctx)
fetched, err := p.db.GetChatPlanModeInstructions(systemCtx)
if err != nil {
logger.Warn(ctx,
"failed to fetch plan mode instructions",
slog.Error(err),
)
return ""
}
return fetched
}
func userSkillContext(ctx context.Context, userID uuid.UUID) context.Context {
actor := rbac.Subject{
Type: rbac.SubjectTypeUser,
ID: userID.String(),
Roles: rbac.RoleIdentifiers{rbac.RoleMember()},
Scope: rbac.ScopeAll,
}.WithCachedASTValue()
// Chat turns run asynchronously after admission, so the original request
// actor may no longer be available when a worker loads personal skills.
// We synthesize the chat owner as a member instead of reusing that actor.
// Hardcoding RoleMember is safe because dbauthz enforces
// ResourceUserSkill.WithOwner(userID), so this actor cannot read any other
// user's skills regardless of role. Org scoping is not needed because
// personal skills are user-scoped, not org-scoped.
//nolint:gocritic // The synthetic actor is intentional for the reasons above.
return dbauthz.As(ctx, actor)
}
func (p *Server) fetchPersonalSkillMetadata(
ctx context.Context,
userID uuid.UUID,
logger slog.Logger,
) []skillspkg.Skill {
rows, err := p.db.ListUserSkillMetadataByUserID(userSkillContext(ctx, userID), userID)
// See package coderd/x/skills (doc.go) for why metadata fetch failures
// intentionally degrade to an empty personal-skill list instead of
// failing the chat turn.
if err != nil {
logger.Warn(ctx, "failed to load personal skill metadata",
slog.F("owner_id", userID),
slog.Error(err),
)
return nil
}
personalSkills := make([]skillspkg.Skill, 0, len(rows))
for _, row := range rows {
personalSkills = append(personalSkills, skillspkg.Skill{
Name: row.Name,
Description: row.Description,
Source: skillspkg.SourcePersonal,
})
}
return personalSkills
}
func (p *Server) loadPersonalSkillBody(
ctx context.Context,
userID uuid.UUID,
name string,
) (skillspkg.ParsedSkill, error) {
row, err := p.db.GetUserSkillByUserIDAndName(
userSkillContext(ctx, userID),
database.GetUserSkillByUserIDAndNameParams{
UserID: userID,
Name: name,
},
)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return skillspkg.ParsedSkill{}, skillspkg.ErrSkillNotFound
}
p.logger.Error(ctx, "load personal skill body failed",
slog.F("user_id", userID),
slog.F("name", name),
slog.Error(err),
)
return skillspkg.ParsedSkill{}, xerrors.Errorf("load personal skill body: %w", err)
}
parsed, err := skillspkg.ParsePersonalSkillMarkdown([]byte(row.Content))
if err != nil {
p.logger.Error(ctx, "parse personal skill body failed",
slog.F("user_id", userID),
slog.F("name", name),
slog.Error(err),
)
return skillspkg.ParsedSkill{}, xerrors.Errorf("parse personal skill body: %w", err)
}
return parsed, nil
}
func (p *Server) appendRootChatTools(
ctx context.Context,
tools []fantasy.AgentTool,
opts rootChatToolsOptions,
) []fantasy.AgentTool {
onChatUpdated := func(updatedChat database.Chat) {
opts.workspaceCtx.selectWorkspace(updatedChat)
// Notify the frontend immediately so it can start streaming
// build logs before the tool completes.
p.publishChatPubsubEvent(updatedChat, codersdk.ChatWatchEventKindStatusChange, nil)
// Note: we intentionally do not insert AGENTS.md / workspace
// context here. Local tool callbacks must not mutate chat
// history while a local-tool generation task is in flight,
// because that advances history_version before the tool
// result is committed and exits the local-tool commit as
// stale. Workspace context is persisted by the
// persist_workspace_context generation action in a later
// pass.
// Prime the workspace MCP tools cache while the create_workspace
// or start_workspace tool is still running. The AgentID guard
// below restricts the primer to the post-ready callback, when
// the agent is reachable. ListMCPTools may still return an
// empty list on the first try when the agent's MCP Connect is
// racing with agent startup; primeWorkspaceMCPCache retries
// with a short backoff up to workspaceMCPPrimeMaxWait. Priming
// here lets the next assistant-generation action hit the cache
// instead of dialing again on a separate timeout budget.
//
// Run asynchronously: the tool itself must not block on the
// primer because the agent may not advertise any MCP tools at
// all (e.g. minimal templates), in which case the primer waits
// the full budget before giving up. The next assistant-generation
// action covers the cache miss path; the primer is purely an
// optimization that warms the cache while the LLM is thinking.
// inflight tracking ensures server shutdown still waits for any
// in-progress primer.
//
// Guard on both WorkspaceID and AgentID being valid:
// create_workspace and start_workspace each fire onChatUpdated
// twice for a new build (binding before waitForAgentReady;
// post-ready after it), and stop_workspace fires it with a nil
// agent. Only the post-ready callback has a live AgentID, so
// the pre-build and stop-side firings would otherwise spawn a
// primer goroutine that dials a missing or dying agent and
// burns the full budget for nothing.
snapshot := opts.workspaceCtx.currentChatSnapshot()
if snapshot.WorkspaceID.Valid && snapshot.AgentID.Valid {
p.inflight.Go(func() {
p.primeWorkspaceMCPCache(opts.primerCtx, p.logger, snapshot.ID, opts.workspaceCtx)
})
}
}
tools = append(tools,
chattool.ListTemplates(p.db, opts.chat.OrganizationID, chattool.ListTemplatesOptions{
OwnerID: opts.chat.OwnerID,
AllowedTemplateIDs: p.chatTemplateAllowlist,
}),
chattool.ReadTemplate(p.db, opts.chat.OrganizationID, chattool.ReadTemplateOptions{
OwnerID: opts.chat.OwnerID,
AllowedTemplateIDs: p.chatTemplateAllowlist,
}),
chattool.CreateWorkspace(p.db, opts.chat.OrganizationID, opts.chat.ID, chattool.CreateWorkspaceOptions{
OwnerID: opts.chat.OwnerID,
CreateFn: p.createWorkspaceFn,
AgentConnFn: chattool.AgentConnFunc(p.agentConnFn),
AgentInactiveDisconnectTimeout: p.agentInactiveDisconnectTimeout,
WorkspaceMu: opts.workspaceMu,
OnChatUpdated: onChatUpdated,
Logger: p.logger,
AllowedTemplateIDs: p.chatTemplateAllowlist,
}),
chattool.StartWorkspace(p.db, opts.chat.ID, chattool.StartWorkspaceOptions{
OwnerID: opts.chat.OwnerID,
StartFn: p.startWorkspaceFn,
AgentConnFn: chattool.AgentConnFunc(p.agentConnFn),
WorkspaceMu: opts.workspaceMu,
OnChatUpdated: onChatUpdated,
Logger: p.logger,
}),
chattool.StopWorkspace(p.db, opts.chat.ID, chattool.StopWorkspaceOptions{
OwnerID: opts.chat.OwnerID,
StopFn: p.stopWorkspaceFn,
WorkspaceMu: opts.workspaceMu,
OnChatUpdated: onChatUpdated,
Logger: p.logger,
}),
)
if opts.isPlanModeTurn {
tools = append(tools, chattool.ProposePlan(chattool.ProposePlanOptions{
GetWorkspaceConn: opts.workspaceCtx.getWorkspaceConn,
ResolvePlanPath: opts.resolvePlanPath,
IsPlanTurn: opts.isPlanModeTurn,
StoreFile: opts.storeFile,
}))
}
return append(tools, p.subagentTools(ctx, func() database.Chat {
return opts.chat
}, opts.modelConfigID)...)
}
func appendDynamicTools(
ctx context.Context,
logger slog.Logger,
tools []fantasy.AgentTool,
raw pqtype.NullRawMessage,
planMode database.NullChatPlanMode,
chatMode database.NullChatMode,
) ([]fantasy.AgentTool, map[string]bool, error) {
if isExploreSubagentMode(chatMode) || (planMode.Valid && planMode.ChatPlanMode == database.ChatPlanModePlan) {
return tools, nil, nil
}
dynamicToolNames, err := parseDynamicToolNames(raw)
if err != nil {
return nil, nil, xerrors.Errorf("parse dynamic tool names: %w", err)
}
if len(dynamicToolNames) == 0 {
return tools, dynamicToolNames, nil
}
var dynamicToolDefs []codersdk.DynamicTool
if raw.Valid {
if err := json.Unmarshal(raw.RawMessage, &dynamicToolDefs); err != nil {
return nil, nil, xerrors.Errorf("unmarshal dynamic tools: %w", err)
}
}
activeToolNames := make(map[string]struct{}, len(tools))
for _, name := range allowedBehaviorToolNames(tools, chatMode) {
activeToolNames[name] = struct{}{}
}
for _, t := range tools {
info := t.Info()
if _, active := activeToolNames[info.Name]; !active {
continue
}
if dynamicToolNames[info.Name] {
logger.Warn(ctx, "dynamic tool name collides with built-in tool, built-in takes precedence",
slog.F("tool_name", info.Name))
delete(dynamicToolNames, info.Name)
}
}
var filteredDefs []codersdk.DynamicTool
for _, dt := range dynamicToolDefs {
if dynamicToolNames[dt.Name] {
filteredDefs = append(filteredDefs, dt)
}
}
return append(tools, dynamicToolsFromSDK(logger, filteredDefs)...), dynamicToolNames, nil
}
// buildProviderTools creates provider-native tool definitions
// (like web search) based on the model configuration. These
// tools are executed server-side by the LLM provider.
func buildProviderTools(options *codersdk.ChatModelProviderOptions) []chatloop.ProviderTool {
var tools []chatloop.ProviderTool
if options == nil {
return nil
}
if options.Anthropic != nil && options.Anthropic.WebSearchEnabled != nil && *options.Anthropic.WebSearchEnabled {
tools = append(tools, chatloop.ProviderTool{
Definition: anthropic.WebSearchTool(&anthropic.WebSearchToolOptions{
AllowedDomains: options.Anthropic.AllowedDomains,
BlockedDomains: options.Anthropic.BlockedDomains,
}),
})
}
if tool, ok := chatopenai.WebSearchTool(options.OpenAI); ok {
tools = append(tools, chatloop.ProviderTool{
Definition: tool,
})
}
if options.Google != nil && options.Google.WebSearchEnabled != nil && *options.Google.WebSearchEnabled {
tools = append(tools, chatloop.ProviderTool{
Definition: fantasy.ProviderDefinedTool{
ID: "web_search",
Name: "web_search",
},
})
}
return tools
}
// persistChatContextSummary is called from the chat loop's compaction
// callback. activeAPIKeyID is stamped onto the summary user message. When
// empty, it falls back to the delegated key in ctx.
func (p *Server) persistChatContextSummary(
ctx context.Context,
chatID uuid.UUID,
modelConfigID uuid.UUID,
activeAPIKeyID string,
toolCallID string,
result chatloop.CompactionResult,
) error {
if strings.TrimSpace(result.SystemSummary) == "" ||
strings.TrimSpace(result.SummaryReport) == "" {
return nil
}
systemContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
codersdk.ChatMessageText(result.SystemSummary),
})
if err != nil {
return xerrors.Errorf("encode system summary: %w", err)
}
args, err := json.Marshal(map[string]any{
"source": "automatic",
"threshold_percent": result.ThresholdPercent,
})
if err != nil {
return xerrors.Errorf("encode summary tool args: %w", err)
}
assistantContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
codersdk.ChatMessageToolCall(toolCallID, "chat_summarized", args),
})
if err != nil {
return xerrors.Errorf("encode summary tool call: %w", err)
}
summaryResult, err := json.Marshal(map[string]any{
"summary": result.SummaryReport,
"source": "automatic",
"threshold_percent": result.ThresholdPercent,
"usage_percent": result.UsagePercent,
"context_tokens": result.ContextTokens,
"context_limit_tokens": result.ContextLimit,
})
if err != nil {
return xerrors.Errorf("encode summary result payload: %w", err)
}
toolResult, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
codersdk.ChatMessageToolResult(toolCallID, "chat_summarized", summaryResult, false, false),
})
if err != nil {
return xerrors.Errorf("encode summary tool result: %w", err)
}
summaryAPIKeyID := activeAPIKeyID
if summaryAPIKeyID == "" {
summaryAPIKeyID, _ = aibridge.DelegatedAPIKeyIDFromContext(ctx)
}
var insertedMessages []database.ChatMessage
txErr := p.db.InTx(func(tx database.Store) error {
summaryParams := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by append[User]ChatMessage.
ChatID: chatID,
}
summaryUserMsg := newUserChatMessage(
summaryAPIKeyID,
systemContent,
database.ChatMessageVisibilityModel,
modelConfigID,
chatprompt.CurrentContentVersion,
)
summaryUserMsg = summaryUserMsg.withCompressed()
appendUserChatMessage(&summaryParams, summaryUserMsg)
appendChatMessage(&summaryParams, newChatMessage(
database.ChatMessageRoleAssistant,
assistantContent,
database.ChatMessageVisibilityUser,
modelConfigID,
chatprompt.CurrentContentVersion,
).withCompressed())
appendChatMessage(&summaryParams, newChatMessage(
database.ChatMessageRoleTool,
toolResult,
database.ChatMessageVisibilityBoth,
modelConfigID,
chatprompt.CurrentContentVersion,
).withCompressed())
allInserted, txErr := tx.InsertChatMessages(ctx, summaryParams)
if txErr != nil {
return xerrors.Errorf("insert summary messages: %w", txErr)
}
insertedMessages = allInserted[1:]
return nil
}, nil)
if txErr != nil {
return txErr
}
_ = insertedMessages
return nil
}
func (p *Server) resolveChatModel(
ctx context.Context,
chat database.Chat,
modelOpts modelBuildOptions,
) (
model fantasy.LanguageModel,
dbConfig database.ChatModelConfig,
keys chatprovider.ProviderAPIKeys,
route resolvedModelRoute,
debugEnabled bool,
resolvedProvider string,
resolvedModel string,
err error,
) {
dbConfig, err = p.resolveModelConfig(ctx, chat)
if err != nil {
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, resolvedModelRoute{}, false, "", "", xerrors.Errorf("resolve model config: %w", err)
}
if !dbConfig.Enabled {
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, resolvedModelRoute{}, false, "", "", xerrors.Errorf("chat model config %s is disabled", dbConfig.ID)
}
route, err = p.resolveModelRouteForConfig(ctx, chat.OwnerID, dbConfig, chatprovider.ProviderAPIKeys{})
if err != nil {
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, resolvedModelRoute{}, false, "", "", err
}
keys = route.directProviderKeys()
providerHint, err := route.providerHint()
if err != nil {
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, resolvedModelRoute{}, false, "", "", err
}
resolvedProvider, resolvedModel, err = chatprovider.ResolveModelWithProviderHint(
dbConfig.Model,
providerHint,
)
if err != nil {
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, resolvedModelRoute{}, false, "", "", xerrors.Errorf(
"resolve model metadata: %w", err,
)
}
model, debugEnabled, err = p.newDebugAwareModel(ctx, modelClientRequest{
Chat: chat,
ModelName: dbConfig.Model,
UserAgent: chatprovider.UserAgent(),
ExtraHeaders: chatprovider.CoderHeaders(chat),
}, route, modelOpts)
if err != nil {
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, resolvedModelRoute{}, false, "", "", xerrors.Errorf(
"create model: %w", err,
)
}
return model, dbConfig, keys, route, debugEnabled, resolvedProvider, resolvedModel, nil
}
func (p *Server) aiProviderConfig(ctx context.Context, provider database.AIProvider) (chatprovider.ConfiguredProvider, error) {
keys, err := p.db.GetAIProviderKeysByProviderID(ctx, provider.ID)
if err != nil {
return chatprovider.ConfiguredProvider{}, xerrors.Errorf("get AI provider keys: %w", err)
}
return p.aiProviderConfigFromKeys(provider, keys)
}
func (p *Server) aiProviderConfigFromKeys(provider database.AIProvider, keys []database.AIProviderKey) (chatprovider.ConfiguredProvider, error) {
if !provider.Enabled {
return chatprovider.ConfiguredProvider{}, xerrors.Errorf("AI provider %s is disabled", provider.ID)
}
apiKey := ""
// GetAIProviderKeysByProviderID orders keys oldest first. chatd consumes
// one provider-scoped key because runtime provider config has one API key slot.
for _, key := range keys {
if key.APIKey != "" {
apiKey = key.APIKey
break
}
}
return chatprovider.ConfiguredProvider{
ProviderID: provider.ID,
Provider: string(provider.Type),
APIKey: apiKey,
BaseURL: provider.BaseUrl,
CentralAPIKeyEnabled: true,
AllowUserAPIKey: p.allowBYOK,
AllowCentralAPIKeyFallback: true,
}, nil
}
func (p *Server) aiProviderConfigs(ctx context.Context, providers []database.AIProvider) ([]chatprovider.ConfiguredProvider, error) {
if len(providers) == 0 {
return nil, nil
}
providerIDs := make([]uuid.UUID, 0, len(providers))
for _, provider := range providers {
providerIDs = append(providerIDs, provider.ID)
}
keys, err := p.db.GetAIProviderKeysByProviderIDs(ctx, providerIDs)
if err != nil {
return nil, xerrors.Errorf("get AI provider keys: %w", err)
}
keysByProviderID := make(map[uuid.UUID][]database.AIProviderKey, len(providers))
for _, key := range keys {
keysByProviderID[key.ProviderID] = append(keysByProviderID[key.ProviderID], key)
}
configuredProviders := make([]chatprovider.ConfiguredProvider, 0, len(providers))
for _, provider := range providers {
configuredProvider, err := p.aiProviderConfigFromKeys(provider, keysByProviderID[provider.ID])
if err != nil {
return nil, err
}
configuredProviders = append(configuredProviders, configuredProvider)
}
return configuredProviders, nil
}
func ensureUniqueConfiguredProviderTypes(providers []chatprovider.ConfiguredProvider) error {
seen := make(map[string]uuid.UUID, len(providers))
for _, provider := range providers {
normalizedProvider := chatprovider.NormalizeProvider(provider.Provider)
if normalizedProvider == "" {
continue
}
if existingProviderID, ok := seen[normalizedProvider]; ok && existingProviderID != provider.ProviderID {
return xerrors.Errorf("multiple enabled AI providers use provider type %q; select an AI provider by ID", normalizedProvider)
}
seen[normalizedProvider] = provider.ProviderID
}
return nil
}
func (p *Server) resolveUserProviderAPIKeysForProvider(
ctx context.Context,
ownerID uuid.UUID,
provider database.AIProvider,
) (chatprovider.ProviderAPIKeys, error) {
configuredProvider, err := p.aiProviderConfig(ctx, provider)
if err != nil {
return chatprovider.ProviderAPIKeys{}, err
}
userKeys := []chatprovider.UserProviderKey{}
if p.allowBYOK {
userKey, err := p.db.GetUserAIProviderKeyByProviderID(ctx, database.GetUserAIProviderKeyByProviderIDParams{
UserID: ownerID,
AIProviderID: provider.ID,
})
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
return chatprovider.ProviderAPIKeys{}, xerrors.Errorf("get user AI provider key: %w", err)
}
if err == nil {
userKeys = append(userKeys, chatprovider.UserProviderKey{
ChatProviderID: userKey.AIProviderID,
APIKey: userKey.APIKey,
})
}
}
keys, _ := chatprovider.ResolveUserProviderKeys(
chatprovider.ProviderAPIKeys{},
[]chatprovider.ConfiguredProvider{configuredProvider},
userKeys,
)
return keys, nil
}
func (p *Server) resolveUserProviderAPIKeysForProviderType(
ctx context.Context,
ownerID uuid.UUID,
providerType string,
) (chatprovider.ProviderAPIKeys, error) {
keys, _, err := p.resolveUserProviderAPIKeysAndProviderForProviderType(ctx, ownerID, providerType)
return keys, err
}
func (p *Server) resolveUserProviderAPIKeysAndProviderForProviderType(
ctx context.Context,
ownerID uuid.UUID,
providerType string,
) (chatprovider.ProviderAPIKeys, *database.AIProvider, error) {
providers, err := p.db.GetAIProviders(ctx, database.GetAIProvidersParams{})
if err != nil {
return chatprovider.ProviderAPIKeys{}, nil, xerrors.Errorf("get enabled AI providers: %w", err)
}
normalizedProviderType := chatprovider.NormalizeProvider(providerType)
for _, provider := range providers {
if chatprovider.NormalizeProvider(string(provider.Type)) != normalizedProviderType {
continue
}
keys, err := p.resolveUserProviderAPIKeysForProvider(ctx, ownerID, provider)
if err != nil {
return chatprovider.ProviderAPIKeys{}, nil, err
}
if userCanUseProviderKeys(keys, normalizedProviderType) {
return keys, &provider, nil
}
}
keys, err := p.resolveUserProviderAPIKeys(ctx, ownerID, uuid.Nil)
if err != nil {
return chatprovider.ProviderAPIKeys{}, nil, err
}
return keys, nil, nil
}
func (p *Server) resolveUserProviderAPIKeys(
ctx context.Context,
ownerID uuid.UUID,
selectedAIProviderID uuid.UUID,
) (chatprovider.ProviderAPIKeys, error) {
if selectedAIProviderID != uuid.Nil {
provider, err := p.db.GetAIProviderByID(ctx, selectedAIProviderID)
if err != nil {
return chatprovider.ProviderAPIKeys{}, xerrors.Errorf("get AI provider: %w", err)
}
return p.resolveUserProviderAPIKeysForProvider(ctx, ownerID, provider)
}
providers, err := p.configCache.EnabledProviders(ctx)
if err != nil {
return chatprovider.ProviderAPIKeys{}, xerrors.Errorf(
"get enabled AI providers: %w",
err,
)
}
configuredProviders, err := p.aiProviderConfigs(ctx, providers)
if err != nil {
return chatprovider.ProviderAPIKeys{}, err
}
if err := ensureUniqueConfiguredProviderTypes(configuredProviders); err != nil {
return chatprovider.ProviderAPIKeys{}, err
}
userKeys := []chatprovider.UserProviderKey{}
if p.allowBYOK {
userKeyRows, err := p.db.GetUserAIProviderKeysByUserID(ctx, ownerID)
if err != nil {
return chatprovider.ProviderAPIKeys{}, xerrors.Errorf(
"get user AI provider keys: %w",
err,
)
}
userKeys = make([]chatprovider.UserProviderKey, 0, len(userKeyRows))
for _, userKey := range userKeyRows {
userKeys = append(userKeys, chatprovider.UserProviderKey{
ChatProviderID: userKey.AIProviderID,
APIKey: userKey.APIKey,
})
}
}
keys, _ := chatprovider.ResolveUserProviderKeys(
p.providerAPIKeys,
configuredProviders,
userKeys,
)
enabledProviders := make(map[string]struct{}, len(configuredProviders))
for _, provider := range configuredProviders {
normalizedProvider := chatprovider.NormalizeProvider(provider.Provider)
if normalizedProvider == "" {
continue
}
enabledProviders[normalizedProvider] = struct{}{}
}
chatprovider.PruneDisabledProviderKeys(&keys, enabledProviders)
return keys, nil
}
// resolveModelConfig looks up the chat's model config by its
// LastModelConfigID. If the referenced config no longer exists
// (e.g. it was deleted), it falls back to the default model
// config. Returns an error when no usable config is available.
func (p *Server) resolveModelConfig(
ctx context.Context,
chat database.Chat,
) (database.ChatModelConfig, error) {
if chat.LastModelConfigID != uuid.Nil {
modelConfig, err := p.configCache.ModelConfigByID(
ctx, chat.LastModelConfigID,
)
if err == nil {
return modelConfig, nil
}
if !xerrors.Is(err, sql.ErrNoRows) {
return database.ChatModelConfig{}, xerrors.Errorf(
"get chat model config %s: %w",
chat.LastModelConfigID, err,
)
}
// Model config was deleted, fall through to default.
}
defaultConfig, err := p.configCache.DefaultModelConfig(ctx)
if err != nil {
if xerrors.Is(err, sql.ErrNoRows) {
return database.ChatModelConfig{}, xerrors.New(
"no default chat model config is available",
)
}
return database.ChatModelConfig{}, xerrors.Errorf(
"get default chat model config: %w", err,
)
}
return defaultConfig, nil
}
func refreshChatWorkspaceSnapshot(
ctx context.Context,
chat database.Chat,
loadChat func(context.Context, uuid.UUID) (database.Chat, error),
) (database.Chat, error) {
if chat.WorkspaceID.Valid || loadChat == nil {
return chat, nil
}
refreshedChat, err := loadChat(ctx, chat.ID)
if err != nil {
return chat, xerrors.Errorf("reload chat workspace state: %w", err)
}
return refreshedChat, nil
}
// contextFileAgentID extracts the workspace agent ID from the most
// recent persisted instruction-file parts. The skill-only sentinel is
// ignored because it does not represent persisted instruction content.
// Returns uuid.Nil, false if no instruction-file parts exist.
func contextFileAgentID(messages []database.ChatMessage) (uuid.UUID, bool) {
var lastID uuid.UUID
found := false
for _, msg := range messages {
if !msg.Content.Valid || !bytes.Contains(msg.Content.RawMessage, []byte(`"context-file"`)) {
continue
}
var parts []codersdk.ChatMessagePart
if err := json.Unmarshal(msg.Content.RawMessage, &parts); err != nil {
continue
}
for _, p := range parts {
if p.Type != codersdk.ChatMessagePartTypeContextFile ||
!p.ContextFileAgentID.Valid ||
p.ContextFilePath == AgentChatContextSentinelPath {
continue
}
lastID = p.ContextFileAgentID.UUID
found = true
break
}
}
return lastID, found
}
// fetchWorkspaceContext retrieves fresh instruction files and
// skills from the workspace agent without persisting. It handles
// agent connection, context configuration fetching, content
// sanitization, and metadata stamping. Returns the workspace
// agent, the stamped parts, discovered skills, and whether the
// workspace connection succeeded. A nil agent means the chat has
// no valid workspace or the agent lookup failed;
// workspaceConnOK is false in that case.
func (p *Server) fetchWorkspaceContext(
ctx context.Context,
chat database.Chat,
getWorkspaceAgent func(context.Context) (database.WorkspaceAgent, error),
getWorkspaceConn func(context.Context) (workspacesdk.AgentConn, error),
) (agent *database.WorkspaceAgent, agentParts []codersdk.ChatMessagePart, discoveredSkills []chattool.SkillMeta, workspaceConnOK bool) {
if !chat.WorkspaceID.Valid || getWorkspaceAgent == nil {
return nil, nil, nil, false
}
loadedAgent, agentErr := getWorkspaceAgent(ctx)
if agentErr != nil {
return nil, nil, nil, false
}
directory := loadedAgent.ExpandedDirectory
if directory == "" {
directory = loadedAgent.Directory
}
// Fetch context configuration from the agent. Parts
// arrive pre-populated with context-file and skill entries
// so we don't need additional round-trips.
if getWorkspaceConn != nil {
instructionCtx, cancel := context.WithTimeout(ctx, p.instructionLookupTimeout)
defer cancel()
conn, connErr := getWorkspaceConn(instructionCtx)
if connErr != nil {
p.logger.Debug(ctx, "failed to resolve workspace connection for instruction files",
slog.F("chat_id", chat.ID),
slog.Error(connErr),
)
} else {
workspaceConnOK = true
agentCfg, cfgErr := conn.ContextConfig(instructionCtx)
if cfgErr != nil {
p.logger.Debug(ctx, "failed to fetch context config from agent",
slog.F("chat_id", chat.ID), slog.Error(cfgErr))
// Treat a transient ContextConfig failure the
// same as a failed connection so no sentinel is
// persisted. The next turn will retry.
workspaceConnOK = false
} else {
agentParts = agentCfg.Parts
}
}
}
// Stamp server-side fields and sanitize content. The
// agent cannot know its own UUID, OS metadata, or
// directory — those are added here at the trust boundary.
agentID := uuid.NullUUID{UUID: loadedAgent.ID, Valid: true}
for i := range agentParts {
agentParts[i].ContextFileAgentID = agentID
switch agentParts[i].Type {
case codersdk.ChatMessagePartTypeContextFile:
agentParts[i].ContextFileContent = SanitizePromptText(agentParts[i].ContextFileContent)
agentParts[i].ContextFileOS = loadedAgent.OperatingSystem
agentParts[i].ContextFileDirectory = directory
case codersdk.ChatMessagePartTypeSkill:
discoveredSkills = append(discoveredSkills, chattool.SkillMeta{
Name: agentParts[i].SkillName,
Description: agentParts[i].SkillDescription,
Dir: agentParts[i].SkillDir,
MetaFile: agentParts[i].ContextFileSkillMetaFile,
})
}
}
return &loadedAgent, agentParts, discoveredSkills, workspaceConnOK
}
func filterSkillParts(parts []codersdk.ChatMessagePart) []codersdk.ChatMessagePart {
var filtered []codersdk.ChatMessagePart
for _, part := range parts {
if part.Type == codersdk.ChatMessagePartTypeSkill {
filtered = append(filtered, part)
}
}
return filtered
}
// persistInstructionFiles fetches AGENTS.md instruction files and
// skills from the workspace agent, persisting both as message
// parts. This is called once when a workspace is first attached
// to a chat (or when the agent changes). Returns the formatted
// instruction string and skill index for injection into the
// current turn's prompt.
func (p *Server) persistInstructionFiles(
ctx context.Context,
chat database.Chat,
modelConfigID uuid.UUID,
getWorkspaceAgent func(context.Context) (database.WorkspaceAgent, error),
getWorkspaceConn func(context.Context) (workspacesdk.AgentConn, error),
) (instruction string, skills []chattool.SkillMeta, err error) {
agent, agentParts, discoveredSkills, workspaceConnOK := p.fetchWorkspaceContext(
ctx, chat, getWorkspaceAgent, getWorkspaceConn,
)
if agent == nil {
return "", nil, nil
}
agentID := uuid.NullUUID{UUID: agent.ID, Valid: true}
hasContent := false
hasContextFilePart := false
for _, part := range agentParts {
if part.Type == codersdk.ChatMessagePartTypeContextFile {
hasContextFilePart = true
if part.ContextFileContent != "" {
hasContent = true
}
}
}
directory := agent.ExpandedDirectory
if directory == "" {
directory = agent.Directory
}
contextAPIKeyID, _ := aibridge.DelegatedAPIKeyIDFromContext(ctx)
if !hasContent {
if !workspaceConnOK {
return "", nil, nil
}
if !hasContextFilePart {
agentParts = append([]codersdk.ChatMessagePart{{
Type: codersdk.ChatMessagePartTypeContextFile,
ContextFileAgentID: agentID,
}}, agentParts...)
}
content, err := chatprompt.MarshalParts(agentParts)
if err != nil {
return "", nil, nil
}
msgParams := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendUserChatMessage.
ChatID: chat.ID,
}
appendUserChatMessage(&msgParams, newUserChatMessage(
contextAPIKeyID,
content,
database.ChatMessageVisibilityBoth,
modelConfigID,
chatprompt.CurrentContentVersion,
))
_, _ = p.db.InsertChatMessages(ctx, msgParams)
skillParts := filterSkillParts(agentParts)
p.updateLastInjectedContext(ctx, chat.ID, skillParts)
return "", discoveredSkills, nil
}
content, err := chatprompt.MarshalParts(agentParts)
if err != nil {
return "", nil, xerrors.Errorf("marshal context-file parts: %w", err)
}
msgParams := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendUserChatMessage.
ChatID: chat.ID,
}
appendUserChatMessage(&msgParams, newUserChatMessage(
contextAPIKeyID,
content,
database.ChatMessageVisibilityBoth,
modelConfigID,
chatprompt.CurrentContentVersion,
))
if _, err := p.db.InsertChatMessages(ctx, msgParams); err != nil {
return "", nil, xerrors.Errorf("persist instruction files: %w", err)
}
stripped := make([]codersdk.ChatMessagePart, len(agentParts))
copy(stripped, agentParts)
for i := range stripped {
stripped[i].StripInternal()
}
p.updateLastInjectedContext(ctx, chat.ID, stripped)
return formatSystemInstructions(agent.OperatingSystem, directory, agentParts), discoveredSkills, nil
}
// updateLastInjectedContext persists the injected context
// parts (AGENTS.md files and skills) on the chat row so they
// are directly queryable without scanning messages. This is
// best-effort — a failure here is logged but does not block
// the turn.
func (p *Server) updateLastInjectedContext(ctx context.Context, chatID uuid.UUID, parts []codersdk.ChatMessagePart) {
param := pqtype.NullRawMessage{Valid: false}
if parts != nil {
raw, err := json.Marshal(parts)
if err != nil {
p.logger.Warn(ctx, "failed to marshal injected context",
slog.F("chat_id", chatID),
slog.Error(err),
)
return
}
param = pqtype.NullRawMessage{RawMessage: raw, Valid: true}
}
if _, err := p.db.UpdateChatLastInjectedContext(ctx, database.UpdateChatLastInjectedContextParams{
ID: chatID,
LastInjectedContext: param,
}); err != nil {
p.logger.Warn(ctx, "failed to update injected context",
slog.F("chat_id", chatID),
slog.Error(err),
)
}
}
// resolveUserCompactionThreshold looks up the user's per-model
// compaction threshold override. Returns the override value and
// true if one exists and is valid, or 0 and false otherwise.
func (p *Server) resolveUserCompactionThreshold(ctx context.Context, userID uuid.UUID, modelConfigID uuid.UUID) (int32, bool) {
raw, err := p.db.GetUserChatCompactionThreshold(ctx, database.GetUserChatCompactionThresholdParams{
UserID: userID,
Key: codersdk.CompactionThresholdKey(modelConfigID),
})
if errors.Is(err, sql.ErrNoRows) {
return 0, false
}
if err != nil {
p.logger.Warn(ctx, "failed to fetch compaction threshold override",
slog.F("user_id", userID),
slog.F("model_config_id", modelConfigID),
slog.Error(err),
)
return 0, false
}
// Range 0..100 must stay in sync with handler validation in
// coderd/chats.go.
val, err := strconv.ParseInt(raw, 10, 32)
if err != nil || val < 0 || val > 100 {
return 0, false
}
return int32(val), true
}
// resolveDeploymentSystemPrompt builds the deployment-level system
// prompt from the built-in default and the admin-configured custom
// prompt stored in site_configs.
func (p *Server) resolveDeploymentSystemPrompt(ctx context.Context) string {
config, err := p.db.GetChatSystemPromptConfig(ctx)
if err != nil {
// Fail open: use the built-in default so chats always have
// some system guidance.
p.logger.Error(ctx, "failed to fetch chat system prompt configuration, using default", slog.Error(err))
return DefaultSystemPrompt
}
sanitizedCustom := SanitizePromptText(config.ChatSystemPrompt)
if sanitizedCustom == "" && strings.TrimSpace(config.ChatSystemPrompt) != "" {
p.logger.Warn(ctx, "custom system prompt became empty after sanitization, omitting custom portion")
}
var parts []string
if config.IncludeDefaultSystemPrompt {
parts = append(parts, DefaultSystemPrompt)
}
if sanitizedCustom != "" {
parts = append(parts, sanitizedCustom)
}
result := strings.Join(parts, "\n\n")
if result == "" {
p.logger.Warn(ctx, "resolved system prompt is empty, no system prompt will be injected into chats")
}
return result
}
// resolveUserPrompt fetches the user's custom chat prompt from the
// database and wraps it in <user-instructions> tags. Returns empty
// string if no prompt is set.
func (p *Server) resolveUserPrompt(ctx context.Context, userID uuid.UUID) string {
raw, err := p.configCache.UserPrompt(ctx, userID)
if err != nil {
// sql.ErrNoRows is the normal "not set" case.
return ""
}
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return ""
}
return "<user-instructions>\n" + trimmed + "\n</user-instructions>"
}
// renderPlanPathPrompt fills the plan-path placeholder when it is
// present in the prompt.
func renderPlanPathPrompt(prompt []fantasy.Message, planPathBlock string) []fantasy.Message {
prompt, _ = replacePlanPathPlaceholder(prompt, planPathBlock)
return prompt
}
func replacePlanPathPlaceholder(
prompt []fantasy.Message,
planPathBlock string,
) ([]fantasy.Message, bool) {
var updatedPrompt []fantasy.Message
replaced := false
for i, message := range prompt {
updatedMessage, ok := replacePlanPathPlaceholderInMessage(message, planPathBlock)
if !ok {
continue
}
if updatedPrompt == nil {
updatedPrompt = slices.Clone(prompt)
}
updatedPrompt[i] = updatedMessage
replaced = true
}
if !replaced {
return prompt, false
}
return updatedPrompt, true
}
func replacePlanPathPlaceholderInMessage(
message fantasy.Message,
planPathBlock string,
) (fantasy.Message, bool) {
if message.Role != fantasy.MessageRoleSystem {
return message, false
}
content := slices.Clone(message.Content)
replaced := false
for i, part := range content {
textPart, ok := fantasy.AsMessagePart[fantasy.TextPart](part)
if !ok || !strings.Contains(textPart.Text, defaultSystemPromptPlanPathBlockPlaceholder) {
continue
}
replaced = true
content[i] = fantasy.TextPart{Text: strings.ReplaceAll(
textPart.Text,
defaultSystemPromptPlanPathBlockPlaceholder,
planPathBlock,
)}
}
if !replaced {
return message, false
}
message.Content = content
return message, true
}
func formatPlanPathBlock(chatPath, home string) string {
chatPath = strings.TrimSpace(chatPath)
if chatPath == "" {
return ""
}
avoidPlanPath := chattool.LegacySharedPlanPath
home = strings.TrimSpace(home)
if home != "" {
avoidPlanPath = strings.TrimRight(home, "/") + "/PLAN.md"
}
var b strings.Builder
_, _ = b.WriteString("<plan-file-path>\n")
_, _ = b.WriteString("Your plan file path for this chat is: ")
_, _ = b.WriteString(chatPath)
_, _ = b.WriteString("\n")
_, _ = b.WriteString("Always use this exact path when creating or proposing plan files. Do not use ")
_, _ = b.WriteString(avoidPlanPath)
_, _ = b.WriteString(".\n")
_, _ = b.WriteString("</plan-file-path>")
return b.String()
}
// parseDynamicToolNames unmarshals the dynamic tools JSON column
// and returns a map of tool names. This centralizes the repeated
// pattern of deserializing DynamicTools into a name set.
func parseDynamicToolNames(raw pqtype.NullRawMessage) (map[string]bool, error) {
if !raw.Valid || len(raw.RawMessage) == 0 {
return make(map[string]bool), nil
}
var tools []codersdk.DynamicTool
if err := json.Unmarshal(raw.RawMessage, &tools); err != nil {
return nil, xerrors.Errorf("unmarshal dynamic tools: %w", err)
}
names := make(map[string]bool, len(tools))
for _, t := range tools {
names[t.Name] = true
}
return names, nil
}
// maybeFinalizeTurnStatusLabelAndPush updates the cached turn status label
// for parent chats and optionally sends a web push notification.
func (p *Server) maybeFinalizeTurnStatusLabelAndPush(
ctx context.Context,
chat database.Chat,
status database.ChatStatus,
lastError string,
runResult runChatResult,
logger slog.Logger,
) {
if chat.ParentChatID.Valid {
return
}
switch status {
case database.ChatStatusWaiting:
p.finalizeSuccessfulTurnStatusLabelAndPush(ctx, chat, status, runResult, logger)
case database.ChatStatusPending:
p.setLastTurnSummaryAsync(ctx, chat, fallbackTurnStatusLabel(status), logger)
case database.ChatStatusError:
p.clearLastTurnSummaryAsync(ctx, chat, logger)
if p.webpushConfigured() {
pushBody := fallbackTurnStatusLabel(status)
if lastError != "" {
pushBody = lastError
}
p.dispatchPush(ctx, chat, pushBody, status, logger)
}
case database.ChatStatusRequiresAction:
p.setLastTurnSummaryAsync(ctx, chat, fallbackTurnStatusLabel(status), logger)
default:
// New statuses must be classified before they can safely
// preserve or finalize a cached turn status label.
p.clearLastTurnSummaryAsync(ctx, chat, logger)
}
}
func (p *Server) finalizeSuccessfulTurnStatusLabelAndPush(
ctx context.Context,
chat database.Chat,
status database.ChatStatus,
runResult runChatResult,
logger slog.Logger,
) {
p.finalizeSuccessfulTurnStatusLabelWithAfterFunc(ctx, chat, status, runResult, logger, func(finalizeCtx context.Context, statusLabel string) {
p.dispatchSuccessfulTurnPush(finalizeCtx, chat, statusLabel, logger)
})
}
func (p *Server) finalizeSuccessfulTurnStatusLabelWithAfterFunc(
ctx context.Context,
chat database.Chat,
status database.ChatStatus,
runResult runChatResult,
logger slog.Logger,
afterFinalize func(context.Context, string),
) {
// This helper runs during processChat cleanup, while processChat is
// still counted in p.inflight. Do not take inflightMu here because
// drainInflight holds it while waiting.
p.inflight.Go(func() {
finalizeCtx := context.WithoutCancel(ctx)
statusLabel := p.generateFinalTurnStatusLabel(finalizeCtx, chat, status, runResult, logger)
logger.Debug(finalizeCtx, "generated chat turn status label",
slog.F("chat_id", chat.ID),
slog.F("status", status),
slog.F("label_length", len(statusLabel)),
)
p.updateLastTurnSummary(finalizeCtx, chat, chat.HistoryVersion, statusLabel, logger)
afterFinalize(finalizeCtx, statusLabel)
})
}
func (p *Server) generateFinalTurnStatusLabel(
ctx context.Context,
chat database.Chat,
status database.ChatStatus,
runResult runChatResult,
logger slog.Logger,
) string {
if status != database.ChatStatusWaiting {
return fallbackTurnStatusLabel(status)
}
assistantText := strings.TrimSpace(runResult.FinalAssistantText)
if assistantText == "" || runResult.StatusLabelModel == nil {
return fallbackTurnStatusLabel(status)
}
statusLabel := p.generateTurnStatusLabel(
ctx,
chat,
status,
assistantText,
runResult.FallbackProvider,
runResult.FallbackModel,
runResult.StatusLabelModel,
runResult.FallbackRoute,
runResult.ProviderKeys,
runResult.ModelBuildOptions,
logger,
p.existingDebugService(),
runResult.TriggerMessageID,
runResult.HistoryTipMessageID,
)
if statusLabel == "" {
return fallbackTurnStatusLabel(status)
}
return statusLabel
}
func (p *Server) dispatchSuccessfulTurnPush(
ctx context.Context,
chat database.Chat,
statusLabel string,
logger slog.Logger,
) {
if !p.webpushConfigured() {
return
}
pushBody := fallbackTurnStatusLabel(database.ChatStatusWaiting)
if statusLabel != "" {
pushBody = statusLabel
}
p.dispatchPush(ctx, chat, pushBody, database.ChatStatusWaiting, logger)
}
func (p *Server) maybeClearLastTurnSummaryAsync(
ctx context.Context,
chat database.Chat,
logger slog.Logger,
) {
if chat.ParentChatID.Valid {
return
}
p.clearLastTurnSummaryAsync(ctx, chat, logger)
}
func (p *Server) setLastTurnSummaryAsync(
ctx context.Context,
chat database.Chat,
summary string,
logger slog.Logger,
) {
summary = strings.TrimSpace(summary)
if summary == "" {
p.clearLastTurnSummaryAsync(ctx, chat, logger)
return
}
if chat.LastTurnSummary.Valid && strings.TrimSpace(chat.LastTurnSummary.String) == summary {
return
}
// This helper runs during processChat cleanup, while processChat is
// still counted in p.inflight. Do not take inflightMu here because
// drainInflight holds it while waiting.
p.inflight.Go(func() {
p.updateLastTurnSummary(context.WithoutCancel(ctx), chat, chat.HistoryVersion, summary, logger)
})
}
func (p *Server) clearLastTurnSummaryAsync(
ctx context.Context,
chat database.Chat,
logger slog.Logger,
) {
// This helper runs during processChat cleanup, while processChat is
// still counted in p.inflight. Do not take inflightMu here because
// drainInflight holds it while waiting.
p.inflight.Go(func() {
p.updateLastTurnSummary(context.WithoutCancel(ctx), chat, chat.HistoryVersion, "", logger)
})
}
// updateLastTurnSummary writes the cached sidebar summary for a chat.
// Callers should pass a detached context because this method is used for
// best-effort background cache writes.
func (p *Server) updateLastTurnSummary(
ctx context.Context,
chat database.Chat,
expectedHistoryVersion int64,
summary string,
logger slog.Logger,
) {
summary = strings.TrimSpace(summary)
lastTurnSummary := sql.NullString{String: summary, Valid: summary != ""}
//nolint:gocritic // Narrow daemon access for best-effort summary cache writes.
updateCtx := dbauthz.AsChatd(ctx)
updateCtx, cancel := context.WithTimeout(updateCtx, turnStatusLabelWriteTimeout)
defer cancel()
affected, err := p.db.UpdateChatLastTurnSummary(updateCtx, database.UpdateChatLastTurnSummaryParams{
ID: chat.ID,
ExpectedHistoryVersion: expectedHistoryVersion,
LastTurnSummary: lastTurnSummary,
})
if err != nil {
logger.Warn(updateCtx, "failed to update chat turn summary",
slog.F("chat_id", chat.ID),
slog.Error(err),
)
return
}
if affected == 0 {
if summary != "" {
logger.Info(updateCtx, "skipped stale chat turn summary update with non-empty summary",
slog.F("chat_id", chat.ID),
slog.F("summary_length", len(summary)),
slog.F("expected_history_version", expectedHistoryVersion),
)
return
}
logger.Debug(updateCtx, "skipped stale chat turn summary update",
slog.F("chat_id", chat.ID),
slog.F("expected_history_version", expectedHistoryVersion),
)
return
}
updatedChat := chat
updatedChat.LastTurnSummary = lastTurnSummary
p.publishChatPubsubEvent(updatedChat, codersdk.ChatWatchEventKindSummaryChange, nil)
}
func (p *Server) webpushConfigured() bool {
return p.webpushDispatcher != nil && p.webpushDispatcher.PublicKey() != ""
}
func (p *Server) dispatchPush(
ctx context.Context,
chat database.Chat,
body string,
status database.ChatStatus,
logger slog.Logger,
) {
pushMsg := codersdk.WebpushMessage{
Title: chat.Title,
Body: body,
Icon: "/favicon.ico",
Data: map[string]string{"url": fmt.Sprintf("/agents/%s", chat.ID)},
}
if err := p.webpushDispatcher.Dispatch(ctx, chat.OwnerID, pushMsg); err != nil {
logger.Warn(ctx, "failed to send chat completion web push",
slog.F("chat_id", chat.ID),
slog.F("status", status),
slog.Error(err),
)
}
}
// Close stops the processor and waits for it to finish.
func (p *Server) Close() error {
if unsub := p.configCacheUnsubscribe; unsub != nil {
p.configCacheUnsubscribe = nil
unsub()
}
if p.chatWorker != nil {
if err := p.chatWorker.Close(); err != nil {
p.logger.Warn(context.Background(), "failed to close chat worker", slog.Error(err))
}
}
if p.messagePartBuffer != nil {
p.messagePartBuffer.Close()
}
p.cancel()
p.wg.Wait()
p.drainInflight()
return nil
}
// drainInflight waits for all in-flight operations to complete.
// It acquires inflightMu to prevent processOnce from spawning
// new goroutines (via inflight.Add) concurrently with Wait,
// which would violate sync.WaitGroup's contract.
//
// https://pkg.go.dev/sync#WaitGroup.Add
// > Note that calls with a positive delta that occur when the counter is zero must happen before a Wait.
func (p *Server) drainInflight() {
p.inflightMu.Lock()
p.inflight.Wait()
p.inflightMu.Unlock()
}
// refreshExpiredMCPTokens checks each MCP OAuth2 token and refreshes
// any that are expired (or about to expire). Tokens without a
// refresh_token or that fail to refresh are returned unchanged so the
// caller can still attempt the connection (which will likely fail with
// a 401 for the expired ones).
func (p *Server) refreshExpiredMCPTokens(
ctx context.Context,
logger slog.Logger,
configs []database.MCPServerConfig,
tokens []database.MCPServerUserToken,
) []database.MCPServerUserToken {
configsByID := make(map[uuid.UUID]database.MCPServerConfig, len(configs))
for _, cfg := range configs {
configsByID[cfg.ID] = cfg
}
result := slices.Clone(tokens)
var eg errgroup.Group
for i, tok := range result {
cfg, ok := configsByID[tok.MCPServerConfigID]
if !ok || cfg.AuthType != "oauth2" {
continue
}
if tok.RefreshToken == "" {
continue
}
eg.Go(func() error {
refreshed, err := p.refreshMCPTokenIfNeeded(ctx, logger, cfg, tok)
if err != nil {
logger.Warn(ctx, "failed to refresh MCP oauth2 token",
slog.F("server_slug", cfg.Slug),
slog.Error(err),
)
return nil
}
result[i] = refreshed
return nil
})
}
_ = eg.Wait()
return result
}
// refreshMCPTokenIfNeeded delegates to mcpclient.RefreshOAuth2Token
// and persists the result to the database when a refresh occurs.
// The logger should carry chat-scoped fields so log lines can be
// correlated with specific chat requests.
func (p *Server) refreshMCPTokenIfNeeded(
ctx context.Context,
logger slog.Logger,
cfg database.MCPServerConfig,
tok database.MCPServerUserToken,
) (database.MCPServerUserToken, error) {
result, err := mcpclient.RefreshOAuth2Token(ctx, cfg, tok)
if err != nil {
return tok, err
}
if !result.Refreshed {
return tok, nil
}
logger.Info(ctx, "refreshed MCP oauth2 token",
slog.F("server_slug", cfg.Slug),
slog.F("user_id", tok.UserID),
)
var expiry sql.NullTime
if !result.Expiry.IsZero() {
expiry = sql.NullTime{Time: result.Expiry, Valid: true}
}
//nolint:gocritic // Chatd needs system-level write access to
// persist the refreshed OAuth2 token for the user.
updated, err := p.db.UpsertMCPServerUserToken(
dbauthz.AsSystemRestricted(ctx),
database.UpsertMCPServerUserTokenParams{
MCPServerConfigID: tok.MCPServerConfigID,
UserID: tok.UserID,
AccessToken: result.AccessToken,
AccessTokenKeyID: sql.NullString{},
RefreshToken: result.RefreshToken,
RefreshTokenKeyID: sql.NullString{},
TokenType: result.TokenType,
Expiry: expiry,
},
)
if err != nil {
// The provider may have rotated the refresh token,
// invalidating the old one. Use the new token
// in-memory so at least this connection succeeds.
logger.Warn(ctx, "failed to persist refreshed MCP oauth2 token, using in-memory",
slog.F("server_slug", cfg.Slug),
slog.Error(err),
)
tok.AccessToken = result.AccessToken
tok.RefreshToken = result.RefreshToken
tok.TokenType = result.TokenType
tok.Expiry = expiry
return tok, nil
}
return updated, nil
}