mirror of
https://github.com/coder/coder.git
synced 2026-06-06 14:38:23 +00:00
6585 lines
208 KiB
Go
6585 lines
208 KiB
Go
package chatd
|
|
|
|
import (
|
|
"bytes"
|
|
"cmp"
|
|
"context"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"math"
|
|
"net/http"
|
|
"slices"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"charm.land/fantasy"
|
|
"charm.land/fantasy/providers/anthropic"
|
|
"github.com/dustin/go-humanize"
|
|
"github.com/google/uuid"
|
|
"github.com/prometheus/client_golang/prometheus"
|
|
"github.com/shopspring/decimal"
|
|
"github.com/sqlc-dev/pqtype"
|
|
"golang.org/x/sync/errgroup"
|
|
"golang.org/x/xerrors"
|
|
|
|
"cdr.dev/slog/v3"
|
|
"github.com/coder/coder/v2/coderd/aibridge"
|
|
"github.com/coder/coder/v2/coderd/audit"
|
|
"github.com/coder/coder/v2/coderd/database"
|
|
"github.com/coder/coder/v2/coderd/database/db2sdk"
|
|
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
|
"github.com/coder/coder/v2/coderd/database/pubsub"
|
|
"github.com/coder/coder/v2/coderd/notifications"
|
|
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
|
|
"github.com/coder/coder/v2/coderd/rbac"
|
|
"github.com/coder/coder/v2/coderd/util/ptr"
|
|
"github.com/coder/coder/v2/coderd/util/xjson"
|
|
"github.com/coder/coder/v2/coderd/webpush"
|
|
"github.com/coder/coder/v2/coderd/workspacestats"
|
|
"github.com/coder/coder/v2/coderd/x/chatd/chatadvisor"
|
|
"github.com/coder/coder/v2/coderd/x/chatd/chatcost"
|
|
"github.com/coder/coder/v2/coderd/x/chatd/chatdebug"
|
|
"github.com/coder/coder/v2/coderd/x/chatd/chaterror"
|
|
"github.com/coder/coder/v2/coderd/x/chatd/chatloop"
|
|
"github.com/coder/coder/v2/coderd/x/chatd/chatopenai"
|
|
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
|
|
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
|
|
"github.com/coder/coder/v2/coderd/x/chatd/chatstate"
|
|
"github.com/coder/coder/v2/coderd/x/chatd/chattool"
|
|
"github.com/coder/coder/v2/coderd/x/chatd/internal/agentselect"
|
|
"github.com/coder/coder/v2/coderd/x/chatd/mcpclient"
|
|
"github.com/coder/coder/v2/coderd/x/chatd/messagepartbuffer"
|
|
skillspkg "github.com/coder/coder/v2/coderd/x/skills"
|
|
"github.com/coder/coder/v2/codersdk"
|
|
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
|
"github.com/coder/quartz"
|
|
)
|
|
|
|
const (
|
|
// DefaultPendingChatAcquireInterval is the default time between attempts to
|
|
// acquire pending chats.
|
|
DefaultPendingChatAcquireInterval = time.Second
|
|
// DefaultInFlightChatStaleAfter is the default age after which a running
|
|
// chat is considered stale and should be recovered.
|
|
DefaultInFlightChatStaleAfter = 5 * time.Minute
|
|
|
|
homeInstructionLookupTimeout = 5 * time.Second
|
|
planPathLookupTimeout = 5 * time.Second
|
|
workspaceDialValidationDelay = 5 * time.Second
|
|
// Must exceed agent/x/agentmcp.connectTimeout (30s) so a
|
|
// cold-start agent's first MCP reload can settle before
|
|
// chatd gives up.
|
|
workspaceMCPDiscoveryTimeout = 35 * time.Second
|
|
// workspaceMCPPrimeMaxWait bounds the deadline used by the
|
|
// create_workspace / start_workspace post-ready cache primer
|
|
// loop. The primer checks the deadline only after each
|
|
// discoverWorkspaceMCPTools call returns, so total wall-clock
|
|
// time can exceed this by one such call (dialTimeout +
|
|
// workspaceMCPDiscoveryTimeout in the worst case). The constant
|
|
// caps when new retries can start, not when an in-flight call
|
|
// must finish. Empty results usually mean the agent's MCP
|
|
// Connect is still racing with agent startup. The agent-side
|
|
// budget is agent/x/agentmcp.connectTimeout (30s).
|
|
workspaceMCPPrimeMaxWait = 30 * time.Second
|
|
// workspaceMCPPrimeRetryInterval is the short backoff between
|
|
// re-attempts inside the primer when ListMCPTools returns an
|
|
// empty list without error.
|
|
workspaceMCPPrimeRetryInterval = 2 * time.Second
|
|
turnStatusLabelWriteTimeout = 5 * time.Second
|
|
// defaultDialTimeout matches the timeout used by ~8 other
|
|
// server-side AgentConn callers.
|
|
defaultDialTimeout = 30 * time.Second
|
|
// DefaultChatHeartbeatInterval is the default time between chat
|
|
// heartbeat updates while a chat is being processed.
|
|
DefaultChatHeartbeatInterval = 30 * time.Second
|
|
maxChatSteps = 1200
|
|
|
|
// RelaySentinelAfterID is the after_id sentinel used by cross-replica
|
|
// relay subscribers. It instructs the peer to skip the durable DB
|
|
// snapshot and only deliver buffered message_part events. The
|
|
// buffer itself filters committed parts out (see snapshotBufferLocked),
|
|
// so the sentinel resolves to "send me any in-progress streaming
|
|
// parts you have; I will receive durable messages through pubsub."
|
|
RelaySentinelAfterID = math.MaxInt64
|
|
|
|
// maxConcurrentRecordingUploads caps the number of recording
|
|
// stop-and-store operations that can run concurrently. Each
|
|
// slot buffers up to MaxRecordingSize + MaxThumbnailSize
|
|
// (110 MB) in memory, so this value implicitly bounds memory
|
|
// to roughly maxConcurrentRecordingUploads * 110 MB.
|
|
maxConcurrentRecordingUploads = 25
|
|
|
|
// bufferRetainGracePeriod is how long the per-chat stream
|
|
// state is kept after processing completes. The retained
|
|
// state lets late-connecting cross-replica relay subscribers
|
|
// register against the live stream before the next worker
|
|
// run starts, preventing a race between cleanupStreamIfIdle
|
|
// and subscriber registration. The buffer itself is no
|
|
// longer useful at this point: every part has been claimed
|
|
// by its durable assistant message and is filtered out of
|
|
// the subscriber snapshot.
|
|
bufferRetainGracePeriod = 5 * time.Second
|
|
// chatStreamControlFetchTimeout bounds subscriber-owned
|
|
// control-path DB reads when the caller has no deadline.
|
|
chatStreamControlFetchTimeout = 5 * time.Second
|
|
|
|
// streamJanitorInterval is how often sweepIdleStreams runs.
|
|
// Worst-case retention is bufferRetainGracePeriod +
|
|
// streamJanitorInterval.
|
|
streamJanitorInterval = 30 * time.Second
|
|
|
|
// agentDisconnectedRecoveryThreshold is how long the latest
|
|
// workspace agent must be disconnected before chatd suggests
|
|
// destructive stop/start recovery. This is intentionally longer
|
|
// than the inactive-disconnect timeout so short heartbeat gaps do
|
|
// not prompt a workspace restart.
|
|
agentDisconnectedRecoveryThreshold = 90 * time.Second
|
|
|
|
// DefaultMaxChatsPerAcquire is the maximum number of chats to
|
|
// acquire in a single processOnce call. Batching avoids
|
|
// waiting a full polling interval between acquisitions
|
|
// when many chats are pending.
|
|
DefaultMaxChatsPerAcquire int32 = 10
|
|
|
|
defaultSubagentInstruction = "You are running as a delegated sub-agent chat. Complete the delegated task and provide clear, concise assistant responses for the parent agent."
|
|
|
|
// defaultAdvisorMaxOutputTokens caps the nested advisor response
|
|
// when the admin config omits the field (or sets it to <= 0).
|
|
// It is intentionally generous relative to the advisor's concise
|
|
// guidance remit so short plans are not truncated mid-reasoning.
|
|
defaultAdvisorMaxOutputTokens = 16384
|
|
)
|
|
|
|
var (
|
|
errChatHasNoWorkspaceAgent = xerrors.New("workspace has no running agent: the workspace is likely stopped. Use the start_workspace tool to start it")
|
|
errChatAgentDisconnected = xerrors.New(
|
|
"workspace agent has been disconnected for at least 90 seconds " +
|
|
"and cannot execute tools. To recover, call stop_workspace " +
|
|
"to stop the workspace, then start_workspace to start it " +
|
|
"again",
|
|
)
|
|
errChatDialTimeout = xerrors.New(
|
|
"connection to the workspace agent timed out. " +
|
|
"The agent may still be reachable on the next attempt.",
|
|
)
|
|
errChatExternalAgentUnavailable = xerrors.New("external workspace agent unavailable")
|
|
)
|
|
|
|
type chatExternalAgentUnavailableError struct {
|
|
message string
|
|
}
|
|
|
|
func (e chatExternalAgentUnavailableError) Error() string {
|
|
return e.message
|
|
}
|
|
|
|
func (chatExternalAgentUnavailableError) Is(target error) bool {
|
|
return target == errChatExternalAgentUnavailable
|
|
}
|
|
|
|
func newChatExternalAgentUnavailableError(agent database.WorkspaceAgent) error {
|
|
return chatExternalAgentUnavailableError{
|
|
message: chattool.ExternalAgentUnavailableMessage(agent),
|
|
}
|
|
}
|
|
|
|
// Server handles background processing of pending chats.
|
|
type Server struct {
|
|
cancel context.CancelFunc
|
|
ctx context.Context
|
|
wg sync.WaitGroup
|
|
inflight sync.WaitGroup
|
|
inflightMu sync.Mutex
|
|
|
|
db database.Store
|
|
workerID uuid.UUID
|
|
logger slog.Logger
|
|
|
|
subscribeFn SubscribeFn
|
|
|
|
agentConnFn AgentConnFunc
|
|
agentInactiveDisconnectTimeout time.Duration
|
|
dialTimeout time.Duration
|
|
instructionLookupTimeout time.Duration
|
|
createWorkspaceFn chattool.CreateWorkspaceFn
|
|
startWorkspaceFn chattool.StartWorkspaceFn
|
|
stopWorkspaceFn chattool.StopWorkspaceFn
|
|
pubsub pubsub.Pubsub
|
|
webpushDispatcher webpush.Dispatcher
|
|
providerAPIKeys chatprovider.ProviderAPIKeys
|
|
allowBYOK bool
|
|
oidcTokenSource mcpclient.UserOIDCTokenSource
|
|
debugSvc *chatdebug.Service
|
|
debugSvcFactory func() *chatdebug.Service
|
|
debugSvcReady atomic.Bool
|
|
debugSvcInit sync.Once
|
|
configCache *chatConfigCache
|
|
configCacheUnsubscribe func()
|
|
|
|
// chatStreams stores per-chat stream state. Using sync.Map
|
|
// gives each chat independent locking — concurrent chats
|
|
// never contend with each other.
|
|
chatStreams sync.Map // uuid.UUID -> *chatStreamState
|
|
|
|
// workspaceMCPToolsCache caches workspace MCP tool definitions
|
|
// per chat to avoid re-fetching on every turn. The cache is
|
|
// keyed by chat ID and invalidated when the agent changes.
|
|
workspaceMCPToolsCache sync.Map // uuid.UUID -> *cachedWorkspaceMCPTools
|
|
|
|
usageTracker *workspacestats.UsageTracker
|
|
clock quartz.Clock
|
|
metrics *chatloop.Metrics
|
|
chatWorker *chatWorker
|
|
messagePartBuffer *messagepartbuffer.Buffer
|
|
recordingSem chan struct{}
|
|
|
|
aibridgeTransportFactory *atomic.Pointer[aibridge.TransportFactory]
|
|
aiGatewayRoutingEnabled bool
|
|
|
|
// Configuration
|
|
pendingChatAcquireInterval time.Duration
|
|
maxChatsPerAcquire int32
|
|
inFlightChatStaleAfter time.Duration
|
|
chatHeartbeatInterval time.Duration
|
|
}
|
|
|
|
// chatTemplateAllowlist returns the deployment-wide template
|
|
// allowlist as a set of permitted template IDs. The callback
|
|
// signature matches what the chat tools expect. When the
|
|
// allowlist is empty or cannot be loaded the function returns
|
|
// nil, which the tools interpret as "all templates allowed".
|
|
func (p *Server) chatTemplateAllowlist() map[uuid.UUID]bool {
|
|
//nolint:gocritic // AsChatd provides narrowly-scoped daemon
|
|
// access for reading deployment config.
|
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
|
defer cancel()
|
|
//nolint:gocritic // AsChatd provides narrowly-scoped read
|
|
// access to deployment config (the template allowlist).
|
|
ctx = dbauthz.AsChatd(ctx)
|
|
raw, err := p.db.GetChatTemplateAllowlist(ctx)
|
|
if err != nil {
|
|
p.logger.Warn(ctx, "failed to load chat template allowlist", slog.Error(err))
|
|
return nil
|
|
}
|
|
ids, err := xjson.ParseUUIDList(raw)
|
|
if err != nil {
|
|
p.logger.Warn(ctx, "failed to parse chat template allowlist", slog.Error(err))
|
|
return nil
|
|
}
|
|
m := make(map[uuid.UUID]bool, len(ids))
|
|
for _, id := range ids {
|
|
m[id] = true
|
|
}
|
|
return m
|
|
}
|
|
|
|
func (p *Server) loadAdvisorConfig(ctx context.Context, logger slog.Logger) codersdk.AdvisorConfig {
|
|
cfg, err := p.configCache.AdvisorConfig(ctx)
|
|
if err != nil {
|
|
logger.Warn(ctx, "failed to load advisor config", slog.Error(err))
|
|
return codersdk.AdvisorConfig{}
|
|
}
|
|
return cfg
|
|
}
|
|
|
|
// stripAdvisorGuidanceBlock removes any system message whose text content
|
|
// matches chatadvisor.ParentGuidanceBlock after whitespace normalization.
|
|
// The block is meant for the parent agent (it advertises the advisor tool)
|
|
// and would waste context tokens if forwarded to the advisor's nested run.
|
|
func stripAdvisorGuidanceBlock(msgs []fantasy.Message) []fantasy.Message {
|
|
filtered := msgs[:0]
|
|
for _, msg := range msgs {
|
|
if msg.Role == fantasy.MessageRoleSystem && isAdvisorGuidanceMessage(msg) {
|
|
continue
|
|
}
|
|
filtered = append(filtered, msg)
|
|
}
|
|
return filtered
|
|
}
|
|
|
|
func isAdvisorGuidanceMessage(msg fantasy.Message) bool {
|
|
if len(msg.Content) != 1 {
|
|
return false
|
|
}
|
|
text, ok := msg.Content[0].(fantasy.TextPart)
|
|
if !ok {
|
|
return false
|
|
}
|
|
return strings.TrimSpace(text.Text) == strings.TrimSpace(chatadvisor.ParentGuidanceBlock)
|
|
}
|
|
|
|
func (p *Server) resolveAdvisorModelOverride(
|
|
ctx context.Context,
|
|
chat database.Chat,
|
|
advisorCfg codersdk.AdvisorConfig,
|
|
fallbackModel fantasy.LanguageModel,
|
|
fallbackCallConfig codersdk.ChatModelCallConfig,
|
|
providerKeys chatprovider.ProviderAPIKeys,
|
|
modelOpts modelBuildOptions,
|
|
logger slog.Logger,
|
|
) (fantasy.LanguageModel, codersdk.ChatModelCallConfig, error) {
|
|
if advisorCfg.ModelConfigID == uuid.Nil {
|
|
return fallbackModel, fallbackCallConfig, nil
|
|
}
|
|
|
|
// Re-read the override instead of using the cache so disabled models
|
|
// or providers stop routing advisor prompts immediately.
|
|
overrideConfig, err := p.db.GetEnabledChatModelConfigByID(
|
|
ctx,
|
|
advisorCfg.ModelConfigID,
|
|
)
|
|
if err != nil {
|
|
if xerrors.Is(err, sql.ErrNoRows) {
|
|
logger.Warn(
|
|
ctx,
|
|
"advisor model config is disabled or unavailable, continuing with chat model",
|
|
slog.F("model_config_id", advisorCfg.ModelConfigID),
|
|
)
|
|
return fallbackModel, fallbackCallConfig, nil
|
|
}
|
|
logger.Warn(
|
|
ctx,
|
|
"failed to resolve advisor model config, continuing with chat model",
|
|
slog.F("model_config_id", advisorCfg.ModelConfigID),
|
|
slog.Error(err),
|
|
)
|
|
return fallbackModel, fallbackCallConfig, nil
|
|
}
|
|
|
|
overrideCallConfig := codersdk.ChatModelCallConfig{}
|
|
if len(overrideConfig.Options) > 0 {
|
|
if err := json.Unmarshal(overrideConfig.Options, &overrideCallConfig); err != nil {
|
|
logger.Warn(
|
|
ctx,
|
|
"failed to parse advisor model config, continuing with chat model",
|
|
slog.F("model_config_id", advisorCfg.ModelConfigID),
|
|
slog.Error(err),
|
|
)
|
|
return fallbackModel, fallbackCallConfig, nil
|
|
}
|
|
}
|
|
|
|
route, err := p.resolveModelRouteForConfig(
|
|
ctx,
|
|
chat.OwnerID,
|
|
overrideConfig,
|
|
providerKeys,
|
|
)
|
|
if err != nil {
|
|
if p.shouldUseAIGatewayRouting() && overrideConfig.AIProviderID.Valid {
|
|
return nil, codersdk.ChatModelCallConfig{}, xerrors.Errorf("resolve advisor override route: %w", err)
|
|
}
|
|
logger.Warn(
|
|
ctx,
|
|
"failed to resolve advisor override route, continuing with chat model",
|
|
slog.F("model_config_id", advisorCfg.ModelConfigID),
|
|
slog.Error(err),
|
|
)
|
|
return fallbackModel, fallbackCallConfig, nil
|
|
}
|
|
overrideModel, err := p.newModel(ctx, modelClientRequest{
|
|
Chat: chat,
|
|
ModelName: overrideConfig.Model,
|
|
UserAgent: chatprovider.UserAgent(),
|
|
ExtraHeaders: chatprovider.CoderHeaders(chat),
|
|
}, route, modelOpts)
|
|
if err != nil {
|
|
if p.shouldUseAIGatewayRouting() && overrideConfig.AIProviderID.Valid {
|
|
return nil, codersdk.ChatModelCallConfig{}, xerrors.Errorf("create advisor override model: %w", err)
|
|
}
|
|
logger.Warn(
|
|
ctx,
|
|
"failed to create advisor override model, continuing with chat model",
|
|
slog.F("model_config_id", advisorCfg.ModelConfigID),
|
|
slog.Error(err),
|
|
)
|
|
return fallbackModel, fallbackCallConfig, nil
|
|
}
|
|
|
|
return overrideModel, overrideCallConfig, nil
|
|
}
|
|
|
|
func (p *Server) newAdvisorRuntime(
|
|
ctx context.Context,
|
|
chat database.Chat,
|
|
advisorCfg codersdk.AdvisorConfig,
|
|
fallbackModel fantasy.LanguageModel,
|
|
fallbackCallConfig codersdk.ChatModelCallConfig,
|
|
providerKeys chatprovider.ProviderAPIKeys,
|
|
modelOpts modelBuildOptions,
|
|
logger slog.Logger,
|
|
) (*chatadvisor.Runtime, error) {
|
|
advisorModel, advisorCallConfig, err := p.resolveAdvisorModelOverride(
|
|
ctx,
|
|
chat,
|
|
advisorCfg,
|
|
fallbackModel,
|
|
fallbackCallConfig,
|
|
providerKeys,
|
|
modelOpts,
|
|
logger,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
maxUsesPerRun := advisorCfg.MaxUsesPerRun
|
|
switch {
|
|
case maxUsesPerRun == 0:
|
|
// Advisor config treats 0 as unlimited, but the runtime
|
|
// requires a positive bound. maxChatSteps is the
|
|
// effective upper bound because advisor can run at most
|
|
// once per loop step.
|
|
maxUsesPerRun = maxChatSteps
|
|
case maxUsesPerRun < 0:
|
|
logger.Warn(
|
|
ctx,
|
|
"invalid advisor max uses per run, continuing without advisor",
|
|
slog.F("max_uses_per_run", maxUsesPerRun),
|
|
)
|
|
return nil, nil //nolint:nilnil // Nil runtime with nil error means advisor is skipped for this turn.
|
|
}
|
|
|
|
maxOutputTokens := advisorCfg.MaxOutputTokens
|
|
if maxOutputTokens <= 0 {
|
|
maxOutputTokens = defaultAdvisorMaxOutputTokens
|
|
}
|
|
|
|
advisorCallConfig.MaxOutputTokens = ptr.Ref(maxOutputTokens)
|
|
providerOptions := chatprovider.ProviderOptionsFromChatModelConfig(
|
|
advisorModel,
|
|
advisorCallConfig.ProviderOptions,
|
|
)
|
|
|
|
rt, err := chatadvisor.NewRuntime(chatadvisor.RuntimeConfig{
|
|
Model: advisorModel,
|
|
ModelConfig: advisorCallConfig,
|
|
ProviderOptions: providerOptions,
|
|
MaxUsesPerRun: maxUsesPerRun,
|
|
MaxOutputTokens: maxOutputTokens,
|
|
})
|
|
if err != nil {
|
|
logger.Warn(
|
|
ctx,
|
|
"failed to create advisor runtime, continuing without advisor",
|
|
slog.Error(err),
|
|
)
|
|
return nil, nil //nolint:nilnil // Nil runtime with nil error means advisor is skipped for this turn.
|
|
}
|
|
return rt, nil
|
|
}
|
|
|
|
// cachedWorkspaceMCPTools stores workspace MCP tools discovered
|
|
// from a workspace agent, keyed by the agent ID that provided them.
|
|
type cachedWorkspaceMCPTools struct {
|
|
agentID uuid.UUID
|
|
tools []workspacesdk.MCPToolInfo
|
|
}
|
|
|
|
// loadCachedWorkspaceContext checks the MCP tools cache for the
|
|
// given chat and agent. Returns non-nil tools when the cache hits,
|
|
// which signals the caller to skip the slow MCP discovery path.
|
|
func (p *Server) loadCachedWorkspaceContext(
|
|
chatID uuid.UUID,
|
|
agent database.WorkspaceAgent,
|
|
getConn func(context.Context) (workspacesdk.AgentConn, error),
|
|
) []fantasy.AgentTool {
|
|
cached, ok := p.workspaceMCPToolsCache.Load(chatID)
|
|
if !ok {
|
|
return nil
|
|
}
|
|
entry, ok := cached.(*cachedWorkspaceMCPTools)
|
|
if !ok || entry.agentID != agent.ID {
|
|
return nil
|
|
}
|
|
|
|
var tools []fantasy.AgentTool
|
|
invalidate := func() { p.workspaceMCPToolsCache.Delete(chatID) }
|
|
for _, t := range entry.tools {
|
|
tools = append(tools, chattool.NewWorkspaceMCPTool(t, getConn, invalidate))
|
|
}
|
|
|
|
return tools
|
|
}
|
|
|
|
// discoverWorkspaceMCPTools resolves the chat's workspace agent and
|
|
// lists the workspace MCP tools advertised by that agent. Results are
|
|
// cached per chat keyed on the agent ID so subsequent calls hit the
|
|
// cache. Returns nil (and never an error) on every failure mode so the
|
|
// caller can continue without MCP tools.
|
|
//
|
|
// This helper is shared between the initial discovery path and the
|
|
// mid-turn workspace binding path triggered after create_workspace or
|
|
// start_workspace bind a workspace to a chat that started without one.
|
|
func (p *Server) discoverWorkspaceMCPTools(
|
|
ctx context.Context,
|
|
logger slog.Logger,
|
|
chatID uuid.UUID,
|
|
workspaceCtx *turnWorkspaceContext,
|
|
) []fantasy.AgentTool {
|
|
// Fast path: check cache using the in-memory cached agent
|
|
// (ensureWorkspaceAgent is free when already loaded). This
|
|
// avoids a per-turn latest-build DB query on the common
|
|
// subsequent-turn path.
|
|
if agent, agentErr := workspaceCtx.getWorkspaceAgent(ctx); agentErr == nil {
|
|
if tools := p.loadCachedWorkspaceContext(
|
|
chatID, agent, workspaceCtx.getWorkspaceConn,
|
|
); tools != nil {
|
|
return tools
|
|
}
|
|
} // Cache miss, agent changed, or no cache: validate
|
|
// that the workspace still has a live agent before
|
|
// attempting a dial.
|
|
_, _, agentErr := workspaceCtx.workspaceAgentIDForConn(ctx)
|
|
if agentErr != nil {
|
|
if xerrors.Is(agentErr, errChatHasNoWorkspaceAgent) {
|
|
p.workspaceMCPToolsCache.Delete(chatID)
|
|
return nil
|
|
}
|
|
logger.Warn(ctx, "failed to resolve workspace agent for MCP tools",
|
|
slog.Error(agentErr))
|
|
return nil
|
|
}
|
|
|
|
// List workspace MCP tools via the agent conn.
|
|
conn, connErr := workspaceCtx.getWorkspaceConn(ctx)
|
|
if connErr != nil {
|
|
logger.Warn(ctx, "failed to get workspace conn for MCP tools",
|
|
slog.Error(connErr))
|
|
return nil
|
|
}
|
|
listCtx, cancel := context.WithTimeout(ctx, workspaceMCPDiscoveryTimeout)
|
|
defer cancel()
|
|
toolsResp, listErr := conn.ListMCPTools(listCtx)
|
|
if listErr != nil {
|
|
logger.Warn(ctx, "failed to list workspace MCP tools",
|
|
slog.Error(listErr))
|
|
return nil
|
|
}
|
|
// Cache the result for subsequent turns. Skip caching when
|
|
// the list is empty because the agent's MCP Connect may not
|
|
// have finished yet; caching an empty list would hide tools
|
|
// permanently.
|
|
if len(toolsResp.Tools) > 0 {
|
|
if agent, agentErr := workspaceCtx.getWorkspaceAgent(ctx); agentErr == nil {
|
|
p.workspaceMCPToolsCache.Store(chatID, &cachedWorkspaceMCPTools{
|
|
agentID: agent.ID,
|
|
tools: toolsResp.Tools,
|
|
})
|
|
}
|
|
}
|
|
|
|
invalidate := func() { p.workspaceMCPToolsCache.Delete(chatID) }
|
|
tools := make([]fantasy.AgentTool, 0, len(toolsResp.Tools))
|
|
for _, t := range toolsResp.Tools {
|
|
tools = append(tools, chattool.NewWorkspaceMCPTool(t, workspaceCtx.getWorkspaceConn, invalidate))
|
|
}
|
|
return tools
|
|
}
|
|
|
|
// primeWorkspaceMCPCache populates workspaceMCPToolsCache after the
|
|
// create_workspace or start_workspace tool finishes waiting for the
|
|
// workspace agent to become reachable. By the time it runs the agent
|
|
// is already Ready, so a single ListMCPTools call usually succeeds.
|
|
// When the agent's MCP server is still racing with agent startup,
|
|
// ListMCPTools may return an empty list (no error) on the first call;
|
|
// the primer retries with a short backoff up to
|
|
// workspaceMCPPrimeMaxWait so the generation action that follows the
|
|
// tool call sees the workspace MCP tools in the cache and does not need
|
|
// to dial again.
|
|
//
|
|
// Returns silently on every failure mode. The chat continues without
|
|
// workspace MCP tools when the agent does not advertise any within
|
|
// the budget. The next user turn re-runs top-of-turn discovery from
|
|
// scratch.
|
|
func (p *Server) primeWorkspaceMCPCache(
|
|
ctx context.Context,
|
|
logger slog.Logger,
|
|
chatID uuid.UUID,
|
|
workspaceCtx *turnWorkspaceContext,
|
|
) {
|
|
deadline := p.clock.Now().Add(workspaceMCPPrimeMaxWait)
|
|
attempt := 0
|
|
for {
|
|
attempt++
|
|
tools := p.discoverWorkspaceMCPTools(ctx, logger, chatID, workspaceCtx)
|
|
if len(tools) > 0 {
|
|
logger.Debug(ctx, "primed workspace MCP cache",
|
|
slog.F("chat_id", chatID),
|
|
slog.F("tool_count", len(tools)),
|
|
slog.F("attempts", attempt),
|
|
)
|
|
return
|
|
}
|
|
if ctx.Err() != nil {
|
|
return
|
|
}
|
|
if !p.clock.Now().Before(deadline) {
|
|
logger.Debug(ctx,
|
|
"workspace MCP cache primer gave up waiting for tools",
|
|
slog.F("chat_id", chatID),
|
|
slog.F("attempts", attempt),
|
|
)
|
|
return
|
|
}
|
|
timer := p.clock.NewTimer(workspaceMCPPrimeRetryInterval, "chatd", "workspace-mcp-prime")
|
|
select {
|
|
case <-timer.C:
|
|
case <-ctx.Done():
|
|
timer.Stop()
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
type turnWorkspaceContext struct {
|
|
server *Server
|
|
chatStateMu *sync.Mutex
|
|
currentChat *database.Chat
|
|
loadChatSnapshot func(context.Context, uuid.UUID) (database.Chat, error)
|
|
|
|
mu sync.Mutex
|
|
agent database.WorkspaceAgent
|
|
agentLoaded bool
|
|
conn workspacesdk.AgentConn
|
|
releaseConn func()
|
|
cachedWorkspaceID uuid.NullUUID
|
|
}
|
|
|
|
func (c *turnWorkspaceContext) close() {
|
|
c.clearCachedWorkspaceState()
|
|
}
|
|
|
|
func (c *turnWorkspaceContext) clearCachedWorkspaceState() {
|
|
c.mu.Lock()
|
|
releaseConn := c.releaseConn
|
|
c.agent = database.WorkspaceAgent{}
|
|
c.agentLoaded = false
|
|
c.conn = nil
|
|
c.releaseConn = nil
|
|
c.cachedWorkspaceID = uuid.NullUUID{}
|
|
c.mu.Unlock()
|
|
|
|
if releaseConn != nil {
|
|
releaseConn()
|
|
}
|
|
}
|
|
|
|
func (c *turnWorkspaceContext) setCurrentChat(chat database.Chat) {
|
|
c.chatStateMu.Lock()
|
|
*c.currentChat = chat
|
|
c.chatStateMu.Unlock()
|
|
}
|
|
|
|
func (c *turnWorkspaceContext) currentChatSnapshot() database.Chat {
|
|
c.chatStateMu.Lock()
|
|
chatSnapshot := *c.currentChat
|
|
c.chatStateMu.Unlock()
|
|
return chatSnapshot
|
|
}
|
|
|
|
func (c *turnWorkspaceContext) selectWorkspace(chat database.Chat) {
|
|
c.setCurrentChat(chat)
|
|
c.clearCachedWorkspaceState()
|
|
}
|
|
|
|
func (c *turnWorkspaceContext) currentWorkspaceMatches(expected uuid.NullUUID) (database.Chat, bool) {
|
|
chatSnapshot := c.currentChatSnapshot()
|
|
return chatSnapshot, nullUUIDEqual(chatSnapshot.WorkspaceID, expected)
|
|
}
|
|
|
|
func (c *turnWorkspaceContext) trackWorkspaceUsage(ctx context.Context, chatSnapshot database.Chat) {
|
|
if c.server == nil || !chatSnapshot.WorkspaceID.Valid {
|
|
return
|
|
}
|
|
logger := c.server.logger.With(
|
|
slog.F("chat_id", chatSnapshot.ID),
|
|
slog.F("owner_id", chatSnapshot.OwnerID),
|
|
)
|
|
c.server.trackWorkspaceUsage(ctx, chatSnapshot.ID, chatSnapshot.WorkspaceID, logger)
|
|
}
|
|
|
|
func nullUUIDEqual(left, right uuid.NullUUID) bool {
|
|
if left.Valid != right.Valid {
|
|
return false
|
|
}
|
|
if !left.Valid {
|
|
return true
|
|
}
|
|
return left.UUID == right.UUID
|
|
}
|
|
|
|
func (c *turnWorkspaceContext) persistBuildAgentBinding(
|
|
ctx context.Context,
|
|
chatSnapshot database.Chat,
|
|
buildID uuid.UUID,
|
|
agentID uuid.UUID,
|
|
) (database.Chat, error) {
|
|
updatedChat, err := c.server.db.UpdateChatBuildAgentBinding(
|
|
ctx,
|
|
database.UpdateChatBuildAgentBindingParams{
|
|
ID: chatSnapshot.ID,
|
|
BuildID: uuid.NullUUID{
|
|
UUID: buildID,
|
|
Valid: true,
|
|
},
|
|
AgentID: uuid.NullUUID{
|
|
UUID: agentID,
|
|
Valid: true,
|
|
},
|
|
},
|
|
)
|
|
if err != nil {
|
|
return chatSnapshot, xerrors.Errorf(
|
|
"update chat build/agent binding: %w", err,
|
|
)
|
|
}
|
|
c.setCurrentChat(updatedChat)
|
|
return updatedChat, nil
|
|
}
|
|
|
|
func (c *turnWorkspaceContext) getWorkspaceAgent(ctx context.Context) (database.WorkspaceAgent, error) {
|
|
_, agent, err := c.ensureWorkspaceAgent(ctx)
|
|
return agent, err
|
|
}
|
|
|
|
func (c *turnWorkspaceContext) ensureWorkspaceAgent(
|
|
ctx context.Context,
|
|
) (database.Chat, database.WorkspaceAgent, error) {
|
|
c.mu.Lock()
|
|
defer c.mu.Unlock()
|
|
|
|
if c.agentLoaded {
|
|
chatSnapshot := c.currentChatSnapshot()
|
|
if nullUUIDEqual(c.cachedWorkspaceID, chatSnapshot.WorkspaceID) {
|
|
return chatSnapshot, c.agent, nil
|
|
}
|
|
c.agent = database.WorkspaceAgent{}
|
|
c.agentLoaded = false
|
|
}
|
|
|
|
return c.loadWorkspaceAgentLocked(ctx)
|
|
}
|
|
|
|
func (c *turnWorkspaceContext) loadWorkspaceAgentLocked(
|
|
ctx context.Context,
|
|
) (database.Chat, database.WorkspaceAgent, error) {
|
|
chatSnapshot := c.currentChatSnapshot()
|
|
|
|
for attempt := 0; attempt < 2; attempt++ {
|
|
if !chatSnapshot.WorkspaceID.Valid {
|
|
refreshedChat, refreshErr := refreshChatWorkspaceSnapshot(
|
|
ctx,
|
|
chatSnapshot,
|
|
c.loadChatSnapshot,
|
|
)
|
|
if refreshErr != nil {
|
|
return chatSnapshot, database.WorkspaceAgent{}, refreshErr
|
|
}
|
|
if refreshedChat.WorkspaceID.Valid {
|
|
c.setCurrentChat(refreshedChat)
|
|
chatSnapshot = refreshedChat
|
|
}
|
|
}
|
|
|
|
if !chatSnapshot.WorkspaceID.Valid {
|
|
return chatSnapshot, database.WorkspaceAgent{}, xerrors.New("no workspace is associated with this chat. Use the create_workspace tool to create one")
|
|
}
|
|
|
|
if chatSnapshot.AgentID.Valid {
|
|
agent, err := c.server.db.GetWorkspaceAgentByID(ctx, chatSnapshot.AgentID.UUID)
|
|
if err == nil {
|
|
latestChat, workspaceMatches := c.currentWorkspaceMatches(chatSnapshot.WorkspaceID)
|
|
if !workspaceMatches {
|
|
chatSnapshot = latestChat
|
|
continue
|
|
}
|
|
c.agent = agent
|
|
c.agentLoaded = true
|
|
c.cachedWorkspaceID = chatSnapshot.WorkspaceID
|
|
return chatSnapshot, c.agent, nil
|
|
}
|
|
if !xerrors.Is(err, sql.ErrNoRows) {
|
|
c.server.logger.Warn(ctx, "agent binding lookup failed, re-resolving",
|
|
slog.F("agent_id", chatSnapshot.AgentID.UUID),
|
|
slog.Error(err),
|
|
)
|
|
}
|
|
}
|
|
|
|
agents, err := c.server.db.GetWorkspaceAgentsInLatestBuildByWorkspaceID(
|
|
ctx,
|
|
chatSnapshot.WorkspaceID.UUID,
|
|
)
|
|
if err != nil {
|
|
return chatSnapshot, database.WorkspaceAgent{}, xerrors.Errorf(
|
|
"get workspace agents in latest build: %w",
|
|
err,
|
|
)
|
|
}
|
|
if len(agents) == 0 {
|
|
return chatSnapshot, database.WorkspaceAgent{}, errChatHasNoWorkspaceAgent
|
|
}
|
|
selected, err := agentselect.FindChatAgent(agents)
|
|
if err != nil {
|
|
return chatSnapshot, database.WorkspaceAgent{}, xerrors.Errorf(
|
|
"find chat agent: %w",
|
|
err,
|
|
)
|
|
}
|
|
|
|
build, err := c.server.db.GetLatestWorkspaceBuildByWorkspaceID(ctx, chatSnapshot.WorkspaceID.UUID)
|
|
if err != nil {
|
|
return chatSnapshot, database.WorkspaceAgent{}, xerrors.Errorf("get latest workspace build: %w", err)
|
|
}
|
|
|
|
updatedChat, err := c.persistBuildAgentBinding(
|
|
ctx,
|
|
chatSnapshot,
|
|
build.ID,
|
|
selected.ID,
|
|
)
|
|
if err != nil {
|
|
return chatSnapshot, database.WorkspaceAgent{}, err
|
|
}
|
|
|
|
chatSnapshot = updatedChat
|
|
latestChat, workspaceMatches := c.currentWorkspaceMatches(chatSnapshot.WorkspaceID)
|
|
if !workspaceMatches {
|
|
chatSnapshot = latestChat
|
|
continue
|
|
}
|
|
c.agent = selected
|
|
c.agentLoaded = true
|
|
c.cachedWorkspaceID = chatSnapshot.WorkspaceID
|
|
return chatSnapshot, c.agent, nil
|
|
}
|
|
|
|
return chatSnapshot, database.WorkspaceAgent{}, xerrors.New(
|
|
"chat workspace changed while resolving agent",
|
|
)
|
|
}
|
|
|
|
func (c *turnWorkspaceContext) latestWorkspaceAgentID(
|
|
ctx context.Context,
|
|
workspaceID uuid.UUID,
|
|
) (uuid.UUID, error) {
|
|
agents, err := c.server.db.GetWorkspaceAgentsInLatestBuildByWorkspaceID(
|
|
ctx,
|
|
workspaceID,
|
|
)
|
|
if err != nil {
|
|
return uuid.Nil, xerrors.Errorf(
|
|
"get workspace agents in latest build: %w",
|
|
err,
|
|
)
|
|
}
|
|
if len(agents) == 0 {
|
|
return uuid.Nil, errChatHasNoWorkspaceAgent
|
|
}
|
|
selected, err := agentselect.FindChatAgent(agents)
|
|
if err != nil {
|
|
return uuid.Nil, xerrors.Errorf(
|
|
"find chat agent: %w",
|
|
err,
|
|
)
|
|
}
|
|
return selected.ID, nil
|
|
}
|
|
|
|
func (c *turnWorkspaceContext) workspaceAgentIDForConn(
|
|
ctx context.Context,
|
|
) (database.Chat, uuid.UUID, error) {
|
|
for attempt := 0; attempt < 2; attempt++ {
|
|
chatSnapshot := c.currentChatSnapshot()
|
|
if !chatSnapshot.WorkspaceID.Valid || !chatSnapshot.AgentID.Valid {
|
|
updatedChat, agent, err := c.ensureWorkspaceAgent(ctx)
|
|
if err != nil {
|
|
return updatedChat, uuid.Nil, err
|
|
}
|
|
return updatedChat, agent.ID, nil
|
|
}
|
|
|
|
currentAgentID, err := c.latestWorkspaceAgentID(
|
|
ctx,
|
|
chatSnapshot.WorkspaceID.UUID,
|
|
)
|
|
if err != nil {
|
|
if xerrors.Is(err, errChatHasNoWorkspaceAgent) {
|
|
c.clearCachedWorkspaceState()
|
|
}
|
|
return chatSnapshot, uuid.Nil, err
|
|
}
|
|
|
|
latestChat, workspaceMatches := c.currentWorkspaceMatches(
|
|
chatSnapshot.WorkspaceID,
|
|
)
|
|
if !workspaceMatches {
|
|
continue
|
|
}
|
|
return latestChat, currentAgentID, nil
|
|
}
|
|
|
|
chatSnapshot := c.currentChatSnapshot()
|
|
return chatSnapshot, uuid.Nil, xerrors.New(
|
|
"chat workspace changed while resolving agent",
|
|
)
|
|
}
|
|
|
|
// getWorkspaceConnLocked returns the cached connection when it still matches
|
|
// the current workspace. When the workspace changed, it clears the stale
|
|
// cached state and returns the release func for the caller to run after
|
|
// unlocking.
|
|
func (c *turnWorkspaceContext) getWorkspaceConnLocked() (workspacesdk.AgentConn, func()) {
|
|
if c.conn == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
chatSnapshot := c.currentChatSnapshot()
|
|
if nullUUIDEqual(c.cachedWorkspaceID, chatSnapshot.WorkspaceID) {
|
|
return c.conn, nil
|
|
}
|
|
|
|
agentRelease := c.releaseConn
|
|
c.agent = database.WorkspaceAgent{}
|
|
c.agentLoaded = false
|
|
c.conn = nil
|
|
c.releaseConn = nil
|
|
c.cachedWorkspaceID = uuid.NullUUID{}
|
|
return nil, agentRelease
|
|
}
|
|
|
|
// isAgentUnreachable reports whether the given agent row's
|
|
// status is disconnected or timed out. It uses timestamp
|
|
// arithmetic on the row. The "connecting" state is allowed
|
|
// through because it is normal after a fresh workspace build.
|
|
func isAgentUnreachable(now time.Time, agent database.WorkspaceAgent, inactiveTimeout time.Duration) bool {
|
|
status := agent.Status(now, inactiveTimeout)
|
|
return status.Status == database.WorkspaceAgentStatusDisconnected ||
|
|
status.Status == database.WorkspaceAgentStatusTimeout
|
|
}
|
|
|
|
func agentDisconnectedFor(now time.Time, agent database.WorkspaceAgent, inactiveTimeout time.Duration) (time.Duration, bool) {
|
|
status := agent.Status(now, inactiveTimeout)
|
|
if status.Status != database.WorkspaceAgentStatusDisconnected || status.DisconnectedAt == nil {
|
|
return 0, false
|
|
}
|
|
|
|
disconnectedFor := now.Sub(*status.DisconnectedAt)
|
|
if disconnectedFor < 0 {
|
|
disconnectedFor = 0
|
|
}
|
|
return disconnectedFor, true
|
|
}
|
|
|
|
func (c *turnWorkspaceContext) latestWorkspaceAgentNeedsRestart(
|
|
ctx context.Context,
|
|
workspaceID uuid.UUID,
|
|
) (bool, error) {
|
|
agentID, err := c.latestWorkspaceAgentID(ctx, workspaceID)
|
|
if err != nil {
|
|
if xerrors.Is(err, errChatHasNoWorkspaceAgent) {
|
|
return false, err
|
|
}
|
|
c.server.logger.Warn(ctx, "failed to resolve latest agent for timeout classification", slog.Error(err))
|
|
return false, nil
|
|
}
|
|
|
|
agent, err := c.server.db.GetWorkspaceAgentByID(ctx, agentID)
|
|
if err != nil {
|
|
c.server.logger.Warn(ctx, "failed to load latest agent for timeout classification",
|
|
slog.F("agent_id", agentID),
|
|
slog.Error(err),
|
|
)
|
|
return false, nil
|
|
}
|
|
|
|
disconnectedFor, disconnected := agentDisconnectedFor(c.server.clock.Now(), agent, c.server.agentInactiveDisconnectTimeout)
|
|
return disconnected && disconnectedFor >= agentDisconnectedRecoveryThreshold, nil
|
|
}
|
|
|
|
func (c *turnWorkspaceContext) externalAgentError(
|
|
ctx context.Context,
|
|
agent database.WorkspaceAgent,
|
|
fallback error,
|
|
) error {
|
|
isExternal, err := chattool.IsExternalWorkspaceAgent(ctx, c.server.db, agent)
|
|
if err != nil || !isExternal {
|
|
return fallback
|
|
}
|
|
return newChatExternalAgentUnavailableError(agent)
|
|
}
|
|
|
|
func (c *turnWorkspaceContext) externalAgentPreflightError(
|
|
ctx context.Context,
|
|
chatSnapshot database.Chat,
|
|
agent database.WorkspaceAgent,
|
|
) error {
|
|
// Mirror the cache-hit gate: only short-circuit on clearly offline
|
|
// states (Disconnected/Timeout). Connecting is allowed through so
|
|
// an external agent the user just started can still connect inside
|
|
// the normal dial window.
|
|
if !isAgentUnreachable(c.server.clock.Now(), agent, c.server.agentInactiveDisconnectTimeout) {
|
|
return nil
|
|
}
|
|
|
|
isExternal, err := chattool.IsExternalWorkspaceAgent(ctx, c.server.db, agent)
|
|
if err != nil || !isExternal || !chatSnapshot.WorkspaceID.Valid {
|
|
return nil
|
|
}
|
|
|
|
// Stale agent bindings rely on dialWithLazyValidation to discover
|
|
// replacement agents, so only skip the dial when this agent is still
|
|
// the latest selected chat agent for the workspace.
|
|
latestAgentID, err := c.latestWorkspaceAgentID(ctx, chatSnapshot.WorkspaceID.UUID)
|
|
if err != nil || latestAgentID != agent.ID {
|
|
return nil
|
|
}
|
|
return newChatExternalAgentUnavailableError(agent)
|
|
}
|
|
|
|
func (c *turnWorkspaceContext) getWorkspaceConn(ctx context.Context) (workspacesdk.AgentConn, error) {
|
|
if c.server.agentConnFn == nil {
|
|
return nil, xerrors.New("workspace agent connector is not configured")
|
|
}
|
|
|
|
for attempt := 0; attempt < 2; attempt++ {
|
|
c.mu.Lock()
|
|
currentConn, staleRelease := c.getWorkspaceConnLocked()
|
|
// Capture agentID in the same lock section as
|
|
// currentConn to prevent a TOCTOU race with
|
|
// concurrent clearCachedWorkspaceState calls.
|
|
agentID := c.agent.ID
|
|
c.mu.Unlock()
|
|
|
|
// Status check on cache hit: re-fetch the agent
|
|
// row so we see the latest heartbeat rather than
|
|
// a potentially stale cached copy.
|
|
if currentConn != nil {
|
|
chatSnapshot := c.currentChatSnapshot()
|
|
if agentID != uuid.Nil {
|
|
freshAgent, err := c.server.db.GetWorkspaceAgentByID(ctx, agentID)
|
|
if err != nil {
|
|
c.server.logger.Warn(ctx, "failed to re-fetch agent for status check",
|
|
slog.F("agent_id", agentID),
|
|
slog.Error(err),
|
|
)
|
|
// On DB error the check re-runs on the
|
|
// next tool call.
|
|
} else if _, disconnected := agentDisconnectedFor(
|
|
c.server.clock.Now(),
|
|
freshAgent,
|
|
c.server.agentInactiveDisconnectTimeout,
|
|
); disconnected {
|
|
c.clearCachedWorkspaceState()
|
|
continue
|
|
}
|
|
}
|
|
c.trackWorkspaceUsage(ctx, chatSnapshot)
|
|
return currentConn, nil
|
|
}
|
|
if staleRelease != nil {
|
|
staleRelease()
|
|
}
|
|
|
|
chatSnapshot, agent, err := c.ensureWorkspaceAgent(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if err := c.externalAgentPreflightError(ctx, chatSnapshot, agent); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Wrap the dial in a timeout to bound the time spent
|
|
// waiting for an unreachable agent. The timeout scopes
|
|
// only dialWithLazyValidation, not ensureWorkspaceAgent
|
|
// or the post-dial binding steps.
|
|
dialCtx, dialCancel := context.WithTimeoutCause(ctx, c.server.dialTimeout, errChatDialTimeout)
|
|
dialResult, err := dialWithLazyValidation(
|
|
dialCtx,
|
|
agent.ID,
|
|
chatSnapshot.WorkspaceID.UUID,
|
|
DialFunc(c.server.agentConnFn),
|
|
func(ctx context.Context, workspaceID uuid.UUID) (uuid.UUID, error) {
|
|
return c.latestWorkspaceAgentID(ctx, workspaceID)
|
|
},
|
|
workspaceDialValidationDelay,
|
|
)
|
|
dialCancel()
|
|
if err != nil {
|
|
if xerrors.Is(err, errChatHasNoWorkspaceAgent) {
|
|
c.clearCachedWorkspaceState()
|
|
return nil, err
|
|
}
|
|
// Surface the dial timeout sentinel only when the
|
|
// parent context is still alive. If the parent was
|
|
// canceled (e.g. ErrInterrupted), its error must
|
|
// propagate unchanged so the chatloop can detect it.
|
|
if ctx.Err() == nil && errors.Is(context.Cause(dialCtx), errChatDialTimeout) {
|
|
c.clearCachedWorkspaceState()
|
|
needsRestart, statusErr := c.latestWorkspaceAgentNeedsRestart(ctx, chatSnapshot.WorkspaceID.UUID)
|
|
if statusErr != nil {
|
|
return nil, statusErr
|
|
}
|
|
if needsRestart {
|
|
return nil, c.externalAgentError(ctx, agent, errChatAgentDisconnected)
|
|
}
|
|
return nil, c.externalAgentError(ctx, agent, errChatDialTimeout)
|
|
}
|
|
return nil, err
|
|
}
|
|
agentConn := dialResult.Conn
|
|
agentRelease := dialResult.Release
|
|
if dialResult.WasSwitched {
|
|
build, err := c.server.db.GetLatestWorkspaceBuildByWorkspaceID(ctx, chatSnapshot.WorkspaceID.UUID)
|
|
if err != nil {
|
|
if agentRelease != nil {
|
|
agentRelease()
|
|
}
|
|
return nil, xerrors.Errorf("get latest workspace build: %w", err)
|
|
}
|
|
|
|
switchedAgent, err := c.server.db.GetWorkspaceAgentByID(ctx, dialResult.AgentID)
|
|
if err != nil {
|
|
if agentRelease != nil {
|
|
agentRelease()
|
|
}
|
|
return nil, xerrors.Errorf("get workspace agent by id: %w", err)
|
|
}
|
|
|
|
updatedChat, err := c.persistBuildAgentBinding(
|
|
ctx,
|
|
chatSnapshot,
|
|
build.ID,
|
|
switchedAgent.ID,
|
|
)
|
|
if err != nil {
|
|
if agentRelease != nil {
|
|
agentRelease()
|
|
}
|
|
return nil, err
|
|
}
|
|
chatSnapshot = updatedChat
|
|
|
|
c.mu.Lock()
|
|
c.agent = switchedAgent
|
|
c.agentLoaded = true
|
|
c.cachedWorkspaceID = chatSnapshot.WorkspaceID
|
|
c.mu.Unlock()
|
|
}
|
|
|
|
if _, workspaceMatches := c.currentWorkspaceMatches(chatSnapshot.WorkspaceID); !workspaceMatches {
|
|
if agentRelease != nil {
|
|
agentRelease()
|
|
}
|
|
c.clearCachedWorkspaceState()
|
|
continue
|
|
}
|
|
|
|
c.mu.Lock()
|
|
if c.conn == nil {
|
|
c.conn = agentConn
|
|
c.releaseConn = agentRelease
|
|
c.cachedWorkspaceID = chatSnapshot.WorkspaceID
|
|
|
|
var ancestorIDs []string
|
|
if chatSnapshot.ParentChatID.Valid {
|
|
ancestorIDs = append(ancestorIDs, chatSnapshot.ParentChatID.UUID.String())
|
|
}
|
|
ancestorJSON, marshalErr := json.Marshal(ancestorIDs)
|
|
if marshalErr != nil {
|
|
ancestorJSON = []byte("[]")
|
|
}
|
|
agentConn.SetExtraHeaders(http.Header{
|
|
workspacesdk.CoderChatIDHeader: {chatSnapshot.ID.String()},
|
|
workspacesdk.CoderAncestorChatIDsHeader: {string(ancestorJSON)},
|
|
})
|
|
|
|
c.mu.Unlock()
|
|
c.server.logger.Debug(ctx, "set chat headers on agent conn",
|
|
slog.F("chat_id", chatSnapshot.ID),
|
|
slog.F("ancestor_chat_ids", ancestorIDs),
|
|
slog.F("workspace_id", chatSnapshot.WorkspaceID.UUID),
|
|
slog.F("agent_id", dialResult.AgentID),
|
|
)
|
|
c.trackWorkspaceUsage(ctx, chatSnapshot)
|
|
return agentConn, nil
|
|
}
|
|
currentConn = c.conn
|
|
c.mu.Unlock()
|
|
|
|
if agentRelease != nil {
|
|
agentRelease()
|
|
}
|
|
c.trackWorkspaceUsage(ctx, chatSnapshot)
|
|
return currentConn, nil
|
|
}
|
|
|
|
return nil, xerrors.New("chat workspace changed while connecting")
|
|
}
|
|
|
|
// AgentConnFunc provides access to workspace agent connections.
|
|
type AgentConnFunc func(ctx context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error)
|
|
|
|
// SubscribeFn replaces the default local-only subscription with a
|
|
// multi-replica-aware implementation that merges pubsub notifications,
|
|
// remote relay streams, and local parts into a single event channel.
|
|
// When set, Subscribe delegates the event-merge goroutine to this
|
|
// function instead of using simple local forwarding.
|
|
//
|
|
// Parameters:
|
|
// - ctx: subscription lifetime context (canceled on unsubscribe).
|
|
// - params: all state needed to build the merged stream.
|
|
//
|
|
// Returns the merged event channel. Cleanup is driven by ctx
|
|
// cancellation — the merge goroutine tears down all relay state
|
|
// in its defer when ctx is done.
|
|
// Set by enterprise for HA deployments. Nil in AGPL single-replica.
|
|
type SubscribeFn func(
|
|
ctx context.Context,
|
|
params SubscribeFnParams,
|
|
) <-chan codersdk.ChatStreamEvent
|
|
|
|
// StatusNotification informs the enterprise relay manager of chat
|
|
// status changes so it can open or close relay connections.
|
|
type StatusNotification struct {
|
|
Status database.ChatStatus
|
|
WorkerID uuid.UUID
|
|
}
|
|
|
|
// SubscribeFnParams carries the state that the enterprise
|
|
// SubscribeFn implementation needs from the OSS Subscribe preamble.
|
|
type SubscribeFnParams struct {
|
|
ChatID uuid.UUID
|
|
Chat database.Chat
|
|
WorkerID uuid.UUID
|
|
StatusNotifications <-chan StatusNotification
|
|
RequestHeader http.Header
|
|
DB database.Store
|
|
Logger slog.Logger
|
|
}
|
|
|
|
// bufferedStreamPart is a buffered message_part event with its
|
|
// committed-message linkage. Parts that have not yet been claimed by
|
|
// a durable assistant message carry committedMessageID == 0 and are
|
|
// considered "in progress"; when an assistant message is published
|
|
// every still-in-progress part is claimed by that durable message
|
|
// ID, marking the part as redundant for any subscriber that will
|
|
// receive the durable message via REST or pubsub.
|
|
type bufferedStreamPart struct {
|
|
event codersdk.ChatStreamEvent
|
|
// committedMessageID is the durable assistant message ID that
|
|
// claimed this part, or 0 while the part belongs to the
|
|
// in-progress turn. snapshotBufferLocked drops parts with
|
|
// committedMessageID != 0 because the subscriber will receive
|
|
// the durable message through a different channel (REST snapshot,
|
|
// initial DB query in SubscribeAuthorized, or pubsub).
|
|
committedMessageID int64
|
|
}
|
|
|
|
type chatStreamState struct {
|
|
mu sync.Mutex
|
|
buffer []bufferedStreamPart
|
|
buffering bool
|
|
durableMessages []codersdk.ChatStreamEvent
|
|
durableEvictedBefore int64 // highest message ID evicted from durable cache
|
|
subscribers map[uuid.UUID]chan codersdk.ChatStreamEvent
|
|
bufferDropCount int64
|
|
bufferLastWarnAt time.Time
|
|
subscriberDropCount int64
|
|
subscriberLastWarnAt time.Time
|
|
// currentRetry records the current retry phase for late-joining
|
|
// same-replica subscribers. Nil when the stream is not waiting
|
|
// to retry.
|
|
currentRetry *codersdk.ChatStreamRetry
|
|
// bufferRetainedAt records when processing completed and
|
|
// the per-chat stream state entered the post-completion
|
|
// grace window. Zero while buffering is active. When
|
|
// non-zero, cleanupStreamIfIdle skips GC until the grace
|
|
// period expires so cross-replica relay subscribers can
|
|
// register without racing state deletion. The buffer
|
|
// itself does not deliver content here: every part is
|
|
// claimed by a durable assistant message before
|
|
// bufferRetainedAt is set, so snapshotBufferLocked
|
|
// returns no parts during the grace window.
|
|
bufferRetainedAt time.Time
|
|
}
|
|
|
|
// streamStateCollector exposes scrape-time gauges derived from
|
|
// p.chatStreams. Scrape cost is O(n) with a brief per-state mutex
|
|
// held for two len() reads; acceptable at typical scrape cadences.
|
|
type streamStateCollector struct {
|
|
server *Server
|
|
}
|
|
|
|
var (
|
|
streamsActiveDesc = prometheus.NewDesc(
|
|
"coderd_chatd_streams_active",
|
|
"Current number of chat stream state entries (in-flight plus retained).",
|
|
nil, nil,
|
|
)
|
|
streamBufferSizeMaxDesc = prometheus.NewDesc(
|
|
"coderd_chatd_stream_buffer_size_max",
|
|
"Maximum current buffer length across all chat streams.",
|
|
nil, nil,
|
|
)
|
|
streamBufferEventsDesc = prometheus.NewDesc(
|
|
"coderd_chatd_stream_buffer_events",
|
|
"Sum of current buffer lengths across all chat streams.",
|
|
nil, nil,
|
|
)
|
|
streamSubscribersDesc = prometheus.NewDesc(
|
|
"coderd_chatd_stream_subscribers",
|
|
"Current number of chat stream subscribers across all chat streams.",
|
|
nil, nil,
|
|
)
|
|
)
|
|
|
|
func (*streamStateCollector) Describe(ch chan<- *prometheus.Desc) {
|
|
ch <- streamsActiveDesc
|
|
ch <- streamBufferSizeMaxDesc
|
|
ch <- streamBufferEventsDesc
|
|
ch <- streamSubscribersDesc
|
|
}
|
|
|
|
func (c *streamStateCollector) Collect(ch chan<- prometheus.Metric) {
|
|
var active, totalEvents, maxBufLen, totalSubs int
|
|
c.server.chatStreams.Range(func(_, v any) bool {
|
|
state, ok := v.(*chatStreamState)
|
|
if !ok {
|
|
return true
|
|
}
|
|
active++
|
|
state.mu.Lock()
|
|
bufLen := len(state.buffer)
|
|
subs := len(state.subscribers)
|
|
state.mu.Unlock()
|
|
totalEvents += bufLen
|
|
totalSubs += subs
|
|
maxBufLen = max(maxBufLen, bufLen)
|
|
return true
|
|
})
|
|
ch <- prometheus.MustNewConstMetric(streamsActiveDesc, prometheus.GaugeValue, float64(active))
|
|
ch <- prometheus.MustNewConstMetric(streamBufferSizeMaxDesc, prometheus.GaugeValue, float64(maxBufLen))
|
|
ch <- prometheus.MustNewConstMetric(streamBufferEventsDesc, prometheus.GaugeValue, float64(totalEvents))
|
|
ch <- prometheus.MustNewConstMetric(streamSubscribersDesc, prometheus.GaugeValue, float64(totalSubs))
|
|
}
|
|
|
|
var (
|
|
// ErrInvalidModelConfigID indicates the requested model config does not exist.
|
|
ErrInvalidModelConfigID = xerrors.New("invalid model config ID")
|
|
// ErrEditedMessageNotFound indicates the edited message does not exist
|
|
// in the target chat.
|
|
ErrEditedMessageNotFound = xerrors.New("edited message not found")
|
|
// ErrEditedMessageNotUser indicates a non-user message edit attempt.
|
|
ErrEditedMessageNotUser = xerrors.New("only user messages can be edited")
|
|
// ErrChatArchived indicates the chat is archived and cannot
|
|
// accept modifications (messages, edits, promotions, or
|
|
// tool-result submissions).
|
|
ErrChatArchived = xerrors.New("chat is archived")
|
|
)
|
|
|
|
// UsageLimitExceededError indicates the user has exceeded their chat spend
|
|
// limit.
|
|
type UsageLimitExceededError struct {
|
|
LimitMicros int64
|
|
ConsumedMicros int64
|
|
PeriodEnd time.Time
|
|
}
|
|
|
|
func formatMicrosAsDollars(micros int64) string {
|
|
return "$" + decimal.NewFromInt(micros).Shift(-6).StringFixed(2)
|
|
}
|
|
|
|
func (e *UsageLimitExceededError) Error() string {
|
|
return fmt.Sprintf(
|
|
"usage limit exceeded: spent %s of %s limit, resets at %s",
|
|
formatMicrosAsDollars(e.ConsumedMicros),
|
|
formatMicrosAsDollars(e.LimitMicros),
|
|
e.PeriodEnd.Format(time.RFC3339),
|
|
)
|
|
}
|
|
|
|
// CreateOptions controls chat creation in the shared chat mutation path.
|
|
type CreateOptions struct {
|
|
OrganizationID uuid.UUID
|
|
OwnerID uuid.UUID
|
|
WorkspaceID uuid.NullUUID
|
|
BuildID uuid.NullUUID
|
|
AgentID uuid.NullUUID
|
|
ParentChatID uuid.NullUUID
|
|
RootChatID uuid.NullUUID
|
|
Title string
|
|
ModelConfigID uuid.UUID
|
|
ChatMode database.NullChatMode
|
|
PlanMode database.NullChatPlanMode
|
|
ClientType database.ChatClientType
|
|
SystemPrompt string
|
|
InitialUserContent []codersdk.ChatMessagePart
|
|
APIKeyID string
|
|
MCPServerIDs []uuid.UUID
|
|
Labels database.StringMap
|
|
DynamicTools json.RawMessage
|
|
}
|
|
|
|
// SendMessageBusyBehavior controls what happens when a chat is already active.
|
|
type SendMessageBusyBehavior string
|
|
|
|
const (
|
|
// SendMessageBusyBehaviorQueue queues user messages while the chat is busy.
|
|
SendMessageBusyBehaviorQueue SendMessageBusyBehavior = "queue"
|
|
// SendMessageBusyBehaviorInterrupt queues the message and
|
|
// interrupts the active run. The queued message is
|
|
// auto-promoted after the interrupted assistant response is
|
|
// persisted, ensuring correct message ordering.
|
|
SendMessageBusyBehaviorInterrupt SendMessageBusyBehavior = "interrupt"
|
|
)
|
|
|
|
// SendMessageOptions controls user message insertion with busy-state behavior.
|
|
type SendMessageOptions struct {
|
|
ChatID uuid.UUID
|
|
CreatedBy uuid.UUID
|
|
Content []codersdk.ChatMessagePart
|
|
ModelConfigID uuid.UUID
|
|
APIKeyID string
|
|
BusyBehavior SendMessageBusyBehavior
|
|
PlanMode *database.NullChatPlanMode
|
|
MCPServerIDs *[]uuid.UUID
|
|
}
|
|
|
|
// SendMessageResult contains the outcome of user message processing.
|
|
type SendMessageResult struct {
|
|
Queued bool
|
|
QueuedMessage *database.ChatQueuedMessage
|
|
Message database.ChatMessage
|
|
Chat database.Chat
|
|
}
|
|
|
|
// EditMessageOptions controls user message edits via soft-delete and re-insert.
|
|
type EditMessageOptions struct {
|
|
ChatID uuid.UUID
|
|
CreatedBy uuid.UUID
|
|
EditedMessageID int64
|
|
Content []codersdk.ChatMessagePart
|
|
APIKeyID string
|
|
// ModelConfigID, when non-zero, overrides the model used for
|
|
// the replacement user message. When set to uuid.Nil the
|
|
// original message's model is preserved.
|
|
ModelConfigID uuid.UUID
|
|
}
|
|
|
|
// EditMessageResult contains the replacement user message and chat status.
|
|
type EditMessageResult struct {
|
|
Message database.ChatMessage
|
|
Chat database.Chat
|
|
}
|
|
|
|
// PromoteQueuedOptions controls queued-message promotion.
|
|
type PromoteQueuedOptions struct {
|
|
ChatID uuid.UUID
|
|
CreatedBy uuid.UUID
|
|
QueuedMessageID int64
|
|
}
|
|
|
|
// PromoteQueuedResult contains post-promotion message metadata.
|
|
type PromoteQueuedResult struct {
|
|
// PromotedMessage is the inserted user message. For a chat that
|
|
// was running at promote time, the insertion is deferred to the
|
|
// worker's auto-promote and PromotedMessage is the zero value.
|
|
PromotedMessage database.ChatMessage
|
|
}
|
|
|
|
// CreateChat creates a chat with its initial history through
|
|
// chatstate.CreateChat. The new chat starts in `running` status per
|
|
// the RFC. Ownership hints wake chat workers.
|
|
func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.Chat, error) {
|
|
if opts.OrganizationID == uuid.Nil {
|
|
return database.Chat{}, xerrors.New("organization_id is required")
|
|
}
|
|
if opts.OwnerID == uuid.Nil {
|
|
return database.Chat{}, xerrors.New("owner_id is required")
|
|
}
|
|
if strings.TrimSpace(opts.Title) == "" {
|
|
return database.Chat{}, xerrors.New("title is required")
|
|
}
|
|
if len(opts.InitialUserContent) == 0 {
|
|
return database.Chat{}, xerrors.New("initial user content is required")
|
|
}
|
|
// Ensure MCPServerIDs is non-nil so pq.Array produces '{}'
|
|
// instead of SQL NULL, which violates the NOT NULL column
|
|
// constraint.
|
|
if opts.MCPServerIDs == nil {
|
|
opts.MCPServerIDs = []uuid.UUID{}
|
|
}
|
|
if opts.Labels == nil {
|
|
opts.Labels = database.StringMap{}
|
|
}
|
|
opts.ClientType = cmp.Or(opts.ClientType, database.ChatClientTypeApi)
|
|
if !opts.ClientType.Valid() {
|
|
return database.Chat{}, xerrors.Errorf("invalid client_type: %q", opts.ClientType)
|
|
}
|
|
// Resolve the deployment prompt before opening the transaction so
|
|
// chat creation does not hold one DB connection while waiting for
|
|
// another pool checkout.
|
|
deploymentPrompt := p.resolveDeploymentSystemPrompt(ctx)
|
|
|
|
// Usage limits gate the create before we touch the state machine.
|
|
if limitErr := p.checkUsageLimit(ctx, p.db, opts.OwnerID, uuid.NullUUID{UUID: opts.OrganizationID, Valid: true}); limitErr != nil {
|
|
return database.Chat{}, limitErr
|
|
}
|
|
|
|
labelsJSON, err := json.Marshal(opts.Labels)
|
|
if err != nil {
|
|
return database.Chat{}, xerrors.Errorf("marshal labels: %w", err)
|
|
}
|
|
|
|
userPrompt := SanitizePromptText(opts.SystemPrompt)
|
|
var workspaceAwareness string
|
|
if opts.WorkspaceID.Valid {
|
|
workspaceAwareness = workspaceAttachedAwareness
|
|
} else {
|
|
workspaceAwareness = workspaceDetachedAwareness
|
|
}
|
|
workspaceAwarenessContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageText(workspaceAwareness),
|
|
})
|
|
if err != nil {
|
|
return database.Chat{}, xerrors.Errorf("marshal workspace awareness: %w", err)
|
|
}
|
|
userContent, err := chatprompt.MarshalParts(opts.InitialUserContent)
|
|
if err != nil {
|
|
return database.Chat{}, xerrors.Errorf("marshal initial user content: %w", err)
|
|
}
|
|
|
|
initialMessages := make([]chatstate.Message, 0, 4)
|
|
if deploymentPrompt != "" {
|
|
deploymentContent, marshalErr := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageText(deploymentPrompt),
|
|
})
|
|
if marshalErr != nil {
|
|
return database.Chat{}, xerrors.Errorf("marshal deployment system prompt: %w", marshalErr)
|
|
}
|
|
initialMessages = append(initialMessages, systemMessage(deploymentContent, opts.ModelConfigID))
|
|
}
|
|
if userPrompt != "" {
|
|
userPromptContent, marshalErr := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageText(userPrompt),
|
|
})
|
|
if marshalErr != nil {
|
|
return database.Chat{}, xerrors.Errorf("marshal user system prompt: %w", marshalErr)
|
|
}
|
|
initialMessages = append(initialMessages, systemMessage(userPromptContent, opts.ModelConfigID))
|
|
}
|
|
initialMessages = append(initialMessages, systemMessage(workspaceAwarenessContent, opts.ModelConfigID))
|
|
initialMessages = append(initialMessages, userMessageWithAPIKeyID(userContent, opts.ModelConfigID, opts.OwnerID, opts.APIKeyID))
|
|
|
|
result, err := chatstate.CreateChat(ctx, p.db, p.pubsub, chatstate.CreateChatInput{
|
|
OrganizationID: opts.OrganizationID,
|
|
OwnerID: opts.OwnerID,
|
|
WorkspaceID: opts.WorkspaceID,
|
|
BuildID: opts.BuildID,
|
|
AgentID: opts.AgentID,
|
|
ParentChatID: opts.ParentChatID,
|
|
RootChatID: opts.RootChatID,
|
|
LastModelConfigID: opts.ModelConfigID,
|
|
Title: opts.Title,
|
|
Mode: opts.ChatMode,
|
|
PlanMode: opts.PlanMode,
|
|
MCPServerIDs: opts.MCPServerIDs,
|
|
Labels: pqtype.NullRawMessage{
|
|
RawMessage: labelsJSON,
|
|
Valid: true,
|
|
},
|
|
DynamicTools: pqtype.NullRawMessage{
|
|
RawMessage: opts.DynamicTools,
|
|
Valid: len(opts.DynamicTools) > 0,
|
|
},
|
|
ClientType: opts.ClientType,
|
|
InitialMessages: initialMessages,
|
|
})
|
|
if err != nil {
|
|
return database.Chat{}, err
|
|
}
|
|
chat := result.Chat
|
|
if !chat.RootChatID.Valid && !chat.ParentChatID.Valid {
|
|
chat.RootChatID = uuid.NullUUID{UUID: chat.ID, Valid: true}
|
|
}
|
|
|
|
// Publish the sidebar watch event explicitly after chatstate has
|
|
// committed and emitted its own state-machine notifications. The
|
|
// watch endpoint is intentionally outside the RFC refactor scope.
|
|
p.publishChatPubsubEvent(chat, codersdk.ChatWatchEventKindCreated, nil)
|
|
return chat, nil
|
|
}
|
|
|
|
// SendMessage admits a user message through the chatstate.SendMessage
|
|
// transition. Pre-transition admission policy (usage limit, plan-mode
|
|
// metadata update, MCP server ID update, model-config resolution, queue
|
|
// cap) runs inside the same chatstate transaction via tx.Store() so
|
|
// everything commits or rolls back together.
|
|
func (p *Server) SendMessage(
|
|
ctx context.Context,
|
|
opts SendMessageOptions,
|
|
) (SendMessageResult, error) {
|
|
if opts.ChatID == uuid.Nil {
|
|
return SendMessageResult{}, xerrors.New("chat_id is required")
|
|
}
|
|
if len(opts.Content) == 0 {
|
|
return SendMessageResult{}, xerrors.New("content is required")
|
|
}
|
|
|
|
busyBehavior := opts.BusyBehavior
|
|
if busyBehavior == "" {
|
|
busyBehavior = SendMessageBusyBehaviorQueue
|
|
}
|
|
switch busyBehavior {
|
|
case SendMessageBusyBehaviorQueue, SendMessageBusyBehaviorInterrupt:
|
|
default:
|
|
return SendMessageResult{}, xerrors.Errorf("invalid busy behavior %q", opts.BusyBehavior)
|
|
}
|
|
|
|
content, err := chatprompt.MarshalParts(opts.Content)
|
|
if err != nil {
|
|
return SendMessageResult{}, xerrors.Errorf("marshal message content: %w", err)
|
|
}
|
|
|
|
requestedPlanMode := opts.PlanMode
|
|
requestedMCPServerIDs := opts.MCPServerIDs
|
|
|
|
var result SendMessageResult
|
|
machine := p.newChatMachine(opts.ChatID)
|
|
updateErr := machine.Update(ctx, func(tx *chatstate.Tx) error {
|
|
store := tx.Store()
|
|
lockedChat, err := store.GetChatByID(ctx, opts.ChatID)
|
|
if err != nil {
|
|
return xerrors.Errorf("load chat: %w", err)
|
|
}
|
|
|
|
if lockedChat.Archived {
|
|
return ErrChatArchived
|
|
}
|
|
|
|
// Enforce usage limits before any state-machine work.
|
|
if limitErr := p.checkUsageLimit(ctx, store, lockedChat.OwnerID, uuid.NullUUID{UUID: lockedChat.OrganizationID, Valid: true}); limitErr != nil {
|
|
return limitErr
|
|
}
|
|
|
|
if requestedPlanMode != nil {
|
|
lockedChat, err = store.UpdateChatPlanModeByID(ctx, database.UpdateChatPlanModeByIDParams{
|
|
PlanMode: *requestedPlanMode,
|
|
ID: opts.ChatID,
|
|
})
|
|
if err != nil {
|
|
return xerrors.Errorf("update chat plan mode: %w", err)
|
|
}
|
|
}
|
|
|
|
modelConfigID, err := resolveSendMessageModelConfigID(
|
|
ctx,
|
|
store,
|
|
lockedChat,
|
|
opts.ModelConfigID,
|
|
)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Update MCP server IDs on the chat when explicitly provided.
|
|
// Explore child chats keep the spawn-time snapshot immutable.
|
|
if requestedMCPServerIDs != nil {
|
|
if isExploreSubagentMode(lockedChat.Mode) {
|
|
p.logger.Warn(ctx,
|
|
"ignoring explore subagent mcp server ids update, snapshot is immutable after spawn",
|
|
slog.F("chat_id", opts.ChatID),
|
|
)
|
|
} else {
|
|
lockedChat, err = store.UpdateChatMCPServerIDs(ctx, database.UpdateChatMCPServerIDsParams{
|
|
ID: opts.ChatID,
|
|
MCPServerIDs: *requestedMCPServerIDs,
|
|
})
|
|
if err != nil {
|
|
return xerrors.Errorf("update chat mcp server ids: %w", err)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Queue capacity is enforced inside tx.SendMessage; this
|
|
// wrapper only propagates the typed error.
|
|
sendResult, err := tx.SendMessage(chatstate.SendMessageInput{
|
|
Message: userMessageWithAPIKeyID(content, modelConfigID, opts.CreatedBy, opts.APIKeyID),
|
|
BusyBehavior: busyBehaviorToChatState(busyBehavior),
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if sendResult.QueuedMessage != nil {
|
|
result.Queued = true
|
|
result.QueuedMessage = sendResult.QueuedMessage
|
|
} else if len(sendResult.InsertedMessages) > 0 {
|
|
// The state machine prepends synthetic tool-result
|
|
// cancellation messages; the user message is always
|
|
// last in the inserted slice.
|
|
result.Message = sendResult.InsertedMessages[len(sendResult.InsertedMessages)-1]
|
|
}
|
|
// Capture the post-transition chat inside the same
|
|
// transaction so the returned chat and the watch event
|
|
// reflect the snapshot bump and status change produced by
|
|
// the transition itself.
|
|
refreshed, err := store.GetChatByID(ctx, opts.ChatID)
|
|
if err != nil {
|
|
return xerrors.Errorf("reload chat after send: %w", err)
|
|
}
|
|
result.Chat = refreshed
|
|
return nil
|
|
})
|
|
if updateErr != nil {
|
|
return SendMessageResult{}, updateErr
|
|
}
|
|
|
|
// Sidebar watch event keeps the chat list in sync. Stream side
|
|
// effects are handled by chat:update consumers.
|
|
p.publishChatPubsubEvent(result.Chat, codersdk.ChatWatchEventKindStatusChange, nil)
|
|
return result, nil
|
|
}
|
|
|
|
func (p *Server) checkUsageLimit(ctx context.Context, store database.Store, ownerID uuid.UUID, organizationID uuid.NullUUID) error {
|
|
status, err := ResolveUsageLimitStatus(ctx, store, ownerID, organizationID, time.Now())
|
|
if err != nil {
|
|
// Fail open: never block chat due to a limit-resolution failure.
|
|
p.logger.Warn(ctx, "usage limit check failed, allowing message",
|
|
slog.F("owner_id", ownerID),
|
|
slog.Error(err),
|
|
)
|
|
return nil
|
|
}
|
|
if status == nil {
|
|
return nil
|
|
}
|
|
// Block when current spend reaches or exceeds limit (>= ensures
|
|
// the user cannot start new conversations once the limit is hit).
|
|
if status.SpendLimitMicros != nil && status.CurrentSpend >= *status.SpendLimitMicros {
|
|
return &UsageLimitExceededError{
|
|
LimitMicros: *status.SpendLimitMicros,
|
|
ConsumedMicros: status.CurrentSpend,
|
|
PeriodEnd: status.PeriodEnd,
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func chatdModelConfigLookupContext(ctx context.Context) context.Context {
|
|
//nolint:gocritic // Chat message admission needs daemon-scoped
|
|
// deployment-config reads for model config validation.
|
|
return dbauthz.AsChatd(ctx)
|
|
}
|
|
|
|
func resolveSendMessageModelConfigID(
|
|
ctx context.Context,
|
|
store database.Store,
|
|
chat database.Chat,
|
|
requested uuid.UUID,
|
|
) (uuid.UUID, error) {
|
|
if requested == uuid.Nil {
|
|
return resolveFallbackModelConfigID(ctx, store, chat.LastModelConfigID)
|
|
}
|
|
|
|
chatdCtx := chatdModelConfigLookupContext(ctx)
|
|
if _, err := store.GetChatModelConfigByID(chatdCtx, requested); err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return uuid.Nil, xerrors.Errorf(
|
|
"%w: %s",
|
|
ErrInvalidModelConfigID,
|
|
requested,
|
|
)
|
|
}
|
|
return uuid.Nil, xerrors.Errorf(
|
|
"get requested model config %s: %w",
|
|
requested,
|
|
err,
|
|
)
|
|
}
|
|
return requested, nil
|
|
}
|
|
|
|
func resolveFallbackModelConfigID(
|
|
ctx context.Context,
|
|
store database.Store,
|
|
modelConfigID uuid.UUID,
|
|
) (uuid.UUID, error) {
|
|
chatdCtx := chatdModelConfigLookupContext(ctx)
|
|
if modelConfigID != uuid.Nil {
|
|
if _, err := store.GetChatModelConfigByID(chatdCtx, modelConfigID); err == nil {
|
|
return modelConfigID, nil
|
|
} else if !errors.Is(err, sql.ErrNoRows) {
|
|
return uuid.Nil, xerrors.Errorf(
|
|
"get chat model config %s: %w",
|
|
modelConfigID,
|
|
err,
|
|
)
|
|
}
|
|
}
|
|
|
|
defaultConfig, err := store.GetDefaultChatModelConfig(chatdCtx)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return uuid.Nil, xerrors.New("no default chat model config is available")
|
|
}
|
|
return uuid.Nil, xerrors.Errorf("get default chat model config: %w", err)
|
|
}
|
|
return defaultConfig.ID, nil
|
|
}
|
|
|
|
// EditMessage replaces an earlier user message and discards the
|
|
// active-history suffix through chatstate.EditMessage. Model-config
|
|
// override validation and usage-limit admission run in the same
|
|
// transaction as the state-machine transition.
|
|
func (p *Server) EditMessage(
|
|
ctx context.Context,
|
|
opts EditMessageOptions,
|
|
) (EditMessageResult, error) {
|
|
if opts.ChatID == uuid.Nil {
|
|
return EditMessageResult{}, xerrors.New("chat_id is required")
|
|
}
|
|
if opts.EditedMessageID <= 0 {
|
|
return EditMessageResult{}, xerrors.New("edited_message_id is required")
|
|
}
|
|
if len(opts.Content) == 0 {
|
|
return EditMessageResult{}, xerrors.New("content is required")
|
|
}
|
|
|
|
content, err := chatprompt.MarshalParts(opts.Content)
|
|
if err != nil {
|
|
return EditMessageResult{}, xerrors.Errorf("marshal message content: %w", err)
|
|
}
|
|
|
|
var (
|
|
result EditMessageResult
|
|
editedMsg database.ChatMessage
|
|
editedCutoffT time.Time
|
|
)
|
|
machine := p.newChatMachine(opts.ChatID)
|
|
updateErr := machine.Update(ctx, func(tx *chatstate.Tx) error {
|
|
store := tx.Store()
|
|
lockedChat, err := store.GetChatByID(ctx, opts.ChatID)
|
|
if err != nil {
|
|
return xerrors.Errorf("load chat: %w", err)
|
|
}
|
|
if lockedChat.Archived {
|
|
return ErrChatArchived
|
|
}
|
|
if limitErr := p.checkUsageLimit(ctx, store, lockedChat.OwnerID, uuid.NullUUID{UUID: lockedChat.OrganizationID, Valid: true}); limitErr != nil {
|
|
return limitErr
|
|
}
|
|
|
|
// Capture the target message for the post-commit debug
|
|
// cleanup hook below. The transition itself revalidates
|
|
// chat ownership and user-message constraints.
|
|
target, err := store.GetChatMessageByID(ctx, opts.EditedMessageID)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return ErrEditedMessageNotFound
|
|
}
|
|
return xerrors.Errorf("get edited message: %w", err)
|
|
}
|
|
if target.ChatID != opts.ChatID {
|
|
return ErrEditedMessageNotFound
|
|
}
|
|
editedMsg = target
|
|
|
|
// Validate the optional model-config override up front so
|
|
// the user sees ErrInvalidModelConfigID instead of a
|
|
// foreign-key error from the message-insert path.
|
|
var modelOverride uuid.NullUUID
|
|
if opts.ModelConfigID != uuid.Nil {
|
|
if _, err := store.GetChatModelConfigByID(
|
|
chatdModelConfigLookupContext(ctx),
|
|
opts.ModelConfigID,
|
|
); err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return xerrors.Errorf(
|
|
"%w: %s",
|
|
ErrInvalidModelConfigID,
|
|
opts.ModelConfigID,
|
|
)
|
|
}
|
|
return xerrors.Errorf(
|
|
"get requested model config %s: %w",
|
|
opts.ModelConfigID,
|
|
err,
|
|
)
|
|
}
|
|
modelOverride = uuid.NullUUID{UUID: opts.ModelConfigID, Valid: true}
|
|
}
|
|
|
|
editResult, err := tx.EditMessage(chatstate.EditMessageInput{
|
|
MessageID: opts.EditedMessageID,
|
|
CreatedBy: opts.CreatedBy,
|
|
Content: content,
|
|
ModelConfigIDOverride: modelOverride,
|
|
})
|
|
if err != nil {
|
|
if errors.Is(err, chatstate.ErrEditedMessageNotUser) {
|
|
return ErrEditedMessageNotUser
|
|
}
|
|
return err
|
|
}
|
|
result.Message = editResult.ReplacementMessage
|
|
// Capture the post-edit chat inside the same transaction so
|
|
// the returned chat and the debug-cleanup cutoff use the
|
|
// snapshot bump and updated_at stamped by the transition.
|
|
refreshed, err := store.GetChatByID(ctx, opts.ChatID)
|
|
if err != nil {
|
|
return xerrors.Errorf("reload chat after edit: %w", err)
|
|
}
|
|
result.Chat = refreshed
|
|
editedCutoffT = refreshed.UpdatedAt
|
|
return nil
|
|
})
|
|
if updateErr != nil {
|
|
return EditMessageResult{}, updateErr
|
|
}
|
|
|
|
// Sidebar watch event keeps the chat list responsive. Stream
|
|
// side effects are handled by chat:update consumers.
|
|
p.publishChatPubsubEvent(result.Chat, codersdk.ChatWatchEventKindStatusChange, nil)
|
|
|
|
// Editing can race with an interrupted worker still flushing its
|
|
// final debug writes. Run a short bounded retry loop so we converge
|
|
// quickly without relying on the much longer stale-finalization
|
|
// sweep. Source editCutoff from the DB-stamped updated_at returned
|
|
// by the post-edit chat row so the filter uses the same clock that
|
|
// stamps replacement-turn debug rows; subtract
|
|
// debugCleanupClockSkew so replica clock drift cannot let the retry
|
|
// delete a replacement turn's debug rows.
|
|
editCutoff := editedCutoffT.Add(-debugCleanupClockSkew)
|
|
p.scheduleDebugCleanup(
|
|
ctx,
|
|
"failed to delete chat debug rows after edit",
|
|
[]slog.Field{
|
|
slog.F("chat_id", opts.ChatID),
|
|
slog.F("edited_message_id", editedMsg.ID),
|
|
},
|
|
func(cleanupCtx context.Context, debugSvc *chatdebug.Service) error {
|
|
_, err := debugSvc.DeleteAfterMessageID(cleanupCtx, opts.ChatID, editedMsg.ID-1, editCutoff)
|
|
return err
|
|
},
|
|
)
|
|
|
|
return result, nil
|
|
}
|
|
|
|
// ErrArchiveRequiresRootChat is returned by [Server.ArchiveChat] and
|
|
// [Server.UnarchiveChat] when the supplied chat is a child chat.
|
|
// Archive state changes must always target the root chat so the
|
|
// whole family flips together.
|
|
var ErrArchiveRequiresRootChat = xerrors.New(
|
|
"chat archive state can only be changed on the root chat",
|
|
)
|
|
|
|
// ArchiveChat archives a root chat and every child in its family
|
|
// through the chatstate state machine. The transition is atomic over
|
|
// the whole family: either every member is archived or none is. The
|
|
// state machine only permits archive from the idle / error execution
|
|
// states (W, E0, E1); active members cause a state conflict that the
|
|
// HTTP handler maps to a client error.
|
|
//
|
|
// Child chats must not be archived independently. ArchiveChat
|
|
// rejects them with [ErrArchiveRequiresRootChat] so callers cannot
|
|
// silently break the parent-implies-child archive invariant.
|
|
func (p *Server) ArchiveChat(ctx context.Context, chat database.Chat) error {
|
|
if chat.ID == uuid.Nil {
|
|
return xerrors.New("chat_id is required")
|
|
}
|
|
if chat.ParentChatID.Valid {
|
|
return ErrArchiveRequiresRootChat
|
|
}
|
|
return p.setChatFamilyArchived(ctx, chat, true, codersdk.ChatWatchEventKindDeleted)
|
|
}
|
|
|
|
// UnarchiveChat unarchives a root chat and every child in its family
|
|
// through the chatstate state machine. Like ArchiveChat the cascade
|
|
// is atomic; ChildChat unarchive attempts are rejected with
|
|
// [ErrArchiveRequiresRootChat].
|
|
func (p *Server) UnarchiveChat(ctx context.Context, chat database.Chat) error {
|
|
if chat.ID == uuid.Nil {
|
|
return xerrors.New("chat_id is required")
|
|
}
|
|
if chat.ParentChatID.Valid {
|
|
return ErrArchiveRequiresRootChat
|
|
}
|
|
return p.setChatFamilyArchived(ctx, chat, false, codersdk.ChatWatchEventKindCreated)
|
|
}
|
|
|
|
// setChatFamilyArchived applies SetArchived(archived) to every chat
|
|
// in chat's family through chatstate. The transaction-captured
|
|
// family rows feed the post-commit debug cleanup and sidebar watch
|
|
// events. Callers must only invoke this for root chats.
|
|
//
|
|
//nolint:revive // Existing API takes the target archive state as a boolean.
|
|
func (p *Server) setChatFamilyArchived(
|
|
ctx context.Context,
|
|
chat database.Chat,
|
|
archived bool,
|
|
watchKind codersdk.ChatWatchEventKind,
|
|
) error {
|
|
if chat.ID == uuid.Nil {
|
|
return xerrors.New("chat_id is required")
|
|
}
|
|
if chat.ParentChatID.Valid {
|
|
return ErrArchiveRequiresRootChat
|
|
}
|
|
|
|
familyChats, err := chatstate.SetFamilyArchived(
|
|
ctx,
|
|
p.db,
|
|
p.pubsub,
|
|
chatstate.SetFamilyArchivedInput{
|
|
RootID: chat.ID,
|
|
Archived: archived,
|
|
},
|
|
)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if archived {
|
|
p.scheduleArchiveDebugCleanup(ctx, familyChats)
|
|
}
|
|
|
|
p.publishChatPubsubEvents(familyChats, watchKind)
|
|
return nil
|
|
}
|
|
|
|
// DeleteQueued removes a queued user message through the chatstate
|
|
// state machine. Stream side effects are handled by chat:update
|
|
// consumers.
|
|
func (p *Server) DeleteQueued(
|
|
ctx context.Context,
|
|
chatID uuid.UUID,
|
|
queuedMessageID int64,
|
|
) error {
|
|
if chatID == uuid.Nil {
|
|
return xerrors.New("chat_id is required")
|
|
}
|
|
|
|
machine := p.newChatMachine(chatID)
|
|
err := machine.Update(ctx, func(tx *chatstate.Tx) error {
|
|
_, err := tx.DeleteQueuedMessage(chatstate.DeleteQueuedMessageInput{
|
|
QueuedMessageID: queuedMessageID,
|
|
})
|
|
return err
|
|
})
|
|
return err
|
|
}
|
|
|
|
// PromoteQueued promotes a queued message through the chatstate state
|
|
// machine. From running / interrupting states the state machine
|
|
// transitions the chat to `interrupting` so the worker can drain the
|
|
// in-flight generation before promoting; from idle / error / requires
|
|
// action states it inserts the user message into history
|
|
// synchronously.
|
|
func (p *Server) PromoteQueued(
|
|
ctx context.Context,
|
|
opts PromoteQueuedOptions,
|
|
) (PromoteQueuedResult, error) {
|
|
if opts.ChatID == uuid.Nil {
|
|
return PromoteQueuedResult{}, xerrors.New("chat_id is required")
|
|
}
|
|
|
|
var (
|
|
result PromoteQueuedResult
|
|
refreshChat database.Chat
|
|
refreshedOK bool
|
|
)
|
|
machine := p.newChatMachine(opts.ChatID)
|
|
updateErr := machine.Update(ctx, func(tx *chatstate.Tx) error {
|
|
store := tx.Store()
|
|
lockedChat, err := store.GetChatByID(ctx, opts.ChatID)
|
|
if err != nil {
|
|
return xerrors.Errorf("load chat: %w", err)
|
|
}
|
|
if lockedChat.Archived {
|
|
return ErrChatArchived
|
|
}
|
|
|
|
promoteResult, err := tx.PromoteQueuedMessage(chatstate.PromoteQueuedMessageInput{
|
|
QueuedMessageID: opts.QueuedMessageID,
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if promoteResult.InsertedMessage != nil {
|
|
result.PromotedMessage = *promoteResult.InsertedMessage
|
|
}
|
|
// Capture the chat inside the transaction so the watch event
|
|
// published below uses the snapshot bump and status change
|
|
// produced by the transition itself.
|
|
refreshed, err := store.GetChatByID(ctx, opts.ChatID)
|
|
if err != nil {
|
|
return xerrors.Errorf("reload chat after promote: %w", err)
|
|
}
|
|
refreshChat = refreshed
|
|
refreshedOK = true
|
|
return nil
|
|
})
|
|
if updateErr != nil {
|
|
return PromoteQueuedResult{}, updateErr
|
|
}
|
|
|
|
if refreshedOK {
|
|
p.publishChatPubsubEvent(refreshChat, codersdk.ChatWatchEventKindStatusChange, nil)
|
|
}
|
|
return result, nil
|
|
}
|
|
|
|
// SubmitToolResultsOptions controls tool result submission.
|
|
type SubmitToolResultsOptions struct {
|
|
ChatID uuid.UUID
|
|
UserID uuid.UUID
|
|
ModelConfigID uuid.UUID
|
|
Results []codersdk.ToolResult
|
|
DynamicTools json.RawMessage
|
|
}
|
|
|
|
// ToolResultValidationError indicates the submitted tool results
|
|
// failed validation (e.g. missing, duplicate, or unexpected IDs,
|
|
// or invalid JSON output).
|
|
type ToolResultValidationError struct {
|
|
Message string
|
|
Detail string
|
|
}
|
|
|
|
func (e *ToolResultValidationError) Error() string {
|
|
if e.Detail != "" {
|
|
return e.Message + ": " + e.Detail
|
|
}
|
|
return e.Message
|
|
}
|
|
|
|
// ToolResultStatusConflictError indicates the chat is not in the
|
|
// requires_action state expected for tool result submission.
|
|
type ToolResultStatusConflictError struct {
|
|
ActualStatus database.ChatStatus
|
|
}
|
|
|
|
func (e *ToolResultStatusConflictError) Error() string {
|
|
return fmt.Sprintf(
|
|
"chat status is %q, expected %q",
|
|
e.ActualStatus, database.ChatStatusRequiresAction,
|
|
)
|
|
}
|
|
|
|
// SubmitToolResults validates and persists client-provided tool
|
|
// results, returning the chat to running through the chatstate state
|
|
// machine. Validation runs inside the same transaction as the
|
|
// transition so the assistant message and pending tool calls cannot
|
|
// drift between reads.
|
|
func (p *Server) SubmitToolResults(
|
|
ctx context.Context,
|
|
opts SubmitToolResultsOptions,
|
|
) error {
|
|
var (
|
|
statusConflict *ToolResultStatusConflictError
|
|
refreshChat database.Chat
|
|
refreshedOK bool
|
|
)
|
|
machine := p.newChatMachine(opts.ChatID)
|
|
updateErr := machine.Update(ctx, func(tx *chatstate.Tx) error {
|
|
store := tx.Store()
|
|
locked, err := store.GetChatByID(ctx, opts.ChatID)
|
|
if err != nil {
|
|
return xerrors.Errorf("load chat: %w", err)
|
|
}
|
|
if locked.Archived {
|
|
return ErrChatArchived
|
|
}
|
|
|
|
toolResults := make([]chatstate.ToolResultInput, 0, len(opts.Results))
|
|
for _, r := range opts.Results {
|
|
toolResults = append(toolResults, chatstate.ToolResultInput{
|
|
ToolCallID: r.ToolCallID,
|
|
Output: r.Output,
|
|
IsError: r.IsError,
|
|
})
|
|
}
|
|
modelConfigID := opts.ModelConfigID
|
|
if modelConfigID == uuid.Nil {
|
|
modelConfigID = locked.LastModelConfigID
|
|
}
|
|
if _, err := tx.CompleteRequiresAction(chatstate.CompleteRequiresActionInput{
|
|
CreatedBy: opts.UserID,
|
|
ModelConfigID: modelConfigID,
|
|
Results: toolResults,
|
|
}); err != nil {
|
|
if !errors.Is(err, chatstate.ErrInvalidState) &&
|
|
locked.Status != database.ChatStatusRequiresAction &&
|
|
errors.Is(err, chatstate.ErrTransitionNotAllowed) {
|
|
statusConflict = &ToolResultStatusConflictError{
|
|
ActualStatus: locked.Status,
|
|
}
|
|
return statusConflict
|
|
}
|
|
return err
|
|
}
|
|
// Capture the chat inside the transaction so the watch event
|
|
// uses the snapshot bump and status change produced by the
|
|
// transition itself.
|
|
refreshed, err := store.GetChatByID(ctx, opts.ChatID)
|
|
if err != nil {
|
|
return xerrors.Errorf("reload chat after tool results: %w", err)
|
|
}
|
|
refreshChat = refreshed
|
|
refreshedOK = true
|
|
return nil
|
|
})
|
|
if updateErr != nil {
|
|
if statusConflict != nil {
|
|
return statusConflict
|
|
}
|
|
return translateToolResultValidationError(updateErr)
|
|
}
|
|
|
|
if refreshedOK {
|
|
p.publishChatPubsubEvent(refreshChat, codersdk.ChatWatchEventKindStatusChange, nil)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// translateToolResultValidationError converts a chatstate tool-result
|
|
// validation error into the legacy chatd.ToolResultValidationError
|
|
// shape so HTTP handlers preserve their existing response detail. If
|
|
// err is not a tool-result validation error, it is returned
|
|
// unchanged.
|
|
func translateToolResultValidationError(err error) error {
|
|
var v *chatstate.ToolResultValidationError
|
|
if !errors.As(err, &v) {
|
|
return err
|
|
}
|
|
switch {
|
|
case xerrors.Is(v, chatstate.ErrToolResultDuplicate):
|
|
return &ToolResultValidationError{
|
|
Message: "Duplicate tool_call_id in results.",
|
|
Detail: fmt.Sprintf("Duplicate tool call ID %q.", v.ToolCallID),
|
|
}
|
|
case xerrors.Is(v, chatstate.ErrToolResultMissing):
|
|
return &ToolResultValidationError{
|
|
Message: "Missing tool result.",
|
|
Detail: fmt.Sprintf("Missing result for tool call %q.", v.ToolCallID),
|
|
}
|
|
case xerrors.Is(v, chatstate.ErrToolResultUnexpected):
|
|
return &ToolResultValidationError{
|
|
Message: "Unexpected tool result.",
|
|
Detail: fmt.Sprintf("No pending tool call with ID %q.", v.ToolCallID),
|
|
}
|
|
case xerrors.Is(v, chatstate.ErrToolResultInvalidJSON):
|
|
return &ToolResultValidationError{
|
|
Message: "Tool result output must be valid JSON.",
|
|
Detail: fmt.Sprintf("Output for tool call %q is not valid JSON.", v.ToolCallID),
|
|
}
|
|
}
|
|
return err
|
|
}
|
|
|
|
// InterruptChat interrupts execution through the chatstate.Interrupt
|
|
// transition. Active runs land in `interrupting`; requires-action
|
|
// chats synthesize cancellation messages and return to running.
|
|
//
|
|
// Returns the post-transition chat and an error so callers can map
|
|
// state conflicts deliberately. Idle chats return a
|
|
// chatstate.ErrTransitionNotAllowed wrapper.
|
|
func (p *Server) InterruptChat(
|
|
ctx context.Context,
|
|
chat database.Chat,
|
|
) (database.Chat, error) {
|
|
if chat.ID == uuid.Nil {
|
|
return chat, xerrors.New("chat_id is required")
|
|
}
|
|
|
|
var refreshed database.Chat
|
|
machine := p.newChatMachine(chat.ID)
|
|
err := machine.Update(ctx, func(tx *chatstate.Tx) error {
|
|
if _, err := tx.Interrupt(chatstate.InterruptInput{
|
|
Reason: "Tool execution interrupted by user",
|
|
}); err != nil {
|
|
return err
|
|
}
|
|
// Capture the post-interrupt chat inside the transaction so
|
|
// the returned chat and the watch event reflect the snapshot
|
|
// bump and status change produced by the transition itself.
|
|
latest, err := tx.Store().GetChatByID(ctx, chat.ID)
|
|
if err != nil {
|
|
return xerrors.Errorf("reload chat after interrupt: %w", err)
|
|
}
|
|
refreshed = latest
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return chat, err
|
|
}
|
|
|
|
p.publishChatPubsubEvent(refreshed, codersdk.ChatWatchEventKindStatusChange, nil)
|
|
return refreshed, nil
|
|
}
|
|
|
|
// ReconcileInvalidStateChat recovers a chat stuck in an invalid
|
|
// execution-state combination by running the
|
|
// chatstate.ReconcileInvalidState transition. The chat lands in an
|
|
// error state (E0/E1); queued messages are preserved and pending
|
|
// dynamic-tool calls are closed with synthetic cancellations.
|
|
//
|
|
// Returns the post-transition chat. When the chat is not actually in an
|
|
// invalid state the transition returns a wrapped
|
|
// chatstate.ErrTransitionNotAllowed; a missing chat returns
|
|
// chatstate.ErrChatNotFound. Callers map these to deliberate HTTP
|
|
// responses.
|
|
func (p *Server) ReconcileInvalidStateChat(
|
|
ctx context.Context,
|
|
chat database.Chat,
|
|
) (database.Chat, error) {
|
|
if chat.ID == uuid.Nil {
|
|
return chat, xerrors.New("chat_id is required")
|
|
}
|
|
|
|
var refreshed database.Chat
|
|
machine := p.newChatMachine(chat.ID)
|
|
err := machine.Update(ctx, func(tx *chatstate.Tx) error {
|
|
if _, err := tx.ReconcileInvalidState(chatstate.ReconcileInvalidStateInput{}); err != nil {
|
|
return err
|
|
}
|
|
// Capture the post-reconcile chat inside the transaction so
|
|
// the returned chat and the watch event reflect the snapshot
|
|
// bump and status change produced by the transition itself.
|
|
latest, err := tx.Store().GetChatByID(ctx, chat.ID)
|
|
if err != nil {
|
|
return xerrors.Errorf("reload chat after reconcile: %w", err)
|
|
}
|
|
refreshed = latest
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return chat, err
|
|
}
|
|
|
|
p.publishChatPubsubEvent(refreshed, codersdk.ChatWatchEventKindStatusChange, nil)
|
|
return refreshed, nil
|
|
}
|
|
|
|
const manualTitleMessageWindowLimit = 50
|
|
|
|
var ErrManualTitleRegenerationInProgress = xerrors.New(
|
|
"manual title regeneration already in progress",
|
|
)
|
|
|
|
type manualTitleCandidateResult struct {
|
|
title string
|
|
modelConfig database.ChatModelConfig
|
|
usage fantasy.Usage
|
|
activeAPIKeyID string
|
|
hasMessages bool
|
|
}
|
|
|
|
type manualTitleGenerationError struct {
|
|
cause error
|
|
modelConfig database.ChatModelConfig
|
|
usage fantasy.Usage
|
|
activeAPIKeyID string
|
|
}
|
|
|
|
// generatedChatTitle carries the title produced by the detached
|
|
// automatic title-generation goroutine. maybeGenerateChatTitle stores
|
|
// the generated title here so tests can observe it without a database
|
|
// read; the title_change pubsub event it publishes remains the source of
|
|
// truth for clients.
|
|
type generatedChatTitle struct {
|
|
mu sync.RWMutex
|
|
title string
|
|
}
|
|
|
|
func (t *generatedChatTitle) Store(title string) {
|
|
if t == nil || title == "" {
|
|
return
|
|
}
|
|
|
|
t.mu.Lock()
|
|
t.title = title
|
|
t.mu.Unlock()
|
|
}
|
|
|
|
func (t *generatedChatTitle) Load() (string, bool) {
|
|
if t == nil {
|
|
return "", false
|
|
}
|
|
|
|
t.mu.RLock()
|
|
defer t.mu.RUnlock()
|
|
if t.title == "" {
|
|
return "", false
|
|
}
|
|
return t.title, true
|
|
}
|
|
|
|
func (e *manualTitleGenerationError) Error() string {
|
|
return e.cause.Error()
|
|
}
|
|
|
|
func (e *manualTitleGenerationError) Unwrap() error {
|
|
return e.cause
|
|
}
|
|
|
|
var manualTitleLockWorkerID = uuid.MustParse(
|
|
"00000000-0000-0000-0000-000000000001",
|
|
)
|
|
|
|
const manualTitleLockStaleAfter = time.Minute
|
|
|
|
func isFreshManualTitleLock(chat database.Chat, now time.Time) bool {
|
|
if !chat.WorkerID.Valid || chat.WorkerID.UUID != manualTitleLockWorkerID {
|
|
return false
|
|
}
|
|
leaseAt := chat.HeartbeatAt
|
|
if !leaseAt.Valid {
|
|
leaseAt = chat.StartedAt
|
|
}
|
|
return leaseAt.Valid && leaseAt.Time.After(now.Add(-manualTitleLockStaleAfter))
|
|
}
|
|
|
|
// updateChatStatusPreserveUpdatedAt applies internal lock transitions without
|
|
// changing chat recency, because chat list ordering uses updated_at.
|
|
func updateChatStatusPreserveUpdatedAt(
|
|
ctx context.Context,
|
|
store database.Store,
|
|
chat database.Chat,
|
|
workerID uuid.NullUUID,
|
|
startedAt sql.NullTime,
|
|
heartbeatAt sql.NullTime,
|
|
) (database.Chat, error) {
|
|
return store.UpdateChatStatusPreserveUpdatedAt(
|
|
ctx,
|
|
database.UpdateChatStatusPreserveUpdatedAtParams{
|
|
ID: chat.ID,
|
|
Status: chat.Status,
|
|
WorkerID: workerID,
|
|
StartedAt: startedAt,
|
|
HeartbeatAt: heartbeatAt,
|
|
LastError: chat.LastError,
|
|
UpdatedAt: chat.UpdatedAt,
|
|
},
|
|
)
|
|
}
|
|
|
|
func (p *Server) acquireManualTitleLock(ctx context.Context, chatID uuid.UUID) error {
|
|
now := time.Now()
|
|
return p.db.InTx(func(tx database.Store) error {
|
|
lockedChat, err := tx.GetChatByIDForUpdate(ctx, chatID)
|
|
if err != nil {
|
|
return xerrors.Errorf("lock chat for manual title regeneration: %w", err)
|
|
}
|
|
// Only a fresh manual lock or a chat without a real worker should
|
|
// block title regeneration. Running chats with a real worker may
|
|
// regenerate their title concurrently, and last write wins.
|
|
hasRealWorker := lockedChat.Status == database.ChatStatusRunning &&
|
|
lockedChat.WorkerID.Valid &&
|
|
lockedChat.WorkerID.UUID != manualTitleLockWorkerID
|
|
if lockedChat.Status == database.ChatStatusPending ||
|
|
(lockedChat.Status == database.ChatStatusRunning && !hasRealWorker) ||
|
|
isFreshManualTitleLock(lockedChat, now) {
|
|
return ErrManualTitleRegenerationInProgress
|
|
}
|
|
if hasRealWorker {
|
|
return nil
|
|
}
|
|
|
|
_, err = updateChatStatusPreserveUpdatedAt(
|
|
ctx,
|
|
tx,
|
|
lockedChat,
|
|
uuid.NullUUID{UUID: manualTitleLockWorkerID, Valid: true},
|
|
sql.NullTime{Time: now, Valid: true},
|
|
sql.NullTime{},
|
|
)
|
|
if err != nil {
|
|
return xerrors.Errorf("mark chat for manual title regeneration: %w", err)
|
|
}
|
|
return nil
|
|
}, database.DefaultTXOptions().WithID("chat_title_regenerate_lock"))
|
|
}
|
|
|
|
func (p *Server) releaseManualTitleLock(ctx context.Context, chatID uuid.UUID) {
|
|
cleanupCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second)
|
|
defer cancel()
|
|
|
|
err := p.db.InTx(func(tx database.Store) error {
|
|
lockedChat, err := tx.GetChatByIDForUpdate(cleanupCtx, chatID)
|
|
if err != nil {
|
|
return xerrors.Errorf("lock chat to release manual title regeneration: %w", err)
|
|
}
|
|
if !lockedChat.WorkerID.Valid || lockedChat.WorkerID.UUID != manualTitleLockWorkerID {
|
|
return nil
|
|
}
|
|
_, err = updateChatStatusPreserveUpdatedAt(
|
|
cleanupCtx,
|
|
tx,
|
|
lockedChat,
|
|
uuid.NullUUID{},
|
|
sql.NullTime{},
|
|
sql.NullTime{},
|
|
)
|
|
if err != nil {
|
|
return xerrors.Errorf("clear manual title regeneration marker: %w", err)
|
|
}
|
|
return nil
|
|
}, database.DefaultTXOptions().WithID("chat_title_regenerate_unlock"))
|
|
if err != nil {
|
|
p.logger.Warn(cleanupCtx, "failed to release manual title regeneration marker",
|
|
slog.F("chat_id", chatID),
|
|
slog.Error(err),
|
|
)
|
|
}
|
|
}
|
|
|
|
// RegenerateChatTitle regenerates a chat title from the chat's visible
|
|
// messages, persists it when it changes, and broadcasts the update.
|
|
func (p *Server) RegenerateChatTitle(
|
|
ctx context.Context,
|
|
chat database.Chat,
|
|
) (database.Chat, error) {
|
|
// Reuse chatd's scoped auth context for deployment-config lookups while
|
|
// keeping chat ownership authorization at the HTTP layer.
|
|
//nolint:gocritic // Non-admin users need chatd-scoped config reads here.
|
|
chatdCtx := dbauthz.AsChatd(ctx)
|
|
keys, err := p.resolveUserProviderAPIKeys(chatdCtx, chat.OwnerID, uuid.Nil)
|
|
if err != nil {
|
|
keys = chatprovider.ProviderAPIKeys{}
|
|
}
|
|
if err := p.acquireManualTitleLock(ctx, chat.ID); err != nil {
|
|
return database.Chat{}, err
|
|
}
|
|
defer p.releaseManualTitleLock(chatdCtx, chat.ID)
|
|
|
|
updatedChat, err := p.regenerateChatTitleWithStore(
|
|
chatdCtx,
|
|
p.db,
|
|
chat,
|
|
keys,
|
|
)
|
|
if err != nil {
|
|
return database.Chat{}, p.recordManualTitleGenerationFailure(ctx, chat, err)
|
|
}
|
|
return updatedChat, nil
|
|
}
|
|
|
|
// RenameChatTitle persists a user-supplied chat title.
|
|
func (p *Server) RenameChatTitle(
|
|
ctx context.Context,
|
|
chat database.Chat,
|
|
newTitle string,
|
|
) (updated database.Chat, wrote bool, err error) {
|
|
//nolint:gocritic // Lock release needs chatd-scoped writes.
|
|
chatdCtx := dbauthz.AsChatd(ctx)
|
|
if err := p.acquireManualTitleLock(ctx, chat.ID); err != nil {
|
|
return database.Chat{}, false, err
|
|
}
|
|
defer p.releaseManualTitleLock(chatdCtx, chat.ID)
|
|
|
|
currentChat, err := p.db.GetChatByID(ctx, chat.ID)
|
|
if err != nil {
|
|
return database.Chat{}, false, xerrors.Errorf("get chat for rename: %w", err)
|
|
}
|
|
if newTitle == currentChat.Title {
|
|
return currentChat, false, nil
|
|
}
|
|
|
|
updatedChat, err := p.db.UpdateChatTitleByID(ctx, database.UpdateChatTitleByIDParams{
|
|
ID: chat.ID,
|
|
Title: newTitle,
|
|
})
|
|
if err != nil {
|
|
return database.Chat{}, false, xerrors.Errorf("update chat title: %w", err)
|
|
}
|
|
return updatedChat, true, nil
|
|
}
|
|
|
|
// PublishTitleChange broadcasts a title_change event for the given chat.
|
|
func (p *Server) PublishTitleChange(chat database.Chat) {
|
|
p.publishChatPubsubEvent(chat, codersdk.ChatWatchEventKindTitleChange, nil)
|
|
}
|
|
|
|
// ProposeChatTitle generates a title suggestion from the chat's visible messages without persisting it.
|
|
func (p *Server) ProposeChatTitle(
|
|
ctx context.Context,
|
|
chat database.Chat,
|
|
) (string, error) {
|
|
//nolint:gocritic // Non-admin users need chatd-scoped config reads here.
|
|
chatdCtx := dbauthz.AsChatd(ctx)
|
|
keys, err := p.resolveUserProviderAPIKeys(chatdCtx, chat.OwnerID, uuid.Nil)
|
|
if err != nil {
|
|
keys = chatprovider.ProviderAPIKeys{}
|
|
}
|
|
if err := p.acquireManualTitleLock(ctx, chat.ID); err != nil {
|
|
return "", err
|
|
}
|
|
defer p.releaseManualTitleLock(chatdCtx, chat.ID)
|
|
|
|
title, err := p.proposeChatTitleWithStore(chatdCtx, p.db, chat, keys)
|
|
if err != nil {
|
|
return "", p.recordManualTitleGenerationFailure(ctx, chat, err)
|
|
}
|
|
return title, nil
|
|
}
|
|
|
|
func (p *Server) recordManualTitleGenerationFailure(
|
|
ctx context.Context,
|
|
chat database.Chat,
|
|
err error,
|
|
) error {
|
|
var generationErr *manualTitleGenerationError
|
|
if !errors.As(err, &generationErr) {
|
|
return err
|
|
}
|
|
|
|
//nolint:gocritic // Failure accounting still needs chatd-scoped config reads.
|
|
recordCtx, recordCancel := context.WithTimeout(
|
|
dbauthz.AsChatd(context.WithoutCancel(ctx)),
|
|
5*time.Second,
|
|
)
|
|
defer recordCancel()
|
|
if _, recordErr := recordManualTitleUsage(
|
|
recordCtx,
|
|
p.db,
|
|
chat,
|
|
generationErr.modelConfig,
|
|
generationErr.usage,
|
|
generationErr.activeAPIKeyID,
|
|
"",
|
|
); recordErr != nil {
|
|
return errors.Join(
|
|
generationErr,
|
|
xerrors.Errorf("record manual title usage: %w", recordErr),
|
|
)
|
|
}
|
|
return generationErr
|
|
}
|
|
|
|
// generateManualTitleCandidate performs only model generation and returns the
|
|
// candidate plus accounting metadata. Endpoint-specific commit paths are
|
|
// responsible for recording usage and deciding whether to persist the title.
|
|
// The context may carry the caller's delegated API key for manual title routes.
|
|
func (p *Server) generateManualTitleCandidate(
|
|
ctx context.Context,
|
|
store database.Store,
|
|
chat database.Chat,
|
|
keys chatprovider.ProviderAPIKeys,
|
|
) (manualTitleCandidateResult, error) {
|
|
if limitErr := p.checkUsageLimit(ctx, store, chat.OwnerID, uuid.NullUUID{UUID: chat.OrganizationID, Valid: true}); limitErr != nil {
|
|
return manualTitleCandidateResult{}, limitErr
|
|
}
|
|
|
|
headMessages, err := store.GetChatMessagesByChatIDAscPaginated(
|
|
ctx,
|
|
database.GetChatMessagesByChatIDAscPaginatedParams{
|
|
ChatID: chat.ID,
|
|
AfterID: 0,
|
|
LimitVal: manualTitleMessageWindowLimit,
|
|
},
|
|
)
|
|
if err != nil {
|
|
return manualTitleCandidateResult{}, xerrors.Errorf("get head chat messages: %w", err)
|
|
}
|
|
tailMessages, err := store.GetChatMessagesByChatIDDescPaginated(
|
|
ctx,
|
|
database.GetChatMessagesByChatIDDescPaginatedParams{
|
|
ChatID: chat.ID,
|
|
BeforeID: 0,
|
|
LimitVal: manualTitleMessageWindowLimit,
|
|
},
|
|
)
|
|
if err != nil {
|
|
return manualTitleCandidateResult{}, xerrors.Errorf("get tail chat messages: %w", err)
|
|
}
|
|
messages := mergeManualTitleMessages(headMessages, tailMessages)
|
|
if len(messages) == 0 {
|
|
return manualTitleCandidateResult{}, nil
|
|
}
|
|
modelOpts := modelBuildOptionsFromMessages(messages)
|
|
// Manual title routes can run over messages that lack API key attribution.
|
|
// Fall back to the authenticated caller's delegated key for AI Gateway routing.
|
|
if modelOpts.ActiveAPIKeyID == "" {
|
|
if apiKeyID, ok := aibridge.DelegatedAPIKeyIDFromContext(ctx); ok {
|
|
modelOpts.ActiveAPIKeyID = apiKeyID
|
|
}
|
|
}
|
|
|
|
model, modelConfig, modelKeys, err := p.resolveManualTitleModel(ctx, store, chat, keys, modelOpts)
|
|
result := manualTitleCandidateResult{
|
|
modelConfig: modelConfig,
|
|
activeAPIKeyID: modelOpts.ActiveAPIKeyID,
|
|
hasMessages: true,
|
|
}
|
|
if err != nil {
|
|
return result, err
|
|
}
|
|
|
|
titleCtx := ctx
|
|
titleModel := model
|
|
finishDebugRun := func(error) {}
|
|
if debugSvc := p.debugService(); debugSvc != nil && debugSvc.IsEnabled(ctx, chat.ID, chat.OwnerID) {
|
|
titleCtx, titleModel, finishDebugRun = p.prepareManualTitleDebugRun(
|
|
ctx,
|
|
debugSvc,
|
|
chat,
|
|
modelConfig,
|
|
modelKeys,
|
|
modelOpts,
|
|
messages,
|
|
model,
|
|
)
|
|
}
|
|
|
|
title, usage, err := generateManualTitle(titleCtx, messages, titleModel)
|
|
finishDebugRun(err)
|
|
result.title = title
|
|
result.usage = usage
|
|
if err != nil {
|
|
wrappedErr := xerrors.Errorf("generate manual title: %w", err)
|
|
if usage == (fantasy.Usage{}) {
|
|
return result, wrappedErr
|
|
}
|
|
return result, &manualTitleGenerationError{
|
|
cause: wrappedErr,
|
|
modelConfig: modelConfig,
|
|
usage: usage,
|
|
activeAPIKeyID: modelOpts.ActiveAPIKeyID,
|
|
}
|
|
}
|
|
|
|
return result, nil
|
|
}
|
|
|
|
func (p *Server) proposeChatTitleWithStore(
|
|
ctx context.Context,
|
|
store database.Store,
|
|
chat database.Chat,
|
|
keys chatprovider.ProviderAPIKeys,
|
|
) (string, error) {
|
|
result, err := p.generateManualTitleCandidate(ctx, store, chat, keys)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
if !result.hasMessages {
|
|
return "", nil
|
|
}
|
|
|
|
recordCtx, recordCancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second)
|
|
defer recordCancel()
|
|
if _, recordErr := recordManualTitleUsage(
|
|
recordCtx,
|
|
store,
|
|
chat,
|
|
result.modelConfig,
|
|
result.usage,
|
|
result.activeAPIKeyID,
|
|
"",
|
|
); recordErr != nil {
|
|
return "", xerrors.Errorf("record manual title usage: %w", recordErr)
|
|
}
|
|
return result.title, nil
|
|
}
|
|
|
|
func (p *Server) regenerateChatTitleWithStore(
|
|
ctx context.Context,
|
|
store database.Store,
|
|
chat database.Chat,
|
|
keys chatprovider.ProviderAPIKeys,
|
|
) (database.Chat, error) {
|
|
result, err := p.generateManualTitleCandidate(ctx, store, chat, keys)
|
|
if err != nil {
|
|
return database.Chat{}, err
|
|
}
|
|
if !result.hasMessages {
|
|
return chat, nil
|
|
}
|
|
|
|
recordCtx, recordCancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second)
|
|
defer recordCancel()
|
|
|
|
updatedChat, recordErr := recordManualTitleUsage(
|
|
recordCtx,
|
|
store,
|
|
chat,
|
|
result.modelConfig,
|
|
result.usage,
|
|
result.activeAPIKeyID,
|
|
result.title,
|
|
)
|
|
if recordErr != nil {
|
|
if result.title != "" {
|
|
return database.Chat{}, xerrors.Errorf("record manual title usage and update chat title: %w", recordErr)
|
|
}
|
|
return database.Chat{}, xerrors.Errorf("record manual title usage: %w", recordErr)
|
|
}
|
|
if updatedChat.Title == chat.Title {
|
|
return updatedChat, nil
|
|
}
|
|
|
|
p.publishChatPubsubEvent(updatedChat, codersdk.ChatWatchEventKindTitleChange, nil)
|
|
return updatedChat, nil
|
|
}
|
|
|
|
func (p *Server) prepareManualTitleDebugRun(
|
|
ctx context.Context,
|
|
debugSvc *chatdebug.Service,
|
|
chat database.Chat,
|
|
modelConfig database.ChatModelConfig,
|
|
keys chatprovider.ProviderAPIKeys,
|
|
modelOpts modelBuildOptions,
|
|
messages []database.ChatMessage,
|
|
fallbackModel fantasy.LanguageModel,
|
|
) (context.Context, fantasy.LanguageModel, func(error)) {
|
|
titleCtx := ctx
|
|
titleModel := fallbackModel
|
|
finishDebugRun := func(error) {}
|
|
|
|
route, routeErr := p.resolveModelRouteForConfig(ctx, chat.OwnerID, modelConfig, keys)
|
|
debugOpts := modelOpts
|
|
debugOpts.RecordHTTP = true
|
|
var debugModelErr error
|
|
var debugModel fantasy.LanguageModel
|
|
if routeErr != nil {
|
|
debugModelErr = routeErr
|
|
} else {
|
|
debugModel, debugModelErr = p.newModel(ctx, modelClientRequest{
|
|
Chat: chat,
|
|
ModelName: modelConfig.Model,
|
|
UserAgent: chatprovider.UserAgent(),
|
|
ExtraHeaders: chatprovider.CoderHeaders(chat),
|
|
}, route, debugOpts)
|
|
}
|
|
switch {
|
|
case debugModelErr != nil:
|
|
p.logger.Warn(ctx, "failed to create debug-aware manual title model",
|
|
slog.F("chat_id", chat.ID),
|
|
slog.F("provider", modelConfig.Provider),
|
|
slog.F("model", modelConfig.Model),
|
|
slog.Error(debugModelErr),
|
|
)
|
|
case debugModel == nil:
|
|
p.logger.Warn(ctx, "manual title debug model creation returned nil",
|
|
slog.F("chat_id", chat.ID),
|
|
slog.F("provider", modelConfig.Provider),
|
|
slog.F("model", modelConfig.Model),
|
|
)
|
|
default:
|
|
titleModel = chatdebug.WrapModel(debugModel, debugSvc, chatdebug.RecorderOptions{
|
|
ChatID: chat.ID,
|
|
OwnerID: chat.OwnerID,
|
|
Provider: modelConfig.Provider,
|
|
Model: modelConfig.Model,
|
|
})
|
|
}
|
|
|
|
var historyTipMessageID int64
|
|
if len(messages) > 0 {
|
|
historyTipMessageID = messages[len(messages)-1].ID
|
|
}
|
|
|
|
// Derive a first_message label from the first user message.
|
|
var firstUserLabel string
|
|
for _, msg := range messages {
|
|
if msg.Role == database.ChatMessageRoleUser {
|
|
if parts, parseErr := chatprompt.ParseContent(msg); parseErr == nil {
|
|
firstUserLabel = contentBlocksToText(parts)
|
|
}
|
|
break
|
|
}
|
|
}
|
|
if firstUserLabel == "" {
|
|
firstUserLabel = "Title generation"
|
|
}
|
|
seedSummary := chatdebug.SeedSummary(
|
|
chatdebug.TruncateLabel(firstUserLabel, chatdebug.MaxLabelLength),
|
|
)
|
|
|
|
createRunCtx, createRunCancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second)
|
|
debugRun, createRunErr := debugSvc.CreateRun(createRunCtx, chatdebug.CreateRunParams{
|
|
ChatID: chat.ID,
|
|
ModelConfigID: modelConfig.ID,
|
|
Provider: modelConfig.Provider,
|
|
Model: modelConfig.Model,
|
|
Kind: chatdebug.KindTitleGeneration,
|
|
Status: chatdebug.StatusInProgress,
|
|
HistoryTipMessageID: historyTipMessageID,
|
|
TriggerMessageID: 0,
|
|
Summary: seedSummary,
|
|
})
|
|
createRunCancel()
|
|
if createRunErr != nil {
|
|
p.logger.Warn(ctx, "failed to create manual title debug run",
|
|
slog.F("chat_id", chat.ID),
|
|
slog.F("provider", modelConfig.Provider),
|
|
slog.F("model", modelConfig.Model),
|
|
slog.Error(createRunErr),
|
|
)
|
|
return titleCtx, titleModel, finishDebugRun
|
|
}
|
|
|
|
runContext := chatdebugRunContext(debugRun)
|
|
titleCtx = chatdebug.ContextWithRun(titleCtx, &runContext)
|
|
finishDebugRun = func(generateErr error) {
|
|
if finalizeErr := debugSvc.FinalizeRun(ctx, chatdebug.FinalizeRunParams{
|
|
RunID: debugRun.ID,
|
|
ChatID: debugRun.ChatID,
|
|
Status: chatdebug.ClassifyError(generateErr),
|
|
SeedSummary: seedSummary,
|
|
}); finalizeErr != nil {
|
|
p.logger.Warn(ctx, "failed to finalize manual title debug run",
|
|
slog.F("chat_id", chat.ID),
|
|
slog.F("run_id", debugRun.ID),
|
|
slog.Error(finalizeErr),
|
|
)
|
|
}
|
|
}
|
|
|
|
return titleCtx, titleModel, finishDebugRun
|
|
}
|
|
|
|
func chatdebugRunContext(run database.ChatDebugRun) chatdebug.RunContext {
|
|
runContext := chatdebug.RunContext{
|
|
RunID: run.ID,
|
|
ChatID: run.ChatID,
|
|
Kind: chatdebug.RunKind(run.Kind),
|
|
}
|
|
if run.RootChatID.Valid {
|
|
runContext.RootChatID = run.RootChatID.UUID
|
|
}
|
|
if run.ParentChatID.Valid {
|
|
runContext.ParentChatID = run.ParentChatID.UUID
|
|
}
|
|
if run.ModelConfigID.Valid {
|
|
runContext.ModelConfigID = run.ModelConfigID.UUID
|
|
}
|
|
if run.TriggerMessageID.Valid {
|
|
runContext.TriggerMessageID = run.TriggerMessageID.Int64
|
|
}
|
|
if run.HistoryTipMessageID.Valid {
|
|
runContext.HistoryTipMessageID = run.HistoryTipMessageID.Int64
|
|
}
|
|
if run.Provider.Valid {
|
|
runContext.Provider = run.Provider.String
|
|
}
|
|
if run.Model.Valid {
|
|
runContext.Model = run.Model.String
|
|
}
|
|
return runContext
|
|
}
|
|
|
|
func deriveChatDebugSeed(messages []database.ChatMessage) (
|
|
triggerMessageID int64,
|
|
historyTipMessageID int64,
|
|
triggerLabel string,
|
|
) {
|
|
for i := len(messages) - 1; i >= 0; i-- {
|
|
if messages[i].Role != database.ChatMessageRoleUser {
|
|
continue
|
|
}
|
|
triggerMessageID = messages[i].ID
|
|
if parts, parseErr := chatprompt.ParseContent(messages[i]); parseErr == nil {
|
|
triggerLabel = contentBlocksToText(parts)
|
|
}
|
|
break
|
|
}
|
|
|
|
if len(messages) > 0 {
|
|
historyTipMessageID = messages[len(messages)-1].ID
|
|
}
|
|
|
|
return triggerMessageID, historyTipMessageID, triggerLabel
|
|
}
|
|
|
|
func prepareChatTurnDebugRun(
|
|
ctx context.Context,
|
|
logger slog.Logger,
|
|
chat database.Chat,
|
|
modelConfig database.ChatModelConfig,
|
|
debugSvc *chatdebug.Service,
|
|
debugProvider string,
|
|
debugModel string,
|
|
triggerMessageID int64,
|
|
historyTipMessageID int64,
|
|
triggerLabel string,
|
|
) (context.Context, func(error, any)) {
|
|
finishDebugRun := func(error, any) {}
|
|
if debugSvc == nil {
|
|
return ctx, finishDebugRun
|
|
}
|
|
|
|
seedSummary := chatdebug.SeedSummary(
|
|
chatdebug.TruncateLabel(triggerLabel, chatdebug.MaxLabelLength),
|
|
)
|
|
rootChatID := uuid.Nil
|
|
if chat.RootChatID.Valid {
|
|
rootChatID = chat.RootChatID.UUID
|
|
}
|
|
parentChatID := uuid.Nil
|
|
if chat.ParentChatID.Valid {
|
|
parentChatID = chat.ParentChatID.UUID
|
|
}
|
|
|
|
// Debug instrumentation must never block the user turn. Detach
|
|
// from the chat-processing context and bound the insert so a slow
|
|
// or locked DB makes debug logging degrade silently rather than
|
|
// stalling chat processing. Matches the pattern used by
|
|
// prepareManualTitleDebugRun.
|
|
createRunCtx, createRunCancel := context.WithTimeout(
|
|
context.WithoutCancel(ctx), debugCreateRunTimeout,
|
|
)
|
|
run, createRunErr := debugSvc.CreateRun(createRunCtx, chatdebug.CreateRunParams{
|
|
ChatID: chat.ID,
|
|
RootChatID: rootChatID,
|
|
ParentChatID: parentChatID,
|
|
ModelConfigID: modelConfig.ID,
|
|
TriggerMessageID: triggerMessageID,
|
|
HistoryTipMessageID: historyTipMessageID,
|
|
Kind: chatdebug.KindChatTurn,
|
|
Status: chatdebug.StatusInProgress,
|
|
Provider: debugProvider,
|
|
Model: debugModel,
|
|
Summary: seedSummary,
|
|
})
|
|
createRunCancel()
|
|
if createRunErr != nil {
|
|
logger.Warn(ctx, "failed to create chat debug run",
|
|
slog.F("chat_id", chat.ID),
|
|
slog.Error(createRunErr),
|
|
)
|
|
return ctx, finishDebugRun
|
|
}
|
|
|
|
runCtx := chatdebug.ContextWithRun(ctx, &chatdebug.RunContext{
|
|
RunID: run.ID,
|
|
ChatID: chat.ID,
|
|
RootChatID: rootChatID,
|
|
ParentChatID: parentChatID,
|
|
ModelConfigID: modelConfig.ID,
|
|
TriggerMessageID: triggerMessageID,
|
|
HistoryTipMessageID: historyTipMessageID,
|
|
Kind: chatdebug.KindChatTurn,
|
|
Provider: debugProvider,
|
|
Model: debugModel,
|
|
})
|
|
finishDebugRun = func(loopErr error, panicValue any) {
|
|
status := chatdebug.ClassifyError(loopErr)
|
|
switch {
|
|
case panicValue != nil:
|
|
status = chatdebug.StatusError
|
|
case errors.Is(loopErr, chatloop.ErrInterrupted):
|
|
status = chatdebug.StatusInterrupted
|
|
case errors.Is(loopErr, chatloop.ErrDynamicToolCall):
|
|
// Dynamic tool calls are a successful pause; the run completed
|
|
// its model round-trip.
|
|
status = chatdebug.StatusCompleted
|
|
}
|
|
|
|
if finalizeErr := debugSvc.FinalizeRun(runCtx, chatdebug.FinalizeRunParams{
|
|
RunID: run.ID,
|
|
ChatID: chat.ID,
|
|
Status: status,
|
|
SeedSummary: seedSummary,
|
|
}); finalizeErr != nil {
|
|
logger.Warn(ctx, "failed to finalize chat debug run",
|
|
slog.F("chat_id", chat.ID),
|
|
slog.F("run_id", run.ID),
|
|
slog.Error(finalizeErr),
|
|
)
|
|
}
|
|
}
|
|
|
|
return runCtx, finishDebugRun
|
|
}
|
|
|
|
func (p *Server) resolveManualTitleModel(
|
|
ctx context.Context,
|
|
store database.Store,
|
|
chat database.Chat,
|
|
keys chatprovider.ProviderAPIKeys,
|
|
modelOpts modelBuildOptions,
|
|
) (fantasy.LanguageModel, database.ChatModelConfig, chatprovider.ProviderAPIKeys, error) {
|
|
overrideConfig, overrideModel, overrideKeys, _, overrideSet, overrideErr := p.resolveTitleGenerationModelOverride(
|
|
ctx,
|
|
chat,
|
|
keys,
|
|
modelOpts,
|
|
)
|
|
if overrideErr != nil {
|
|
if overrideSet {
|
|
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, xerrors.Errorf(
|
|
"resolve manual title generation model override: %w",
|
|
overrideErr,
|
|
)
|
|
}
|
|
p.logger.Debug(ctx, "failed to resolve title generation model override for manual title",
|
|
slog.F("chat_id", chat.ID),
|
|
slog.Error(overrideErr),
|
|
)
|
|
} else if overrideSet {
|
|
return overrideModel, overrideConfig, overrideKeys, nil
|
|
}
|
|
|
|
configs, err := store.GetEnabledChatModelConfigs(ctx)
|
|
if err != nil {
|
|
p.logger.Debug(ctx, "failed to list manual title model configs",
|
|
slog.F("chat_id", chat.ID),
|
|
slog.Error(err),
|
|
)
|
|
return p.resolveFallbackManualTitleModel(ctx, chat, keys, modelOpts)
|
|
}
|
|
|
|
config, ok := selectPreferredConfiguredShortTextModelConfig(configs)
|
|
if !ok {
|
|
return p.resolveFallbackManualTitleModel(ctx, chat, keys, modelOpts)
|
|
}
|
|
|
|
route, err := p.resolveModelRouteForConfig(ctx, chat.OwnerID, config, keys)
|
|
if err != nil {
|
|
p.logger.Debug(ctx, "manual title preferred model unavailable",
|
|
slog.F("chat_id", chat.ID),
|
|
slog.F("provider", config.Provider),
|
|
slog.F("model", config.Model),
|
|
slog.Error(err),
|
|
)
|
|
return p.resolveFallbackManualTitleModel(ctx, chat, keys, modelOpts)
|
|
}
|
|
model, err := p.newModel(ctx, modelClientRequest{
|
|
Chat: chat,
|
|
ModelName: config.Model,
|
|
UserAgent: chatprovider.UserAgent(),
|
|
ExtraHeaders: chatprovider.CoderHeaders(chat),
|
|
}, route, modelOpts)
|
|
if err != nil {
|
|
p.logger.Debug(ctx, "manual title preferred model unavailable",
|
|
slog.F("chat_id", chat.ID),
|
|
slog.F("provider", config.Provider),
|
|
slog.F("model", config.Model),
|
|
slog.Error(err),
|
|
)
|
|
return p.resolveFallbackManualTitleModel(ctx, chat, keys, modelOpts)
|
|
}
|
|
|
|
return model, config, route.directProviderKeys(), nil
|
|
}
|
|
|
|
func (p *Server) resolveFallbackManualTitleModel(
|
|
ctx context.Context,
|
|
chat database.Chat,
|
|
keys chatprovider.ProviderAPIKeys,
|
|
modelOpts modelBuildOptions,
|
|
) (fantasy.LanguageModel, database.ChatModelConfig, chatprovider.ProviderAPIKeys, error) {
|
|
config, err := p.resolveModelConfig(ctx, chat)
|
|
if err != nil {
|
|
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, xerrors.Errorf(
|
|
"resolve fallback manual title model config: %w",
|
|
err,
|
|
)
|
|
}
|
|
route, err := p.resolveModelRouteForConfig(ctx, chat.OwnerID, config, keys)
|
|
if err != nil {
|
|
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, err
|
|
}
|
|
model, err := p.newModel(ctx, modelClientRequest{
|
|
Chat: chat,
|
|
ModelName: config.Model,
|
|
UserAgent: chatprovider.UserAgent(),
|
|
ExtraHeaders: chatprovider.CoderHeaders(chat),
|
|
}, route, modelOpts)
|
|
if err != nil {
|
|
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, xerrors.Errorf(
|
|
"create fallback manual title model: %w",
|
|
err,
|
|
)
|
|
}
|
|
return model, config, route.directProviderKeys(), nil
|
|
}
|
|
|
|
func mergeManualTitleMessages(
|
|
headMessages []database.ChatMessage,
|
|
tailMessagesDesc []database.ChatMessage,
|
|
) []database.ChatMessage {
|
|
merged := make([]database.ChatMessage, 0, len(headMessages)+len(tailMessagesDesc))
|
|
seen := make(map[int64]struct{}, len(headMessages)+len(tailMessagesDesc))
|
|
appendUnique := func(message database.ChatMessage) {
|
|
if _, ok := seen[message.ID]; ok {
|
|
return
|
|
}
|
|
seen[message.ID] = struct{}{}
|
|
merged = append(merged, message)
|
|
}
|
|
for _, message := range headMessages {
|
|
appendUnique(message)
|
|
}
|
|
for i := len(tailMessagesDesc) - 1; i >= 0; i-- {
|
|
appendUnique(tailMessagesDesc[i])
|
|
}
|
|
return merged
|
|
}
|
|
|
|
func fantasyUsageToChatMessageUsage(usage fantasy.Usage) codersdk.ChatMessageUsage {
|
|
var chatUsage codersdk.ChatMessageUsage
|
|
if usage.InputTokens != 0 {
|
|
chatUsage.InputTokens = ptr.Ref(usage.InputTokens)
|
|
}
|
|
if usage.OutputTokens != 0 {
|
|
chatUsage.OutputTokens = ptr.Ref(usage.OutputTokens)
|
|
}
|
|
if usage.ReasoningTokens != 0 {
|
|
chatUsage.ReasoningTokens = ptr.Ref(usage.ReasoningTokens)
|
|
}
|
|
if usage.CacheCreationTokens != 0 {
|
|
chatUsage.CacheCreationTokens = ptr.Ref(usage.CacheCreationTokens)
|
|
}
|
|
if usage.CacheReadTokens != 0 {
|
|
chatUsage.CacheReadTokens = ptr.Ref(usage.CacheReadTokens)
|
|
}
|
|
return chatUsage
|
|
}
|
|
|
|
func recordManualTitleUsage(
|
|
ctx context.Context,
|
|
store database.Store,
|
|
chat database.Chat,
|
|
modelConfig database.ChatModelConfig,
|
|
usage fantasy.Usage,
|
|
activeAPIKeyID string,
|
|
newTitle string,
|
|
) (database.Chat, error) {
|
|
hasUsage := usage != (fantasy.Usage{})
|
|
if !hasUsage && newTitle == "" {
|
|
return chat, nil
|
|
}
|
|
|
|
var totalCostMicros *int64
|
|
if hasUsage {
|
|
callConfig := codersdk.ChatModelCallConfig{}
|
|
if len(modelConfig.Options) > 0 {
|
|
if err := json.Unmarshal(modelConfig.Options, &callConfig); err != nil {
|
|
return database.Chat{}, xerrors.Errorf("parse model call config: %w", err)
|
|
}
|
|
}
|
|
totalCostMicros = chatcost.CalculateTotalCostMicros(
|
|
fantasyUsageToChatMessageUsage(usage),
|
|
callConfig.Cost,
|
|
)
|
|
}
|
|
|
|
// Use a valid empty JSON array for the content column.
|
|
// MarshalParts returns a null NullRawMessage for empty
|
|
// slices, which becomes an empty string that PostgreSQL
|
|
// rejects as invalid JSON.
|
|
content := "[]"
|
|
|
|
updatedChat := chat
|
|
err := store.InTx(func(tx database.Store) error {
|
|
lockedChat, err := tx.GetChatByIDForUpdate(ctx, chat.ID)
|
|
if err != nil {
|
|
return xerrors.Errorf("lock chat for manual title usage: %w", err)
|
|
}
|
|
updatedChat = lockedChat
|
|
if hasUsage {
|
|
messages, err := tx.InsertChatMessages(ctx, database.InsertChatMessagesParams{
|
|
ChatID: chat.ID,
|
|
CreatedBy: []uuid.UUID{chat.OwnerID},
|
|
APIKeyID: []string{activeAPIKeyID},
|
|
ModelConfigID: []uuid.UUID{modelConfig.ID},
|
|
Role: []database.ChatMessageRole{database.ChatMessageRoleAssistant},
|
|
Content: []string{content},
|
|
ContentVersion: []int16{chatprompt.CurrentContentVersion},
|
|
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityModel},
|
|
InputTokens: []int64{usage.InputTokens},
|
|
OutputTokens: []int64{usage.OutputTokens},
|
|
TotalTokens: []int64{usage.TotalTokens},
|
|
ReasoningTokens: []int64{usage.ReasoningTokens},
|
|
CacheCreationTokens: []int64{usage.CacheCreationTokens},
|
|
CacheReadTokens: []int64{usage.CacheReadTokens},
|
|
ContextLimit: []int64{modelConfig.ContextLimit},
|
|
Compressed: []bool{false},
|
|
TotalCostMicros: []int64{ptr.NilToDefault(totalCostMicros, 0)},
|
|
RuntimeMs: []int64{0},
|
|
ProviderResponseID: []string{""},
|
|
})
|
|
if err != nil {
|
|
return xerrors.Errorf("insert manual title usage message: %w", err)
|
|
}
|
|
if len(messages) != 1 {
|
|
return xerrors.Errorf("expected 1 manual title usage message, got %d", len(messages))
|
|
}
|
|
if err := tx.SoftDeleteChatMessageByID(ctx, messages[0].ID); err != nil {
|
|
return xerrors.Errorf("soft delete manual title usage message: %w", err)
|
|
}
|
|
if lockedChat.LastModelConfigID != modelConfig.ID {
|
|
if _, err := tx.UpdateChatLastModelConfigByID(ctx, database.UpdateChatLastModelConfigByIDParams{
|
|
ID: chat.ID,
|
|
LastModelConfigID: lockedChat.LastModelConfigID,
|
|
}); err != nil {
|
|
return xerrors.Errorf("restore chat model config after manual title usage: %w", err)
|
|
}
|
|
}
|
|
}
|
|
if newTitle != "" && lockedChat.Title == chat.Title && newTitle != lockedChat.Title {
|
|
updatedChat, err = tx.UpdateChatByID(ctx, database.UpdateChatByIDParams{
|
|
ID: chat.ID,
|
|
Title: newTitle,
|
|
})
|
|
if err != nil {
|
|
return xerrors.Errorf("update chat title: %w", err)
|
|
}
|
|
}
|
|
return nil
|
|
}, nil)
|
|
if err != nil {
|
|
return database.Chat{}, err
|
|
}
|
|
return updatedChat, nil
|
|
}
|
|
|
|
type chatMessage struct {
|
|
role database.ChatMessageRole
|
|
content pqtype.NullRawMessage
|
|
visibility database.ChatMessageVisibility
|
|
modelConfigID uuid.UUID
|
|
createdBy uuid.UUID
|
|
contentVersion int16
|
|
compressed bool
|
|
inputTokens int64
|
|
outputTokens int64
|
|
totalTokens int64
|
|
reasoningTokens int64
|
|
cacheCreationTokens int64
|
|
cacheReadTokens int64
|
|
contextLimit int64
|
|
totalCostMicros int64
|
|
runtimeMs int64
|
|
providerResponseID string
|
|
}
|
|
|
|
type userChatMessage struct {
|
|
chatMessage
|
|
apiKeyID string
|
|
}
|
|
|
|
func (m userChatMessage) withCreatedBy(id uuid.UUID) userChatMessage {
|
|
m.chatMessage = m.chatMessage.withCreatedBy(id)
|
|
return m
|
|
}
|
|
|
|
func (m userChatMessage) withCompressed() userChatMessage {
|
|
m.chatMessage = m.chatMessage.withCompressed()
|
|
return m
|
|
}
|
|
|
|
func newChatMessage(
|
|
role database.ChatMessageRole,
|
|
content pqtype.NullRawMessage,
|
|
visibility database.ChatMessageVisibility,
|
|
modelConfigID uuid.UUID,
|
|
contentVersion int16,
|
|
) chatMessage {
|
|
return chatMessage{
|
|
role: role,
|
|
content: content,
|
|
visibility: visibility,
|
|
modelConfigID: modelConfigID,
|
|
contentVersion: contentVersion,
|
|
}
|
|
}
|
|
|
|
func newUserChatMessage(
|
|
apiKeyID string,
|
|
content pqtype.NullRawMessage,
|
|
visibility database.ChatMessageVisibility,
|
|
modelConfigID uuid.UUID,
|
|
contentVersion int16,
|
|
) userChatMessage {
|
|
return userChatMessage{
|
|
chatMessage: newChatMessage(
|
|
database.ChatMessageRoleUser,
|
|
content,
|
|
visibility,
|
|
modelConfigID,
|
|
contentVersion,
|
|
),
|
|
apiKeyID: apiKeyID,
|
|
}
|
|
}
|
|
|
|
func (m chatMessage) withCreatedBy(id uuid.UUID) chatMessage {
|
|
m.createdBy = id
|
|
return m
|
|
}
|
|
|
|
func (m chatMessage) withCompressed() chatMessage {
|
|
m.compressed = true
|
|
return m
|
|
}
|
|
|
|
func (m chatMessage) withUsage(
|
|
inputTokens, outputTokens, totalTokens, reasoningTokens,
|
|
cacheCreationTokens, cacheReadTokens int64,
|
|
) chatMessage {
|
|
m.inputTokens = inputTokens
|
|
m.outputTokens = outputTokens
|
|
m.totalTokens = totalTokens
|
|
m.reasoningTokens = reasoningTokens
|
|
m.cacheCreationTokens = cacheCreationTokens
|
|
m.cacheReadTokens = cacheReadTokens
|
|
return m
|
|
}
|
|
|
|
func (m chatMessage) withContextLimit(limit int64) chatMessage {
|
|
m.contextLimit = limit
|
|
return m
|
|
}
|
|
|
|
func (m chatMessage) withTotalCostMicros(cost int64) chatMessage {
|
|
m.totalCostMicros = cost
|
|
return m
|
|
}
|
|
|
|
func (m chatMessage) withRuntimeMs(ms int64) chatMessage {
|
|
m.runtimeMs = ms
|
|
return m
|
|
}
|
|
|
|
func (m chatMessage) withProviderResponseID(id string) chatMessage {
|
|
m.providerResponseID = id
|
|
return m
|
|
}
|
|
|
|
func appendMessageFields(
|
|
params *database.InsertChatMessagesParams,
|
|
msg chatMessage,
|
|
apiKeyID string,
|
|
) {
|
|
params.CreatedBy = append(params.CreatedBy, msg.createdBy)
|
|
params.APIKeyID = append(params.APIKeyID, apiKeyID)
|
|
params.ModelConfigID = append(params.ModelConfigID, msg.modelConfigID)
|
|
params.Role = append(params.Role, msg.role)
|
|
params.Content = append(params.Content, string(msg.content.RawMessage))
|
|
params.ContentVersion = append(params.ContentVersion, msg.contentVersion)
|
|
params.Visibility = append(params.Visibility, msg.visibility)
|
|
params.InputTokens = append(params.InputTokens, msg.inputTokens)
|
|
params.OutputTokens = append(params.OutputTokens, msg.outputTokens)
|
|
params.TotalTokens = append(params.TotalTokens, msg.totalTokens)
|
|
params.ReasoningTokens = append(params.ReasoningTokens, msg.reasoningTokens)
|
|
params.CacheCreationTokens = append(params.CacheCreationTokens, msg.cacheCreationTokens)
|
|
params.CacheReadTokens = append(params.CacheReadTokens, msg.cacheReadTokens)
|
|
params.ContextLimit = append(params.ContextLimit, msg.contextLimit)
|
|
params.Compressed = append(params.Compressed, msg.compressed)
|
|
params.TotalCostMicros = append(params.TotalCostMicros, msg.totalCostMicros)
|
|
params.RuntimeMs = append(params.RuntimeMs, msg.runtimeMs)
|
|
params.ProviderResponseID = append(params.ProviderResponseID, msg.providerResponseID)
|
|
}
|
|
|
|
func appendChatMessage(params *database.InsertChatMessagesParams, msg chatMessage) {
|
|
if msg.role == database.ChatMessageRoleUser {
|
|
panic("developer error: use appendUserChatMessage for user-role messages")
|
|
}
|
|
appendMessageFields(params, msg, "")
|
|
}
|
|
|
|
func appendUserChatMessage(params *database.InsertChatMessagesParams, msg userChatMessage) {
|
|
appendMessageFields(params, msg.chatMessage, msg.apiKeyID)
|
|
}
|
|
|
|
// BuildSingleUserChatMessageInsertParams creates batch insert params for
|
|
// one user message, requiring an apiKeyID for AI Gateway attribution.
|
|
// BuildSingleChatMessageInsertParams creates batch insert params for one
|
|
// non-user message using the shared chat message builder.
|
|
func BuildSingleChatMessageInsertParams(
|
|
chatID uuid.UUID,
|
|
role database.ChatMessageRole,
|
|
content pqtype.NullRawMessage,
|
|
visibility database.ChatMessageVisibility,
|
|
modelConfigID uuid.UUID,
|
|
contentVersion int16,
|
|
createdBy uuid.UUID,
|
|
) database.InsertChatMessagesParams {
|
|
params := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendChatMessage.
|
|
ChatID: chatID,
|
|
}
|
|
msg := newChatMessage(role, content, visibility, modelConfigID, contentVersion)
|
|
if createdBy != uuid.Nil {
|
|
msg = msg.withCreatedBy(createdBy)
|
|
}
|
|
if role == database.ChatMessageRoleUser {
|
|
appendMessageFields(¶ms, msg, "")
|
|
} else {
|
|
appendChatMessage(¶ms, msg)
|
|
}
|
|
return params
|
|
}
|
|
|
|
func BuildSingleUserChatMessageInsertParams(
|
|
chatID uuid.UUID,
|
|
apiKeyID string,
|
|
content pqtype.NullRawMessage,
|
|
visibility database.ChatMessageVisibility,
|
|
modelConfigID uuid.UUID,
|
|
contentVersion int16,
|
|
createdBy uuid.UUID,
|
|
) database.InsertChatMessagesParams {
|
|
params := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendUserChatMessage.
|
|
ChatID: chatID,
|
|
}
|
|
msg := newUserChatMessage(apiKeyID, content, visibility, modelConfigID, contentVersion)
|
|
if createdBy != uuid.Nil {
|
|
msg = msg.withCreatedBy(createdBy)
|
|
}
|
|
appendUserChatMessage(¶ms, msg)
|
|
return params
|
|
}
|
|
|
|
// Config configures a chat processor.
|
|
type Config struct {
|
|
Logger slog.Logger
|
|
Database database.Store
|
|
ReplicaID uuid.UUID
|
|
SubscribeFn SubscribeFn
|
|
PendingChatAcquireInterval time.Duration
|
|
MaxChatsPerAcquire int32
|
|
InFlightChatStaleAfter time.Duration
|
|
ChatHeartbeatInterval time.Duration
|
|
AgentConn AgentConnFunc
|
|
AgentInactiveDisconnectTimeout time.Duration
|
|
InstructionLookupTimeout time.Duration
|
|
CreateWorkspace chattool.CreateWorkspaceFn
|
|
StartWorkspace chattool.StartWorkspaceFn
|
|
StopWorkspace chattool.StopWorkspaceFn
|
|
Pubsub pubsub.Pubsub
|
|
ProviderAPIKeys chatprovider.ProviderAPIKeys
|
|
AllowBYOK bool
|
|
AllowBYOKSet bool
|
|
AlwaysEnableDebugLogs bool
|
|
WebpushDispatcher webpush.Dispatcher
|
|
UsageTracker *workspacestats.UsageTracker
|
|
Clock quartz.Clock
|
|
AIBridgeTransportFactory *atomic.Pointer[aibridge.TransportFactory]
|
|
AIGatewayRoutingEnabled bool
|
|
|
|
PrometheusRegistry prometheus.Registerer
|
|
|
|
// OIDCTokenSource resolves the calling user's OIDC access
|
|
// token for MCP servers configured with auth_type=user_oidc.
|
|
// May be nil if the deployment has no OIDC provider; servers
|
|
// using user_oidc will then send no Authorization header.
|
|
OIDCTokenSource mcpclient.UserOIDCTokenSource
|
|
|
|
NotificationsEnqueuer notifications.Enqueuer
|
|
Auditor *atomic.Pointer[audit.Auditor]
|
|
}
|
|
|
|
// New creates a new chat processor. The processor polls for pending
|
|
// chats and processes them. It is the caller's responsibility to call Close
|
|
// on the returned instance.
|
|
func New(cfg Config) *Server {
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
|
|
pendingChatAcquireInterval := cfg.PendingChatAcquireInterval
|
|
if pendingChatAcquireInterval == 0 {
|
|
pendingChatAcquireInterval = DefaultPendingChatAcquireInterval
|
|
}
|
|
|
|
inFlightChatStaleAfter := cfg.InFlightChatStaleAfter
|
|
if inFlightChatStaleAfter == 0 {
|
|
inFlightChatStaleAfter = DefaultInFlightChatStaleAfter
|
|
}
|
|
|
|
maxChatsPerAcquire := cfg.MaxChatsPerAcquire
|
|
if maxChatsPerAcquire <= 0 {
|
|
maxChatsPerAcquire = DefaultMaxChatsPerAcquire
|
|
}
|
|
|
|
chatHeartbeatInterval := cfg.ChatHeartbeatInterval
|
|
if chatHeartbeatInterval == 0 {
|
|
chatHeartbeatInterval = DefaultChatHeartbeatInterval
|
|
}
|
|
|
|
clk := cfg.Clock
|
|
if clk == nil {
|
|
clk = quartz.NewReal()
|
|
}
|
|
|
|
if cfg.Pubsub == nil {
|
|
panic("chatd: Pubsub is nil")
|
|
}
|
|
ps := cfg.Pubsub
|
|
|
|
notificationsEnqueuer := cfg.NotificationsEnqueuer
|
|
if notificationsEnqueuer == nil {
|
|
notificationsEnqueuer = notifications.NewNoopEnqueuer()
|
|
}
|
|
|
|
instructionLookupTimeout := cfg.InstructionLookupTimeout
|
|
if instructionLookupTimeout == 0 {
|
|
instructionLookupTimeout = homeInstructionLookupTimeout
|
|
}
|
|
|
|
workerID := cfg.ReplicaID
|
|
if workerID == uuid.Nil {
|
|
workerID = uuid.New()
|
|
}
|
|
|
|
allowBYOK := true
|
|
if cfg.AllowBYOKSet {
|
|
allowBYOK = cfg.AllowBYOK
|
|
}
|
|
|
|
p := &Server{
|
|
cancel: cancel,
|
|
db: cfg.Database,
|
|
workerID: workerID,
|
|
logger: cfg.Logger.Named("processor"),
|
|
subscribeFn: cfg.SubscribeFn,
|
|
agentConnFn: cfg.AgentConn,
|
|
agentInactiveDisconnectTimeout: cfg.AgentInactiveDisconnectTimeout,
|
|
dialTimeout: defaultDialTimeout,
|
|
instructionLookupTimeout: instructionLookupTimeout,
|
|
createWorkspaceFn: cfg.CreateWorkspace,
|
|
startWorkspaceFn: cfg.StartWorkspace,
|
|
stopWorkspaceFn: cfg.StopWorkspace,
|
|
pubsub: ps,
|
|
webpushDispatcher: cfg.WebpushDispatcher,
|
|
providerAPIKeys: cfg.ProviderAPIKeys,
|
|
allowBYOK: allowBYOK,
|
|
oidcTokenSource: cfg.OIDCTokenSource,
|
|
debugSvcFactory: func() *chatdebug.Service {
|
|
debugSvc := chatdebug.NewService(
|
|
cfg.Database,
|
|
cfg.Logger.Named("chatdebug"),
|
|
ps,
|
|
chatdebug.WithAlwaysEnable(cfg.AlwaysEnableDebugLogs),
|
|
)
|
|
// Debug runs do not heartbeat during model streams; their
|
|
// updated_at is only touched on step/run completion. Use a
|
|
// longer stale window so long-running turns are not falsely
|
|
// finalized as stale while still executing.
|
|
debugSvc.SetStaleAfter(inFlightChatStaleAfter * 3)
|
|
return debugSvc
|
|
},
|
|
aibridgeTransportFactory: cfg.AIBridgeTransportFactory,
|
|
aiGatewayRoutingEnabled: cfg.AIGatewayRoutingEnabled,
|
|
pendingChatAcquireInterval: pendingChatAcquireInterval,
|
|
maxChatsPerAcquire: maxChatsPerAcquire,
|
|
inFlightChatStaleAfter: inFlightChatStaleAfter,
|
|
chatHeartbeatInterval: chatHeartbeatInterval,
|
|
usageTracker: cfg.UsageTracker,
|
|
clock: clk,
|
|
recordingSem: make(chan struct{}, maxConcurrentRecordingUploads),
|
|
}
|
|
var chatAutoArchiveRecords prometheus.Counter
|
|
if cfg.PrometheusRegistry != nil {
|
|
p.metrics = chatloop.NewMetrics(cfg.PrometheusRegistry)
|
|
cfg.PrometheusRegistry.MustRegister(&streamStateCollector{server: p})
|
|
chatAutoArchiveRecords = prometheus.NewCounter(prometheus.CounterOpts{
|
|
Namespace: "coderd",
|
|
Subsystem: "chat_auto_archive",
|
|
Name: "records_archived_total",
|
|
Help: "Total number of chats archived by the auto-archive job (counting both roots and cascaded children).",
|
|
})
|
|
cfg.PrometheusRegistry.MustRegister(chatAutoArchiveRecords)
|
|
} else {
|
|
p.metrics = chatloop.NopMetrics()
|
|
}
|
|
p.messagePartBuffer = messagepartbuffer.New(messagepartbuffer.Options{Clock: clk})
|
|
chatWorker, err := newChatWorker(p, chatWorkerOptions{
|
|
WorkerID: workerID,
|
|
Store: cfg.Database,
|
|
Pubsub: ps,
|
|
Logger: cfg.Logger.Named("chatworker"),
|
|
Clock: clk,
|
|
MessagePartBuffer: p.messagePartBuffer,
|
|
AcquisitionInterval: pendingChatAcquireInterval,
|
|
AcquisitionBatchSize: maxChatsPerAcquire,
|
|
HeartbeatInterval: chatHeartbeatInterval,
|
|
HeartbeatStaleSeconds: int32(inFlightChatStaleAfter.Seconds()),
|
|
NotificationsEnqueuer: notificationsEnqueuer,
|
|
Auditor: cfg.Auditor,
|
|
AutoArchiveRecords: chatAutoArchiveRecords,
|
|
})
|
|
if err != nil {
|
|
panic("chatd: create chat worker: " + err.Error())
|
|
}
|
|
p.chatWorker = chatWorker
|
|
|
|
//nolint:gocritic // The chat processor uses a scoped chatd context.
|
|
ctx = dbauthz.AsChatd(ctx)
|
|
|
|
p.configCache = newChatConfigCache(ctx, cfg.Database, clk)
|
|
cancelConfigSub, err := p.pubsub.SubscribeWithErr(
|
|
coderdpubsub.ChatConfigEventChannel,
|
|
coderdpubsub.HandleChatConfigEvent(func(ctx context.Context, ev coderdpubsub.ChatConfigEvent, err error) {
|
|
if err != nil {
|
|
p.logger.Warn(ctx, "chat config event error", slog.Error(err))
|
|
return
|
|
}
|
|
switch ev.Kind {
|
|
case coderdpubsub.ChatConfigEventProviders:
|
|
p.configCache.InvalidateProviders()
|
|
case coderdpubsub.ChatConfigEventModelConfig:
|
|
p.configCache.InvalidateModelConfig(ev.EntityID)
|
|
case coderdpubsub.ChatConfigEventUserPrompt:
|
|
p.configCache.InvalidateUserPrompt(ev.EntityID)
|
|
case coderdpubsub.ChatConfigEventAdvisorConfig:
|
|
p.configCache.InvalidateAdvisorConfig()
|
|
}
|
|
}),
|
|
)
|
|
if err != nil {
|
|
p.logger.Error(ctx, "subscribe to chat config events", slog.Error(err))
|
|
} else {
|
|
p.configCacheUnsubscribe = cancelConfigSub
|
|
}
|
|
|
|
p.ctx = ctx
|
|
|
|
// Spawn background goroutines that all servers need.
|
|
p.wg.Go(func() { p.streamJanitorLoop(ctx) })
|
|
|
|
return p
|
|
}
|
|
|
|
// Start runs the background acquire/wake loop that picks up
|
|
// pending chats and processes them. Callers that want a passive
|
|
// server (e.g. tests) can skip this call; heartbeat, stream
|
|
// janitor, and stale recovery still run.
|
|
func (p *Server) Start() *Server {
|
|
if p.chatWorker != nil {
|
|
if err := p.chatWorker.Start(p.ctx); err != nil {
|
|
p.logger.Error(p.ctx, "failed to start chat worker", slog.Error(err))
|
|
}
|
|
}
|
|
return p
|
|
}
|
|
|
|
// getCachedDurableMessages returns cached durable messages with IDs
|
|
// greater than afterID. Returns nil when the cache has no relevant
|
|
// entries.
|
|
func (p *Server) getCachedDurableMessages(
|
|
chatID uuid.UUID,
|
|
afterID int64,
|
|
) []codersdk.ChatStreamEvent {
|
|
state := p.getOrCreateStreamState(chatID)
|
|
state.mu.Lock()
|
|
defer state.mu.Unlock()
|
|
|
|
if afterID < state.durableEvictedBefore {
|
|
return nil
|
|
}
|
|
|
|
var result []codersdk.ChatStreamEvent
|
|
for _, event := range state.durableMessages {
|
|
if event.Message != nil && event.Message.ID > afterID {
|
|
result = append(result, event)
|
|
}
|
|
}
|
|
return result
|
|
}
|
|
|
|
// snapshotBufferLocked returns the buffered message_part events that
|
|
// the caller should receive in their initial snapshot.
|
|
//
|
|
// Parts whose committedMessageID != 0 are dropped: those parts were
|
|
// claimed by a durable assistant message that the subscriber will
|
|
// receive through a different channel (REST snapshot, the initial DB
|
|
// query in SubscribeAuthorized, or pubsub catch-up). Delivering them
|
|
// here would render the same content twice on the client, once in the
|
|
// streaming UI and once as a durable message.
|
|
//
|
|
// Every caller receives the same view: in-progress parts are always
|
|
// delivered and committed parts are always dropped, regardless of
|
|
// cursor or relay sentinel. This keeps the buffer free of duplicate
|
|
// work for every subscriber, including cross-replica relay
|
|
// subscribers whose user-facing peers receive the durable message
|
|
// via pubsub.
|
|
//
|
|
// The caller must hold the per-chat stream state lock.
|
|
func snapshotBufferLocked(buffer []bufferedStreamPart) []codersdk.ChatStreamEvent {
|
|
if len(buffer) == 0 {
|
|
return nil
|
|
}
|
|
snapshot := make([]codersdk.ChatStreamEvent, 0, len(buffer))
|
|
for _, part := range buffer {
|
|
if part.committedMessageID != 0 {
|
|
continue
|
|
}
|
|
snapshot = append(snapshot, part.event)
|
|
}
|
|
return snapshot
|
|
}
|
|
|
|
// subscribeToStream registers a subscriber to the per-chat in-memory
|
|
// stream and returns a snapshot of currently in-progress message_part
|
|
// events plus the current retry phase, the live subscriber channel,
|
|
// and a cancel func.
|
|
//
|
|
// Parts that were claimed by a committed durable assistant message
|
|
// (committedMessageID != 0) are excluded from the snapshot. The
|
|
// subscriber will receive those durable messages through the REST
|
|
// snapshot, the initial DB query in SubscribeAuthorized, or pubsub,
|
|
// so re-delivering their constituent parts here would render the
|
|
// same content twice.
|
|
func (p *Server) subscribeToStream(chatID uuid.UUID) (
|
|
[]codersdk.ChatStreamEvent,
|
|
*codersdk.ChatStreamRetry,
|
|
<-chan codersdk.ChatStreamEvent,
|
|
func(),
|
|
) {
|
|
state := p.getOrCreateStreamState(chatID)
|
|
state.mu.Lock()
|
|
snapshot := snapshotBufferLocked(state.buffer)
|
|
var currentRetry *codersdk.ChatStreamRetry
|
|
if state.currentRetry != nil {
|
|
retryCopy := *state.currentRetry
|
|
currentRetry = &retryCopy
|
|
}
|
|
id := uuid.New()
|
|
ch := make(chan codersdk.ChatStreamEvent, 128)
|
|
state.subscribers[id] = ch
|
|
state.mu.Unlock()
|
|
|
|
cancel := func() {
|
|
state.mu.Lock()
|
|
// Remove the subscriber but do not close the channel.
|
|
// publishToStream copies subscriber references under
|
|
// the per-chat lock then sends outside; closing here
|
|
// races with that send and can panic. The channel
|
|
// becomes unreachable once removed and will be GC'd.
|
|
delete(state.subscribers, id)
|
|
p.cleanupStreamIfIdle(chatID, state)
|
|
state.mu.Unlock()
|
|
}
|
|
|
|
return snapshot, currentRetry, ch, cancel
|
|
}
|
|
|
|
// getOrCreateStreamState returns the per-chat stream state,
|
|
// creating one atomically if it doesn't exist. The returned
|
|
// state has its own mutex — callers must lock state.mu for
|
|
// access.
|
|
func (p *Server) getOrCreateStreamState(chatID uuid.UUID) *chatStreamState {
|
|
if val, ok := p.chatStreams.Load(chatID); ok {
|
|
state, _ := val.(*chatStreamState)
|
|
return state
|
|
}
|
|
val, _ := p.chatStreams.LoadOrStore(chatID, &chatStreamState{
|
|
subscribers: make(map[uuid.UUID]chan codersdk.ChatStreamEvent),
|
|
})
|
|
state, _ := val.(*chatStreamState)
|
|
return state
|
|
}
|
|
|
|
// cleanupStreamIfIdle removes the chat entry from the sync.Map when
|
|
// there are no subscribers, the stream is not buffering, and any
|
|
// grace period for late-connecting relay subscribers has elapsed. If
|
|
// the grace window is still open it returns without rescheduling.
|
|
// streamJanitorLoop is the backstop that re-checks on a timer.
|
|
//
|
|
// The caller must hold state.mu. The state pointer may have been
|
|
// captured outside this lock (sync.Map.Load or Range); we use
|
|
// CompareAndDelete so a stale pointer cannot evict a fresh entry
|
|
// installed by a racing getOrCreateStreamState. Returns true
|
|
// if the state was deleted, false otherwise.
|
|
func (p *Server) cleanupStreamIfIdle(chatID uuid.UUID, state *chatStreamState) bool {
|
|
if state.buffering || len(state.subscribers) > 0 {
|
|
return false
|
|
}
|
|
// Keep stream state alive during the grace period so
|
|
// late-connecting cross-replica relay subscribers can
|
|
// register against this chat before GC.
|
|
if !state.bufferRetainedAt.IsZero() &&
|
|
p.clock.Now().Before(state.bufferRetainedAt.Add(bufferRetainGracePeriod)) {
|
|
return false
|
|
}
|
|
if !p.chatStreams.CompareAndDelete(chatID, state) {
|
|
return false
|
|
}
|
|
p.workspaceMCPToolsCache.Delete(chatID)
|
|
return true
|
|
}
|
|
|
|
// streamJanitorLoop periodically reaps idle chat stream states whose
|
|
// grace period has expired. It is the backstop for the grace-window
|
|
// early-return in cleanupStreamIfIdle; without it, a subscriber that
|
|
// detaches inside grace (the common enterprise relay-drain case,
|
|
// relayDrainTimeout = 200ms vs. 5s grace) pins the state forever.
|
|
func (p *Server) streamJanitorLoop(ctx context.Context) {
|
|
ticker := p.clock.NewTicker(streamJanitorInterval, "chatd", "stream-janitor")
|
|
defer ticker.Stop()
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-ticker.C:
|
|
p.safeSweepIdleStreams(ctx)
|
|
}
|
|
}
|
|
}
|
|
|
|
// safeSweepIdleStreams runs sweepIdleStreams under a panic recovery
|
|
// so an unexpected panic in the sweep cannot kill the janitor
|
|
// goroutine and silently reintroduce the very leak it exists to
|
|
// prevent. The next tick retries.
|
|
func (p *Server) safeSweepIdleStreams(ctx context.Context) {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
p.logger.Error(ctx, "stream janitor sweep panicked, will retry next tick",
|
|
slog.F("panic", r))
|
|
}
|
|
}()
|
|
p.sweepIdleStreams()
|
|
}
|
|
|
|
// sweepIdleStreams iterates chatStreams once and delegates each entry
|
|
// to cleanupStreamIfIdle. Range may skip entries that become reapable
|
|
// concurrently. Any such entry is reaped on the next tick.
|
|
func (p *Server) sweepIdleStreams() {
|
|
var reaped atomic.Int64
|
|
defer func() {
|
|
if count := reaped.Load(); count > 0 {
|
|
p.logger.Info(context.Background(), "reaped idle chat streams", slog.F("count", count))
|
|
}
|
|
}()
|
|
p.chatStreams.Range(func(key, value any) bool {
|
|
chatID, ok := key.(uuid.UUID)
|
|
if !ok {
|
|
return true
|
|
}
|
|
state, ok := value.(*chatStreamState)
|
|
if !ok {
|
|
return true
|
|
}
|
|
// guard against any panic from cleanupStreamIfIdle locking state.mu for all time
|
|
func() {
|
|
state.mu.Lock()
|
|
defer state.mu.Unlock()
|
|
if p.cleanupStreamIfIdle(chatID, state) {
|
|
reaped.Add(1)
|
|
}
|
|
}()
|
|
return true
|
|
})
|
|
}
|
|
|
|
// streamSubscriberControlFetchContext keeps a control-path lookup tied to the
|
|
// requesting subscriber while applying a fallback timeout when the caller has
|
|
// no deadline.
|
|
func streamSubscriberControlFetchContext(ctx context.Context) (context.Context, context.CancelFunc) {
|
|
if _, ok := ctx.Deadline(); ok {
|
|
return ctx, func() {}
|
|
}
|
|
return context.WithTimeout(ctx, chatStreamControlFetchTimeout)
|
|
}
|
|
|
|
func subscribeWithInitialError(chatID uuid.UUID, message string) (
|
|
[]codersdk.ChatStreamEvent,
|
|
<-chan codersdk.ChatStreamEvent,
|
|
func(),
|
|
bool,
|
|
) {
|
|
events := make(chan codersdk.ChatStreamEvent)
|
|
close(events)
|
|
return []codersdk.ChatStreamEvent{{
|
|
Type: codersdk.ChatStreamEventTypeError,
|
|
ChatID: chatID,
|
|
Error: &codersdk.ChatError{Message: message},
|
|
}}, events, func() {}, true
|
|
}
|
|
|
|
func (p *Server) Subscribe(
|
|
ctx context.Context,
|
|
chatID uuid.UUID,
|
|
requestHeader http.Header,
|
|
afterMessageID int64,
|
|
) (
|
|
[]codersdk.ChatStreamEvent,
|
|
<-chan codersdk.ChatStreamEvent,
|
|
func(),
|
|
bool,
|
|
) {
|
|
if p == nil {
|
|
return nil, nil, nil, false
|
|
}
|
|
|
|
chat, err := p.db.GetChatByID(ctx, chatID)
|
|
if err != nil {
|
|
if dbauthz.IsNotAuthorizedError(err) {
|
|
return nil, nil, nil, false
|
|
}
|
|
p.logger.Warn(ctx, "failed to load chat for stream subscription",
|
|
slog.F("chat_id", chatID),
|
|
slog.Error(err),
|
|
)
|
|
return subscribeWithInitialError(chatID, "failed to load initial snapshot")
|
|
}
|
|
return p.SubscribeAuthorized(ctx, chat, requestHeader, afterMessageID)
|
|
}
|
|
|
|
// SubscribeAuthorized subscribes an already-authorized chat to merged stream
|
|
// updates. The passed chat row proves authorization, but SubscribeAuthorized
|
|
// still reloads the chat after the stream subscriptions are armed so the
|
|
// initial status and relay setup use fresh state.
|
|
func (p *Server) SubscribeAuthorized(
|
|
ctx context.Context,
|
|
chat database.Chat,
|
|
requestHeader http.Header,
|
|
afterMessageID int64,
|
|
) (
|
|
[]codersdk.ChatStreamEvent,
|
|
<-chan codersdk.ChatStreamEvent,
|
|
func(),
|
|
bool,
|
|
) {
|
|
if p == nil {
|
|
return nil, nil, nil, false
|
|
}
|
|
chatID := chat.ID
|
|
|
|
// Subscribe to the local stream for message_parts and same-replica
|
|
// persisted messages. Capture the current retry phase under the same
|
|
// lock so the transient snapshot and subscriber registration reflect
|
|
// a single moment in time.
|
|
localSnapshot, localRetry, localParts, localCancel := p.subscribeToStream(chatID)
|
|
|
|
// Merge all event sources.
|
|
mergedCtx, mergedCancel := context.WithCancel(ctx)
|
|
mergedEvents := make(chan codersdk.ChatStreamEvent, 128)
|
|
|
|
var allCancels []func()
|
|
allCancels = append(allCancels, localCancel)
|
|
|
|
// Subscribe to pubsub for durable and structured control
|
|
// events (status, messages, queue updates, retry, errors).
|
|
// If the subscription cannot be established, deliver all local
|
|
// events.
|
|
//
|
|
// This MUST happen before the DB queries below so that any
|
|
// notification published between the query and the subscription
|
|
// is not lost (subscribe-first-then-query pattern).
|
|
notifyCh := make(chan coderdpubsub.ChatStreamNotifyMessage, 10)
|
|
errCh := make(chan error, 1)
|
|
listener := func(_ context.Context, message []byte, listenErr error) {
|
|
if listenErr != nil {
|
|
select {
|
|
case <-mergedCtx.Done():
|
|
case errCh <- listenErr:
|
|
}
|
|
return
|
|
}
|
|
var notify coderdpubsub.ChatStreamNotifyMessage
|
|
if unmarshalErr := json.Unmarshal(message, ¬ify); unmarshalErr != nil {
|
|
select {
|
|
case <-mergedCtx.Done():
|
|
case errCh <- xerrors.Errorf("unmarshal chat stream notify: %w", unmarshalErr):
|
|
}
|
|
return
|
|
}
|
|
select {
|
|
case <-mergedCtx.Done():
|
|
case notifyCh <- notify:
|
|
}
|
|
}
|
|
|
|
if pubsubCancel, pubsubErr := p.pubsub.SubscribeWithErr(
|
|
coderdpubsub.ChatStreamNotifyChannel(chatID),
|
|
listener,
|
|
); pubsubErr == nil {
|
|
allCancels = append(allCancels, pubsubCancel)
|
|
} else {
|
|
p.logger.Warn(ctx, "failed to subscribe to chat stream notifications",
|
|
slog.F("chat_id", chatID),
|
|
slog.Error(pubsubErr),
|
|
)
|
|
}
|
|
|
|
cancel := func() {
|
|
mergedCancel()
|
|
for _, cancelFn := range allCancels {
|
|
if cancelFn != nil {
|
|
cancelFn()
|
|
}
|
|
}
|
|
}
|
|
|
|
// Re-read the chat after the local/pubsub subscriptions are active so
|
|
// the initial status event and any enterprise relay setup use fresh
|
|
// state instead of the middleware-loaded row.
|
|
refreshCtx, refreshCancel := streamSubscriberControlFetchContext(ctx)
|
|
snapshotChat, err := func() (database.Chat, error) {
|
|
defer refreshCancel()
|
|
//nolint:gocritic // SubscribeAuthorized already validated the
|
|
// caller; this refresh only loads the latest status/worker for
|
|
// the already-authorized stream subscription.
|
|
return p.db.GetChatByID(dbauthz.AsChatd(refreshCtx), chatID)
|
|
}()
|
|
if err != nil {
|
|
p.logger.Warn(ctx, "failed to refresh chat for stream subscription; using stale state",
|
|
slog.F("chat_id", chatID),
|
|
slog.Error(err),
|
|
)
|
|
snapshotChat = chat
|
|
}
|
|
|
|
// Build initial snapshot synchronously. The pubsub subscription
|
|
// is already active so no notifications can be lost during this
|
|
// window.
|
|
initialSnapshot := make([]codersdk.ChatStreamEvent, 0)
|
|
delivered := map[int64]struct{}{}
|
|
// Add local same-replica message_parts to the snapshot. Retry comes
|
|
// from state.currentRetry, not the event buffer, so late joiners see
|
|
// only the latest phase rather than a stale buffered retry event.
|
|
for _, event := range localSnapshot {
|
|
if event.Type == codersdk.ChatStreamEventTypeMessagePart {
|
|
initialSnapshot = append(initialSnapshot, event)
|
|
}
|
|
}
|
|
|
|
var retryEvent *codersdk.ChatStreamEvent
|
|
if localRetry != nil {
|
|
retryEvent = &codersdk.ChatStreamEvent{
|
|
Type: codersdk.ChatStreamEventTypeRetry,
|
|
ChatID: chatID,
|
|
Retry: localRetry,
|
|
}
|
|
}
|
|
|
|
// Load initial messages from DB. When afterMessageID > 0 the
|
|
// caller already has messages up to that ID (e.g. from the REST
|
|
// endpoint), so we only fetch newer ones to avoid sending
|
|
// duplicate data.
|
|
messages, err := p.db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
|
ChatID: chatID,
|
|
AfterID: afterMessageID,
|
|
})
|
|
if err != nil {
|
|
p.logger.Error(ctx, "failed to load initial chat messages",
|
|
slog.Error(err),
|
|
slog.F("chat_id", chatID),
|
|
)
|
|
initialSnapshot = append(initialSnapshot, codersdk.ChatStreamEvent{
|
|
Type: codersdk.ChatStreamEventTypeError,
|
|
ChatID: chatID,
|
|
Error: &codersdk.ChatError{Message: "failed to load initial snapshot"},
|
|
})
|
|
} else {
|
|
for _, msg := range messages {
|
|
sdkMsg := db2sdk.ChatMessage(msg)
|
|
initialSnapshot = append(initialSnapshot, codersdk.ChatStreamEvent{
|
|
Type: codersdk.ChatStreamEventTypeMessage,
|
|
ChatID: chatID,
|
|
Message: &sdkMsg,
|
|
})
|
|
delivered[msg.ID] = struct{}{}
|
|
}
|
|
}
|
|
|
|
// Load initial queue. Queue snapshots are intentionally not
|
|
// singleflighted because a chat-scoped key cannot distinguish the
|
|
// pre- and post-notification queue state.
|
|
queueCtx, queueCancel := streamSubscriberControlFetchContext(ctx)
|
|
queued, err := p.db.GetChatQueuedMessages(queueCtx, chatID)
|
|
queueCancel()
|
|
if err != nil {
|
|
p.logger.Error(ctx, "failed to load initial queued messages",
|
|
slog.Error(err),
|
|
slog.F("chat_id", chatID),
|
|
)
|
|
initialSnapshot = append(initialSnapshot, codersdk.ChatStreamEvent{
|
|
Type: codersdk.ChatStreamEventTypeError,
|
|
ChatID: chatID,
|
|
Error: &codersdk.ChatError{Message: "failed to load initial snapshot"},
|
|
})
|
|
} else if len(queued) > 0 {
|
|
initialSnapshot = append(initialSnapshot, codersdk.ChatStreamEvent{
|
|
Type: codersdk.ChatStreamEventTypeQueueUpdate,
|
|
ChatID: chatID,
|
|
QueuedMessages: db2sdk.ChatQueuedMessages(queued),
|
|
})
|
|
}
|
|
|
|
// Include the current chat status in the snapshot so the
|
|
// frontend can gate message_part processing correctly from
|
|
// the very first batch, without waiting for a separate REST
|
|
// query.
|
|
statusEvent := codersdk.ChatStreamEvent{
|
|
Type: codersdk.ChatStreamEventTypeStatus,
|
|
ChatID: chatID,
|
|
Status: &codersdk.ChatStreamStatus{
|
|
Status: codersdk.ChatStatus(snapshotChat.Status),
|
|
},
|
|
}
|
|
// Prepend so the frontend sees the current stream phases
|
|
// before any message_part events.
|
|
prefix := []codersdk.ChatStreamEvent{statusEvent}
|
|
if retryEvent != nil {
|
|
prefix = append(prefix, *retryEvent)
|
|
}
|
|
initialSnapshot = append(prefix, initialSnapshot...)
|
|
|
|
// Track the highest durable message ID delivered to this subscriber,
|
|
// whether it came from the initial DB snapshot, the same-replica local
|
|
// stream, or a later DB/cache catch-up.
|
|
lastMessageID := afterMessageID
|
|
if len(messages) > 0 {
|
|
lastMessageID = messages[len(messages)-1].ID
|
|
}
|
|
|
|
// When an enterprise SubscribeFn is provided, call it to get relay events
|
|
// (message_parts from remote replicas). OSS owns pubsub subscription,
|
|
// message catch-up, queue updates, and status forwarding; enterprise only
|
|
// manages relay dialing.
|
|
var relayEvents <-chan codersdk.ChatStreamEvent
|
|
var statusNotifications chan StatusNotification
|
|
if p.subscribeFn != nil {
|
|
statusNotifications = make(chan StatusNotification, 10)
|
|
relayEvents = p.subscribeFn(mergedCtx, SubscribeFnParams{
|
|
ChatID: chatID,
|
|
Chat: snapshotChat,
|
|
WorkerID: p.workerID,
|
|
StatusNotifications: statusNotifications,
|
|
RequestHeader: requestHeader,
|
|
DB: p.db,
|
|
Logger: p.logger,
|
|
})
|
|
}
|
|
// hasPubsubSubscription is only true when we actually subscribed
|
|
// successfully above (allCancels will contain the pubsub
|
|
// cancel func in that case).
|
|
hasPubsubSubscription := len(allCancels) > 1
|
|
|
|
//nolint:nestif
|
|
go func() {
|
|
defer close(mergedEvents)
|
|
if statusNotifications != nil {
|
|
defer close(statusNotifications)
|
|
}
|
|
for {
|
|
select {
|
|
case <-mergedCtx.Done():
|
|
return
|
|
case psErr := <-errCh:
|
|
p.logger.Error(mergedCtx, "chat stream pubsub error",
|
|
slog.F("chat_id", chatID),
|
|
slog.Error(psErr),
|
|
)
|
|
select {
|
|
case mergedEvents <- codersdk.ChatStreamEvent{
|
|
Type: codersdk.ChatStreamEventTypeError,
|
|
ChatID: chatID,
|
|
Error: &codersdk.ChatError{
|
|
Message: psErr.Error(),
|
|
},
|
|
}:
|
|
case <-mergedCtx.Done():
|
|
}
|
|
return
|
|
case notify := <-notifyCh:
|
|
// Marker for ENG-2645: subscriber received pubsub notify.
|
|
p.logger.Debug(mergedCtx, "stream subscriber received notify",
|
|
slog.F("chat_id", chatID),
|
|
slog.F("after_message_id", notify.AfterMessageID),
|
|
slog.F("status", notify.Status),
|
|
slog.F("queue_update", notify.QueueUpdate),
|
|
slog.F("last_message_id", lastMessageID),
|
|
)
|
|
if notify.AfterMessageID > 0 || notify.FullRefresh {
|
|
if notify.FullRefresh {
|
|
lastMessageID = 0
|
|
clear(delivered)
|
|
}
|
|
var (
|
|
deliveredCount int
|
|
source string
|
|
)
|
|
// Notifies can arrive out of order. Rescan from
|
|
// min(AfterMessageID, lastMessageID) to cover the gap,
|
|
// floored at afterMessageID to respect the subscription
|
|
// boundary. The delivered set deduplicates.
|
|
lookupAfter := lastMessageID
|
|
if !notify.FullRefresh {
|
|
lookupAfter = max(afterMessageID, min(notify.AfterMessageID, lastMessageID))
|
|
}
|
|
cached := p.getCachedDurableMessages(chatID, lookupAfter)
|
|
if !notify.FullRefresh && len(cached) > 0 {
|
|
for _, event := range cached {
|
|
if event.Message == nil {
|
|
continue
|
|
}
|
|
if _, ok := delivered[event.Message.ID]; ok {
|
|
continue
|
|
}
|
|
select {
|
|
case <-mergedCtx.Done():
|
|
return
|
|
case mergedEvents <- event:
|
|
}
|
|
delivered[event.Message.ID] = struct{}{}
|
|
if event.Message.ID > lastMessageID {
|
|
lastMessageID = event.Message.ID
|
|
}
|
|
deliveredCount++
|
|
source = "cache"
|
|
}
|
|
}
|
|
// DB pass picks up cross-replica messages the local cache
|
|
// cannot have. Delivered set dedupes against the cache pass.
|
|
newMessages, msgErr := p.db.GetChatMessagesByChatID(mergedCtx, database.GetChatMessagesByChatIDParams{
|
|
ChatID: chatID,
|
|
AfterID: lookupAfter,
|
|
})
|
|
if msgErr != nil {
|
|
p.logger.Warn(mergedCtx, "failed to get chat messages after pubsub notification",
|
|
slog.F("chat_id", chatID),
|
|
slog.Error(msgErr),
|
|
)
|
|
} else {
|
|
for _, msg := range newMessages {
|
|
if msg.ID <= lookupAfter {
|
|
continue
|
|
}
|
|
if _, ok := delivered[msg.ID]; ok {
|
|
continue
|
|
}
|
|
sdkMsg := db2sdk.ChatMessage(msg)
|
|
select {
|
|
case <-mergedCtx.Done():
|
|
return
|
|
case mergedEvents <- codersdk.ChatStreamEvent{
|
|
Type: codersdk.ChatStreamEventTypeMessage,
|
|
ChatID: chatID,
|
|
Message: &sdkMsg,
|
|
}:
|
|
}
|
|
delivered[msg.ID] = struct{}{}
|
|
if msg.ID > lastMessageID {
|
|
lastMessageID = msg.ID
|
|
}
|
|
deliveredCount++
|
|
switch source {
|
|
case "":
|
|
source = "db"
|
|
case "cache":
|
|
source = "cache+db"
|
|
}
|
|
}
|
|
}
|
|
// Marker for ENG-2645: subscriber delivered durable messages.
|
|
p.logger.Debug(mergedCtx, "stream subscriber delivered messages",
|
|
slog.F("chat_id", chatID),
|
|
slog.F("after_message_id", notify.AfterMessageID),
|
|
slog.F("lookup_after", lookupAfter),
|
|
slog.F("source", source),
|
|
slog.F("delivered_count", deliveredCount),
|
|
slog.F("last_message_id", lastMessageID),
|
|
)
|
|
}
|
|
if notify.Status != "" {
|
|
status := database.ChatStatus(notify.Status)
|
|
select {
|
|
case <-mergedCtx.Done():
|
|
return
|
|
case mergedEvents <- codersdk.ChatStreamEvent{
|
|
Type: codersdk.ChatStreamEventTypeStatus,
|
|
ChatID: chatID,
|
|
Status: &codersdk.ChatStreamStatus{Status: codersdk.ChatStatus(status)},
|
|
}:
|
|
}
|
|
// Notify enterprise relay manager if present.
|
|
if statusNotifications != nil {
|
|
workerID := uuid.Nil
|
|
if notify.WorkerID != "" {
|
|
if parsed, parseErr := uuid.Parse(notify.WorkerID); parseErr == nil {
|
|
workerID = parsed
|
|
}
|
|
}
|
|
select {
|
|
case statusNotifications <- StatusNotification{Status: status, WorkerID: workerID}:
|
|
case <-mergedCtx.Done():
|
|
return
|
|
}
|
|
}
|
|
}
|
|
if notify.Retry != nil {
|
|
select {
|
|
case <-mergedCtx.Done():
|
|
return
|
|
case mergedEvents <- codersdk.ChatStreamEvent{
|
|
Type: codersdk.ChatStreamEventTypeRetry,
|
|
ChatID: chatID,
|
|
Retry: notify.Retry,
|
|
}:
|
|
}
|
|
}
|
|
if notify.ErrorPayload != nil {
|
|
select {
|
|
case <-mergedCtx.Done():
|
|
return
|
|
case mergedEvents <- codersdk.ChatStreamEvent{
|
|
Type: codersdk.ChatStreamEventTypeError,
|
|
ChatID: chatID,
|
|
Error: notify.ErrorPayload,
|
|
}:
|
|
}
|
|
} else if notify.Error != "" {
|
|
select {
|
|
case <-mergedCtx.Done():
|
|
return
|
|
case mergedEvents <- codersdk.ChatStreamEvent{
|
|
Type: codersdk.ChatStreamEventTypeError,
|
|
ChatID: chatID,
|
|
Error: &codersdk.ChatError{
|
|
Message: notify.Error,
|
|
},
|
|
}:
|
|
}
|
|
}
|
|
if notify.QueueUpdate {
|
|
queueCtx, queueCancel := streamSubscriberControlFetchContext(mergedCtx)
|
|
queuedMsgs, queueErr := p.db.GetChatQueuedMessages(queueCtx, chatID)
|
|
queueCancel()
|
|
if queueErr != nil {
|
|
p.logger.Warn(mergedCtx, "failed to get queued messages after pubsub notification",
|
|
slog.F("chat_id", chatID),
|
|
slog.Error(queueErr),
|
|
)
|
|
} else {
|
|
select {
|
|
case <-mergedCtx.Done():
|
|
return
|
|
case mergedEvents <- codersdk.ChatStreamEvent{
|
|
Type: codersdk.ChatStreamEventTypeQueueUpdate,
|
|
ChatID: chatID,
|
|
QueuedMessages: db2sdk.ChatQueuedMessages(queuedMsgs),
|
|
}:
|
|
}
|
|
}
|
|
}
|
|
case event, ok := <-localParts:
|
|
if !ok {
|
|
localParts = nil
|
|
// Local parts channel closed. If pubsub is
|
|
// active we continue with pubsub-driven events.
|
|
// Otherwise terminate.
|
|
if !hasPubsubSubscription {
|
|
return
|
|
}
|
|
continue
|
|
}
|
|
if hasPubsubSubscription {
|
|
// Forward transient events from local.
|
|
// Durable events (messages, queue updates)
|
|
// come via pubsub + cache. Status is
|
|
// included alongside message_part because
|
|
// both travel through the same ordered
|
|
// channel: publishStatus is called before
|
|
// the first message_part, so FIFO delivery
|
|
// guarantees the frontend sees
|
|
// status=running before any content.
|
|
// Pubsub will deliver a duplicate status
|
|
// later; the frontend deduplicates it
|
|
// (setChatStatus is idempotent).
|
|
// action_required is also transient and
|
|
// only published on the local stream, so
|
|
// it must be forwarded here.
|
|
if event.Type == codersdk.ChatStreamEventTypeMessagePart ||
|
|
event.Type == codersdk.ChatStreamEventTypeStatus ||
|
|
event.Type == codersdk.ChatStreamEventTypeActionRequired {
|
|
select {
|
|
case <-mergedCtx.Done():
|
|
return
|
|
case mergedEvents <- event:
|
|
}
|
|
}
|
|
} else {
|
|
// No pubsub subscription: forward all event types.
|
|
select {
|
|
case <-mergedCtx.Done():
|
|
return
|
|
case mergedEvents <- event:
|
|
}
|
|
}
|
|
case event, ok := <-relayEvents:
|
|
if !ok {
|
|
relayEvents = nil
|
|
continue
|
|
}
|
|
select {
|
|
case <-mergedCtx.Done():
|
|
return
|
|
case mergedEvents <- event:
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
|
|
return initialSnapshot, mergedEvents, cancel, true
|
|
}
|
|
|
|
// publishChatPubsubEvents broadcasts a lifecycle event for each affected chat.
|
|
func (p *Server) publishChatPubsubEvents(chats []database.Chat, kind codersdk.ChatWatchEventKind) {
|
|
for _, chat := range chats {
|
|
p.publishChatPubsubEvent(chat, kind, nil)
|
|
}
|
|
}
|
|
|
|
// publishChatPubsubEvent broadcasts a chat lifecycle event via PostgreSQL
|
|
// pubsub so that all replicas can push updates to watching clients.
|
|
func (p *Server) publishChatPubsubEvent(chat database.Chat, kind codersdk.ChatWatchEventKind, diffStatus *codersdk.ChatDiffStatus) {
|
|
if p.pubsub == nil {
|
|
return
|
|
}
|
|
// diffStatus is applied below. File metadata is intentionally
|
|
// omitted from pubsub events to avoid an extra DB query per
|
|
// publish. Clients must merge pubsub updates, not replace
|
|
// cached file metadata.
|
|
sdkChat := db2sdk.Chat(chat, nil, nil)
|
|
if diffStatus != nil {
|
|
sdkChat.DiffStatus = diffStatus
|
|
}
|
|
event := codersdk.ChatWatchEvent{
|
|
Kind: kind,
|
|
Chat: sdkChat,
|
|
}
|
|
payload, err := json.Marshal(event)
|
|
if err != nil {
|
|
p.logger.Error(context.Background(), "failed to marshal chat pubsub event",
|
|
slog.F("chat_id", chat.ID),
|
|
slog.Error(err),
|
|
)
|
|
return
|
|
}
|
|
if err := p.pubsub.Publish(coderdpubsub.ChatWatchEventChannel(chat.OwnerID), payload); err != nil {
|
|
p.logger.Error(context.Background(), "failed to publish chat pubsub event",
|
|
slog.F("chat_id", chat.ID),
|
|
slog.F("kind", kind),
|
|
slog.Error(err),
|
|
)
|
|
}
|
|
}
|
|
|
|
// PublishDiffStatusChange broadcasts a diff_status_change event for
|
|
// the given chat so that watching clients know to re-fetch the diff
|
|
// status. This is called from the HTTP layer after the diff status
|
|
// is updated in the database.
|
|
func (p *Server) PublishDiffStatusChange(ctx context.Context, chatID uuid.UUID) error {
|
|
chat, err := p.db.GetChatByID(ctx, chatID)
|
|
if err != nil {
|
|
return xerrors.Errorf("get chat: %w", err)
|
|
}
|
|
|
|
dbStatus, err := p.db.GetChatDiffStatusByChatID(ctx, chatID)
|
|
if err != nil {
|
|
return xerrors.Errorf("get chat diff status: %w", err)
|
|
}
|
|
|
|
sdkStatus := db2sdk.ChatDiffStatus(chatID, &dbStatus)
|
|
p.publishChatPubsubEvent(chat, codersdk.ChatWatchEventKindDiffStatusChange, &sdkStatus)
|
|
return nil
|
|
}
|
|
|
|
// Rejects oversize images on capped providers before any upstream
|
|
// request is issued.
|
|
//
|
|
// Gotcha: a historical oversize image bricks the chat on a capped
|
|
// provider until the user switches providers back, starts a new
|
|
// chat, or edits a message above the offending one (which truncates
|
|
// the prompt forward). A future change should skip the file with a
|
|
// user-facing warning, but that requires altering the FileResolver
|
|
// contract.
|
|
func (p *Server) chatFileResolver(provider string) chatprompt.FileResolver {
|
|
return func(ctx context.Context, ids []uuid.UUID) (map[uuid.UUID]chatprompt.FileData, error) {
|
|
files, err := p.db.GetChatFilesByIDs(ctx, ids)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
imageCap, hasImageCap := chatprovider.InlineImageCapBytes(provider)
|
|
normalizedProvider := chatprovider.NormalizeProvider(provider)
|
|
result := make(map[uuid.UUID]chatprompt.FileData, len(files))
|
|
for _, f := range files {
|
|
if hasImageCap &&
|
|
strings.HasPrefix(f.Mimetype, "image/") &&
|
|
len(f.Data) >= imageCap {
|
|
err := xerrors.Errorf(
|
|
"image attachment %q is %d bytes; %s inline image limit is %d bytes",
|
|
f.Name, len(f.Data),
|
|
chatprovider.ProviderDisplayName(normalizedProvider),
|
|
imageCap,
|
|
)
|
|
// User-facing message stays client-agnostic since
|
|
// older web clients and direct API callers don't
|
|
// auto-resize; the wrapped error above keeps the
|
|
// exact byte count for operator logs.
|
|
return nil, chaterror.WithClassification(err, chaterror.ClassifiedError{
|
|
Kind: codersdk.ChatErrorKindConfig,
|
|
Provider: normalizedProvider,
|
|
Message: fmt.Sprintf(
|
|
"Image attachment exceeds %s's %s inline image limit. Replace it with a smaller image.",
|
|
chatprovider.ProviderDisplayName(normalizedProvider),
|
|
//nolint:gosec // imageCap is a small positive constant defined in chatprovider.
|
|
humanize.IBytes(uint64(imageCap)),
|
|
),
|
|
Retryable: false,
|
|
})
|
|
}
|
|
result[f.ID] = chatprompt.FileData{
|
|
Name: f.Name,
|
|
Data: f.Data,
|
|
MediaType: f.Mimetype,
|
|
}
|
|
}
|
|
return result, nil
|
|
}
|
|
}
|
|
|
|
// trackWorkspaceUsage bumps the workspace's last_used_at via the
|
|
// usage tracker and extends the workspace's autostop deadline. If
|
|
// wsID is not yet valid, it re-reads the chat from the DB to pick
|
|
// up late associations (e.g. create_workspace linking a workspace
|
|
// mid-conversation). The caller should store the returned value so
|
|
// that subsequent calls skip the DB lookup once a workspace has
|
|
// been found.
|
|
func (p *Server) trackWorkspaceUsage(
|
|
ctx context.Context,
|
|
chatID uuid.UUID,
|
|
wsID uuid.NullUUID,
|
|
logger slog.Logger,
|
|
) uuid.NullUUID {
|
|
if p.usageTracker == nil {
|
|
return wsID
|
|
}
|
|
if !wsID.Valid {
|
|
latest, err := p.db.GetChatByID(ctx, chatID)
|
|
if err != nil {
|
|
logger.Warn(ctx, "failed to re-read chat for workspace association", slog.Error(err))
|
|
return wsID
|
|
}
|
|
wsID = latest.WorkspaceID
|
|
}
|
|
if wsID.Valid {
|
|
p.usageTracker.Add(wsID.UUID)
|
|
// Bump the workspace autostop deadline. We pass time.Time{}
|
|
// for nextAutostart since we don't have access to
|
|
// TemplateScheduleStore here. The activity bump logic
|
|
// defaults to the template's activity_bump duration
|
|
// (typically 1 hour). Chat workspaces are never prebuilds,
|
|
// so no prebuild guard is needed (unlike reporter.go).
|
|
//
|
|
// This fires every heartbeat (~30s) but the SQL only
|
|
// writes when 5% of the deadline has elapsed — most calls
|
|
// perform a read-only CTE lookup with no UPDATE.
|
|
//
|
|
// Scaling note: for 10,000 active chats, this could lead to
|
|
// approx. 333 CTE queries/second. A cheap fix for this could
|
|
// be to heartbeat every Nth query. Leaving as potential future
|
|
// low-hanging fruit if needed.
|
|
workspacestats.ActivityBumpWorkspace(ctx, logger.Named("activity_bump"), p.db, wsID.UUID, time.Time{}, workspacestats.ActivityBumpReasonChatHeartbeat)
|
|
}
|
|
return wsID
|
|
}
|
|
|
|
type runChatResult struct {
|
|
FinalAssistantText string
|
|
StatusLabelModel fantasy.LanguageModel
|
|
ProviderKeys chatprovider.ProviderAPIKeys
|
|
FallbackProvider string
|
|
FallbackRoute resolvedModelRoute
|
|
FallbackModel string
|
|
ModelBuildOptions modelBuildOptions
|
|
TriggerMessageID int64
|
|
HistoryTipMessageID int64
|
|
}
|
|
|
|
func activeTurnAPIKeyIDFromMessages(messages []database.ChatMessage) (string, bool) {
|
|
for i := len(messages) - 1; i >= 0; i-- {
|
|
message := messages[i]
|
|
if message.Role != database.ChatMessageRoleUser {
|
|
continue
|
|
}
|
|
if !isUserVisibleChatMessage(message) &&
|
|
!(message.Visibility == database.ChatMessageVisibilityModel && message.Compressed) {
|
|
continue
|
|
}
|
|
if !message.APIKeyID.Valid || message.APIKeyID.String == "" {
|
|
return "", false
|
|
}
|
|
return message.APIKeyID.String, true
|
|
}
|
|
return "", false
|
|
}
|
|
|
|
func isUserVisibleChatMessage(message database.ChatMessage) bool {
|
|
return message.Visibility == database.ChatMessageVisibilityBoth ||
|
|
message.Visibility == database.ChatMessageVisibilityUser
|
|
}
|
|
|
|
func allToolNames(allTools []fantasy.AgentTool) []string {
|
|
toolNames := make([]string, 0, len(allTools))
|
|
for _, tool := range allTools {
|
|
toolNames = append(toolNames, tool.Info().Name)
|
|
}
|
|
return toolNames
|
|
}
|
|
|
|
func isExploreSubagentMode(mode database.NullChatMode) bool {
|
|
return mode.Valid && mode.ChatMode == database.ChatModeExplore
|
|
}
|
|
|
|
// filterExternalMCPConfigsForTurn returns the external MCP server configs
|
|
// visible on the current turn. Explore children snapshot this filtered set at
|
|
// spawn time so later model overrides cannot widen the external-tool boundary.
|
|
func filterExternalMCPConfigsForTurn(
|
|
configs []database.MCPServerConfig,
|
|
mode database.NullChatPlanMode,
|
|
parentChatID uuid.NullUUID,
|
|
) ([]database.MCPServerConfig, map[uuid.UUID]struct{}) {
|
|
if !mode.Valid || mode.ChatPlanMode != database.ChatPlanModePlan {
|
|
return configs, nil
|
|
}
|
|
if parentChatID.Valid {
|
|
// Plan-mode subagents do not receive external MCP tools because
|
|
// their trust boundary is narrower than the root chat's.
|
|
return nil, map[uuid.UUID]struct{}{}
|
|
}
|
|
|
|
filtered := make([]database.MCPServerConfig, 0, len(configs))
|
|
approvedIDs := make(map[uuid.UUID]struct{})
|
|
for _, cfg := range configs {
|
|
if !cfg.AllowInPlanMode {
|
|
continue
|
|
}
|
|
filtered = append(filtered, cfg)
|
|
approvedIDs[cfg.ID] = struct{}{}
|
|
}
|
|
return filtered, approvedIDs
|
|
}
|
|
|
|
func builtinPlanToolAllowed(name string, isRootChat bool) bool {
|
|
switch name {
|
|
case "read_file", "execute", "process_output", "read_skill", "read_skill_file":
|
|
return true
|
|
case "write_file", "edit_files", "list_templates", "read_template",
|
|
"create_workspace", "start_workspace", "stop_workspace", "propose_plan", "spawn_agent",
|
|
"spawn_explore_agent", "wait_agent", "ask_user_question", "attach_file":
|
|
return isRootChat
|
|
case "process_list", "process_signal", "message_agent", "close_agent",
|
|
"spawn_computer_use_agent":
|
|
return false
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
func toolAllowedForTurn(
|
|
tool fantasy.AgentTool,
|
|
mode database.NullChatPlanMode,
|
|
parentChatID uuid.NullUUID,
|
|
approvedMCPConfigIDs map[uuid.UUID]struct{},
|
|
) bool {
|
|
if !mode.Valid || mode.ChatPlanMode != database.ChatPlanModePlan {
|
|
return true
|
|
}
|
|
if builtinPlanToolAllowed(tool.Info().Name, !parentChatID.Valid) {
|
|
return true
|
|
}
|
|
mcpTool, ok := tool.(mcpclient.MCPToolIdentifier)
|
|
if !ok {
|
|
return false
|
|
}
|
|
_, approved := approvedMCPConfigIDs[mcpTool.MCPServerConfigID()]
|
|
return approved
|
|
}
|
|
|
|
func filterToolsForTurn(
|
|
allTools []fantasy.AgentTool,
|
|
mode database.NullChatPlanMode,
|
|
parentChatID uuid.NullUUID,
|
|
approvedMCPConfigIDs map[uuid.UUID]struct{},
|
|
) []fantasy.AgentTool {
|
|
if !mode.Valid || mode.ChatPlanMode != database.ChatPlanModePlan {
|
|
return allTools
|
|
}
|
|
|
|
filtered := make([]fantasy.AgentTool, 0, len(allTools))
|
|
for _, tool := range allTools {
|
|
if toolAllowedForTurn(tool, mode, parentChatID, approvedMCPConfigIDs) {
|
|
filtered = append(filtered, tool)
|
|
}
|
|
}
|
|
return filtered
|
|
}
|
|
|
|
// activeToolNamesForTurn extends the built-in plan allowlist with approved
|
|
// external MCP tools for root plan-mode chats.
|
|
func activeToolNamesForTurn(
|
|
allTools []fantasy.AgentTool,
|
|
mode database.NullChatPlanMode,
|
|
parentChatID uuid.NullUUID,
|
|
approvedMCPConfigIDs map[uuid.UUID]struct{},
|
|
) []string {
|
|
toolNames := make([]string, 0, len(allTools))
|
|
for _, tool := range allTools {
|
|
if toolAllowedForTurn(tool, mode, parentChatID, approvedMCPConfigIDs) {
|
|
toolNames = append(toolNames, tool.Info().Name)
|
|
}
|
|
}
|
|
return toolNames
|
|
}
|
|
|
|
func allowedExploreToolNames(allTools []fantasy.AgentTool) []string {
|
|
builtinExplorePolicy := map[string]bool{
|
|
"read_file": true,
|
|
"write_file": false,
|
|
"edit_files": false,
|
|
"execute": true,
|
|
"process_output": true,
|
|
"process_list": false,
|
|
"process_signal": false,
|
|
"list_templates": false,
|
|
"read_template": false,
|
|
"create_workspace": false,
|
|
"start_workspace": false,
|
|
"stop_workspace": false,
|
|
"propose_plan": false,
|
|
"spawn_agent": false,
|
|
"wait_agent": false,
|
|
"message_agent": false,
|
|
"close_agent": false,
|
|
"read_skill": true,
|
|
"read_skill_file": true,
|
|
"ask_user_question": false,
|
|
}
|
|
|
|
toolNames := make([]string, 0, len(allTools))
|
|
for _, tool := range allTools {
|
|
name := tool.Info().Name
|
|
if builtinExplorePolicy[name] {
|
|
toolNames = append(toolNames, name)
|
|
continue
|
|
}
|
|
// External MCP tools pass through here. They were snapshot-filtered
|
|
// at spawn time on chat.MCPServerIDs. WorkspaceMCPTool does not
|
|
// implement MCPToolIdentifier, so workspace tools are excluded
|
|
// here too, in addition to the structural exclusion in runChat
|
|
// tool assembly.
|
|
if _, ok := tool.(mcpclient.MCPToolIdentifier); ok {
|
|
toolNames = append(toolNames, name)
|
|
}
|
|
}
|
|
return toolNames
|
|
}
|
|
|
|
// allowedBehaviorToolNames runs only on non-plan turns because
|
|
// appendDynamicTools returns early for plan mode. Within that boundary,
|
|
// Explore mode wins over the default behavior that allows all tools.
|
|
func allowedBehaviorToolNames(
|
|
allTools []fantasy.AgentTool,
|
|
chatMode database.NullChatMode,
|
|
) []string {
|
|
if isExploreSubagentMode(chatMode) {
|
|
return allowedExploreToolNames(allTools)
|
|
}
|
|
return allToolNames(allTools)
|
|
}
|
|
|
|
func stopAfterPlanTools(
|
|
planMode database.NullChatPlanMode,
|
|
parentChatID uuid.NullUUID,
|
|
) map[string]struct{} {
|
|
if !planMode.Valid || planMode.ChatPlanMode != database.ChatPlanModePlan {
|
|
return nil
|
|
}
|
|
stopTools := map[string]struct{}{
|
|
"propose_plan": {},
|
|
}
|
|
if !parentChatID.Valid {
|
|
stopTools["ask_user_question"] = struct{}{}
|
|
}
|
|
return stopTools
|
|
}
|
|
|
|
func stopAfterBehaviorTools(
|
|
planMode database.NullChatPlanMode,
|
|
chatMode database.NullChatMode,
|
|
parentChatID uuid.NullUUID,
|
|
) map[string]struct{} {
|
|
if isExploreSubagentMode(chatMode) {
|
|
return nil
|
|
}
|
|
return stopAfterPlanTools(planMode, parentChatID)
|
|
}
|
|
|
|
type systemPromptBehaviorContext struct {
|
|
planMode database.NullChatPlanMode
|
|
chatMode database.NullChatMode
|
|
planModeInstructions string
|
|
isRootChat bool
|
|
}
|
|
|
|
func workspaceSkillsForResolution(workspaceSkills []chattool.SkillMeta) []skillspkg.Skill {
|
|
if len(workspaceSkills) == 0 {
|
|
return nil
|
|
}
|
|
resolved := make([]skillspkg.Skill, 0, len(workspaceSkills))
|
|
for _, skill := range workspaceSkills {
|
|
resolved = append(resolved, skillspkg.Skill{
|
|
Name: skill.Name,
|
|
Description: skill.Description,
|
|
Source: skillspkg.SourceWorkspace,
|
|
})
|
|
}
|
|
return resolved
|
|
}
|
|
|
|
func mergeTurnSkills(
|
|
personalSkills []skillspkg.Skill,
|
|
workspaceSkills []chattool.SkillMeta,
|
|
) []skillspkg.ResolvedSkill {
|
|
return skillspkg.MergeSkills(
|
|
personalSkills,
|
|
workspaceSkillsForResolution(workspaceSkills),
|
|
)
|
|
}
|
|
|
|
// buildSystemPrompt applies system-level prompt injections in the
|
|
// canonical order. It is used by both the initial prompt assembly
|
|
// and the ReloadMessages callback to keep them in sync.
|
|
func buildSystemPrompt(
|
|
prompt []fantasy.Message,
|
|
subagentInstruction string,
|
|
instruction string,
|
|
resolvedSkills []skillspkg.ResolvedSkill,
|
|
userPrompt string,
|
|
behaviorContext systemPromptBehaviorContext,
|
|
) []fantasy.Message {
|
|
if subagentInstruction != "" {
|
|
prompt = chatprompt.InsertSystem(prompt, subagentInstruction)
|
|
}
|
|
if instruction != "" {
|
|
prompt = chatprompt.InsertSystem(prompt, instruction)
|
|
}
|
|
if skillIndex := chattool.FormatResolvedSkillIndex(resolvedSkills); skillIndex != "" {
|
|
prompt = chatprompt.InsertSystem(prompt, skillIndex)
|
|
}
|
|
if userPrompt != "" {
|
|
prompt = chatprompt.InsertSystem(prompt, userPrompt)
|
|
}
|
|
if isExploreSubagentMode(behaviorContext.chatMode) {
|
|
prompt = chatprompt.InsertSystem(prompt, ExploreSubagentOverlayPrompt)
|
|
return prompt
|
|
}
|
|
isPlanModeTurn := behaviorContext.planMode.Valid && behaviorContext.planMode.ChatPlanMode == database.ChatPlanModePlan
|
|
if isPlanModeTurn {
|
|
if behaviorContext.isRootChat {
|
|
prompt = chatprompt.InsertSystem(prompt, PlanningOverlayPrompt())
|
|
if behaviorContext.planModeInstructions != "" {
|
|
prompt = chatprompt.InsertSystem(prompt, behaviorContext.planModeInstructions)
|
|
}
|
|
} else {
|
|
prompt = chatprompt.InsertSystem(prompt, PlanningSubagentOverlayPrompt)
|
|
}
|
|
}
|
|
return prompt
|
|
}
|
|
|
|
func removeSkillIndexMessages(prompt []fantasy.Message) []fantasy.Message {
|
|
out := make([]fantasy.Message, 0, len(prompt))
|
|
removed := false
|
|
for _, message := range prompt {
|
|
if isSkillIndexMessage(message) {
|
|
removed = true
|
|
continue
|
|
}
|
|
out = append(out, message)
|
|
}
|
|
if !removed {
|
|
return prompt
|
|
}
|
|
return out
|
|
}
|
|
|
|
func isSkillIndexMessage(message fantasy.Message) bool {
|
|
if message.Role != fantasy.MessageRoleSystem || len(message.Content) != 1 {
|
|
return false
|
|
}
|
|
textPart, ok := fantasy.AsMessagePart[fantasy.TextPart](message.Content[0])
|
|
if !ok {
|
|
return false
|
|
}
|
|
text := strings.TrimSpace(textPart.Text)
|
|
return strings.HasPrefix(text, chattool.AvailableSkillsOpenTag+"\n") && strings.HasSuffix(text, chattool.AvailableSkillsCloseTag)
|
|
}
|
|
|
|
type rootChatToolsOptions struct {
|
|
chat database.Chat
|
|
modelConfigID uuid.UUID
|
|
workspaceCtx *turnWorkspaceContext
|
|
workspaceMu *sync.Mutex
|
|
resolvePlanPath func(context.Context) (string, string, error)
|
|
storeFile chattool.StoreFileFunc
|
|
isPlanModeTurn bool
|
|
// primerCtx scopes the workspace MCP cache primer goroutines
|
|
// that onChatUpdated launches. runChat cancels it before
|
|
// workspaceCtx.close() so an in-flight primer cannot dial a
|
|
// fresh conn after the cached one was released.
|
|
primerCtx context.Context
|
|
}
|
|
|
|
func (p *Server) loadPlanModeInstructions(
|
|
ctx context.Context,
|
|
mode database.NullChatPlanMode,
|
|
logger slog.Logger,
|
|
) string {
|
|
if !mode.Valid || mode.ChatPlanMode != database.ChatPlanModePlan {
|
|
return ""
|
|
}
|
|
|
|
// Plan-mode instructions live in deployment config, but chat workers do
|
|
// not carry a deployment-config actor during background execution.
|
|
//nolint:gocritic // Required to read deployment config during background chat processing.
|
|
systemCtx := dbauthz.AsSystemRestricted(ctx)
|
|
fetched, err := p.db.GetChatPlanModeInstructions(systemCtx)
|
|
if err != nil {
|
|
logger.Warn(ctx,
|
|
"failed to fetch plan mode instructions",
|
|
slog.Error(err),
|
|
)
|
|
return ""
|
|
}
|
|
|
|
return fetched
|
|
}
|
|
|
|
func userSkillContext(ctx context.Context, userID uuid.UUID) context.Context {
|
|
actor := rbac.Subject{
|
|
Type: rbac.SubjectTypeUser,
|
|
ID: userID.String(),
|
|
Roles: rbac.RoleIdentifiers{rbac.RoleMember()},
|
|
Scope: rbac.ScopeAll,
|
|
}.WithCachedASTValue()
|
|
// Chat turns run asynchronously after admission, so the original request
|
|
// actor may no longer be available when a worker loads personal skills.
|
|
// We synthesize the chat owner as a member instead of reusing that actor.
|
|
// Hardcoding RoleMember is safe because dbauthz enforces
|
|
// ResourceUserSkill.WithOwner(userID), so this actor cannot read any other
|
|
// user's skills regardless of role. Org scoping is not needed because
|
|
// personal skills are user-scoped, not org-scoped.
|
|
//nolint:gocritic // The synthetic actor is intentional for the reasons above.
|
|
return dbauthz.As(ctx, actor)
|
|
}
|
|
|
|
func (p *Server) fetchPersonalSkillMetadata(
|
|
ctx context.Context,
|
|
userID uuid.UUID,
|
|
logger slog.Logger,
|
|
) []skillspkg.Skill {
|
|
rows, err := p.db.ListUserSkillMetadataByUserID(userSkillContext(ctx, userID), userID)
|
|
// See package coderd/x/skills (doc.go) for why metadata fetch failures
|
|
// intentionally degrade to an empty personal-skill list instead of
|
|
// failing the chat turn.
|
|
if err != nil {
|
|
logger.Warn(ctx, "failed to load personal skill metadata",
|
|
slog.F("owner_id", userID),
|
|
slog.Error(err),
|
|
)
|
|
return nil
|
|
}
|
|
|
|
personalSkills := make([]skillspkg.Skill, 0, len(rows))
|
|
for _, row := range rows {
|
|
personalSkills = append(personalSkills, skillspkg.Skill{
|
|
Name: row.Name,
|
|
Description: row.Description,
|
|
Source: skillspkg.SourcePersonal,
|
|
})
|
|
}
|
|
return personalSkills
|
|
}
|
|
|
|
func (p *Server) loadPersonalSkillBody(
|
|
ctx context.Context,
|
|
userID uuid.UUID,
|
|
name string,
|
|
) (skillspkg.ParsedSkill, error) {
|
|
row, err := p.db.GetUserSkillByUserIDAndName(
|
|
userSkillContext(ctx, userID),
|
|
database.GetUserSkillByUserIDAndNameParams{
|
|
UserID: userID,
|
|
Name: name,
|
|
},
|
|
)
|
|
if err != nil {
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return skillspkg.ParsedSkill{}, skillspkg.ErrSkillNotFound
|
|
}
|
|
p.logger.Error(ctx, "load personal skill body failed",
|
|
slog.F("user_id", userID),
|
|
slog.F("name", name),
|
|
slog.Error(err),
|
|
)
|
|
return skillspkg.ParsedSkill{}, xerrors.Errorf("load personal skill body: %w", err)
|
|
}
|
|
|
|
parsed, err := skillspkg.ParsePersonalSkillMarkdown([]byte(row.Content))
|
|
if err != nil {
|
|
p.logger.Error(ctx, "parse personal skill body failed",
|
|
slog.F("user_id", userID),
|
|
slog.F("name", name),
|
|
slog.Error(err),
|
|
)
|
|
return skillspkg.ParsedSkill{}, xerrors.Errorf("parse personal skill body: %w", err)
|
|
}
|
|
return parsed, nil
|
|
}
|
|
|
|
func (p *Server) appendRootChatTools(
|
|
ctx context.Context,
|
|
tools []fantasy.AgentTool,
|
|
opts rootChatToolsOptions,
|
|
) []fantasy.AgentTool {
|
|
onChatUpdated := func(updatedChat database.Chat) {
|
|
opts.workspaceCtx.selectWorkspace(updatedChat)
|
|
// Notify the frontend immediately so it can start streaming
|
|
// build logs before the tool completes.
|
|
p.publishChatPubsubEvent(updatedChat, codersdk.ChatWatchEventKindStatusChange, nil)
|
|
|
|
// Note: we intentionally do not insert AGENTS.md / workspace
|
|
// context here. Local tool callbacks must not mutate chat
|
|
// history while a local-tool generation task is in flight,
|
|
// because that advances history_version before the tool
|
|
// result is committed and exits the local-tool commit as
|
|
// stale. Workspace context is persisted by the
|
|
// persist_workspace_context generation action in a later
|
|
// pass.
|
|
|
|
// Prime the workspace MCP tools cache while the create_workspace
|
|
// or start_workspace tool is still running. The AgentID guard
|
|
// below restricts the primer to the post-ready callback, when
|
|
// the agent is reachable. ListMCPTools may still return an
|
|
// empty list on the first try when the agent's MCP Connect is
|
|
// racing with agent startup; primeWorkspaceMCPCache retries
|
|
// with a short backoff up to workspaceMCPPrimeMaxWait. Priming
|
|
// here lets the next assistant-generation action hit the cache
|
|
// instead of dialing again on a separate timeout budget.
|
|
//
|
|
// Run asynchronously: the tool itself must not block on the
|
|
// primer because the agent may not advertise any MCP tools at
|
|
// all (e.g. minimal templates), in which case the primer waits
|
|
// the full budget before giving up. The next assistant-generation
|
|
// action covers the cache miss path; the primer is purely an
|
|
// optimization that warms the cache while the LLM is thinking.
|
|
// inflight tracking ensures server shutdown still waits for any
|
|
// in-progress primer.
|
|
//
|
|
// Guard on both WorkspaceID and AgentID being valid:
|
|
// create_workspace and start_workspace each fire onChatUpdated
|
|
// twice for a new build (binding before waitForAgentReady;
|
|
// post-ready after it), and stop_workspace fires it with a nil
|
|
// agent. Only the post-ready callback has a live AgentID, so
|
|
// the pre-build and stop-side firings would otherwise spawn a
|
|
// primer goroutine that dials a missing or dying agent and
|
|
// burns the full budget for nothing.
|
|
snapshot := opts.workspaceCtx.currentChatSnapshot()
|
|
if snapshot.WorkspaceID.Valid && snapshot.AgentID.Valid {
|
|
p.inflight.Go(func() {
|
|
p.primeWorkspaceMCPCache(opts.primerCtx, p.logger, snapshot.ID, opts.workspaceCtx)
|
|
})
|
|
}
|
|
}
|
|
|
|
tools = append(tools,
|
|
chattool.ListTemplates(p.db, opts.chat.OrganizationID, chattool.ListTemplatesOptions{
|
|
OwnerID: opts.chat.OwnerID,
|
|
AllowedTemplateIDs: p.chatTemplateAllowlist,
|
|
}),
|
|
chattool.ReadTemplate(p.db, opts.chat.OrganizationID, chattool.ReadTemplateOptions{
|
|
OwnerID: opts.chat.OwnerID,
|
|
AllowedTemplateIDs: p.chatTemplateAllowlist,
|
|
}),
|
|
chattool.CreateWorkspace(p.db, opts.chat.OrganizationID, opts.chat.ID, chattool.CreateWorkspaceOptions{
|
|
OwnerID: opts.chat.OwnerID,
|
|
CreateFn: p.createWorkspaceFn,
|
|
AgentConnFn: chattool.AgentConnFunc(p.agentConnFn),
|
|
AgentInactiveDisconnectTimeout: p.agentInactiveDisconnectTimeout,
|
|
WorkspaceMu: opts.workspaceMu,
|
|
OnChatUpdated: onChatUpdated,
|
|
Logger: p.logger,
|
|
AllowedTemplateIDs: p.chatTemplateAllowlist,
|
|
}),
|
|
chattool.StartWorkspace(p.db, opts.chat.ID, chattool.StartWorkspaceOptions{
|
|
OwnerID: opts.chat.OwnerID,
|
|
StartFn: p.startWorkspaceFn,
|
|
AgentConnFn: chattool.AgentConnFunc(p.agentConnFn),
|
|
WorkspaceMu: opts.workspaceMu,
|
|
OnChatUpdated: onChatUpdated,
|
|
Logger: p.logger,
|
|
}),
|
|
chattool.StopWorkspace(p.db, opts.chat.ID, chattool.StopWorkspaceOptions{
|
|
OwnerID: opts.chat.OwnerID,
|
|
StopFn: p.stopWorkspaceFn,
|
|
WorkspaceMu: opts.workspaceMu,
|
|
OnChatUpdated: onChatUpdated,
|
|
Logger: p.logger,
|
|
}),
|
|
)
|
|
if opts.isPlanModeTurn {
|
|
tools = append(tools, chattool.ProposePlan(chattool.ProposePlanOptions{
|
|
GetWorkspaceConn: opts.workspaceCtx.getWorkspaceConn,
|
|
ResolvePlanPath: opts.resolvePlanPath,
|
|
IsPlanTurn: opts.isPlanModeTurn,
|
|
StoreFile: opts.storeFile,
|
|
}))
|
|
}
|
|
|
|
return append(tools, p.subagentTools(ctx, func() database.Chat {
|
|
return opts.chat
|
|
}, opts.modelConfigID)...)
|
|
}
|
|
|
|
func appendDynamicTools(
|
|
ctx context.Context,
|
|
logger slog.Logger,
|
|
tools []fantasy.AgentTool,
|
|
raw pqtype.NullRawMessage,
|
|
planMode database.NullChatPlanMode,
|
|
chatMode database.NullChatMode,
|
|
) ([]fantasy.AgentTool, map[string]bool, error) {
|
|
if isExploreSubagentMode(chatMode) || (planMode.Valid && planMode.ChatPlanMode == database.ChatPlanModePlan) {
|
|
return tools, nil, nil
|
|
}
|
|
|
|
dynamicToolNames, err := parseDynamicToolNames(raw)
|
|
if err != nil {
|
|
return nil, nil, xerrors.Errorf("parse dynamic tool names: %w", err)
|
|
}
|
|
if len(dynamicToolNames) == 0 {
|
|
return tools, dynamicToolNames, nil
|
|
}
|
|
|
|
var dynamicToolDefs []codersdk.DynamicTool
|
|
if raw.Valid {
|
|
if err := json.Unmarshal(raw.RawMessage, &dynamicToolDefs); err != nil {
|
|
return nil, nil, xerrors.Errorf("unmarshal dynamic tools: %w", err)
|
|
}
|
|
}
|
|
|
|
activeToolNames := make(map[string]struct{}, len(tools))
|
|
for _, name := range allowedBehaviorToolNames(tools, chatMode) {
|
|
activeToolNames[name] = struct{}{}
|
|
}
|
|
for _, t := range tools {
|
|
info := t.Info()
|
|
if _, active := activeToolNames[info.Name]; !active {
|
|
continue
|
|
}
|
|
if dynamicToolNames[info.Name] {
|
|
logger.Warn(ctx, "dynamic tool name collides with built-in tool, built-in takes precedence",
|
|
slog.F("tool_name", info.Name))
|
|
delete(dynamicToolNames, info.Name)
|
|
}
|
|
}
|
|
|
|
var filteredDefs []codersdk.DynamicTool
|
|
for _, dt := range dynamicToolDefs {
|
|
if dynamicToolNames[dt.Name] {
|
|
filteredDefs = append(filteredDefs, dt)
|
|
}
|
|
}
|
|
|
|
return append(tools, dynamicToolsFromSDK(logger, filteredDefs)...), dynamicToolNames, nil
|
|
}
|
|
|
|
// buildProviderTools creates provider-native tool definitions
|
|
// (like web search) based on the model configuration. These
|
|
// tools are executed server-side by the LLM provider.
|
|
func buildProviderTools(options *codersdk.ChatModelProviderOptions) []chatloop.ProviderTool {
|
|
var tools []chatloop.ProviderTool
|
|
|
|
if options == nil {
|
|
return nil
|
|
}
|
|
|
|
if options.Anthropic != nil && options.Anthropic.WebSearchEnabled != nil && *options.Anthropic.WebSearchEnabled {
|
|
tools = append(tools, chatloop.ProviderTool{
|
|
Definition: anthropic.WebSearchTool(&anthropic.WebSearchToolOptions{
|
|
AllowedDomains: options.Anthropic.AllowedDomains,
|
|
BlockedDomains: options.Anthropic.BlockedDomains,
|
|
}),
|
|
})
|
|
}
|
|
|
|
if tool, ok := chatopenai.WebSearchTool(options.OpenAI); ok {
|
|
tools = append(tools, chatloop.ProviderTool{
|
|
Definition: tool,
|
|
})
|
|
}
|
|
|
|
if options.Google != nil && options.Google.WebSearchEnabled != nil && *options.Google.WebSearchEnabled {
|
|
tools = append(tools, chatloop.ProviderTool{
|
|
Definition: fantasy.ProviderDefinedTool{
|
|
ID: "web_search",
|
|
Name: "web_search",
|
|
},
|
|
})
|
|
}
|
|
|
|
return tools
|
|
}
|
|
|
|
// persistChatContextSummary is called from the chat loop's compaction
|
|
// callback. activeAPIKeyID is stamped onto the summary user message. When
|
|
// empty, it falls back to the delegated key in ctx.
|
|
func (p *Server) persistChatContextSummary(
|
|
ctx context.Context,
|
|
chatID uuid.UUID,
|
|
modelConfigID uuid.UUID,
|
|
activeAPIKeyID string,
|
|
toolCallID string,
|
|
result chatloop.CompactionResult,
|
|
) error {
|
|
if strings.TrimSpace(result.SystemSummary) == "" ||
|
|
strings.TrimSpace(result.SummaryReport) == "" {
|
|
return nil
|
|
}
|
|
|
|
systemContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageText(result.SystemSummary),
|
|
})
|
|
if err != nil {
|
|
return xerrors.Errorf("encode system summary: %w", err)
|
|
}
|
|
|
|
args, err := json.Marshal(map[string]any{
|
|
"source": "automatic",
|
|
"threshold_percent": result.ThresholdPercent,
|
|
})
|
|
if err != nil {
|
|
return xerrors.Errorf("encode summary tool args: %w", err)
|
|
}
|
|
|
|
assistantContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageToolCall(toolCallID, "chat_summarized", args),
|
|
})
|
|
if err != nil {
|
|
return xerrors.Errorf("encode summary tool call: %w", err)
|
|
}
|
|
|
|
summaryResult, err := json.Marshal(map[string]any{
|
|
"summary": result.SummaryReport,
|
|
"source": "automatic",
|
|
"threshold_percent": result.ThresholdPercent,
|
|
"usage_percent": result.UsagePercent,
|
|
"context_tokens": result.ContextTokens,
|
|
"context_limit_tokens": result.ContextLimit,
|
|
})
|
|
if err != nil {
|
|
return xerrors.Errorf("encode summary result payload: %w", err)
|
|
}
|
|
toolResult, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
|
|
codersdk.ChatMessageToolResult(toolCallID, "chat_summarized", summaryResult, false, false),
|
|
})
|
|
if err != nil {
|
|
return xerrors.Errorf("encode summary tool result: %w", err)
|
|
}
|
|
|
|
summaryAPIKeyID := activeAPIKeyID
|
|
if summaryAPIKeyID == "" {
|
|
summaryAPIKeyID, _ = aibridge.DelegatedAPIKeyIDFromContext(ctx)
|
|
}
|
|
|
|
var insertedMessages []database.ChatMessage
|
|
txErr := p.db.InTx(func(tx database.Store) error {
|
|
summaryParams := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by append[User]ChatMessage.
|
|
ChatID: chatID,
|
|
}
|
|
|
|
summaryUserMsg := newUserChatMessage(
|
|
summaryAPIKeyID,
|
|
systemContent,
|
|
database.ChatMessageVisibilityModel,
|
|
modelConfigID,
|
|
chatprompt.CurrentContentVersion,
|
|
)
|
|
summaryUserMsg = summaryUserMsg.withCompressed()
|
|
appendUserChatMessage(&summaryParams, summaryUserMsg)
|
|
|
|
appendChatMessage(&summaryParams, newChatMessage(
|
|
database.ChatMessageRoleAssistant,
|
|
assistantContent,
|
|
database.ChatMessageVisibilityUser,
|
|
modelConfigID,
|
|
chatprompt.CurrentContentVersion,
|
|
).withCompressed())
|
|
|
|
appendChatMessage(&summaryParams, newChatMessage(
|
|
database.ChatMessageRoleTool,
|
|
toolResult,
|
|
database.ChatMessageVisibilityBoth,
|
|
modelConfigID,
|
|
chatprompt.CurrentContentVersion,
|
|
).withCompressed())
|
|
|
|
allInserted, txErr := tx.InsertChatMessages(ctx, summaryParams)
|
|
if txErr != nil {
|
|
return xerrors.Errorf("insert summary messages: %w", txErr)
|
|
}
|
|
insertedMessages = allInserted[1:]
|
|
return nil
|
|
}, nil)
|
|
if txErr != nil {
|
|
return txErr
|
|
}
|
|
|
|
_ = insertedMessages
|
|
return nil
|
|
}
|
|
|
|
func (p *Server) resolveChatModel(
|
|
ctx context.Context,
|
|
chat database.Chat,
|
|
modelOpts modelBuildOptions,
|
|
) (
|
|
model fantasy.LanguageModel,
|
|
dbConfig database.ChatModelConfig,
|
|
keys chatprovider.ProviderAPIKeys,
|
|
route resolvedModelRoute,
|
|
debugEnabled bool,
|
|
resolvedProvider string,
|
|
resolvedModel string,
|
|
err error,
|
|
) {
|
|
dbConfig, err = p.resolveModelConfig(ctx, chat)
|
|
if err != nil {
|
|
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, resolvedModelRoute{}, false, "", "", xerrors.Errorf("resolve model config: %w", err)
|
|
}
|
|
|
|
if !dbConfig.Enabled {
|
|
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, resolvedModelRoute{}, false, "", "", xerrors.Errorf("chat model config %s is disabled", dbConfig.ID)
|
|
}
|
|
|
|
route, err = p.resolveModelRouteForConfig(ctx, chat.OwnerID, dbConfig, chatprovider.ProviderAPIKeys{})
|
|
if err != nil {
|
|
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, resolvedModelRoute{}, false, "", "", err
|
|
}
|
|
keys = route.directProviderKeys()
|
|
|
|
providerHint, err := route.providerHint()
|
|
if err != nil {
|
|
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, resolvedModelRoute{}, false, "", "", err
|
|
}
|
|
resolvedProvider, resolvedModel, err = chatprovider.ResolveModelWithProviderHint(
|
|
dbConfig.Model,
|
|
providerHint,
|
|
)
|
|
if err != nil {
|
|
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, resolvedModelRoute{}, false, "", "", xerrors.Errorf(
|
|
"resolve model metadata: %w", err,
|
|
)
|
|
}
|
|
|
|
model, debugEnabled, err = p.newDebugAwareModel(ctx, modelClientRequest{
|
|
Chat: chat,
|
|
ModelName: dbConfig.Model,
|
|
UserAgent: chatprovider.UserAgent(),
|
|
ExtraHeaders: chatprovider.CoderHeaders(chat),
|
|
}, route, modelOpts)
|
|
if err != nil {
|
|
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, resolvedModelRoute{}, false, "", "", xerrors.Errorf(
|
|
"create model: %w", err,
|
|
)
|
|
}
|
|
return model, dbConfig, keys, route, debugEnabled, resolvedProvider, resolvedModel, nil
|
|
}
|
|
|
|
func (p *Server) aiProviderConfig(ctx context.Context, provider database.AIProvider) (chatprovider.ConfiguredProvider, error) {
|
|
keys, err := p.db.GetAIProviderKeysByProviderID(ctx, provider.ID)
|
|
if err != nil {
|
|
return chatprovider.ConfiguredProvider{}, xerrors.Errorf("get AI provider keys: %w", err)
|
|
}
|
|
return p.aiProviderConfigFromKeys(provider, keys)
|
|
}
|
|
|
|
func (p *Server) aiProviderConfigFromKeys(provider database.AIProvider, keys []database.AIProviderKey) (chatprovider.ConfiguredProvider, error) {
|
|
if !provider.Enabled {
|
|
return chatprovider.ConfiguredProvider{}, xerrors.Errorf("AI provider %s is disabled", provider.ID)
|
|
}
|
|
apiKey := ""
|
|
// GetAIProviderKeysByProviderID orders keys oldest first. chatd consumes
|
|
// one provider-scoped key because runtime provider config has one API key slot.
|
|
for _, key := range keys {
|
|
if key.APIKey != "" {
|
|
apiKey = key.APIKey
|
|
break
|
|
}
|
|
}
|
|
return chatprovider.ConfiguredProvider{
|
|
ProviderID: provider.ID,
|
|
Provider: string(provider.Type),
|
|
APIKey: apiKey,
|
|
BaseURL: provider.BaseUrl,
|
|
CentralAPIKeyEnabled: true,
|
|
AllowUserAPIKey: p.allowBYOK,
|
|
AllowCentralAPIKeyFallback: true,
|
|
}, nil
|
|
}
|
|
|
|
func (p *Server) aiProviderConfigs(ctx context.Context, providers []database.AIProvider) ([]chatprovider.ConfiguredProvider, error) {
|
|
if len(providers) == 0 {
|
|
return nil, nil
|
|
}
|
|
providerIDs := make([]uuid.UUID, 0, len(providers))
|
|
for _, provider := range providers {
|
|
providerIDs = append(providerIDs, provider.ID)
|
|
}
|
|
keys, err := p.db.GetAIProviderKeysByProviderIDs(ctx, providerIDs)
|
|
if err != nil {
|
|
return nil, xerrors.Errorf("get AI provider keys: %w", err)
|
|
}
|
|
keysByProviderID := make(map[uuid.UUID][]database.AIProviderKey, len(providers))
|
|
for _, key := range keys {
|
|
keysByProviderID[key.ProviderID] = append(keysByProviderID[key.ProviderID], key)
|
|
}
|
|
configuredProviders := make([]chatprovider.ConfiguredProvider, 0, len(providers))
|
|
for _, provider := range providers {
|
|
configuredProvider, err := p.aiProviderConfigFromKeys(provider, keysByProviderID[provider.ID])
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
configuredProviders = append(configuredProviders, configuredProvider)
|
|
}
|
|
return configuredProviders, nil
|
|
}
|
|
|
|
func ensureUniqueConfiguredProviderTypes(providers []chatprovider.ConfiguredProvider) error {
|
|
seen := make(map[string]uuid.UUID, len(providers))
|
|
for _, provider := range providers {
|
|
normalizedProvider := chatprovider.NormalizeProvider(provider.Provider)
|
|
if normalizedProvider == "" {
|
|
continue
|
|
}
|
|
if existingProviderID, ok := seen[normalizedProvider]; ok && existingProviderID != provider.ProviderID {
|
|
return xerrors.Errorf("multiple enabled AI providers use provider type %q; select an AI provider by ID", normalizedProvider)
|
|
}
|
|
seen[normalizedProvider] = provider.ProviderID
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (p *Server) resolveUserProviderAPIKeysForProvider(
|
|
ctx context.Context,
|
|
ownerID uuid.UUID,
|
|
provider database.AIProvider,
|
|
) (chatprovider.ProviderAPIKeys, error) {
|
|
configuredProvider, err := p.aiProviderConfig(ctx, provider)
|
|
if err != nil {
|
|
return chatprovider.ProviderAPIKeys{}, err
|
|
}
|
|
userKeys := []chatprovider.UserProviderKey{}
|
|
if p.allowBYOK {
|
|
userKey, err := p.db.GetUserAIProviderKeyByProviderID(ctx, database.GetUserAIProviderKeyByProviderIDParams{
|
|
UserID: ownerID,
|
|
AIProviderID: provider.ID,
|
|
})
|
|
if err != nil && !xerrors.Is(err, sql.ErrNoRows) {
|
|
return chatprovider.ProviderAPIKeys{}, xerrors.Errorf("get user AI provider key: %w", err)
|
|
}
|
|
if err == nil {
|
|
userKeys = append(userKeys, chatprovider.UserProviderKey{
|
|
ChatProviderID: userKey.AIProviderID,
|
|
APIKey: userKey.APIKey,
|
|
})
|
|
}
|
|
}
|
|
keys, _ := chatprovider.ResolveUserProviderKeys(
|
|
chatprovider.ProviderAPIKeys{},
|
|
[]chatprovider.ConfiguredProvider{configuredProvider},
|
|
userKeys,
|
|
)
|
|
return keys, nil
|
|
}
|
|
|
|
func (p *Server) resolveUserProviderAPIKeysForProviderType(
|
|
ctx context.Context,
|
|
ownerID uuid.UUID,
|
|
providerType string,
|
|
) (chatprovider.ProviderAPIKeys, error) {
|
|
keys, _, err := p.resolveUserProviderAPIKeysAndProviderForProviderType(ctx, ownerID, providerType)
|
|
return keys, err
|
|
}
|
|
|
|
func (p *Server) resolveUserProviderAPIKeysAndProviderForProviderType(
|
|
ctx context.Context,
|
|
ownerID uuid.UUID,
|
|
providerType string,
|
|
) (chatprovider.ProviderAPIKeys, *database.AIProvider, error) {
|
|
providers, err := p.db.GetAIProviders(ctx, database.GetAIProvidersParams{})
|
|
if err != nil {
|
|
return chatprovider.ProviderAPIKeys{}, nil, xerrors.Errorf("get enabled AI providers: %w", err)
|
|
}
|
|
normalizedProviderType := chatprovider.NormalizeProvider(providerType)
|
|
for _, provider := range providers {
|
|
if chatprovider.NormalizeProvider(string(provider.Type)) != normalizedProviderType {
|
|
continue
|
|
}
|
|
keys, err := p.resolveUserProviderAPIKeysForProvider(ctx, ownerID, provider)
|
|
if err != nil {
|
|
return chatprovider.ProviderAPIKeys{}, nil, err
|
|
}
|
|
if userCanUseProviderKeys(keys, normalizedProviderType) {
|
|
return keys, &provider, nil
|
|
}
|
|
}
|
|
keys, err := p.resolveUserProviderAPIKeys(ctx, ownerID, uuid.Nil)
|
|
if err != nil {
|
|
return chatprovider.ProviderAPIKeys{}, nil, err
|
|
}
|
|
return keys, nil, nil
|
|
}
|
|
|
|
func (p *Server) resolveUserProviderAPIKeys(
|
|
ctx context.Context,
|
|
ownerID uuid.UUID,
|
|
selectedAIProviderID uuid.UUID,
|
|
) (chatprovider.ProviderAPIKeys, error) {
|
|
if selectedAIProviderID != uuid.Nil {
|
|
provider, err := p.db.GetAIProviderByID(ctx, selectedAIProviderID)
|
|
if err != nil {
|
|
return chatprovider.ProviderAPIKeys{}, xerrors.Errorf("get AI provider: %w", err)
|
|
}
|
|
return p.resolveUserProviderAPIKeysForProvider(ctx, ownerID, provider)
|
|
}
|
|
|
|
providers, err := p.configCache.EnabledProviders(ctx)
|
|
if err != nil {
|
|
return chatprovider.ProviderAPIKeys{}, xerrors.Errorf(
|
|
"get enabled AI providers: %w",
|
|
err,
|
|
)
|
|
}
|
|
configuredProviders, err := p.aiProviderConfigs(ctx, providers)
|
|
if err != nil {
|
|
return chatprovider.ProviderAPIKeys{}, err
|
|
}
|
|
if err := ensureUniqueConfiguredProviderTypes(configuredProviders); err != nil {
|
|
return chatprovider.ProviderAPIKeys{}, err
|
|
}
|
|
|
|
userKeys := []chatprovider.UserProviderKey{}
|
|
if p.allowBYOK {
|
|
userKeyRows, err := p.db.GetUserAIProviderKeysByUserID(ctx, ownerID)
|
|
if err != nil {
|
|
return chatprovider.ProviderAPIKeys{}, xerrors.Errorf(
|
|
"get user AI provider keys: %w",
|
|
err,
|
|
)
|
|
}
|
|
userKeys = make([]chatprovider.UserProviderKey, 0, len(userKeyRows))
|
|
for _, userKey := range userKeyRows {
|
|
userKeys = append(userKeys, chatprovider.UserProviderKey{
|
|
ChatProviderID: userKey.AIProviderID,
|
|
APIKey: userKey.APIKey,
|
|
})
|
|
}
|
|
}
|
|
|
|
keys, _ := chatprovider.ResolveUserProviderKeys(
|
|
p.providerAPIKeys,
|
|
configuredProviders,
|
|
userKeys,
|
|
)
|
|
enabledProviders := make(map[string]struct{}, len(configuredProviders))
|
|
for _, provider := range configuredProviders {
|
|
normalizedProvider := chatprovider.NormalizeProvider(provider.Provider)
|
|
if normalizedProvider == "" {
|
|
continue
|
|
}
|
|
enabledProviders[normalizedProvider] = struct{}{}
|
|
}
|
|
chatprovider.PruneDisabledProviderKeys(&keys, enabledProviders)
|
|
return keys, nil
|
|
}
|
|
|
|
// resolveModelConfig looks up the chat's model config by its
|
|
// LastModelConfigID. If the referenced config no longer exists
|
|
// (e.g. it was deleted), it falls back to the default model
|
|
// config. Returns an error when no usable config is available.
|
|
func (p *Server) resolveModelConfig(
|
|
ctx context.Context,
|
|
chat database.Chat,
|
|
) (database.ChatModelConfig, error) {
|
|
if chat.LastModelConfigID != uuid.Nil {
|
|
modelConfig, err := p.configCache.ModelConfigByID(
|
|
ctx, chat.LastModelConfigID,
|
|
)
|
|
if err == nil {
|
|
return modelConfig, nil
|
|
}
|
|
if !xerrors.Is(err, sql.ErrNoRows) {
|
|
return database.ChatModelConfig{}, xerrors.Errorf(
|
|
"get chat model config %s: %w",
|
|
chat.LastModelConfigID, err,
|
|
)
|
|
}
|
|
// Model config was deleted, fall through to default.
|
|
}
|
|
|
|
defaultConfig, err := p.configCache.DefaultModelConfig(ctx)
|
|
if err != nil {
|
|
if xerrors.Is(err, sql.ErrNoRows) {
|
|
return database.ChatModelConfig{}, xerrors.New(
|
|
"no default chat model config is available",
|
|
)
|
|
}
|
|
return database.ChatModelConfig{}, xerrors.Errorf(
|
|
"get default chat model config: %w", err,
|
|
)
|
|
}
|
|
return defaultConfig, nil
|
|
}
|
|
|
|
func refreshChatWorkspaceSnapshot(
|
|
ctx context.Context,
|
|
chat database.Chat,
|
|
loadChat func(context.Context, uuid.UUID) (database.Chat, error),
|
|
) (database.Chat, error) {
|
|
if chat.WorkspaceID.Valid || loadChat == nil {
|
|
return chat, nil
|
|
}
|
|
|
|
refreshedChat, err := loadChat(ctx, chat.ID)
|
|
if err != nil {
|
|
return chat, xerrors.Errorf("reload chat workspace state: %w", err)
|
|
}
|
|
|
|
return refreshedChat, nil
|
|
}
|
|
|
|
// contextFileAgentID extracts the workspace agent ID from the most
|
|
// recent persisted instruction-file parts. The skill-only sentinel is
|
|
// ignored because it does not represent persisted instruction content.
|
|
// Returns uuid.Nil, false if no instruction-file parts exist.
|
|
func contextFileAgentID(messages []database.ChatMessage) (uuid.UUID, bool) {
|
|
var lastID uuid.UUID
|
|
found := false
|
|
for _, msg := range messages {
|
|
if !msg.Content.Valid || !bytes.Contains(msg.Content.RawMessage, []byte(`"context-file"`)) {
|
|
continue
|
|
}
|
|
var parts []codersdk.ChatMessagePart
|
|
if err := json.Unmarshal(msg.Content.RawMessage, &parts); err != nil {
|
|
continue
|
|
}
|
|
for _, p := range parts {
|
|
if p.Type != codersdk.ChatMessagePartTypeContextFile ||
|
|
!p.ContextFileAgentID.Valid ||
|
|
p.ContextFilePath == AgentChatContextSentinelPath {
|
|
continue
|
|
}
|
|
lastID = p.ContextFileAgentID.UUID
|
|
found = true
|
|
break
|
|
}
|
|
}
|
|
return lastID, found
|
|
}
|
|
|
|
// fetchWorkspaceContext retrieves fresh instruction files and
|
|
// skills from the workspace agent without persisting. It handles
|
|
// agent connection, context configuration fetching, content
|
|
// sanitization, and metadata stamping. Returns the workspace
|
|
// agent, the stamped parts, discovered skills, and whether the
|
|
// workspace connection succeeded. A nil agent means the chat has
|
|
// no valid workspace or the agent lookup failed;
|
|
// workspaceConnOK is false in that case.
|
|
func (p *Server) fetchWorkspaceContext(
|
|
ctx context.Context,
|
|
chat database.Chat,
|
|
getWorkspaceAgent func(context.Context) (database.WorkspaceAgent, error),
|
|
getWorkspaceConn func(context.Context) (workspacesdk.AgentConn, error),
|
|
) (agent *database.WorkspaceAgent, agentParts []codersdk.ChatMessagePart, discoveredSkills []chattool.SkillMeta, workspaceConnOK bool) {
|
|
if !chat.WorkspaceID.Valid || getWorkspaceAgent == nil {
|
|
return nil, nil, nil, false
|
|
}
|
|
|
|
loadedAgent, agentErr := getWorkspaceAgent(ctx)
|
|
if agentErr != nil {
|
|
return nil, nil, nil, false
|
|
}
|
|
|
|
directory := loadedAgent.ExpandedDirectory
|
|
if directory == "" {
|
|
directory = loadedAgent.Directory
|
|
}
|
|
|
|
// Fetch context configuration from the agent. Parts
|
|
// arrive pre-populated with context-file and skill entries
|
|
// so we don't need additional round-trips.
|
|
if getWorkspaceConn != nil {
|
|
instructionCtx, cancel := context.WithTimeout(ctx, p.instructionLookupTimeout)
|
|
defer cancel()
|
|
|
|
conn, connErr := getWorkspaceConn(instructionCtx)
|
|
if connErr != nil {
|
|
p.logger.Debug(ctx, "failed to resolve workspace connection for instruction files",
|
|
slog.F("chat_id", chat.ID),
|
|
slog.Error(connErr),
|
|
)
|
|
} else {
|
|
workspaceConnOK = true
|
|
|
|
agentCfg, cfgErr := conn.ContextConfig(instructionCtx)
|
|
if cfgErr != nil {
|
|
p.logger.Debug(ctx, "failed to fetch context config from agent",
|
|
slog.F("chat_id", chat.ID), slog.Error(cfgErr))
|
|
// Treat a transient ContextConfig failure the
|
|
// same as a failed connection so no sentinel is
|
|
// persisted. The next turn will retry.
|
|
workspaceConnOK = false
|
|
} else {
|
|
agentParts = agentCfg.Parts
|
|
}
|
|
}
|
|
}
|
|
|
|
// Stamp server-side fields and sanitize content. The
|
|
// agent cannot know its own UUID, OS metadata, or
|
|
// directory — those are added here at the trust boundary.
|
|
agentID := uuid.NullUUID{UUID: loadedAgent.ID, Valid: true}
|
|
|
|
for i := range agentParts {
|
|
agentParts[i].ContextFileAgentID = agentID
|
|
switch agentParts[i].Type {
|
|
case codersdk.ChatMessagePartTypeContextFile:
|
|
agentParts[i].ContextFileContent = SanitizePromptText(agentParts[i].ContextFileContent)
|
|
agentParts[i].ContextFileOS = loadedAgent.OperatingSystem
|
|
agentParts[i].ContextFileDirectory = directory
|
|
case codersdk.ChatMessagePartTypeSkill:
|
|
discoveredSkills = append(discoveredSkills, chattool.SkillMeta{
|
|
Name: agentParts[i].SkillName,
|
|
Description: agentParts[i].SkillDescription,
|
|
Dir: agentParts[i].SkillDir,
|
|
MetaFile: agentParts[i].ContextFileSkillMetaFile,
|
|
})
|
|
}
|
|
}
|
|
|
|
return &loadedAgent, agentParts, discoveredSkills, workspaceConnOK
|
|
}
|
|
|
|
func filterSkillParts(parts []codersdk.ChatMessagePart) []codersdk.ChatMessagePart {
|
|
var filtered []codersdk.ChatMessagePart
|
|
for _, part := range parts {
|
|
if part.Type == codersdk.ChatMessagePartTypeSkill {
|
|
filtered = append(filtered, part)
|
|
}
|
|
}
|
|
return filtered
|
|
}
|
|
|
|
// persistInstructionFiles fetches AGENTS.md instruction files and
|
|
// skills from the workspace agent, persisting both as message
|
|
// parts. This is called once when a workspace is first attached
|
|
// to a chat (or when the agent changes). Returns the formatted
|
|
// instruction string and skill index for injection into the
|
|
// current turn's prompt.
|
|
func (p *Server) persistInstructionFiles(
|
|
ctx context.Context,
|
|
chat database.Chat,
|
|
modelConfigID uuid.UUID,
|
|
getWorkspaceAgent func(context.Context) (database.WorkspaceAgent, error),
|
|
getWorkspaceConn func(context.Context) (workspacesdk.AgentConn, error),
|
|
) (instruction string, skills []chattool.SkillMeta, err error) {
|
|
agent, agentParts, discoveredSkills, workspaceConnOK := p.fetchWorkspaceContext(
|
|
ctx, chat, getWorkspaceAgent, getWorkspaceConn,
|
|
)
|
|
if agent == nil {
|
|
return "", nil, nil
|
|
}
|
|
|
|
agentID := uuid.NullUUID{UUID: agent.ID, Valid: true}
|
|
hasContent := false
|
|
hasContextFilePart := false
|
|
for _, part := range agentParts {
|
|
if part.Type == codersdk.ChatMessagePartTypeContextFile {
|
|
hasContextFilePart = true
|
|
if part.ContextFileContent != "" {
|
|
hasContent = true
|
|
}
|
|
}
|
|
}
|
|
directory := agent.ExpandedDirectory
|
|
if directory == "" {
|
|
directory = agent.Directory
|
|
}
|
|
|
|
contextAPIKeyID, _ := aibridge.DelegatedAPIKeyIDFromContext(ctx)
|
|
if !hasContent {
|
|
if !workspaceConnOK {
|
|
return "", nil, nil
|
|
}
|
|
if !hasContextFilePart {
|
|
agentParts = append([]codersdk.ChatMessagePart{{
|
|
Type: codersdk.ChatMessagePartTypeContextFile,
|
|
ContextFileAgentID: agentID,
|
|
}}, agentParts...)
|
|
}
|
|
content, err := chatprompt.MarshalParts(agentParts)
|
|
if err != nil {
|
|
return "", nil, nil
|
|
}
|
|
msgParams := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendUserChatMessage.
|
|
ChatID: chat.ID,
|
|
}
|
|
appendUserChatMessage(&msgParams, newUserChatMessage(
|
|
contextAPIKeyID,
|
|
content,
|
|
database.ChatMessageVisibilityBoth,
|
|
modelConfigID,
|
|
chatprompt.CurrentContentVersion,
|
|
))
|
|
_, _ = p.db.InsertChatMessages(ctx, msgParams)
|
|
skillParts := filterSkillParts(agentParts)
|
|
p.updateLastInjectedContext(ctx, chat.ID, skillParts)
|
|
return "", discoveredSkills, nil
|
|
}
|
|
content, err := chatprompt.MarshalParts(agentParts)
|
|
if err != nil {
|
|
return "", nil, xerrors.Errorf("marshal context-file parts: %w", err)
|
|
}
|
|
|
|
msgParams := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendUserChatMessage.
|
|
ChatID: chat.ID,
|
|
}
|
|
appendUserChatMessage(&msgParams, newUserChatMessage(
|
|
contextAPIKeyID,
|
|
content,
|
|
database.ChatMessageVisibilityBoth,
|
|
modelConfigID,
|
|
chatprompt.CurrentContentVersion,
|
|
))
|
|
if _, err := p.db.InsertChatMessages(ctx, msgParams); err != nil {
|
|
return "", nil, xerrors.Errorf("persist instruction files: %w", err)
|
|
}
|
|
stripped := make([]codersdk.ChatMessagePart, len(agentParts))
|
|
copy(stripped, agentParts)
|
|
for i := range stripped {
|
|
stripped[i].StripInternal()
|
|
}
|
|
p.updateLastInjectedContext(ctx, chat.ID, stripped)
|
|
|
|
return formatSystemInstructions(agent.OperatingSystem, directory, agentParts), discoveredSkills, nil
|
|
}
|
|
|
|
// updateLastInjectedContext persists the injected context
|
|
// parts (AGENTS.md files and skills) on the chat row so they
|
|
// are directly queryable without scanning messages. This is
|
|
// best-effort — a failure here is logged but does not block
|
|
// the turn.
|
|
func (p *Server) updateLastInjectedContext(ctx context.Context, chatID uuid.UUID, parts []codersdk.ChatMessagePart) {
|
|
param := pqtype.NullRawMessage{Valid: false}
|
|
if parts != nil {
|
|
raw, err := json.Marshal(parts)
|
|
if err != nil {
|
|
p.logger.Warn(ctx, "failed to marshal injected context",
|
|
slog.F("chat_id", chatID),
|
|
slog.Error(err),
|
|
)
|
|
return
|
|
}
|
|
param = pqtype.NullRawMessage{RawMessage: raw, Valid: true}
|
|
}
|
|
if _, err := p.db.UpdateChatLastInjectedContext(ctx, database.UpdateChatLastInjectedContextParams{
|
|
ID: chatID,
|
|
LastInjectedContext: param,
|
|
}); err != nil {
|
|
p.logger.Warn(ctx, "failed to update injected context",
|
|
slog.F("chat_id", chatID),
|
|
slog.Error(err),
|
|
)
|
|
}
|
|
}
|
|
|
|
// resolveUserCompactionThreshold looks up the user's per-model
|
|
// compaction threshold override. Returns the override value and
|
|
// true if one exists and is valid, or 0 and false otherwise.
|
|
func (p *Server) resolveUserCompactionThreshold(ctx context.Context, userID uuid.UUID, modelConfigID uuid.UUID) (int32, bool) {
|
|
raw, err := p.db.GetUserChatCompactionThreshold(ctx, database.GetUserChatCompactionThresholdParams{
|
|
UserID: userID,
|
|
Key: codersdk.CompactionThresholdKey(modelConfigID),
|
|
})
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return 0, false
|
|
}
|
|
if err != nil {
|
|
p.logger.Warn(ctx, "failed to fetch compaction threshold override",
|
|
slog.F("user_id", userID),
|
|
slog.F("model_config_id", modelConfigID),
|
|
slog.Error(err),
|
|
)
|
|
return 0, false
|
|
}
|
|
// Range 0..100 must stay in sync with handler validation in
|
|
// coderd/chats.go.
|
|
val, err := strconv.ParseInt(raw, 10, 32)
|
|
if err != nil || val < 0 || val > 100 {
|
|
return 0, false
|
|
}
|
|
return int32(val), true
|
|
}
|
|
|
|
// resolveDeploymentSystemPrompt builds the deployment-level system
|
|
// prompt from the built-in default and the admin-configured custom
|
|
// prompt stored in site_configs.
|
|
func (p *Server) resolveDeploymentSystemPrompt(ctx context.Context) string {
|
|
config, err := p.db.GetChatSystemPromptConfig(ctx)
|
|
if err != nil {
|
|
// Fail open: use the built-in default so chats always have
|
|
// some system guidance.
|
|
p.logger.Error(ctx, "failed to fetch chat system prompt configuration, using default", slog.Error(err))
|
|
return DefaultSystemPrompt
|
|
}
|
|
|
|
sanitizedCustom := SanitizePromptText(config.ChatSystemPrompt)
|
|
if sanitizedCustom == "" && strings.TrimSpace(config.ChatSystemPrompt) != "" {
|
|
p.logger.Warn(ctx, "custom system prompt became empty after sanitization, omitting custom portion")
|
|
}
|
|
|
|
var parts []string
|
|
if config.IncludeDefaultSystemPrompt {
|
|
parts = append(parts, DefaultSystemPrompt)
|
|
}
|
|
if sanitizedCustom != "" {
|
|
parts = append(parts, sanitizedCustom)
|
|
}
|
|
result := strings.Join(parts, "\n\n")
|
|
if result == "" {
|
|
p.logger.Warn(ctx, "resolved system prompt is empty, no system prompt will be injected into chats")
|
|
}
|
|
return result
|
|
}
|
|
|
|
// resolveUserPrompt fetches the user's custom chat prompt from the
|
|
// database and wraps it in <user-instructions> tags. Returns empty
|
|
// string if no prompt is set.
|
|
func (p *Server) resolveUserPrompt(ctx context.Context, userID uuid.UUID) string {
|
|
raw, err := p.configCache.UserPrompt(ctx, userID)
|
|
if err != nil {
|
|
// sql.ErrNoRows is the normal "not set" case.
|
|
return ""
|
|
}
|
|
trimmed := strings.TrimSpace(raw)
|
|
if trimmed == "" {
|
|
return ""
|
|
}
|
|
return "<user-instructions>\n" + trimmed + "\n</user-instructions>"
|
|
}
|
|
|
|
// renderPlanPathPrompt fills the plan-path placeholder when it is
|
|
// present in the prompt.
|
|
func renderPlanPathPrompt(prompt []fantasy.Message, planPathBlock string) []fantasy.Message {
|
|
prompt, _ = replacePlanPathPlaceholder(prompt, planPathBlock)
|
|
return prompt
|
|
}
|
|
|
|
func replacePlanPathPlaceholder(
|
|
prompt []fantasy.Message,
|
|
planPathBlock string,
|
|
) ([]fantasy.Message, bool) {
|
|
var updatedPrompt []fantasy.Message
|
|
replaced := false
|
|
for i, message := range prompt {
|
|
updatedMessage, ok := replacePlanPathPlaceholderInMessage(message, planPathBlock)
|
|
if !ok {
|
|
continue
|
|
}
|
|
if updatedPrompt == nil {
|
|
updatedPrompt = slices.Clone(prompt)
|
|
}
|
|
updatedPrompt[i] = updatedMessage
|
|
replaced = true
|
|
}
|
|
if !replaced {
|
|
return prompt, false
|
|
}
|
|
return updatedPrompt, true
|
|
}
|
|
|
|
func replacePlanPathPlaceholderInMessage(
|
|
message fantasy.Message,
|
|
planPathBlock string,
|
|
) (fantasy.Message, bool) {
|
|
if message.Role != fantasy.MessageRoleSystem {
|
|
return message, false
|
|
}
|
|
|
|
content := slices.Clone(message.Content)
|
|
replaced := false
|
|
for i, part := range content {
|
|
textPart, ok := fantasy.AsMessagePart[fantasy.TextPart](part)
|
|
if !ok || !strings.Contains(textPart.Text, defaultSystemPromptPlanPathBlockPlaceholder) {
|
|
continue
|
|
}
|
|
replaced = true
|
|
content[i] = fantasy.TextPart{Text: strings.ReplaceAll(
|
|
textPart.Text,
|
|
defaultSystemPromptPlanPathBlockPlaceholder,
|
|
planPathBlock,
|
|
)}
|
|
}
|
|
if !replaced {
|
|
return message, false
|
|
}
|
|
message.Content = content
|
|
return message, true
|
|
}
|
|
|
|
func formatPlanPathBlock(chatPath, home string) string {
|
|
chatPath = strings.TrimSpace(chatPath)
|
|
if chatPath == "" {
|
|
return ""
|
|
}
|
|
|
|
avoidPlanPath := chattool.LegacySharedPlanPath
|
|
home = strings.TrimSpace(home)
|
|
if home != "" {
|
|
avoidPlanPath = strings.TrimRight(home, "/") + "/PLAN.md"
|
|
}
|
|
|
|
var b strings.Builder
|
|
_, _ = b.WriteString("<plan-file-path>\n")
|
|
_, _ = b.WriteString("Your plan file path for this chat is: ")
|
|
_, _ = b.WriteString(chatPath)
|
|
_, _ = b.WriteString("\n")
|
|
_, _ = b.WriteString("Always use this exact path when creating or proposing plan files. Do not use ")
|
|
_, _ = b.WriteString(avoidPlanPath)
|
|
_, _ = b.WriteString(".\n")
|
|
_, _ = b.WriteString("</plan-file-path>")
|
|
return b.String()
|
|
}
|
|
|
|
// parseDynamicToolNames unmarshals the dynamic tools JSON column
|
|
// and returns a map of tool names. This centralizes the repeated
|
|
// pattern of deserializing DynamicTools into a name set.
|
|
func parseDynamicToolNames(raw pqtype.NullRawMessage) (map[string]bool, error) {
|
|
if !raw.Valid || len(raw.RawMessage) == 0 {
|
|
return make(map[string]bool), nil
|
|
}
|
|
var tools []codersdk.DynamicTool
|
|
if err := json.Unmarshal(raw.RawMessage, &tools); err != nil {
|
|
return nil, xerrors.Errorf("unmarshal dynamic tools: %w", err)
|
|
}
|
|
names := make(map[string]bool, len(tools))
|
|
for _, t := range tools {
|
|
names[t.Name] = true
|
|
}
|
|
return names, nil
|
|
}
|
|
|
|
// maybeFinalizeTurnStatusLabelAndPush updates the cached turn status label
|
|
// for parent chats and optionally sends a web push notification.
|
|
func (p *Server) maybeFinalizeTurnStatusLabelAndPush(
|
|
ctx context.Context,
|
|
chat database.Chat,
|
|
status database.ChatStatus,
|
|
lastError string,
|
|
runResult runChatResult,
|
|
logger slog.Logger,
|
|
) {
|
|
if chat.ParentChatID.Valid {
|
|
return
|
|
}
|
|
|
|
switch status {
|
|
case database.ChatStatusWaiting:
|
|
p.finalizeSuccessfulTurnStatusLabelAndPush(ctx, chat, status, runResult, logger)
|
|
|
|
case database.ChatStatusPending:
|
|
p.setLastTurnSummaryAsync(ctx, chat, fallbackTurnStatusLabel(status), logger)
|
|
|
|
case database.ChatStatusError:
|
|
p.clearLastTurnSummaryAsync(ctx, chat, logger)
|
|
if p.webpushConfigured() {
|
|
pushBody := fallbackTurnStatusLabel(status)
|
|
if lastError != "" {
|
|
pushBody = lastError
|
|
}
|
|
p.dispatchPush(ctx, chat, pushBody, status, logger)
|
|
}
|
|
|
|
case database.ChatStatusRequiresAction:
|
|
p.setLastTurnSummaryAsync(ctx, chat, fallbackTurnStatusLabel(status), logger)
|
|
|
|
default:
|
|
// New statuses must be classified before they can safely
|
|
// preserve or finalize a cached turn status label.
|
|
p.clearLastTurnSummaryAsync(ctx, chat, logger)
|
|
}
|
|
}
|
|
|
|
func (p *Server) finalizeSuccessfulTurnStatusLabelAndPush(
|
|
ctx context.Context,
|
|
chat database.Chat,
|
|
status database.ChatStatus,
|
|
runResult runChatResult,
|
|
logger slog.Logger,
|
|
) {
|
|
p.finalizeSuccessfulTurnStatusLabelWithAfterFunc(ctx, chat, status, runResult, logger, func(finalizeCtx context.Context, statusLabel string) {
|
|
p.dispatchSuccessfulTurnPush(finalizeCtx, chat, statusLabel, logger)
|
|
})
|
|
}
|
|
|
|
func (p *Server) finalizeSuccessfulTurnStatusLabelWithAfterFunc(
|
|
ctx context.Context,
|
|
chat database.Chat,
|
|
status database.ChatStatus,
|
|
runResult runChatResult,
|
|
logger slog.Logger,
|
|
afterFinalize func(context.Context, string),
|
|
) {
|
|
// This helper runs during processChat cleanup, while processChat is
|
|
// still counted in p.inflight. Do not take inflightMu here because
|
|
// drainInflight holds it while waiting.
|
|
p.inflight.Go(func() {
|
|
finalizeCtx := context.WithoutCancel(ctx)
|
|
statusLabel := p.generateFinalTurnStatusLabel(finalizeCtx, chat, status, runResult, logger)
|
|
logger.Debug(finalizeCtx, "generated chat turn status label",
|
|
slog.F("chat_id", chat.ID),
|
|
slog.F("status", status),
|
|
slog.F("label_length", len(statusLabel)),
|
|
)
|
|
|
|
p.updateLastTurnSummary(finalizeCtx, chat, chat.HistoryVersion, statusLabel, logger)
|
|
|
|
afterFinalize(finalizeCtx, statusLabel)
|
|
})
|
|
}
|
|
|
|
func (p *Server) generateFinalTurnStatusLabel(
|
|
ctx context.Context,
|
|
chat database.Chat,
|
|
status database.ChatStatus,
|
|
runResult runChatResult,
|
|
logger slog.Logger,
|
|
) string {
|
|
if status != database.ChatStatusWaiting {
|
|
return fallbackTurnStatusLabel(status)
|
|
}
|
|
|
|
assistantText := strings.TrimSpace(runResult.FinalAssistantText)
|
|
if assistantText == "" || runResult.StatusLabelModel == nil {
|
|
return fallbackTurnStatusLabel(status)
|
|
}
|
|
|
|
statusLabel := p.generateTurnStatusLabel(
|
|
ctx,
|
|
chat,
|
|
status,
|
|
assistantText,
|
|
runResult.FallbackProvider,
|
|
runResult.FallbackModel,
|
|
runResult.StatusLabelModel,
|
|
runResult.FallbackRoute,
|
|
runResult.ProviderKeys,
|
|
runResult.ModelBuildOptions,
|
|
logger,
|
|
p.existingDebugService(),
|
|
runResult.TriggerMessageID,
|
|
runResult.HistoryTipMessageID,
|
|
)
|
|
if statusLabel == "" {
|
|
return fallbackTurnStatusLabel(status)
|
|
}
|
|
return statusLabel
|
|
}
|
|
|
|
func (p *Server) dispatchSuccessfulTurnPush(
|
|
ctx context.Context,
|
|
chat database.Chat,
|
|
statusLabel string,
|
|
logger slog.Logger,
|
|
) {
|
|
if !p.webpushConfigured() {
|
|
return
|
|
}
|
|
pushBody := fallbackTurnStatusLabel(database.ChatStatusWaiting)
|
|
if statusLabel != "" {
|
|
pushBody = statusLabel
|
|
}
|
|
p.dispatchPush(ctx, chat, pushBody, database.ChatStatusWaiting, logger)
|
|
}
|
|
|
|
func (p *Server) maybeClearLastTurnSummaryAsync(
|
|
ctx context.Context,
|
|
chat database.Chat,
|
|
logger slog.Logger,
|
|
) {
|
|
if chat.ParentChatID.Valid {
|
|
return
|
|
}
|
|
p.clearLastTurnSummaryAsync(ctx, chat, logger)
|
|
}
|
|
|
|
func (p *Server) setLastTurnSummaryAsync(
|
|
ctx context.Context,
|
|
chat database.Chat,
|
|
summary string,
|
|
logger slog.Logger,
|
|
) {
|
|
summary = strings.TrimSpace(summary)
|
|
if summary == "" {
|
|
p.clearLastTurnSummaryAsync(ctx, chat, logger)
|
|
return
|
|
}
|
|
if chat.LastTurnSummary.Valid && strings.TrimSpace(chat.LastTurnSummary.String) == summary {
|
|
return
|
|
}
|
|
// This helper runs during processChat cleanup, while processChat is
|
|
// still counted in p.inflight. Do not take inflightMu here because
|
|
// drainInflight holds it while waiting.
|
|
p.inflight.Go(func() {
|
|
p.updateLastTurnSummary(context.WithoutCancel(ctx), chat, chat.HistoryVersion, summary, logger)
|
|
})
|
|
}
|
|
|
|
func (p *Server) clearLastTurnSummaryAsync(
|
|
ctx context.Context,
|
|
chat database.Chat,
|
|
logger slog.Logger,
|
|
) {
|
|
// This helper runs during processChat cleanup, while processChat is
|
|
// still counted in p.inflight. Do not take inflightMu here because
|
|
// drainInflight holds it while waiting.
|
|
p.inflight.Go(func() {
|
|
p.updateLastTurnSummary(context.WithoutCancel(ctx), chat, chat.HistoryVersion, "", logger)
|
|
})
|
|
}
|
|
|
|
// updateLastTurnSummary writes the cached sidebar summary for a chat.
|
|
// Callers should pass a detached context because this method is used for
|
|
// best-effort background cache writes.
|
|
func (p *Server) updateLastTurnSummary(
|
|
ctx context.Context,
|
|
chat database.Chat,
|
|
expectedHistoryVersion int64,
|
|
summary string,
|
|
logger slog.Logger,
|
|
) {
|
|
summary = strings.TrimSpace(summary)
|
|
lastTurnSummary := sql.NullString{String: summary, Valid: summary != ""}
|
|
|
|
//nolint:gocritic // Narrow daemon access for best-effort summary cache writes.
|
|
updateCtx := dbauthz.AsChatd(ctx)
|
|
updateCtx, cancel := context.WithTimeout(updateCtx, turnStatusLabelWriteTimeout)
|
|
defer cancel()
|
|
|
|
affected, err := p.db.UpdateChatLastTurnSummary(updateCtx, database.UpdateChatLastTurnSummaryParams{
|
|
ID: chat.ID,
|
|
ExpectedHistoryVersion: expectedHistoryVersion,
|
|
LastTurnSummary: lastTurnSummary,
|
|
})
|
|
if err != nil {
|
|
logger.Warn(updateCtx, "failed to update chat turn summary",
|
|
slog.F("chat_id", chat.ID),
|
|
slog.Error(err),
|
|
)
|
|
return
|
|
}
|
|
if affected == 0 {
|
|
if summary != "" {
|
|
logger.Info(updateCtx, "skipped stale chat turn summary update with non-empty summary",
|
|
slog.F("chat_id", chat.ID),
|
|
slog.F("summary_length", len(summary)),
|
|
slog.F("expected_history_version", expectedHistoryVersion),
|
|
)
|
|
return
|
|
}
|
|
logger.Debug(updateCtx, "skipped stale chat turn summary update",
|
|
slog.F("chat_id", chat.ID),
|
|
slog.F("expected_history_version", expectedHistoryVersion),
|
|
)
|
|
return
|
|
}
|
|
|
|
updatedChat := chat
|
|
updatedChat.LastTurnSummary = lastTurnSummary
|
|
p.publishChatPubsubEvent(updatedChat, codersdk.ChatWatchEventKindSummaryChange, nil)
|
|
}
|
|
|
|
func (p *Server) webpushConfigured() bool {
|
|
return p.webpushDispatcher != nil && p.webpushDispatcher.PublicKey() != ""
|
|
}
|
|
|
|
func (p *Server) dispatchPush(
|
|
ctx context.Context,
|
|
chat database.Chat,
|
|
body string,
|
|
status database.ChatStatus,
|
|
logger slog.Logger,
|
|
) {
|
|
pushMsg := codersdk.WebpushMessage{
|
|
Title: chat.Title,
|
|
Body: body,
|
|
Icon: "/favicon.ico",
|
|
Data: map[string]string{"url": fmt.Sprintf("/agents/%s", chat.ID)},
|
|
}
|
|
if err := p.webpushDispatcher.Dispatch(ctx, chat.OwnerID, pushMsg); err != nil {
|
|
logger.Warn(ctx, "failed to send chat completion web push",
|
|
slog.F("chat_id", chat.ID),
|
|
slog.F("status", status),
|
|
slog.Error(err),
|
|
)
|
|
}
|
|
}
|
|
|
|
// Close stops the processor and waits for it to finish.
|
|
func (p *Server) Close() error {
|
|
if unsub := p.configCacheUnsubscribe; unsub != nil {
|
|
p.configCacheUnsubscribe = nil
|
|
unsub()
|
|
}
|
|
if p.chatWorker != nil {
|
|
if err := p.chatWorker.Close(); err != nil {
|
|
p.logger.Warn(context.Background(), "failed to close chat worker", slog.Error(err))
|
|
}
|
|
}
|
|
if p.messagePartBuffer != nil {
|
|
p.messagePartBuffer.Close()
|
|
}
|
|
p.cancel()
|
|
p.wg.Wait()
|
|
p.drainInflight()
|
|
return nil
|
|
}
|
|
|
|
// drainInflight waits for all in-flight operations to complete.
|
|
// It acquires inflightMu to prevent processOnce from spawning
|
|
// new goroutines (via inflight.Add) concurrently with Wait,
|
|
// which would violate sync.WaitGroup's contract.
|
|
//
|
|
// https://pkg.go.dev/sync#WaitGroup.Add
|
|
// > Note that calls with a positive delta that occur when the counter is zero must happen before a Wait.
|
|
func (p *Server) drainInflight() {
|
|
p.inflightMu.Lock()
|
|
p.inflight.Wait()
|
|
p.inflightMu.Unlock()
|
|
}
|
|
|
|
// refreshExpiredMCPTokens checks each MCP OAuth2 token and refreshes
|
|
// any that are expired (or about to expire). Tokens without a
|
|
// refresh_token or that fail to refresh are returned unchanged so the
|
|
// caller can still attempt the connection (which will likely fail with
|
|
// a 401 for the expired ones).
|
|
func (p *Server) refreshExpiredMCPTokens(
|
|
ctx context.Context,
|
|
logger slog.Logger,
|
|
configs []database.MCPServerConfig,
|
|
tokens []database.MCPServerUserToken,
|
|
) []database.MCPServerUserToken {
|
|
configsByID := make(map[uuid.UUID]database.MCPServerConfig, len(configs))
|
|
for _, cfg := range configs {
|
|
configsByID[cfg.ID] = cfg
|
|
}
|
|
|
|
result := slices.Clone(tokens)
|
|
|
|
var eg errgroup.Group
|
|
for i, tok := range result {
|
|
cfg, ok := configsByID[tok.MCPServerConfigID]
|
|
if !ok || cfg.AuthType != "oauth2" {
|
|
continue
|
|
}
|
|
if tok.RefreshToken == "" {
|
|
continue
|
|
}
|
|
|
|
eg.Go(func() error {
|
|
refreshed, err := p.refreshMCPTokenIfNeeded(ctx, logger, cfg, tok)
|
|
if err != nil {
|
|
logger.Warn(ctx, "failed to refresh MCP oauth2 token",
|
|
slog.F("server_slug", cfg.Slug),
|
|
slog.Error(err),
|
|
)
|
|
return nil
|
|
}
|
|
result[i] = refreshed
|
|
return nil
|
|
})
|
|
}
|
|
_ = eg.Wait()
|
|
|
|
return result
|
|
}
|
|
|
|
// refreshMCPTokenIfNeeded delegates to mcpclient.RefreshOAuth2Token
|
|
// and persists the result to the database when a refresh occurs.
|
|
// The logger should carry chat-scoped fields so log lines can be
|
|
// correlated with specific chat requests.
|
|
func (p *Server) refreshMCPTokenIfNeeded(
|
|
ctx context.Context,
|
|
logger slog.Logger,
|
|
cfg database.MCPServerConfig,
|
|
tok database.MCPServerUserToken,
|
|
) (database.MCPServerUserToken, error) {
|
|
result, err := mcpclient.RefreshOAuth2Token(ctx, cfg, tok)
|
|
if err != nil {
|
|
return tok, err
|
|
}
|
|
|
|
if !result.Refreshed {
|
|
return tok, nil
|
|
}
|
|
|
|
logger.Info(ctx, "refreshed MCP oauth2 token",
|
|
slog.F("server_slug", cfg.Slug),
|
|
slog.F("user_id", tok.UserID),
|
|
)
|
|
|
|
var expiry sql.NullTime
|
|
if !result.Expiry.IsZero() {
|
|
expiry = sql.NullTime{Time: result.Expiry, Valid: true}
|
|
}
|
|
|
|
//nolint:gocritic // Chatd needs system-level write access to
|
|
// persist the refreshed OAuth2 token for the user.
|
|
updated, err := p.db.UpsertMCPServerUserToken(
|
|
dbauthz.AsSystemRestricted(ctx),
|
|
database.UpsertMCPServerUserTokenParams{
|
|
MCPServerConfigID: tok.MCPServerConfigID,
|
|
UserID: tok.UserID,
|
|
AccessToken: result.AccessToken,
|
|
AccessTokenKeyID: sql.NullString{},
|
|
RefreshToken: result.RefreshToken,
|
|
RefreshTokenKeyID: sql.NullString{},
|
|
TokenType: result.TokenType,
|
|
Expiry: expiry,
|
|
},
|
|
)
|
|
if err != nil {
|
|
// The provider may have rotated the refresh token,
|
|
// invalidating the old one. Use the new token
|
|
// in-memory so at least this connection succeeds.
|
|
logger.Warn(ctx, "failed to persist refreshed MCP oauth2 token, using in-memory",
|
|
slog.F("server_slug", cfg.Slug),
|
|
slog.Error(err),
|
|
)
|
|
tok.AccessToken = result.AccessToken
|
|
tok.RefreshToken = result.RefreshToken
|
|
tok.TokenType = result.TokenType
|
|
tok.Expiry = expiry
|
|
return tok, nil
|
|
}
|
|
|
|
return updated, nil
|
|
}
|