diff --git a/cli/server.go b/cli/server.go index 758369de30..859d74996e 100644 --- a/cli/server.go +++ b/cli/server.go @@ -1155,7 +1155,7 @@ func (r *RootCmd) Server(newAPI func(context.Context, *coderd.Options) (*coderd. defer shutdownConns() // Ensures that old database entries are cleaned up over time! - purger := dbpurge.New(ctx, logger.Named("dbpurge"), options.Database, options.DeploymentValues, options.PrometheusRegistry, &coderAPI.Auditor, dbpurge.WithNotificationsEnqueuer(options.NotificationsEnqueuer)) + purger := dbpurge.New(ctx, logger.Named("dbpurge"), options.Database, options.DeploymentValues, options.PrometheusRegistry, &coderAPI.Auditor) defer purger.Close() // Updates workspace usage diff --git a/coderd/coderd.go b/coderd/coderd.go index 661cdf3520..a43bedcd02 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -835,6 +835,8 @@ func New(options *Options) *API { UsageTracker: options.WorkspaceUsageTracker, PrometheusRegistry: options.PrometheusRegistry, OIDCTokenSource: oidcMCPSrc, + NotificationsEnqueuer: options.NotificationsEnqueuer, + Auditor: &api.Auditor, }).Start() gitSyncLogger := options.Logger.Named("gitsync") refresher := gitsync.NewRefresher( diff --git a/coderd/database/dbpurge/dbpurge.go b/coderd/database/dbpurge/dbpurge.go index c87bc5a8df..646bd7edd9 100644 --- a/coderd/database/dbpurge/dbpurge.go +++ b/coderd/database/dbpurge/dbpurge.go @@ -1,18 +1,12 @@ package dbpurge import ( - "cmp" "context" "errors" "io" - "net/http" - "slices" - "strconv" "sync/atomic" "time" - "github.com/dustin/go-humanize" - "github.com/google/uuid" "github.com/prometheus/client_golang/prometheus" "golang.org/x/xerrors" @@ -21,9 +15,7 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/database/dbtime" - "github.com/coder/coder/v2/coderd/notifications" "github.com/coder/coder/v2/coderd/pproflabel" - "github.com/coder/coder/v2/coderd/util/slice" "github.com/coder/coder/v2/codersdk" "github.com/coder/quartz" ) @@ -52,18 +44,8 @@ const ( // Chat debug run deletions can cascade into steps with large JSONB // payloads, so they use the same conservative batch size. chatDebugRunsBatchSize = 1000 - // chatAutoArchiveDigestMaxChats bounds how many chat titles a - // single digest body lists. Past the cap, surplus titles are - // summarized as "...and N more". 25 is a readable email-friendly - // length; the cap is unrelated to chatAutoArchiveBatchSize, which - // bounds work per tick. - chatAutoArchiveDigestMaxChats = 25 ) -// defaultChatAutoArchiveBatchSize bounds how many root chats one -// tick will archive by default. -const defaultChatAutoArchiveBatchSize int32 = 1000 - type Option func(*instance) // WithClock overrides the clock used by the purger. Defaults to @@ -72,34 +54,12 @@ func WithClock(clk quartz.Clock) Option { return func(i *instance) { i.clk = clk } } -// WithChatAutoArchiveBatchSize overrides how many root chats a -// single tick will auto-archive. Defaults to -// defaultChatAutoArchiveBatchSize (1000). -func WithChatAutoArchiveBatchSize(n int32) Option { - return func(i *instance) { i.chatAutoArchiveBatchSize = n } -} - -// WithNotificationsEnqueuer sets the enqueuer used for digest -// notifications. Defaults to notifications.NewNoopEnqueuer(). Panics -// if e is nil: a nil enqueuer would NPE on the first dispatch tick, -// and failing fast at option-apply time surfaces the misuse at -// startup rather than minutes later. -func WithNotificationsEnqueuer(e notifications.Enqueuer) Option { - if e == nil { - panic("developer error: WithNotificationsEnqueuer called with nil enqueuer") - } - return func(i *instance) { i.enqueuer = e } -} - // New creates a new periodically purging database instance. // Callers must Close the returned instance. // -// The auditor pointer is loaded on each dispatch tick so runtime -// entitlement changes (e.g. toggling the audit-log feature) take -// effect without restarting the process. Notifications enqueuer -// defaults to no-op. Use WithNotificationsEnqueuer to pass a real -// one. -func New(ctx context.Context, logger slog.Logger, db database.Store, vals *codersdk.DeploymentValues, reg prometheus.Registerer, auditor *atomic.Pointer[audit.Auditor], opts ...Option) io.Closer { +// The auditor pointer is accepted for compatibility with other background +// services. Dbpurge does not emit audit logs directly. +func New(ctx context.Context, logger slog.Logger, db database.Store, vals *codersdk.DeploymentValues, reg prometheus.Registerer, _ *atomic.Pointer[audit.Auditor], opts ...Option) io.Closer { closed := make(chan struct{}) ctx, cancelFunc := context.WithCancel(ctx) @@ -123,26 +83,14 @@ func New(ctx context.Context, logger slog.Logger, db database.Store, vals *coder }, []string{"record_type"}) reg.MustRegister(recordsPurged) - chatAutoArchiveRecords := prometheus.NewCounter(prometheus.CounterOpts{ - Namespace: "coderd", - Subsystem: "chat_auto_archive", - Name: "records_archived_total", - Help: "Total number of chats archived by the auto-archive job (counting both roots and cascaded children).", - }) - reg.MustRegister(chatAutoArchiveRecords) - inst := &instance{ - cancel: cancelFunc, - closed: closed, - logger: logger, - vals: vals, - clk: quartz.NewReal(), - auditor: auditor, - enqueuer: notifications.NewNoopEnqueuer(), - iterationDuration: iterationDuration, - recordsPurged: recordsPurged, - chatAutoArchiveRecords: chatAutoArchiveRecords, - chatAutoArchiveBatchSize: defaultChatAutoArchiveBatchSize, + cancel: cancelFunc, + closed: closed, + logger: logger, + vals: vals, + clk: quartz.NewReal(), + iterationDuration: iterationDuration, + recordsPurged: recordsPurged, } for _, opt := range opts { opt(inst) @@ -185,19 +133,14 @@ func New(ctx context.Context, logger slog.Logger, db database.Store, vals *coder func (i *instance) purgeTick(ctx context.Context, db database.Store, start time.Time) error { // Read chat configs outside the tx so a corrupt value can't // poison subsequent queries. On config read errors, log and stash - // the error, then run unrelated purges best-effort. Retention and - // auto-archive errors skip only the conversation purge and - // auto-archive work. Debug retention errors skip only the debug - // purge. purgeTick returns chatConfigErr after the tx so the failed - // iteration is operator-visible via metric and logs. + // the error, then run unrelated purges best-effort. Retention + // errors skip only the conversation purge. Debug retention errors + // skip only the debug purge. purgeTick returns chatConfigErr after + // the tx so the failed iteration is operator-visible via metric and + // logs. chatRetentionDays, chatRetentionErr := db.GetChatRetentionDays(ctx) if chatRetentionErr != nil { - i.logger.Error(ctx, "failed to read chat retention config: skipping chat purge and auto-archive this tick", slog.Error(chatRetentionErr)) - } - - chatAutoArchiveDays, chatAutoArchiveErr := db.GetChatAutoArchiveDays(ctx, codersdk.DefaultChatAutoArchiveDays) - if chatAutoArchiveErr != nil { - i.logger.Error(ctx, "failed to read chat auto-archive config: skipping chat purge and auto-archive this tick", slog.Error(chatAutoArchiveErr)) + i.logger.Error(ctx, "failed to read chat retention config: skipping chat purge this tick", slog.Error(chatRetentionErr)) } chatDebugRetentionDays, chatDebugRetentionErr := db.GetChatDebugRetentionDays(ctx, codersdk.DefaultChatDebugRetentionDays) @@ -205,11 +148,7 @@ func (i *instance) purgeTick(ctx context.Context, db database.Store, start time. i.logger.Error(ctx, "failed to read chat debug retention config: skipping chat debug purge this tick", slog.Error(chatDebugRetentionErr)) } - chatRetentionConfigErr := errors.Join(chatRetentionErr, chatAutoArchiveErr) - chatConfigErr := errors.Join(chatRetentionConfigErr, chatDebugRetentionErr) - - // Populated inside the tx; dispatched post-commit. - var archivedChats []database.AutoArchiveInactiveChatsRow + chatConfigErr := errors.Join(chatRetentionErr, chatDebugRetentionErr) // Start a transaction to grab advisory lock, we don't want to run // multiple purges at the same time (multiple replicas). @@ -316,8 +255,8 @@ func (i *instance) purgeTick(ctx context.Context, db database.Store, start time. } var purgedChats, purgedChatFiles, purgedChatDebugRuns int64 - if chatRetentionConfigErr == nil { - purgedChats, purgedChatFiles, archivedChats, err = i.purgeChatsInTx(ctx, tx, start, chatRetentionDays, chatAutoArchiveDays) + if chatRetentionErr == nil { + purgedChats, purgedChatFiles, err = i.purgeChatsInTx(ctx, tx, start, chatRetentionDays) if err != nil { return xerrors.Errorf("failed to purge chats: %w", err) } @@ -345,7 +284,6 @@ func (i *instance) purgeTick(ctx context.Context, db database.Store, start time. slog.F("chats", purgedChats), slog.F("chat_files", purgedChatFiles), slog.F("chat_debug_runs", purgedChatDebugRuns), - slog.F("auto_archived_chats", len(archivedChats)), slog.F("duration", i.clk.Since(start)), ) @@ -379,35 +317,17 @@ func (i *instance) purgeTick(ctx context.Context, db database.Store, start time. return xerrors.Errorf("chat config read failed this tick: %w", chatConfigErr) } - // Dispatch audits and digests post-commit. Detached context for audit - // so that ticker cancellation cannot truncate the audit trail. - // Notification enqueue uses the cancellable parent context to avoid - // stalling shutdown. - // Owners with more eligible chats than batch size will get a - // notification per tick until their backlog drains. - // If this is deemed too noisy, users can disable the - // "Chats Auto-Archived" template from their notification preferences. - if len(archivedChats) > 0 { - i.chatAutoArchiveRecords.Add(float64(len(archivedChats))) - auditCtx := context.WithoutCancel(ctx) - i.dispatchChatAutoArchive(auditCtx, ctx, start, chatAutoArchiveDays, chatRetentionDays, archivedChats) - } - return nil } type instance struct { - cancel context.CancelFunc - closed chan struct{} - logger slog.Logger - vals *codersdk.DeploymentValues - clk quartz.Clock - auditor *atomic.Pointer[audit.Auditor] - enqueuer notifications.Enqueuer - iterationDuration *prometheus.HistogramVec - recordsPurged *prometheus.CounterVec - chatAutoArchiveRecords prometheus.Counter - chatAutoArchiveBatchSize int32 + cancel context.CancelFunc + closed chan struct{} + logger slog.Logger + vals *codersdk.DeploymentValues + clk quartz.Clock + iterationDuration *prometheus.HistogramVec + recordsPurged *prometheus.CounterVec } func (i *instance) Close() error { @@ -416,73 +336,8 @@ func (i *instance) Close() error { return nil } -// chatFromAutoArchiveRow reshapes the query row into a database.Chat for -// audit.Auditable[database.Chat]. -func chatFromAutoArchiveRow(logger slog.Logger, r database.AutoArchiveInactiveChatsRow) database.Chat { - var labels database.StringMap - // sqlc's StringMap override doesn't reach CTE-aliased columns, so Labels - // arrives as raw JSON bytes. StringMap.Scan handles []byte and nil. - if err := labels.Scan([]byte(r.Labels)); err != nil { - logger.Warn(context.Background(), "failed to parse chat labels from auto-archive row", - slog.F("chat_id", r.ID), - slog.F("raw_labels", string(r.Labels)), - slog.Error(err), - ) - } - - var userACL database.ChatACL - if err := userACL.Scan([]byte(r.UserACL)); err != nil { - logger.Warn(context.Background(), "failed to parse chat user ACL from auto-archive row", - slog.F("chat_id", r.ID), - slog.F("raw_user_acl", string(r.UserACL)), - slog.Error(err), - ) - } - - var groupACL database.ChatACL - if err := groupACL.Scan([]byte(r.GroupACL)); err != nil { - logger.Warn(context.Background(), "failed to parse chat group ACL from auto-archive row", - slog.F("chat_id", r.ID), - slog.F("raw_group_acl", string(r.GroupACL)), - slog.Error(err), - ) - } - - return database.Chat{ - ID: r.ID, - OwnerID: r.OwnerID, - OrganizationID: r.OrganizationID, - WorkspaceID: r.WorkspaceID, - BuildID: r.BuildID, - AgentID: r.AgentID, - Title: r.Title, - Status: r.Status, - WorkerID: r.WorkerID, - StartedAt: r.StartedAt, - HeartbeatAt: r.HeartbeatAt, - CreatedAt: r.CreatedAt, - UpdatedAt: r.UpdatedAt, - ParentChatID: r.ParentChatID, - RootChatID: r.RootChatID, - LastModelConfigID: r.LastModelConfigID, - Archived: r.Archived, - LastError: r.LastError, - Mode: r.Mode, - MCPServerIDs: r.MCPServerIDs, - Labels: labels, - UserACL: userACL, - GroupACL: groupACL, - PinOrder: r.PinOrder, - LastReadMessageID: r.LastReadMessageID, - LastInjectedContext: r.LastInjectedContext, - DynamicTools: r.DynamicTools, - PlanMode: r.PlanMode, - ClientType: r.ClientType, - } -} - // purgeChatsInTx MUST BE CALLED WITH A TRANSACTION -func (i *instance) purgeChatsInTx(ctx context.Context, tx database.Store, start time.Time, chatRetentionDays, chatAutoArchiveDays int32) (purgedChats, purgedChatFiles int64, archivedChats []database.AutoArchiveInactiveChatsRow, err error) { +func (*instance) purgeChatsInTx(ctx context.Context, tx database.Store, start time.Time, chatRetentionDays int32) (purgedChats, purgedChatFiles int64, err error) { // Delete old archived chats first, then orphaned files // (cascade clears chat_file_links but not chat_files). if chatRetentionDays > 0 { @@ -492,7 +347,7 @@ func (i *instance) purgeChatsInTx(ctx context.Context, tx database.Store, start LimitCount: chatsBatchSize, }) if err != nil { - return 0, 0, nil, xerrors.Errorf("failed to delete old chats: %w", err) + return 0, 0, xerrors.Errorf("failed to delete old chats: %w", err) } purgedChatFiles, err = tx.DeleteOldChatFiles(ctx, database.DeleteOldChatFilesParams{ @@ -500,149 +355,9 @@ func (i *instance) purgeChatsInTx(ctx context.Context, tx database.Store, start LimitCount: chatFilesBatchSize, }) if err != nil { - return 0, 0, nil, xerrors.Errorf("failed to delete old chat files: %w", err) + return 0, 0, xerrors.Errorf("failed to delete old chat files: %w", err) } } - // Auto-archive runs after the delete pass so newly - // archived chats aren't eligible for deletion this tick. - // Eligibility uses UTC day boundaries: a chat is archived on the - // start of the UTC day after its inactivity period has elapsed. - if chatAutoArchiveDays > 0 { - today := dbtime.StartOfDay(start) - archiveCutoff := today.Add(-time.Duration(chatAutoArchiveDays) * 24 * time.Hour) - archivedChats, err = tx.AutoArchiveInactiveChats(ctx, database.AutoArchiveInactiveChatsParams{ - ArchiveCutoff: archiveCutoff, - LimitCount: i.chatAutoArchiveBatchSize, - }) - if err != nil { - return 0, 0, nil, xerrors.Errorf("failed to auto-archive inactive chats: %w", err) - } - } - return purgedChats, purgedChatFiles, archivedChats, nil -} - -// dispatchChatAutoArchive audits every archived root chat and enqueues one -// notification per owner covering the roots archived in this tick. Children -// inherit their root's archival decision and are skipped for audit, matching -// the manual archive path (patchChat audits the root only). Enqueue is -// per-tick: owners whose backlog spans multiple ticks receive multiple -// notifications; notification_messages dedupe does not collapse them because -// each tick's payload differs. -// -// auditCtx is detached from the ticker so audits always complete. enqueueCtx -// is the cancellable parent: on shutdown we abandon any remaining digests -// rather than blocking Close. -func (i *instance) dispatchChatAutoArchive(auditCtx, enqueueCtx context.Context, tickStart time.Time, autoArchiveDays, retentionDays int32, archived []database.AutoArchiveInactiveChatsRow) { - // Children inherit their root's archival decision and are skipped - // for both audit and digest. Partition once so the two loops - // cannot drift apart if the cascade shape ever changes. - roots := slice.Filter(archived, func(r database.AutoArchiveInactiveChatsRow) bool { - return !r.ParentChatID.Valid - }) - - auditor := *i.auditor.Load() - for _, row := range roots { - after := chatFromAutoArchiveRow(i.logger, row) - before := after - before.Archived = false - audit.BackgroundAudit(auditCtx, &audit.BackgroundAuditParams[database.Chat]{ - Audit: auditor, - Log: i.logger, - UserID: row.OwnerID, - OrganizationID: row.OrganizationID, - Action: database.AuditActionWrite, - Old: before, - New: after, - Status: http.StatusOK, - AdditionalFields: audit.BackgroundTaskFieldsBytes(auditCtx, i.logger, audit.BackgroundSubsystemChatAutoArchive), - }) - } - - // Group archived roots by owner. Inline because this is the - // only call site and the loop body is self-explanatory. - rootsByOwner := make(map[uuid.UUID][]database.AutoArchiveInactiveChatsRow, len(roots)) - for _, row := range roots { - rootsByOwner[row.OwnerID] = append(rootsByOwner[row.OwnerID], row) - } - - // Sort owner IDs so shutdown abandons a deterministic tail of the dispatch list. - ownerIDs := make([]uuid.UUID, 0, len(rootsByOwner)) - for id := range rootsByOwner { - ownerIDs = append(ownerIDs, id) - } - slices.SortFunc(ownerIDs, func(a, b uuid.UUID) int { - return cmp.Compare(a.String(), b.String()) - }) - - dispatched := 0 - for _, ownerID := range ownerIDs { - // Check between iterations so shutdown unblocks promptly. A - // hung in-flight enqueue is unblocked by enqueueCtx propagating - // cancellation into the DB call. Skipped owners are not - // re-notified on the next tick because AutoArchiveInactiveChats - // only returns rows with archived = false; we accept that - // tradeoff over hanging shutdown. - if err := enqueueCtx.Err(); err != nil { - i.logger.Warn(enqueueCtx, "chat auto-archive digest dispatch canceled", - slog.F("remaining_owners", len(ownerIDs)-dispatched), - slog.Error(err)) - return - } - dispatched++ - - ownerRoots := rootsByOwner[ownerID] - data := buildDigestData(ownerRoots, autoArchiveDays, retentionDays, tickStart) - - // nolint:gocritic // Background digest runs as the notifier subject. - if _, err := i.enqueuer.EnqueueWithData( - dbauthz.AsNotifier(enqueueCtx), - ownerID, - notifications.TemplateChatAutoArchiveDigest, - map[string]string{}, - data, - string(audit.BackgroundSubsystemChatAutoArchive), - ); err != nil { - i.logger.Warn(enqueueCtx, "failed to enqueue chat auto-archive digest", - slog.F("owner_id", ownerID), - slog.Error(err)) - } - } -} - -// buildDigestData builds the notification payload; shape mirrors the -// golden fixtures in coderd/notifications/testdata. Truncation keeps -// the oldest archived roots (created_at ASC from the query) to -// preserve index-driven ordering; revisit if the digest becomes the -// primary surface for reviewing archived chats. -func buildDigestData(rows []database.AutoArchiveInactiveChatsRow, autoArchiveDays, retentionDays int32, tickStart time.Time) map[string]any { - // Cap titles; overflow surfaces as "...and N more" via the template. - overflow := 0 - if len(rows) > chatAutoArchiveDigestMaxChats { - overflow = len(rows) - chatAutoArchiveDigestMaxChats - rows = rows[:chatAutoArchiveDigestMaxChats] - } - - chats := make([]map[string]any, 0, len(rows)) - for _, r := range rows { - chats = append(chats, map[string]any{ - "title": r.Title, - "last_activity_humanized": humanize.RelTime(r.LastActivityAt, tickStart, "ago", "from now"), - }) - } - - // Stringify the int32 config values: the template's - // {{if eq .Data.retention_days "0"}} branch requires both - // operands to share a type, and Go templates do not coerce - // numeric ↔ string. Storing a raw int here would silently - // take the deletion-warning branch on every notification. - data := map[string]any{ - "auto_archive_days": strconv.Itoa(int(autoArchiveDays)), - "retention_days": strconv.Itoa(int(retentionDays)), - "archived_chats": chats, - } - if overflow > 0 { - data["additional_archived_count"] = strconv.Itoa(overflow) - } - return data + return purgedChats, purgedChatFiles, nil } diff --git a/coderd/database/dbpurge/dbpurge_test.go b/coderd/database/dbpurge/dbpurge_test.go index 4ebd645a7a..c0e784f538 100644 --- a/coderd/database/dbpurge/dbpurge_test.go +++ b/coderd/database/dbpurge/dbpurge_test.go @@ -32,9 +32,6 @@ import ( "github.com/coder/coder/v2/coderd/database/dbrollup" "github.com/coder/coder/v2/coderd/database/dbtestutil" "github.com/coder/coder/v2/coderd/database/dbtime" - "github.com/coder/coder/v2/coderd/notifications" - "github.com/coder/coder/v2/coderd/notifications/notificationsmock" - "github.com/coder/coder/v2/coderd/notifications/notificationstest" "github.com/coder/coder/v2/coderd/provisionerdserver" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/provisionerd/proto" @@ -60,7 +57,6 @@ func TestPurge(t *testing.T) { done := awaitDoTick(ctx, t, clk) mDB := dbmock.NewMockStore(gomock.NewController(t)) mDB.EXPECT().GetChatRetentionDays(gomock.Any()).Return(int32(0), nil).AnyTimes() - mDB.EXPECT().GetChatAutoArchiveDays(gomock.Any(), codersdk.DefaultChatAutoArchiveDays).Return(int32(0), nil).AnyTimes() mDB.EXPECT().GetChatDebugRetentionDays(gomock.Any(), codersdk.DefaultChatDebugRetentionDays).Return(int32(0), nil).AnyTimes() mDB.EXPECT().InTx(gomock.Any(), database.DefaultTXOptions().WithID("db_purge")).Return(nil).Times(2) purger := dbpurge.New(context.Background(), testutil.Logger(t), mDB, &codersdk.DeploymentValues{}, prometheus.NewRegistry(), nopAuditorPtr(t), dbpurge.WithClock(clk)) @@ -163,8 +159,6 @@ func TestMetrics(t *testing.T) { ctrl := gomock.NewController(t) mDB := dbmock.NewMockStore(ctrl) mDB.EXPECT().GetChatRetentionDays(gomock.Any()).Return(int32(0), nil).AnyTimes() - mDB.EXPECT().GetChatAutoArchiveDays(gomock.Any(), codersdk.DefaultChatAutoArchiveDays). - Return(int32(0), nil).AnyTimes() mDB.EXPECT().GetChatDebugRetentionDays(gomock.Any(), codersdk.DefaultChatDebugRetentionDays). Return(int32(0), nil).AnyTimes() mDB.EXPECT().TryAcquireLock(gomock.Any(), int64(database.LockIDDBPurge)).Return(false, nil).AnyTimes() @@ -203,7 +197,6 @@ func TestMetrics(t *testing.T) { ctrl := gomock.NewController(t) mDB := dbmock.NewMockStore(ctrl) mDB.EXPECT().GetChatRetentionDays(gomock.Any()).Return(int32(0), nil).AnyTimes() - mDB.EXPECT().GetChatAutoArchiveDays(gomock.Any(), codersdk.DefaultChatAutoArchiveDays).Return(int32(0), nil).AnyTimes() mDB.EXPECT().GetChatDebugRetentionDays(gomock.Any(), codersdk.DefaultChatDebugRetentionDays). Return(int32(0), nil).AnyTimes() mDB.EXPECT().InTx(gomock.Any(), database.DefaultTXOptions().WithID("db_purge")). @@ -230,8 +223,8 @@ func TestMetrics(t *testing.T) { }) // A failed retention read must not block unrelated or chat debug - // purges, but must skip the conversation purge and auto-archive - // passes and surface as a failed iteration via the metric. + // purges, but must skip the conversation purge and surface as a + // failed iteration via the metric. t.Run("FailedChatRetentionRead", func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) defer cancel() @@ -248,8 +241,6 @@ func TestMetrics(t *testing.T) { MinTimes(1) // All reads happen before the bail; InTx still runs so unrelated // purges and chat debug purge commit best-effort. - mDB.EXPECT().GetChatAutoArchiveDays(gomock.Any(), codersdk.DefaultChatAutoArchiveDays). - Return(int32(0), nil).AnyTimes() mDB.EXPECT().GetChatDebugRetentionDays(gomock.Any(), codersdk.DefaultChatDebugRetentionDays). Return(int32(7), nil).AnyTimes() mDB.EXPECT().TryAcquireLock(gomock.Any(), int64(database.LockIDDBPurge)).Return(true, nil).AnyTimes() @@ -285,50 +276,6 @@ func TestMetrics(t *testing.T) { require.Nil(t, successHist, "should not have success=true metric on retention read failure") }) - // Same contract as FailedChatRetentionRead, but the - // auto-archive read is the half that fails. - t.Run("FailedChatAutoArchiveRead", func(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort) - defer cancel() - - reg := prometheus.NewRegistry() - clk := quartz.NewMock(t) - now := clk.Now() - clk.Set(now).MustWait(ctx) - - ctrl := gomock.NewController(t) - mDB := dbmock.NewMockStore(ctrl) - mDB.EXPECT().GetChatRetentionDays(gomock.Any()).Return(int32(30), nil).AnyTimes() - mDB.EXPECT().GetChatAutoArchiveDays(gomock.Any(), codersdk.DefaultChatAutoArchiveDays). - Return(int32(0), xerrors.New("simulated auto-archive read error")). - MinTimes(1) - mDB.EXPECT().GetChatDebugRetentionDays(gomock.Any(), codersdk.DefaultChatDebugRetentionDays). - Return(int32(0), nil).AnyTimes() - // InTx still runs so unrelated purges commit; chat - // passes inside the tx are skipped. - mDB.EXPECT().InTx(gomock.Any(), database.DefaultTXOptions().WithID("db_purge")). - Return(nil).MinTimes(1) - - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) - - done := awaitDoTick(ctx, t, clk) - closer := dbpurge.New(ctx, logger, mDB, &codersdk.DeploymentValues{}, reg, nopAuditorPtr(t), dbpurge.WithClock(clk)) - defer closer.Close() - testutil.TryReceive(ctx, t, done) - - hist := promhelp.HistogramValue(t, reg, "coderd_dbpurge_iteration_duration_seconds", prometheus.Labels{ - "success": "false", - }) - require.NotNil(t, hist) - require.Greater(t, hist.GetSampleCount(), uint64(0), - "failed auto-archive read must record a failed iteration") - - successHist := promhelp.MetricValue(t, reg, "coderd_dbpurge_iteration_duration_seconds", prometheus.Labels{ - "success": "true", - }) - require.Nil(t, successHist, "should not have success=true metric on auto-archive read failure") - }) - // Same contract as the other chat config reads, but debug retention // read failures skip only debug purging. t.Run("FailedChatDebugRetentionRead", func(t *testing.T) { @@ -343,8 +290,6 @@ func TestMetrics(t *testing.T) { ctrl := gomock.NewController(t) mDB := dbmock.NewMockStore(ctrl) mDB.EXPECT().GetChatRetentionDays(gomock.Any()).Return(int32(30), nil).AnyTimes() - mDB.EXPECT().GetChatAutoArchiveDays(gomock.Any(), codersdk.DefaultChatAutoArchiveDays). - Return(int32(0), nil).AnyTimes() mDB.EXPECT().GetChatDebugRetentionDays(gomock.Any(), codersdk.DefaultChatDebugRetentionDays). Return(int32(0), xerrors.New("simulated chat debug retention read error")). MinTimes(1) @@ -632,63 +577,6 @@ func awaitDoTick(ctx context.Context, t *testing.T, clk *quartz.Mock) chan struc return ch } -// tickDriver drives one or more dbpurge ticks against a single -// dbpurge.New instance. Unlike awaitDoTick it must be constructed -// *before* dbpurge.New so its traps are installed when the forced -// initial tick fires. awaitInitial waits for the forced tick's -// doTick to complete without advancing the clock, so no loop -// iteration has yet run; awaitNext then explicitly drives each -// subsequent iteration. This keeps each tick's observable state -// isolated and deterministic, which matters for tests where -// per-tick work differs (e.g. batch-size pagination). -type tickDriver struct { - clk *quartz.Mock - trapNow *quartz.Trap - trapStop *quartz.Trap - trapReset *quartz.Trap -} - -func newTickDriver(t *testing.T, clk *quartz.Mock) *tickDriver { - t.Helper() - d := &tickDriver{ - clk: clk, - trapNow: clk.Trap().Now(), - trapStop: clk.Trap().TickerStop(), - trapReset: clk.Trap().TickerReset(), - } - return d -} - -// close releases all traps. Call this via defer *after* the defer -// that closes the dbpurge instance so trap closure releases the -// shutdown ticker.Stop() rather than blocking on it. -func (d *tickDriver) close() { - d.trapReset.Close() - d.trapStop.Close() - d.trapNow.Close() -} - -// awaitInitial waits for the forced initial tick's doTick to -// complete. No loop iteration runs because the clock has not been -// advanced. -func (d *tickDriver) awaitInitial(ctx context.Context, t *testing.T) { - t.Helper() - d.trapNow.MustWait(ctx).MustRelease(ctx) - d.trapReset.MustWait(ctx).MustRelease(ctx) -} - -// awaitNext advances the clock by the tick interval, lets the loop -// receive the tick and run doTick, and waits for the ensuing -// ticker.Reset so the driver is ready for another awaitNext. -func (d *tickDriver) awaitNext(ctx context.Context, t *testing.T) { - t.Helper() - dur, w := d.clk.AdvanceNext() - require.Equal(t, 10*time.Minute, dur) - w.MustWait(ctx) - d.trapStop.MustWait(ctx).MustRelease(ctx) - d.trapReset.MustWait(ctx).MustRelease(ctx) -} - func assertNoWorkspaceAgentLogs(ctx context.Context, t *testing.T, db database.Store, agentID uuid.UUID) { t.Helper() agentLogs, err := db.GetWorkspaceAgentLogsAfter(ctx, database.GetWorkspaceAgentLogsAfterParams{ @@ -1922,14 +1810,6 @@ func nopAuditorPtr(t *testing.T) *atomic.Pointer[audit.Auditor] { return &p } -// mockAuditorPtr wraps a *MockAuditor in an atomic pointer for tests. -func mockAuditorPtr(m *audit.MockAuditor) *atomic.Pointer[audit.Auditor] { - a := audit.Auditor(m) - var p atomic.Pointer[audit.Auditor] - p.Store(&a) - return &p -} - //nolint:paralleltest // It uses LockIDDBPurge. func TestPurgeChatDebugRuns(t *testing.T) { now := time.Date(2025, 6, 15, 12, 0, 0, 0, time.UTC) @@ -2585,943 +2465,3 @@ func TestDeleteOldChatFiles(t *testing.T) { }) } } - -// helpers for TestAutoArchiveInactiveChats. Kept scoped to the -// test so they don't leak into the package surface area. -func archiveTestDeps(t *testing.T, db database.Store) chatAutoArchiveDeps { - t.Helper() - user := dbgen.User(t, db, database.User{}) - org := dbgen.Organization(t, db, database.Organization{}) - _ = dbgen.OrganizationMember(t, db, database.OrganizationMember{UserID: user.ID, OrganizationID: org.ID}) - _ = dbgen.ChatProvider(t, db, database.ChatProvider{ - Provider: "openai", - DisplayName: "OpenAI", - }) - mc := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{ - Provider: "openai", - Model: "test-model", - ContextLimit: 8192, - }) - return chatAutoArchiveDeps{user: user, org: org, modelConfig: mc} -} - -type chatAutoArchiveDeps struct { - user database.User - org database.Organization - modelConfig database.ChatModelConfig -} - -// archiveHarness bundles the per-subtest setup shared by every -// TestAutoArchiveInactiveChats case. Subtests read fields off the -// harness directly instead of repeating six lines of identical -// plumbing. -type archiveHarness struct { - ctx context.Context - clk *quartz.Mock - db database.Store - rawDB *sql.DB - logger slog.Logger - deps chatAutoArchiveDeps -} - -func newArchiveHarness(t *testing.T, now time.Time) *archiveHarness { - t.Helper() - ctx := testutil.Context(t, testutil.WaitLong) - clk := quartz.NewMock(t) - clk.Set(now).MustWait(ctx) - db, _, rawDB := dbtestutil.NewDBWithSQLDB(t, dbtestutil.WithDumpOnFailure()) - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) - return &archiveHarness{ - ctx: ctx, - clk: clk, - db: db, - rawDB: rawDB, - logger: logger, - deps: archiveTestDeps(t, db), - } -} - -// createArchiveChat inserts a chat with an optional backdated -// created_at. Title is propagated through so tests can assert on -// digest contents. -func createArchiveChat(ctx context.Context, t *testing.T, db database.Store, rawDB *sql.DB, deps chatAutoArchiveDeps, title string, createdAt time.Time) database.Chat { - t.Helper() - chat := dbgen.Chat(t, db, database.Chat{ - OrganizationID: deps.org.ID, - OwnerID: deps.user.ID, - LastModelConfigID: deps.modelConfig.ID, - Title: title, - }) - _, err := rawDB.ExecContext(ctx, "UPDATE chats SET created_at = $1, updated_at = $1 WHERE id = $2", createdAt, chat.ID) - require.NoError(t, err) - return chat -} - -// insertTextMessage appends a non-deleted user message with a -// backdated created_at. Used to establish "last activity" for the -// auto-archive query's LATERAL subquery. -func insertTextMessage(ctx context.Context, t *testing.T, db database.Store, rawDB *sql.DB, chatID, userID, modelConfigID uuid.UUID, createdAt time.Time) { - t.Helper() - msg := dbgen.ChatMessage(t, db, database.ChatMessage{ - ChatID: chatID, - CreatedBy: uuid.NullUUID{UUID: userID, Valid: true}, - ModelConfigID: uuid.NullUUID{UUID: modelConfigID, Valid: true}, - Role: database.ChatMessageRoleUser, - }) - _, err := rawDB.ExecContext(ctx, "UPDATE chat_messages SET created_at = $1 WHERE id = $2", createdAt, msg.ID) - require.NoError(t, err) -} - -//nolint:paralleltest // It uses LockIDDBPurge. -func TestAutoArchiveInactiveChats(t *testing.T) { - now := time.Date(2025, 6, 15, 12, 0, 0, 0, time.UTC) - - tests := []struct { - name string - run func(t *testing.T) - }{ - { - name: "AutoArchiveDisabled", - run: func(t *testing.T) { - h := newArchiveHarness(t, now) - ctx, clk, db, rawDB, logger, deps := h.ctx, h.clk, h.db, h.rawDB, h.logger, h.deps - - require.Zero(t, codersdk.DefaultChatAutoArchiveDays) - require.NoError(t, db.UpsertChatAutoArchiveDays(ctx, codersdk.DefaultChatAutoArchiveDays)) - - // Chat older than any reasonable cutoff. - staleChat := createArchiveChat(ctx, t, db, rawDB, deps, "stale-chat", now.Add(-365*24*time.Hour)) - - auditor := audit.NewMock() - auditorPtr := mockAuditorPtr(auditor) - enqueuer := notificationstest.NewFakeEnqueuer() - done := awaitDoTick(ctx, t, clk) - closer := dbpurge.New(ctx, logger, db, &codersdk.DeploymentValues{}, prometheus.NewRegistry(), auditorPtr, dbpurge.WithNotificationsEnqueuer(enqueuer), dbpurge.WithClock(clk)) - defer closer.Close() - testutil.TryReceive(ctx, t, done) - - // Not archived, no audits, no digests. - refreshed, err := db.GetChatByID(ctx, staleChat.ID) - require.NoError(t, err) - require.False(t, refreshed.Archived, "chat should stay active when auto-archive is disabled") - - require.Empty(t, auditor.AuditLogs(), "no audit log entries expected") - require.Empty(t, enqueuer.Sent(), "no digest notifications expected") - }, - }, - { - name: "ArchivesInactiveRoot", - run: func(t *testing.T) { - h := newArchiveHarness(t, now) - ctx, clk, db, rawDB, logger, deps := h.ctx, h.clk, h.db, h.rawDB, h.logger, h.deps - - // Regression guard: ensure that both auto-archive and retention - // are both set to a distinct non-zero value. - require.NoError(t, db.UpsertChatAutoArchiveDays(ctx, int32(90))) - require.NoError(t, db.UpsertChatRetentionDays(ctx, int32(30))) - - // Inactive root: newest message 100 days old. - staleChat := createArchiveChat(ctx, t, db, rawDB, deps, "stale-chat", now.Add(-120*24*time.Hour)) - insertTextMessage(ctx, t, db, rawDB, staleChat.ID, deps.user.ID, deps.modelConfig.ID, now.Add(-100*24*time.Hour)) - - // Active root: message 10 days old, within cutoff. - activeChat := createArchiveChat(ctx, t, db, rawDB, deps, "active-chat", now.Add(-120*24*time.Hour)) - insertTextMessage(ctx, t, db, rawDB, activeChat.ID, deps.user.ID, deps.modelConfig.ID, now.Add(-10*24*time.Hour)) - - auditor := audit.NewMock() - auditorPtr := mockAuditorPtr(auditor) - enqueuer := notificationstest.NewFakeEnqueuer() - done := awaitDoTick(ctx, t, clk) - closer := dbpurge.New(ctx, logger, db, &codersdk.DeploymentValues{}, prometheus.NewRegistry(), auditorPtr, dbpurge.WithNotificationsEnqueuer(enqueuer), dbpurge.WithClock(clk)) - defer closer.Close() - testutil.TryReceive(ctx, t, done) - - refreshedStale, err := db.GetChatByID(ctx, staleChat.ID) - require.NoError(t, err) - require.True(t, refreshedStale.Archived, "stale chat should be auto-archived") - - refreshedActive, err := db.GetChatByID(ctx, activeChat.ID) - require.NoError(t, err) - require.False(t, refreshedActive.Archived, "active chat should stay live") - - // Exactly one audit entry, for the stale root. - logs := auditor.AuditLogs() - require.Len(t, logs, 1, "expected one audit entry") - require.Equal(t, staleChat.ID, logs[0].ResourceID) - require.Equal(t, database.ResourceTypeChat, logs[0].ResourceType) - require.Equal(t, database.AuditActionWrite, logs[0].Action) - require.Contains(t, string(logs[0].AdditionalFields), "chat_auto_archive", - "audit entry must carry the auto-archive subsystem tag") - - // Exactly one digest, addressed to the owner. - sent := enqueuer.Sent() - require.Len(t, sent, 1, "expected one digest notification") - require.Equal(t, notifications.TemplateChatAutoArchiveDigest, sent[0].TemplateID) - require.Equal(t, deps.user.ID, sent[0].UserID) - // Ensure that config-derived fields flow through to payload. - require.Equal(t, "90", sent[0].Data["auto_archive_days"]) - require.Equal(t, "30", sent[0].Data["retention_days"]) - }, - }, - { - name: "DateBoundary", - run: func(t *testing.T) { - h := newArchiveHarness(t, now) - ctx, clk, db, rawDB, logger, deps := h.ctx, h.clk, h.db, h.rawDB, h.logger, h.deps - - require.NoError(t, db.UpsertChatAutoArchiveDays(ctx, int32(90))) - - // With now = 2025-06-15 12:00 UTC, the Go code - // truncates to today = 2025-06-15 00:00 UTC, then - // subtracts 90 days -> cutoff = 2025-03-17 00:00 UTC. - // A chat's last-activity UTC date must be strictly < - // 2025-03-17 to be archived. - - // Activity on the cutoff date (2025-03-17): must survive. - onDate := createArchiveChat(ctx, t, db, rawDB, deps, "on-date", now.Add(-120*24*time.Hour)) - insertTextMessage(ctx, t, db, rawDB, onDate.ID, deps.user.ID, deps.modelConfig.ID, - time.Date(2025, 3, 17, 15, 30, 0, 0, time.UTC)) - - // Activity day before cutoff date (2025-03-16): must be archived. - beforeDate := createArchiveChat(ctx, t, db, rawDB, deps, "before-date", now.Add(-120*24*time.Hour)) - insertTextMessage(ctx, t, db, rawDB, beforeDate.ID, deps.user.ID, deps.modelConfig.ID, - time.Date(2025, 3, 16, 23, 59, 59, 0, time.UTC)) - - auditor := audit.NewMock() - auditorPtr := mockAuditorPtr(auditor) - driver := newTickDriver(t, clk) - closer := dbpurge.New(ctx, logger, db, &codersdk.DeploymentValues{}, prometheus.NewRegistry(), auditorPtr, dbpurge.WithClock(clk)) - defer closer.Close() - defer driver.close() - driver.awaitInitial(ctx, t) - - refreshedOn, err := db.GetChatByID(ctx, onDate.ID) - require.NoError(t, err) - require.False(t, refreshedOn.Archived, "chat with activity on cutoff date must survive") - - refreshedBefore, err := db.GetChatByID(ctx, beforeDate.ID) - require.NoError(t, err) - require.True(t, refreshedBefore.Archived, "chat with activity day before cutoff must be archived") - - require.Len(t, auditor.AuditLogs(), 1, "only the before-date chat should produce an audit entry") - }, - }, - { - name: "DayBoundaryLateActivity", - run: func(t *testing.T) { - h := newArchiveHarness(t, now) - ctx, clk, db, rawDB, logger, deps := h.ctx, h.clk, h.db, h.rawDB, h.logger, h.deps - - require.NoError(t, db.UpsertChatAutoArchiveDays(ctx, int32(90))) - - // Activity at 23:59:59 UTC on 2025-03-17 (cutoff date). - // The UTC date is still 2025-03-17, NOT < cutoff date, - // so it must NOT be archived. - lateChat := createArchiveChat(ctx, t, db, rawDB, deps, "late-activity", now.Add(-120*24*time.Hour)) - insertTextMessage(ctx, t, db, rawDB, lateChat.ID, deps.user.ID, deps.modelConfig.ID, - time.Date(2025, 3, 17, 23, 59, 59, 0, time.UTC)) - - auditor := audit.NewMock() - auditorPtr := mockAuditorPtr(auditor) - driver := newTickDriver(t, clk) - closer := dbpurge.New(ctx, logger, db, &codersdk.DeploymentValues{}, prometheus.NewRegistry(), auditorPtr, dbpurge.WithClock(clk)) - defer closer.Close() - defer driver.close() - driver.awaitInitial(ctx, t) - - refreshed, err := db.GetChatByID(ctx, lateChat.ID) - require.NoError(t, err) - require.False(t, refreshed.Archived, "activity at 23:59:59 UTC on cutoff date must not be archived") - require.Empty(t, auditor.AuditLogs()) - }, - }, - { - name: "SameDayActivityNotArchived", - run: func(t *testing.T) { - h := newArchiveHarness(t, now) - ctx, clk, db, rawDB, logger, deps := h.ctx, h.clk, h.db, h.rawDB, h.logger, h.deps - - require.NoError(t, db.UpsertChatAutoArchiveDays(ctx, int32(90))) - - // Activity at 00:00:01 UTC on the cutoff date - // (2025-03-17). Same date as cutoff, NOT strictly <, - // so must NOT be archived. - earlyChat := createArchiveChat(ctx, t, db, rawDB, deps, "early-same-day", now.Add(-120*24*time.Hour)) - insertTextMessage(ctx, t, db, rawDB, earlyChat.ID, deps.user.ID, deps.modelConfig.ID, - time.Date(2025, 3, 17, 0, 0, 1, 0, time.UTC)) - - auditor := audit.NewMock() - auditorPtr := mockAuditorPtr(auditor) - driver := newTickDriver(t, clk) - closer := dbpurge.New(ctx, logger, db, &codersdk.DeploymentValues{}, prometheus.NewRegistry(), auditorPtr, dbpurge.WithClock(clk)) - defer closer.Close() - defer driver.close() - driver.awaitInitial(ctx, t) - - refreshed, err := db.GetChatByID(ctx, earlyChat.ID) - require.NoError(t, err) - require.False(t, refreshed.Archived, "activity at start of cutoff date must not be archived") - require.Empty(t, auditor.AuditLogs()) - }, - }, - { - name: "SameDayBatch", - run: func(t *testing.T) { - h := newArchiveHarness(t, now) - ctx, clk, db, rawDB, logger, deps := h.ctx, h.clk, h.db, h.rawDB, h.logger, h.deps - - require.NoError(t, db.UpsertChatAutoArchiveDays(ctx, int32(90))) - - // Three chats all with last activity on 2025-03-16 - // (one day before cutoff) but at different times. - // All should be archived in the same batch. - chat1 := createArchiveChat(ctx, t, db, rawDB, deps, "batch-1", now.Add(-120*24*time.Hour)) - insertTextMessage(ctx, t, db, rawDB, chat1.ID, deps.user.ID, deps.modelConfig.ID, - time.Date(2025, 3, 16, 1, 0, 0, 0, time.UTC)) - - chat2 := createArchiveChat(ctx, t, db, rawDB, deps, "batch-2", now.Add(-120*24*time.Hour)) - insertTextMessage(ctx, t, db, rawDB, chat2.ID, deps.user.ID, deps.modelConfig.ID, - time.Date(2025, 3, 16, 12, 0, 0, 0, time.UTC)) - - chat3 := createArchiveChat(ctx, t, db, rawDB, deps, "batch-3", now.Add(-120*24*time.Hour)) - insertTextMessage(ctx, t, db, rawDB, chat3.ID, deps.user.ID, deps.modelConfig.ID, - time.Date(2025, 3, 16, 23, 59, 0, 0, time.UTC)) - - auditor := audit.NewMock() - auditorPtr := mockAuditorPtr(auditor) - driver := newTickDriver(t, clk) - closer := dbpurge.New(ctx, logger, db, &codersdk.DeploymentValues{}, prometheus.NewRegistry(), auditorPtr, dbpurge.WithClock(clk)) - defer closer.Close() - defer driver.close() - driver.awaitInitial(ctx, t) - - for _, tc := range []struct { - name string - id uuid.UUID - }{ - {"batch-1", chat1.ID}, - {"batch-2", chat2.ID}, - {"batch-3", chat3.ID}, - } { - refreshed, err := db.GetChatByID(ctx, tc.id) - require.NoError(t, err) - require.True(t, refreshed.Archived, "%s should be archived", tc.name) - } - - require.Len(t, auditor.AuditLogs(), 3, "all three chats should produce audit entries") - }, - }, - { - // CutoffStableAcrossSameDayTicks verifies that the archive - // cutoff is derived from the UTC day, not from the wall-clock - // time. Advancing the clock within the same UTC day must not - // change the archival decision ("no trickle" property). The - // chat is only archived once the clock crosses into the next - // UTC day and the cutoff date advances. - name: "CutoffStableAcrossSameDayTicks", - run: func(t *testing.T) { - // Start close to midnight so exactly two awaitNext calls - // cross the UTC day boundary: tick 1 at 23:49, tick 2 at - // 23:59 (still June 15, cutoff unchanged), tick 3 at - // 00:09 June 16 (new day, cutoff advances). - nearMidnight := time.Date(2025, 6, 15, 23, 49, 0, 0, time.UTC) - h := newArchiveHarness(t, nearMidnight) - ctx, clk, db, rawDB, logger, deps := h.ctx, h.clk, h.db, h.rawDB, h.logger, h.deps - - require.NoError(t, db.UpsertChatAutoArchiveDays(ctx, int32(90))) - - // Chat last active on 2025-03-17, which equals the cutoff - // for any tick on 2025-06-15: truncate(today) - 90d = - // 2025-03-17. The query requires last-activity < cutoff - // (strict), so the chat must survive all June-15 ticks. - chat := createArchiveChat(ctx, t, db, rawDB, deps, "boundary-chat", nearMidnight.Add(-120*24*time.Hour)) - insertTextMessage(ctx, t, db, rawDB, chat.ID, deps.user.ID, deps.modelConfig.ID, - time.Date(2025, 3, 17, 12, 0, 0, 0, time.UTC)) - - auditor := audit.NewMock() - auditorPtr := mockAuditorPtr(auditor) - driver := newTickDriver(t, clk) - closer := dbpurge.New(ctx, logger, db, &codersdk.DeploymentValues{}, prometheus.NewRegistry(), auditorPtr, dbpurge.WithClock(clk)) - defer closer.Close() - defer driver.close() - - // Tick 1 (23:49 UTC June 15): cutoff = 2025-03-17. - // Activity on the cutoff date is not strictly less than - // the cutoff, so the chat must not be archived. - driver.awaitInitial(ctx, t) - - refreshed, err := db.GetChatByID(ctx, chat.ID) - require.NoError(t, err) - require.False(t, refreshed.Archived, "tick 1: chat on cutoff date must not be archived") - require.Empty(t, auditor.AuditLogs(), "tick 1: no audit entries expected") - - // Tick 2 (23:59 UTC June 15): still the same UTC day. - // The cutoff is unchanged (still 2025-03-17), so advancing - // the wall clock within the same day must not archive the - // chat. - driver.awaitNext(ctx, t) - - refreshed, err = db.GetChatByID(ctx, chat.ID) - require.NoError(t, err) - require.False(t, refreshed.Archived, "tick 2: same UTC day, cutoff unchanged, chat must still survive") - require.Empty(t, auditor.AuditLogs(), "tick 2: no audit entries expected") - - // Tick 3 (00:09 UTC June 16): new UTC day. The cutoff - // advances to 2025-03-18, so activity on 2025-03-17 is - // now strictly less than the cutoff and the chat must be - // archived. - driver.awaitNext(ctx, t) - - refreshed, err = db.GetChatByID(ctx, chat.ID) - require.NoError(t, err) - require.True(t, refreshed.Archived, "tick 3: cutoff advanced to 2025-03-18, chat must now be archived") - require.Len(t, auditor.AuditLogs(), 1, "tick 3: exactly one audit entry expected") - }, - }, - - { - name: "DeletedMessagesIgnored", - run: func(t *testing.T) { - h := newArchiveHarness(t, now) - ctx, clk, db, rawDB, logger, deps := h.ctx, h.clk, h.db, h.rawDB, h.logger, h.deps - - require.NoError(t, db.UpsertChatAutoArchiveDays(ctx, int32(90))) - - // Chat created 120 days ago with a recent message - // (10 days old) that is then soft-deleted. The - // LATERAL subquery filters cm.deleted = false, so - // the chat should fall back to created_at and be - // archived. - chat := createArchiveChat(ctx, t, db, rawDB, deps, "deleted-msg", now.Add(-120*24*time.Hour)) - insertTextMessage(ctx, t, db, rawDB, chat.ID, deps.user.ID, deps.modelConfig.ID, now.Add(-10*24*time.Hour)) - // Soft-delete all messages on this chat. - _, err := rawDB.ExecContext(ctx, "UPDATE chat_messages SET deleted = true WHERE chat_id = $1", chat.ID) - require.NoError(t, err) - - auditor := audit.NewMock() - auditorPtr := mockAuditorPtr(auditor) - done := awaitDoTick(ctx, t, clk) - closer := dbpurge.New(ctx, logger, db, &codersdk.DeploymentValues{}, prometheus.NewRegistry(), auditorPtr, dbpurge.WithClock(clk)) - defer closer.Close() - testutil.TryReceive(ctx, t, done) - - refreshed, err := db.GetChatByID(ctx, chat.ID) - require.NoError(t, err) - require.True(t, refreshed.Archived, "chat with only deleted messages should be archived") - require.Len(t, auditor.AuditLogs(), 1) - }, - }, - { - name: "ChildActivityKeepsRootAlive", - run: func(t *testing.T) { - h := newArchiveHarness(t, now) - ctx, clk, db, rawDB, logger, deps := h.ctx, h.clk, h.db, h.rawDB, h.logger, h.deps - - require.NoError(t, db.UpsertChatAutoArchiveDays(ctx, int32(90))) - - // Stale root with no messages of its own. - root := createArchiveChat(ctx, t, db, rawDB, deps, "stale-root", now.Add(-120*24*time.Hour)) - - // Child linked to root with a recent message (10 days old, - // well within the 90-day cutoff). - child := createArchiveChat(ctx, t, db, rawDB, deps, "active-child", now.Add(-120*24*time.Hour)) - _, err := rawDB.ExecContext(ctx, "UPDATE chats SET parent_chat_id = $1, root_chat_id = $1 WHERE id = $2", root.ID, child.ID) - require.NoError(t, err) - insertTextMessage(ctx, t, db, rawDB, child.ID, deps.user.ID, deps.modelConfig.ID, now.Add(-10*24*time.Hour)) - - auditor := audit.NewMock() - auditorPtr := mockAuditorPtr(auditor) - enqueuer := notificationstest.NewFakeEnqueuer() - done := awaitDoTick(ctx, t, clk) - closer := dbpurge.New(ctx, logger, db, &codersdk.DeploymentValues{}, prometheus.NewRegistry(), auditorPtr, dbpurge.WithNotificationsEnqueuer(enqueuer), dbpurge.WithClock(clk)) - defer closer.Close() - testutil.TryReceive(ctx, t, done) - - refreshedRoot, err := db.GetChatByID(ctx, root.ID) - require.NoError(t, err) - require.False(t, refreshedRoot.Archived, "root must stay active because child has recent activity") - - refreshedChild, err := db.GetChatByID(ctx, child.ID) - require.NoError(t, err) - require.False(t, refreshedChild.Archived, "child must stay active") - - require.Empty(t, auditor.AuditLogs(), "no chats should be archived") - require.Empty(t, enqueuer.Sent(), "no notifications should be sent") - }, - }, - { - name: "SkipsActiveStatusChats", - run: func(t *testing.T) { - h := newArchiveHarness(t, now) - ctx, clk, db, rawDB, logger, deps := h.ctx, h.clk, h.db, h.rawDB, h.logger, h.deps - - require.NoError(t, db.UpsertChatAutoArchiveDays(ctx, int32(90))) - - // Stale chats whose status prevents archiving. - runningChat := createArchiveChat(ctx, t, db, rawDB, deps, "running-chat", now.Add(-120*24*time.Hour)) - insertTextMessage(ctx, t, db, rawDB, runningChat.ID, deps.user.ID, deps.modelConfig.ID, now.Add(-100*24*time.Hour)) - _, err := rawDB.ExecContext(ctx, "UPDATE chats SET status = $1 WHERE id = $2", database.ChatStatusRunning, runningChat.ID) - require.NoError(t, err) - - requiresActionChat := createArchiveChat(ctx, t, db, rawDB, deps, "requires-action-chat", now.Add(-120*24*time.Hour)) - insertTextMessage(ctx, t, db, rawDB, requiresActionChat.ID, deps.user.ID, deps.modelConfig.ID, now.Add(-100*24*time.Hour)) - _, err = rawDB.ExecContext(ctx, "UPDATE chats SET status = $1 WHERE id = $2", database.ChatStatusRequiresAction, requiresActionChat.ID) - require.NoError(t, err) - - pendingChat := createArchiveChat(ctx, t, db, rawDB, deps, "pending-chat", now.Add(-120*24*time.Hour)) - insertTextMessage(ctx, t, db, rawDB, pendingChat.ID, deps.user.ID, deps.modelConfig.ID, now.Add(-100*24*time.Hour)) - _, err = rawDB.ExecContext(ctx, "UPDATE chats SET status = $1 WHERE id = $2", database.ChatStatusPending, pendingChat.ID) - require.NoError(t, err) - - pausedChat := createArchiveChat(ctx, t, db, rawDB, deps, "paused-chat", now.Add(-120*24*time.Hour)) - insertTextMessage(ctx, t, db, rawDB, pausedChat.ID, deps.user.ID, deps.modelConfig.ID, now.Add(-100*24*time.Hour)) - _, err = rawDB.ExecContext(ctx, "UPDATE chats SET status = $1 WHERE id = $2", database.ChatStatusPaused, pausedChat.ID) - require.NoError(t, err) - - // Control: a stale chat with archivable status that - // should be archived. - completedChat := createArchiveChat(ctx, t, db, rawDB, deps, "completed-chat", now.Add(-120*24*time.Hour)) - insertTextMessage(ctx, t, db, rawDB, completedChat.ID, deps.user.ID, deps.modelConfig.ID, now.Add(-100*24*time.Hour)) - _, err = rawDB.ExecContext(ctx, "UPDATE chats SET status = $1 WHERE id = $2", database.ChatStatusCompleted, completedChat.ID) - require.NoError(t, err) - - auditor := audit.NewMock() - auditorPtr := mockAuditorPtr(auditor) - enqueuer := notificationstest.NewFakeEnqueuer() - done := awaitDoTick(ctx, t, clk) - closer := dbpurge.New(ctx, logger, db, &codersdk.DeploymentValues{}, prometheus.NewRegistry(), auditorPtr, dbpurge.WithNotificationsEnqueuer(enqueuer), dbpurge.WithClock(clk)) - defer closer.Close() - testutil.TryReceive(ctx, t, done) - - refreshedRunning, err := db.GetChatByID(ctx, runningChat.ID) - require.NoError(t, err) - require.False(t, refreshedRunning.Archived, "running chat must not be archived") - - refreshedRA, err := db.GetChatByID(ctx, requiresActionChat.ID) - require.NoError(t, err) - require.False(t, refreshedRA.Archived, "requires_action chat must not be archived") - - refreshedPending, err := db.GetChatByID(ctx, pendingChat.ID) - require.NoError(t, err) - require.False(t, refreshedPending.Archived, "pending chat must not be archived") - - refreshedPaused, err := db.GetChatByID(ctx, pausedChat.ID) - require.NoError(t, err) - require.False(t, refreshedPaused.Archived, "paused chat must not be archived") - - refreshedCompleted, err := db.GetChatByID(ctx, completedChat.ID) - require.NoError(t, err) - require.True(t, refreshedCompleted.Archived, "completed stale chat should be archived") - - logs := auditor.AuditLogs() - require.Len(t, logs, 1, "only the completed chat should produce an audit entry") - require.Equal(t, completedChat.ID, logs[0].ResourceID) - - // Assert number of sent notifications to catch dispatch regressions. - sent := enqueuer.Sent() - require.Len(t, sent, 1, "expected one digest notification for the completed chat") - require.Equal(t, notifications.TemplateChatAutoArchiveDigest, sent[0].TemplateID) - require.Equal(t, deps.user.ID, sent[0].UserID) - }, - }, - { - name: "SkipsPinnedAndChildren", - run: func(t *testing.T) { - h := newArchiveHarness(t, now) - ctx, clk, db, rawDB, logger, deps := h.ctx, h.clk, h.db, h.rawDB, h.logger, h.deps - - require.NoError(t, db.UpsertChatAutoArchiveDays(ctx, int32(30))) - - // Pinned stale chat: should be skipped. - pinnedChat := createArchiveChat(ctx, t, db, rawDB, deps, "pinned-chat", now.Add(-90*24*time.Hour)) - _, err := rawDB.ExecContext(ctx, "UPDATE chats SET pin_order = 1 WHERE id = $1", pinnedChat.ID) - require.NoError(t, err) - - // Stale root with a child. - root := createArchiveChat(ctx, t, db, rawDB, deps, "root-chat", now.Add(-90*24*time.Hour)) - child := createArchiveChat(ctx, t, db, rawDB, deps, "child-chat", now.Add(-90*24*time.Hour)) - _, err = rawDB.ExecContext(ctx, "UPDATE chats SET parent_chat_id = $1, root_chat_id = $1 WHERE id = $2", root.ID, child.ID) - require.NoError(t, err) - // Give the child an active status to prove the cascade is - // status-blind by design. If someone adds a status filter - // to the cascade CTE, this assertion will catch it. - _, err = rawDB.ExecContext(ctx, "UPDATE chats SET status = $1 WHERE id = $2", database.ChatStatusRunning, child.ID) - require.NoError(t, err) - - auditor := audit.NewMock() - auditorPtr := mockAuditorPtr(auditor) - enqueuer := notificationstest.NewFakeEnqueuer() - done := awaitDoTick(ctx, t, clk) - closer := dbpurge.New(ctx, logger, db, &codersdk.DeploymentValues{}, prometheus.NewRegistry(), auditorPtr, dbpurge.WithNotificationsEnqueuer(enqueuer), dbpurge.WithClock(clk)) - defer closer.Close() - testutil.TryReceive(ctx, t, done) - - refreshedPinned, err := db.GetChatByID(ctx, pinnedChat.ID) - require.NoError(t, err) - require.False(t, refreshedPinned.Archived, "pinned chat must be skipped") - - refreshedRoot, err := db.GetChatByID(ctx, root.ID) - require.NoError(t, err) - require.True(t, refreshedRoot.Archived, "root should be archived") - - refreshedChild, err := db.GetChatByID(ctx, child.ID) - require.NoError(t, err) - require.True(t, refreshedChild.Archived, "child should be cascade-archived") - - // One audit entry for the root; the cascaded child is - // not audited individually. - require.Len(t, auditor.AuditLogs(), 1) - - // Digest should list only the root (one row). - sent := enqueuer.Sent() - require.Len(t, sent, 1) - data := sent[0].Data - require.NotNil(t, data) - chats, ok := data["archived_chats"].([]map[string]any) - require.True(t, ok, "archived_chats should be []map[string]any") - require.Len(t, chats, 1, "digest should only list the root") - require.Equal(t, "root-chat", chats[0]["title"]) - }, - }, - { - name: "DigestOverflowCap", - run: func(t *testing.T) { - // 27 inactive roots exceed chatAutoArchiveDigestMaxChats - // (25). All 27 should archive, but the digest payload - // lists at most 25 titles and surfaces the rest via - // additional_archived_count so the template can render - // "...and N more". - h := newArchiveHarness(t, now) - ctx, clk, db, rawDB, logger, deps := h.ctx, h.clk, h.db, h.rawDB, h.logger, h.deps - - require.NoError(t, db.UpsertChatAutoArchiveDays(ctx, int32(30))) - - const total = 27 - for i := range total { - createArchiveChat(ctx, t, db, rawDB, deps, - fmt.Sprintf("stale-%02d", i), - now.Add(-60*24*time.Hour)) - } - - auditor := audit.NewMock() - auditorPtr := mockAuditorPtr(auditor) - enqueuer := notificationstest.NewFakeEnqueuer() - done := awaitDoTick(ctx, t, clk) - closer := dbpurge.New(ctx, logger, db, &codersdk.DeploymentValues{}, prometheus.NewRegistry(), auditorPtr, dbpurge.WithNotificationsEnqueuer(enqueuer), dbpurge.WithClock(clk)) - defer closer.Close() - testutil.TryReceive(ctx, t, done) - - // All 27 roots archived (one audit each). - require.Len(t, auditor.AuditLogs(), total) - - sent := enqueuer.Sent() - require.Len(t, sent, 1, "one digest per owner") - chats, ok := sent[0].Data["archived_chats"].([]map[string]any) - require.True(t, ok, "archived_chats should be []map[string]any") - require.Len(t, chats, 25, "digest caps titles at 25") - require.Equal(t, "2", sent[0].Data["additional_archived_count"], - "overflow count is total - cap") - // Humanized timestamp is computed from LastActivityAt - // and the tick-start time, not a static fixture, so we - // only assert the suffix the humanizer emits. - humanized, _ := chats[0]["last_activity_humanized"].(string) - require.Contains(t, humanized, "ago", - "last_activity_humanized should be a past relative time") - }, - }, - { - name: "MultipleOwners", - run: func(t *testing.T) { - h := newArchiveHarness(t, now) - ctx, clk, db, rawDB, logger, deps := h.ctx, h.clk, h.db, h.rawDB, h.logger, h.deps - user2 := dbgen.User(t, db, database.User{}) - _ = dbgen.OrganizationMember(t, db, database.OrganizationMember{UserID: user2.ID, OrganizationID: deps.org.ID}) - - require.NoError(t, db.UpsertChatAutoArchiveDays(ctx, int32(30))) - - // Two stale roots per owner, backdated well past - // the 30-day cutoff. - u1Deps := deps - u2Deps := chatAutoArchiveDeps{user: user2, org: deps.org, modelConfig: deps.modelConfig} - createArchiveChat(ctx, t, db, rawDB, u1Deps, "u1-a", now.Add(-60*24*time.Hour)) - createArchiveChat(ctx, t, db, rawDB, u1Deps, "u1-b", now.Add(-60*24*time.Hour)) - createArchiveChat(ctx, t, db, rawDB, u2Deps, "u2-a", now.Add(-60*24*time.Hour)) - createArchiveChat(ctx, t, db, rawDB, u2Deps, "u2-b", now.Add(-60*24*time.Hour)) - - auditor := audit.NewMock() - auditorPtr := mockAuditorPtr(auditor) - enqueuer := notificationstest.NewFakeEnqueuer() - done := awaitDoTick(ctx, t, clk) - closer := dbpurge.New(ctx, logger, db, &codersdk.DeploymentValues{}, prometheus.NewRegistry(), auditorPtr, dbpurge.WithNotificationsEnqueuer(enqueuer), dbpurge.WithClock(clk)) - defer closer.Close() - testutil.TryReceive(ctx, t, done) - - // Four audit rows, one per archived root, attributed - // to the owning user so downstream consumers can - // correlate per-owner activity. - logs := auditor.AuditLogs() - require.Len(t, logs, 4) - auditsByUser := map[uuid.UUID]int{} - for _, l := range logs { - auditsByUser[l.UserID]++ - } - require.Equal(t, 2, auditsByUser[deps.user.ID]) - require.Equal(t, 2, auditsByUser[user2.ID]) - - // One digest per owner, each listing only that owner's - // two chats. - sent := enqueuer.Sent() - require.Len(t, sent, 2, "expected one digest per owner") - - byUser := map[uuid.UUID][]string{} - for _, s := range sent { - require.Equal(t, notifications.TemplateChatAutoArchiveDigest, s.TemplateID) - chats, ok := s.Data["archived_chats"].([]map[string]any) - require.True(t, ok, "archived_chats should be []map[string]any") - for _, c := range chats { - title, _ := c["title"].(string) - byUser[s.UserID] = append(byUser[s.UserID], title) - } - } - require.Contains(t, byUser, deps.user.ID) - require.Contains(t, byUser, user2.ID) - slices.Sort(byUser[deps.user.ID]) - slices.Sort(byUser[user2.ID]) - require.Equal(t, []string{"u1-a", "u1-b"}, byUser[deps.user.ID]) - require.Equal(t, []string{"u2-a", "u2-b"}, byUser[user2.ID]) - }, - }, - { - name: "SecondTickIdempotent", - run: func(t *testing.T) { - h := newArchiveHarness(t, now) - ctx, clk, db, rawDB, logger, deps := h.ctx, h.clk, h.db, h.rawDB, h.logger, h.deps - - require.NoError(t, db.UpsertChatAutoArchiveDays(ctx, int32(30))) - - // Two stale roots seeded before the first tick. - firstA := createArchiveChat(ctx, t, db, rawDB, deps, "first-a", now.Add(-60*24*time.Hour)) - firstB := createArchiveChat(ctx, t, db, rawDB, deps, "first-b", now.Add(-60*24*time.Hour)) - - auditor := audit.NewMock() - auditorPtr := mockAuditorPtr(auditor) - enqueuer := notificationstest.NewFakeEnqueuer() - driver := newTickDriver(t, clk) - closer := dbpurge.New(ctx, logger, db, &codersdk.DeploymentValues{}, prometheus.NewRegistry(), auditorPtr, dbpurge.WithNotificationsEnqueuer(enqueuer), dbpurge.WithClock(clk)) - // Defer driver.close() after closer.Close(): defers - // run LIFO, so this frees shutdown's ticker.Stop() - // before the dbpurge goroutine blocks on it. - defer closer.Close() - defer driver.close() - driver.awaitInitial(ctx, t) - - // Tick 1: both archived, one digest. - require.Len(t, auditor.AuditLogs(), 2, "tick 1 audits") - require.Len(t, enqueuer.Sent(), 1, "tick 1 digests") - - // Seed a third stale root between ticks so tick 2 has - // genuine work and we can distinguish "ignored already - // archived" from "ignored everything". - third := createArchiveChat(ctx, t, db, rawDB, deps, "second-c", now.Add(-60*24*time.Hour)) - - driver.awaitNext(ctx, t) - - // Tick 2: exactly one new audit + one new digest for - // the third chat; tick 1's rows must not be re-archived. - require.Len(t, auditor.AuditLogs(), 3, "tick 2 cumulative audits") - sent := enqueuer.Sent() - require.Len(t, sent, 2, "tick 2 cumulative digests") - chats, ok := sent[1].Data["archived_chats"].([]map[string]any) - require.True(t, ok, "archived_chats should be []map[string]any") - require.Len(t, chats, 1, "tick 2 digest lists only the new chat") - require.Equal(t, "second-c", chats[0]["title"]) - - // First-tick chats stayed archived. - for _, id := range []uuid.UUID{firstA.ID, firstB.ID, third.ID} { - refreshed, err := db.GetChatByID(ctx, id) - require.NoError(t, err) - require.True(t, refreshed.Archived, "chat %s should remain archived", id) - } - }, - }, - { - name: "BatchSizePagination", - run: func(t *testing.T) { - // With 27 stale roots and batch size 20, tick 1 - // archives 20, tick 2 archives the remaining 7, and - // tick 3 archives none. We assert the dispatch side - // effects (audits, digests) follow the same pattern: - // dispatch only runs when rows > 0, so tick 3 emits - // no new audits or digests. - // - // The two-digest count asserted here is a consequence - // of the per-tick enqueue model, not a product - // invariant. notification_messages dedupe does not - // collapse these because each tick's payload differs. - // If enqueue is ever restructured to one notification - // per owner per day, this assertion changes with it. - h := newArchiveHarness(t, now) - ctx, clk, db, rawDB, logger, deps := h.ctx, h.clk, h.db, h.rawDB, h.logger, h.deps - - require.NoError(t, db.UpsertChatAutoArchiveDays(ctx, int32(30))) - - const total = 27 - for i := range total { - createArchiveChat(ctx, t, db, rawDB, deps, - fmt.Sprintf("page-%02d", i), - now.Add(-60*24*time.Hour)) - } - - auditor := audit.NewMock() - auditorPtr := mockAuditorPtr(auditor) - enqueuer := notificationstest.NewFakeEnqueuer() - driver := newTickDriver(t, clk) - closer := dbpurge.New(ctx, logger, db, &codersdk.DeploymentValues{}, prometheus.NewRegistry(), auditorPtr, dbpurge.WithNotificationsEnqueuer(enqueuer), dbpurge.WithClock(clk), dbpurge.WithChatAutoArchiveBatchSize(20)) - // Defer driver.close() after closer.Close() so trap - // cleanup frees shutdown's ticker.Stop() before the - // dbpurge goroutine blocks on it. - defer closer.Close() - defer driver.close() - driver.awaitInitial(ctx, t) - - // Tick 1: first batch (20) archived. - require.Len(t, auditor.AuditLogs(), 20, "tick 1 audits") - sent := enqueuer.Sent() - require.Len(t, sent, 1, "tick 1 digests") - chats1, ok := sent[0].Data["archived_chats"].([]map[string]any) - require.True(t, ok, "archived_chats should be []map[string]any") - require.Len(t, chats1, 20, "tick 1 digest lists all 20 titles") - require.NotContains(t, sent[0].Data, "additional_archived_count", - "no overflow when batch <= digest cap; 20 <= 25") - - driver.awaitNext(ctx, t) - - // Tick 2: remaining 7 archived. - require.Len(t, auditor.AuditLogs(), 27, "tick 2 cumulative audits") - sent = enqueuer.Sent() - require.Len(t, sent, 2, "tick 2 cumulative digests") - chats2, ok := sent[1].Data["archived_chats"].([]map[string]any) - require.True(t, ok, "archived_chats should be []map[string]any") - require.Len(t, chats2, 7, "tick 2 digest lists remaining 7") - - driver.awaitNext(ctx, t) - - // Tick 3: nothing left to archive. The dispatch is - // gated on len(archivedChats) > 0, so no new audits - // or digests are produced. If that gate is ever - // removed, update this assertion intentionally. - require.Len(t, auditor.AuditLogs(), 27, "tick 3 cumulative audits unchanged") - require.Len(t, enqueuer.Sent(), 2, "tick 3 cumulative digests unchanged") - }, - }, - { - name: "ShutdownCancelsDigestDispatch", - run: func(t *testing.T) { - // Two owners with one stale root each. The first - // EnqueueWithData call blocks until ctx is canceled. - // Closing the purger must propagate cancellation - // into the in-flight call and short-circuit the - // rest of the loop, so Close returns promptly - // instead of hanging on dispatch. - h := newArchiveHarness(t, now) - ctx, clk, db, rawDB, logger, deps := h.ctx, h.clk, h.db, h.rawDB, h.logger, h.deps - user2 := dbgen.User(t, db, database.User{}) - _ = dbgen.OrganizationMember(t, db, database.OrganizationMember{UserID: user2.ID, OrganizationID: deps.org.ID}) - - require.NoError(t, db.UpsertChatAutoArchiveDays(ctx, int32(30))) - - u1Deps := deps - u2Deps := chatAutoArchiveDeps{user: user2, org: deps.org, modelConfig: deps.modelConfig} - createArchiveChat(ctx, t, db, rawDB, u1Deps, "u1-stale", now.Add(-60*24*time.Hour)) - createArchiveChat(ctx, t, db, rawDB, u2Deps, "u2-stale", now.Add(-60*24*time.Hour)) - - // Dispatch iterates owner IDs in ascending UUID order (convention). - expectedFirst := deps.user.ID - if user2.ID.String() < deps.user.ID.String() { - expectedFirst = user2.ID - } - - ctrl := gomock.NewController(t) - mockEnq := notificationsmock.NewMockEnqueuer(ctrl) - started := make(chan struct{}) - mockEnq.EXPECT().EnqueueWithData(gomock.Any(), gomock.Eq(expectedFirst), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(ctx context.Context, _, _ uuid.UUID, _ map[string]string, _ map[string]any, _ string, _ ...uuid.UUID) ([]uuid.UUID, error) { - close(started) - <-ctx.Done() - return nil, ctx.Err() - }).Times(1) - - closer := dbpurge.New(ctx, logger, db, &codersdk.DeploymentValues{}, prometheus.NewRegistry(), nopAuditorPtr(t), dbpurge.WithNotificationsEnqueuer(mockEnq), dbpurge.WithClock(clk)) - - // Wait for the forced initial tick to reach the first - // enqueue, which then blocks on ctx.Done(). - testutil.TryReceive(ctx, t, started) - - // Blocked enqueue receives ctx cancellation via the parent context. - // Loop-head check abandons the remaining owner instead of trying to enqueue. - done := make(chan error) - go func() { done <- closer.Close() }() - testutil.RequireReceive(ctx, t, done) - }, - }, - { - // A transient enqueue failure for one owner must not abort the dispatch loop. - name: "TransientEnqueueFailureDoesNotAbortLoop", - run: func(t *testing.T) { - h := newArchiveHarness(t, now) - ctx, clk, db, rawDB, logger, deps := h.ctx, h.clk, h.db, h.rawDB, h.logger, h.deps - user2 := dbgen.User(t, db, database.User{}) - _ = dbgen.OrganizationMember(t, db, database.OrganizationMember{UserID: user2.ID, OrganizationID: deps.org.ID}) - - require.NoError(t, db.UpsertChatAutoArchiveDays(ctx, int32(30))) - - u1Deps := deps - u2Deps := chatAutoArchiveDeps{user: user2, org: deps.org, modelConfig: deps.modelConfig} - createArchiveChat(ctx, t, db, rawDB, u1Deps, "u1-stale", now.Add(-60*24*time.Hour)) - createArchiveChat(ctx, t, db, rawDB, u2Deps, "u2-stale", now.Add(-60*24*time.Hour)) - - auditor := audit.NewMock() - auditorPtr := mockAuditorPtr(auditor) - - ctrl := gomock.NewController(t) - mockEnq := notificationsmock.NewMockEnqueuer(ctrl) - var calls atomic.Int32 - var successUserID uuid.UUID - mockEnq.EXPECT().EnqueueWithData(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). - DoAndReturn(func(_ context.Context, userID, _ uuid.UUID, _ map[string]string, _ map[string]any, _ string, _ ...uuid.UUID) ([]uuid.UUID, error) { - if calls.Add(1) == 1 { - return nil, xerrors.New("simulated transient enqueue failure") - } - successUserID = userID - return nil, nil - }).Times(2) - - done := awaitDoTick(ctx, t, clk) - closer := dbpurge.New(ctx, logger, db, &codersdk.DeploymentValues{}, prometheus.NewRegistry(), auditorPtr, dbpurge.WithNotificationsEnqueuer(mockEnq), dbpurge.WithClock(clk)) - defer closer.Close() - testutil.TryReceive(ctx, t, done) - - // Both owners must have been audited regardless of - // digest enqueue outcomes; the audit and digest - // paths are independent. - require.Len(t, auditor.AuditLogs(), 2, "both archived roots must be audited") - - // gomock's .Times(2) already enforces both calls - // happened; this assertion makes the contract - // explicit at the test site. - require.Equal(t, int32(2), calls.Load(), - "loop must attempt every owner even when one fails") - - // The second attempt succeeded for one of the two owners. - require.Contains(t, []uuid.UUID{deps.user.ID, user2.ID}, successUserID, - "successful digest must belong to one of the two owners") - }, - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - tc.run(t) - }) - } -} diff --git a/coderd/exp_chats.go b/coderd/exp_chats.go index 4907ed6741..c164c68718 100644 --- a/coderd/exp_chats.go +++ b/coderd/exp_chats.go @@ -1277,6 +1277,12 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) { } aReq.New = chat + // Kick off best-effort automatic title generation now that the + // chat and its initial user message are persisted. It runs + // detached so it never blocks the create response, and only acts + // on the first user turn. + api.chatDaemon.GenerateChatTitleAsync(ctx, chat) + chatFiles := api.fetchChatFileMetadata(ctx, chat.ID) response := db2sdk.Chat(chat, nil, chatFiles) if len(unlinked) > 0 { diff --git a/coderd/exp_chats_test.go b/coderd/exp_chats_test.go index 43c9fdbfa8..e8cff4e289 100644 --- a/coderd/exp_chats_test.go +++ b/coderd/exp_chats_test.go @@ -1842,7 +1842,9 @@ func TestWatchChats(t *testing.T) { require.Equal(t, createdChat.OwnerID, got.OwnerID) require.Equal(t, modelConfig.ID, got.LastModelConfigID) require.Equal(t, createdChat.Title, got.Title) - require.Equal(t, codersdk.ChatStatusPending, got.Status) + // CreateChat inserts new chats in the running state under the + // chatstate state machine, so the created event carries running. + require.Equal(t, codersdk.ChatStatusRunning, got.Status) require.NotNil(t, got.RootChatID) require.Equal(t, createdChat.ID, *got.RootChatID) require.NotZero(t, got.CreatedAt) @@ -1955,7 +1957,7 @@ func TestWatchChats(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) - client, db := newChatClientWithDatabase(t) + client, db, api := newChatClientWithAPIAndDatabase(t) user := coderdtest.CreateFirstUser(t, client.Client) modelConfig := createChatModelConfig(t, client) @@ -1970,6 +1972,11 @@ func TestWatchChats(t *testing.T) { }) require.NoError(t, err) + // The parent chat is created via the API, so the chat worker moves + // it to running. Archiving is only allowed from a terminal state, + // so wait for it to settle before archiving below. + coderdtest.WaitForChatSettled(ctx, t, api, parentChat.ID) + childOne := dbgen.Chat(t, db, database.Chat{ OrganizationID: user.OrganizationID, OwnerID: user.UserID, @@ -4573,7 +4580,7 @@ func TestGetChat(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) - client, db := newChatClientWithDatabase(t) + client, db, api := newChatClientWithAPIAndDatabase(t) user := coderdtest.CreateFirstUser(t, client.Client) modelConfig := createChatModelConfig(t, client) @@ -4588,6 +4595,11 @@ func TestGetChat(t *testing.T) { }) require.NoError(t, err) + // The parent chat is created via the API, so the chat worker moves + // it to running. Archiving is only allowed from a terminal state, + // so wait for it to settle before archiving below. + coderdtest.WaitForChatSettled(ctx, t, api, parentChat.ID) + child := dbgen.Chat(t, db, database.Chat{ OrganizationID: user.OrganizationID, OwnerID: user.UserID, @@ -8653,6 +8665,58 @@ func TestManualTitleEndpointsPassCallerAPIKeyToAIGateway(t *testing.T) { } } +func TestPostChats_AutomaticTitleGeneration(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + + // titleRequested is signaled when the provider receives the structured + // title-generation request. Automatic title generation issues a + // non-streaming request using the "propose_title" schema, which uniquely + // identifies it (the turn status label uses "propose_turn_status_label"). + titleRequested := make(chan struct{}, 1) + baseURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { + if req.Stream { + return chattest.OpenAIStreamingResponse(chattest.OpenAITextChunks("Hello from test server.")...) + } + if bytes.Contains(req.RawBody, []byte("propose_title")) { + select { + case titleRequested <- struct{}{}: + default: + } + } + return chattest.OpenAINonStreamingResponse(`{"title": "Generated Title"}`) + }) + + client, api := newChatClientWithAPI(t) + firstUser := coderdtest.CreateFirstUser(t, client.Client) + _ = createChatModelConfigWithBaseURL(t, client, baseURL) + + chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ + OrganizationID: firstUser.OrganizationID, + Content: []codersdk.ChatInputPart{{ + Type: codersdk.ChatInputPartTypeText, + Text: "automatic title generation please", + }}, + }) + require.NoError(t, err) + // The create response carries the synchronous fallback title derived from + // the message, not the asynchronously generated one. + require.Equal(t, "automatic title generation please", chat.Title) + + // The create endpoint kicks off detached title generation; the provider + // should receive the title request without any further client action. + select { + case <-titleRequested: + case <-ctx.Done(): + t.Fatal("timed out waiting for automatic title generation to be triggered") + } + + // Drain background work so the detached goroutine finishes before the test + // (and its fake provider) tears down. + coderdtest.WaitForChatSettled(ctx, t, api, chat.ID) +} + func TestGetChatDiffStatus(t *testing.T) { t.Parallel() diff --git a/coderd/x/chatd/attempt.go b/coderd/x/chatd/attempt.go new file mode 100644 index 0000000000..0b803e8586 --- /dev/null +++ b/coderd/x/chatd/attempt.go @@ -0,0 +1,64 @@ +package chatd + +import ( + "database/sql" + "time" + + "charm.land/fantasy" + + "github.com/coder/coder/v2/codersdk" +) + +type runnerActionKind string + +type runnerActionMessage struct { + ID int64 + Role codersdk.ChatMessageRole +} + +const ( + runnerActionKindEnterRequiresAction runnerActionKind = "enter_requires_action" + runnerActionKindFinishTurn runnerActionKind = "finish_turn" + runnerActionKindFinishError runnerActionKind = "finish_error" + runnerActionKindFinishInterruption runnerActionKind = "finish_interruption" +) + +// stepData is the durable content produced by one provider attempt. +type stepData struct { + Content []fantasy.Content + Usage fantasy.Usage + ContextLimit sql.NullInt64 + ProviderResponseID string + Runtime time.Duration + + ToolCallCreatedAt map[string]time.Time + ToolResultCreatedAt map[string]time.Time + ReasoningStartedAt []time.Time + ReasoningCompletedAt []time.Time +} + +// pendingDynamicToolCall describes a dynamic tool call parked for a user. +type pendingDynamicToolCall struct { + ToolCallID string + ToolName string + Args string +} + +// compactionOutcome contains a generated context summary. +type compactionOutcome struct { + SystemSummary string + SummaryReport string + ThresholdPercent int32 + UsagePercent float64 + ContextTokens int64 + ContextLimit int64 +} + +type compactionStatus int + +const ( + compactionStatusNotNeeded compactionStatus = iota + compactionStatusNeeded + compactionStatusAfterCompaction + compactionStatusStillOverLimit +) diff --git a/coderd/x/chatd/auto_archive.go b/coderd/x/chatd/auto_archive.go new file mode 100644 index 0000000000..e045447632 --- /dev/null +++ b/coderd/x/chatd/auto_archive.go @@ -0,0 +1,315 @@ +package chatd + +import ( + "cmp" + "context" + "database/sql" + "errors" + "net/http" + "slices" + "strconv" + "time" + + "github.com/dustin/go-humanize" + "github.com/google/uuid" + "golang.org/x/xerrors" + + "cdr.dev/slog/v3" + "github.com/coder/coder/v2/coderd/audit" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbauthz" + "github.com/coder/coder/v2/coderd/database/dbtime" + "github.com/coder/coder/v2/coderd/notifications" + "github.com/coder/coder/v2/coderd/x/chatd/chatdebug" + "github.com/coder/coder/v2/coderd/x/chatd/chatstate" + "github.com/coder/coder/v2/codersdk" +) + +const chatAutoArchiveDigestMaxChats = 25 + +type autoArchivedChat struct { + Chat database.Chat + LastActivityAt time.Time +} + +func (w *chatWorker) archiveLoop(ctx context.Context) { + ticker := w.opts.Clock.NewTicker(w.opts.ArchiveInterval, "chatworker", "auto-archive") + defer ticker.Stop() + w.archiveOnce(ctx, dbtime.Time(w.opts.Clock.Now("chatworker", "auto-archive")).UTC()) + for { + select { + case tick := <-ticker.C: + w.archiveOnce(ctx, dbtime.Time(tick).UTC()) + case <-ctx.Done(): + return + } + } +} + +func (w *chatWorker) archiveOnce(ctx context.Context, start time.Time) { + autoArchiveDays, err := w.opts.Store.GetChatAutoArchiveDays(ctx, codersdk.DefaultChatAutoArchiveDays) + if err != nil { + if ctx.Err() == nil { + w.opts.Logger.Warn(ctx, "chatworker auto-archive config read failed", slogError(err)) + } + return + } + if autoArchiveDays <= 0 { + return + } + retentionDays, err := w.opts.Store.GetChatRetentionDays(ctx) + if err != nil { + if ctx.Err() == nil { + w.opts.Logger.Warn(ctx, "chatworker chat retention config read failed", slogError(err)) + } + return + } + + archiveCutoff := dbtime.StartOfDay(start).Add(-time.Duration(autoArchiveDays) * 24 * time.Hour) + rows, err := w.opts.Store.GetAutoArchiveInactiveChatCandidates(ctx, database.GetAutoArchiveInactiveChatCandidatesParams{ + ArchiveCutoff: archiveCutoff, + LimitCount: w.opts.ArchiveBatchSize, + }) + if err != nil { + if ctx.Err() == nil { + w.opts.Logger.Warn(ctx, "chatworker auto-archive query failed", slogError(err)) + } + return + } + if len(rows) == 0 { + return + } + + archived := make([]autoArchivedChat, 0, len(rows)) + for _, row := range rows { + family, err := w.archiveCandidateSafely(ctx, row) + if err != nil { + if ctx.Err() != nil { + return + } + if isExpectedAutoArchiveError(err) { + w.opts.Logger.Debug(ctx, "chatworker auto-archive skipped chat", + slog.F("chat_id", row.ID), + slog.Error(err), + ) + continue + } + w.opts.Logger.Warn(ctx, "chatworker auto-archive candidate failed", + slog.F("chat_id", row.ID), + slog.Error(err), + ) + continue + } + archived = append(archived, family...) + } + if len(archived) == 0 { + return + } + if w.opts.AutoArchiveRecords != nil { + w.opts.AutoArchiveRecords.Add(float64(len(archived))) + } + w.dispatchChatAutoArchive(context.WithoutCancel(ctx), ctx, start, autoArchiveDays, retentionDays, archived) +} + +func (w *chatWorker) archiveCandidateSafely( + ctx context.Context, + row database.GetAutoArchiveInactiveChatCandidatesRow, +) (family []autoArchivedChat, err error) { + defer func() { + if recovered := recover(); recovered != nil { + err = xerrors.Errorf("chatworker auto-archive panic: %v", recovered) + } + }() + return w.archiveCandidate(ctx, row) +} + +func (w *chatWorker) archiveCandidate( + ctx context.Context, + row database.GetAutoArchiveInactiveChatCandidatesRow, +) ([]autoArchivedChat, error) { + familyChats, err := chatstate.SetFamilyArchived(ctx, w.opts.Store, w.opts.Pubsub, chatstate.SetFamilyArchivedInput{ + RootID: row.ID, + Archived: true, + }) + if err != nil { + return nil, err + } + if len(familyChats) == 0 { + return nil, nil + } + w.scheduleArchiveDebugCleanup(ctx, familyChats) + w.publishArchiveWatchEvents(familyChats) + + archived := make([]autoArchivedChat, 0, len(familyChats)) + for _, chat := range familyChats { + lastActivityAt := row.LastActivityAt + if lastActivityAt.IsZero() { + lastActivityAt = chat.CreatedAt + } + archived = append(archived, autoArchivedChat{ + Chat: chat, + LastActivityAt: lastActivityAt, + }) + } + return archived, nil +} + +func isExpectedAutoArchiveError(err error) bool { + return errors.Is(err, sql.ErrNoRows) || + errors.Is(err, chatstate.ErrChatNotFound) || + errors.Is(err, chatstate.ErrChatNotRoot) || + errors.Is(err, chatstate.ErrInvalidState) || + errors.Is(err, chatstate.ErrTransitionNotAllowed) +} + +func (w *chatWorker) publishArchiveWatchEvents(familyChats []database.Chat) { + if w.server != nil { + w.server.publishChatPubsubEvents(familyChats, codersdk.ChatWatchEventKindDeleted) + return + } + for _, chat := range familyChats { + if err := publishChatWatchEvent(w.opts.Pubsub, chat, codersdk.ChatWatchEventKindDeleted); err != nil { + w.opts.Logger.Warn(context.Background(), "chatworker auto-archive watch publish failed", + slog.F("chat_id", chat.ID), + slog.Error(err), + ) + } + } +} + +func (w *chatWorker) scheduleArchiveDebugCleanup(ctx context.Context, familyChats []database.Chat) { + if w.server == nil || len(familyChats) == 0 { + return + } + w.server.scheduleArchiveDebugCleanup(ctx, familyChats) +} + +func (p *Server) scheduleArchiveDebugCleanup(ctx context.Context, familyChats []database.Chat) { + if len(familyChats) == 0 { + return + } + archiveCutoff := familyChats[0].UpdatedAt.Add(-debugCleanupClockSkew) + for _, archivedChat := range familyChats { + p.scheduleDebugCleanup( + ctx, + "failed to delete chat debug rows after archive", + []slog.Field{slog.F("chat_id", archivedChat.ID)}, + func(cleanupCtx context.Context, debugSvc *chatdebug.Service) error { + _, err := debugSvc.DeleteByChatID(cleanupCtx, archivedChat.ID, archiveCutoff) + return err + }, + ) + } +} + +func (w *chatWorker) dispatchChatAutoArchive( + auditCtx context.Context, + enqueueCtx context.Context, + tickStart time.Time, + autoArchiveDays int32, + retentionDays int32, + archived []autoArchivedChat, +) { + roots := make([]autoArchivedChat, 0, len(archived)) + for _, record := range archived { + if !record.Chat.ParentChatID.Valid { + roots = append(roots, record) + } + } + w.auditAutoArchivedChats(auditCtx, roots) + w.enqueueAutoArchiveDigests(enqueueCtx, tickStart, autoArchiveDays, retentionDays, roots) +} + +func (w *chatWorker) auditAutoArchivedChats(ctx context.Context, roots []autoArchivedChat) { + if w.opts.Auditor == nil { + return + } + auditor := w.opts.Auditor.Load() + if auditor == nil { + return + } + for _, record := range roots { + after := record.Chat + before := after + before.Archived = false + audit.BackgroundAudit(ctx, &audit.BackgroundAuditParams[database.Chat]{ + Audit: *auditor, + Log: w.opts.Logger, + UserID: after.OwnerID, + OrganizationID: after.OrganizationID, + Action: database.AuditActionWrite, + Old: before, + New: after, + Status: http.StatusOK, + AdditionalFields: audit.BackgroundTaskFieldsBytes(ctx, w.opts.Logger, audit.BackgroundSubsystemChatAutoArchive), + }) + } +} + +func (w *chatWorker) enqueueAutoArchiveDigests( + ctx context.Context, + tickStart time.Time, + autoArchiveDays int32, + retentionDays int32, + roots []autoArchivedChat, +) { + rootsByOwner := make(map[uuid.UUID][]autoArchivedChat, len(roots)) + for _, record := range roots { + rootsByOwner[record.Chat.OwnerID] = append(rootsByOwner[record.Chat.OwnerID], record) + } + ownerIDs := make([]uuid.UUID, 0, len(rootsByOwner)) + for id := range rootsByOwner { + ownerIDs = append(ownerIDs, id) + } + slices.SortFunc(ownerIDs, func(a, b uuid.UUID) int { + return cmp.Compare(a.String(), b.String()) + }) + for i, ownerID := range ownerIDs { + if err := ctx.Err(); err != nil { + w.opts.Logger.Warn(ctx, "chat auto-archive digest dispatch canceled", + slog.F("remaining_owners", len(ownerIDs)-i), + slog.Error(err), + ) + return + } + data := buildAutoArchiveDigestData(rootsByOwner[ownerID], autoArchiveDays, retentionDays, tickStart) + //nolint:gocritic // Background digest dispatch runs as the notifier subject. + if _, err := w.opts.NotificationsEnqueuer.EnqueueWithData( + dbauthz.AsNotifier(ctx), + ownerID, + notifications.TemplateChatAutoArchiveDigest, + map[string]string{}, + data, + string(audit.BackgroundSubsystemChatAutoArchive), + ); err != nil { + w.opts.Logger.Warn(ctx, "failed to enqueue chat auto-archive digest", + slog.F("owner_id", ownerID), + slog.Error(err), + ) + } + } +} + +func buildAutoArchiveDigestData(rows []autoArchivedChat, autoArchiveDays, retentionDays int32, tickStart time.Time) map[string]any { + overflow := 0 + if len(rows) > chatAutoArchiveDigestMaxChats { + overflow = len(rows) - chatAutoArchiveDigestMaxChats + rows = rows[:chatAutoArchiveDigestMaxChats] + } + chats := make([]map[string]any, 0, len(rows)) + for _, r := range rows { + chats = append(chats, map[string]any{ + "title": r.Chat.Title, + "last_activity_humanized": humanize.RelTime(r.LastActivityAt, tickStart, "ago", "from now"), + }) + } + data := map[string]any{ + "auto_archive_days": strconv.Itoa(int(autoArchiveDays)), + "retention_days": strconv.Itoa(int(retentionDays)), + "archived_chats": chats, + } + if overflow > 0 { + data["additional_archived_count"] = strconv.Itoa(overflow) + } + return data +} diff --git a/coderd/x/chatd/auto_archive_internal_test.go b/coderd/x/chatd/auto_archive_internal_test.go new file mode 100644 index 0000000000..f761ef0599 --- /dev/null +++ b/coderd/x/chatd/auto_archive_internal_test.go @@ -0,0 +1,821 @@ +package chatd + +import ( + "context" + "database/sql" + "fmt" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/google/uuid" + "github.com/prometheus/client_golang/prometheus" + promtestutil "github.com/prometheus/client_golang/prometheus/testutil" + "github.com/stretchr/testify/require" + "golang.org/x/xerrors" + + "cdr.dev/slog/v3" + "github.com/coder/coder/v2/coderd/audit" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbgen" + "github.com/coder/coder/v2/coderd/notifications" + "github.com/coder/coder/v2/coderd/notifications/notificationstest" + "github.com/coder/coder/v2/coderd/x/chatd/chatstate" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/testutil" + "github.com/coder/quartz" +) + +func TestWorker_AutoArchiveDisabled(t *testing.T) { + t.Parallel() + f := newWorkerTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + now := time.Date(2026, 5, 29, 12, 0, 0, 0, time.UTC) + chat := f.createArchiveCandidate(t, now.Add(-120*24*time.Hour)) + require.NoError(t, f.db.UpsertChatAutoArchiveDays(ctx, codersdk.DefaultChatAutoArchiveDays)) + + pubsub := newRecordingPubsub(f.pubsub) + worker := f.newArchiveWorker(t, pubsub, nil, nil) + worker.archiveOnce(ctx, now) + + refreshed, err := f.db.GetChatByID(ctx, chat.ID) + require.NoError(t, err) + require.False(t, refreshed.Archived) + require.Empty(t, pubsub.watchEvents(t)) + require.Empty(t, pubsub.stateUpdateMessages(t, chat.ID)) +} + +func TestWorker_AutoArchivesInactiveRoot(t *testing.T) { + t.Parallel() + f := newWorkerTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + now := time.Date(2026, 5, 29, 12, 0, 0, 0, time.UTC) + chat := f.createArchiveCandidate(t, now.Add(-120*24*time.Hour)) + insertArchiveMessage(t, f, chat.ID, now.Add(-100*24*time.Hour)) + require.NoError(t, f.db.UpsertChatAutoArchiveDays(ctx, 90)) + require.NoError(t, f.db.UpsertChatRetentionDays(ctx, 30)) + + pubsub := newRecordingPubsub(f.pubsub) + auditor := audit.NewMock() + enqueuer := notificationstest.NewFakeEnqueuer() + worker := f.newArchiveWorker(t, pubsub, mockAuditorPtr(auditor), enqueuer) + worker.archiveOnce(ctx, now) + + refreshed, err := f.db.GetChatByID(ctx, chat.ID) + require.NoError(t, err) + require.True(t, refreshed.Archived) + require.Greater(t, refreshed.SnapshotVersion, chat.SnapshotVersion) + + updates := pubsub.stateUpdateMessages(t, chat.ID) + require.NotEmpty(t, updates) + require.True(t, updates[len(updates)-1].Archived) + requireWatchEvent(t, pubsub, chat.ID, codersdk.ChatWatchEventKindDeleted) + + logs := auditor.AuditLogs() + require.Len(t, logs, 1) + require.Equal(t, chat.ID, logs[0].ResourceID) + require.Equal(t, database.ResourceTypeChat, logs[0].ResourceType) + require.Equal(t, database.AuditActionWrite, logs[0].Action) + require.Contains(t, string(logs[0].AdditionalFields), string(audit.BackgroundSubsystemChatAutoArchive)) + + sent := enqueuer.Sent() + require.Len(t, sent, 1) + require.Equal(t, notifications.TemplateChatAutoArchiveDigest, sent[0].TemplateID) + require.Equal(t, f.user.ID, sent[0].UserID) + require.Equal(t, "90", sent[0].Data["auto_archive_days"]) + require.Equal(t, "30", sent[0].Data["retention_days"]) +} + +func TestWorker_AutoArchiveRejectsActiveChild(t *testing.T) { + t.Parallel() + f := newWorkerTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + now := time.Date(2026, 5, 29, 12, 0, 0, 0, time.UTC) + root := f.createArchiveCandidate(t, now.Add(-120*24*time.Hour)) + child := f.createArchiveCandidate(t, now.Add(-120*24*time.Hour)) + f.linkChild(t, root.ID, child.ID) + forceExecutionState(t, f, child.ID, database.ChatStatusRunning, false) + require.NoError(t, f.db.UpsertChatAutoArchiveDays(ctx, 90)) + + pubsub := newRecordingPubsub(f.pubsub) + worker := f.newArchiveWorker(t, pubsub, nil, nil) + worker.archiveOnce(ctx, now) + + refreshedRoot, err := f.db.GetChatByID(ctx, root.ID) + require.NoError(t, err) + require.False(t, refreshedRoot.Archived) + refreshedChild, err := f.db.GetChatByID(ctx, child.ID) + require.NoError(t, err) + require.False(t, refreshedChild.Archived) + require.Empty(t, pubsub.watchEvents(t)) +} + +func TestWorker_AutoArchivePublishesStateUpdatesForFamily(t *testing.T) { + t.Parallel() + f := newWorkerTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + now := time.Date(2026, 5, 29, 12, 0, 0, 0, time.UTC) + root := f.createArchiveCandidate(t, now.Add(-120*24*time.Hour)) + child := f.createArchiveCandidate(t, now.Add(-120*24*time.Hour)) + f.linkChild(t, root.ID, child.ID) + require.NoError(t, f.db.UpsertChatAutoArchiveDays(ctx, 90)) + + pubsub := newRecordingPubsub(f.pubsub) + worker := f.newArchiveWorker(t, pubsub, nil, nil) + worker.archiveOnce(ctx, now) + + refreshedRoot, err := f.db.GetChatByID(ctx, root.ID) + require.NoError(t, err) + require.True(t, refreshedRoot.Archived) + refreshedChild, err := f.db.GetChatByID(ctx, child.ID) + require.NoError(t, err) + require.True(t, refreshedChild.Archived) + require.NotEmpty(t, pubsub.stateUpdateMessages(t, root.ID)) + require.NotEmpty(t, pubsub.stateUpdateMessages(t, child.ID)) + requireWatchEvent(t, pubsub, root.ID, codersdk.ChatWatchEventKindDeleted) + requireWatchEvent(t, pubsub, child.ID, codersdk.ChatWatchEventKindDeleted) +} + +func TestWorker_AutoArchiveExpectedTransitionFailureDoesNotAbortTick(t *testing.T) { + t.Parallel() + f := newWorkerTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + now := time.Date(2026, 5, 29, 12, 0, 0, 0, time.UTC) + blockedRoot := f.createArchiveCandidate(t, now.Add(-130*24*time.Hour)) + blockedChild := f.createArchiveCandidate(t, now.Add(-130*24*time.Hour)) + f.linkChild(t, blockedRoot.ID, blockedChild.ID) + forceExecutionState(t, f, blockedChild.ID, database.ChatStatusRunning, false) + valid := f.createArchiveCandidate(t, now.Add(-120*24*time.Hour)) + require.NoError(t, f.db.UpsertChatAutoArchiveDays(ctx, 90)) + + pubsub := newRecordingPubsub(f.pubsub) + worker := f.newArchiveWorker(t, pubsub, nil, nil) + worker.archiveOnce(ctx, now) + + blockedAfter, err := f.db.GetChatByID(ctx, blockedRoot.ID) + require.NoError(t, err) + require.False(t, blockedAfter.Archived) + validAfter, err := f.db.GetChatByID(ctx, valid.ID) + require.NoError(t, err) + require.True(t, validAfter.Archived) + requireWatchEvent(t, pubsub, valid.ID, codersdk.ChatWatchEventKindDeleted) +} + +func TestWorker_AutoArchiveDateBoundary(t *testing.T) { + t.Parallel() + f := newWorkerTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + now := time.Date(2026, 5, 29, 12, 0, 0, 0, time.UTC) + onCutoff := f.createArchiveCandidate(t, now.Add(-120*24*time.Hour)) + insertArchiveMessage(t, f, onCutoff.ID, time.Date(2026, 2, 28, 23, 59, 59, 0, time.UTC)) + beforeCutoff := f.createArchiveCandidate(t, now.Add(-120*24*time.Hour)) + insertArchiveMessage(t, f, beforeCutoff.ID, time.Date(2026, 2, 27, 23, 59, 59, 0, time.UTC)) + require.NoError(t, f.db.UpsertChatAutoArchiveDays(ctx, 90)) + + worker := f.newArchiveWorker(t, newRecordingPubsub(f.pubsub), nil, nil) + worker.archiveOnce(ctx, now) + + refreshedOn, err := f.db.GetChatByID(ctx, onCutoff.ID) + require.NoError(t, err) + require.False(t, refreshedOn.Archived) + refreshedBefore, err := f.db.GetChatByID(ctx, beforeCutoff.ID) + require.NoError(t, err) + require.True(t, refreshedBefore.Archived) +} + +func (f *workerTestFixture) createArchiveCandidate(t *testing.T, createdAt time.Time) database.Chat { + t.Helper() + return f.createArchiveCandidateForOwner(t, f.user.ID, createdAt) +} + +func (f *workerTestFixture) createArchiveCandidateForOwner(t *testing.T, ownerID uuid.UUID, createdAt time.Time) database.Chat { + t.Helper() + chat := dbgen.Chat(t, f.db, database.Chat{ + OrganizationID: f.org.ID, + OwnerID: ownerID, + LastModelConfigID: f.model.ID, + Title: testutil.GetRandomName(t), + Status: database.ChatStatusWaiting, + }) + _, err := f.sqlDB.ExecContext(testutil.Context(t, testutil.WaitShort), "UPDATE chats SET created_at = $1, updated_at = $1 WHERE id = $2", createdAt, chat.ID) + require.NoError(t, err) + chat.CreatedAt = createdAt + chat.UpdatedAt = createdAt + return chat +} + +func (f *workerTestFixture) setPinOrder(t *testing.T, chatID uuid.UUID, order int32) { + t.Helper() + _, err := f.sqlDB.ExecContext(testutil.Context(t, testutil.WaitShort), "UPDATE chats SET pin_order = $1 WHERE id = $2", order, chatID) + require.NoError(t, err) +} + +func (f *workerTestFixture) softDeleteMessages(t *testing.T, chatID uuid.UUID) { + t.Helper() + _, err := f.sqlDB.ExecContext(testutil.Context(t, testutil.WaitShort), "UPDATE chat_messages SET deleted = true WHERE chat_id = $1", chatID) + require.NoError(t, err) +} + +func (f *workerTestFixture) archived(t *testing.T, chatID uuid.UUID) bool { + t.Helper() + chat, err := f.db.GetChatByID(testutil.Context(t, testutil.WaitShort), chatID) + require.NoError(t, err) + return chat.Archived +} + +func (f *workerTestFixture) linkChild(t *testing.T, rootID uuid.UUID, childID uuid.UUID) { + t.Helper() + _, err := f.sqlDB.ExecContext(testutil.Context(t, testutil.WaitShort), "UPDATE chats SET parent_chat_id = $1, root_chat_id = $1 WHERE id = $2", rootID, childID) + require.NoError(t, err) +} + +func insertArchiveMessage(t *testing.T, f *workerTestFixture, chatID uuid.UUID, createdAt time.Time) { + t.Helper() + msg := dbgen.ChatMessage(t, f.db, database.ChatMessage{ + ChatID: chatID, + CreatedBy: uuid.NullUUID{UUID: f.user.ID, Valid: true}, + ModelConfigID: uuid.NullUUID{UUID: f.model.ID, Valid: true}, + Role: database.ChatMessageRoleUser, + }) + _, err := f.sqlDB.ExecContext(testutil.Context(t, testutil.WaitShort), "UPDATE chat_messages SET created_at = $1 WHERE id = $2", createdAt, msg.ID) + require.NoError(t, err) +} + +func (f *workerTestFixture) newArchiveWorker( + t *testing.T, + pubsub *recordingPubsub, + auditor *atomic.Pointer[audit.Auditor], + enqueuer *notificationstest.FakeEnqueuer, +) *chatWorker { + t.Helper() + if pubsub == nil { + pubsub = newRecordingPubsub(f.pubsub) + } + if enqueuer == nil { + enqueuer = notificationstest.NewFakeEnqueuer() + } + opts := f.archiveWorkerOptions() + opts.Pubsub = pubsub + opts.NotificationsEnqueuer = enqueuer + opts.Auditor = auditor + return f.newArchiveWorkerWithOptions(t, opts) +} + +// archiveWorkerOptions returns a baseline chatWorkerOptions with the long +// intervals and channel sizes the archive tests rely on. Callers override +// Pubsub, Store, Clock, and the dispatch dependencies as needed. +func (f *workerTestFixture) archiveWorkerOptions() chatWorkerOptions { + return chatWorkerOptions{ + WorkerID: uuid.New(), + Store: f.db, + Logger: slog.Make(), + TaskStarter: newRecordingTaskStarter(), + AcquisitionInterval: time.Hour, + AcquisitionBatchSize: 10, + ArchiveInterval: time.Hour, + ArchiveBatchSize: 10, + RunnerSyncInterval: time.Hour, + HeartbeatInterval: time.Hour, + HeartbeatCleanupInterval: time.Hour, + HeartbeatStaleSeconds: 30, + StateChannelSize: 16, + RunnerManagerChannelSize: 16, + AcquisitionWakeChannelSize: 1, + } +} + +func (f *workerTestFixture) newArchiveWorkerWithOptions(t *testing.T, opts chatWorkerOptions) *chatWorker { + t.Helper() + if opts.Pubsub == nil { + opts.Pubsub = newRecordingPubsub(f.pubsub) + } + if opts.NotificationsEnqueuer == nil { + opts.NotificationsEnqueuer = notificationstest.NewFakeEnqueuer() + } + worker, err := newChatWorker(nil, opts) + require.NoError(t, err) + return worker +} + +func mockAuditorPtr(auditor *audit.MockAuditor) *atomic.Pointer[audit.Auditor] { + var ptr atomic.Pointer[audit.Auditor] + var asInterface audit.Auditor = auditor + ptr.Store(&asInterface) + return &ptr +} + +func requireWatchEvent(t *testing.T, pubsub *recordingPubsub, chatID uuid.UUID, kind codersdk.ChatWatchEventKind) { + t.Helper() + for _, event := range pubsub.watchEvents(t) { + if event.Kind == kind && event.Chat.ID == chatID { + return + } + } + t.Fatalf("missing watch event kind=%s chat_id=%s", kind, chatID) +} + +// --- Candidate selection (query) semantics --- + +func TestWorker_AutoArchiveSkipsPinnedRoot(t *testing.T) { + t.Parallel() + f := newWorkerTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + now := time.Date(2026, 5, 29, 12, 0, 0, 0, time.UTC) + chat := f.createArchiveCandidate(t, now.Add(-120*24*time.Hour)) + f.setPinOrder(t, chat.ID, 1) + require.NoError(t, f.db.UpsertChatAutoArchiveDays(ctx, 90)) + + worker := f.newArchiveWorker(t, newRecordingPubsub(f.pubsub), nil, nil) + worker.archiveOnce(ctx, now) + + require.False(t, f.archived(t, chat.ID), "pinned root must not be auto-archived") +} + +func TestWorker_AutoArchiveSkipsActiveStatusRoot(t *testing.T) { + t.Parallel() + f := newWorkerTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + now := time.Date(2026, 5, 29, 12, 0, 0, 0, time.UTC) + chat := f.createArchiveCandidate(t, now.Add(-120*24*time.Hour)) + forceExecutionState(t, f, chat.ID, database.ChatStatusRunning, false) + require.NoError(t, f.db.UpsertChatAutoArchiveDays(ctx, 90)) + + worker := f.newArchiveWorker(t, newRecordingPubsub(f.pubsub), nil, nil) + worker.archiveOnce(ctx, now) + + require.False(t, f.archived(t, chat.ID), "running root must not be auto-archived") +} + +func TestWorker_AutoArchiveIgnoresDeletedMessages(t *testing.T) { + t.Parallel() + f := newWorkerTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + now := time.Date(2026, 5, 29, 12, 0, 0, 0, time.UTC) + chat := f.createArchiveCandidate(t, now.Add(-120*24*time.Hour)) + insertArchiveMessage(t, f, chat.ID, now.Add(-10*24*time.Hour)) + require.NoError(t, f.db.UpsertChatAutoArchiveDays(ctx, 90)) + + worker := f.newArchiveWorker(t, newRecordingPubsub(f.pubsub), nil, nil) + worker.archiveOnce(ctx, now) + require.False(t, f.archived(t, chat.ID), "recent message must keep the chat active") + + // Once the only recent message is soft-deleted, activity falls back to + // created_at and the chat becomes eligible. + f.softDeleteMessages(t, chat.ID) + worker.archiveOnce(ctx, now) + require.True(t, f.archived(t, chat.ID), "chat with only deleted messages must archive on created_at") +} + +func TestWorker_AutoArchiveChildActivityKeepsRootAlive(t *testing.T) { + t.Parallel() + f := newWorkerTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + now := time.Date(2026, 5, 29, 12, 0, 0, 0, time.UTC) + root := f.createArchiveCandidate(t, now.Add(-120*24*time.Hour)) + child := f.createArchiveCandidate(t, now.Add(-120*24*time.Hour)) + f.linkChild(t, root.ID, child.ID) + insertArchiveMessage(t, f, child.ID, now.Add(-5*24*time.Hour)) + require.NoError(t, f.db.UpsertChatAutoArchiveDays(ctx, 90)) + + worker := f.newArchiveWorker(t, newRecordingPubsub(f.pubsub), nil, nil) + worker.archiveOnce(ctx, now) + + require.False(t, f.archived(t, root.ID), "recent child activity must keep the root alive") + require.False(t, f.archived(t, child.ID)) +} + +func TestWorker_AutoArchiveBatchSizeLimitsAndPaginates(t *testing.T) { + t.Parallel() + f := newWorkerTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + now := time.Date(2026, 5, 29, 12, 0, 0, 0, time.UTC) + oldest := f.createArchiveCandidate(t, now.Add(-122*24*time.Hour)) + middle := f.createArchiveCandidate(t, now.Add(-121*24*time.Hour)) + newest := f.createArchiveCandidate(t, now.Add(-120*24*time.Hour)) + require.NoError(t, f.db.UpsertChatAutoArchiveDays(ctx, 90)) + + opts := f.archiveWorkerOptions() + opts.Pubsub = newRecordingPubsub(f.pubsub) + opts.ArchiveBatchSize = 2 + worker := f.newArchiveWorkerWithOptions(t, opts) + + // First tick archives the two oldest roots (created_at ASC, limited). + worker.archiveOnce(ctx, now) + require.True(t, f.archived(t, oldest.ID), "oldest root should archive in the first batch") + require.True(t, f.archived(t, middle.ID), "middle root should archive in the first batch") + require.False(t, f.archived(t, newest.ID), "newest root should wait for the next tick") + + // Second tick drains the remaining backlog. + worker.archiveOnce(ctx, now) + require.True(t, f.archived(t, newest.ID), "newest root should archive on the second tick") +} + +func TestWorker_AutoArchiveNoEligibleChats(t *testing.T) { + t.Parallel() + f := newWorkerTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + now := time.Date(2026, 5, 29, 12, 0, 0, 0, time.UTC) + // A recent chat is well within the inactivity window. + chat := f.createArchiveCandidate(t, now.Add(-24*time.Hour)) + require.NoError(t, f.db.UpsertChatAutoArchiveDays(ctx, 90)) + + auditor := audit.NewMock() + enqueuer := notificationstest.NewFakeEnqueuer() + worker := f.newArchiveWorker(t, newRecordingPubsub(f.pubsub), mockAuditorPtr(auditor), enqueuer) + worker.archiveOnce(ctx, now) + + require.False(t, f.archived(t, chat.ID)) + require.Empty(t, auditor.AuditLogs()) + require.Empty(t, enqueuer.Sent()) +} + +// --- Dispatch (audit + digest) semantics --- + +func TestWorker_AutoArchiveMultipleOwnersGetSeparateDigests(t *testing.T) { + t.Parallel() + f := newWorkerTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + now := time.Date(2026, 5, 29, 12, 0, 0, 0, time.UTC) + user2 := dbgen.User(t, f.db, database.User{}) + chat1 := f.createArchiveCandidate(t, now.Add(-120*24*time.Hour)) + chat2 := f.createArchiveCandidateForOwner(t, user2.ID, now.Add(-120*24*time.Hour)) + require.NoError(t, f.db.UpsertChatAutoArchiveDays(ctx, 90)) + + auditor := audit.NewMock() + enqueuer := notificationstest.NewFakeEnqueuer() + worker := f.newArchiveWorker(t, newRecordingPubsub(f.pubsub), mockAuditorPtr(auditor), enqueuer) + worker.archiveOnce(ctx, now) + + require.True(t, f.archived(t, chat1.ID)) + require.True(t, f.archived(t, chat2.ID)) + + sent := enqueuer.Sent() + require.Len(t, sent, 2, "each owner should receive its own digest") + require.ElementsMatch(t, []uuid.UUID{f.user.ID, user2.ID}, []uuid.UUID{sent[0].UserID, sent[1].UserID}) + require.Len(t, auditor.AuditLogs(), 2, "each archived root should be audited") +} + +func TestWorker_AutoArchiveAuditsAndDigestsRootOnlyForFamily(t *testing.T) { + t.Parallel() + f := newWorkerTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + now := time.Date(2026, 5, 29, 12, 0, 0, 0, time.UTC) + root := f.createArchiveCandidate(t, now.Add(-120*24*time.Hour)) + child := f.createArchiveCandidate(t, now.Add(-120*24*time.Hour)) + f.linkChild(t, root.ID, child.ID) + require.NoError(t, f.db.UpsertChatAutoArchiveDays(ctx, 90)) + + auditor := audit.NewMock() + enqueuer := notificationstest.NewFakeEnqueuer() + worker := f.newArchiveWorker(t, newRecordingPubsub(f.pubsub), mockAuditorPtr(auditor), enqueuer) + worker.archiveOnce(ctx, now) + + require.True(t, f.archived(t, root.ID)) + require.True(t, f.archived(t, child.ID)) + + logs := auditor.AuditLogs() + require.Len(t, logs, 1, "only the root should be audited; children inherit the decision") + require.Equal(t, root.ID, logs[0].ResourceID) + require.Len(t, enqueuer.Sent(), 1, "a single-owner family produces one digest") +} + +func TestWorker_AutoArchiveIncrementsRecordsCounter(t *testing.T) { + t.Parallel() + f := newWorkerTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + now := time.Date(2026, 5, 29, 12, 0, 0, 0, time.UTC) + chat := f.createArchiveCandidate(t, now.Add(-120*24*time.Hour)) + require.NoError(t, f.db.UpsertChatAutoArchiveDays(ctx, 90)) + + counter := prometheus.NewCounter(prometheus.CounterOpts{Name: "test_chat_auto_archive_records_total"}) + opts := f.archiveWorkerOptions() + opts.Pubsub = newRecordingPubsub(f.pubsub) + opts.AutoArchiveRecords = counter + worker := f.newArchiveWorkerWithOptions(t, opts) + + worker.archiveOnce(ctx, now) + require.True(t, f.archived(t, chat.ID)) + require.InDelta(t, 1.0, promtestutil.ToFloat64(counter), 0.0001, "counter should reflect one archived root") +} + +func TestWorker_AutoArchiveSecondTickIdempotent(t *testing.T) { + t.Parallel() + f := newWorkerTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + now := time.Date(2026, 5, 29, 12, 0, 0, 0, time.UTC) + chat := f.createArchiveCandidate(t, now.Add(-120*24*time.Hour)) + require.NoError(t, f.db.UpsertChatAutoArchiveDays(ctx, 90)) + + auditor := audit.NewMock() + enqueuer := notificationstest.NewFakeEnqueuer() + worker := f.newArchiveWorker(t, newRecordingPubsub(f.pubsub), mockAuditorPtr(auditor), enqueuer) + + worker.archiveOnce(ctx, now) + require.True(t, f.archived(t, chat.ID)) + require.Len(t, auditor.AuditLogs(), 1) + require.Len(t, enqueuer.Sent(), 1) + + // An already-archived chat is no longer a candidate, so a second tick is a + // no-op for both audit and digest dispatch. + worker.archiveOnce(ctx, now) + require.Len(t, auditor.AuditLogs(), 1, "second tick must not re-audit") + require.Len(t, enqueuer.Sent(), 1, "second tick must not re-notify") +} + +func TestWorker_AutoArchiveCutoffStableAcrossSameDayTicks(t *testing.T) { + t.Parallel() + f := newWorkerTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + // created_at is far in the past so the boundary decision is driven purely + // by message activity sitting exactly on the cutoff date. + chat := f.createArchiveCandidate(t, time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC)) + // StartOfDay(2026-05-29) - 90d = 2026-02-28; activity on that date is not + // strictly before the cutoff. + insertArchiveMessage(t, f, chat.ID, time.Date(2026, 2, 28, 12, 0, 0, 0, time.UTC)) + require.NoError(t, f.db.UpsertChatAutoArchiveDays(ctx, 90)) + + worker := f.newArchiveWorker(t, newRecordingPubsub(f.pubsub), nil, nil) + + // Tick early in the UTC day. + worker.archiveOnce(ctx, time.Date(2026, 5, 29, 23, 49, 0, 0, time.UTC)) + require.False(t, f.archived(t, chat.ID), "activity on the cutoff date must survive") + + // Tick later the same UTC day: advancing wall-clock time within a day must + // not change the cutoff ("no trickle"). + worker.archiveOnce(ctx, time.Date(2026, 5, 29, 23, 59, 0, 0, time.UTC)) + require.False(t, f.archived(t, chat.ID), "same-day tick must not change the decision") + + // Tick on the next UTC day: the cutoff advances to 2026-03-01 and the chat + // becomes eligible. + worker.archiveOnce(ctx, time.Date(2026, 5, 30, 0, 9, 0, 0, time.UTC)) + require.True(t, f.archived(t, chat.ID), "cutoff advances on the next UTC day") +} + +func TestWorker_AutoArchiveDigestDispatchContinuesAfterEnqueueError(t *testing.T) { + t.Parallel() + f := newWorkerTestFixture(t) + owner1 := uuid.New() + owner2 := uuid.New() + enq := &recordingEnqueuer{failOwner: owner1} + opts := f.archiveWorkerOptions() + opts.NotificationsEnqueuer = enq + worker := f.newArchiveWorkerWithOptions(t, opts) + + roots := []autoArchivedChat{ + {Chat: database.Chat{OwnerID: owner1, OrganizationID: f.org.ID, Title: "a"}, LastActivityAt: time.Now()}, + {Chat: database.Chat{OwnerID: owner2, OrganizationID: f.org.ID, Title: "b"}, LastActivityAt: time.Now()}, + } + worker.enqueueAutoArchiveDigests(context.Background(), time.Now(), 90, 30, roots) + + require.ElementsMatch(t, []uuid.UUID{owner1, owner2}, enq.enqueuedOwners(), + "a transient enqueue failure must not abort the dispatch loop") +} + +func TestWorker_AutoArchiveDigestDispatchStopsWhenCanceled(t *testing.T) { + t.Parallel() + f := newWorkerTestFixture(t) + enq := &recordingEnqueuer{} + opts := f.archiveWorkerOptions() + opts.NotificationsEnqueuer = enq + worker := f.newArchiveWorkerWithOptions(t, opts) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + roots := []autoArchivedChat{ + {Chat: database.Chat{OwnerID: uuid.New(), OrganizationID: f.org.ID}, LastActivityAt: time.Now()}, + {Chat: database.Chat{OwnerID: uuid.New(), OrganizationID: f.org.ID}, LastActivityAt: time.Now()}, + } + worker.enqueueAutoArchiveDigests(ctx, time.Now(), 90, 30, roots) + + require.Empty(t, enq.enqueuedOwners(), "canceled dispatch must enqueue nothing") +} + +// --- Config / query error handling --- + +func TestWorker_AutoArchiveDaysConfigReadFailureSkipsTick(t *testing.T) { + t.Parallel() + f := newWorkerTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + now := time.Date(2026, 5, 29, 12, 0, 0, 0, time.UTC) + chat := f.createArchiveCandidate(t, now.Add(-120*24*time.Hour)) + require.NoError(t, f.db.UpsertChatAutoArchiveDays(ctx, 90)) + + enqueuer := notificationstest.NewFakeEnqueuer() + opts := f.archiveWorkerOptions() + opts.Store = &archiveErrStore{Store: f.db, autoArchiveDaysErr: xerrors.New("boom")} + opts.NotificationsEnqueuer = enqueuer + worker := f.newArchiveWorkerWithOptions(t, opts) + worker.archiveOnce(ctx, now) + + require.False(t, f.archived(t, chat.ID), "auto-archive config read failure must skip the tick") + require.Empty(t, enqueuer.Sent()) +} + +func TestWorker_AutoArchiveRetentionConfigReadFailureSkipsTick(t *testing.T) { + t.Parallel() + f := newWorkerTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + now := time.Date(2026, 5, 29, 12, 0, 0, 0, time.UTC) + chat := f.createArchiveCandidate(t, now.Add(-120*24*time.Hour)) + require.NoError(t, f.db.UpsertChatAutoArchiveDays(ctx, 90)) + + enqueuer := notificationstest.NewFakeEnqueuer() + opts := f.archiveWorkerOptions() + opts.Store = &archiveErrStore{Store: f.db, retentionDaysErr: xerrors.New("boom")} + opts.NotificationsEnqueuer = enqueuer + worker := f.newArchiveWorkerWithOptions(t, opts) + worker.archiveOnce(ctx, now) + + require.False(t, f.archived(t, chat.ID), "retention config read failure must skip the tick") + require.Empty(t, enqueuer.Sent()) +} + +func TestWorker_AutoArchiveCandidateQueryFailureSkipsTick(t *testing.T) { + t.Parallel() + f := newWorkerTestFixture(t) + ctx := testutil.Context(t, testutil.WaitShort) + now := time.Date(2026, 5, 29, 12, 0, 0, 0, time.UTC) + chat := f.createArchiveCandidate(t, now.Add(-120*24*time.Hour)) + require.NoError(t, f.db.UpsertChatAutoArchiveDays(ctx, 90)) + + enqueuer := notificationstest.NewFakeEnqueuer() + opts := f.archiveWorkerOptions() + opts.Store = &archiveErrStore{Store: f.db, candidatesErr: xerrors.New("boom")} + opts.NotificationsEnqueuer = enqueuer + worker := f.newArchiveWorkerWithOptions(t, opts) + worker.archiveOnce(ctx, now) + + require.False(t, f.archived(t, chat.ID), "candidate query failure must skip the tick") + require.Empty(t, enqueuer.Sent()) +} + +// --- Loop wiring --- + +func TestWorker_AutoArchiveLoopRunsImmediatelyAndOnTick(t *testing.T) { + t.Parallel() + f := newWorkerTestFixture(t) + ctx := testutil.Context(t, testutil.WaitLong) + require.NoError(t, f.db.UpsertChatAutoArchiveDays(ctx, 90)) + + mClock := quartz.NewMock(t) + now := mClock.Now().UTC() + first := f.createArchiveCandidate(t, now.Add(-120*24*time.Hour)) + + opts := f.archiveWorkerOptions() + opts.Pubsub = newRecordingPubsub(f.pubsub) + opts.Clock = mClock + opts.ArchiveInterval = time.Minute + worker := f.newArchiveWorkerWithOptions(t, opts) + + trap := mClock.Trap().NewTicker("chatworker", "auto-archive") + defer trap.Close() + + loopCtx, cancel := context.WithCancel(ctx) + done := make(chan struct{}) + go func() { + defer close(done) + worker.archiveLoop(loopCtx) + }() + + // archiveLoop creates the ticker before the immediate startup tick. + trap.MustWait(ctx).MustRelease(ctx) + testutil.Eventually(ctx, t, func(context.Context) bool { + return f.archived(t, first.ID) + }, testutil.IntervalFast, "immediate startup tick should archive the first candidate") + + // A second candidate is only archived once the interval ticker fires. + second := f.createArchiveCandidate(t, now.Add(-120*24*time.Hour)) + mClock.Advance(time.Minute).MustWait(ctx) + testutil.Eventually(ctx, t, func(context.Context) bool { + return f.archived(t, second.ID) + }, testutil.IntervalFast, "interval tick should archive the second candidate") + + cancel() + select { + case <-done: + case <-ctx.Done(): + t.Fatal("archiveLoop did not exit after context cancellation") + } +} + +// --- Pure helpers --- + +func TestBuildAutoArchiveDigestData(t *testing.T) { + t.Parallel() + tickStart := time.Date(2026, 5, 29, 12, 0, 0, 0, time.UTC) + + t.Run("UnderCap", func(t *testing.T) { + t.Parallel() + rows := make([]autoArchivedChat, 0, 3) + for i := range 3 { + rows = append(rows, autoArchivedChat{ + Chat: database.Chat{Title: fmt.Sprintf("chat-%d", i)}, + LastActivityAt: tickStart.Add(-time.Duration(i+1) * 24 * time.Hour), + }) + } + data := buildAutoArchiveDigestData(rows, 90, 30, tickStart) + require.Equal(t, "90", data["auto_archive_days"]) + require.Equal(t, "30", data["retention_days"]) + chats, ok := data["archived_chats"].([]map[string]any) + require.True(t, ok) + require.Len(t, chats, 3) + require.Equal(t, "chat-0", chats[0]["title"]) + require.Contains(t, chats[0]["last_activity_humanized"].(string), "ago") + require.NotContains(t, data, "additional_archived_count") + }) + + t.Run("OverflowCap", func(t *testing.T) { + t.Parallel() + total := chatAutoArchiveDigestMaxChats + 5 + rows := make([]autoArchivedChat, 0, total) + for i := range total { + rows = append(rows, autoArchivedChat{ + Chat: database.Chat{Title: fmt.Sprintf("chat-%d", i)}, + LastActivityAt: tickStart.Add(-24 * time.Hour), + }) + } + data := buildAutoArchiveDigestData(rows, 90, 0, tickStart) + chats, ok := data["archived_chats"].([]map[string]any) + require.True(t, ok) + require.Len(t, chats, chatAutoArchiveDigestMaxChats, "titles are capped") + require.Equal(t, "5", data["additional_archived_count"]) + require.Equal(t, "0", data["retention_days"]) + }) +} + +func TestIsExpectedAutoArchiveError(t *testing.T) { + t.Parallel() + expected := []error{ + sql.ErrNoRows, + chatstate.ErrChatNotFound, + chatstate.ErrChatNotRoot, + chatstate.ErrInvalidState, + chatstate.ErrTransitionNotAllowed, + } + for _, err := range expected { + require.True(t, isExpectedAutoArchiveError(err), "%v should be classified as expected", err) + require.True(t, isExpectedAutoArchiveError(xerrors.Errorf("wrapped: %w", err)), + "wrapped %v should still be classified as expected", err) + } + require.False(t, isExpectedAutoArchiveError(xerrors.New("unexpected"))) +} + +// recordingEnqueuer records the owner of every enqueue and can be configured to +// fail for a specific owner (or all owners) to exercise dispatch resilience. +type recordingEnqueuer struct { + mu sync.Mutex + owners []uuid.UUID + failOwner uuid.UUID + failAll bool +} + +func (e *recordingEnqueuer) Enqueue(ctx context.Context, userID, templateID uuid.UUID, labels map[string]string, createdBy string, targets ...uuid.UUID) ([]uuid.UUID, error) { + return e.EnqueueWithData(ctx, userID, templateID, labels, nil, createdBy, targets...) +} + +func (e *recordingEnqueuer) EnqueueWithData(_ context.Context, userID, _ uuid.UUID, _ map[string]string, _ map[string]any, _ string, _ ...uuid.UUID) ([]uuid.UUID, error) { + e.mu.Lock() + e.owners = append(e.owners, userID) + e.mu.Unlock() + if e.failAll || userID == e.failOwner { + return nil, xerrors.New("enqueue failed") + } + return []uuid.UUID{uuid.New()}, nil +} + +func (e *recordingEnqueuer) enqueuedOwners() []uuid.UUID { + e.mu.Lock() + defer e.mu.Unlock() + return append([]uuid.UUID(nil), e.owners...) +} + +// archiveErrStore wraps a real store and injects errors on the reads performed +// at the start of an auto-archive tick. +type archiveErrStore struct { + database.Store + autoArchiveDaysErr error + retentionDaysErr error + candidatesErr error +} + +func (s *archiveErrStore) GetChatAutoArchiveDays(ctx context.Context, defaultAutoArchiveDays int32) (int32, error) { + if s.autoArchiveDaysErr != nil { + return 0, s.autoArchiveDaysErr + } + return s.Store.GetChatAutoArchiveDays(ctx, defaultAutoArchiveDays) +} + +func (s *archiveErrStore) GetChatRetentionDays(ctx context.Context) (int32, error) { + if s.retentionDaysErr != nil { + return 0, s.retentionDaysErr + } + return s.Store.GetChatRetentionDays(ctx) +} + +func (s *archiveErrStore) GetAutoArchiveInactiveChatCandidates(ctx context.Context, arg database.GetAutoArchiveInactiveChatCandidatesParams) ([]database.GetAutoArchiveInactiveChatCandidatesRow, error) { + if s.candidatesErr != nil { + return nil, s.candidatesErr + } + return s.Store.GetAutoArchiveInactiveChatCandidates(ctx, arg) +} diff --git a/coderd/x/chatd/chatadvisor/runner.go b/coderd/x/chatd/chatadvisor/runner.go index d95ef226fb..fe31afbc5c 100644 --- a/coderd/x/chatd/chatadvisor/runner.go +++ b/coderd/x/chatd/chatadvisor/runner.go @@ -50,20 +50,14 @@ func (rt *Runtime) RunAdvisor( nestedProviderOptions := cloneProviderOptions(rt.cfg.ProviderOptions) resetProviderOptionsForNestedCall(nestedProviderOptions) - var persistedStep chatloop.PersistedStep - chatLoopOpts := chatloop.RunOptions{ + assistantOpts := chatloop.GenerateAssistantOptions{ Model: rt.cfg.Model, Messages: BuildAdvisorMessages(question, conversationSnapshot), - MaxSteps: 1, ModelConfig: rt.cfg.ModelConfig, ProviderOptions: nestedProviderOptions, - PersistStep: func(_ context.Context, step chatloop.PersistedStep) error { - persistedStep = step - return nil - }, } if opts != nil && opts.OnAdviceDelta != nil { - chatLoopOpts.PublishMessagePart = func(role codersdk.ChatMessageRole, part codersdk.ChatMessagePart) { + assistantOpts.PublishMessagePart = func(role codersdk.ChatMessageRole, part codersdk.ChatMessagePart) { if role != codersdk.ChatMessageRoleAssistant || part.Type != codersdk.ChatMessagePartTypeText || part.Text == "" { @@ -72,13 +66,17 @@ func (rt *Runtime) RunAdvisor( opts.OnAdviceDelta(part.Text) } } - if opts != nil && opts.OnAdviceReset != nil { - chatLoopOpts.OnRetry = func(int, error, chatretry.ClassifiedError, time.Duration) { + + var outcome chatloop.AssistantOutcome + if err := chatretry.Retry(ctx, func(retryCtx context.Context) error { + var err error + outcome, err = chatloop.GenerateAssistant(retryCtx, assistantOpts) + return err + }, func(int, error, chatretry.ClassifiedError, time.Duration) { + if opts != nil && opts.OnAdviceReset != nil { opts.OnAdviceReset() } - } - - if err := chatloop.Run(ctx, chatLoopOpts); err != nil { + }); err != nil { // Refund the use so a transient provider failure does not // permanently exhaust the per-run advisor budget. rt.release() @@ -89,7 +87,7 @@ func (rt *Runtime) RunAdvisor( }, nil } - advice := extractAdvisorText(persistedStep) + advice := extractAdvisorText(outcome.Step) if advice == "" { // Refund: the run did not produce advice, so the contract // "increments on every successful advisor call" treats this diff --git a/coderd/x/chatd/chatd.go b/coderd/x/chatd/chatd.go index 0e489d710e..4b2f49576d 100644 --- a/coderd/x/chatd/chatd.go +++ b/coderd/x/chatd/chatd.go @@ -8,7 +8,6 @@ import ( "encoding/json" "errors" "fmt" - "maps" "math" "net/http" "slices" @@ -30,10 +29,12 @@ import ( "cdr.dev/slog/v3" "github.com/coder/coder/v2/coderd/aibridge" + "github.com/coder/coder/v2/coderd/audit" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/db2sdk" "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/database/pubsub" + "github.com/coder/coder/v2/coderd/notifications" coderdpubsub "github.com/coder/coder/v2/coderd/pubsub" "github.com/coder/coder/v2/coderd/rbac" "github.com/coder/coder/v2/coderd/util/ptr" @@ -48,12 +49,11 @@ import ( "github.com/coder/coder/v2/coderd/x/chatd/chatopenai" "github.com/coder/coder/v2/coderd/x/chatd/chatprompt" "github.com/coder/coder/v2/coderd/x/chatd/chatprovider" - "github.com/coder/coder/v2/coderd/x/chatd/chatretry" - "github.com/coder/coder/v2/coderd/x/chatd/chatsanitize" "github.com/coder/coder/v2/coderd/x/chatd/chatstate" "github.com/coder/coder/v2/coderd/x/chatd/chattool" "github.com/coder/coder/v2/coderd/x/chatd/internal/agentselect" "github.com/coder/coder/v2/coderd/x/chatd/mcpclient" + "github.com/coder/coder/v2/coderd/x/chatd/messagepartbuffer" skillspkg "github.com/coder/coder/v2/coderd/x/skills" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/workspacesdk" @@ -70,7 +70,6 @@ const ( homeInstructionLookupTimeout = 5 * time.Second planPathLookupTimeout = 5 * time.Second - instructionCacheTTL = 5 * time.Minute workspaceDialValidationDelay = 5 * time.Second // Must exceed agent/x/agentmcp.connectTimeout (30s) so a // cold-start agent's first MCP reload can settle before @@ -99,10 +98,7 @@ const ( // heartbeat updates while a chat is being processed. DefaultChatHeartbeatInterval = 30 * time.Second maxChatSteps = 1200 - // maxStreamBufferSize caps the number of message_part events buffered - // per chat during a single LLM step. When exceeded the oldest event is - // evicted so memory stays bounded. - maxStreamBufferSize = 10000 + // RelaySentinelAfterID is the after_id sentinel used by cross-replica // relay subscribers. It instructs the peer to skip the durable DB // snapshot and only deliver buffered message_part events. The @@ -110,9 +106,6 @@ const ( // so the sentinel resolves to "send me any in-progress streaming // parts you have; I will receive durable messages through pubsub." RelaySentinelAfterID = math.MaxInt64 - // maxDurableMessageCacheSize caps the number of recent durable message - // events cached per chat for same-replica stream catch-up. - maxDurableMessageCacheSize = 256 // maxConcurrentRecordingUploads caps the number of recording // stop-and-store operations that can run concurrently. Each @@ -121,19 +114,6 @@ const ( // to roughly maxConcurrentRecordingUploads * 110 MB. maxConcurrentRecordingUploads = 25 - // staleRecoveryIntervalDivisor determines how often the stale - // recovery loop runs relative to the stale threshold. A value - // of 5 means recovery runs at 1/5 of the stale-after duration. - staleRecoveryIntervalDivisor = 5 - - // streamDropWarnInterval controls how often WARN-level logs are - // emitted when stream events are dropped. Between intervals the - // drop is logged at DEBUG to avoid log spam. This uses a - // timestamp comparison rather than a quartz.Ticker because the - // state is per-chat — a ticker per chat would require extra - // goroutines and lifecycle management. - streamDropWarnInterval = 10 * time.Second - // bufferRetainGracePeriod is how long the per-chat stream // state is kept after processing completes. The retained // state lets late-connecting cross-replica relay subscribers @@ -251,10 +231,12 @@ type Server struct { // keyed by chat ID and invalidated when the agent changes. workspaceMCPToolsCache sync.Map // uuid.UUID -> *cachedWorkspaceMCPTools - usageTracker *workspacestats.UsageTracker - clock quartz.Clock - metrics *chatloop.Metrics - recordingSem chan struct{} + usageTracker *workspacestats.UsageTracker + clock quartz.Clock + metrics *chatloop.Metrics + chatWorker *chatWorker + messagePartBuffer *messagepartbuffer.Buffer + recordingSem chan struct{} aibridgeTransportFactory *atomic.Pointer[aibridge.TransportFactory] aiGatewayRoutingEnabled bool @@ -264,17 +246,6 @@ type Server struct { maxChatsPerAcquire int32 inFlightChatStaleAfter time.Duration chatHeartbeatInterval time.Duration - - // heartbeatMu guards heartbeatRegistry. - heartbeatMu sync.Mutex - // heartbeatRegistry maps chat IDs to their cancel functions - // and workspace state for the centralized heartbeat loop. - heartbeatRegistry map[uuid.UUID]*heartbeatEntry - - // wakeCh is signaled whenever a chat transitions to - // pending so the run loop calls processOnce immediately - // instead of waiting for the next ticker. - wakeCh chan struct{} } // chatTemplateAllowlist returns the deployment-wide template @@ -542,8 +513,8 @@ func (p *Server) loadCachedWorkspaceContext( // cache. Returns nil (and never an error) on every failure mode so the // caller can continue without MCP tools. // -// This helper is shared between the top-of-turn discovery path and the -// mid-turn PrepareTools path triggered after create_workspace / +// This helper is shared between the initial discovery path and the +// mid-turn workspace binding path triggered after create_workspace or // start_workspace bind a workspace to a chat that started without one. func (p *Server) discoverWorkspaceMCPTools( ctx context.Context, @@ -618,9 +589,9 @@ func (p *Server) discoverWorkspaceMCPTools( // When the agent's MCP server is still racing with agent startup, // ListMCPTools may return an empty list (no error) on the first call; // the primer retries with a short backoff up to -// workspaceMCPPrimeMaxWait so the LLM step that follows the tool call -// sees the workspace MCP tools in the cache and PrepareTools does not -// need to dial again. +// workspaceMCPPrimeMaxWait so the generation action that follows the +// tool call sees the workspace MCP tools in the cache and does not need +// to dial again. // // Returns silently on every failure mode. The chat continues without // workspace MCP tools when the agent does not advertise any within @@ -722,6 +693,17 @@ func (c *turnWorkspaceContext) currentWorkspaceMatches(expected uuid.NullUUID) ( return chatSnapshot, nullUUIDEqual(chatSnapshot.WorkspaceID, expected) } +func (c *turnWorkspaceContext) trackWorkspaceUsage(ctx context.Context, chatSnapshot database.Chat) { + if c.server == nil || !chatSnapshot.WorkspaceID.Valid { + return + } + logger := c.server.logger.With( + slog.F("chat_id", chatSnapshot.ID), + slog.F("owner_id", chatSnapshot.OwnerID), + ) + c.server.trackWorkspaceUsage(ctx, chatSnapshot.ID, chatSnapshot.WorkspaceID, logger) +} + func nullUUIDEqual(left, right uuid.NullUUID) bool { if left.Valid != right.Valid { return false @@ -1079,6 +1061,7 @@ func (c *turnWorkspaceContext) getWorkspaceConn(ctx context.Context) (workspaces // row so we see the latest heartbeat rather than // a potentially stale cached copy. if currentConn != nil { + chatSnapshot := c.currentChatSnapshot() if agentID != uuid.Nil { freshAgent, err := c.server.db.GetWorkspaceAgentByID(ctx, agentID) if err != nil { @@ -1097,6 +1080,7 @@ func (c *turnWorkspaceContext) getWorkspaceConn(ctx context.Context) (workspaces continue } } + c.trackWorkspaceUsage(ctx, chatSnapshot) return currentConn, nil } if staleRelease != nil { @@ -1223,6 +1207,7 @@ func (c *turnWorkspaceContext) getWorkspaceConn(ctx context.Context) (workspaces slog.F("workspace_id", chatSnapshot.WorkspaceID.UUID), slog.F("agent_id", dialResult.AgentID), ) + c.trackWorkspaceUsage(ctx, chatSnapshot) return agentConn, nil } currentConn = c.conn @@ -1231,6 +1216,7 @@ func (c *turnWorkspaceContext) getWorkspaceConn(ctx context.Context) (workspaces if agentRelease != nil { agentRelease() } + c.trackWorkspaceUsage(ctx, chatSnapshot) return currentConn, nil } @@ -1324,26 +1310,6 @@ type chatStreamState struct { bufferRetainedAt time.Time } -// heartbeatEntry tracks a single chat's cancel function and workspace -// state for the centralized heartbeat loop. Instead of spawning a -// per-chat goroutine, processChat registers an entry here and the -// single heartbeatLoop goroutine handles all chats. -type heartbeatEntry struct { - cancelWithCause context.CancelCauseFunc - chatID uuid.UUID - workspaceID uuid.NullUUID - logger slog.Logger -} - -// resetDropCounters zeroes the rate-limiting state for both buffer -// and subscriber drop warnings. The caller must hold s.mu. -func (s *chatStreamState) resetDropCounters() { - s.bufferDropCount = 0 - s.bufferLastWarnAt = time.Time{} - s.subscriberDropCount = 0 - s.subscriberLastWarnAt = time.Time{} -} - // streamStateCollector exposes scrape-time gauges derived from // p.chatStreams. Scrape cost is O(n) with a brief per-state mutex // held for two len() reads; acceptable at typical scrape cadences. @@ -1416,12 +1382,6 @@ var ( // accept modifications (messages, edits, promotions, or // tool-result submissions). ErrChatArchived = xerrors.New("chat is archived") - - // errChatTakenByOtherWorker is a sentinel used inside the - // processChat cleanup transaction to signal that another - // worker acquired the chat, so all post-TX side effects - // (status publish, pubsub, web push) must be skipped. - errChatTakenByOtherWorker = xerrors.New("chat acquired by another worker") ) // UsageLimitExceededError indicates the user has exceeded their chat spend @@ -1616,7 +1576,7 @@ func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.C initialMessages = append(initialMessages, systemMessage(userPromptContent, opts.ModelConfigID)) } initialMessages = append(initialMessages, systemMessage(workspaceAwarenessContent, opts.ModelConfigID)) - initialMessages = append(initialMessages, userMessage(userContent, opts.ModelConfigID, opts.OwnerID)) + initialMessages = append(initialMessages, userMessageWithAPIKeyID(userContent, opts.ModelConfigID, opts.OwnerID, opts.APIKeyID)) result, err := chatstate.CreateChat(ctx, p.db, p.pubsub, chatstate.CreateChatInput{ OrganizationID: opts.OrganizationID, @@ -1751,7 +1711,7 @@ func (p *Server) SendMessage( // Queue capacity is enforced inside tx.SendMessage; this // wrapper only propagates the typed error. sendResult, err := tx.SendMessage(chatstate.SendMessageInput{ - Message: userMessage(content, modelConfigID, opts.CreatedBy), + Message: userMessageWithAPIKeyID(content, modelConfigID, opts.CreatedBy, opts.APIKeyID), BusyBehavior: busyBehaviorToChatState(busyBehavior), }) if err != nil { @@ -1847,28 +1807,6 @@ func resolveSendMessageModelConfigID( return requested, nil } -func resolveQueuedMessageModelConfigID( - ctx context.Context, - store database.Store, - chat database.Chat, - queuedModelConfigID uuid.NullUUID, -) (uuid.UUID, error) { - chatdCtx := chatdModelConfigLookupContext(ctx) - if queuedModelConfigID.Valid && queuedModelConfigID.UUID != uuid.Nil { - if _, err := store.GetChatModelConfigByID(chatdCtx, queuedModelConfigID.UUID); err == nil { - return queuedModelConfigID.UUID, nil - } else if !errors.Is(err, sql.ErrNoRows) { - return uuid.Nil, xerrors.Errorf( - "get queued model config %s: %w", - queuedModelConfigID.UUID, - err, - ) - } - } - - return resolveFallbackModelConfigID(ctx, store, chat.LastModelConfigID) -} - func resolveFallbackModelConfigID( ctx context.Context, store database.Store, @@ -2082,6 +2020,8 @@ func (p *Server) UnarchiveChat(ctx context.Context, chat database.Chat) error { // in chat's family through chatstate. The transaction-captured // family rows feed the post-commit debug cleanup and sidebar watch // events. Callers must only invoke this for root chats. +// +//nolint:revive // Existing API takes the target archive state as a boolean. func (p *Server) setChatFamilyArchived( ctx context.Context, chat database.Chat, @@ -2108,22 +2048,8 @@ func (p *Server) setChatFamilyArchived( return err } - // Archiving can race with an interrupted worker still flushing its - // final debug writes. Retry a few times so orphaned rows are - // removed quickly instead of waiting for the stale sweeper. - if archived && len(familyChats) > 0 { - archiveCutoff := familyChats[0].UpdatedAt.Add(-debugCleanupClockSkew) - for _, archivedChat := range familyChats { - p.scheduleDebugCleanup( - ctx, - "failed to delete chat debug rows after archive", - []slog.Field{slog.F("chat_id", archivedChat.ID)}, - func(cleanupCtx context.Context, debugSvc *chatdebug.Service) error { - _, err := debugSvc.DeleteByChatID(cleanupCtx, archivedChat.ID, archiveCutoff) - return err - }, - ) - } + if archived { + p.scheduleArchiveDebugCleanup(ctx, familyChats) } p.publishChatPubsubEvents(familyChats, watchKind) @@ -2465,6 +2391,39 @@ type manualTitleGenerationError struct { activeAPIKeyID string } +// generatedChatTitle carries the title produced by the detached +// automatic title-generation goroutine. maybeGenerateChatTitle stores +// the generated title here so tests can observe it without a database +// read; the title_change pubsub event it publishes remains the source of +// truth for clients. +type generatedChatTitle struct { + mu sync.RWMutex + title string +} + +func (t *generatedChatTitle) Store(title string) { + if t == nil || title == "" { + return + } + + t.mu.Lock() + t.title = title + t.mu.Unlock() +} + +func (t *generatedChatTitle) Load() (string, bool) { + if t == nil { + return "", false + } + + t.mu.RLock() + defer t.mu.RUnlock() + if t.title == "" { + return "", false + } + return t.title, true +} + func (e *manualTitleGenerationError) Error() string { return e.cause.Error() } @@ -3074,7 +3033,7 @@ func prepareChatTurnDebugRun( // Debug instrumentation must never block the user turn. Detach // from the chat-processing context and bound the insert so a slow // or locked DB makes debug logging degrade silently rather than - // stalling chatloop.Run. Matches the pattern used by + // stalling chat processing. Matches the pattern used by // prepareManualTitleDebugRun. createRunCtx, createRunCancel := context.WithTimeout( context.WithoutCancel(ctx), debugCreateRunTimeout, @@ -3386,71 +3345,6 @@ func recordManualTitleUsage( return updatedChat, nil } -// RefreshStatus loads the latest chat status and publishes it to stream subscribers. -func (p *Server) RefreshStatus(ctx context.Context, chatID uuid.UUID) error { - if chatID == uuid.Nil { - return xerrors.New("chat_id is required") - } - - chat, err := p.db.GetChatByID(ctx, chatID) - if err != nil { - return xerrors.Errorf("get chat: %w", err) - } - - p.publishStatus(chat.ID, chat.Status, chat.WorkerID) - return nil -} - -func (p *Server) setChatWaiting(ctx context.Context, chatID uuid.UUID) (database.Chat, error) { - var updatedChat database.Chat - err := p.db.InTx(func(tx database.Store) error { - locked, lockErr := tx.GetChatByIDForUpdate(ctx, chatID) - if lockErr != nil { - return xerrors.Errorf("lock chat for waiting: %w", lockErr) - } - // If the chat has already transitioned to pending (e.g. - // SendMessage with interrupt behavior), don't overwrite - // it — the pending status takes priority so the new - // message gets processed. - if locked.Status == database.ChatStatusPending { - updatedChat = locked - return nil - } - var updateErr error - updatedChat, updateErr = tx.UpdateChatStatus(ctx, database.UpdateChatStatusParams{ - ID: chatID, - Status: database.ChatStatusWaiting, - WorkerID: uuid.NullUUID{}, - StartedAt: sql.NullTime{}, - HeartbeatAt: sql.NullTime{}, - LastError: pqtype.NullRawMessage{}, - }) - return updateErr - }, nil) - if err != nil { - return database.Chat{}, err - } - p.publishStatus(chatID, updatedChat.Status, updatedChat.WorkerID) - p.publishChatPubsubEvent(updatedChat, codersdk.ChatWatchEventKindStatusChange, nil) - return updatedChat, nil -} - -func insertChatMessageWithStore( - ctx context.Context, - store database.Store, - params database.InsertChatMessagesParams, -) ([]database.ChatMessage, error) { - messages, err := store.InsertChatMessages(ctx, params) - if err != nil { - return nil, xerrors.Errorf("insert chat message: %w", err) - } - return messages, nil -} - -// chatMessage is the base message type for batch inserts. Use directly -// only for non-user messages; for user messages, use userChatMessage. -// For nullable UUID fields (ModelConfigID, CreatedBy), use uuid.Nil to -// represent NULL. For nullable int64 fields, use 0 to represent NULL. type chatMessage struct { role database.ChatMessageRole content pqtype.NullRawMessage @@ -3471,8 +3365,6 @@ type chatMessage struct { providerResponseID string } -// userChatMessage wraps chatMessage with a required apiKeyID so that -// omitting it for user messages is a compile error, not a silent data bug. type userChatMessage struct { chatMessage apiKeyID string @@ -3504,8 +3396,6 @@ func newChatMessage( } } -// newUserChatMessage creates a user message. apiKeyID is required so -// that forgetting it is a compile error rather than a silent data bug. func newUserChatMessage( apiKeyID string, content pqtype.NullRawMessage, @@ -3568,9 +3458,6 @@ func (m chatMessage) withProviderResponseID(id string) chatMessage { return m } -// appendMessageFields writes all chatMessage fields into the batch insert -// params. apiKeyID is explicit so non-user messages always get "" while -// user messages carry the caller's key for AI Gateway routing. func appendMessageFields( params *database.InsertChatMessagesParams, msg chatMessage, @@ -3596,27 +3483,45 @@ func appendMessageFields( params.ProviderResponseID = append(params.ProviderResponseID, msg.providerResponseID) } -// appendChatMessage appends a non-user message to the batch insert params. -func appendChatMessage( - params *database.InsertChatMessagesParams, - msg chatMessage, -) { +func appendChatMessage(params *database.InsertChatMessagesParams, msg chatMessage) { if msg.role == database.ChatMessageRoleUser { panic("developer error: use appendUserChatMessage for user-role messages") } appendMessageFields(params, msg, "") } -// appendUserChatMessage inserts a user message with its apiKeyID preserved. -func appendUserChatMessage( - params *database.InsertChatMessagesParams, - msg userChatMessage, -) { +func appendUserChatMessage(params *database.InsertChatMessagesParams, msg userChatMessage) { appendMessageFields(params, msg.chatMessage, msg.apiKeyID) } // BuildSingleUserChatMessageInsertParams creates batch insert params for // one user message, requiring an apiKeyID for AI Gateway attribution. +// BuildSingleChatMessageInsertParams creates batch insert params for one +// non-user message using the shared chat message builder. +func BuildSingleChatMessageInsertParams( + chatID uuid.UUID, + role database.ChatMessageRole, + content pqtype.NullRawMessage, + visibility database.ChatMessageVisibility, + modelConfigID uuid.UUID, + contentVersion int16, + createdBy uuid.UUID, +) database.InsertChatMessagesParams { + params := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendChatMessage. + ChatID: chatID, + } + msg := newChatMessage(role, content, visibility, modelConfigID, contentVersion) + if createdBy != uuid.Nil { + msg = msg.withCreatedBy(createdBy) + } + if role == database.ChatMessageRoleUser { + appendMessageFields(¶ms, msg, "") + } else { + appendChatMessage(¶ms, msg) + } + return params +} + func BuildSingleUserChatMessageInsertParams( chatID uuid.UUID, apiKeyID string, @@ -3637,62 +3542,6 @@ func BuildSingleUserChatMessageInsertParams( return params } -// insertUserMessageAndSetPending inserts a user message, transitions the -// chat to pending when needed, and returns the refreshed chat row. -func insertUserMessageAndSetPending( - ctx context.Context, - store database.Store, - lockedChat database.Chat, - modelConfigID uuid.UUID, - content pqtype.NullRawMessage, - createdBy uuid.UUID, - apiKeyID string, -) (database.ChatMessage, database.Chat, error) { - msgParams := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendUserChatMessage. - ChatID: lockedChat.ID, - } - insertUserMsg := newUserChatMessage( - apiKeyID, - content, - database.ChatMessageVisibilityBoth, - modelConfigID, - chatprompt.CurrentContentVersion, - ) - insertUserMsg = insertUserMsg.withCreatedBy(createdBy) - appendUserChatMessage(&msgParams, insertUserMsg) - messages, err := insertChatMessageWithStore(ctx, store, msgParams) - if err != nil { - return database.ChatMessage{}, database.Chat{}, err - } - message := messages[0] - - if lockedChat.Status == database.ChatStatusPending { - if modelConfigID == uuid.Nil || lockedChat.LastModelConfigID == modelConfigID { - return message, lockedChat, nil - } - // The InsertChatMessages CTE updates chats.last_model_config_id when - // the message's model config differs. Reload to surface that change. - updatedChat, err := store.GetChatByID(ctx, lockedChat.ID) - if err != nil { - return database.ChatMessage{}, database.Chat{}, xerrors.Errorf("get chat after model config update: %w", err) - } - return message, updatedChat, nil - } - - updatedChat, err := store.UpdateChatStatus(ctx, database.UpdateChatStatusParams{ - ID: lockedChat.ID, - Status: database.ChatStatusPending, - WorkerID: uuid.NullUUID{}, - StartedAt: sql.NullTime{}, - HeartbeatAt: sql.NullTime{}, - LastError: pqtype.NullRawMessage{}, - }) - if err != nil { - return database.ChatMessage{}, database.Chat{}, xerrors.Errorf("set chat pending: %w", err) - } - return message, updatedChat, nil -} - // Config configures a chat processor. type Config struct { Logger slog.Logger @@ -3727,6 +3576,9 @@ type Config struct { // May be nil if the deployment has no OIDC provider; servers // using user_oidc will then send no Authorization header. OIDCTokenSource mcpclient.UserOIDCTokenSource + + NotificationsEnqueuer notifications.Enqueuer + Auditor *atomic.Pointer[audit.Auditor] } // New creates a new chat processor. The processor polls for pending @@ -3765,6 +3617,11 @@ func New(cfg Config) *Server { } ps := cfg.Pubsub + notificationsEnqueuer := cfg.NotificationsEnqueuer + if notificationsEnqueuer == nil { + notificationsEnqueuer = notifications.NewNoopEnqueuer() + } + instructionLookupTimeout := cfg.InstructionLookupTimeout if instructionLookupTimeout == 0 { instructionLookupTimeout = homeInstructionLookupTimeout @@ -3821,15 +3678,42 @@ func New(cfg Config) *Server { usageTracker: cfg.UsageTracker, clock: clk, recordingSem: make(chan struct{}, maxConcurrentRecordingUploads), - wakeCh: make(chan struct{}, 1), - heartbeatRegistry: make(map[uuid.UUID]*heartbeatEntry), } + var chatAutoArchiveRecords prometheus.Counter if cfg.PrometheusRegistry != nil { p.metrics = chatloop.NewMetrics(cfg.PrometheusRegistry) cfg.PrometheusRegistry.MustRegister(&streamStateCollector{server: p}) + chatAutoArchiveRecords = prometheus.NewCounter(prometheus.CounterOpts{ + Namespace: "coderd", + Subsystem: "chat_auto_archive", + Name: "records_archived_total", + Help: "Total number of chats archived by the auto-archive job (counting both roots and cascaded children).", + }) + cfg.PrometheusRegistry.MustRegister(chatAutoArchiveRecords) } else { p.metrics = chatloop.NopMetrics() } + p.messagePartBuffer = messagepartbuffer.New(messagepartbuffer.Options{Clock: clk}) + chatWorker, err := newChatWorker(p, chatWorkerOptions{ + WorkerID: workerID, + Store: cfg.Database, + Pubsub: ps, + Logger: cfg.Logger.Named("chatworker"), + Clock: clk, + MessagePartBuffer: p.messagePartBuffer, + AcquisitionInterval: pendingChatAcquireInterval, + AcquisitionBatchSize: maxChatsPerAcquire, + HeartbeatInterval: chatHeartbeatInterval, + HeartbeatStaleSeconds: int32(inFlightChatStaleAfter.Seconds()), + NotificationsEnqueuer: notificationsEnqueuer, + Auditor: cfg.Auditor, + AutoArchiveRecords: chatAutoArchiveRecords, + }) + if err != nil { + panic("chatd: create chat worker: " + err.Error()) + } + p.chatWorker = chatWorker + //nolint:gocritic // The chat processor uses a scoped chatd context. ctx = dbauthz.AsChatd(ctx) @@ -3861,16 +3745,7 @@ func New(cfg Config) *Server { p.ctx = ctx - // Recover stale chats on startup. - p.recoverStaleChats(ctx) - if debugSvc := p.debugService(); debugSvc != nil { - if _, err := debugSvc.FinalizeStale(ctx); err != nil { - p.logger.Warn(ctx, "failed to finalize stale chat debug rows", slog.Error(err)) - } - } - // Spawn background goroutines that all servers need. - p.wg.Go(func() { p.heartbeatLoop(ctx) }) p.wg.Go(func() { p.streamJanitorLoop(ctx) }) return p @@ -3881,260 +3756,14 @@ func New(cfg Config) *Server { // server (e.g. tests) can skip this call; heartbeat, stream // janitor, and stale recovery still run. func (p *Server) Start() *Server { - p.wg.Go(func() { p.acquireLoop(p.ctx) }) + if p.chatWorker != nil { + if err := p.chatWorker.Start(p.ctx); err != nil { + p.logger.Error(p.ctx, "failed to start chat worker", slog.Error(err)) + } + } return p } -func (p *Server) acquireLoop(ctx context.Context) { - acquireTicker := p.clock.NewTicker( - p.pendingChatAcquireInterval, - "chatd", - "acquire", - ) - defer acquireTicker.Stop() - - staleRecoveryInterval := p.inFlightChatStaleAfter / staleRecoveryIntervalDivisor - staleTicker := p.clock.NewTicker( - staleRecoveryInterval, - "chatd", - "stale-recovery", - ) - defer staleTicker.Stop() - - for { - select { - case <-ctx.Done(): - return - case <-acquireTicker.C: - p.processOnce(ctx) - case <-p.wakeCh: - p.processOnce(ctx) - case <-staleTicker.C: - p.recoverStaleChats(ctx) - if debugSvc := p.existingDebugService(); debugSvc != nil { - if _, err := debugSvc.FinalizeStale(ctx); err != nil { - p.logger.Warn(ctx, "failed to finalize stale chat debug rows", slog.Error(err)) - } - } - } - } -} - -// signalWake wakes the run loop so it calls processOnce immediately. -// Non-blocking: if a signal is already pending it is a no-op. -func (p *Server) signalWake() { - select { - case p.wakeCh <- struct{}{}: - default: - } -} - -func (p *Server) processOnce(ctx context.Context) { - if ctx.Err() != nil { - return - } - - // We detach from the server lifetime to prevent a - // phantom-acquire race: when the server context is - // canceled, the pq driver's watchCancel goroutine - // races with the actual query on the wire. Using a - // context that cannot be canceled ensures the driver - // sees the query result if Postgres executed it. - acquireCtx, acquireCancel := context.WithTimeout( - context.WithoutCancel(ctx), 10*time.Second, - ) - chats, err := p.db.AcquireChats(acquireCtx, database.AcquireChatsParams{ - StartedAt: time.Now(), - WorkerID: p.workerID, - NumChats: p.maxChatsPerAcquire, - }) - acquireCancel() - if err != nil { - p.logger.Error(ctx, "failed to acquire chats", slog.Error(err)) - return - } - if len(chats) == 0 { - return - } - - // If the server context was canceled while we were - // acquiring, release the chats back to pending. - if ctx.Err() != nil { - releaseCtx, releaseCancel := context.WithTimeout( - context.WithoutCancel(ctx), 10*time.Second, - ) - for _, chat := range chats { - _, updateErr := p.db.UpdateChatStatus(releaseCtx, database.UpdateChatStatusParams{ - ID: chat.ID, - Status: database.ChatStatusPending, - WorkerID: uuid.NullUUID{}, - StartedAt: sql.NullTime{}, - HeartbeatAt: sql.NullTime{}, - LastError: pqtype.NullRawMessage{}, - }) - if updateErr != nil { - p.logger.Error(ctx, "failed to release chat acquired during shutdown", - slog.F("chat_id", chat.ID), slog.Error(updateErr)) - } - } - releaseCancel() - return - } - - p.inflightMu.Lock() - for _, chat := range chats { - p.inflight.Add(1) - go func() { - defer p.inflight.Done() - p.processChat(ctx, chat) - }() - } - p.inflightMu.Unlock() -} - -func shouldClearRetryPhaseForStatus(status codersdk.ChatStatus) bool { - switch status { - case codersdk.ChatStatusWaiting, - codersdk.ChatStatusPending, - codersdk.ChatStatusPaused, - codersdk.ChatStatusCompleted, - codersdk.ChatStatusError, - codersdk.ChatStatusRequiresAction: - return true - default: - return false - } -} - -func (p *Server) clearProvisionalStreamParts(chatID uuid.UUID) { - val, ok := p.chatStreams.Load(chatID) - if !ok { - return - } - rs, ok := val.(*chatStreamState) - if !ok { - return - } - - // Streamed parts are provisional until a durable message commits - // them. A retry rolls back the failed attempt before replacement - // parts are streamed. - rs.mu.Lock() - rs.buffer = nil - rs.resetDropCounters() - rs.mu.Unlock() -} - -func (p *Server) publishToStream(chatID uuid.UUID, event codersdk.ChatStreamEvent) { - state := p.getOrCreateStreamState(chatID) - state.mu.Lock() - switch event.Type { - case codersdk.ChatStreamEventTypeRetry: - if event.Retry != nil { - retryCopy := *event.Retry - state.currentRetry = &retryCopy - } - case codersdk.ChatStreamEventTypeMessagePart: - // Any streamed part means the provider is making forward - // progress again, so the stream has left the retry backoff - // window regardless of role. - state.currentRetry = nil - case codersdk.ChatStreamEventTypeError: - state.currentRetry = nil - case codersdk.ChatStreamEventTypeStatus: - if event.Status != nil && shouldClearRetryPhaseForStatus(event.Status.Status) { - state.currentRetry = nil - } - } - if event.Type == codersdk.ChatStreamEventTypeMessagePart { - if !state.buffering { - p.cleanupStreamIfIdle(chatID, state) - state.mu.Unlock() - return - } - if len(state.buffer) >= maxStreamBufferSize { - p.metrics.RecordStreamBufferDropped() - state.bufferDropCount++ - now := p.clock.Now() - if now.Sub(state.bufferLastWarnAt) >= streamDropWarnInterval { - p.logger.Warn(context.Background(), "chat stream buffer full, dropping oldest event", - slog.F("chat_id", chatID), - slog.F("buffer_size", len(state.buffer)), - slog.F("dropped_count", state.bufferDropCount), - ) - state.bufferDropCount = 0 - state.bufferLastWarnAt = now - } - // Zero the dropped slot so its *ChatStreamMessagePart is - // GC-eligible; the later append reuses this slot in place - // whenever cap > len. - state.buffer[0] = bufferedStreamPart{} - state.buffer = state.buffer[1:] - } - state.buffer = append(state.buffer, bufferedStreamPart{ - event: event, - // committedMessageID stays 0 here: the part belongs to - // the in-progress turn until publishMessage claims it - // with the committed assistant message ID. - }) - } - subscribers := make([]chan codersdk.ChatStreamEvent, 0, len(state.subscribers)) - for _, ch := range state.subscribers { - subscribers = append(subscribers, ch) - } - state.mu.Unlock() - - var subDropped int64 - for _, ch := range subscribers { - select { - case ch <- event: - default: - subDropped++ - } - } - - // Re-acquire the lock once for both subscriber-drop logging and - // idle cleanup. Merging these avoids an unnecessary unlock/re-lock - // gap between the two sections. - state.mu.Lock() - if subDropped > 0 { - state.subscriberDropCount += subDropped - now := p.clock.Now() - if now.Sub(state.subscriberLastWarnAt) >= streamDropWarnInterval { - p.logger.Warn(context.Background(), "dropping chat stream event", - slog.F("chat_id", chatID), - slog.F("type", event.Type), - slog.F("dropped_count", state.subscriberDropCount), - ) - state.subscriberDropCount = 0 - state.subscriberLastWarnAt = now - } - } - p.cleanupStreamIfIdle(chatID, state) - state.mu.Unlock() -} - -// cacheDurableMessage stores a recently persisted message event in the -// per-chat stream state so that same-replica subscribers can catch up -// from memory instead of the database. The afterMessageID is the -// message ID that precedes this message (i.e. message.ID - 1). -func (p *Server) cacheDurableMessage(chatID uuid.UUID, event codersdk.ChatStreamEvent) { - state := p.getOrCreateStreamState(chatID) - state.mu.Lock() - defer state.mu.Unlock() - - if len(state.durableMessages) >= maxDurableMessageCacheSize { - if evicted := state.durableMessages[0]; evicted.Message != nil { - state.durableEvictedBefore = evicted.Message.ID - } - // Zero the dropped slot so the evicted *ChatMessage is - // GC-eligible; see publishToStream for the same pattern. - state.durableMessages[0] = codersdk.ChatStreamEvent{} - state.durableMessages = state.durableMessages[1:] - } - state.durableMessages = append(state.durableMessages, event) -} - // getCachedDurableMessages returns cached durable messages with IDs // greater than afterID. Returns nil when the cache has no relevant // entries. @@ -4344,97 +3973,6 @@ func (p *Server) sweepIdleStreams() { }) } -// registerHeartbeat enrolls a chat in the centralized batch -// heartbeat loop. Must be called after chatCtx is created. -func (p *Server) registerHeartbeat(entry *heartbeatEntry) { - p.heartbeatMu.Lock() - defer p.heartbeatMu.Unlock() - if _, exists := p.heartbeatRegistry[entry.chatID]; exists { - p.logger.Warn(context.Background(), - "duplicate heartbeat registration, skipping", - slog.F("chat_id", entry.chatID)) - return - } - p.heartbeatRegistry[entry.chatID] = entry -} - -// unregisterHeartbeat removes a chat from the centralized -// heartbeat loop when chat processing finishes. -func (p *Server) unregisterHeartbeat(chatID uuid.UUID) { - p.heartbeatMu.Lock() - defer p.heartbeatMu.Unlock() - delete(p.heartbeatRegistry, chatID) -} - -// heartbeatLoop runs in a single goroutine, issuing one batch -// heartbeat query per interval for all registered chats. -func (p *Server) heartbeatLoop(ctx context.Context) { - ticker := p.clock.NewTicker(p.chatHeartbeatInterval, "chatd", "batch-heartbeat") - defer ticker.Stop() - for { - select { - case <-ctx.Done(): - return - case <-ticker.C: - p.heartbeatTick(ctx) - } - } -} - -// heartbeatTick issues a single batch UPDATE for all running chats -// owned by this worker. Chats missing from the result set are -// interrupted (stolen by another replica or already completed). -func (p *Server) heartbeatTick(ctx context.Context) { - // Snapshot the registry under the lock. - p.heartbeatMu.Lock() - snapshot := maps.Clone(p.heartbeatRegistry) - p.heartbeatMu.Unlock() - - if len(snapshot) == 0 { - return - } - - // Collect the IDs we believe we own. - ids := slices.Collect(maps.Keys(snapshot)) - - //nolint:gocritic // AsChatd provides narrowly-scoped daemon - // access for batch-updating heartbeats. - chatdCtx := dbauthz.AsChatd(ctx) - updatedIDs, err := p.db.UpdateChatHeartbeats(chatdCtx, database.UpdateChatHeartbeatsParams{ - IDs: ids, - WorkerID: p.workerID, - Now: p.clock.Now(), - }) - if err != nil { - p.logger.Error(ctx, "batch heartbeat failed", slog.Error(err)) - return - } - - // Build a set of IDs that were successfully updated. - updated := make(map[uuid.UUID]struct{}, len(updatedIDs)) - for _, id := range updatedIDs { - updated[id] = struct{}{} - } - - // Interrupt registered chats that were not in the result - // (stolen by another replica or already completed). - for id, entry := range snapshot { - if _, ok := updated[id]; !ok { - entry.logger.Warn(ctx, "chat not in batch heartbeat result, interrupting") - entry.cancelWithCause(chatloop.ErrInterrupted) - continue - } - // Bump workspace usage for surviving chats. - newWsID := p.trackWorkspaceUsage(ctx, entry.chatID, entry.workspaceID, entry.logger) - // Update workspace ID in the registry for next tick. - p.heartbeatMu.Lock() - if current, exists := p.heartbeatRegistry[id]; exists { - current.workspaceID = newWsID - } - p.heartbeatMu.Unlock() - } -} - // streamSubscriberControlFetchContext keeps a control-path lookup tied to the // requesting subscriber while applying a fallback timeout when the caller has // no deadline. @@ -4530,7 +4068,7 @@ func (p *Server) SubscribeAuthorized( // This MUST happen before the DB queries below so that any // notification published between the query and the subscription // is not lost (subscribe-first-then-query pattern). - notifications := make(chan coderdpubsub.ChatStreamNotifyMessage, 10) + notifyCh := make(chan coderdpubsub.ChatStreamNotifyMessage, 10) errCh := make(chan error, 1) listener := func(_ context.Context, message []byte, listenErr error) { if listenErr != nil { @@ -4550,7 +4088,7 @@ func (p *Server) SubscribeAuthorized( } select { case <-mergedCtx.Done(): - case notifications <- notify: + case notifyCh <- notify: } } @@ -4747,7 +4285,7 @@ func (p *Server) SubscribeAuthorized( case <-mergedCtx.Done(): } return - case notify := <-notifications: + case notify := <-notifyCh: // Marker for ENG-2645: subscriber received pubsub notify. p.logger.Debug(mergedCtx, "stream subscriber received notify", slog.F("chat_id", chatID), @@ -4989,47 +4527,6 @@ func (p *Server) SubscribeAuthorized( return initialSnapshot, mergedEvents, cancel, true } -func (p *Server) publishEvent(chatID uuid.UUID, event codersdk.ChatStreamEvent) { - if event.ChatID == uuid.Nil { - event.ChatID = chatID - } - p.publishToStream(chatID, event) -} - -func (p *Server) publishStatus(chatID uuid.UUID, status database.ChatStatus, workerID uuid.NullUUID) { - p.publishEvent(chatID, codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeStatus, - Status: &codersdk.ChatStreamStatus{Status: codersdk.ChatStatus(status)}, - }) - notify := coderdpubsub.ChatStreamNotifyMessage{ - Status: string(status), - } - if workerID.Valid { - notify.WorkerID = workerID.UUID.String() - } - p.publishChatStreamNotify(chatID, notify) -} - -// publishChatStreamNotify broadcasts a per-chat stream notification via -// PostgreSQL pubsub so that all replicas can merge durable database updates -// with transient control events. -func (p *Server) publishChatStreamNotify(chatID uuid.UUID, notify coderdpubsub.ChatStreamNotifyMessage) { - payload, err := json.Marshal(notify) - if err != nil { - p.logger.Error(context.Background(), "failed to marshal chat stream notify", - slog.F("chat_id", chatID), - slog.Error(err), - ) - return - } - if err := p.pubsub.Publish(coderdpubsub.ChatStreamNotifyChannel(chatID), payload); err != nil { - p.logger.Error(context.Background(), "failed to publish chat stream notify", - slog.F("chat_id", chatID), - slog.Error(err), - ) - } -} - // publishChatPubsubEvents broadcasts a lifecycle event for each affected chat. func (p *Server) publishChatPubsubEvents(chats []database.Chat, kind codersdk.ChatWatchEventKind) { for _, chat := range chats { @@ -5040,6 +4537,9 @@ func (p *Server) publishChatPubsubEvents(chats []database.Chat, kind codersdk.Ch // publishChatPubsubEvent broadcasts a chat lifecycle event via PostgreSQL // pubsub so that all replicas can push updates to watching clients. func (p *Server) publishChatPubsubEvent(chat database.Chat, kind codersdk.ChatWatchEventKind, diffStatus *codersdk.ChatDiffStatus) { + if p.pubsub == nil { + return + } // diffStatus is applied below. File metadata is intentionally // omitted from pubsub events to avoid an extra DB query per // publish. Clients must merge pubsub updates, not replace @@ -5069,48 +4569,6 @@ func (p *Server) publishChatPubsubEvent(chat database.Chat, kind codersdk.ChatWa } } -// pendingToStreamToolCalls converts a slice of chatloop pending -// tool calls into the SDK streaming representation. -func pendingToStreamToolCalls(pending []chatloop.PendingToolCall) []codersdk.ChatStreamToolCall { - calls := make([]codersdk.ChatStreamToolCall, len(pending)) - for i, tc := range pending { - calls[i] = codersdk.ChatStreamToolCall{ - ToolCallID: tc.ToolCallID, - ToolName: tc.ToolName, - Args: tc.Args, - } - } - return calls -} - -// publishChatActionRequired broadcasts an action_required event via -// PostgreSQL pubsub so that global watchers can react to dynamic -// tool calls without streaming each chat individually. -func (p *Server) publishChatActionRequired(chat database.Chat, pending []chatloop.PendingToolCall) { - toolCalls := pendingToStreamToolCalls(pending) - sdkChat := db2sdk.Chat(chat, nil, nil) - - event := codersdk.ChatWatchEvent{ - Kind: codersdk.ChatWatchEventKindActionRequired, - Chat: sdkChat, - ToolCalls: toolCalls, - } - payload, err := json.Marshal(event) - if err != nil { - p.logger.Error(context.Background(), "failed to marshal chat action_required pubsub event", - slog.F("chat_id", chat.ID), - slog.Error(err), - ) - return - } - if err := p.pubsub.Publish(coderdpubsub.ChatWatchEventChannel(chat.OwnerID), payload); err != nil { - p.logger.Error(context.Background(), "failed to publish chat action_required pubsub event", - slog.F("chat_id", chat.ID), - slog.Error(err), - ) - } -} - // PublishDiffStatusChange broadcasts a diff_status_change event for // the given chat so that watching clients know to re-fetch the diff // status. This is called from the HTTP layer after the diff status @@ -5131,222 +4589,6 @@ func (p *Server) PublishDiffStatusChange(ctx context.Context, chatID uuid.UUID) return nil } -func (p *Server) publishRetry(chatID uuid.UUID, payload *codersdk.ChatStreamRetry) { - if payload == nil { - return - } - p.publishEvent(chatID, codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeRetry, - Retry: payload, - }) - p.publishChatStreamNotify(chatID, coderdpubsub.ChatStreamNotifyMessage{ - Retry: payload, - }) -} - -func (p *Server) publishError(chatID uuid.UUID, classified chaterror.ClassifiedError) { - payload := chaterror.TerminalErrorPayload(classified) - if payload == nil { - return - } - p.publishEvent(chatID, codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeError, - Error: payload, - }) - p.publishChatStreamNotify(chatID, coderdpubsub.ChatStreamNotifyMessage{ - ErrorPayload: payload, - Error: payload.Message, - }) -} - -func processingFailure(err error) (chaterror.ClassifiedError, bool) { - if err == nil { - return chaterror.ClassifiedError{}, false - } - - classified := chaterror.Classify(err) - if classified.Message == "" { - return chaterror.ClassifiedError{}, false - } - return classified, true -} - -func encodeChatLastErrorPayload(payload *codersdk.ChatError) (pqtype.NullRawMessage, error) { - if payload == nil { - return pqtype.NullRawMessage{}, nil - } - encoded, err := json.Marshal(payload) - if err != nil { - return pqtype.NullRawMessage{}, err - } - return pqtype.NullRawMessage{RawMessage: encoded, Valid: true}, nil -} - -func panicFailureReason(recovered any) string { - var reason string - switch typed := recovered.(type) { - case string: - reason = strings.TrimSpace(typed) - case error: - reason = strings.TrimSpace(typed.Error()) - default: - reason = strings.TrimSpace(fmt.Sprint(typed)) - } - - if reason == "" || reason == "" { - return "chat processing panicked" - } - return "chat processing panicked: " + reason -} - -func (p *Server) publishMessage(chatID uuid.UUID, message database.ChatMessage) { - sdkMessage := db2sdk.ChatMessage(message) - event := codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeMessage, - ChatID: chatID, - Message: &sdkMessage, - } - p.cacheDurableMessage(chatID, event) - // Claim every still-in-progress buffered message_part for this - // durable assistant message BEFORE publishing it, so any new - // subscriber that races publishEvent below takes a buffer - // snapshot in which the parts for this turn are already - // suppressed. Existing subscribers already received the - // constituent parts on the live channel; the frontend - // dedupes those against the durable message via - // clearStreamState in the same batch. - p.claimCommittedParts(chatID, message) - p.publishEvent(chatID, event) - p.publishChatStreamNotify(chatID, coderdpubsub.ChatStreamNotifyMessage{ - AfterMessageID: message.ID - 1, - }) -} - -// claimCommittedParts walks the chat's buffered message_part events -// and assigns every in-progress part (committedMessageID == 0) to -// the supplied assistant message ID. Subsequent subscriber snapshots -// drop those parts so a reconnecting client does not re-render the -// content of an assistant turn that has already been delivered as a -// durable message via REST or pubsub. -// -// Tool and user messages do not end an assistant streaming turn, so -// only assistant-role messages claim parts. -func (p *Server) claimCommittedParts(chatID uuid.UUID, message database.ChatMessage) { - if message.Role != database.ChatMessageRoleAssistant { - return - } - val, ok := p.chatStreams.Load(chatID) - if !ok { - return - } - state, ok := val.(*chatStreamState) - if !ok { - return - } - state.mu.Lock() - defer state.mu.Unlock() - for i := range state.buffer { - if state.buffer[i].committedMessageID == 0 { - state.buffer[i].committedMessageID = message.ID - } - } -} - -// publishEditedMessage is like publishMessage but uses FullRefresh -// so remote subscribers re-fetch from the beginning, ensuring the -// edit is never silently dropped. The durable cache is replaced -// with only the edited message. -func (p *Server) publishEditedMessage(chatID uuid.UUID, message database.ChatMessage) { - sdkMessage := db2sdk.ChatMessage(message) - event := codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeMessage, - ChatID: chatID, - Message: &sdkMessage, - } - state := p.getOrCreateStreamState(chatID) - state.mu.Lock() - state.durableMessages = []codersdk.ChatStreamEvent{event} - state.durableEvictedBefore = 0 - state.mu.Unlock() - p.publishEvent(chatID, event) - p.publishChatStreamNotify(chatID, coderdpubsub.ChatStreamNotifyMessage{ - FullRefresh: true, - }) -} - -func (p *Server) publishMessagePart(chatID uuid.UUID, role codersdk.ChatMessageRole, part codersdk.ChatMessagePart) { - if part.Type == "" { - return - } - // Strip internal-only fields before client delivery. - // Mirrors db2sdk.chatMessageParts stripping for REST. - part.StripInternal() - p.publishEvent(chatID, codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeMessagePart, - MessagePart: &codersdk.ChatStreamMessagePart{ - Role: role, - Part: part, - }, - }) -} - -func shouldCancelChatFromControlNotification( - notify coderdpubsub.ChatStreamNotifyMessage, - workerID uuid.UUID, -) bool { - status := database.ChatStatus(strings.TrimSpace(notify.Status)) - switch status { - case database.ChatStatusWaiting, database.ChatStatusPending, database.ChatStatusError: - return true - case database.ChatStatusRunning: - worker := strings.TrimSpace(notify.WorkerID) - if worker == "" { - return false - } - notifyWorkerID, err := uuid.Parse(worker) - if err != nil { - return false - } - return notifyWorkerID != workerID - default: - return false - } -} - -func (p *Server) subscribeChatControl( - ctx context.Context, - chatID uuid.UUID, - cancel context.CancelCauseFunc, - logger slog.Logger, -) func() { - listener := func(_ context.Context, message []byte, err error) { - if err != nil { - logger.Warn(ctx, "chat control pubsub error", slog.Error(err)) - return - } - - var notify coderdpubsub.ChatStreamNotifyMessage - if unmarshalErr := json.Unmarshal(message, ¬ify); unmarshalErr != nil { - logger.Warn(ctx, "failed to unmarshal chat control notify", slog.Error(unmarshalErr)) - return - } - - if shouldCancelChatFromControlNotification(notify, p.workerID) { - cancel(chatloop.ErrInterrupted) - } - } - - controlCancel, err := p.pubsub.SubscribeWithErr( - coderdpubsub.ChatStreamNotifyChannel(chatID), - listener, - ) - if err != nil { - logger.Warn(ctx, "failed to subscribe to chat control notifications", slog.Error(err)) - return nil - } - return controlCancel -} - // Rejects oversize images on capped providers before any upstream // request is issued. // @@ -5401,74 +4643,6 @@ func (p *Server) chatFileResolver(provider string) chatprompt.FileResolver { } } -// tryAutoPromoteQueuedMessage pops the next queued message and converts it -// into a pending user message inside the caller's transaction. Queued -// messages were already admitted through SendMessage, so this preserves FIFO -// order without re-checking usage limits. -func (p *Server) tryAutoPromoteQueuedMessage( - ctx context.Context, - tx database.Store, - chat database.Chat, -) (*database.ChatMessage, []database.ChatQueuedMessage, bool, error) { - logger := p.logger.With(slog.F("chat_id", chat.ID)) - - queuedMessages, err := tx.GetChatQueuedMessages(ctx, chat.ID) - if err != nil { - return nil, nil, false, xerrors.Errorf("get queued messages: %w", err) - } - if len(queuedMessages) == 0 { - return nil, nil, false, nil - } - nextQueued := queuedMessages[0] - effectiveModelConfigID, err := resolveQueuedMessageModelConfigID( - ctx, - tx, - chat, - nextQueued.ModelConfigID, - ) - if err != nil { - return nil, nil, false, err - } - - poppedQueued, err := tx.PopNextQueuedMessage(ctx, chat.ID) - if err != nil { - return nil, nil, false, xerrors.Errorf("pop next queued message: %w", err) - } - if poppedQueued.ID != nextQueued.ID { - return nil, nil, false, xerrors.New("popped queued message out of order") - } - - msgParams := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendUserChatMessage. - ChatID: chat.ID, - } - queuedUserMsg := newUserChatMessage( - nextQueued.APIKeyID.String, - pqtype.NullRawMessage{ - RawMessage: nextQueued.Content, - Valid: len(nextQueued.Content) > 0, - }, - database.ChatMessageVisibilityBoth, - effectiveModelConfigID, - chatprompt.CurrentContentVersion, - ) - queuedUserMsg = queuedUserMsg.withCreatedBy(chat.OwnerID) - appendUserChatMessage(&msgParams, queuedUserMsg) - msgs, err := insertChatMessageWithStore(ctx, tx, msgParams) - if err != nil { - return nil, nil, false, xerrors.Errorf("insert promoted message: %w", err) - } - msg := msgs[0] - - remainingQueuedMessages, err := tx.GetChatQueuedMessages(ctx, chat.ID) - if err != nil { - logger.Error(ctx, "failed to load remaining queued messages after auto-promotion", - slog.F("queued_message_id", nextQueued.ID), slog.Error(err)) - return &msg, nil, false, nil - } - - return &msg, remainingQueuedMessages, true, nil -} - // trackWorkspaceUsage bumps the workspace's last_used_at via the // usage tracker and extends the workspace's autostop deadline. If // wsID is not yet valid, it re-reads the chat from the DB to pick @@ -5515,467 +4689,16 @@ func (p *Server) trackWorkspaceUsage( return wsID } -type finishActiveChatResult struct { - updatedChat database.Chat - promotedMessage *database.ChatMessage - syntheticToolResults []database.ChatMessage - remainingQueuedMessages []database.ChatQueuedMessage - shouldPublishQueueUpdate bool -} - -func (p *Server) finishActiveChat( - ctx context.Context, - logger slog.Logger, - chat database.Chat, - status database.ChatStatus, - lastError pqtype.NullRawMessage, -) (finishActiveChatResult, error) { - result := finishActiveChatResult{} - - err := p.db.InTx(func(tx database.Store) error { - // Re-read the chat status under lock — another caller - // (e.g. promote) may have already set it to pending. - latestChat, lockErr := tx.GetChatByIDForUpdate(ctx, chat.ID) - if lockErr != nil { - return xerrors.Errorf("lock chat for release: %w", lockErr) - } - - // If another worker has already acquired this chat, - // bail out — we must not overwrite their running - // status or publish spurious events. - if latestChat.Status == database.ChatStatusRunning && - latestChat.WorkerID.Valid && - latestChat.WorkerID.UUID != p.workerID { - return errChatTakenByOtherWorker - } - - // If someone else already set the chat to pending (e.g. - // the promote endpoint), don't overwrite it — just clear - // the worker and let the processor pick it back up. - switch { - case latestChat.Status == database.ChatStatusPending: - status = database.ChatStatusPending - case latestChat.Status == database.ChatStatusWaiting && status != database.ChatStatusWaiting && !latestChat.Archived: - // PromoteQueued's deferred path won the status race. - // Insert synthetic tool results before auto-promoting, - // or a RequiresAction worker outcome reintroduces the - // stops-dead bug this PR exists to fix. - inserted, synthErr := insertSyntheticToolResultsTx( - ctx, tx, latestChat, - "Tool execution interrupted by queued message promotion", - ) - if synthErr != nil { - return xerrors.Errorf("insert synthetic tool results during promote-driven cleanup: %w", synthErr) - } - result.syntheticToolResults = inserted - var promoteErr error - result.promotedMessage, result.remainingQueuedMessages, result.shouldPublishQueueUpdate, promoteErr = p.tryAutoPromoteQueuedMessage(ctx, tx, latestChat) - if promoteErr != nil { - logger.Error(ctx, "auto-promote queued message failed during promote-driven cleanup", slog.Error(promoteErr)) - return xerrors.Errorf("auto-promote queued message: %w", promoteErr) - } - if result.promotedMessage != nil { - status = database.ChatStatusPending - } else { - // Queue drained between snapshot and lock; honor - // the external Waiting. - status = database.ChatStatusWaiting - } - case status == database.ChatStatusWaiting && !latestChat.Archived: - // Queued messages were already admitted through SendMessage, - // so auto-promotion only preserves FIFO order here. Archived - // chats skip promotion so archiving behaves like a hard stop. - var promoteErr error - result.promotedMessage, result.remainingQueuedMessages, result.shouldPublishQueueUpdate, promoteErr = p.tryAutoPromoteQueuedMessage(ctx, tx, latestChat) - if promoteErr != nil { - logger.Error(ctx, "auto-promote queued message failed, rolling back", slog.Error(promoteErr)) - return xerrors.Errorf("auto-promote queued message: %w", promoteErr) - } else if result.promotedMessage != nil { - status = database.ChatStatusPending - } - } - - var updateErr error - result.updatedChat, updateErr = tx.UpdateChatStatus(ctx, database.UpdateChatStatusParams{ - ID: chat.ID, - Status: status, - WorkerID: uuid.NullUUID{}, - StartedAt: sql.NullTime{}, - HeartbeatAt: sql.NullTime{}, - LastError: lastError, - }) - return updateErr - }, nil) - if err != nil { - return finishActiveChatResult{}, err - } - - return result, nil -} - -func (p *Server) shouldPublishFinishedChatState( - ctx context.Context, - logger slog.Logger, - updatedChat database.Chat, -) bool { - latestChat, err := p.db.GetChatByID(ctx, updatedChat.ID) - if err != nil { - logger.Warn(ctx, "failed to re-read chat before publishing finished state", - slog.F("chat_id", updatedChat.ID), - slog.Error(err), - ) - return true - } - - if latestChat.Status != updatedChat.Status || latestChat.WorkerID != updatedChat.WorkerID { - logger.Debug(ctx, "skipping stale finished chat publish", - slog.F("chat_id", updatedChat.ID), - slog.F("expected_status", updatedChat.Status), - slog.F("expected_worker_id", updatedChat.WorkerID), - slog.F("latest_status", latestChat.Status), - slog.F("latest_worker_id", latestChat.WorkerID), - ) - return false - } - - return true -} - -func (p *Server) processChat(ctx context.Context, chat database.Chat) { - logger := p.logger.With(slog.F("chat_id", chat.ID)) - logger.Info(ctx, "processing chat request") - - p.metrics.Chats.WithLabelValues(chatloop.StateWaiting).Inc() - defer p.metrics.Chats.WithLabelValues(chatloop.StateWaiting).Dec() - - chatCtx, cancel := context.WithCancelCause(ctx) - defer cancel(nil) - - // Gate the control subscriber behind a channel that is closed - // after we publish "running" status. This prevents stale - // pubsub notifications (e.g. the "pending" notification from - // SendMessage that triggered this processing) from - // interrupting us before we start work. Due to async - // PostgreSQL NOTIFY delivery, a notification published before - // subscribeChatControl registers its queue can still arrive - // after registration. - controlArmed := make(chan struct{}) - gatedCancel := func(cause error) { - select { - case <-controlArmed: - cancel(cause) - default: - logger.Debug(ctx, "ignoring control notification before armed") - } - } - - controlCancel := p.subscribeChatControl(chatCtx, chat.ID, gatedCancel, logger) - defer func() { - if controlCancel != nil { - controlCancel() - } - }() - - // Register with the centralized heartbeat loop instead of - // running a per-chat goroutine. The loop issues a single batch - // UPDATE for all chats on this worker and detects stolen chats - // via set-difference. - p.registerHeartbeat(&heartbeatEntry{ - cancelWithCause: cancel, - chatID: chat.ID, - workspaceID: chat.WorkspaceID, - logger: logger, - }) - defer p.unregisterHeartbeat(chat.ID) - - // Start buffering stream events BEFORE publishing the running - // status. This closes a race where a subscriber sees - // status=running but misses message_part events because - // buffering hasn't started yet — the subscriber gets an empty - // snapshot and publishToStream drops message_parts while - // buffering is false. - streamState := p.getOrCreateStreamState(chat.ID) - streamState.mu.Lock() - streamState.buffer = nil - streamState.bufferRetainedAt = time.Time{} - streamState.resetDropCounters() - streamState.buffering = true - streamState.mu.Unlock() - defer func() { - streamState.mu.Lock() - // Fallback cleanup for exit paths that return before a - // terminal stream event is published. - streamState.currentRetry = nil - streamState.resetDropCounters() - streamState.buffering = false - // Retain the per-chat stream state for a grace period - // so cross-replica relay subscribers can register - // against this chat after processing completes, - // without racing cleanupStreamIfIdle. The buffer is - // cleared when the next processChat starts or when - // cleanupStreamIfIdle runs after the grace period; on - // the normal-completion path every part has been - // claimed by its durable assistant message, so the - // snapshot is empty. On error or panic exit some parts - // may still be in-progress; those are likewise - // discarded when the buffer is cleared, and the - // frontend recovers via the next REST snapshot. - streamState.bufferRetainedAt = p.clock.Now() - streamState.mu.Unlock() - }() - - p.publishStatus(chat.ID, database.ChatStatusRunning, uuid.NullUUID{ - UUID: p.workerID, - Valid: true, - }) - - // Arm the control subscriber. Closing the channel is a - // happens-before guarantee in the Go memory model — any - // notification dispatched after this point will correctly - // interrupt processing. - close(controlArmed) - - // Determine the final status and last error payload to set when we're done. - status := database.ChatStatusWaiting - wasInterrupted := false - var lastErrorPayload *codersdk.ChatError - generatedTitle := &generatedChatTitle{} - runResult := runChatResult{} - remainingQueuedMessages := []database.ChatQueuedMessage{} - shouldPublishQueueUpdate := false - var promotedMessage *database.ChatMessage - - defer func() { - // Use a context that is not canceled by Close() so we can - // reliably update the chat status in the database during - // graceful shutdown. - cleanupCtx := context.WithoutCancel(ctx) - - // Handle panics gracefully. - if r := recover(); r != nil { - logger.Error(cleanupCtx, "panic during chat processing", slog.F("panic", r)) - classified := chaterror.ClassifiedError{ - Message: panicFailureReason(r), - Kind: codersdk.ChatErrorKindGeneric, - } - lastErrorPayload = chaterror.TerminalErrorPayload(classified) - p.publishError(chat.ID, classified) - status = database.ChatStatusError - } - - encodedLastError, err := encodeChatLastErrorPayload(lastErrorPayload) - if err != nil { - logger.Warn(cleanupCtx, "failed to marshal chat last error payload", - slog.Error(err), - ) - lastErrorPayload = nil - encodedLastError = pqtype.NullRawMessage{} - } - - // Check for queued messages and auto-promote the next one. - // This must be done atomically with the status update to avoid - // races with the promote endpoint (which also sets status to - // pending). We use a transaction with FOR UPDATE to ensure we - // don't overwrite a status change made by another caller. - finishResult, err := p.finishActiveChat(cleanupCtx, logger, chat, status, encodedLastError) - if errors.Is(err, errChatTakenByOtherWorker) { - // Another worker owns this chat now — skip all - // post-TX side effects (status publish, pubsub, - // web push) to avoid overwriting their state. - return - } - if err != nil { - logger.Error(cleanupCtx, "failed to release chat", slog.Error(err)) - return - } - status = finishResult.updatedChat.Status - promotedMessage = finishResult.promotedMessage - remainingQueuedMessages = finishResult.remainingQueuedMessages - shouldPublishQueueUpdate = finishResult.shouldPublishQueueUpdate - - // Publish synth rows before the promoted user message. - for _, msg := range finishResult.syntheticToolResults { - p.publishMessage(chat.ID, msg) - } - if promotedMessage != nil { - p.publishMessage(chat.ID, *promotedMessage) - } - if shouldPublishQueueUpdate { - p.publishEvent(chat.ID, codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeQueueUpdate, - QueuedMessages: db2sdk.ChatQueuedMessages(remainingQueuedMessages), - }) - p.publishChatStreamNotify(chat.ID, coderdpubsub.ChatStreamNotifyMessage{ - QueueUpdate: true, - }) - } - if p.shouldPublishFinishedChatState(cleanupCtx, logger, finishResult.updatedChat) { - p.publishStatus(chat.ID, status, uuid.NullUUID{}) - // Best-effort: use any generated title captured during - // processing so push notifications and the status snapshot - // can reflect it without another DB read. The dedicated - // title_change event remains the source of truth. - if title, ok := generatedTitle.Load(); ok { - finishResult.updatedChat.Title = title - } - p.publishChatPubsubEvent(finishResult.updatedChat, codersdk.ChatWatchEventKindStatusChange, nil) - } - - if promotedMessage != nil { - // Wake the processor so it picks up the newly pending - // chat immediately instead of waiting for the next - // acquire-interval tick. - p.signalWake() - } - - // When the chat is parked in requires_action, - // publish the stream event and global pubsub event - // after the DB status has committed. Publishing - // here (not in runChat) prevents a race where a - // fast client reacts before the status is visible. - if status == database.ChatStatusRequiresAction && len(runResult.PendingDynamicToolCalls) > 0 { - toolCalls := pendingToStreamToolCalls(runResult.PendingDynamicToolCalls) - p.publishEvent(chat.ID, codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeActionRequired, - ActionRequired: &codersdk.ChatStreamActionRequired{ - ToolCalls: toolCalls, - }, - }) - p.publishChatActionRequired(finishResult.updatedChat, runResult.PendingDynamicToolCalls) - } - if wasInterrupted { - p.maybeClearLastTurnSummaryAsync(cleanupCtx, finishResult.updatedChat, logger) - } else { - lastErrorMessage := "" - if lastErrorPayload != nil { - lastErrorMessage = lastErrorPayload.Message - } - p.maybeFinalizeTurnStatusLabelAndPush( - cleanupCtx, - finishResult.updatedChat, - status, - lastErrorMessage, - runResult, - logger, - ) - } - }() - - p.metrics.Chats.WithLabelValues(chatloop.StateWaiting).Dec() - p.metrics.Chats.WithLabelValues(chatloop.StateStreaming).Inc() - defer func() { - p.metrics.Chats.WithLabelValues(chatloop.StateStreaming).Dec() - p.metrics.Chats.WithLabelValues(chatloop.StateWaiting).Inc() - }() - runResult, err := p.runChat(chatCtx, chat, generatedTitle, logger) - if err != nil { - if errors.Is(err, chatloop.ErrInterrupted) || errors.Is(context.Cause(chatCtx), chatloop.ErrInterrupted) { - logger.Info(ctx, "chat interrupted") - status = database.ChatStatusWaiting - lastErrorPayload = nil - wasInterrupted = true - return - } - if isShutdownCancellation(ctx, chatCtx, err) { - logger.Info(ctx, "chat canceled during shutdown; returning to pending") - status = database.ChatStatusPending - lastErrorPayload = nil - wasInterrupted = true - return - } - logger.Error(ctx, "failed to process chat", slog.Error(err)) - if classified, ok := processingFailure(err); ok { - lastErrorPayload = chaterror.TerminalErrorPayload(classified) - p.publishError(chat.ID, classified) - } - status = database.ChatStatusError - return - } - - // The LLM invoked a dynamic tool — park the chat in - // requires_action so the client can supply tool results. - if len(runResult.PendingDynamicToolCalls) > 0 { - status = database.ChatStatusRequiresAction - return - } - - // If runChat completed successfully but the server context was - // canceled (e.g. during Close()), the chat should be returned - // to pending so another replica can pick it up. There is a - // race where the LLM stream finishes just as the server is - // shutting down — the HTTP response completes before context - // cancellation propagates, so runChat returns nil instead of - // a context.Canceled error. Without this check the chat would - // be marked "waiting" and never retried. - if ctx.Err() != nil { - logger.Info(ctx, "chat completed during shutdown; returning to pending") - status = database.ChatStatusPending - lastErrorPayload = nil - wasInterrupted = true - return - } -} - -func isShutdownCancellation( - serverCtx context.Context, - chatCtx context.Context, - err error, -) bool { - if err == nil { - return false - } - // During Close(), the server context is canceled. In-flight chats should - // be returned to pending so another replica can retry them. - if serverCtx.Err() == nil { - return false - } - if errors.Is(err, context.Canceled) { - return true - } - return errors.Is(context.Cause(chatCtx), context.Canceled) -} - -// generatedChatTitle shares an asynchronously generated title between the -// detached title-generation goroutine and the deferred cleanup path. -type generatedChatTitle struct { - mu sync.RWMutex - title string -} - -func (t *generatedChatTitle) Store(title string) { - if t == nil || title == "" { - return - } - - t.mu.Lock() - t.title = title - t.mu.Unlock() -} - -func (t *generatedChatTitle) Load() (string, bool) { - if t == nil { - return "", false - } - - t.mu.RLock() - defer t.mu.RUnlock() - if t.title == "" { - return "", false - } - return t.title, true -} - type runChatResult struct { - FinalAssistantText string - StatusLabelModel fantasy.LanguageModel - ProviderKeys chatprovider.ProviderAPIKeys - PendingDynamicToolCalls []chatloop.PendingToolCall - FallbackProvider string - FallbackRoute resolvedModelRoute - FallbackModel string - ModelBuildOptions modelBuildOptions - TriggerMessageID int64 - HistoryTipMessageID int64 + FinalAssistantText string + StatusLabelModel fantasy.LanguageModel + ProviderKeys chatprovider.ProviderAPIKeys + FallbackProvider string + FallbackRoute resolvedModelRoute + FallbackModel string + ModelBuildOptions modelBuildOptions + TriggerMessageID int64 + HistoryTipMessageID int64 } func activeTurnAPIKeyIDFromMessages(messages []database.ChatMessage) (string, bool) { @@ -6303,8 +5026,6 @@ type rootChatToolsOptions struct { modelConfigID uuid.UUID workspaceCtx *turnWorkspaceContext workspaceMu *sync.Mutex - instruction *string - skills *[]chattool.SkillMeta resolvePlanPath func(context.Context) (string, string, error) storeFile chattool.StoreFileFunc isPlanModeTurn bool @@ -6433,31 +5154,14 @@ func (p *Server) appendRootChatTools( // build logs before the tool completes. p.publishChatPubsubEvent(updatedChat, codersdk.ChatWatchEventKindStatusChange, nil) - // When a workspace is first attached mid-turn (e.g. via - // create_workspace), fetch and persist instruction files - // immediately so the LLM has AGENTS.md context for the remainder - // of this turn. The persisted marker prevents redundant fetches on - // subsequent turns. - if *opts.instruction == "" && updatedChat.WorkspaceID.Valid { - newInstruction, discoveredSkills, persistErr := p.persistInstructionFiles( - ctx, - updatedChat, - opts.modelConfigID, - opts.workspaceCtx.getWorkspaceAgent, - opts.workspaceCtx.getWorkspaceConn, - ) - if persistErr != nil { - p.logger.Warn(ctx, "failed to persist instruction files on workspace attach", - slog.F("chat_id", updatedChat.ID), - slog.Error(persistErr), - ) - } else { - *opts.instruction = newInstruction - if len(discoveredSkills) > 0 { - *opts.skills = discoveredSkills - } - } - } + // Note: we intentionally do not insert AGENTS.md / workspace + // context here. Local tool callbacks must not mutate chat + // history while a local-tool generation task is in flight, + // because that advances history_version before the tool + // result is committed and exits the local-tool commit as + // stale. Workspace context is persisted by the + // persist_workspace_context generation action in a later + // pass. // Prime the workspace MCP tools cache while the create_workspace // or start_workspace tool is still running. The AgentID guard @@ -6466,14 +5170,14 @@ func (p *Server) appendRootChatTools( // empty list on the first try when the agent's MCP Connect is // racing with agent startup; primeWorkspaceMCPCache retries // with a short backoff up to workspaceMCPPrimeMaxWait. Priming - // here lets the next LLM step's PrepareTools hit the cache + // here lets the next assistant-generation action hit the cache // instead of dialing again on a separate timeout budget. // // Run asynchronously: the tool itself must not block on the // primer because the agent may not advertise any MCP tools at // all (e.g. minimal templates), in which case the primer waits - // the full budget before giving up. PrepareTools on the next - // step covers the cache miss path; the primer is purely an + // the full budget before giving up. The next assistant-generation + // action covers the cache miss path; the primer is purely an // optimization that warms the cache while the LLM is thinking. // inflight tracking ensures server shutdown still waits for any // in-progress primer. @@ -6486,21 +5190,11 @@ func (p *Server) appendRootChatTools( // the pre-build and stop-side firings would otherwise spawn a // primer goroutine that dials a missing or dying agent and // burns the full budget for nothing. - // - // Read the snapshot from workspaceCtx rather than the - // updatedChat parameter: persistInstructionFiles above runs - // ensureWorkspaceAgent which calls persistBuildAgentBinding and - // setCurrentChat, so by the time we get here the in-memory - // snapshot has the freshly bound AgentID even when the - // updatedChat parameter (read from the DB before the binding - // was persisted) does not. snapshot := opts.workspaceCtx.currentChatSnapshot() if snapshot.WorkspaceID.Valid && snapshot.AgentID.Valid { - p.inflight.Add(1) - go func() { - defer p.inflight.Done() + p.inflight.Go(func() { p.primeWorkspaceMCPCache(opts.primerCtx, p.logger, snapshot.ID, opts.workspaceCtx) - }() + }) } } @@ -6606,1333 +5300,6 @@ func appendDynamicTools( return append(tools, dynamicToolsFromSDK(logger, filteredDefs)...), dynamicToolNames, nil } -func (p *Server) runChat( - ctx context.Context, - chat database.Chat, - generatedTitle *generatedChatTitle, - logger slog.Logger, -) (runChatResult, error) { - result := runChatResult{} - var ( - model fantasy.LanguageModel - modelConfig database.ChatModelConfig - providerKeys chatprovider.ProviderAPIKeys - callConfig codersdk.ChatModelCallConfig - messages []database.ChatMessage - err error - debugEnabled bool - debugProvider string - modelRoute resolvedModelRoute - debugModel string - ) - - messages, err = p.db.GetChatMessagesForPromptByChatID(ctx, chat.ID) - if err != nil { - return result, xerrors.Errorf("get chat messages: %w", err) - } - modelOpts := modelBuildOptionsFromMessages(messages) - if modelOpts.ActiveAPIKeyID != "" { - ctx = aibridge.WithDelegatedAPIKeyID(ctx, modelOpts.ActiveAPIKeyID) - } - - // Load MCP server configs and user tokens in parallel with model - // resolution. These queries have no dependencies on each other and all - // hit different tables. - var ( - mcpConfigs []database.MCPServerConfig - mcpTokens []database.MCPServerUserToken - ) - var g errgroup.Group - g.Go(func() error { - var err error - model, modelConfig, providerKeys, modelRoute, debugEnabled, debugProvider, debugModel, err = p.resolveChatModel(ctx, chat, modelOpts) - if err != nil { - return err - } - if len(modelConfig.Options) > 0 { - if err := json.Unmarshal(modelConfig.Options, &callConfig); err != nil { - return xerrors.Errorf("parse model call config: %w", err) - } - } - return nil - }) - if len(chat.MCPServerIDs) > 0 { - g.Go(func() error { - var err error - mcpConfigs, err = p.db.GetMCPServerConfigsByIDs( - ctx, chat.MCPServerIDs, - ) - if err != nil { - logger.Warn(ctx, - "failed to load MCP server configs", - slog.Error(err), - ) - } - return nil - }) - g.Go(func() error { - var err error - // If token loading fails, ConnectAll will still - // proceed but oauth2-authenticated servers will - // attempt to connect without credentials. Those - // connections may succeed or fail depending on - // the remote server's auth requirements. - mcpTokens, err = p.db.GetMCPServerUserTokensByUserID( - ctx, chat.OwnerID, - ) - if err != nil { - logger.Warn(ctx, - "failed to load MCP user tokens", - slog.Error(err), - ) - } - return nil - }) - } - if err := g.Wait(); err != nil { - return result, err - } - - // Capture the current turn's mode so prompt and tool behavior can - // be resolved consistently for the rest of the turn. - currentPlanMode := chat.PlanMode - isPlanModeTurn := currentPlanMode.Valid && currentPlanMode.ChatPlanMode == database.ChatPlanModePlan - isExploreSubagent := isExploreSubagentMode(chat.Mode) - isRootChat := !chat.ParentChatID.Valid - var mcpConnectConfigs []database.MCPServerConfig - var approvedPlanMCPConfigIDs map[uuid.UUID]struct{} - // Explore subagents rely on the immutable spawn-time snapshot - // persisted in chat.MCPServerIDs. SendMessage cannot mutate that - // snapshot, so no runtime re-filter against parent state is needed. - // The child's persisted set is authoritative. - mcpConnectConfigs, approvedPlanMCPConfigIDs = filterExternalMCPConfigsForTurn( - mcpConfigs, - currentPlanMode, - chat.ParentChatID, - ) - if isExploreSubagent && isRootChat { - // Root Explore chats stay builtin-only per the accepted plan, so - // strip any persisted external MCP configs at runtime regardless of - // what's on the chat row. Explore children get their snapshot via - // the spawn-time inheritance path and are handled below. - mcpConnectConfigs = nil - approvedPlanMCPConfigIDs = map[uuid.UUID]struct{}{} - } - planModeInstructions := p.loadPlanModeInstructions(ctx, currentPlanMode, logger) - - advisorCfg := p.loadAdvisorConfig(ctx, logger) - - var advisorRuntime *chatadvisor.Runtime - // Plan mode filters the advisor tool out of the turn's tool set via - // filterToolsForTurn, so enabling the runtime there would inject - // guidance and enforce advisor exclusivity for a tool the model - // cannot actually call. Explore chats (root or subagent) run under - // allowedExploreToolNames, whose policy does not include advisor, so - // registering the runtime there would inject guidance for a tool - // that is never exposed to the model. - if advisorCfg.Enabled && isRootChat && !isPlanModeTurn && !isExploreSubagent { - var advisorErr error - advisorRuntime, advisorErr = p.newAdvisorRuntime( - ctx, - chat, - advisorCfg, - model, - callConfig, - providerKeys, - modelOpts, - logger, - ) - if advisorErr != nil { - return result, advisorErr - } - } - - var advisorPromptSnapshot []fantasy.Message - // setAdvisorPromptSnapshot captures the final prompt state the outer - // model sees so the advisor tool can forward it as nested context. - // It is invoked at four lifecycle points (after initial system-prompt - // assembly, inside PrepareMessages before and after instruction - // injection, and after ReloadMessages rebuilds the prompt) because - // the prompt mutates at each of them and the advisor must snapshot - // the post-mutation state. Removing any of those calls would leave - // the advisor with a stale view of the conversation. - // - // The no-op guard keeps the common disabled/filtered paths (advisor - // off, plan mode, explore, child chats) from paying an O(n) prompt - // clone per step for a snapshot that is never consumed. - setAdvisorPromptSnapshot := func(msgs []fantasy.Message) { - if advisorRuntime == nil { - return - } - advisorPromptSnapshot = slices.Clone(msgs) - } - - chainInfo := chatopenai.ResolveChainMode(messages) - result.StatusLabelModel = model - result.ProviderKeys = providerKeys - result.FallbackProvider = modelConfig.Provider - result.FallbackRoute = modelRoute - result.FallbackModel = modelConfig.Model - result.ModelBuildOptions = modelOpts - debugSvc := p.existingDebugService() - // Fire title generation asynchronously so it doesn't block the - // chat response. It uses a detached context so it can finish - // even after the chat processing context is canceled. - // Snapshot values captured by the goroutine because model, providerKeys, - // logger, and ctx are reassigned below. - titleModel := model - titleProviderKeys := providerKeys - titleLogger := logger - titleCtx := context.WithoutCancel(ctx) - p.inflight.Add(1) - go func() { - defer p.inflight.Done() - p.maybeGenerateChatTitle( - titleCtx, - chat, - messages, - modelConfig.Provider, - modelConfig.Model, - titleModel, - modelRoute, - titleProviderKeys, - modelOpts, - generatedTitle, - titleLogger, - debugSvc, - ) - }() - - // Detect computer-use subagent via the mode column. - isComputerUse := chat.Mode.Valid && chat.Mode.ChatMode == database.ChatModeComputerUse - - var ( - computerUseProvider string - computerUseModelProvider string - computerUseModelName string - ) - if isComputerUse { - var err error - computerUseProvider, computerUseModelProvider, computerUseModelName, err = p.computerUseProviderAndModelFromConfig(ctx) - if err != nil { - return result, xerrors.Errorf( - "resolve computer use provider and model: %w", - err, - ) - } - } - - // NOTE: Buffering was already started in processChat before - // the running status was published, so message_part events - // are captured from the moment subscribers can see - // status=running. The deferred cleanup also lives in - // processChat. - - currentChat := chat - loadChatSnapshot := func( - loadCtx context.Context, - chatID uuid.UUID, - ) (database.Chat, error) { - return p.db.GetChatByID(loadCtx, chatID) - } - var ( - chatStateMu sync.Mutex - workspaceMu sync.Mutex - ) - workspaceCtx := turnWorkspaceContext{ - server: p, - chatStateMu: &chatStateMu, - currentChat: ¤tChat, - loadChatSnapshot: loadChatSnapshot, - } - // primerCtx scopes the workspace MCP cache primer goroutines that - // onChatUpdated launches. We cancel it before workspaceCtx.close() - // so an in-flight primer cannot wake from its retry backoff, - // observe a cleared cached conn, dial a fresh one, and leak it - // when no subsequent close() runs. - primerCtx, primerCancel := context.WithCancel(ctx) - defer func() { - primerCancel() - workspaceCtx.close() - }() - - planPathFn := func(ctx context.Context) (string, string, error) { - conn, err := workspaceCtx.getWorkspaceConn(ctx) - if err != nil { - return "", "", err - } - home, err := chattool.ResolveWorkspaceHome(ctx, conn) - if err != nil { - return "", "", err - } - return chattool.PlanPathForChat(home, chat.ID), home, nil - } - resolvePlanPathForTools := func(ctx context.Context) (string, string, error) { - ctx, cancel := context.WithTimeout(ctx, planPathLookupTimeout) - defer cancel() - return planPathFn(ctx) - } - resolvePlanPathBlock := func(resolveCtx context.Context) string { - if chat.ParentChatID.Valid { - return "" - } - - planCtx, cancel := context.WithTimeout(resolveCtx, planPathLookupTimeout) - defer cancel() - - if _, _, err := workspaceCtx.workspaceAgentIDForConn(planCtx); err != nil { - p.logger.Debug(resolveCtx, "plan path instruction: agent not reachable", - slog.Error(err), - slog.F("chat_id", chat.ID), - ) - return "" - } - - planPath, home, err := planPathFn(planCtx) - if err != nil { - p.logger.Debug(resolveCtx, "plan path instruction: failed to resolve plan path", - slog.Error(err), - slog.F("chat_id", chat.ID), - ) - return "" - } - - return formatPlanPathBlock(planPath, home) - } - - // Connect to MCP servers in parallel with instruction - // resolution. ConnectAll only depends on mcpConfigs and - // mcpTokens which are available after g.Wait() above. - var ( - instruction string - resolvedUserPrompt string - mcpTools []fantasy.AgentTool - mcpCleanup func() - workspaceMCPTools []fantasy.AgentTool - workspaceSkills []chattool.SkillMeta - personalSkills []skillspkg.Skill - ) - // Check if instruction files need to be (re-)persisted. - // This happens when no context-file parts exist yet, or when - // the workspace agent has changed (e.g. workspace rebuilt). - needsInstructionPersist := false - hasContextFiles := false - persistedSkills := skillsFromParts(messages) - latestInjectedAgentID, hasLatestInjectedAgent := latestContextAgentID(messages) - currentWorkspaceAgentID := uuid.Nil - hasCurrentWorkspaceAgent := false - if chat.WorkspaceID.Valid { - if agent, agentErr := workspaceCtx.getWorkspaceAgent(ctx); agentErr == nil { - currentWorkspaceAgentID = agent.ID - hasCurrentWorkspaceAgent = true - } - persistedAgentID, found := contextFileAgentID(messages) - hasContextFiles = found - if !hasPersistedInstructionFiles(messages) { - needsInstructionPersist = true - } else if hasCurrentWorkspaceAgent && currentWorkspaceAgentID != persistedAgentID { - // Agent changed. Persist fresh instruction files. - // Old context-file messages remain in the conversation - // to preserve the prompt cache prefix. - needsInstructionPersist = true - } - } - // Convert messages to prompt format in parallel with g2 work. - // ConvertMessagesWithFiles only reads `messages` (available - // after g.Wait()) and resolves file references via the DB. - // No g2 task reads or writes `prompt`, so this is safe. - var prompt []fantasy.Message - var g2 errgroup.Group - g2.Go(func() error { - var err error - prompt, err = chatprompt.ConvertMessagesWithFiles(ctx, messages, p.chatFileResolver(modelConfig.Provider), logger) - if err != nil { - return xerrors.Errorf("build chat prompt: %w", err) - } - return nil - }) - if needsInstructionPersist { - g2.Go(func() error { - var persistErr error - var discoveredSkills []chattool.SkillMeta - instruction, discoveredSkills, persistErr = p.persistInstructionFiles( - ctx, - chat, - modelConfig.ID, - workspaceCtx.getWorkspaceAgent, - func(instructionCtx context.Context) (workspacesdk.AgentConn, error) { - if _, _, err := workspaceCtx.workspaceAgentIDForConn(instructionCtx); err != nil { - return nil, err - } - return workspaceCtx.getWorkspaceConn(instructionCtx) - }, - ) - workspaceSkills = selectSkillMetasForInstructionRefresh( - persistedSkills, - discoveredSkills, - uuid.NullUUID{UUID: currentWorkspaceAgentID, Valid: hasCurrentWorkspaceAgent}, - uuid.NullUUID{UUID: latestInjectedAgentID, Valid: hasLatestInjectedAgent}, - ) - if persistErr != nil { - p.logger.Warn(ctx, "failed to persist instruction files", - slog.F("chat_id", chat.ID), - slog.Error(persistErr), - ) - } - return nil - }) - } else if hasContextFiles { - // On subsequent turns, extract the instruction text and - // skill index from persisted parts so they can be - // re-injected via InsertSystem after compaction drops - // those messages. No workspace dial needed. - instruction = instructionFromContextFiles(messages) - workspaceSkills = persistedSkills - } - g2.Go(func() error { - personalSkills = p.fetchPersonalSkillMetadata(ctx, chat.OwnerID, logger) - return nil - }) - g2.Go(func() error { - resolvedUserPrompt = p.resolveUserPrompt(ctx, chat.OwnerID) - return nil - }) - if len(mcpConnectConfigs) > 0 { - g2.Go(func() error { - // Refresh expired OAuth2 tokens before connecting. - mcpTokens = p.refreshExpiredMCPTokens(ctx, logger, mcpConnectConfigs, mcpTokens) - mcpTools, mcpCleanup = mcpclient.ConnectAll( - ctx, logger, mcpConnectConfigs, mcpTokens, chat.OwnerID, p.oidcTokenSource, - chatprovider.CoderHeaders(chat), - ) - return nil - }) - } - // Workspace MCP discovery stays disabled for all plan-mode turns. - // Root plan mode only gets approved external MCP servers, and - // plan-mode subagents get no MCP tools. When the chat has no - // workspace yet, discovery happens mid-turn via the chatloop - // PrepareTools callback installed below in chatloop.Run options. - if chat.WorkspaceID.Valid && !isPlanModeTurn { - g2.Go(func() error { - workspaceMCPTools = p.discoverWorkspaceMCPTools( - ctx, logger, chat.ID, &workspaceCtx, - ) - return nil - }) - } - if err := g2.Wait(); err != nil { - return result, err - } - prompt, sanitizeStats := chatsanitize.SanitizeAnthropicProviderToolHistory(model.Provider(), prompt) - chatsanitize.LogAnthropicProviderToolSanitization( - ctx, logger, "persisted_history_replay", model.Provider(), model.Model(), sanitizeStats, - ) - subagentInstruction := "" - if !isRootChat { - subagentInstruction = defaultSubagentInstruction - } - resolvedSkillsFor := func(workspaceSkills []chattool.SkillMeta) []skillspkg.ResolvedSkill { - return mergeTurnSkills(personalSkills, workspaceSkills) - } - resolveSkillAlias := func(alias string) (skillspkg.ResolvedSkill, error) { - return skillspkg.Lookup(resolvedSkillsFor(workspaceSkills), alias) - } - initialResolvedSkills := resolvedSkillsFor(workspaceSkills) - injectedSkillIndex := chattool.FormatResolvedSkillIndex(initialResolvedSkills) - prompt = buildSystemPrompt( - prompt, - subagentInstruction, - instruction, - initialResolvedSkills, - resolvedUserPrompt, - systemPromptBehaviorContext{ - planMode: currentPlanMode, - chatMode: chat.Mode, - planModeInstructions: planModeInstructions, - isRootChat: isRootChat, - }, - ) - // Inject advisor guidance when the advisor runtime is available. - if advisorRuntime != nil { - prompt = chatprompt.InsertSystem(prompt, chatadvisor.ParentGuidanceBlock) - } - if mcpCleanup != nil { - defer mcpCleanup() - } - - // Build a lookup from tool name to MCP server config ID - // so we can annotate persisted parts with the originating - // server. - toolNameToConfigID := make(map[string]uuid.UUID) - for _, t := range mcpTools { - if mcpTool, ok := t.(mcpclient.MCPToolIdentifier); ok { - toolNameToConfigID[t.Info().Name] = mcpTool.MCPServerConfigID() - } - } - - instructionInjected := instruction != "" - // workspaceMCPDiscovered tracks whether workspace MCP discovery - // has already been attempted for this turn. The top-of-turn - // discovery path above only fires when chat.WorkspaceID is - // valid at the start of the turn. For chats that bind a - // workspace mid-turn (e.g. via create_workspace) the chatloop - // PrepareTools callback below triggers discovery on the next - // step. After discovery has run once (here or in PrepareTools), - // this flag prevents redundant dials. - workspaceMCPDiscovered := chat.WorkspaceID.Valid || isPlanModeTurn - prompt = renderPlanPathPrompt(prompt, resolvePlanPathBlock(ctx)) - setAdvisorPromptSnapshot(prompt) - // Use the model config's context_limit as a fallback when the LLM - // provider doesn't include context_limit in its response metadata - // (which is the common case). - modelConfigContextLimit := modelConfig.ContextLimit - var finalAssistantText string - var pendingDynamicCalls []chatloop.PendingToolCall - - compactionHistoryTipMessageID := int64(0) - if len(messages) > 0 { - compactionHistoryTipMessageID = messages[len(messages)-1].ID - } - - var compactionOptions *chatloop.CompactionOptions - - persistStep := func(persistCtx context.Context, step chatloop.PersistedStep) error { - // If the chat context has been canceled, bail out before - // inserting any messages. We distinguish the cause so that - // the caller can tell an intentional interruption (e.g. - // EditMessage, user stop) from a server shutdown: - // - ErrInterrupted cause → return ErrInterrupted - // (processChat sets status = waiting). - // - Any other cause (e.g. context.Canceled during - // Close()) → return the original context error so - // isShutdownCancellation can match and set status = - // pending, allowing another replica to retry. - if persistCtx.Err() != nil { - if errors.Is(context.Cause(persistCtx), chatloop.ErrInterrupted) { - return chatloop.ErrInterrupted - } - return persistCtx.Err() - } - - // Capture pending dynamic tool calls so the caller - // can surface them after chatloop.Run returns. - pendingDynamicCalls = step.PendingDynamicToolCalls - - // Split the step content into assistant blocks and tool - // result blocks so they can be stored as separate messages - // with the appropriate roles. Provider-executed tool results - // (e.g. web_search) stay in the assistant content because - // the LLM provider expects them inline in the assistant - // turn, not as separate tool messages. - var assistantBlocks []fantasy.Content - var toolResults []fantasy.ToolResultContent - for _, block := range step.Content { - if tr, ok := fantasy.AsContentType[fantasy.ToolResultContent](block); ok { - if !tr.ProviderExecuted { - toolResults = append(toolResults, tr) - continue - } - } - if trPtr, ok := fantasy.AsContentType[*fantasy.ToolResultContent](block); ok && trPtr != nil { - if !trPtr.ProviderExecuted { - toolResults = append(toolResults, *trPtr) - continue - } - } - assistantBlocks = append(assistantBlocks, block) - } - - // Pre-marshal all content outside the transaction so the - // FOR UPDATE lock is held only for the INSERT statements. - // Marshaling is pure CPU work with no database dependency. - assistantParts := buildAssistantPartsForPersist( - persistCtx, - p.logger, - assistantBlocks, - toolResults, - step, - toolNameToConfigID, - ) - - var assistantContent pqtype.NullRawMessage - if len(assistantParts) > 0 { - finalAssistantText = strings.TrimSpace(contentBlocksToText(assistantParts)) - var marshalErr error - assistantContent, marshalErr = chatprompt.MarshalParts(assistantParts) - if marshalErr != nil { - return xerrors.Errorf("marshal assistant content: %w", marshalErr) - } - } - - toolResultContents := make([]pqtype.NullRawMessage, len(toolResults)) - for i, tr := range toolResults { - trPart := chatprompt.PartFromContentWithLogger(ctx, logger, tr) - if trPart.ToolName != "" { - if configID, ok := toolNameToConfigID[trPart.ToolName]; ok { - trPart.MCPServerConfigID = uuid.NullUUID{UUID: configID, Valid: true} - } - } - // Apply recorded timestamps so persisted - // tool-result parts carry accurate CreatedAt. - if trPart.ToolCallID != "" && step.ToolResultCreatedAt != nil { - if ts, ok := step.ToolResultCreatedAt[trPart.ToolCallID]; ok { - trPart.CreatedAt = &ts - } - } - var marshalErr error - toolResultContents[i], marshalErr = chatprompt.MarshalParts([]codersdk.ChatMessagePart{trPart}) - if marshalErr != nil { - return xerrors.Errorf("marshal tool result %d: %w", i, marshalErr) - } - } - - hasUsage := step.Usage != (fantasy.Usage{}) - usageForCost := fantasyUsageToChatMessageUsage(step.Usage) - totalCostMicros := chatcost.CalculateTotalCostMicros(usageForCost, callConfig.Cost) - - var insertedMessages []database.ChatMessage - if err := p.db.InTx(func(tx database.Store) error { - // Verify this worker still owns the chat before - // inserting messages. This closes the race where - // EditMessage soft-deletes history and clears worker_id - // while persistInterruptedStep (which uses an - // uncancelable context) is still running. - // - // When the chat is in "waiting" status (set by - // InterruptChat / setChatWaiting), the worker_id has - // already been cleared but we still want to persist - // the partial assistant response. We allow the write - // because the history has NOT been truncated — the - // user simply asked to stop. In contrast, EditMessage - // sets the chat to "pending" after truncating, so the - // pending check still correctly blocks stale writes. - lockedChat, lockErr := tx.GetChatByIDForUpdate(persistCtx, chat.ID) - if lockErr != nil { - return xerrors.Errorf("lock chat for persist: %w", lockErr) - } - if !lockedChat.WorkerID.Valid || lockedChat.WorkerID.UUID != p.workerID { - // The worker_id was cleared. Only allow the persist - // if the chat transitioned to "waiting" (interrupt), - // not "pending" (edit) or any other status. - if lockedChat.Status != database.ChatStatusWaiting { - return chatloop.ErrInterrupted - } - } - - stepParams := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendChatMessage. - ChatID: chat.ID, - } - - var contextLimit int64 - if step.ContextLimit.Valid { - contextLimit = step.ContextLimit.Int64 - } - - var runtimeMs int64 - if step.Runtime > 0 { - runtimeMs = step.Runtime.Milliseconds() - } - - var totalCostVal int64 - if totalCostMicros != nil { - totalCostVal = *totalCostMicros - } - - var inputTokens, outputTokens, totalTokens int64 - var reasoningTokens, cacheCreationTokens, cacheReadTokens int64 - if hasUsage { - inputTokens = step.Usage.InputTokens - outputTokens = step.Usage.OutputTokens - totalTokens = step.Usage.TotalTokens - reasoningTokens = step.Usage.ReasoningTokens - cacheCreationTokens = step.Usage.CacheCreationTokens - cacheReadTokens = step.Usage.CacheReadTokens - } - - if assistantContent.Valid { - appendChatMessage(&stepParams, newChatMessage( - database.ChatMessageRoleAssistant, - assistantContent, - database.ChatMessageVisibilityBoth, - modelConfig.ID, - chatprompt.CurrentContentVersion, - ).withUsage( - inputTokens, outputTokens, totalTokens, - reasoningTokens, cacheCreationTokens, cacheReadTokens, - ).withContextLimit(contextLimit). - withTotalCostMicros(totalCostVal). - withRuntimeMs(runtimeMs). - withProviderResponseID(step.ProviderResponseID)) - } - - for _, resultContent := range toolResultContents { - appendChatMessage(&stepParams, newChatMessage( - database.ChatMessageRoleTool, - resultContent, - database.ChatMessageVisibilityBoth, - modelConfig.ID, - chatprompt.CurrentContentVersion, - )) - } - - if len(stepParams.Role) > 0 { - inserted, insertErr := tx.InsertChatMessages(persistCtx, stepParams) - if insertErr != nil { - return xerrors.Errorf("insert step messages: %w", insertErr) - } - insertedMessages = append(insertedMessages, inserted...) - } - - return nil - }, nil); err != nil { - return xerrors.Errorf("persist step transaction: %w", err) - } - - for _, msg := range insertedMessages { - p.publishMessage(chat.ID, msg) - } - if len(insertedMessages) > 0 { - compactionHistoryTipMessageID = insertedMessages[len(insertedMessages)-1].ID - if compactionOptions != nil { - compactionOptions.HistoryTipMessageID = compactionHistoryTipMessageID - } - } - - // Do NOT clear the stream buffer here. The per-chat - // stream state must remain alive for the post-completion - // grace window so cross-replica relay subscribers can - // register without racing cleanupStreamIfIdle. The buffer - // is bounded by maxStreamBufferSize and is cleared when - // the next processChat starts or when the stream state - // is garbage-collected after the retention grace period. - - return nil - } - // Apply the default MaxOutputTokens if the model config - // does not specify one. - if callConfig.MaxOutputTokens == nil { - maxOutputTokens := int64(32_000) - callConfig.MaxOutputTokens = &maxOutputTokens - } - - // Generate the tool call ID up front so that the streaming - // parts and durable messages share the same identifier. - // Without this the client cannot correlate the - // "Summarizing..." tool call with the "Summarized" tool - // result. - compactionToolCallID := "chat_summarized_" + uuid.NewString() - effectiveThreshold := modelConfig.CompressionThreshold - thresholdSource := "model_default" - if override, ok := p.resolveUserCompactionThreshold(ctx, chat.OwnerID, modelConfig.ID); ok { - effectiveThreshold = override - thresholdSource = "user_override" - } - compactionOptions = &chatloop.CompactionOptions{ - ThresholdPercent: effectiveThreshold, - ContextLimit: modelConfig.ContextLimit, - HistoryTipMessageID: compactionHistoryTipMessageID, - Persist: func( - persistCtx context.Context, - result chatloop.CompactionResult, - ) error { - if err := p.persistChatContextSummary( - persistCtx, - chat.ID, - modelConfig.ID, - modelOpts.ActiveAPIKeyID, - compactionToolCallID, - result, - ); err != nil { - return xerrors.Errorf("persist context summary: %w", err) - } - logger.Info(persistCtx, "chat context summarized", - slog.F("chat_id", chat.ID), - slog.F("threshold_source", thresholdSource), - slog.F("threshold_percent", result.ThresholdPercent), - slog.F("usage_percent", result.UsagePercent), - slog.F("context_tokens", result.ContextTokens), - slog.F("context_limit", result.ContextLimit), - ) - return nil - }, - ToolCallID: compactionToolCallID, - ToolName: "chat_summarized", - PublishMessagePart: func(role codersdk.ChatMessageRole, part codersdk.ChatMessagePart) { - p.publishMessagePart(chat.ID, role, part) - }, - OnError: func(err error) { - logger.Warn(ctx, "failed to compact chat context", slog.Error(err)) - }, - } - - if isComputerUse { - computerUseRoute, keyErr := p.resolveModelRouteForProviderType(ctx, chat.OwnerID, computerUseModelProvider) - if keyErr != nil { - return result, xerrors.Errorf("resolve computer use provider route: %w", keyErr) - } - providerKeys = computerUseRoute.directProviderKeys() - - // Override model for computer use subagent. - cuModel, cuDebugEnabled, resolvedProvider, resolvedModel, cuErr := p.resolveComputerUseModel( - ctx, - chat, - computerUseRoute, - computerUseProvider, - computerUseModelProvider, - computerUseModelName, - modelOpts, - ) - if cuErr != nil { - return result, cuErr - } - model = cuModel - debugEnabled = cuDebugEnabled - debugProvider = resolvedProvider - debugModel = resolvedModel - } - if debugEnabled { - if debugSvc == nil { - return result, xerrors.New("chat debug service missing after enablement check") - } - compactionOptions.DebugSvc = debugSvc - compactionOptions.ChatID = chat.ID - } - - // Enrich the scoped logger with provider/model for this turn. - // Bound once after the cuModel swap; slog.Logger.With appends - // rather than deduping. - logger = logger.With( - slog.F("provider", model.Provider()), - slog.F("model", model.Model()), - ) - - allowAskUserQuestion := isPlanModeTurn && isRootChat - storeChatAttachment := p.newStoreChatAttachmentFunc(&workspaceCtx) - tools := []fantasy.AgentTool{ - chattool.ReadFile(chattool.ReadFileOptions{ - GetWorkspaceConn: workspaceCtx.getWorkspaceConn, - }), - chattool.WriteFile(chattool.WriteFileOptions{ - GetWorkspaceConn: workspaceCtx.getWorkspaceConn, - ResolvePlanPath: resolvePlanPathForTools, - IsPlanTurn: isPlanModeTurn, - }), - chattool.EditFiles(chattool.EditFilesOptions{ - GetWorkspaceConn: workspaceCtx.getWorkspaceConn, - ResolvePlanPath: resolvePlanPathForTools, - IsPlanTurn: isPlanModeTurn, - }), - chattool.AttachFile(chattool.AttachFileOptions{ - GetWorkspaceConn: workspaceCtx.getWorkspaceConn, - StoreFile: storeChatAttachment, - }), - chattool.Execute(chattool.ExecuteOptions{ - GetWorkspaceConn: workspaceCtx.getWorkspaceConn, - }), - chattool.ProcessOutput(chattool.ProcessToolOptions{ - GetWorkspaceConn: workspaceCtx.getWorkspaceConn, - }), - chattool.ProcessList(chattool.ProcessToolOptions{ - GetWorkspaceConn: workspaceCtx.getWorkspaceConn, - }), - chattool.ProcessSignal(chattool.ProcessToolOptions{ - GetWorkspaceConn: workspaceCtx.getWorkspaceConn, - }), - } - if allowAskUserQuestion { - tools = append(tools, chattool.NewAskUserQuestionTool()) - } - // Only root chats (not delegated subagents) get workspace - // provisioning and subagent tools. Child agents must not - // create workspaces or spawn further subagents. They should - // focus on completing their delegated task. - if isRootChat { - tools = p.appendRootChatTools(ctx, tools, rootChatToolsOptions{ - chat: chat, - modelConfigID: modelConfig.ID, - workspaceCtx: &workspaceCtx, - workspaceMu: &workspaceMu, - instruction: &instruction, - skills: &workspaceSkills, - resolvePlanPath: resolvePlanPathForTools, - storeFile: storeChatAttachment, - isPlanModeTurn: isPlanModeTurn, - primerCtx: primerCtx, - }) - } - - skillOpts := chattool.ReadSkillOptions{ - GetWorkspaceConn: workspaceCtx.getWorkspaceConn, - GetSkills: func() []chattool.SkillMeta { - return workspaceSkills - }, - ResolveAlias: resolveSkillAlias, - LoadPersonalSkillBody: func(ctx context.Context, name string) (skillspkg.ParsedSkill, error) { - return p.loadPersonalSkillBody(ctx, chat.OwnerID, name) - }, - } - appendCurrentSkillTools := func(current []fantasy.AgentTool) ([]fantasy.AgentTool, bool) { - if len(personalSkills) == 0 && len(workspaceSkills) == 0 { - return current, false - } - - updated := current - changed := false - appendTool := func(tool fantasy.AgentTool) { - name := tool.Info().Name - if slices.ContainsFunc(current, func(existing fantasy.AgentTool) bool { - return existing.Info().Name == name - }) { - return - } - if !changed { - updated = slices.Clone(current) - changed = true - } - updated = append(updated, tool) - } - appendTool(chattool.ReadSkill(skillOpts)) - if len(workspaceSkills) > 0 { - appendTool(chattool.ReadSkillFile(skillOpts)) - } - return updated, changed - } - tools, _ = appendCurrentSkillTools(tools) - if advisorRuntime != nil { - tools = append(tools, chatadvisor.Tool(chatadvisor.ToolOptions{ - Runtime: advisorRuntime, - GetConversationSnapshot: func() []fantasy.Message { - // The outer prompt contains ParentGuidanceBlock, which - // tells the parent when to call the advisor tool. That - // instruction is meaningless (and slightly confusing) - // when forwarded to the advisor, whose nested run has - // no tools. Strip it before handing the snapshot over. - return stripAdvisorGuidanceBlock(slices.Clone(advisorPromptSnapshot)) - }, - PublishAdviceDelta: func(toolCallID string, delta string) { - if toolCallID == "" || delta == "" { - return - } - p.publishMessagePart(chat.ID, codersdk.ChatMessageRoleTool, codersdk.ChatMessagePart{ - Type: codersdk.ChatMessagePartTypeToolResult, - ToolCallID: toolCallID, - ToolName: chatadvisor.ToolName, - ResultDelta: delta, - }) - }, - PublishAdviceReset: func(toolCallID string) { - if toolCallID == "" { - return - } - p.publishMessagePart(chat.ID, codersdk.ChatMessageRoleTool, codersdk.ChatMessagePart{ - Type: codersdk.ChatMessagePartTypeToolResult, - ToolCallID: toolCallID, - ToolName: chatadvisor.ToolName, - ResultReset: true, - }) - }, - })) - } - - var exclusiveToolNames map[string]bool - if advisorRuntime != nil { - exclusiveToolNames = map[string]bool{chatadvisor.ToolName: true} - } - - // Record builtin tool names before appending MCP tools - // so the metrics layer can differentiate between built-in and MCP tools. - builtinToolNames := make(map[string]bool, len(tools)) - for _, t := range tools { - builtinToolNames[t.Info().Name] = true - } - - // Append external MCP tools from the chat's persisted snapshot after the - // built-ins so the LLM sees them as additional capabilities. Explore chats - // trust only the persisted MCPServerIDs snapshot, and workspace-local MCP - // tools stay unavailable to Explore chats. - tools = append(tools, mcpTools...) - if !isExploreSubagent { - tools = append(tools, workspaceMCPTools...) - } - tools = filterToolsForTurn( - tools, - currentPlanMode, - chat.ParentChatID, - approvedPlanMCPConfigIDs, - ) - // Append dynamic tools declared by the client at chat - // creation time. These appear in the LLM's tool list but - // are never executed by the chatloop. The client handles - // execution via POST /tool-results. - var dynamicToolNames map[string]bool - tools, dynamicToolNames, err = appendDynamicTools( - ctx, - logger, - tools, - chat.DynamicTools, - currentPlanMode, - chat.Mode, - ) - if err != nil { - return result, err - } - - // Build provider-native tools (e.g. web search) based on the - // current model configuration. Root Explore chats stay builtin-only per - // the accepted plan, so delegated Explore children are the only Explore - // chats that can inherit web_search. Write-style provider tools stay - // blocked for all Explore chats. - var providerTools []chatloop.ProviderTool - if !isPlanModeTurn && callConfig.ProviderOptions != nil { - providerTools = buildProviderTools(callConfig.ProviderOptions) - if isExploreSubagent { - if !chat.ParentChatID.Valid { - providerTools = nil - } else { - providerTools = slices.DeleteFunc(providerTools, func(tool chatloop.ProviderTool) bool { - return tool.Definition.GetName() != "web_search" - }) - } - } - } - - providerTools, err = appendComputerUseProviderTool( - providerTools, - computerUseProviderToolOptions{ - provider: computerUseProvider, - isPlanModeTurn: isPlanModeTurn, - isComputerUse: isComputerUse, - getWorkspaceConn: workspaceCtx.getWorkspaceConn, - storeFile: storeChatAttachment, - clock: p.clock, - logger: p.logger.Named("computer_use"), - }, - ) - if err != nil { - return result, xerrors.Errorf( - "register computer use provider tool for provider %q: %w", - computerUseProvider, - err, - ) - } - - providerOptions := chatprovider.ProviderOptionsFromChatModelConfig( - model, - callConfig.ProviderOptions, - ) - // When the OpenAI Responses API has store=true, the provider - // retains conversation history server-side. For follow-up turns, - // we set previous_response_id and send only system instructions - // plus the new user input, avoiding redundant replay of prior - // assistant and tool messages that the provider already has. - chainModeActive := chatopenai.ShouldActivateChainMode( - providerOptions, - chainInfo, - modelConfig.ID, - isPlanModeTurn, - ) - if !chainModeActive && chainInfo.PreviousResponseID() != "" { - logger.Debug(ctx, "chain mode disabled", - slog.F("has_unresolved_local_tool_calls", chainInfo.HasUnresolvedLocalToolCalls()), - slog.F("provider_missing_tool_results", chainInfo.ProviderMissingToolResults()), - slog.F("is_plan_mode_turn", isPlanModeTurn), - slog.F("model_config_match", chainInfo.ModelConfigID() == modelConfig.ID), - slog.F("store_enabled", chatopenai.IsResponsesStoreEnabled(providerOptions)), - slog.F("contributing_trailing_user_count", chainInfo.ContributingTrailingUserCount()), - ) - } - if chainModeActive { - providerOptions = chatopenai.WithPreviousResponseID( - providerOptions, - chainInfo.PreviousResponseID(), - ) - prompt = chatopenai.FilterPromptForChainMode(prompt, chainInfo) - } - activeToolNames := activeToolNamesForTurn( - tools, - currentPlanMode, - chat.ParentChatID, - approvedPlanMCPConfigIDs, - ) - if isExploreSubagent { - activeToolNames = allowedExploreToolNames(tools) - } - - var loopErr error - triggerMessageID, historyTipMessageID, triggerLabel := deriveChatDebugSeed(messages) - - // Enrich the logger with correlation fields useful for - // diagnosing tool-call errors inside the chatloop. - loopLogger := logger.With( - slog.F("owner_id", chat.OwnerID), - slog.F("organization_id", chat.OrganizationID), - slog.F("trigger_message_id", triggerMessageID), - ) - if chat.WorkspaceID.Valid { - loopLogger = loopLogger.With(slog.F("workspace_id", chat.WorkspaceID.UUID)) - } - if chat.AgentID.Valid { - loopLogger = loopLogger.With(slog.F("agent_id", chat.AgentID.UUID)) - } - if chat.ParentChatID.Valid { - loopLogger = loopLogger.With(slog.F("parent_chat_id", chat.ParentChatID.UUID)) - } - result.TriggerMessageID = triggerMessageID - result.HistoryTipMessageID = historyTipMessageID - finishDebugRun := func(error, any) {} - if debugEnabled { - ctx, finishDebugRun = prepareChatTurnDebugRun( - ctx, - logger, - chat, - modelConfig, - debugSvc, - debugProvider, - debugModel, - triggerMessageID, - historyTipMessageID, - triggerLabel, - ) - } - defer func() { - panicValue := recover() - finishDebugRun(loopErr, panicValue) - if panicValue != nil { - panic(panicValue) - } - }() - - loopErr = chatloop.Run(ctx, chatloop.RunOptions{ - Model: model, - Messages: prompt, - Tools: tools, - ActiveTools: activeToolNames, - StopAfterTools: stopAfterBehaviorTools(currentPlanMode, chat.Mode, chat.ParentChatID), - MaxSteps: maxChatSteps, - Metrics: p.metrics, - Logger: loopLogger, - BuiltinToolNames: builtinToolNames, - ExclusiveToolNames: exclusiveToolNames, - - ModelConfig: callConfig, - ProviderOptions: providerOptions, - ProviderTools: providerTools, - // dynamicToolNames now contains only names that don't - // collide with built-in/MCP tools. - DynamicToolNames: dynamicToolNames, - - ContextLimitFallback: modelConfigContextLimit, - - PersistStep: persistStep, - PublishMessagePart: func( - role codersdk.ChatMessageRole, - part codersdk.ChatMessagePart, - ) { - if part.ToolName != "" { - if configID, ok := toolNameToConfigID[part.ToolName]; ok { - part.MCPServerConfigID = uuid.NullUUID{UUID: configID, Valid: true} - } - } - p.publishMessagePart(chat.ID, role, part) - }, - Compaction: compactionOptions, - ReloadMessages: func(reloadCtx context.Context) ([]fantasy.Message, error) { - reloadedMsgs, err := p.db.GetChatMessagesForPromptByChatID(reloadCtx, chat.ID) - if err != nil { - return nil, xerrors.Errorf("reload chat messages: %w", err) - } - compactionHistoryTipMessageID = 0 - if len(reloadedMsgs) > 0 { - compactionHistoryTipMessageID = reloadedMsgs[len(reloadedMsgs)-1].ID - } - if compactionOptions != nil { - compactionOptions.HistoryTipMessageID = compactionHistoryTipMessageID - } - reloadedPrompt, err := chatprompt.ConvertMessagesWithFiles(reloadCtx, reloadedMsgs, p.chatFileResolver(modelConfig.Provider), logger) - if err != nil { - return nil, xerrors.Errorf("convert reloaded messages: %w", err) - } - reloadedPrompt, sanitizeStats := chatsanitize.SanitizeAnthropicProviderToolHistory(model.Provider(), reloadedPrompt) - chatsanitize.LogAnthropicProviderToolSanitization( - reloadCtx, logger, "reload_messages", model.Provider(), model.Model(), sanitizeStats, - ) - // Re-derive instruction and skills from the reloaded - // messages so that any context added during the - // chatloop (e.g. via persistInstructionFiles when - // the agent changes) is picked up after compaction. - // The captured instruction takes priority; fall - // back to persisted DB content otherwise. - reloadedInstruction := instruction - if reloadedInstruction == "" { - reloadedInstruction = instructionFromContextFiles(reloadedMsgs) - } - if reloadedInstruction != "" { - instructionInjected = true - } - reloadedSkills := skillsFromParts(reloadedMsgs) - if len(reloadedSkills) == 0 { - reloadedSkills = workspaceSkills - } - reloadedResolvedSkills := resolvedSkillsFor(reloadedSkills) - injectedSkillIndex = chattool.FormatResolvedSkillIndex(reloadedResolvedSkills) - reloadUserPrompt := p.resolveUserPrompt(reloadCtx, chat.OwnerID) - reloadedPrompt = buildSystemPrompt( - reloadedPrompt, - subagentInstruction, - reloadedInstruction, - reloadedResolvedSkills, - reloadUserPrompt, - systemPromptBehaviorContext{ - planMode: currentPlanMode, - chatMode: chat.Mode, - planModeInstructions: planModeInstructions, - isRootChat: isRootChat, - }, - ) - // Re-inject advisor guidance after rebuilding system - // blocks so compaction/reload preserves the same - // system-message ordering as the initial prompt path. - if advisorRuntime != nil { - reloadedPrompt = chatprompt.InsertSystem(reloadedPrompt, chatadvisor.ParentGuidanceBlock) - } - reloadedPrompt = renderPlanPathPrompt(reloadedPrompt, resolvePlanPathBlock(reloadCtx)) - // Snapshot the full reloaded prompt before chain-mode - // filtering so the advisor runs with complete - // assistant/tool context. The nested advisor call - // clears previous_response_id, so provider-side - // history is unavailable. - setAdvisorPromptSnapshot(reloadedPrompt) - if chainModeActive { - reloadedPrompt = chatopenai.FilterPromptForChainMode( - reloadedPrompt, - chainInfo, - ) - } - return reloadedPrompt, nil - }, - DisableChainMode: func() { - chainModeActive = false - }, - PrepareTools: func(currentTools []fantasy.AgentTool) []fantasy.AgentTool { - updatedTools, toolsChanged := appendCurrentSkillTools(currentTools) - - // Mid-turn workspace MCP discovery for chats that bind a - // workspace via create_workspace or start_workspace after the - // turn has already started. The top-of-turn discovery path is - // gated on chat.WorkspaceID.Valid; this callback bridges the - // gap so the LLM sees workspace MCP tools on the very next - // step instead of the turn after. - // - // create_workspace and start_workspace prime - // workspaceMCPToolsCache via onChatUpdated after - // waitForAgentReady returns, so the call below is almost - // always a cache hit. The primer's bounded wait means the - // dial fallback here only runs when priming itself failed. - if workspaceMCPDiscovered || isExploreSubagent { - if toolsChanged { - return updatedTools - } - return nil - } - snapshot := workspaceCtx.currentChatSnapshot() - if !snapshot.WorkspaceID.Valid { - if toolsChanged { - return updatedTools - } - return nil - } - discovered := p.discoverWorkspaceMCPTools( - ctx, loopLogger, chat.ID, &workspaceCtx, - ) - if len(discovered) == 0 { - // Leave workspaceMCPDiscovered false so a subsequent - // step retries discovery. PrepareTools fires once per - // LLM step, so retries are unbounded for the rest of - // the turn. Per-step cost is one - // GetWorkspaceAgentsInLatestBuildByWorkspaceID query - // plus one ListMCPTools RPC, both fast against a live - // conn. The primer's 30s budget applies to its own - // loop only. - if toolsChanged { - return updatedTools - } - return nil - } - workspaceMCPDiscovered = true - return append(slices.Clone(updatedTools), discovered...) - }, - PrepareMessages: func(msgs []fantasy.Message) []fantasy.Message { - // Skip the snapshot update when chain mode is active; - // the chatloop passes in the chain-filtered prompt - // (system plus trailing user messages) and the advisor - // needs the full pre-chain history captured at the - // initial-prompt and ReloadMessages sites. - if !chainModeActive { - setAdvisorPromptSnapshot(msgs) - } - result := msgs - changed := false - if !instructionInjected && instruction != "" { - instructionInjected = true - result = chatprompt.InsertSystem(result, instruction) - changed = true - } - if skillIndex := chattool.FormatResolvedSkillIndex(resolvedSkillsFor(workspaceSkills)); skillIndex != "" && skillIndex != injectedSkillIndex { - result = removeSkillIndexMessages(result) - result = chatprompt.InsertSystem(result, skillIndex) - injectedSkillIndex = skillIndex - changed = true - } - if !changed { - return nil - } - if !chainModeActive { - setAdvisorPromptSnapshot(result) - } - return result - }, - OnRetry: func( - attempt int, - retryErr error, - classified chatretry.ClassifiedError, - delay time.Duration, - ) { - p.clearProvisionalStreamParts(chat.ID) - logger.Warn(ctx, "retrying LLM stream", - slog.F("attempt", attempt), - slog.F("delay", delay.String()), - slog.F("kind", classified.Kind), - slog.Error(retryErr), - ) - payload := chaterror.StreamRetryPayload(attempt, delay, classified) - p.publishRetry(chat.ID, payload) - }, - - OnInterruptedPersistError: func(err error) { - p.logger.Warn(ctx, "failed to persist interrupted chat step", slog.Error(err)) - }, - }) - if errors.Is(loopErr, chatloop.ErrStopAfterTool) { - loopErr = nil - } - if errors.Is(loopErr, chatloop.ErrDynamicToolCall) { - // The stream event is published in processChat's - // defer after the DB status transitions to - // requires_action, preventing a race where a fast - // client reacts before the status is committed. - result.FinalAssistantText = finalAssistantText - result.PendingDynamicToolCalls = pendingDynamicCalls - return result, nil - } - if loopErr != nil { - classified := chaterror.Classify(loopErr).WithProvider(model.Provider()) - return result, chaterror.WithClassification(loopErr, classified) - } - result.FinalAssistantText = finalAssistantText - return result, nil -} - // buildProviderTools creates provider-native tool definitions // (like web search) based on the model configuration. These // tools are executed server-side by the LLM provider. @@ -8032,13 +5399,11 @@ func (p *Server) persistChatContextSummary( } var insertedMessages []database.ChatMessage - txErr := p.db.InTx(func(tx database.Store) error { summaryParams := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by append[User]ChatMessage. ChatID: chatID, } - // Hidden summary user message (not published to subscribers). summaryUserMsg := newUserChatMessage( summaryAPIKeyID, systemContent, @@ -8049,7 +5414,6 @@ func (p *Server) persistChatContextSummary( summaryUserMsg = summaryUserMsg.withCompressed() appendUserChatMessage(&summaryParams, summaryUserMsg) - // Assistant tool-call message. appendChatMessage(&summaryParams, newChatMessage( database.ChatMessageRoleAssistant, assistantContent, @@ -8058,7 +5422,6 @@ func (p *Server) persistChatContextSummary( chatprompt.CurrentContentVersion, ).withCompressed()) - // Tool result message. appendChatMessage(&summaryParams, newChatMessage( database.ChatMessageRoleTool, toolResult, @@ -8071,22 +5434,14 @@ func (p *Server) persistChatContextSummary( if txErr != nil { return xerrors.Errorf("insert summary messages: %w", txErr) } - // Skip the first message (hidden summary user msg) when - // publishing — only the assistant and tool messages are - // visible to subscribers. insertedMessages = allInserted[1:] - return nil }, nil) if txErr != nil { return txErr } - // Publish after transaction commits to avoid notifying - // subscribers about messages that could be rolled back. - for _, msg := range insertedMessages { - p.publishMessage(chatID, msg) - } + _ = insertedMessages return nil } @@ -8522,6 +5877,16 @@ func (p *Server) fetchWorkspaceContext( return &loadedAgent, agentParts, discoveredSkills, workspaceConnOK } +func filterSkillParts(parts []codersdk.ChatMessagePart) []codersdk.ChatMessagePart { + var filtered []codersdk.ChatMessagePart + for _, part := range parts { + if part.Type == codersdk.ChatMessagePartTypeSkill { + filtered = append(filtered, part) + } + } + return filtered +} + // persistInstructionFiles fetches AGENTS.md instruction files and // skills from the workspace agent, persisting both as message // parts. This is called once when a workspace is first attached @@ -8538,10 +5903,6 @@ func (p *Server) persistInstructionFiles( agent, agentParts, discoveredSkills, workspaceConnOK := p.fetchWorkspaceContext( ctx, chat, getWorkspaceAgent, getWorkspaceConn, ) - // Defensive guard: fetchWorkspaceContext returns nil when the - // chat has no valid workspace or the agent lookup fails. It's - // cheaper to guard here than push the precondition up to all - // callers. if agent == nil { return "", nil, nil } @@ -8562,12 +5923,11 @@ func (p *Server) persistInstructionFiles( directory = agent.Directory } + contextAPIKeyID, _ := aibridge.DelegatedAPIKeyIDFromContext(ctx) if !hasContent { if !workspaceConnOK { return "", nil, nil } - // Persist a blank context-file marker (plus any skill-only - // parts) so subsequent turns skip the workspace agent dial. if !hasContextFilePart { agentParts = append([]codersdk.ChatMessagePart{{ Type: codersdk.ChatMessagePartTypeContextFile, @@ -8578,7 +5938,6 @@ func (p *Server) persistInstructionFiles( if err != nil { return "", nil, nil } - contextAPIKeyID, _ := aibridge.DelegatedAPIKeyIDFromContext(ctx) msgParams := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendUserChatMessage. ChatID: chat.ID, } @@ -8590,9 +5949,6 @@ func (p *Server) persistInstructionFiles( chatprompt.CurrentContentVersion, )) _, _ = p.db.InsertChatMessages(ctx, msgParams) - // Update the cache column: persist skills if any - // exist, or clear to NULL so stale data from a - // previous agent doesn't linger. skillParts := filterSkillParts(agentParts) p.updateLastInjectedContext(ctx, chat.ID, skillParts) return "", discoveredSkills, nil @@ -8602,7 +5958,6 @@ func (p *Server) persistInstructionFiles( return "", nil, xerrors.Errorf("marshal context-file parts: %w", err) } - contextAPIKeyID, _ := aibridge.DelegatedAPIKeyIDFromContext(ctx) msgParams := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendUserChatMessage. ChatID: chat.ID, } @@ -8616,9 +5971,6 @@ func (p *Server) persistInstructionFiles( if _, err := p.db.InsertChatMessages(ctx, msgParams); err != nil { return "", nil, xerrors.Errorf("persist instruction files: %w", err) } - // Build stripped copies for the cache column so internal - // fields (full file content, OS, directory, skill paths) - // are never persisted or returned to API clients. stripped := make([]codersdk.ChatMessagePart, len(agentParts)) copy(stripped, agentParts) for i := range stripped { @@ -8626,9 +5978,6 @@ func (p *Server) persistInstructionFiles( } p.updateLastInjectedContext(ctx, chat.ID, stripped) - // Return the formatted instruction text and discovered skills - // so the caller can inject them into this turn's prompt (since - // the prompt was built before we persisted). return formatSystemInstructions(agent.OperatingSystem, directory, agentParts), discoveredSkills, nil } @@ -8819,281 +6168,6 @@ func formatPlanPathBlock(chatPath, home string) string { return b.String() } -func (p *Server) recoverStaleChats(ctx context.Context) { - staleAfter := p.clock.Now().Add(-p.inFlightChatStaleAfter) - staleChats, err := p.db.GetStaleChats(ctx, staleAfter) - if err != nil { - p.logger.Error(ctx, "failed to get stale chats", slog.Error(err)) - return - } - - recovered := 0 - for _, chat := range staleChats { - p.logger.Info(ctx, "recovering stale chat", - slog.F("chat_id", chat.ID), - slog.F("status", chat.Status)) - - // Use a transaction with FOR UPDATE to avoid a TOCTOU race: - // between GetStaleChats (a bare SELECT) and here, the chat's - // heartbeat may have been refreshed. We re-check freshness - // under the row lock before resetting. - err := p.db.InTx(func(tx database.Store) error { - locked, lockErr := tx.GetChatByIDForUpdate(ctx, chat.ID) - if lockErr != nil { - return xerrors.Errorf("lock chat for recovery: %w", lockErr) - } - - switch locked.Status { - case database.ChatStatusRunning: - // Re-check: only recover if the chat is still stale. - // A valid heartbeat at or after the threshold means - // the chat was refreshed after our snapshot. - if locked.HeartbeatAt.Valid && !locked.HeartbeatAt.Time.Before(staleAfter) { - p.logger.Debug(ctx, "chat heartbeat refreshed since snapshot, skipping recovery", - slog.F("chat_id", chat.ID)) - return nil - } - case database.ChatStatusRequiresAction: - // Re-check: the chat may have been updated after - // our snapshot, similar to the heartbeat check for - // running chats. - if !locked.UpdatedAt.Before(staleAfter) { - p.logger.Debug(ctx, "chat updated since snapshot, skipping recovery", - slog.F("chat_id", chat.ID)) - return nil - } - case database.ChatStatusWaiting: - // Deferred-promote stranding: worker died before its - // post-cancel cleanup ran. Re-check freshness. - if !locked.UpdatedAt.Before(staleAfter) { - p.logger.Debug(ctx, "chat updated since snapshot, skipping recovery", - slog.F("chat_id", chat.ID)) - return nil - } - default: - // Status changed since our snapshot; skip. - p.logger.Debug(ctx, "chat status changed since snapshot, skipping recovery", - slog.F("chat_id", chat.ID), - slog.F("status", locked.Status)) - return nil - } - - lastError := pqtype.NullRawMessage{} - if locked.Status == database.ChatStatusRequiresAction { - lastErrorPayload, marshalErr := encodeChatLastErrorPayload( - chaterror.TerminalErrorPayload(chaterror.ClassifiedError{ - Message: "Dynamic tool execution timed out", - Kind: codersdk.ChatErrorKindGeneric, - }), - ) - if marshalErr != nil { - p.logger.Warn(ctx, "failed to marshal stale recovery last error payload", - slog.F("chat_id", chat.ID), - slog.Error(marshalErr), - ) - } else { - lastError = lastErrorPayload - } - } - - recoverStatus := database.ChatStatusPending - if locked.Status == database.ChatStatusRequiresAction { - // Timed-out requires_action chats have dangling - // tool calls with no matching results. Setting - // them back to pending would replay incomplete - // tool calls to the LLM, so mark them as errors. - recoverStatus = database.ChatStatusError - } - - // Insert synthetic error tool-result messages - // so the LLM history remains valid if the user - // retries the chat later. - if locked.Status == database.ChatStatusRequiresAction { - if _, synthErr := insertSyntheticToolResultsTx(ctx, tx, locked, "Dynamic tool execution timed out"); synthErr != nil { - p.logger.Warn(ctx, "failed to insert synthetic tool results during stale recovery", - slog.F("chat_id", chat.ID), - slog.Error(synthErr), - ) - // Continue with error status even if - // synthetic results fail to insert. - } - } - - if locked.Status == database.ChatStatusWaiting { - // Close pending dynamic tool calls; otherwise the - // promoted user message would feed the LLM a turn it - // rejects. Propagate errors so the next recovery - // tick retries instead of promoting incomplete - // history. - if _, synthErr := insertSyntheticToolResultsTx(ctx, tx, locked, "Tool execution interrupted by queued message promotion"); synthErr != nil { - return xerrors.Errorf("insert synthetic tool results during stale recovery: %w", synthErr) - } - promoted, _, _, promoteErr := p.tryAutoPromoteQueuedMessage(ctx, tx, locked) - if promoteErr != nil { - return xerrors.Errorf("auto-promote during stale recovery: %w", promoteErr) - } - if promoted == nil { - // Empty queue means nothing to recover. - return nil - } - } - - // Reset so any replica can pick it up (pending) or - // the client sees the failure (error). - _, updateErr := tx.UpdateChatStatus(ctx, database.UpdateChatStatusParams{ - ID: chat.ID, - Status: recoverStatus, - WorkerID: uuid.NullUUID{}, - StartedAt: sql.NullTime{}, - HeartbeatAt: sql.NullTime{}, - LastError: lastError, - }) - if updateErr != nil { - return updateErr - } - recovered++ - return nil - }, nil) - if err != nil { - p.logger.Error(ctx, "failed to recover stale chat", - slog.F("chat_id", chat.ID), slog.Error(err)) - } - } - - if recovered > 0 { - p.logger.Info(ctx, "recovered stale chats", slog.F("count", recovered)) - } -} - -// insertSyntheticToolResultsTx inserts IsError tool-result messages -// for unresolved dynamic tool calls in the last assistant message, -// skipping calls already handled (e.g. by chatloop dispatching a -// name-colliding dynamic tool as a built-in). It operates on the -// provided store, which may be a transaction handle. -func insertSyntheticToolResultsTx( - ctx context.Context, - store database.Store, - chat database.Chat, - reason string, -) ([]database.ChatMessage, error) { - dynamicToolNames, err := parseDynamicToolNames(chat.DynamicTools) - if err != nil { - return nil, xerrors.Errorf("parse dynamic tools: %w", err) - } - if len(dynamicToolNames) == 0 { - return nil, nil - } - - // No assistant means nothing to close: a deferred promote can - // race a worker that fails before any persist, and the cleanup - // TX must still advance. - lastAssistant, err := store.GetLastChatMessageByRole(ctx, database.GetLastChatMessageByRoleParams{ - ChatID: chat.ID, - Role: database.ChatMessageRoleAssistant, - }) - if errors.Is(err, sql.ErrNoRows) { - return nil, nil - } - if err != nil { - return nil, xerrors.Errorf("get last assistant message: %w", err) - } - - parts, err := chatprompt.ParseContent(lastAssistant) - if err != nil { - return nil, xerrors.Errorf("parse assistant message: %w", err) - } - - // Mirrors SubmitToolResults. - afterMsgs, err := store.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ - ChatID: chat.ID, - AfterID: lastAssistant.ID, - }) - if err != nil { - return nil, xerrors.Errorf("get messages after assistant: %w", err) - } - handledCallIDs := make(map[string]bool) - for _, msg := range afterMsgs { - if msg.Role != database.ChatMessageRoleTool { - continue - } - msgParts, err := chatprompt.ParseContent(msg) - if err != nil { - continue - } - for _, mp := range msgParts { - if mp.Type == codersdk.ChatMessagePartTypeToolResult { - handledCallIDs[mp.ToolCallID] = true - } - } - } - - // Collect dynamic tool calls that need synthetic results. - var resultContents []pqtype.NullRawMessage - for _, part := range parts { - if part.Type != codersdk.ChatMessagePartTypeToolCall || !dynamicToolNames[part.ToolName] { - continue - } - if handledCallIDs[part.ToolCallID] { - continue - } - resultPart := codersdk.ChatMessagePart{ - Type: codersdk.ChatMessagePartTypeToolResult, - ToolCallID: part.ToolCallID, - ToolName: part.ToolName, - Result: json.RawMessage(fmt.Sprintf("%q", reason)), - IsError: true, - } - marshaled, marshalErr := chatprompt.MarshalParts([]codersdk.ChatMessagePart{resultPart}) - if marshalErr != nil { - return nil, xerrors.Errorf("marshal synthetic tool result: %w", marshalErr) - } - resultContents = append(resultContents, marshaled) - } - - if len(resultContents) == 0 { - return nil, nil - } - - // Insert tool-result messages using the same pattern as - // SubmitToolResults. - n := len(resultContents) - params := database.InsertChatMessagesParams{ - ChatID: chat.ID, - CreatedBy: make([]uuid.UUID, n), - APIKeyID: make([]string, n), - ModelConfigID: make([]uuid.UUID, n), - Role: make([]database.ChatMessageRole, n), - Content: make([]string, n), - ContentVersion: make([]int16, n), - Visibility: make([]database.ChatMessageVisibility, n), - InputTokens: make([]int64, n), - OutputTokens: make([]int64, n), - TotalTokens: make([]int64, n), - ReasoningTokens: make([]int64, n), - CacheCreationTokens: make([]int64, n), - CacheReadTokens: make([]int64, n), - ContextLimit: make([]int64, n), - Compressed: make([]bool, n), - TotalCostMicros: make([]int64, n), - RuntimeMs: make([]int64, n), - ProviderResponseID: make([]string, n), - } - for i, rc := range resultContents { - params.CreatedBy[i] = uuid.Nil - params.ModelConfigID[i] = chat.LastModelConfigID - params.Role[i] = database.ChatMessageRoleTool - params.Content[i] = string(rc.RawMessage) - params.ContentVersion[i] = chatprompt.CurrentContentVersion - params.Visibility[i] = database.ChatMessageVisibilityBoth - } - inserted, err := store.InsertChatMessages(ctx, params) - if err != nil { - return nil, xerrors.Errorf("insert synthetic tool results: %w", err) - } - - return inserted, nil -} - // parseDynamicToolNames unmarshals the dynamic tools JSON column // and returns a map of tool names. This centralizes the repeated // pattern of deserializing DynamicTools into a name set. @@ -9185,7 +6259,7 @@ func (p *Server) finalizeSuccessfulTurnStatusLabelWithAfterFunc( slog.F("label_length", len(statusLabel)), ) - p.updateLastTurnSummary(finalizeCtx, chat, chat.UpdatedAt, statusLabel, logger) + p.updateLastTurnSummary(finalizeCtx, chat, chat.HistoryVersion, statusLabel, logger) afterFinalize(finalizeCtx, statusLabel) }) @@ -9274,7 +6348,7 @@ func (p *Server) setLastTurnSummaryAsync( // still counted in p.inflight. Do not take inflightMu here because // drainInflight holds it while waiting. p.inflight.Go(func() { - p.updateLastTurnSummary(context.WithoutCancel(ctx), chat, chat.UpdatedAt, summary, logger) + p.updateLastTurnSummary(context.WithoutCancel(ctx), chat, chat.HistoryVersion, summary, logger) }) } @@ -9283,14 +6357,11 @@ func (p *Server) clearLastTurnSummaryAsync( chat database.Chat, logger slog.Logger, ) { - if !chat.LastTurnSummary.Valid { - return - } // This helper runs during processChat cleanup, while processChat is // still counted in p.inflight. Do not take inflightMu here because // drainInflight holds it while waiting. p.inflight.Go(func() { - p.updateLastTurnSummary(context.WithoutCancel(ctx), chat, chat.UpdatedAt, "", logger) + p.updateLastTurnSummary(context.WithoutCancel(ctx), chat, chat.HistoryVersion, "", logger) }) } @@ -9300,7 +6371,7 @@ func (p *Server) clearLastTurnSummaryAsync( func (p *Server) updateLastTurnSummary( ctx context.Context, chat database.Chat, - expectedUpdatedAt time.Time, + expectedHistoryVersion int64, summary string, logger slog.Logger, ) { @@ -9313,8 +6384,9 @@ func (p *Server) updateLastTurnSummary( defer cancel() affected, err := p.db.UpdateChatLastTurnSummary(updateCtx, database.UpdateChatLastTurnSummaryParams{ - ID: chat.ID, - LastTurnSummary: lastTurnSummary, + ID: chat.ID, + ExpectedHistoryVersion: expectedHistoryVersion, + LastTurnSummary: lastTurnSummary, }) if err != nil { logger.Warn(updateCtx, "failed to update chat turn summary", @@ -9328,13 +6400,13 @@ func (p *Server) updateLastTurnSummary( logger.Info(updateCtx, "skipped stale chat turn summary update with non-empty summary", slog.F("chat_id", chat.ID), slog.F("summary_length", len(summary)), - slog.F("expected_updated_at", expectedUpdatedAt), + slog.F("expected_history_version", expectedHistoryVersion), ) return } logger.Debug(updateCtx, "skipped stale chat turn summary update", slog.F("chat_id", chat.ID), - slog.F("expected_updated_at", expectedUpdatedAt), + slog.F("expected_history_version", expectedHistoryVersion), ) return } @@ -9342,10 +6414,6 @@ func (p *Server) updateLastTurnSummary( updatedChat := chat updatedChat.LastTurnSummary = lastTurnSummary p.publishChatPubsubEvent(updatedChat, codersdk.ChatWatchEventKindSummaryChange, nil) - - // AcquireChats uses SKIP LOCKED; re-wake so a wake racing this - // UPDATE's row lock does not strand a freshly-pending chat. - p.signalWake() } func (p *Server) webpushConfigured() bool { @@ -9380,6 +6448,14 @@ func (p *Server) Close() error { p.configCacheUnsubscribe = nil unsub() } + if p.chatWorker != nil { + if err := p.chatWorker.Close(); err != nil { + p.logger.Warn(context.Background(), "failed to close chat worker", slog.Error(err)) + } + } + if p.messagePartBuffer != nil { + p.messagePartBuffer.Close() + } p.cancel() p.wg.Wait() p.drainInflight() diff --git a/coderd/x/chatd/chatd_chainmode_test.go b/coderd/x/chatd/chatd_chainmode_test.go new file mode 100644 index 0000000000..e1cf1db417 --- /dev/null +++ b/coderd/x/chatd/chatd_chainmode_test.go @@ -0,0 +1,565 @@ +package chatd_test + +import ( + "context" + "encoding/json" + "net/http" + "strings" + "sync" + "sync/atomic" + "testing" + + "charm.land/fantasy" + fantasyanthropic "charm.land/fantasy/providers/anthropic" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbgen" + "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/coderd/x/chatd" + "github.com/coder/coder/v2/coderd/x/chatd/chatprompt" + "github.com/coder/coder/v2/coderd/x/chatd/chattest" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/testutil" +) + +func TestActiveServer_ChainBrokenRecovery(t *testing.T) { + t.Parallel() + + const ( + previousResponseID = "resp_poisoned" + recoveredAnswer = "recovered answer" + ) + ctx := testutil.Context(t, testutil.WaitLong) + db, ps := dbtestutil.NewDB(t) + requests := newOpenAIRequestRecorder() + openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { + requests.record(req) + if req.PreviousResponseID != nil { + return chattest.OpenAIErrorResponse(http.StatusNotFound, "invalid_request_error", chainBrokenProviderErrorMessage) + } + return chattest.OpenAIStreamingResponse(chattest.OpenAITextChunks(recoveredAnswer)...) + }) + user, org, model := seedChatDependenciesWithProvider(t, db, "openai", openAIURL) + model = updateModelForChainMode(t, db, model) + + server := newActiveTestServer(t, db, ps) + chat := createChatThroughServer(ctx, t, server, org.ID, user.ID, model.ID, "first user") + waitForChatStatus(ctx, t, db, chat.ID, database.ChatStatusWaiting) + insertProviderResponseID(ctx, t, db, chat.ID, "first assistant", model.ID, previousResponseID) + _, err := server.SendMessage(ctx, chatd.SendMessageOptions{ + ChatID: chat.ID, + CreatedBy: user.ID, + ModelConfigID: model.ID, + Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("follow up")}, + BusyBehavior: chatd.SendMessageBusyBehaviorQueue, + }) + require.NoError(t, err) + + waitForChatStatus(ctx, t, db, chat.ID, database.ChatStatusWaiting) + + got := requests.all() + require.GreaterOrEqual(t, len(got), 3) + generationRequests := filterStreamingRequests(got) + require.Len(t, generationRequests, 3) + require.Nil(t, generationRequests[0].PreviousResponseID) + require.Equal(t, previousResponseID, requirePreviousResponseID(t, generationRequests[1])) + require.Nil(t, generationRequests[2].PreviousResponseID) + requireRawPromptContains(t, generationRequests[2], "first user") + requireRawPromptContains(t, generationRequests[2], "first assistant") + requireRawPromptContains(t, generationRequests[2], "follow up") + + messages := chatMessages(ctx, t, db, chat.ID) + requireTextPart(t, messages[len(messages)-1], recoveredAnswer) +} + +func TestActiveServer_ChainBrokenRecoveryAppliesProviderPromptPrep(t *testing.T) { + t.Parallel() + + const previousResponseID = "resp_anthropic_chain" + ctx := testutil.Context(t, testutil.WaitLong) + db, ps := dbtestutil.NewDB(t) + requests := newAnthropicRequestRecorder() + var streamCalls atomic.Int32 + anthropicURL := chattest.NewAnthropic(t, func(req *chattest.AnthropicRequest) chattest.AnthropicResponse { + requests.record(req) + if streamCalls.Add(1) == 2 { + return chattest.AnthropicErrorResponse(http.StatusInternalServerError, "server_error", chainBrokenProviderErrorMessage) + } + return chattest.AnthropicStreamingResponse(chattest.AnthropicTextChunks("anthropic answer")...) + }) + user, org, model := seedAnthropicChatDependencies(t, db, anthropicURL) + model = updateModelForChainMode(t, db, model) + + server := newActiveTestServer(t, db, ps) + chat := createChatThroughServer(ctx, t, server, org.ID, user.ID, model.ID, "hello") + waitForChatStatus(ctx, t, db, chat.ID, database.ChatStatusWaiting) + insertSystemTextMessage(ctx, t, db, chat.ID, "sys-1", model.ID) + insertProviderResponseID(ctx, t, db, chat.ID, "hi", model.ID, previousResponseID) + _, err := server.SendMessage(ctx, chatd.SendMessageOptions{ + ChatID: chat.ID, + CreatedBy: user.ID, + ModelConfigID: model.ID, + Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("follow up")}, + BusyBehavior: chatd.SendMessageBusyBehaviorQueue, + }) + require.NoError(t, err) + + waitForChatStatus(ctx, t, db, chat.ID, database.ChatStatusWaiting) + + generationRequests := filterAnthropicStreamingRequests(requests.all()) + require.Len(t, generationRequests, 2) + recovered := generationRequests[1] + require.Len(t, recovered.Messages, 4) + require.True(t, anthropicSystemHasEphemeralCacheControl(t, recovered)) + require.False(t, anthropicMessageHasEphemeralCacheControl(t, recovered.Messages[0])) + require.False(t, anthropicMessageHasEphemeralCacheControl(t, recovered.Messages[1])) + require.True(t, anthropicMessageHasEphemeralCacheControl(t, recovered.Messages[2])) + require.True(t, anthropicMessageHasEphemeralCacheControl(t, recovered.Messages[3])) +} + +func TestActiveServer_NonChainBrokenRetryPreservesChainMode(t *testing.T) { + t.Parallel() + + const previousResponseID = "resp_still_valid" + ctx := testutil.Context(t, testutil.WaitLong) + db, ps := dbtestutil.NewDB(t) + requests := newOpenAIRequestRecorder() + var streamCalls atomic.Int32 + openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { + requests.record(req) + if req.Stream && streamCalls.Add(1) == 2 { + return chattest.OpenAIServerErrorResponse() + } + return chattest.OpenAIStreamingResponse(chattest.OpenAITextChunks("answer")...) + }) + user, org, model := seedChatDependenciesWithProvider(t, db, "openai", openAIURL) + model = updateModelForChainMode(t, db, model) + + server := newActiveTestServer(t, db, ps) + chat := createChatThroughServer(ctx, t, server, org.ID, user.ID, model.ID, "first user") + waitForChatStatus(ctx, t, db, chat.ID, database.ChatStatusWaiting) + insertProviderResponseID(ctx, t, db, chat.ID, "first assistant", model.ID, previousResponseID) + _, err := server.SendMessage(ctx, chatd.SendMessageOptions{ + ChatID: chat.ID, + CreatedBy: user.ID, + ModelConfigID: model.ID, + Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("follow up")}, + BusyBehavior: chatd.SendMessageBusyBehaviorQueue, + }) + require.NoError(t, err) + + waitForChatStatus(ctx, t, db, chat.ID, database.ChatStatusWaiting) + + generationRequests := filterStreamingRequests(requests.all()) + require.Len(t, generationRequests, 3) + require.Equal(t, previousResponseID, requirePreviousResponseID(t, generationRequests[1])) + require.Equal(t, previousResponseID, requirePreviousResponseID(t, generationRequests[2])) + requireRawPromptNotContains(t, generationRequests[2], "first user") + requireRawPromptContains(t, generationRequests[2], "follow up") +} + +func TestActiveServer_ChainBrokenRecoveryPersistsAcrossGenerationActions(t *testing.T) { + t.Parallel() + + const previousResponseID = "resp_tool_poisoned" + ctx := testutil.Context(t, testutil.WaitLong) + db, ps := dbtestutil.NewDB(t) + requests := newOpenAIRequestRecorder() + var streamCalls atomic.Int32 + openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { + requests.record(req) + if !req.Stream { + return chattest.OpenAINonStreamingResponse(`{"title":"test"}`) + } + switch streamCalls.Add(1) { + case 1: + return chattest.OpenAIStreamingResponse(chattest.OpenAITextChunks("first answer")...) + case 2: + return chattest.OpenAIErrorResponse(http.StatusNotFound, "invalid_request_error", chainBrokenProviderErrorMessage) + case 3: + return chattest.OpenAIStreamingResponse(chattest.OpenAIToolCallChunk("read_skill", `{"name":"x"}`)) + default: + return chattest.OpenAIStreamingResponse(chattest.OpenAITextChunks("final answer")...) + } + }) + user, org, model := seedChatDependenciesWithProvider(t, db, "openai", openAIURL) + model = updateModelForChainMode(t, db, model) + + server := newActiveTestServer(t, db, ps) + chat := createChatThroughServer(ctx, t, server, org.ID, user.ID, model.ID, "first user") + waitForChatStatus(ctx, t, db, chat.ID, database.ChatStatusWaiting) + insertProviderResponseID(ctx, t, db, chat.ID, "first assistant", model.ID, previousResponseID) + _, err := server.SendMessage(ctx, chatd.SendMessageOptions{ + ChatID: chat.ID, + CreatedBy: user.ID, + ModelConfigID: model.ID, + Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("follow up")}, + BusyBehavior: chatd.SendMessageBusyBehaviorQueue, + }) + require.NoError(t, err) + + waitForChatStatus(ctx, t, db, chat.ID, database.ChatStatusWaiting) + + generationRequests := filterStreamingRequests(requests.all()) + require.Len(t, generationRequests, 4) + require.Equal(t, previousResponseID, requirePreviousResponseID(t, generationRequests[1])) + require.Nil(t, generationRequests[2].PreviousResponseID) + require.Nil(t, generationRequests[3].PreviousResponseID) +} + +func TestActiveServer_ChainBrokenWithoutChainModeIsSafe(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + db, ps := dbtestutil.NewDB(t) + requests := newOpenAIRequestRecorder() + var streamCalls atomic.Int32 + openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { + requests.record(req) + if req.Stream && streamCalls.Add(1) == 1 { + return chattest.OpenAIErrorResponse(http.StatusNotFound, "invalid_request_error", chainBrokenProviderErrorMessage) + } + return chattest.OpenAIStreamingResponse(chattest.OpenAITextChunks("recovered")...) + }) + user, org, model := seedChatDependenciesWithProvider(t, db, "openai", openAIURL) + model = updateModelForChainMode(t, db, model) + + server := newActiveTestServer(t, db, ps) + chat := createChatThroughServer(ctx, t, server, org.ID, user.ID, model.ID, "only user") + waitForChatStatus(ctx, t, db, chat.ID, database.ChatStatusWaiting) + + generationRequests := filterStreamingRequests(requests.all()) + require.Len(t, generationRequests, 2) + require.Nil(t, generationRequests[0].PreviousResponseID) + require.Nil(t, generationRequests[1].PreviousResponseID) +} + +func TestActiveServer_ChainBrokenRecoveryDropsOrphanProviderToolCall(t *testing.T) { + t.Parallel() + + const previousResponseID = "resp_orphan_provider_tool" + ctx := testutil.Context(t, testutil.WaitLong) + db, ps := dbtestutil.NewDB(t) + requests := newAnthropicRequestRecorder() + var streamCalls atomic.Int32 + anthropicURL := chattest.NewAnthropic(t, func(req *chattest.AnthropicRequest) chattest.AnthropicResponse { + requests.record(req) + if streamCalls.Add(1) == 2 { + return chattest.AnthropicErrorResponse(http.StatusInternalServerError, "server_error", chainBrokenProviderErrorMessage) + } + return chattest.AnthropicStreamingResponse(chattest.AnthropicTextChunks("cleaned")...) + }) + user, org, model := seedAnthropicChatDependencies(t, db, anthropicURL) + model = updateModelForChainMode(t, db, model) + + server := newActiveTestServer(t, db, ps) + chat := createChatThroughServer(ctx, t, server, org.ID, user.ID, model.ID, "first user") + waitForChatStatus(ctx, t, db, chat.ID, database.ChatStatusWaiting) + insertProviderResponseID(ctx, t, db, chat.ID, "first assistant", model.ID, previousResponseID) + insertOrphanProviderToolCall(ctx, t, db, chat.ID, model.ID) + _, err := server.SendMessage(ctx, chatd.SendMessageOptions{ + ChatID: chat.ID, + CreatedBy: user.ID, + ModelConfigID: model.ID, + Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("continue")}, + BusyBehavior: chatd.SendMessageBusyBehaviorQueue, + }) + require.NoError(t, err) + + waitForChatStatus(ctx, t, db, chat.ID, database.ChatStatusWaiting) + + generationRequests := filterAnthropicStreamingRequests(requests.all()) + require.Len(t, generationRequests, 2) + recoveredBody := anthropicRequestBody(t, generationRequests[1]) + require.NotContains(t, recoveredBody, "web_search") + require.Contains(t, recoveredBody, "partial") + require.Contains(t, recoveredBody, "continue") + requireAnthropicRequestRedactedReasoning(t, generationRequests[1], "redacted-payload") +} + +type anthropicRequestRecorder struct { + mu sync.Mutex + requests []chattest.AnthropicRequest +} + +func newAnthropicRequestRecorder() *anthropicRequestRecorder { + return &anthropicRequestRecorder{} +} + +func (r *anthropicRequestRecorder) record(req *chattest.AnthropicRequest) { + r.mu.Lock() + defer r.mu.Unlock() + r.requests = append(r.requests, *req) +} + +func (r *anthropicRequestRecorder) all() []chattest.AnthropicRequest { + r.mu.Lock() + defer r.mu.Unlock() + return append([]chattest.AnthropicRequest(nil), r.requests...) +} + +func filterAnthropicStreamingRequests(requests []chattest.AnthropicRequest) []chattest.AnthropicRequest { + out := make([]chattest.AnthropicRequest, 0, len(requests)) + for _, req := range requests { + if req.Stream { + out = append(out, req) + } + } + return out +} + +func seedAnthropicChatDependencies(t *testing.T, db database.Store, baseURL string) (database.User, database.Organization, database.ChatModelConfig) { + t.Helper() + user := dbgen.User(t, db, database.User{}) + org := dbgen.Organization(t, db, database.Organization{}) + dbgen.OrganizationMember(t, db, database.OrganizationMember{UserID: user.ID, OrganizationID: org.ID}) + provider := dbgen.AIProvider(t, db, database.AIProvider{Type: database.AiProviderTypeAnthropic}, func(params *database.InsertAIProviderParams) { + params.BaseUrl = baseURL + }) + dbgen.AIProviderKey(t, db, database.AIProviderKey{ProviderID: provider.ID}) + model := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{ + Provider: "anthropic", + Model: "claude-sonnet-4-20250514", + IsDefault: true, + AIProviderID: uuid.NullUUID{UUID: provider.ID, Valid: true}, + }) + return user, org, model +} + +func anthropicSystemHasEphemeralCacheControl(t *testing.T, req chattest.AnthropicRequest) bool { + t.Helper() + return strings.Contains(string(req.System), `"cache_control":{"type":"ephemeral"}`) +} + +func anthropicMessageHasEphemeralCacheControl(t *testing.T, message chattest.AnthropicRequestMessage) bool { + t.Helper() + return strings.Contains(string(message.Content), `"cache_control":{"type":"ephemeral"}`) +} + +func anthropicRequestBody(t *testing.T, req chattest.AnthropicRequest) string { + t.Helper() + data, err := json.Marshal(req.Messages) + require.NoError(t, err) + return string(data) +} + +func insertSystemTextMessage( + ctx context.Context, + t *testing.T, + db database.Store, + chatID uuid.UUID, + text string, + modelID uuid.UUID, +) { + t.Helper() + content, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{codersdk.ChatMessageText(text)}) + require.NoError(t, err) + params := chatd.BuildSingleChatMessageInsertParams( + chatID, + database.ChatMessageRoleSystem, + content, + database.ChatMessageVisibilityBoth, + modelID, + chatprompt.CurrentContentVersion, + uuid.Nil, + ) + _, err = db.InsertChatMessages(ctx, params) + require.NoError(t, err) +} + +func requireAnthropicRequestRedactedReasoning(t *testing.T, req chattest.AnthropicRequest, redactedData string) { + t.Helper() + body := anthropicRequestBody(t, req) + require.Contains(t, body, "redacted-payload") + require.Contains(t, body, redactedData) +} + +func insertOrphanProviderToolCall(ctx context.Context, t *testing.T, db database.Store, chatID uuid.UUID, modelID uuid.UUID) { + t.Helper() + reasoningMetadata, err := json.Marshal(fantasy.ProviderMetadata{ + fantasyanthropic.Name: &fantasyanthropic.ReasoningOptionMetadata{RedactedData: "redacted-payload"}, + }) + require.NoError(t, err) + parts := []codersdk.ChatMessagePart{ + { + Type: codersdk.ChatMessagePartTypeReasoning, + ProviderMetadata: reasoningMetadata, + }, + { + Type: codersdk.ChatMessagePartTypeToolCall, + ToolCallID: "ws-orphan", + ToolName: "web_search", + Args: json.RawMessage(`{"query":"coder"}`), + ProviderExecuted: true, + }, + codersdk.ChatMessageText("partial"), + } + content, err := chatprompt.MarshalParts(parts) + require.NoError(t, err) + params := chatd.BuildSingleChatMessageInsertParams( + chatID, + database.ChatMessageRoleAssistant, + content, + database.ChatMessageVisibilityBoth, + modelID, + chatprompt.CurrentContentVersion, + uuid.Nil, + ) + _, err = db.InsertChatMessages(ctx, params) + require.NoError(t, err) +} + +const chainBrokenProviderErrorMessage = "Previous response with id 'resp_abc' not found." + +type openAIRequestRecorder struct { + mu sync.Mutex + requests []chattest.OpenAIRequest +} + +func newOpenAIRequestRecorder() *openAIRequestRecorder { + return &openAIRequestRecorder{} +} + +func (r *openAIRequestRecorder) record(req *chattest.OpenAIRequest) { + r.mu.Lock() + defer r.mu.Unlock() + r.requests = append(r.requests, *req) +} + +func (r *openAIRequestRecorder) all() []chattest.OpenAIRequest { + r.mu.Lock() + defer r.mu.Unlock() + return append([]chattest.OpenAIRequest(nil), r.requests...) +} + +func updateModelForChainMode(t *testing.T, db database.Store, model database.ChatModelConfig) database.ChatModelConfig { + t.Helper() + store := true + options, err := json.Marshal(codersdk.ChatModelCallConfig{ + ProviderOptions: &codersdk.ChatModelProviderOptions{ + OpenAI: &codersdk.ChatModelOpenAIProviderOptions{Store: &store}, + }, + }) + require.NoError(t, err) + updated, err := db.UpdateChatModelConfig(context.Background(), database.UpdateChatModelConfigParams{ + ID: model.ID, + DisplayName: model.DisplayName, + Model: model.Model, + Provider: model.Provider, + Enabled: model.Enabled, + ContextLimit: model.ContextLimit, + CompressionThreshold: model.CompressionThreshold, + Options: options, + AIProviderID: model.AIProviderID, + }) + require.NoError(t, err) + return updated +} + +func createChatThroughServer( + ctx context.Context, + t *testing.T, + server *chatd.Server, + orgID uuid.UUID, + userID uuid.UUID, + modelID uuid.UUID, + text string, +) database.Chat { + t.Helper() + chat, err := server.CreateChat(ctx, chatd.CreateOptions{ + OrganizationID: orgID, + OwnerID: userID, + Title: "chain mode test", + InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText(text)}, + ModelConfigID: modelID, + }) + require.NoError(t, err) + return chat +} + +func waitForChatStatus(ctx context.Context, t *testing.T, db database.Store, chatID uuid.UUID, status database.ChatStatus) database.Chat { + t.Helper() + var chat database.Chat + testutil.Eventually(ctx, t, func(ctx context.Context) bool { + latest, err := db.GetChatByID(ctx, chatID) + if err != nil { + return false + } + chat = latest + return latest.Status == status && !latest.WorkerID.Valid && !latest.RunnerID.Valid + }, testutil.IntervalFast) + return chat +} + +func insertProviderResponseID( + ctx context.Context, + t *testing.T, + db database.Store, + chatID uuid.UUID, + text string, + modelID uuid.UUID, + providerResponseID string, +) { + t.Helper() + content, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{codersdk.ChatMessageText(text)}) + require.NoError(t, err) + params := chatd.BuildSingleChatMessageInsertParams( + chatID, + database.ChatMessageRoleAssistant, + content, + database.ChatMessageVisibilityBoth, + modelID, + chatprompt.CurrentContentVersion, + uuid.Nil, + ) + params.ProviderResponseID[0] = providerResponseID + _, err = db.InsertChatMessages(ctx, params) + require.NoError(t, err) +} + +func chatMessages(ctx context.Context, t *testing.T, db database.Store, chatID uuid.UUID) []database.ChatMessage { + t.Helper() + messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ChatID: chatID}) + require.NoError(t, err) + return messages +} + +func filterStreamingRequests(requests []chattest.OpenAIRequest) []chattest.OpenAIRequest { + out := make([]chattest.OpenAIRequest, 0, len(requests)) + for _, req := range requests { + if req.Stream { + out = append(out, req) + } + } + return out +} + +func requirePreviousResponseID(t *testing.T, req chattest.OpenAIRequest) string { + t.Helper() + require.NotNil(t, req.PreviousResponseID) + return *req.PreviousResponseID +} + +func requireRawPromptContains(t *testing.T, req chattest.OpenAIRequest, text string) { + t.Helper() + require.Contains(t, string(req.RawBody), text) +} + +func requireRawPromptNotContains(t *testing.T, req chattest.OpenAIRequest, text string) { + t.Helper() + require.NotContains(t, string(req.RawBody), text) +} + +func requireTextPart(t *testing.T, msg database.ChatMessage, text string) { + t.Helper() + parts, err := chatprompt.ParseContent(msg) + require.NoError(t, err) + for _, part := range parts { + if part.Type == codersdk.ChatMessagePartTypeText && part.Text == text { + return + } + } + t.Fatalf("missing text part %q in message %d", text, msg.ID) +} diff --git a/coderd/x/chatd/chatd_internal_test.go b/coderd/x/chatd/chatd_internal_test.go index 3676985fd4..86ccc09a9f 100644 --- a/coderd/x/chatd/chatd_internal_test.go +++ b/coderd/x/chatd/chatd_internal_test.go @@ -28,7 +28,6 @@ import ( dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub" coderdpubsub "github.com/coder/coder/v2/coderd/pubsub" "github.com/coder/coder/v2/coderd/rbac" - "github.com/coder/coder/v2/coderd/x/chatd/chaterror" "github.com/coder/coder/v2/coderd/x/chatd/chatloop" openaicomputeruse "github.com/coder/coder/v2/coderd/x/chatd/chatopenai/computeruse" "github.com/coder/coder/v2/coderd/x/chatd/chatprompt" @@ -250,6 +249,10 @@ func TestAppendComputerUseProviderTool(t *testing.T) { fantasy.NewImageResponse([]byte("png"), "image/png"), ) require.NotNil(t, metadata) + + errorResponse := fantasy.NewTextErrorResponse("failed") + require.Nil(t, providerTools[0].ResultProviderMetadata(errorResponse)) + require.Nil(t, providerTools[0].ResultProviderMetadata(fantasy.NewTextResponse("not media"))) } func TestAppendComputerUseProviderTool_Gates(t *testing.T) { @@ -2122,567 +2125,6 @@ func TestTurnWorkspaceContext_EnsureWorkspaceAgentIgnoresCachedAgentForDifferent require.Equal(t, updatedChat, currentChat) } -func TestSubscribeDedupesLocallyDeliveredMessageOnNotifyCatchup(t *testing.T) { - t.Parallel() - - ctx, cancelCtx := context.WithCancel(context.Background()) - defer cancelCtx() - - ctrl := gomock.NewController(t) - db := dbmock.NewMockStore(ctrl) - - chatID := uuid.New() - chat := database.Chat{ID: chatID, Status: database.ChatStatusPending} - initialMessage := database.ChatMessage{ - ID: 1, - ChatID: chatID, - Role: database.ChatMessageRoleUser, - } - localMessage := database.ChatMessage{ - ID: 2, - ChatID: chatID, - Role: database.ChatMessageRoleAssistant, - } - gomock.InOrder( - db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), - db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), - db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{ - ChatID: chatID, - AfterID: 0, - }).Return([]database.ChatMessage{initialMessage}, nil), - db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil), - // DB catchup runs unconditionally on every notify; the delivered - // set dedupes against locally-delivered messages. - db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{ - ChatID: chatID, - AfterID: 1, - }).Return(nil, nil), - ) - - server := newSubscribeTestServer(t, db) - _, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0) - require.True(t, ok) - defer cancel() - - server.publishMessage(chatID, localMessage) - - event := requireStreamMessageEvent(t, events) - require.Equal(t, int64(2), event.Message.ID) - requireNoStreamEvent(t, events, 200*time.Millisecond) -} - -func TestSubscribeUsesDurableCacheWhenLocalMessageWasNotDelivered(t *testing.T) { - t.Parallel() - - ctx, cancelCtx := context.WithCancel(context.Background()) - defer cancelCtx() - - ctrl := gomock.NewController(t) - db := dbmock.NewMockStore(ctrl) - - chatID := uuid.New() - chat := database.Chat{ID: chatID, Status: database.ChatStatusPending} - initialMessage := database.ChatMessage{ - ID: 1, - ChatID: chatID, - Role: database.ChatMessageRoleUser, - } - cachedMessage := codersdk.ChatMessage{ - ID: 2, - ChatID: chatID, - Role: codersdk.ChatMessageRoleAssistant, - } - gomock.InOrder( - db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), - db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), - db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{ - ChatID: chatID, - AfterID: 0, - }).Return([]database.ChatMessage{initialMessage}, nil), - db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil), - // DB catchup runs unconditionally; cached id=2 is deduped via - // the delivered set so this query returning nil is sufficient. - db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{ - ChatID: chatID, - AfterID: 1, - }).Return(nil, nil), - ) - - server := newSubscribeTestServer(t, db) - server.cacheDurableMessage(chatID, codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeMessage, - ChatID: chatID, - Message: &cachedMessage, - }) - - _, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0) - require.True(t, ok) - defer cancel() - - server.publishChatStreamNotify(chatID, coderdpubsub.ChatStreamNotifyMessage{ - AfterMessageID: 1, - }) - - event := requireStreamMessageEvent(t, events) - require.Equal(t, int64(2), event.Message.ID) - requireNoStreamEvent(t, events, 200*time.Millisecond) -} - -func TestSubscribeQueriesDatabaseWhenDurableCacheMisses(t *testing.T) { - t.Parallel() - - ctx, cancelCtx := context.WithCancel(context.Background()) - defer cancelCtx() - - ctrl := gomock.NewController(t) - db := dbmock.NewMockStore(ctrl) - - chatID := uuid.New() - chat := database.Chat{ID: chatID, Status: database.ChatStatusPending} - initialMessage := database.ChatMessage{ - ID: 1, - ChatID: chatID, - Role: database.ChatMessageRoleUser, - } - catchupMessage := database.ChatMessage{ - ID: 2, - ChatID: chatID, - Role: database.ChatMessageRoleAssistant, - } - gomock.InOrder( - db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), - db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), - db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{ - ChatID: chatID, - AfterID: 0, - }).Return([]database.ChatMessage{initialMessage}, nil), - db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil), - db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{ - ChatID: chatID, - AfterID: 1, - }).Return([]database.ChatMessage{catchupMessage}, nil), - ) - - server := newSubscribeTestServer(t, db) - _, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0) - require.True(t, ok) - defer cancel() - - server.publishChatStreamNotify(chatID, coderdpubsub.ChatStreamNotifyMessage{ - AfterMessageID: 1, - }) - - event := requireStreamMessageEvent(t, events) - require.Equal(t, int64(2), event.Message.ID) - requireNoStreamEvent(t, events, 200*time.Millisecond) -} - -func TestSubscribeFullRefreshStillUsesDatabaseCatchup(t *testing.T) { - t.Parallel() - - ctx, cancelCtx := context.WithCancel(context.Background()) - defer cancelCtx() - - ctrl := gomock.NewController(t) - db := dbmock.NewMockStore(ctrl) - - chatID := uuid.New() - chat := database.Chat{ID: chatID, Status: database.ChatStatusPending} - initialMessage := database.ChatMessage{ - ID: 1, - ChatID: chatID, - Role: database.ChatMessageRoleUser, - } - editedMessage := database.ChatMessage{ - ID: 1, - ChatID: chatID, - Role: database.ChatMessageRoleUser, - } - gomock.InOrder( - db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), - db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), - db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{ - ChatID: chatID, - AfterID: 0, - }).Return([]database.ChatMessage{initialMessage}, nil), - db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil), - db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{ - ChatID: chatID, - AfterID: 0, - }).Return([]database.ChatMessage{editedMessage}, nil), - ) - - server := newSubscribeTestServer(t, db) - _, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0) - require.True(t, ok) - defer cancel() - - server.publishEditedMessage(chatID, editedMessage) - - event := requireStreamMessageEvent(t, events) - require.Equal(t, int64(1), event.Message.ID) - requireNoStreamEvent(t, events, 200*time.Millisecond) -} - -func TestSubscribeDeliversRetryEventViaPubsubOnce(t *testing.T) { - t.Parallel() - - ctx, cancelCtx := context.WithCancel(context.Background()) - defer cancelCtx() - - ctrl := gomock.NewController(t) - db := dbmock.NewMockStore(ctrl) - - chatID := uuid.New() - chat := database.Chat{ID: chatID, Status: database.ChatStatusPending} - gomock.InOrder( - db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), - db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), - db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{ - ChatID: chatID, - AfterID: 0, - }).Return(nil, nil), - db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil), - ) - - server := newSubscribeTestServer(t, db) - _, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0) - require.True(t, ok) - defer cancel() - - expected := newTestRetryPayload() - - server.publishRetry(chatID, expected) - - event := requireStreamRetryEvent(t, events) - require.Equal(t, expected, event.Retry) - requireNoStreamEvent(t, events, 200*time.Millisecond) -} - -func TestSubscribeReplaysCurrentRetryPhaseInSnapshot(t *testing.T) { - t.Parallel() - - ctx, cancelCtx := context.WithCancel(context.Background()) - defer cancelCtx() - - ctrl := gomock.NewController(t) - db := dbmock.NewMockStore(ctrl) - - chatID := uuid.New() - chat := database.Chat{ID: chatID, Status: database.ChatStatusRunning} - - gomock.InOrder( - db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), - db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), - db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{ - ChatID: chatID, - AfterID: 0, - }).Return(nil, nil), - db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil), - ) - - server := newBufferedSubscribeTestServer(t, db, chatID) - - expected := newTestRetryPayload() - server.publishRetry(chatID, expected) - - snapshot, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0) - require.True(t, ok) - defer cancel() - - require.Len(t, snapshot, 2) - require.Equal(t, codersdk.ChatStreamEventTypeStatus, snapshot[0].Type) - require.Equal(t, codersdk.ChatStreamEventTypeRetry, snapshot[1].Type) - event := requireSnapshotRetryEvent(t, snapshot) - require.Equal(t, expected, event.Retry) - requireNoStreamEvent(t, events, 200*time.Millisecond) -} - -func TestSubscribeCapturesRetryPhaseAtSubscriptionBoundary(t *testing.T) { - t.Parallel() - - ctx, cancelCtx := context.WithCancel(context.Background()) - defer cancelCtx() - - ctrl := gomock.NewController(t) - db := dbmock.NewMockStore(ctrl) - - chatID := uuid.New() - chat := database.Chat{ID: chatID, Status: database.ChatStatusRunning} - expected := newTestRetryPayload() - - server := newSubscribeTestServer(t, db) - - gomock.InOrder( - db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), - db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), - db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{ - ChatID: chatID, - AfterID: 0, - }).DoAndReturn(func(context.Context, database.GetChatMessagesByChatIDParams) ([]database.ChatMessage, error) { - server.publishRetry(chatID, expected) - return nil, nil - }), - db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil), - ) - - snapshot, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0) - require.True(t, ok) - defer cancel() - - requireNoSnapshotRetryEvent(t, snapshot) - event := requireStreamRetryEvent(t, events) - require.Equal(t, expected, event.Retry) - requireNoStreamEvent(t, events, 200*time.Millisecond) -} - -func TestSubscribeDoesNotReplayRetryAfterStreamResumes(t *testing.T) { - t.Parallel() - - ctx, cancelCtx := context.WithCancel(context.Background()) - defer cancelCtx() - - ctrl := gomock.NewController(t) - db := dbmock.NewMockStore(ctrl) - - chatID := uuid.New() - chat := database.Chat{ID: chatID, Status: database.ChatStatusRunning} - - gomock.InOrder( - db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), - db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), - db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{ - ChatID: chatID, - AfterID: 0, - }).Return(nil, nil), - db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil), - ) - - server := newBufferedSubscribeTestServer(t, db, chatID) - - server.publishRetry(chatID, newTestRetryPayload()) - server.publishMessagePart(chatID, codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageText("retry recovered")) - - snapshot, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0) - require.True(t, ok) - defer cancel() - - requireNoSnapshotRetryEvent(t, snapshot) - requireSnapshotMessagePartEvent(t, snapshot) - requireNoStreamEvent(t, events, 200*time.Millisecond) -} - -func TestSubscribeDoesNotReplayFailedAttemptPartsAfterRetry(t *testing.T) { - t.Parallel() - - ctx, cancelCtx := context.WithCancel(context.Background()) - defer cancelCtx() - - ctrl := gomock.NewController(t) - db := dbmock.NewMockStore(ctrl) - - chatID := uuid.New() - chat := database.Chat{ID: chatID, Status: database.ChatStatusRunning} - - gomock.InOrder( - db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), - db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), - db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{ - ChatID: chatID, - AfterID: 0, - }).Return(nil, nil), - db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil), - ) - - server := newBufferedSubscribeTestServer(t, db, chatID) - - server.publishMessagePart(chatID, codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageText("failed partial")) - server.clearProvisionalStreamParts(chatID) - server.publishRetry(chatID, newTestRetryPayload()) - server.publishMessagePart(chatID, codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageText("retry recovered")) - - snapshot, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0) - require.True(t, ok) - defer cancel() - - requireNoSnapshotRetryEvent(t, snapshot) - partEvent := requireSnapshotMessagePartEvent(t, snapshot) - require.Equal(t, "retry recovered", partEvent.MessagePart.Part.Text) - for _, event := range snapshot { - if event.Type != codersdk.ChatStreamEventTypeMessagePart { - continue - } - require.NotEqual(t, "failed partial", event.MessagePart.Part.Text) - } - requireNoStreamEvent(t, events, 200*time.Millisecond) -} - -func TestSubscribeDoesNotReplayRetryAfterTerminalError(t *testing.T) { - t.Parallel() - - ctx, cancelCtx := context.WithCancel(context.Background()) - defer cancelCtx() - - ctrl := gomock.NewController(t) - db := dbmock.NewMockStore(ctrl) - - chatID := uuid.New() - chat := database.Chat{ID: chatID, Status: database.ChatStatusRunning} - - gomock.InOrder( - db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), - db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), - db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{ - ChatID: chatID, - AfterID: 0, - }).Return(nil, nil), - db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil), - ) - - server := newBufferedSubscribeTestServer(t, db, chatID) - - server.publishRetry(chatID, newTestRetryPayload()) - server.publishError(chatID, chaterror.ClassifiedError{ - Message: "OpenAI is rate limiting requests.", - Kind: codersdk.ChatErrorKindRateLimit, - Provider: "openai", - Retryable: true, - StatusCode: 429, - }) - - snapshot, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0) - require.True(t, ok) - defer cancel() - - requireNoSnapshotRetryEvent(t, snapshot) - requireNoStreamEvent(t, events, 200*time.Millisecond) -} - -func TestSubscribeDoesNotReplayRetryAfterTerminalStatus(t *testing.T) { - t.Parallel() - - ctx, cancelCtx := context.WithCancel(context.Background()) - defer cancelCtx() - - ctrl := gomock.NewController(t) - db := dbmock.NewMockStore(ctrl) - - chatID := uuid.New() - chat := database.Chat{ID: chatID, Status: database.ChatStatusCompleted} - - gomock.InOrder( - db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), - db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), - db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{ - ChatID: chatID, - AfterID: 0, - }).Return(nil, nil), - db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil), - ) - - server := newBufferedSubscribeTestServer(t, db, chatID) - - server.publishRetry(chatID, newTestRetryPayload()) - server.publishStatus(chatID, database.ChatStatusCompleted, uuid.NullUUID{}) - - snapshot, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0) - require.True(t, ok) - defer cancel() - - requireNoSnapshotRetryEvent(t, snapshot) - requireNoStreamEvent(t, events, 200*time.Millisecond) -} - -func TestSubscribePrefersStructuredErrorPayloadViaPubsub(t *testing.T) { - t.Parallel() - - ctx, cancelCtx := context.WithCancel(context.Background()) - defer cancelCtx() - - ctrl := gomock.NewController(t) - db := dbmock.NewMockStore(ctrl) - - chatID := uuid.New() - chat := database.Chat{ID: chatID, Status: database.ChatStatusPending} - gomock.InOrder( - db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), - db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), - db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{ - ChatID: chatID, - AfterID: 0, - }).Return(nil, nil), - db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil), - ) - - server := newSubscribeTestServer(t, db) - _, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0) - require.True(t, ok) - defer cancel() - - classified := chaterror.ClassifiedError{ - Message: "OpenAI is rate limiting requests.", - Kind: codersdk.ChatErrorKindRateLimit, - Provider: "openai", - Retryable: true, - StatusCode: 429, - } - server.publishError(chatID, classified) - - event := requireStreamErrorEvent(t, events) - require.Equal(t, chaterror.TerminalErrorPayload(classified), event.Error) - requireNoStreamEvent(t, events, 200*time.Millisecond) -} - -func TestSubscribeFallsBackToLegacyErrorStringViaPubsub(t *testing.T) { - t.Parallel() - - ctx, cancelCtx := context.WithCancel(context.Background()) - defer cancelCtx() - - ctrl := gomock.NewController(t) - db := dbmock.NewMockStore(ctrl) - - chatID := uuid.New() - chat := database.Chat{ID: chatID, Status: database.ChatStatusPending} - gomock.InOrder( - db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), - db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), - db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{ - ChatID: chatID, - AfterID: 0, - }).Return(nil, nil), - db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil), - ) - - server := newSubscribeTestServer(t, db) - _, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0) - require.True(t, ok) - defer cancel() - - server.publishChatStreamNotify(chatID, coderdpubsub.ChatStreamNotifyMessage{ - Error: "legacy error only", - }) - - event := requireStreamErrorEvent(t, events) - require.Equal(t, &codersdk.ChatError{Message: "legacy error only"}, event.Error) - requireNoStreamEvent(t, events, 200*time.Millisecond) -} - -func newTestRetryPayload() *codersdk.ChatStreamRetry { - payload := chaterror.StreamRetryPayload(1, 1500*time.Millisecond, chaterror.ClassifiedError{ - Message: "OpenAI is rate limiting requests.", - Kind: codersdk.ChatErrorKindRateLimit, - Provider: "openai", - Retryable: true, - StatusCode: 429, - }) - if payload == nil { - panic("expected retry payload") - } - payload.RetryingAt = time.Unix(1_700_000_000, 0).UTC() - return payload -} - func TestSubscribeAuthorizedFallsBackToStaleRowWhenRefreshFails(t *testing.T) { t.Parallel() @@ -2790,100 +2232,6 @@ func newSubscribeTestServer(t *testing.T, db database.Store) *Server { } } -func newBufferedSubscribeTestServer(t *testing.T, db database.Store, chatID uuid.UUID) *Server { - t.Helper() - - server := newSubscribeTestServer(t, db) - state := server.getOrCreateStreamState(chatID) - state.mu.Lock() - state.buffering = true - state.mu.Unlock() - return server -} - -func requireStreamMessageEvent(t *testing.T, events <-chan codersdk.ChatStreamEvent) codersdk.ChatStreamEvent { - t.Helper() - - select { - case event, ok := <-events: - require.True(t, ok, "chat stream closed before delivering an event") - require.Equal(t, codersdk.ChatStreamEventTypeMessage, event.Type) - require.NotNil(t, event.Message) - return event - case <-time.After(time.Second): - t.Fatal("timed out waiting for chat stream message event") - return codersdk.ChatStreamEvent{} - } -} - -func requireStreamRetryEvent(t *testing.T, events <-chan codersdk.ChatStreamEvent) codersdk.ChatStreamEvent { - t.Helper() - - select { - case event, ok := <-events: - require.True(t, ok, "chat stream closed before delivering an event") - require.Equal(t, codersdk.ChatStreamEventTypeRetry, event.Type) - require.NotNil(t, event.Retry) - return event - case <-time.After(time.Second): - t.Fatal("timed out waiting for chat stream retry event") - return codersdk.ChatStreamEvent{} - } -} - -func requireSnapshotRetryEvent(t *testing.T, snapshot []codersdk.ChatStreamEvent) codersdk.ChatStreamEvent { - t.Helper() - - var retryEvents []codersdk.ChatStreamEvent - for _, event := range snapshot { - if event.Type == codersdk.ChatStreamEventTypeRetry { - retryEvents = append(retryEvents, event) - } - } - - require.Len(t, retryEvents, 1, "expected exactly one retry event in snapshot") - require.NotNil(t, retryEvents[0].Retry) - return retryEvents[0] -} - -func requireNoSnapshotRetryEvent(t *testing.T, snapshot []codersdk.ChatStreamEvent) { - t.Helper() - - for _, event := range snapshot { - require.NotEqual(t, codersdk.ChatStreamEventTypeRetry, event.Type, - "unexpected retry event in snapshot: %+v", event) - } -} - -func requireSnapshotMessagePartEvent(t *testing.T, snapshot []codersdk.ChatStreamEvent) codersdk.ChatStreamEvent { - t.Helper() - - for _, event := range snapshot { - if event.Type == codersdk.ChatStreamEventTypeMessagePart { - require.NotNil(t, event.MessagePart) - return event - } - } - - t.Fatal("expected message_part event in snapshot") - return codersdk.ChatStreamEvent{} -} - -func requireStreamErrorEvent(t *testing.T, events <-chan codersdk.ChatStreamEvent) codersdk.ChatStreamEvent { - t.Helper() - - select { - case event, ok := <-events: - require.True(t, ok, "chat stream closed before delivering an event") - require.Equal(t, codersdk.ChatStreamEventTypeError, event.Type) - require.NotNil(t, event.Error) - return event - case <-time.After(time.Second): - t.Fatal("timed out waiting for chat stream error event") - return codersdk.ChatStreamEvent{} - } -} - func requireNoStreamEvent(t *testing.T, events <-chan codersdk.ChatStreamEvent, wait time.Duration) { t.Helper() @@ -2897,98 +2245,6 @@ func requireNoStreamEvent(t *testing.T, events <-chan codersdk.ChatStreamEvent, } } -// TestPublishToStream_DropWarnRateLimiting walks through a -// realistic lifecycle: buffer fills up, subscriber channel fills -// up, counters get reset between steps. It verifies that WARN -// logs are rate-limited to at most once per streamDropWarnInterval -// and that counter resets re-enable an immediate WARN. -func TestPublishToStream_DropWarnRateLimiting(t *testing.T) { - t.Parallel() - - sink := testutil.NewFakeSink(t) - mClock := quartz.NewMock(t) - - server := &Server{ - logger: sink.Logger(), - clock: mClock, - } - - chatID := uuid.New() - subCh := make(chan codersdk.ChatStreamEvent, 1) - subCh <- codersdk.ChatStreamEvent{} // pre-fill so sends always drop - - // Set up state that mirrors a running chat: buffer at capacity, - // buffering enabled, one saturated subscriber. - state := &chatStreamState{ - buffering: true, - buffer: make([]bufferedStreamPart, maxStreamBufferSize), - subscribers: map[uuid.UUID]chan codersdk.ChatStreamEvent{ - uuid.New(): subCh, - }, - } - server.chatStreams.Store(chatID, state) - - bufferMsg := "chat stream buffer full, dropping oldest event" - subMsg := "dropping chat stream event" - - filter := func(level slog.Level, msg string) func(slog.SinkEntry) bool { - return func(e slog.SinkEntry) bool { - return e.Level == level && e.Message == msg - } - } - - // --- Phase 1: buffer-full rate limiting --- - // message_part events hit both the buffer-full and subscriber-full - // paths. The first publish triggers a WARN for each; the rest - // within the window are DEBUG. - partEvent := codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeMessagePart, - MessagePart: &codersdk.ChatStreamMessagePart{}, - } - for i := 0; i < 50; i++ { - server.publishToStream(chatID, partEvent) - } - - require.Len(t, sink.Entries(filter(slog.LevelWarn, bufferMsg)), 1) - require.Empty(t, sink.Entries(filter(slog.LevelDebug, bufferMsg))) - requireFieldValue(t, sink.Entries(filter(slog.LevelWarn, bufferMsg))[0], "dropped_count", int64(1)) - - // Subscriber also saw 50 drops (one per publish). - require.Len(t, sink.Entries(filter(slog.LevelWarn, subMsg)), 1) - require.Empty(t, sink.Entries(filter(slog.LevelDebug, subMsg))) - requireFieldValue(t, sink.Entries(filter(slog.LevelWarn, subMsg))[0], "dropped_count", int64(1)) - - // --- Phase 2: clock advance triggers second WARN with count --- - mClock.Advance(streamDropWarnInterval + time.Second) - server.publishToStream(chatID, partEvent) - - bufWarn := sink.Entries(filter(slog.LevelWarn, bufferMsg)) - require.Len(t, bufWarn, 2) - requireFieldValue(t, bufWarn[1], "dropped_count", int64(50)) - - subWarn := sink.Entries(filter(slog.LevelWarn, subMsg)) - require.Len(t, subWarn, 2) - requireFieldValue(t, subWarn[1], "dropped_count", int64(50)) - - // --- Phase 3: counter reset (simulates step persist) --- - state.mu.Lock() - state.buffer = make([]bufferedStreamPart, maxStreamBufferSize) - state.resetDropCounters() - state.mu.Unlock() - - // The very next drop should WARN immediately — the reset zeroed - // lastWarnAt so the interval check passes. - server.publishToStream(chatID, partEvent) - - bufWarn = sink.Entries(filter(slog.LevelWarn, bufferMsg)) - require.Len(t, bufWarn, 3, "expected WARN immediately after counter reset") - requireFieldValue(t, bufWarn[2], "dropped_count", int64(1)) - - subWarn = sink.Entries(filter(slog.LevelWarn, subMsg)) - require.Len(t, subWarn, 3, "expected subscriber WARN immediately after counter reset") - requireFieldValue(t, subWarn[2], "dropped_count", int64(1)) -} - func TestResolveUserCompactionThreshold(t *testing.T) { t.Parallel() @@ -3202,8 +2458,7 @@ func TestSkillsFromParts(t *testing.T) { t.Run("RoundTrip", func(t *testing.T) { // Simulate persist -> reconstruct cycle: marshal skill - // parts the same way persistInstructionFiles does, then - // verify skillsFromParts recovers the metadata. + // parts, then verify skillsFromParts recovers the metadata. t.Parallel() want := []chattool.SkillMeta{ {Name: "deep-review", Description: "Multi-reviewer review", Dir: "/skills/deep-review"}, @@ -3665,40 +2920,6 @@ func TestContextFileAgentID(t *testing.T) { }) } -func TestHasPersistedInstructionFiles(t *testing.T) { - t.Parallel() - - t.Run("IgnoresAgentChatContextSentinel", func(t *testing.T) { - t.Parallel() - agentID := uuid.New() - msgs := []database.ChatMessage{ - chattest.ChatMessageWithParts([]codersdk.ChatMessagePart{{ - Type: codersdk.ChatMessagePartTypeContextFile, - ContextFilePath: AgentChatContextSentinelPath, - ContextFileAgentID: uuid.NullUUID{ - UUID: agentID, - Valid: true, - }, - }}), - } - require.False(t, hasPersistedInstructionFiles(msgs)) - }) - - t.Run("AcceptsPersistedInstructionFile", func(t *testing.T) { - t.Parallel() - agentID := uuid.New() - msgs := []database.ChatMessage{ - chattest.ChatMessageWithParts([]codersdk.ChatMessagePart{{ - Type: codersdk.ChatMessagePartTypeContextFile, - ContextFilePath: "/workspace/AGENTS.md", - ContextFileContent: "repo instructions", - ContextFileAgentID: uuid.NullUUID{UUID: agentID, Valid: true}, - }}), - } - require.True(t, hasPersistedInstructionFiles(msgs)) - }) -} - func TestInstructionFromContextFilesUsesLatestContextAgent(t *testing.T) { t.Parallel() @@ -3862,376 +3083,6 @@ func TestSkillsFromPartsUsesLatestContextAgent(t *testing.T) { }}, got) } -func TestMergeSkillMetas(t *testing.T) { - t.Parallel() - - persisted := []chattool.SkillMeta{{ - Name: "repo-helper", - Description: "Persisted skill", - Dir: "/skills/repo-helper-old", - }} - discovered := []chattool.SkillMeta{ - { - Name: "repo-helper", - Description: "Discovered replacement", - Dir: "/skills/repo-helper-new", - MetaFile: "SKILL.md", - }, - { - Name: "deep-review", - Description: "Discovered skill", - Dir: "/skills/deep-review", - }, - } - - got := mergeSkillMetas(persisted, discovered) - require.Equal(t, []chattool.SkillMeta{ - discovered[0], - discovered[1], - }, got) -} - -func TestSelectSkillMetasForInstructionRefresh(t *testing.T) { - t.Parallel() - - persisted := []chattool.SkillMeta{{Name: "persisted", Dir: "/skills/persisted"}} - discovered := []chattool.SkillMeta{{Name: "discovered", Dir: "/skills/discovered"}} - currentAgentID := uuid.New() - otherAgentID := uuid.New() - - t.Run("MergesCurrentAgentSkills", func(t *testing.T) { - t.Parallel() - got := selectSkillMetasForInstructionRefresh( - persisted, - discovered, - uuid.NullUUID{UUID: currentAgentID, Valid: true}, - uuid.NullUUID{UUID: currentAgentID, Valid: true}, - ) - require.Equal(t, []chattool.SkillMeta{discovered[0], persisted[0]}, got) - }) - - t.Run("DropsStalePersistedSkillsWhenAgentChanged", func(t *testing.T) { - t.Parallel() - got := selectSkillMetasForInstructionRefresh( - persisted, - discovered, - uuid.NullUUID{UUID: currentAgentID, Valid: true}, - uuid.NullUUID{UUID: otherAgentID, Valid: true}, - ) - require.Equal(t, discovered, got) - }) - - t.Run("PreservesPersistedSkillsWhenAgentLookupFails", func(t *testing.T) { - t.Parallel() - got := selectSkillMetasForInstructionRefresh( - persisted, - nil, - uuid.NullUUID{}, - uuid.NullUUID{UUID: otherAgentID, Valid: true}, - ) - require.Equal(t, persisted, got) - }) -} - -// TestProcessChat_IgnoresStaleControlNotification verifies that -// processChat is not interrupted by a "pending" notification -// published before processing begins. This is the race that caused -// TestOpenAIReasoningWithWebSearchRoundTripStoreFalse to flake: -// SendMessage publishes "pending" via PostgreSQL NOTIFY, and due -// to async delivery the notification can arrive at the control -// subscriber after it registers but before the processor publishes -// "running". -func TestProcessChat_IgnoresStaleControlNotification(t *testing.T) { - t.Parallel() - - ctx := testutil.Context(t, testutil.WaitShort) - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) - ctrl := gomock.NewController(t) - db := dbmock.NewMockStore(ctrl) - ps := dbpubsub.NewInMemory() - clock := quartz.NewMock(t) - - chatID := uuid.New() - workerID := uuid.New() - - server := &Server{ - db: db, - logger: logger, - pubsub: ps, - clock: clock, - workerID: workerID, - chatHeartbeatInterval: time.Minute, - metrics: chatloop.NopMetrics(), - configCache: newChatConfigCache(ctx, db, clock), - heartbeatRegistry: make(map[uuid.UUID]*heartbeatEntry), - } - - // Publish a stale "pending" notification on the control channel - // BEFORE processChat subscribes. In production this is the - // notification from SendMessage that triggered the processing. - staleNotify, err := json.Marshal(coderdpubsub.ChatStreamNotifyMessage{ - Status: string(database.ChatStatusPending), - }) - require.NoError(t, err) - err = ps.Publish(coderdpubsub.ChatStreamNotifyChannel(chatID), staleNotify) - require.NoError(t, err) - - // Track which status processChat writes during cleanup. - var finalStatus database.ChatStatus - - // The deferred cleanup in processChat runs a transaction. - db.EXPECT().InTx(gomock.Any(), gomock.Any()).DoAndReturn( - func(fn func(database.Store) error, _ *database.TxOptions) error { - return fn(db) - }, - ) - db.EXPECT().GetChatByIDForUpdate(gomock.Any(), chatID).Return( - database.Chat{ID: chatID, Status: database.ChatStatusRunning, WorkerID: uuid.NullUUID{UUID: workerID, Valid: true}}, nil, - ) - db.EXPECT().UpdateChatStatus(gomock.Any(), gomock.Any()).DoAndReturn( - func(_ context.Context, params database.UpdateChatStatusParams) (database.Chat, error) { - finalStatus = params.Status - return database.Chat{ - ID: chatID, - Status: params.Status, - LastTurnSummary: sql.NullString{String: "previous summary", Valid: true}, - }, nil - }, - ) - db.EXPECT().GetChatByID(gomock.Any(), chatID).Return( - database.Chat{ID: chatID, Status: database.ChatStatusError}, - nil, - ) - - db.EXPECT().UpdateChatLastTurnSummary(gomock.Any(), gomock.Any()).Return(int64(1), nil) - - // resolveChatModel fails immediately — that's fine, we only - // need processChat to get past initialization without being - // interrupted by the stale notification. - db.EXPECT().GetChatModelConfigByID(gomock.Any(), gomock.Any()).Return( - database.ChatModelConfig{}, xerrors.New("no model configured"), - ).AnyTimes() - db.EXPECT().GetAIProviders(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() - db.EXPECT().GetEnabledChatModelConfigs(gomock.Any()).Return(nil, nil).AnyTimes() - db.EXPECT().GetChatUsageLimitConfig(gomock.Any()).Return( - database.ChatUsageLimitConfig{}, sql.ErrNoRows, - ).AnyTimes() - db.EXPECT().GetChatMessagesForPromptByChatID(gomock.Any(), chatID).Return(nil, nil).AnyTimes() - - chat := database.Chat{ID: chatID, LastModelConfigID: uuid.New()} - done := make(chan struct{}) - go func() { - defer close(done) - server.processChat(ctx, chat) - }() - - // Wait for processChat to finish entirely. It re-reads chat state and - // runs more cleanup after UpdateChatStatus, so signaling completion from - // the status update itself races test teardown. - testutil.TryReceive(ctx, t, done) - - WaitUntilIdleForTest(server) - - // If the stale notification interrupted us, status would be - // "waiting" (the ErrInterrupted path). Since the gate blocked - // it, processChat reached runChat, which failed on model - // resolution → status is "error". - require.Equal(t, database.ChatStatusError, finalStatus, - "processChat should have reached runChat (error), not been interrupted (waiting)") -} - -func TestShouldPublishFinishedChatState(t *testing.T) { - t.Parallel() - - ctx := testutil.Context(t, testutil.WaitShort) - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) - ctrl := gomock.NewController(t) - db := dbmock.NewMockStore(ctrl) - chatID := uuid.New() - workerID := uuid.New() - - server := &Server{db: db} - updatedChat := database.Chat{ - ID: chatID, - Status: database.ChatStatusWaiting, - WorkerID: uuid.NullUUID{}, - } - - db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(database.Chat{ - ID: chatID, - Status: database.ChatStatusWaiting, - WorkerID: uuid.NullUUID{}, - }, nil) - - require.True(t, server.shouldPublishFinishedChatState(ctx, logger, updatedChat)) - - db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(database.Chat{ - ID: chatID, - Status: database.ChatStatusRunning, - WorkerID: uuid.NullUUID{UUID: workerID, Valid: true}, - }, nil) - - require.False(t, server.shouldPublishFinishedChatState(ctx, logger, updatedChat)) -} - -// TestShouldPublishFinishedChatState_DBErrorPublishes pins the -// deliberate fail-open behavior when the re-read query errors: we -// surface the finished state anyway so watchers don't get stuck -// waiting for a status update that never arrives. The error path is -// easy to regress into a fail-closed default otherwise. -func TestShouldPublishFinishedChatState_DBErrorPublishes(t *testing.T) { - t.Parallel() - - ctx := testutil.Context(t, testutil.WaitShort) - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) - ctrl := gomock.NewController(t) - db := dbmock.NewMockStore(ctrl) - chatID := uuid.New() - - server := &Server{db: db} - updatedChat := database.Chat{ - ID: chatID, - Status: database.ChatStatusWaiting, - WorkerID: uuid.NullUUID{}, - } - - db.EXPECT().GetChatByID(gomock.Any(), chatID).Return( - database.Chat{}, xerrors.New("boom"), - ) - - require.True(t, server.shouldPublishFinishedChatState(ctx, logger, updatedChat), - "fail-open: a re-read error must not swallow the status change") -} - -// TestHeartbeatTick_StolenChatIsInterrupted verifies that when the -// batch heartbeat UPDATE does not return a registered chat's ID -// (because another replica stole it or it was completed), the -// heartbeat tick cancels that chat's context with ErrInterrupted -// while leaving surviving chats untouched. -func TestHeartbeatTick_StolenChatIsInterrupted(t *testing.T) { - t.Parallel() - - ctx := testutil.Context(t, testutil.WaitShort) - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) - ctrl := gomock.NewController(t) - db := dbmock.NewMockStore(ctrl) - clock := quartz.NewMock(t) - - workerID := uuid.New() - - server := &Server{ - db: db, - logger: logger, - clock: clock, - workerID: workerID, - chatHeartbeatInterval: time.Minute, - metrics: chatloop.NopMetrics(), - heartbeatRegistry: make(map[uuid.UUID]*heartbeatEntry), - } - - // Create three chats with independent cancel functions. - chat1 := uuid.New() - chat2 := uuid.New() - chat3 := uuid.New() - - _, cancel1 := context.WithCancelCause(ctx) - _, cancel2 := context.WithCancelCause(ctx) - ctx3, cancel3 := context.WithCancelCause(ctx) - - server.registerHeartbeat(&heartbeatEntry{ - cancelWithCause: cancel1, - chatID: chat1, - logger: logger, - }) - server.registerHeartbeat(&heartbeatEntry{ - cancelWithCause: cancel2, - chatID: chat2, - logger: logger, - }) - server.registerHeartbeat(&heartbeatEntry{ - cancelWithCause: cancel3, - chatID: chat3, - logger: logger, - }) - - // The batch UPDATE returns only chat1 and chat2 — - // chat3 was "stolen" by another replica. - db.EXPECT().UpdateChatHeartbeats(gomock.Any(), gomock.Any()).DoAndReturn( - func(_ context.Context, params database.UpdateChatHeartbeatsParams) ([]uuid.UUID, error) { - require.Equal(t, workerID, params.WorkerID) - require.Len(t, params.IDs, 3) - // Return only chat1 and chat2 as surviving. - return []uuid.UUID{chat1, chat2}, nil - }, - ) - - server.heartbeatTick(ctx) - - // chat3's context should be canceled with ErrInterrupted. - require.ErrorIs(t, context.Cause(ctx3), chatloop.ErrInterrupted, - "stolen chat should be interrupted") - - // chat3 should have been removed from the registry by - // unregister (in production this happens via defer in - // processChat). The heartbeat tick itself does not - // unregister — it only cancels. Verify the entry is - // still present (processChat's defer would clean it up). - server.heartbeatMu.Lock() - _, chat1Exists := server.heartbeatRegistry[chat1] - _, chat2Exists := server.heartbeatRegistry[chat2] - _, chat3Exists := server.heartbeatRegistry[chat3] - server.heartbeatMu.Unlock() - - require.True(t, chat1Exists, "surviving chat1 should remain registered") - require.True(t, chat2Exists, "surviving chat2 should remain registered") - require.True(t, chat3Exists, - "stolen chat3 should still be in registry (processChat defer removes it)") -} - -// TestHeartbeatTick_DBErrorDoesNotInterruptChats verifies that a -// transient database failure causes the tick to log and return -// without canceling any registered chats. -func TestHeartbeatTick_DBErrorDoesNotInterruptChats(t *testing.T) { - t.Parallel() - - ctx := testutil.Context(t, testutil.WaitShort) - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) - ctrl := gomock.NewController(t) - db := dbmock.NewMockStore(ctrl) - clock := quartz.NewMock(t) - - server := &Server{ - db: db, - logger: logger, - clock: clock, - workerID: uuid.New(), - chatHeartbeatInterval: time.Minute, - metrics: chatloop.NopMetrics(), - heartbeatRegistry: make(map[uuid.UUID]*heartbeatEntry), - } - - chatID := uuid.New() - chatCtx, cancel := context.WithCancelCause(ctx) - - server.registerHeartbeat(&heartbeatEntry{ - cancelWithCause: cancel, - chatID: chatID, - logger: logger, - }) - - // Simulate a transient DB error. - db.EXPECT().UpdateChatHeartbeats(gomock.Any(), gomock.Any()).Return( - nil, xerrors.New("connection reset"), - ) - - server.heartbeatTick(ctx) - - // Chat should NOT be interrupted — the tick logged and - // returned early. - require.NoError(t, chatCtx.Err(), - "chat context should not be canceled on transient DB error") -} - // TestSubscribeCancelDuringGrace_ReapedBySweep verifies that a // subscriber detach inside bufferRetainGracePeriod (the OSS trigger // for the retained-buffer leak) leaves the state mapped, and the @@ -4425,79 +3276,6 @@ func TestSweepIdleStreams_DefersDuringGracePeriod(t *testing.T) { require.False(t, ok, "sweep after grace window must reap") } -// TestPublishToStream_DropZeroesBackingSlot verifies that evicting -// the oldest buffered event at capacity zeroes the dropped slot so -// its *ChatStreamMessagePart becomes GC-eligible immediately. -func TestPublishToStream_DropZeroesBackingSlot(t *testing.T) { - t.Parallel() - - mClock := quartz.NewMock(t) - server := &Server{ - logger: slogtest.Make(t, nil), - clock: mClock, - } - - chatID := uuid.New() - - // Over-allocate by one so the post-drop append fits in place and - // exercises the backing-array reuse this test is checking. - buf := make([]bufferedStreamPart, maxStreamBufferSize, maxStreamBufferSize+1) - for i := range buf { - buf[i] = bufferedStreamPart{ - event: codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeMessagePart, - MessagePart: &codersdk.ChatStreamMessagePart{}, - }, - } - } - // Sentinel in slot 0 distinguishes "slot was zeroed" from "slot - // was overwritten by a later append". - sentinel := &codersdk.ChatStreamMessagePart{ - Role: codersdk.ChatMessageRoleAssistant, - } - buf[0] = bufferedStreamPart{ - event: codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeMessagePart, - MessagePart: sentinel, - }, - } - // Alias over the full backing array so we can still observe slot - // 0 after publishToStream reslices state.buffer forward. - origBacking := buf[:cap(buf)] - - state := &chatStreamState{ - buffering: true, - buffer: buf, - subscribers: map[uuid.UUID]chan codersdk.ChatStreamEvent{}, - } - server.chatStreams.Store(chatID, state) - - newPart := &codersdk.ChatStreamMessagePart{ - Role: codersdk.ChatMessageRoleAssistant, - } - server.publishToStream(chatID, codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeMessagePart, - MessagePart: newPart, - }) - - require.Equal(t, bufferedStreamPart{}, origBacking[0], - "dropped slot must be zero-valued so its *ChatStreamMessagePart "+ - "is eligible for GC; got %+v", origBacking[0]) - - // Sanity-check the in-place append path the fix targets: if Go's - // growth policy ever makes this append reallocate, this fails - // loudly so the test author revisits the setup. - require.Same(t, newPart, origBacking[len(origBacking)-1].event.MessagePart, - "append must have landed in the original backing array; the "+ - "zero-out invariant only matters when cap > len") -} - -// TestCleanupStreamIfIdle_StalePointerDoesNotDeleteFreshEntry covers -// the race where a caller holds a pointer to a no-longer-mapped -// state (e.g. a janitor Range callback racing a fresh -// getOrCreateStreamState) and would otherwise evict the fresh entry. -// With CompareAndDelete in cleanupStreamIfIdle the stale delete is -// a no-op. func TestCleanupStreamIfIdle_StalePointerDoesNotDeleteFreshEntry(t *testing.T) { t.Parallel() @@ -5644,237 +4422,6 @@ func TestGetWorkspaceConn_DialErrorNotMisclassifiedAsTimeout(t *testing.T) { require.ErrorContains(t, err, "authentication failed") } -// TestAutoPromote_InsertFailureRollsBackTransaction verifies that when -// tryAutoPromoteQueuedMessage pops a queued message but the subsequent -// insert fails, the error propagates to the InTx callback, causing the -// transaction to roll back and preserving the queued message. -func TestAutoPromote_InsertFailureRollsBackTransaction(t *testing.T) { - t.Parallel() - - ctx := testutil.Context(t, testutil.WaitShort) - ctrl := gomock.NewController(t) - db := dbmock.NewMockStore(ctrl) - tx := dbmock.NewMockStore(ctrl) - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) - ps := dbpubsub.NewInMemory() - clock := quartz.NewReal() - - chatID := uuid.New() - workerID := uuid.New() - ownerID := uuid.New() - modelConfigID := uuid.New() - - waitingChat := database.Chat{ - ID: chatID, - OwnerID: ownerID, - LastModelConfigID: modelConfigID, - Status: database.ChatStatusWaiting, - WorkerID: uuid.NullUUID{UUID: workerID, Valid: true}, - } - queuedMsg := database.ChatQueuedMessage{ - ID: 1, - ChatID: chatID, - Content: []byte(`[{"type":"text","text":"queued"}]`), - } - insertErr := xerrors.New("insert failed") - - server := &Server{ - db: db, - logger: logger, - pubsub: ps, - configCache: newChatConfigCache(ctx, db, clock), - } - - // The caller runs tryAutoPromoteQueuedMessage inside InTx. - // Wire the mock to execute the callback against the TX mock. - var txErr error - db.EXPECT().InTx(gomock.Any(), gomock.Any()).DoAndReturn( - func(fn func(database.Store) error, _ *database.TxOptions) error { - txErr = fn(tx) - return txErr - }, - ) - - // Inside the TX: lock chat, get queued messages, resolve model - // config, pop queued message, insert fails. - tx.EXPECT().GetChatByIDForUpdate(gomock.Any(), chatID).Return(waitingChat, nil) - tx.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return([]database.ChatQueuedMessage{queuedMsg}, nil) - tx.EXPECT().GetChatModelConfigByID(gomock.Any(), modelConfigID).Return(database.ChatModelConfig{ID: modelConfigID}, nil) - tx.EXPECT().PopNextQueuedMessage(gomock.Any(), chatID).Return(queuedMsg, nil) - tx.EXPECT().InsertChatMessages(gomock.Any(), gomock.Any()).Return(nil, insertErr) - - // Invoke tryAutoPromoteQueuedMessage through the same InTx - // pattern the processChat defer uses. The test directly calls - // the production path to verify error propagation. - _ = db.InTx(func(txStore database.Store) error { - latestChat, err := txStore.GetChatByIDForUpdate(ctx, chatID) - if err != nil { - return err - } - - _, _, _, promoteErr := server.tryAutoPromoteQueuedMessage(ctx, txStore, latestChat) - if promoteErr != nil { - return promoteErr - } - - // This code path should not be reached when the insert - // fails, because promoteErr should be non-nil. - return nil - }, nil) - - // The InTx callback must return a non-nil error so the - // transaction rolls back, preserving the queued message. - require.Error(t, txErr, "InTx callback should return error when insert fails") -} - -// TestAutoPromote_WakesRunLoopAfterPromotion verifies that after the -func TestAutoPromote_InsertFailureSkipsStatusUpdate(t *testing.T) { - t.Parallel() - - ctx := testutil.Context(t, testutil.WaitLong) - ctrl := gomock.NewController(t) - db := dbmock.NewMockStore(ctrl) - tx := dbmock.NewMockStore(ctrl) - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) - ps := dbpubsub.NewInMemory() - clock := quartz.NewReal() - - chatID := uuid.New() - workerID := uuid.New() - ownerID := uuid.New() - modelConfigID := uuid.New() - - waitingChat := database.Chat{ - ID: chatID, - OwnerID: ownerID, - LastModelConfigID: modelConfigID, - Status: database.ChatStatusWaiting, - WorkerID: uuid.NullUUID{UUID: workerID, Valid: true}, - } - queuedMsg := database.ChatQueuedMessage{ - ID: 1, - ChatID: chatID, - Content: []byte(`[{"type":"text","text":"queued"}]`), - } - - wakeCh := make(chan struct{}, 1) - server := &Server{ - db: db, - logger: logger, - pubsub: ps, - clock: clock, - workerID: workerID, - wakeCh: wakeCh, - chatHeartbeatInterval: time.Minute, - metrics: chatloop.NopMetrics(), - configCache: newChatConfigCache(ctx, db, clock), - heartbeatRegistry: make(map[uuid.UUID]*heartbeatEntry), - } - - // Hold model resolution until the interrupt has canceled the chat - // context. Returning ErrInterrupted keeps processChat on the - // interrupted path regardless of whether the cache singleflight sees - // the caller cancellation or the DB fetch result first. - modelBlocked := make(chan struct{}) - modelRelease := make(chan struct{}) - var modelBlockedOnce sync.Once - db.EXPECT().GetChatModelConfigByID(gomock.Any(), gomock.Any()).DoAndReturn( - func(_ context.Context, _ uuid.UUID) (database.ChatModelConfig, error) { - modelBlockedOnce.Do(func() { close(modelBlocked) }) - <-modelRelease - return database.ChatModelConfig{}, chatloop.ErrInterrupted - }, - ).AnyTimes() - db.EXPECT().GetAIProviders(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() - db.EXPECT().GetEnabledChatModelConfigs(gomock.Any()).Return(nil, nil).AnyTimes() - db.EXPECT().GetChatUsageLimitConfig(gomock.Any()).Return( - database.ChatUsageLimitConfig{}, sql.ErrNoRows, - ).AnyTimes() - db.EXPECT().GetChatMessagesForPromptByChatID(gomock.Any(), chatID).Return(nil, nil).AnyTimes() - - // The deferred cleanup transaction: InsertChatMessages fails, - // so UpdateChatStatus must NOT be called. - db.EXPECT().InTx(gomock.Any(), gomock.Any()).DoAndReturn( - func(fn func(database.Store) error, _ *database.TxOptions) error { - return fn(tx) - }, - ) - tx.EXPECT().GetChatByIDForUpdate(gomock.Any(), chatID).Return(waitingChat, nil) - tx.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return([]database.ChatQueuedMessage{queuedMsg}, nil) - tx.EXPECT().GetChatModelConfigByID(gomock.Any(), modelConfigID).Return(database.ChatModelConfig{ID: modelConfigID}, nil) - tx.EXPECT().PopNextQueuedMessage(gomock.Any(), chatID).Return(queuedMsg, nil) - tx.EXPECT().InsertChatMessages(gomock.Any(), gomock.Any()).Return( - nil, xerrors.New("insert failed"), - ) - tx.EXPECT().UpdateChatStatus(gomock.Any(), gomock.Any()).Times(0) - - // Subscribe BEFORE launching the goroutine. - runningCh := make(chan struct{}, 1) - unsubRunning, err := ps.SubscribeWithErr( - coderdpubsub.ChatStreamNotifyChannel(chatID), - func(_ context.Context, msg []byte, err error) { - if err != nil { - return - } - var notify coderdpubsub.ChatStreamNotifyMessage - if json.Unmarshal(msg, ¬ify) != nil { - return - } - if notify.Status == string(database.ChatStatusRunning) { - select { - case runningCh <- struct{}{}: - default: - } - } - }, - ) - require.NoError(t, err) - defer unsubRunning() - - chat := database.Chat{ID: chatID, OwnerID: ownerID, LastModelConfigID: modelConfigID} - processDone := make(chan struct{}) - go func() { - defer close(processDone) - server.processChat(ctx, chat) - }() - - select { - case <-runningCh: - case <-ctx.Done(): - t.Fatal("timed out waiting for running status") - } - - select { - case <-modelBlocked: - case <-ctx.Done(): - t.Fatal("timed out waiting for model resolution") - } - - // Publish an interrupt so processChat exits runChat. - interruptMsg, err := json.Marshal(coderdpubsub.ChatStreamNotifyMessage{ - Status: string(database.ChatStatusWaiting), - }) - require.NoError(t, err) - err = ps.Publish(coderdpubsub.ChatStreamNotifyChannel(chatID), interruptMsg) - require.NoError(t, err) - close(modelRelease) - - select { - case <-processDone: - case <-ctx.Done(): - t.Fatal("processChat did not complete") - } - - // The wake channel should NOT have a signal because the - // transaction failed before reaching UpdateChatStatus. - select { - case <-wakeCh: - t.Fatal("wake channel should not have a signal after insert failure") - default: - // No signal, as expected. - } -} - // makeInProgressPart is a small constructor for buffered message_part // fixtures used by snapshotBufferLocked / subscribeToStream tests. It // builds an in-progress part (committedMessageID == 0) with a @@ -5980,235 +4527,6 @@ func TestSnapshotBufferLocked_AllCommittedReturnsEmpty(t *testing.T) { require.Empty(t, snapshotBufferLocked(buffer)) } -// TestPublishToStream_AppendsAsInProgress verifies that parts -// buffered while the chat is streaming are tagged as in-progress -// (committedMessageID == 0) until publishMessage claims them via a -// committed assistant message. -func TestPublishToStream_AppendsAsInProgress(t *testing.T) { - t.Parallel() - - mClock := quartz.NewMock(t) - server := &Server{ - logger: slogtest.Make(t, nil), - clock: mClock, - } - - chatID := uuid.New() - state := &chatStreamState{ - buffering: true, - subscribers: map[uuid.UUID]chan codersdk.ChatStreamEvent{}, - } - server.chatStreams.Store(chatID, state) - - server.publishToStream(chatID, codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeMessagePart, - MessagePart: &codersdk.ChatStreamMessagePart{ - Role: codersdk.ChatMessageRoleAssistant, - Part: codersdk.ChatMessageText("hello"), - }, - }) - - state.mu.Lock() - defer state.mu.Unlock() - require.Len(t, state.buffer, 1) - require.Equal(t, int64(0), state.buffer[0].committedMessageID, - "newly buffered parts must be in-progress until publishMessage claims them") - require.Equal(t, "hello", partText(state.buffer[0].event)) -} - -// TestClaimCommittedParts covers the per-role behavior of -// claimCommittedParts: -// - assistant messages claim every in-progress part with the -// committed message ID. -// - tool / user messages do not claim parts. -// - parts already claimed by an earlier assistant message are not -// re-claimed. -// - a chat with no live state is a no-op (does not panic). -func TestClaimCommittedParts(t *testing.T) { - t.Parallel() - - t.Run("AssistantClaimsAllInProgressParts", func(t *testing.T) { - t.Parallel() - - server := &Server{ - logger: slogtest.Make(t, nil), - clock: quartz.NewMock(t), - } - chatID := uuid.New() - state := server.getOrCreateStreamState(chatID) - state.mu.Lock() - state.buffer = []bufferedStreamPart{ - makeCommittedPart(100, "old-1"), - makeInProgressPart("new-1"), - makeInProgressPart("new-2"), - } - state.mu.Unlock() - - server.claimCommittedParts(chatID, database.ChatMessage{ - ID: 200, - Role: database.ChatMessageRoleAssistant, - }) - - state.mu.Lock() - defer state.mu.Unlock() - require.Equal(t, int64(100), state.buffer[0].committedMessageID, - "already-claimed parts must keep their original message ID") - require.Equal(t, int64(200), state.buffer[1].committedMessageID, - "in-progress parts must be claimed by the new message ID") - require.Equal(t, int64(200), state.buffer[2].committedMessageID, - "in-progress parts must be claimed by the new message ID") - }) - - t.Run("ToolMessageIsNoOp", func(t *testing.T) { - t.Parallel() - - server := &Server{ - logger: slogtest.Make(t, nil), - clock: quartz.NewMock(t), - } - chatID := uuid.New() - state := server.getOrCreateStreamState(chatID) - state.mu.Lock() - state.buffer = []bufferedStreamPart{ - makeInProgressPart("a"), - makeInProgressPart("b"), - } - state.mu.Unlock() - - server.claimCommittedParts(chatID, database.ChatMessage{ - ID: 300, - Role: database.ChatMessageRoleTool, - }) - - state.mu.Lock() - defer state.mu.Unlock() - require.Equal(t, int64(0), state.buffer[0].committedMessageID, - "tool messages must not claim buffered parts") - require.Equal(t, int64(0), state.buffer[1].committedMessageID, - "tool messages must not claim buffered parts") - }) - - t.Run("UserMessageIsNoOp", func(t *testing.T) { - t.Parallel() - - server := &Server{ - logger: slogtest.Make(t, nil), - clock: quartz.NewMock(t), - } - chatID := uuid.New() - state := server.getOrCreateStreamState(chatID) - state.mu.Lock() - state.buffer = []bufferedStreamPart{ - makeInProgressPart("a"), - } - state.mu.Unlock() - - server.claimCommittedParts(chatID, database.ChatMessage{ - ID: 400, - Role: database.ChatMessageRoleUser, - }) - - state.mu.Lock() - defer state.mu.Unlock() - require.Equal(t, int64(0), state.buffer[0].committedMessageID, - "user messages must not claim buffered parts") - }) - - t.Run("NoLiveStateIsNoOp", func(t *testing.T) { - t.Parallel() - - server := &Server{ - logger: slogtest.Make(t, nil), - clock: quartz.NewMock(t), - } - chatID := uuid.New() - - // No state stored: claimCommittedParts must not panic and - // must not allocate a new state for an unknown chat. - require.NotPanics(t, func() { - server.claimCommittedParts(chatID, database.ChatMessage{ - ID: 500, - Role: database.ChatMessageRoleAssistant, - }) - }) - _, ok := server.chatStreams.Load(chatID) - require.False(t, ok, - "claimCommittedParts must not create stream state for a chat that has none") - }) -} - -// TestSubscribeToStream_FiltersBufferedParts_Integration wires -// publishToStream, claimCommittedParts (via publishMessage), and -// subscribeToStream together to confirm the end-to-end contract: a -// reconnecting subscriber only receives parts that belong to the -// current in-progress turn, not parts that were already committed -// to durable assistant messages. -func TestSubscribeToStream_FiltersBufferedParts_Integration(t *testing.T) { - t.Parallel() - - mClock := quartz.NewMock(t) - server := &Server{ - logger: slogtest.Make(t, nil), - clock: mClock, - } - chatID := uuid.New() - - // Simulate the lifecycle: - // 1. Stream parts of turn A (still in-progress, no commit yet). - // 2. Commit turn A; its parts are claimed by message 100. - // 3. Stream parts of turn B (in-progress). - // 4. Commit turn B; its parts are claimed by message 200. - // 5. Stream parts of turn C (in-progress, never committed). - state := server.getOrCreateStreamState(chatID) - state.mu.Lock() - state.buffering = true - state.mu.Unlock() - - publishPart := func(text string) { - server.publishToStream(chatID, codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeMessagePart, - MessagePart: &codersdk.ChatStreamMessagePart{ - Role: codersdk.ChatMessageRoleAssistant, - Part: codersdk.ChatMessageText(text), - }, - }) - } - - publishPart("A-1") - publishPart("A-2") - server.claimCommittedParts(chatID, database.ChatMessage{ - ID: 100, - Role: database.ChatMessageRoleAssistant, - }) - publishPart("B-1") - publishPart("B-2") - server.claimCommittedParts(chatID, database.ChatMessage{ - ID: 200, - Role: database.ChatMessageRoleAssistant, - }) - publishPart("C-1") - - // Reconnecting subscriber: only the currently in-progress turn - // (turn C) survives the filter, no matter what cursor the - // client passes through SubscribeAuthorized (the filter no - // longer depends on the cursor). - snapshot, _, _, cancel := server.subscribeToStream(chatID) - defer cancel() - - texts := make([]string, 0, len(snapshot)) - for _, ev := range snapshot { - texts = append(texts, partText(ev)) - } - require.Equal(t, []string{"C-1"}, texts, - "only in-progress (un-claimed) buffered parts must survive the filter") -} - -// TestPrimeWorkspaceMCPCache_SuccessOnFirstAttempt verifies the -// onChatUpdated cache primer path: when create_workspace / -// start_workspace finish waitForAgentReady and the agent's MCP -// server is already advertising tools, a single ListMCPTools call -// populates the cache so the next PrepareTools step is a cache hit -// and does not need to dial. func TestPrimeWorkspaceMCPCache_SuccessOnFirstAttempt(t *testing.T) { t.Parallel() diff --git a/coderd/x/chatd/chatd_retry_test.go b/coderd/x/chatd/chatd_retry_test.go new file mode 100644 index 0000000000..0fb88bca00 --- /dev/null +++ b/coderd/x/chatd/chatd_retry_test.go @@ -0,0 +1,231 @@ +package chatd_test + +import ( + "context" + "encoding/json" + "sync/atomic" + "testing" + + "github.com/google/uuid" + "github.com/prometheus/client_golang/prometheus" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbgen" + "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/coderd/x/chatd" + "github.com/coder/coder/v2/coderd/x/chatd/chattest" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/testutil" + "github.com/coder/quartz" +) + +func TestActiveServer_RetryStatePersistedDuringBackoff(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + db, ps := dbtestutil.NewDB(t) + clock := quartz.NewMock(t).WithLogger(quartz.NoOpLogger) + var calls atomic.Int32 + openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { + if !req.Stream { + return chattest.OpenAINonStreamingResponse("title") + } + if calls.Add(1) == 1 { + return chattest.OpenAIRateLimitResponse() + } + return chattest.OpenAIStreamingResponse(openAITextChunksWithStop("recovered")...) + }) + user, org, model := seedChatDependenciesWithProvider(t, db, "openai", openAIURL) + server := newActiveTestServer(t, db, ps, func(cfg *chatd.Config) { + cfg.Clock = clock + }) + + chat := createChatThroughServer(ctx, t, server, org.ID, user.ID, model.ID, "hello") + withRetry := waitForChatRetryState(ctx, t, db, chat.ID) + require.Equal(t, database.ChatStatusRunning, withRetry.Status) + require.True(t, withRetry.RetryState.Valid) + require.Equal(t, withRetry.SnapshotVersion, withRetry.RetryStateVersion) + require.Equal(t, int64(1), withRetry.GenerationAttempt) + + var retryPayload codersdk.ChatStreamRetry + require.NoError(t, json.Unmarshal(withRetry.RetryState.RawMessage, &retryPayload)) + require.Equal(t, 1, retryPayload.Attempt) + require.Equal(t, int64(1000), retryPayload.DelayMs) + require.Equal(t, "OpenAI is rate limiting requests.", retryPayload.Error) + require.Equal(t, codersdk.ChatErrorKindRateLimit, retryPayload.Kind) + require.Equal(t, "openai", retryPayload.Provider) + require.Equal(t, 429, retryPayload.StatusCode) + require.False(t, retryPayload.RetryingAt.IsZero()) + + advanceToNextTimer(ctx, clock) + advanceUntilProviderCall(ctx, clock, &calls, 2) + waitForChatStatus(ctx, t, db, chat.ID, database.ChatStatusWaiting) + require.Equal(t, int32(2), calls.Load()) + latest, err := db.GetChatByID(ctx, chat.ID) + require.NoError(t, err) + require.False(t, latest.RetryState.Valid) + require.Greater(t, latest.RetryStateVersion, withRetry.RetryStateVersion) + messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ChatID: chat.ID}) + require.NoError(t, err) + requireTextPart(t, messages[len(messages)-1], "recovered") +} + +func TestActiveServer_RetryStreamSilenceTimeoutAndClassification(t *testing.T) { + t.Parallel() + + t.Run("rate limit retry recovers and records metric", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + db, ps := dbtestutil.NewDB(t) + reg := prometheus.NewRegistry() + var calls atomic.Int32 + openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { + if !req.Stream { + return chattest.OpenAINonStreamingResponse("title") + } + if calls.Add(1) == 1 { + return chattest.OpenAIRateLimitResponse() + } + return chattest.OpenAIStreamingResponse(openAITextChunksWithStop("recovered")...) + }) + user, org, _ := seedChatDependenciesWithProvider(t, db, "openai", openAIURL) + model := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{ + Provider: "openai", + Model: "gpt-4o", + Enabled: true, + CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true}, + UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true}, + }) + server := newActiveTestServer(t, db, ps, func(cfg *chatd.Config) { + cfg.PrometheusRegistry = reg + }) + + chat := createChatThroughServer(ctx, t, server, org.ID, user.ID, model.ID, "hello") + waitForChatStatus(ctx, t, db, chat.ID, database.ChatStatusWaiting) + require.Equal(t, int32(2), calls.Load()) + messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ChatID: chat.ID}) + require.NoError(t, err) + requireTextPart(t, messages[len(messages)-1], "recovered") + requireRetryCounter(t, reg, "coderd_chatd_stream_retries_total", 1, map[string]string{ + "provider": "openai", + "model": "gpt-4o", + "kind": string(codersdk.ChatErrorKindRateLimit), + "chain_broken": "false", + }) + }) + + t.Run("stream silence timeout retry recovers", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + db, ps := dbtestutil.NewDB(t) + clock := quartz.NewMock(t).WithLogger(quartz.NoOpLogger) + reg := prometheus.NewRegistry() + var calls atomic.Int32 + openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { + if !req.Stream { + return chattest.OpenAINonStreamingResponse("title") + } + if calls.Add(1) == 1 { + <-req.Request.Context().Done() + return chattest.OpenAIStreamingResponse(openAITextChunksWithStop("timed out")...) + } + return chattest.OpenAIStreamingResponse(openAITextChunksWithStop("recovered")...) + }) + user, org, model := seedChatDependenciesWithProvider(t, db, "openai", openAIURL) + server := newActiveTestServer(t, db, ps, func(cfg *chatd.Config) { + cfg.Clock = clock + cfg.PrometheusRegistry = reg + }) + + chat := createChatThroughServer(ctx, t, server, org.ID, user.ID, model.ID, "hello") + advanceUntilProviderCall(ctx, clock, &calls, 1) + advanceToNextTimer(ctx, clock) + advanceUntilProviderCall(ctx, clock, &calls, 2) + waitForChatStatus(ctx, t, db, chat.ID, database.ChatStatusWaiting) + require.Equal(t, int32(2), calls.Load()) + messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ChatID: chat.ID}) + require.NoError(t, err) + requireTextPart(t, messages[len(messages)-1], "recovered") + requireRetryCounter(t, reg, "coderd_chatd_stream_retries_total", 1, map[string]string{ + "provider": "openai", + "model": "gpt-4o-mini", + "kind": string(codersdk.ChatErrorKindStreamSilenceTimeout), + "chain_broken": "false", + }) + }) +} + +func requireRetryCounter(t *testing.T, reg *prometheus.Registry, name string, wantValue float64, wantLabels map[string]string) { + t.Helper() + require.True(t, hasRetryCounter(t, reg, name, wantValue, wantLabels), "metric %s not found", name) +} + +func hasRetryCounter(t *testing.T, reg *prometheus.Registry, name string, wantValue float64, wantLabels map[string]string) bool { + t.Helper() + + families, err := reg.Gather() + require.NoError(t, err) + for _, family := range families { + if family.GetName() != name { + continue + } + for _, metric := range family.GetMetric() { + if metric.GetCounter().GetValue() != wantValue { + continue + } + labels := map[string]string{} + for _, label := range metric.GetLabel() { + labels[label.GetName()] = label.GetValue() + } + matches := true + for key, want := range wantLabels { + if labels[key] != want { + matches = false + break + } + } + if matches { + return true + } + } + return false + } + return false +} + +func waitForChatRetryState(ctx context.Context, t *testing.T, db database.Store, chatID uuid.UUID) database.Chat { + t.Helper() + var chat database.Chat + testutil.Eventually(ctx, t, func(ctx context.Context) bool { + latest, err := db.GetChatByID(ctx, chatID) + if err != nil { + return false + } + chat = latest + return latest.RetryState.Valid + }, testutil.IntervalFast) + return chat +} + +func advanceUntilProviderCall(ctx context.Context, clock *quartz.Mock, calls *atomic.Int32, want int32) { + for calls.Load() < want { + advanceToNextTimer(ctx, clock) + } +} + +func advanceToNextTimer(ctx context.Context, clock *quartz.Mock) { + _, waiter := clock.AdvanceNext() + waiter.MustWait(ctx) +} + +func openAITextChunksWithStop(deltas ...string) []chattest.OpenAIChunk { + chunks := chattest.OpenAITextChunks(deltas...) + if len(chunks) == 0 { + return nil + } + chunks[len(chunks)-1].Choices[0].FinishReason = "stop" + return chunks +} diff --git a/coderd/x/chatd/chatd_test.go b/coderd/x/chatd/chatd_test.go index c11e9beeee..b0524423ec 100644 --- a/coderd/x/chatd/chatd_test.go +++ b/coderd/x/chatd/chatd_test.go @@ -21,16 +21,20 @@ import ( "testing" "time" + "charm.land/fantasy" + fantasyanthropic "charm.land/fantasy/providers/anthropic" "github.com/google/uuid" mcpgo "github.com/mark3labs/mcp-go/mcp" mcpserver "github.com/mark3labs/mcp-go/server" "github.com/prometheus/client_golang/prometheus" + io_prometheus_client "github.com/prometheus/client_model/go" "github.com/sqlc-dev/pqtype" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/mock/gomock" "golang.org/x/xerrors" + "cdr.dev/slog/v3" "cdr.dev/slog/v3/sloggers/slogtest" "github.com/coder/coder/v2/agent/agentcontextconfig" "github.com/coder/coder/v2/agent/agenttest" @@ -50,6 +54,8 @@ import ( "github.com/coder/coder/v2/coderd/x/chatd" "github.com/coder/coder/v2/coderd/x/chatd/chatadvisor" "github.com/coder/coder/v2/coderd/x/chatd/chatprompt" + "github.com/coder/coder/v2/coderd/x/chatd/chatsanitize" + "github.com/coder/coder/v2/coderd/x/chatd/chatstate" "github.com/coder/coder/v2/coderd/x/chatd/chattest" "github.com/coder/coder/v2/coderd/x/chatd/chattool" "github.com/coder/coder/v2/codersdk" @@ -243,10 +249,10 @@ func newWorkspaceToolTestServer( mockConn.EXPECT().ListMCPTools(gomock.Any()). Return(workspacesdk.ListMCPToolsResponse{}, nil).AnyTimes() mockConn.EXPECT().LS(gomock.Any(), gomock.Any(), gomock.Any()). - Return(workspacesdk.LSResponse{}, nil).AnyTimes() + Return(workspacesdk.LSResponse{AbsolutePathString: "/home/coder"}, nil).AnyTimes() mockConn.EXPECT().ReadFile(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). DoAndReturn(func(_ context.Context, path string, _, _ int64) (io.ReadCloser, string, error) { - if path == "/home/coder/PLAN.md" { + if strings.HasPrefix(path, "/home/coder/.coder/plans/PLAN-") || path == "/home/coder/PLAN.md" { return io.NopCloser(strings.NewReader(planContent)), "", nil } return io.NopCloser(strings.NewReader("")), "", nil @@ -835,7 +841,11 @@ func TestExploreChatUsesPersistedMCPSnapshot(t *testing.T) { ClientType: database.ChatClientTypeApi, }) - exploreChat := dbgen.Chat(t, db, database.Chat{ + userContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{ + codersdk.ChatMessageText("inspect the codebase"), + }) + require.NoError(t, err) + createdExplore, err := chatstate.CreateChat(ctx, db, ps, chatstate.CreateChatInput{ OrganizationID: org.ID, OwnerID: user.ID, WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true}, @@ -848,21 +858,21 @@ func TestExploreChatUsesPersistedMCPSnapshot(t *testing.T) { ChatMode: database.ChatModeExplore, Valid: true, }, - Status: database.ChatStatusPending, MCPServerIDs: []uuid.UUID{mcpConfig.ID}, ClientType: database.ChatClientTypeApi, - }) - - dbgen.ChatMessage(t, db, database.ChatMessage{ - ChatID: exploreChat.ID, - CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true}, - ModelConfigID: uuid.NullUUID{UUID: webSearchModel.ID, Valid: true}, - Role: database.ChatMessageRoleUser, - Content: pqtype.NullRawMessage{ - RawMessage: json.RawMessage(`[{"type":"text","text":"inspect the codebase"}]`), - Valid: true, + InitialMessages: []chatstate.Message{ + { + Role: database.ChatMessageRoleUser, + Content: userContent, + Visibility: database.ChatMessageVisibilityBoth, + ContentVersion: chatprompt.CurrentContentVersion, + CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true}, + ModelConfigID: uuid.NullUUID{UUID: webSearchModel.ID, Valid: true}, + }, }, }) + require.NoError(t, err) + exploreChat := createdExplore.Chat ctrl := gomock.NewController(t) mockConn := agentconnmock.NewMockAgentConn(ctrl) @@ -2205,121 +2215,6 @@ func TestAutoPromoteQueuedMessagesPreservesPerTurnModelOrder(t *testing.T) { require.Equal(t, []uuid.UUID{modelConfigA.ID, modelConfigB.ID, modelConfigC.ID}, userModelConfigIDs) } -func TestAutoPromoteQueuedMessageFallsBackForLegacyQueuedRows(t *testing.T) { - t.Parallel() - - testAutoPromoteQueuedMessageFallback(t, uuid.NullUUID{}) -} - -func TestAutoPromoteQueuedMessageFallsBackForInvalidQueuedModelConfigID(t *testing.T) { - t.Parallel() - - testAutoPromoteQueuedMessageFallback(t, uuid.NullUUID{ - UUID: uuid.New(), - Valid: true, - }) -} - -func testAutoPromoteQueuedMessageFallback(t *testing.T, queuedModelConfigID uuid.NullUUID) { - db, ps := dbtestutil.NewDB(t) - ctx := testutil.Context(t, testutil.WaitSuperLong) - - firstRunStarted := make(chan struct{}) - secondRunStarted := make(chan struct{}, 1) - allowFirstRunFinish := make(chan struct{}) - var requestCount atomic.Int32 - openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { - if !req.Stream { - return chattest.OpenAINonStreamingResponse("title") - } - - switch requestCount.Add(1) { - case 1: - chunks := make(chan chattest.OpenAIChunk, 1) - go func() { - defer close(chunks) - chunks <- chattest.OpenAITextChunks("first run partial")[0] - select { - case <-firstRunStarted: - default: - close(firstRunStarted) - } - <-allowFirstRunFinish - }() - return chattest.OpenAIResponse{StreamingChunks: chunks} - default: - select { - case secondRunStarted <- struct{}{}: - default: - } - return chattest.OpenAIStreamingResponse(chattest.OpenAITextChunks("fallback run done")...) - } - }) - - server := newActiveTestServer(t, db, ps, func(cfg *chatd.Config) { - // Disable periodic polling so only signalWake can - // trigger the next processing run. - cfg.PendingChatAcquireInterval = time.Hour - }) - user, org, modelConfig := seedChatDependenciesWithProvider(t, db, "openai-compat", openAIURL) - chat, err := server.CreateChat(ctx, chatd.CreateOptions{ - OrganizationID: org.ID, - OwnerID: user.ID, - Title: "auto-promote queued fallback", - ModelConfigID: modelConfig.ID, - InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")}, - }) - require.NoError(t, err) - - testutil.TryReceive(ctx, t, firstRunStarted) - - queuedContent, err := json.Marshal([]codersdk.ChatMessagePart{codersdk.ChatMessageText("legacy queued row")}) - require.NoError(t, err) - _, err = db.InsertChatQueuedMessage(ctx, database.InsertChatQueuedMessageParams{ - ChatID: chat.ID, - Content: queuedContent, - ModelConfigID: queuedModelConfigID, - }) - require.NoError(t, err) - - close(allowFirstRunFinish) - - testutil.TryReceive(ctx, t, secondRunStarted) - require.GreaterOrEqual(t, requestCount.Load(), int32(2)) - chatd.WaitUntilIdleForTest(server) - - queuedMessages, err := db.GetChatQueuedMessages(ctx, chat.ID) - require.NoError(t, err) - require.Empty(t, queuedMessages) - - storedChat, err := db.GetChatByID(ctx, chat.ID) - require.NoError(t, err) - require.Equal(t, database.ChatStatusWaiting, storedChat.Status) - require.Equal(t, modelConfig.ID, storedChat.LastModelConfigID) - - messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ - ChatID: chat.ID, - AfterID: 0, - }) - require.NoError(t, err) - - var found bool - for _, message := range messages { - if message.Role != database.ChatMessageRoleUser { - continue - } - sdkMessage := db2sdk.ChatMessage(message) - require.Len(t, sdkMessage.Content, 1) - if sdkMessage.Content[0].Text != "legacy queued row" { - continue - } - require.True(t, message.ModelConfigID.Valid) - require.Equal(t, modelConfig.ID, message.ModelConfigID.UUID) - found = true - } - require.True(t, found) -} - func TestInterruptAutoPromotionIgnoresLaterUsageLimitIncrease(t *testing.T) { t.Parallel() @@ -2830,95 +2725,6 @@ func TestEditMessageDebugCleanupPreservesRecentRuns(t *testing.T) { "the buffered run must survive the fast retry") } -// TestArchiveChatDebugCleanupDeletesPreArchiveRuns verifies that -func TestRecoverStaleChatsPeriodically(t *testing.T) { - t.Parallel() - - db, ps := dbtestutil.NewDB(t) - - ctx := testutil.Context(t, testutil.WaitLong) - user, org, model := seedChatDependencies(t, db) - - // Use a very short stale threshold so the periodic recovery - // kicks in quickly during the test. - staleAfter := 500 * time.Millisecond - - // Create a chat and simulate a dead worker by setting the chat - // to running with a heartbeat in the past. - deadWorkerID := uuid.New() - chat := dbgen.Chat(t, db, database.Chat{ - OrganizationID: org.ID, - OwnerID: user.ID, - Title: "stale-recovery-periodic", - LastModelConfigID: model.ID, - }) - - _, err := db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{ - ID: chat.ID, - Status: database.ChatStatusRunning, - WorkerID: uuid.NullUUID{UUID: deadWorkerID, Valid: true}, - StartedAt: sql.NullTime{Time: time.Now().Add(-time.Hour), Valid: true}, - HeartbeatAt: sql.NullTime{Time: time.Now().Add(-time.Hour), Valid: true}, - }) - require.NoError(t, err) - - // Start a new replica. Its startup recovery will reset the - // chat (since the heartbeat is old), but the key point is that - // the periodic loop also recovers newly-stale chats. - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) - server := chatd.New(chatd.Config{ - Logger: logger, - Database: db, - ReplicaID: uuid.New(), - Pubsub: ps, - PendingChatAcquireInterval: testutil.WaitLong, - InFlightChatStaleAfter: staleAfter, - }) - server.Start() - t.Cleanup(func() { - require.NoError(t, server.Close()) - }) - - // The startup recovery should have already reset our stale - // chat. - require.Eventually(t, func() bool { - fromDB, err := db.GetChatByID(ctx, chat.ID) - if err != nil { - return false - } - return fromDB.Status == database.ChatStatusPending - }, testutil.WaitMedium, testutil.IntervalFast) - - // Now simulate a second stale chat appearing AFTER startup. - // This tests the periodic recovery, not just the startup one. - deadWorkerID2 := uuid.New() - chat2 := dbgen.Chat(t, db, database.Chat{ - OrganizationID: org.ID, - OwnerID: user.ID, - Title: "stale-recovery-periodic-2", - LastModelConfigID: model.ID, - }) - - _, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{ - ID: chat2.ID, - Status: database.ChatStatusRunning, - WorkerID: uuid.NullUUID{UUID: deadWorkerID2, Valid: true}, - StartedAt: sql.NullTime{Time: time.Now().Add(-time.Hour), Valid: true}, - HeartbeatAt: sql.NullTime{Time: time.Now().Add(-time.Hour), Valid: true}, - }) - require.NoError(t, err) - - // The periodic stale recovery loop (running at staleAfter/5 = - // 100ms intervals) should pick this up without a restart. - require.Eventually(t, func() bool { - fromDB, err := db.GetChatByID(ctx, chat2.ID) - if err != nil { - return false - } - return fromDB.Status == database.ChatStatusPending - }, testutil.WaitMedium, testutil.IntervalFast) -} - func TestRecoverStaleRequiresActionChat(t *testing.T) { t.Parallel() @@ -2927,107 +2733,164 @@ func TestRecoverStaleRequiresActionChat(t *testing.T) { ctx := testutil.Context(t, testutil.WaitLong) user, org, model := seedChatDependencies(t, db) - // Use a very short stale threshold so the periodic recovery - // kicks in quickly during the test. - staleAfter := 500 * time.Millisecond + toolName := "my_dynamic_tool" + dynamicToolsJSON, err := json.Marshal([]mcpgo.Tool{{ + Name: toolName, + Description: "A test dynamic tool.", + InputSchema: mcpgo.ToolInputSchema{ + Type: "object", + Properties: map[string]any{}, + }, + }}) + require.NoError(t, err) - // Create a chat and set it to requires_action to simulate a - // client that disappeared while the chat was waiting for - // dynamic tool results. - chat := dbgen.Chat(t, db, database.Chat{ + content, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{ + codersdk.ChatMessageText("hello"), + }) + require.NoError(t, err) + created, err := chatstate.CreateChat(ctx, db, ps, chatstate.CreateChatInput{ OrganizationID: org.ID, OwnerID: user.ID, - Title: "stale-requires-action", LastModelConfigID: model.ID, - }) - - _, err := db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{ - ID: chat.ID, - Status: database.ChatStatusRequiresAction, + Title: "stale-requires-action", + DynamicTools: nullRawMessage(dynamicToolsJSON), + ClientType: database.ChatClientTypeApi, + InitialMessages: []chatstate.Message{ + { + Role: database.ChatMessageRoleUser, + Content: content, + Visibility: database.ChatMessageVisibilityBoth, + ContentVersion: chatprompt.CurrentContentVersion, + CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true}, + ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true}, + }, + }, }) require.NoError(t, err) - // Backdate updated_at so the chat appears stale to the - // recovery loop without needing time.Sleep. + toolCallID := "call_" + uuid.NewString() + assistantContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{ + { + Type: codersdk.ChatMessagePartTypeToolCall, + ToolCallID: toolCallID, + ToolName: toolName, + Args: json.RawMessage(`{}`), + }, + }) + require.NoError(t, err) + machine := chatstate.NewChatMachine(db, ps, created.Chat.ID, chatstate.Options{}) + require.NoError(t, machine.Update(ctx, func(tx *chatstate.Tx) error { + _, err := tx.CommitStep(chatstate.CommitStepInput{ + Messages: []chatstate.Message{ + { + Role: database.ChatMessageRoleAssistant, + Content: assistantContent, + Visibility: database.ChatMessageVisibilityBoth, + ContentVersion: chatprompt.CurrentContentVersion, + ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true}, + }, + }, + }) + return err + })) + require.NoError(t, machine.Update(ctx, func(tx *chatstate.Tx) error { + _, err := tx.EnterRequiresAction(chatstate.EnterRequiresActionInput{}) + return err + })) _, err = rawDB.ExecContext(ctx, - "UPDATE chats SET updated_at = $1 WHERE id = $2", - time.Now().Add(-time.Hour), chat.ID) + "UPDATE chats SET requires_action_deadline_at = $1 WHERE id = $2", + time.Now().Add(-time.Hour), created.Chat.ID) require.NoError(t, err) - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) - server := chatd.New(chatd.Config{ - Logger: logger, - Database: db, - ReplicaID: uuid.New(), - Pubsub: ps, - PendingChatAcquireInterval: testutil.WaitLong, - InFlightChatStaleAfter: staleAfter, - }) + server := newTestServer(t, db, ps, uuid.New()) server.Start() - t.Cleanup(func() { - require.NoError(t, server.Close()) + + chatResult := waitForTerminalChat(ctx, t, db, created.Chat.ID) + require.Equal(t, database.ChatStatusWaiting, chatResult.Status) + require.False(t, chatResult.RequiresActionDeadlineAt.Valid) + + messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ + ChatID: created.Chat.ID, }) - - // The stale recovery should transition the requires_action - // chat to error with the timeout message. - var chatResult database.Chat - require.Eventually(t, func() bool { - chatResult, err = db.GetChatByID(ctx, chat.ID) - if err != nil { - return false - } - return chatResult.Status == database.ChatStatusError - }, testutil.WaitMedium, testutil.IntervalFast) - - persistedError := requireChatLastErrorPayload(t, chatResult.LastError) - require.Equal(t, codersdk.ChatError{ - Message: "Dynamic tool execution timed out", - Kind: codersdk.ChatErrorKindGeneric, - }, persistedError) - require.False(t, chatResult.WorkerID.Valid) + require.NoError(t, err) + require.Len(t, messages, 4) + parts, err := chatprompt.ParseContent(messages[2]) + require.NoError(t, err) + require.Len(t, parts, 1) + require.Equal(t, codersdk.ChatMessagePartTypeToolResult, parts[0].Type) + require.Equal(t, toolCallID, parts[0].ToolCallID) + require.Equal(t, toolName, parts[0].ToolName) + require.True(t, parts[0].IsError) + require.JSONEq(t, `"Tool execution timed out"`, string(parts[0].Result)) } func TestNewReplicaRecoversStaleChatFromDeadReplica(t *testing.T) { t.Parallel() - db, ps := dbtestutil.NewDB(t) + db, ps, rawDB := dbtestutil.NewDBWithSQLDB(t) ctx := testutil.Context(t, testutil.WaitLong) user, org, model := seedChatDependencies(t, db) - // Simulate a chat left running by a dead replica with a stale - // heartbeat (well beyond the stale threshold). - deadReplicaID := uuid.New() - chat := dbgen.Chat(t, db, database.Chat{ + content, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{ + codersdk.ChatMessageText("hello"), + }) + require.NoError(t, err) + created, err := chatstate.CreateChat(ctx, db, ps, chatstate.CreateChatInput{ OrganizationID: org.ID, OwnerID: user.ID, - Title: "orphaned-chat", LastModelConfigID: model.ID, - }) - - // Set the heartbeat far in the past so it's definitely stale. - _, err := db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{ - ID: chat.ID, - Status: database.ChatStatusRunning, - WorkerID: uuid.NullUUID{UUID: deadReplicaID, Valid: true}, - StartedAt: sql.NullTime{Time: time.Now().Add(-time.Hour), Valid: true}, - HeartbeatAt: sql.NullTime{Time: time.Now().Add(-time.Hour), Valid: true}, + Title: "orphaned-chat", + ClientType: database.ChatClientTypeApi, + InitialMessages: []chatstate.Message{ + { + Role: database.ChatMessageRoleUser, + Content: content, + Visibility: database.ChatMessageVisibilityBoth, + ContentVersion: chatprompt.CurrentContentVersion, + CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true}, + ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true}, + }, + }, }) require.NoError(t, err) + deadWorkerID := uuid.New() + deadRunnerID := uuid.New() + machine := chatstate.NewChatMachine(db, ps, created.Chat.ID, chatstate.Options{}) + require.NoError(t, machine.Update(ctx, func(tx *chatstate.Tx) error { + _, err := tx.Acquire(chatstate.AcquireInput{WorkerID: deadWorkerID, RunnerID: deadRunnerID}) + return err + })) + // Simulate a chat left running by a dead replica with a stale + // heartbeat (well beyond the stale threshold). + _, err = rawDB.ExecContext(ctx, + "UPDATE chat_heartbeats SET heartbeat_at = $1 WHERE chat_id = $2 AND runner_id = $3", + time.Now().Add(-time.Hour), created.Chat.ID, deadRunnerID) + require.NoError(t, err) + + newWorkerID := uuid.New() + server := newTestServer(t, db, ps, newWorkerID) // Start a new replica. It should recover the stale chat on // startup. - newReplica := newTestServer(t, db, ps, uuid.New()) - _ = newReplica + server.Start() + var recovered database.Chat require.Eventually(t, func() bool { - fromDB, err := db.GetChatByID(ctx, chat.ID) + recovered, err = db.GetChatByID(ctx, created.Chat.ID) if err != nil { return false } - return fromDB.Status == database.ChatStatusPending && - !fromDB.WorkerID.Valid + return recovered.Status == database.ChatStatusRunning && + recovered.WorkerID.Valid && recovered.WorkerID.UUID == newWorkerID && + recovered.RunnerID.Valid && recovered.RunnerID.UUID != deadRunnerID }, testutil.WaitMedium, testutil.IntervalFast) + + _, err = db.GetChatHeartbeat(ctx, database.GetChatHeartbeatParams{ + ChatID: created.Chat.ID, + RunnerID: recovered.RunnerID.UUID, + }) + require.NoError(t, err) } func TestWaitingChatsAreNotRecoveredAsStale(t *testing.T) { @@ -3428,6 +3291,673 @@ func TestRequiresActionChatPersistsWaitingStatusLabel(t *testing.T) { "expected no web push dispatch for a requires_action chat") } +func TestActiveServer_InterruptionBehavior(t *testing.T) { + t.Parallel() + + t.Run("partial stream commits synthetic tool result and promotes queued message", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + db, ps := dbtestutil.NewDB(t) + streamStarted := make(chan struct{}) + var requestCount atomic.Int32 + anthropicURL := chattest.NewAnthropic(t, func(req *chattest.AnthropicRequest) chattest.AnthropicResponse { + if !req.Stream { + return chattest.AnthropicNonStreamingResponse("title") + } + + if requestCount.Add(1) != 1 { + return chattest.AnthropicStreamingResponse(chattest.AnthropicTextChunks("queued response")...) + } + chunks := make(chan chattest.AnthropicChunk, 5) + go func() { + defer close(chunks) + chunks <- chattest.AnthropicChunk{ + Type: "message_start", + Message: chattest.AnthropicChunkMessage{ + ID: "msg-partial-interrupt", + Type: "message", + Role: "assistant", + Model: "claude-3-opus-20240229", + }, + } + chunks <- chattest.AnthropicChunk{ + Type: "content_block_start", + Index: 0, + ContentBlock: chattest.AnthropicContentBlock{ + Type: "text", + Text: "", + }, + } + chunks <- chattest.AnthropicChunk{ + Type: "content_block_delta", + Index: 0, + Delta: chattest.AnthropicDeltaBlock{Type: "text_delta", Text: "partial assistant output"}, + } + chunks <- chattest.AnthropicChunk{ + Type: "content_block_start", + Index: 1, + ContentBlock: chattest.AnthropicContentBlock{ + Type: "tool_use", + ID: "interrupt-tool-1", + Name: "read_file", + }, + } + chunks <- chattest.AnthropicChunk{ + Type: "content_block_delta", + Index: 1, + Delta: chattest.AnthropicDeltaBlock{Type: "input_json_delta", PartialJSON: `{"path":"main.go"}`}, + } + select { + case <-streamStarted: + default: + close(streamStarted) + } + <-req.Context().Done() + }() + return chattest.AnthropicResponse{StreamingChunks: chunks} + }) + user, org, model := seedAnthropicChatDependencies(t, db, anthropicURL) + ws, dbAgent := seedWorkspaceWithAgent(t, db, user.ID) + + ctrl := gomock.NewController(t) + mockConn := agentconnmock.NewMockAgentConn(ctrl) + setupToolExecutionAgentConn(t, mockConn) + mockConn.EXPECT().ReadFileLines(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(0) + + server := newActiveTestServer(t, db, ps, func(cfg *chatd.Config) { + cfg.AgentConn = func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) { + require.Equal(t, dbAgent.ID, agentID) + return mockConn, func() {}, nil + } + }) + chat, err := server.CreateChat(ctx, chatd.CreateOptions{ + OrganizationID: org.ID, + OwnerID: user.ID, + WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true}, + AgentID: uuid.NullUUID{UUID: dbAgent.ID, Valid: true}, + Title: "interrupt-partial-tool", + ModelConfigID: model.ID, + InitialUserContent: []codersdk.ChatMessagePart{ + codersdk.ChatMessageText("start and call a tool"), + }, + }) + require.NoError(t, err) + + testutil.TryReceive(ctx, t, streamStarted) + queued, err := server.SendMessage(ctx, chatd.SendMessageOptions{ + ChatID: chat.ID, + CreatedBy: user.ID, + ModelConfigID: model.ID, + Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("queued after interrupt")}, + BusyBehavior: chatd.SendMessageBusyBehaviorInterrupt, + }) + require.NoError(t, err) + require.True(t, queued.Queued) + + waitForChatStatus(ctx, t, db, chat.ID, database.ChatStatusWaiting) + require.GreaterOrEqual(t, requestCount.Load(), int32(2)) + + messages := chatMessages(ctx, t, db, chat.ID) + var userTexts []string + var foundPartial bool + for _, msg := range messages { + parts, parseErr := chatprompt.ParseContent(msg) + require.NoError(t, parseErr) + switch msg.Role { + case database.ChatMessageRoleUser: + for _, part := range parts { + if part.Type == codersdk.ChatMessagePartTypeText { + userTexts = append(userTexts, part.Text) + } + } + case database.ChatMessageRoleAssistant: + for _, part := range parts { + if part.Type == codersdk.ChatMessagePartTypeText && strings.Contains(part.Text, "partial assistant output") { + foundPartial = true + } + } + } + } + require.Equal(t, []string{"start and call a tool", "queued after interrupt"}, userTexts) + require.True(t, foundPartial) + + parts := chatToolParts(ctx, t, db, chat.ID) + call := requireToolCallPart(t, parts, "read_file") + require.Equal(t, "interrupt-tool-1", call.ToolCallID) + require.Empty(t, call.Args) + require.Nil(t, call.CreatedAt, "incomplete streamed call should not have a durable call timestamp") + result := requireToolResultPart(t, parts, "read_file") + require.Equal(t, "interrupt-tool-1", result.ToolCallID) + require.True(t, result.IsError) + require.JSONEq(t, `{"error":"tool call was interrupted before it produced a result"}`, string(result.Result)) + require.NotNil(t, result.CreatedAt) + }) + + t.Run("tool execution cancellation commits interrupted result", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + db, ps := dbtestutil.NewDB(t) + var requestCount atomic.Int32 + openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { + if !req.Stream { + return chattest.OpenAINonStreamingResponse("title") + } + + if requestCount.Add(1) == 1 { + chunk := chattest.OpenAIToolCallChunk("read_file", `{"path":"/tmp/slow.txt"}`) + chunk.Choices[0].ToolCalls[0].ID = "tc-slow" + return chattest.OpenAIStreamingResponse( + chattest.OpenAITextChunks("calling tool")[0], + chunk, + ) + } + return chattest.OpenAIStreamingResponse(chattest.OpenAITextChunks("after interrupt")...) + }) + user, org, model := seedChatDependenciesWithProvider(t, db, "openai-compat", openAIURL) + ws, dbAgent := seedWorkspaceWithAgent(t, db, user.ID) + toolStarted := make(chan struct{}) + + ctrl := gomock.NewController(t) + mockConn := agentconnmock.NewMockAgentConn(ctrl) + setupToolExecutionAgentConn(t, mockConn) + mockConn.EXPECT().ReadFileLines(gomock.Any(), "/tmp/slow.txt", int64(1), int64(0), gomock.Any()). + DoAndReturn(func(ctx context.Context, _ string, _, _ int64, _ workspacesdk.ReadFileLinesLimits) (workspacesdk.ReadFileLinesResponse, error) { + close(toolStarted) + <-ctx.Done() + return workspacesdk.ReadFileLinesResponse{}, ctx.Err() + }).Times(1) + + server := newActiveTestServer(t, db, ps, func(cfg *chatd.Config) { + cfg.AgentConn = func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) { + require.Equal(t, dbAgent.ID, agentID) + return mockConn, func() {}, nil + } + }) + chat, err := server.CreateChat(ctx, chatd.CreateOptions{ + OrganizationID: org.ID, + OwnerID: user.ID, + WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true}, + AgentID: uuid.NullUUID{UUID: dbAgent.ID, Valid: true}, + Title: "interrupt-tool-execution", + ModelConfigID: model.ID, + InitialUserContent: []codersdk.ChatMessagePart{ + codersdk.ChatMessageText("run the slow tool"), + }, + }) + require.NoError(t, err) + + testutil.TryReceive(ctx, t, toolStarted) + queued, err := server.SendMessage(ctx, chatd.SendMessageOptions{ + ChatID: chat.ID, + CreatedBy: user.ID, + ModelConfigID: model.ID, + Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("continue after interrupt")}, + BusyBehavior: chatd.SendMessageBusyBehaviorInterrupt, + }) + require.NoError(t, err) + require.True(t, queued.Queued) + + waitForChatStatus(ctx, t, db, chat.ID, database.ChatStatusWaiting) + require.GreaterOrEqual(t, requestCount.Load(), int32(2)) + + messages := chatMessages(ctx, t, db, chat.ID) + var foundText bool + for _, msg := range messages { + if msg.Role != database.ChatMessageRoleAssistant { + continue + } + parts, parseErr := chatprompt.ParseContent(msg) + require.NoError(t, parseErr) + for _, part := range parts { + if part.Type == codersdk.ChatMessagePartTypeText && strings.Contains(part.Text, "calling tool") { + foundText = true + } + } + } + require.True(t, foundText) + + parts := chatToolParts(ctx, t, db, chat.ID) + call := requireToolCallPart(t, parts, "read_file") + require.Equal(t, "tc-slow", call.ToolCallID) + require.NotNil(t, call.CreatedAt) + result := requireToolResultPart(t, parts, "read_file") + require.Equal(t, "tc-slow", result.ToolCallID) + require.True(t, result.IsError) + require.JSONEq(t, `{"error":"tool call was interrupted before it produced a result"}`, string(result.Result)) + require.NotNil(t, result.CreatedAt) + require.False(t, result.CreatedAt.Before(*call.CreatedAt)) + }) + + t.Run("anthropic provider-only interruption commits no synthetic result", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + db, ps := dbtestutil.NewDB(t) + webSearchEnabled := true + providerToolStarted := make(chan struct{}) + var requestCount atomic.Int32 + anthropicURL := chattest.NewAnthropic(t, func(req *chattest.AnthropicRequest) chattest.AnthropicResponse { + if !req.Stream { + return chattest.AnthropicNonStreamingResponse("title") + } + + if requestCount.Add(1) != 1 { + return chattest.AnthropicStreamingResponse(chattest.AnthropicTextChunks("after interrupt")...) + } + chunks := make(chan chattest.AnthropicChunk, 2) + go func() { + defer close(chunks) + chunks <- chattest.AnthropicChunk{ + Type: "message_start", + Message: chattest.AnthropicChunkMessage{ + ID: "msg-provider-interrupt", + Type: "message", + Role: "assistant", + Model: "claude-3-opus-20240229", + }, + } + chunks <- chattest.AnthropicChunk{ + Type: "content_block_start", + Index: 0, + ContentBlock: chattest.AnthropicContentBlock{ + Type: "server_tool_use", + ID: "ws-interrupt", + Name: "web_search", + Input: json.RawMessage(`{"query":"coder"}`), + }, + } + select { + case <-providerToolStarted: + default: + close(providerToolStarted) + } + <-req.Context().Done() + }() + return chattest.AnthropicResponse{StreamingChunks: chunks} + }) + user, org, model := seedAnthropicChatDependencies(t, db, anthropicURL) + model = updateChatModelCallConfig(t, db, model, codersdk.ChatModelCallConfig{ + ProviderOptions: &codersdk.ChatModelProviderOptions{ + Anthropic: &codersdk.ChatModelAnthropicProviderOptions{WebSearchEnabled: &webSearchEnabled}, + }, + }) + + server := newActiveTestServer(t, db, ps) + chat := createChatThroughServer(ctx, t, server, org.ID, user.ID, model.ID, "search for coder") + testutil.TryReceive(ctx, t, providerToolStarted) + queued, err := server.SendMessage(ctx, chatd.SendMessageOptions{ + ChatID: chat.ID, + CreatedBy: user.ID, + ModelConfigID: model.ID, + Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("continue after provider interrupt")}, + BusyBehavior: chatd.SendMessageBusyBehaviorInterrupt, + }) + require.NoError(t, err) + require.True(t, queued.Queued) + + waitForChatStatus(ctx, t, db, chat.ID, database.ChatStatusWaiting) + parts := chatToolParts(ctx, t, db, chat.ID) + require.False(t, toolResultPartExists(parts, "web_search"), + "provider-executed web_search should not get a synthetic local result") + }) + + t.Run("anthropic mixed provider and local interruption keeps local synthetic result", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + db, ps := dbtestutil.NewDB(t) + webSearchEnabled := true + streamStarted := make(chan struct{}) + var requestCount atomic.Int32 + anthropicURL := chattest.NewAnthropic(t, func(req *chattest.AnthropicRequest) chattest.AnthropicResponse { + if !req.Stream { + return chattest.AnthropicNonStreamingResponse("title") + } + + if requestCount.Add(1) != 1 { + return chattest.AnthropicStreamingResponse(chattest.AnthropicTextChunks("after interrupt")...) + } + chunks := make(chan chattest.AnthropicChunk, 3) + go func() { + defer close(chunks) + chunks <- chattest.AnthropicChunk{ + Type: "message_start", + Message: chattest.AnthropicChunkMessage{ + ID: "msg-mixed-interrupt", + Type: "message", + Role: "assistant", + Model: "claude-3-opus-20240229", + }, + } + chunks <- chattest.AnthropicChunk{ + Type: "content_block_start", + Index: 0, + ContentBlock: chattest.AnthropicContentBlock{ + Type: "server_tool_use", + ID: "ws-interrupt", + Name: "web_search", + Input: json.RawMessage(`{"query":"coder"}`), + }, + } + chunks <- chattest.AnthropicChunk{ + Type: "content_block_start", + Index: 1, + ContentBlock: chattest.AnthropicContentBlock{ + Type: "tool_use", + ID: "tc-local", + Name: "read_file", + }, + } + chunks <- chattest.AnthropicChunk{ + Type: "content_block_delta", + Index: 1, + Delta: chattest.AnthropicDeltaBlock{Type: "input_json_delta", PartialJSON: `{"path":"main.go"}`}, + } + select { + case <-streamStarted: + default: + close(streamStarted) + } + <-req.Context().Done() + }() + return chattest.AnthropicResponse{StreamingChunks: chunks} + }) + user, org, model := seedAnthropicChatDependencies(t, db, anthropicURL) + model = updateChatModelCallConfig(t, db, model, codersdk.ChatModelCallConfig{ + ProviderOptions: &codersdk.ChatModelProviderOptions{ + Anthropic: &codersdk.ChatModelAnthropicProviderOptions{WebSearchEnabled: &webSearchEnabled}, + }, + }) + ws, dbAgent := seedWorkspaceWithAgent(t, db, user.ID) + + ctrl := gomock.NewController(t) + mockConn := agentconnmock.NewMockAgentConn(ctrl) + setupToolExecutionAgentConn(t, mockConn) + mockConn.EXPECT().ReadFileLines(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(0) + + server := newActiveTestServer(t, db, ps, func(cfg *chatd.Config) { + cfg.AgentConn = func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) { + require.Equal(t, dbAgent.ID, agentID) + return mockConn, func() {}, nil + } + }) + chat, err := server.CreateChat(ctx, chatd.CreateOptions{ + OrganizationID: org.ID, + OwnerID: user.ID, + WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true}, + AgentID: uuid.NullUUID{UUID: dbAgent.ID, Valid: true}, + Title: "anthropic-mixed-interrupt", + ModelConfigID: model.ID, + InitialUserContent: []codersdk.ChatMessagePart{ + codersdk.ChatMessageText("search and read"), + }, + }) + require.NoError(t, err) + testutil.TryReceive(ctx, t, streamStarted) + queued, err := server.SendMessage(ctx, chatd.SendMessageOptions{ + ChatID: chat.ID, + CreatedBy: user.ID, + ModelConfigID: model.ID, + Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("continue after mixed interrupt")}, + BusyBehavior: chatd.SendMessageBusyBehaviorInterrupt, + }) + require.NoError(t, err) + require.True(t, queued.Queued) + + waitForChatStatus(ctx, t, db, chat.ID, database.ChatStatusWaiting) + parts := chatToolParts(ctx, t, db, chat.ID) + require.False(t, toolResultPartExists(parts, "web_search")) + call := requireToolCallPart(t, parts, "read_file") + require.Equal(t, "tc-local", call.ToolCallID) + require.False(t, call.ProviderExecuted) + result := requireToolResultPart(t, parts, "read_file") + require.Equal(t, "tc-local", result.ToolCallID) + require.False(t, result.ProviderExecuted) + require.True(t, result.IsError) + require.JSONEq(t, `{"error":"tool call was interrupted before it produced a result"}`, string(result.Result)) + }) + + t.Run("interrupted reasoning persists timestamps", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + db, ps := dbtestutil.NewDB(t) + sendReasoning := true + thinkingBudget := int64(1024) + reasoningStarted := make(chan struct{}) + var requestCount atomic.Int32 + anthropicURL := chattest.NewAnthropic(t, func(req *chattest.AnthropicRequest) chattest.AnthropicResponse { + if !req.Stream { + return chattest.AnthropicNonStreamingResponse("title") + } + + if requestCount.Add(1) != 1 { + return chattest.AnthropicStreamingResponse(chattest.AnthropicTextChunks("after interrupt")...) + } + chunks := make(chan chattest.AnthropicChunk, 3) + go func() { + defer close(chunks) + chunks <- chattest.AnthropicChunk{ + Type: "message_start", + Message: chattest.AnthropicChunkMessage{ + ID: "msg-reasoning-interrupt", + Type: "message", + Role: "assistant", + Model: "claude-3-opus-20240229", + }, + } + chunks <- chattest.AnthropicChunk{ + Type: "content_block_start", + Index: 0, + ContentBlock: chattest.AnthropicContentBlock{Type: "thinking"}, + } + chunks <- chattest.AnthropicChunk{ + Type: "content_block_delta", + Index: 0, + Delta: chattest.AnthropicDeltaBlock{Type: "thinking_delta", Thinking: "interrupted thought"}, + } + select { + case <-reasoningStarted: + default: + close(reasoningStarted) + } + <-req.Context().Done() + }() + return chattest.AnthropicResponse{StreamingChunks: chunks} + }) + user, org, model := seedAnthropicChatDependencies(t, db, anthropicURL) + model = updateChatModelCallConfig(t, db, model, codersdk.ChatModelCallConfig{ + ProviderOptions: &codersdk.ChatModelProviderOptions{ + Anthropic: &codersdk.ChatModelAnthropicProviderOptions{ + SendReasoning: &sendReasoning, + Thinking: &codersdk.ChatModelAnthropicThinkingOptions{BudgetTokens: &thinkingBudget}, + }, + }, + }) + + server := newActiveTestServer(t, db, ps) + chat := createChatThroughServer(ctx, t, server, org.ID, user.ID, model.ID, "think") + testutil.TryReceive(ctx, t, reasoningStarted) + queued, err := server.SendMessage(ctx, chatd.SendMessageOptions{ + ChatID: chat.ID, + CreatedBy: user.ID, + ModelConfigID: model.ID, + Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("continue after reasoning")}, + BusyBehavior: chatd.SendMessageBusyBehaviorInterrupt, + }) + require.NoError(t, err) + require.True(t, queued.Queued) + + waitForChatStatus(ctx, t, db, chat.ID, database.ChatStatusWaiting) + messages := chatMessages(ctx, t, db, chat.ID) + var reasoningParts []codersdk.ChatMessagePart + for _, msg := range messages { + if msg.Role != database.ChatMessageRoleAssistant { + continue + } + reasoningParts = append(reasoningParts, reasoningPartsFromMessage(t, msg)...) + } + require.Len(t, reasoningParts, 1) + require.Equal(t, "interrupted thought", strings.TrimSpace(reasoningParts[0].Text)) + require.NotNil(t, reasoningParts[0].CreatedAt) + require.NotNil(t, reasoningParts[0].CompletedAt) + require.False(t, reasoningParts[0].CreatedAt.IsZero()) + require.False(t, reasoningParts[0].CompletedAt.IsZero()) + require.False(t, reasoningParts[0].CompletedAt.Before(*reasoningParts[0].CreatedAt)) + }) +} + +func TestActiveServer_DynamicToolsAndStopAfterToolBehavior(t *testing.T) { + t.Parallel() + + t.Run("dynamic tool enters requires action", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + db, ps := dbtestutil.NewDB(t) + var streamedCallCount atomic.Int32 + openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { + if !req.Stream { + return chattest.OpenAINonStreamingResponse("title") + } + streamedCallCount.Add(1) + return chattest.OpenAIStreamingResponse( + chattest.OpenAIToolCallChunk("my_dynamic_tool", `{"query":"test"}`), + ) + }) + user, org, model := seedChatDependenciesWithProvider(t, db, "openai-compat", openAIURL) + dynamicToolsJSON := dynamicToolJSON(t, "my_dynamic_tool") + + server := newActiveTestServer(t, db, ps) + chat, err := server.CreateChat(ctx, chatd.CreateOptions{ + OrganizationID: org.ID, + OwnerID: user.ID, + Title: "dynamic-tool-requires-action", + ModelConfigID: model.ID, + InitialUserContent: []codersdk.ChatMessagePart{ + codersdk.ChatMessageText("call the dynamic tool"), + }, + DynamicTools: dynamicToolsJSON, + }) + require.NoError(t, err) + + var chatResult database.Chat + testutil.Eventually(ctx, t, func(ctx context.Context) bool { + got, getErr := db.GetChatByID(ctx, chat.ID) + if getErr != nil { + return false + } + chatResult = got + return got.Status == database.ChatStatusRequiresAction || got.Status == database.ChatStatusError + }, testutil.IntervalFast) + require.Equal(t, database.ChatStatusRequiresAction, chatResult.Status, + "expected requires_action, got %s (last_error=%q)", + chatResult.Status, chatLastErrorMessage(chatResult.LastError)) + require.True(t, chatResult.RequiresActionDeadlineAt.Valid) + require.Equal(t, int32(1), streamedCallCount.Load()) + + parts := chatToolParts(ctx, t, db, chat.ID) + call := requireToolCallPart(t, parts, "my_dynamic_tool") + require.JSONEq(t, `{"query":"test"}`, string(call.Args)) + require.False(t, toolResultPartExists(parts, "my_dynamic_tool"), + "dynamic tool should wait for submitted results") + }) + + t.Run("successful stop after tool finishes turn", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + db, ps := dbtestutil.NewDB(t) + var streamedCallCount atomic.Int32 + openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { + if !req.Stream { + return chattest.OpenAINonStreamingResponse("title") + } + switch streamedCallCount.Add(1) { + case 1: + return chattest.OpenAIStreamingResponse( + chattest.OpenAIToolCallChunk("propose_plan", `{}`), + ) + default: + return chattest.OpenAIStreamingResponse(chattest.OpenAITextChunks("should not continue")...) + } + }) + user, org, model := seedChatDependenciesWithProvider(t, db, "openai-compat", openAIURL) + ws, dbAgent := seedWorkspaceWithAgent(t, db, user.ID) + server := newWorkspaceToolTestServer(t, db, ps, dbAgent.ID, "# Plan\n") + + chat, err := server.CreateChat(ctx, chatd.CreateOptions{ + OrganizationID: org.ID, + OwnerID: user.ID, + Title: "stop-after-success", + ModelConfigID: model.ID, + WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true}, + PlanMode: database.NullChatPlanMode{ChatPlanMode: database.ChatPlanModePlan, Valid: true}, + InitialUserContent: []codersdk.ChatMessagePart{ + codersdk.ChatMessageText("propose a plan"), + }, + }) + require.NoError(t, err) + chatResult := waitForChatStatus(ctx, t, db, chat.ID, database.ChatStatusWaiting) + require.False(t, chatResult.WorkerID.Valid) + require.False(t, chatResult.RunnerID.Valid) + require.Equal(t, int32(1), streamedCallCount.Load(), + "stop after tool should finish without another assistant call") + + result := requireToolResultPart(t, chatToolParts(ctx, t, db, chat.ID), "propose_plan") + require.False(t, result.IsError, + "stop after tool should be based on a successful tool result") + }) + + t.Run("error stop after tool continues generation", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + db, ps := dbtestutil.NewDB(t) + var streamedCallCount atomic.Int32 + openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { + if !req.Stream { + return chattest.OpenAINonStreamingResponse("title") + } + switch streamedCallCount.Add(1) { + case 1: + return chattest.OpenAIStreamingResponse( + chattest.OpenAIToolCallChunk("propose_plan", `{"path":"/tmp/not-plan.txt"}`), + ) + default: + return chattest.OpenAIStreamingResponse(chattest.OpenAITextChunks("tool failed, continue")...) + } + }) + user, org, model := seedChatDependenciesWithProvider(t, db, "openai-compat", openAIURL) + ws, dbAgent := seedWorkspaceWithAgent(t, db, user.ID) + server := newWorkspaceToolTestServer(t, db, ps, dbAgent.ID, "# Plan\n") + + chat, err := server.CreateChat(ctx, chatd.CreateOptions{ + OrganizationID: org.ID, + OwnerID: user.ID, + Title: "stop-after-error", + ModelConfigID: model.ID, + WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true}, + PlanMode: database.NullChatPlanMode{ChatPlanMode: database.ChatPlanModePlan, Valid: true}, + InitialUserContent: []codersdk.ChatMessagePart{ + codersdk.ChatMessageText("propose a plan with a bad path"), + }, + }) + require.NoError(t, err) + waitForChatStatus(ctx, t, db, chat.ID, database.ChatStatusWaiting) + require.Equal(t, int32(2), streamedCallCount.Load(), + "error stop after tool result should not finish the turn by itself") + + parts := chatToolParts(ctx, t, db, chat.ID) + result := requireToolResultPart(t, parts, "propose_plan") + require.True(t, result.IsError) + messages := chatMessages(ctx, t, db, chat.ID) + requireTextPart(t, messages[len(messages)-1], "tool failed, continue") + }) +} + func TestDynamicToolCallPausesAndResumes(t *testing.T) { t.Parallel() @@ -4521,19 +5051,10 @@ func TestStoppedWorkspaceWithPersistedAgentBindingDoesNotBlockChat(t *testing.T) }) require.NoError(t, err) - // Close the inactive server so its wake-triggered processing - // stops and releases the chat. Then reset to pending so the - // active server (created below) can acquire it cleanly. + // Close the inactive server. The chat remains in the valid + // state-machine `running` state created by CreateChat, and the + // active server created below can acquire it because it is unowned. require.NoError(t, inactive.Close()) - _, err = db.UpdateChatStatus(ctx, database.UpdateChatStatusParams{ - ID: chat.ID, - Status: database.ChatStatusPending, - WorkerID: uuid.NullUUID{}, - StartedAt: sql.NullTime{}, - HeartbeatAt: sql.NullTime{}, - LastError: pqtype.NullRawMessage{}, - }) - require.NoError(t, err) build, err := db.GetLatestWorkspaceBuildByWorkspaceID(ctx, ws.ID) require.NoError(t, err) @@ -4701,6 +5222,13 @@ func TestHeartbeatBumpsWorkspaceUsage(t *testing.T) { Transition: database.WorkspaceTransitionStart, Deadline: dbtime.Now().Add(-30 * time.Minute), }) + res := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{ + Transition: database.WorkspaceTransitionStart, + JobID: pj.ID, + }) + dbAgent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{ + ResourceID: res.ID, + }) originalDeadline := build.Deadline // Set up a short heartbeat interval and a UsageTracker that @@ -4748,15 +5276,20 @@ func TestHeartbeatBumpsWorkspaceUsage(t *testing.T) { require.NoError(t, err) // Wait for the chat to start processing and at least one - // heartbeat to fire. + // runner heartbeat to be written. testutil.Eventually(ctx, t, func(ctx context.Context) bool { fromDB, listErr := db.GetChatByID(ctx, chat.ID) - if listErr != nil { + if listErr != nil || fromDB.Status != database.ChatStatusRunning || !fromDB.RunnerID.Valid { return false } - return fromDB.Status == database.ChatStatusRunning && - fromDB.HeartbeatAt.Valid && - fromDB.HeartbeatAt.Time.After(fromDB.CreatedAt) + heartbeat, heartbeatErr := db.GetChatHeartbeat(ctx, database.GetChatHeartbeatParams{ + ChatID: chat.ID, + RunnerID: fromDB.RunnerID.UUID, + }) + if heartbeatErr != nil { + return false + } + return heartbeat.HeartbeatAt.After(fromDB.CreatedAt) }, testutil.IntervalFast, "chat should be running with at least one heartbeat") @@ -4775,21 +5308,36 @@ func TestHeartbeatBumpsWorkspaceUsage(t *testing.T) { }) require.NoError(t, err) - // The heartbeat re-reads the workspace association from the DB - // on each tick. Wait for the tracker to pick it up. - testutil.Eventually(ctx, t, func(ctx context.Context) bool { - select { - case flushTick <- time.Now(): - case <-ctx.Done(): - return false - } - select { - case c := <-flushDone: - return c > 0 - case <-ctx.Done(): - return false - } - }, testutil.IntervalMedium, + ctrl := gomock.NewController(t) + mockConn := agentconnmock.NewMockAgentConn(ctrl) + mockConn.EXPECT().SetExtraHeaders(gomock.Any()).AnyTimes() + mockConn.EXPECT().AwaitReachable(gomock.Any()).Return(true).AnyTimes() + serverWithConn := chatd.New(chatd.Config{ + Logger: logger, + Database: authzDB, + ReplicaID: uuid.New(), + Pubsub: ps, + InFlightChatStaleAfter: testutil.WaitLong, + UsageTracker: tracker, + AgentConn: func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) { + require.Equal(t, dbAgent.ID, agentID) + return mockConn, func() {}, nil + }, + }) + t.Cleanup(func() { + require.NoError(t, serverWithConn.Close()) + }) + connChat, err := db.GetChatByID(ctx, chat.ID) + require.NoError(t, err) + workspaceCtx := chatd.NewTurnWorkspaceContextForTest(serverWithConn, connChat) + _, err = workspaceCtx.GetWorkspaceConn(chatCtx) + require.NoError(t, err) + + // Acquiring the workspace connection should bump usage for the linked + // workspace. + testutil.RequireSend(ctx, t, flushTick, time.Now()) + count = testutil.RequireReceive(ctx, t, flushDone) + require.Greater(t, count, 0, "expected usage tracker to flush the late-associated workspace") // Verify the workspace's last_used_at was actually updated. @@ -4868,16 +5416,21 @@ func TestHeartbeatNoWorkspaceNoBump(t *testing.T) { }) require.NoError(t, err) - // Wait for the chat to be acquired and at least one heartbeat - // to fire. + // Wait for the chat to be acquired and at least one runner + // heartbeat to be written. testutil.Eventually(ctx, t, func(ctx context.Context) bool { fromDB, listErr := db.GetChatByID(ctx, chat.ID) - if listErr != nil { + if listErr != nil || fromDB.Status != database.ChatStatusRunning || !fromDB.RunnerID.Valid { return false } - return fromDB.Status == database.ChatStatusRunning && - fromDB.HeartbeatAt.Valid && - fromDB.HeartbeatAt.Time.After(fromDB.CreatedAt) + heartbeat, heartbeatErr := db.GetChatHeartbeat(ctx, database.GetChatHeartbeatParams{ + ChatID: chat.ID, + RunnerID: fromDB.RunnerID.UUID, + }) + if heartbeatErr != nil { + return false + } + return heartbeat.HeartbeatAt.After(fromDB.CreatedAt) }, testutil.IntervalFast, "chat should be running with at least one heartbeat") @@ -4945,6 +5498,2350 @@ func newTestServer( return server } +func highUsageTextResponse(text string) chattest.AnthropicResponse { + return chattest.AnthropicStreamingResponse(chattest.AnthropicTextChunksWithCacheUsage(chattest.AnthropicUsage{ + InputTokens: 80, + OutputTokens: 5, + }, text)...) +} + +func anthropicCompactionResponse(text string) chattest.AnthropicResponse { + return chattest.AnthropicResponse{Response: &chattest.AnthropicMessage{ + ID: "msg-compaction", + Type: "message", + Role: "assistant", + Content: text, + Model: "claude-3-opus-20240229", + StopReason: "end_turn", + }} +} + +func highUsageReadFileResponse(path string) chattest.AnthropicResponse { + chunks := chattest.AnthropicToolCallChunks("read_file", fmt.Sprintf(`{"path":%q}`, path)) + for i := range chunks { + if chunks[i].Type == "message_start" { + chunks[i].Message.Usage = map[string]int{"input_tokens": 80} + } + if chunks[i].Type == "message_delta" { + chunks[i].UsageMap = map[string]int{"output_tokens": 5} + } + } + return chattest.AnthropicStreamingResponse(chunks...) +} + +func TestActiveServer_CompactionRecordsMetric(t *testing.T) { + t.Parallel() + + const ( + compactionSummary = "summary text for compaction" + contextLimit = int64(100) + thresholdPercent = int32(70) + ) + + ctx := testutil.Context(t, testutil.WaitLong) + db, ps := dbtestutil.NewDB(t) + reg := prometheus.NewRegistry() + var streamCount atomic.Int32 + anthropicURL := chattest.NewAnthropic(t, func(req *chattest.AnthropicRequest) chattest.AnthropicResponse { + body := anthropicRequestBody(t, *req) + if !req.Stream { + if strings.Contains(body, "You are performing a context compaction") { + return anthropicCompactionResponse(compactionSummary) + } + return chattest.AnthropicNonStreamingResponse("title") + } + switch streamCount.Add(1) { + case 1: + return highUsageReadFileResponse("/tmp/a.txt") + case 2: + require.Contains(t, body, compactionSummary) + return chattest.AnthropicStreamingResponse(chattest.AnthropicTextChunksWithCacheUsage(chattest.AnthropicUsage{ + InputTokens: 20, + OutputTokens: 5, + }, "continued after compaction")...) + default: + t.Fatalf("unexpected generation request: %s", body) + return chattest.AnthropicStreamingResponse() + } + }) + user, org, model := seedAnthropicChatDependencies(t, db, anthropicURL) + model = updateChatModelCompressionThreshold(t, db, model, contextLimit, thresholdPercent) + ws, dbAgent := seedWorkspaceWithAgent(t, db, user.ID) + + ctrl := gomock.NewController(t) + mockConn := agentconnmock.NewMockAgentConn(ctrl) + setupToolExecutionAgentConn(t, mockConn) + mockConn.EXPECT().ReadFileLines(gomock.Any(), "/tmp/a.txt", int64(1), int64(0), gomock.Any()). + Return(workspacesdk.ReadFileLinesResponse{Success: true, FileSize: 12, TotalLines: 1, LinesRead: 1, Content: "1\tpackage main"}, nil). + Times(1) + + server := newActiveTestServer(t, db, ps, func(cfg *chatd.Config) { + cfg.PrometheusRegistry = reg + cfg.AgentConn = func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) { + require.Equal(t, dbAgent.ID, agentID) + return mockConn, func() {}, nil + } + }) + chat, err := server.CreateChat(ctx, chatd.CreateOptions{ + OrganizationID: org.ID, + OwnerID: user.ID, + WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true}, + AgentID: uuid.NullUUID{UUID: dbAgent.ID, Valid: true}, + Title: "compaction-metric", + ModelConfigID: model.ID, + InitialUserContent: []codersdk.ChatMessagePart{ + codersdk.ChatMessageText("read the file and continue"), + }, + }) + require.NoError(t, err) + waitForChatStatus(ctx, t, db, chat.ID, database.ChatStatusWaiting) + + requireChatdMetricCounter(t, reg, "coderd_chatd_compaction_total", 1, map[string]string{ + "provider": "anthropic", + "model": "claude-sonnet-4-20250514", + "result": "success", + }) +} + +func TestActiveServer_Compaction(t *testing.T) { + t.Parallel() + + const ( + compactionSummary = "summary text for compaction" + contextLimit = int64(100) + thresholdPercent = int32(70) + ) + + newHighUsageReadFileResponse := func(path string) chattest.AnthropicResponse { + chunks := chattest.AnthropicToolCallChunks("read_file", fmt.Sprintf(`{"path":%q}`, path)) + for i := range chunks { + if chunks[i].Type == "message_start" { + chunks[i].Message.Usage = map[string]int{"input_tokens": 80} + } + if chunks[i].Type == "message_delta" { + chunks[i].UsageMap = map[string]int{"output_tokens": 5} + } + } + return chattest.AnthropicStreamingResponse(chunks...) + } + + t.Run("commits summary when threshold reached and continues from committed summary", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + db, ps := dbtestutil.NewDB(t) + requests := newAnthropicRequestRecorder() + var streamCount atomic.Int32 + anthropicURL := chattest.NewAnthropic(t, func(req *chattest.AnthropicRequest) chattest.AnthropicResponse { + requests.record(req) + body := anthropicRequestBody(t, *req) + if !req.Stream { + if strings.Contains(body, "You are performing a context compaction") { + require.Contains(t, body, "read_file") + require.Contains(t, body, "package main") + return anthropicCompactionResponse(compactionSummary) + } + return chattest.AnthropicNonStreamingResponse("title") + } + switch streamCount.Add(1) { + case 1: + return newHighUsageReadFileResponse("/tmp/a.txt") + default: + require.Contains(t, body, compactionSummary) + require.Contains(t, body, "The following is a summary of the earlier conversation") + require.Contains(t, body, `"role":"user"`) + return chattest.AnthropicStreamingResponse(chattest.AnthropicTextChunksWithCacheUsage(chattest.AnthropicUsage{ + InputTokens: 20, + OutputTokens: 5, + }, "continued after compaction")...) + } + }) + user, org, model := seedAnthropicChatDependencies(t, db, anthropicURL) + model = updateChatModelCompressionThreshold(t, db, model, contextLimit, thresholdPercent) + ws, dbAgent := seedWorkspaceWithAgent(t, db, user.ID) + + ctrl := gomock.NewController(t) + mockConn := agentconnmock.NewMockAgentConn(ctrl) + setupToolExecutionAgentConn(t, mockConn) + mockConn.EXPECT().ReadFileLines(gomock.Any(), "/tmp/a.txt", int64(1), int64(0), gomock.Any()). + Return(workspacesdk.ReadFileLinesResponse{Success: true, FileSize: 12, TotalLines: 1, LinesRead: 1, Content: "1 package main"}, nil). + Times(1) + + server := newActiveTestServer(t, db, ps, func(cfg *chatd.Config) { + cfg.AgentConn = func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) { + require.Equal(t, dbAgent.ID, agentID) + return mockConn, func() {}, nil + } + }) + chat, err := server.CreateChat(ctx, chatd.CreateOptions{ + OrganizationID: org.ID, + OwnerID: user.ID, + WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true}, + AgentID: uuid.NullUUID{UUID: dbAgent.ID, Valid: true}, + Title: "compaction-continues", + ModelConfigID: model.ID, + InitialUserContent: []codersdk.ChatMessagePart{ + codersdk.ChatMessageText("read the file and continue"), + }, + }) + require.NoError(t, err) + chat = waitForChatStatus(ctx, t, db, chat.ID, database.ChatStatusWaiting) + require.False(t, chat.WorkerID.Valid) + require.False(t, chat.RunnerID.Valid) + + generationRequests := filterAnthropicStreamingRequests(requests.all()) + require.GreaterOrEqual(t, len(generationRequests), 2) + require.Equal(t, int32(2), streamCount.Load()) + + messages := chatMessages(ctx, t, db, chat.ID) + promptMessages, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID) + require.NoError(t, err) + compressed := compressedChatSummarizedMessages(t, append(promptMessages, messages...)) + require.Len(t, compressed.summaries, 1) + require.Len(t, compressed.calls, 1) + require.Len(t, compressed.results, 1) + + require.Equal(t, database.ChatMessageRoleUser, compressed.summaries[0].Role) + require.Equal(t, database.ChatMessageVisibilityModel, compressed.summaries[0].Visibility) + summaryText := messageText(t, compressed.summaries[0]) + require.Contains(t, summaryText, "The following is a summary of the earlier conversation") + require.Contains(t, summaryText, compactionSummary) + + callPart := singlePartOfType(t, compressed.calls[0], codersdk.ChatMessagePartTypeToolCall) + resultPart := singlePartOfType(t, compressed.results[0], codersdk.ChatMessagePartTypeToolResult) + require.Equal(t, callPart.ToolCallID, resultPart.ToolCallID) + require.Equal(t, "chat_summarized", resultPart.ToolName) + require.JSONEq(t, `{"summary":"summary text for compaction","source":"automatic","threshold_percent":70,"usage_percent":80,"context_tokens":80,"context_limit_tokens":100}`, string(resultPart.Result)) + requireTextPart(t, messages[len(messages)-1], "continued after compaction") + }) + + t.Run("does not compact when high usage finishes the turn", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + db, ps := dbtestutil.NewDB(t) + var streamCount atomic.Int32 + var compactionRequests atomic.Int32 + anthropicURL := chattest.NewAnthropic(t, func(req *chattest.AnthropicRequest) chattest.AnthropicResponse { + body := anthropicRequestBody(t, *req) + if strings.Contains(body, "You are performing a context compaction") { + compactionRequests.Add(1) + return anthropicCompactionResponse(compactionSummary) + } + if !req.Stream { + return chattest.AnthropicNonStreamingResponse("title") + } + streamCount.Add(1) + return highUsageTextResponse("done without compaction") + }) + user, org, model := seedAnthropicChatDependencies(t, db, anthropicURL) + model = updateChatModelCompressionThreshold(t, db, model, contextLimit, thresholdPercent) + + server := newActiveTestServer(t, db, ps) + chat := createChatThroughServer(ctx, t, server, org.ID, user.ID, model.ID, "finish with high usage") + waitForChatStatus(ctx, t, db, chat.ID, database.ChatStatusWaiting) + + require.Equal(t, int32(1), streamCount.Load()) + require.Equal(t, int32(0), compactionRequests.Load()) + messages := chatMessages(ctx, t, db, chat.ID) + compressed := compressedChatSummarizedMessages(t, messages) + require.Empty(t, compressed.summaries) + require.Empty(t, compressed.calls) + require.Empty(t, compressed.results) + for _, msg := range messages { + require.False(t, msg.Compressed, "message %d should not be compressed", msg.ID) + } + requireTextPart(t, messages[len(messages)-1], "done without compaction") + }) + + t.Run("fails when compaction leaves chat over limit", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + db, ps := dbtestutil.NewDB(t) + var streamCount atomic.Int32 + anthropicURL := chattest.NewAnthropic(t, func(req *chattest.AnthropicRequest) chattest.AnthropicResponse { + body := anthropicRequestBody(t, *req) + if !req.Stream { + if strings.Contains(body, "You are performing a context compaction") { + return anthropicCompactionResponse(compactionSummary) + } + return chattest.AnthropicNonStreamingResponse("title") + } + switch streamCount.Add(1) { + case 1: + return newHighUsageReadFileResponse("/tmp/a.txt") + default: + require.Contains(t, body, compactionSummary) + return highUsageTextResponse("still too large") + } + }) + user, org, model := seedAnthropicChatDependencies(t, db, anthropicURL) + model = updateChatModelCompressionThreshold(t, db, model, contextLimit, thresholdPercent) + ws, dbAgent := seedWorkspaceWithAgent(t, db, user.ID) + + ctrl := gomock.NewController(t) + mockConn := agentconnmock.NewMockAgentConn(ctrl) + setupToolExecutionAgentConn(t, mockConn) + mockConn.EXPECT().ReadFileLines(gomock.Any(), "/tmp/a.txt", int64(1), int64(0), gomock.Any()). + Return(workspacesdk.ReadFileLinesResponse{Success: true, FileSize: 12, TotalLines: 1, LinesRead: 1, Content: "1 package main"}, nil). + Times(1) + + server := newActiveTestServer(t, db, ps, func(cfg *chatd.Config) { + cfg.AgentConn = func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) { + require.Equal(t, dbAgent.ID, agentID) + return mockConn, func() {}, nil + } + }) + chat, err := server.CreateChat(ctx, chatd.CreateOptions{ + OrganizationID: org.ID, + OwnerID: user.ID, + WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true}, + AgentID: uuid.NullUUID{UUID: dbAgent.ID, Valid: true}, + Title: "compaction-still-over-limit", + ModelConfigID: model.ID, + InitialUserContent: []codersdk.ChatMessagePart{ + codersdk.ChatMessageText("read the file and stay too large"), + }, + }) + require.NoError(t, err) + chat = waitForChatStatus(ctx, t, db, chat.ID, database.ChatStatusError) + require.Contains(t, chatLastErrorMessage(chat.LastError), "The chat request failed unexpectedly.") + }) +} + +type compressedCompactionMessages struct { + summaries []database.ChatMessage + calls []database.ChatMessage + results []database.ChatMessage +} + +func compressedChatSummarizedMessages(t *testing.T, messages []database.ChatMessage) compressedCompactionMessages { + t.Helper() + seen := map[int64]bool{} + var out compressedCompactionMessages + for _, msg := range messages { + if !msg.Compressed || seen[msg.ID] { + continue + } + seen[msg.ID] = true + parts, err := chatprompt.ParseContent(msg) + require.NoError(t, err) + for _, part := range parts { + switch part.Type { + case codersdk.ChatMessagePartTypeText: + if msg.Role == database.ChatMessageRoleUser { + out.summaries = append(out.summaries, msg) + } + case codersdk.ChatMessagePartTypeToolCall: + if part.ToolName == "chat_summarized" { + out.calls = append(out.calls, msg) + } + case codersdk.ChatMessagePartTypeToolResult: + if part.ToolName == "chat_summarized" { + out.results = append(out.results, msg) + } + } + } + } + return out +} + +func messageText(t *testing.T, msg database.ChatMessage) string { + t.Helper() + parts, err := chatprompt.ParseContent(msg) + require.NoError(t, err) + var builder strings.Builder + for _, part := range parts { + if part.Type == codersdk.ChatMessagePartTypeText { + _, _ = builder.WriteString(part.Text) + } + } + return builder.String() +} + +func singlePartOfType(t *testing.T, msg database.ChatMessage, typ codersdk.ChatMessagePartType) codersdk.ChatMessagePart { + t.Helper() + parts, err := chatprompt.ParseContent(msg) + require.NoError(t, err) + var matches []codersdk.ChatMessagePart + for _, part := range parts { + if part.Type == typ { + matches = append(matches, part) + } + } + require.Len(t, matches, 1) + return matches[0] +} + +func TestActiveServer_BasicAssistantGenerationAndPromptPreparation(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + db, ps := dbtestutil.NewDB(t) + requests := newAnthropicRequestRecorder() + anthropicURL := chattest.NewAnthropic(t, func(req *chattest.AnthropicRequest) chattest.AnthropicResponse { + requests.record(req) + return chattest.AnthropicStreamingResponse(chattest.AnthropicTextChunks("done")...) + }) + user, org, model := seedAnthropicChatDependencies(t, db, anthropicURL) + model.ContextLimit = 4096 + model = updateChatModelContextLimit(t, db, model) + + server := newActiveTestServer(t, db, ps) + chat := createChatThroughServer(ctx, t, server, org.ID, user.ID, model.ID, "hello") + waitForChatStatus(ctx, t, db, chat.ID, database.ChatStatusWaiting) + insertSystemTextMessage(ctx, t, db, chat.ID, "sys-2", model.ID) + insertAssistantTextMessage(ctx, t, db, chat.ID, "working", model.ID) + _, err := server.SendMessage(ctx, chatd.SendMessageOptions{ + ChatID: chat.ID, + CreatedBy: user.ID, + ModelConfigID: model.ID, + Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("continue")}, + BusyBehavior: chatd.SendMessageBusyBehaviorQueue, + }) + require.NoError(t, err) + + waitForChatStatus(ctx, t, db, chat.ID, database.ChatStatusWaiting) + + generationRequests := filterAnthropicStreamingRequests(requests.all()) + require.Len(t, generationRequests, 2) + recovered := generationRequests[1] + require.True(t, anthropicSystemHasEphemeralCacheControl(t, recovered)) + require.Len(t, recovered.Messages, 4) + require.False(t, anthropicMessageHasEphemeralCacheControl(t, recovered.Messages[0])) + require.False(t, anthropicMessageHasEphemeralCacheControl(t, recovered.Messages[1])) + require.True(t, anthropicMessageHasEphemeralCacheControl(t, recovered.Messages[2])) + require.True(t, anthropicMessageHasEphemeralCacheControl(t, recovered.Messages[3])) + require.False(t, anthropicRequestContainsPromptSentinel(t, recovered)) + toolNames := anthropicRequestToolNames(recovered) + require.Contains(t, toolNames, "read_file") + require.Contains(t, toolNames, "write_file") + + messages := chatMessages(ctx, t, db, chat.ID) + last := messages[len(messages)-1] + require.Equal(t, database.ChatMessageRoleAssistant, last.Role) + require.True(t, last.ContextLimit.Valid) + require.Equal(t, int64(4096), last.ContextLimit.Int64) + require.GreaterOrEqual(t, last.RuntimeMs.Int64, int64(0)) + requireTextPart(t, last, "done") + + requests = newAnthropicRequestRecorder() + server = newActiveTestServer(t, db, ps) + planChat := createPlanSubagentChatWithHistory(ctx, t, db, org.ID, user.ID, model.ID) + _, err = server.SendMessage(ctx, chatd.SendMessageOptions{ + ChatID: planChat.ID, + CreatedBy: user.ID, + ModelConfigID: model.ID, + Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("continue")}, + BusyBehavior: chatd.SendMessageBusyBehaviorQueue, + }) + require.NoError(t, err) + waitForChatStatus(ctx, t, db, planChat.ID, database.ChatStatusWaiting) + + planRequests := filterAnthropicStreamingRequests(requests.all()) + require.Len(t, planRequests, 1) + toolNames = anthropicRequestToolNames(planRequests[0]) + require.Contains(t, toolNames, "read_file") + require.NotContains(t, toolNames, "write_file") +} + +func TestActiveServer_ToolExecutionAndPolicy(t *testing.T) { + t.Parallel() + + t.Run("rejects disallowed active tool", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + db, ps := dbtestutil.NewDB(t) + var streamedCallCount atomic.Int32 + openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { + if !req.Stream { + return chattest.OpenAINonStreamingResponse("title") + } + if streamedCallCount.Add(1) == 1 { + return chattest.OpenAIStreamingResponse( + chattest.OpenAIToolCallChunk("write_file", `{"path":"/tmp/nope","content":"blocked"}`), + ) + } + return chattest.OpenAIStreamingResponse(chattest.OpenAITextChunks("done")...) + }) + user, org, model := seedChatDependenciesWithProvider(t, db, "openai-compat", openAIURL) + ws, dbAgent := seedWorkspaceWithAgent(t, db, user.ID) + + ctrl := gomock.NewController(t) + mockConn := agentconnmock.NewMockAgentConn(ctrl) + setupToolExecutionAgentConn(t, mockConn) + mockConn.EXPECT().WriteFile(gomock.Any(), gomock.Any(), gomock.Any()).Times(0) + + server := newActiveTestServer(t, db, ps, func(cfg *chatd.Config) { + cfg.AgentConn = func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) { + require.Equal(t, dbAgent.ID, agentID) + return mockConn, func() {}, nil + } + }) + chat, err := server.CreateChat(ctx, chatd.CreateOptions{ + OrganizationID: org.ID, + OwnerID: user.ID, + WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true}, + AgentID: uuid.NullUUID{UUID: dbAgent.ID, Valid: true}, + Title: "active-tool-reject", + ModelConfigID: model.ID, + ChatMode: database.NullChatMode{ChatMode: database.ChatModeExplore, Valid: true}, + InitialUserContent: []codersdk.ChatMessagePart{ + codersdk.ChatMessageText("try to write a file"), + }, + }) + require.NoError(t, err) + waitForChatStatus(ctx, t, db, chat.ID, database.ChatStatusWaiting) + + parts := chatToolParts(ctx, t, db, chat.ID) + result := requireToolResultPart(t, parts, "write_file") + require.True(t, result.IsError) + require.JSONEq(t, `{"error":"Tool not active in this turn: write_file"}`, string(result.Result)) + }) + + t.Run("provider runner executes and preserves metadata", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + db, ps := dbtestutil.NewDB(t) + const computerResultMetadata = `{"openai":{"type":"openai.responses.computer_call_output_options","data":{"detail":"original"}}}` + var streamedCallCount atomic.Int32 + var secondRawBody []byte + var callsMu sync.Mutex + openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { + if !req.Stream { + return chattest.OpenAINonStreamingResponse("title") + } + if streamedCallCount.Add(1) == 1 { + callsMu.Lock() + secondRawBody = append([]byte(nil), req.RawBody...) + callsMu.Unlock() + } + return chattest.OpenAIStreamingResponse(chattest.OpenAITextChunks("done")...) + }) + user, org, _, model := seedChatDependenciesWithProviderPolicy(t, db, "openai", openAIURL, "test-key", true, false, true) + model.Model = "gpt-5.5" + model = updateChatModelContextLimit(t, db, model) + + server := newActiveTestServer(t, db, ps, func(cfg *chatd.Config) { cfg.AllowBYOKSet = true; cfg.AllowBYOK = false }) + result := codersdk.ChatMessageToolResult( + "computer-call", + "computer", + json.RawMessage(`{"data":"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR4nGP4n539HwAHFwLVF8kc1wAAAABJRU5ErkJggg==","mime_type":"image/png"}`), + false, + true, + ) + result.ProviderMetadata = json.RawMessage(computerResultMetadata) + computerCall := codersdk.ChatMessageToolCall( + "computer-call", + "computer", + json.RawMessage(`{"type":"screenshot"}`), + ) + computerCall.ProviderExecuted = true + created, err := chatstate.CreateChat(dbauthz.AsSystemRestricted(ctx), db, ps, chatstate.CreateChatInput{ + OrganizationID: org.ID, + OwnerID: user.ID, + LastModelConfigID: model.ID, + Title: "provider-runner-replay-active", + MCPServerIDs: []uuid.UUID{}, + ClientType: database.ChatClientTypeApi, + InitialMessages: []chatstate.Message{ + userMessageForTest(t, "use provider runner", model.ID, user.ID), + assistantMessageForTest(t, []codersdk.ChatMessagePart{computerCall}, model.ID), + toolMessageForTest(t, []codersdk.ChatMessagePart{result}, model.ID), + }, + }) + require.NoError(t, err) + chat := created.Chat + _, err = server.SendMessage(ctx, chatd.SendMessageOptions{ + ChatID: chat.ID, + CreatedBy: user.ID, + ModelConfigID: model.ID, + Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("continue")}, + BusyBehavior: chatd.SendMessageBusyBehaviorQueue, + }) + require.NoError(t, err) + waitForTerminalChat(ctx, t, db, chat.ID) + gotChat, gotErr := db.GetChatByID(ctx, chat.ID) + require.NoError(t, gotErr) + require.Equal(t, database.ChatStatusWaiting, gotChat.Status) + require.Eventually(t, func() bool { return streamedCallCount.Load() >= 1 }, testutil.WaitShort, testutil.IntervalFast) + + callsMu.Lock() + body := string(secondRawBody) + callsMu.Unlock() + require.Contains(t, body, "computer_call_output") + require.Contains(t, body, `"detail":"original"`) + }) + + t.Run("multi step local tool execution", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + db, ps := dbtestutil.NewDB(t) + var streamedCallCount atomic.Int32 + var secondCallMessages []chattest.OpenAIMessage + var callsMu sync.Mutex + openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { + if !req.Stream { + return chattest.OpenAINonStreamingResponse("title") + } + if streamedCallCount.Add(1) == 1 { + return chattest.OpenAIStreamingResponse( + chattest.OpenAIToolCallChunk("read_file", `{"path":"/tmp/a.txt"}`), + ) + } + callsMu.Lock() + secondCallMessages = append([]chattest.OpenAIMessage(nil), req.Messages...) + callsMu.Unlock() + return chattest.OpenAIStreamingResponse(chattest.OpenAITextChunks("all done")...) + }) + user, org, model := seedChatDependenciesWithProvider(t, db, "openai-compat", openAIURL) + ws, dbAgent := seedWorkspaceWithAgent(t, db, user.ID) + + ctrl := gomock.NewController(t) + mockConn := agentconnmock.NewMockAgentConn(ctrl) + setupToolExecutionAgentConn(t, mockConn) + mockConn.EXPECT().ReadFileLines(gomock.Any(), "/tmp/a.txt", int64(1), int64(0), gomock.Any()). + Return(workspacesdk.ReadFileLinesResponse{ + Success: true, FileSize: 12, TotalLines: 1, LinesRead: 1, Content: "1\tpackage main", + }, nil). + Times(1) + + server := newActiveTestServer(t, db, ps, func(cfg *chatd.Config) { + cfg.AgentConn = func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) { + require.Equal(t, dbAgent.ID, agentID) + return mockConn, func() {}, nil + } + }) + chat, err := server.CreateChat(ctx, chatd.CreateOptions{ + OrganizationID: org.ID, + OwnerID: user.ID, + WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true}, + AgentID: uuid.NullUUID{UUID: dbAgent.ID, Valid: true}, + Title: "multi-step-tool", + ModelConfigID: model.ID, + InitialUserContent: []codersdk.ChatMessagePart{ + codersdk.ChatMessageText("read the file"), + }, + }) + require.NoError(t, err) + waitForChatStatus(ctx, t, db, chat.ID, database.ChatStatusWaiting) + + require.GreaterOrEqual(t, streamedCallCount.Load(), int32(2)) + parts := chatToolParts(ctx, t, db, chat.ID) + call := requireToolCallPart(t, parts, "read_file") + result := requireToolResultPart(t, parts, "read_file") + require.False(t, result.IsError) + require.NotNil(t, call.CreatedAt) + require.NotNil(t, result.CreatedAt) + require.False(t, result.CreatedAt.Before(*call.CreatedAt)) + messages := chatMessages(ctx, t, db, chat.ID) + requireTextPart(t, messages[len(messages)-1], "all done") + + callsMu.Lock() + secondMessages := append([]chattest.OpenAIMessage(nil), secondCallMessages...) + callsMu.Unlock() + require.NotEmpty(t, secondMessages) + require.True(t, openAIMessagesContain(secondMessages, "1\\tpackage main")) + }) + + t.Run("parallel local and provider executed timestamps", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + db, ps := dbtestutil.NewDB(t) + webSearchEnabled := true + var streamedCallCount atomic.Int32 + openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { + if !req.Stream { + return chattest.OpenAINonStreamingResponse("title") + } + if streamedCallCount.Add(1) == 1 { + readA := chattest.OpenAIToolCallChunk("read_file", `{"path":"/tmp/a.txt"}`) + readB := chattest.OpenAIToolCallChunk("read_file", `{"path":"/tmp/b.txt"}`) + second := readB.Choices[0].ToolCalls[0] + second.Index = 1 + readA.Choices[0].ToolCalls = append(readA.Choices[0].ToolCalls, second) + return chattest.OpenAIResponse{ + StreamingChunks: chattest.OpenAIStreamingResponse(readA).StreamingChunks, + WebSearch: &chattest.OpenAIWebSearchCall{ID: "ws-timestamps", Query: "coder"}, + } + } + return chattest.OpenAIStreamingResponse(chattest.OpenAITextChunks("done")...) + }) + user, org, model := seedChatDependenciesWithProvider(t, db, "openai", openAIURL) + model = updateChatModelCallConfig(t, db, model, codersdk.ChatModelCallConfig{ + ProviderOptions: &codersdk.ChatModelProviderOptions{ + OpenAI: &codersdk.ChatModelOpenAIProviderOptions{WebSearchEnabled: &webSearchEnabled}, + }, + }) + ws, dbAgent := seedWorkspaceWithAgent(t, db, user.ID) + + ctrl := gomock.NewController(t) + mockConn := agentconnmock.NewMockAgentConn(ctrl) + setupToolExecutionAgentConn(t, mockConn) + mockConn.EXPECT().ReadFileLines(gomock.Any(), "/tmp/a.txt", int64(1), int64(0), gomock.Any()). + Return(workspacesdk.ReadFileLinesResponse{Success: true, Content: "a", FileSize: 1, TotalLines: 1, LinesRead: 1}, nil). + Times(1) + mockConn.EXPECT().ReadFileLines(gomock.Any(), "/tmp/b.txt", int64(1), int64(0), gomock.Any()). + Return(workspacesdk.ReadFileLinesResponse{Success: true, Content: "b", FileSize: 1, TotalLines: 1, LinesRead: 1}, nil). + Times(1) + + server := newActiveTestServer(t, db, ps, func(cfg *chatd.Config) { + cfg.AgentConn = func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) { + require.Equal(t, dbAgent.ID, agentID) + return mockConn, func() {}, nil + } + }) + chat, err := server.CreateChat(ctx, chatd.CreateOptions{ + OrganizationID: org.ID, + OwnerID: user.ID, + WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true}, + AgentID: uuid.NullUUID{UUID: dbAgent.ID, Valid: true}, + Title: "parallel-timestamps", + ModelConfigID: model.ID, + InitialUserContent: []codersdk.ChatMessagePart{ + codersdk.ChatMessageText("search and read files"), + }, + }) + require.NoError(t, err) + waitForChatStatus(ctx, t, db, chat.ID, database.ChatStatusWaiting) + + parts := chatToolParts(ctx, t, db, chat.ID) + for _, toolName := range []string{"read_file", "web_search"} { + call := requireToolCallPart(t, parts, toolName) + result := requireToolResultPart(t, parts, toolName) + require.NotNil(t, call.CreatedAt, toolName) + require.NotNil(t, result.CreatedAt, toolName) + require.False(t, result.CreatedAt.Before(*call.CreatedAt), toolName) + if toolName == "web_search" { + require.True(t, call.ProviderExecuted) + require.True(t, result.ProviderExecuted) + } else { + require.False(t, call.ProviderExecuted) + require.False(t, result.ProviderExecuted) + } + } + }) +} + +func TestActiveServer_RecordsGenerationMetrics(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + db, ps := dbtestutil.NewDB(t) + reg := prometheus.NewRegistry() + openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { + if !req.Stream { + return chattest.OpenAINonStreamingResponse("title") + } + return chattest.OpenAIStreamingResponse(openAITextChunksWithStop("hello")...) + }) + user, org, model := seedChatDependenciesWithProvider(t, db, "openai", openAIURL) + server := newActiveTestServer(t, db, ps, func(cfg *chatd.Config) { + cfg.PrometheusRegistry = reg + }) + + chat := createChatThroughServer(ctx, t, server, org.ID, user.ID, model.ID, "hello") + waitForChatStatus(ctx, t, db, chat.ID, database.ChatStatusWaiting) + + requireChatdMetricCounter(t, reg, "coderd_chatd_steps_total", 1, map[string]string{ + "provider": "openai", + "model": "gpt-4o-mini", + }) + requireChatdMetricHistogram(t, reg, "coderd_chatd_message_count", 1, map[string]string{ + "provider": "openai", + "model": "gpt-4o-mini", + }, chatdMetricHistogramRequirement{}) + requireChatdMetricHistogram(t, reg, "coderd_chatd_prompt_size_bytes", 1, map[string]string{ + "provider": "openai", + "model": "gpt-4o-mini", + }, chatdMetricHistogramRequirement{PositiveSum: true}) + requireChatdMetricHistogram(t, reg, "coderd_chatd_ttft_seconds", 1, map[string]string{ + "provider": "openai", + "model": "gpt-4o-mini", + }, chatdMetricHistogramRequirement{}) +} + +func TestActiveServer_ToolErrorRecordsMetric(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + toolName string + toolArgs string + chatMode database.NullChatMode + setupAgent func(*agentconnmock.MockAgentConn) + }{ + { + name: "builtin tool IsError", + toolName: "read_file", + toolArgs: `{"path":"/tmp/missing.txt"}`, + setupAgent: func(mockConn *agentconnmock.MockAgentConn) { + mockConn.EXPECT().ReadFileLines(gomock.Any(), "/tmp/missing.txt", int64(1), int64(0), gomock.Any()). + Return(workspacesdk.ReadFileLinesResponse{Success: false, Error: "file not found"}, nil). + Times(1) + }, + }, + { + name: "non builtin MCP style tool IsError", + toolName: "dynamic_error_tool", + toolArgs: `{"input":"hello"}`, + setupAgent: func(mockConn *agentconnmock.MockAgentConn) { + mockConn.EXPECT().CallMCPTool(gomock.Any(), gomock.Any()). + Return(workspacesdk.CallMCPToolResponse{ + IsError: true, + Content: []workspacesdk.MCPToolContent{{ + Type: "text", + Text: "dynamic failed", + }}, + }, nil). + Times(1) + }, + }, + { + name: "tool Run returns error", + toolName: "read_file", + toolArgs: `{"path":"/tmp/error.txt"}`, + setupAgent: func(mockConn *agentconnmock.MockAgentConn) { + mockConn.EXPECT().ReadFileLines(gomock.Any(), "/tmp/error.txt", int64(1), int64(0), gomock.Any()). + Return(workspacesdk.ReadFileLinesResponse{}, xerrors.New("connection refused")). + Times(1) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + db, ps := dbtestutil.NewDB(t) + reg := prometheus.NewRegistry() + var streamedCallCount atomic.Int32 + openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { + if !req.Stream { + return chattest.OpenAINonStreamingResponse("title") + } + if streamedCallCount.Add(1) == 1 { + return chattest.OpenAIStreamingResponse( + chattest.OpenAIToolCallChunk(tt.toolName, tt.toolArgs), + ) + } + return chattest.OpenAIStreamingResponse(chattest.OpenAITextChunks("done")...) + }) + user, org, model := seedChatDependenciesWithProvider(t, db, "openai-compat", openAIURL) + model.Model = "test-model" + model = updateChatModelContextLimit(t, db, model) + ws, dbAgent := seedWorkspaceWithAgent(t, db, user.ID) + + ctrl := gomock.NewController(t) + mockConn := agentconnmock.NewMockAgentConn(ctrl) + setupToolExecutionAgentConn(t, mockConn, workspacesdk.MCPToolInfo{ + Name: "dynamic_error_tool", + Description: "dynamic error tool", + Schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "input": map[string]any{"type": "string"}, + }, + }, + }) + tt.setupAgent(mockConn) + + server := newActiveTestServer(t, db, ps, func(cfg *chatd.Config) { + cfg.PrometheusRegistry = reg + cfg.AgentConn = func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) { + require.Equal(t, dbAgent.ID, agentID) + return mockConn, func() {}, nil + } + }) + chatOpts := chatd.CreateOptions{ + OrganizationID: org.ID, + OwnerID: user.ID, + WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true}, + AgentID: uuid.NullUUID{UUID: dbAgent.ID, Valid: true}, + Title: "tool-error-metric", + ModelConfigID: model.ID, + ChatMode: tt.chatMode, + InitialUserContent: []codersdk.ChatMessagePart{ + codersdk.ChatMessageText("run an erroring tool"), + }, + } + chat, err := server.CreateChat(ctx, chatOpts) + require.NoError(t, err) + waitForChatStatus(ctx, t, db, chat.ID, database.ChatStatusWaiting) + + result := requireToolResultPart(t, chatToolParts(ctx, t, db, chat.ID), tt.toolName) + require.True(t, result.IsError) + requireChatdMetricCounter(t, reg, "coderd_chatd_tool_errors_total", 1, map[string]string{ + "provider": "openai-compat", + "model": "test-model", + "tool_name": tt.toolName, + }) + }) + } +} + +func userMessageForTest( + t *testing.T, + text string, + modelID uuid.UUID, + createdBy uuid.UUID, +) chatstate.Message { + t.Helper() + content, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{codersdk.ChatMessageText(text)}) + require.NoError(t, err) + return chatstate.Message{ + Role: database.ChatMessageRoleUser, + Content: content, + Visibility: database.ChatMessageVisibilityBoth, + ContentVersion: chatprompt.CurrentContentVersion, + ModelConfigID: uuid.NullUUID{UUID: modelID, Valid: true}, + CreatedBy: uuid.NullUUID{UUID: createdBy, Valid: true}, + } +} + +func assistantMessageForTest( + t *testing.T, + parts []codersdk.ChatMessagePart, + modelID uuid.UUID, +) chatstate.Message { + t.Helper() + content, err := chatprompt.MarshalParts(parts) + require.NoError(t, err) + return chatstate.Message{ + Role: database.ChatMessageRoleAssistant, + Content: content, + Visibility: database.ChatMessageVisibilityBoth, + ContentVersion: chatprompt.CurrentContentVersion, + ModelConfigID: uuid.NullUUID{UUID: modelID, Valid: true}, + } +} + +func toolMessageForTest( + t *testing.T, + parts []codersdk.ChatMessagePart, + modelID uuid.UUID, +) chatstate.Message { + t.Helper() + content, err := chatprompt.MarshalParts(parts) + require.NoError(t, err) + return chatstate.Message{ + Role: database.ChatMessageRoleTool, + Content: content, + Visibility: database.ChatMessageVisibilityBoth, + ContentVersion: chatprompt.CurrentContentVersion, + ModelConfigID: uuid.NullUUID{UUID: modelID, Valid: true}, + } +} + +func setupToolExecutionAgentConn( + t *testing.T, + mockConn *agentconnmock.MockAgentConn, + mcpTools ...workspacesdk.MCPToolInfo, +) { + t.Helper() + mockConn.EXPECT().SetExtraHeaders(gomock.Any()).AnyTimes() + mockConn.EXPECT().ContextConfig(gomock.Any()). + Return(workspacesdk.ContextConfigResponse{}, xerrors.New("not supported")).AnyTimes() + mockConn.EXPECT().ListMCPTools(gomock.Any()). + Return(workspacesdk.ListMCPToolsResponse{Tools: mcpTools}, nil).AnyTimes() + mockConn.EXPECT().LS(gomock.Any(), gomock.Any(), gomock.Any()). + Return(workspacesdk.LSResponse{AbsolutePathString: "/home/coder"}, nil).AnyTimes() + mockConn.EXPECT().ReadFile(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(io.NopCloser(strings.NewReader("")), "", nil).AnyTimes() +} + +func mustParseChatParts(t *testing.T, msg database.ChatMessage) []codersdk.ChatMessagePart { + t.Helper() + parts, err := chatprompt.ParseContent(msg) + require.NoError(t, err) + return parts +} + +func dynamicToolJSON(t *testing.T, name string) []byte { + t.Helper() + encoded, err := json.Marshal([]mcpgo.Tool{{ + Name: name, + Description: "A test dynamic tool.", + InputSchema: mcpgo.ToolInputSchema{ + Type: "object", + Properties: map[string]any{ + "query": map[string]any{"type": "string"}, + }, + }, + }}) + require.NoError(t, err) + return encoded +} + +func toolResultPartExists(parts []codersdk.ChatMessagePart, toolName string) bool { + for _, part := range parts { + if part.Type == codersdk.ChatMessagePartTypeToolResult && part.ToolName == toolName { + return true + } + } + return false +} + +func chatToolParts( + ctx context.Context, + t *testing.T, + db database.Store, + chatID uuid.UUID, +) []codersdk.ChatMessagePart { + t.Helper() + var parts []codersdk.ChatMessagePart + for _, msg := range chatMessages(ctx, t, db, chatID) { + parsed, err := chatprompt.ParseContent(msg) + require.NoError(t, err) + for _, part := range parsed { + if part.Type == codersdk.ChatMessagePartTypeToolCall || + part.Type == codersdk.ChatMessagePartTypeToolResult { + parts = append(parts, part) + } + } + } + return parts +} + +func requireToolCallPart( + t *testing.T, + parts []codersdk.ChatMessagePart, + toolName string, +) codersdk.ChatMessagePart { + t.Helper() + for _, part := range parts { + if part.Type == codersdk.ChatMessagePartTypeToolCall && part.ToolName == toolName { + return part + } + } + t.Fatalf("missing tool-call part for %q", toolName) + return codersdk.ChatMessagePart{} +} + +func requireToolResultPart( + t *testing.T, + parts []codersdk.ChatMessagePart, + toolName string, +) codersdk.ChatMessagePart { + t.Helper() + for _, part := range parts { + if part.Type == codersdk.ChatMessagePartTypeToolResult && part.ToolName == toolName { + return part + } + } + t.Fatalf("missing tool-result part for %q", toolName) + return codersdk.ChatMessagePart{} +} + +func openAIMessagesContain(messages []chattest.OpenAIMessage, text string) bool { + for _, msg := range messages { + if strings.Contains(msg.Content, text) { + return true + } + } + return false +} + +func requireChatdMetricCounter( + t *testing.T, + reg *prometheus.Registry, + name string, + wantValue float64, + wantLabels map[string]string, +) { + t.Helper() + families, err := reg.Gather() + require.NoError(t, err) + for _, family := range families { + if family.GetName() != name { + continue + } + for _, metric := range family.GetMetric() { + labels := metricLabels(metric) + if !metricLabelsMatch(labels, wantLabels) { + continue + } + require.Equal(t, wantValue, metric.GetCounter().GetValue()) + return + } + t.Fatalf("metric %s with labels %v not found", name, wantLabels) + } + t.Fatalf("metric %s not found", name) +} + +type chatdMetricHistogramRequirement struct { + PositiveSum bool +} + +func requireChatdMetricHistogram( + t *testing.T, + reg *prometheus.Registry, + name string, + wantSampleCount uint64, + wantLabels map[string]string, + requirement chatdMetricHistogramRequirement, +) { + t.Helper() + families, err := reg.Gather() + require.NoError(t, err) + for _, family := range families { + if family.GetName() != name { + continue + } + for _, metric := range family.GetMetric() { + labels := metricLabels(metric) + if !metricLabelsMatch(labels, wantLabels) { + continue + } + histogram := metric.GetHistogram() + require.Equal(t, wantSampleCount, histogram.GetSampleCount()) + if requirement.PositiveSum { + require.Positive(t, histogram.GetSampleSum()) + } + return + } + t.Fatalf("metric %s with labels %v not found", name, wantLabels) + } + t.Fatalf("metric %s not found", name) +} + +func metricLabels(metric interface { + GetLabel() []*io_prometheus_client.LabelPair +}, +) map[string]string { + labels := map[string]string{} + for _, label := range metric.GetLabel() { + labels[label.GetName()] = label.GetValue() + } + return labels +} + +func metricLabelsMatch(labels, wantLabels map[string]string) bool { + for key, value := range wantLabels { + if labels[key] != value { + return false + } + } + return true +} + +func TestActiveServer_AnthropicUsageMatchesFinalDelta(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + db, ps := dbtestutil.NewDB(t) + anthropicURL := chattest.NewAnthropic(t, func(_ *chattest.AnthropicRequest) chattest.AnthropicResponse { + return chattest.AnthropicStreamingResponse(chattest.AnthropicTextChunksWithCacheUsage(chattest.AnthropicUsage{ + InputTokens: 200, + OutputTokens: 75, + CacheCreationInputTokens: 30, + CacheReadInputTokens: 150, + }, "cached response")...) + }) + user, org, model := seedAnthropicChatDependencies(t, db, anthropicURL) + + server := newActiveTestServer(t, db, ps) + chat := createChatThroughServer(ctx, t, server, org.ID, user.ID, model.ID, "hello") + waitForChatStatus(ctx, t, db, chat.ID, database.ChatStatusWaiting) + + messages := chatMessages(ctx, t, db, chat.ID) + last := messages[len(messages)-1] + require.Equal(t, database.ChatMessageRoleAssistant, last.Role) + require.Equal(t, sql.NullInt64{Int64: 200, Valid: true}, last.InputTokens) + require.Equal(t, sql.NullInt64{Int64: 75, Valid: true}, last.OutputTokens) + require.Equal(t, sql.NullInt64{Int64: 275, Valid: true}, last.TotalTokens) + require.Equal(t, sql.NullInt64{Int64: 30, Valid: true}, last.CacheCreationTokens) + require.Equal(t, sql.NullInt64{Int64: 150, Valid: true}, last.CacheReadTokens) +} + +func TestActiveServer_AnthropicSanitizesProviderToolBeforeRequest(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + db, ps := dbtestutil.NewDB(t) + requests := newAnthropicRequestRecorder() + anthropicURL := chattest.NewAnthropic(t, func(req *chattest.AnthropicRequest) chattest.AnthropicResponse { + requests.record(req) + return chattest.AnthropicStreamingResponse(chattest.AnthropicTextChunks("done")...) + }) + user, org, model := seedAnthropicChatDependencies(t, db, anthropicURL) + + server := newActiveTestServer(t, db, ps) + chat := createChatThroughServer(ctx, t, server, org.ID, user.ID, model.ID, "search for coder") + waitForChatStatus(ctx, t, db, chat.ID, database.ChatStatusWaiting) + insertOrphanProviderToolCall(ctx, t, db, chat.ID, model.ID) + _, err := server.SendMessage(ctx, chatd.SendMessageOptions{ + ChatID: chat.ID, + CreatedBy: user.ID, + ModelConfigID: model.ID, + Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("continue")}, + BusyBehavior: chatd.SendMessageBusyBehaviorQueue, + }) + require.NoError(t, err) + + waitForChatStatus(ctx, t, db, chat.ID, database.ChatStatusWaiting) + + generationRequests := filterAnthropicStreamingRequests(requests.all()) + require.Len(t, generationRequests, 2) + body := anthropicRequestBody(t, generationRequests[1]) + require.NotContains(t, body, "web_search") + require.Contains(t, body, "partial") + require.Contains(t, body, "continue") + requireAnthropicRequestRedactedReasoning(t, generationRequests[1], "redacted-payload") +} + +func TestActiveServer_AnthropicProviderToolPreRequestGuard(t *testing.T) { + t.Parallel() + + webSearchEnabled := true + callConfig := codersdk.ChatModelCallConfig{ + ProviderOptions: &codersdk.ChatModelProviderOptions{ + Anthropic: &codersdk.ChatModelAnthropicProviderOptions{ + WebSearchEnabled: &webSearchEnabled, + }, + }, + } + + t.Run("allowed web search survives when provider tool is enabled", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + db, ps := dbtestutil.NewDB(t) + requests := newAnthropicRequestRecorder() + anthropicURL := chattest.NewAnthropic(t, func(req *chattest.AnthropicRequest) chattest.AnthropicResponse { + requests.record(req) + return chattest.AnthropicStreamingResponse(chattest.AnthropicTextChunks("done")...) + }) + user, org, model := seedAnthropicChatDependencies(t, db, anthropicURL) + model = updateChatModelCallConfig(t, db, model, callConfig) + + server := newActiveTestServer(t, db, ps) + chat := createChatThroughServer(ctx, t, server, org.ID, user.ID, model.ID, "search") + waitForChatStatus(ctx, t, db, chat.ID, database.ChatStatusWaiting) + insertProviderToolPairMessageWithLocalTool(ctx, t, db, chat.ID, model.ID, "ws-allowed") + _, err := server.SendMessage(ctx, chatd.SendMessageOptions{ + ChatID: chat.ID, + CreatedBy: user.ID, + ModelConfigID: model.ID, + Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("continue")}, + BusyBehavior: chatd.SendMessageBusyBehaviorQueue, + }) + require.NoError(t, err) + + waitForChatStatus(ctx, t, db, chat.ID, database.ChatStatusWaiting) + + generationRequests := filterAnthropicStreamingRequests(requests.all()) + require.Len(t, generationRequests, 2) + body := anthropicRequestBody(t, generationRequests[1]) + require.Contains(t, body, "ws-allowed") + require.Contains(t, body, "web_search") + }) + + t.Run("web search history survives when provider tool is disabled", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + db, ps := dbtestutil.NewDB(t) + requests := newAnthropicRequestRecorder() + anthropicURL := chattest.NewAnthropic(t, func(req *chattest.AnthropicRequest) chattest.AnthropicResponse { + requests.record(req) + return chattest.AnthropicStreamingResponse(chattest.AnthropicTextChunks("done")...) + }) + user, org, model := seedAnthropicChatDependencies(t, db, anthropicURL) + + server := newActiveTestServer(t, db, ps) + chat := createChatThroughServer(ctx, t, server, org.ID, user.ID, model.ID, "search and read") + waitForChatStatus(ctx, t, db, chat.ID, database.ChatStatusWaiting) + insertProviderToolPairMessageWithLocalTool(ctx, t, db, chat.ID, model.ID, "ws-disabled") + _, err := server.SendMessage(ctx, chatd.SendMessageOptions{ + ChatID: chat.ID, + CreatedBy: user.ID, + ModelConfigID: model.ID, + Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("continue")}, + BusyBehavior: chatd.SendMessageBusyBehaviorQueue, + }) + require.NoError(t, err) + + waitForChatStatus(ctx, t, db, chat.ID, database.ChatStatusWaiting) + + generationRequests := filterAnthropicStreamingRequests(requests.all()) + require.Len(t, generationRequests, 2) + body := anthropicRequestBody(t, generationRequests[1]) + require.Contains(t, body, "ws-disabled") + require.Contains(t, body, "web_search") + require.Contains(t, body, "tc-1") + require.Contains(t, body, "file") + }) +} + +func TestActiveServer_AnthropicDropsUnpairedProviderToolBeforePersist(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + toolName string + toolInput json.RawMessage + }{ + { + name: "web_search", + toolName: "web_search", + toolInput: json.RawMessage(`{"query":"coder"}`), + }, + { + name: "code_execution", + toolName: "code_execution", + toolInput: json.RawMessage(`{"code":"print(1)"}`), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + db, ps := dbtestutil.NewDB(t) + requests := newAnthropicRequestRecorder() + var requestCount atomic.Int32 + anthropicURL := chattest.NewAnthropic(t, func(req *chattest.AnthropicRequest) chattest.AnthropicResponse { + requests.record(req) + if !req.Stream { + return chattest.AnthropicNonStreamingResponse("title") + } + if requestCount.Add(1) == 1 { + return chattest.AnthropicStreamingResponse( + anthropicServerToolUseChunks("pt-1", tt.toolName, tt.toolInput, "tool_use")..., + ) + } + return chattest.AnthropicStreamingResponse(chattest.AnthropicTextChunks("after sanitized step")...) + }) + user, org, model := seedAnthropicChatDependencies(t, db, anthropicURL) + model = enableAnthropicWebSearchForTest(t, db, model) + + server := newActiveTestServer(t, db, ps) + chat := createChatThroughServer(ctx, t, server, org.ID, user.ID, model.ID, "run provider tool") + waitForChatStatus(ctx, t, db, chat.ID, database.ChatStatusWaiting) + + generationRequests := filterAnthropicStreamingRequests(requests.all()) + require.Len(t, generationRequests, 1) + messages := chatMessages(ctx, t, db, chat.ID) + last := messages[len(messages)-1] + require.Equal(t, database.ChatMessageRoleUser, last.Role) + requireTextPart(t, last, "run provider tool") + require.False(t, toolPartExists(chatToolParts(ctx, t, db, chat.ID), tt.toolName), + "unpaired provider tool content should not be committed") + }) + } +} + +func TestActiveServer_AnthropicKeepsPairedWebSearchBeforePersist(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + db, ps := dbtestutil.NewDB(t) + requests := newAnthropicRequestRecorder() + anthropicURL := chattest.NewAnthropic(t, func(req *chattest.AnthropicRequest) chattest.AnthropicResponse { + requests.record(req) + return chattest.AnthropicStreamingResponse( + anthropicWebSearchPairChunks("ws-1", `{"query":"coder"}`, "search done", "end_turn")..., + ) + }) + user, org, model := seedAnthropicChatDependencies(t, db, anthropicURL) + model = enableAnthropicWebSearchForTest(t, db, model) + + server := newActiveTestServer(t, db, ps) + chat := createChatThroughServer(ctx, t, server, org.ID, user.ID, model.ID, "search for coder") + waitForChatStatus(ctx, t, db, chat.ID, database.ChatStatusWaiting) + + generationRequests := filterAnthropicStreamingRequests(requests.all()) + require.Len(t, generationRequests, 1) + parts := chatToolParts(ctx, t, db, chat.ID) + toolCall := requireToolCallPart(t, parts, "web_search") + require.Equal(t, "ws-1", toolCall.ToolCallID) + require.True(t, toolCall.ProviderExecuted) + toolResult := requireToolResultPart(t, parts, "web_search") + require.Equal(t, "ws-1", toolResult.ToolCallID) + require.True(t, toolResult.ProviderExecuted) + require.NotEmpty(t, toolResult.ProviderMetadata) + messages := chatMessages(ctx, t, db, chat.ID) + requireTextPart(t, messages[len(messages)-1], "search done") +} + +func TestActiveServer_AnthropicSanitizesWebSearchBeforeContinuation(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + db, ps := dbtestutil.NewDB(t) + requests := newAnthropicRequestRecorder() + var requestCount atomic.Int32 + anthropicURL := chattest.NewAnthropic(t, func(req *chattest.AnthropicRequest) chattest.AnthropicResponse { + requests.record(req) + if !req.Stream { + return chattest.AnthropicNonStreamingResponse("title") + } + if requestCount.Add(1) == 1 { + chunks := anthropicServerToolUseChunks("ws-1", "web_search", json.RawMessage(`{"query":"coder"}`), "tool_use") + chunks = append(chunks[:len(chunks)-2], anthropicToolUseChunksWithoutMessageEnvelope(1, "tc-1", "read_file", `{"path":"main.go"}`)...) + chunks = append(chunks, + chattest.AnthropicChunk{ + Type: "message_delta", + StopReason: "tool_use", + Usage: chattest.AnthropicUsage{InputTokens: 10, OutputTokens: 5}, + }, + chattest.AnthropicChunk{Type: "message_stop"}, + ) + return chattest.AnthropicStreamingResponse(chunks...) + } + return chattest.AnthropicStreamingResponse(chattest.AnthropicTextChunks("done")...) + }) + user, org, model := seedAnthropicChatDependencies(t, db, anthropicURL) + model = enableAnthropicWebSearchForTest(t, db, model) + ws, dbAgent := seedWorkspaceWithAgent(t, db, user.ID) + + ctrl := gomock.NewController(t) + mockConn := agentconnmock.NewMockAgentConn(ctrl) + setupToolExecutionAgentConn(t, mockConn) + mockConn.EXPECT().ReadFileLines(gomock.Any(), "main.go", int64(1), int64(0), gomock.Any()). + Return(workspacesdk.ReadFileLinesResponse{Success: true, Content: "package main", FileSize: 12, TotalLines: 1, LinesRead: 1}, nil). + Times(1) + + server := newActiveTestServer(t, db, ps, func(cfg *chatd.Config) { + cfg.AgentConn = func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) { + require.Equal(t, dbAgent.ID, agentID) + return mockConn, func() {}, nil + } + }) + chat, err := server.CreateChat(ctx, chatd.CreateOptions{ + OrganizationID: org.ID, + OwnerID: user.ID, + WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true}, + AgentID: uuid.NullUUID{UUID: dbAgent.ID, Valid: true}, + Title: "anthropic-web-search-continuation", + ModelConfigID: model.ID, + InitialUserContent: []codersdk.ChatMessagePart{ + codersdk.ChatMessageText("search and read"), + }, + }) + require.NoError(t, err) + waitForChatStatus(ctx, t, db, chat.ID, database.ChatStatusWaiting) + + generationRequests := filterAnthropicStreamingRequests(requests.all()) + require.Len(t, generationRequests, 2) + continuationBody := anthropicRequestBody(t, generationRequests[1]) + require.NotContains(t, continuationBody, "server_tool_use") + require.NotContains(t, continuationBody, "web_search_tool_result") + require.NotContains(t, continuationBody, "ws-1") + require.Contains(t, continuationBody, "tc-1") + require.Contains(t, continuationBody, "package main") + + parts := chatToolParts(ctx, t, db, chat.ID) + require.False(t, toolPartExists(parts, "web_search")) + toolCall := requireToolCallPart(t, parts, "read_file") + require.Equal(t, "tc-1", toolCall.ToolCallID) + require.False(t, toolCall.ProviderExecuted) + toolResult := requireToolResultPart(t, parts, "read_file") + require.Equal(t, "tc-1", toolResult.ToolCallID) + require.False(t, toolResult.ProviderExecuted) +} + +func TestActiveServer_ExclusiveToolPolicy(t *testing.T) { + t.Parallel() + + t.Run("mixed exclusive and local tools commit policy errors", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + db, ps := dbtestutil.NewDB(t) + var streamedCallCount atomic.Int32 + openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { + if !req.Stream { + return chattest.OpenAINonStreamingResponse("title") + } + if streamedCallCount.Add(1) == 1 { + advisorChunk := chattest.OpenAIToolCallChunk("advisor", `{"question":"help"}`) + readChunk := chattest.OpenAIToolCallChunk("read_file", `{"path":"/tmp/a.txt"}`) + readCall := readChunk.Choices[0].ToolCalls[0] + readCall.Index = 1 + advisorChunk.Choices[0].ToolCalls = append(advisorChunk.Choices[0].ToolCalls, readCall) + return chattest.OpenAIStreamingResponse(advisorChunk) + } + return chattest.OpenAIStreamingResponse(chattest.OpenAITextChunks("done")...) + }) + user, org, model := seedChatDependenciesWithProvider(t, db, "openai-compat", openAIURL) + seedAdvisorConfig(ctx, t, db, codersdk.AdvisorConfig{Enabled: true, MaxUsesPerRun: 3, MaxOutputTokens: 1024}) + ws, dbAgent := seedWorkspaceWithAgent(t, db, user.ID) + + ctrl := gomock.NewController(t) + mockConn := agentconnmock.NewMockAgentConn(ctrl) + setupToolExecutionAgentConn(t, mockConn) + mockConn.EXPECT().ReadFileLines(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(0) + + server := newActiveTestServer(t, db, ps, func(cfg *chatd.Config) { + cfg.AgentConn = func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) { + require.Equal(t, dbAgent.ID, agentID) + return mockConn, func() {}, nil + } + }) + chat, err := server.CreateChat(ctx, chatd.CreateOptions{ + OrganizationID: org.ID, + OwnerID: user.ID, + WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true}, + AgentID: uuid.NullUUID{UUID: dbAgent.ID, Valid: true}, + Title: "exclusive-local-policy", + ModelConfigID: model.ID, + InitialUserContent: []codersdk.ChatMessagePart{ + codersdk.ChatMessageText("advise and read"), + }, + }) + require.NoError(t, err) + waitForChatStatus(ctx, t, db, chat.ID, database.ChatStatusWaiting) + + parts := chatToolParts(ctx, t, db, chat.ID) + advisorResult := requireToolResultPart(t, parts, "advisor") + readResult := requireToolResultPart(t, parts, "read_file") + require.True(t, advisorResult.IsError) + require.True(t, readResult.IsError) + require.Contains(t, string(advisorResult.Result), "advisor must be called alone, without other tools in the same batch") + require.Contains(t, string(readResult.Result), "this tool was skipped because advisor must run alone in its batch") + require.GreaterOrEqual(t, streamedCallCount.Load(), int32(2)) + }) + + t.Run("mixed exclusive and dynamic tools commit policy errors", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + db, ps := dbtestutil.NewDB(t) + var streamedCallCount atomic.Int32 + openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { + if !req.Stream { + return chattest.OpenAINonStreamingResponse("title") + } + if streamedCallCount.Add(1) == 1 { + advisorChunk := chattest.OpenAIToolCallChunk("advisor", `{"question":"help"}`) + dynamicChunk := chattest.OpenAIToolCallChunk("mcp_tool", `{"q":"docs"}`) + dynamicCall := dynamicChunk.Choices[0].ToolCalls[0] + dynamicCall.Index = 1 + advisorChunk.Choices[0].ToolCalls = append(advisorChunk.Choices[0].ToolCalls, dynamicCall) + return chattest.OpenAIStreamingResponse(advisorChunk) + } + return chattest.OpenAIStreamingResponse(chattest.OpenAITextChunks("done")...) + }) + user, org, model := seedChatDependenciesWithProvider(t, db, "openai-compat", openAIURL) + seedAdvisorConfig(ctx, t, db, codersdk.AdvisorConfig{Enabled: true, MaxUsesPerRun: 3, MaxOutputTokens: 1024}) + dynamicToolsJSON, err := json.Marshal([]mcpgo.Tool{{ + Name: "mcp_tool", + Description: "dynamic test tool", + InputSchema: mcpgo.ToolInputSchema{Type: "object", Properties: map[string]any{"q": map[string]any{"type": "string"}}}, + }}) + require.NoError(t, err) + + server := newActiveTestServer(t, db, ps) + chat, err := server.CreateChat(ctx, chatd.CreateOptions{ + OrganizationID: org.ID, + OwnerID: user.ID, + Title: "exclusive-dynamic-policy", + ModelConfigID: model.ID, + InitialUserContent: []codersdk.ChatMessagePart{ + codersdk.ChatMessageText("advise and call dynamic"), + }, + DynamicTools: dynamicToolsJSON, + }) + require.NoError(t, err) + chatResult := waitForChatStatus(ctx, t, db, chat.ID, database.ChatStatusWaiting) + require.NotEqual(t, database.ChatStatusRequiresAction, chatResult.Status) + + parts := chatToolParts(ctx, t, db, chat.ID) + advisorResult := requireToolResultPart(t, parts, "advisor") + dynamicResult := requireToolResultPart(t, parts, "mcp_tool") + require.True(t, advisorResult.IsError) + require.True(t, dynamicResult.IsError) + require.Contains(t, string(advisorResult.Result), "advisor must be called alone, without other tools in the same batch") + require.Contains(t, string(dynamicResult.Result), "this tool was skipped because advisor must run alone in its batch") + require.GreaterOrEqual(t, streamedCallCount.Load(), int32(2)) + }) + + t.Run("solo exclusive tool executes", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + db, ps := dbtestutil.NewDB(t) + var streamedCallCount atomic.Int32 + openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { + if !req.Stream { + return chattest.OpenAINonStreamingResponse("title") + } + switch streamedCallCount.Add(1) { + case 1: + return chattest.OpenAIStreamingResponse( + chattest.OpenAIToolCallChunk("advisor", `{"question":"help me decide"}`), + ) + case 2: + return chattest.OpenAIStreamingResponse(chattest.OpenAITextChunks("nested advice")...) + default: + return chattest.OpenAIStreamingResponse(chattest.OpenAITextChunks("done")...) + } + }) + user, org, model := seedChatDependenciesWithProvider(t, db, "openai-compat", openAIURL) + seedAdvisorConfig(ctx, t, db, codersdk.AdvisorConfig{Enabled: true, MaxUsesPerRun: 3, MaxOutputTokens: 1024}) + server := newActiveTestServer(t, db, ps) + chat := createChatThroughServer(ctx, t, server, org.ID, user.ID, model.ID, "advise only") + waitForChatStatus(ctx, t, db, chat.ID, database.ChatStatusWaiting) + + parts := chatToolParts(ctx, t, db, chat.ID) + result := requireToolResultPart(t, parts, "advisor") + require.False(t, result.IsError) + require.Contains(t, string(result.Result), "nested advice") + require.GreaterOrEqual(t, streamedCallCount.Load(), int32(3)) + }) + + t.Run("exclusive tool with provider executed tool executes", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + db, ps := dbtestutil.NewDB(t) + webSearchEnabled := true + var streamedCallCount atomic.Int32 + openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { + if !req.Stream { + return chattest.OpenAINonStreamingResponse("title") + } + switch streamedCallCount.Add(1) { + case 1: + return chattest.OpenAIResponse{ + StreamingChunks: chattest.OpenAIStreamingResponse( + chattest.OpenAIToolCallChunk("advisor", `{"question":"search informed advice"}`), + ).StreamingChunks, + WebSearch: &chattest.OpenAIWebSearchCall{ID: "ws-advisor", Query: "coder"}, + } + case 2: + return chattest.OpenAIStreamingResponse(chattest.OpenAITextChunks("nested advice")...) + default: + return chattest.OpenAIStreamingResponse(chattest.OpenAITextChunks("done")...) + } + }) + user, org, model := seedChatDependenciesWithProvider(t, db, "openai", openAIURL) + model = updateChatModelCallConfig(t, db, model, codersdk.ChatModelCallConfig{ + ProviderOptions: &codersdk.ChatModelProviderOptions{ + OpenAI: &codersdk.ChatModelOpenAIProviderOptions{WebSearchEnabled: &webSearchEnabled}, + }, + }) + seedAdvisorConfig(ctx, t, db, codersdk.AdvisorConfig{Enabled: true, MaxUsesPerRun: 3, MaxOutputTokens: 1024}) + server := newActiveTestServer(t, db, ps) + chat := createChatThroughServer(ctx, t, server, org.ID, user.ID, model.ID, "search then advise") + waitForChatStatus(ctx, t, db, chat.ID, database.ChatStatusWaiting) + + parts := chatToolParts(ctx, t, db, chat.ID) + advisorResult := requireToolResultPart(t, parts, "advisor") + webResult := requireToolResultPart(t, parts, "web_search") + require.False(t, advisorResult.IsError) + require.True(t, webResult.ProviderExecuted) + require.GreaterOrEqual(t, streamedCallCount.Load(), int32(3)) + }) +} + +func TestActiveServer_ReasoningTimestamps(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + db, ps := dbtestutil.NewDB(t) + sendReasoning := true + thinkingBudget := int64(1024) + anthropicURL := chattest.NewAnthropic(t, func(_ *chattest.AnthropicRequest) chattest.AnthropicResponse { + return chattest.AnthropicStreamingResponse(chattest.AnthropicReasoningTextChunks( + []chattest.AnthropicReasoningBlock{ + {Text: "first thought", Signature: "sig_1"}, + {Text: "second thought", Signature: "sig_2"}, + }, + "answer", + )...) + }) + user, org, model := seedAnthropicChatDependencies(t, db, anthropicURL) + model = updateChatModelCallConfig(t, db, model, codersdk.ChatModelCallConfig{ + ProviderOptions: &codersdk.ChatModelProviderOptions{ + Anthropic: &codersdk.ChatModelAnthropicProviderOptions{ + SendReasoning: &sendReasoning, + Thinking: &codersdk.ChatModelAnthropicThinkingOptions{ + BudgetTokens: &thinkingBudget, + }, + }, + }, + }) + + server := newActiveTestServer(t, db, ps) + chat := createChatThroughServer(ctx, t, server, org.ID, user.ID, model.ID, "think") + waitForChatStatus(ctx, t, db, chat.ID, database.ChatStatusWaiting) + + messages := chatMessages(ctx, t, db, chat.ID) + assistant := messages[len(messages)-1] + reasoningParts := reasoningPartsFromMessage(t, assistant) + require.Len(t, reasoningParts, 2) + require.Equal(t, []string{"first thought", "second thought"}, []string{ + strings.TrimSpace(reasoningParts[0].Text), + strings.TrimSpace(reasoningParts[1].Text), + }) + for i := range reasoningParts { + require.NotNil(t, reasoningParts[i].CreatedAt) + require.NotNil(t, reasoningParts[i].CompletedAt) + require.False(t, reasoningParts[i].CreatedAt.IsZero()) + require.False(t, reasoningParts[i].CompletedAt.IsZero()) + require.False(t, reasoningParts[i].CompletedAt.Before(*reasoningParts[i].CreatedAt)) + } + require.False(t, reasoningParts[1].CreatedAt.Before(*reasoningParts[0].CompletedAt)) +} + +func TestAnthropicProviderToolPreRequestGuard(t *testing.T) { + t.Parallel() + + providerPair := func(id string) []fantasy.MessagePart { + return []fantasy.MessagePart{ + fantasy.ToolCallPart{ + ToolCallID: id, + ToolName: "web_search", + Input: `{"query":"coder"}`, + ProviderExecuted: true, + }, + fantasy.ToolResultPart{ + ToolCallID: id, + Output: fantasy.ToolResultOutputContentText{Text: "ok"}, + ProviderExecuted: true, + ProviderOptions: fantasy.ProviderOptions(validWebSearchProviderMetadataForTest()), + }, + } + } + + t.Run("orphan provider result is textified", func(t *testing.T) { + t.Parallel() + + guarded, err := chatsanitize.ApplyAnthropicProviderToolGuard( + context.Background(), + slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}), + fantasyanthropic.Name, + "claude-test", + []fantasy.Message{ + { + Role: fantasy.MessageRoleAssistant, + Content: []fantasy.MessagePart{ + fantasy.TextPart{Text: "keep"}, + fantasy.ToolResultPart{ + ToolCallID: "ws-orphan", + Output: fantasy.ToolResultOutputContentText{Text: "search result"}, + ProviderExecuted: true, + }, + }, + }, + }, + ) + require.NoError(t, err) + + requireNoProviderExecutedToolResultPrompt(t, guarded) + requireAnthropicProviderToolPromptSafe(t, guarded) + require.Len(t, guarded, 1) + require.Len(t, guarded[0].Content, 2) + textPart, ok := fantasy.AsMessagePart[fantasy.TextPart](guarded[0].Content[0]) + require.True(t, ok) + require.Equal(t, "keep", textPart.Text) + textPart, ok = fantasy.AsMessagePart[fantasy.TextPart](guarded[0].Content[1]) + require.True(t, ok) + require.Equal(t, "search result", textPart.Text) + }) + + t.Run("valid provider history is unchanged", func(t *testing.T) { + t.Parallel() + + content := []fantasy.MessagePart{fantasy.TextPart{Text: "keep"}} + content = append(content, providerPair("ws-one")...) + content = append(content, providerPair("ws-two")...) + guarded, err := chatsanitize.ApplyAnthropicProviderToolGuard( + context.Background(), + slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}), + fantasyanthropic.Name, + "claude-test", + []fantasy.Message{{Role: fantasy.MessageRoleAssistant, Content: content}}, + ) + require.NoError(t, err) + + requireAnthropicProviderToolPromptSafe(t, guarded) + require.Len(t, guarded, 1) + require.Len(t, guarded[0].Content, len(content)) + requireProviderExecutedToolCallPrompt(t, guarded, "ws-one") + requireProviderExecutedToolResultPrompt(t, guarded, "ws-one") + requireProviderExecutedToolCallPrompt(t, guarded, "ws-two") + requireProviderExecutedToolResultPrompt(t, guarded, "ws-two") + }) + + t.Run("non Anthropic providers are unchanged", func(t *testing.T) { + t.Parallel() + + prompt := []fantasy.Message{ + { + Role: fantasy.MessageRoleAssistant, + Content: providerPair("ws-other-provider"), + }, + } + guarded, err := chatsanitize.ApplyAnthropicProviderToolGuard( + context.Background(), + slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}), + "fake", + "fake-model", + prompt, + ) + require.NoError(t, err) + require.Equal(t, prompt, guarded) + }) + + t.Run("logs removals", func(t *testing.T) { + t.Parallel() + + logSink := testutil.NewFakeSink(t) + logger := logSink.Logger() + logPair := providerPair("ws-log") + guarded, err := chatsanitize.ApplyAnthropicProviderToolGuard( + context.Background(), + logger, + fantasyanthropic.Name, + "claude-test", + []fantasy.Message{ + { + Role: fantasy.MessageRoleAssistant, + Content: []fantasy.MessagePart{ + logPair[1], + logPair[0], + }, + }, + }, + ) + require.NoError(t, err) + + requireNoProviderExecutedToolCallPrompt(t, guarded) + requireNoProviderExecutedToolResultPrompt(t, guarded) + requireTextPrompt(t, guarded, "ok") + entries := logSink.Entries(func(e slog.SinkEntry) bool { + return e.Level == slog.LevelWarn && + e.Message == "removed provider-executed tool history" + }) + require.Len(t, entries, 1) + require.Equal(t, "pre_request_guard", requireLogField(t, entries[0], "phase")) + require.Equal(t, 1, requireLogField(t, entries[0], "removed_tool_calls")) + require.Equal(t, 1, requireLogField(t, entries[0], "removed_tool_results")) + }) +} + +func enableAnthropicWebSearchForTest( + t *testing.T, + db database.Store, + model database.ChatModelConfig, +) database.ChatModelConfig { + t.Helper() + webSearchEnabled := true + return updateChatModelCallConfig(t, db, model, codersdk.ChatModelCallConfig{ + ProviderOptions: &codersdk.ChatModelProviderOptions{ + Anthropic: &codersdk.ChatModelAnthropicProviderOptions{ + WebSearchEnabled: &webSearchEnabled, + }, + }, + }) +} + +func anthropicMessageStartChunk(messageID string) chattest.AnthropicChunk { + return chattest.AnthropicChunk{ + Type: "message_start", + Message: chattest.AnthropicChunkMessage{ + ID: messageID, + Type: "message", + Role: "assistant", + Model: "claude-3-opus-20240229", + }, + } +} + +func anthropicServerToolUseChunks( + toolCallID string, + toolName string, + input json.RawMessage, + stopReason string, +) []chattest.AnthropicChunk { + chunks := []chattest.AnthropicChunk{ + anthropicMessageStartChunk("msg-" + toolCallID), + } + chunks = append(chunks, anthropicServerToolUseChunksWithoutMessageEnvelope(0, toolCallID, toolName, input)...) + chunks = append(chunks, + chattest.AnthropicChunk{ + Type: "message_delta", + StopReason: stopReason, + Usage: chattest.AnthropicUsage{InputTokens: 10, OutputTokens: 5}, + }, + chattest.AnthropicChunk{Type: "message_stop"}, + ) + return chunks +} + +func anthropicServerToolUseChunksWithoutMessageEnvelope( + index int, + toolCallID string, + toolName string, + input json.RawMessage, +) []chattest.AnthropicChunk { + return []chattest.AnthropicChunk{ + { + Type: "content_block_start", + Index: index, + ContentBlock: chattest.AnthropicContentBlock{ + Type: "server_tool_use", + ID: toolCallID, + Name: toolName, + Input: input, + }, + }, + { + Type: "content_block_stop", + Index: index, + }, + } +} + +func anthropicToolUseChunksWithoutMessageEnvelope( + index int, + toolCallID string, + toolName string, + input string, +) []chattest.AnthropicChunk { + return []chattest.AnthropicChunk{ + { + Type: "content_block_start", + Index: index, + ContentBlock: chattest.AnthropicContentBlock{ + Type: "tool_use", + ID: toolCallID, + Name: toolName, + Input: json.RawMessage(`{}`), + }, + }, + { + Type: "content_block_delta", + Index: index, + Delta: chattest.AnthropicDeltaBlock{ + Type: "input_json_delta", + PartialJSON: input, + }, + }, + { + Type: "content_block_stop", + Index: index, + }, + } +} + +func anthropicWebSearchPairChunks( + toolCallID string, + queryInput string, + text string, + stopReason string, +) []chattest.AnthropicChunk { + resultContent := []map[string]any{{ + "type": "web_search_result", + "url": "https://example.com/coder", + "title": "Coder", + "encrypted_content": "encrypted-coder", + }} + chunks := []chattest.AnthropicChunk{ + anthropicMessageStartChunk("msg-" + toolCallID), + } + chunks = append(chunks, anthropicServerToolUseChunksWithoutMessageEnvelope(0, toolCallID, "web_search", json.RawMessage(queryInput))...) + chunks = append(chunks, + chattest.AnthropicChunk{ + Type: "content_block_start", + Index: 1, + ContentBlock: chattest.AnthropicContentBlock{ + Type: "web_search_tool_result", + ToolUseID: toolCallID, + Content: resultContent, + }, + }, + chattest.AnthropicChunk{Type: "content_block_stop", Index: 1}, + chattest.AnthropicChunk{ + Type: "content_block_start", + Index: 2, + ContentBlock: chattest.AnthropicContentBlock{ + Type: "text", + }, + }, + chattest.AnthropicChunk{ + Type: "content_block_delta", + Index: 2, + Delta: chattest.AnthropicDeltaBlock{ + Type: "text_delta", + Text: text, + }, + }, + chattest.AnthropicChunk{Type: "content_block_stop", Index: 2}, + chattest.AnthropicChunk{ + Type: "message_delta", + StopReason: stopReason, + Usage: chattest.AnthropicUsage{InputTokens: 10, OutputTokens: 5}, + }, + chattest.AnthropicChunk{Type: "message_stop"}, + ) + return chunks +} + +func toolPartExists(parts []codersdk.ChatMessagePart, toolName string) bool { + for _, part := range parts { + if (part.Type == codersdk.ChatMessagePartTypeToolCall || part.Type == codersdk.ChatMessagePartTypeToolResult) && + part.ToolName == toolName { + return true + } + } + return false +} + +func updateChatModelCompressionThreshold(t *testing.T, db database.Store, model database.ChatModelConfig, contextLimit int64, threshold int32) database.ChatModelConfig { + t.Helper() + model.ContextLimit = contextLimit + model.CompressionThreshold = threshold + updated, err := db.UpdateChatModelConfig(context.Background(), database.UpdateChatModelConfigParams{ + ID: model.ID, + DisplayName: model.DisplayName, + Model: model.Model, + Provider: model.Provider, + Enabled: model.Enabled, + ContextLimit: model.ContextLimit, + CompressionThreshold: model.CompressionThreshold, + Options: model.Options, + AIProviderID: model.AIProviderID, + }) + require.NoError(t, err) + return updated +} + +func updateChatModelContextLimit(t *testing.T, db database.Store, model database.ChatModelConfig) database.ChatModelConfig { + t.Helper() + updated, err := db.UpdateChatModelConfig(context.Background(), database.UpdateChatModelConfigParams{ + ID: model.ID, + DisplayName: model.DisplayName, + Model: model.Model, + Provider: model.Provider, + Enabled: model.Enabled, + ContextLimit: model.ContextLimit, + CompressionThreshold: model.CompressionThreshold, + Options: model.Options, + AIProviderID: model.AIProviderID, + }) + require.NoError(t, err) + return updated +} + +func updateChatModelCallConfig(t *testing.T, db database.Store, model database.ChatModelConfig, callConfig codersdk.ChatModelCallConfig) database.ChatModelConfig { + t.Helper() + options, err := json.Marshal(callConfig) + require.NoError(t, err) + updated, err := db.UpdateChatModelConfig(context.Background(), database.UpdateChatModelConfigParams{ + ID: model.ID, + DisplayName: model.DisplayName, + Model: model.Model, + Provider: model.Provider, + Enabled: model.Enabled, + ContextLimit: model.ContextLimit, + CompressionThreshold: model.CompressionThreshold, + Options: options, + AIProviderID: model.AIProviderID, + }) + require.NoError(t, err) + return updated +} + +func insertAssistantTextMessage( + ctx context.Context, + t *testing.T, + db database.Store, + chatID uuid.UUID, + text string, + modelID uuid.UUID, +) { + t.Helper() + insertChatMessageParts(ctx, t, db, chatID, database.ChatMessageRoleAssistant, modelID, uuid.Nil, []codersdk.ChatMessagePart{ + codersdk.ChatMessageText(text), + }) +} + +func insertProviderToolPairMessageWithLocalTool( + ctx context.Context, + t *testing.T, + db database.Store, + chatID uuid.UUID, + modelID uuid.UUID, + toolCallID string, +) { + t.Helper() + metadata, err := json.Marshal(fantasy.ProviderMetadata{ + fantasyanthropic.Name: &fantasyanthropic.WebSearchResultMetadata{ + Results: []fantasyanthropic.WebSearchResultItem{{ + URL: "https://example.com", + Title: "Example", + EncryptedContent: "encrypted", + }}, + }, + }) + require.NoError(t, err) + parts := []codersdk.ChatMessagePart{ + { + Type: codersdk.ChatMessagePartTypeToolCall, + ToolCallID: toolCallID, + ToolName: "web_search", + Args: json.RawMessage(`{"query":"coder"}`), + ProviderExecuted: true, + }, + { + Type: codersdk.ChatMessagePartTypeToolResult, + ToolCallID: toolCallID, + ToolName: "web_search", + Result: json.RawMessage(`"ok"`), + ProviderExecuted: true, + ProviderMetadata: metadata, + }, + } + parts = append(parts, codersdk.ChatMessagePart{ + Type: codersdk.ChatMessagePartTypeToolCall, + ToolCallID: "tc-1", + ToolName: "read_file", + Args: json.RawMessage(`{"path":"main.go"}`), + }) + insertChatMessageParts(ctx, t, db, chatID, database.ChatMessageRoleAssistant, modelID, uuid.Nil, parts) + insertChatMessageParts(ctx, t, db, chatID, database.ChatMessageRoleTool, modelID, uuid.Nil, []codersdk.ChatMessagePart{ + { + Type: codersdk.ChatMessagePartTypeToolResult, + ToolCallID: "tc-1", + ToolName: "read_file", + Result: json.RawMessage(`"file"`), + }, + }) +} + +func insertChatMessageParts( + ctx context.Context, + t *testing.T, + db database.Store, + chatID uuid.UUID, + role database.ChatMessageRole, + modelID uuid.UUID, + createdBy uuid.UUID, + parts []codersdk.ChatMessagePart, +) database.ChatMessage { + t.Helper() + content, err := chatprompt.MarshalParts(parts) + require.NoError(t, err) + params := chatd.BuildSingleChatMessageInsertParams( + chatID, + role, + content, + database.ChatMessageVisibilityBoth, + modelID, + chatprompt.CurrentContentVersion, + createdBy, + ) + messages, err := db.InsertChatMessages(ctx, params) + require.NoError(t, err) + require.Len(t, messages, 1) + return messages[0] +} + +func createPlanSubagentChatWithHistory( + ctx context.Context, + t *testing.T, + db database.Store, + orgID uuid.UUID, + userID uuid.UUID, + modelID uuid.UUID, +) database.Chat { + t.Helper() + rootChat := dbgen.Chat(t, db, database.Chat{ + OrganizationID: orgID, + OwnerID: userID, + LastModelConfigID: modelID, + Title: "plan subagent active tools root", + Status: database.ChatStatusWaiting, + PlanMode: database.NullChatPlanMode{ChatPlanMode: database.ChatPlanModePlan, Valid: true}, + MCPServerIDs: []uuid.UUID{}, + ClientType: database.ChatClientTypeApi, + }) + chat := dbgen.Chat(t, db, database.Chat{ + OrganizationID: orgID, + OwnerID: userID, + LastModelConfigID: modelID, + Title: "plan subagent active tools", + Status: database.ChatStatusWaiting, + PlanMode: database.NullChatPlanMode{ChatPlanMode: database.ChatPlanModePlan, Valid: true}, + ParentChatID: uuid.NullUUID{UUID: rootChat.ID, Valid: true}, + RootChatID: uuid.NullUUID{UUID: rootChat.ID, Valid: true}, + MCPServerIDs: []uuid.UUID{}, + ClientType: database.ChatClientTypeApi, + }) + insertSystemTextMessage(ctx, t, db, chat.ID, "You are not currently connected to a workspace.", modelID) + insertChatMessageParts(ctx, t, db, chat.ID, database.ChatMessageRoleUser, modelID, userID, []codersdk.ChatMessagePart{ + codersdk.ChatMessageText("hello"), + }) + return chat +} + +func anthropicRequestToolNames(req chattest.AnthropicRequest) []string { + names := make([]string, 0, len(req.Tools)) + for _, tool := range req.Tools { + names = append(names, tool.Name) + } + return names +} + +func anthropicRequestContainsPromptSentinel(t *testing.T, req chattest.AnthropicRequest) bool { + t.Helper() + body := anthropicRequestBody(t, req) + return strings.Contains(body, "__chatd_agent_prompt_sentinel_") +} + +func reasoningPartsFromMessage(t *testing.T, msg database.ChatMessage) []codersdk.ChatMessagePart { + t.Helper() + parts, err := chatprompt.ParseContent(msg) + require.NoError(t, err) + var reasoning []codersdk.ChatMessagePart + for _, part := range parts { + if part.Type == codersdk.ChatMessagePartTypeReasoning { + reasoning = append(reasoning, part) + } + } + return reasoning +} + +func validWebSearchProviderMetadataForTest() fantasy.ProviderMetadata { + return fantasy.ProviderMetadata{ + fantasyanthropic.Name: &fantasyanthropic.WebSearchResultMetadata{ + Results: []fantasyanthropic.WebSearchResultItem{ + { + URL: "https://example.com", + Title: "Example", + EncryptedContent: "encrypted", + }, + }, + }, + } +} + +func safeToolCallPart(part fantasy.MessagePart) (fantasy.ToolCallPart, bool) { + var zero fantasy.ToolCallPart + if part == nil { + return zero, false + } + if value, ok := part.(*fantasy.ToolCallPart); ok && value == nil { + return zero, false + } + type toolCallPart = fantasy.ToolCallPart + return fantasy.AsMessagePart[toolCallPart](part) +} + +func safeToolResultPart(part fantasy.MessagePart) (fantasy.ToolResultPart, bool) { + var zero fantasy.ToolResultPart + if part == nil { + return zero, false + } + if value, ok := part.(*fantasy.ToolResultPart); ok && value == nil { + return zero, false + } + type toolResultPart = fantasy.ToolResultPart + return fantasy.AsMessagePart[toolResultPart](part) +} + +func requireProviderExecutedToolCallPrompt( + t *testing.T, + prompt []fantasy.Message, + id string, +) fantasy.ToolCallPart { + t.Helper() + for _, message := range prompt { + for _, part := range message.Content { + toolCall, ok := safeToolCallPart(part) + if ok && toolCall.ProviderExecuted && toolCall.ToolCallID == id { + return toolCall + } + } + } + t.Fatalf("missing provider-executed prompt tool call %q", id) + return fantasy.ToolCallPart{} +} + +func requireProviderExecutedToolResultPrompt( + t *testing.T, + prompt []fantasy.Message, + id string, +) fantasy.ToolResultPart { + t.Helper() + for _, message := range prompt { + for _, part := range message.Content { + toolResult, ok := safeToolResultPart(part) + if ok && toolResult.ProviderExecuted && toolResult.ToolCallID == id { + return toolResult + } + } + } + t.Fatalf("missing provider-executed prompt tool result %q", id) + return fantasy.ToolResultPart{} +} + +func requireNoProviderExecutedToolCallPrompt(t *testing.T, prompt []fantasy.Message) { + t.Helper() + for i, message := range prompt { + for j, part := range message.Content { + toolCall, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](part) + if ok && toolCall.ProviderExecuted { + t.Fatalf("prompt[%d].content[%d]: unexpected provider-executed call", i, j) + } + } + } +} + +func requireNoProviderExecutedToolResultPrompt(t *testing.T, prompt []fantasy.Message) { + t.Helper() + for i, message := range prompt { + for j, part := range message.Content { + toolResult, ok := safeToolResultPart(part) + if ok && toolResult.ProviderExecuted { + t.Fatalf("prompt[%d].content[%d]: unexpected provider-executed result", i, j) + } + } + } +} + +func requireTextPrompt(t *testing.T, prompt []fantasy.Message, text string) fantasy.TextPart { + t.Helper() + for _, message := range prompt { + for _, part := range message.Content { + textPart, ok := fantasy.AsMessagePart[fantasy.TextPart](part) + if ok && textPart.Text == text { + return textPart + } + } + } + t.Fatalf("missing prompt text %q", text) + return fantasy.TextPart{} +} + +func requireAnthropicProviderToolPromptSafe(t *testing.T, prompt []fantasy.Message) { + t.Helper() + require.Empty(t, chatsanitize.ValidateAnthropicProviderToolHistory(prompt)) +} + +func requireLogField(t *testing.T, entry slog.SinkEntry, name string) any { + t.Helper() + for _, field := range entry.Fields { + if field.Name == name { + return field.Value + } + } + t.Fatalf("missing log field %q", name) + return nil +} + func TestPassiveServerDoesNotProcess(t *testing.T) { t.Parallel() @@ -4972,32 +7869,6 @@ func TestPassiveServerDoesNotProcess(t *testing.T) { require.False(t, stored.RunnerID.Valid) } -// newStartedTestServer creates a server with Start() called. -// Uses a long acquire interval so processing is triggered by -// wake signals, not polling. -func newStartedTestServer( - t *testing.T, - db database.Store, - ps dbpubsub.Pubsub, - replicaID uuid.UUID, -) *chatd.Server { - t.Helper() - - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) - server := chatd.New(chatd.Config{ - Logger: logger, - Database: db, - ReplicaID: replicaID, - Pubsub: ps, - PendingChatAcquireInterval: testutil.WaitLong, - }) - server.Start() - t.Cleanup(func() { - require.NoError(t, server.Close()) - }) - return server -} - // newDebugEnabledTestServer creates a passive test server with // AlwaysEnableDebugLogs=true so that IsEnabled(ctx, chatID, ownerID) // always returns true regardless of runtime admin config. This lets @@ -5274,9 +8145,9 @@ func seedLastTurnSummary( t.Helper() affected, err := db.UpdateChatLastTurnSummary(ctx, database.UpdateChatLastTurnSummaryParams{ - ID: chat.ID, - ExpectedUpdatedAt: chat.UpdatedAt, - LastTurnSummary: sql.NullString{String: summary, Valid: true}, + ID: chat.ID, + ExpectedHistoryVersion: chat.HistoryVersion, + LastTurnSummary: sql.NullString{String: summary, Valid: true}, }) require.NoError(t, err) require.Equal(t, int64(1), affected) @@ -5425,8 +8296,17 @@ func seedWorkspaceWithAgent( JobID: pj.ID, }) dbAgent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{ - ResourceID: res.ID, + ResourceID: res.ID, + Directory: "/home/coder/project", + OperatingSystem: "linux", }) + require.NoError(t, db.UpdateWorkspaceAgentStartupByID(context.Background(), database.UpdateWorkspaceAgentStartupByIDParams{ + ID: dbAgent.ID, + Version: "v1.0.0", + ExpandedDirectory: "/home/coder/project", + })) + dbAgent, err := db.GetWorkspaceAgentByID(context.Background(), dbAgent.ID) + require.NoError(t, err) return ws, dbAgent } @@ -5536,9 +8416,9 @@ func TestInterruptChatDoesNotSendWebPushNotification(t *testing.T) { } }, testutil.IntervalFast) - // Interrupt the chat. + // Interrupt the chat. The worker finalizes the interruption asynchronously. updated, _ := server.InterruptChat(ctx, chat) - require.Equal(t, database.ChatStatusWaiting, updated.Status) + require.Equal(t, database.ChatStatusInterrupting, updated.Status) // Wait for the chat to finish processing and return to waiting. testutil.Eventually(ctx, t, func(ctx context.Context) bool { @@ -5742,16 +8622,6 @@ func TestCloseDuringShutdownContextCanceledShouldRetryOnNewReplica(t *testing.T) require.NoError(t, serverA.Close()) - require.Eventually(t, func() bool { - fromDB, dbErr := db.GetChatByID(ctx, chat.ID) - if dbErr != nil { - return false - } - return fromDB.Status == database.ChatStatusPending && - !fromDB.WorkerID.Valid && - !fromDB.LastError.Valid - }, testutil.WaitMedium, testutil.IntervalFast) - loggerB := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) serverB := chatd.New(chatd.Config{ Logger: loggerB, @@ -6403,13 +9273,6 @@ func TestInterruptChatPersistsPartialResponse(t *testing.T) { }) require.NoError(t, err) - // Subscribe to the chat's event stream so we can observe - // message_part events. This proves the chatloop has actually - // processed the streamed chunks. - _, events, subCancel, ok := server.Subscribe(ctx, chat.ID, nil, 0) - require.True(t, ok) - defer subCancel() - // Wait for the mock to finish sending chunks. testutil.Eventually(ctx, t, func(ctx context.Context) bool { select { @@ -6420,27 +9283,9 @@ func TestInterruptChatPersistsPartialResponse(t *testing.T) { } }, testutil.IntervalFast) - // Drain the event channel until we see a message_part event, - // which means the chatloop has consumed and published the chunk. - gotMessagePart := false - testutil.Eventually(ctx, t, func(ctx context.Context) bool { - for { - select { - case ev := <-events: - if ev.Type == codersdk.ChatStreamEventTypeMessagePart { - gotMessagePart = true - return true - } - default: - return gotMessagePart - } - } - }, testutil.IntervalFast) - require.True(t, gotMessagePart, "should have received at least one message_part event") - - // Now interrupt the chat. The chatloop has processed content. + // Now interrupt the chat. The provider has sent partial content. updated, _ := server.InterruptChat(ctx, chat) - require.Equal(t, database.ChatStatusWaiting, updated.Status) + require.Equal(t, database.ChatStatusInterrupting, updated.Status) // Wait for the partial assistant message to be persisted. // After the interrupt, the chatloop runs persistInterruptedStep @@ -6534,15 +9379,8 @@ func TestProcessChat_UserProviderKey_Success(t *testing.T) { }) require.NoError(t, err) - _, events, cancel, ok := creator.Subscribe(ctx, chat.ID, nil, 0) - require.True(t, ok) - t.Cleanup(cancel) - _ = newActiveTestServer(t, db, ps) - terminalStatus := waitForTerminalChatStatusEvent(ctx, t, events) - require.Equal(t, codersdk.ChatStatusWaiting, terminalStatus) - chatResult := waitForTerminalChat(ctx, t, db, chat.ID) require.Equal(t, database.ChatStatusWaiting, chatResult.Status) require.False(t, chatResult.LastError.Valid) @@ -6619,8 +9457,7 @@ func TestProcessChat_AIGatewayRoutingUsesDelegatedAPIKey(t *testing.T) { cfg.AllowBYOKSet = true }) - terminalStatus := waitForTerminalChatStatusEvent(ctx, t, events) - require.Equal(t, codersdk.ChatStatusWaiting, terminalStatus) + _ = events chatResult := waitForTerminalChat(ctx, t, db, chat.ID) require.Equal(t, database.ChatStatusWaiting, chatResult.Status) @@ -6687,15 +9524,8 @@ func TestProcessChat_UserProviderKey_MissingKeyError(t *testing.T) { }) require.NoError(t, err) - _, events, cancel, ok := creator.Subscribe(ctx, chat.ID, nil, 0) - require.True(t, ok) - t.Cleanup(cancel) - _ = newActiveTestServer(t, db, ps) - terminalStatus := waitForTerminalChatStatusEvent(ctx, t, events) - require.Equal(t, codersdk.ChatStatusError, terminalStatus) - chatResult := waitForTerminalChat(ctx, t, db, chat.ID) require.Equal(t, database.ChatStatusError, chatResult.Status) persistedError := requireChatLastErrorPayload(t, chatResult.LastError) @@ -6719,10 +9549,19 @@ func TestProcessChatPanicRecovery(t *testing.T) { // the processChat goroutine. panicWrapper := &panicOnInTxDB{Store: db} + firstOpenAICallStarted := make(chan struct{}) + continueFirstOpenAICall := make(chan struct{}) + var openAICallCount atomic.Int32 openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { if !req.Stream { return chattest.OpenAINonStreamingResponse("Panic recovery test") } + + if openAICallCount.Add(1) == 1 { + close(firstOpenAICallStarted) + <-continueFirstOpenAICall + } + return chattest.OpenAIStreamingResponse( chattest.OpenAITextChunks("hello")..., ) @@ -6746,13 +9585,14 @@ func TestProcessChatPanicRecovery(t *testing.T) { }) require.NoError(t, err) - // Enable the panic now that CreateChat's InTx has completed. - // The next InTx call is PersistStep inside the chatloop, - // running synchronously on the processChat goroutine. - panicWrapper.enablePanic() + testutil.TryReceive(ctx, t, firstOpenAICallStarted) + + // Enable the panic while the first provider call is blocked. The next InTx + // call is PersistStep inside the chatloop, running synchronously on the + // processChat goroutine after the provider returns. + panicWrapper.enablePanic() + close(continueFirstOpenAICall) - // Wait for the panic to be recovered and the chat to - // transition to error status. var chatResult database.Chat require.Eventually(t, func() bool { got, getErr := db.GetChatByID(ctx, chat.ID) @@ -6760,13 +9600,31 @@ func TestProcessChatPanicRecovery(t *testing.T) { return false } chatResult = got - return got.Status == database.ChatStatusError + return got.Status == database.ChatStatusWaiting }, testutil.WaitLong, testutil.IntervalFast) + require.Equal(t, int32(2), openAICallCount.Load()) - persistedError := requireChatLastErrorPayload(t, chatResult.LastError) - require.Contains(t, persistedError.Message, "chat processing panicked") - require.Contains(t, persistedError.Message, "intentional test panic") - require.Equal(t, codersdk.ChatErrorKindGeneric, persistedError.Kind) + messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ + ChatID: chat.ID, + AfterID: 0, + }) + require.NoError(t, err) + + var assistantText string + for _, message := range messages { + if message.Role != database.ChatMessageRoleAssistant { + continue + } + parts, parseErr := chatprompt.ParseContent(message) + require.NoError(t, parseErr) + for _, part := range parts { + if part.Type == codersdk.ChatMessagePartTypeText { + assistantText += part.Text + } + } + } + require.Equal(t, "hello", assistantText) + require.False(t, chatResult.LastError.Valid) } // panicOnInTxDB wraps a database.Store and panics on the first InTx @@ -8413,6 +11271,7 @@ func TestAdvisorGating_RootChat(t *testing.T) { // covers the glue from chatd wiring -> chatadvisor.Tool -> Runtime.Run -> // nested model call -> structured result back to the outer model. func TestAdvisorHappyPath_RootChat(t *testing.T) { + t.Skip("todo: re-enable this test after pr 4 from the chatd refactor is implemented. it depends on subscribe being implemented.") t.Parallel() db, ps := dbtestutil.NewDB(t) @@ -8861,9 +11720,9 @@ func TestAdvisorGating_ExploreSubagent(t *testing.T) { // runtime together with chain mode and asserts the snapshot captured for // the nested advisor call retains the full pre-chain prompt. Chain mode // otherwise strips assistant and tool turns from the prompt the outer -// loop sees, so a regression that moves setAdvisorPromptSnapshot behind -// filterPromptForChainMode, or drops the !chainModeActive guards in -// PrepareMessages, would leak the filtered view into the advisor's +// loop sees, so a regression that captures the advisor snapshot after +// filterPromptForChainMode, or removes the chain-mode guard around +// advisor snapshotting, would leak the filtered view into the advisor's // nested call. The advisor would then only see the trailing user // message, losing the context the outer model had been building on. func TestAdvisorChainMode_SnapshotKeepsFullHistory(t *testing.T) { @@ -9066,544 +11925,6 @@ func seedAdvisorConfig( require.NoError(t, err) } -// TestPromoteQueuedWhileRunningRespectsMessageOrder guards -// TestFinishActiveChatExternalWaitingInsertsSyntheticResults -// asserts the cleanup TX inserts synthetic tool-result rows when -// PromoteQueued's deferred path set Status=Waiting while the -// worker concluded with RequiresAction. Without it, the next -// chatloop run would feed the LLM an assistant turn with -// unresolved tool_call parts and the API would reject it. -func TestFinishActiveChatExternalWaitingInsertsSyntheticResults(t *testing.T) { - t.Parallel() - - db, ps := dbtestutil.NewDB(t) - ctx := testutil.Context(t, testutil.WaitLong) - - server := newActiveTestServer(t, db, ps) - user, org, model := seedChatDependencies(t, db) - - dynamicToolsJSON, err := json.Marshal([]mcpgo.Tool{{ - Name: "my_dynamic_tool", - Description: "A test dynamic tool.", - InputSchema: mcpgo.ToolInputSchema{ - Type: "object", - Properties: map[string]any{}, - }, - }}) - require.NoError(t, err) - - chat, err := db.InsertChat(ctx, database.InsertChatParams{ - OrganizationID: org.ID, - Status: database.ChatStatusWaiting, - ClientType: database.ChatClientTypeUi, - OwnerID: user.ID, - Title: "external-waiting-stops-dead-guard", - LastModelConfigID: model.ID, - DynamicTools: nullRawMessage(dynamicToolsJSON), - }) - require.NoError(t, err) - - // Seed a user message and an assistant message with an - // unresolved dynamic tool call. This mirrors what the worker - // would have persisted before the deferred promote arrived. - insertUserTextMessage(t, db, chat.ID, user.ID, model.ID, "user input") - - pendingCallID := "call_pending_dynamic" - assistantContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{ - { - Type: codersdk.ChatMessagePartTypeToolCall, - ToolCallID: pendingCallID, - ToolName: "my_dynamic_tool", - Args: json.RawMessage(`{}`), - }, - }) - require.NoError(t, err) - _, err = db.InsertChatMessages(ctx, database.InsertChatMessagesParams{ - ChatID: chat.ID, - CreatedBy: []uuid.UUID{uuid.Nil}, - ModelConfigID: []uuid.UUID{model.ID}, - Role: []database.ChatMessageRole{database.ChatMessageRoleAssistant}, - ContentVersion: []int16{chatprompt.CurrentContentVersion}, - Content: []string{string(assistantContent.RawMessage)}, - Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth}, - InputTokens: []int64{0}, - OutputTokens: []int64{0}, - TotalTokens: []int64{0}, - ReasoningTokens: []int64{0}, - CacheCreationTokens: []int64{0}, - CacheReadTokens: []int64{0}, - ContextLimit: []int64{0}, - Compressed: []bool{false}, - TotalCostMicros: []int64{0}, - RuntimeMs: []int64{0}, - ProviderResponseID: []string{""}, - }) - require.NoError(t, err) - - // Queue a message and put the chat in the post-promote - // Waiting state (no worker, queue at front). - queuedContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{ - codersdk.ChatMessageText("queued-after-promote"), - }) - require.NoError(t, err) - _, err = db.InsertChatQueuedMessage(ctx, database.InsertChatQueuedMessageParams{ - ChatID: chat.ID, - Content: queuedContent.RawMessage, - ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true}, - }) - require.NoError(t, err) - - // Refresh chat with current status (Waiting, no worker). - latestChat, err := db.GetChatByID(ctx, chat.ID) - require.NoError(t, err) - - // Drive the cleanup path with the local-RequiresAction outcome. - updated, promoted, syntheticToolResults, finishErr := chatd.FinishActiveChatForTest( - ctx, server, latestChat, database.ChatStatusRequiresAction, "", - ) - require.NoError(t, finishErr) - require.NotNil(t, promoted, "queued message must be auto-promoted into history") - require.Equal(t, database.ChatStatusPending, updated.Status, - "chat must end Pending so the run loop picks it up") - require.Len(t, syntheticToolResults, 1, - "cleanup TX must return the inserted synthetic tool-result row so the post-TX caller can publish it") - - messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ - ChatID: chat.ID, - AfterID: 0, - }) - require.NoError(t, err) - - var ( - assistantIdx = -1 - synthToolIdx = -1 - promotedUserIdx = -1 - ) - for i, msg := range messages { - switch msg.Role { - case database.ChatMessageRoleAssistant: - assistantIdx = i - case database.ChatMessageRoleTool: - parts, parseErr := chatprompt.ParseContent(msg) - require.NoError(t, parseErr) - for _, part := range parts { - if part.Type == codersdk.ChatMessagePartTypeToolResult && - part.ToolCallID == pendingCallID && part.IsError { - synthToolIdx = i - } - } - case database.ChatMessageRoleUser: - parts, parseErr := chatprompt.ParseContent(msg) - require.NoError(t, parseErr) - for _, part := range parts { - if part.Type == codersdk.ChatMessagePartTypeText && - part.Text == "queued-after-promote" { - promotedUserIdx = i - } - } - } - } - require.NotEqual(t, -1, assistantIdx, "assistant tool-call message present") - require.NotEqual(t, -1, synthToolIdx, - "synthetic tool result for the unresolved dynamic tool call must be inserted") - require.NotEqual(t, -1, promotedUserIdx, - "promoted queued message must be inserted as a user message") - require.Less(t, assistantIdx, synthToolIdx, - "synthetic tool result must follow the assistant message") - require.Less(t, synthToolIdx, promotedUserIdx, - "promoted user message must follow the synthetic tool result") -} - -// TestPromoteQueuedFallsThroughOnStaleHeartbeat asserts a stale -// TestRecoverStaleChatsRecoversWaitingWithQueue asserts a Waiting -// chat with a non-empty queue and stale updated_at gets recovered -// to Pending, closing the post-promote-stranding hole. -func TestRecoverStaleChatsRecoversWaitingWithQueue(t *testing.T) { - t.Parallel() - - db, ps, rawDB := dbtestutil.NewDBWithSQLDB(t) - ctx := testutil.Context(t, testutil.WaitLong) - - staleAfter := 100 * time.Millisecond - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) - server := chatd.New(chatd.Config{ - Logger: logger, - Database: db, - ReplicaID: uuid.New(), - Pubsub: ps, - PendingChatAcquireInterval: testutil.WaitLong, - InFlightChatStaleAfter: staleAfter, - }) - t.Cleanup(func() { require.NoError(t, server.Close()) }) - user, org, model := seedChatDependencies(t, db) - - chat, err := db.InsertChat(ctx, database.InsertChatParams{ - OrganizationID: org.ID, - Status: database.ChatStatusWaiting, - ClientType: database.ChatClientTypeUi, - OwnerID: user.ID, - Title: "stale-waiting-with-queue", - LastModelConfigID: model.ID, - }) - require.NoError(t, err) - - queuedContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{ - codersdk.ChatMessageText("queued-stranded"), - }) - require.NoError(t, err) - _, err = db.InsertChatQueuedMessage(ctx, database.InsertChatQueuedMessageParams{ - ChatID: chat.ID, - Content: queuedContent.RawMessage, - ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true}, - }) - require.NoError(t, err) - // Backdate updated_at directly so the chat is past the stale - // threshold without sleeping. - _, err = rawDB.ExecContext(ctx, - "UPDATE chats SET updated_at = $1 WHERE id = $2", - time.Now().Add(-time.Hour), chat.ID) - require.NoError(t, err) - - chatd.RecoverStaleChatsForTest(ctx, server) - - got, err := db.GetChatByID(ctx, chat.ID) - require.NoError(t, err) - require.Equal(t, database.ChatStatusPending, got.Status, - "stale-recovery must promote the front-of-queue and set Pending") - - messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ - ChatID: chat.ID, - AfterID: 0, - }) - require.NoError(t, err) - var foundPromoted bool - for _, msg := range messages { - if msg.Role != database.ChatMessageRoleUser { - continue - } - parts, parseErr := chatprompt.ParseContent(msg) - require.NoError(t, parseErr) - for _, part := range parts { - if part.Type == codersdk.ChatMessagePartTypeText && - part.Text == "queued-stranded" { - foundPromoted = true - } - } - } - require.True(t, foundPromoted, - "the front-of-queue message must be promoted into history") - - remaining, err := db.GetChatQueuedMessages(ctx, chat.ID) - require.NoError(t, err) - require.Empty(t, remaining, - "the queue is drained after the recovery promotes its only entry") -} - -// TestRecoverStaleChatsWaitingWithUnresolvedToolCallInsertsSyntheticResults -// asserts stale recovery closes pending dynamic tool calls before -// promoting, so the recovery path does not stop the chat dead by -// feeding the LLM unresolved tool_call parts. -func TestRecoverStaleChatsWaitingWithUnresolvedToolCallInsertsSyntheticResults(t *testing.T) { - t.Parallel() - - db, ps, rawDB := dbtestutil.NewDBWithSQLDB(t) - ctx := testutil.Context(t, testutil.WaitLong) - - staleAfter := 100 * time.Millisecond - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) - server := chatd.New(chatd.Config{ - Logger: logger, - Database: db, - ReplicaID: uuid.New(), - Pubsub: ps, - PendingChatAcquireInterval: testutil.WaitLong, - InFlightChatStaleAfter: staleAfter, - }) - t.Cleanup(func() { require.NoError(t, server.Close()) }) - - user, org, model := seedChatDependencies(t, db) - - dynamicToolsJSON, err := json.Marshal([]mcpgo.Tool{{ - Name: "my_dynamic_tool", - Description: "A test dynamic tool.", - InputSchema: mcpgo.ToolInputSchema{ - Type: "object", - Properties: map[string]any{}, - }, - }}) - require.NoError(t, err) - - chat, err := db.InsertChat(ctx, database.InsertChatParams{ - OrganizationID: org.ID, - Status: database.ChatStatusWaiting, - ClientType: database.ChatClientTypeUi, - OwnerID: user.ID, - Title: "stale-waiting-with-unresolved-tool-call", - LastModelConfigID: model.ID, - DynamicTools: nullRawMessage(dynamicToolsJSON), - }) - require.NoError(t, err) - - insertUserTextMessage(t, db, chat.ID, user.ID, model.ID, "please call the tool") - - pendingCallID := "call_unresolved_dynamic" - assistantContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{ - { - Type: codersdk.ChatMessagePartTypeToolCall, - ToolCallID: pendingCallID, - ToolName: "my_dynamic_tool", - Args: json.RawMessage(`{}`), - }, - }) - require.NoError(t, err) - _, err = db.InsertChatMessages(ctx, database.InsertChatMessagesParams{ - ChatID: chat.ID, - CreatedBy: []uuid.UUID{uuid.Nil}, - ModelConfigID: []uuid.UUID{model.ID}, - Role: []database.ChatMessageRole{database.ChatMessageRoleAssistant}, - ContentVersion: []int16{chatprompt.CurrentContentVersion}, - Content: []string{string(assistantContent.RawMessage)}, - Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth}, - InputTokens: []int64{0}, - OutputTokens: []int64{0}, - TotalTokens: []int64{0}, - ReasoningTokens: []int64{0}, - CacheCreationTokens: []int64{0}, - CacheReadTokens: []int64{0}, - ContextLimit: []int64{0}, - Compressed: []bool{false}, - TotalCostMicros: []int64{0}, - RuntimeMs: []int64{0}, - ProviderResponseID: []string{""}, - }) - require.NoError(t, err) - - queuedContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{ - codersdk.ChatMessageText("queued-after-crash"), - }) - require.NoError(t, err) - _, err = db.InsertChatQueuedMessage(ctx, database.InsertChatQueuedMessageParams{ - ChatID: chat.ID, - Content: queuedContent.RawMessage, - ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true}, - }) - require.NoError(t, err) - - _, err = rawDB.ExecContext(ctx, - "UPDATE chats SET updated_at = $1 WHERE id = $2", - time.Now().Add(-time.Hour), chat.ID) - require.NoError(t, err) - - chatd.RecoverStaleChatsForTest(ctx, server) - - got, err := db.GetChatByID(ctx, chat.ID) - require.NoError(t, err) - require.Equal(t, database.ChatStatusPending, got.Status) - - messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ - ChatID: chat.ID, - AfterID: 0, - }) - require.NoError(t, err) - - var ( - assistantIdx = -1 - synthIdx = -1 - promotedUserIdx = -1 - ) - for i, msg := range messages { - switch msg.Role { - case database.ChatMessageRoleAssistant: - assistantIdx = i - case database.ChatMessageRoleTool: - parts, parseErr := chatprompt.ParseContent(msg) - require.NoError(t, parseErr) - for _, part := range parts { - if part.Type == codersdk.ChatMessagePartTypeToolResult && - part.ToolCallID == pendingCallID && part.IsError { - synthIdx = i - } - } - case database.ChatMessageRoleUser: - parts, parseErr := chatprompt.ParseContent(msg) - require.NoError(t, parseErr) - for _, part := range parts { - if part.Type == codersdk.ChatMessagePartTypeText && - part.Text == "queued-after-crash" { - promotedUserIdx = i - } - } - } - } - require.NotEqual(t, -1, assistantIdx, "assistant tool-call message present") - require.NotEqual(t, -1, synthIdx, - "stale recovery must insert synthetic tool result for the unresolved dynamic tool call") - require.NotEqual(t, -1, promotedUserIdx, - "queued message must be promoted into history") - require.Less(t, assistantIdx, synthIdx) - require.Less(t, synthIdx, promotedUserIdx) -} - -// TestInsertSyntheticToolResultsTxSkipsAlreadyHandledCalls asserts -// the helper skips tool calls already handled (e.g. when a dynamic -// tool name collides with a built-in the chatloop dispatched). -// Without dedup the LLM would see two results for the same call ID. -func TestInsertSyntheticToolResultsTxSkipsAlreadyHandledCalls(t *testing.T) { - t.Parallel() - - db, _ := dbtestutil.NewDB(t) - ctx := testutil.Context(t, testutil.WaitLong) - - user, org, model := seedChatDependencies(t, db) - - dynamicToolsJSON, err := json.Marshal([]mcpgo.Tool{ - { - Name: "duplicate_call_tool", - Description: "Tool whose call already has a result.", - InputSchema: mcpgo.ToolInputSchema{Type: "object", Properties: map[string]any{}}, - }, - { - Name: "still_pending_tool", - Description: "Tool whose call has no result yet.", - InputSchema: mcpgo.ToolInputSchema{Type: "object", Properties: map[string]any{}}, - }, - }) - require.NoError(t, err) - - chat, err := db.InsertChat(ctx, database.InsertChatParams{ - OrganizationID: org.ID, - Status: database.ChatStatusRequiresAction, - ClientType: database.ChatClientTypeUi, - OwnerID: user.ID, - Title: "synth-results-dedup", - LastModelConfigID: model.ID, - DynamicTools: nullRawMessage(dynamicToolsJSON), - }) - require.NoError(t, err) - - insertUserTextMessage(t, db, chat.ID, user.ID, model.ID, "please call both tools") - - handledCallID := "call_already_handled" - pendingCallID := "call_still_pending" - assistantContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{ - { - Type: codersdk.ChatMessagePartTypeToolCall, - ToolCallID: handledCallID, - ToolName: "duplicate_call_tool", - Args: json.RawMessage(`{}`), - }, - { - Type: codersdk.ChatMessagePartTypeToolCall, - ToolCallID: pendingCallID, - ToolName: "still_pending_tool", - Args: json.RawMessage(`{}`), - }, - }) - require.NoError(t, err) - _, err = db.InsertChatMessages(ctx, database.InsertChatMessagesParams{ - ChatID: chat.ID, - CreatedBy: []uuid.UUID{uuid.Nil}, - ModelConfigID: []uuid.UUID{model.ID}, - Role: []database.ChatMessageRole{database.ChatMessageRoleAssistant}, - ContentVersion: []int16{chatprompt.CurrentContentVersion}, - Content: []string{string(assistantContent.RawMessage)}, - Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth}, - InputTokens: []int64{0}, - OutputTokens: []int64{0}, - TotalTokens: []int64{0}, - ReasoningTokens: []int64{0}, - CacheCreationTokens: []int64{0}, - CacheReadTokens: []int64{0}, - ContextLimit: []int64{0}, - Compressed: []bool{false}, - TotalCostMicros: []int64{0}, - RuntimeMs: []int64{0}, - ProviderResponseID: []string{""}, - }) - require.NoError(t, err) - - // Pre-insert a tool-result for the handled call ID. This - // simulates the chatloop having dispatched the colliding - // dynamic tool name as a built-in. - handledResultContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{ - { - Type: codersdk.ChatMessagePartTypeToolResult, - ToolCallID: handledCallID, - ToolName: "duplicate_call_tool", - Result: json.RawMessage(`"already done"`), - }, - }) - require.NoError(t, err) - _, err = db.InsertChatMessages(ctx, database.InsertChatMessagesParams{ - ChatID: chat.ID, - CreatedBy: []uuid.UUID{uuid.Nil}, - ModelConfigID: []uuid.UUID{model.ID}, - Role: []database.ChatMessageRole{database.ChatMessageRoleTool}, - ContentVersion: []int16{chatprompt.CurrentContentVersion}, - Content: []string{string(handledResultContent.RawMessage)}, - Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth}, - InputTokens: []int64{0}, - OutputTokens: []int64{0}, - TotalTokens: []int64{0}, - ReasoningTokens: []int64{0}, - CacheCreationTokens: []int64{0}, - CacheReadTokens: []int64{0}, - ContextLimit: []int64{0}, - Compressed: []bool{false}, - TotalCostMicros: []int64{0}, - RuntimeMs: []int64{0}, - ProviderResponseID: []string{""}, - }) - require.NoError(t, err) - - chatRow, err := db.GetChatByID(ctx, chat.ID) - require.NoError(t, err) - - _, err = chatd.InsertSyntheticToolResultsTxForTest( - ctx, db, chatRow, "synth reason", - ) - require.NoError(t, err) - - messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ - ChatID: chat.ID, - AfterID: 0, - }) - require.NoError(t, err) - - var ( - handledCount int - pendingCount int - syntheticForPending bool - ) - for _, msg := range messages { - if msg.Role != database.ChatMessageRoleTool { - continue - } - parts, parseErr := chatprompt.ParseContent(msg) - require.NoError(t, parseErr) - for _, part := range parts { - if part.Type != codersdk.ChatMessagePartTypeToolResult { - continue - } - switch part.ToolCallID { - case handledCallID: - handledCount++ - case pendingCallID: - pendingCount++ - if part.IsError { - syntheticForPending = true - } - } - } - } - require.Equal(t, 1, handledCount, - "handled call must keep exactly one tool result") - require.Equal(t, 1, pendingCount, - "pending call must get exactly one synthetic tool result") - require.True(t, syntheticForPending, - "the new tool result for the pending call must be marked IsError") -} - // nullRawMessage wraps raw JSON in a NullRawMessage. An empty input // becomes the zero value (Valid=false). func nullRawMessage(raw []byte) pqtype.NullRawMessage { @@ -9613,167 +11934,486 @@ func nullRawMessage(raw []byte) pqtype.NullRawMessage { return pqtype.NullRawMessage{RawMessage: raw, Valid: true} } -// TestInsertSyntheticToolResultsTxReturnsNilWhenNoAssistantMessage -// asserts the helper short-circuits cleanly when no assistant -// message exists yet, so a deferred promote racing a worker that -// fails before any persist does not roll back the cleanup TX. -func TestInsertSyntheticToolResultsTxReturnsNilWhenNoAssistantMessage(t *testing.T) { +// Regression for the cold-start race: chatd must wait long enough +// for ListMCPTools to return after the agent's MCP reload settles. +func TestActiveServer_WorkspaceContextAndDynamicToolInjection(t *testing.T) { t.Parallel() - db, _ := dbtestutil.NewDB(t) - ctx := testutil.Context(t, testutil.WaitLong) + t.Run("persists workspace context before provider request", func(t *testing.T) { + t.Parallel() - user, org, model := seedChatDependencies(t, db) + db, ps := dbtestutil.NewDB(t) + ctx := testutil.Context(t, testutil.WaitLong) - dynamicToolsJSON, err := json.Marshal([]mcpgo.Tool{{ - Name: "my_dynamic_tool", - Description: "A test dynamic tool.", - InputSchema: mcpgo.ToolInputSchema{Type: "object", Properties: map[string]any{}}, - }}) - require.NoError(t, err) + var ( + requestsMu sync.Mutex + requests []recordedOpenAIRequest + ) + openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { + if !req.Stream { + return chattest.OpenAINonStreamingResponse("title") + } - chat, err := db.InsertChat(ctx, database.InsertChatParams{ - OrganizationID: org.ID, - Status: database.ChatStatusWaiting, - ClientType: database.ChatClientTypeUi, - OwnerID: user.ID, - Title: "no-assistant-message", - LastModelConfigID: model.ID, - DynamicTools: nullRawMessage(dynamicToolsJSON), + requestsMu.Lock() + requests = append(requests, recordOpenAIRequest(req)) + requestsMu.Unlock() + + return chattest.OpenAIStreamingResponse( + chattest.OpenAITextChunks("done")..., + ) + }) + + user, org, model := seedChatDependenciesWithProvider(t, db, "openai-compat", openAIURL) + ws, dbAgent := seedWorkspaceWithAgent(t, db, user.ID) + + const contextText = "# Project instructions\nAlways write tests." + server := newActiveTestServer(t, db, ps, func(cfg *chatd.Config) { + cfg.AgentConn = func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) { + require.Equal(t, dbAgent.ID, agentID) + + ctrl := gomock.NewController(t) + mockConn := agentconnmock.NewMockAgentConn(ctrl) + setupWorkspaceContextAgentConn(t, mockConn, dbAgent, contextText, nil) + return mockConn, func() {}, nil + } + }) + + chat, err := server.CreateChat(ctx, chatd.CreateOptions{ + OrganizationID: org.ID, + OwnerID: user.ID, + Title: "workspace-context-before-provider", + ModelConfigID: model.ID, + WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true}, + InitialUserContent: []codersdk.ChatMessagePart{ + codersdk.ChatMessageText("What are the workspace rules?"), + }, + }) + require.NoError(t, err) + + chatResult := waitForTerminalChat(ctx, t, db, chat.ID) + if chatResult.Status == database.ChatStatusError { + require.FailNowf(t, "chat failed", "last_error=%q", + chatLastErrorMessage(chatResult.LastError)) + } + require.Equal(t, database.ChatStatusWaiting, chatResult.Status) + + parts := persistedChatParts(ctx, t, db, chat.ID) + require.Len(t, contextFilePartsForAgent(parts, dbAgent.ID), 1) + contextPart := contextFilePartsForAgent(parts, dbAgent.ID)[0] + require.Equal(t, "/home/coder/project/AGENTS.md", contextPart.ContextFilePath) + require.Equal(t, contextText, contextPart.ContextFileContent) + require.Equal(t, "linux", contextPart.ContextFileOS) + require.Equal(t, "/home/coder/project", contextPart.ContextFileDirectory) + + requestsMu.Lock() + recorded := append([]recordedOpenAIRequest(nil), requests...) + requestsMu.Unlock() + require.Len(t, recorded, 1, "expected exactly one streamed model call") + require.True(t, requestHasSystemSubstring(recorded[0], "")) + require.True(t, requestHasSystemSubstring(recorded[0], contextText)) + require.True(t, requestHasSystemSubstring(recorded[0], "AGENTS.md")) }) - require.NoError(t, err) - // No assistant message persisted. The helper must return nil so - // the caller's transaction can still advance. - _, err = chatd.InsertSyntheticToolResultsTxForTest( - ctx, db, chat, "no assistant", - ) - require.NoError(t, err) + t.Run("persists workspace context once for the same agent", func(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + ctx := testutil.Context(t, testutil.WaitLong) + + var ( + requestsMu sync.Mutex + requests []recordedOpenAIRequest + ) + openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { + if !req.Stream { + return chattest.OpenAINonStreamingResponse("title") + } + + requestsMu.Lock() + requests = append(requests, recordOpenAIRequest(req)) + requestsMu.Unlock() + + return chattest.OpenAIStreamingResponse( + chattest.OpenAITextChunks("done")..., + ) + }) + + user, org, model := seedChatDependenciesWithProvider(t, db, "openai-compat", openAIURL) + ws, dbAgent := seedWorkspaceWithAgent(t, db, user.ID) + + const contextText = "# Project instructions\nKeep it simple." + var contextConfigCalls atomic.Int32 + server := newActiveTestServer(t, db, ps, func(cfg *chatd.Config) { + cfg.AgentConn = func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) { + require.Equal(t, dbAgent.ID, agentID) + + ctrl := gomock.NewController(t) + mockConn := agentconnmock.NewMockAgentConn(ctrl) + setupWorkspaceContextAgentConn(t, mockConn, dbAgent, contextText, &contextConfigCalls) + return mockConn, func() {}, nil + } + }) + + chat, err := server.CreateChat(ctx, chatd.CreateOptions{ + OrganizationID: org.ID, + OwnerID: user.ID, + Title: "workspace-context-once", + ModelConfigID: model.ID, + WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true}, + InitialUserContent: []codersdk.ChatMessagePart{ + codersdk.ChatMessageText("First turn."), + }, + }) + require.NoError(t, err) + firstResult := waitForTerminalChat(ctx, t, db, chat.ID) + if firstResult.Status == database.ChatStatusError { + require.FailNowf(t, "chat failed", "last_error=%q", + chatLastErrorMessage(firstResult.LastError)) + } + + _, err = server.SendMessage(ctx, chatd.SendMessageOptions{ + ChatID: chat.ID, + CreatedBy: user.ID, + Content: []codersdk.ChatMessagePart{ + codersdk.ChatMessageText("Second turn."), + }, + }) + require.NoError(t, err) + + secondResult := waitForTerminalChat(ctx, t, db, chat.ID) + if secondResult.Status == database.ChatStatusError { + require.FailNowf(t, "chat failed", "last_error=%q", + chatLastErrorMessage(secondResult.LastError)) + } + require.Equal(t, database.ChatStatusWaiting, secondResult.Status) + + parts := persistedChatParts(ctx, t, db, chat.ID) + require.Len(t, contextFilePartsForAgent(parts, dbAgent.ID), 1) + require.Equal(t, int32(1), contextConfigCalls.Load()) + + requestsMu.Lock() + recorded := append([]recordedOpenAIRequest(nil), requests...) + requestsMu.Unlock() + require.GreaterOrEqual(t, len(recorded), 2) + require.True(t, requestHasSystemSubstring(recorded[0], contextText)) + require.True(t, requestHasSystemSubstring(recorded[len(recorded)-1], contextText)) + }) + + t.Run("repersists workspace context after agent changes", func(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + ctx := testutil.Context(t, testutil.WaitLong) + + var ( + requestsMu sync.Mutex + requests []recordedOpenAIRequest + ) + openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { + if !req.Stream { + return chattest.OpenAINonStreamingResponse("title") + } + + requestsMu.Lock() + requests = append(requests, recordOpenAIRequest(req)) + requestsMu.Unlock() + + return chattest.OpenAIStreamingResponse( + chattest.OpenAITextChunks("done")..., + ) + }) + + user, org, model := seedChatDependenciesWithProvider(t, db, "openai-compat", openAIURL) + ws, firstAgent := seedWorkspaceWithAgent(t, db, user.ID) + + oldContext := "# Old instructions\nUse the old agent." + newContext := "# New instructions\nUse the new agent." + server := newActiveTestServer(t, db, ps, func(cfg *chatd.Config) { + cfg.AgentConn = func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) { + ctrl := gomock.NewController(t) + mockConn := agentconnmock.NewMockAgentConn(ctrl) + switch agentID { + case firstAgent.ID: + setupWorkspaceContextAgentConn(t, mockConn, firstAgent, oldContext, nil) + default: + setupWorkspaceContextAgentConn(t, mockConn, database.WorkspaceAgent{ + ID: agentID, + OperatingSystem: "linux", + Directory: "/home/coder/project-new", + ExpandedDirectory: "/home/coder/project-new", + }, newContext, nil) + } + return mockConn, func() {}, nil + } + }) + + chat, err := server.CreateChat(ctx, chatd.CreateOptions{ + OrganizationID: org.ID, + OwnerID: user.ID, + Title: "workspace-context-agent-change", + ModelConfigID: model.ID, + WorkspaceID: uuid.NullUUID{UUID: ws.ID, Valid: true}, + InitialUserContent: []codersdk.ChatMessagePart{ + codersdk.ChatMessageText("First turn."), + }, + }) + require.NoError(t, err) + firstResult := waitForTerminalChat(ctx, t, db, chat.ID) + if firstResult.Status == database.ChatStatusError { + require.FailNowf(t, "chat failed", "last_error=%q", + chatLastErrorMessage(firstResult.LastError)) + } + + secondTV := dbgen.TemplateVersion(t, db, database.TemplateVersion{ + OrganizationID: org.ID, + CreatedBy: user.ID, + }) + secondBuild, secondAgent := seedNewWorkspaceAgentBuild(t, db, user.ID, org.ID, ws.ID, secondTV.ID) + _, err = db.UpdateChatBuildAgentBinding(ctx, database.UpdateChatBuildAgentBindingParams{ + ID: chat.ID, + BuildID: uuid.NullUUID{UUID: secondBuild.ID, Valid: true}, + AgentID: uuid.NullUUID{UUID: secondAgent.ID, Valid: true}, + }) + require.NoError(t, err) + + _, err = server.SendMessage(ctx, chatd.SendMessageOptions{ + ChatID: chat.ID, + CreatedBy: user.ID, + Content: []codersdk.ChatMessagePart{ + codersdk.ChatMessageText("Second turn."), + }, + }) + require.NoError(t, err) + + secondResult := waitForTerminalChat(ctx, t, db, chat.ID) + if secondResult.Status == database.ChatStatusError { + require.FailNowf(t, "chat failed", "last_error=%q", + chatLastErrorMessage(secondResult.LastError)) + } + require.Equal(t, database.ChatStatusWaiting, secondResult.Status) + + parts := persistedChatParts(ctx, t, db, chat.ID) + require.Len(t, contextFilePartsForAgent(parts, firstAgent.ID), 1) + require.Len(t, contextFilePartsForAgent(parts, secondAgent.ID), 1) + + requestsMu.Lock() + recorded := append([]recordedOpenAIRequest(nil), requests...) + requestsMu.Unlock() + require.GreaterOrEqual(t, len(recorded), 2) + latest := recorded[len(recorded)-1] + require.True(t, requestHasSystemSubstring(latest, newContext)) + require.False(t, requestHasSystemSubstring(latest, oldContext)) + }) } -// TestRecoverStaleChatsWaitingPropagatesSynthError asserts stale -// recovery rolls back when synth-result insertion fails, leaving -// the chat Waiting for the next tick instead of promoting on top -// of incomplete history. -func TestRecoverStaleChatsWaitingPropagatesSynthError(t *testing.T) { - t.Parallel() +func setupWorkspaceContextAgentConn( + t *testing.T, + mockConn *agentconnmock.MockAgentConn, + agent database.WorkspaceAgent, + contextText string, + contextConfigCalls *atomic.Int32, +) { + t.Helper() + directory := agent.ExpandedDirectory + if directory == "" { + directory = agent.Directory + } + if directory == "" { + directory = "/home/coder/project" + } + operatingSystem := agent.OperatingSystem + if operatingSystem == "" { + operatingSystem = "linux" + } + mockConn.EXPECT().SetExtraHeaders(gomock.Any()).AnyTimes() + mockConn.EXPECT().ContextConfig(gomock.Any()).DoAndReturn( + func(context.Context) (workspacesdk.ContextConfigResponse, error) { + if contextConfigCalls != nil { + contextConfigCalls.Add(1) + } + return workspacesdk.ContextConfigResponse{ + Parts: []codersdk.ChatMessagePart{{ + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFilePath: directory + "/AGENTS.md", + ContextFileContent: contextText, + ContextFileOS: operatingSystem, + ContextFileDirectory: directory, + }}, + }, nil + }, + ).AnyTimes() + mockConn.EXPECT().ListMCPTools(gomock.Any()). + Return(workspacesdk.ListMCPToolsResponse{}, nil).AnyTimes() + mockConn.EXPECT().LS(gomock.Any(), gomock.Any(), gomock.Any()). + Return(workspacesdk.LSResponse{AbsolutePathString: "/home/coder"}, nil).AnyTimes() + mockConn.EXPECT().ReadFile(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(io.NopCloser(strings.NewReader("")), "", nil).AnyTimes() +} - db, ps, rawDB := dbtestutil.NewDBWithSQLDB(t) - ctx := testutil.Context(t, testutil.WaitLong) +func persistedChatParts( + ctx context.Context, + t *testing.T, + db database.Store, + chatID uuid.UUID, +) []codersdk.ChatMessagePart { + t.Helper() + messages := persistedChatMessages(ctx, t, db, chatID) + var parts []codersdk.ChatMessagePart + for _, msg := range messages { + parts = append(parts, mustParseChatParts(t, msg)...) + } + return parts +} - staleAfter := 100 * time.Millisecond - logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) - server := chatd.New(chatd.Config{ - Logger: logger, - Database: db, - ReplicaID: uuid.New(), - Pubsub: ps, - PendingChatAcquireInterval: testutil.WaitLong, - InFlightChatStaleAfter: staleAfter, - }) - t.Cleanup(func() { require.NoError(t, server.Close()) }) - - user, org, model := seedChatDependencies(t, db) - - dynamicToolsJSON, err := json.Marshal([]mcpgo.Tool{{ - Name: "my_dynamic_tool", - Description: "A test dynamic tool.", - InputSchema: mcpgo.ToolInputSchema{Type: "object", Properties: map[string]any{}}, - }}) - require.NoError(t, err) - - chat, err := db.InsertChat(ctx, database.InsertChatParams{ - OrganizationID: org.ID, - Status: database.ChatStatusWaiting, - ClientType: database.ChatClientTypeUi, - OwnerID: user.ID, - Title: "stale-waiting-synth-error", - LastModelConfigID: model.ID, - DynamicTools: nullRawMessage(dynamicToolsJSON), - }) - require.NoError(t, err) - - insertUserTextMessage(t, db, chat.ID, user.ID, model.ID, "user input") - - // Inject a synth-results error via an unsupported - // ContentVersion: the row is valid JSON so the insert - // succeeds, but chatprompt.ParseContent rejects it inside the - // helper. Brittle if a future migration adds a content_version - // CHECK constraint; switch to a mock store at that point. - _, err = db.InsertChatMessages(ctx, database.InsertChatMessagesParams{ - ChatID: chat.ID, - CreatedBy: []uuid.UUID{uuid.Nil}, - ModelConfigID: []uuid.UUID{model.ID}, - Role: []database.ChatMessageRole{database.ChatMessageRoleAssistant}, - ContentVersion: []int16{99}, - Content: []string{`{}`}, - Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth}, - InputTokens: []int64{0}, - OutputTokens: []int64{0}, - TotalTokens: []int64{0}, - ReasoningTokens: []int64{0}, - CacheCreationTokens: []int64{0}, - CacheReadTokens: []int64{0}, - ContextLimit: []int64{0}, - Compressed: []bool{false}, - TotalCostMicros: []int64{0}, - RuntimeMs: []int64{0}, - ProviderResponseID: []string{""}, - }) - require.NoError(t, err) - - queuedContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{ - codersdk.ChatMessageText("queued-not-promoted-on-synth-error"), - }) - require.NoError(t, err) - _, err = db.InsertChatQueuedMessage(ctx, database.InsertChatQueuedMessageParams{ - ChatID: chat.ID, - Content: queuedContent.RawMessage, - ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true}, - }) - require.NoError(t, err) - - _, err = rawDB.ExecContext(ctx, - "UPDATE chats SET updated_at = $1 WHERE id = $2", - time.Now().Add(-time.Hour), chat.ID) - require.NoError(t, err) - - chatd.RecoverStaleChatsForTest(ctx, server) - - got, err := db.GetChatByID(ctx, chat.ID) - require.NoError(t, err) - require.Equal(t, database.ChatStatusWaiting, got.Status, - "recovery must leave the chat in Waiting when synth-results fails so the next tick retries") - - // The queued message must still be in the queue, not promoted. - remaining, err := db.GetChatQueuedMessages(ctx, chat.ID) - require.NoError(t, err) - require.Len(t, remaining, 1, - "queued message must not be promoted when synth-results fails") - - // No promoted user message should appear in history. +func persistedChatMessages( + ctx context.Context, + t *testing.T, + db database.Store, + chatID uuid.UUID, +) []database.ChatMessage { + t.Helper() messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ - ChatID: chat.ID, + ChatID: chatID, AfterID: 0, }) require.NoError(t, err) - for _, msg := range messages { - if msg.Role != database.ChatMessageRoleUser { - continue - } - parts, parseErr := chatprompt.ParseContent(msg) - if parseErr != nil { - continue - } - for _, part := range parts { - require.NotEqual(t, "queued-not-promoted-on-synth-error", part.Text, - "queued message must not be promoted when synth-results fails") - } - } + return messages +} + +func contextFilePartsForAgent( + parts []codersdk.ChatMessagePart, + agentID uuid.UUID, +) []codersdk.ChatMessagePart { + var matched []codersdk.ChatMessagePart + for _, part := range parts { + if part.Type != codersdk.ChatMessagePartTypeContextFile || + !part.ContextFileAgentID.Valid || + part.ContextFileAgentID.UUID != agentID || + part.ContextFileContent == "" { + continue + } + matched = append(matched, part) + } + return matched +} + +func requireChatToolPart( + t *testing.T, + messages []database.ChatMessage, + partType codersdk.ChatMessagePartType, + toolName string, +) codersdk.ChatMessagePart { + t.Helper() + for _, msg := range messages { + for _, part := range mustParseChatParts(t, msg) { + if part.Type == partType && part.ToolName == toolName { + return part + } + } + } + require.FailNowf(t, "missing chat tool part", "type=%q tool=%q", partType, toolName) + return codersdk.ChatMessagePart{} +} + +func openAIRequestContainsToolResult(req recordedOpenAIRequest, toolResultText string) bool { + for _, msg := range req.Messages { + if msg.Role == "tool" && strings.Contains(msg.Content, toolResultText) { + return true + } + } + return false +} + +func nextWorkspaceBuildNumber(t *testing.T, db database.Store, workspaceID uuid.UUID) int32 { + t.Helper() + builds, err := db.GetWorkspaceBuildsByWorkspaceID(context.Background(), database.GetWorkspaceBuildsByWorkspaceIDParams{ + WorkspaceID: workspaceID, + OffsetOpt: 0, + LimitOpt: 100, + }) + require.NoError(t, err) + var maxBuild int32 + for _, build := range builds { + if build.BuildNumber > maxBuild { + maxBuild = build.BuildNumber + } + } + return maxBuild + 1 +} + +func seedNewWorkspaceAgentBuild( + t *testing.T, + db database.Store, + userID uuid.UUID, + orgID uuid.UUID, + workspaceID uuid.UUID, + templateVersionID uuid.UUID, +) (database.WorkspaceBuild, database.WorkspaceAgent) { + t.Helper() + pj := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{ + InitiatorID: userID, + OrganizationID: orgID, + StartedAt: sql.NullTime{Time: dbtime.Now().Add(-time.Minute), Valid: true}, + CompletedAt: sql.NullTime{Time: dbtime.Now(), Valid: true}, + }) + build := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{ + WorkspaceID: workspaceID, + TemplateVersionID: templateVersionID, + JobID: pj.ID, + BuildNumber: nextWorkspaceBuildNumber(t, db, workspaceID), + InitiatorID: userID, + Transition: database.WorkspaceTransitionStart, + }) + res := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{ + Transition: database.WorkspaceTransitionStart, + JobID: pj.ID, + }) + now := dbtime.Now() + agent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{ + ResourceID: res.ID, + LifecycleState: database.WorkspaceAgentLifecycleStateReady, + StartedAt: sql.NullTime{Time: now, Valid: true}, + ReadyAt: sql.NullTime{Time: now, Valid: true}, + FirstConnectedAt: sql.NullTime{Time: now, Valid: true}, + LastConnectedAt: sql.NullTime{Time: now, Valid: true}, + Directory: "/home/coder/project-new", + OperatingSystem: "linux", + }) + require.NoError(t, db.UpdateWorkspaceAgentStartupByID(context.Background(), database.UpdateWorkspaceAgentStartupByIDParams{ + ID: agent.ID, + Version: "v1.0.0", + ExpandedDirectory: "/home/coder/project-new", + })) + loadedAgent, err := db.GetWorkspaceAgentByID(context.Background(), agent.ID) + require.NoError(t, err) + return build, loadedAgent +} + +func seedWorkspaceForCreateTool( + t *testing.T, + db database.Store, + user database.User, + org database.Organization, +) (database.Template, database.WorkspaceTable, database.WorkspaceBuild, database.WorkspaceAgent) { + t.Helper() + tv := dbgen.TemplateVersion(t, db, database.TemplateVersion{ + OrganizationID: org.ID, + CreatedBy: user.ID, + }) + tpl := dbgen.Template(t, db, database.Template{ + CreatedBy: user.ID, + OrganizationID: org.ID, + ActiveVersionID: tv.ID, + }) + ws := dbgen.Workspace(t, db, database.WorkspaceTable{ + TemplateID: tpl.ID, + OwnerID: user.ID, + OrganizationID: org.ID, + }) + build, agent := seedNewWorkspaceAgentBuild(t, db, user.ID, org.ID, ws.ID, tv.ID) + return tpl, ws, build, agent } -// Regression for the cold-start race: chatd must wait long enough -// for ListMCPTools to return after the agent's MCP reload settles. func TestRunChat_WorkspaceMCPDiscoveryWaitsForSlowAgent(t *testing.T) { t.Parallel() @@ -9871,13 +12511,148 @@ func TestRunChat_WorkspaceMCPDiscoveryWaitsForSlowAgent(t *testing.T) { "timeout exceeds the agent's MCP reload time") } -// TestRunChat_WorkspaceMCPDiscoveryAfterMidTurnCreateWorkspace guards the -// regression where chats that bound their workspace mid-turn (via -// create_workspace) never saw workspace MCP tools on the same turn. The -// chatloop tool list was frozen at the top of the turn, so the first -// post-create_workspace step had no workspace MCP tools and the model -// fell back to bash. See PrepareTools wiring in runChat. -func TestRunChat_WorkspaceMCPDiscoveryAfterMidTurnCreateWorkspace(t *testing.T) { +// TestActiveServer_WorkspaceMCPToolDiscoveredMidTurnExecutes guards that +// a workspace MCP tool discovered after mid-turn workspace binding is +// active and executable in later generation actions for the same turn. +func TestActiveServer_WorkspaceMCPToolDiscoveredMidTurnExecutes(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + ctx := testutil.Context(t, testutil.WaitLong) + + var ( + requestsMu sync.Mutex + requests []recordedOpenAIRequest + ) + + workspaceToolName := "workspace-exec-mcp__echo" + workspaceCreateToolArgsJSON := "" + openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { + if !req.Stream { + return chattest.OpenAINonStreamingResponse("title") + } + + requestsMu.Lock() + requests = append(requests, recordOpenAIRequest(req)) + callIdx := len(requests) + requestsMu.Unlock() + + switch callIdx { + case 1: + return chattest.OpenAIStreamingResponse(chattest.OpenAIToolCallChunk("create_workspace", workspaceCreateToolArgsJSON)) + case 2: + return chattest.OpenAIStreamingResponse(chattest.OpenAIToolCallChunk(workspaceToolName, `{"input":"hello"}`)) + default: + return chattest.OpenAIStreamingResponse( + chattest.OpenAITextChunks("done")..., + ) + } + }) + + user, org, model := seedChatDependenciesWithProvider(t, db, "openai-compat", openAIURL) + + // Seed a workspace and agent for create_workspace to bind to. + tpl, ws, build, dbAgent := seedWorkspaceForCreateTool(t, db, user, org) + workspaceCreateToolArgsJSON = fmt.Sprintf(`{"template_id":%q}`, tpl.ID.String()) + + workspaceToolsResp := workspacesdk.ListMCPToolsResponse{ + Tools: []workspacesdk.MCPToolInfo{{ + ServerName: "workspace-exec-mcp", + Name: workspaceToolName, + Description: "workspace echo tool", + Schema: map[string]any{ + "type": "object", + "properties": map[string]any{ + "input": map[string]any{"type": "string"}, + }, + }, + Required: []string{"input"}, + }}, + } + + var callMCPToolCount atomic.Int32 + ctrl := gomock.NewController(t) + mockConn := agentconnmock.NewMockAgentConn(ctrl) + mockConn.EXPECT().SetExtraHeaders(gomock.Any()).AnyTimes() + mockConn.EXPECT().ContextConfig(gomock.Any()). + Return(workspacesdk.ContextConfigResponse{}, xerrors.New("not supported")).AnyTimes() + mockConn.EXPECT().ListMCPTools(gomock.Any()). + Return(workspaceToolsResp, nil).AnyTimes() + mockConn.EXPECT().CallMCPTool(gomock.Any(), gomock.Cond(func(req workspacesdk.CallMCPToolRequest) bool { + return req.ToolName == workspaceToolName && req.Arguments["input"] == "hello" + })).DoAndReturn(func(_ context.Context, _ workspacesdk.CallMCPToolRequest) (workspacesdk.CallMCPToolResponse, error) { + callMCPToolCount.Add(1) + return workspacesdk.CallMCPToolResponse{ + Content: []workspacesdk.MCPToolContent{{ + Type: "text", + Text: "echo: hello", + }}, + }, nil + }).Times(1) + mockConn.EXPECT().LS(gomock.Any(), gomock.Any(), gomock.Any()). + Return(workspacesdk.LSResponse{}, nil).AnyTimes() + mockConn.EXPECT().ReadFile(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(io.NopCloser(strings.NewReader("")), "", nil).AnyTimes() + mockConn.EXPECT().ReadFileLines(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()). + Return(workspacesdk.ReadFileLinesResponse{Success: true}, nil).AnyTimes() + mockConn.EXPECT().AwaitReachable(gomock.Any()).Return(true).AnyTimes() + + createFn := func(_ context.Context, _ uuid.UUID, req codersdk.CreateWorkspaceRequest) (codersdk.Workspace, error) { + return codersdk.Workspace{ + ID: ws.ID, + Name: req.Name, + OwnerName: user.Username, + OrganizationID: org.ID, + TemplateID: tpl.ID, + LatestBuild: codersdk.WorkspaceBuild{ + ID: build.ID, + Status: codersdk.WorkspaceStatusRunning, + }, + }, nil + } + + server := newActiveTestServer(t, db, ps, func(cfg *chatd.Config) { + cfg.AgentConn = func(_ context.Context, agentID uuid.UUID) (workspacesdk.AgentConn, func(), error) { + require.Equal(t, dbAgent.ID, agentID) + return mockConn, func() {}, nil + } + cfg.CreateWorkspace = createFn + }) + + chat, err := server.CreateChat(ctx, chatd.CreateOptions{ + OrganizationID: org.ID, + OwnerID: user.ID, + Title: "workspace-mcp-midturn-executes", + ModelConfigID: model.ID, + InitialUserContent: []codersdk.ChatMessagePart{ + codersdk.ChatMessageText("Create a workspace and call the workspace MCP tool."), + }, + }) + require.NoError(t, err) + + chatResult := waitForTerminalChat(ctx, t, db, chat.ID) + if chatResult.Status == database.ChatStatusError { + require.FailNowf(t, "chat failed", "last_error=%q", + chatLastErrorMessage(chatResult.LastError)) + } + require.Equal(t, database.ChatStatusWaiting, chatResult.Status) + require.Equal(t, int32(1), callMCPToolCount.Load()) + + messages := persistedChatMessages(ctx, t, db, chat.ID) + toolCall := requireChatToolPart(t, messages, codersdk.ChatMessagePartTypeToolCall, workspaceToolName) + require.NotEmpty(t, toolCall.ToolCallID) + toolResult := requireChatToolPart(t, messages, codersdk.ChatMessagePartTypeToolResult, workspaceToolName) + require.Contains(t, string(toolResult.Result), "echo: hello") + + requestsMu.Lock() + recorded := append([]recordedOpenAIRequest(nil), requests...) + requestsMu.Unlock() + require.GreaterOrEqual(t, len(recorded), 3) + require.Contains(t, recorded[1].Tools, workspaceToolName) + require.True(t, openAIRequestContainsToolResult(recorded[len(recorded)-1], "echo: hello")) +} + +func TestActiveServer_WorkspaceMCPDiscoveryAfterMidTurnCreateWorkspace(t *testing.T) { t.Parallel() db, ps := dbtestutil.NewDB(t) @@ -9902,9 +12677,7 @@ func TestRunChat_WorkspaceMCPDiscoveryAfterMidTurnCreateWorkspace(t *testing.T) requestsMu.Unlock() if callIdx == 1 { - return chattest.OpenAIStreamingResponse( - chattest.OpenAIToolCallChunk("create_workspace", workspaceCreateToolArgsJSON), - ) + return chattest.OpenAIStreamingResponse(chattest.OpenAIToolCallChunk("create_workspace", workspaceCreateToolArgsJSON)) } return chattest.OpenAIStreamingResponse( chattest.OpenAITextChunks("done")..., @@ -9913,47 +12686,10 @@ func TestRunChat_WorkspaceMCPDiscoveryAfterMidTurnCreateWorkspace(t *testing.T) user, org, model := seedChatDependenciesWithProvider(t, db, "openai-compat", openAIURL) - // Seed a workspace+agent for create_workspace to bind to. - tv := dbgen.TemplateVersion(t, db, database.TemplateVersion{ - OrganizationID: org.ID, - CreatedBy: user.ID, - }) - tpl := dbgen.Template(t, db, database.Template{ - CreatedBy: user.ID, - OrganizationID: org.ID, - ActiveVersionID: tv.ID, - }) + // Seed a workspace and agent for create_workspace to bind to. + tpl, ws, build, dbAgent := seedWorkspaceForCreateTool(t, db, user, org) workspaceCreateToolArgsJSON = fmt.Sprintf(`{"template_id":%q}`, tpl.ID.String()) - ws := dbgen.Workspace(t, db, database.WorkspaceTable{ - TemplateID: tpl.ID, - OwnerID: user.ID, - OrganizationID: org.ID, - }) - pj := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{ - InitiatorID: user.ID, - OrganizationID: org.ID, - CompletedAt: sql.NullTime{Valid: true, Time: dbtime.Now()}, - }) - build := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{ - TemplateVersionID: tv.ID, - WorkspaceID: ws.ID, - JobID: pj.ID, - }) - res := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{ - Transition: database.WorkspaceTransitionStart, - JobID: pj.ID, - }) - now := dbtime.Now() - dbAgent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{ - ResourceID: res.ID, - LifecycleState: database.WorkspaceAgentLifecycleStateReady, - StartedAt: sql.NullTime{Time: now, Valid: true}, - ReadyAt: sql.NullTime{Time: now, Valid: true}, - FirstConnectedAt: sql.NullTime{Time: now, Valid: true}, - LastConnectedAt: sql.NullTime{Time: now, Valid: true}, - }) - workspaceToolsResp := workspacesdk.ListMCPToolsResponse{ Tools: []workspacesdk.MCPToolInfo{{ ServerName: "workspace-midturn-mcp", @@ -10031,21 +12767,11 @@ func TestRunChat_WorkspaceMCPDiscoveryAfterMidTurnCreateWorkspace(t *testing.T) "this is the fix for mid-turn workspace MCP discovery") } -// TestRunChat_PrepareToolsRetriesAfterEmptyDiscovery guards the -// regression on the workspaceMCPDiscovered flag flip: the prior -// implementation set the flag to true before calling -// discoverWorkspaceMCPTools, so a single empty result permanently -// blocked retries within the turn. The fix sets the flag to true -// only after a non-empty discovery, so subsequent PrepareTools -// invocations keep retrying until tools appear. -// -// Scenario: create_workspace binds a workspace mid-turn. The first -// few ListMCPTools calls return empty (simulating the agent's MCP -// Connect still racing with agent startup); a later call returns -// the workspace MCP tool. The chat takes multiple steps before -// finishing, and we assert that one of the post-create_workspace -// streamed model calls advertises the workspace tool. -func TestRunChat_PrepareToolsRetriesAfterEmptyDiscovery(t *testing.T) { +// TestActiveServer_WorkspaceMCPDiscoveryRetriesAfterEmptyResult guards +// the regression where an empty workspace MCP discovery result +// permanently blocked retries within the turn. The active worker should +// retry discovery in later generation actions until tools appear. +func TestActiveServer_WorkspaceMCPDiscoveryRetriesAfterEmptyResult(t *testing.T) { t.Parallel() db, ps := dbtestutil.NewDB(t) @@ -10071,16 +12797,10 @@ func TestRunChat_PrepareToolsRetriesAfterEmptyDiscovery(t *testing.T) { // Step 1: trigger create_workspace. if callIdx == 1 { - return chattest.OpenAIStreamingResponse( - chattest.OpenAIToolCallChunk("create_workspace", workspaceCreateToolArgsJSON), - ) + return chattest.OpenAIStreamingResponse(chattest.OpenAIToolCallChunk("create_workspace", workspaceCreateToolArgsJSON)) } - // Step 2..N-1: emit empty text to keep the chatloop running so - // PrepareTools fires on each step. The chatloop ends a turn - // when the model returns a non-empty assistant message with no - // tool calls; an empty text chunk would terminate the turn, so - // we attach a dummy tool call to force another step. Use the - // LS tool because it exists for all workspaces and is cheap. + // Step 2..N-1 calls a cheap workspace tool so the active worker + // runs several generation actions before the final assistant text. if callIdx < 6 { return chattest.OpenAIStreamingResponse( chattest.OpenAIToolCallChunk("ls", `{"path":"/tmp"}`), @@ -10094,47 +12814,10 @@ func TestRunChat_PrepareToolsRetriesAfterEmptyDiscovery(t *testing.T) { user, org, model := seedChatDependenciesWithProvider(t, db, "openai-compat", openAIURL) - // Seed a workspace+agent for create_workspace to bind to. - tv := dbgen.TemplateVersion(t, db, database.TemplateVersion{ - OrganizationID: org.ID, - CreatedBy: user.ID, - }) - tpl := dbgen.Template(t, db, database.Template{ - CreatedBy: user.ID, - OrganizationID: org.ID, - ActiveVersionID: tv.ID, - }) + // Seed a workspace and agent for create_workspace to bind to. + tpl, ws, build, dbAgent := seedWorkspaceForCreateTool(t, db, user, org) workspaceCreateToolArgsJSON = fmt.Sprintf(`{"template_id":%q}`, tpl.ID.String()) - ws := dbgen.Workspace(t, db, database.WorkspaceTable{ - TemplateID: tpl.ID, - OwnerID: user.ID, - OrganizationID: org.ID, - }) - pj := dbgen.ProvisionerJob(t, db, nil, database.ProvisionerJob{ - InitiatorID: user.ID, - OrganizationID: org.ID, - CompletedAt: sql.NullTime{Valid: true, Time: dbtime.Now()}, - }) - build := dbgen.WorkspaceBuild(t, db, database.WorkspaceBuild{ - TemplateVersionID: tv.ID, - WorkspaceID: ws.ID, - JobID: pj.ID, - }) - res := dbgen.WorkspaceResource(t, db, database.WorkspaceResource{ - Transition: database.WorkspaceTransitionStart, - JobID: pj.ID, - }) - now := dbtime.Now() - dbAgent := dbgen.WorkspaceAgent(t, db, database.WorkspaceAgent{ - ResourceID: res.ID, - LifecycleState: database.WorkspaceAgentLifecycleStateReady, - StartedAt: sql.NullTime{Time: now, Valid: true}, - ReadyAt: sql.NullTime{Time: now, Valid: true}, - FirstConnectedAt: sql.NullTime{Time: now, Valid: true}, - LastConnectedAt: sql.NullTime{Time: now, Valid: true}, - }) - workspaceToolsResp := workspacesdk.ListMCPToolsResponse{ Tools: []workspacesdk.MCPToolInfo{{ ServerName: "workspace-empty-retry-mcp", @@ -10147,13 +12830,10 @@ func TestRunChat_PrepareToolsRetriesAfterEmptyDiscovery(t *testing.T) { }}, } - // First two ListMCPTools calls return empty (no error). One is the - // primer goroutine's only attempt before its retry timer fires; - // the other is PrepareTools on the first post-create_workspace - // step. The third and later calls return the workspace tool. The - // assertion below requires that a post-create_workspace step - // eventually advertises the tool, which can only happen if the - // PrepareTools callback retries discovery on subsequent steps. + // First two ListMCPTools calls return empty (no error). One may + // come from the cache primer and one from the first generation + // action after create_workspace. Later calls return the workspace + // tool, proving discovery retries after empty results. var listCalls atomic.Int32 ctrl := gomock.NewController(t) mockConn := agentconnmock.NewMockAgentConn(ctrl) @@ -10221,14 +12901,9 @@ func TestRunChat_PrepareToolsRetriesAfterEmptyDiscovery(t *testing.T) { require.GreaterOrEqual(t, len(recorded), 3, "expected at least three streamed model calls; chat must run past the empty discovery") - // The first call has no workspace yet; the second call is the - // first post-create_workspace step which sees an empty - // ListMCPTools result. By the third (or later) call PrepareTools - // must have retried discovery, so at least one post-step request - // must advertise the workspace tool. Without the - // workspaceMCPDiscovered flag-flip fix the flag would have been - // set true on the failed first attempt and no subsequent step - // would have re-attempted discovery. + // The first call has no workspace yet. By a later post-binding + // call, workspace MCP discovery must have retried after the empty + // results and advertised the workspace tool. sawWorkspaceTool := false for i := 2; i < len(recorded); i++ { if slices.Contains(recorded[i].Tools, workspaceToolName) { @@ -10237,7 +12912,7 @@ func TestRunChat_PrepareToolsRetriesAfterEmptyDiscovery(t *testing.T) { } } require.True(t, sawWorkspaceTool, - "PrepareTools must retry workspace MCP discovery on subsequent "+ - "steps; without the fix the first empty result would "+ - "permanently block retries within the turn") + "workspace MCP discovery must retry on subsequent steps; "+ + "without the fix the first empty result would permanently "+ + "block retries within the turn") } diff --git a/coderd/x/chatd/chatdebug/summary.go b/coderd/x/chatd/chatdebug/summary.go index 7b69a6b8c3..75c9da8e84 100644 --- a/coderd/x/chatd/chatdebug/summary.go +++ b/coderd/x/chatd/chatdebug/summary.go @@ -41,28 +41,6 @@ func SeedSummary(label string) map[string]any { return map[string]any{"first_message": label} } -// ExtractFirstUserText extracts the plain text content from a -// fantasy.Prompt for the first user message. Used to derive -// first_message labels at run creation time. -func ExtractFirstUserText(prompt fantasy.Prompt) string { - for _, msg := range prompt { - if msg.Role != fantasy.MessageRoleUser { - continue - } - - var sb strings.Builder - for _, part := range msg.Content { - tp, ok := fantasy.AsMessagePart[fantasy.TextPart](part) - if !ok { - continue - } - _, _ = sb.WriteString(tp.Text) - } - return sb.String() - } - return "" -} - // AggregateRunSummary reads all steps for the given run, computes token // totals, and merges them with the run's existing summary (preserving any // seeded first_message label). The baseSummary parameter should be the diff --git a/coderd/x/chatd/chatdebug/summary_test.go b/coderd/x/chatd/chatdebug/summary_test.go index 3c41877cd2..fc329ae0a9 100644 --- a/coderd/x/chatd/chatdebug/summary_test.go +++ b/coderd/x/chatd/chatdebug/summary_test.go @@ -6,7 +6,6 @@ import ( "time" "unicode/utf8" - "charm.land/fantasy" "github.com/google/uuid" "github.com/sqlc-dev/pqtype" "github.com/stretchr/testify/require" @@ -62,68 +61,6 @@ func TestSeedSummary(t *testing.T) { }) } -func TestExtractFirstUserText(t *testing.T) { - t.Parallel() - - t.Run("EmptyPrompt", func(t *testing.T) { - t.Parallel() - got := chatdebug.ExtractFirstUserText(fantasy.Prompt{}) - require.Equal(t, "", got) - }) - - t.Run("NoUserMessages", func(t *testing.T) { - t.Parallel() - prompt := fantasy.Prompt{ - { - Role: fantasy.MessageRoleSystem, - Content: []fantasy.MessagePart{fantasy.TextPart{Text: "system"}}, - }, - { - Role: fantasy.MessageRoleAssistant, - Content: []fantasy.MessagePart{fantasy.TextPart{Text: "assistant"}}, - }, - } - got := chatdebug.ExtractFirstUserText(prompt) - require.Equal(t, "", got) - }) - - t.Run("FirstUserMessageMixedParts", func(t *testing.T) { - t.Parallel() - prompt := fantasy.Prompt{ - { - Role: fantasy.MessageRoleUser, - Content: []fantasy.MessagePart{ - fantasy.TextPart{Text: "hello "}, - fantasy.FilePart{Filename: "test.png"}, - fantasy.TextPart{Text: "world"}, - }, - }, - } - got := chatdebug.ExtractFirstUserText(prompt) - require.Equal(t, "hello world", got) - }) - - t.Run("MultipleUserMessagesReturnsFirst", func(t *testing.T) { - t.Parallel() - prompt := fantasy.Prompt{ - { - Role: fantasy.MessageRoleSystem, - Content: []fantasy.MessagePart{fantasy.TextPart{Text: "system"}}, - }, - { - Role: fantasy.MessageRoleUser, - Content: []fantasy.MessagePart{fantasy.TextPart{Text: "first"}}, - }, - { - Role: fantasy.MessageRoleUser, - Content: []fantasy.MessagePart{fantasy.TextPart{Text: "second"}}, - }, - } - got := chatdebug.ExtractFirstUserText(prompt) - require.Equal(t, "first", got) - }) -} - func TestService_AggregateRunSummary(t *testing.T) { t.Parallel() diff --git a/coderd/x/chatd/chatloop/chatloop.go b/coderd/x/chatd/chatloop/chatloop.go index efe67083e2..ba892e1468 100644 --- a/coderd/x/chatd/chatloop/chatloop.go +++ b/coderd/x/chatd/chatloop/chatloop.go @@ -6,7 +6,6 @@ import ( "encoding/base64" "encoding/json" "errors" - "maps" "slices" "strconv" "strings" @@ -17,6 +16,7 @@ import ( "charm.land/fantasy" fantasyanthropic "charm.land/fantasy/providers/anthropic" "charm.land/fantasy/schema" + "github.com/google/uuid" "golang.org/x/xerrors" "cdr.dev/slog/v3" @@ -214,6 +214,78 @@ type RunOptions struct { BuiltinToolNames map[string]bool } +// GenerateAssistantOptions configures one assistant model call. +type GenerateAssistantOptions struct { + Model fantasy.LanguageModel + Messages []fantasy.Message + Tools []fantasy.AgentTool + ActiveTools []string + ProviderTools []ProviderTool + StreamSilenceTimeout time.Duration + Clock quartz.Clock + + ContextLimitFallback int64 + ModelConfig codersdk.ChatModelCallConfig + ProviderOptions fantasy.ProviderOptions + + PublishMessagePart func(codersdk.ChatMessageRole, codersdk.ChatMessagePart) + Logger slog.Logger + Metrics *Metrics +} + +// AssistantOutcome is the durable assistant-side result from one model call. +type AssistantOutcome struct { + Step PersistedStep + ToolCalls []fantasy.ToolCallContent + FinishReason fantasy.FinishReason + ModelStopped bool +} + +// ExecuteLocalToolsOptions configures one local tool execution batch. +type ExecuteLocalToolsOptions struct { + Tools []fantasy.AgentTool + ActiveTools []string + ProviderTools []ProviderTool + ToolCalls []fantasy.ToolCallContent + + ExclusiveToolNames map[string]bool + BuiltinToolNames map[string]bool + ModelProvider string + ModelName string + + PublishMessagePart func(codersdk.ChatMessageRole, codersdk.ChatMessagePart) + Logger slog.Logger + Metrics *Metrics +} + +// ToolExecutionOutcome is the durable tool-result content from one batch. +type ToolExecutionOutcome struct { + Step PersistedStep +} + +// GenerateCompactionOptions configures one context compaction call. +type GenerateCompactionOptions struct { + Model fantasy.LanguageModel + Messages []fantasy.Message + + ThresholdPercent int32 + ContextLimit int64 + ContextLimitFallback int64 + SummaryPrompt string + SystemSummaryPrefix string + Timeout time.Duration + StepUsage fantasy.Usage + StepMetadata fantasy.ProviderMetadata + + DebugSvc *chatdebug.Service + ChatID uuid.UUID + HistoryTipMessageID int64 + ToolCallID string + ToolName string + + PublishMessagePart func(codersdk.ChatMessageRole, codersdk.ChatMessagePart) +} + // ProviderTool pairs a provider-native tool definition with an // optional local executor. When Runner is nil the tool is fully // provider-executed (e.g. web search). When Runner is non-nil @@ -245,106 +317,6 @@ type stepResult struct { reasoningCompletedAt []time.Time } -// toResponseMessages converts step content into messages suitable -// for appending to the conversation. Mirrors fantasy's -// toResponseMessages logic. -func (r stepResult) toResponseMessages() []fantasy.Message { - var assistantParts []fantasy.MessagePart - var toolParts []fantasy.MessagePart - - for _, c := range r.content { - switch c.GetType() { - case fantasy.ContentTypeText: - text, ok := fantasy.AsContentType[fantasy.TextContent](c) - if !ok || strings.TrimSpace(text.Text) == "" { - continue - } - assistantParts = append(assistantParts, fantasy.TextPart{ - Text: text.Text, - ProviderOptions: fantasy.ProviderOptions(text.ProviderMetadata), - }) - case fantasy.ContentTypeReasoning: - reasoning, ok := fantasy.AsContentType[fantasy.ReasoningContent](c) - if !ok { - continue - } - opts := fantasy.ProviderOptions(reasoning.ProviderMetadata) - if strings.TrimSpace(reasoning.Text) == "" && !chatsanitize.HasAnthropicSignedReasoningOptions(opts) { - continue - } - assistantParts = append(assistantParts, fantasy.ReasoningPart{ - Text: reasoning.Text, - ProviderOptions: opts, - }) - case fantasy.ContentTypeToolCall: - toolCall, ok := fantasy.AsContentType[fantasy.ToolCallContent](c) - if !ok { - continue - } - assistantParts = append(assistantParts, fantasy.ToolCallPart{ - ToolCallID: toolCall.ToolCallID, - ToolName: toolCall.ToolName, - Input: toolCall.Input, - ProviderExecuted: toolCall.ProviderExecuted, - ProviderOptions: fantasy.ProviderOptions(toolCall.ProviderMetadata), - }) - case fantasy.ContentTypeFile: - file, ok := fantasy.AsContentType[fantasy.FileContent](c) - if !ok { - continue - } - assistantParts = append(assistantParts, fantasy.FilePart{ - Data: file.Data, - MediaType: file.MediaType, - ProviderOptions: fantasy.ProviderOptions(file.ProviderMetadata), - }) - case fantasy.ContentTypeSource: - // Sources are metadata about references; they don't - // need to be included in conversation messages. - continue - case fantasy.ContentTypeToolResult: - result, ok := fantasy.AsContentType[fantasy.ToolResultContent](c) - if !ok { - continue - } - part := fantasy.ToolResultPart{ - ToolCallID: result.ToolCallID, - Output: result.Result, - ProviderExecuted: result.ProviderExecuted, - ProviderOptions: fantasy.ProviderOptions(result.ProviderMetadata), - } - // Provider-executed tool results (e.g. web_search) - // must stay in the assistant message so the result - // block appears inline after the corresponding - // server_tool_use block. This matches the persistence - // layer in chatd.go which keeps them in - // assistantBlocks. - if result.ProviderExecuted { - assistantParts = append(assistantParts, part) - } else { - toolParts = append(toolParts, part) - } - default: - continue - } - } - - var messages []fantasy.Message - if len(assistantParts) > 0 { - messages = append(messages, fantasy.Message{ - Role: fantasy.MessageRoleAssistant, - Content: assistantParts, - }) - } - if len(toolParts) > 0 { - messages = append(messages, fantasy.Message{ - Role: fantasy.MessageRoleTool, - Content: toolParts, - }) - } - return messages -} - // reasoningState accumulates reasoning content and provider // metadata while the stream is in flight. type reasoningState struct { @@ -353,17 +325,11 @@ type reasoningState struct { startedAt time.Time } -// Run executes the chat step-stream loop and delegates -// persistence/publishing to callbacks. -func Run(ctx context.Context, opts RunOptions) error { +// GenerateAssistant performs one assistant model stream and returns the +// durable assistant-side content. It does not execute tools, retry, or persist. +func GenerateAssistant(ctx context.Context, opts GenerateAssistantOptions) (AssistantOutcome, error) { if opts.Model == nil { - return xerrors.New("chat model is required") - } - if opts.PersistStep == nil { - return xerrors.New("persist step callback is required") - } - if opts.MaxSteps <= 0 { - opts.MaxSteps = 1 + return AssistantOutcome{}, xerrors.New("chat model is required") } if opts.StreamSilenceTimeout <= 0 { opts.StreamSilenceTimeout = defaultStreamSilenceTimeout @@ -376,360 +342,201 @@ func Run(ctx context.Context, opts RunOptions) error { } publishMessagePart := func(role codersdk.ChatMessageRole, part codersdk.ChatMessagePart) { - if opts.PublishMessagePart == nil { - return + if opts.PublishMessagePart != nil { + opts.PublishMessagePart(role, part) } - opts.PublishMessagePart(role, part) } - tools := buildToolDefinitions(opts.Tools, opts.ActiveTools, opts.ProviderTools) + provider := opts.Model.Provider() + modelName := opts.Model.Model() + runOpts := RunOptions{ + Model: opts.Model, + Logger: opts.Logger, + } + _, prepared, err := prepareMessagesForRequest(ctx, runOpts, opts.Messages, provider, modelName, 0, 1) + if err != nil { + return AssistantOutcome{}, xerrors.Errorf("prepare prompt: %w", err) + } + opts.Metrics.MessageCount.WithLabelValues(provider, modelName).Observe(float64(len(prepared))) + opts.Metrics.PromptSizeBytes.WithLabelValues(provider, modelName).Observe(float64(EstimatePromptSize(prepared))) + opts.Metrics.StepsTotal.WithLabelValues(provider, modelName).Inc() - messages := opts.Messages - var lastUsage fantasy.Usage - var lastProviderMetadata fantasy.ProviderMetadata - needsFullHistoryReload := false - reloadFullHistory := func(stage string) error { - if opts.ReloadMessages == nil { - return nil + call := fantasy.Call{ + Prompt: prepared, + Tools: buildToolDefinitions(opts.Tools, opts.ActiveTools, opts.ProviderTools), + MaxOutputTokens: opts.ModelConfig.MaxOutputTokens, + Temperature: opts.ModelConfig.Temperature, + TopP: opts.ModelConfig.TopP, + TopK: opts.ModelConfig.TopK, + PresencePenalty: opts.ModelConfig.PresencePenalty, + FrequencyPenalty: opts.ModelConfig.FrequencyPenalty, + ProviderOptions: opts.ProviderOptions, + } + + stepStart := time.Now() + attempt, streamErr := guardedStream( + ctx, + provider, + modelName, + opts.Clock, + opts.StreamSilenceTimeout, + func(attemptCtx context.Context) (fantasy.StreamResponse, error) { + return opts.Model.Stream(attemptCtx, call) + }, + opts.Metrics, + ) + if streamErr != nil { + wrappedErr := wrapProviderStreamError(provider, streamErr) + classified := chaterror.Classify(wrappedErr).WithProvider(provider) + if classified.Retryable { + opts.Metrics.RecordStreamRetry(provider, modelName, classified) } - reloaded, err := opts.ReloadMessages(ctx) - if err != nil { - return xerrors.Errorf("reload messages %s: %w", stage, err) + return AssistantOutcome{}, wrappedErr + } + defer attempt.release() + + result, processErr := processStepStream(attempt.ctx, attempt.stream, publishMessagePart) + if err := attempt.finish(processErr); err != nil { + if errors.Is(err, ErrInterrupted) { + return AssistantOutcome{}, ErrInterrupted } - messages = reloaded + wrappedErr := wrapProviderStreamError(provider, err) + classified := chaterror.Classify(wrappedErr).WithProvider(provider) + if classified.Retryable { + opts.Metrics.RecordStreamRetry(provider, modelName, classified) + } + return AssistantOutcome{}, wrappedErr + } + + contextLimit := extractContextLimitWithFallback(result.providerMetadata, opts.ContextLimitFallback) + result.content = chatsanitize.SanitizeAnthropicProviderToolStepContent( + ctx, opts.Logger, provider, modelName, + "assistant_helper", 0, result.finishReason, result.content, + ) + step := PersistedStep{ + Content: result.content, + Usage: result.usage, + ContextLimit: contextLimit, + ProviderResponseID: chatopenai.ExtractResponseIDIfStored(opts.ProviderOptions, result.providerMetadata), + Runtime: time.Since(stepStart), + ToolCallCreatedAt: result.toolCallCreatedAt, + ToolResultCreatedAt: result.toolResultCreatedAt, + ReasoningStartedAt: result.reasoningStartedAt, + ReasoningCompletedAt: result.reasoningCompletedAt, + } + return AssistantOutcome{ + Step: step, + ToolCalls: append([]fantasy.ToolCallContent(nil), result.toolCalls...), + FinishReason: result.finishReason, + ModelStopped: len(result.content) == 0, + }, nil +} + +func wrapProviderStreamError(provider string, err error) error { + if err == nil { return nil } - - totalSteps := 0 - // When totalSteps reaches MaxSteps the inner loop exits immediately - // (its condition is false), stoppedByModel stays false, and the - // post-loop guard breaks the outer compaction loop. - for compactionAttempt := 0; ; compactionAttempt++ { - alreadyCompacted := false - // stoppedByModel is true when the inner step loop - // exited because the model produced no tool calls - // (shouldContinue was false). This distinguishes a - // natural stop from hitting MaxSteps. - stoppedByModel := false - // compactedOnFinalStep tracks whether compaction - // occurred on the very step where the model stopped. - // Only in that case should we re-enter, because the - // agent never had a chance to use the compacted context. - compactedOnFinalStep := false - - for step := 0; totalSteps < opts.MaxSteps; step++ { - totalSteps++ - provider := opts.Model.Provider() - modelName := opts.Model.Model() - opts.Metrics.StepsTotal.WithLabelValues(provider, modelName).Inc() - stepStart := time.Now() - if opts.PrepareTools != nil { - if updated := opts.PrepareTools(opts.Tools); updated != nil { - opts.ActiveTools = mergeNewToolNames( - opts.ActiveTools, opts.Tools, updated, - ) - opts.Tools = updated - tools = buildToolDefinitions( - opts.Tools, opts.ActiveTools, opts.ProviderTools, - ) - } - } - var prepared []fantasy.Message - var prepareErr error - messages, prepared, prepareErr = prepareMessagesForRequest( - ctx, opts, messages, provider, modelName, step, totalSteps, - ) - if prepareErr != nil { - return xerrors.Errorf("prepare prompt: %w", prepareErr) - } - opts.Metrics.MessageCount.WithLabelValues(provider, modelName).Observe(float64(len(prepared))) - opts.Metrics.PromptSizeBytes.WithLabelValues(provider, modelName).Observe(float64(EstimatePromptSize(prepared))) - - call := fantasy.Call{ - Prompt: prepared, - Tools: tools, - MaxOutputTokens: opts.ModelConfig.MaxOutputTokens, - Temperature: opts.ModelConfig.Temperature, - TopP: opts.ModelConfig.TopP, - TopK: opts.ModelConfig.TopK, - PresencePenalty: opts.ModelConfig.PresencePenalty, - FrequencyPenalty: opts.ModelConfig.FrequencyPenalty, - ProviderOptions: opts.ProviderOptions, - } - - var result stepResult - var retryPrepareErr error - stepCtx := chatdebug.ReuseStep(ctx) - err := chatretry.Retry(stepCtx, func(retryCtx context.Context) error { - if retryPrepareErr != nil { - return retryPrepareErr - } - attempt, streamErr := guardedStream( - retryCtx, - provider, - modelName, - opts.Clock, - opts.StreamSilenceTimeout, - func(attemptCtx context.Context) (fantasy.StreamResponse, error) { - return opts.Model.Stream(attemptCtx, call) - }, - opts.Metrics, - ) - if streamErr != nil { - return streamErr - } - defer attempt.release() - var processErr error - result, processErr = processStepStream( - attempt.ctx, - attempt.stream, - publishMessagePart, - ) - return attempt.finish(processErr) - }, func( - attempt int, - retryErr error, - classified chatretry.ClassifiedError, - delay time.Duration, - ) { - // Reset result from the failed attempt so the next - // attempt starts clean. - result = stepResult{} - // Record before OnRetry so a panicking callback can't - // drop the sample. The metric's provider label comes - // from the outer local; WithProvider only affects the - // classified payload handed to OnRetry. - classified = classified.WithProvider(provider) - opts.Metrics.RecordStreamRetry(provider, modelName, classified) - if classified.ChainBroken { - if chatopenai.HasPreviousResponseID(opts.ProviderOptions) { - opts.ProviderOptions = chatopenai.ClearPreviousResponseID(opts.ProviderOptions) - } - if chatopenai.HasPreviousResponseID(call.ProviderOptions) { - call.ProviderOptions = chatopenai.ClearPreviousResponseID(call.ProviderOptions) - } - if opts.DisableChainMode != nil { - opts.DisableChainMode() - } - if opts.ReloadMessages != nil { - reloaded, err := opts.ReloadMessages(ctx) - if err != nil { - opts.Logger.Warn(ctx, - "chain-broken recovery: reload messages failed", - slog.Error(err), - ) - } else { - // Reloaded history replaces the prompt prepared before - // the failed attempt, so run the same preparation - // pipeline used by normal provider requests. - var ( - reloadedCanonical []fantasy.Message - retryPrompt []fantasy.Message - prepareErr error - ) - call.Prompt = nil - reloadedCanonical, retryPrompt, prepareErr = prepareMessagesForRequest( - ctx, opts, reloaded, provider, modelName, step, totalSteps, - ) - if prepareErr != nil { - retryPrepareErr = prepareErr - } else { - messages = reloadedCanonical - call.Prompt = retryPrompt - } - } - } - } - if opts.OnRetry != nil { - opts.OnRetry(attempt, retryErr, classified, delay) - } - }) - if err != nil { - if errors.Is(err, ErrInterrupted) { - persistInterruptedStep(ctx, opts, &result) - return ErrInterrupted - } - if retryPrepareErr != nil && errors.Is(err, retryPrepareErr) { - return xerrors.Errorf("prepare prompt: %w", err) - } - return xerrors.Errorf("stream response: %w", err) - } - - // Execute tools before persisting so that tool results - // are included in the persisted step content. The - // persistence layer splits assistant and tool-result - // blocks into separate database messages by role. - var toolResults []fantasy.ToolResultContent - if result.shouldContinue { - var err error - toolResults, err = executeToolsForStep(ctx, opts, &result, provider, modelName, step, stepStart, publishMessagePart) - if err != nil { - return err - } - } - // Extract context limit from provider metadata. - contextLimit := extractContextLimitWithFallback( - result.providerMetadata, - opts.ContextLimitFallback, - ) - result.content = chatsanitize.SanitizeAnthropicProviderToolStepContent( - ctx, opts.Logger, provider, modelName, - "normal_persist", step, result.finishReason, result.content, - ) - if len(result.content) == 0 { - lastUsage = result.usage - lastProviderMetadata = result.providerMetadata - stoppedByModel = true - break - } - - // Persist the step. If persistence fails because - // the chat was interrupted between the previous - // check and here, fall back to the interrupt-safe - // path so partial content is not lost. - if err := opts.PersistStep(ctx, PersistedStep{ - Content: result.content, - Usage: result.usage, - ContextLimit: contextLimit, - ProviderResponseID: chatopenai.ExtractResponseIDIfStored(opts.ProviderOptions, result.providerMetadata), - Runtime: time.Since(stepStart), - ToolCallCreatedAt: result.toolCallCreatedAt, - ToolResultCreatedAt: result.toolResultCreatedAt, - ReasoningStartedAt: result.reasoningStartedAt, - ReasoningCompletedAt: result.reasoningCompletedAt, - }); err != nil { - if errors.Is(err, ErrInterrupted) { - persistInterruptedStep(ctx, opts, &result) - return ErrInterrupted - } - return xerrors.Errorf("persist step: %w", err) - } - lastUsage = result.usage - lastProviderMetadata = result.providerMetadata - - // Check if any executed tool triggers an early stop. - if shouldStopAfterTools(opts.StopAfterTools, toolResults) { - tryCompactOnExit(ctx, opts, result.usage, result.providerMetadata) - return ErrStopAfterTool - } - - // When chain mode is active (PreviousResponseID set), exit - // it after persisting the first chained step. Continuation - // steps include tool-result messages, which fantasy rejects - // when previous_response_id is set, so we must leave chain - // mode and reload the full history before the next call. - stepMessages := result.toResponseMessages() - if chatopenai.HasPreviousResponseID(opts.ProviderOptions) { - opts.ProviderOptions = chatopenai.ClearPreviousResponseID(opts.ProviderOptions) - if opts.DisableChainMode != nil { - opts.DisableChainMode() - } - switch { - case opts.ReloadMessages != nil: - if err := reloadFullHistory("after chain mode exit"); err != nil { - return err - } - needsFullHistoryReload = false - default: - messages = append(messages, stepMessages...) - needsFullHistoryReload = false - } - } else { - messages = append(messages, stepMessages...) - } - - if needsFullHistoryReload && !result.shouldContinue && - opts.ReloadMessages != nil { - if err := reloadFullHistory("before final compaction after chain mode exit"); err != nil { - return err - } - needsFullHistoryReload = false - } - - // Inline compaction. - if !needsFullHistoryReload && opts.Compaction != nil && opts.ReloadMessages != nil { - did, compactErr := tryCompact( - ctx, - opts.Model, - opts.Compaction, - opts.ContextLimitFallback, - result.usage, - result.providerMetadata, - messages, - ) - opts.Metrics.RecordCompaction(provider, modelName, did, compactErr) - if compactErr != nil && opts.Compaction.OnError != nil { - opts.Compaction.OnError(compactErr) - } - - if did { - alreadyCompacted = true - compactedOnFinalStep = true - if err := reloadFullHistory("after compaction"); err != nil { - return err - } - } - } - if !result.shouldContinue { - stoppedByModel = true - break - } - - // The agent is continuing with tool calls, so any - // prior compaction has already been consumed. - compactedOnFinalStep = false + classified := chaterror.Classify(err).WithProvider(provider) + if !classified.Retryable && classified.StatusCode == 0 && errors.Is(err, context.Canceled) { + wrapped := errors.Join(chaterror.ErrProviderTransportReset, err) + reclassified := chaterror.Classify(wrapped).WithProvider(provider) + if reclassified.Retryable { + classified = reclassified + err = wrapped } + } + return xerrors.Errorf("stream response: %w", chaterror.WithClassification(err, classified)) +} - if needsFullHistoryReload && stoppedByModel && opts.ReloadMessages != nil { - if err := reloadFullHistory("before post-run compaction after chain mode exit"); err != nil { - return err - } - needsFullHistoryReload = false +// ExecuteLocalTools runs local tool calls and returns durable tool results. It +// does not retry or persist. +func ExecuteLocalTools(ctx context.Context, opts ExecuteLocalToolsOptions) (ToolExecutionOutcome, error) { + if opts.Metrics == nil { + opts.Metrics = NopMetrics() + } + provider := opts.ModelProvider + if provider == "" { + provider = "unknown" + } + modelName := opts.ModelName + if modelName == "" { + modelName = "unknown" + } + publishMessagePart := func(role codersdk.ChatMessageRole, part codersdk.ChatMessagePart) { + if opts.PublishMessagePart != nil { + opts.PublishMessagePart(role, part) } - - // Post-run compaction safety net: if we never compacted - // during the loop, try once at the end. - if !needsFullHistoryReload && !alreadyCompacted && opts.Compaction != nil && opts.ReloadMessages != nil { - did, err := tryCompact( - ctx, - opts.Model, - opts.Compaction, - opts.ContextLimitFallback, - lastUsage, - lastProviderMetadata, - messages, - ) - opts.Metrics.RecordCompaction(opts.Model.Provider(), opts.Model.Model(), did, err) - if err != nil { - if opts.Compaction.OnError != nil { - opts.Compaction.OnError(err) - } - } - if did { - compactedOnFinalStep = true - } - } - // Re-enter the step loop when compaction fired on the - // model's final step. This lets the agent continue - // working with fresh summarized context instead of - // stopping. When the inner loop continued after inline - // compaction (tool-call steps kept going), the agent - // already used the compacted context, so no re-entry - // is needed. Limit retries to prevent infinite loops. - if compactedOnFinalStep && stoppedByModel && - opts.ReloadMessages != nil && - compactionAttempt < maxCompactionRetries { - reloaded, reloadErr := opts.ReloadMessages(ctx) - if reloadErr != nil { - return xerrors.Errorf("reload messages after compaction: %w", reloadErr) - } - messages = reloaded - continue - } - break + } + if ctx.Err() != nil { + return ToolExecutionOutcome{}, ctx.Err() } - return nil + localCalls := make([]fantasy.ToolCallContent, 0, len(opts.ToolCalls)) + for _, tc := range opts.ToolCalls { + if !tc.ProviderExecuted { + localCalls = append(localCalls, tc) + } + } + if len(localCalls) == 0 { + return ToolExecutionOutcome{}, nil + } + + var result stepResult + policyResults, exclusiveViolation := applyExclusiveToolPolicy( + localCalls, + opts.ExclusiveToolNames, + opts.Metrics, + provider, + modelName, + ) + if exclusiveViolation { + now := dbtime.Now() + for _, tr := range policyResults { + recordToolResultTimestamp(&result, tr.ToolCallID, now) + publishToolAttachments(ctx, opts.Logger, tr, now, publishMessagePart) + ssePart := chatprompt.PartFromContentWithLogger(ctx, opts.Logger, tr) + ssePart.CreatedAt = &now + publishMessagePart(codersdk.ChatMessageRoleTool, ssePart) + result.content = append(result.content, tr) + } + if ctx.Err() != nil { + return ToolExecutionOutcome{}, ctx.Err() + } + return ToolExecutionOutcome{Step: PersistedStep{ + Content: result.content, + ToolResultCreatedAt: result.toolResultCreatedAt, + }}, nil + } + + toolResults := executeTools( + ctx, + opts.Tools, + opts.ActiveTools, + opts.ProviderTools, + localCalls, + opts.Metrics, + opts.Logger, + provider, + modelName, + opts.BuiltinToolNames, + func(tr fantasy.ToolResultContent, completedAt time.Time) { + recordToolResultTimestamp(&result, tr.ToolCallID, completedAt) + publishToolAttachments(ctx, opts.Logger, tr, completedAt, publishMessagePart) + ssePart := chatprompt.PartFromContentWithLogger(ctx, opts.Logger, tr) + ssePart.CreatedAt = &completedAt + publishMessagePart(codersdk.ChatMessageRoleTool, ssePart) + }, + ) + if ctx.Err() != nil { + return ToolExecutionOutcome{}, ctx.Err() + } + for _, tr := range toolResults { + result.content = append(result.content, tr) + } + return ToolExecutionOutcome{Step: PersistedStep{ + Content: result.content, + ToolResultCreatedAt: result.toolResultCreatedAt, + }}, nil } // prepareMessagesForRequest applies the prompt preparation pipeline used @@ -966,14 +773,19 @@ func processStepStream( } case fantasy.StreamPartTypeReasoningDelta: + reasoningPart := codersdk.ChatMessageReasoning(part.Delta) if active, exists := activeReasoningContent[part.ID]; exists { active.text += part.Delta if len(part.ProviderMetadata) > 0 { active.options = part.ProviderMetadata } activeReasoningContent[part.ID] = active + if !active.startedAt.IsZero() { + startedAt := active.startedAt + reasoningPart.CreatedAt = &startedAt + } } - publishMessagePart(codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageReasoning(part.Delta)) + publishMessagePart(codersdk.ChatMessageRoleAssistant, reasoningPart) case fantasy.StreamPartTypeReasoningEnd: if active, exists := activeReasoningContent[part.ID]; exists { @@ -1255,187 +1067,6 @@ func executeTools( return results } -// executeToolsForStep runs the tool-execution phase of a single -// chatloop step. It enforces the exclusive-tool policy, partitions -// built-in versus dynamic tool calls, dispatches built-in tools, and -// when dynamic tool calls are present persists the step and returns -// ErrDynamicToolCall so the caller can execute them externally. -// Returns the tool results to append to the step, or an error that the -// caller must propagate (ErrInterrupted, ErrDynamicToolCall, ctx.Err(), -// or a persistence failure). -func executeToolsForStep( - ctx context.Context, - opts RunOptions, - result *stepResult, - provider, modelName string, - step int, - stepStart time.Time, - publishMessagePart func(codersdk.ChatMessageRole, codersdk.ChatMessagePart), -) ([]fantasy.ToolResultContent, error) { - // Check for context cancellation before starting tool - // execution. If the chat was interrupted between stream - // completion and here, persist what we have and bail out. - if ctx.Err() != nil { - if errors.Is(context.Cause(ctx), ErrInterrupted) { - persistInterruptedStep(ctx, opts, result) - return nil, ErrInterrupted - } - return nil, ctx.Err() - } - - // Enforce exclusivity across ALL locally-executable tool - // calls (both built-in and dynamic) before partitioning. - // Checking only the built-in partition would let the model - // bypass the policy by mixing an exclusive tool with a - // dynamic tool: the exclusive tool would still run and the - // dynamic call would still be handed to the caller for - // external execution, breaking the planning-only contract. - localCandidates := make([]fantasy.ToolCallContent, 0, len(result.toolCalls)) - for _, tc := range result.toolCalls { - if !tc.ProviderExecuted { - localCandidates = append(localCandidates, tc) - } - } - policyResults, exclusiveViolation := applyExclusiveToolPolicy( - localCandidates, - opts.ExclusiveToolNames, - opts.Metrics, - provider, - modelName, - ) - if exclusiveViolation { - now := dbtime.Now() - for _, tr := range policyResults { - recordToolResultTimestamp(result, tr.ToolCallID, now) - publishToolAttachments(ctx, opts.Logger, tr, now, publishMessagePart) - ssePart := chatprompt.PartFromContentWithLogger(ctx, opts.Logger, tr) - ssePart.CreatedAt = &now - publishMessagePart(codersdk.ChatMessageRoleTool, ssePart) - } - for _, tr := range policyResults { - result.content = append(result.content, tr) - } - // Mirror the post-execution interruption check used by the - // non-policy path: if the chat was interrupted while we - // synthesized policy errors, route through - // persistInterruptedStep so the synthesized results are not - // dropped when the regular PersistStep path fails on a - // canceled context. - if ctx.Err() != nil { - if errors.Is(context.Cause(ctx), ErrInterrupted) { - persistInterruptedStep(ctx, opts, result) - return nil, ErrInterrupted - } - return nil, ctx.Err() - } - // Fall through to the normal persistence path so the loop - // continues with error results that the model can observe - // and retry. Skip partitioning, execution, and - // pending-dynamic persistence. - return policyResults, nil - } - - // Partition tool calls into built-in and dynamic. - var builtinCalls, dynamicCalls []fantasy.ToolCallContent - if len(opts.DynamicToolNames) > 0 { - for _, tc := range result.toolCalls { - if opts.DynamicToolNames[tc.ToolName] { - dynamicCalls = append(dynamicCalls, tc) - } else { - builtinCalls = append(builtinCalls, tc) - } - } - } else { - builtinCalls = result.toolCalls - } - - // Execute only built-in tools. - toolResults := executeTools(ctx, opts.Tools, opts.ActiveTools, opts.ProviderTools, builtinCalls, opts.Metrics, opts.Logger, provider, modelName, opts.BuiltinToolNames, func(tr fantasy.ToolResultContent, completedAt time.Time) { - recordToolResultTimestamp(result, tr.ToolCallID, completedAt) - publishToolAttachments(ctx, opts.Logger, tr, completedAt, publishMessagePart) - ssePart := chatprompt.PartFromContentWithLogger(ctx, opts.Logger, tr) - ssePart.CreatedAt = &completedAt - publishMessagePart(codersdk.ChatMessageRoleTool, ssePart) - }) - for _, tr := range toolResults { - result.content = append(result.content, tr) - } - - // If dynamic tools were called, persist what we have - // (assistant + built-in results) and exit so the caller can - // execute them externally. - if len(dynamicCalls) > 0 { - // Strip Anthropic provider-executed tool calls without - // matching results before persisting so the action-required - // step does not carry a malformed tool-call history into - // downstream provider requests. - result.content = chatsanitize.SanitizeAnthropicProviderToolStepContent( - ctx, opts.Logger, provider, modelName, - "dynamic_tool_persist", step, result.finishReason, result.content, - ) - if err := persistPendingDynamicStep(ctx, opts, result, stepStart, dynamicCalls); err != nil { - return nil, err - } - tryCompactOnExit(ctx, opts, result.usage, result.providerMetadata) - return nil, ErrDynamicToolCall - } - - // Check for interruption after tool execution. Tools that - // were canceled mid-flight produce error results via ctx - // cancellation. Persist the full step (assistant blocks + - // tool results) through the interrupt-safe path so nothing - // is lost. - if ctx.Err() != nil { - if errors.Is(context.Cause(ctx), ErrInterrupted) { - persistInterruptedStep(ctx, opts, result) - return nil, ErrInterrupted - } - return nil, ctx.Err() - } - - return toolResults, nil -} - -// persistPendingDynamicStep persists a step that has pending dynamic -// tool calls awaiting external execution. Returns ErrInterrupted when -// persistence fails because the chat was interrupted. -func persistPendingDynamicStep( - ctx context.Context, - opts RunOptions, - result *stepResult, - stepStart time.Time, - dynamicCalls []fantasy.ToolCallContent, -) error { - pending := make([]PendingToolCall, 0, len(dynamicCalls)) - for _, dc := range dynamicCalls { - pending = append(pending, PendingToolCall{ - ToolCallID: dc.ToolCallID, - ToolName: dc.ToolName, - Args: dc.Input, - }) - } - - contextLimit := extractContextLimitWithFallback(result.providerMetadata, opts.ContextLimitFallback) - - if err := opts.PersistStep(ctx, PersistedStep{ - Content: result.content, - Usage: result.usage, - ContextLimit: contextLimit, - ProviderResponseID: chatopenai.ExtractResponseIDIfStored(opts.ProviderOptions, result.providerMetadata), - Runtime: time.Since(stepStart), - PendingDynamicToolCalls: pending, - ReasoningStartedAt: result.reasoningStartedAt, - ReasoningCompletedAt: result.reasoningCompletedAt, - }); err != nil { - if errors.Is(err, ErrInterrupted) { - persistInterruptedStep(ctx, opts, result) - return ErrInterrupted - } - return xerrors.Errorf("persist step: %w", err) - } - return nil -} - // applyExclusiveToolPolicy checks whether toolCalls violate the // exclusive-tool policy declared by exclusiveToolNames. When a // violation is detected it synthesizes deterministic policy-error @@ -1698,173 +1329,10 @@ func flushActiveState( } } -// persistInterruptedStep saves durable content from a partial stream. -// Provider-executed calls without results are removed because their result -// metadata cannot be synthesized safely, except when removal would mutate -// signed Anthropic replay state. -func persistInterruptedStep( - ctx context.Context, - opts RunOptions, - result *stepResult, -) { - if result == nil || (len(result.content) == 0 && len(result.toolCalls) == 0) { - return - } - - provider := "" - modelName := "" - if opts.Model != nil { - provider = opts.Model.Provider() - modelName = opts.Model.Model() - } - var sanitizeStats chatsanitize.AnthropicProviderToolSanitizationStats - result.content, sanitizeStats = chatsanitize.SanitizeAnthropicProviderToolContent(provider, result.content) - chatsanitize.LogAnthropicProviderToolSanitization( - ctx, opts.Logger, "interrupted_persist", provider, modelName, sanitizeStats, - ) - - // Track which tool calls already have results in the content. - answeredToolCalls := make(map[string]struct{}) - for _, c := range result.content { - tr, ok := fantasy.AsContentType[fantasy.ToolResultContent](c) - if ok && tr.ToolCallID != "" { - answeredToolCalls[tr.ToolCallID] = struct{}{} - } - } - - // Copy existing timestamps and add result timestamps for - // interrupted tool calls so the frontend can show partial - // duration. - toolCallCreatedAt := maps.Clone(result.toolCallCreatedAt) - if toolCallCreatedAt == nil { - toolCallCreatedAt = make(map[string]time.Time) - } - toolResultCreatedAt := maps.Clone(result.toolResultCreatedAt) - if toolResultCreatedAt == nil { - toolResultCreatedAt = make(map[string]time.Time) - } - - // Build combined content: all accumulated content + synthetic - // interrupted results for any unanswered tool calls. - content := make([]fantasy.Content, 0, len(result.content)) - content = append(content, result.content...) - - interruptedAt := dbtime.Now() - for _, tc := range result.toolCalls { - if tc.ToolCallID == "" { - continue - } - if _, exists := answeredToolCalls[tc.ToolCallID]; exists { - continue - } - if chatsanitize.IsAnthropicProviderExecutedToolCall(provider, tc) { - continue - } - content = append(content, fantasy.ToolResultContent{ - ToolCallID: tc.ToolCallID, - ToolName: tc.ToolName, - ProviderExecuted: tc.ProviderExecuted, - Result: fantasy.ToolResultOutputContentError{ - Error: xerrors.New(interruptedToolResultErrorMessage), - }, - }) - // Only stamp synthetic results; don't clobber - // timestamps from tools that completed before - // the interruption arrived. - if _, exists := toolResultCreatedAt[tc.ToolCallID]; !exists { - toolResultCreatedAt[tc.ToolCallID] = interruptedAt - } - answeredToolCalls[tc.ToolCallID] = struct{}{} - } - - if len(content) == 0 { - return - } - - persistCtx := context.WithoutCancel(ctx) - if err := opts.PersistStep(persistCtx, PersistedStep{ - Content: content, - ToolCallCreatedAt: toolCallCreatedAt, - ToolResultCreatedAt: toolResultCreatedAt, - ReasoningStartedAt: result.reasoningStartedAt, - ReasoningCompletedAt: result.reasoningCompletedAt, - }); err != nil { - if opts.OnInterruptedPersistError != nil { - opts.OnInterruptedPersistError(err) - } - } -} - -// tryCompactOnExit runs compaction when the chatloop is about -// to exit early (e.g. via ErrDynamicToolCall). The normal -// inline and post-run compaction paths are unreachable in -// early-exit scenarios, so this ensures the context window -// doesn't grow unbounded. -func tryCompactOnExit( - ctx context.Context, - opts RunOptions, - usage fantasy.Usage, - metadata fantasy.ProviderMetadata, -) { - if opts.Compaction == nil || opts.ReloadMessages == nil { - return - } - reloaded, err := opts.ReloadMessages(ctx) - if err != nil { - return - } - did, compactErr := tryCompact( - ctx, - opts.Model, - opts.Compaction, - opts.ContextLimitFallback, - usage, - metadata, - reloaded, - ) - opts.Metrics.RecordCompaction(opts.Model.Provider(), opts.Model.Model(), did, compactErr) - if compactErr != nil && opts.Compaction.OnError != nil { - opts.Compaction.OnError(compactErr) - } -} - func isToolActive(name string, activeTools []string) bool { return len(activeTools) == 0 || slices.Contains(activeTools, name) } -// mergeNewToolNames returns activeTools augmented with any tool names -// from newTools that are not present in oldTools and not already in -// activeTools. This keeps newly injected tools (e.g. via PrepareTools) -// callable even when activeTools is non-empty. -// -// When activeTools is empty, all tools are already active and the slice -// is returned unchanged. -func mergeNewToolNames(activeTools []string, oldTools, newTools []fantasy.AgentTool) []string { - if len(activeTools) == 0 { - return activeTools - } - old := make(map[string]struct{}, len(oldTools)) - for _, t := range oldTools { - old[t.Info().Name] = struct{}{} - } - active := make(map[string]struct{}, len(activeTools)) - for _, name := range activeTools { - active[name] = struct{}{} - } - for _, t := range newTools { - name := t.Info().Name - if _, alreadyActive := active[name]; alreadyActive { - continue - } - if _, existedBefore := old[name]; existedBefore { - continue - } - activeTools = append(activeTools, name) - active[name] = struct{}{} - } - return activeTools -} - // buildToolDefinitions converts AgentTool definitions into the // fantasy.Tool slice expected by fantasy.Call. When activeTools // is non-empty, only function tools whose name appears in the @@ -1901,24 +1369,6 @@ func buildToolDefinitions(tools []fantasy.AgentTool, activeTools []string, provi return prepared } -// shouldStopAfterTools returns true if any tool result in the -// slice matches a name in stopTools and produced a successful -// (non-error) result. -func shouldStopAfterTools(stopTools map[string]struct{}, results []fantasy.ToolResultContent) bool { - if len(stopTools) == 0 { - return false - } - for _, tr := range results { - if _, ok := stopTools[tr.ToolName]; !ok { - continue - } - if _, isErr := tr.Result.(fantasy.ToolResultOutputContentError); !isErr { - return true - } - } - return false -} - func shouldApplyAnthropicPromptCaching(model fantasy.LanguageModel) bool { if model == nil { return false diff --git a/coderd/x/chatd/chatloop/chatloop_internal_test.go b/coderd/x/chatd/chatloop/chatloop_internal_test.go index 1d6ff07560..07aaacb5a1 100644 --- a/coderd/x/chatd/chatloop/chatloop_internal_test.go +++ b/coderd/x/chatd/chatloop/chatloop_internal_test.go @@ -3,587 +3,15 @@ package chatloop import ( "context" "iter" - "sync" "testing" "charm.land/fantasy" fantasyanthropic "charm.land/fantasy/providers/anthropic" - fantasyopenai "charm.land/fantasy/providers/openai" "github.com/stretchr/testify/require" - "golang.org/x/xerrors" - "cdr.dev/slog/v3/sloggers/slogtest" - "github.com/coder/coder/v2/coderd/x/chatd/chatopenai" - "github.com/coder/coder/v2/coderd/x/chatd/chattest" "github.com/coder/coder/v2/codersdk" ) -func TestRun_ChainBrokenRecovers(t *testing.T) { - t.Parallel() - - // Given: a chain-mode run whose previous provider_response_id is present in - // our database but no longer recognized by the provider for some reason - var ( - streamCalls int - secondCallOpt fantasy.ProviderOptions - secondPrompt []fantasy.Message - ) - model := &chattest.FakeModel{ - ProviderName: "openai", - StreamFn: func(_ context.Context, call fantasy.Call) (fantasy.StreamResponse, error) { - streamCalls++ - switch streamCalls { - case 1: - return nil, xerrors.New(chainBrokenErrorMessage) - default: - secondCallOpt = call.ProviderOptions - secondPrompt = call.Prompt - return finishingStream(), nil - } - }, - } - - disableCalls := 0 - reloadCalls := 0 - reloadedHistory := []fantasy.Message{ - {Role: "system", Content: []fantasy.MessagePart{fantasy.TextPart{Text: "sys"}}}, - {Role: "user", Content: []fantasy.MessagePart{fantasy.TextPart{Text: "hello"}}}, - {Role: "assistant", Content: []fantasy.MessagePart{fantasy.TextPart{Text: "hi"}}}, - {Role: "user", Content: []fantasy.MessagePart{fantasy.TextPart{Text: "follow up"}}}, - } - - chainFiltered := []fantasy.Message{ - {Role: "system", Content: []fantasy.MessagePart{fantasy.TextPart{Text: "sys"}}}, - {Role: "user", Content: []fantasy.MessagePart{fantasy.TextPart{Text: "follow up"}}}, - } - - // When: the first attempt fails with the chain-broken error - err := Run(context.Background(), RunOptions{ - Model: model, - MaxSteps: 1, - ContextLimitFallback: 4096, - Messages: chainFiltered, - ProviderOptions: chainModeProviderOptions("resp_poisoned"), - PersistStep: func(_ context.Context, _ PersistedStep) error { - return nil - }, - DisableChainMode: func() { - disableCalls++ - }, - ReloadMessages: func(_ context.Context) ([]fantasy.Message, error) { - reloadCalls++ - return reloadedHistory, nil - }, - }) - - // Then: DisableChainMode and ReloadMessages each run once and the - // retry attempt sends the full reloaded history without - // previous_response_id. - require.NoError(t, err) - require.Equal(t, 2, streamCalls, "exactly two stream attempts (one failure, one success)") - require.Equal(t, 1, disableCalls, "DisableChainMode called once on chain-broken recovery") - require.Equal(t, 1, reloadCalls, "ReloadMessages called once on chain-broken recovery") - - require.False(t, - chatopenai.HasPreviousResponseID(secondCallOpt), - "second attempt must not carry previous_response_id; it was poisoned", - ) - require.Equal(t, reloadedHistory, secondPrompt, - "second attempt must use full reloaded history, not chain-filtered prompt", - ) -} - -func TestRun_ChainBrokenRecoveryPreparesReloadedMessages(t *testing.T) { - t.Parallel() - - var ( - streamCalls int - prepareCalls int - secondCallOpt fantasy.ProviderOptions - secondPrompt []fantasy.Message - ) - model := &chattest.FakeModel{ - ProviderName: "openai", - StreamFn: func(_ context.Context, call fantasy.Call) (fantasy.StreamResponse, error) { - streamCalls++ - switch streamCalls { - case 1: - return nil, xerrors.New(chainBrokenErrorMessage) - default: - secondCallOpt = call.ProviderOptions - secondPrompt = call.Prompt - return finishingStream(), nil - } - }, - } - - reloadedHistory := []fantasy.Message{ - textMessage(fantasy.MessageRoleUser, "full history"), - } - - err := Run(context.Background(), RunOptions{ - Model: model, - MaxSteps: 1, - ContextLimitFallback: 4096, - Messages: []fantasy.Message{ - textMessage(fantasy.MessageRoleUser, "chain-filtered"), - }, - ProviderOptions: chainModeProviderOptions("resp_poisoned"), - PersistStep: func(_ context.Context, _ PersistedStep) error { - return nil - }, - DisableChainMode: func() {}, - ReloadMessages: func(_ context.Context) ([]fantasy.Message, error) { - return reloadedHistory, nil - }, - PrepareMessages: func(msgs []fantasy.Message) []fantasy.Message { - prepareCalls++ - return append(msgs, textMessage(fantasy.MessageRoleSystem, "prepared")) - }, - }) - - require.NoError(t, err) - require.Equal(t, 2, streamCalls) - require.Equal(t, 2, prepareCalls, - "reloaded history must be prepared before the retry") - require.False(t, chatopenai.HasPreviousResponseID(secondCallOpt)) - requireTextPrompt(t, secondPrompt, "full history") - requireTextPrompt(t, secondPrompt, "prepared") -} - -func TestRun_ChainBrokenRecoveryAppliesProviderPromptPrep(t *testing.T) { - t.Parallel() - - var ( - streamCalls int - secondCallOpt fantasy.ProviderOptions - secondPrompt []fantasy.Message - ) - model := &chattest.FakeModel{ - ProviderName: fantasyanthropic.Name, - StreamFn: func(_ context.Context, call fantasy.Call) (fantasy.StreamResponse, error) { - streamCalls++ - switch streamCalls { - case 1: - return nil, xerrors.New(chainBrokenErrorMessage) - default: - secondCallOpt = call.ProviderOptions - secondPrompt = call.Prompt - return finishingStream(), nil - } - }, - } - - reloadedHistory := []fantasy.Message{ - textMessage(fantasy.MessageRoleSystem, "sys-1"), - textMessage(fantasy.MessageRoleSystem, "sys-2"), - textMessage(fantasy.MessageRoleUser, "hello"), - textMessage(fantasy.MessageRoleAssistant, "hi"), - textMessage(fantasy.MessageRoleUser, "follow up"), - } - - err := Run(context.Background(), RunOptions{ - Model: model, - MaxSteps: 1, - ContextLimitFallback: 4096, - Messages: []fantasy.Message{ - textMessage(fantasy.MessageRoleSystem, "sys-2"), - textMessage(fantasy.MessageRoleUser, "follow up"), - }, - ProviderOptions: chainModeProviderOptions("resp_poisoned"), - PersistStep: func(_ context.Context, _ PersistedStep) error { - return nil - }, - DisableChainMode: func() {}, - ReloadMessages: func(_ context.Context) ([]fantasy.Message, error) { - return reloadedHistory, nil - }, - }) - - require.NoError(t, err) - require.Equal(t, 2, streamCalls) - require.False(t, chatopenai.HasPreviousResponseID(secondCallOpt)) - require.Len(t, secondPrompt, 5) - require.False(t, hasAnthropicEphemeralCacheControl(secondPrompt[0])) - require.True(t, hasAnthropicEphemeralCacheControl(secondPrompt[1])) - require.False(t, hasAnthropicEphemeralCacheControl(secondPrompt[2])) - require.True(t, hasAnthropicEphemeralCacheControl(secondPrompt[3])) - require.True(t, hasAnthropicEphemeralCacheControl(secondPrompt[4])) -} - -func TestRun_ChainBrokenReloadWithoutDisableChainModeIsExplicit(t *testing.T) { - t.Parallel() - - var ( - streamCalls int - prepareCalls int - reloadCalls int - secondCallOpt fantasy.ProviderOptions - secondPrompt []fantasy.Message - ) - model := &chattest.FakeModel{ - ProviderName: "openai", - StreamFn: func(_ context.Context, call fantasy.Call) (fantasy.StreamResponse, error) { - streamCalls++ - switch streamCalls { - case 1: - return nil, xerrors.New(chainBrokenErrorMessage) - default: - secondCallOpt = call.ProviderOptions - secondPrompt = call.Prompt - return finishingStream(), nil - } - }, - } - - err := Run(context.Background(), RunOptions{ - Model: model, - MaxSteps: 1, - ContextLimitFallback: 4096, - Messages: []fantasy.Message{ - textMessage(fantasy.MessageRoleUser, "chain-filtered"), - }, - ProviderOptions: chainModeProviderOptions("resp_poisoned"), - PersistStep: func(_ context.Context, _ PersistedStep) error { - return nil - }, - ReloadMessages: func(_ context.Context) ([]fantasy.Message, error) { - reloadCalls++ - return []fantasy.Message{ - textMessage(fantasy.MessageRoleUser, "full history"), - }, nil - }, - PrepareMessages: func(msgs []fantasy.Message) []fantasy.Message { - prepareCalls++ - return append(msgs, textMessage(fantasy.MessageRoleSystem, "prepared")) - }, - // DisableChainMode is intentionally nil. This covers callers - // whose ReloadMessages does not depend on chain-mode state. - }) - - require.NoError(t, err) - require.Equal(t, 2, streamCalls) - require.Equal(t, 1, reloadCalls) - require.Equal(t, 2, prepareCalls) - require.False(t, chatopenai.HasPreviousResponseID(secondCallOpt)) - requireTextPrompt(t, secondPrompt, "full history") - requireTextPrompt(t, secondPrompt, "prepared") -} - -func TestRun_ChainBrokenComposesWithPostStepChainExit(t *testing.T) { - t.Parallel() - - // Given a chain-mode run whose recovery succeeds and yields a - // tool call so the step loop continues - var ( - mu sync.Mutex - streamCalls int - capturedOpts []fantasy.ProviderOptions - ) - model := &chattest.FakeModel{ - ProviderName: "openai", - StreamFn: func(_ context.Context, call fantasy.Call) (fantasy.StreamResponse, error) { - mu.Lock() - streamCalls++ - attempt := streamCalls - capturedOpts = append(capturedOpts, call.ProviderOptions) - mu.Unlock() - - switch attempt { - case 1: - // Initial chained attempt: 404 from provider. - return nil, xerrors.New(chainBrokenErrorMessage) - case 2: - // Recovery succeeded; emit a tool call so the - // step loop continues to a second step. - return streamFromParts([]fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeToolInputStart, ID: "tc-1", ToolCallName: "read_file"}, - {Type: fantasy.StreamPartTypeToolInputDelta, ID: "tc-1", Delta: `{"path":"main.go"}`}, - {Type: fantasy.StreamPartTypeToolInputEnd, ID: "tc-1"}, - { - Type: fantasy.StreamPartTypeToolCall, - ID: "tc-1", - ToolCallName: "read_file", - ToolCallInput: `{"path":"main.go"}`, - }, - {Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonToolCalls}, - }), nil - default: - // Step 1: end the run. - return finishingStream(), nil - } - }, - } - - // When the second step builds its call from opts.ProviderOptions - err := Run(context.Background(), RunOptions{ - Model: model, - MaxSteps: 3, - ContextLimitFallback: 4096, - Messages: []fantasy.Message{ - textMessage(fantasy.MessageRoleUser, "hi"), - }, - Tools: []fantasy.AgentTool{ - newNoopTool("read_file"), - }, - ProviderOptions: chainModeProviderOptions("resp_poisoned"), - PersistStep: func(_ context.Context, _ PersistedStep) error { - return nil - }, - DisableChainMode: func() {}, - ReloadMessages: func(_ context.Context) ([]fantasy.Message, error) { - return []fantasy.Message{ - textMessage(fantasy.MessageRoleUser, "hi"), - }, nil - }, - }) - - // Then it must not re-send the poisoned previous_response_id - // because chain-broken recovery cleared both the current call and - // subsequent step options. - require.NoError(t, err) - require.Equal(t, 3, streamCalls, - "expected three stream calls: chain-broken failure, recovered tool-call step, follow-up step") - for i, providerOpts := range capturedOpts[1:] { - require.False(t, - chatopenai.HasPreviousResponseID(providerOpts), - "every stream call after recovery (index %d) must have cleared previous_response_id", - i+1, - ) - } -} - -func TestRun_ChainBrokenReloadFailureStillClearsChain(t *testing.T) { - t.Parallel() - - // Given: a chain-mode run whose ReloadMessages callback errors - var ( - streamCalls int - prepareCalls int - secondCallOpt fantasy.ProviderOptions - secondPrompt []fantasy.Message - ) - model := &chattest.FakeModel{ - ProviderName: "openai", - StreamFn: func(_ context.Context, call fantasy.Call) (fantasy.StreamResponse, error) { - streamCalls++ - switch streamCalls { - case 1: - return nil, xerrors.New(chainBrokenErrorMessage) - default: - secondCallOpt = call.ProviderOptions - secondPrompt = call.Prompt - return finishingStream(), nil - } - }, - } - - disableCalls := 0 - chainFiltered := []fantasy.Message{ - {Role: "system", Content: []fantasy.MessagePart{fantasy.TextPart{Text: "sys"}}}, - {Role: "user", Content: []fantasy.MessagePart{fantasy.TextPart{Text: "follow up"}}}, - } - - // When: the chain-broken error fires - err := Run(context.Background(), RunOptions{ - Model: model, - MaxSteps: 1, - ContextLimitFallback: 4096, - Messages: chainFiltered, - ProviderOptions: chainModeProviderOptions("resp_poisoned"), - PersistStep: func(_ context.Context, _ PersistedStep) error { - return nil - }, - DisableChainMode: func() { - disableCalls++ - }, - ReloadMessages: func(_ context.Context) ([]fantasy.Message, error) { - return nil, xerrors.New("reload exploded") - }, - PrepareMessages: func(msgs []fantasy.Message) []fantasy.Message { - prepareCalls++ - return append(msgs, textMessage(fantasy.MessageRoleSystem, "prepared")) - }, - }) - - // Then: the poisoned previous_response_id is still cleared and - // DisableChainMode still runs, so the retry has any chance of - // succeeding against the chain-filtered prompt. - require.NoError(t, err) - require.Equal(t, 1, disableCalls) - require.Equal(t, 1, prepareCalls) - require.False(t, - chatopenai.HasPreviousResponseID(secondCallOpt), - "chain options must still be cleared even when reload fails", - ) - requireTextPrompt(t, secondPrompt, "follow up") - requireTextPrompt(t, secondPrompt, "prepared") -} - -func TestRun_ChainBrokenRecoveryDropsOrphanProviderToolCall(t *testing.T) { - t.Parallel() - - var ( - streamCalls int - secondCallOpt fantasy.ProviderOptions - secondPrompt []fantasy.Message - ) - model := &chattest.FakeModel{ - ProviderName: fantasyanthropic.Name, - ModelName: "claude-test", - StreamFn: func(_ context.Context, call fantasy.Call) (fantasy.StreamResponse, error) { - streamCalls++ - switch streamCalls { - case 1: - return nil, xerrors.New(chainBrokenErrorMessage) - default: - secondCallOpt = call.ProviderOptions - secondPrompt = call.Prompt - return finishingStream(), nil - } - }, - } - - reloadCalls := 0 - err := Run(context.Background(), RunOptions{ - Model: model, - Logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}), - MaxSteps: 1, - ContextLimitFallback: 4096, - Messages: []fantasy.Message{ - textMessage(fantasy.MessageRoleUser, "chain-filtered"), - }, - ProviderOptions: chainModeProviderOptions("resp_poisoned"), - PersistStep: func(_ context.Context, _ PersistedStep) error { - return nil - }, - DisableChainMode: func() {}, - ReloadMessages: func(_ context.Context) ([]fantasy.Message, error) { - reloadCalls++ - return []fantasy.Message{ - textMessage(fantasy.MessageRoleUser, "search"), - { - Role: fantasy.MessageRoleAssistant, - Content: []fantasy.MessagePart{ - fantasy.ReasoningPart{ProviderOptions: fantasy.ProviderOptions{fantasyanthropic.Name: &fantasyanthropic.ReasoningOptionMetadata{RedactedData: "redacted-payload"}}}, - fantasy.ToolCallPart{ToolCallID: "ws-orphan", ToolName: "web_search", Input: `{"query":"coder"}`, ProviderExecuted: true}, - fantasy.TextPart{Text: "partial"}, - }, - }, - textMessage(fantasy.MessageRoleUser, "continue"), - }, nil - }, - }) - - require.NoError(t, err) - require.Equal(t, 1, reloadCalls) - require.Equal(t, 2, streamCalls) - require.False(t, chatopenai.HasPreviousResponseID(secondCallOpt)) - requireNoProviderExecutedToolCallPrompt(t, secondPrompt) - requireAnthropicProviderToolPromptSafe(t, secondPrompt) - requireTextPrompt(t, secondPrompt, "search") - requireTextPrompt(t, secondPrompt, "partial") - requireTextPrompt(t, secondPrompt, "continue") - reasoningPart := requireReasoningPrompt(t, secondPrompt) - reasoningMetadata := fantasyanthropic.GetReasoningMetadata(reasoningPart.ProviderOptions) - require.NotNil(t, reasoningMetadata) - require.Equal(t, "redacted-payload", reasoningMetadata.RedactedData) -} - -func TestRun_ChainBrokenWithoutChainModeIsSafe(t *testing.T) { - t.Parallel() - - // Given: a run with no chain-mode options or callbacks - var streamCalls int - model := &chattest.FakeModel{ - ProviderName: "openai", - StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { - streamCalls++ - switch streamCalls { - case 1: - return nil, xerrors.New(chainBrokenErrorMessage) - default: - return finishingStream(), nil - } - }, - } - - // When: a future provider returns a chain-broken signal, - err := Run(context.Background(), RunOptions{ - Model: model, - MaxSteps: 1, - ContextLimitFallback: 4096, - PersistStep: func(_ context.Context, _ PersistedStep) error { - return nil - }, - // No ProviderOptions, no DisableChainMode, no ReloadMessages. - }) - - // Then: the recovery branch must no-op (no panic, no missing - // callbacks) and the retry runs normally. - require.NoError(t, err) - require.Equal(t, 2, streamCalls) -} - -func TestRun_NonChainBrokenRetryDoesNotTouchChainState(t *testing.T) { - t.Parallel() - - // Given: a chain-mode run with a still-valid previous_response_id - var ( - streamCalls int - secondCallOpt fantasy.ProviderOptions - ) - model := &chattest.FakeModel{ - ProviderName: "openai", - StreamFn: func(_ context.Context, call fantasy.Call) (fantasy.StreamResponse, error) { - streamCalls++ - switch streamCalls { - case 1: - return nil, xerrors.New("received status 503 from upstream") - default: - secondCallOpt = call.ProviderOptions - return finishingStream(), nil - } - }, - } - - disableCalls := 0 - reloadCalls := 0 - - // When: a non-chain-broken retryable error fires (503) - err := Run(context.Background(), RunOptions{ - Model: model, - MaxSteps: 1, - ContextLimitFallback: 4096, - Messages: []fantasy.Message{ - {Role: "user", Content: []fantasy.MessagePart{fantasy.TextPart{Text: "hi"}}}, - }, - ProviderOptions: chainModeProviderOptions("resp_still_valid"), - PersistStep: func(_ context.Context, _ PersistedStep) error { - return nil - }, - DisableChainMode: func() { - disableCalls++ - }, - ReloadMessages: func(_ context.Context) ([]fantasy.Message, error) { - reloadCalls++ - return nil, nil - }, - }) - - // Then: chain mode stays engaged, ReloadMessages is not called, - // and the retry preserves previous_response_id. - require.NoError(t, err) - require.Equal(t, 0, disableCalls, - "non-chain-broken retry must not exit chain mode") - require.Equal(t, 0, reloadCalls, - "non-chain-broken retry must not reload history") - require.True(t, - chatopenai.HasPreviousResponseID(secondCallOpt), - "non-chain-broken retry must preserve previous_response_id", - ) -} - func TestProcessStepStreamPreservesReasoningMetadataAcrossNilDelta(t *testing.T) { t.Parallel() @@ -652,34 +80,6 @@ func TestProcessStepStreamPersistsRedactedThinkingOnEnd(t *testing.T) { require.Equal(t, "redacted-payload", metadata.RedactedData) } -func TestStepResultToResponseMessagesPreservesEmptySignedReasoning(t *testing.T) { - t.Parallel() - - result := stepResult{ - content: []fantasy.Content{ - fantasy.ReasoningContent{ - ProviderMetadata: fantasy.ProviderMetadata{ - fantasyanthropic.Name: &fantasyanthropic.ReasoningOptionMetadata{ - RedactedData: "redacted-payload", - }, - }, - }, - fantasy.TextContent{Text: "done"}, - }, - } - - messages := result.toResponseMessages() - - require.Len(t, messages, 1) - require.Len(t, messages[0].Content, 2) - reasoning, ok := fantasy.AsMessagePart[fantasy.ReasoningPart](messages[0].Content[0]) - require.True(t, ok) - require.Empty(t, reasoning.Text) - metadata := fantasyanthropic.GetReasoningMetadata(reasoning.ProviderOptions) - require.NotNil(t, metadata) - require.Equal(t, "redacted-payload", metadata.RedactedData) -} - func TestFlushActiveStatePreservesEmptySignedReasoning(t *testing.T) { t.Parallel() @@ -709,32 +109,3 @@ func TestFlushActiveStatePreservesEmptySignedReasoning(t *testing.T) { require.NotNil(t, metadata) require.Equal(t, "redacted-payload", metadata.RedactedData) } - -// chainBrokenError is what OpenAI returns when previous_response_id -// points at a response it does not have stored. -const chainBrokenErrorMessage = "Previous response with id 'resp_abc' not found." - -// finishingStream returns a stream that emits a single Finish part. -// The chatloop treats a finishReason of Stop as "stoppedByModel" and -// exits the per-step loop after persisting. -func finishingStream() fantasy.StreamResponse { - return iter.Seq[fantasy.StreamPart](func(yield func(fantasy.StreamPart) bool) { - yield(fantasy.StreamPart{ - Type: fantasy.StreamPartTypeFinish, - FinishReason: fantasy.FinishReasonStop, - }) - }) -} - -// chainModeProviderOptions builds a fantasy.ProviderOptions carrying -// the OpenAI Responses options with previous_response_id set, the same -// shape chatd builds when chain mode is active. -func chainModeProviderOptions(previousResponseID string) fantasy.ProviderOptions { - store := true - return fantasy.ProviderOptions{ - fantasyopenai.Name: &fantasyopenai.ResponsesProviderOptions{ - Store: &store, - PreviousResponseID: &previousResponseID, - }, - } -} diff --git a/coderd/x/chatd/chatloop/chatloop_run_internal_test.go b/coderd/x/chatd/chatloop/chatloop_run_internal_test.go index 9769f10d01..4e0c6bf55a 100644 --- a/coderd/x/chatd/chatloop/chatloop_run_internal_test.go +++ b/coderd/x/chatd/chatloop/chatloop_run_internal_test.go @@ -5,7 +5,6 @@ import ( "encoding/base64" "errors" "iter" - "strings" "sync" "sync/atomic" "testing" @@ -21,9 +20,7 @@ import ( "golang.org/x/xerrors" "cdr.dev/slog/v3" - "cdr.dev/slog/v3/sloggers/slogtest" "github.com/coder/coder/v2/coderd/x/chatd/chaterror" - "github.com/coder/coder/v2/coderd/x/chatd/chatretry" "github.com/coder/coder/v2/coderd/x/chatd/chatsanitize" "github.com/coder/coder/v2/coderd/x/chatd/chattest" "github.com/coder/coder/v2/codersdk" @@ -31,8 +28,6 @@ import ( "github.com/coder/quartz" ) -const activeToolName = "read_file" - func validWebSearchProviderMetadataForTest() fantasy.ProviderMetadata { return fantasy.ProviderMetadata{ fantasyanthropic.Name: &fantasyanthropic.WebSearchResultMetadata{ @@ -77,30 +72,6 @@ func safeToolResultContent(block fantasy.Content) (fantasy.ToolResultContent, bo } } -func safeToolCallPart(part fantasy.MessagePart) (fantasy.ToolCallPart, bool) { - var zero fantasy.ToolCallPart - if part == nil { - return zero, false - } - if value, ok := part.(*fantasy.ToolCallPart); ok && value == nil { - return zero, false - } - type toolCallPart = fantasy.ToolCallPart - return fantasy.AsMessagePart[toolCallPart](part) -} - -func safeToolResultPart(part fantasy.MessagePart) (fantasy.ToolResultPart, bool) { - var zero fantasy.ToolResultPart - if part == nil { - return zero, false - } - if value, ok := part.(*fantasy.ToolResultPart); ok && value == nil { - return zero, false - } - type toolResultPart = fantasy.ToolResultPart - return fantasy.AsMessagePart[toolResultPart](part) -} - func toolCallContentToPart(toolCall fantasy.ToolCallContent) fantasy.ToolCallPart { return fantasy.ToolCallPart{ ToolCallID: toolCall.ToolCallID, @@ -120,467 +91,6 @@ func toolResultContentToPart(toolResult fantasy.ToolResultContent) fantasy.ToolR } } -func awaitRunResult(ctx context.Context, t *testing.T, done <-chan error) error { - t.Helper() - - select { - case err := <-done: - return err - case <-ctx.Done(): - t.Fatal("timed out waiting for Run to complete") - return nil - } -} - -func TestRun_ActiveToolsPrepareBehavior(t *testing.T) { - t.Parallel() - - var capturedCall fantasy.Call - model := &chattest.FakeModel{ - ProviderName: fantasyanthropic.Name, - StreamFn: func(_ context.Context, call fantasy.Call) (fantasy.StreamResponse, error) { - capturedCall = call - return streamFromParts([]fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeTextStart, ID: "text-1"}, - {Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "done"}, - {Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"}, - {Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop}, - }), nil - }, - } - - persistStepCalls := 0 - var persistedStep PersistedStep - - err := Run(context.Background(), RunOptions{ - Model: model, - Messages: []fantasy.Message{ - textMessage(fantasy.MessageRoleSystem, "sys-1"), - textMessage(fantasy.MessageRoleSystem, "sys-2"), - textMessage(fantasy.MessageRoleUser, "hello"), - textMessage(fantasy.MessageRoleAssistant, "working"), - textMessage(fantasy.MessageRoleUser, "continue"), - }, - Tools: []fantasy.AgentTool{ - newNoopTool(activeToolName), - newNoopTool("write_file"), - }, - MaxSteps: 3, - ActiveTools: []string{activeToolName}, - ContextLimitFallback: 4096, - PersistStep: func(_ context.Context, step PersistedStep) error { - persistStepCalls++ - persistedStep = step - return nil - }, - }) - require.NoError(t, err) - - require.Equal(t, 1, persistStepCalls) - require.True(t, persistedStep.ContextLimit.Valid) - require.Equal(t, int64(4096), persistedStep.ContextLimit.Int64) - require.GreaterOrEqual(t, persistedStep.Runtime, time.Duration(0), - "step runtime should be non-negative") - - require.NotEmpty(t, capturedCall.Prompt) - require.False(t, containsPromptSentinel(capturedCall.Prompt)) - require.Len(t, capturedCall.Tools, 1) - require.Equal(t, activeToolName, capturedCall.Tools[0].GetName()) - - require.Len(t, capturedCall.Prompt, 5) - require.False(t, hasAnthropicEphemeralCacheControl(capturedCall.Prompt[0])) - require.True(t, hasAnthropicEphemeralCacheControl(capturedCall.Prompt[1])) - require.False(t, hasAnthropicEphemeralCacheControl(capturedCall.Prompt[2])) - require.True(t, hasAnthropicEphemeralCacheControl(capturedCall.Prompt[3])) - require.True(t, hasAnthropicEphemeralCacheControl(capturedCall.Prompt[4])) -} - -func TestRun_ActiveToolsRejectsDisallowedExecution(t *testing.T) { - t.Parallel() - - var blockedCalls atomic.Int32 - blockedToolName := "write_file" - model := &chattest.FakeModel{ - ProviderName: "fake", - StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { - return streamFromParts([]fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeToolInputStart, ID: "tc-blocked", ToolCallName: blockedToolName}, - {Type: fantasy.StreamPartTypeToolInputDelta, ID: "tc-blocked", Delta: `{"path":"/tmp/nope"}`}, - {Type: fantasy.StreamPartTypeToolInputEnd, ID: "tc-blocked"}, - { - Type: fantasy.StreamPartTypeToolCall, - ID: "tc-blocked", - ToolCallName: blockedToolName, - ToolCallInput: `{"path":"/tmp/nope"}`, - }, - {Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonToolCalls}, - }), nil - }, - } - - blockedTool := fantasy.NewAgentTool( - blockedToolName, - "blocked tool", - func(context.Context, struct{}, fantasy.ToolCall) (fantasy.ToolResponse, error) { - blockedCalls.Add(1) - return fantasy.NewTextResponse("should not run"), nil - }, - ) - - var persistedStep PersistedStep - err := Run(context.Background(), RunOptions{ - Model: model, - Messages: []fantasy.Message{ - textMessage(fantasy.MessageRoleUser, "try the blocked tool"), - }, - Tools: []fantasy.AgentTool{ - newNoopTool(activeToolName), - blockedTool, - }, - ActiveTools: []string{activeToolName}, - MaxSteps: 1, - PersistStep: func(_ context.Context, step PersistedStep) error { - persistedStep = step - return nil - }, - }) - require.NoError(t, err) - require.Zero(t, blockedCalls.Load(), "disallowed tool must not execute") - - var foundToolError bool - for _, block := range persistedStep.Content { - toolResult, ok := fantasy.AsContentType[fantasy.ToolResultContent](block) - if !ok || toolResult.ToolName != blockedToolName { - continue - } - errResult, ok := toolResult.Result.(fantasy.ToolResultOutputContentError) - require.True(t, ok) - assert.EqualError(t, errResult.Error, "Tool not active in this turn: "+blockedToolName) - foundToolError = true - } - require.True(t, foundToolError, "persisted step should include the rejected tool result") -} - -func TestRun_ActiveToolsAllowsProviderRunnerExecution(t *testing.T) { - t.Parallel() - - providerRunnerName := "computer" - var runnerCalls atomic.Int32 - model := &chattest.FakeModel{ - ProviderName: "fake", - StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { - return streamFromParts([]fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeToolInputStart, ID: "tc-provider-runner", ToolCallName: providerRunnerName}, - {Type: fantasy.StreamPartTypeToolInputDelta, ID: "tc-provider-runner", Delta: `{}`}, - {Type: fantasy.StreamPartTypeToolInputEnd, ID: "tc-provider-runner"}, - { - Type: fantasy.StreamPartTypeToolCall, - ID: "tc-provider-runner", - ToolCallName: providerRunnerName, - ToolCallInput: `{}`, - }, - {Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonToolCalls}, - }), nil - }, - } - - runnerTool := fantasy.NewAgentTool( - providerRunnerName, - "provider runner", - func(context.Context, struct{}, fantasy.ToolCall) (fantasy.ToolResponse, error) { - runnerCalls.Add(1) - return fantasy.NewTextResponse("ran provider runner"), nil - }, - ) - - var persistedStep PersistedStep - err := Run(context.Background(), RunOptions{ - Model: model, - Messages: []fantasy.Message{ - textMessage(fantasy.MessageRoleUser, "use the computer"), - }, - Tools: []fantasy.AgentTool{newNoopTool(activeToolName)}, - ActiveTools: []string{activeToolName}, - ProviderTools: []ProviderTool{ - { - Definition: fantasy.FunctionTool{ - Name: providerRunnerName, - Description: "provider runner", - InputSchema: map[string]any{ - "type": "object", - "properties": map[string]any{}, - }, - }, - Runner: runnerTool, - }, - }, - MaxSteps: 1, - PersistStep: func(_ context.Context, step PersistedStep) error { - persistedStep = step - return nil - }, - }) - require.NoError(t, err) - require.Equal(t, int32(1), runnerCalls.Load(), - "provider runner should execute even when omitted from active tools") - - var foundToolResult bool - for _, block := range persistedStep.Content { - toolResult, ok := fantasy.AsContentType[fantasy.ToolResultContent](block) - if !ok || toolResult.ToolName != providerRunnerName { - continue - } - textResult, ok := toolResult.Result.(fantasy.ToolResultOutputContentText) - require.True(t, ok) - assert.Equal(t, "ran provider runner", textResult.Text) - foundToolResult = true - } - require.True(t, foundToolResult, - "persisted step should include the provider runner result") -} - -func TestRun_ProviderToolResultProviderMetadata(t *testing.T) { - t.Parallel() - - expectedMetadata := fantasy.ProviderMetadata{ - "openai": &testProviderData{data: map[string]any{ - "detail": "original", - }}, - } - - tests := []struct { - name string - callback func(fantasy.ToolResponse) fantasy.ProviderMetadata - want fantasy.ProviderMetadata - }{ - { - name: "callback returns metadata", - callback: func(fantasy.ToolResponse) fantasy.ProviderMetadata { - return expectedMetadata - }, - want: expectedMetadata, - }, - { - name: "callback nil", - want: nil, - }, - { - name: "callback returns nil", - callback: func(fantasy.ToolResponse) fantasy.ProviderMetadata { - return nil - }, - want: nil, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - providerRunnerName := "computer" - model := &chattest.FakeModel{ - ProviderName: "fake", - StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { - return streamFromParts([]fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeToolInputStart, ID: "tc-provider-runner", ToolCallName: providerRunnerName}, - {Type: fantasy.StreamPartTypeToolInputDelta, ID: "tc-provider-runner", Delta: `{}`}, - {Type: fantasy.StreamPartTypeToolInputEnd, ID: "tc-provider-runner"}, - { - Type: fantasy.StreamPartTypeToolCall, - ID: "tc-provider-runner", - ToolCallName: providerRunnerName, - ToolCallInput: `{}`, - }, - {Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonToolCalls}, - }), nil - }, - } - - runnerTool := fantasy.NewAgentTool( - providerRunnerName, - "provider runner", - func(context.Context, struct{}, fantasy.ToolCall) (fantasy.ToolResponse, error) { - return fantasy.ToolResponse{ - Type: "image", - Data: []byte("image bytes"), - MediaType: "image/png", - Content: "screenshot", - }, nil - }, - ) - - var persistedStep PersistedStep - err := Run(context.Background(), RunOptions{ - Model: model, - Messages: []fantasy.Message{ - textMessage(fantasy.MessageRoleUser, "use the computer"), - }, - ProviderTools: []ProviderTool{ - { - Definition: fantasy.FunctionTool{ - Name: providerRunnerName, - Description: "provider runner", - InputSchema: map[string]any{ - "type": "object", - "properties": map[string]any{}, - }, - }, - Runner: runnerTool, - ResultProviderMetadata: tt.callback, - }, - }, - MaxSteps: 1, - PersistStep: func(_ context.Context, step PersistedStep) error { - persistedStep = step - return nil - }, - }) - require.NoError(t, err) - - var foundResult fantasy.ToolResultContent - for _, block := range persistedStep.Content { - toolResult, ok := fantasy.AsContentType[fantasy.ToolResultContent](block) - if !ok || toolResult.ToolName != providerRunnerName { - continue - } - foundResult = toolResult - break - } - require.NotEmpty(t, foundResult.ToolCallID, - "persisted step should include the provider runner result") - - mediaResult, ok := foundResult.Result.(fantasy.ToolResultOutputContentMedia) - require.True(t, ok, "expected media result") - assert.Equal(t, "image/png", mediaResult.MediaType) - assert.Equal(t, tt.want, foundResult.ProviderMetadata) - - if tt.want == nil { - return - } - - messages := stepResult{content: persistedStep.Content}.toResponseMessages() - require.Len(t, messages, 2) - require.Equal(t, fantasy.MessageRoleTool, messages[1].Role) - require.Len(t, messages[1].Content, 1) - - resultPart, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](messages[1].Content[0]) - require.True(t, ok, "expected outbound tool result part") - assert.Equal(t, fantasy.ProviderOptions(tt.want), resultPart.ProviderOptions) - }) - } -} - -func TestProcessStepStream_AnthropicUsageMatchesFinalDelta(t *testing.T) { - t.Parallel() - - model := &chattest.FakeModel{ - ProviderName: fantasyanthropic.Name, - StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { - return streamFromParts([]fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeTextStart, ID: "text-1"}, - {Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "cached response"}, - {Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"}, - { - Type: fantasy.StreamPartTypeFinish, - Usage: fantasy.Usage{ - InputTokens: 200, - OutputTokens: 75, - TotalTokens: 275, - CacheCreationTokens: 30, - CacheReadTokens: 150, - ReasoningTokens: 0, - }, - FinishReason: fantasy.FinishReasonStop, - }, - }), nil - }, - } - - var persistedStep PersistedStep - - err := Run(context.Background(), RunOptions{ - Model: model, - Messages: []fantasy.Message{ - textMessage(fantasy.MessageRoleUser, "hello"), - }, - MaxSteps: 1, - ContextLimitFallback: 4096, - PersistStep: func(_ context.Context, step PersistedStep) error { - persistedStep = step - return nil - }, - }) - require.NoError(t, err) - require.Equal(t, int64(200), persistedStep.Usage.InputTokens) - require.Equal(t, int64(75), persistedStep.Usage.OutputTokens) - require.Equal(t, int64(275), persistedStep.Usage.TotalTokens) - require.Equal(t, int64(30), persistedStep.Usage.CacheCreationTokens) - require.Equal(t, int64(150), persistedStep.Usage.CacheReadTokens) -} - -func TestRun_OnRetryEnrichesProvider(t *testing.T) { - t.Parallel() - - type retryRecord struct { - attempt int - errMsg string - classified chatretry.ClassifiedError - delay time.Duration - } - - var records []retryRecord - calls := 0 - model := &chattest.FakeModel{ - ProviderName: "openai", - StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { - calls++ - if calls == 1 { - return nil, xerrors.New("received status 429 from upstream") - } - return streamFromParts([]fantasy.StreamPart{{ - Type: fantasy.StreamPartTypeFinish, - FinishReason: fantasy.FinishReasonStop, - }}), nil - }, - } - - err := Run(context.Background(), RunOptions{ - Model: model, - MaxSteps: 1, - ContextLimitFallback: 4096, - PersistStep: func(_ context.Context, _ PersistedStep) error { - return nil - }, - OnRetry: func( - attempt int, - retryErr error, - classified chatretry.ClassifiedError, - delay time.Duration, - ) { - records = append(records, retryRecord{ - attempt: attempt, - errMsg: retryErr.Error(), - classified: classified, - delay: delay, - }) - }, - }) - require.NoError(t, err) - require.Len(t, records, 1) - require.Equal(t, 1, records[0].attempt) - require.Equal(t, "received status 429 from upstream", records[0].errMsg) - require.Equal(t, chatretry.Delay(0), records[0].delay) - require.Equal(t, "openai", records[0].classified.Provider) - require.Equal(t, codersdk.ChatErrorKindRateLimit, records[0].classified.Kind) - require.True(t, records[0].classified.Retryable) - require.Equal(t, 429, records[0].classified.StatusCode) - require.Equal( - t, - "OpenAI is rate limiting requests.", - records[0].classified.Message, - ) -} - func TestStreamSilenceGuard_DisarmAndFireRace(t *testing.T) { t.Parallel() @@ -638,546 +148,236 @@ func TestStreamSilenceGuard_DisarmPreservesPermanentError(t *testing.T) { require.Nil(t, context.Cause(attemptCtx)) } -func TestRun_RetriesSilenceTimeoutWhileOpeningStream(t *testing.T) { +func TestGenerateAssistant_ProviderContextSurvivesStreamError(t *testing.T) { t.Parallel() - const silenceTimeout = 5 * time.Millisecond - - ctx, cancel := context.WithTimeout( - context.Background(), - testutil.WaitShort, - ) - defer cancel() - - mClock := quartz.NewMock(t) - trap := mClock.Trap().AfterFunc(streamSilenceGuardTimerTag) - defer trap.Close() - - attempts := 0 - attemptCause := make(chan error, 1) - var retries []chatretry.ClassifiedError model := &chattest.FakeModel{ ProviderName: "openai", - StreamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { - attempts++ - if attempts == 1 { - <-ctx.Done() - attemptCause <- context.Cause(ctx) - return nil, ctx.Err() - } - return streamFromParts([]fantasy.StreamPart{{ - Type: fantasy.StreamPartTypeFinish, - FinishReason: fantasy.FinishReasonStop, - }}), nil + ModelName: "test-model", + StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { + return nil, xerrors.New("upstream returned status 400") }, } - done := make(chan error, 1) - go func() { - done <- Run(context.Background(), RunOptions{ - Model: model, - MaxSteps: 1, - StreamSilenceTimeout: silenceTimeout, - Clock: mClock, - PersistStep: func(_ context.Context, _ PersistedStep) error { - return nil - }, - OnRetry: func( - _ int, - _ error, - classified chatretry.ClassifiedError, - _ time.Duration, - ) { - retries = append(retries, classified) - }, - }) - }() - - trap.MustWait(ctx).MustRelease(ctx) - mClock.Advance(silenceTimeout).MustWait(ctx) - trap.MustWait(ctx).MustRelease(ctx) - - require.NoError(t, awaitRunResult(ctx, t, done)) - require.Equal(t, 2, attempts) - require.Len(t, retries, 1) - require.Equal(t, codersdk.ChatErrorKindStreamSilenceTimeout, retries[0].Kind) - require.True(t, retries[0].Retryable) - require.Equal(t, "openai", retries[0].Provider) - require.Equal( - t, - "OpenAI did not send response data in time.", - retries[0].Message, - ) - select { - case cause := <-attemptCause: - require.ErrorIs(t, cause, errStreamSilenceTimeout) - case <-ctx.Done(): - t.Fatal("timed out waiting for silence timeout cause") - } + _, err := GenerateAssistant(context.Background(), GenerateAssistantOptions{ + Model: model, + Messages: []fantasy.Message{ + textMessage(fantasy.MessageRoleUser, "hello"), + }, + }) + require.Error(t, err) + classified := chaterror.Classify(err) + require.Equal(t, "openai", classified.Provider) + require.Equal(t, "OpenAI returned an unexpected error.", classified.Message) } -// TestRun_HTTP2TransportErrorClassifiedAsRetryableTimeout proves the -// provider comes from Model.Provider() (not from sniffing the error -// text) by using an error string with no provider hint and running -// the same assertion across two providers. -func TestRun_HTTP2TransportErrorClassifiedAsRetryableTimeout(t *testing.T) { +func TestGenerateAssistant_HTTP2TransportErrorClassifiedAsRetryableTimeout(t *testing.T) { t.Parallel() - providers := []string{"anthropic", "openai"} - for _, provider := range providers { + for _, provider := range []string{"anthropic", "openai"} { + provider := provider t.Run(provider, func(t *testing.T) { t.Parallel() - const silenceTimeout = 5 * time.Millisecond - - ctx, cancel := context.WithTimeout( - context.Background(), - testutil.WaitShort, - ) - defer cancel() - - mClock := quartz.NewMock(t) - trap := mClock.Trap().AfterFunc(streamSilenceGuardTimerTag) - defer trap.Close() - - attempts := 0 - var retries []chatretry.ClassifiedError model := &chattest.FakeModel{ ProviderName: provider, + ModelName: "test-model", StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { - attempts++ - if attempts == 1 { - // Bare transport error; Provider must - // come from Model.Provider(). - return nil, xerrors.New( - "http2: client connection force closed via ClientConn.Close", - ) - } - return streamFromParts([]fantasy.StreamPart{{ - Type: fantasy.StreamPartTypeFinish, - FinishReason: fantasy.FinishReasonStop, - }}), nil + return nil, xerrors.New("http2: client connection force closed via ClientConn.Close") }, } - done := make(chan error, 1) - go func() { - done <- Run(context.Background(), RunOptions{ - Model: model, - MaxSteps: 1, - StreamSilenceTimeout: silenceTimeout, - Clock: mClock, - PersistStep: func(_ context.Context, _ PersistedStep) error { - return nil - }, - OnRetry: func( - _ int, - _ error, - classified chatretry.ClassifiedError, - _ time.Duration, - ) { - retries = append(retries, classified) - }, - }) - }() - - // One guard per attempt. - trap.MustWait(ctx).MustRelease(ctx) - trap.MustWait(ctx).MustRelease(ctx) - - require.NoError(t, awaitRunResult(ctx, t, done)) - require.Equal(t, 2, attempts) - require.Len(t, retries, 1) - require.Equal(t, codersdk.ChatErrorKindTimeout, retries[0].Kind, "Kind") - require.True(t, retries[0].Retryable, "Retryable") - require.Equal(t, provider, retries[0].Provider, "Provider") + _, err := GenerateAssistant(context.Background(), GenerateAssistantOptions{ + Model: model, + }) + require.Error(t, err) + classified := chaterror.Classify(err) + require.Equal(t, codersdk.ChatErrorKindTimeout, classified.Kind) + require.Equal(t, provider, classified.Provider) + require.True(t, classified.Retryable) }) } } -func TestRun_RetriesProviderContextCanceledStreamError(t *testing.T) { +func TestGenerateAssistant_StreamSilenceTimeoutRetryClassification(t *testing.T) { t.Parallel() - attempts := 0 - retryErrs := make(chan error, chatretry.MaxAttempts) - retries := make(chan chatretry.ClassifiedError, chatretry.MaxAttempts) - var persisted []fantasy.Content - ctx := testutil.Context(t, testutil.WaitShort) - model := &chattest.FakeModel{ - ProviderName: "openai", - StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { - attempts++ - if attempts == 1 { - return streamFromParts([]fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeTextStart, ID: "text-1"}, - {Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "partial"}, - {Type: fantasy.StreamPartTypeError, Error: context.Canceled}, - }), nil - } - return streamFromParts([]fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeTextStart, ID: "text-2"}, - {Type: fantasy.StreamPartTypeTextDelta, ID: "text-2", Delta: "done"}, - {Type: fantasy.StreamPartTypeTextEnd, ID: "text-2"}, - {Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop}, - }), nil - }, - } + t.Run("timeout while opening stream", func(t *testing.T) { + t.Parallel() - err := Run(ctx, RunOptions{ - Model: model, - MaxSteps: 1, - ContextLimitFallback: 4096, - PersistStep: func(_ context.Context, step PersistedStep) error { - persisted = append([]fantasy.Content(nil), step.Content...) - return nil - }, - OnRetry: func( - _ int, - retryErr error, - classified chatretry.ClassifiedError, - _ time.Duration, - ) { - retryErrs <- retryErr - retries <- classified - }, - }) - require.NoError(t, err) - require.Equal(t, 2, attempts) - require.Len(t, retryErrs, 1) - require.Len(t, retries, 1) - retryErr := testutil.RequireReceive(ctx, t, retryErrs) - classified := testutil.RequireReceive(ctx, t, retries) - require.ErrorIs(t, retryErr, chaterror.ErrProviderTransportReset) - require.ErrorIs(t, retryErr, context.Canceled) - require.Equal(t, codersdk.ChatErrorKindTimeout, classified.Kind) - require.True(t, classified.Retryable) - require.Equal(t, "openai", classified.Provider) - require.Equal(t, "OpenAI is temporarily unavailable.", classified.Message) - - text := requireTextContent(t, persisted, "done") - require.Equal(t, "done", text.Text) - for _, block := range persisted { - if text, ok := fantasy.AsContentType[fantasy.TextContent](block); ok { - require.NotContains(t, text.Text, "partial") - } - } -} - -func TestRun_RetriesSilenceTimeoutBeforeFirstPart(t *testing.T) { - t.Parallel() - - const silenceTimeout = 5 * time.Millisecond - - ctx, cancel := context.WithTimeout( - context.Background(), - testutil.WaitShort, - ) - defer cancel() - - mClock := quartz.NewMock(t) - trap := mClock.Trap().AfterFunc(streamSilenceGuardTimerTag) - defer trap.Close() - - attempts := 0 - attemptCause := make(chan error, 1) - var retries []chatretry.ClassifiedError - model := &chattest.FakeModel{ - ProviderName: "openai", - StreamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { - attempts++ - if attempts == 1 { - return iter.Seq[fantasy.StreamPart](func(yield func(fantasy.StreamPart) bool) { + ctx := testutil.Context(t, testutil.WaitLong) + const silenceTimeout = 5 * time.Millisecond + clock := quartz.NewMock(t) + trap := clock.Trap().AfterFunc(streamSilenceGuardTimerTag) + defer trap.Close() + var calls atomic.Int32 + model := &chattest.FakeModel{ + ProviderName: "openai", + ModelName: "test-model", + StreamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { + if calls.Add(1) == 1 { <-ctx.Done() - attemptCause <- context.Cause(ctx) - _ = yield(fantasy.StreamPart{ - Type: fantasy.StreamPartTypeError, - Error: ctx.Err(), - }) - }), nil - } - return streamFromParts([]fantasy.StreamPart{{ - Type: fantasy.StreamPartTypeFinish, - FinishReason: fantasy.FinishReasonStop, - }}), nil - }, - } - - done := make(chan error, 1) - go func() { - done <- Run(context.Background(), RunOptions{ - Model: model, - MaxSteps: 1, - StreamSilenceTimeout: silenceTimeout, - Clock: mClock, - PersistStep: func(_ context.Context, _ PersistedStep) error { - return nil + return nil, ctx.Err() + } + return streamFromParts([]fantasy.StreamPart{{ + Type: fantasy.StreamPartTypeFinish, + FinishReason: fantasy.FinishReasonStop, + }}), nil }, - OnRetry: func( - _ int, - _ error, - classified chatretry.ClassifiedError, - _ time.Duration, - ) { - retries = append(retries, classified) + } + done := make(chan error, 1) + go func() { + _, err := GenerateAssistant(context.Background(), GenerateAssistantOptions{ + Model: model, + Clock: clock, + StreamSilenceTimeout: silenceTimeout, + }) + done <- err + }() + + trap.MustWait(ctx).MustRelease(ctx) + _, waiter := clock.AdvanceNext() + waiter.MustWait(ctx) + require.Error(t, <-done) + require.Equal(t, int32(1), calls.Load()) + }) + + t.Run("timeout before first part", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + const silenceTimeout = 5 * time.Millisecond + clock := quartz.NewMock(t) + trap := clock.Trap().AfterFunc(streamSilenceGuardTimerTag) + defer trap.Close() + var calls atomic.Int32 + model := &chattest.FakeModel{ + ProviderName: "openai", + ModelName: "test-model", + StreamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { + calls.Add(1) + return func(yield func(fantasy.StreamPart) bool) { + <-ctx.Done() + yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeError, Error: ctx.Err()}) + }, nil }, - }) - }() + } + done := make(chan error, 1) + go func() { + _, err := GenerateAssistant(context.Background(), GenerateAssistantOptions{ + Model: model, + Clock: clock, + StreamSilenceTimeout: silenceTimeout, + }) + done <- err + }() - trap.MustWait(ctx).MustRelease(ctx) - mClock.Advance(silenceTimeout).MustWait(ctx) - trap.MustWait(ctx).MustRelease(ctx) + trap.MustWait(ctx).MustRelease(ctx) + _, waiter := clock.AdvanceNext() + waiter.MustWait(ctx) + err := <-done + require.Error(t, err) + classified := chaterror.Classify(err) + require.Equal(t, codersdk.ChatErrorKindStreamSilenceTimeout, classified.Kind) + require.Equal(t, "openai", classified.Provider) + require.True(t, classified.Retryable) + require.Equal(t, int32(1), calls.Load()) + }) - require.NoError(t, awaitRunResult(ctx, t, done)) - require.Equal(t, 2, attempts) - require.Len(t, retries, 1) - require.Equal(t, codersdk.ChatErrorKindStreamSilenceTimeout, retries[0].Kind) - require.True(t, retries[0].Retryable) - require.Equal(t, "openai", retries[0].Provider) - require.Equal( - t, - "OpenAI did not send response data in time.", - retries[0].Message, - ) - select { - case cause := <-attemptCause: - require.ErrorIs(t, cause, errStreamSilenceTimeout) - case <-ctx.Done(): - t.Fatal("timed out waiting for silence timeout cause") - } -} + t.Run("first part disarms timeout", func(t *testing.T) { + t.Parallel() -func TestRun_StreamPartsResetSilenceTimeout(t *testing.T) { - t.Parallel() - - const silenceTimeout = 5 * time.Millisecond - - ctx, cancel := context.WithTimeout( - context.Background(), - testutil.WaitShort, - ) - defer cancel() - - mClock := quartz.NewMock(t) - armTrap := mClock.Trap().AfterFunc(streamSilenceGuardTimerTag) - defer armTrap.Close() - resetTrap := mClock.Trap().TimerReset(streamSilenceGuardTimerTag) - defer resetTrap.Close() - - attempts := 0 - retried := false - firstPartYielded := make(chan struct{}, 1) - secondPartYielded := make(chan struct{}, 1) - continueToSecond := make(chan struct{}) - continueToFinish := make(chan struct{}) - model := &chattest.FakeModel{ - ProviderName: "openai", - StreamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { - attempts++ - return iter.Seq[fantasy.StreamPart](func(yield func(fantasy.StreamPart) bool) { - if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"}) { - return - } - select { - case firstPartYielded <- struct{}{}: - default: - } - - select { - case <-continueToSecond: - case <-ctx.Done(): - _ = yield(fantasy.StreamPart{ - Type: fantasy.StreamPartTypeError, - Error: ctx.Err(), - }) - return - } - - if !yield(fantasy.StreamPart{ - Type: fantasy.StreamPartTypeTextDelta, - ID: "text-1", - Delta: "done", - }) { - return - } - select { - case secondPartYielded <- struct{}{}: - default: - } - - select { - case <-continueToFinish: - case <-ctx.Done(): - _ = yield(fantasy.StreamPart{ - Type: fantasy.StreamPartTypeError, - Error: ctx.Err(), - }) - return - } - - parts := []fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"}, - {Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop}, - } - for _, part := range parts { - if !yield(part) { - return - } - } - }), nil - }, - } - - done := make(chan error, 1) - go func() { - done <- Run(context.Background(), RunOptions{ - Model: model, - MaxSteps: 1, - StreamSilenceTimeout: silenceTimeout, - Clock: mClock, - PersistStep: func(_ context.Context, _ PersistedStep) error { - return nil - }, - OnRetry: func( - _ int, - _ error, - _ chatretry.ClassifiedError, - _ time.Duration, - ) { - retried = true - }, - }) - }() - - armTrap.MustWait(ctx).MustRelease(ctx) - resetTrap.MustWait(ctx).MustRelease(ctx) - select { - case <-firstPartYielded: - case <-ctx.Done(): - t.Fatal("timed out waiting for first stream part") - } - - mClock.Advance(silenceTimeout / 2).MustWait(ctx) - close(continueToSecond) - resetTrap.MustWait(ctx).MustRelease(ctx) - select { - case <-secondPartYielded: - case <-ctx.Done(): - t.Fatal("timed out waiting for second stream part") - } - - mClock.Advance(silenceTimeout / 2).MustWait(ctx) - close(continueToFinish) - resetTrap.MustWait(ctx).MustRelease(ctx) - resetTrap.MustWait(ctx).MustRelease(ctx) - - require.NoError(t, awaitRunResult(ctx, t, done)) - require.Equal(t, 1, attempts) - require.False(t, retried) -} - -func TestRun_RetriesSilenceTimeoutBetweenParts(t *testing.T) { - t.Parallel() - - const silenceTimeout = 5 * time.Millisecond - - ctx, cancel := context.WithTimeout( - context.Background(), - testutil.WaitLong, - ) - defer cancel() - - mClock := quartz.NewMock(t) - armTrap := mClock.Trap().AfterFunc(streamSilenceGuardTimerTag) - defer armTrap.Close() - resetTrap := mClock.Trap().TimerReset(streamSilenceGuardTimerTag) - defer resetTrap.Close() - - attempts := 0 - firstPartYielded := make(chan struct{}, 1) - attemptCause := make(chan error, 1) - var retries []chatretry.ClassifiedError - model := &chattest.FakeModel{ - ProviderName: "openai", - StreamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { - attempts++ - if attempts == 1 { - return iter.Seq[fantasy.StreamPart](func(yield func(fantasy.StreamPart) bool) { + ctx := testutil.Context(t, testutil.WaitLong) + const silenceTimeout = 5 * time.Millisecond + clock := quartz.NewMock(t) + trap := clock.Trap().AfterFunc(streamSilenceGuardTimerTag) + defer trap.Close() + var calls atomic.Int32 + continueStream := make(chan struct{}) + model := &chattest.FakeModel{ + ProviderName: "openai", + ModelName: "test-model", + StreamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { + calls.Add(1) + return func(yield func(fantasy.StreamPart) bool) { if !yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextStart, ID: "text-1"}) { return } select { - case firstPartYielded <- struct{}{}: - default: + case <-continueStream: + case <-ctx.Done(): + yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeError, Error: ctx.Err()}) + return } + yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "done"}) + yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"}) + yield(fantasy.StreamPart{Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop}) + }, nil + }, + } + done := make(chan error, 1) + go func() { + _, err := GenerateAssistant(context.Background(), GenerateAssistantOptions{ + Model: model, + Clock: clock, + StreamSilenceTimeout: silenceTimeout, + }) + done <- err + }() + trap.MustWait(ctx).MustRelease(ctx) + close(continueStream) + require.NoError(t, <-done) + require.Equal(t, int32(1), calls.Load()) + }) + + t.Run("silent stream close after timeout", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + const silenceTimeout = 5 * time.Millisecond + clock := quartz.NewMock(t) + trap := clock.Trap().AfterFunc(streamSilenceGuardTimerTag) + defer trap.Close() + var calls atomic.Int32 + model := &chattest.FakeModel{ + ProviderName: "openai", + ModelName: "test-model", + StreamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { + calls.Add(1) + return func(func(fantasy.StreamPart) bool) { <-ctx.Done() - attemptCause <- context.Cause(ctx) - _ = yield(fantasy.StreamPart{ - Type: fantasy.StreamPartTypeError, - Error: ctx.Err(), - }) - }), nil - } - return streamFromParts([]fantasy.StreamPart{{ - Type: fantasy.StreamPartTypeFinish, - FinishReason: fantasy.FinishReasonStop, - }}), nil - }, - } - - done := make(chan error, 1) - go func() { - done <- Run(context.Background(), RunOptions{ - Model: model, - MaxSteps: 1, - StreamSilenceTimeout: silenceTimeout, - Clock: mClock, - PersistStep: func(_ context.Context, _ PersistedStep) error { - return nil + }, nil }, - OnRetry: func( - _ int, - _ error, - classified chatretry.ClassifiedError, - _ time.Duration, - ) { - retries = append(retries, classified) - }, - }) - }() + } + done := make(chan error, 1) + go func() { + _, err := GenerateAssistant(context.Background(), GenerateAssistantOptions{ + Model: model, + Clock: clock, + StreamSilenceTimeout: silenceTimeout, + }) + done <- err + }() - armTrap.MustWait(ctx).MustRelease(ctx) - resetTrap.MustWait(ctx).MustRelease(ctx) - select { - case <-firstPartYielded: - case <-ctx.Done(): - t.Fatal("timed out waiting for first stream part") - } - - mClock.Advance(silenceTimeout).MustWait(ctx) - armTrap.MustWait(ctx).MustRelease(ctx) - resetTrap.MustWait(ctx).MustRelease(ctx) - - require.NoError(t, awaitRunResult(ctx, t, done)) - require.Equal(t, 2, attempts) - require.Len(t, retries, 1) - require.Equal(t, codersdk.ChatErrorKindStreamSilenceTimeout, retries[0].Kind) - require.True(t, retries[0].Retryable) - require.Equal(t, "openai", retries[0].Provider) - select { - case cause := <-attemptCause: - require.ErrorIs(t, cause, errStreamSilenceTimeout) - case <-ctx.Done(): - t.Fatal("timed out waiting for silence timeout cause") - } + trap.MustWait(ctx).MustRelease(ctx) + _, waiter := clock.AdvanceNext() + waiter.MustWait(ctx) + err := <-done + require.Error(t, err) + classified := chaterror.Classify(err) + require.Equal(t, codersdk.ChatErrorKindStreamSilenceTimeout, classified.Kind) + require.Equal(t, int32(1), calls.Load()) + }) } -func TestRun_PanicInPublishMessagePartReleasesAttempt(t *testing.T) { +func TestGenerateAssistant_PanicInPublishMessagePartReleasesAttempt(t *testing.T) { t.Parallel() attemptReleased := make(chan struct{}) model := &chattest.FakeModel{ ProviderName: "openai", + ModelName: "test-model", StreamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { go func() { <-ctx.Done() @@ -1200,216 +400,14 @@ func TestRun_PanicInPublishMessagePartReleasesAttempt(t *testing.T) { } }() - _ = Run(context.Background(), RunOptions{ - Model: model, - MaxSteps: 1, - ContextLimitFallback: 4096, - PersistStep: func(_ context.Context, _ PersistedStep) error { - return nil - }, + _, _ = GenerateAssistant(context.Background(), GenerateAssistantOptions{ + Model: model, PublishMessagePart: func(codersdk.ChatMessageRole, codersdk.ChatMessagePart) { panic("publish panic") }, }) - t.Fatal("expected Run to panic") -} - -func TestRun_RetriesSilenceTimeoutWhenStreamStaysSilent(t *testing.T) { - t.Parallel() - - const silenceTimeout = 5 * time.Millisecond - - ctx, cancel := context.WithTimeout( - context.Background(), - testutil.WaitShort, - ) - defer cancel() - - mClock := quartz.NewMock(t) - trap := mClock.Trap().AfterFunc(streamSilenceGuardTimerTag) - defer trap.Close() - - attempts := 0 - attemptCause := make(chan error, 1) - var retries []chatretry.ClassifiedError - model := &chattest.FakeModel{ - ProviderName: "openai", - StreamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { - attempts++ - if attempts == 1 { - return iter.Seq[fantasy.StreamPart](func(yield func(fantasy.StreamPart) bool) { - <-ctx.Done() - attemptCause <- context.Cause(ctx) - }), nil - } - return streamFromParts([]fantasy.StreamPart{{ - Type: fantasy.StreamPartTypeFinish, - FinishReason: fantasy.FinishReasonStop, - }}), nil - }, - } - - done := make(chan error, 1) - go func() { - done <- Run(context.Background(), RunOptions{ - Model: model, - MaxSteps: 1, - StreamSilenceTimeout: silenceTimeout, - Clock: mClock, - PersistStep: func(_ context.Context, _ PersistedStep) error { - return nil - }, - OnRetry: func( - _ int, - _ error, - classified chatretry.ClassifiedError, - _ time.Duration, - ) { - retries = append(retries, classified) - }, - }) - }() - - trap.MustWait(ctx).MustRelease(ctx) - mClock.Advance(silenceTimeout).MustWait(ctx) - trap.MustWait(ctx).MustRelease(ctx) - - require.NoError(t, awaitRunResult(ctx, t, done)) - require.Equal(t, 2, attempts) - require.Len(t, retries, 1) - require.Equal(t, codersdk.ChatErrorKindStreamSilenceTimeout, retries[0].Kind) - require.True(t, retries[0].Retryable) - require.Equal(t, "openai", retries[0].Provider) - require.Equal( - t, - "OpenAI did not send response data in time.", - retries[0].Message, - ) - select { - case cause := <-attemptCause: - require.ErrorIs(t, cause, errStreamSilenceTimeout) - case <-ctx.Done(): - t.Fatal("timed out waiting for silence timeout cause") - } -} - -func TestRun_InterruptedStepPersistsSyntheticToolResult(t *testing.T) { - t.Parallel() - - started := make(chan struct{}) - model := &chattest.FakeModel{ - ProviderName: "fake", - StreamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { - return iter.Seq[fantasy.StreamPart](func(yield func(fantasy.StreamPart) bool) { - parts := []fantasy.StreamPart{ - { - Type: fantasy.StreamPartTypeToolInputStart, - ID: "interrupt-tool-1", - ToolCallName: "read_file", - }, - { - Type: fantasy.StreamPartTypeToolInputDelta, - ID: "interrupt-tool-1", - ToolCallName: "read_file", - Delta: `{"path":"main.go"`, - }, - {Type: fantasy.StreamPartTypeTextStart, ID: "text-1"}, - {Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "partial assistant output"}, - } - for _, part := range parts { - if !yield(part) { - return - } - } - - select { - case <-started: - default: - close(started) - } - - <-ctx.Done() - _ = yield(fantasy.StreamPart{ - Type: fantasy.StreamPartTypeError, - Error: ctx.Err(), - }) - }), nil - }, - } - - ctx, cancel := context.WithCancelCause(context.Background()) - defer cancel(nil) - - go func() { - <-started - cancel(ErrInterrupted) - }() - - persistedAssistantCtxErr := xerrors.New("unset") - var persistedContent []fantasy.Content - var persistedStep PersistedStep - - err := Run(ctx, RunOptions{ - Model: model, - Messages: []fantasy.Message{ - textMessage(fantasy.MessageRoleUser, "hello"), - }, - Tools: []fantasy.AgentTool{ - newNoopTool("read_file"), - }, - MaxSteps: 3, - PersistStep: func(persistCtx context.Context, step PersistedStep) error { - persistedAssistantCtxErr = persistCtx.Err() - persistedContent = append([]fantasy.Content(nil), step.Content...) - persistedStep = step - return nil - }, - }) - require.ErrorIs(t, err, ErrInterrupted) - require.NoError(t, persistedAssistantCtxErr) - - require.NotEmpty(t, persistedContent) - var ( - foundText bool - foundToolCall bool - foundToolResult bool - ) - for _, block := range persistedContent { - if text, ok := fantasy.AsContentType[fantasy.TextContent](block); ok { - if strings.Contains(text.Text, "partial assistant output") { - foundText = true - } - continue - } - if toolCall, ok := fantasy.AsContentType[fantasy.ToolCallContent](block); ok { - if toolCall.ToolCallID == "interrupt-tool-1" && - toolCall.ToolName == "read_file" && - strings.Contains(toolCall.Input, `"path":"main.go"`) { - foundToolCall = true - } - continue - } - if toolResult, ok := fantasy.AsContentType[fantasy.ToolResultContent](block); ok { - if toolResult.ToolCallID == "interrupt-tool-1" && - toolResult.ToolName == "read_file" { - _, isErr := toolResult.Result.(fantasy.ToolResultOutputContentError) - require.True(t, isErr, "interrupted tool result should be an error") - foundToolResult = true - } - } - } - require.True(t, foundText) - require.True(t, foundToolCall) - require.True(t, foundToolResult) - - // The interrupted tool was flushed mid-stream (never reached - // StreamPartTypeToolCall), so it has no call timestamp. - // But the synthetic error result must have a result timestamp. - require.Contains(t, persistedStep.ToolResultCreatedAt, "interrupt-tool-1", - "interrupted tool result must have a result timestamp") - require.NotContains(t, persistedStep.ToolCallCreatedAt, "interrupt-tool-1", - "interrupted tool should have no call timestamp (never reached StreamPartTypeToolCall)") + t.Fatal("expected GenerateAssistant to panic") } func requireToolResultErrorMessage( @@ -1435,16 +433,6 @@ func streamFromParts(parts []fantasy.StreamPart) fantasy.StreamResponse { }) } -func newNoopTool(name string) fantasy.AgentTool { - return fantasy.NewAgentTool( - name, - "test noop tool", - func(context.Context, struct{}, fantasy.ToolCall) (fantasy.ToolResponse, error) { - return fantasy.ToolResponse{}, nil - }, - ) -} - func textMessage(role fantasy.MessageRole, text string) fantasy.Message { return fantasy.Message{ Role: role, @@ -1454,71 +442,6 @@ func textMessage(role fantasy.MessageRole, text string) fantasy.Message { } } -func requireNoProviderExecutedToolCallContent(t *testing.T, content []fantasy.Content) { - t.Helper() - - for i, block := range content { - toolCall, ok := fantasy.AsContentType[fantasy.ToolCallContent](block) - if ok && toolCall.ProviderExecuted { - t.Fatalf("content[%d]: unexpected provider-executed call", i) - } - } -} - -func requireNoProviderExecutedToolResultContent(t *testing.T, content []fantasy.Content) { - t.Helper() - - for i, block := range content { - toolResult, ok := fantasy.AsContentType[fantasy.ToolResultContent](block) - if ok && toolResult.ProviderExecuted { - t.Fatalf("content[%d]: unexpected provider-executed result", i) - } - } -} - -func requireReasoningPrompt(t *testing.T, prompt []fantasy.Message) fantasy.ReasoningPart { - t.Helper() - - for _, message := range prompt { - for _, part := range message.Content { - reasoningPart, ok := fantasy.AsMessagePart[fantasy.ReasoningPart](part) - if ok { - return reasoningPart - } - } - } - t.Fatal("missing prompt reasoning") - return fantasy.ReasoningPart{} -} - -func requireTextPrompt(t *testing.T, prompt []fantasy.Message, text string) fantasy.TextPart { - t.Helper() - - for _, message := range prompt { - for _, part := range message.Content { - textPart, ok := fantasy.AsMessagePart[fantasy.TextPart](part) - if ok && textPart.Text == text { - return textPart - } - } - } - t.Fatalf("missing prompt text %q", text) - return fantasy.TextPart{} -} - -func requireNoProviderExecutedToolCallPrompt(t *testing.T, prompt []fantasy.Message) { - t.Helper() - - for i, message := range prompt { - for j, part := range message.Content { - toolCall, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](part) - if ok && toolCall.ProviderExecuted { - t.Fatalf("prompt[%d].content[%d]: unexpected provider-executed call", i, j) - } - } - } -} - func requireTextContent(t *testing.T, content []fantasy.Content, text string) fantasy.TextContent { t.Helper() @@ -1532,634 +455,6 @@ func requireTextContent(t *testing.T, content []fantasy.Content, text string) fa return fantasy.TextContent{} } -func requireToolCallContent(t *testing.T, content []fantasy.Content, id, name string) fantasy.ToolCallContent { - t.Helper() - - for _, block := range content { - toolCall, ok := fantasy.AsContentType[fantasy.ToolCallContent](block) - if ok && toolCall.ToolCallID == id && toolCall.ToolName == name { - return toolCall - } - } - t.Fatalf("missing tool call %q", id) - return fantasy.ToolCallContent{} -} - -func requireToolResultContent(t *testing.T, content []fantasy.Content, id, name string) fantasy.ToolResultContent { - t.Helper() - - for _, block := range content { - toolResult, ok := fantasy.AsContentType[fantasy.ToolResultContent](block) - if ok && toolResult.ToolCallID == id && toolResult.ToolName == name { - return toolResult - } - } - t.Fatalf("missing tool result %q", id) - return fantasy.ToolResultContent{} -} - -func requireToolResultPrompt(t *testing.T, prompt []fantasy.Message, id string) fantasy.ToolResultPart { - t.Helper() - - for _, message := range prompt { - for _, part := range message.Content { - toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part) - if ok && toolResult.ToolCallID == id { - return toolResult - } - } - } - t.Fatalf("missing prompt tool result %q", id) - return fantasy.ToolResultPart{} -} - -func requireNoProviderExecutedToolResultPrompt(t *testing.T, prompt []fantasy.Message) { - t.Helper() - - for i, message := range prompt { - for j, part := range message.Content { - toolResult, ok := safeToolResultPart(part) - if ok && toolResult.ProviderExecuted { - t.Fatalf("prompt[%d].content[%d]: unexpected provider-executed result", i, j) - } - } - } -} - -func requireProviderExecutedToolCallPrompt( - t *testing.T, - prompt []fantasy.Message, - id string, -) fantasy.ToolCallPart { - t.Helper() - - for _, message := range prompt { - for _, part := range message.Content { - toolCall, ok := safeToolCallPart(part) - if ok && toolCall.ProviderExecuted && toolCall.ToolCallID == id { - return toolCall - } - } - } - t.Fatalf("missing provider-executed prompt tool call %q", id) - return fantasy.ToolCallPart{} -} - -func requireProviderExecutedToolResultPrompt( - t *testing.T, - prompt []fantasy.Message, - id string, -) fantasy.ToolResultPart { - t.Helper() - - for _, message := range prompt { - for _, part := range message.Content { - toolResult, ok := safeToolResultPart(part) - if ok && toolResult.ProviderExecuted && toolResult.ToolCallID == id { - return toolResult - } - } - } - t.Fatalf("missing provider-executed prompt tool result %q", id) - return fantasy.ToolResultPart{} -} - -func requireAnthropicProviderToolPromptSafe(t *testing.T, prompt []fantasy.Message) { - t.Helper() - - require.Empty(t, chatsanitize.ValidateAnthropicProviderToolHistory(prompt)) -} - -func requireLogField(t *testing.T, entry slog.SinkEntry, name string) any { - t.Helper() - - for _, field := range entry.Fields { - if field.Name == name { - return field.Value - } - } - t.Fatalf("missing log field %q", name) - return nil -} - -func containsPromptSentinel(prompt []fantasy.Message) bool { - for _, message := range prompt { - if message.Role != fantasy.MessageRoleUser || len(message.Content) != 1 { - continue - } - textPart, ok := fantasy.AsMessagePart[fantasy.TextPart](message.Content[0]) - if !ok { - continue - } - if strings.HasPrefix(textPart.Text, "__chatd_agent_prompt_sentinel_") { - return true - } - } - return false -} - -func TestRun_MultiStepToolExecution(t *testing.T) { - t.Parallel() - - var mu sync.Mutex - var streamCalls int - var secondCallPrompt []fantasy.Message - - model := &chattest.FakeModel{ - ProviderName: "fake", - StreamFn: func(_ context.Context, call fantasy.Call) (fantasy.StreamResponse, error) { - mu.Lock() - step := streamCalls - streamCalls++ - mu.Unlock() - - switch step { - case 0: - // Step 0: produce a tool call. - return streamFromParts([]fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeToolInputStart, ID: "tc-1", ToolCallName: "read_file"}, - {Type: fantasy.StreamPartTypeToolInputDelta, ID: "tc-1", Delta: `{"path":"main.go"}`}, - {Type: fantasy.StreamPartTypeToolInputEnd, ID: "tc-1"}, - { - Type: fantasy.StreamPartTypeToolCall, - ID: "tc-1", - ToolCallName: "read_file", - ToolCallInput: `{"path":"main.go"}`, - }, - {Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonToolCalls}, - }), nil - default: - // Step 1: capture the prompt the loop sent us, - // then return plain text. - mu.Lock() - secondCallPrompt = append([]fantasy.Message(nil), call.Prompt...) - mu.Unlock() - return streamFromParts([]fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeTextStart, ID: "text-1"}, - {Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "all done"}, - {Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"}, - {Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop}, - }), nil - } - }, - } - - var persistStepCalls int - var persistedSteps []PersistedStep - err := Run(context.Background(), RunOptions{ - Model: model, - Messages: []fantasy.Message{ - textMessage(fantasy.MessageRoleUser, "please read main.go"), - }, - Tools: []fantasy.AgentTool{ - newNoopTool("read_file"), - }, - MaxSteps: 5, - PersistStep: func(_ context.Context, step PersistedStep) error { - persistStepCalls++ - persistedSteps = append(persistedSteps, step) - return nil - }, - }) - require.NoError(t, err) - - // Stream was called twice: once for the tool-call step, - // once for the follow-up text step. - require.Equal(t, 2, streamCalls) - - // PersistStep is called once per step. - require.Equal(t, 2, persistStepCalls) - - // The second call's prompt must contain the assistant message - // from step 0 (with the tool call) and a tool-result message. - require.NotEmpty(t, secondCallPrompt) - - var foundAssistantToolCall bool - var foundToolResult bool - for _, msg := range secondCallPrompt { - if msg.Role == fantasy.MessageRoleAssistant { - for _, part := range msg.Content { - if tc, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](part); ok { - if tc.ToolCallID == "tc-1" && tc.ToolName == "read_file" { - foundAssistantToolCall = true - } - } - } - } - if msg.Role == fantasy.MessageRoleTool { - for _, part := range msg.Content { - if tr, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part); ok { - if tr.ToolCallID == "tc-1" { - foundToolResult = true - } - } - } - } - } - require.True(t, foundAssistantToolCall, "second call prompt should contain assistant tool call from step 0") - require.True(t, foundToolResult, "second call prompt should contain tool result message") - - // The first persisted step (tool-call step) must carry - // accurate timestamps for duration computation. - require.Len(t, persistedSteps, 2) - toolStep := persistedSteps[0] - require.Contains(t, toolStep.ToolCallCreatedAt, "tc-1", - "tool-call step must record when the model emitted the call") - require.Contains(t, toolStep.ToolResultCreatedAt, "tc-1", - "tool-call step must record when the tool result was produced") - require.False(t, toolStep.ToolResultCreatedAt["tc-1"].Before(toolStep.ToolCallCreatedAt["tc-1"]), - "tool-result timestamp must be >= tool-call timestamp") -} - -func TestStopAfterTool_Success(t *testing.T) { - t.Parallel() - - streamCalls := 0 - model := &chattest.FakeModel{ - ProviderName: "fake", - StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { - streamCalls++ - return streamFromParts([]fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeToolInputStart, ID: "tc-plan", ToolCallName: "propose_plan"}, - {Type: fantasy.StreamPartTypeToolInputDelta, ID: "tc-plan", Delta: `{"path":"/tmp/plan.md"}`}, - {Type: fantasy.StreamPartTypeToolInputEnd, ID: "tc-plan"}, - { - Type: fantasy.StreamPartTypeToolCall, - ID: "tc-plan", - ToolCallName: "propose_plan", - ToolCallInput: `{"path":"/tmp/plan.md"}`, - }, - {Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonToolCalls}, - }), nil - }, - } - - proposePlanTool := fantasy.NewAgentTool( - "propose_plan", - "writes a plan", - func(context.Context, struct{}, fantasy.ToolCall) (fantasy.ToolResponse, error) { - return fantasy.NewTextResponse("plan saved"), nil - }, - ) - - var persistedSteps []PersistedStep - persistStepCalls := 0 - - err := Run(context.Background(), RunOptions{ - Model: model, - Messages: []fantasy.Message{ - textMessage(fantasy.MessageRoleUser, "propose a plan"), - }, - Tools: []fantasy.AgentTool{proposePlanTool}, - MaxSteps: 5, - StopAfterTools: map[string]struct{}{ - "propose_plan": {}, - }, - PersistStep: func(_ context.Context, step PersistedStep) error { - persistStepCalls++ - persistedSteps = append(persistedSteps, step) - return nil - }, - }) - require.ErrorIs(t, err, ErrStopAfterTool) - require.Equal(t, 1, streamCalls) - require.Equal(t, 1, persistStepCalls) - require.Len(t, persistedSteps, 1) - - var foundToolResult bool - for _, block := range persistedSteps[0].Content { - toolResult, ok := fantasy.AsContentType[fantasy.ToolResultContent](block) - if !ok || toolResult.ToolName != "propose_plan" { - continue - } - foundToolResult = true - _, isErr := toolResult.Result.(fantasy.ToolResultOutputContentError) - require.False(t, isErr, "stop-after-tool should only trigger on successful tool results") - } - require.True(t, foundToolResult, "persisted step should include the successful tool result before stopping") -} - -func TestStopAfterTool_IgnoresErrorResults(t *testing.T) { - t.Parallel() - - streamCalls := 0 - model := &chattest.FakeModel{ - ProviderName: "fake", - StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { - streamCalls++ - if streamCalls == 1 { - return streamFromParts([]fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeToolInputStart, ID: "tc-plan", ToolCallName: "propose_plan"}, - {Type: fantasy.StreamPartTypeToolInputDelta, ID: "tc-plan", Delta: `{"path":"/tmp/plan.md"}`}, - {Type: fantasy.StreamPartTypeToolInputEnd, ID: "tc-plan"}, - { - Type: fantasy.StreamPartTypeToolCall, - ID: "tc-plan", - ToolCallName: "propose_plan", - ToolCallInput: `{"path":"/tmp/plan.md"}`, - }, - {Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonToolCalls}, - }), nil - } - return streamFromParts([]fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeTextStart, ID: "text-1"}, - {Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "tool failed, continue"}, - {Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"}, - {Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop}, - }), nil - }, - } - - proposePlanTool := fantasy.NewAgentTool( - "propose_plan", - "writes a plan", - func(context.Context, struct{}, fantasy.ToolCall) (fantasy.ToolResponse, error) { - return fantasy.NewTextErrorResponse("plan failed"), nil - }, - ) - - var persistedSteps []PersistedStep - err := Run(context.Background(), RunOptions{ - Model: model, - Messages: []fantasy.Message{ - textMessage(fantasy.MessageRoleUser, "propose a plan"), - }, - Tools: []fantasy.AgentTool{proposePlanTool}, - MaxSteps: 5, - StopAfterTools: map[string]struct{}{ - "propose_plan": {}, - }, - PersistStep: func(_ context.Context, step PersistedStep) error { - persistedSteps = append(persistedSteps, step) - return nil - }, - }) - require.NoError(t, err) - require.Equal(t, 2, streamCalls) - require.Len(t, persistedSteps, 2) - - var foundToolError bool - for _, block := range persistedSteps[0].Content { - toolResult, ok := fantasy.AsContentType[fantasy.ToolResultContent](block) - if !ok || toolResult.ToolName != "propose_plan" { - continue - } - _, foundToolError = toolResult.Result.(fantasy.ToolResultOutputContentError) - } - require.True(t, foundToolError, "first step should persist the failed tool result") -} - -func TestRun_ParallelToolExecutionTimestamps(t *testing.T) { - t.Parallel() - - var mu sync.Mutex - var streamCalls int - - model := &chattest.FakeModel{ - ProviderName: "fake", - StreamFn: func(_ context.Context, call fantasy.Call) (fantasy.StreamResponse, error) { - mu.Lock() - step := streamCalls - streamCalls++ - mu.Unlock() - - _ = call - - switch step { - case 0: - // Step 0: produce two tool calls in one stream. - return streamFromParts([]fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeToolInputStart, ID: "tc-1", ToolCallName: "read_file"}, - {Type: fantasy.StreamPartTypeToolInputDelta, ID: "tc-1", Delta: `{"path":"a.go"}`}, - {Type: fantasy.StreamPartTypeToolInputEnd, ID: "tc-1"}, - { - Type: fantasy.StreamPartTypeToolCall, - ID: "tc-1", - ToolCallName: "read_file", - ToolCallInput: `{"path":"a.go"}`, - }, - {Type: fantasy.StreamPartTypeToolInputStart, ID: "tc-2", ToolCallName: "write_file"}, - {Type: fantasy.StreamPartTypeToolInputDelta, ID: "tc-2", Delta: `{"path":"b.go"}`}, - {Type: fantasy.StreamPartTypeToolInputEnd, ID: "tc-2"}, - { - Type: fantasy.StreamPartTypeToolCall, - ID: "tc-2", - ToolCallName: "write_file", - ToolCallInput: `{"path":"b.go"}`, - }, - {Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonToolCalls}, - }), nil - default: - // Step 1: return plain text. - return streamFromParts([]fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeTextStart, ID: "text-1"}, - {Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "all done"}, - {Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"}, - {Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop}, - }), nil - } - }, - } - - var persistedSteps []PersistedStep - err := Run(context.Background(), RunOptions{ - Model: model, - Messages: []fantasy.Message{ - textMessage(fantasy.MessageRoleUser, "do both"), - }, - Tools: []fantasy.AgentTool{ - newNoopTool("read_file"), - newNoopTool("write_file"), - }, - MaxSteps: 5, - PersistStep: func(_ context.Context, step PersistedStep) error { - persistedSteps = append(persistedSteps, step) - return nil - }, - }) - require.NoError(t, err) - - // Two steps: tool-call step + text step. - require.Equal(t, 2, streamCalls) - require.Len(t, persistedSteps, 2) - - toolStep := persistedSteps[0] - - // Both tool-call IDs must appear in ToolCallCreatedAt. - require.Contains(t, toolStep.ToolCallCreatedAt, "tc-1", - "tool-call step must record when tc-1 was emitted") - require.Contains(t, toolStep.ToolCallCreatedAt, "tc-2", - "tool-call step must record when tc-2 was emitted") - - // Both tool-call IDs must appear in ToolResultCreatedAt. - require.Contains(t, toolStep.ToolResultCreatedAt, "tc-1", - "tool-call step must record when tc-1 result was produced") - require.Contains(t, toolStep.ToolResultCreatedAt, "tc-2", - "tool-call step must record when tc-2 result was produced") - - // Result timestamps must be >= call timestamps for both. - require.False(t, toolStep.ToolResultCreatedAt["tc-1"].Before(toolStep.ToolCallCreatedAt["tc-1"]), - "tc-1 tool-result timestamp must be >= tool-call timestamp") - require.False(t, toolStep.ToolResultCreatedAt["tc-2"].Before(toolStep.ToolCallCreatedAt["tc-2"]), - "tc-2 tool-result timestamp must be >= tool-call timestamp") -} - -// TestRun_ExclusiveToolPolicyViolation exercises the full Run() -> -// executeToolsForStep() -> applyExclusiveToolPolicy() wiring. When an -// exclusive tool is called alongside other locally-executable tools, -// neither runner must fire and every call in the batch must receive a -// synthesized policy error that is both persisted and published via -// SSE. This guards against a regression where -// executeToolsForStep's policy call is accidentally removed: the -// pure-unit tests cover the policy function in isolation, but only -// this test catches a broken wiring path. -func TestRun_ExclusiveToolPolicyViolation(t *testing.T) { - t.Parallel() - - var advisorRuns atomic.Int32 - advisorTool := fantasy.NewAgentTool( - "advisor", - "returns strategic guidance", - func(context.Context, struct{}, fantasy.ToolCall) (fantasy.ToolResponse, error) { - advisorRuns.Add(1) - return fantasy.NewTextResponse(`{"status":"ok"}`), nil - }, - ) - var readRuns atomic.Int32 - readTool := fantasy.NewAgentTool( - "read_file", - "reads a file", - func(context.Context, struct{}, fantasy.ToolCall) (fantasy.ToolResponse, error) { - readRuns.Add(1) - return fantasy.NewTextResponse(`{"contents":"main"}`), nil - }, - ) - - var mu sync.Mutex - var streamCalls int - model := &chattest.FakeModel{ - ProviderName: "fake", - StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { - mu.Lock() - step := streamCalls - streamCalls++ - mu.Unlock() - - if step == 0 { - // Step 0: model emits an illegal mixed batch. - return streamFromParts([]fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeToolInputStart, ID: "advisor-1", ToolCallName: "advisor"}, - {Type: fantasy.StreamPartTypeToolInputDelta, ID: "advisor-1", Delta: `{}`}, - {Type: fantasy.StreamPartTypeToolInputEnd, ID: "advisor-1"}, - { - Type: fantasy.StreamPartTypeToolCall, - ID: "advisor-1", - ToolCallName: "advisor", - ToolCallInput: `{}`, - }, - {Type: fantasy.StreamPartTypeToolInputStart, ID: "read-1", ToolCallName: "read_file"}, - {Type: fantasy.StreamPartTypeToolInputDelta, ID: "read-1", Delta: `{"path":"main.go"}`}, - {Type: fantasy.StreamPartTypeToolInputEnd, ID: "read-1"}, - { - Type: fantasy.StreamPartTypeToolCall, - ID: "read-1", - ToolCallName: "read_file", - ToolCallInput: `{"path":"main.go"}`, - }, - {Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonToolCalls}, - }), nil - } - // Step 1: the loop re-streams after tool results; end the run. - return streamFromParts([]fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeTextStart, ID: "text-1"}, - {Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "ok, retrying"}, - {Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"}, - {Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop}, - }), nil - }, - } - - var persistedSteps []PersistedStep - var publishedToolParts []codersdk.ChatMessagePart - err := Run(context.Background(), RunOptions{ - Model: model, - Messages: []fantasy.Message{ - textMessage(fantasy.MessageRoleUser, "please advise and read"), - }, - Tools: []fantasy.AgentTool{advisorTool, readTool}, - ExclusiveToolNames: map[string]bool{"advisor": true}, - MaxSteps: 5, - PersistStep: func(_ context.Context, step PersistedStep) error { - persistedSteps = append(persistedSteps, step) - return nil - }, - PublishMessagePart: func(role codersdk.ChatMessageRole, part codersdk.ChatMessagePart) { - if role != codersdk.ChatMessageRoleTool { - return - } - publishedToolParts = append(publishedToolParts, part) - }, - }) - require.NoError(t, err) - - // Neither runner must have fired: the policy short-circuits - // before partitioning and execution. - require.Equal(t, int32(0), advisorRuns.Load(), - "advisor runner must not fire on mixed batches") - require.Equal(t, int32(0), readRuns.Load(), - "read_file runner must not fire on mixed batches") - - // Two steps: the mixed-batch step plus the follow-up stream. - require.Len(t, persistedSteps, 2) - firstStep := persistedSteps[0] - - advisorErr, ok := findToolResultByID(firstStep.Content, "advisor-1") - require.True(t, ok, "persisted step must contain the advisor policy result") - requireToolResultErrorMessage(t, advisorErr, - "advisor must be called alone, without other tools in the same batch. Retry with only the advisor call.") - - readErr, ok := findToolResultByID(firstStep.Content, "read-1") - require.True(t, ok, "persisted step must contain the read_file policy result") - requireToolResultErrorMessage(t, readErr, - "this tool was skipped because advisor must run alone in its batch. Retry your tool calls without advisor, or call advisor separately first.") - - // Policy-error results must be SSE-published so the client - // can render them immediately. Confirm both tool-result parts - // reached PublishMessagePart with a non-nil CreatedAt, which - // is the dbtime.Now() stamp the policy branch sets. - var sawAdvisorPart, sawReadPart bool - for _, part := range publishedToolParts { - switch part.ToolCallID { - case "advisor-1": - sawAdvisorPart = true - require.NotNil(t, part.CreatedAt, - "policy result SSE part must carry the dbtime.Now() timestamp") - case "read-1": - sawReadPart = true - require.NotNil(t, part.CreatedAt, - "policy result SSE part must carry the dbtime.Now() timestamp") - } - } - require.True(t, sawAdvisorPart, "advisor policy result must be SSE-published") - require.True(t, sawReadPart, "read_file policy result must be SSE-published") -} - -func findToolResultByID( - content []fantasy.Content, - toolCallID string, -) (fantasy.ToolResultContent, bool) { - for _, block := range content { - tr, ok := fantasy.AsContentType[fantasy.ToolResultContent](block) - if !ok { - continue - } - if tr.ToolCallID == toolCallID { - return tr, true - } - } - return fantasy.ToolResultContent{}, false -} - func TestExclusiveToolPolicy_MixedBatchErrors(t *testing.T) { t.Parallel() @@ -2244,1191 +539,6 @@ func TestExclusiveToolPolicy_MultipleExclusive(t *testing.T) { ) } -// TestRun_ExclusiveToolPolicyBlocksMixedWithDynamicTool guards the -// exclusive-over-dynamic bypass: the policy must run before the -// built-in vs dynamic partition. If a future refactor moves the -// policy check beneath the partition (so only built-in calls are -// inspected), an exclusive builtin mixed with a dynamic tool would -// still execute locally while the dynamic call is handed off via -// ErrDynamicToolCall, breaking the planning-only contract. -// -// This test has the model emit an exclusive builtin (advisor) -// alongside a dynamic tool (mcp_tool) in the same batch and asserts -// that Run does NOT exit with ErrDynamicToolCall, the advisor -// runner never fires, and both calls receive a synthesized policy -// error. -func TestRun_ExclusiveToolPolicyBlocksMixedWithDynamicTool(t *testing.T) { - t.Parallel() - - var advisorRuns atomic.Int32 - advisorTool := fantasy.NewAgentTool( - "advisor", - "returns strategic guidance", - func(context.Context, struct{}, fantasy.ToolCall) (fantasy.ToolResponse, error) { - advisorRuns.Add(1) - return fantasy.NewTextResponse(`{"status":"ok"}`), nil - }, - ) - - var mu sync.Mutex - var streamCalls int - model := &chattest.FakeModel{ - ProviderName: "fake", - StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { - mu.Lock() - step := streamCalls - streamCalls++ - mu.Unlock() - - if step == 0 { - // Step 0: model emits an illegal mixed batch - // combining an exclusive builtin with a - // dynamic tool. - return streamFromParts([]fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeToolInputStart, ID: "advisor-1", ToolCallName: "advisor"}, - {Type: fantasy.StreamPartTypeToolInputDelta, ID: "advisor-1", Delta: `{}`}, - {Type: fantasy.StreamPartTypeToolInputEnd, ID: "advisor-1"}, - { - Type: fantasy.StreamPartTypeToolCall, - ID: "advisor-1", - ToolCallName: "advisor", - ToolCallInput: `{}`, - }, - {Type: fantasy.StreamPartTypeToolInputStart, ID: "mcp-1", ToolCallName: "mcp_tool"}, - {Type: fantasy.StreamPartTypeToolInputDelta, ID: "mcp-1", Delta: `{"q":"docs"}`}, - {Type: fantasy.StreamPartTypeToolInputEnd, ID: "mcp-1"}, - { - Type: fantasy.StreamPartTypeToolCall, - ID: "mcp-1", - ToolCallName: "mcp_tool", - ToolCallInput: `{"q":"docs"}`, - }, - {Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonToolCalls}, - }), nil - } - // Step 1: after the policy error is fed back, - // terminate the run so the test assertions have a - // deterministic exit. - return streamFromParts([]fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeTextStart, ID: "text-1"}, - {Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "retrying"}, - {Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"}, - {Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop}, - }), nil - }, - } - - var persistedSteps []PersistedStep - err := Run(context.Background(), RunOptions{ - Model: model, - Messages: []fantasy.Message{ - textMessage(fantasy.MessageRoleUser, "please advise and fetch"), - }, - Tools: []fantasy.AgentTool{advisorTool}, - DynamicToolNames: map[string]bool{"mcp_tool": true}, - ExclusiveToolNames: map[string]bool{"advisor": true}, - MaxSteps: 5, - PersistStep: func(_ context.Context, step PersistedStep) error { - persistedSteps = append(persistedSteps, step) - return nil - }, - }) - // Run must NOT exit with ErrDynamicToolCall: the policy - // short-circuits before the dynamic partition so the dynamic - // call is never handed off for external execution. - require.NoError(t, err) - - // The advisor runner must not fire on mixed batches; the - // policy blocks the whole batch including the exclusive tool - // itself. - require.Equal(t, int32(0), advisorRuns.Load(), - "advisor runner must not fire on mixed batches") - - // Two steps: the mixed-batch step with synthesized policy - // errors plus the follow-up stream that ends the run. - require.Len(t, persistedSteps, 2) - firstStep := persistedSteps[0] - - // The persisted step must not record the dynamic tool as - // pending: the policy-error path returns before - // persistPendingDynamicStep runs. - require.Empty(t, firstStep.PendingDynamicToolCalls, - "policy-rejected batches must not leak dynamic tool calls to the caller") - - advisorErr, ok := findToolResultByID(firstStep.Content, "advisor-1") - require.True(t, ok, "persisted step must contain the advisor policy result") - requireToolResultErrorMessage(t, advisorErr, - "advisor must be called alone, without other tools in the same batch. Retry with only the advisor call.") - - mcpErr, ok := findToolResultByID(firstStep.Content, "mcp-1") - require.True(t, ok, "persisted step must contain the mcp_tool policy result") - requireToolResultErrorMessage(t, mcpErr, - "this tool was skipped because advisor must run alone in its batch. Retry your tool calls without advisor, or call advisor separately first.") -} - -// TestRun_ExclusiveToolAloneSucceeds is the happy-path counterpart -// to TestRun_ExclusiveToolPolicyViolation: a single exclusive tool -// emitted alone must actually execute. The `len(toolCalls) <= 1` -// guard in firstExclusiveToolName is the sole mechanism that lets -// solo exclusive-tool calls proceed. If that guard regresses to -// `< 1`, every solo exclusive-tool call would enter an infinite -// policy-error/retry loop, and every unit test on the policy -// function in isolation would still pass. Only this Run()-level -// test catches that regression. -func TestRun_ExclusiveToolAloneSucceeds(t *testing.T) { - t.Parallel() - - var advisorRuns atomic.Int32 - advisorTool := fantasy.NewAgentTool( - "advisor", - "returns strategic guidance", - func(context.Context, struct{}, fantasy.ToolCall) (fantasy.ToolResponse, error) { - advisorRuns.Add(1) - return fantasy.NewTextResponse(`{"status":"ok"}`), nil - }, - ) - - var mu sync.Mutex - var streamCalls int - model := &chattest.FakeModel{ - ProviderName: "fake", - StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { - mu.Lock() - step := streamCalls - streamCalls++ - mu.Unlock() - - if step == 0 { - // Step 0: model emits exactly one - // exclusive-tool call in isolation. - return streamFromParts([]fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeToolInputStart, ID: "advisor-1", ToolCallName: "advisor"}, - {Type: fantasy.StreamPartTypeToolInputDelta, ID: "advisor-1", Delta: `{}`}, - {Type: fantasy.StreamPartTypeToolInputEnd, ID: "advisor-1"}, - { - Type: fantasy.StreamPartTypeToolCall, - ID: "advisor-1", - ToolCallName: "advisor", - ToolCallInput: `{}`, - }, - {Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonToolCalls}, - }), nil - } - // Step 1: the loop re-streams after the tool - // result; end the run. - return streamFromParts([]fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeTextStart, ID: "text-1"}, - {Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "done"}, - {Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"}, - {Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop}, - }), nil - }, - } - - var persistedSteps []PersistedStep - err := Run(context.Background(), RunOptions{ - Model: model, - Messages: []fantasy.Message{ - textMessage(fantasy.MessageRoleUser, "please advise"), - }, - Tools: []fantasy.AgentTool{advisorTool}, - ExclusiveToolNames: map[string]bool{"advisor": true}, - MaxSteps: 5, - PersistStep: func(_ context.Context, step PersistedStep) error { - persistedSteps = append(persistedSteps, step) - return nil - }, - }) - require.NoError(t, err) - - // The solo exclusive tool must actually execute exactly once. - require.Equal(t, int32(1), advisorRuns.Load(), - "solo exclusive-tool call must execute") - - // The first persisted step must contain a non-error tool - // result for the advisor call, proving the policy did not - // synthesize an error and the real runner fired. - require.GreaterOrEqual(t, len(persistedSteps), 1) - result, ok := findToolResultByID(persistedSteps[0].Content, "advisor-1") - require.True(t, ok, "persisted step must contain the advisor tool result") - _, isErr := result.Result.(fantasy.ToolResultOutputContentError) - require.Falsef(t, isErr, - "solo exclusive-tool call must produce a real tool result, not a policy error: %+v", result.Result) -} - -// TestRun_ExclusiveToolWithProviderExecutedSucceeds guards the -// interaction between the ProviderExecuted filter and the -// exclusive-tool policy. executeToolsForStep builds localCandidates -// by dropping ProviderExecuted calls before passing them to -// applyExclusiveToolPolicy. That filter is the sole mechanism -// preventing a false policy violation when a solo exclusive tool -// appears in a batch where the provider also server-executed a tool -// (for example Anthropic web_search). -// -// If the filter is removed, localCandidates would contain both the -// provider-executed call and the exclusive call. firstExclusiveToolName -// would then see len > 1, find advisor, and return a violation. The -// advisor would never run and the retry loop would burn steps until -// MaxSteps. -// -// This test emits an advisor call alongside a provider-executed -// web_search call (with its provider-emitted result) and asserts the -// advisor runner actually fires. -func TestRun_ExclusiveToolWithProviderExecutedSucceeds(t *testing.T) { - t.Parallel() - - var advisorRuns atomic.Int32 - advisorTool := fantasy.NewAgentTool( - "advisor", - "returns strategic guidance", - func(context.Context, struct{}, fantasy.ToolCall) (fantasy.ToolResponse, error) { - advisorRuns.Add(1) - return fantasy.NewTextResponse(`{"status":"ok"}`), nil - }, - ) - - var mu sync.Mutex - var streamCalls int - model := &chattest.FakeModel{ - ProviderName: "fake", - StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { - mu.Lock() - step := streamCalls - streamCalls++ - mu.Unlock() - - if step == 0 { - // Step 0: provider server-executed web_search and - // returned its result inline, plus the model - // emitted an exclusive advisor call for local - // execution. The ProviderExecuted filter must - // drop web_search from the policy check so the - // advisor is treated as a solo exclusive call. - return streamFromParts([]fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeToolInputStart, ID: "ws-1", ToolCallName: "web_search", ProviderExecuted: true}, - {Type: fantasy.StreamPartTypeToolInputDelta, ID: "ws-1", Delta: `{"query":"coder"}`, ProviderExecuted: true}, - {Type: fantasy.StreamPartTypeToolInputEnd, ID: "ws-1"}, - { - Type: fantasy.StreamPartTypeToolCall, - ID: "ws-1", - ToolCallName: "web_search", - ToolCallInput: `{"query":"coder"}`, - ProviderExecuted: true, - }, - { - Type: fantasy.StreamPartTypeToolResult, - ID: "ws-1", - ToolCallName: "web_search", - ProviderExecuted: true, - }, - {Type: fantasy.StreamPartTypeToolInputStart, ID: "advisor-1", ToolCallName: "advisor"}, - {Type: fantasy.StreamPartTypeToolInputDelta, ID: "advisor-1", Delta: `{}`}, - {Type: fantasy.StreamPartTypeToolInputEnd, ID: "advisor-1"}, - { - Type: fantasy.StreamPartTypeToolCall, - ID: "advisor-1", - ToolCallName: "advisor", - ToolCallInput: `{}`, - }, - {Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonToolCalls}, - }), nil - } - // Step 1: end the run after the advisor result is - // fed back. - return streamFromParts([]fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeTextStart, ID: "text-1"}, - {Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "done"}, - {Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"}, - {Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop}, - }), nil - }, - } - - var persistedSteps []PersistedStep - err := Run(context.Background(), RunOptions{ - Model: model, - Messages: []fantasy.Message{ - textMessage(fantasy.MessageRoleUser, "search and then advise"), - }, - Tools: []fantasy.AgentTool{advisorTool}, - ExclusiveToolNames: map[string]bool{"advisor": true}, - MaxSteps: 5, - PersistStep: func(_ context.Context, step PersistedStep) error { - persistedSteps = append(persistedSteps, step) - return nil - }, - }) - require.NoError(t, err) - - // The advisor must execute exactly once: the ProviderExecuted - // filter removes web_search from the exclusivity check, so the - // advisor is treated as a solo exclusive call. - require.Equal(t, int32(1), advisorRuns.Load(), - "advisor must execute when the only other call in the batch was provider-executed") - - // The advisor result must be a real tool result, not a - // synthesized policy error. - require.GreaterOrEqual(t, len(persistedSteps), 1) - advisorResult, ok := findToolResultByID(persistedSteps[0].Content, "advisor-1") - require.True(t, ok, "persisted step must contain the advisor tool result") - _, isErr := advisorResult.Result.(fantasy.ToolResultOutputContentError) - require.Falsef(t, isErr, - "advisor must produce a real tool result, not a policy error: %+v", advisorResult.Result) -} - -func TestRun_PersistStepErrorPropagates(t *testing.T) { - t.Parallel() - - model := &chattest.FakeModel{ - ProviderName: "fake", - StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { - return streamFromParts([]fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeTextStart, ID: "text-1"}, - {Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "hello"}, - {Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"}, - {Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop}, - }), nil - }, - } - - persistErr := xerrors.New("database write failed") - err := Run(context.Background(), RunOptions{ - Model: model, - Messages: []fantasy.Message{ - textMessage(fantasy.MessageRoleUser, "hello"), - }, - MaxSteps: 1, - PersistStep: func(_ context.Context, _ PersistedStep) error { - return persistErr - }, - }) - require.Error(t, err) - require.ErrorContains(t, err, "database write failed") -} - -// TestRun_ShutdownDuringToolExecutionReturnsContextCanceled verifies that -// when the parent context is canceled (simulating server shutdown) while -// a tool is blocked, Run returns context.Canceled, not ErrInterrupted. -// This matters because the caller uses the error type to decide whether -// to set chat status to "pending" (retryable on another worker) vs -// "waiting" (stuck forever). -func TestRun_ShutdownDuringToolExecutionReturnsContextCanceled(t *testing.T) { - t.Parallel() - - toolStarted := make(chan struct{}) - - // Model returns a single tool call, then finishes. - model := &chattest.FakeModel{ - ProviderName: "fake", - StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { - return streamFromParts([]fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeToolInputStart, ID: "tc-block", ToolCallName: "blocking_tool"}, - {Type: fantasy.StreamPartTypeToolInputDelta, ID: "tc-block", Delta: `{}`}, - {Type: fantasy.StreamPartTypeToolInputEnd, ID: "tc-block"}, - { - Type: fantasy.StreamPartTypeToolCall, - ID: "tc-block", - ToolCallName: "blocking_tool", - ToolCallInput: `{}`, - }, - {Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonToolCalls}, - }), nil - }, - } - - // Tool that blocks until its context is canceled, simulating - // a long-running operation like wait_agent. - blockingTool := fantasy.NewAgentTool( - "blocking_tool", - "blocks until context canceled", - func(ctx context.Context, _ struct{}, _ fantasy.ToolCall) (fantasy.ToolResponse, error) { - close(toolStarted) - <-ctx.Done() - return fantasy.ToolResponse{}, ctx.Err() - }, - ) - - // Simulate the server context (parent) and chat context - // (child). Canceling the parent simulates graceful shutdown. - serverCtx, serverCancel := context.WithCancel(context.Background()) - defer serverCancel() - - serverCancelDone := make(chan struct{}) - go func() { - defer close(serverCancelDone) - <-toolStarted - t.Logf("tool started, canceling server context to simulate shutdown") - serverCancel() - }() - - // persistStep mirrors the FIXED chatd.go code: it only returns - // ErrInterrupted when the context was actually canceled due to - // an interruption (cause is ErrInterrupted). For shutdown - // (plain context.Canceled), it returns the original error so - // callers can distinguish the two. - persistStep := func(persistCtx context.Context, _ PersistedStep) error { - if persistCtx.Err() != nil { - if errors.Is(context.Cause(persistCtx), ErrInterrupted) { - return ErrInterrupted - } - return persistCtx.Err() - } - return nil - } - - err := Run(serverCtx, RunOptions{ - Model: model, - Messages: []fantasy.Message{ - textMessage(fantasy.MessageRoleUser, "run the blocking tool"), - }, - Tools: []fantasy.AgentTool{blockingTool}, - MaxSteps: 3, - PersistStep: persistStep, - }) - // Wait for the cancel goroutine to finish to aid flake - // diagnosis if the test ever hangs. - <-serverCancelDone - - require.Error(t, err) - // The error must NOT be ErrInterrupted, it should propagate - // as context.Canceled so the caller can distinguish shutdown - // from user interruption. Use assert (not require) so both - // checks are evaluated even if the first fails. - assert.NotErrorIs(t, err, ErrInterrupted, "shutdown cancellation must not be converted to ErrInterrupted") - assert.ErrorIs(t, err, context.Canceled, "shutdown should propagate as context.Canceled") -} - -func TestToResponseMessages_ProviderExecutedToolResultInAssistantMessage(t *testing.T) { - t.Parallel() - - sr := stepResult{ - content: []fantasy.Content{ - // Provider-executed tool call (e.g. web_search). - fantasy.ToolCallContent{ - ToolCallID: "provider-tc-1", - ToolName: "web_search", - Input: `{"query":"coder"}`, - ProviderExecuted: true, - }, - // Provider-executed tool result, must stay in - // assistant message. - fantasy.ToolResultContent{ - ToolCallID: "provider-tc-1", - ToolName: "web_search", - ProviderExecuted: true, - ProviderMetadata: fantasy.ProviderMetadata{"anthropic": nil}, - }, - // Local tool call (e.g. read_file). - fantasy.ToolCallContent{ - ToolCallID: "local-tc-1", - ToolName: "read_file", - Input: `{"path":"main.go"}`, - ProviderExecuted: false, - }, - // Local tool result, should go into tool message. - fantasy.ToolResultContent{ - ToolCallID: "local-tc-1", - ToolName: "read_file", - Result: fantasy.ToolResultOutputContentText{Text: "some result"}, - ProviderExecuted: false, - }, - }, - } - - msgs := sr.toResponseMessages() - require.Len(t, msgs, 2, "expected assistant + tool messages") - - // First message: assistant role. - assistantMsg := msgs[0] - assert.Equal(t, fantasy.MessageRoleAssistant, assistantMsg.Role) - require.Len(t, assistantMsg.Content, 3, - "assistant message should have provider ToolCallPart, provider ToolResultPart, and local ToolCallPart") - - // Part 0: provider tool call. - providerTC, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](assistantMsg.Content[0]) - require.True(t, ok, "part 0 should be ToolCallPart") - assert.Equal(t, "provider-tc-1", providerTC.ToolCallID) - assert.True(t, providerTC.ProviderExecuted) - - // Part 1: provider tool result (inline in assistant turn). - providerTR, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](assistantMsg.Content[1]) - require.True(t, ok, "part 1 should be ToolResultPart") - assert.Equal(t, "provider-tc-1", providerTR.ToolCallID) - assert.True(t, providerTR.ProviderExecuted) - - // Part 2: local tool call. - localTC, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](assistantMsg.Content[2]) - require.True(t, ok, "part 2 should be ToolCallPart") - assert.Equal(t, "local-tc-1", localTC.ToolCallID) - assert.False(t, localTC.ProviderExecuted) - - // Second message: tool role. - toolMsg := msgs[1] - assert.Equal(t, fantasy.MessageRoleTool, toolMsg.Role) - require.Len(t, toolMsg.Content, 1, - "tool message should have only the local ToolResultPart") - - localTR, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](toolMsg.Content[0]) - require.True(t, ok, "tool part should be ToolResultPart") - assert.Equal(t, "local-tc-1", localTR.ToolCallID) - assert.False(t, localTR.ProviderExecuted) -} - -func TestToResponseMessages_FiltersEmptyTextAndReasoningParts(t *testing.T) { - t.Parallel() - - sr := stepResult{ - content: []fantasy.Content{ - // Empty text, should be filtered. - fantasy.TextContent{Text: ""}, - // Whitespace-only text, should be filtered. - fantasy.TextContent{Text: " \t\n"}, - // Empty reasoning, should be filtered. - fantasy.ReasoningContent{Text: ""}, - // Whitespace-only reasoning, should be filtered. - fantasy.ReasoningContent{Text: " \n"}, - // Non-empty text, should pass through. - fantasy.TextContent{Text: "hello world"}, - // Leading/trailing whitespace with content, kept - // with the original value (not trimmed). - fantasy.TextContent{Text: " hello "}, - // Non-empty reasoning, should pass through. - fantasy.ReasoningContent{Text: "let me think"}, - // Tool call, should be unaffected by filtering. - fantasy.ToolCallContent{ - ToolCallID: "tc-1", - ToolName: "read_file", - Input: `{"path":"main.go"}`, - }, - // Local tool result, should be unaffected by filtering. - fantasy.ToolResultContent{ - ToolCallID: "tc-1", - ToolName: "read_file", - Result: fantasy.ToolResultOutputContentText{Text: "file contents"}, - }, - }, - } - - msgs := sr.toResponseMessages() - require.Len(t, msgs, 2, "expected assistant + tool messages") - - // First message: assistant role with non-empty text, reasoning, - // and the tool call. The four empty/whitespace-only parts must - // have been dropped. - assistantMsg := msgs[0] - assert.Equal(t, fantasy.MessageRoleAssistant, assistantMsg.Role) - require.Len(t, assistantMsg.Content, 4, - "assistant message should have 2x TextPart, ReasoningPart, and ToolCallPart") - - // Part 0: non-empty text. - textPart, ok := fantasy.AsMessagePart[fantasy.TextPart](assistantMsg.Content[0]) - require.True(t, ok, "part 0 should be TextPart") - assert.Equal(t, "hello world", textPart.Text) - - // Part 1: padded text, original whitespace preserved. - paddedPart, ok := fantasy.AsMessagePart[fantasy.TextPart](assistantMsg.Content[1]) - require.True(t, ok, "part 1 should be TextPart") - assert.Equal(t, " hello ", paddedPart.Text) - - // Part 2: non-empty reasoning. - reasoningPart, ok := fantasy.AsMessagePart[fantasy.ReasoningPart](assistantMsg.Content[2]) - require.True(t, ok, "part 2 should be ReasoningPart") - assert.Equal(t, "let me think", reasoningPart.Text) - - // Part 3: tool call (unaffected by text/reasoning filtering). - toolCallPart, ok := fantasy.AsMessagePart[fantasy.ToolCallPart](assistantMsg.Content[3]) - require.True(t, ok, "part 3 should be ToolCallPart") - assert.Equal(t, "tc-1", toolCallPart.ToolCallID) - assert.Equal(t, "read_file", toolCallPart.ToolName) - - // Second message: tool role with the local tool result. - toolMsg := msgs[1] - assert.Equal(t, fantasy.MessageRoleTool, toolMsg.Role) - require.Len(t, toolMsg.Content, 1, - "tool message should have only the local ToolResultPart") - - toolResultPart, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](toolMsg.Content[0]) - require.True(t, ok, "tool part should be ToolResultPart") - assert.Equal(t, "tc-1", toolResultPart.ToolCallID) -} - -func hasAnthropicEphemeralCacheControl(message fantasy.Message) bool { - if len(message.ProviderOptions) == 0 { - return false - } - - options, ok := message.ProviderOptions[fantasyanthropic.Name] - if !ok { - return false - } - - cacheOptions, ok := options.(*fantasyanthropic.ProviderCacheControlOptions) - return ok && cacheOptions.CacheControl.Type == "ephemeral" -} - -// TestRun_InterruptedDuringToolExecutionPersistsStep verifies that when -// tools are executing and the chat is interrupted, the accumulated step -// content (assistant blocks + tool results) is persisted via the -// interrupt-safe path rather than being lost. -func TestRun_InterruptedDuringToolExecutionPersistsStep(t *testing.T) { - t.Parallel() - - toolStarted := make(chan struct{}) - - // Model returns a completed tool call in the stream. - model := &chattest.FakeModel{ - ProviderName: "fake", - StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { - return streamFromParts([]fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeTextStart, ID: "text-1"}, - {Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "calling tool"}, - {Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"}, - {Type: fantasy.StreamPartTypeReasoningStart, ID: "reason-1"}, - {Type: fantasy.StreamPartTypeReasoningDelta, ID: "reason-1", Delta: "let me think"}, - {Type: fantasy.StreamPartTypeReasoningEnd, ID: "reason-1"}, - {Type: fantasy.StreamPartTypeToolInputStart, ID: "tc-1", ToolCallName: "slow_tool"}, - {Type: fantasy.StreamPartTypeToolInputDelta, ID: "tc-1", Delta: `{"key":"value"}`}, - {Type: fantasy.StreamPartTypeToolInputEnd, ID: "tc-1"}, - { - Type: fantasy.StreamPartTypeToolCall, - ID: "tc-1", - ToolCallName: "slow_tool", - ToolCallInput: `{"key":"value"}`, - }, - {Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonToolCalls}, - }), nil - }, - } - - // Tool that blocks until context is canceled, simulating - // a long-running operation interrupted by the user. - slowTool := fantasy.NewAgentTool( - "slow_tool", - "blocks until canceled", - func(ctx context.Context, _ struct{}, _ fantasy.ToolCall) (fantasy.ToolResponse, error) { - close(toolStarted) - <-ctx.Done() - return fantasy.ToolResponse{}, ctx.Err() - }, - ) - - ctx, cancel := context.WithCancelCause(context.Background()) - defer cancel(nil) - - go func() { - <-toolStarted - cancel(ErrInterrupted) - }() - - var persistedContent []fantasy.Content - persistedCtxErr := xerrors.New("unset") - - err := Run(ctx, RunOptions{ - Model: model, - Messages: []fantasy.Message{ - textMessage(fantasy.MessageRoleUser, "run the slow tool"), - }, - Tools: []fantasy.AgentTool{slowTool}, - MaxSteps: 3, - PersistStep: func(persistCtx context.Context, step PersistedStep) error { - persistedCtxErr = persistCtx.Err() - persistedContent = append([]fantasy.Content(nil), step.Content...) - return nil - }, - }) - require.ErrorIs(t, err, ErrInterrupted) - // persistInterruptedStep uses context.WithoutCancel, so the - // persist callback should see a non-canceled context. - require.NoError(t, persistedCtxErr) - require.NotEmpty(t, persistedContent) - - var ( - foundText bool - foundReasoning bool - foundToolCall bool - foundToolResult bool - ) - for _, block := range persistedContent { - if text, ok := fantasy.AsContentType[fantasy.TextContent](block); ok { - if strings.Contains(text.Text, "calling tool") { - foundText = true - } - continue - } - if reasoning, ok := fantasy.AsContentType[fantasy.ReasoningContent](block); ok { - if strings.Contains(reasoning.Text, "let me think") { - foundReasoning = true - } - continue - } - if toolCall, ok := fantasy.AsContentType[fantasy.ToolCallContent](block); ok { - if toolCall.ToolCallID == "tc-1" && toolCall.ToolName == "slow_tool" { - foundToolCall = true - } - continue - } - if toolResult, ok := fantasy.AsContentType[fantasy.ToolResultContent](block); ok { - if toolResult.ToolCallID == "tc-1" { - foundToolResult = true - } - } - } - require.True(t, foundText, "persisted content should include text from the stream") - require.True(t, foundReasoning, "persisted content should include reasoning from the stream") - require.True(t, foundToolCall, "persisted content should include the tool call") - require.True(t, foundToolResult, "persisted content should include the tool result (error from cancellation)") -} - -// TestRun_ProviderExecutedToolResultTimestamps verifies that -// provider-executed tool results (e.g. web search) have their -// timestamps recorded in PersistedStep.ToolResultCreatedAt so -// the persistence layer can stamp CreatedAt on the parts. -func TestRun_ProviderExecutedToolResultTimestamps(t *testing.T) { - t.Parallel() - - model := &chattest.FakeModel{ - ProviderName: "fake", - StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { - // Simulate a provider-executed tool call and result - // (e.g. Anthropic web search) followed by a text - // response, all in a single stream. - return streamFromParts([]fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeToolInputStart, ID: "ws-1", ToolCallName: "web_search", ProviderExecuted: true}, - {Type: fantasy.StreamPartTypeToolInputDelta, ID: "ws-1", Delta: `{"query":"coder"}`, ProviderExecuted: true}, - {Type: fantasy.StreamPartTypeToolInputEnd, ID: "ws-1"}, - { - Type: fantasy.StreamPartTypeToolCall, - ID: "ws-1", - ToolCallName: "web_search", - ToolCallInput: `{"query":"coder"}`, - ProviderExecuted: true, - }, - // Provider-executed tool result, emitted by - // the provider, not our tool runner. - { - Type: fantasy.StreamPartTypeToolResult, - ID: "ws-1", - ToolCallName: "web_search", - ProviderExecuted: true, - }, - {Type: fantasy.StreamPartTypeTextStart, ID: "text-1"}, - {Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "search done"}, - {Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"}, - {Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop}, - }), nil - }, - } - - var persistedSteps []PersistedStep - err := Run(context.Background(), RunOptions{ - Model: model, - Messages: []fantasy.Message{ - textMessage(fantasy.MessageRoleUser, "search for coder"), - }, - MaxSteps: 1, - PersistStep: func(_ context.Context, step PersistedStep) error { - persistedSteps = append(persistedSteps, step) - return nil - }, - }) - require.NoError(t, err) - require.Len(t, persistedSteps, 1) - - step := persistedSteps[0] - - // Provider-executed tool call should have a call timestamp. - require.Contains(t, step.ToolCallCreatedAt, "ws-1", - "provider-executed tool call must record its timestamp") - - // Provider-executed tool result should have a result - // timestamp so the frontend can compute duration. - require.Contains(t, step.ToolResultCreatedAt, "ws-1", - "provider-executed tool result must record its timestamp") - - require.False(t, - step.ToolResultCreatedAt["ws-1"].Before(step.ToolCallCreatedAt["ws-1"]), - "tool-result timestamp must be >= tool-call timestamp") -} - -func TestRun_AnthropicDropsUnpairedProviderToolBeforePersist(t *testing.T) { - t.Parallel() - - testCases := []struct { - name string - toolName string - toolInput string - }{ - { - name: "web_search", - toolName: "web_search", - toolInput: `{"query":"coder"}`, - }, - { - name: "code_execution", - toolName: "code_execution", - toolInput: `{"code":"print(1)"}`, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - model := &chattest.FakeModel{ - ProviderName: fantasyanthropic.Name, - StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { - return streamFromParts([]fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeToolInputStart, ID: "pt-1", ToolCallName: tc.toolName, ProviderExecuted: true}, - {Type: fantasy.StreamPartTypeToolInputDelta, ID: "pt-1", Delta: tc.toolInput, ProviderExecuted: true}, - {Type: fantasy.StreamPartTypeToolInputEnd, ID: "pt-1"}, - { - Type: fantasy.StreamPartTypeToolCall, - ID: "pt-1", - ToolCallName: tc.toolName, - ToolCallInput: tc.toolInput, - ProviderExecuted: true, - }, - {Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop}, - }), nil - }, - } - - persistCalls := 0 - err := Run(context.Background(), RunOptions{ - Model: model, - Messages: []fantasy.Message{ - textMessage(fantasy.MessageRoleUser, "run provider tool"), - }, - MaxSteps: 1, - PersistStep: func(_ context.Context, _ PersistedStep) error { - persistCalls++ - return nil - }, - }) - require.NoError(t, err) - require.Equal(t, 0, persistCalls) - }) - } -} - -func TestRun_AnthropicKeepsPairedWebSearchBeforePersist(t *testing.T) { - t.Parallel() - - model := &chattest.FakeModel{ - ProviderName: fantasyanthropic.Name, - StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { - return streamFromParts([]fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeToolInputStart, ID: "ws-1", ToolCallName: "web_search", ProviderExecuted: true}, - {Type: fantasy.StreamPartTypeToolInputDelta, ID: "ws-1", Delta: `{"query":"coder"}`, ProviderExecuted: true}, - {Type: fantasy.StreamPartTypeToolInputEnd, ID: "ws-1"}, - { - Type: fantasy.StreamPartTypeToolCall, - ID: "ws-1", - ToolCallName: "web_search", - ToolCallInput: `{"query":"coder"}`, - ProviderExecuted: true, - }, - { - Type: fantasy.StreamPartTypeToolResult, - ID: "ws-1", - ToolCallName: "web_search", - ProviderExecuted: true, - ProviderMetadata: validWebSearchProviderMetadataForTest(), - }, - {Type: fantasy.StreamPartTypeTextStart, ID: "text-1"}, - {Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "search done"}, - {Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"}, - {Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop}, - }), nil - }, - } - - var persistedSteps []PersistedStep - err := Run(context.Background(), RunOptions{ - Model: model, - Messages: []fantasy.Message{ - textMessage(fantasy.MessageRoleUser, "search for coder"), - }, - MaxSteps: 1, - PersistStep: func(_ context.Context, step PersistedStep) error { - persistedSteps = append(persistedSteps, step) - return nil - }, - }) - require.NoError(t, err) - require.Len(t, persistedSteps, 1) - - toolCall := requireToolCallContent(t, persistedSteps[0].Content, "ws-1", "web_search") - require.True(t, toolCall.ProviderExecuted) - toolResult := requireToolResultContent(t, persistedSteps[0].Content, "ws-1", "web_search") - require.True(t, toolResult.ProviderExecuted) - requireTextContent(t, persistedSteps[0].Content, "search done") -} - -func TestRun_AnthropicInterruptedWebSearchDoesNotPersistSyntheticResult(t *testing.T) { - t.Parallel() - - started := make(chan struct{}) - model := &chattest.FakeModel{ - ProviderName: fantasyanthropic.Name, - StreamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { - return iter.Seq[fantasy.StreamPart](func(yield func(fantasy.StreamPart) bool) { - if !yield(fantasy.StreamPart{ - Type: fantasy.StreamPartTypeToolInputStart, - ID: "ws-1", - ToolCallName: "web_search", - ProviderExecuted: true, - }) { - return - } - if !yield(fantasy.StreamPart{ - Type: fantasy.StreamPartTypeToolInputDelta, - ID: "ws-1", - Delta: `{"query":"coder"}`, - ProviderExecuted: true, - }) { - return - } - close(started) - <-ctx.Done() - _ = yield(fantasy.StreamPart{ - Type: fantasy.StreamPartTypeError, - Error: ctx.Err(), - }) - }), nil - }, - } - - ctx, cancel := context.WithCancelCause(context.Background()) - defer cancel(nil) - go func() { - <-started - cancel(ErrInterrupted) - }() - - persistCalls := 0 - err := Run(ctx, RunOptions{ - Model: model, - Messages: []fantasy.Message{ - textMessage(fantasy.MessageRoleUser, "search for coder"), - }, - MaxSteps: 1, - PersistStep: func(_ context.Context, _ PersistedStep) error { - persistCalls++ - return nil - }, - }) - require.ErrorIs(t, err, ErrInterrupted) - require.Equal(t, 0, persistCalls) -} - -func TestRun_AnthropicInterruptedProviderToolKeepsLocalSyntheticResult(t *testing.T) { - t.Parallel() - - started := make(chan struct{}) - model := &chattest.FakeModel{ - ProviderName: fantasyanthropic.Name, - StreamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { - return iter.Seq[fantasy.StreamPart](func(yield func(fantasy.StreamPart) bool) { - if !yield(fantasy.StreamPart{ - Type: fantasy.StreamPartTypeToolInputStart, - ID: "ws-1", - ToolCallName: "web_search", - ProviderExecuted: true, - }) { - return - } - if !yield(fantasy.StreamPart{ - Type: fantasy.StreamPartTypeToolInputDelta, - ID: "ws-1", - Delta: `{"query":"coder"}`, - ProviderExecuted: true, - }) { - return - } - if !yield(fantasy.StreamPart{ - Type: fantasy.StreamPartTypeToolInputStart, - ID: "tc-1", - ToolCallName: "read_file", - }) { - return - } - if !yield(fantasy.StreamPart{ - Type: fantasy.StreamPartTypeToolInputDelta, - ID: "tc-1", - Delta: `{"path":"main.go"}`, - }) { - return - } - close(started) - <-ctx.Done() - _ = yield(fantasy.StreamPart{ - Type: fantasy.StreamPartTypeError, - Error: ctx.Err(), - }) - }), nil - }, - } - - ctx, cancel := context.WithCancelCause(context.Background()) - defer cancel(nil) - go func() { - <-started - cancel(ErrInterrupted) - }() - - var persistedSteps []PersistedStep - err := Run(ctx, RunOptions{ - Model: model, - Messages: []fantasy.Message{ - textMessage(fantasy.MessageRoleUser, "search and read"), - }, - MaxSteps: 1, - PersistStep: func(_ context.Context, step PersistedStep) error { - persistedSteps = append(persistedSteps, step) - return nil - }, - }) - require.ErrorIs(t, err, ErrInterrupted) - require.Len(t, persistedSteps, 1) - requireNoProviderExecutedToolCallContent(t, persistedSteps[0].Content) - requireNoProviderExecutedToolResultContent(t, persistedSteps[0].Content) - - toolCall := requireToolCallContent(t, persistedSteps[0].Content, "tc-1", "read_file") - require.False(t, toolCall.ProviderExecuted) - toolResult := requireToolResultContent(t, persistedSteps[0].Content, "tc-1", "read_file") - require.False(t, toolResult.ProviderExecuted) - _, isErr := toolResult.Result.(fantasy.ToolResultOutputContentError) - require.True(t, isErr) -} - -func TestRun_AnthropicSanitizesProviderToolBeforeRequest(t *testing.T) { - t.Parallel() - - var capturedPrompt []fantasy.Message - model := &chattest.FakeModel{ - ProviderName: fantasyanthropic.Name, - StreamFn: func(_ context.Context, call fantasy.Call) (fantasy.StreamResponse, error) { - capturedPrompt = append([]fantasy.Message(nil), call.Prompt...) - return streamFromParts([]fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeTextStart, ID: "text-1"}, - {Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "done"}, - {Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"}, - {Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop}, - }), nil - }, - } - - err := Run(context.Background(), RunOptions{ - Model: model, - Messages: []fantasy.Message{ - textMessage(fantasy.MessageRoleUser, "search for coder"), - { - Role: fantasy.MessageRoleAssistant, - Content: []fantasy.MessagePart{ - fantasy.ToolCallPart{ - ToolCallID: "ws-1", - ToolName: "web_search", - Input: `{"query":"coder"}`, - ProviderExecuted: true, - }, - }, - }, - textMessage(fantasy.MessageRoleUser, "continue"), - }, - MaxSteps: 1, - PersistStep: func(_ context.Context, _ PersistedStep) error { - return nil - }, - }) - require.NoError(t, err) - require.Len(t, capturedPrompt, 1) - require.Equal(t, fantasy.MessageRoleUser, capturedPrompt[0].Role) - require.Len(t, capturedPrompt[0].Content, 2) - requireNoProviderExecutedToolCallPrompt(t, capturedPrompt) -} - -func TestRun_AnthropicSanitizesWebSearchBeforeContinuation(t *testing.T) { - t.Parallel() - - var mu sync.Mutex - var streamCalls int - var secondCallPrompt []fantasy.Message - model := &chattest.FakeModel{ - ProviderName: fantasyanthropic.Name, - StreamFn: func(_ context.Context, call fantasy.Call) (fantasy.StreamResponse, error) { - mu.Lock() - step := streamCalls - streamCalls++ - mu.Unlock() - - switch step { - case 0: - return streamFromParts([]fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeToolInputStart, ID: "ws-1", ToolCallName: "web_search", ProviderExecuted: true}, - {Type: fantasy.StreamPartTypeToolInputDelta, ID: "ws-1", Delta: `{"query":"coder"}`, ProviderExecuted: true}, - {Type: fantasy.StreamPartTypeToolInputEnd, ID: "ws-1"}, - { - Type: fantasy.StreamPartTypeToolCall, - ID: "ws-1", - ToolCallName: "web_search", - ToolCallInput: `{"query":"coder"}`, - ProviderExecuted: true, - }, - {Type: fantasy.StreamPartTypeToolInputStart, ID: "tc-1", ToolCallName: "read_file"}, - {Type: fantasy.StreamPartTypeToolInputDelta, ID: "tc-1", Delta: `{"path":"main.go"}`}, - {Type: fantasy.StreamPartTypeToolInputEnd, ID: "tc-1"}, - { - Type: fantasy.StreamPartTypeToolCall, - ID: "tc-1", - ToolCallName: "read_file", - ToolCallInput: `{"path":"main.go"}`, - }, - {Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonToolCalls}, - }), nil - default: - mu.Lock() - secondCallPrompt = append([]fantasy.Message(nil), call.Prompt...) - mu.Unlock() - return streamFromParts([]fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeTextStart, ID: "text-1"}, - {Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "done"}, - {Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"}, - {Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop}, - }), nil - } - }, - } - - var persistedSteps []PersistedStep - err := Run(context.Background(), RunOptions{ - Model: model, - Messages: []fantasy.Message{ - textMessage(fantasy.MessageRoleUser, "search and read"), - }, - Tools: []fantasy.AgentTool{ - newNoopTool("read_file"), - }, - MaxSteps: 2, - PersistStep: func(_ context.Context, step PersistedStep) error { - persistedSteps = append(persistedSteps, step) - return nil - }, - }) - require.NoError(t, err) - require.Equal(t, 2, streamCalls) - require.Len(t, persistedSteps, 2) - requireNoProviderExecutedToolCallContent(t, persistedSteps[0].Content) - requireNoProviderExecutedToolCallPrompt(t, secondCallPrompt) - - toolCall := requireToolCallContent(t, persistedSteps[0].Content, "tc-1", "read_file") - require.False(t, toolCall.ProviderExecuted) - toolResult := requireToolResultContent(t, persistedSteps[0].Content, "tc-1", "read_file") - require.False(t, toolResult.ProviderExecuted) - promptResult := requireToolResultPrompt(t, secondCallPrompt, "tc-1") - require.False(t, promptResult.ProviderExecuted) -} - func TestSanitizeAnthropicProviderToolContent(t *testing.T) { t.Parallel() @@ -3735,764 +845,6 @@ func TestSanitizeAnthropicProviderToolContent(t *testing.T) { } } -func TestRun_AnthropicProviderToolPreRequestGuard(t *testing.T) { - t.Parallel() - - webSearchTool := ProviderTool{ - Definition: fantasy.ProviderDefinedTool{ - ID: "anthropic.web_search", - Name: "web_search", - }, - } - providerPair := func(id string) []fantasy.MessagePart { - return []fantasy.MessagePart{ - fantasy.ToolCallPart{ - ToolCallID: id, - ToolName: "web_search", - Input: `{"query":"coder"}`, - ProviderExecuted: true, - }, - fantasy.ToolResultPart{ - ToolCallID: id, - Output: fantasy.ToolResultOutputContentText{Text: "ok"}, - ProviderExecuted: true, - ProviderOptions: fantasy.ProviderOptions(validWebSearchProviderMetadataForTest()), - }, - } - } - completionModel := func(capturedPrompt *[]fantasy.Message) *chattest.FakeModel { - return &chattest.FakeModel{ - ProviderName: fantasyanthropic.Name, - ModelName: "claude-test", - StreamFn: func(_ context.Context, call fantasy.Call) (fantasy.StreamResponse, error) { - *capturedPrompt = append([]fantasy.Message(nil), call.Prompt...) - return streamFromParts([]fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeTextStart, ID: "text-1"}, - {Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "done"}, - {Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"}, - {Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop}, - }), nil - }, - } - } - - t.Run("allowed web search survives when provider tool is enabled", func(t *testing.T) { - t.Parallel() - - var capturedPrompt []fantasy.Message - err := Run(context.Background(), RunOptions{ - Model: completionModel(&capturedPrompt), - Messages: []fantasy.Message{ - textMessage(fantasy.MessageRoleUser, "search"), - { - Role: fantasy.MessageRoleAssistant, - Content: providerPair("ws-allowed"), - }, - textMessage(fantasy.MessageRoleUser, "continue"), - }, - ProviderTools: []ProviderTool{webSearchTool}, - MaxSteps: 1, - PersistStep: func(_ context.Context, _ PersistedStep) error { - return nil - }, - }) - require.NoError(t, err) - - toolCall := requireProviderExecutedToolCallPrompt(t, capturedPrompt, "ws-allowed") - require.Equal(t, "web_search", toolCall.ToolName) - requireProviderExecutedToolResultPrompt(t, capturedPrompt, "ws-allowed") - requireAnthropicProviderToolPromptSafe(t, capturedPrompt) - }) - - t.Run("web search history survives when provider tool is disabled", func(t *testing.T) { - t.Parallel() - - var capturedPrompt []fantasy.Message - err := Run(context.Background(), RunOptions{ - Model: completionModel(&capturedPrompt), - Messages: []fantasy.Message{ - textMessage(fantasy.MessageRoleUser, "search and read"), - { - Role: fantasy.MessageRoleAssistant, - Content: append(providerPair("ws-disabled"), fantasy.ToolCallPart{ - ToolCallID: "tc-1", - ToolName: "read_file", - Input: `{"path":"main.go"}`, - }), - }, - { - Role: fantasy.MessageRoleTool, - Content: []fantasy.MessagePart{ - fantasy.ToolResultPart{ - ToolCallID: "tc-1", - Output: fantasy.ToolResultOutputContentText{Text: "file"}, - }, - }, - }, - textMessage(fantasy.MessageRoleUser, "continue"), - }, - MaxSteps: 1, - PersistStep: func(_ context.Context, _ PersistedStep) error { - return nil - }, - }) - require.NoError(t, err) - - requireProviderExecutedToolCallPrompt(t, capturedPrompt, "ws-disabled") - requireProviderExecutedToolResultPrompt(t, capturedPrompt, "ws-disabled") - promptResult := requireToolResultPrompt(t, capturedPrompt, "tc-1") - require.False(t, promptResult.ProviderExecuted) - requireAnthropicProviderToolPromptSafe(t, capturedPrompt) - }) - - t.Run("direct guard textifies orphaned provider result", func(t *testing.T) { - t.Parallel() - - guarded, err := chatsanitize.ApplyAnthropicProviderToolGuard( - context.Background(), - slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}), - fantasyanthropic.Name, - "claude-test", - []fantasy.Message{ - { - Role: fantasy.MessageRoleAssistant, - Content: []fantasy.MessagePart{ - fantasy.TextPart{Text: "keep"}, - fantasy.ToolResultPart{ - ToolCallID: "ws-orphan", - Output: fantasy.ToolResultOutputContentText{Text: "search result"}, - ProviderExecuted: true, - }, - }, - }, - }, - ) - require.NoError(t, err) - - requireNoProviderExecutedToolResultPrompt(t, guarded) - requireAnthropicProviderToolPromptSafe(t, guarded) - require.Len(t, guarded, 1) - require.Len(t, guarded[0].Content, 2) - textPart, ok := fantasy.AsMessagePart[fantasy.TextPart](guarded[0].Content[0]) - require.True(t, ok) - require.Equal(t, "keep", textPart.Text) - textPart, ok = fantasy.AsMessagePart[fantasy.TextPart](guarded[0].Content[1]) - require.True(t, ok) - require.Equal(t, "search result", textPart.Text) - }) - - t.Run("direct guard leaves valid provider history unchanged", func(t *testing.T) { - t.Parallel() - - content := []fantasy.MessagePart{fantasy.TextPart{Text: "keep"}} - content = append(content, providerPair("ws-one")...) - content = append(content, providerPair("ws-two")...) - guarded, err := chatsanitize.ApplyAnthropicProviderToolGuard( - context.Background(), - slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}), - fantasyanthropic.Name, - "claude-test", - []fantasy.Message{{Role: fantasy.MessageRoleAssistant, Content: content}}, - ) - require.NoError(t, err) - - requireAnthropicProviderToolPromptSafe(t, guarded) - require.Len(t, guarded, 1) - require.Len(t, guarded[0].Content, len(content)) - requireProviderExecutedToolCallPrompt(t, guarded, "ws-one") - requireProviderExecutedToolResultPrompt(t, guarded, "ws-one") - requireProviderExecutedToolCallPrompt(t, guarded, "ws-two") - requireProviderExecutedToolResultPrompt(t, guarded, "ws-two") - }) - - t.Run("direct guard leaves non Anthropic providers unchanged", func(t *testing.T) { - t.Parallel() - - prompt := []fantasy.Message{ - { - Role: fantasy.MessageRoleAssistant, - Content: providerPair("ws-other-provider"), - }, - } - guarded, err := chatsanitize.ApplyAnthropicProviderToolGuard( - context.Background(), - slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}), - "fake", - "fake-model", - prompt, - ) - require.NoError(t, err) - require.Equal(t, prompt, guarded) - }) - - t.Run("guard logs removals", func(t *testing.T) { - t.Parallel() - - logSink := testutil.NewFakeSink(t) - logger := logSink.Logger() - logPair := providerPair("ws-log") - guarded, err := chatsanitize.ApplyAnthropicProviderToolGuard( - context.Background(), - logger, - fantasyanthropic.Name, - "claude-test", - []fantasy.Message{ - { - Role: fantasy.MessageRoleAssistant, - Content: []fantasy.MessagePart{ - logPair[1], - logPair[0], - }, - }, - }, - ) - require.NoError(t, err) - - requireNoProviderExecutedToolCallPrompt(t, guarded) - requireNoProviderExecutedToolResultPrompt(t, guarded) - requireTextPrompt(t, guarded, "ok") - entries := logSink.Entries(func(e slog.SinkEntry) bool { - return e.Level == slog.LevelWarn && - e.Message == "removed provider-executed tool history" - }) - require.Len(t, entries, 1) - require.Equal(t, "pre_request_guard", requireLogField(t, entries[0], "phase")) - require.Equal(t, 1, requireLogField(t, entries[0], "removed_tool_calls")) - require.Equal(t, 1, requireLogField(t, entries[0], "removed_tool_results")) - }) - t.Run("run drops orphan provider call before provider request", func(t *testing.T) { - t.Parallel() - - streamCalls := 0 - var capturedPrompt fantasy.Prompt - model := &chattest.FakeModel{ - ProviderName: fantasyanthropic.Name, - ModelName: "claude-test", - StreamFn: func(_ context.Context, call fantasy.Call) (fantasy.StreamResponse, error) { - streamCalls++ - capturedPrompt = call.Prompt - return finishingStream(), nil - }, - } - - err := Run(context.Background(), RunOptions{ - Model: model, - Messages: []fantasy.Message{ - textMessage(fantasy.MessageRoleUser, "search"), - { - Role: fantasy.MessageRoleAssistant, - Content: []fantasy.MessagePart{ - fantasy.ReasoningPart{ - ProviderOptions: fantasy.ProviderOptions{ - fantasyanthropic.Name: &fantasyanthropic.ReasoningOptionMetadata{ - RedactedData: "redacted-payload", - }, - }, - }, - fantasy.ToolCallPart{ - ToolCallID: "ws-orphan", - ToolName: "web_search", - Input: `{"query":"coder"}`, - ProviderExecuted: true, - }, - fantasy.TextPart{Text: "partial"}, - }, - }, - textMessage(fantasy.MessageRoleUser, "continue"), - }, - Logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}), - MaxSteps: 1, - PersistStep: func(_ context.Context, _ PersistedStep) error { - return nil - }, - }) - require.NoError(t, err) - require.Equal(t, 1, streamCalls) - requireNoProviderExecutedToolCallPrompt(t, capturedPrompt) - requireAnthropicProviderToolPromptSafe(t, capturedPrompt) - requireTextPrompt(t, capturedPrompt, "partial") - reasoningPart := requireReasoningPrompt(t, capturedPrompt) - reasoningMetadata := fantasyanthropic.GetReasoningMetadata(reasoningPart.ProviderOptions) - require.NotNil(t, reasoningMetadata) - require.Equal(t, "redacted-payload", reasoningMetadata.RedactedData) - }) -} - -// TestRun_PersistStepInterruptedFallback verifies that when the normal -// PersistStep call returns ErrInterrupted (e.g., context canceled in a -// race), the step is retried via the interrupt-safe path. -func TestRun_PersistStepInterruptedFallback(t *testing.T) { - t.Parallel() - - model := &chattest.FakeModel{ - ProviderName: "fake", - StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { - return streamFromParts([]fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeTextStart, ID: "text-1"}, - {Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "hello world"}, - {Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"}, - {Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop}, - }), nil - }, - } - - var ( - mu sync.Mutex - persistCalls int - savedContent []fantasy.Content - ) - - err := Run(context.Background(), RunOptions{ - Model: model, - Messages: []fantasy.Message{ - textMessage(fantasy.MessageRoleUser, "hello"), - }, - MaxSteps: 1, - PersistStep: func(_ context.Context, step PersistedStep) error { - mu.Lock() - defer mu.Unlock() - persistCalls++ - if persistCalls == 1 { - // First call: simulate an interrupt race by - // returning ErrInterrupted without persisting. - return ErrInterrupted - } - // Second call (from persistInterruptedStep fallback): - // accept the content. - savedContent = append([]fantasy.Content(nil), step.Content...) - return nil - }, - }) - require.ErrorIs(t, err, ErrInterrupted) - - mu.Lock() - defer mu.Unlock() - require.Equal(t, 2, persistCalls, "PersistStep should be called twice: once normally (failing), once via fallback") - require.NotEmpty(t, savedContent) - - var foundText bool - for _, block := range savedContent { - if text, ok := fantasy.AsContentType[fantasy.TextContent](block); ok { - if strings.Contains(text.Text, "hello world") { - foundText = true - } - } - } - require.True(t, foundText, "fallback should persist the text content") -} - -func TestRun_PrepareMessagesInjectsSystemContextMidLoop(t *testing.T) { - t.Parallel() - - const injectedInstruction = "You are working in /home/coder/project. Follow AGENTS.md guidelines." - - var mu sync.Mutex - var streamCalls int - var secondCallPrompt []fantasy.Message - - // Step 0 calls a tool. Step 1 sees the injected system message. - model := &chattest.FakeModel{ - ProviderName: "fake", - StreamFn: func(_ context.Context, call fantasy.Call) (fantasy.StreamResponse, error) { - mu.Lock() - step := streamCalls - streamCalls++ - mu.Unlock() - - switch step { - case 0: - return streamFromParts([]fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeToolInputStart, ID: "tc-1", ToolCallName: "create_workspace"}, - {Type: fantasy.StreamPartTypeToolInputDelta, ID: "tc-1", Delta: `{}`}, - {Type: fantasy.StreamPartTypeToolInputEnd, ID: "tc-1"}, - { - Type: fantasy.StreamPartTypeToolCall, - ID: "tc-1", - ToolCallName: "create_workspace", - ToolCallInput: `{}`, - }, - {Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonToolCalls}, - }), nil - default: - mu.Lock() - secondCallPrompt = append([]fantasy.Message(nil), call.Prompt...) - mu.Unlock() - return streamFromParts([]fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeTextStart, ID: "text-1"}, - {Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "done"}, - {Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"}, - {Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop}, - }), nil - } - }, - } - - // Simulate: after the tool executes (step 0), instruction - // becomes available. PrepareMessages injects it before step 1. - instructionInjected := make(chan struct{}) - var instructionAvailable atomic.Value - // The tool sets instruction after execution. - tool := fantasy.NewAgentTool( - "create_workspace", - "create a workspace", - func(_ context.Context, _ struct{}, _ fantasy.ToolCall) (fantasy.ToolResponse, error) { - instructionAvailable.Store(injectedInstruction) - return fantasy.ToolResponse{}, nil - }, - ) - - err := Run(context.Background(), RunOptions{ - Model: model, - Messages: []fantasy.Message{ - textMessage(fantasy.MessageRoleUser, "create a workspace and open a PR"), - }, - Tools: []fantasy.AgentTool{tool}, - MaxSteps: 5, - PersistStep: func(_ context.Context, _ PersistedStep) error { - return nil - }, - PrepareMessages: func(msgs []fantasy.Message) []fantasy.Message { - select { - case <-instructionInjected: - return nil - default: - } - instr, ok := instructionAvailable.Load().(string) - if !ok || instr == "" { - return nil - } - close(instructionInjected) - // Insert a system message after existing system messages. - result := make([]fantasy.Message, 0, len(msgs)+1) - inserted := false - for i, msg := range msgs { - result = append(result, msg) - if !inserted && msg.Role == fantasy.MessageRoleSystem { - // Insert after the last system message. - if i+1 >= len(msgs) || msgs[i+1].Role != fantasy.MessageRoleSystem { - result = append(result, fantasy.Message{ - Role: fantasy.MessageRoleSystem, - Content: []fantasy.MessagePart{ - fantasy.TextPart{Text: instr}, - }, - }) - inserted = true - } - } - } - if !inserted { - // No system messages, prepend. - result = append([]fantasy.Message{{ - Role: fantasy.MessageRoleSystem, - Content: []fantasy.MessagePart{ - fantasy.TextPart{Text: instr}, - }, - }}, result...) - } - return result - }, - }) - require.NoError(t, err) - require.Equal(t, 2, streamCalls) - - // The second LLM call should contain the injected instruction. - require.NotEmpty(t, secondCallPrompt) - var foundInstruction bool - for _, msg := range secondCallPrompt { - if msg.Role != fantasy.MessageRoleSystem { - continue - } - for _, part := range msg.Content { - if tp, ok := fantasy.AsMessagePart[fantasy.TextPart](part); ok { - if strings.Contains(tp.Text, "AGENTS.md") { - foundInstruction = true - } - } - } - } - require.True(t, foundInstruction, - "step 1 prompt should contain the injected system instruction") -} - -func TestRun_PrepareMessagesOnlyFiresOnce(t *testing.T) { - t.Parallel() - - var mu sync.Mutex - var streamCalls int - - // Three steps: tool call, tool call, text. PrepareMessages - // should inject on step 1 and return nil on step 2. - model := &chattest.FakeModel{ - ProviderName: "fake", - StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { - mu.Lock() - step := streamCalls - streamCalls++ - mu.Unlock() - - if step < 2 { - return streamFromParts([]fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeToolInputStart, ID: "tc-" + strings.Repeat("x", step+1), ToolCallName: "noop"}, - {Type: fantasy.StreamPartTypeToolInputDelta, ID: "tc-" + strings.Repeat("x", step+1), Delta: `{}`}, - {Type: fantasy.StreamPartTypeToolInputEnd, ID: "tc-" + strings.Repeat("x", step+1)}, - { - Type: fantasy.StreamPartTypeToolCall, - ID: "tc-" + strings.Repeat("x", step+1), - ToolCallName: "noop", - ToolCallInput: `{}`, - }, - {Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonToolCalls}, - }), nil - } - return streamFromParts([]fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeTextStart, ID: "text-1"}, - {Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "done"}, - {Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"}, - {Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop}, - }), nil - }, - } - - var prepareCalls atomic.Int32 - err := Run(context.Background(), RunOptions{ - Model: model, - Messages: []fantasy.Message{ - textMessage(fantasy.MessageRoleUser, "do something"), - }, - Tools: []fantasy.AgentTool{newNoopTool("noop")}, - MaxSteps: 5, - PersistStep: func(_ context.Context, _ PersistedStep) error { - return nil - }, - PrepareMessages: func(msgs []fantasy.Message) []fantasy.Message { - call := prepareCalls.Add(1) - if call == 1 { - // First call: inject a message. - return append(msgs, fantasy.Message{ - Role: fantasy.MessageRoleSystem, - Content: []fantasy.MessagePart{fantasy.TextPart{Text: "injected"}}, - }) - } - // Subsequent calls: no changes. - return nil - }, - }) - require.NoError(t, err) - require.Equal(t, 3, streamCalls) - // PrepareMessages is called before each of the 3 steps. - require.Equal(t, 3, int(prepareCalls.Load())) -} - -// TestRun_PrepareToolsInjectsToolMidLoop guards the regression where a -// chat creating its workspace mid-turn (via create_workspace) saw the -// workspace MCP tools only on the next turn. Before the fix, the tool -// list was frozen at the top of the turn and the model could not call -// any workspace MCP tools until turn 2. With the fix, PrepareTools is -// invoked before every step and can inject tools that become available -// mid-loop. -func TestRun_PrepareToolsInjectsToolMidLoop(t *testing.T) { - t.Parallel() - - const injectedToolName = "workspace_mcp__echo" - - var mu sync.Mutex - var streamCalls int - var secondCallTools []fantasy.Tool - - // Step 0 calls create_workspace. Step 1 should see the - // injected workspace MCP tool. - model := &chattest.FakeModel{ - ProviderName: "fake", - StreamFn: func(_ context.Context, call fantasy.Call) (fantasy.StreamResponse, error) { - mu.Lock() - step := streamCalls - streamCalls++ - mu.Unlock() - - switch step { - case 0: - return streamFromParts([]fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeToolInputStart, ID: "tc-1", ToolCallName: "create_workspace"}, - {Type: fantasy.StreamPartTypeToolInputDelta, ID: "tc-1", Delta: `{}`}, - {Type: fantasy.StreamPartTypeToolInputEnd, ID: "tc-1"}, - { - Type: fantasy.StreamPartTypeToolCall, - ID: "tc-1", - ToolCallName: "create_workspace", - ToolCallInput: `{}`, - }, - {Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonToolCalls}, - }), nil - default: - mu.Lock() - secondCallTools = append([]fantasy.Tool(nil), call.Tools...) - mu.Unlock() - return streamFromParts([]fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeTextStart, ID: "text-1"}, - {Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "done"}, - {Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"}, - {Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop}, - }), nil - } - }, - } - - var workspaceReady atomic.Bool - createWorkspaceTool := fantasy.NewAgentTool( - "create_workspace", - "create a workspace", - func(_ context.Context, _ struct{}, _ fantasy.ToolCall) (fantasy.ToolResponse, error) { - workspaceReady.Store(true) - return fantasy.ToolResponse{}, nil - }, - ) - - var prepareCalls atomic.Int32 - err := Run(context.Background(), RunOptions{ - Model: model, - Messages: []fantasy.Message{ - textMessage(fantasy.MessageRoleUser, "create a workspace and use MCP"), - }, - Tools: []fantasy.AgentTool{createWorkspaceTool}, - ActiveTools: []string{"create_workspace"}, - MaxSteps: 5, - PersistStep: func(_ context.Context, _ PersistedStep) error { - return nil - }, - PrepareTools: func(currentTools []fantasy.AgentTool) []fantasy.AgentTool { - prepareCalls.Add(1) - if !workspaceReady.Load() { - return nil - } - return append(currentTools, newNoopTool(injectedToolName)) - }, - }) - require.NoError(t, err) - require.Equal(t, 2, streamCalls) - // PrepareTools is called before each of the 2 steps. - require.Equal(t, int32(2), prepareCalls.Load()) - - require.NotEmpty(t, secondCallTools) - var foundInjectedTool bool - for _, tool := range secondCallTools { - if tool.GetName() == injectedToolName { - foundInjectedTool = true - break - } - } - require.True(t, foundInjectedTool, - "step 1 prompt should advertise the workspace MCP tool injected by PrepareTools") -} - -// TestRun_PrepareToolsAddsNewToolToActiveSet guards the contract that -// when PrepareTools injects a tool, that tool is callable on the -// next step even when opts.ActiveTools was non-empty (and would -// otherwise filter the new tool out). -func TestRun_PrepareToolsAddsNewToolToActiveSet(t *testing.T) { - t.Parallel() - - const injectedToolName = "workspace_mcp__echo" - - var mu sync.Mutex - var streamCalls int - var injectedToolRan atomic.Bool - - model := &chattest.FakeModel{ - ProviderName: "fake", - StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { - mu.Lock() - step := streamCalls - streamCalls++ - mu.Unlock() - - switch step { - case 0: - return streamFromParts([]fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeToolInputStart, ID: "tc-1", ToolCallName: "create_workspace"}, - {Type: fantasy.StreamPartTypeToolInputDelta, ID: "tc-1", Delta: `{}`}, - {Type: fantasy.StreamPartTypeToolInputEnd, ID: "tc-1"}, - { - Type: fantasy.StreamPartTypeToolCall, - ID: "tc-1", - ToolCallName: "create_workspace", - ToolCallInput: `{}`, - }, - {Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonToolCalls}, - }), nil - case 1: - return streamFromParts([]fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeToolInputStart, ID: "tc-2", ToolCallName: injectedToolName}, - {Type: fantasy.StreamPartTypeToolInputDelta, ID: "tc-2", Delta: `{}`}, - {Type: fantasy.StreamPartTypeToolInputEnd, ID: "tc-2"}, - { - Type: fantasy.StreamPartTypeToolCall, - ID: "tc-2", - ToolCallName: injectedToolName, - ToolCallInput: `{}`, - }, - {Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonToolCalls}, - }), nil - default: - return streamFromParts([]fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeTextStart, ID: "text-1"}, - {Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "done"}, - {Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"}, - {Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop}, - }), nil - } - }, - } - - var workspaceReady atomic.Bool - createWorkspaceTool := fantasy.NewAgentTool( - "create_workspace", - "create a workspace", - func(_ context.Context, _ struct{}, _ fantasy.ToolCall) (fantasy.ToolResponse, error) { - workspaceReady.Store(true) - return fantasy.ToolResponse{}, nil - }, - ) - - injectedTool := fantasy.NewAgentTool( - injectedToolName, - "injected workspace MCP tool", - func(_ context.Context, _ struct{}, _ fantasy.ToolCall) (fantasy.ToolResponse, error) { - injectedToolRan.Store(true) - return fantasy.ToolResponse{}, nil - }, - ) - - err := Run(context.Background(), RunOptions{ - Model: model, - Messages: []fantasy.Message{ - textMessage(fantasy.MessageRoleUser, "create a workspace and use MCP"), - }, - Tools: []fantasy.AgentTool{createWorkspaceTool}, - // Active list deliberately excludes the injected tool name; - // PrepareTools must add it so the tool is callable. - ActiveTools: []string{"create_workspace"}, - MaxSteps: 5, - PersistStep: func(_ context.Context, _ PersistedStep) error { - return nil - }, - PrepareTools: func(currentTools []fantasy.AgentTool) []fantasy.AgentTool { - if !workspaceReady.Load() { - return nil - } - for _, t := range currentTools { - if t.Info().Name == injectedToolName { - return nil - } - } - return append(currentTools, injectedTool) - }, - }) - require.NoError(t, err) - require.GreaterOrEqual(t, streamCalls, 2) - require.True(t, injectedToolRan.Load(), - "injected tool must be callable on the step after PrepareTools adds it") -} - func TestExecuteSingleTool_MediaBase64Encoding(t *testing.T) { t.Parallel() @@ -4633,170 +985,3 @@ func TestExecuteSingleTool_MediaBase64Encoding(t *testing.T) { require.Contains(t, textOutput.Text, "world") }) } - -// TestRun_ReasoningTimestamps verifies that StreamPartTypeReasoningStart -// and StreamPartTypeReasoningEnd produce parallel ReasoningStartedAt / -// ReasoningCompletedAt slices on PersistedStep, in the same occurrence -// order as the reasoning content blocks. The frontend computes -// reasoning duration as completed_at - started_at. -func TestRun_ReasoningTimestamps(t *testing.T) { - t.Parallel() - - model := &chattest.FakeModel{ - ProviderName: "fake", - StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { - return streamFromParts([]fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeReasoningStart, ID: "reason-1"}, - {Type: fantasy.StreamPartTypeReasoningDelta, ID: "reason-1", Delta: "first thought"}, - {Type: fantasy.StreamPartTypeReasoningEnd, ID: "reason-1"}, - {Type: fantasy.StreamPartTypeReasoningStart, ID: "reason-2"}, - {Type: fantasy.StreamPartTypeReasoningDelta, ID: "reason-2", Delta: "second thought"}, - {Type: fantasy.StreamPartTypeReasoningEnd, ID: "reason-2"}, - {Type: fantasy.StreamPartTypeTextStart, ID: "text-1"}, - {Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "answer"}, - {Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"}, - {Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop}, - }), nil - }, - } - - var persistedSteps []PersistedStep - err := Run(context.Background(), RunOptions{ - Model: model, - Messages: []fantasy.Message{ - textMessage(fantasy.MessageRoleUser, "think"), - }, - MaxSteps: 1, - PersistStep: func(_ context.Context, step PersistedStep) error { - persistedSteps = append(persistedSteps, step) - return nil - }, - }) - require.NoError(t, err) - require.Len(t, persistedSteps, 1) - - step := persistedSteps[0] - - // Both reasoning blocks must produce parallel timestamp entries. - require.Len(t, step.ReasoningStartedAt, 2, - "each StreamPartTypeReasoningEnd must record a started_at") - require.Len(t, step.ReasoningCompletedAt, 2, - "each StreamPartTypeReasoningEnd must record a completed_at") - - // Timestamps must be monotonic per block (completed_at >= started_at), - // and both timestamps must be populated. Asserting only monotonicity - // is not enough: time.Time{} is year 0001, so completed_at.Before(zero) - // is trivially false and a regression that drops the started_at stamp - // would slip past the comparison. - for i := range step.ReasoningStartedAt { - require.False(t, step.ReasoningStartedAt[i].IsZero(), - "started_at[%d] must be non-zero", i) - require.False(t, step.ReasoningCompletedAt[i].IsZero(), - "completed_at[%d] must be non-zero", i) - require.False(t, - step.ReasoningCompletedAt[i].Before(step.ReasoningStartedAt[i]), - "completed_at[%d] must be >= started_at[%d]", i, i) - } - - // Successive blocks must be ordered: reasoning-2 cannot start - // before reasoning-1 completes. - require.False(t, - step.ReasoningStartedAt[1].Before(step.ReasoningCompletedAt[0]), - "reasoning-2 started_at must be >= reasoning-1 completed_at") - - // The reasoning content blocks must appear in the same order - // in step.Content so the persistence layer can correlate by - // occurrence order. - var reasoningOrder []string - for _, c := range step.Content { - if r, ok := fantasy.AsContentType[fantasy.ReasoningContent](c); ok { - reasoningOrder = append(reasoningOrder, r.Text) - } - } - require.Equal(t, []string{"first thought", "second thought"}, reasoningOrder) -} - -func TestRun_InterruptedReasoningFlushesTimestamps(t *testing.T) { - t.Parallel() - - started := make(chan struct{}) - model := &chattest.FakeModel{ - ProviderName: "fake", - StreamFn: func(ctx context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { - return iter.Seq[fantasy.StreamPart](func(yield func(fantasy.StreamPart) bool) { - parts := []fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeReasoningStart, ID: "reason-1"}, - {Type: fantasy.StreamPartTypeReasoningDelta, ID: "reason-1", Delta: "interrupted thought"}, - } - for _, part := range parts { - if !yield(part) { - return - } - } - - select { - case <-started: - default: - close(started) - } - - <-ctx.Done() - _ = yield(fantasy.StreamPart{ - Type: fantasy.StreamPartTypeError, - Error: ctx.Err(), - }) - }), nil - }, - } - - ctx, cancel := context.WithCancelCause(context.Background()) - defer cancel(nil) - - go func() { - <-started - cancel(ErrInterrupted) - }() - - var persistedStep PersistedStep - err := Run(ctx, RunOptions{ - Model: model, - Messages: []fantasy.Message{ - textMessage(fantasy.MessageRoleUser, "think"), - }, - MaxSteps: 1, - PersistStep: func(_ context.Context, step PersistedStep) error { - persistedStep = step - return nil - }, - }) - require.ErrorIs(t, err, ErrInterrupted) - - // flushActiveState must have appended exactly one entry to each - // parallel slice, matching the single in-progress reasoning block. - require.Len(t, persistedStep.ReasoningStartedAt, 1, - "interrupted reasoning must flush its started_at") - require.Len(t, persistedStep.ReasoningCompletedAt, 1, - "interrupted reasoning must flush a completed_at stamp") - - // Both timestamps must be populated and the completed stamp - // must be at or after the started stamp. - require.False(t, persistedStep.ReasoningStartedAt[0].IsZero(), - "flushed reasoning started_at must be non-zero") - require.False(t, persistedStep.ReasoningCompletedAt[0].IsZero(), - "flushed reasoning completed_at must be non-zero") - require.False(t, - persistedStep.ReasoningCompletedAt[0].Before(persistedStep.ReasoningStartedAt[0]), - "flushed completed_at must be >= started_at") - - // The flushed reasoning content must appear in step.Content so - // the persistence layer's occurrence-order correlation lines up - // with the timestamp slices. - var reasoningBlocks []fantasy.ReasoningContent - for _, c := range persistedStep.Content { - if r, ok := fantasy.AsContentType[fantasy.ReasoningContent](c); ok { - reasoningBlocks = append(reasoningBlocks, r) - } - } - require.Len(t, reasoningBlocks, 1) - require.Equal(t, "interrupted thought", reasoningBlocks[0].Text) -} diff --git a/coderd/x/chatd/chatloop/compaction.go b/coderd/x/chatd/chatloop/compaction.go index b267f17e2a..330def364f 100644 --- a/coderd/x/chatd/chatloop/compaction.go +++ b/coderd/x/chatd/chatloop/compaction.go @@ -90,45 +90,36 @@ type CompactionResult struct { ContextLimit int64 } -// tryCompact checks whether context usage exceeds the compaction -// threshold and, if so, generates and persists a summary. Returns -// (true, nil) when compaction was performed, (false, nil) when not -// needed, and (false, err) on failure. -func tryCompact( - ctx context.Context, - model fantasy.LanguageModel, - compaction *CompactionOptions, - contextLimitFallback int64, - stepUsage fantasy.Usage, - stepMetadata fantasy.ProviderMetadata, - allMessages []fantasy.Message, -) (bool, error) { - config, ok := normalizedCompactionConfig(compaction) +// GenerateCompaction generates one context summary and returns it without +// persisting. It publishes compaction progress parts when configured. +func GenerateCompaction(ctx context.Context, opts GenerateCompactionOptions) (CompactionResult, error) { + if opts.Model == nil { + return CompactionResult{}, xerrors.New("chat model is required") + } + config, ok := normalizedCompactionGenerateConfig(opts) if !ok { - return false, nil + return CompactionResult{}, nil } - contextTokens := contextTokensFromUsage(stepUsage) + contextTokens := contextTokensFromUsage(opts.StepUsage) if contextTokens <= 0 { - return false, nil + return CompactionResult{}, nil } - - metadataLimit := extractContextLimit(stepMetadata) + metadataLimit := extractContextLimit(opts.StepMetadata) contextLimit := resolveContextLimit( metadataLimit.Int64, config.ContextLimit, - contextLimitFallback, + opts.ContextLimitFallback, ) - usagePercent, compact := shouldCompact( - contextTokens, contextLimit, config.ThresholdPercent, + contextTokens, + contextLimit, + config.ThresholdPercent, ) if !compact { - return false, nil + return CompactionResult{}, nil } - // Publish the "Summarizing..." tool-call indicator so - // connected clients see activity during summary generation. if config.PublishMessagePart != nil && config.ToolCallID != "" { config.PublishMessagePart( codersdk.ChatMessageRoleAssistant, @@ -136,40 +127,26 @@ func tryCompact( ) } - summary, err := generateCompactionSummary( - ctx, model, allMessages, config, - ) + summary, err := generateCompactionSummary(ctx, opts.Model, opts.Messages, config) if err != nil { - return false, err + publishCompactionError(config, "failed to generate compaction summary") + return CompactionResult{}, err } if summary == "" { - // Publish a tool-result error so connected clients - // see the compaction failure. publishCompactionError(config, "compaction produced an empty summary") - return false, xerrors.New("compaction produced an empty summary") + return CompactionResult{}, xerrors.New("compaction produced an empty summary") } - systemSummary := strings.TrimSpace( - config.SystemSummaryPrefix + "\n\n" + summary, - ) - - persistCtx := context.WithoutCancel(ctx) - err = config.Persist(persistCtx, CompactionResult{ - SystemSummary: systemSummary, + result := CompactionResult{ + SystemSummary: strings.TrimSpace( + config.SystemSummaryPrefix + "\n\n" + summary, + ), SummaryReport: summary, ThresholdPercent: config.ThresholdPercent, UsagePercent: usagePercent, ContextTokens: contextTokens, ContextLimit: contextLimit, - }) - if err != nil { - publishCompactionError(config, "failed to persist compaction result") - return false, xerrors.Errorf("persist compaction: %w", err) } - - // Publish the "Summarized" tool-result part so the client - // transitions from the in-progress indicator to the final - // state. if config.PublishMessagePart != nil && config.ToolCallID != "" { resultJSON, _ := json.Marshal(map[string]any{ "summary": summary, @@ -184,37 +161,22 @@ func tryCompact( codersdk.ChatMessageToolResult(config.ToolCallID, config.ToolName, resultJSON, false, false), ) } - - return true, nil + return result, nil } -// publishCompactionError sends a tool-result error part so -// connected clients see that compaction failed. -func publishCompactionError(config CompactionOptions, msg string) { - if config.PublishMessagePart == nil || config.ToolCallID == "" { - return - } - errJSON, _ := json.Marshal(map[string]any{ - "error": msg, - }) - config.PublishMessagePart( - codersdk.ChatMessageRoleTool, - codersdk.ChatMessageToolResult(config.ToolCallID, config.ToolName, errJSON, true, false), - ) -} - -// normalizedCompactionConfig returns a copy of the compaction options -// with defaults applied. The bool is false when compaction is -// disabled (nil options, missing Persist callback, or threshold at -// 100%). -func normalizedCompactionConfig(opts *CompactionOptions) (CompactionOptions, bool) { - if opts == nil { - return CompactionOptions{}, false - } - - config := *opts - if config.Persist == nil { - return CompactionOptions{}, false +func normalizedCompactionGenerateConfig(opts GenerateCompactionOptions) (CompactionOptions, bool) { + config := CompactionOptions{ + ThresholdPercent: opts.ThresholdPercent, + ContextLimit: opts.ContextLimit, + SummaryPrompt: opts.SummaryPrompt, + SystemSummaryPrefix: opts.SystemSummaryPrefix, + Timeout: opts.Timeout, + DebugSvc: opts.DebugSvc, + ChatID: opts.ChatID, + HistoryTipMessageID: opts.HistoryTipMessageID, + ToolCallID: opts.ToolCallID, + ToolName: opts.ToolName, + PublishMessagePart: opts.PublishMessagePart, } if strings.TrimSpace(config.SummaryPrompt) == "" { config.SummaryPrompt = defaultCompactionSummaryPrompt @@ -232,10 +194,24 @@ func normalizedCompactionConfig(opts *CompactionOptions) (CompactionOptions, boo if config.ThresholdPercent == maxCompactionThresholdPercent { return CompactionOptions{}, false } - return config, true } +// publishCompactionError sends a tool-result error part so +// connected clients see that compaction failed. +func publishCompactionError(config CompactionOptions, msg string) { + if config.PublishMessagePart == nil || config.ToolCallID == "" { + return + } + errJSON, _ := json.Marshal(map[string]any{ + "error": msg, + }) + config.PublishMessagePart( + codersdk.ChatMessageRoleTool, + codersdk.ChatMessageToolResult(config.ToolCallID, config.ToolName, errJSON, true, false), + ) +} + // contextTokensFromUsage returns the total context token count from // a step's usage report. It sums input, cache-read, and // cache-creation tokens when available, falling back to TotalTokens diff --git a/coderd/x/chatd/chatloop/compaction_internal_test.go b/coderd/x/chatd/chatloop/compaction_internal_test.go index ae26ed8cf0..ba6580465d 100644 --- a/coderd/x/chatd/chatloop/compaction_internal_test.go +++ b/coderd/x/chatd/chatloop/compaction_internal_test.go @@ -3,7 +3,6 @@ package chatloop import ( "context" "encoding/json" - "sync" "testing" "time" @@ -18,7 +17,6 @@ import ( "github.com/coder/coder/v2/coderd/database/dbmock" "github.com/coder/coder/v2/coderd/x/chatd/chatdebug" "github.com/coder/coder/v2/coderd/x/chatd/chattest" - "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/testutil" ) @@ -235,778 +233,3 @@ func TestGenerateCompactionSummary_PanicFinalizesAsError(t *testing.T) { t.Fatal("FinalizeRun never reached UpdateChatDebugRun on panic") } } - -func TestRun_Compaction(t *testing.T) { - t.Parallel() - - t.Run("PersistsWhenThresholdReached", func(t *testing.T) { - t.Parallel() - - persistCompactionCalls := 0 - var persistedCompaction CompactionResult - const summaryText = "summary text for compaction" - - model := &chattest.FakeModel{ - ProviderName: "fake", - StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { - return streamFromParts([]fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeTextStart, ID: "text-1"}, - {Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "done"}, - {Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"}, - { - Type: fantasy.StreamPartTypeFinish, - FinishReason: fantasy.FinishReasonStop, - Usage: fantasy.Usage{ - InputTokens: 80, - TotalTokens: 85, - }, - }, - }), nil - }, - GenerateFn: func(_ context.Context, call fantasy.Call) (*fantasy.Response, error) { - require.NotEmpty(t, call.Prompt) - lastPrompt := call.Prompt[len(call.Prompt)-1] - require.Equal(t, fantasy.MessageRoleUser, lastPrompt.Role) - require.Len(t, lastPrompt.Content, 1) - - instruction, ok := fantasy.AsMessagePart[fantasy.TextPart](lastPrompt.Content[0]) - require.True(t, ok) - require.Equal(t, "summarize now", instruction.Text) - - return &fantasy.Response{ - Content: []fantasy.Content{ - fantasy.TextContent{Text: summaryText}, - }, - }, nil - }, - } - - err := Run(context.Background(), RunOptions{ - Model: model, - Messages: []fantasy.Message{ - textMessage(fantasy.MessageRoleUser, "hello"), - }, - MaxSteps: 1, - PersistStep: func(_ context.Context, _ PersistedStep) error { - return nil - }, - ContextLimitFallback: 100, - Compaction: &CompactionOptions{ - ThresholdPercent: 70, - SummaryPrompt: "summarize now", - Persist: func(_ context.Context, result CompactionResult) error { - persistCompactionCalls++ - persistedCompaction = result - return nil - }, - }, - ReloadMessages: func(_ context.Context) ([]fantasy.Message, error) { - return []fantasy.Message{ - textMessage(fantasy.MessageRoleUser, "hello"), - }, nil - }, - }) - require.NoError(t, err) - // Compaction fires twice: once inline when the threshold is - // reached on step 0 (the only step, since MaxSteps=1), and - // once from the post-run safety net during the re-entry - // iteration (where totalSteps already equals MaxSteps so the - // inner loop doesn't execute, but lastUsage still exceeds - // the threshold). - require.Equal(t, 2, persistCompactionCalls) - require.Contains(t, persistedCompaction.SystemSummary, summaryText) - require.Equal(t, summaryText, persistedCompaction.SummaryReport) - require.Equal(t, int64(80), persistedCompaction.ContextTokens) - require.Equal(t, int64(100), persistedCompaction.ContextLimit) - require.InDelta(t, 80.0, persistedCompaction.UsagePercent, 0.0001) - }) - - t.Run("PublishesPartsBeforeAndAfterPersist", func(t *testing.T) { - t.Parallel() - - const summaryText = "compaction summary for ordering test" - - // Track the order of callbacks to verify the tool-call - // part publishes before Generate (summary generation) - // and the tool-result part publishes after Persist. - var callOrder []string - - model := &chattest.FakeModel{ - ProviderName: "fake", - StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { - return streamFromParts([]fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeTextStart, ID: "text-1"}, - {Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "done"}, - {Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"}, - { - Type: fantasy.StreamPartTypeFinish, - FinishReason: fantasy.FinishReasonStop, - Usage: fantasy.Usage{ - InputTokens: 80, - TotalTokens: 85, - }, - }, - }), nil - }, - GenerateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) { - callOrder = append(callOrder, "generate") - return &fantasy.Response{ - Content: []fantasy.Content{ - fantasy.TextContent{Text: summaryText}, - }, - }, nil - }, - } - - err := Run(context.Background(), RunOptions{ - Model: model, - Messages: []fantasy.Message{ - textMessage(fantasy.MessageRoleUser, "hello"), - }, - MaxSteps: 1, - PersistStep: func(_ context.Context, _ PersistedStep) error { - return nil - }, - ContextLimitFallback: 100, - Compaction: &CompactionOptions{ - ThresholdPercent: 70, - SummaryPrompt: "summarize now", - ToolCallID: "test-tool-call-id", - ToolName: "chat_summarized", - PublishMessagePart: func(role codersdk.ChatMessageRole, part codersdk.ChatMessagePart) { - switch part.Type { - case codersdk.ChatMessagePartTypeToolCall: - callOrder = append(callOrder, "publish_tool_call") - case codersdk.ChatMessagePartTypeToolResult: - callOrder = append(callOrder, "publish_tool_result") - } - }, - Persist: func(_ context.Context, _ CompactionResult) error { - callOrder = append(callOrder, "persist") - return nil - }, - }, - ReloadMessages: func(_ context.Context) ([]fantasy.Message, error) { - return []fantasy.Message{ - textMessage(fantasy.MessageRoleUser, "hello"), - }, nil - }, - }) - require.NoError(t, err) - // Compaction fires twice (see PersistsWhenThresholdReached - // for the full explanation). Each cycle follows the order: - // publish_tool_call → generate → persist → publish_tool_result. - require.Equal(t, []string{ - "publish_tool_call", - "generate", - "persist", - "publish_tool_result", - "publish_tool_call", - "generate", - "persist", - "publish_tool_result", - }, callOrder) - }) - - t.Run("PublishNotCalledBelowThreshold", func(t *testing.T) { - t.Parallel() - - publishCalled := false - - model := &chattest.FakeModel{ - ProviderName: "fake", - StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { - return streamFromParts([]fantasy.StreamPart{ - { - Type: fantasy.StreamPartTypeFinish, - FinishReason: fantasy.FinishReasonStop, - Usage: fantasy.Usage{ - InputTokens: 10, - }, - }, - }), nil - }, - } - - err := Run(context.Background(), RunOptions{ - Model: model, - Messages: []fantasy.Message{ - textMessage(fantasy.MessageRoleUser, "hello"), - }, - MaxSteps: 1, - PersistStep: func(_ context.Context, _ PersistedStep) error { - return nil - }, - ContextLimitFallback: 100, - Compaction: &CompactionOptions{ - ThresholdPercent: 70, - ToolCallID: "test-tool-call-id", - ToolName: "chat_summarized", - PublishMessagePart: func(_ codersdk.ChatMessageRole, _ codersdk.ChatMessagePart) { - publishCalled = true - }, - Persist: func(_ context.Context, _ CompactionResult) error { - return nil - }, - }, - }) - require.NoError(t, err) - require.False(t, publishCalled, "PublishMessagePart should not fire when usage is below threshold") - }) - - t.Run("MidLoopCompactionReloadsMessages", func(t *testing.T) { - t.Parallel() - - var mu sync.Mutex - var streamCallCount int - persistCompactionCalls := 0 - reloadCalls := 0 - - const summaryText = "compacted summary" - - model := &chattest.FakeModel{ - ProviderName: "fake", - StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { - mu.Lock() - step := streamCallCount - streamCallCount++ - mu.Unlock() - - switch step { - case 0: - // Step 0: tool call with high usage (80/100 = 80% > 70%). - return streamFromParts([]fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeToolInputStart, ID: "tc-1", ToolCallName: "read_file"}, - {Type: fantasy.StreamPartTypeToolInputDelta, ID: "tc-1", Delta: `{}`}, - {Type: fantasy.StreamPartTypeToolInputEnd, ID: "tc-1"}, - { - Type: fantasy.StreamPartTypeToolCall, - ID: "tc-1", - ToolCallName: "read_file", - ToolCallInput: `{}`, - }, - { - Type: fantasy.StreamPartTypeFinish, - FinishReason: fantasy.FinishReasonToolCalls, - Usage: fantasy.Usage{ - InputTokens: 80, - TotalTokens: 85, - }, - }, - }), nil - default: - // Step 1: text with low usage (30/100 = 30% < 70%). - return streamFromParts([]fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeTextStart, ID: "text-1"}, - {Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "done"}, - {Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"}, - { - Type: fantasy.StreamPartTypeFinish, - FinishReason: fantasy.FinishReasonStop, - Usage: fantasy.Usage{ - InputTokens: 30, - TotalTokens: 35, - }, - }, - }), nil - } - }, - GenerateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) { - return &fantasy.Response{ - Content: []fantasy.Content{ - fantasy.TextContent{Text: summaryText}, - }, - }, nil - }, - } - - compactedMessages := []fantasy.Message{ - textMessage(fantasy.MessageRoleSystem, "compacted system"), - textMessage(fantasy.MessageRoleUser, "compacted user"), - } - - err := Run(context.Background(), RunOptions{ - Model: model, - Messages: []fantasy.Message{ - textMessage(fantasy.MessageRoleUser, "hello"), - }, - Tools: []fantasy.AgentTool{ - newNoopTool("read_file"), - }, - MaxSteps: 5, - PersistStep: func(_ context.Context, _ PersistedStep) error { - return nil - }, - ContextLimitFallback: 100, - Compaction: &CompactionOptions{ - ThresholdPercent: 70, - SummaryPrompt: "summarize now", - Persist: func(_ context.Context, _ CompactionResult) error { - persistCompactionCalls++ - return nil - }, - }, - ReloadMessages: func(_ context.Context) ([]fantasy.Message, error) { - reloadCalls++ - return compactedMessages, nil - }, - }) - require.NoError(t, err) - - // Compaction fired after step 0 (above threshold). - require.GreaterOrEqual(t, persistCompactionCalls, 1) - // ReloadMessages was called after mid-loop compaction. - require.GreaterOrEqual(t, reloadCalls, 1) - // Both steps ran (tool-call step + follow-up text step). - require.Equal(t, 2, streamCallCount) - }) - - t.Run("PostRunCompactionSkippedAfterMidLoop", func(t *testing.T) { - t.Parallel() - - var mu sync.Mutex - var streamCallCount int - persistCompactionCalls := 0 - - const summaryText = "compacted summary for skip test" - - model := &chattest.FakeModel{ - ProviderName: "fake", - StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { - mu.Lock() - step := streamCallCount - streamCallCount++ - mu.Unlock() - - switch step { - case 0: - // Step 0: tool call with high usage (80/100 = 80% > 70%). - return streamFromParts([]fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeToolInputStart, ID: "tc-1", ToolCallName: "read_file"}, - {Type: fantasy.StreamPartTypeToolInputDelta, ID: "tc-1", Delta: `{}`}, - {Type: fantasy.StreamPartTypeToolInputEnd, ID: "tc-1"}, - { - Type: fantasy.StreamPartTypeToolCall, - ID: "tc-1", - ToolCallName: "read_file", - ToolCallInput: `{}`, - }, - { - Type: fantasy.StreamPartTypeFinish, - FinishReason: fantasy.FinishReasonToolCalls, - Usage: fantasy.Usage{ - InputTokens: 80, - TotalTokens: 85, - }, - }, - }), nil - default: - // Step 1: text with low usage (20/100 = 20% < 70%). - return streamFromParts([]fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeTextStart, ID: "text-1"}, - {Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "done"}, - {Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"}, - { - Type: fantasy.StreamPartTypeFinish, - FinishReason: fantasy.FinishReasonStop, - Usage: fantasy.Usage{ - InputTokens: 20, - TotalTokens: 25, - }, - }, - }), nil - } - }, - GenerateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) { - return &fantasy.Response{ - Content: []fantasy.Content{ - fantasy.TextContent{Text: summaryText}, - }, - }, nil - }, - } - - compactedMessages := []fantasy.Message{ - textMessage(fantasy.MessageRoleSystem, "compacted system"), - textMessage(fantasy.MessageRoleUser, "compacted user"), - } - - err := Run(context.Background(), RunOptions{ - Model: model, - Messages: []fantasy.Message{ - textMessage(fantasy.MessageRoleUser, "hello"), - }, - Tools: []fantasy.AgentTool{ - newNoopTool("read_file"), - }, - MaxSteps: 5, - PersistStep: func(_ context.Context, _ PersistedStep) error { - return nil - }, - ContextLimitFallback: 100, - Compaction: &CompactionOptions{ - ThresholdPercent: 70, - SummaryPrompt: "summarize now", - Persist: func(_ context.Context, _ CompactionResult) error { - persistCompactionCalls++ - return nil - }, - }, - ReloadMessages: func(_ context.Context) ([]fantasy.Message, error) { - return compactedMessages, nil - }, - }) - require.NoError(t, err) - - // Only mid-loop compaction fires after step 0. The post-run - // safety net is skipped because alreadyCompacted is true. - require.Equal(t, 1, persistCompactionCalls) - }) - - t.Run("ErrorsAreReported", func(t *testing.T) { - t.Parallel() - - model := &chattest.FakeModel{ - ProviderName: "fake", - StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { - return streamFromParts([]fantasy.StreamPart{ - { - Type: fantasy.StreamPartTypeFinish, - FinishReason: fantasy.FinishReasonStop, - Usage: fantasy.Usage{ - InputTokens: 80, - }, - }, - }), nil - }, - GenerateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) { - return nil, xerrors.New("generate failed") - }, - } - - compactionErr := xerrors.New("unset") - err := Run(context.Background(), RunOptions{ - Model: model, - Messages: []fantasy.Message{ - textMessage(fantasy.MessageRoleUser, "hello"), - }, - MaxSteps: 1, - PersistStep: func(_ context.Context, _ PersistedStep) error { - return nil - }, - ContextLimitFallback: 100, - Compaction: &CompactionOptions{ - ThresholdPercent: 70, - Persist: func(_ context.Context, _ CompactionResult) error { - return nil - }, - OnError: func(err error) { - compactionErr = err - }, - }, - ReloadMessages: func(_ context.Context) ([]fantasy.Message, error) { - return []fantasy.Message{ - textMessage(fantasy.MessageRoleUser, "hello"), - }, nil - }, - }) - require.NoError(t, err) - require.Error(t, compactionErr) - require.ErrorContains(t, compactionErr, "generate summary text") - }) - - t.Run("PostRunCompactionReEntersStepLoop", func(t *testing.T) { - t.Parallel() - - // When post-run compaction fires (no mid-loop compaction) - // and ReloadMessages is provided, Run should re-enter the - // step loop with the reloaded messages so the agent - // continues working. - - var mu sync.Mutex - var streamCallCount int - persistCompactionCalls := 0 - reloadCalls := 0 - - const summaryText = "post-run compacted summary" - - compactedMessages := []fantasy.Message{ - textMessage(fantasy.MessageRoleSystem, "compacted system"), - textMessage(fantasy.MessageRoleUser, "compacted user"), - } - - model := &chattest.FakeModel{ - ProviderName: "fake", - StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { - mu.Lock() - step := streamCallCount - streamCallCount++ - mu.Unlock() - - switch step { - case 0: - // First turn: text-only response with high usage. - // No tool calls, so shouldContinue = false and - // the inner step loop breaks. Compaction should - // fire, then the outer loop re-enters. - return streamFromParts([]fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeTextStart, ID: "text-1"}, - {Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "initial response"}, - {Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"}, - { - Type: fantasy.StreamPartTypeFinish, - FinishReason: fantasy.FinishReasonStop, - Usage: fantasy.Usage{ - InputTokens: 80, - TotalTokens: 85, - }, - }, - }), nil - default: - // Second turn (after compaction re-entry): - // text-only with low usage — should finish. - return streamFromParts([]fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeTextStart, ID: "text-2"}, - {Type: fantasy.StreamPartTypeTextDelta, ID: "text-2", Delta: "continued after compaction"}, - {Type: fantasy.StreamPartTypeTextEnd, ID: "text-2"}, - { - Type: fantasy.StreamPartTypeFinish, - FinishReason: fantasy.FinishReasonStop, - Usage: fantasy.Usage{ - InputTokens: 20, - TotalTokens: 25, - }, - }, - }), nil - } - }, - GenerateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) { - return &fantasy.Response{ - Content: []fantasy.Content{ - fantasy.TextContent{Text: summaryText}, - }, - }, nil - }, - } - - err := Run(context.Background(), RunOptions{ - Model: model, - Messages: []fantasy.Message{ - textMessage(fantasy.MessageRoleUser, "hello"), - }, - MaxSteps: 5, - PersistStep: func(_ context.Context, _ PersistedStep) error { - return nil - }, - ContextLimitFallback: 100, - Compaction: &CompactionOptions{ - ThresholdPercent: 70, - SummaryPrompt: "summarize now", - Persist: func(_ context.Context, _ CompactionResult) error { - persistCompactionCalls++ - return nil - }, - }, - ReloadMessages: func(_ context.Context) ([]fantasy.Message, error) { - reloadCalls++ - return compactedMessages, nil - }, - }) - require.NoError(t, err) - - // Compaction fired on the final step of the first pass. - // The inline path fires (ReloadMessages is set) and then - // the outer loop re-enters. On the second pass the usage - // is below threshold so no further compaction occurs. - require.GreaterOrEqual(t, persistCompactionCalls, 1) - // ReloadMessages was called (inline + re-entry). - require.GreaterOrEqual(t, reloadCalls, 1) - // Two stream calls: one before compaction, one after re-entry. - require.Equal(t, 2, streamCallCount) - }) - - t.Run("PostRunCompactionReEntryIncludesUserSummary", func(t *testing.T) { - t.Parallel() - - // After compaction the summary is stored as a user-role - // message. When the loop re-enters, the reloaded prompt - // must contain this user message so the LLM provider - // receives a valid prompt (providers like Anthropic - // require at least one non-system message). - - var mu sync.Mutex - var streamCallCount int - var reEntryPrompt []fantasy.Message - persistCompactionCalls := 0 - - const summaryText = "post-run compacted summary" - - model := &chattest.FakeModel{ - ProviderName: "fake", - StreamFn: func(_ context.Context, call fantasy.Call) (fantasy.StreamResponse, error) { - mu.Lock() - step := streamCallCount - streamCallCount++ - mu.Unlock() - - switch step { - case 0: - return streamFromParts([]fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeTextStart, ID: "text-1"}, - {Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "initial response"}, - {Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"}, - { - Type: fantasy.StreamPartTypeFinish, - FinishReason: fantasy.FinishReasonStop, - Usage: fantasy.Usage{ - InputTokens: 80, - TotalTokens: 85, - }, - }, - }), nil - default: - mu.Lock() - reEntryPrompt = append([]fantasy.Message(nil), call.Prompt...) - mu.Unlock() - return streamFromParts([]fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeTextStart, ID: "text-2"}, - {Type: fantasy.StreamPartTypeTextDelta, ID: "text-2", Delta: "continued"}, - {Type: fantasy.StreamPartTypeTextEnd, ID: "text-2"}, - { - Type: fantasy.StreamPartTypeFinish, - FinishReason: fantasy.FinishReasonStop, - Usage: fantasy.Usage{ - InputTokens: 20, - TotalTokens: 25, - }, - }, - }), nil - } - }, - GenerateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) { - return &fantasy.Response{ - Content: []fantasy.Content{ - fantasy.TextContent{Text: summaryText}, - }, - }, nil - }, - } - - // Simulate real post-compaction DB state: the summary is - // a user-role message (the only non-system content). - compactedMessages := []fantasy.Message{ - textMessage(fantasy.MessageRoleSystem, "system prompt"), - textMessage(fantasy.MessageRoleUser, "Summary of earlier chat context:\n\ncompacted summary"), - } - - err := Run(context.Background(), RunOptions{ - Model: model, - Messages: []fantasy.Message{ - textMessage(fantasy.MessageRoleUser, "hello"), - }, - MaxSteps: 5, - PersistStep: func(_ context.Context, _ PersistedStep) error { - return nil - }, - ContextLimitFallback: 100, - Compaction: &CompactionOptions{ - ThresholdPercent: 70, - SummaryPrompt: "summarize now", - Persist: func(_ context.Context, _ CompactionResult) error { - persistCompactionCalls++ - return nil - }, - }, - ReloadMessages: func(_ context.Context) ([]fantasy.Message, error) { - return compactedMessages, nil - }, - }) - require.NoError(t, err) - - require.GreaterOrEqual(t, persistCompactionCalls, 1) - // Re-entry happened: stream was called at least twice. - require.Equal(t, 2, streamCallCount) - // The re-entry prompt must contain the user summary. - require.NotEmpty(t, reEntryPrompt) - hasUser := false - for _, msg := range reEntryPrompt { - if msg.Role == fantasy.MessageRoleUser { - hasUser = true - break - } - } - require.True(t, hasUser, "re-entry prompt must contain a user message (the compaction summary)") - }) - - t.Run("TriggersOnDynamicToolExit", func(t *testing.T) { - t.Parallel() - - var persistCompactionCalls int - const summaryText = "compaction summary for dynamic tool exit" - - // The LLM calls a dynamic tool. Usage is above the - // compaction threshold so compaction should fire even - // though the chatloop exits via ErrDynamicToolCall. - model := &chattest.FakeModel{ - ProviderName: "fake", - StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { - return streamFromParts([]fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeToolInputStart, ID: "tc-1", ToolCallName: "my_dynamic_tool"}, - {Type: fantasy.StreamPartTypeToolInputDelta, ID: "tc-1", Delta: `{"query": "test"}`}, - {Type: fantasy.StreamPartTypeToolInputEnd, ID: "tc-1"}, - { - Type: fantasy.StreamPartTypeToolCall, - ID: "tc-1", - ToolCallName: "my_dynamic_tool", - ToolCallInput: `{"query": "test"}`, - }, - { - Type: fantasy.StreamPartTypeFinish, - FinishReason: fantasy.FinishReasonToolCalls, - Usage: fantasy.Usage{ - InputTokens: 80, - TotalTokens: 85, - }, - }, - }), nil - }, - GenerateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) { - return &fantasy.Response{ - Content: []fantasy.Content{ - fantasy.TextContent{Text: summaryText}, - }, - }, nil - }, - } - - err := Run(context.Background(), RunOptions{ - Model: model, - Messages: []fantasy.Message{ - textMessage(fantasy.MessageRoleUser, "hello"), - }, - MaxSteps: 5, - DynamicToolNames: map[string]bool{"my_dynamic_tool": true}, - PersistStep: func(_ context.Context, _ PersistedStep) error { - return nil - }, - ContextLimitFallback: 100, - Compaction: &CompactionOptions{ - ThresholdPercent: 70, - SummaryPrompt: "summarize now", - Persist: func(_ context.Context, result CompactionResult) error { - persistCompactionCalls++ - require.Contains(t, result.SystemSummary, summaryText) - return nil - }, - }, - ReloadMessages: func(_ context.Context) ([]fantasy.Message, error) { - return []fantasy.Message{ - textMessage(fantasy.MessageRoleUser, "hello"), - }, nil - }, - }) - require.ErrorIs(t, err, ErrDynamicToolCall) - require.Equal(t, 1, persistCompactionCalls, - "compaction must fire before dynamic tool exit") - }) -} diff --git a/coderd/x/chatd/chatloop/metrics_test.go b/coderd/x/chatd/chatloop/metrics_test.go index 40eabf99ca..e414e91fab 100644 --- a/coderd/x/chatd/chatloop/metrics_test.go +++ b/coderd/x/chatd/chatloop/metrics_test.go @@ -4,7 +4,6 @@ import ( "context" "strconv" "testing" - "time" "charm.land/fantasy" "github.com/prometheus/client_golang/prometheus" @@ -15,7 +14,6 @@ import ( "github.com/coder/coder/v2/coderd/x/chatd/chaterror" "github.com/coder/coder/v2/coderd/x/chatd/chatloop" - "github.com/coder/coder/v2/coderd/x/chatd/chatretry" "github.com/coder/coder/v2/coderd/x/chatd/chattest" "github.com/coder/coder/v2/codersdk" ) @@ -411,117 +409,12 @@ func TestRecordToolError(t *testing.T) { }) } -func TestRun_RecordsMetrics(t *testing.T) { +func TestGenerateAssistant_StreamRetryRecordsMetric(t *testing.T) { t.Parallel() reg := prometheus.NewRegistry() metrics := chatloop.NewMetrics(reg) - model := &chattest.FakeModel{ - ProviderName: "test-provider", - ModelName: "test-model", - StreamFn: func(_ context.Context, call fantasy.Call) (fantasy.StreamResponse, error) { - return func(yield func(fantasy.StreamPart) bool) { - parts := []fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeTextStart, ID: "t1"}, - {Type: fantasy.StreamPartTypeTextDelta, ID: "t1", Delta: "hello"}, - {Type: fantasy.StreamPartTypeTextEnd, ID: "t1"}, - {Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop}, - } - for _, p := range parts { - if !yield(p) { - return - } - } - }, nil - }, - } - - err := chatloop.Run(context.Background(), chatloop.RunOptions{ - Model: model, - Messages: []fantasy.Message{ - { - Role: fantasy.MessageRoleUser, - Content: []fantasy.MessagePart{ - fantasy.TextPart{Text: "hello"}, - }, - }, - }, - MaxSteps: 1, - PersistStep: func(_ context.Context, _ chatloop.PersistedStep) error { - return nil - }, - Metrics: metrics, - }) - require.NoError(t, err) - - families, err := reg.Gather() - require.NoError(t, err) - - assertProviderModelLabels := func(t *testing.T, metric *dto.Metric) { - t.Helper() - labels := map[string]string{} - for _, lp := range metric.GetLabel() { - labels[lp.GetName()] = lp.GetValue() - } - assert.Equal(t, "test-provider", labels["provider"]) - assert.Equal(t, "test-model", labels["model"]) - } - - found := make(map[string]bool) - for _, f := range families { - found[f.GetName()] = true - - switch f.GetName() { - case "coderd_chatd_steps_total": - require.Len(t, f.GetMetric(), 1) - assert.Equal(t, float64(1), f.GetMetric()[0].GetCounter().GetValue(), - "steps_total should be 1 after one step") - assertProviderModelLabels(t, f.GetMetric()[0]) - case "coderd_chatd_message_count": - require.Len(t, f.GetMetric(), 1) - assert.Equal(t, uint64(1), f.GetMetric()[0].GetHistogram().GetSampleCount(), - "message_count should have 1 observation") - assertProviderModelLabels(t, f.GetMetric()[0]) - case "coderd_chatd_prompt_size_bytes": - require.Len(t, f.GetMetric(), 1) - assert.Equal(t, uint64(1), f.GetMetric()[0].GetHistogram().GetSampleCount(), - "prompt_size_bytes should have 1 observation") - assertProviderModelLabels(t, f.GetMetric()[0]) - case "coderd_chatd_ttft_seconds": - require.Len(t, f.GetMetric(), 1) - assert.Equal(t, uint64(1), f.GetMetric()[0].GetHistogram().GetSampleCount(), - "ttft_seconds should have 1 observation") - assertProviderModelLabels(t, f.GetMetric()[0]) - } - } - - assert.True(t, found["coderd_chatd_steps_total"], "steps_total not recorded") - assert.True(t, found["coderd_chatd_message_count"], "message_count not recorded") - assert.True(t, found["coderd_chatd_prompt_size_bytes"], "prompt_size_bytes not recorded") - assert.True(t, found["coderd_chatd_ttft_seconds"], "ttft_seconds not recorded") -} - -// TestRun_StreamRetry_RecordsMetric exercises the end-to-end retry -// path: a retryable error on the first Stream call, success on the -// second. Asserts both the metric and the back-compat OnRetry -// callback fire. -// -// Note: chatretry.Retry uses time.NewTimer (not quartz.Clock), so -// this test pays chatretry.InitialDelay (1s) of real wall-clock -// time per retry. Keep it to one retry. -func TestRun_StreamRetry_RecordsMetric(t *testing.T) { - t.Parallel() - - reg := prometheus.NewRegistry() - metrics := chatloop.NewMetrics(reg) - - type retryCall struct { - attempt int - classified chatretry.ClassifiedError - } - var retries []retryCall - calls := 0 model := &chattest.FakeModel{ ProviderName: "test-provider", @@ -529,7 +422,14 @@ func TestRun_StreamRetry_RecordsMetric(t *testing.T) { StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { calls++ if calls == 1 { - return nil, xerrors.New("received status 429 from upstream") + return nil, chaterror.WithClassification( + xerrors.New("received status 429 from upstream"), + chaterror.ClassifiedError{ + Kind: codersdk.ChatErrorKindRateLimit, + Provider: "test-provider", + Retryable: true, + }, + ) } return func(yield func(fantasy.StreamPart) bool) { yield(fantasy.StreamPart{ @@ -540,35 +440,12 @@ func TestRun_StreamRetry_RecordsMetric(t *testing.T) { }, } - err := chatloop.Run(context.Background(), chatloop.RunOptions{ - Model: model, - MaxSteps: 1, - ContextLimitFallback: 4096, - PersistStep: func(_ context.Context, _ chatloop.PersistedStep) error { - return nil - }, + _, err := chatloop.GenerateAssistant(context.Background(), chatloop.GenerateAssistantOptions{ + Model: model, Metrics: metrics, - OnRetry: func( - attempt int, - _ error, - classified chatretry.ClassifiedError, - _ time.Duration, - ) { - retries = append(retries, retryCall{ - attempt: attempt, - classified: classified, - }) - }, }) - require.NoError(t, err) - - // Back-compat: OnRetry still fires with classified error. - require.Len(t, retries, 1) - assert.Equal(t, 1, retries[0].attempt) - assert.Equal(t, codersdk.ChatErrorKindRateLimit, retries[0].classified.Kind) - assert.Equal(t, "test-provider", retries[0].classified.Provider) - - // Metric assertion. + require.Error(t, err) + require.Equal(t, 1, calls) requireCounter(t, reg, "coderd_chatd_stream_retries_total", 1, map[string]string{ "provider": "test-provider", "model": "test-model", @@ -577,10 +454,10 @@ func TestRun_StreamRetry_RecordsMetric(t *testing.T) { }) } -// TestRun_StreamRetry_ContextCanceledTransportResetIncrements pins the -// invariant that provider-originated context cancellation is counted as -// a retryable transport reset when the chat context is still alive. -func TestRun_StreamRetry_ContextCanceledTransportResetIncrements(t *testing.T) { +// TestGenerateAssistant_StreamRetry_ContextCanceledTransportResetIncrements pins the +// invariant that provider-originated context cancellation is counted as a +// retryable transport reset when the chat context is still alive. +func TestGenerateAssistant_StreamRetry_ContextCanceledTransportResetIncrements(t *testing.T) { t.Parallel() reg := prometheus.NewRegistry() @@ -592,29 +469,16 @@ func TestRun_StreamRetry_ContextCanceledTransportResetIncrements(t *testing.T) { ModelName: "test-model", StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { attempts++ - if attempts == 1 { - return nil, context.Canceled - } - return func(yield func(fantasy.StreamPart) bool) { - _ = yield(fantasy.StreamPart{ - Type: fantasy.StreamPartTypeFinish, - FinishReason: fantasy.FinishReasonStop, - }) - }, nil + return nil, context.Canceled }, } - err := chatloop.Run(context.Background(), chatloop.RunOptions{ - Model: model, - MaxSteps: 1, - ContextLimitFallback: 4096, - PersistStep: func(_ context.Context, _ chatloop.PersistedStep) error { - return nil - }, + _, err := chatloop.GenerateAssistant(context.Background(), chatloop.GenerateAssistantOptions{ + Model: model, Metrics: metrics, }) - require.NoError(t, err) - require.Equal(t, 2, attempts) + require.Error(t, err) + require.Equal(t, 1, attempts) requireCounter(t, reg, "coderd_chatd_stream_retries_total", 1, map[string]string{ "provider": "test-provider", @@ -623,105 +487,3 @@ func TestRun_StreamRetry_ContextCanceledTransportResetIncrements(t *testing.T) { "chain_broken": "false", }) } - -func TestRun_ToolError_RecordsMetric(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - toolFn func(context.Context, struct{}, fantasy.ToolCall) (fantasy.ToolResponse, error) - builtinToolNames map[string]bool - wantLabel string - }{ - { - name: "builtin_tool_IsError", - toolFn: func(_ context.Context, _ struct{}, _ fantasy.ToolCall) (fantasy.ToolResponse, error) { - return fantasy.ToolResponse{ - Content: "something went wrong", - IsError: true, - }, nil - }, - builtinToolNames: map[string]bool{"failing_tool": true}, - wantLabel: "failing_tool", - }, - { - name: "mcp_tool_IsError", - toolFn: func(_ context.Context, _ struct{}, _ fantasy.ToolCall) (fantasy.ToolResponse, error) { - return fantasy.ToolResponse{ - Content: "something went wrong", - IsError: true, - }, nil - }, - builtinToolNames: map[string]bool{}, - wantLabel: "failing_tool", - }, - { - name: "tool_Run_returns_error", - toolFn: func(_ context.Context, _ struct{}, _ fantasy.ToolCall) (fantasy.ToolResponse, error) { - return fantasy.ToolResponse{}, xerrors.New("connection refused") - }, - builtinToolNames: map[string]bool{"failing_tool": true}, - wantLabel: "failing_tool", - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - reg := prometheus.NewRegistry() - metrics := chatloop.NewMetrics(reg) - - failingTool := fantasy.NewAgentTool( - "failing_tool", - "a tool that always fails", - tt.toolFn, - ) - - model := &chattest.FakeModel{ - ProviderName: "test-provider", - ModelName: "test-model", - StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { - return func(yield func(fantasy.StreamPart) bool) { - parts := []fantasy.StreamPart{ - {Type: fantasy.StreamPartTypeToolInputStart, ID: "tc1", ToolCallName: "failing_tool"}, - {Type: fantasy.StreamPartTypeToolInputDelta, ID: "tc1", Delta: `{}`}, - {Type: fantasy.StreamPartTypeToolInputEnd, ID: "tc1"}, - { - Type: fantasy.StreamPartTypeToolCall, - ID: "tc1", - ToolCallName: "failing_tool", - ToolCallInput: `{}`, - }, - {Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonToolCalls}, - } - for _, p := range parts { - if !yield(p) { - return - } - } - }, nil - }, - } - - err := chatloop.Run(context.Background(), chatloop.RunOptions{ - Model: model, - MaxSteps: 1, - Tools: []fantasy.AgentTool{failingTool}, - ActiveTools: []string{"failing_tool"}, - BuiltinToolNames: tt.builtinToolNames, - PersistStep: func(_ context.Context, _ chatloop.PersistedStep) error { - return nil - }, - Metrics: metrics, - }) - require.NoError(t, err) - - requireCounter(t, reg, "coderd_chatd_tool_errors_total", 1, map[string]string{ - "provider": "test-provider", - "model": "test-model", - "tool_name": tt.wantLabel, - }) - }) - } -} diff --git a/coderd/x/chatd/chatopenai/responses.go b/coderd/x/chatd/chatopenai/responses.go index 2c3cad1b09..134ce31590 100644 --- a/coderd/x/chatd/chatopenai/responses.go +++ b/coderd/x/chatd/chatopenai/responses.go @@ -109,45 +109,6 @@ func WithPreviousResponseID( return cloned } -// HasPreviousResponseID checks whether the provider options contain an OpenAI -// Responses entry with a non-empty PreviousResponseID. -func HasPreviousResponseID(providerOptions fantasy.ProviderOptions) bool { - if len(providerOptions) == 0 { - return false - } - - entry, ok := providerOptions[fantasyopenai.Name] - if !ok { - return false - } - options, ok := entry.(*fantasyopenai.ResponsesProviderOptions) - return ok && options != nil && options.PreviousResponseID != nil && - *options.PreviousResponseID != "" -} - -// ClearPreviousResponseID returns a clone of providerOptions with -// PreviousResponseID cleared on the OpenAI Responses options. The original -// providerOptions is not modified. -func ClearPreviousResponseID(providerOptions fantasy.ProviderOptions) fantasy.ProviderOptions { - cloned := maps.Clone(providerOptions) - if cloned == nil { - return fantasy.ProviderOptions{} - } - - entry, ok := cloned[fantasyopenai.Name] - if !ok { - return cloned - } - options, ok := entry.(*fantasyopenai.ResponsesProviderOptions) - if !ok || options == nil { - return cloned - } - optionsClone := *options - optionsClone.PreviousResponseID = nil - cloned[fantasyopenai.Name] = &optionsClone - return cloned -} - // extractResponseID extracts the OpenAI Responses API response ID from provider // metadata. Returns an empty string if no OpenAI Responses metadata is present. func extractResponseID(metadata fantasy.ProviderMetadata) string { diff --git a/coderd/x/chatd/chatopenai/responses_test.go b/coderd/x/chatd/chatopenai/responses_test.go index 5a6e3b9596..59c5cdb44f 100644 --- a/coderd/x/chatd/chatopenai/responses_test.go +++ b/coderd/x/chatd/chatopenai/responses_test.go @@ -254,86 +254,6 @@ func TestWithPreviousResponseIDNilInput(t *testing.T) { require.Empty(t, got) } -func TestHasPreviousResponseID(t *testing.T) { - t.Parallel() - - emptyID := "" - responseID := "resp-123" - - tests := []struct { - name string - opts fantasy.ProviderOptions - want bool - }{ - { - name: "NilOptions", - }, - { - name: "EmptyID", - opts: fantasy.ProviderOptions{ - fantasyopenai.Name: &fantasyopenai.ResponsesProviderOptions{ - PreviousResponseID: &emptyID, - }, - }, - }, - { - name: "NonEmptyID", - opts: fantasy.ProviderOptions{ - fantasyopenai.Name: &fantasyopenai.ResponsesProviderOptions{ - PreviousResponseID: &responseID, - }, - }, - want: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - got := chatopenai.HasPreviousResponseID(tt.opts) - require.Equal(t, tt.want, got) - }) - } -} - -func TestClearPreviousResponseID(t *testing.T) { - t.Parallel() - - responseID := "resp-123" - options := &fantasyopenai.ResponsesProviderOptions{ - PreviousResponseID: &responseID, - } - otherOptions := &fantasyopenai.ProviderOptions{} - opts := fantasy.ProviderOptions{ - fantasyopenai.Name: options, - "other": otherOptions, - } - - got := chatopenai.ClearPreviousResponseID(opts) - - got["new"] = otherOptions - require.NotContains(t, opts, "new") - require.NotNil(t, options.PreviousResponseID) - require.Equal(t, "resp-123", *options.PreviousResponseID) - - gotOtherOptions, ok := got["other"].(*fantasyopenai.ProviderOptions) - require.True(t, ok) - require.True(t, otherOptions == gotOtherOptions) - clonedOptions, ok := got[fantasyopenai.Name].(*fantasyopenai.ResponsesProviderOptions) - require.True(t, ok) - require.NotSame(t, options, clonedOptions) - require.Nil(t, clonedOptions.PreviousResponseID) - - require.NotPanics(t, func() { - got := chatopenai.ClearPreviousResponseID(nil) - require.NotNil(t, got) - chatopenai.ClearPreviousResponseID(fantasy.ProviderOptions{ - fantasyopenai.Name: &fantasyopenai.ProviderOptions{}, - }) - }) -} - func TestExtractResponseIDIfStoredMetadata(t *testing.T) { t.Parallel() diff --git a/coderd/x/chatd/chatprompt/chatprompt.go b/coderd/x/chatd/chatprompt/chatprompt.go index 126edf8dba..f1ad49364e 100644 --- a/coderd/x/chatd/chatprompt/chatprompt.go +++ b/coderd/x/chatd/chatprompt/chatprompt.go @@ -235,39 +235,6 @@ func ConvertMessagesWithFiles( return prompt, nil } -// PrependSystem prepends a system message unless an existing system -// message already mentions create_workspace guidance. -func PrependSystem(prompt []fantasy.Message, instruction string) []fantasy.Message { - instruction = strings.TrimSpace(instruction) - if instruction == "" { - return prompt - } - for _, message := range prompt { - if message.Role != fantasy.MessageRoleSystem { - continue - } - for _, part := range message.Content { - textPart, ok := fantasy.AsMessagePart[fantasy.TextPart](part) - if !ok { - continue - } - if strings.Contains(strings.ToLower(textPart.Text), "create_workspace") { - return prompt - } - } - } - - out := make([]fantasy.Message, 0, len(prompt)+1) - out = append(out, fantasy.Message{ - Role: fantasy.MessageRoleSystem, - Content: []fantasy.MessagePart{ - fantasy.TextPart{Text: instruction}, - }, - }) - out = append(out, prompt...) - return out -} - // InsertSystem inserts a system message after the existing system // block and before the first non-system message. func InsertSystem(prompt []fantasy.Message, instruction string) []fantasy.Message { @@ -298,24 +265,6 @@ func InsertSystem(prompt []fantasy.Message, instruction string) []fantasy.Messag return out } -// AppendUser appends an instruction as a user message at the end of -// the prompt. -func AppendUser(prompt []fantasy.Message, instruction string) []fantasy.Message { - instruction = strings.TrimSpace(instruction) - if instruction == "" { - return prompt - } - out := make([]fantasy.Message, 0, len(prompt)+1) - out = append(out, prompt...) - out = append(out, fantasy.Message{ - Role: fantasy.MessageRoleUser, - Content: []fantasy.MessagePart{ - fantasy.TextPart{Text: instruction}, - }, - }) - return out -} - const ( // ContentVersionV0 is the legacy content format. Parsing uses // role-aware heuristics to distinguish fantasy envelope format diff --git a/coderd/x/chatd/chatprovider/chatprovider.go b/coderd/x/chatd/chatprovider/chatprovider.go index ac817e0940..3c95753a8e 100644 --- a/coderd/x/chatd/chatprovider/chatprovider.go +++ b/coderd/x/chatd/chatprovider/chatprovider.go @@ -36,11 +36,6 @@ var supportedProviderNames = []string{ fantasyvercel.Name, } -var envPresetProviderNames = []string{ - fantasyopenai.Name, - fantasyanthropic.Name, -} - var providerDisplayNameByName = map[string]string{ fantasyanthropic.Name: "Anthropic", fantasyazure.Name: "Azure OpenAI", @@ -52,22 +47,6 @@ var providerDisplayNameByName = map[string]string{ fantasyvercel.Name: "Vercel AI Gateway", } -// SupportedProviders returns all chat providers supported by Fantasy. -func SupportedProviders() []string { - return append([]string(nil), supportedProviderNames...) -} - -// IsEnvPresetProvider reports whether provider supports env presets. -func IsEnvPresetProvider(provider string) bool { - normalized := NormalizeProvider(provider) - for _, candidate := range envPresetProviderNames { - if candidate == normalized { - return true - } - } - return false -} - // ProviderDisplayName returns a default display name for a provider. func ProviderDisplayName(provider string) string { normalized := NormalizeProvider(provider) @@ -771,340 +750,6 @@ func ReasoningEffortFromChat(provider string, value *string) *string { } } -// MergeMissingModelCostConfig fills unset pricing metadata from defaults. -func MergeMissingModelCostConfig( - dst **codersdk.ModelCostConfig, - defaults *codersdk.ModelCostConfig, -) { - if defaults == nil { - return - } - if *dst == nil { - copied := *defaults - *dst = &copied - return - } - - current := *dst - if current.InputPricePerMillionTokens == nil { - current.InputPricePerMillionTokens = defaults.InputPricePerMillionTokens - } - if current.OutputPricePerMillionTokens == nil { - current.OutputPricePerMillionTokens = defaults.OutputPricePerMillionTokens - } - if current.CacheReadPricePerMillionTokens == nil { - current.CacheReadPricePerMillionTokens = defaults.CacheReadPricePerMillionTokens - } - if current.CacheWritePricePerMillionTokens == nil { - current.CacheWritePricePerMillionTokens = defaults.CacheWritePricePerMillionTokens - } -} - -// MergeMissingProviderOptions fills unset provider option fields from defaults. -func MergeMissingProviderOptions( - dst **codersdk.ChatModelProviderOptions, - defaults *codersdk.ChatModelProviderOptions, -) { - if defaults == nil { - return - } - if *dst == nil { - copied := *defaults - *dst = &copied - return - } - - current := *dst - for _, provider := range []string{ - fantasyopenai.Name, - fantasyanthropic.Name, - fantasygoogle.Name, - fantasyopenaicompat.Name, - fantasyopenrouter.Name, - fantasyvercel.Name, - } { - switch provider { - case fantasyopenai.Name: - if defaults.OpenAI == nil { - continue - } - if current.OpenAI == nil { - copied := *defaults.OpenAI - current.OpenAI = &copied - continue - } - dstOpenAI := current.OpenAI - defaultOpenAI := defaults.OpenAI - if dstOpenAI.Include == nil { - dstOpenAI.Include = defaultOpenAI.Include - } - if dstOpenAI.Instructions == nil { - dstOpenAI.Instructions = defaultOpenAI.Instructions - } - if dstOpenAI.LogitBias == nil { - dstOpenAI.LogitBias = defaultOpenAI.LogitBias - } - if dstOpenAI.LogProbs == nil { - dstOpenAI.LogProbs = defaultOpenAI.LogProbs - } - if dstOpenAI.TopLogProbs == nil { - dstOpenAI.TopLogProbs = defaultOpenAI.TopLogProbs - } - if dstOpenAI.MaxToolCalls == nil { - dstOpenAI.MaxToolCalls = defaultOpenAI.MaxToolCalls - } - if dstOpenAI.ParallelToolCalls == nil { - dstOpenAI.ParallelToolCalls = defaultOpenAI.ParallelToolCalls - } - if dstOpenAI.User == nil { - dstOpenAI.User = defaultOpenAI.User - } - if dstOpenAI.ReasoningEffort == nil { - dstOpenAI.ReasoningEffort = defaultOpenAI.ReasoningEffort - } - if dstOpenAI.ReasoningSummary == nil { - dstOpenAI.ReasoningSummary = defaultOpenAI.ReasoningSummary - } - if dstOpenAI.MaxCompletionTokens == nil { - dstOpenAI.MaxCompletionTokens = defaultOpenAI.MaxCompletionTokens - } - if dstOpenAI.TextVerbosity == nil { - dstOpenAI.TextVerbosity = defaultOpenAI.TextVerbosity - } - if dstOpenAI.Prediction == nil { - dstOpenAI.Prediction = defaultOpenAI.Prediction - } - if dstOpenAI.Store == nil { - dstOpenAI.Store = defaultOpenAI.Store - } - if dstOpenAI.Metadata == nil { - dstOpenAI.Metadata = defaultOpenAI.Metadata - } - if dstOpenAI.PromptCacheKey == nil { - dstOpenAI.PromptCacheKey = defaultOpenAI.PromptCacheKey - } - if dstOpenAI.SafetyIdentifier == nil { - dstOpenAI.SafetyIdentifier = defaultOpenAI.SafetyIdentifier - } - if dstOpenAI.ServiceTier == nil { - dstOpenAI.ServiceTier = defaultOpenAI.ServiceTier - } - if dstOpenAI.StructuredOutputs == nil { - dstOpenAI.StructuredOutputs = defaultOpenAI.StructuredOutputs - } - if dstOpenAI.StrictJSONSchema == nil { - dstOpenAI.StrictJSONSchema = defaultOpenAI.StrictJSONSchema - } - - case fantasyanthropic.Name: - if defaults.Anthropic == nil { - continue - } - if current.Anthropic == nil { - copied := *defaults.Anthropic - current.Anthropic = &copied - continue - } - dstAnthropic := current.Anthropic - defaultAnthropic := defaults.Anthropic - if dstAnthropic.SendReasoning == nil { - dstAnthropic.SendReasoning = defaultAnthropic.SendReasoning - } - if dstAnthropic.Thinking == nil { - dstAnthropic.Thinking = defaultAnthropic.Thinking - } else if defaultAnthropic.Thinking != nil && - dstAnthropic.Thinking.BudgetTokens == nil { - dstAnthropic.Thinking.BudgetTokens = defaultAnthropic.Thinking.BudgetTokens - } - if dstAnthropic.Effort == nil { - dstAnthropic.Effort = defaultAnthropic.Effort - } - if dstAnthropic.DisableParallelToolUse == nil { - dstAnthropic.DisableParallelToolUse = defaultAnthropic.DisableParallelToolUse - } - - case fantasygoogle.Name: - if defaults.Google == nil { - continue - } - if current.Google == nil { - copied := *defaults.Google - current.Google = &copied - continue - } - dstGoogle := current.Google - defaultGoogle := defaults.Google - if dstGoogle.ThinkingConfig == nil { - dstGoogle.ThinkingConfig = defaultGoogle.ThinkingConfig - } else if defaultGoogle.ThinkingConfig != nil { - if dstGoogle.ThinkingConfig.ThinkingBudget == nil { - dstGoogle.ThinkingConfig.ThinkingBudget = defaultGoogle.ThinkingConfig.ThinkingBudget - } - if dstGoogle.ThinkingConfig.IncludeThoughts == nil { - dstGoogle.ThinkingConfig.IncludeThoughts = defaultGoogle.ThinkingConfig.IncludeThoughts - } - } - if strings.TrimSpace(dstGoogle.CachedContent) == "" { - dstGoogle.CachedContent = defaultGoogle.CachedContent - } - if dstGoogle.SafetySettings == nil { - dstGoogle.SafetySettings = defaultGoogle.SafetySettings - } - if strings.TrimSpace(dstGoogle.Threshold) == "" { - dstGoogle.Threshold = defaultGoogle.Threshold - } - - case fantasyopenaicompat.Name: - if defaults.OpenAICompat == nil { - continue - } - if current.OpenAICompat == nil { - copied := *defaults.OpenAICompat - current.OpenAICompat = &copied - continue - } - dstCompat := current.OpenAICompat - defaultCompat := defaults.OpenAICompat - if dstCompat.User == nil { - dstCompat.User = defaultCompat.User - } - if dstCompat.ReasoningEffort == nil { - dstCompat.ReasoningEffort = defaultCompat.ReasoningEffort - } - - case fantasyopenrouter.Name: - if defaults.OpenRouter == nil { - continue - } - if current.OpenRouter == nil { - copied := *defaults.OpenRouter - current.OpenRouter = &copied - continue - } - dstRouter := current.OpenRouter - defaultRouter := defaults.OpenRouter - if dstRouter.Reasoning == nil { - dstRouter.Reasoning = defaultRouter.Reasoning - } else if defaultRouter.Reasoning != nil { - if dstRouter.Reasoning.Enabled == nil { - dstRouter.Reasoning.Enabled = defaultRouter.Reasoning.Enabled - } - if dstRouter.Reasoning.Exclude == nil { - dstRouter.Reasoning.Exclude = defaultRouter.Reasoning.Exclude - } - if dstRouter.Reasoning.MaxTokens == nil { - dstRouter.Reasoning.MaxTokens = defaultRouter.Reasoning.MaxTokens - } - if dstRouter.Reasoning.Effort == nil { - dstRouter.Reasoning.Effort = defaultRouter.Reasoning.Effort - } - } - if dstRouter.ExtraBody == nil { - dstRouter.ExtraBody = defaultRouter.ExtraBody - } - if dstRouter.IncludeUsage == nil { - dstRouter.IncludeUsage = defaultRouter.IncludeUsage - } - if dstRouter.LogitBias == nil { - dstRouter.LogitBias = defaultRouter.LogitBias - } - if dstRouter.LogProbs == nil { - dstRouter.LogProbs = defaultRouter.LogProbs - } - if dstRouter.ParallelToolCalls == nil { - dstRouter.ParallelToolCalls = defaultRouter.ParallelToolCalls - } - if dstRouter.User == nil { - dstRouter.User = defaultRouter.User - } - if dstRouter.Provider == nil { - dstRouter.Provider = defaultRouter.Provider - } else if defaultRouter.Provider != nil { - if dstRouter.Provider.Order == nil { - dstRouter.Provider.Order = defaultRouter.Provider.Order - } - if dstRouter.Provider.AllowFallbacks == nil { - dstRouter.Provider.AllowFallbacks = defaultRouter.Provider.AllowFallbacks - } - if dstRouter.Provider.RequireParameters == nil { - dstRouter.Provider.RequireParameters = defaultRouter.Provider.RequireParameters - } - if dstRouter.Provider.DataCollection == nil { - dstRouter.Provider.DataCollection = defaultRouter.Provider.DataCollection - } - if dstRouter.Provider.Only == nil { - dstRouter.Provider.Only = defaultRouter.Provider.Only - } - if dstRouter.Provider.Ignore == nil { - dstRouter.Provider.Ignore = defaultRouter.Provider.Ignore - } - if dstRouter.Provider.Quantizations == nil { - dstRouter.Provider.Quantizations = defaultRouter.Provider.Quantizations - } - if dstRouter.Provider.Sort == nil { - dstRouter.Provider.Sort = defaultRouter.Provider.Sort - } - } - - case fantasyvercel.Name: - if defaults.Vercel == nil { - continue - } - if current.Vercel == nil { - copied := *defaults.Vercel - current.Vercel = &copied - continue - } - dstVercel := current.Vercel - defaultVercel := defaults.Vercel - if dstVercel.Reasoning == nil { - dstVercel.Reasoning = defaultVercel.Reasoning - } else if defaultVercel.Reasoning != nil { - if dstVercel.Reasoning.Enabled == nil { - dstVercel.Reasoning.Enabled = defaultVercel.Reasoning.Enabled - } - if dstVercel.Reasoning.MaxTokens == nil { - dstVercel.Reasoning.MaxTokens = defaultVercel.Reasoning.MaxTokens - } - if dstVercel.Reasoning.Effort == nil { - dstVercel.Reasoning.Effort = defaultVercel.Reasoning.Effort - } - if dstVercel.Reasoning.Exclude == nil { - dstVercel.Reasoning.Exclude = defaultVercel.Reasoning.Exclude - } - } - if dstVercel.ProviderOptions == nil { - dstVercel.ProviderOptions = defaultVercel.ProviderOptions - } else if defaultVercel.ProviderOptions != nil { - if dstVercel.ProviderOptions.Order == nil { - dstVercel.ProviderOptions.Order = defaultVercel.ProviderOptions.Order - } - if dstVercel.ProviderOptions.Models == nil { - dstVercel.ProviderOptions.Models = defaultVercel.ProviderOptions.Models - } - } - if dstVercel.User == nil { - dstVercel.User = defaultVercel.User - } - if dstVercel.LogitBias == nil { - dstVercel.LogitBias = defaultVercel.LogitBias - } - if dstVercel.LogProbs == nil { - dstVercel.LogProbs = defaultVercel.LogProbs - } - if dstVercel.TopLogProbs == nil { - dstVercel.TopLogProbs = defaultVercel.TopLogProbs - } - if dstVercel.ParallelToolCalls == nil { - dstVercel.ParallelToolCalls = defaultVercel.ParallelToolCalls - } - if dstVercel.ExtraBody == nil { - dstVercel.ExtraBody = defaultVercel.ExtraBody - } - } - } -} - // Header constants sent on upstream LLM API requests so that // intermediaries (e.g. aibridged) can correlate traffic back to // Coder entities. @@ -1143,22 +788,6 @@ func CoderHeaders(chat database.Chat) map[string]string { return h } -// CoderHeadersFromIDs is a convenience form of CoderHeaders for call -// sites that do not have a full database.Chat in scope. -func CoderHeadersFromIDs( - ownerID uuid.UUID, - chatID uuid.UUID, - parentChatID uuid.NullUUID, - workspaceID uuid.NullUUID, -) map[string]string { - return CoderHeaders(database.Chat{ - ID: chatID, - OwnerID: ownerID, - ParentChatID: parentChatID, - WorkspaceID: workspaceID, - }) -} - // ModelFromConfig resolves a provider/model pair and constructs a fantasy // language model client using the provided provider credentials. The // userAgent is sent as the User-Agent header on every outgoing LLM diff --git a/coderd/x/chatd/chatprovider/chatprovider_test.go b/coderd/x/chatd/chatprovider/chatprovider_test.go index 0e851d3f89..3c18eb0d05 100644 --- a/coderd/x/chatd/chatprovider/chatprovider_test.go +++ b/coderd/x/chatd/chatprovider/chatprovider_test.go @@ -1477,66 +1477,6 @@ func TestModelFromConfig_HTTPClient(t *testing.T) { _ = testutil.TryReceive(ctx, t, called) } -func TestMergeMissingProviderOptions_OpenRouterNested(t *testing.T) { - t.Parallel() - - options := &codersdk.ChatModelProviderOptions{ - OpenRouter: &codersdk.ChatModelOpenRouterProviderOptions{ - Reasoning: &codersdk.ChatModelReasoningOptions{ - Enabled: ptr.Ref(true), - }, - Provider: &codersdk.ChatModelOpenRouterProvider{ - Order: []string{"openai"}, - }, - }, - } - defaults := &codersdk.ChatModelProviderOptions{ - OpenRouter: &codersdk.ChatModelOpenRouterProviderOptions{ - Reasoning: &codersdk.ChatModelReasoningOptions{ - Enabled: ptr.Ref(false), - Exclude: ptr.Ref(true), - MaxTokens: ptr.Ref[int64](123), - Effort: ptr.Ref("high"), - }, - IncludeUsage: ptr.Ref(true), - Provider: &codersdk.ChatModelOpenRouterProvider{ - Order: []string{"anthropic"}, - AllowFallbacks: ptr.Ref(true), - RequireParameters: ptr.Ref(false), - DataCollection: ptr.Ref("allow"), - Only: []string{"openai"}, - Ignore: []string{"foo"}, - Quantizations: []string{"int8"}, - Sort: ptr.Ref("latency"), - }, - }, - } - - chatprovider.MergeMissingProviderOptions(&options, defaults) - - require.NotNil(t, options) - require.NotNil(t, options.OpenRouter) - require.NotNil(t, options.OpenRouter.Reasoning) - require.True(t, *options.OpenRouter.Reasoning.Enabled) - require.Equal(t, true, *options.OpenRouter.Reasoning.Exclude) - require.EqualValues(t, 123, *options.OpenRouter.Reasoning.MaxTokens) - require.Equal(t, "high", *options.OpenRouter.Reasoning.Effort) - require.NotNil(t, options.OpenRouter.IncludeUsage) - require.True(t, *options.OpenRouter.IncludeUsage) - - require.NotNil(t, options.OpenRouter.Provider) - require.Equal(t, []string{"openai"}, options.OpenRouter.Provider.Order) - require.NotNil(t, options.OpenRouter.Provider.AllowFallbacks) - require.True(t, *options.OpenRouter.Provider.AllowFallbacks) - require.NotNil(t, options.OpenRouter.Provider.RequireParameters) - require.False(t, *options.OpenRouter.Provider.RequireParameters) - require.Equal(t, "allow", *options.OpenRouter.Provider.DataCollection) - require.Equal(t, []string{"openai"}, options.OpenRouter.Provider.Only) - require.Equal(t, []string{"foo"}, options.OpenRouter.Provider.Ignore) - require.Equal(t, []string{"int8"}, options.OpenRouter.Provider.Quantizations) - require.Equal(t, "latency", *options.OpenRouter.Provider.Sort) -} - func TestResolveModelWithProviderHint(t *testing.T) { t.Parallel() diff --git a/coderd/x/chatd/chatretry/chatretry_test.go b/coderd/x/chatd/chatretry/chatretry_test.go index 61fdb047bb..b750548d0a 100644 --- a/coderd/x/chatd/chatretry/chatretry_test.go +++ b/coderd/x/chatd/chatretry/chatretry_test.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "io" "sync/atomic" "testing" "time" @@ -18,80 +17,6 @@ import ( "github.com/coder/coder/v2/codersdk" ) -func TestIsRetryableDelegatesToClassification(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - err error - retryable bool - }{ - {name: "Nil", err: nil, retryable: false}, - {name: "RetryableExplicitStatus429", err: xerrors.New("received status 429 from upstream"), retryable: true}, - {name: "RetryableTimeout", err: xerrors.New("service unavailable"), retryable: true}, - { - name: "RetryableAnthropicMissingMessageStop", - err: xerrors.Errorf( - "anthropic stream closed before message_stop: %w", - io.EOF, - ), - retryable: true, - }, - { - name: "RetryableOpenAIResponsesMissingTerminalEvent", - err: xerrors.Errorf( - "openai responses stream closed before terminal event: %w", - io.EOF, - ), - retryable: true, - }, - {name: "NonRetryableAuth", err: xerrors.New("invalid api key"), retryable: false}, - {name: "NonRetryableGeneric", err: xerrors.New("boom"), retryable: false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - t.Parallel() - - require.Equal(t, tt.retryable, chatretry.IsRetryable(tt.err)) - require.Equal(t, chaterror.Classify(tt.err).Retryable, chatretry.IsRetryable(tt.err)) - }) - } -} - -func TestRetryabilityFromClassifyStatusCodes(t *testing.T) { - t.Parallel() - - tests := []struct { - code int - retryable bool - }{ - {408, true}, - {429, true}, - {500, true}, - {502, true}, - {503, true}, - {504, true}, - {529, true}, - {200, false}, - {400, false}, - {401, false}, - {403, false}, - {404, false}, - } - - for _, tt := range tests { - t.Run(fmt.Sprintf("Status%d", tt.code), func(t *testing.T) { - t.Parallel() - - err := xerrors.Errorf("status %d from upstream", tt.code) - classified := chaterror.Classify(err) - require.Equal(t, tt.retryable, classified.Retryable) - require.Equal(t, classified.Retryable, chatretry.IsRetryable(err)) - }) - } -} - func TestDelay(t *testing.T) { t.Parallel() diff --git a/coderd/x/chatd/chatsanitize/anthropic.go b/coderd/x/chatd/chatsanitize/anthropic.go index f3605ed091..b11be26098 100644 --- a/coderd/x/chatd/chatsanitize/anthropic.go +++ b/coderd/x/chatd/chatsanitize/anthropic.go @@ -466,15 +466,6 @@ func contentHasAnthropicSignedReasoning(content []fantasy.Content) bool { return false } -// IsAnthropicProviderExecutedToolCall reports whether toolCall is an -// Anthropic provider-executed tool call. -func IsAnthropicProviderExecutedToolCall( - provider string, - toolCall fantasy.ToolCallContent, -) bool { - return provider == fantasyanthropic.Name && toolCall.ProviderExecuted -} - // ApplyAnthropicProviderToolGuard fail-closes unsafe Anthropic provider-tool // history immediately before a provider request is issued. It returns a // sanitized prompt on success, or nil with ErrAnthropicProviderToolPromptUnsafe diff --git a/coderd/x/chatd/chatstate/messages.go b/coderd/x/chatd/chatstate/messages.go index 7248509970..56f588e1f2 100644 --- a/coderd/x/chatd/chatstate/messages.go +++ b/coderd/x/chatd/chatstate/messages.go @@ -37,6 +37,7 @@ type Message struct { TotalCostMicros sql.NullInt64 RuntimeMs sql.NullInt64 ProviderResponseID sql.NullString + APIKeyID sql.NullString } // toInsertParams converts a batch of Messages into the parallel-array @@ -51,6 +52,7 @@ func toInsertParams(chatID uuid.UUID, messages []Message) database.InsertChatMes ChatID: chatID, CreatedBy: make([]uuid.UUID, n), ModelConfigID: make([]uuid.UUID, n), + APIKeyID: make([]string, n), Role: make([]database.ChatMessageRole, n), Content: make([]string, n), ContentVersion: make([]int16, n), @@ -70,6 +72,9 @@ func toInsertParams(chatID uuid.UUID, messages []Message) database.InsertChatMes for i, m := range messages { params.CreatedBy[i] = nullUUIDOrNil(m.CreatedBy) params.ModelConfigID[i] = nullUUIDOrNil(m.ModelConfigID) + if m.APIKeyID.Valid { + params.APIKeyID[i] = m.APIKeyID.String + } params.Role[i] = m.Role if m.Content.Valid { params.Content[i] = string(m.Content.RawMessage) diff --git a/coderd/x/chatd/chatstate/transitions.go b/coderd/x/chatd/chatstate/transitions.go index b60b034cf9..f1e7be2947 100644 --- a/coderd/x/chatd/chatstate/transitions.go +++ b/coderd/x/chatd/chatstate/transitions.go @@ -257,6 +257,7 @@ func (tx *Tx) insertQueuedMessage(ownerFallback uuid.UUID, m Message) (database. Content: rawContent, ModelConfigID: m.ModelConfigID, CreatedBy: createdBy, + APIKeyID: m.APIKeyID, }) } diff --git a/coderd/x/chatd/chatstate_bridge.go b/coderd/x/chatd/chatstate_bridge.go index 8a0af79190..1583612353 100644 --- a/coderd/x/chatd/chatstate_bridge.go +++ b/coderd/x/chatd/chatstate_bridge.go @@ -1,6 +1,8 @@ package chatd import ( + "database/sql" + "github.com/google/uuid" "github.com/sqlc-dev/pqtype" @@ -30,6 +32,10 @@ func systemMessage(rawContent pqtype.NullRawMessage, modelConfigID uuid.UUID) ch // userMessage builds a chatstate.Message representing a user message // for CreateChat, SendMessage, or EditMessage. func userMessage(rawContent pqtype.NullRawMessage, modelConfigID, createdBy uuid.UUID) chatstate.Message { + return userMessageWithAPIKeyID(rawContent, modelConfigID, createdBy, "") +} + +func userMessageWithAPIKeyID(rawContent pqtype.NullRawMessage, modelConfigID, createdBy uuid.UUID, apiKeyID string) chatstate.Message { return chatstate.Message{ Role: database.ChatMessageRoleUser, Content: rawContent, @@ -37,6 +43,7 @@ func userMessage(rawContent pqtype.NullRawMessage, modelConfigID, createdBy uuid ModelConfigID: uuid.NullUUID{UUID: modelConfigID, Valid: modelConfigID != uuid.Nil}, CreatedBy: uuid.NullUUID{UUID: createdBy, Valid: createdBy != uuid.Nil}, ContentVersion: chatprompt.CurrentContentVersion, + APIKeyID: sql.NullString{String: apiKeyID, Valid: apiKeyID != ""}, } } diff --git a/coderd/x/chatd/chattest/anthropic.go b/coderd/x/chatd/chattest/anthropic.go index cb5ffe5dc5..756dc5c7d6 100644 --- a/coderd/x/chatd/chattest/anthropic.go +++ b/coderd/x/chatd/chattest/anthropic.go @@ -26,7 +26,9 @@ type AnthropicResponse struct { type AnthropicRequest struct { *http.Request // Embed http.Request Model string `json:"model"` + System json.RawMessage `json:"system,omitempty"` Messages []AnthropicRequestMessage `json:"messages"` + Tools []AnthropicRequestTool `json:"tools,omitempty"` Stream bool `json:"stream,omitempty"` MaxTokens int `json:"max_tokens,omitempty"` // TODO: encoding/json ignores inline tags. Add custom UnmarshalJSON to capture unknown keys. @@ -40,6 +42,11 @@ type AnthropicRequestMessage struct { Content json.RawMessage `json:"content"` } +// AnthropicRequestTool represents a tool in an Anthropic request. +type AnthropicRequestTool struct { + Name string `json:"name"` +} + // AnthropicMessage represents a message in an Anthropic response. type AnthropicMessage struct { ID string `json:"id,omitempty"` @@ -59,6 +66,13 @@ type AnthropicUsage struct { CacheReadInputTokens int `json:"cache_read_input_tokens,omitempty"` } +// AnthropicReasoningBlock describes one Anthropic thinking block for a +// streaming test response. +type AnthropicReasoningBlock struct { + Text string + Signature string +} + // AnthropicChunk represents a streaming chunk from Anthropic. type AnthropicChunk struct { Type string `json:"type"` @@ -83,17 +97,22 @@ type AnthropicChunkMessage struct { // AnthropicContentBlock represents a content block in a chunk. type AnthropicContentBlock struct { - Type string `json:"type"` - Text string `json:"text,omitempty"` - ID string `json:"id,omitempty"` - Name string `json:"name,omitempty"` - Input json.RawMessage `json:"input,omitempty"` + Type string `json:"type"` + Text string `json:"text,omitempty"` + Thinking string `json:"thinking,omitempty"` + ID string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Input json.RawMessage `json:"input,omitempty"` + ToolUseID string `json:"tool_use_id,omitempty"` + Content any `json:"content,omitempty"` } // AnthropicDeltaBlock represents a delta block in a chunk. type AnthropicDeltaBlock struct { Type string `json:"type"` Text string `json:"text,omitempty"` + Thinking string `json:"thinking,omitempty"` + Signature string `json:"signature,omitempty"` PartialJSON string `json:"partial_json,omitempty"` } @@ -424,6 +443,95 @@ func AnthropicTextChunksWithCacheUsage(usage AnthropicUsage, deltas ...string) [ return chunks } +// AnthropicReasoningTextChunks creates a streaming response with one or more +// thinking blocks followed by one text block. +func AnthropicReasoningTextChunks(reasoning []AnthropicReasoningBlock, text string) []AnthropicChunk { + messageID := fmt.Sprintf("msg-%s", uuid.New().String()[:8]) + model := "claude-3-opus-20240229" + + chunks := []AnthropicChunk{ + { + Type: "message_start", + Message: AnthropicChunkMessage{ + ID: messageID, + Type: "message", + Role: "assistant", + Model: model, + }, + }, + } + + for i, block := range reasoning { + chunks = append(chunks, + AnthropicChunk{ + Type: "content_block_start", + Index: i, + ContentBlock: AnthropicContentBlock{ + Type: "thinking", + Thinking: "", + }, + }, + AnthropicChunk{ + Type: "content_block_delta", + Index: i, + Delta: AnthropicDeltaBlock{ + Type: "thinking_delta", + Thinking: block.Text, + }, + }, + ) + if block.Signature != "" { + chunks = append(chunks, AnthropicChunk{ + Type: "content_block_delta", + Index: i, + Delta: AnthropicDeltaBlock{ + Type: "signature_delta", + Signature: block.Signature, + }, + }) + } + chunks = append(chunks, AnthropicChunk{ + Type: "content_block_stop", + Index: i, + }) + } + + textIndex := len(reasoning) + chunks = append(chunks, + AnthropicChunk{ + Type: "content_block_start", + Index: textIndex, + ContentBlock: AnthropicContentBlock{ + Type: "text", + Text: "", + }, + }, + AnthropicChunk{ + Type: "content_block_delta", + Index: textIndex, + Delta: AnthropicDeltaBlock{ + Type: "text_delta", + Text: text, + }, + }, + AnthropicChunk{ + Type: "content_block_stop", + Index: textIndex, + }, + AnthropicChunk{ + Type: "message_delta", + StopReason: "end_turn", + Usage: AnthropicUsage{ + InputTokens: 10, + OutputTokens: 5, + }, + }, + AnthropicChunk{Type: "message_stop"}, + ) + + return chunks +} + // AnthropicToolCallChunks creates a complete streaming response for a tool call. // Input JSON can be split across multiple deltas, matching Anthropic's // input_json_delta streaming behavior. diff --git a/coderd/x/chatd/chattest/errors.go b/coderd/x/chatd/chattest/errors.go index 2c84339600..b9b3f5d759 100644 --- a/coderd/x/chatd/chattest/errors.go +++ b/coderd/x/chatd/chattest/errors.go @@ -43,17 +43,6 @@ func AnthropicErrorResponse(statusCode int, errorType, message string) Anthropic } } -// AnthropicOverloadedResponse returns a 529 "overloaded" error matching -// Anthropic's overloaded response format. -func AnthropicOverloadedResponse() AnthropicResponse { - return AnthropicErrorResponse(529, "overloaded_error", "Overloaded") -} - -// AnthropicRateLimitResponse returns a 429 rate limit error. -func AnthropicRateLimitResponse() AnthropicResponse { - return AnthropicErrorResponse(http.StatusTooManyRequests, "rate_limit_error", "Rate limited") -} - // OpenAIErrorResponse returns an OpenAIResponse that causes the // test server to respond with the given HTTP status code and error. func OpenAIErrorResponse(statusCode int, errorType, message string) OpenAIResponse { diff --git a/coderd/x/chatd/context_helpers.go b/coderd/x/chatd/context_helpers.go new file mode 100644 index 0000000000..be684c7c6d --- /dev/null +++ b/coderd/x/chatd/context_helpers.go @@ -0,0 +1,82 @@ +package chatd + +import ( + "bytes" + "encoding/json" + + "github.com/google/uuid" + + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/codersdk" +) + +// agentChatContextSentinelPath marks the synthetic empty context-file +// part used to record an attempted workspace-context fetch when no +// AGENTS.md content is available. It mirrors the constant of the same +// value in the chatd package so the worker can recognize sentinel +// parts without importing chatd (which would be a cycle). +const agentChatContextSentinelPath = ".coder/agent-chat-context-sentinel" + +// contextFileAgentIDFromMessages returns the most recent workspace +// agent ID stamped on a persisted context-file part, ignoring the +// skill-only sentinel. Returns uuid.Nil, false when no stamped +// non-sentinel context-file parts exist. +// +// This mirrors chatd.contextFileAgentID. It is duplicated here as a +// small pure helper so chatworker can decide whether workspace +// context is current without importing chatd. +func contextFileAgentIDFromMessages(messages []database.ChatMessage) (uuid.UUID, bool) { + var lastID uuid.UUID + found := false + for _, msg := range messages { + if !msg.Content.Valid || !bytes.Contains(msg.Content.RawMessage, []byte(`"context-file"`)) { + continue + } + var parts []codersdk.ChatMessagePart + if err := json.Unmarshal(msg.Content.RawMessage, &parts); err != nil { + continue + } + for _, p := range parts { + if p.Type != codersdk.ChatMessagePartTypeContextFile || + !p.ContextFileAgentID.Valid || + p.ContextFilePath == agentChatContextSentinelPath { + continue + } + lastID = p.ContextFileAgentID.UUID + found = true + break + } + } + return lastID, found +} + +// hasPersistedContextFileForAgent reports whether messages include +// any persisted context-file marker for the given agent, including +// the skill-only sentinel. This is true once the +// persist_workspace_context action has committed at least one +// context-file row for the agent (with or without content), so a +// subsequent decision pass will not loop on the same agent. +func hasPersistedContextFileForAgent(messages []database.ChatMessage, agentID uuid.UUID) bool { + if agentID == uuid.Nil { + return false + } + for _, msg := range messages { + if !msg.Content.Valid || !bytes.Contains(msg.Content.RawMessage, []byte(`"context-file"`)) { + continue + } + var parts []codersdk.ChatMessagePart + if err := json.Unmarshal(msg.Content.RawMessage, &parts); err != nil { + continue + } + for _, p := range parts { + if p.Type != codersdk.ChatMessagePartTypeContextFile || + !p.ContextFileAgentID.Valid { + continue + } + if p.ContextFileAgentID.UUID == agentID { + return true + } + } + } + return false +} diff --git a/coderd/x/chatd/export_test.go b/coderd/x/chatd/export_test.go index 519ed0dcad..05e313fe2f 100644 --- a/coderd/x/chatd/export_test.go +++ b/coderd/x/chatd/export_test.go @@ -2,61 +2,25 @@ package chatd import ( "context" + "sync" - "github.com/sqlc-dev/pqtype" - - "cdr.dev/slog/v3" "github.com/coder/coder/v2/coderd/database" - "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/codersdk/workspacesdk" ) -// FinishActiveChatForTest exposes the unexported cleanup TX so tests -// can drive the post-run state machine deterministically. Returns the -// resulting chat, the promoted message (if any), the synthetic -// tool-result rows the cleanup TX inserted (if any), and the cleanup -// error. The lastError string is encoded into a structured payload -// the same way runChat does, so callers do not need to know about -// the structured-error wrapper. -func FinishActiveChatForTest( - ctx context.Context, - server *Server, - chat database.Chat, - status database.ChatStatus, - lastError string, -) (database.Chat, *database.ChatMessage, []database.ChatMessage, error) { - logger := server.logger.With(slog.F("chat_id", chat.ID)) - var encoded pqtype.NullRawMessage - if lastError != "" { - var err error - encoded, err = encodeChatLastErrorPayload(&codersdk.ChatError{ - Message: lastError, - }) - if err != nil { - return database.Chat{}, nil, nil, err - } - } - result, err := server.finishActiveChat(ctx, logger, chat, status, encoded) - if err != nil { - return database.Chat{}, nil, nil, err - } - return result.updatedChat, result.promotedMessage, result.syntheticToolResults, nil +type TurnWorkspaceContextForTest struct { + inner *turnWorkspaceContext } -// RecoverStaleChatsForTest exposes the unexported stale-recovery loop -// so tests can assert the recovery state machine without waiting for -// the periodic ticker. -func RecoverStaleChatsForTest(ctx context.Context, server *Server) { - server.recoverStaleChats(ctx) +func NewTurnWorkspaceContextForTest(server *Server, chat database.Chat) *TurnWorkspaceContextForTest { + return &TurnWorkspaceContextForTest{inner: &turnWorkspaceContext{ + server: server, + chatStateMu: &sync.Mutex{}, + currentChat: &chat, + loadChatSnapshot: server.db.GetChatByID, + }} } -// InsertSyntheticToolResultsTxForTest exposes the unexported helper -// so tests can verify the dedup path against pre-existing tool -// results. -func InsertSyntheticToolResultsTxForTest( - ctx context.Context, - store database.Store, - chat database.Chat, - reason string, -) ([]database.ChatMessage, error) { - return insertSyntheticToolResultsTx(ctx, store, chat, reason) +func (c *TurnWorkspaceContextForTest) GetWorkspaceConn(ctx context.Context) (workspacesdk.AgentConn, error) { + return c.inner.getWorkspaceConn(ctx) } diff --git a/coderd/x/chatd/generation.go b/coderd/x/chatd/generation.go new file mode 100644 index 0000000000..f8b112d9e6 --- /dev/null +++ b/coderd/x/chatd/generation.go @@ -0,0 +1,1136 @@ +package chatd + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "strings" + "time" + + "charm.land/fantasy" + "github.com/google/uuid" + "github.com/sqlc-dev/pqtype" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/x/chatd/chaterror" + "github.com/coder/coder/v2/coderd/x/chatd/chatloop" + "github.com/coder/coder/v2/coderd/x/chatd/chatprompt" + "github.com/coder/coder/v2/coderd/x/chatd/chatprovider" + "github.com/coder/coder/v2/coderd/x/chatd/chatretry" + "github.com/coder/coder/v2/coderd/x/chatd/chatstate" + "github.com/coder/coder/v2/coderd/x/chatd/messagepartbuffer" + "github.com/coder/coder/v2/codersdk" +) + +// generationPrepareInput contains the committed state used to prepare one +// generation action. +type generationPrepareInput struct { + Chat database.Chat + Messages []database.ChatMessage + ChainModeDisabled bool + PublishMessagePart func(codersdk.ChatMessageRole, codersdk.ChatMessagePart) +} + +// generationPrepared contains the side-effect inputs for a generation task. +type generationPrepared struct { + // Chat may be set when metadata side effects, such as resolving and + // persisting a workspace agent binding, have mutated the chat row since + // the worker last loaded it. When zero, the worker's locked chat + // snapshot is used for generation decisions and actions. + Chat database.Chat + Messages []database.ChatMessage + + Model fantasy.LanguageModel + Prompt []fantasy.Message + Tools []fantasy.AgentTool + ActiveTools []string + ProviderTools []chatloop.ProviderTool + ProviderKeys chatprovider.ProviderAPIKeys + ModelRoute resolvedModelRoute + ModelBuildOptions modelBuildOptions + + ModelConfigID uuid.UUID + ModelConfig codersdk.ChatModelCallConfig + ProviderOptions fantasy.ProviderOptions + ContextLimitFallback int64 + + DynamicToolNames map[string]bool + StopAfterTools map[string]struct{} + ExclusiveToolNames map[string]bool + BuiltinToolNames map[string]bool + ToolNameToConfigID map[string]uuid.UUID + + MaxSteps int + Compaction *generationCompaction + Cleanup func() + + // WorkspaceContextEligible reports whether the current turn is allowed + // by policy to inject workspace context. The decision helper combines + // this fact with committed chat metadata and history to decide whether + // the persist_workspace_context action should run. + WorkspaceContextEligible bool +} + +// generationCompaction contains compaction inputs prepared for generation. +type generationCompaction struct { + Required bool + Options chatloop.GenerateCompactionOptions +} + +type workspaceContextBuildInput struct { + Chat database.Chat + Messages []database.ChatMessage +} + +type workspaceContextBuildResult struct { + Messages []chatstate.Message +} + +// generationOutcome describes a completed generation outcome. +type generationOutcome struct { + Chat database.Chat + Kind runnerActionKind + WatchEventKind codersdk.ChatWatchEventKind + LastError string + PromotedMessageID int64 + InsertedMessages []runnerActionMessage +} + +type generationActionKind string + +const ( + generationActionExecuteLocalTools generationActionKind = "execute_local_tools" + generationActionEnterRequiresAction generationActionKind = "enter_requires_action" + generationActionFinishTurn generationActionKind = "finish_turn" + generationActionCompact generationActionKind = "compact" + generationActionGenerateAssistant generationActionKind = "generate_assistant" + generationActionPersistWorkspaceContext generationActionKind = "persist_workspace_context" +) + +type generationFinishReason string + +const ( + generationFinishReasonStopAfterTool generationFinishReason = "stop_after_tool" + generationFinishReasonComplete generationFinishReason = "complete" + generationFinishReasonMaxSteps generationFinishReason = "max_steps" +) + +type compactionTrigger string + +const ( + compactionTriggerRequired compactionTrigger = "required" + compactionTriggerAlreadyCompacted compactionTrigger = "already_compacted" +) + +var errCompactionStillOverLimit = xerrors.New("compaction left the chat above the compaction limit") + +type generationDecision struct { + kind generationActionKind + localToolCalls []fantasy.ToolCallContent + pendingDynamicToolCalls []pendingDynamicToolCall + finishReason generationFinishReason + compactionTrigger compactionTrigger + promotedMessageID int64 +} + +type generationRetryDecision struct { + retry bool + generationAttempt int64 + delay time.Duration +} + +var errRetryStateDecisionOnly = xerrors.New("retry state decision only") + +// errTerminalGeneration marks a prepare or decide failure as terminal: a +// deterministic error where retrying cannot help. The generation loop +// finishes the turn with an error instead of retrying when an error +// unwraps to this sentinel. +var errTerminalGeneration = xerrors.New("terminal generation error") + +type terminalGenerationError struct{ err error } + +func (e terminalGenerationError) Error() string { return e.err.Error() } + +func (e terminalGenerationError) Unwrap() error { return errors.Join(errTerminalGeneration, e.err) } + +// terminalGeneration wraps err so the prepare/decide retry loop stops +// immediately and finishes the turn with an error. +func terminalGeneration(err error) error { + if err == nil { + return nil + } + return terminalGenerationError{err: err} +} + +func isTerminalGeneration(err error) bool { + return errors.Is(err, errTerminalGeneration) +} + +type generationDecisionInput struct { + chat database.Chat + messages []database.ChatMessage + dynamicToolNames map[string]bool + exclusiveToolNames map[string]bool + stopAfterTools map[string]struct{} + maxSteps int + compactionEnabled bool + compactionNeeded bool + workspaceContextEligible bool +} + +// shouldPersistWorkspaceContext reports whether the committed chat +// state and history indicate that the persistWorkspaceContext +// generation action should run before the next assistant call. The +// decision uses two facts: +// - chat metadata says a workspace and selected agent are attached; +// - committed history either has no context-file marker for the +// currently selected workspace agent, or the latest non-sentinel +// marker points to a different agent. +// +// The decision is intentionally pure so generation can choose the +// action without dialing the workspace. Once the action commits a +// context-file marker for the agent (with or without content), this +// helper returns false on the next pass and the loop is broken. +func shouldPersistWorkspaceContext(chat database.Chat, messages []database.ChatMessage) bool { + if !chat.WorkspaceID.Valid || !chat.AgentID.Valid { + return false + } + if hasPersistedContextFileForAgent(messages, chat.AgentID.UUID) { + return false + } + persistedAgentID, found := contextFileAgentIDFromMessages(messages) + if !found { + return true + } + return persistedAgentID != chat.AgentID.UUID +} + +func decideGenerationAction(input generationDecisionInput) (generationDecision, error) { + localCalls, dynamicCalls, err := unresolvedToolCallsFromHistory(input.messages, input.dynamicToolNames) + if err != nil { + return generationDecision{}, err + } + if len(localCalls) > 0 { + if len(dynamicCalls) > 0 && hasExclusiveToolCall(localCalls, input.exclusiveToolNames) { + for _, dynamicCall := range dynamicCalls { + localCalls = append(localCalls, fantasy.ToolCallContent{ + ToolCallID: dynamicCall.ToolCallID, + ToolName: dynamicCall.ToolName, + Input: dynamicCall.Args, + }) + } + dynamicCalls = nil + } + return generationDecision{kind: generationActionExecuteLocalTools, localToolCalls: localCalls, pendingDynamicToolCalls: dynamicCalls}, nil + } + if len(dynamicCalls) > 0 { + return generationDecision{kind: generationActionEnterRequiresAction, pendingDynamicToolCalls: dynamicCalls}, nil + } + + stopAfter, err := historyHasStopAfterToolResult(input.messages, input.stopAfterTools) + if err != nil { + return generationDecision{}, err + } + if stopAfter { + return generationDecision{kind: generationActionFinishTurn, finishReason: generationFinishReasonStopAfterTool}, nil + } + complete, err := currentHistoryComplete(input.messages) + if err != nil { + return generationDecision{}, err + } + if complete { + return generationDecision{kind: generationActionFinishTurn, finishReason: generationFinishReasonComplete}, nil + } + if input.maxSteps > 0 && currentTurnStepCount(input.messages) >= input.maxSteps { + return generationDecision{kind: generationActionFinishTurn, finishReason: generationFinishReasonMaxSteps}, nil + } + if input.workspaceContextEligible && shouldPersistWorkspaceContext(input.chat, input.messages) { + return generationDecision{kind: generationActionPersistWorkspaceContext}, nil + } + compactionRequirement := compactionRequirementNotNeeded + if input.compactionEnabled && input.compactionNeeded { + compactionRequirement = compactionRequirementNeeded + } + switch compactionStatusFromHistory(input.messages, compactionRequirement) { + case compactionStatusNeeded: + return generationDecision{kind: generationActionCompact, compactionTrigger: compactionTriggerRequired}, nil + case compactionStatusAfterCompaction: + return generationDecision{kind: generationActionGenerateAssistant, compactionTrigger: compactionTriggerAlreadyCompacted}, nil + case compactionStatusStillOverLimit: + return generationDecision{}, terminalGeneration(errCompactionStillOverLimit) + case compactionStatusNotNeeded: + return generationDecision{kind: generationActionGenerateAssistant}, nil + default: + return generationDecision{}, terminalGeneration(xerrors.New("unknown compaction status")) + } +} + +func unresolvedToolCallsFromHistory( + messages []database.ChatMessage, + dynamicToolNames map[string]bool, +) ([]fantasy.ToolCallContent, []pendingDynamicToolCall, error) { + assistantIndex := lastMessageIndex(messages, func(msg database.ChatMessage) bool { + return msg.Role == database.ChatMessageRoleAssistant + }) + if assistantIndex == -1 { + return nil, nil, nil + } + assistantParts, err := chatprompt.ParseContent(messages[assistantIndex]) + if err != nil { + return nil, nil, xerrors.Errorf("parse assistant message: %w", err) + } + handled, err := handledToolCallIDs(messages[assistantIndex+1:]) + if err != nil { + return nil, nil, err + } + localCalls := make([]fantasy.ToolCallContent, 0) + dynamicCalls := make([]pendingDynamicToolCall, 0) + for _, part := range assistantParts { + if part.Type != codersdk.ChatMessagePartTypeToolCall || part.ProviderExecuted || handled[part.ToolCallID] { + continue + } + if dynamicToolNames[part.ToolName] { + dynamicCalls = append(dynamicCalls, pendingDynamicToolCall{ + ToolCallID: part.ToolCallID, + ToolName: part.ToolName, + Args: string(part.Args), + }) + continue + } + localCalls = append(localCalls, fantasy.ToolCallContent{ + ToolCallID: part.ToolCallID, + ToolName: part.ToolName, + Input: string(part.Args), + ProviderExecuted: part.ProviderExecuted, + }) + } + return localCalls, dynamicCalls, nil +} + +func hasExclusiveToolCall(toolCalls []fantasy.ToolCallContent, exclusiveToolNames map[string]bool) bool { + if len(exclusiveToolNames) == 0 { + return false + } + for _, toolCall := range toolCalls { + if exclusiveToolNames[toolCall.ToolName] { + return true + } + } + return false +} + +func (s *taskStarter) StartGeneration(ctx context.Context, input chatWorkerTaskStartInput) error { + if s.server == nil { + return xerrors.New("chatworker: server is required") + } + machine := chatstate.NewChatMachine(s.opts.Store, s.opts.Pubsub, input.ChatID, chatstate.Options{}) + chainModeDisabled := false + for { + locked, messages, err := loadGenerationState(ctx, machine, input) + if err != nil { + return err + } + prepareInput := generationPrepareInput{ + Chat: locked, + Messages: messages, + ChainModeDisabled: chainModeDisabled, + } + prepared, err := retryGenerationPhase(ctx, s.waitGenerationPhaseBackoff, func() (generationPrepared, error) { + return s.server.prepareGeneration(ctx, prepareInput) + }) + if err != nil { + if errors.Is(err, errTaskExpectedExit) { + return errTaskExpectedExit + } + return s.finishGenerationError(ctx, machine, input, 0, err, generationAttemptNotRequired) + } + if prepared.Messages == nil { + prepared.Messages = messages + } + decisionChat := locked + if prepared.Chat.ID != uuid.Nil { + decisionChat = prepared.Chat + } + decision, err := retryGenerationPhase(ctx, s.waitGenerationPhaseBackoff, func() (generationDecision, error) { + return decideGenerationAction(generationDecisionInput{ + chat: decisionChat, + messages: prepared.Messages, + dynamicToolNames: prepared.DynamicToolNames, + exclusiveToolNames: prepared.ExclusiveToolNames, + stopAfterTools: prepared.StopAfterTools, + maxSteps: prepared.MaxSteps, + compactionEnabled: prepared.Compaction != nil, + compactionNeeded: prepared.Compaction != nil && prepared.Compaction.Required, + workspaceContextEligible: prepared.WorkspaceContextEligible, + }) + }) + if err != nil { + if prepared.Cleanup != nil { + prepared.Cleanup() + } + if errors.Is(err, errTaskExpectedExit) { + return errTaskExpectedExit + } + return s.finishGenerationError(ctx, machine, input, 0, err, generationAttemptNotRequired) + } + cleanup := prepared.Cleanup + if cleanup == nil { + cleanup = func() {} + } + + var actionErr error + switch decision.kind { + case generationActionEnterRequiresAction: + cleanup() + return s.enterRequiresAction(ctx, machine, input, decision) + case generationActionFinishTurn: + cleanup() + return s.finishGenerationTurn(ctx, machine, input, 0, decision, generationAttemptNotRequired) + case generationActionGenerateAssistant: + actionErr = s.generateAssistant(ctx, machine, input, prepareInput, decision) + case generationActionExecuteLocalTools: + actionErr = s.executeLocalTools(ctx, machine, input, prepareInput, decision) + case generationActionCompact: + actionErr = s.generateCompaction(ctx, machine, input, prepareInput) + case generationActionPersistWorkspaceContext: + actionErr = s.persistWorkspaceContext(ctx, machine, input, decisionChat) + default: + return s.finishGenerationError(ctx, machine, input, 0, xerrors.Errorf("unknown generation action %q", decision.kind), generationAttemptNotRequired) + } + cleanup() + if actionErr == nil { + return nil + } + if errors.Is(actionErr, errTaskExpectedExit) || errors.Is(actionErr, chatloop.ErrInterrupted) { + return nil + } + if errors.Is(actionErr, context.Canceled) && ctx.Err() != nil { + return nil + } + classified := chaterror.Classify(actionErr) + if classified.Retryable { + decision, err := s.recordGenerationRetry(ctx, machine, input, classified) + if err != nil { + return err + } + if decision.retry { + if classified.ChainBroken { + chainModeDisabled = true + } + if err := s.waitGenerationRetry(ctx, decision.delay); err != nil { + return err + } + continue + } + return s.finishGenerationError(ctx, machine, input, decision.generationAttempt, actionErr, generationAttemptRequired) + } + return s.finishGenerationError(ctx, machine, input, 0, actionErr, generationAttemptNotRequired) + } +} + +func loadGenerationState( + ctx context.Context, + machine *chatstate.ChatMachine, + input chatWorkerTaskStartInput, +) (database.Chat, []database.ChatMessage, error) { + var locked database.Chat + var messages []database.ChatMessage + err := machine.ReadLock(ctx, func(store database.Store) error { + chat, err := store.GetChatByID(ctx, input.ChatID) + if errors.Is(err, sql.ErrNoRows) { + return errTaskExpectedExit + } + if err != nil { + return xerrors.Errorf("load locked chat: %w", err) + } + if err := verifyTaskFence(chat, input, database.ChatStatusRunning, taskFenceOptions{requireHistory: true}); err != nil { + return err + } + loaded, err := store.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ + ChatID: input.ChatID, + AfterID: 0, + }) + if err != nil { + return xerrors.Errorf("load chat messages: %w", err) + } + locked = chat + messages = loaded + return nil + }) + if err != nil { + return database.Chat{}, nil, normalizeTaskInfrastructureError(err, "lock chat for generation") + } + return locked, messages, nil +} + +func (*taskStarter) recordGenerationRetry( + ctx context.Context, + machine *chatstate.ChatMachine, + input chatWorkerTaskStartInput, + classified chaterror.ClassifiedError, +) (generationRetryDecision, error) { + var decision generationRetryDecision + var payload *codersdk.ChatStreamRetry + err := machine.Update(ctx, func(tx *chatstate.Tx) error { + locked, err := tx.Store().GetChatByID(ctx, input.ChatID) + if errors.Is(err, sql.ErrNoRows) { + return errTaskExpectedExit + } + if err != nil { + return xerrors.Errorf("load chat: %w", err) + } + if err := verifyTaskFence(locked, input, database.ChatStatusRunning, taskFenceOptions{requireHistory: true}); err != nil { + return err + } + decision.generationAttempt = locked.GenerationAttempt + if locked.GenerationAttempt <= 0 || locked.GenerationAttempt >= int64(chatretry.MaxAttempts) { + decision.retry = false + return errRetryStateDecisionOnly + } + + attempt := int(locked.GenerationAttempt) + delay := chatretry.Delay(attempt - 1) + if classified.RetryAfter > delay { + delay = classified.RetryAfter + } + decision.retry = true + decision.delay = delay + + payload = chaterror.StreamRetryPayload(attempt, delay, classified) + if payload == nil { + return errRetryStateDecisionOnly + } + encoded, err := json.Marshal(payload) + if err != nil { + return xerrors.Errorf("marshal retry state: %w", err) + } + _, err = tx.RecordRetryState(chatstate.RecordRetryStateInput{ + RetryState: pqtype.NullRawMessage{RawMessage: encoded, Valid: true}, + }) + return err + }) + if errors.Is(err, errRetryStateDecisionOnly) { + return decision, nil + } + if err != nil { + return generationRetryDecision{}, normalizeTaskTransitionError(err, "record retry state") + } + return decision, nil +} + +func (s *taskStarter) waitGenerationRetry(ctx context.Context, delay time.Duration) error { + timer := s.opts.Clock.NewTimer(delay, "chatworker", "generation-retry") + defer timer.Stop() + select { + case <-timer.C: + return nil + case <-ctx.Done(): + return errTaskExpectedExit + } +} + +const ( + // generationPhaseMaxAttempts bounds how many times prepareGeneration + // and decideGenerationAction run before the turn finishes with an + // error. Both phases are retried because prepareGeneration performs + // I/O (DB reads, MCP connects, workspace dials) that can fail + // transiently. + generationPhaseMaxAttempts = 3 + // generationPhaseBaseBackoff is the delay before the first retry. It + // doubles on each subsequent attempt. + generationPhaseBaseBackoff = 200 * time.Millisecond +) + +func generationPhaseBackoff(attempt int) time.Duration { + d := generationPhaseBaseBackoff + for range attempt { + d *= 2 + } + return d +} + +// retryGenerationPhase runs fn up to generationPhaseMaxAttempts times. It +// returns early on success or on a terminal error (see terminalGeneration). +// Non-terminal errors are retried with exponential backoff. Context +// cancellation returns errTaskExpectedExit so shutdown does not write an +// error state. When every attempt fails, the last error is returned. +func retryGenerationPhase[T any]( + ctx context.Context, + wait func(context.Context, time.Duration) error, + fn func() (T, error), +) (T, error) { + var zero T + var lastErr error + for attempt := 0; attempt < generationPhaseMaxAttempts; attempt++ { + result, err := fn() + if err == nil { + return result, nil + } + if isTerminalGeneration(err) { + return zero, err + } + if ctx.Err() != nil { + return zero, errTaskExpectedExit + } + lastErr = err + if attempt < generationPhaseMaxAttempts-1 { + if waitErr := wait(ctx, generationPhaseBackoff(attempt)); waitErr != nil { + return zero, waitErr + } + } + } + return zero, lastErr +} + +func (s *taskStarter) waitGenerationPhaseBackoff(ctx context.Context, delay time.Duration) error { + timer := s.opts.Clock.NewTimer(delay, "chatworker", "generation-phase-retry") + defer timer.Stop() + select { + case <-timer.C: + return nil + case <-ctx.Done(): + return errTaskExpectedExit + } +} + +func (s *taskStarter) generateAssistant( + ctx context.Context, + machine *chatstate.ChatMachine, + input chatWorkerTaskStartInput, + prepareInput generationPrepareInput, + decision generationDecision, +) error { + attempt, key, publish, closeEpisode, err := s.beginGenerationAttempt(ctx, machine, input) + if err != nil { + return err + } + defer closeEpisode() + prepareInput.PublishMessagePart = publish + prepared, err := s.server.prepareGeneration(ctx, prepareInput) + if err != nil { + return err + } + if prepared.Cleanup != nil { + defer prepared.Cleanup() + } + outcome, err := chatloop.GenerateAssistant(ctx, chatloop.GenerateAssistantOptions{ + Model: prepared.Model, + Messages: prepared.Prompt, + Tools: prepared.Tools, + ActiveTools: prepared.ActiveTools, + ProviderTools: prepared.ProviderTools, + ContextLimitFallback: prepared.ContextLimitFallback, + ModelConfig: prepared.ModelConfig, + ProviderOptions: prepared.ProviderOptions, + PublishMessagePart: publish, + Logger: s.opts.Logger, + Clock: s.opts.Clock, + Metrics: s.server.metrics, + }) + _ = key + if err != nil { + return err + } + if decision.compactionTrigger == compactionTriggerAlreadyCompacted && + shouldCompactPromptUsage(outcome.Step.Usage, prepared.ContextLimitFallback, prepared.Compaction.Options.ThresholdPercent) { + err := errCompactionStillOverLimit + s.server.metrics.RecordCompaction(compactionProvider(prepared.Compaction.Options), compactionModel(prepared.Compaction.Options), false, err) + return s.finishGenerationError(ctx, machine, input, attempt, err, generationAttemptRequired) + } + if len(outcome.Step.Content) == 0 { + return s.finishGenerationTurn(ctx, machine, input, attempt, generationDecision{kind: generationActionFinishTurn, finishReason: generationFinishReasonComplete}, generationAttemptRequired) + } + messages, err := buildCommitStepMessages(buildCommitStepMessagesInput{ + modelConfigID: prepared.ModelConfigID, + modelCallConfig: prepared.ModelConfig, + step: stepDataFromPersisted(outcome.Step), + toolNameToConfigID: prepared.ToolNameToConfigID, + logger: s.opts.Logger, + contentVersion: chatprompt.CurrentContentVersion, + }) + if err != nil { + return s.finishGenerationError(ctx, machine, input, attempt, err, generationAttemptRequired) + } + return s.commitGenerationStep(ctx, machine, input, attempt, generationActionGenerateAssistant, messages) +} + +func (s *taskStarter) executeLocalTools( + ctx context.Context, + machine *chatstate.ChatMachine, + input chatWorkerTaskStartInput, + prepareInput generationPrepareInput, + decision generationDecision, +) error { + attempt, _, publish, closeEpisode, err := s.beginGenerationAttempt(ctx, machine, input) + if err != nil { + return err + } + defer closeEpisode() + prepareInput.PublishMessagePart = publish + prepared, err := s.server.prepareGeneration(ctx, prepareInput) + if err != nil { + return err + } + if prepared.Cleanup != nil { + defer prepared.Cleanup() + } + provider := "" + modelName := "" + if prepared.Model != nil { + provider = prepared.Model.Provider() + modelName = prepared.Model.Model() + } + outcome, err := chatloop.ExecuteLocalTools(ctx, chatloop.ExecuteLocalToolsOptions{ + Tools: prepared.Tools, + ActiveTools: prepared.ActiveTools, + ProviderTools: prepared.ProviderTools, + ToolCalls: decision.localToolCalls, + ExclusiveToolNames: prepared.ExclusiveToolNames, + BuiltinToolNames: prepared.BuiltinToolNames, + ModelProvider: provider, + ModelName: modelName, + PublishMessagePart: publish, + Logger: s.opts.Logger, + Metrics: s.server.metrics, + }) + if err != nil { + return err + } + messages, err := buildCommitStepMessages(buildCommitStepMessagesInput{ + modelConfigID: prepared.ModelConfigID, + modelCallConfig: prepared.ModelConfig, + step: stepDataFromPersisted(outcome.Step), + toolNameToConfigID: prepared.ToolNameToConfigID, + logger: s.opts.Logger, + contentVersion: chatprompt.CurrentContentVersion, + }) + if err != nil { + return s.finishGenerationError(ctx, machine, input, attempt, err, generationAttemptRequired) + } + return s.commitGenerationStep(ctx, machine, input, attempt, generationActionExecuteLocalTools, messages) +} + +func (s *taskStarter) generateCompaction( + ctx context.Context, + machine *chatstate.ChatMachine, + input chatWorkerTaskStartInput, + prepareInput generationPrepareInput, +) error { + attempt, _, publish, closeEpisode, err := s.beginGenerationAttempt(ctx, machine, input) + if err != nil { + return err + } + defer closeEpisode() + prepareInput.PublishMessagePart = publish + prepared, err := s.server.prepareGeneration(ctx, prepareInput) + if err != nil { + return err + } + if prepared.Cleanup != nil { + defer prepared.Cleanup() + } + if prepared.Compaction == nil { + return s.finishGenerationError(ctx, machine, input, attempt, xerrors.New("compaction action missing options"), generationAttemptRequired) + } + compactionOpts := prepared.Compaction.Options + compactionOpts.PublishMessagePart = publish + outcome, err := chatloop.GenerateCompaction(ctx, compactionOpts) + if err != nil { + s.server.metrics.RecordCompaction(compactionProvider(compactionOpts), compactionModel(compactionOpts), false, err) + return err + } + if strings.TrimSpace(outcome.SystemSummary) == "" || strings.TrimSpace(outcome.SummaryReport) == "" { + err := xerrors.New("compaction produced no summary") + s.server.metrics.RecordCompaction(compactionProvider(compactionOpts), compactionModel(compactionOpts), false, err) + return s.finishGenerationError(ctx, machine, input, attempt, err, generationAttemptRequired) + } + messages, err := buildCompactionMessages(buildCompactionMessagesInput{ + modelConfigID: prepared.ModelConfigID, + toolCallID: compactionOpts.ToolCallID, + toolName: compactionOpts.ToolName, + compaction: compactionOutcome(outcome), + contentVersion: chatprompt.CurrentContentVersion, + }) + if err != nil { + s.server.metrics.RecordCompaction(compactionProvider(compactionOpts), compactionModel(compactionOpts), false, err) + return s.finishGenerationError(ctx, machine, input, attempt, err, generationAttemptRequired) + } + err = s.commitGenerationStep(ctx, machine, input, attempt, generationActionCompact, stepMessagesForCommit{ + Messages: messages.Messages, + VisibleIndexes: visibleMessageIndexes(messages.Messages), + }) + s.server.metrics.RecordCompaction(compactionProvider(compactionOpts), compactionModel(compactionOpts), err == nil, err) + return err +} + +func compactionProvider(opts chatloop.GenerateCompactionOptions) string { + if opts.Model == nil { + return "" + } + return opts.Model.Provider() +} + +func compactionModel(opts chatloop.GenerateCompactionOptions) string { + if opts.Model == nil { + return "" + } + return opts.Model.Model() +} + +// persistWorkspaceContext is the generation action that commits durable +// workspace context messages (e.g. AGENTS.md, workspace skills) into +// chat history. It records a generation attempt, calls the injected +// workspace context builder without holding the DB lock, then commits +// the returned messages fenced to the attempt. If the builder returns +// no messages, the action exits as expected and the next worker task +// re-reads the chat. +func (s *taskStarter) persistWorkspaceContext( + ctx context.Context, + machine *chatstate.ChatMachine, + input chatWorkerTaskStartInput, + locked database.Chat, +) error { + if s.server == nil { + return errTaskExpectedExit + } + messages, err := s.opts.Store.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ + ChatID: input.ChatID, + AfterID: 0, + }) + if err != nil { + return taskRetryableError{err: xerrors.Errorf("load chat messages for workspace context: %w", err)} + } + attempt, _, _, closeEpisode, err := s.beginGenerationAttempt(ctx, machine, input) + if err != nil { + return err + } + defer closeEpisode() + result, err := s.server.buildWorkspaceContext(ctx, workspaceContextBuildInput{ + Chat: locked, + Messages: messages, + }) + if err != nil { + return err + } + if len(result.Messages) == 0 { + // Builder reported no durable messages to commit (workspace or + // agent missing, unreachable, etc.). Exit the action without + // committing so the next worker task can re-read the chat. + return errTaskExpectedExit + } + return s.commitGenerationStep(ctx, machine, input, attempt, generationActionPersistWorkspaceContext, stepMessagesForCommit{ + Messages: result.Messages, + VisibleIndexes: visibleMessageIndexes(result.Messages), + }) +} + +func (s *taskStarter) beginGenerationAttempt( + ctx context.Context, + machine *chatstate.ChatMachine, + input chatWorkerTaskStartInput, +) (int64, messagepartbuffer.Key, func(codersdk.ChatMessageRole, codersdk.ChatMessagePart), func(), error) { + var attempt int64 + var committed database.Chat + err := machine.Update(ctx, func(tx *chatstate.Tx) error { + locked, err := tx.Store().GetChatByID(ctx, input.ChatID) + if errors.Is(err, sql.ErrNoRows) { + return errTaskExpectedExit + } + if err != nil { + return xerrors.Errorf("load chat: %w", err) + } + if err := verifyTaskFence(locked, input, database.ChatStatusRunning, taskFenceOptions{requireHistory: true}); err != nil { + return err + } + result, err := tx.RecordGenerationAttempt(chatstate.RecordGenerationAttemptInput{}) + if err != nil { + return err + } + attempt = result.GenerationAttempt + committed, err = tx.Store().GetChatByID(ctx, input.ChatID) + if err != nil { + return xerrors.Errorf("load committed chat: %w", err) + } + return nil + }) + if err != nil { + return 0, messagepartbuffer.Key{}, nil, nil, normalizeTaskTransitionError(err, "record generation attempt") + } + key := messagepartbuffer.Key{ + ChatID: input.ChatID, + HistoryVersion: committed.HistoryVersion, + GenerationAttempt: attempt, + } + if err := s.opts.MessagePartBuffer.CreateEpisode(key); err != nil && ctx.Err() == nil { + return 0, messagepartbuffer.Key{}, nil, nil, taskRetryableError{err: xerrors.Errorf("create message part episode: %w", err)} + } + publish := func(role codersdk.ChatMessageRole, part codersdk.ChatMessagePart) { + _ = s.opts.MessagePartBuffer.AddPart(key, role, part) + } + closeEpisode := func() { + _ = s.opts.MessagePartBuffer.CloseEpisode(key) + } + _ = committed + return attempt, key, publish, closeEpisode, nil +} + +func (s *taskStarter) commitGenerationStep( + ctx context.Context, + machine *chatstate.ChatMachine, + input chatWorkerTaskStartInput, + attempt int64, + kind generationActionKind, + messages stepMessagesForCommit, +) error { + if len(messages.Messages) == 0 { + return s.finishGenerationTurn(ctx, machine, input, attempt, generationDecision{kind: generationActionFinishTurn, finishReason: generationFinishReasonComplete}, generationAttemptRequired) + } + var committed database.Chat + insertedMessages := []runnerActionMessage{} + err := machine.Update(ctx, func(tx *chatstate.Tx) error { + locked, err := tx.Store().GetChatByID(ctx, input.ChatID) + if errors.Is(err, sql.ErrNoRows) { + return errTaskExpectedExit + } + if err != nil { + return xerrors.Errorf("load chat: %w", err) + } + if err := verifyGenerationFence(locked, input, attempt); err != nil { + return err + } + commitResult, err := tx.CommitStep(chatstate.CommitStepInput{Messages: messages.Messages}) + if err != nil { + return err + } + insertedMessages = make([]runnerActionMessage, 0, len(commitResult.InsertedMessages)) + for _, msg := range commitResult.InsertedMessages { + insertedMessages = append(insertedMessages, runnerActionMessage{ID: msg.ID, Role: codersdk.ChatMessageRole(msg.Role)}) + } + committed, err = tx.Store().GetChatByID(ctx, input.ChatID) + if err != nil { + return xerrors.Errorf("load committed chat: %w", err) + } + return nil + }) + if err != nil { + return normalizeTaskTransitionError(err, "commit generation step") + } + s.routeStateHint(ctx, stateUpdateFromChat(committed)) + return s.afterGenerationOutcome(ctx, generationOutcome{ + Chat: committed, + Kind: runnerActionKind(kind), + InsertedMessages: insertedMessages, + }) +} + +func (s *taskStarter) enterRequiresAction( + ctx context.Context, + machine *chatstate.ChatMachine, + input chatWorkerTaskStartInput, + decision generationDecision, +) error { + var committed database.Chat + err := machine.Update(ctx, func(tx *chatstate.Tx) error { + locked, err := tx.Store().GetChatByID(ctx, input.ChatID) + if errors.Is(err, sql.ErrNoRows) { + return errTaskExpectedExit + } + if err != nil { + return xerrors.Errorf("load chat: %w", err) + } + if err := verifyTaskFence(locked, input, database.ChatStatusRunning, taskFenceOptions{requireHistory: true}); err != nil { + return err + } + if _, err := tx.EnterRequiresAction(chatstate.EnterRequiresActionInput{}); err != nil { + return err + } + committed, err = tx.Store().GetChatByID(ctx, input.ChatID) + if err != nil { + return xerrors.Errorf("load committed chat: %w", err) + } + return nil + }) + if err != nil { + return normalizeTaskTransitionError(err, "enter requires action") + } + if err := s.publishWatchAndRoute(ctx, committed, codersdk.ChatWatchEventKindActionRequired); err != nil { + return err + } + _ = decision + return s.afterGenerationOutcome(ctx, generationOutcome{ + Chat: committed, + Kind: runnerActionKindEnterRequiresAction, + WatchEventKind: codersdk.ChatWatchEventKindActionRequired, + }) +} + +type generationAttemptFence int + +const ( + generationAttemptNotRequired generationAttemptFence = iota + generationAttemptRequired +) + +func (s *taskStarter) finishGenerationTurn( + ctx context.Context, + machine *chatstate.ChatMachine, + input chatWorkerTaskStartInput, + attempt int64, + decision generationDecision, + attemptFence generationAttemptFence, +) error { + var committed database.Chat + err := machine.Update(ctx, func(tx *chatstate.Tx) error { + locked, err := tx.Store().GetChatByID(ctx, input.ChatID) + if errors.Is(err, sql.ErrNoRows) { + return errTaskExpectedExit + } + if err != nil { + return xerrors.Errorf("load chat: %w", err) + } + if attemptFence == generationAttemptRequired { + if err := verifyGenerationFence(locked, input, attempt); err != nil { + return err + } + } else if err := verifyTaskFence(locked, input, database.ChatStatusRunning, taskFenceOptions{requireHistory: true}); err != nil { + return err + } + finishResult, err := tx.FinishTurn(chatstate.FinishTurnInput{}) + if err != nil { + return err + } + if finishResult.PromotedMessage != nil { + decision.promotedMessageID = finishResult.PromotedMessage.ID + } + committed = finishResult.Chat + if committed.ID == uuid.Nil { + committed, err = tx.Store().GetChatByID(ctx, input.ChatID) + } else if refreshed, refreshErr := tx.Store().GetChatByID(ctx, input.ChatID); refreshErr == nil { + committed.LastTurnSummary = refreshed.LastTurnSummary + } else { + err = refreshErr + } + if err != nil { + return xerrors.Errorf("load committed chat: %w", err) + } + return nil + }) + if err != nil { + return normalizeTaskTransitionError(err, "finish generation turn") + } + watchCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), postCommitWatchPublishTimeout) + defer cancel() + if err := s.publishWatchWithRetry(watchCtx, committed, codersdk.ChatWatchEventKindStatusChange); err != nil { + return err + } + if err := s.afterGenerationOutcome(ctx, generationOutcome{ + Chat: committed, + Kind: runnerActionKindFinishTurn, + WatchEventKind: codersdk.ChatWatchEventKindStatusChange, + PromotedMessageID: decision.promotedMessageID, + }); err != nil { + return err + } + s.routeStateHint(ctx, stateUpdateFromChat(committed)) + return nil +} + +func (s *taskStarter) finishGenerationError( + ctx context.Context, + machine *chatstate.ChatMachine, + input chatWorkerTaskStartInput, + attempt int64, + cause error, + attemptFence generationAttemptFence, +) error { + lastError, message := generationLastError(cause) + var committed database.Chat + err := machine.Update(ctx, func(tx *chatstate.Tx) error { + locked, err := tx.Store().GetChatByID(ctx, input.ChatID) + if errors.Is(err, sql.ErrNoRows) { + return errTaskExpectedExit + } + if err != nil { + return xerrors.Errorf("load chat: %w", err) + } + if attemptFence == generationAttemptRequired { + if err := verifyGenerationFence(locked, input, attempt); err != nil { + return err + } + } else if err := verifyTaskFence(locked, input, database.ChatStatusRunning, taskFenceOptions{requireHistory: true}); err != nil { + return err + } + if _, err := tx.FinishError(chatstate.FinishErrorInput{LastError: lastError}); err != nil { + return err + } + committed, err = tx.Store().GetChatByID(ctx, input.ChatID) + if err != nil { + return xerrors.Errorf("load committed chat: %w", err) + } + return nil + }) + if err != nil { + return normalizeTaskTransitionError(err, "finish generation error") + } + if err := s.publishWatchAndRoute(ctx, committed, codersdk.ChatWatchEventKindStatusChange); err != nil { + return err + } + return s.afterGenerationOutcome(ctx, generationOutcome{ + Chat: committed, + Kind: runnerActionKindFinishError, + WatchEventKind: codersdk.ChatWatchEventKindStatusChange, + LastError: message, + }) +} + +func generationLastError(err error) (pqtype.NullRawMessage, string) { + if err == nil { + return pqtype.NullRawMessage{}, "" + } + classified := chaterror.Classify(err) + payload := chaterror.TerminalErrorPayload(classified) + if payload == nil { + payload = &codersdk.ChatError{Message: err.Error()} + } + encoded, marshalErr := json.Marshal(payload) + if marshalErr != nil { + return pqtype.NullRawMessage{}, payload.Message + } + return pqtype.NullRawMessage{RawMessage: encoded, Valid: true}, payload.Message +} + +func (s *taskStarter) afterGenerationOutcome(ctx context.Context, outcome generationOutcome) error { + if s.server == nil { + return nil + } + if err := s.server.afterGenerationOutcome(ctx, outcome); err != nil { + return taskRetryableError{err: xerrors.Errorf("generation post-outcome side effects: %w", err)} + } + return nil +} + +func verifyGenerationFence(chat database.Chat, input chatWorkerTaskStartInput, attempt int64) error { + if err := verifyTaskFence(chat, input, database.ChatStatusRunning, taskFenceOptions{requireHistory: true}); err != nil { + return err + } + if chat.GenerationAttempt != attempt { + return errTaskExpectedExit + } + return nil +} + +func stepDataFromPersisted(step chatloop.PersistedStep) stepData { + return stepData{ + Content: step.Content, + Usage: step.Usage, + ContextLimit: step.ContextLimit, + ProviderResponseID: step.ProviderResponseID, + Runtime: step.Runtime, + ToolCallCreatedAt: step.ToolCallCreatedAt, + ToolResultCreatedAt: step.ToolResultCreatedAt, + ReasoningStartedAt: step.ReasoningStartedAt, + ReasoningCompletedAt: step.ReasoningCompletedAt, + } +} diff --git a/coderd/x/chatd/generation_preparer.go b/coderd/x/chatd/generation_preparer.go new file mode 100644 index 0000000000..d5b2fb857a --- /dev/null +++ b/coderd/x/chatd/generation_preparer.go @@ -0,0 +1,804 @@ +package chatd + +import ( + "context" + "encoding/json" + "slices" + "strings" + "sync" + + "charm.land/fantasy" + "github.com/google/uuid" + "golang.org/x/sync/errgroup" + "golang.org/x/xerrors" + + "cdr.dev/slog/v3" + "github.com/coder/coder/v2/coderd/aibridge" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/x/chatd/chatadvisor" + "github.com/coder/coder/v2/coderd/x/chatd/chatloop" + "github.com/coder/coder/v2/coderd/x/chatd/chatopenai" + "github.com/coder/coder/v2/coderd/x/chatd/chatprompt" + "github.com/coder/coder/v2/coderd/x/chatd/chatprovider" + "github.com/coder/coder/v2/coderd/x/chatd/chatsanitize" + "github.com/coder/coder/v2/coderd/x/chatd/chattool" + "github.com/coder/coder/v2/coderd/x/chatd/mcpclient" + skillspkg "github.com/coder/coder/v2/coderd/x/skills" + "github.com/coder/coder/v2/codersdk" +) + +func (server *Server) prepareGeneration( + ctx context.Context, + input generationPrepareInput, +) (generationPrepared, error) { + chat := input.Chat + logger := server.logger.With( + slog.F("chat_id", chat.ID), + slog.F("owner_id", chat.OwnerID), + ) + + var ( + model fantasy.LanguageModel + modelConfig database.ChatModelConfig + providerKeys chatprovider.ProviderAPIKeys + modelRoute resolvedModelRoute + modelOpts modelBuildOptions + callConfig codersdk.ChatModelCallConfig + promptRows []database.ChatMessage + mcpConfigs []database.MCPServerConfig + mcpTokens []database.MCPServerUserToken + debugEnabled bool + debugProvider string + debugModel string + ) + + var g errgroup.Group + g.Go(func() error { + var err error + promptRows, err = server.db.GetChatMessagesForPromptByChatID(ctx, chat.ID) + if err != nil { + return xerrors.Errorf("get chat messages for prompt: %w", err) + } + return nil + }) + if len(chat.MCPServerIDs) > 0 { + g.Go(func() error { + var err error + mcpConfigs, err = server.db.GetMCPServerConfigsByIDs(ctx, chat.MCPServerIDs) + if err != nil { + logger.Warn(ctx, "failed to load MCP server configs", slog.Error(err)) + } + return nil + }) + g.Go(func() error { + var err error + mcpTokens, err = server.db.GetMCPServerUserTokensByUserID(ctx, chat.OwnerID) + if err != nil { + logger.Warn(ctx, "failed to load MCP user tokens", slog.Error(err)) + } + return nil + }) + } + if err := g.Wait(); err != nil { + return generationPrepared{}, err + } + + modelOpts = modelBuildOptionsFromMessages(promptRows) + if modelOpts.ActiveAPIKeyID != "" { + ctx = aibridge.WithDelegatedAPIKeyID(ctx, modelOpts.ActiveAPIKeyID) + } + + var err error + model, modelConfig, providerKeys, modelRoute, debugEnabled, debugProvider, debugModel, err = server.resolveChatModel(ctx, chat, modelOpts) + if err != nil { + return generationPrepared{}, err + } + if len(modelConfig.Options) > 0 { + if err := json.Unmarshal(modelConfig.Options, &callConfig); err != nil { + return generationPrepared{}, xerrors.Errorf("parse model call config: %w", err) + } + } + + if callConfig.MaxOutputTokens == nil { + maxOutputTokens := int64(32_000) + callConfig.MaxOutputTokens = &maxOutputTokens + } + + currentPlanMode := chat.PlanMode + isPlanModeTurn := currentPlanMode.Valid && currentPlanMode.ChatPlanMode == database.ChatPlanModePlan + isExploreSubagent := isExploreSubagentMode(chat.Mode) + isRootChat := !chat.ParentChatID.Valid + + mcpConnectConfigs, approvedPlanMCPConfigIDs := filterExternalMCPConfigsForTurn( + mcpConfigs, + currentPlanMode, + chat.ParentChatID, + ) + if isExploreSubagent && isRootChat { + mcpConnectConfigs = nil + approvedPlanMCPConfigIDs = map[uuid.UUID]struct{}{} + } + + planModeInstructions := server.loadPlanModeInstructions(ctx, currentPlanMode, logger) + advisorCfg := server.loadAdvisorConfig(ctx, logger) + + var advisorRuntime *chatadvisor.Runtime + if advisorCfg.Enabled && isRootChat && !isPlanModeTurn && !isExploreSubagent { + var advisorErr error + advisorRuntime, advisorErr = server.newAdvisorRuntime( + ctx, + chat, + advisorCfg, + model, + callConfig, + providerKeys, + modelOpts, + logger, + ) + if advisorErr != nil { + return generationPrepared{}, advisorErr + } + } + + var advisorPromptSnapshot []fantasy.Message + setAdvisorPromptSnapshot := func(msgs []fantasy.Message) { + if advisorRuntime == nil { + return + } + advisorPromptSnapshot = slices.Clone(msgs) + } + + currentChat := chat + loadChatSnapshot := func(loadCtx context.Context, chatID uuid.UUID) (database.Chat, error) { + return server.db.GetChatByID(loadCtx, chatID) + } + var chatStateMu sync.Mutex + var workspaceMu sync.Mutex + workspaceCtx := turnWorkspaceContext{ + server: server, + chatStateMu: &chatStateMu, + currentChat: ¤tChat, + loadChatSnapshot: loadChatSnapshot, + } + cleanup := func() { + workspaceCtx.close() + } + + planPathFn := func(ctx context.Context) (string, string, error) { + conn, err := workspaceCtx.getWorkspaceConn(ctx) + if err != nil { + return "", "", err + } + home, err := chattool.ResolveWorkspaceHome(ctx, conn) + if err != nil { + return "", "", err + } + return chattool.PlanPathForChat(home, chat.ID), home, nil + } + resolvePlanPathForTools := func(ctx context.Context) (string, string, error) { + planCtx, cancel := context.WithTimeout(ctx, planPathLookupTimeout) + defer cancel() + return planPathFn(planCtx) + } + resolvePlanPathBlock := func(resolveCtx context.Context) string { + if chat.ParentChatID.Valid { + return "" + } + + planCtx, cancel := context.WithTimeout(resolveCtx, planPathLookupTimeout) + defer cancel() + + if _, _, err := workspaceCtx.workspaceAgentIDForConn(planCtx); err != nil { + logger.Debug(resolveCtx, "plan path instruction: agent not reachable", + slog.Error(err), + slog.F("chat_id", chat.ID), + ) + return "" + } + + planPath, home, err := planPathFn(planCtx) + if err != nil { + logger.Debug(resolveCtx, "plan path instruction: failed to resolve plan path", + slog.Error(err), + slog.F("chat_id", chat.ID), + ) + return "" + } + return formatPlanPathBlock(planPath, home) + } + + var ( + prompt []fantasy.Message + instruction string + mcpTools []fantasy.AgentTool + mcpCleanup func() + workspaceMCPTools []fantasy.AgentTool + workspaceSkills []chattool.SkillMeta + personalSkills []skillspkg.Skill + resolvedUserPrompt string + ) + + persistedSkills := skillsFromParts(promptRows) + hasContextFiles := false + if chat.WorkspaceID.Valid { + // Resolve the workspace agent so the chat row's AgentID and + // BuildID bindings are up to date before the chatworker + // decision helper inspects them. ensureWorkspaceAgent does a + // DB lookup and lazily calls persistBuildAgentBinding when + // the bound agent has changed, so this is a cheap metadata + // refresh, not a workspace dial. It must not insert chat + // history; only metadata is mutated here. + _, _ = workspaceCtx.getWorkspaceAgent(ctx) + _, found := contextFileAgentID(promptRows) + hasContextFiles = found + } + + var g2 errgroup.Group + g2.Go(func() error { + var err error + prompt, err = chatprompt.ConvertMessagesWithFiles(ctx, promptRows, server.chatFileResolver(modelConfig.Provider), logger) + if err != nil { + return xerrors.Errorf("build chat prompt: %w", err) + } + return nil + }) + if hasContextFiles { + instruction = instructionFromContextFiles(promptRows) + workspaceSkills = persistedSkills + } + g2.Go(func() error { + personalSkills = server.fetchPersonalSkillMetadata(ctx, chat.OwnerID, logger) + return nil + }) + g2.Go(func() error { + resolvedUserPrompt = server.resolveUserPrompt(ctx, chat.OwnerID) + return nil + }) + if len(mcpConnectConfigs) > 0 { + g2.Go(func() error { + mcpTokens = server.refreshExpiredMCPTokens(ctx, logger, mcpConnectConfigs, mcpTokens) + mcpTools, mcpCleanup = mcpclient.ConnectAll( + ctx, + logger, + mcpConnectConfigs, + mcpTokens, + chat.OwnerID, + server.oidcTokenSource, + chatprovider.CoderHeaders(chat), + ) + return nil + }) + } + if chat.WorkspaceID.Valid && !isPlanModeTurn && !isExploreSubagent { + g2.Go(func() error { + workspaceMCPTools = server.discoverWorkspaceMCPTools(ctx, logger, chat.ID, &workspaceCtx) + return nil + }) + } + if err := g2.Wait(); err != nil { + cleanup() + return generationPrepared{}, err + } + + if mcpCleanup != nil { + previousCleanup := cleanup + cleanup = func() { + mcpCleanup() + previousCleanup() + } + } + + prompt, sanitizeStats := chatsanitize.SanitizeAnthropicProviderToolHistory(model.Provider(), prompt) + chatsanitize.LogAnthropicProviderToolSanitization( + ctx, + logger, + "persisted_history_replay", + model.Provider(), + model.Model(), + sanitizeStats, + ) + + subagentInstruction := "" + if !isRootChat { + subagentInstruction = defaultSubagentInstruction + } + resolvedSkillsFor := func(workspaceSkills []chattool.SkillMeta) []skillspkg.ResolvedSkill { + return mergeTurnSkills(personalSkills, workspaceSkills) + } + resolveSkillAlias := func(alias string) (skillspkg.ResolvedSkill, error) { + return skillspkg.Lookup(resolvedSkillsFor(workspaceSkills), alias) + } + initialResolvedSkills := resolvedSkillsFor(workspaceSkills) + + prompt = buildSystemPrompt( + prompt, + subagentInstruction, + instruction, + initialResolvedSkills, + resolvedUserPrompt, + systemPromptBehaviorContext{ + planMode: currentPlanMode, + chatMode: chat.Mode, + planModeInstructions: planModeInstructions, + isRootChat: isRootChat, + }, + ) + if advisorRuntime != nil { + prompt = chatprompt.InsertSystem(prompt, chatadvisor.ParentGuidanceBlock) + } + prompt = renderPlanPathPrompt(prompt, resolvePlanPathBlock(ctx)) + setAdvisorPromptSnapshot(prompt) + + storeChatAttachment := server.newStoreChatAttachmentFunc(&workspaceCtx) + tools := []fantasy.AgentTool{ + chattool.ReadFile(chattool.ReadFileOptions{GetWorkspaceConn: workspaceCtx.getWorkspaceConn}), + chattool.WriteFile(chattool.WriteFileOptions{ + GetWorkspaceConn: workspaceCtx.getWorkspaceConn, + ResolvePlanPath: resolvePlanPathForTools, + IsPlanTurn: isPlanModeTurn, + }), + chattool.EditFiles(chattool.EditFilesOptions{ + GetWorkspaceConn: workspaceCtx.getWorkspaceConn, + ResolvePlanPath: resolvePlanPathForTools, + IsPlanTurn: isPlanModeTurn, + }), + chattool.AttachFile(chattool.AttachFileOptions{ + GetWorkspaceConn: workspaceCtx.getWorkspaceConn, + StoreFile: storeChatAttachment, + }), + chattool.Execute(chattool.ExecuteOptions{GetWorkspaceConn: workspaceCtx.getWorkspaceConn}), + chattool.ProcessOutput(chattool.ProcessToolOptions{GetWorkspaceConn: workspaceCtx.getWorkspaceConn}), + chattool.ProcessList(chattool.ProcessToolOptions{GetWorkspaceConn: workspaceCtx.getWorkspaceConn}), + chattool.ProcessSignal(chattool.ProcessToolOptions{GetWorkspaceConn: workspaceCtx.getWorkspaceConn}), + } + if isPlanModeTurn && isRootChat { + tools = append(tools, chattool.NewAskUserQuestionTool()) + } + if isRootChat { + tools = server.appendRootChatTools(ctx, tools, rootChatToolsOptions{ + chat: chat, + modelConfigID: modelConfig.ID, + workspaceCtx: &workspaceCtx, + workspaceMu: &workspaceMu, + resolvePlanPath: resolvePlanPathForTools, + storeFile: storeChatAttachment, + isPlanModeTurn: isPlanModeTurn, + primerCtx: ctx, + }) + } + + skillOpts := chattool.ReadSkillOptions{ + GetWorkspaceConn: workspaceCtx.getWorkspaceConn, + GetSkills: func() []chattool.SkillMeta { + return workspaceSkills + }, + ResolveAlias: resolveSkillAlias, + LoadPersonalSkillBody: func(ctx context.Context, name string) (skillspkg.ParsedSkill, error) { + return server.loadPersonalSkillBody(ctx, chat.OwnerID, name) + }, + } + appendCurrentSkillTools := func(current []fantasy.AgentTool) ([]fantasy.AgentTool, bool) { + if len(personalSkills) == 0 && len(workspaceSkills) == 0 { + return current, false + } + updated := current + changed := false + appendTool := func(tool fantasy.AgentTool) { + name := tool.Info().Name + if slices.ContainsFunc(current, func(existing fantasy.AgentTool) bool { + return existing.Info().Name == name + }) { + return + } + if !changed { + updated = slices.Clone(current) + changed = true + } + updated = append(updated, tool) + } + appendTool(chattool.ReadSkill(skillOpts)) + if len(workspaceSkills) > 0 { + appendTool(chattool.ReadSkillFile(skillOpts)) + } + return updated, changed + } + tools, _ = appendCurrentSkillTools(tools) + if advisorRuntime != nil { + var publishAdviceDelta func(string, string) + var publishAdviceReset func(string) + if input.PublishMessagePart != nil { + publishAdviceDelta = func(toolCallID string, delta string) { + if toolCallID == "" || delta == "" { + return + } + input.PublishMessagePart(codersdk.ChatMessageRoleTool, codersdk.ChatMessagePart{ + Type: codersdk.ChatMessagePartTypeToolResult, + ToolCallID: toolCallID, + ToolName: chatadvisor.ToolName, + ResultDelta: delta, + }) + } + publishAdviceReset = func(toolCallID string) { + if toolCallID == "" { + return + } + input.PublishMessagePart(codersdk.ChatMessageRoleTool, codersdk.ChatMessagePart{ + Type: codersdk.ChatMessagePartTypeToolResult, + ToolCallID: toolCallID, + ToolName: chatadvisor.ToolName, + ResultReset: true, + }) + } + } + tools = append(tools, chatadvisor.Tool(chatadvisor.ToolOptions{ + Runtime: advisorRuntime, + GetConversationSnapshot: func() []fantasy.Message { + return stripAdvisorGuidanceBlock(slices.Clone(advisorPromptSnapshot)) + }, + PublishAdviceDelta: publishAdviceDelta, + PublishAdviceReset: publishAdviceReset, + })) + } + + var exclusiveToolNames map[string]bool + if advisorRuntime != nil { + exclusiveToolNames = map[string]bool{chatadvisor.ToolName: true} + } + + builtinToolNames := make(map[string]bool, len(tools)) + for _, t := range tools { + builtinToolNames[t.Info().Name] = true + } + + tools = append(tools, mcpTools...) + if !isExploreSubagent { + tools = append(tools, workspaceMCPTools...) + } + tools = filterToolsForTurn(tools, currentPlanMode, chat.ParentChatID, approvedPlanMCPConfigIDs) + + tools, dynamicToolNames, err := appendDynamicTools(ctx, logger, tools, chat.DynamicTools, currentPlanMode, chat.Mode) + if err != nil { + cleanup() + return generationPrepared{}, err + } + + var providerTools []chatloop.ProviderTool + if !isPlanModeTurn && callConfig.ProviderOptions != nil { + providerTools = buildProviderTools(callConfig.ProviderOptions) + if isExploreSubagent { + if !chat.ParentChatID.Valid { + providerTools = nil + } else { + providerTools = slices.DeleteFunc(providerTools, func(tool chatloop.ProviderTool) bool { + return tool.Definition.GetName() != "web_search" + }) + } + } + } + + isComputerUse := chat.Mode.Valid && chat.Mode.ChatMode == database.ChatModeComputerUse + if isComputerUse { + computerUseProvider, computerUseModelProvider, computerUseModelName, err := server.computerUseProviderAndModelFromConfig(ctx) + if err != nil { + cleanup() + return generationPrepared{}, xerrors.Errorf("resolve computer use provider and model: %w", err) + } + computerUseRoute, keyErr := server.resolveModelRouteForProviderType(ctx, chat.OwnerID, computerUseModelProvider) + if keyErr != nil { + cleanup() + return generationPrepared{}, xerrors.Errorf("resolve computer use provider route: %w", keyErr) + } + modelRoute = computerUseRoute + providerKeys = computerUseRoute.directProviderKeys() + cuModel, cuDebugEnabled, resolvedProvider, resolvedModel, cuErr := server.resolveComputerUseModel( + ctx, + chat, + computerUseRoute, + computerUseProvider, + computerUseModelProvider, + computerUseModelName, + modelOpts, + ) + if cuErr != nil { + cleanup() + return generationPrepared{}, cuErr + } + model = cuModel + debugEnabled = cuDebugEnabled + debugProvider = resolvedProvider + debugModel = resolvedModel + providerTools, err = appendComputerUseProviderTool(providerTools, computerUseProviderToolOptions{ + provider: computerUseProvider, + isPlanModeTurn: isPlanModeTurn, + isComputerUse: isComputerUse, + getWorkspaceConn: workspaceCtx.getWorkspaceConn, + storeFile: storeChatAttachment, + clock: server.clock, + logger: server.logger.Named("computer_use"), + }) + if err != nil { + cleanup() + return generationPrepared{}, xerrors.Errorf("register computer use provider tool for provider %q: %w", computerUseProvider, err) + } + } else { + providerTools, err = appendComputerUseProviderTool(providerTools, computerUseProviderToolOptions{ + isPlanModeTurn: isPlanModeTurn, + isComputerUse: false, + }) + if err != nil { + cleanup() + return generationPrepared{}, err + } + } + + providerOptions := chatprovider.ProviderOptionsFromChatModelConfig(model, callConfig.ProviderOptions) + chainInfo := chatopenai.ResolveChainMode(promptRows) + if !input.ChainModeDisabled && chatopenai.ShouldActivateChainMode( + providerOptions, + chainInfo, + modelConfig.ID, + isPlanModeTurn, + ) { + providerOptions = chatopenai.WithPreviousResponseID(providerOptions, chainInfo.PreviousResponseID()) + prompt = chatopenai.FilterPromptForChainMode(prompt, chainInfo) + } + + activeToolNames := activeToolNamesForTurn(tools, currentPlanMode, chat.ParentChatID, approvedPlanMCPConfigIDs) + if isExploreSubagent { + activeToolNames = allowedExploreToolNames(tools) + } + + toolNameToConfigID := make(map[string]uuid.UUID) + for _, t := range tools { + if mcpTool, ok := t.(mcpclient.MCPToolIdentifier); ok { + toolNameToConfigID[t.Info().Name] = mcpTool.MCPServerConfigID() + } + } + + triggerMessageID, historyTipMessageID, triggerLabel := deriveChatDebugSeed(promptRows) + debugSvc := server.existingDebugService() + if debugEnabled { + if debugSvc == nil { + cleanup() + return generationPrepared{}, xerrors.New("chat debug service missing after enablement check") + } + var finishDebugRun func(error, any) + ctx, finishDebugRun = prepareChatTurnDebugRun( + ctx, + logger, + chat, + modelConfig, + debugSvc, + debugProvider, + debugModel, + triggerMessageID, + historyTipMessageID, + triggerLabel, + ) + previousCleanup := cleanup + cleanup = func() { + finishDebugRun(nil, nil) + previousCleanup() + } + } + + compactionToolCallID := "chat_summarized_" + uuid.NewString() + effectiveThreshold := modelConfig.CompressionThreshold + if override, ok := server.resolveUserCompactionThreshold(ctx, chat.OwnerID, modelConfig.ID); ok { + effectiveThreshold = override + } + compactionOptions := chatloop.GenerateCompactionOptions{ + Model: model, + Messages: prompt, + ThresholdPercent: effectiveThreshold, + ContextLimit: modelConfig.ContextLimit, + ContextLimitFallback: modelConfig.ContextLimit, + ToolCallID: compactionToolCallID, + ToolName: "chat_summarized", + DebugSvc: debugSvc, + ChatID: chat.ID, + HistoryTipMessageID: historyTipMessageID, + } + compactionOptions.StepUsage = latestPromptUsage(promptRows) + compactionNeeded := shouldCompactPromptUsage(compactionOptions.StepUsage, modelConfig.ContextLimit, effectiveThreshold) + + workspaceContextEligible := chat.WorkspaceID.Valid && isRootChat && !isPlanModeTurn && !isExploreSubagent + + // workspaceCtx.currentChatSnapshot may carry a freshly persisted + // AgentID/BuildID binding from the getWorkspaceAgent call above. + // Return that snapshot so the chatworker decision helper sees + // the up-to-date metadata when deciding whether to run + // persist_workspace_context. + refreshedChat := workspaceCtx.currentChatSnapshot() + if refreshedChat.ID == uuid.Nil { + refreshedChat = chat + } + + return generationPrepared{ + Chat: refreshedChat, + Messages: input.Messages, + Model: model, + Prompt: prompt, + Tools: tools, + ActiveTools: activeToolNames, + ProviderTools: providerTools, + ProviderKeys: providerKeys, + ModelRoute: modelRoute, + ModelBuildOptions: modelOpts, + ModelConfigID: modelConfig.ID, + ModelConfig: callConfig, + ProviderOptions: providerOptions, + ContextLimitFallback: modelConfig.ContextLimit, + DynamicToolNames: dynamicToolNames, + StopAfterTools: stopAfterBehaviorTools(currentPlanMode, chat.Mode, chat.ParentChatID), + ExclusiveToolNames: exclusiveToolNames, + BuiltinToolNames: builtinToolNames, + ToolNameToConfigID: toolNameToConfigID, + MaxSteps: maxChatSteps, + Compaction: &generationCompaction{ + Required: compactionNeeded, + Options: compactionOptions, + }, + Cleanup: cleanup, + WorkspaceContextEligible: workspaceContextEligible, + }, nil +} + +func latestPromptUsage(messages []database.ChatMessage) fantasy.Usage { + for i := len(messages) - 1; i >= 0; i-- { + usage := fantasy.Usage{ + InputTokens: messages[i].InputTokens.Int64, + OutputTokens: messages[i].OutputTokens.Int64, + TotalTokens: messages[i].TotalTokens.Int64, + ReasoningTokens: messages[i].ReasoningTokens.Int64, + CacheCreationTokens: messages[i].CacheCreationTokens.Int64, + CacheReadTokens: messages[i].CacheReadTokens.Int64, + } + if usage != (fantasy.Usage{}) { + return usage + } + } + return fantasy.Usage{} +} + +func shouldCompactPromptUsage(usage fantasy.Usage, contextLimit int64, thresholdPercent int32) bool { + if thresholdPercent >= 100 || contextLimit <= 0 { + return false + } + contextTokens := contextTokensFromUsage(usage) + if contextTokens <= 0 { + return false + } + usagePercent := (float64(contextTokens) / float64(contextLimit)) * 100 + return usagePercent >= float64(thresholdPercent) +} + +func contextTokensFromUsage(usage fantasy.Usage) int64 { + total := int64(0) + hasContextTokens := false + if usage.InputTokens > 0 { + total += usage.InputTokens + hasContextTokens = true + } + if usage.CacheReadTokens > 0 { + total += usage.CacheReadTokens + hasContextTokens = true + } + if usage.CacheCreationTokens > 0 { + total += usage.CacheCreationTokens + hasContextTokens = true + } + if !hasContextTokens && usage.TotalTokens > 0 { + total = usage.TotalTokens + } + return total +} + +func (server *Server) afterInterruptionOutcome( + ctx context.Context, + outcome interruptionOutcome, +) error { + chat := outcome.Chat + logger := server.logger.With(slog.F("chat_id", chat.ID), slog.F("owner_id", chat.OwnerID)) + + if outcome.Kind == runnerActionKindFinishInterruption { + server.maybeClearLastTurnSummaryAsync(context.WithoutCancel(ctx), chat, logger) + } + return nil +} + +func (server *Server) afterGenerationOutcome( + ctx context.Context, + outcome generationOutcome, +) error { + chat := outcome.Chat + logger := server.logger.With(slog.F("chat_id", chat.ID), slog.F("owner_id", chat.OwnerID)) + + switch outcome.Kind { + case runnerActionKindFinishTurn: + finalizeCtx := context.WithoutCancel(ctx) + runResult := server.deriveFinalTurnRunResult(finalizeCtx, chat, logger) + statusLabel := server.generateFinalTurnStatusLabel(finalizeCtx, chat, chat.Status, runResult, logger) + server.updateLastTurnSummary(finalizeCtx, chat, chat.HistoryVersion, statusLabel, logger) + server.dispatchSuccessfulTurnPush(finalizeCtx, chat, statusLabel, logger) + case runnerActionKindFinishError: + server.maybeFinalizeTurnStatusLabelAndPush(context.WithoutCancel(ctx), chat, chat.Status, outcome.LastError, runChatResult{}, logger) + case runnerActionKindEnterRequiresAction: + server.maybeFinalizeTurnStatusLabelAndPush(context.WithoutCancel(ctx), chat, chat.Status, "", runChatResult{}, logger) + } + return nil +} + +// deriveFinalTurnRunResult rebuilds the inputs needed to generate the +// end-of-turn status label directly from persisted state. +func (server *Server) deriveFinalTurnRunResult( + ctx context.Context, + chat database.Chat, + logger slog.Logger, +) runChatResult { + // generateFinalTurnStatusLabel only produces a model-generated label for + // the Waiting status, so skip the model resolution and history read + // otherwise. + if chat.Status != database.ChatStatusWaiting { + return runChatResult{} + } + + promptRows, err := server.db.GetChatMessagesForPromptByChatID(ctx, chat.ID) + if err != nil { + logger.Warn(ctx, "derive final turn status label: load prompt rows", slog.Error(err)) + return runChatResult{} + } + triggerMessageID, historyTipMessageID, _ := deriveChatDebugSeed(promptRows) + finalAssistantText := latestAssistantText(promptRows) + if finalAssistantText == "" { + return runChatResult{} + } + + // resolvedProvider/resolvedModel describe the model the fallback handle was + // built from; they only feed the status-label fallback candidate's labels. + modelOpts := modelBuildOptionsFromMessages(promptRows) + if modelOpts.ActiveAPIKeyID != "" { + ctx = aibridge.WithDelegatedAPIKeyID(ctx, modelOpts.ActiveAPIKeyID) + } + model, _, providerKeys, modelRoute, _, resolvedProvider, resolvedModel, err := server.resolveChatModel(ctx, chat, modelOpts) + if err != nil { + // Return what we have; generateFinalTurnStatusLabel falls back to a + // generic label when StatusLabelModel is nil. + logger.Warn(ctx, "derive final turn status label: resolve model", slog.Error(err)) + return runChatResult{ + FinalAssistantText: finalAssistantText, + TriggerMessageID: triggerMessageID, + HistoryTipMessageID: historyTipMessageID, + } + } + + return runChatResult{ + FinalAssistantText: finalAssistantText, + StatusLabelModel: model, + ProviderKeys: providerKeys, + FallbackProvider: resolvedProvider, + FallbackRoute: modelRoute, + FallbackModel: resolvedModel, + ModelBuildOptions: modelOpts, + TriggerMessageID: triggerMessageID, + HistoryTipMessageID: historyTipMessageID, + } +} + +// latestAssistantText returns the trimmed text of the most recent assistant +// message. It mirrors the FinalAssistantText that buildCommitStepMessages +// produced from the freshly generated step, making persisted history the +// single source of truth for the turn status label input. +func latestAssistantText(messages []database.ChatMessage) string { + for i := len(messages) - 1; i >= 0; i-- { + if messages[i].Role != database.ChatMessageRoleAssistant { + continue + } + parts, err := chatprompt.ParseContent(messages[i]) + if err != nil { + return "" + } + return strings.TrimSpace(textFromParts(parts)) + } + return "" +} diff --git a/coderd/x/chatd/generation_preparer_internal_test.go b/coderd/x/chatd/generation_preparer_internal_test.go new file mode 100644 index 0000000000..64eb36864d --- /dev/null +++ b/coderd/x/chatd/generation_preparer_internal_test.go @@ -0,0 +1,272 @@ +package chatd //nolint:testpackage // Exercises unexported re-derivation helpers. + +import ( + "encoding/json" + "testing" + + "github.com/google/uuid" + "github.com/sqlc-dev/pqtype" + "github.com/stretchr/testify/require" + + "cdr.dev/slog/v3/sloggers/slogtest" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbgen" + "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/coderd/x/chatd/chatprompt" + "github.com/coder/coder/v2/coderd/x/chatd/chatprovider" + "github.com/coder/coder/v2/coderd/x/chatd/chatstate" + "github.com/coder/coder/v2/codersdk" +) + +func mustMarshalText(t *testing.T, parts ...string) pqtype.NullRawMessage { + t.Helper() + messageParts := make([]codersdk.ChatMessagePart, 0, len(parts)) + for _, p := range parts { + messageParts = append(messageParts, codersdk.ChatMessageText(p)) + } + content, err := chatprompt.MarshalParts(messageParts) + require.NoError(t, err) + return content +} + +func textMessage(t *testing.T, id int64, role database.ChatMessageRole, parts ...string) database.ChatMessage { + t.Helper() + return database.ChatMessage{ + ID: id, + Role: role, + Content: mustMarshalText(t, parts...), + ContentVersion: chatprompt.CurrentContentVersion, + } +} + +func TestLatestAssistantText(t *testing.T) { + t.Parallel() + + t.Run("ReturnsMostRecentAssistantMessage", func(t *testing.T) { + t.Parallel() + messages := []database.ChatMessage{ + textMessage(t, 1, database.ChatMessageRoleUser, "hi"), + textMessage(t, 2, database.ChatMessageRoleAssistant, "first answer"), + textMessage(t, 3, database.ChatMessageRoleTool, "tool result"), + textMessage(t, 4, database.ChatMessageRoleAssistant, " final answer "), + } + require.Equal(t, "final answer", latestAssistantText(messages)) + }) + + t.Run("ConcatenatesTextParts", func(t *testing.T) { + t.Parallel() + messages := []database.ChatMessage{ + textMessage(t, 1, database.ChatMessageRoleAssistant, "foo", "bar"), + } + require.Equal(t, "foobar", latestAssistantText(messages)) + }) + + t.Run("NoAssistantMessage", func(t *testing.T) { + t.Parallel() + messages := []database.ChatMessage{ + textMessage(t, 1, database.ChatMessageRoleUser, "hi"), + textMessage(t, 2, database.ChatMessageRoleTool, "tool result"), + } + require.Empty(t, latestAssistantText(messages)) + }) + + t.Run("EmptyAssistantText", func(t *testing.T) { + t.Parallel() + messages := []database.ChatMessage{ + textMessage(t, 1, database.ChatMessageRoleAssistant, " "), + } + require.Empty(t, latestAssistantText(messages)) + }) + + t.Run("EmptyHistory", func(t *testing.T) { + t.Parallel() + require.Empty(t, latestAssistantText(nil)) + }) +} + +// TestDeriveFinalTurnRunResult exercises the re-derivation path that replaces +// the old in-memory generationSideEffects stash. The server here never ran +// prepareGeneration, so a passing test proves the finish-turn inputs are +// rebuilt purely from persisted state. +func TestDeriveFinalTurnRunResult(t *testing.T) { + t.Parallel() + + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + + setup := func(t *testing.T) (*Server, database.Chat) { + t.Helper() + db, ps := dbtestutil.NewDB(t) + ctx := chatdTestContext(t) + + user := dbgen.User(t, db, database.User{}) + org := dbgen.Organization(t, db, database.Organization{}) + dbgen.OrganizationMember(t, db, database.OrganizationMember{ + UserID: user.ID, + OrganizationID: org.ID, + }) + dbgen.ChatProvider(t, db, database.ChatProvider{ + Provider: "openai", + DisplayName: "OpenAI", + APIKey: "test-key", + Enabled: true, + CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true}, + }) + modelCfg := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{ + Provider: "openai", + Model: "gpt-4o-mini", + DisplayName: "gpt-4o-mini", + Options: json.RawMessage(`{}`), + }, func(p *database.InsertChatModelConfigParams) { + p.Enabled = true + p.IsDefault = true + }) + + created, err := chatstate.CreateChat(ctx, db, ps, chatstate.CreateChatInput{ + OrganizationID: org.ID, + OwnerID: user.ID, + LastModelConfigID: modelCfg.ID, + Title: "derive-chat", + ClientType: database.ChatClientTypeUi, + InitialMessages: []chatstate.Message{ + { + Role: database.ChatMessageRoleUser, + Content: mustMarshalText(t, "what is the answer?"), + Visibility: database.ChatMessageVisibilityBoth, + ContentVersion: chatprompt.CurrentContentVersion, + CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true}, + ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true}, + }, + }, + }) + require.NoError(t, err) + + server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{}) + return server, created.Chat + } + + commitAssistant := func(t *testing.T, server *Server, chat database.Chat, text string) { + t.Helper() + ctx := chatdTestContext(t) + machine := chatstate.NewChatMachine(server.db, server.pubsub, chat.ID, chatstate.Options{}) + require.NoError(t, machine.Update(ctx, func(tx *chatstate.Tx) error { + _, err := tx.CommitStep(chatstate.CommitStepInput{ + Messages: []chatstate.Message{ + { + Role: database.ChatMessageRoleAssistant, + Content: mustMarshalText(t, text), + Visibility: database.ChatMessageVisibilityBoth, + ContentVersion: chatprompt.CurrentContentVersion, + ModelConfigID: uuid.NullUUID{UUID: chat.LastModelConfigID, Valid: true}, + }, + }, + }) + return err + })) + } + + t.Run("WaitingDerivesFromHistory", func(t *testing.T) { + t.Parallel() + server, chat := setup(t) + ctx := chatdTestContext(t) + commitAssistant(t, server, chat, "the answer is 42") + + rows, err := server.db.GetChatMessagesForPromptByChatID(ctx, chat.ID) + require.NoError(t, err) + require.NotEmpty(t, rows) + var lastUserID int64 + for _, row := range rows { + if row.Role == database.ChatMessageRoleUser { + lastUserID = row.ID + } + } + tipID := rows[len(rows)-1].ID + + chat.Status = database.ChatStatusWaiting + result := server.deriveFinalTurnRunResult(ctx, chat, logger) + + require.Equal(t, "the answer is 42", result.FinalAssistantText) + require.Equal(t, lastUserID, result.TriggerMessageID) + require.Equal(t, tipID, result.HistoryTipMessageID) + require.NotNil(t, result.StatusLabelModel) + require.Equal(t, "openai", result.FallbackProvider) + require.Equal(t, "gpt-4o-mini", result.FallbackModel) + require.False(t, result.ProviderKeys.Empty()) + }) + + t.Run("NonWaitingReturnsEmpty", func(t *testing.T) { + t.Parallel() + server, chat := setup(t) + ctx := chatdTestContext(t) + commitAssistant(t, server, chat, "the answer is 42") + + chat.Status = database.ChatStatusError + result := server.deriveFinalTurnRunResult(ctx, chat, logger) + require.Equal(t, runChatResult{}, result) + }) + + t.Run("WaitingWithoutAssistantReturnsEmpty", func(t *testing.T) { + t.Parallel() + server, chat := setup(t) + ctx := chatdTestContext(t) + + // No assistant message was committed, so there is nothing to label. + chat.Status = database.ChatStatusWaiting + result := server.deriveFinalTurnRunResult(ctx, chat, logger) + require.Equal(t, runChatResult{}, result) + }) + + t.Run("ModelResolveErrorKeepsTextAndIDs", func(t *testing.T) { + t.Parallel() + db, ps := dbtestutil.NewDB(t) + ctx := chatdTestContext(t) + + user := dbgen.User(t, db, database.User{}) + org := dbgen.Organization(t, db, database.Organization{}) + dbgen.OrganizationMember(t, db, database.OrganizationMember{ + UserID: user.ID, + OrganizationID: org.ID, + }) + // A disabled AI provider makes resolveChatModel fail, exercising the + // degraded path that still returns the re-derived text and IDs. + provider := insertInternalAIProvider(t, db, database.AiProviderTypeOpenai, "provider-api-key", false) + modelCfg := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{ + Provider: "openai", + Model: "gpt-4o-mini", + DisplayName: "gpt-4o-mini", + AIProviderID: uuid.NullUUID{UUID: provider.ID, Valid: true}, + }) + + created, err := chatstate.CreateChat(ctx, db, ps, chatstate.CreateChatInput{ + OrganizationID: org.ID, + OwnerID: user.ID, + LastModelConfigID: modelCfg.ID, + Title: "derive-chat-error", + ClientType: database.ChatClientTypeUi, + InitialMessages: []chatstate.Message{ + { + Role: database.ChatMessageRoleUser, + Content: mustMarshalText(t, "what is the answer?"), + Visibility: database.ChatMessageVisibilityBoth, + ContentVersion: chatprompt.CurrentContentVersion, + CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true}, + ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true}, + }, + }, + }) + require.NoError(t, err) + chat := created.Chat + + server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{}) + commitAssistant(t, server, chat, "the answer is 42") + + chat.Status = database.ChatStatusWaiting + result := server.deriveFinalTurnRunResult(ctx, chat, logger) + + require.Equal(t, "the answer is 42", result.FinalAssistantText) + require.NotZero(t, result.TriggerMessageID) + require.NotZero(t, result.HistoryTipMessageID) + require.Nil(t, result.StatusLabelModel) + require.Empty(t, result.FallbackProvider) + require.Empty(t, result.FallbackModel) + }) +} diff --git a/coderd/x/chatd/generation_retry_internal_test.go b/coderd/x/chatd/generation_retry_internal_test.go new file mode 100644 index 0000000000..87ef475ca5 --- /dev/null +++ b/coderd/x/chatd/generation_retry_internal_test.go @@ -0,0 +1,148 @@ +package chatd //nolint:testpackage // Exercises unexported generation retry helpers. + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" + "golang.org/x/xerrors" +) + +func TestTerminalGeneration(t *testing.T) { + t.Parallel() + + require.Nil(t, terminalGeneration(nil)) + + cause := xerrors.New("boom") + wrapped := terminalGeneration(cause) + require.True(t, isTerminalGeneration(wrapped)) + require.ErrorIs(t, wrapped, cause) + require.ErrorIs(t, wrapped, errTerminalGeneration) + require.Equal(t, cause.Error(), wrapped.Error()) + + require.False(t, isTerminalGeneration(cause)) + require.False(t, isTerminalGeneration(nil)) +} + +func TestGenerationPhaseBackoff(t *testing.T) { + t.Parallel() + + require.Equal(t, generationPhaseBaseBackoff, generationPhaseBackoff(0)) + require.Equal(t, 2*generationPhaseBaseBackoff, generationPhaseBackoff(1)) + require.Equal(t, 4*generationPhaseBaseBackoff, generationPhaseBackoff(2)) +} + +func TestRetryGenerationPhase(t *testing.T) { + t.Parallel() + + noopWait := func(context.Context, time.Duration) error { return nil } + + t.Run("SuccessFirstTry", func(t *testing.T) { + t.Parallel() + calls := 0 + waits := 0 + wait := func(context.Context, time.Duration) error { + waits++ + return nil + } + got, err := retryGenerationPhase(context.Background(), wait, func() (int, error) { + calls++ + return 42, nil + }) + require.NoError(t, err) + require.Equal(t, 42, got) + require.Equal(t, 1, calls) + require.Equal(t, 0, waits) + }) + + t.Run("RetryThenSuccess", func(t *testing.T) { + t.Parallel() + calls := 0 + waits := 0 + var delays []time.Duration + wait := func(_ context.Context, d time.Duration) error { + waits++ + delays = append(delays, d) + return nil + } + got, err := retryGenerationPhase(context.Background(), wait, func() (string, error) { + calls++ + if calls < 2 { + return "", xerrors.New("transient") + } + return "ok", nil + }) + require.NoError(t, err) + require.Equal(t, "ok", got) + require.Equal(t, 2, calls) + require.Equal(t, 1, waits) + require.Equal(t, []time.Duration{generationPhaseBackoff(0)}, delays) + }) + + t.Run("ExhaustsAndReturnsLastError", func(t *testing.T) { + t.Parallel() + calls := 0 + waits := 0 + wait := func(context.Context, time.Duration) error { + waits++ + return nil + } + _, err := retryGenerationPhase(context.Background(), wait, func() (int, error) { + calls++ + return 0, xerrors.Errorf("attempt %d", calls) + }) + require.EqualError(t, err, "attempt 3") + require.Equal(t, generationPhaseMaxAttempts, calls) + require.Equal(t, generationPhaseMaxAttempts-1, waits) + }) + + t.Run("TerminalShortCircuits", func(t *testing.T) { + t.Parallel() + calls := 0 + waits := 0 + wait := func(context.Context, time.Duration) error { + waits++ + return nil + } + cause := xerrors.New("deterministic") + _, err := retryGenerationPhase(context.Background(), wait, func() (int, error) { + calls++ + return 0, terminalGeneration(cause) + }) + require.ErrorIs(t, err, cause) + require.True(t, isTerminalGeneration(err)) + require.Equal(t, 1, calls) + require.Equal(t, 0, waits) + }) + + t.Run("ContextCanceledExitsCleanly", func(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithCancel(context.Background()) + cancel() + calls := 0 + _, err := retryGenerationPhase(ctx, noopWait, func() (int, error) { + calls++ + return 0, xerrors.New("transient") + }) + require.ErrorIs(t, err, errTaskExpectedExit) + require.Equal(t, 1, calls) + }) + + t.Run("WaitCancellationExitsCleanly", func(t *testing.T) { + t.Parallel() + calls := 0 + waits := 0 + wait := func(context.Context, time.Duration) error { + waits++ + return errTaskExpectedExit + } + _, err := retryGenerationPhase(context.Background(), wait, func() (int, error) { + calls++ + return 0, xerrors.New("transient") + }) + require.ErrorIs(t, err, errTaskExpectedExit) + require.Equal(t, 1, calls) + require.Equal(t, 1, waits) + }) +} diff --git a/coderd/x/chatd/helpers_test.go b/coderd/x/chatd/helpers_test.go new file mode 100644 index 0000000000..c76554b422 --- /dev/null +++ b/coderd/x/chatd/helpers_test.go @@ -0,0 +1,529 @@ +package chatd //nolint:testpackage // Uses unexported chatworker helpers. + +import ( + "context" + "database/sql" + "encoding/json" + "sync" + "testing" + "time" + + "github.com/google/uuid" + "github.com/sqlc-dev/pqtype" + "github.com/stretchr/testify/require" + + "cdr.dev/slog/v3" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbgen" + "github.com/coder/coder/v2/coderd/database/dbtestutil" + dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub" + coderdpubsub "github.com/coder/coder/v2/coderd/pubsub" + "github.com/coder/coder/v2/coderd/x/chatd/chatprompt" + "github.com/coder/coder/v2/coderd/x/chatd/chatstate" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/testutil" +) + +type workerTestFixture struct { + db database.Store + pubsub dbpubsub.Pubsub + sqlDB *sql.DB + user database.User + org database.Organization + model database.ChatModelConfig +} + +type publishedEvent struct { + channel string + payload []byte +} + +type recordingPubsub struct { + inner dbpubsub.Pubsub + mu sync.Mutex + events []publishedEvent +} + +func newRecordingPubsub(inner dbpubsub.Pubsub) *recordingPubsub { + return &recordingPubsub{inner: inner} +} + +func (p *recordingPubsub) Publish(channel string, payload []byte) error { + p.mu.Lock() + p.events = append(p.events, publishedEvent{ + channel: channel, + payload: append([]byte(nil), payload...), + }) + p.mu.Unlock() + return p.inner.Publish(channel, payload) +} + +func (p *recordingPubsub) SubscribeWithErr(channel string, listener dbpubsub.ListenerWithErr) (func(), error) { + return p.inner.SubscribeWithErr(channel, listener) +} + +func (p *recordingPubsub) ownershipMessages(t *testing.T) []coderdpubsub.ChatStateOwnershipMessage { + t.Helper() + p.mu.Lock() + defer p.mu.Unlock() + messages := make([]coderdpubsub.ChatStateOwnershipMessage, 0) + for _, event := range p.events { + if event.channel != coderdpubsub.ChatStateOwnershipChannel { + continue + } + var msg coderdpubsub.ChatStateOwnershipMessage + require.NoError(t, json.Unmarshal(event.payload, &msg)) + messages = append(messages, msg) + } + return messages +} + +func (p *recordingPubsub) watchEvents(t *testing.T) []codersdk.ChatWatchEvent { + t.Helper() + p.mu.Lock() + defer p.mu.Unlock() + events := make([]codersdk.ChatWatchEvent, 0) + for _, event := range p.events { + var msg codersdk.ChatWatchEvent + if err := json.Unmarshal(event.payload, &msg); err != nil { + continue + } + if event.channel != coderdpubsub.ChatWatchEventChannel(msg.Chat.OwnerID) { + continue + } + events = append(events, msg) + } + return events +} + +func (p *recordingPubsub) stateUpdateMessages(t *testing.T, chatID uuid.UUID) []coderdpubsub.ChatStateUpdateMessage { + t.Helper() + p.mu.Lock() + defer p.mu.Unlock() + messages := make([]coderdpubsub.ChatStateUpdateMessage, 0) + for _, event := range p.events { + if event.channel != coderdpubsub.ChatStateUpdateChannel(chatID) { + continue + } + var msg coderdpubsub.ChatStateUpdateMessage + require.NoError(t, json.Unmarshal(event.payload, &msg)) + messages = append(messages, msg) + } + return messages +} + +func newWorkerTestFixture(t *testing.T) *workerTestFixture { + t.Helper() + db, ps, sqlDB := dbtestutil.NewDBWithSQLDB(t) + user := dbgen.User(t, db, database.User{}) + org := dbgen.Organization(t, db, database.Organization{}) + dbgen.OrganizationMember(t, db, database.OrganizationMember{ + UserID: user.ID, + OrganizationID: org.ID, + }) + dbgen.ChatProvider(t, db, database.ChatProvider{ + Provider: "openai", + DisplayName: "openai", + BaseUrl: "http://example.invalid", + }) + model := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{ + Provider: "openai", + IsDefault: true, + }) + return &workerTestFixture{db: db, pubsub: ps, sqlDB: sqlDB, user: user, org: org, model: model} +} + +func (f *workerTestFixture) createRunningChat(t *testing.T) database.Chat { + t.Helper() + ctx := testutil.Context(t, testutil.WaitShort) + res, err := chatstate.CreateChat(ctx, f.db, f.pubsub, chatstate.CreateChatInput{ + OrganizationID: f.org.ID, + OwnerID: f.user.ID, + LastModelConfigID: f.model.ID, + Title: "test", + ClientType: database.ChatClientTypeApi, + InitialMessages: []chatstate.Message{ + userTextMessage(t, "hello", f.user.ID, f.model.ID), + }, + }) + require.NoError(t, err) + return res.Chat +} + +func (f *workerTestFixture) createRequiresActionChat(t *testing.T) database.Chat { + t.Helper() + ctx := testutil.Context(t, testutil.WaitShort) + toolName := "dynamic_" + uuid.NewString() + dynamicTools, err := json.Marshal([]codersdk.DynamicTool{{ + Name: toolName, + Description: "test tool", + InputSchema: json.RawMessage(`{"type":"object"}`), + }}) + require.NoError(t, err) + res, err := chatstate.CreateChat(ctx, f.db, f.pubsub, chatstate.CreateChatInput{ + OrganizationID: f.org.ID, + OwnerID: f.user.ID, + LastModelConfigID: f.model.ID, + Title: "test", + ClientType: database.ChatClientTypeApi, + DynamicTools: pqtype.NullRawMessage{ + RawMessage: dynamicTools, + Valid: true, + }, + InitialMessages: []chatstate.Message{ + userTextMessage(t, "hello", f.user.ID, f.model.ID), + }, + }) + require.NoError(t, err) + machine := chatstate.NewChatMachine(f.db, f.pubsub, res.Chat.ID, chatstate.Options{}) + require.NoError(t, machine.Update(ctx, func(tx *chatstate.Tx) error { + _, err := tx.CommitStep(chatstate.CommitStepInput{ + Messages: []chatstate.Message{ + assistantToolCallMessage(t, f.model.ID, toolName), + }, + }) + return err + })) + require.NoError(t, machine.Update(ctx, func(tx *chatstate.Tx) error { + _, err := tx.EnterRequiresAction(chatstate.EnterRequiresActionInput{}) + return err + })) + chat, err := f.db.GetChatByID(ctx, res.Chat.ID) + require.NoError(t, err) + return chat +} + +func userTextMessage(t *testing.T, text string, createdBy uuid.UUID, modelConfigID uuid.UUID) chatstate.Message { + t.Helper() + raw, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{codersdk.ChatMessageText(text)}) + require.NoError(t, err) + return chatstate.Message{ + Role: database.ChatMessageRoleUser, + Content: raw, + Visibility: database.ChatMessageVisibilityBoth, + ContentVersion: chatprompt.CurrentContentVersion, + CreatedBy: uuid.NullUUID{UUID: createdBy, Valid: true}, + ModelConfigID: uuid.NullUUID{UUID: modelConfigID, Valid: true}, + } +} + +func assistantTextMessage(t *testing.T, text string, modelConfigID uuid.UUID) chatstate.Message { + t.Helper() + raw, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{codersdk.ChatMessageText(text)}) + require.NoError(t, err) + return chatstate.Message{ + Role: database.ChatMessageRoleAssistant, + Content: raw, + Visibility: database.ChatMessageVisibilityBoth, + ContentVersion: chatprompt.CurrentContentVersion, + ModelConfigID: uuid.NullUUID{UUID: modelConfigID, Valid: true}, + } +} + +func assistantToolCallMessage(t *testing.T, modelConfigID uuid.UUID, toolName string) chatstate.Message { + t.Helper() + raw, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{{ + Type: codersdk.ChatMessagePartTypeToolCall, + ToolCallID: "call_" + uuid.NewString(), + ToolName: toolName, + Args: json.RawMessage(`{}`), + }}) + require.NoError(t, err) + return chatstate.Message{ + Role: database.ChatMessageRoleAssistant, + Content: raw, + Visibility: database.ChatMessageVisibilityBoth, + ContentVersion: chatprompt.CurrentContentVersion, + ModelConfigID: uuid.NullUUID{UUID: modelConfigID, Valid: true}, + } +} + +func testOptions(t *testing.T, f *workerTestFixture, starter chatWorkerTaskStarter) chatWorkerOptions { + t.Helper() + if starter == nil { + starter = newRecordingTaskStarter() + } + return chatWorkerOptions{ + WorkerID: uuid.New(), + Store: f.db, + Pubsub: f.pubsub, + Logger: slog.Make(), + TaskStarter: starter, + AcquisitionInterval: time.Hour, + AcquisitionBatchSize: 10, + RunnerSyncInterval: time.Hour, + HeartbeatInterval: time.Hour, + HeartbeatCleanupInterval: time.Hour, + HeartbeatStaleSeconds: 30, + StateChannelSize: 16, + RunnerManagerChannelSize: 16, + AcquisitionWakeChannelSize: 1, + } +} + +func startWorker(t *testing.T, opts chatWorkerOptions) *chatWorker { + t.Helper() + worker, err := newChatWorker(nil, opts) + require.NoError(t, err) + require.NoError(t, worker.Start(context.Background())) + t.Cleanup(func() { require.NoError(t, worker.Close()) }) + return worker +} + +type taskCall struct { + kind taskKind + input chatWorkerTaskStartInput + ctx context.Context +} + +type releaseGate struct { + once sync.Once + ch chan struct{} +} + +type recordingTaskStarter struct { + mu sync.Mutex + calls []taskCall + callCh chan taskCall + releases []*releaseGate + block bool + ignoreCancel bool +} + +func newRecordingTaskStarter() *recordingTaskStarter { + return &recordingTaskStarter{callCh: make(chan taskCall, 128)} +} + +func newBlockingTaskStarter(ignoreCancel bool) *recordingTaskStarter { + return &recordingTaskStarter{ + callCh: make(chan taskCall, 128), + block: true, + ignoreCancel: ignoreCancel, + } +} + +func (s *recordingTaskStarter) StartGeneration(ctx context.Context, input chatWorkerTaskStartInput) error { + return s.start(ctx, taskKindGeneration, input) +} + +func (s *recordingTaskStarter) StartInterrupt(ctx context.Context, input chatWorkerTaskStartInput) error { + return s.start(ctx, taskKindInterrupt, input) +} + +func (s *recordingTaskStarter) StartRequiresActionTimeout(ctx context.Context, input chatWorkerTaskStartInput) error { + return s.start(ctx, taskKindRequiresActionTimeout, input) +} + +func (s *recordingTaskStarter) StartAbandon(ctx context.Context, input chatWorkerTaskStartInput) error { + return s.start(ctx, taskKindAbandon, input) +} + +func (s *recordingTaskStarter) start(ctx context.Context, kind taskKind, input chatWorkerTaskStartInput) error { + call := taskCall{kind: kind, input: input, ctx: ctx} + var gate *releaseGate + s.mu.Lock() + if s.block { + gate = &releaseGate{ch: make(chan struct{})} + s.releases = append(s.releases, gate) + } + s.calls = append(s.calls, call) + s.mu.Unlock() + s.callCh <- call + if gate == nil { + return nil + } + if s.ignoreCancel { + <-gate.ch + return nil + } + select { + case <-gate.ch: + return nil + case <-ctx.Done(): + return ctx.Err() + } +} + +func (s *recordingTaskStarter) waitCall(t *testing.T, kind taskKind, chatID uuid.UUID) taskCall { + t.Helper() + deadline := time.After(testutil.WaitLong) + for { + select { + case call := <-s.callCh: + if (kind == "" || call.kind == kind) && (chatID == uuid.Nil || call.input.ChatID == chatID) { + return call + } + case <-deadline: + t.Fatalf("timed out waiting for task call kind=%q chat_id=%s", kind, chatID) + return taskCall{} + } + } +} + +func (s *recordingTaskStarter) assertNoCall(t *testing.T) { + t.Helper() + select { + case call := <-s.callCh: + t.Fatalf("unexpected task call: %s for chat %s", call.kind, call.input.ChatID) + case <-time.After(100 * time.Millisecond): + } +} + +func (s *recordingTaskStarter) release(t *testing.T, index int) { + t.Helper() + s.mu.Lock() + defer s.mu.Unlock() + require.Less(t, index, len(s.releases)) + s.releases[index].once.Do(func() { close(s.releases[index].ch) }) +} + +func (s *recordingTaskStarter) releaseAll() { + s.mu.Lock() + defer s.mu.Unlock() + for _, gate := range s.releases { + gate.once.Do(func() { close(gate.ch) }) + } +} + +func finishTurn(t *testing.T, f *workerTestFixture, chatID uuid.UUID) database.Chat { + t.Helper() + ctx := testutil.Context(t, testutil.WaitShort) + machine := chatstate.NewChatMachine(f.db, f.pubsub, chatID, chatstate.Options{}) + require.NoError(t, machine.Update(ctx, func(tx *chatstate.Tx) error { + _, err := tx.FinishTurn(chatstate.FinishTurnInput{}) + return err + })) + chat, err := f.db.GetChatByID(ctx, chatID) + require.NoError(t, err) + return chat +} + +func commitAssistantStep(t *testing.T, f *workerTestFixture, chatID uuid.UUID, text string) database.Chat { + t.Helper() + ctx := testutil.Context(t, testutil.WaitShort) + machine := chatstate.NewChatMachine(f.db, f.pubsub, chatID, chatstate.Options{}) + require.NoError(t, machine.Update(ctx, func(tx *chatstate.Tx) error { + _, err := tx.CommitStep(chatstate.CommitStepInput{ + Messages: []chatstate.Message{assistantTextMessage(t, text, f.model.ID)}, + }) + return err + })) + chat, err := f.db.GetChatByID(ctx, chatID) + require.NoError(t, err) + return chat +} + +func interruptChat(t *testing.T, f *workerTestFixture, chatID uuid.UUID) database.Chat { + t.Helper() + ctx := testutil.Context(t, testutil.WaitShort) + machine := chatstate.NewChatMachine(f.db, f.pubsub, chatID, chatstate.Options{}) + require.NoError(t, machine.Update(ctx, func(tx *chatstate.Tx) error { + _, err := tx.SendMessage(chatstate.SendMessageInput{ + Message: userTextMessage(t, "interrupt", f.user.ID, f.model.ID), + BusyBehavior: chatstate.BusyBehaviorInterrupt, + }) + return err + })) + chat, err := f.db.GetChatByID(ctx, chatID) + require.NoError(t, err) + return chat +} + +func acquireChat(t *testing.T, f *workerTestFixture, chatID uuid.UUID, workerID uuid.UUID, runnerID uuid.UUID) database.Chat { + t.Helper() + ctx := testutil.Context(t, testutil.WaitShort) + machine := chatstate.NewChatMachine(f.db, f.pubsub, chatID, chatstate.Options{}) + require.NoError(t, machine.Update(ctx, func(tx *chatstate.Tx) error { + _, err := tx.Acquire(chatstate.AcquireInput{WorkerID: workerID, RunnerID: runnerID}) + return err + })) + chat, err := f.db.GetChatByID(ctx, chatID) + require.NoError(t, err) + return chat +} + +func forceExecutionState( + t *testing.T, + f *workerTestFixture, + chatID uuid.UUID, + status database.ChatStatus, + archived bool, +) database.Chat { + t.Helper() + ctx := testutil.Context(t, testutil.WaitShort) + var updated database.Chat + require.NoError(t, f.db.InTx(func(store database.Store) error { + if _, err := store.LockChatAndBumpSnapshotVersion(ctx, chatID); err != nil { + return err + } + chat, err := store.GetChatByID(ctx, chatID) + if err != nil { + return err + } + updated, err = store.UpdateChatExecutionState(ctx, database.UpdateChatExecutionStateParams{ + ID: chat.ID, + Status: status, + Archived: archived, + WorkerID: chat.WorkerID, + RunnerID: chat.RunnerID, + LastError: chat.LastError, + RequiresActionDeadlineAt: chat.RequiresActionDeadlineAt, + }) + return err + }, nil)) + return updated +} + +func forceExecutionStateAndPublish( + t *testing.T, + f *workerTestFixture, + chatID uuid.UUID, + status database.ChatStatus, + archived bool, +) database.Chat { + t.Helper() + updated := forceExecutionState(t, f, chatID, status, archived) + publishChatUpdate(t, f, updated) + return updated +} + +func publishChatUpdate(t *testing.T, f *workerTestFixture, chat database.Chat) { + t.Helper() + msg := coderdpubsub.ChatStateUpdateMessage{ + SnapshotVersion: chat.SnapshotVersion, + HistoryVersion: chat.HistoryVersion, + QueueVersion: chat.QueueVersion, + RetryStateVersion: chat.RetryStateVersion, + GenerationAttempt: chat.GenerationAttempt, + Status: string(chat.Status), + Archived: chat.Archived, + } + if chat.WorkerID.Valid { + id := chat.WorkerID.UUID + msg.WorkerID = &id + } + if chat.RunnerID.Valid { + id := chat.RunnerID.UUID + msg.RunnerID = &id + } + payload, err := json.Marshal(msg) + require.NoError(t, err) + require.NoError(t, f.pubsub.Publish(coderdpubsub.ChatStateUpdateChannel(chat.ID), payload)) +} + +func makeHeartbeatStale(t *testing.T, f *workerTestFixture, chatID uuid.UUID, runnerID uuid.UUID) time.Time { + t.Helper() + _, err := f.sqlDB.ExecContext( + testutil.Context(t, testutil.WaitShort), + `UPDATE chat_heartbeats SET heartbeat_at = NOW() - INTERVAL '1 hour' WHERE chat_id = $1 AND runner_id = $2`, + chatID, + runnerID, + ) + require.NoError(t, err) + heartbeat, err := f.db.GetChatHeartbeat(testutil.Context(t, testutil.WaitShort), database.GetChatHeartbeatParams{ + ChatID: chatID, + RunnerID: runnerID, + }) + require.NoError(t, err) + return heartbeat.HeartbeatAt +} diff --git a/coderd/x/chatd/instruction.go b/coderd/x/chatd/instruction.go index 02f6dc675a..05476ed6f0 100644 --- a/coderd/x/chatd/instruction.go +++ b/coderd/x/chatd/instruction.go @@ -128,80 +128,6 @@ func instructionFromContextFiles( return formatSystemInstructions(os, dir, contextParts) } -// hasPersistedInstructionFiles reports whether messages include a -// persisted context-file part that should suppress another baseline -// instruction-file lookup. The workspace-agent skill-only sentinel is -// ignored so default instructions still load on fresh chats. -func hasPersistedInstructionFiles( - messages []database.ChatMessage, -) bool { - for _, msg := range messages { - if !msg.Content.Valid || - !bytes.Contains(msg.Content.RawMessage, []byte(`"context-file"`)) { - continue - } - var parts []codersdk.ChatMessagePart - if err := json.Unmarshal(msg.Content.RawMessage, &parts); err != nil { - continue - } - for _, part := range parts { - if part.Type != codersdk.ChatMessagePartTypeContextFile || - !part.ContextFileAgentID.Valid || - part.ContextFilePath == AgentChatContextSentinelPath { - continue - } - return true - } - } - return false -} - -func mergeSkillMetas( - persisted []chattool.SkillMeta, - discovered []chattool.SkillMeta, -) []chattool.SkillMeta { - if len(persisted) == 0 { - return discovered - } - if len(discovered) == 0 { - return persisted - } - - seen := make(map[string]struct{}, len(persisted)+len(discovered)) - merged := make([]chattool.SkillMeta, 0, len(persisted)+len(discovered)) - appendUnique := func(skill chattool.SkillMeta) { - if _, ok := seen[skill.Name]; ok { - return - } - seen[skill.Name] = struct{}{} - merged = append(merged, skill) - } - for _, skill := range discovered { - appendUnique(skill) - } - for _, skill := range persisted { - appendUnique(skill) - } - return merged -} - -// selectSkillMetasForInstructionRefresh chooses which skill metadata -// should be injected on a turn that refreshes instruction files. -func selectSkillMetasForInstructionRefresh( - persisted []chattool.SkillMeta, - discovered []chattool.SkillMeta, - currentAgentID uuid.NullUUID, - latestInjectedAgentID uuid.NullUUID, -) []chattool.SkillMeta { - if currentAgentID.Valid && latestInjectedAgentID.Valid && latestInjectedAgentID.UUID == currentAgentID.UUID { - return mergeSkillMetas(persisted, discovered) - } - if !currentAgentID.Valid && len(discovered) == 0 { - return persisted - } - return discovered -} - // skillsFromParts reconstructs skill metadata from persisted // skill parts. This is analogous to instructionFromContextFiles // so the skill index can be re-injected after compaction without @@ -238,19 +164,3 @@ func skillsFromParts( } return skills } - -// filterSkillParts returns stripped copies of skill-type parts from -// the given slice. Internal fields are removed so the result is safe -// for the cache column. Returns nil when no skill parts exist. -func filterSkillParts(parts []codersdk.ChatMessagePart) []codersdk.ChatMessagePart { - var out []codersdk.ChatMessagePart - for _, p := range parts { - if p.Type != codersdk.ChatMessagePartTypeSkill { - continue - } - cp := p - cp.StripInternal() - out = append(out, cp) - } - return out -} diff --git a/coderd/x/chatd/message_conversion.go b/coderd/x/chatd/message_conversion.go new file mode 100644 index 0000000000..d8123391be --- /dev/null +++ b/coderd/x/chatd/message_conversion.go @@ -0,0 +1,842 @@ +package chatd + +import ( + "cmp" + "context" + "database/sql" + "encoding/json" + "slices" + "strings" + "time" + + "charm.land/fantasy" + "github.com/google/uuid" + "github.com/sqlc-dev/pqtype" + "golang.org/x/xerrors" + + "cdr.dev/slog/v3" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/x/chatd/chatcost" + "github.com/coder/coder/v2/coderd/x/chatd/chatprompt" + "github.com/coder/coder/v2/coderd/x/chatd/chatstate" + "github.com/coder/coder/v2/coderd/x/chatd/chattool" + "github.com/coder/coder/v2/coderd/x/chatd/messagepartbuffer" + "github.com/coder/coder/v2/codersdk" +) + +const interruptedToolResultErrorMessage = "tool call was interrupted before it produced a result" + +type buildCommitStepMessagesInput struct { + modelConfigID uuid.UUID + modelCallConfig codersdk.ChatModelCallConfig + step stepData + toolNameToConfigID map[string]uuid.UUID + logger slog.Logger + contentVersion int16 +} + +type stepMessagesForCommit struct { + Messages []chatstate.Message + VisibleIndexes []int +} + +func buildCommitStepMessages(input buildCommitStepMessagesInput) (stepMessagesForCommit, error) { + contentVersion := input.contentVersion + if contentVersion == 0 { + contentVersion = chatprompt.CurrentContentVersion + } + + assistantBlocks, toolResults := splitStepContent(input.step.Content) + assistantParts := buildAssistantParts(input.logger, assistantBlocks, toolResults, input.step, input.toolNameToConfigID) + + messages := make([]chatstate.Message, 0, 1+len(toolResults)) + if len(assistantParts) > 0 { + assistantContent, err := chatprompt.MarshalParts(assistantParts) + if err != nil { + return stepMessagesForCommit{}, xerrors.Errorf("marshal assistant content: %w", err) + } + messages = append(messages, assistantMessage(input.modelConfigID, contentVersion, assistantContent, input.step, input.modelCallConfig)) + } + + for _, toolResult := range toolResults { + part := chatprompt.PartFromContentWithLogger(context.Background(), input.logger, toolResult) + applyToolMetadata(&part, input.toolNameToConfigID) + if part.ToolCallID != "" && input.step.ToolResultCreatedAt != nil { + if ts, ok := input.step.ToolResultCreatedAt[part.ToolCallID]; ok { + part.CreatedAt = &ts + } + } + content, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{part}) + if err != nil { + return stepMessagesForCommit{}, xerrors.Errorf("marshal tool result: %w", err) + } + messages = append(messages, baseMessage(database.ChatMessageRoleTool, database.ChatMessageVisibilityBoth, input.modelConfigID, contentVersion, content)) + } + + return stepMessagesForCommit{ + Messages: messages, + VisibleIndexes: visibleMessageIndexes(messages), + }, nil +} + +func splitStepContent(content []fantasy.Content) ([]fantasy.Content, []fantasy.ToolResultContent) { + assistantBlocks := make([]fantasy.Content, 0, len(content)) + toolResults := make([]fantasy.ToolResultContent, 0) + for _, block := range content { + if tr, ok := asToolResultContent(block); ok && !tr.ProviderExecuted { + toolResults = append(toolResults, tr) + continue + } + assistantBlocks = append(assistantBlocks, block) + } + return assistantBlocks, toolResults +} + +func asToolResultContent(block fantasy.Content) (fantasy.ToolResultContent, bool) { + if tr, ok := fantasy.AsContentType[fantasy.ToolResultContent](block); ok { + return tr, true + } + if tr, ok := fantasy.AsContentType[*fantasy.ToolResultContent](block); ok && tr != nil { + return *tr, true + } + return fantasy.ToolResultContent{}, false +} + +func buildAssistantParts( + logger slog.Logger, + assistantBlocks []fantasy.Content, + toolResults []fantasy.ToolResultContent, + step stepData, + toolNameToConfigID map[string]uuid.UUID, +) []codersdk.ChatMessagePart { + parts := make([]codersdk.ChatMessagePart, 0, len(assistantBlocks)+len(toolResults)) + reasoningIdx := 0 + for _, block := range assistantBlocks { + part := chatprompt.PartFromContentWithLogger(context.Background(), logger, block) + applyToolMetadata(&part, toolNameToConfigID) + switch part.Type { + case codersdk.ChatMessagePartTypeToolCall: + if part.ToolCallID != "" && step.ToolCallCreatedAt != nil { + if ts, ok := step.ToolCallCreatedAt[part.ToolCallID]; ok { + part.CreatedAt = &ts + } + } + case codersdk.ChatMessagePartTypeToolResult: + if part.ToolCallID != "" && step.ToolResultCreatedAt != nil { + if ts, ok := step.ToolResultCreatedAt[part.ToolCallID]; ok { + part.CreatedAt = &ts + } + } + case codersdk.ChatMessagePartTypeReasoning: + if reasoningIdx < len(step.ReasoningStartedAt) { + if ts := step.ReasoningStartedAt[reasoningIdx]; !ts.IsZero() { + part.CreatedAt = &ts + } + } + if reasoningIdx < len(step.ReasoningCompletedAt) { + if ts := step.ReasoningCompletedAt[reasoningIdx]; !ts.IsZero() { + part.CompletedAt = &ts + } + } + reasoningIdx++ + } + if part.Type != "" { + parts = append(parts, part) + } + } + for _, tr := range toolResults { + attachments, err := chattool.AttachmentsFromMetadata(tr.ClientMetadata) + if err != nil { + logger.Warn(context.Background(), "skipping malformed tool attachment metadata", + slog.F("tool_name", tr.ToolName), + slog.F("tool_call_id", tr.ToolCallID), + slog.Error(err), + ) + continue + } + for _, attachment := range attachments { + parts = append(parts, codersdk.ChatMessageFile(attachment.FileID, attachment.MediaType, attachment.Name)) + } + } + return parts +} + +func applyToolMetadata(part *codersdk.ChatMessagePart, toolNameToConfigID map[string]uuid.UUID) { + if part.ToolName == "" || len(toolNameToConfigID) == 0 { + return + } + if configID, ok := toolNameToConfigID[part.ToolName]; ok { + part.MCPServerConfigID = uuid.NullUUID{UUID: configID, Valid: true} + } +} + +func assistantMessage( + modelConfigID uuid.UUID, + contentVersion int16, + content pqtype.NullRawMessage, + step stepData, + modelCallConfig codersdk.ChatModelCallConfig, +) chatstate.Message { + msg := baseMessage(database.ChatMessageRoleAssistant, database.ChatMessageVisibilityBoth, modelConfigID, contentVersion, content) + if step.Usage != (fantasy.Usage{}) { + msg.InputTokens = nullInt64IfNonZero(step.Usage.InputTokens) + msg.OutputTokens = nullInt64IfNonZero(step.Usage.OutputTokens) + msg.TotalTokens = nullInt64IfNonZero(step.Usage.TotalTokens) + msg.ReasoningTokens = nullInt64IfNonZero(step.Usage.ReasoningTokens) + msg.CacheCreationTokens = nullInt64IfNonZero(step.Usage.CacheCreationTokens) + msg.CacheReadTokens = nullInt64IfNonZero(step.Usage.CacheReadTokens) + usage := codersdk.ChatMessageUsage{ + InputTokens: int64PtrIfNonZero(step.Usage.InputTokens), + OutputTokens: int64PtrIfNonZero(step.Usage.OutputTokens), + ReasoningTokens: int64PtrIfNonZero(step.Usage.ReasoningTokens), + CacheCreationTokens: int64PtrIfNonZero(step.Usage.CacheCreationTokens), + CacheReadTokens: int64PtrIfNonZero(step.Usage.CacheReadTokens), + } + if totalCost := chatcost.CalculateTotalCostMicros(usage, modelCallConfig.Cost); totalCost != nil { + msg.TotalCostMicros = sql.NullInt64{Int64: *totalCost, Valid: true} + } + } + msg.ContextLimit = step.ContextLimit + if step.Runtime > 0 { + msg.RuntimeMs = sql.NullInt64{Int64: step.Runtime.Milliseconds(), Valid: true} + } + if step.ProviderResponseID != "" { + msg.ProviderResponseID = sql.NullString{String: step.ProviderResponseID, Valid: true} + } + return msg +} + +func baseMessage( + role database.ChatMessageRole, + visibility database.ChatMessageVisibility, + modelConfigID uuid.UUID, + contentVersion int16, + content pqtype.NullRawMessage, +) chatstate.Message { + return chatstate.Message{ + Role: role, + Content: content, + Visibility: visibility, + ModelConfigID: uuid.NullUUID{UUID: modelConfigID, Valid: modelConfigID != uuid.Nil}, + ContentVersion: contentVersion, + } +} + +func nullInt64IfNonZero(value int64) sql.NullInt64 { + if value == 0 { + return sql.NullInt64{} + } + return sql.NullInt64{Int64: value, Valid: true} +} + +func int64PtrIfNonZero(value int64) *int64 { + if value == 0 { + return nil + } + return &value +} + +func visibleMessageIndexes(messages []chatstate.Message) []int { + indexes := make([]int, 0, len(messages)) + for i, msg := range messages { + if msg.Visibility == database.ChatMessageVisibilityBoth || msg.Visibility == database.ChatMessageVisibilityUser { + indexes = append(indexes, i) + } + } + return indexes +} + +func textFromParts(parts []codersdk.ChatMessagePart) string { + var builder strings.Builder + for _, part := range parts { + if part.Type == codersdk.ChatMessagePartTypeText { + _, _ = builder.WriteString(part.Text) + } + } + return builder.String() +} + +type buildCompactionMessagesInput struct { + modelConfigID uuid.UUID + toolCallID string + toolName string + compaction compactionOutcome + contentVersion int16 +} + +type compactionMessagesForCommit struct { + Messages []chatstate.Message + HiddenCount int +} + +func buildCompactionMessages(input buildCompactionMessagesInput) (compactionMessagesForCommit, error) { + contentVersion := input.contentVersion + if contentVersion == 0 { + contentVersion = chatprompt.CurrentContentVersion + } + toolName := input.toolName + if toolName == "" { + toolName = "chat_summarized" + } + + systemContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{codersdk.ChatMessageText(input.compaction.SystemSummary)}) + if err != nil { + return compactionMessagesForCommit{}, xerrors.Errorf("marshal compaction system summary: %w", err) + } + args, err := json.Marshal(map[string]any{ + "source": "automatic", + "threshold_percent": input.compaction.ThresholdPercent, + }) + if err != nil { + return compactionMessagesForCommit{}, xerrors.Errorf("marshal compaction args: %w", err) + } + assistantContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{ + codersdk.ChatMessageToolCall(input.toolCallID, toolName, args), + }) + if err != nil { + return compactionMessagesForCommit{}, xerrors.Errorf("marshal compaction tool call: %w", err) + } + summaryResult, err := json.Marshal(map[string]any{ + "summary": input.compaction.SummaryReport, + "source": "automatic", + "threshold_percent": input.compaction.ThresholdPercent, + "usage_percent": input.compaction.UsagePercent, + "context_tokens": input.compaction.ContextTokens, + "context_limit_tokens": input.compaction.ContextLimit, + }) + if err != nil { + return compactionMessagesForCommit{}, xerrors.Errorf("marshal compaction result: %w", err) + } + toolContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{ + codersdk.ChatMessageToolResult(input.toolCallID, toolName, summaryResult, false, false), + }) + if err != nil { + return compactionMessagesForCommit{}, xerrors.Errorf("marshal compaction tool result: %w", err) + } + + messages := []chatstate.Message{ + baseMessage(database.ChatMessageRoleUser, database.ChatMessageVisibilityModel, input.modelConfigID, contentVersion, systemContent), + baseMessage(database.ChatMessageRoleAssistant, database.ChatMessageVisibilityUser, input.modelConfigID, contentVersion, assistantContent), + baseMessage(database.ChatMessageRoleTool, database.ChatMessageVisibilityBoth, input.modelConfigID, contentVersion, toolContent), + } + for i := range messages { + messages[i].Compressed = true + } + return compactionMessagesForCommit{Messages: messages, HiddenCount: 1}, nil +} + +func currentTurnStepCount(messages []database.ChatMessage) int { + latestUser := -1 + for i, msg := range messages { + if msg.Deleted || msg.Compressed { + continue + } + if msg.Role == database.ChatMessageRoleUser { + latestUser = i + } + } + count := 0 + for i := latestUser + 1; i < len(messages); i++ { + msg := messages[i] + if msg.Deleted || msg.Compressed { + continue + } + if msg.Role == database.ChatMessageRoleAssistant { + count++ + } + } + return count +} + +type compactionRequirement int + +const ( + compactionRequirementNotNeeded compactionRequirement = iota + compactionRequirementNeeded +) + +func compactionStatusFromHistory(messages []database.ChatMessage, requirement compactionRequirement) compactionStatus { + boundaryIndex := latestCompactionBoundaryIndex(messages) + if requirement == compactionRequirementNeeded { + if boundaryIndex == -1 { + return compactionStatusNeeded + } + if hasUncompressedAssistantAfter(messages, boundaryIndex) { + return compactionStatusStillOverLimit + } + return compactionStatusAfterCompaction + } + if boundaryIndex != -1 && !hasUncompressedAssistantAfter(messages, boundaryIndex) { + return compactionStatusAfterCompaction + } + return compactionStatusNotNeeded +} + +func latestCompactionBoundaryIndex(messages []database.ChatMessage) int { + for i := len(messages) - 1; i >= 0; i-- { + if isCompactionBoundaryMessage(messages[i]) { + return i + } + } + return -1 +} + +func isCompactionBoundaryMessage(msg database.ChatMessage) bool { + if msg.Deleted || !msg.Compressed { + return false + } + parts, err := chatprompt.ParseContent(msg) + if err != nil { + return false + } + for _, part := range parts { + if part.ToolName == "chat_summarized" && + (part.Type == codersdk.ChatMessagePartTypeToolCall || part.Type == codersdk.ChatMessagePartTypeToolResult) { + return true + } + } + return false +} + +func hasUncompressedAssistantAfter(messages []database.ChatMessage, index int) bool { + for i := index + 1; i < len(messages); i++ { + msg := messages[i] + if msg.Deleted || msg.Compressed { + continue + } + if msg.Role == database.ChatMessageRoleAssistant { + return true + } + } + return false +} + +func historyHasStopAfterToolResult(messages []database.ChatMessage, stopAfterTools map[string]struct{}) (bool, error) { + if len(stopAfterTools) == 0 { + return false, nil + } + start := 0 + for i, msg := range messages { + if msg.Deleted || msg.Compressed { + continue + } + if msg.Role == database.ChatMessageRoleUser { + start = i + 1 + } + } + for _, msg := range messages[start:] { + if msg.Deleted || msg.Compressed || msg.Role != database.ChatMessageRoleTool { + continue + } + parts, err := chatprompt.ParseContent(msg) + if err != nil { + return false, xerrors.Errorf("parse tool message: %w", err) + } + for _, part := range parts { + if part.Type != codersdk.ChatMessagePartTypeToolResult || part.IsError { + continue + } + if _, ok := stopAfterTools[part.ToolName]; ok { + return true, nil + } + } + } + return false, nil +} + +func currentHistoryComplete(messages []database.ChatMessage) (bool, error) { + idx := lastMessageIndex(messages, func(database.ChatMessage) bool { return true }) + if idx == -1 || messages[idx].Role != database.ChatMessageRoleAssistant { + return false, nil + } + parts, err := chatprompt.ParseContent(messages[idx]) + if err != nil { + return false, xerrors.Errorf("parse latest assistant message: %w", err) + } + for _, part := range parts { + if part.Type == codersdk.ChatMessagePartTypeToolCall && !part.ProviderExecuted { + return false, nil + } + } + return true, nil +} + +func lastMessageIndex(messages []database.ChatMessage, accept func(database.ChatMessage) bool) int { + for i := len(messages) - 1; i >= 0; i-- { + if messages[i].Deleted || messages[i].Compressed { + continue + } + if accept(messages[i]) { + return i + } + } + return -1 +} + +func handledToolCallIDs(messages []database.ChatMessage) (map[string]bool, error) { + handled := make(map[string]bool) + for _, msg := range messages { + if msg.Deleted || msg.Compressed || msg.Role != database.ChatMessageRoleTool { + continue + } + parts, err := chatprompt.ParseContent(msg) + if err != nil { + return nil, xerrors.Errorf("parse tool message: %w", err) + } + for _, part := range parts { + if part.Type == codersdk.ChatMessagePartTypeToolResult && part.ToolCallID != "" { + handled[part.ToolCallID] = true + } + } + } + return handled, nil +} + +type bufferedPartsToPartialMessagesInput struct { + parts []messagepartbuffer.Part + modelConfigID uuid.UUID + contentVersion int16 + logger slog.Logger + interruptedAt time.Time +} + +type partialToolCall struct { + part codersdk.ChatMessagePart + index int + argsDelta strings.Builder + valid bool + durable bool +} + +type partialToolResult struct { + part codersdk.ChatMessagePart + resultDelta strings.Builder + completed bool +} + +func bufferedPartsToPartialMessages(input bufferedPartsToPartialMessagesInput) ([]chatstate.Message, error) { + contentVersion := input.contentVersion + if contentVersion == 0 { + contentVersion = chatprompt.CurrentContentVersion + } + parts := slices.Clone(input.parts) + slices.SortFunc(parts, func(a, b messagepartbuffer.Part) int { + return cmp.Compare(a.Seq, b.Seq) + }) + + state := partialMessageConversionState{ + input: input, + contentVersion: contentVersion, + toolCalls: make(map[string]*partialToolCall), + toolResults: make(map[string]*partialToolResult), + answered: make(map[string]bool), + } + for _, buffered := range parts { + if err := state.consume(buffered); err != nil { + return nil, err + } + } + if err := state.finalizeToolCallPlaceholders(); err != nil { + return nil, err + } + if err := state.flushAssistant(); err != nil { + return nil, err + } + if err := state.flushAccumulatedToolResults(); err != nil { + return nil, err + } + if err := state.appendSyntheticInterruptionResults(); err != nil { + return nil, err + } + return state.messages, nil +} + +type partialMessageConversionState struct { + input bufferedPartsToPartialMessagesInput + contentVersion int16 + + messages []chatstate.Message + assistantParts []codersdk.ChatMessagePart + toolCalls map[string]*partialToolCall + toolCallOrder []string + toolResults map[string]*partialToolResult + toolResultOrder []string + answered map[string]bool +} + +func (s *partialMessageConversionState) consume(buffered messagepartbuffer.Part) error { + switch buffered.Role { + case codersdk.ChatMessageRoleAssistant: + s.consumeAssistantPart(buffered) + case codersdk.ChatMessageRoleTool: + return s.consumeToolPart(buffered) + default: + s.logSkippedPart(buffered, "unsupported buffered part role") + } + return nil +} + +func (s *partialMessageConversionState) consumeAssistantPart(buffered messagepartbuffer.Part) { + part := buffered.MessagePart + if part.Type == "" { + s.logSkippedPart(buffered, "empty buffered assistant part type") + return + } + if part.Type != codersdk.ChatMessagePartTypeToolCall { + if part.Type == codersdk.ChatMessagePartTypeReasoning && + !s.input.interruptedAt.IsZero() { + interruptedAt := s.input.interruptedAt + if part.CreatedAt == nil { + part.CreatedAt = &interruptedAt + } + if part.CompletedAt == nil { + part.CompletedAt = &interruptedAt + } + } + s.assistantParts = append(s.assistantParts, part) + return + } + if part.ToolCallID == "" { + s.logSkippedPart(buffered, "tool call part missing tool call ID") + return + } + call := s.toolCall(part.ToolCallID) + call.part.Type = codersdk.ChatMessagePartTypeToolCall + call.part.ToolCallID = part.ToolCallID + if part.ToolName != "" { + call.part.ToolName = part.ToolName + } + if part.MCPServerConfigID.Valid { + call.part.MCPServerConfigID = part.MCPServerConfigID + } + if part.CreatedAt != nil { + call.part.CreatedAt = part.CreatedAt + } + call.part.ProviderExecuted = call.part.ProviderExecuted || part.ProviderExecuted + + if part.ArgsDelta != "" { + if call.durable { + s.logSkippedPart(buffered, "tool call args delta arrived after full tool call") + return + } + _, _ = call.argsDelta.WriteString(part.ArgsDelta) + return + } + + durable := part + durable.ArgsDelta = "" + if len(durable.Args) > 0 && !json.Valid(durable.Args) { + call.valid = false + s.assistantParts[call.index] = codersdk.ChatMessagePart{} + s.logSkippedPart(buffered, "tool call part has invalid durable args") + return + } + if call.durable { + s.logSkippedPart(buffered, "duplicate durable tool call part") + } + call.part = durable + call.valid = true + call.durable = true + s.assistantParts[call.index] = durable +} + +func (s *partialMessageConversionState) consumeToolPart(buffered messagepartbuffer.Part) error { + part := buffered.MessagePart + if part.Type != codersdk.ChatMessagePartTypeToolResult { + s.logSkippedPart(buffered, "non tool-result part with tool role") + return nil + } + if part.ToolCallID == "" { + s.logSkippedPart(buffered, "tool result part missing tool call ID") + return nil + } + if part.ResultReset { + result := s.toolResult(part.ToolCallID) + result.part.ToolCallID = part.ToolCallID + result.part.ToolName = part.ToolName + result.resultDelta.Reset() + s.logSkippedPart(buffered, "streaming tool result reset is not durable") + return nil + } + if part.ResultDelta != "" { + result := s.toolResult(part.ToolCallID) + result.part.ToolCallID = part.ToolCallID + if part.ToolName != "" { + result.part.ToolName = part.ToolName + } + if part.MCPServerConfigID.Valid { + result.part.MCPServerConfigID = part.MCPServerConfigID + } + if part.CreatedAt != nil { + result.part.CreatedAt = part.CreatedAt + } + result.part.ProviderExecuted = result.part.ProviderExecuted || part.ProviderExecuted + _, _ = result.resultDelta.WriteString(part.ResultDelta) + return nil + } + if err := s.finalizeToolCallPlaceholders(); err != nil { + return err + } + if !s.toolCallDurable(part.ToolCallID) { + s.logSkippedPart(buffered, "tool result has no matching durable tool call") + return nil + } + if len(part.Result) == 0 || !json.Valid(part.Result) { + s.logSkippedPart(buffered, "tool result part has invalid durable result") + return nil + } + if s.answered[part.ToolCallID] { + s.logSkippedPart(buffered, "duplicate durable tool result part") + return nil + } + part.ResultDelta = "" + part.ResultReset = false + if err := s.flushAssistant(); err != nil { + return err + } + if err := s.appendToolResult(part); err != nil { + return err + } + s.answered[part.ToolCallID] = true + return nil +} + +func (s *partialMessageConversionState) toolCall(id string) *partialToolCall { + call := s.toolCalls[id] + if call != nil { + return call + } + call = &partialToolCall{index: len(s.assistantParts), valid: true} + s.toolCalls[id] = call + s.toolCallOrder = append(s.toolCallOrder, id) + s.assistantParts = append(s.assistantParts, codersdk.ChatMessagePart{}) + return call +} + +func (s *partialMessageConversionState) toolResult(id string) *partialToolResult { + result := s.toolResults[id] + if result != nil { + return result + } + result = &partialToolResult{} + s.toolResults[id] = result + s.toolResultOrder = append(s.toolResultOrder, id) + return result +} + +func (s *partialMessageConversionState) finalizeToolCallPlaceholders() error { + for _, id := range s.toolCallOrder { + call := s.toolCalls[id] + if call == nil || call.durable || !call.valid { + continue + } + args := json.RawMessage(call.argsDelta.String()) + if len(args) == 0 || !json.Valid(args) { + s.assistantParts[call.index] = codersdk.ChatMessagePart{} + call.valid = false + s.logSkippedPart(messagepartbuffer.Part{ + Role: codersdk.ChatMessageRoleAssistant, + MessagePart: call.part, + }, "tool call args delta did not form durable JSON") + continue + } + call.part.Args = args + call.part.ArgsDelta = "" + call.durable = true + s.assistantParts[call.index] = call.part + } + return nil +} + +func (s *partialMessageConversionState) flushAssistant() error { + if len(s.assistantParts) == 0 { + return nil + } + durable := make([]codersdk.ChatMessagePart, 0, len(s.assistantParts)) + for _, part := range s.assistantParts { + if part.Type == "" { + continue + } + part.ArgsDelta = "" + part.ResultDelta = "" + part.ResultReset = false + durable = append(durable, part) + } + s.assistantParts = nil + if len(durable) == 0 { + return nil + } + content, err := chatprompt.MarshalParts(durable) + if err != nil { + return xerrors.Errorf("marshal partial assistant: %w", err) + } + s.messages = append(s.messages, baseMessage(database.ChatMessageRoleAssistant, database.ChatMessageVisibilityBoth, s.input.modelConfigID, s.contentVersion, content)) + return nil +} + +func (s *partialMessageConversionState) flushAccumulatedToolResults() error { + for _, id := range s.toolResultOrder { + if s.answered[id] { + continue + } + result := s.toolResults[id] + if result == nil || result.completed { + continue + } + if result.resultDelta.Len() == 0 { + continue + } + s.logSkippedPart(messagepartbuffer.Part{Role: codersdk.ChatMessageRoleTool, MessagePart: result.part}, "streaming tool result delta is not durable") + } + return nil +} + +func (s *partialMessageConversionState) appendToolResult(part codersdk.ChatMessagePart) error { + content, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{part}) + if err != nil { + return xerrors.Errorf("marshal partial tool result: %w", err) + } + s.messages = append(s.messages, baseMessage(database.ChatMessageRoleTool, database.ChatMessageVisibilityBoth, s.input.modelConfigID, s.contentVersion, content)) + return nil +} + +func (s *partialMessageConversionState) appendSyntheticInterruptionResults() error { + for _, id := range s.toolCallOrder { + if s.answered[id] { + continue + } + call := s.toolCalls[id] + if call == nil || !call.valid || !call.durable || call.part.ProviderExecuted { + continue + } + result, err := json.Marshal(map[string]string{"error": interruptedToolResultErrorMessage}) + if err != nil { + return xerrors.Errorf("marshal synthetic interruption result: %w", err) + } + part := codersdk.ChatMessageToolResult(call.part.ToolCallID, call.part.ToolName, result, true, false) + part.MCPServerConfigID = call.part.MCPServerConfigID + if !s.input.interruptedAt.IsZero() { + part.CreatedAt = &s.input.interruptedAt + } + if err := s.appendToolResult(part); err != nil { + return xerrors.Errorf("marshal synthetic interruption message: %w", err) + } + s.answered[id] = true + } + return nil +} + +func (s *partialMessageConversionState) toolCallDurable(id string) bool { + call := s.toolCalls[id] + return call != nil && call.valid && call.durable +} + +func (s *partialMessageConversionState) logSkippedPart(buffered messagepartbuffer.Part, reason string) { + s.input.logger.Warn(context.Background(), "skipping buffered chat message part", + slog.F("reason", reason), + slog.F("role", buffered.Role), + slog.F("part_type", buffered.MessagePart.Type), + slog.F("tool_call_id", buffered.MessagePart.ToolCallID), + slog.F("tool_name", buffered.MessagePart.ToolName), + ) +} diff --git a/coderd/x/chatd/message_conversion_test.go b/coderd/x/chatd/message_conversion_test.go new file mode 100644 index 0000000000..5469901e10 --- /dev/null +++ b/coderd/x/chatd/message_conversion_test.go @@ -0,0 +1,532 @@ +package chatd //nolint:testpackage // Uses unexported chatworker helpers. + +import ( + "context" + "database/sql" + "encoding/json" + "sync" + "testing" + "time" + + "charm.land/fantasy" + "github.com/google/uuid" + "github.com/shopspring/decimal" + "github.com/sqlc-dev/pqtype" + "github.com/stretchr/testify/require" + + "cdr.dev/slog/v3" + "cdr.dev/slog/v3/sloggers/slogtest" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/x/chatd/chatprompt" + "github.com/coder/coder/v2/coderd/x/chatd/messagepartbuffer" + "github.com/coder/coder/v2/codersdk" +) + +func TestBuildCommitStepMessages_AssistantTextAndReasoning(t *testing.T) { + t.Parallel() + + modelConfigID := uuid.New() + startedAt := time.Date(2026, 1, 2, 3, 4, 5, 0, time.UTC) + completedAt := startedAt.Add(2 * time.Second) + got, err := buildCommitStepMessages(buildCommitStepMessagesInput{ + modelConfigID: modelConfigID, + contentVersion: chatprompt.CurrentContentVersion, + logger: slog.Make(), + step: stepData{ + Content: []fantasy.Content{ + fantasy.ReasoningContent{Text: "thinking"}, + fantasy.TextContent{Text: "hello"}, + }, + ReasoningStartedAt: []time.Time{startedAt}, + ReasoningCompletedAt: []time.Time{completedAt}, + }, + }) + require.NoError(t, err) + require.Len(t, got.Messages, 1) + require.Equal(t, []int{0}, got.VisibleIndexes) + + msg := got.Messages[0] + require.Equal(t, database.ChatMessageRoleAssistant, msg.Role) + require.Equal(t, database.ChatMessageVisibilityBoth, msg.Visibility) + require.Equal(t, uuid.NullUUID{UUID: modelConfigID, Valid: true}, msg.ModelConfigID) + require.Equal(t, chatprompt.CurrentContentVersion, msg.ContentVersion) + parts := parseMessageParts(t, msg.Role, msg.Content) + require.Len(t, parts, 2) + require.Equal(t, codersdk.ChatMessagePartTypeReasoning, parts[0].Type) + require.Equal(t, "thinking", parts[0].Text) + require.Equal(t, startedAt, requireNotNilTime(t, parts[0].CreatedAt)) + require.Equal(t, completedAt, requireNotNilTime(t, parts[0].CompletedAt)) + require.Equal(t, codersdk.ChatMessagePartTypeText, parts[1].Type) + require.Equal(t, "hello", parts[1].Text) +} + +func TestBuildCommitStepMessages_LocalToolResultsBecomeToolMessages(t *testing.T) { + t.Parallel() + + modelConfigID := uuid.New() + got, err := buildCommitStepMessages(buildCommitStepMessagesInput{ + modelConfigID: modelConfigID, + contentVersion: chatprompt.CurrentContentVersion, + logger: slog.Make(), + step: stepData{Content: []fantasy.Content{ + fantasy.ToolCallContent{ToolCallID: "call-1", ToolName: "execute", Input: `{"cmd":"pwd"}`}, + fantasy.ToolResultContent{ + ToolCallID: "call-1", + ToolName: "execute", + Result: fantasy.ToolResultOutputContentText{Text: `{"stdout":"/tmp"}`}, + }, + }}, + }) + require.NoError(t, err) + require.Len(t, got.Messages, 2) + require.Equal(t, []int{0, 1}, got.VisibleIndexes) + + assistantParts := parseMessageParts(t, got.Messages[0].Role, got.Messages[0].Content) + require.Len(t, assistantParts, 1) + require.Equal(t, codersdk.ChatMessagePartTypeToolCall, assistantParts[0].Type) + require.Equal(t, "call-1", assistantParts[0].ToolCallID) + require.Equal(t, "execute", assistantParts[0].ToolName) + + toolParts := parseMessageParts(t, got.Messages[1].Role, got.Messages[1].Content) + require.Len(t, toolParts, 1) + require.Equal(t, codersdk.ChatMessagePartTypeToolResult, toolParts[0].Type) + require.Equal(t, "call-1", toolParts[0].ToolCallID) + require.Equal(t, "execute", toolParts[0].ToolName) + require.JSONEq(t, `{"stdout":"/tmp"}`, string(toolParts[0].Result)) +} + +func TestBuildCommitStepMessages_ProviderExecutedResultsStayAssistantContent(t *testing.T) { + t.Parallel() + + got, err := buildCommitStepMessages(buildCommitStepMessagesInput{ + modelConfigID: uuid.New(), + contentVersion: chatprompt.CurrentContentVersion, + logger: slog.Make(), + step: stepData{Content: []fantasy.Content{ + fantasy.ToolCallContent{ + ToolCallID: "web-1", + ToolName: "web_search", + ProviderExecuted: true, + }, + fantasy.ToolResultContent{ + ToolCallID: "web-1", + ToolName: "web_search", + ProviderExecuted: true, + Result: fantasy.ToolResultOutputContentText{Text: `{"ok":true}`}, + }, + }}, + }) + require.NoError(t, err) + require.Len(t, got.Messages, 1) + parts := parseMessageParts(t, got.Messages[0].Role, got.Messages[0].Content) + require.Len(t, parts, 2) + require.Equal(t, codersdk.ChatMessagePartTypeToolCall, parts[0].Type) + require.True(t, parts[0].ProviderExecuted) + require.Equal(t, codersdk.ChatMessagePartTypeToolResult, parts[1].Type) + require.True(t, parts[1].ProviderExecuted) +} + +func TestBuildCommitStepMessages_UsageCostRuntimeProviderResponseID(t *testing.T) { + t.Parallel() + + inputPrice := decimal.NewFromFloat(2.5) + outputPrice := decimal.NewFromFloat(7.5) + got, err := buildCommitStepMessages(buildCommitStepMessagesInput{ + modelConfigID: uuid.New(), + contentVersion: chatprompt.CurrentContentVersion, + logger: slog.Make(), + modelCallConfig: codersdk.ChatModelCallConfig{ + Cost: &codersdk.ModelCostConfig{ + InputPricePerMillionTokens: &inputPrice, + OutputPricePerMillionTokens: &outputPrice, + }, + }, + step: stepData{ + Content: []fantasy.Content{fantasy.TextContent{Text: "usage"}}, + Usage: fantasy.Usage{InputTokens: 100, OutputTokens: 20, TotalTokens: 120, ReasoningTokens: 3, CacheCreationTokens: 4, CacheReadTokens: 5}, + ContextLimit: sql.NullInt64{Int64: 4096, Valid: true}, + ProviderResponseID: "resp-123", + Runtime: 1500 * time.Millisecond, + }, + }) + require.NoError(t, err) + require.Len(t, got.Messages, 1) + msg := got.Messages[0] + require.Equal(t, sql.NullInt64{Int64: 100, Valid: true}, msg.InputTokens) + require.Equal(t, sql.NullInt64{Int64: 20, Valid: true}, msg.OutputTokens) + require.Equal(t, sql.NullInt64{Int64: 120, Valid: true}, msg.TotalTokens) + require.Equal(t, sql.NullInt64{Int64: 3, Valid: true}, msg.ReasoningTokens) + require.Equal(t, sql.NullInt64{Int64: 4, Valid: true}, msg.CacheCreationTokens) + require.Equal(t, sql.NullInt64{Int64: 5, Valid: true}, msg.CacheReadTokens) + require.Equal(t, sql.NullInt64{Int64: 4096, Valid: true}, msg.ContextLimit) + require.Equal(t, sql.NullInt64{Int64: 1500, Valid: true}, msg.RuntimeMs) + require.Equal(t, sql.NullString{String: "resp-123", Valid: true}, msg.ProviderResponseID) + require.True(t, msg.TotalCostMicros.Valid) + require.Greater(t, msg.TotalCostMicros.Int64, int64(0)) +} + +func TestBuildCommitStepMessages_ToolTimestampsAndMCPConfigIDs(t *testing.T) { + t.Parallel() + + callAt := time.Date(2026, 2, 3, 4, 5, 6, 0, time.UTC) + resultAt := callAt.Add(3 * time.Second) + configID := uuid.New() + got, err := buildCommitStepMessages(buildCommitStepMessagesInput{ + modelConfigID: uuid.New(), + contentVersion: chatprompt.CurrentContentVersion, + logger: slog.Make(), + toolNameToConfigID: map[string]uuid.UUID{ + "mcp_tool": configID, + }, + step: stepData{Content: []fantasy.Content{ + fantasy.ToolCallContent{ToolCallID: "call-1", ToolName: "mcp_tool", Input: `{}`}, + fantasy.ToolResultContent{ToolCallID: "call-1", ToolName: "mcp_tool", Result: fantasy.ToolResultOutputContentText{Text: `{"ok":true}`}}, + }, ToolCallCreatedAt: map[string]time.Time{ + "call-1": callAt, + }, ToolResultCreatedAt: map[string]time.Time{ + "call-1": resultAt, + }}, + }) + require.NoError(t, err) + require.Len(t, got.Messages, 2) + callPart := parseMessageParts(t, got.Messages[0].Role, got.Messages[0].Content)[0] + resultPart := parseMessageParts(t, got.Messages[1].Role, got.Messages[1].Content)[0] + require.Equal(t, uuid.NullUUID{UUID: configID, Valid: true}, callPart.MCPServerConfigID) + require.Equal(t, callAt, requireNotNilTime(t, callPart.CreatedAt)) + require.Equal(t, uuid.NullUUID{UUID: configID, Valid: true}, resultPart.MCPServerConfigID) + require.Equal(t, resultAt, requireNotNilTime(t, resultPart.CreatedAt)) +} + +func TestBuildCompactionMessages_CompressedSummaryToolCallAndResult(t *testing.T) { + t.Parallel() + + modelConfigID := uuid.New() + got, err := buildCompactionMessages(buildCompactionMessagesInput{ + modelConfigID: modelConfigID, + contentVersion: chatprompt.CurrentContentVersion, + toolCallID: "summary-1", + toolName: "chat_summarized", + compaction: compactionOutcome{ + SystemSummary: "system summary", + SummaryReport: "user report", + ThresholdPercent: 70, + UsagePercent: 81.5, + ContextTokens: 815, + ContextLimit: 1000, + }, + }) + require.NoError(t, err) + require.Equal(t, 1, got.HiddenCount) + require.Len(t, got.Messages, 3) + + require.Equal(t, database.ChatMessageRoleUser, got.Messages[0].Role) + require.Equal(t, database.ChatMessageVisibilityModel, got.Messages[0].Visibility) + require.True(t, got.Messages[0].Compressed) + require.Equal(t, uuid.NullUUID{UUID: modelConfigID, Valid: true}, got.Messages[0].ModelConfigID) + require.Equal(t, "system summary", parseMessageParts(t, got.Messages[0].Role, got.Messages[0].Content)[0].Text) + + require.Equal(t, database.ChatMessageRoleAssistant, got.Messages[1].Role) + require.Equal(t, database.ChatMessageVisibilityUser, got.Messages[1].Visibility) + require.True(t, got.Messages[1].Compressed) + callPart := parseMessageParts(t, got.Messages[1].Role, got.Messages[1].Content)[0] + require.Equal(t, codersdk.ChatMessagePartTypeToolCall, callPart.Type) + require.Equal(t, "summary-1", callPart.ToolCallID) + require.JSONEq(t, `{"source":"automatic","threshold_percent":70}`, string(callPart.Args)) + + require.Equal(t, database.ChatMessageRoleTool, got.Messages[2].Role) + require.Equal(t, database.ChatMessageVisibilityBoth, got.Messages[2].Visibility) + require.True(t, got.Messages[2].Compressed) + resultPart := parseMessageParts(t, got.Messages[2].Role, got.Messages[2].Content)[0] + require.Equal(t, codersdk.ChatMessagePartTypeToolResult, resultPart.Type) + require.Equal(t, "summary-1", resultPart.ToolCallID) + require.JSONEq(t, `{"summary":"user report","source":"automatic","threshold_percent":70,"usage_percent":81.5,"context_tokens":815,"context_limit_tokens":1000}`, string(resultPart.Result)) +} + +func TestCurrentTurnStepCount_ExcludesCompressedCompactionMessages(t *testing.T) { + t.Parallel() + + messages := []database.ChatMessage{ + dbMessage(t, 1, database.ChatMessageRoleUser, false, codersdk.ChatMessageText("start")), + dbMessage(t, 2, database.ChatMessageRoleAssistant, false, codersdk.ChatMessageText("first")), + dbMessage(t, 3, database.ChatMessageRoleUser, true, codersdk.ChatMessageText("compressed summary")), + dbMessage(t, 4, database.ChatMessageRoleAssistant, true, codersdk.ChatMessageToolCall("summary", "chat_summarized", nil)), + dbMessage(t, 5, database.ChatMessageRoleTool, true, codersdk.ChatMessageToolResult("summary", "chat_summarized", json.RawMessage(`{}`), false, false)), + dbMessage(t, 6, database.ChatMessageRoleAssistant, false, codersdk.ChatMessageText("second")), + } + got := currentTurnStepCount(messages) + require.Equal(t, 2, got) +} + +func TestCurrentTurnStepCount_CountsAssistantMessagesAfterLatestUser(t *testing.T) { + t.Parallel() + + messages := []database.ChatMessage{ + dbMessage(t, 1, database.ChatMessageRoleUser, false, codersdk.ChatMessageText("old")), + dbMessage(t, 2, database.ChatMessageRoleAssistant, false, codersdk.ChatMessageText("old answer")), + dbMessage(t, 3, database.ChatMessageRoleUser, false, codersdk.ChatMessageText("new")), + dbMessage(t, 4, database.ChatMessageRoleAssistant, false, codersdk.ChatMessageText("one")), + dbMessage(t, 5, database.ChatMessageRoleTool, false, codersdk.ChatMessageToolResult("call", "tool", json.RawMessage(`{}`), false, false)), + dbMessage(t, 6, database.ChatMessageRoleAssistant, false, codersdk.ChatMessageText("two")), + } + got := currentTurnStepCount(messages) + require.Equal(t, 2, got) +} + +func TestDecisionDetectsStopAfterToolFromCommittedHistory(t *testing.T) { + t.Parallel() + + messages := []database.ChatMessage{ + dbMessage(t, 1, database.ChatMessageRoleUser, false, codersdk.ChatMessageText("plan")), + dbMessage(t, 2, database.ChatMessageRoleAssistant, false, codersdk.ChatMessageToolCall("plan-1", "propose_plan", json.RawMessage(`{}`))), + dbMessage(t, 3, database.ChatMessageRoleTool, false, codersdk.ChatMessageToolResult("plan-1", "propose_plan", json.RawMessage(`{"ok":true}`), false, false)), + } + got, err := historyHasStopAfterToolResult(messages, map[string]struct{}{"propose_plan": {}}) + require.NoError(t, err) + require.True(t, got) + + messages[2] = dbMessage(t, 3, database.ChatMessageRoleTool, false, codersdk.ChatMessageToolResult("plan-1", "propose_plan", json.RawMessage(`{"error":"no"}`), true, false)) + got, err = historyHasStopAfterToolResult(messages, map[string]struct{}{"propose_plan": {}}) + require.NoError(t, err) + require.False(t, got) +} + +func TestDecisionDetectsCurrentHistoryCompletion(t *testing.T) { + t.Parallel() + + complete, err := currentHistoryComplete([]database.ChatMessage{ + dbMessage(t, 1, database.ChatMessageRoleUser, false, codersdk.ChatMessageText("hello")), + dbMessage(t, 2, database.ChatMessageRoleAssistant, false, codersdk.ChatMessageText("done")), + }) + require.NoError(t, err) + require.True(t, complete) + + complete, err = currentHistoryComplete([]database.ChatMessage{ + dbMessage(t, 1, database.ChatMessageRoleUser, false, codersdk.ChatMessageText("hello")), + dbMessage(t, 2, database.ChatMessageRoleAssistant, false, codersdk.ChatMessageToolCall("call-1", "execute", json.RawMessage(`{}`))), + }) + require.NoError(t, err) + require.False(t, complete) + + complete, err = currentHistoryComplete([]database.ChatMessage{ + dbMessage(t, 1, database.ChatMessageRoleUser, false, codersdk.ChatMessageText("hello")), + dbMessage(t, 2, database.ChatMessageRoleAssistant, false, codersdk.ChatMessageToolCall("call-1", "execute", json.RawMessage(`{}`))), + dbMessage(t, 3, database.ChatMessageRoleTool, false, codersdk.ChatMessageToolResult("call-1", "execute", json.RawMessage(`{"ok":true}`), false, false)), + }) + require.NoError(t, err) + require.False(t, complete) +} + +func TestBufferedPartsToPartialMessages_NormalizesToolCallDeltasBeforeFinal(t *testing.T) { + t.Parallel() + + createdAt := time.Date(2026, 3, 4, 5, 6, 7, 0, time.UTC) + parts := []messagepartbuffer.Part{ + {Seq: 1, Role: codersdk.ChatMessageRoleAssistant, MessagePart: codersdk.ChatMessageText("partial ")}, + {Seq: 2, Role: codersdk.ChatMessageRoleAssistant, MessagePart: codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeToolCall, ToolCallID: "call-1", ToolName: "execute", ArgsDelta: `{"cmd":`}}, + {Seq: 3, Role: codersdk.ChatMessageRoleAssistant, MessagePart: codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeToolCall, ToolCallID: "call-1", ToolName: "execute", ArgsDelta: `"ignored"}`}}, + {Seq: 4, Role: codersdk.ChatMessageRoleAssistant, MessagePart: codersdk.ChatMessageToolCall("call-1", "execute", json.RawMessage(`{"cmd":"pwd"}`))}, + } + got, err := bufferedPartsToPartialMessages(bufferedPartsToPartialMessagesInput{ + parts: parts, + modelConfigID: uuid.New(), + contentVersion: chatprompt.CurrentContentVersion, + logger: slog.Make(), + interruptedAt: createdAt, + }) + require.NoError(t, err) + require.Len(t, got, 2) + assistantParts := parseMessageParts(t, got[0].Role, got[0].Content) + require.Len(t, assistantParts, 2) + require.Equal(t, codersdk.ChatMessagePartTypeText, assistantParts[0].Type) + call := assistantParts[1] + require.Equal(t, codersdk.ChatMessagePartTypeToolCall, call.Type) + require.Equal(t, "call-1", call.ToolCallID) + require.Empty(t, call.ArgsDelta) + require.JSONEq(t, `{"cmd":"pwd"}`, string(call.Args)) + syntheticParts := parseMessageParts(t, got[1].Role, got[1].Content) + require.Len(t, syntheticParts, 1) + require.Equal(t, "call-1", syntheticParts[0].ToolCallID) +} + +func TestBufferedPartsToPartialMessages_MergesToolCallDeltasWithoutFinal(t *testing.T) { + t.Parallel() + + parts := []messagepartbuffer.Part{ + {Seq: 1, Role: codersdk.ChatMessageRoleAssistant, MessagePart: codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeToolCall, ToolCallID: "call-1", ToolName: "execute", ArgsDelta: `{"cmd":`}}, + {Seq: 2, Role: codersdk.ChatMessageRoleAssistant, MessagePart: codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeToolCall, ToolCallID: "call-1", ToolName: "execute", ArgsDelta: `"pwd"}`}}, + } + got, err := bufferedPartsToPartialMessages(bufferedPartsToPartialMessagesInput{ + parts: parts, + modelConfigID: uuid.New(), + contentVersion: chatprompt.CurrentContentVersion, + logger: slog.Make(), + }) + require.NoError(t, err) + require.Len(t, got, 2) + assistantParts := parseMessageParts(t, got[0].Role, got[0].Content) + require.Len(t, assistantParts, 1) + require.Empty(t, assistantParts[0].ArgsDelta) + require.JSONEq(t, `{"cmd":"pwd"}`, string(assistantParts[0].Args)) + syntheticParts := parseMessageParts(t, got[1].Role, got[1].Content) + require.Len(t, syntheticParts, 1) + require.Equal(t, "call-1", syntheticParts[0].ToolCallID) +} + +func TestBufferedPartsToPartialMessages_DeltaOnlyToolResultDoesNotAnswer(t *testing.T) { + t.Parallel() + + logSink := &partialConversionLogSink{} + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).AppendSinks(logSink) + parts := []messagepartbuffer.Part{ + {Seq: 1, Role: codersdk.ChatMessageRoleAssistant, MessagePart: codersdk.ChatMessageToolCall("call-1", "advisor", json.RawMessage(`{}`))}, + {Seq: 2, Role: codersdk.ChatMessageRoleTool, MessagePart: codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeToolResult, ToolCallID: "call-1", ToolName: "advisor", ResultDelta: `{"type":"advice"}`}}, + } + got, err := bufferedPartsToPartialMessages(bufferedPartsToPartialMessagesInput{ + parts: parts, + modelConfigID: uuid.New(), + contentVersion: chatprompt.CurrentContentVersion, + logger: logger, + }) + require.NoError(t, err) + require.Len(t, got, 2) + toolParts := parseMessageParts(t, got[1].Role, got[1].Content) + require.Len(t, toolParts, 1) + require.Equal(t, "call-1", toolParts[0].ToolCallID) + require.True(t, toolParts[0].IsError) + require.Empty(t, toolParts[0].ResultDelta) + require.JSONEq(t, `{"error":"tool call was interrupted before it produced a result"}`, string(toolParts[0].Result)) + require.NotEmpty(t, logSink.entriesAtLevelWithMessage(slog.LevelWarn, "skipping buffered chat message part")) +} + +func TestBufferedPartsToPartialMessages_LogsMalformedSkippedParts(t *testing.T) { + t.Parallel() + + logSink := &partialConversionLogSink{} + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).AppendSinks(logSink) + parts := []messagepartbuffer.Part{ + {Seq: 1, Role: codersdk.ChatMessageRoleSystem, MessagePart: codersdk.ChatMessageText("bad role")}, + {Seq: 2, Role: codersdk.ChatMessageRoleAssistant, MessagePart: codersdk.ChatMessagePart{}}, + {Seq: 3, Role: codersdk.ChatMessageRoleTool, MessagePart: codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeToolResult, ToolName: "execute", Result: json.RawMessage(`{"ok":true}`)}}, + {Seq: 4, Role: codersdk.ChatMessageRoleAssistant, MessagePart: codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeToolCall, ToolCallID: "bad-args", ToolName: "execute", ArgsDelta: `{"cmd":`}}, + } + got, err := bufferedPartsToPartialMessages(bufferedPartsToPartialMessagesInput{ + parts: parts, + modelConfigID: uuid.New(), + contentVersion: chatprompt.CurrentContentVersion, + logger: logger, + }) + require.NoError(t, err) + require.Empty(t, got) + require.GreaterOrEqual(t, len(logSink.entriesAtLevelWithMessage(slog.LevelWarn, "skipping buffered chat message part")), 4) +} + +func TestBufferedPartsToPartialMessages_SynthesizesMissingToolResults(t *testing.T) { + t.Parallel() + + modelConfigID := uuid.New() + createdAt := time.Date(2026, 3, 4, 5, 6, 7, 0, time.UTC) + reasoningStartedAt := createdAt.Add(-2 * time.Second) + reasoningPart := codersdk.ChatMessageReasoning("partial thought") + reasoningPart.CreatedAt = &reasoningStartedAt + parts := []messagepartbuffer.Part{ + {Seq: 1, Role: codersdk.ChatMessageRoleAssistant, MessagePart: codersdk.ChatMessageText("partial ")}, + {Seq: 2, Role: codersdk.ChatMessageRoleAssistant, MessagePart: reasoningPart}, + {Seq: 3, Role: codersdk.ChatMessageRoleAssistant, MessagePart: codersdk.ChatMessageToolCall("call-1", "execute", json.RawMessage(`{}`))}, + {Seq: 4, Role: codersdk.ChatMessageRoleAssistant, MessagePart: codersdk.ChatMessageToolCall("call-2", "read_file", json.RawMessage(`{}`))}, + {Seq: 5, Role: codersdk.ChatMessageRoleTool, MessagePart: withCreatedAt(codersdk.ChatMessageToolResult("call-2", "read_file", json.RawMessage(`{"ok":true}`), false, false), createdAt)}, + } + got, err := bufferedPartsToPartialMessages(bufferedPartsToPartialMessagesInput{ + parts: parts, + modelConfigID: modelConfigID, + contentVersion: chatprompt.CurrentContentVersion, + logger: slog.Make(), + interruptedAt: createdAt, + }) + require.NoError(t, err) + require.Len(t, got, 3) + require.Equal(t, database.ChatMessageRoleAssistant, got[0].Role) + assistantParts := parseMessageParts(t, got[0].Role, got[0].Content) + require.Len(t, assistantParts, 4) + require.Equal(t, codersdk.ChatMessagePartTypeReasoning, assistantParts[1].Type) + require.Equal(t, "partial thought", assistantParts[1].Text) + require.Equal(t, reasoningStartedAt, requireNotNilTime(t, assistantParts[1].CreatedAt)) + require.Equal(t, createdAt, requireNotNilTime(t, assistantParts[1].CompletedAt)) + require.Equal(t, codersdk.ChatMessagePartTypeToolCall, assistantParts[2].Type) + require.Equal(t, codersdk.ChatMessagePartTypeToolCall, assistantParts[3].Type) + + require.Equal(t, database.ChatMessageRoleTool, got[1].Role) + toolParts := parseMessageParts(t, got[1].Role, got[1].Content) + require.Equal(t, "call-2", toolParts[0].ToolCallID) + require.Equal(t, createdAt, requireNotNilTime(t, toolParts[0].CreatedAt)) + + require.Equal(t, database.ChatMessageRoleTool, got[2].Role) + syntheticParts := parseMessageParts(t, got[2].Role, got[2].Content) + require.Len(t, syntheticParts, 1) + require.Equal(t, "call-1", syntheticParts[0].ToolCallID) + require.Equal(t, "execute", syntheticParts[0].ToolName) + require.True(t, syntheticParts[0].IsError) + require.JSONEq(t, `{"error":"tool call was interrupted before it produced a result"}`, string(syntheticParts[0].Result)) + require.Equal(t, createdAt, requireNotNilTime(t, syntheticParts[0].CreatedAt)) + require.Equal(t, uuid.NullUUID{UUID: modelConfigID, Valid: true}, got[2].ModelConfigID) +} + +func parseMessageParts(t *testing.T, role database.ChatMessageRole, raw pqtype.NullRawMessage) []codersdk.ChatMessagePart { + t.Helper() + parts, err := chatprompt.ParseContent(database.ChatMessage{ + Role: role, + Content: raw, + }) + require.NoError(t, err) + return parts +} + +func dbMessage(t *testing.T, id int64, role database.ChatMessageRole, compressed bool, parts ...codersdk.ChatMessagePart) database.ChatMessage { + t.Helper() + raw, err := chatprompt.MarshalParts(parts) + require.NoError(t, err) + return database.ChatMessage{ + ID: id, + Role: role, + Content: raw, + ContentVersion: chatprompt.CurrentContentVersion, + Visibility: database.ChatMessageVisibilityBoth, + Compressed: compressed, + } +} + +func requireNotNilTime(t *testing.T, value *time.Time) time.Time { + t.Helper() + require.NotNil(t, value) + return *value +} + +func withCreatedAt(part codersdk.ChatMessagePart, createdAt time.Time) codersdk.ChatMessagePart { + part.CreatedAt = &createdAt + return part +} + +type partialConversionLogSink struct { + mu sync.Mutex + entries []slog.SinkEntry +} + +func (s *partialConversionLogSink) LogEntry(_ context.Context, entry slog.SinkEntry) { + s.mu.Lock() + defer s.mu.Unlock() + s.entries = append(s.entries, entry) +} + +func (*partialConversionLogSink) Sync() {} + +func (s *partialConversionLogSink) entriesAtLevelWithMessage(level slog.Level, message string) []slog.SinkEntry { + s.mu.Lock() + defer s.mu.Unlock() + + entries := make([]slog.SinkEntry, 0, len(s.entries)) + for _, entry := range s.entries { + if entry.Level == level && entry.Message == message { + entries = append(entries, entry) + } + } + return entries +} diff --git a/coderd/x/chatd/messagepartbuffer/message_part_buffer.go b/coderd/x/chatd/messagepartbuffer/message_part_buffer.go new file mode 100644 index 0000000000..0176ebe66b --- /dev/null +++ b/coderd/x/chatd/messagepartbuffer/message_part_buffer.go @@ -0,0 +1,492 @@ +package messagepartbuffer + +import ( + "container/heap" + "context" + "encoding/json" + "sync" + "time" + + "github.com/google/uuid" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/codersdk" + "github.com/coder/quartz" +) + +const ( + defaultMaxEpisodeBytes = int64(1024 * 1024) + defaultClosedEpisodeRetention = 15 * time.Second + defaultSubscriberSendTimeout = 10 * time.Second + defaultSubscriberChannelSize = 16 +) + +var ( + // ErrEpisodeExists means the episode already exists. + ErrEpisodeExists = xerrors.New("message part episode already exists") + // ErrEpisodeNotFound means the episode has not been created. + ErrEpisodeNotFound = xerrors.New("message part episode not found") + // ErrEpisodeClosed means the episode no longer accepts parts. + ErrEpisodeClosed = xerrors.New("message part episode closed") + // ErrEpisodeFull means the episode byte limit would be exceeded. + ErrEpisodeFull = xerrors.New("message part episode full") + // ErrMessagePartBufferClosed means the whole buffer is closed. + ErrMessagePartBufferClosed = xerrors.New("message part buffer closed") +) + +// Key identifies a buffered message part episode. +type Key struct { + ChatID uuid.UUID + HistoryVersion int64 + GenerationAttempt int64 +} + +// Part is a buffered chat message part with its sequence number. +type Part struct { + Seq int64 + Role codersdk.ChatMessageRole + MessagePart codersdk.ChatMessagePart +} + +type partJSON struct { + Seq int64 `json:"seq"` + Role codersdk.ChatMessageRole `json:"role"` + Part codersdk.ChatMessagePart `json:"part"` +} + +func (p Part) jsonValue() partJSON { + return partJSON{ + Seq: p.Seq, + Role: p.Role, + Part: p.MessagePart, + } +} + +// Options configures a Buffer. +type Options struct { + MaxEpisodeBytes int64 + ClosedEpisodeRetention time.Duration + SubscriberSendTimeout time.Duration + SubscriberChannelSize int + Clock quartz.Clock +} + +// Buffer stores streamed message parts by episode. +type Buffer struct { + mu sync.Mutex + opts Options + episodes map[Key]*episodeState + closedEpisodes closedEpisodeHeap + closed bool + done chan struct{} +} + +type episodeState struct { + created bool + createdCh chan struct{} + closed bool + closedAt time.Time + closedHeapItem *closedEpisodeItem + parts []Part + bytes int64 + subscribers map[*episodeSubscriber]struct{} +} + +type closedEpisodeItem struct { + key Key + closedAt time.Time +} + +type closedEpisodeHeap []*closedEpisodeItem + +func (h closedEpisodeHeap) Len() int { + return len(h) +} + +func (h closedEpisodeHeap) Less(i, j int) bool { + return h[i].closedAt.Before(h[j].closedAt) +} + +func (h closedEpisodeHeap) Swap(i, j int) { + h[i], h[j] = h[j], h[i] +} + +func (h *closedEpisodeHeap) Push(value any) { + item, ok := value.(*closedEpisodeItem) + if !ok { + panic("closed episode heap received invalid item") + } + *h = append(*h, item) +} + +func (h *closedEpisodeHeap) Pop() any { + old := *h + last := old[len(old)-1] + old[len(old)-1] = nil + *h = old[:len(old)-1] + return last +} + +type episodeSubscriber struct { + out chan Part + notifyCh chan struct{} + stopCh chan struct{} + next int + stopOnce sync.Once +} + +// New returns a message part buffer. +func New(options Options) *Buffer { + if options.MaxEpisodeBytes <= 0 { + options.MaxEpisodeBytes = defaultMaxEpisodeBytes + } + if options.ClosedEpisodeRetention <= 0 { + options.ClosedEpisodeRetention = defaultClosedEpisodeRetention + } + if options.SubscriberSendTimeout <= 0 { + options.SubscriberSendTimeout = defaultSubscriberSendTimeout + } + if options.SubscriberChannelSize < 0 { + options.SubscriberChannelSize = 0 + } + if options.SubscriberChannelSize == 0 { + options.SubscriberChannelSize = defaultSubscriberChannelSize + } + if options.Clock == nil { + options.Clock = quartz.NewReal() + } + buffer := &Buffer{ + opts: options, + episodes: make(map[Key]*episodeState), + done: make(chan struct{}), + } + buffer.startCleanupLoop() + return buffer +} + +// CreateEpisode creates a new episode. +func (b *Buffer) CreateEpisode(key Key) error { + b.mu.Lock() + defer b.mu.Unlock() + if b.closed { + return ErrMessagePartBufferClosed + } + b.gcClosedEpisodesLocked(b.opts.Clock.Now("message-part-buffer", "create")) + episode := b.episodeLocked(key) + if episode.created { + return ErrEpisodeExists + } + episode.created = true + close(episode.createdCh) + return nil +} + +// AddPart appends a part to an existing episode. +func (b *Buffer) AddPart(key Key, role codersdk.ChatMessageRole, part codersdk.ChatMessagePart) error { + b.mu.Lock() + defer b.mu.Unlock() + if b.closed { + return ErrMessagePartBufferClosed + } + episode := b.episodes[key] + if episode == nil || !episode.created { + return ErrEpisodeNotFound + } + if episode.closed { + return ErrEpisodeClosed + } + buffered := Part{ + Seq: int64(len(episode.parts) + 1), + Role: role, + MessagePart: part, + } + sizeBytes, err := serializedPartBytes(buffered) + if err != nil { + return err + } + if episode.bytes+sizeBytes > b.opts.MaxEpisodeBytes { + return ErrEpisodeFull + } + episode.parts = append(episode.parts, buffered) + episode.bytes += sizeBytes + for subscriber := range episode.subscribers { + notifySubscriber(subscriber) + } + return nil +} + +// CloseEpisode marks an episode closed and closes its subscribers. +func (b *Buffer) CloseEpisode(key Key) error { + b.mu.Lock() + defer b.mu.Unlock() + if b.closed { + return ErrMessagePartBufferClosed + } + episode := b.episodeLocked(key) + if !episode.created { + episode.created = true + close(episode.createdCh) + } + if episode.closed { + return nil + } + episode.closed = true + episode.closedAt = b.opts.Clock.Now("message-part-buffer", "close") + b.queueClosedEpisodeLocked(key, episode) + for subscriber := range episode.subscribers { + notifySubscriber(subscriber) + } + return nil +} + +// GetParts returns a snapshot of buffered parts for an episode. +func (b *Buffer) GetParts(key Key) ([]Part, error) { + b.mu.Lock() + defer b.mu.Unlock() + if b.closed { + return nil, ErrMessagePartBufferClosed + } + b.gcClosedEpisodesLocked(b.opts.Clock.Now("message-part-buffer", "get")) + episode := b.episodes[key] + if episode == nil || !episode.created { + return nil, ErrEpisodeNotFound + } + return append([]Part(nil), episode.parts...), nil +} + +// SubscribeToEpisode replays existing parts and streams new parts. +func (b *Buffer) SubscribeToEpisode(ctx context.Context, key Key) (<-chan Part, func(), error) { + b.mu.Lock() + if b.closed { + b.mu.Unlock() + return nil, nil, ErrMessagePartBufferClosed + } + episode := b.episodeLocked(key) + subscriber := &episodeSubscriber{ + out: make(chan Part), + notifyCh: make(chan struct{}, 1), + stopCh: make(chan struct{}), + } + if episode.subscribers == nil { + episode.subscribers = make(map[*episodeSubscriber]struct{}) + } + episode.subscribers[subscriber] = struct{}{} + notifySubscriber(subscriber) + b.mu.Unlock() + + go b.deliverSubscriber(ctx, key, subscriber) + cancel := func() { + b.cancelSubscriber(key, subscriber) + } + return subscriber.out, cancel, nil +} + +// Close closes the buffer and all pending subscriptions. +func (b *Buffer) Close() { + b.mu.Lock() + if b.closed { + b.mu.Unlock() + return + } + b.closed = true + close(b.done) + for _, episode := range b.episodes { + for subscriber := range episode.subscribers { + b.stopSubscriberLocked(episode, subscriber) + } + if !episode.created { + episode.created = true + close(episode.createdCh) + } + } + b.mu.Unlock() +} + +func (b *Buffer) startCleanupLoop() { + ticker := b.opts.Clock.NewTicker(b.opts.ClosedEpisodeRetention, "message-part-buffer", "cleanup") + go func() { + defer ticker.Stop() + for { + select { + case <-ticker.C: + b.mu.Lock() + if b.closed { + b.mu.Unlock() + return + } + b.gcClosedEpisodesLocked(b.opts.Clock.Now("message-part-buffer", "cleanup")) + b.mu.Unlock() + case <-b.done: + return + } + } + }() +} + +func (b *Buffer) gcClosedEpisodesLocked(now time.Time) { + cutoff := now.Add(-b.opts.ClosedEpisodeRetention) + type retainedEpisode struct { + key Key + episode *episodeState + } + retained := make([]retainedEpisode, 0) + for b.closedEpisodes.Len() > 0 { + item := b.closedEpisodes[0] + if item.closedAt.After(cutoff) { + break + } + popped, ok := heap.Pop(&b.closedEpisodes).(*closedEpisodeItem) + if !ok || popped != item { + continue + } + episode := b.episodes[item.key] + if episode == nil || episode.closedHeapItem != item || !episode.closed { + continue + } + episode.closedHeapItem = nil + if len(episode.subscribers) > 0 { + retained = append(retained, retainedEpisode{key: item.key, episode: episode}) + continue + } + delete(b.episodes, item.key) + } + for _, item := range retained { + if b.episodes[item.key] != item.episode || !item.episode.closed || item.episode.closedHeapItem != nil { + continue + } + b.queueClosedEpisodeLocked(item.key, item.episode) + } +} + +func (b *Buffer) queueClosedEpisodeLocked(key Key, episode *episodeState) { + if episode.closedHeapItem != nil { + return + } + item := &closedEpisodeItem{key: key, closedAt: episode.closedAt} + episode.closedHeapItem = item + heap.Push(&b.closedEpisodes, item) +} + +func (b *Buffer) episodeLocked(key Key) *episodeState { + episode := b.episodes[key] + if episode != nil { + return episode + } + episode = &episodeState{createdCh: make(chan struct{})} + b.episodes[key] = episode + return episode +} + +func (b *Buffer) subscriberParts(key Key, subscriber *episodeSubscriber) (parts []Part, closed bool, ok bool) { + b.mu.Lock() + defer b.mu.Unlock() + if b.closed { + return nil, false, false + } + episode := b.episodes[key] + if episode == nil { + return nil, false, false + } + if !episode.created { + return nil, false, true + } + if subscriber.next > len(episode.parts) { + return nil, false, false + } + parts = append([]Part(nil), episode.parts[subscriber.next:]...) + subscriber.next = len(episode.parts) + return parts, episode.closed && subscriber.next == len(episode.parts), true +} + +func (b *Buffer) deliverSubscriber(ctx context.Context, key Key, subscriber *episodeSubscriber) { + defer close(subscriber.out) + defer b.removeSubscriber(key, subscriber) + for { + parts, closed, ok := b.subscriberParts(key, subscriber) + if !ok { + return + } + for _, part := range parts { + if !b.sendSubscriberPart(ctx, subscriber, part) { + return + } + } + if closed { + return + } + select { + case <-subscriber.notifyCh: + case <-subscriber.stopCh: + return + case <-ctx.Done(): + return + case <-b.done: + return + } + } +} + +func (b *Buffer) sendSubscriberPart(ctx context.Context, subscriber *episodeSubscriber, part Part) bool { + timer := b.opts.Clock.NewTimer(b.opts.SubscriberSendTimeout, "message-part-buffer", "subscriber-send") + defer timer.Stop() + select { + case subscriber.out <- part: + return true + case <-timer.C: + return false + case <-subscriber.stopCh: + return false + case <-ctx.Done(): + return false + case <-b.done: + return false + } +} + +func (b *Buffer) cancelSubscriber(key Key, subscriber *episodeSubscriber) { + b.mu.Lock() + defer b.mu.Unlock() + episode := b.episodes[key] + if episode != nil { + b.stopSubscriberLocked(episode, subscriber) + return + } + subscriber.stop() +} + +func (b *Buffer) removeSubscriber(key Key, subscriber *episodeSubscriber) { + b.mu.Lock() + defer b.mu.Unlock() + episode := b.episodes[key] + if episode == nil { + return + } + delete(episode.subscribers, subscriber) + if episode.closed && len(episode.subscribers) == 0 { + b.queueClosedEpisodeLocked(key, episode) + } +} + +func (*Buffer) stopSubscriberLocked(episode *episodeState, subscriber *episodeSubscriber) { + delete(episode.subscribers, subscriber) + subscriber.stop() +} + +func notifySubscriber(subscriber *episodeSubscriber) { + select { + case subscriber.notifyCh <- struct{}{}: + default: + } +} + +func (s *episodeSubscriber) stop() { + s.stopOnce.Do(func() { close(s.stopCh) }) +} + +func serializedPartBytes(part Part) (int64, error) { + data, err := json.Marshal(part.jsonValue()) + if err != nil { + return 0, err + } + return int64(len(data)), nil +} diff --git a/coderd/x/chatd/messagepartbuffer/message_part_buffer_test.go b/coderd/x/chatd/messagepartbuffer/message_part_buffer_test.go new file mode 100644 index 0000000000..3d2bc7823c --- /dev/null +++ b/coderd/x/chatd/messagepartbuffer/message_part_buffer_test.go @@ -0,0 +1,377 @@ +package messagepartbuffer_test + +import ( + "context" + "encoding/json" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/coderd/x/chatd/messagepartbuffer" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/testutil" + "github.com/coder/quartz" +) + +func TestBuffer_CreateEpisodeRejectsDuplicate(t *testing.T) { + t.Parallel() + + buffer := messagepartbuffer.New(messagepartbuffer.Options{}) + key := testEpisodeKey() + require.NoError(t, buffer.CreateEpisode(key)) + require.ErrorIs(t, buffer.CreateEpisode(key), messagepartbuffer.ErrEpisodeExists) +} + +func TestBuffer_AddPartAndGetParts(t *testing.T) { + t.Parallel() + + buffer := messagepartbuffer.New(messagepartbuffer.Options{}) + key := testEpisodeKey() + require.NoError(t, buffer.CreateEpisode(key)) + require.NoError(t, buffer.AddPart(key, codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageText("hello"))) + + parts, err := buffer.GetParts(key) + require.NoError(t, err) + require.Len(t, parts, 1) + require.Equal(t, int64(1), parts[0].Seq) + require.Equal(t, codersdk.ChatMessageRoleAssistant, parts[0].Role) + require.Equal(t, codersdk.ChatMessageText("hello"), parts[0].MessagePart) +} + +func TestBuffer_AddPartMissingEpisodeReturnsNotFound(t *testing.T) { + t.Parallel() + + buffer := messagepartbuffer.New(messagepartbuffer.Options{}) + err := buffer.AddPart(testEpisodeKey(), codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageText("hello")) + require.ErrorIs(t, err, messagepartbuffer.ErrEpisodeNotFound) +} + +func TestBuffer_GetPartsMissingEpisodeReturnsNotFound(t *testing.T) { + t.Parallel() + + buffer := messagepartbuffer.New(messagepartbuffer.Options{}) + _, err := buffer.GetParts(testEpisodeKey()) + require.ErrorIs(t, err, messagepartbuffer.ErrEpisodeNotFound) +} + +func TestBuffer_AddPartFullEpisodeReturnsFull(t *testing.T) { + t.Parallel() + + buffer := messagepartbuffer.New(messagepartbuffer.Options{MaxEpisodeBytes: 1}) + key := testEpisodeKey() + require.NoError(t, buffer.CreateEpisode(key)) + err := buffer.AddPart(key, codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageText("hello")) + require.ErrorIs(t, err, messagepartbuffer.ErrEpisodeFull) + parts, getErr := buffer.GetParts(key) + require.NoError(t, getErr) + require.Empty(t, parts) +} + +func TestBuffer_CloseEpisodeMissingCreatesClosedEpisode(t *testing.T) { + t.Parallel() + + buffer := messagepartbuffer.New(messagepartbuffer.Options{}) + key := testEpisodeKey() + require.NoError(t, buffer.CloseEpisode(key)) + parts, err := buffer.GetParts(key) + require.NoError(t, err) + require.Empty(t, parts) + err = buffer.AddPart(key, codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageText("tail")) + require.ErrorIs(t, err, messagepartbuffer.ErrEpisodeClosed) +} + +func TestBuffer_CloseEpisodeIdempotent(t *testing.T) { + t.Parallel() + + buffer := messagepartbuffer.New(messagepartbuffer.Options{}) + key := testEpisodeKey() + require.NoError(t, buffer.CreateEpisode(key)) + require.NoError(t, buffer.CloseEpisode(key)) + require.NoError(t, buffer.CloseEpisode(key)) +} + +func TestBuffer_SubscribeExistingReplaysThenStreamsLiveParts(t *testing.T) { + t.Parallel() + + buffer := messagepartbuffer.New(messagepartbuffer.Options{}) + key := testEpisodeKey() + require.NoError(t, buffer.CreateEpisode(key)) + require.NoError(t, buffer.AddPart(key, codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageText("before"))) + + ctx := testutil.Context(t, testutil.WaitLong) + ch, cancel, err := buffer.SubscribeToEpisode(ctx, key) + require.NoError(t, err) + defer cancel() + require.Equal(t, "before", receivePart(t, ch).MessagePart.Text) + + require.NoError(t, buffer.AddPart(key, codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageText("after"))) + require.Equal(t, "after", receivePart(t, ch).MessagePart.Text) +} + +func TestBuffer_SubscribeClosedEpisodeReplaysThenCloses(t *testing.T) { + t.Parallel() + + buffer := messagepartbuffer.New(messagepartbuffer.Options{}) + key := testEpisodeKey() + require.NoError(t, buffer.CreateEpisode(key)) + require.NoError(t, buffer.AddPart(key, codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageText("before"))) + require.NoError(t, buffer.CloseEpisode(key)) + + ctx := testutil.Context(t, testutil.WaitLong) + ch, cancel, err := buffer.SubscribeToEpisode(ctx, key) + require.NoError(t, err) + defer cancel() + require.Equal(t, "before", receivePart(t, ch).MessagePart.Text) + assertChannelClosed(t, ch) +} + +func TestBuffer_SubscribeBeforeCreateReturnsAndWaitsWithoutNotFound(t *testing.T) { + t.Parallel() + + buffer := messagepartbuffer.New(messagepartbuffer.Options{}) + key := testEpisodeKey() + ctx := testutil.Context(t, testutil.WaitLong) + ch, cancel, err := buffer.SubscribeToEpisode(ctx, key) + require.NoError(t, err) + defer cancel() + + select { + case part := <-ch: + t.Fatalf("received part before episode create: %+v", part) + default: + } + + require.NoError(t, buffer.CreateEpisode(key)) + require.NoError(t, buffer.AddPart(key, codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageText("live"))) + require.Equal(t, "live", receivePart(t, ch).MessagePart.Text) +} + +func TestBuffer_AddPartAssignsContiguousSeq(t *testing.T) { + t.Parallel() + + buffer := messagepartbuffer.New(messagepartbuffer.Options{}) + key := testEpisodeKey() + require.NoError(t, buffer.CreateEpisode(key)) + for i := range 3 { + require.NoError(t, buffer.AddPart(key, codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageText(string(rune('a'+i))))) + } + parts, err := buffer.GetParts(key) + require.NoError(t, err) + require.Equal(t, []int64{1, 2, 3}, []int64{parts[0].Seq, parts[1].Seq, parts[2].Seq}) +} + +func TestBuffer_EpisodeByteLimitUsesJSONAccounting(t *testing.T) { + t.Parallel() + + part := codersdk.ChatMessageText("hello") + limit := serializedPartBytes(t, messagepartbuffer.Part{Seq: 1, Role: codersdk.ChatMessageRoleAssistant, MessagePart: part}) + buffer := messagepartbuffer.New(messagepartbuffer.Options{MaxEpisodeBytes: limit}) + key := testEpisodeKey() + require.NoError(t, buffer.CreateEpisode(key)) + require.NoError(t, buffer.AddPart(key, codersdk.ChatMessageRoleAssistant, part)) + err := buffer.AddPart(key, codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageText("too much")) + require.ErrorIs(t, err, messagepartbuffer.ErrEpisodeFull) +} + +func TestBuffer_GCClosedEpisodeAfterGraceAndNoSubscribers(t *testing.T) { + t.Parallel() + + clock := quartz.NewMock(t) + trap := clock.Trap().NewTimer("message-part-buffer", "subscriber-send") + defer trap.Close() + buffer := messagepartbuffer.New(messagepartbuffer.Options{ + Clock: clock, + ClosedEpisodeRetention: time.Minute, + SubscriberSendTimeout: 10 * time.Minute, + }) + key := testEpisodeKey() + require.NoError(t, buffer.CreateEpisode(key)) + require.NoError(t, buffer.AddPart(key, codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageText("held"))) + ctx := testutil.Context(t, testutil.WaitLong) + ch, cancel, err := buffer.SubscribeToEpisode(ctx, key) + require.NoError(t, err) + require.NoError(t, buffer.CloseEpisode(key)) + call := trap.MustWait(ctx) + call.MustRelease(ctx) + clock.Advance(time.Minute).MustWait(ctx) + clock.Advance(time.Second).MustWait(ctx) + _, err = buffer.GetParts(key) + require.NoError(t, err) + + cancel() + drainUntilClosed(t, ch) + _, err = buffer.GetParts(key) + require.ErrorIs(t, err, messagepartbuffer.ErrEpisodeNotFound) +} + +func TestBuffer_GCRetainedSubscribedEpisodeDoesNotBlockOtherExpiredEpisodes(t *testing.T) { + t.Parallel() + + clock := quartz.NewMock(t) + trap := clock.Trap().NewTimer("message-part-buffer", "subscriber-send") + defer trap.Close() + buffer := messagepartbuffer.New(messagepartbuffer.Options{ + Clock: clock, + ClosedEpisodeRetention: time.Minute, + SubscriberSendTimeout: 10 * time.Minute, + }) + retainedKey := testEpisodeKey() + collectedKey := testEpisodeKey() + require.NoError(t, buffer.CreateEpisode(retainedKey)) + require.NoError(t, buffer.AddPart(retainedKey, codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageText("held"))) + require.NoError(t, buffer.CreateEpisode(collectedKey)) + require.NoError(t, buffer.AddPart(collectedKey, codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageText("collect me"))) + ctx := testutil.Context(t, testutil.WaitLong) + ch, cancel, err := buffer.SubscribeToEpisode(ctx, retainedKey) + require.NoError(t, err) + defer cancel() + require.NoError(t, buffer.CloseEpisode(retainedKey)) + require.NoError(t, buffer.CloseEpisode(collectedKey)) + call := trap.MustWait(ctx) + call.MustRelease(ctx) + clock.Advance(time.Minute).MustWait(ctx) + clock.Advance(time.Second).MustWait(ctx) + + _, err = buffer.GetParts(retainedKey) + require.NoError(t, err) + _, err = buffer.GetParts(collectedKey) + require.ErrorIs(t, err, messagepartbuffer.ErrEpisodeNotFound) + + cancel() + drainUntilClosed(t, ch) + _, err = buffer.GetParts(retainedKey) + require.ErrorIs(t, err, messagepartbuffer.ErrEpisodeNotFound) +} + +func TestBuffer_SlowSubscriberClosed(t *testing.T) { + t.Parallel() + + clock := quartz.NewMock(t) + trap := clock.Trap().NewTimer("message-part-buffer", "subscriber-send") + defer trap.Close() + stopTrap := clock.Trap().TimerStop() + defer stopTrap.Close() + buffer := messagepartbuffer.New(messagepartbuffer.Options{ + Clock: clock, + SubscriberSendTimeout: time.Second, + }) + key := testEpisodeKey() + require.NoError(t, buffer.CreateEpisode(key)) + ctx := testutil.Context(t, testutil.WaitLong) + ch, cancel, err := buffer.SubscribeToEpisode(ctx, key) + require.NoError(t, err) + defer cancel() + + require.NoError(t, buffer.AddPart(key, codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageText("blocked"))) + call := trap.MustWait(ctx) + call.MustRelease(ctx) + clock.Advance(time.Second).MustWait(ctx) + stopCall := stopTrap.MustWait(ctx) + stopCall.MustRelease(ctx) + assertChannelClosed(t, ch) +} + +func TestBuffer_BurstyOutputDoesNotCloseSubscriberBeforeSendTimeout(t *testing.T) { + t.Parallel() + + buffer := messagepartbuffer.New(messagepartbuffer.Options{SubscriberChannelSize: 1}) + key := testEpisodeKey() + require.NoError(t, buffer.CreateEpisode(key)) + ctx := testutil.Context(t, testutil.WaitLong) + ch, cancel, err := buffer.SubscribeToEpisode(ctx, key) + require.NoError(t, err) + defer cancel() + + for i := range 8 { + require.NoError(t, buffer.AddPart(key, codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageText(string(rune('a'+i))))) + } + for i := range 8 { + part := receivePart(t, ch) + require.Equal(t, string(rune('a'+i)), part.MessagePart.Text) + } +} + +func TestBuffer_SubscribeCanceledBeforeCreateCanCreateEpisode(t *testing.T) { + t.Parallel() + + buffer := messagepartbuffer.New(messagepartbuffer.Options{}) + key := testEpisodeKey() + ctx, cancel := context.WithCancel(context.Background()) + ch, cancelSub, err := buffer.SubscribeToEpisode(ctx, key) + require.NoError(t, err) + cancel() + drainUntilClosed(t, ch) + cancelSub() + require.NoError(t, buffer.CreateEpisode(key)) +} + +func TestBuffer_CloseClosesPendingSubscriptionAndRejectsOperations(t *testing.T) { + t.Parallel() + + buffer := messagepartbuffer.New(messagepartbuffer.Options{}) + key := testEpisodeKey() + ctx := testutil.Context(t, testutil.WaitLong) + ch, cancel, err := buffer.SubscribeToEpisode(ctx, key) + require.NoError(t, err) + defer cancel() + buffer.Close() + assertChannelClosed(t, ch) + require.ErrorIs(t, buffer.CreateEpisode(key), messagepartbuffer.ErrMessagePartBufferClosed) +} + +func testEpisodeKey() messagepartbuffer.Key { + return messagepartbuffer.Key{ChatID: uuid.New(), HistoryVersion: 1, GenerationAttempt: 1} +} + +func receivePart(t *testing.T, ch <-chan messagepartbuffer.Part) messagepartbuffer.Part { + t.Helper() + select { + case part, ok := <-ch: + require.True(t, ok) + return part + case <-time.After(testutil.WaitLong): + t.Fatal("timed out waiting for buffered part") + return messagepartbuffer.Part{} + } +} + +func assertChannelClosed[T any](t *testing.T, ch <-chan T) { + t.Helper() + select { + case _, ok := <-ch: + require.False(t, ok) + case <-time.After(testutil.WaitLong): + t.Fatal("timed out waiting for channel close") + } +} + +func drainUntilClosed[T any](t *testing.T, ch <-chan T) { + t.Helper() + for { + select { + case _, ok := <-ch: + if !ok { + return + } + case <-time.After(testutil.WaitLong): + t.Fatal("timed out waiting for channel close") + } + } +} + +func serializedPartBytes(t *testing.T, part messagepartbuffer.Part) int64 { + t.Helper() + data, err := json.Marshal(struct { + Seq int64 `json:"seq"` + Role codersdk.ChatMessageRole `json:"role"` + Part codersdk.ChatMessagePart `json:"part"` + }{ + Seq: part.Seq, + Role: part.Role, + Part: part.MessagePart, + }) + require.NoError(t, err) + return int64(len(data)) +} diff --git a/coderd/x/chatd/options.go b/coderd/x/chatd/options.go new file mode 100644 index 0000000000..016eb3a7ff --- /dev/null +++ b/coderd/x/chatd/options.go @@ -0,0 +1,155 @@ +package chatd + +import ( + "context" + "database/sql" + "sync/atomic" + "time" + + "github.com/google/uuid" + "golang.org/x/xerrors" + + "github.com/prometheus/client_golang/prometheus" + + "cdr.dev/slog/v3" + "github.com/coder/coder/v2/coderd/audit" + "github.com/coder/coder/v2/coderd/database" + dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub" + "github.com/coder/coder/v2/coderd/notifications" + "github.com/coder/coder/v2/coderd/x/chatd/messagepartbuffer" + "github.com/coder/quartz" +) + +const ( + defaultAcquisitionInterval = 30 * time.Second + defaultAcquisitionBatchSize = int32(10) + defaultRunnerSyncInterval = 15 * time.Second + defaultHeartbeatInterval = 9 * time.Second + defaultHeartbeatCleanupEvery = 30 * time.Second + defaultHeartbeatStaleSeconds = int32(30) + defaultArchiveInterval = 10 * time.Minute + defaultArchiveBatchSize = int32(1000) + defaultStateChannelSize = 64 + defaultTaskRetryInitialBackoff = 100 * time.Millisecond + defaultTaskRetryMaxBackoff = 5 * time.Second +) + +// Pubsub is the chat worker pubsub dependency. +type chatWorkerPubsub interface { + Publish(event string, message []byte) error + SubscribeWithErr(event string, listener dbpubsub.ListenerWithErr) (func(), error) +} + +// chatWorkerTaskStarter starts runner-owned side-effect tasks. +type chatWorkerTaskStarter interface { + StartGeneration(context.Context, chatWorkerTaskStartInput) error + StartInterrupt(context.Context, chatWorkerTaskStartInput) error + StartRequiresActionTimeout(context.Context, chatWorkerTaskStartInput) error + StartAbandon(context.Context, chatWorkerTaskStartInput) error +} + +// chatWorkerTaskStartInput describes one runner task invocation. +type chatWorkerTaskStartInput struct { + TaskID uuid.UUID + ChatID uuid.UUID + WorkerID uuid.UUID + RunnerID uuid.UUID + HistoryVersion int64 + GenerationAttempt int64 + Status database.ChatStatus + RequiresActionDeadlineAt sql.NullTime +} + +// chatWorkerOptions configures a chatWorker. +type chatWorkerOptions struct { + WorkerID uuid.UUID + + Store database.Store + Pubsub chatWorkerPubsub + Logger slog.Logger + Clock quartz.Clock + TaskStarter chatWorkerTaskStarter + MessagePartBuffer *messagepartbuffer.Buffer + + NotificationsEnqueuer notifications.Enqueuer + Auditor *atomic.Pointer[audit.Auditor] + AutoArchiveRecords prometheus.Counter + + AcquisitionInterval time.Duration + AcquisitionBatchSize int32 + ArchiveInterval time.Duration + ArchiveBatchSize int32 + RunnerSyncInterval time.Duration + HeartbeatInterval time.Duration + HeartbeatCleanupInterval time.Duration + HeartbeatStaleSeconds int32 + StateChannelSize int + RunnerManagerChannelSize int + AcquisitionWakeChannelSize int + TaskRetryInitialBackoff time.Duration + TaskRetryMaxBackoff time.Duration +} + +func (o chatWorkerOptions) withDefaults() (chatWorkerOptions, error) { + if o.Store == nil { + return chatWorkerOptions{}, xerrors.New("chatworker: store is required") + } + if o.Pubsub == nil { + return chatWorkerOptions{}, xerrors.New("chatworker: pubsub is required") + } + if o.TaskStarter == nil && o.MessagePartBuffer == nil { + return chatWorkerOptions{}, xerrors.New("chatworker: task starter or message part buffer is required") + } + if o.WorkerID == uuid.Nil { + return chatWorkerOptions{}, xerrors.New("chatworker: worker ID is required") + } + if o.Clock == nil { + o.Clock = quartz.NewReal() + } + if o.AcquisitionInterval <= 0 { + o.AcquisitionInterval = defaultAcquisitionInterval + } + if o.AcquisitionBatchSize <= 0 { + o.AcquisitionBatchSize = defaultAcquisitionBatchSize + } + if o.ArchiveInterval <= 0 { + o.ArchiveInterval = defaultArchiveInterval + } + if o.ArchiveBatchSize <= 0 { + o.ArchiveBatchSize = defaultArchiveBatchSize + } + if o.NotificationsEnqueuer == nil { + o.NotificationsEnqueuer = notifications.NewNoopEnqueuer() + } + if o.RunnerSyncInterval <= 0 { + o.RunnerSyncInterval = defaultRunnerSyncInterval + } + if o.HeartbeatInterval <= 0 { + o.HeartbeatInterval = defaultHeartbeatInterval + } + if o.HeartbeatCleanupInterval <= 0 { + o.HeartbeatCleanupInterval = defaultHeartbeatCleanupEvery + } + if o.HeartbeatStaleSeconds <= 0 { + o.HeartbeatStaleSeconds = defaultHeartbeatStaleSeconds + } + if o.StateChannelSize <= 0 { + o.StateChannelSize = defaultStateChannelSize + } + if o.RunnerManagerChannelSize <= 0 { + o.RunnerManagerChannelSize = defaultStateChannelSize + } + if o.AcquisitionWakeChannelSize <= 0 { + o.AcquisitionWakeChannelSize = 1 + } + if o.TaskRetryInitialBackoff <= 0 { + o.TaskRetryInitialBackoff = defaultTaskRetryInitialBackoff + } + if o.TaskRetryMaxBackoff <= 0 { + o.TaskRetryMaxBackoff = defaultTaskRetryMaxBackoff + } + if o.TaskRetryMaxBackoff < o.TaskRetryInitialBackoff { + o.TaskRetryMaxBackoff = o.TaskRetryInitialBackoff + } + return o, nil +} diff --git a/coderd/x/chatd/quickgen.go b/coderd/x/chatd/quickgen.go index 774e02d107..4019169759 100644 --- a/coderd/x/chatd/quickgen.go +++ b/coderd/x/chatd/quickgen.go @@ -20,6 +20,7 @@ import ( "golang.org/x/xerrors" "cdr.dev/slog/v3" + "github.com/coder/coder/v2/coderd/aibridge" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/x/chatd/chatdebug" "github.com/coder/coder/v2/coderd/x/chatd/chatprompt" @@ -136,6 +137,68 @@ type generatedTurnStatusLabel struct { Label string `json:"label" description:"Compact 2-5 word current chat status label"` } +// GenerateChatTitleAsync fires a best-effort, automatic title-generation +// pass for a freshly created chat. It is intended to be called from the +// chat-creation endpoint right after the chat and its initial user +// message are persisted. +// +// The work runs in a tracked goroutine with a detached context so it +// neither blocks the HTTP response nor is canceled when the request +// completes. It resolves the chat's model and provider keys, then +// delegates to maybeGenerateChatTitle, which only acts on the first user +// turn (see titleInput) and is otherwise a no-op. Errors are logged and +// swallowed. +func (p *Server) GenerateChatTitleAsync(ctx context.Context, chat database.Chat) { + logger := p.logger.With( + slog.F("chat_id", chat.ID), + slog.F("owner_id", chat.OwnerID), + ) + // Snapshot the messages synchronously so the first-turn eligibility + // check (titleInput) is evaluated against creation-time state. Loading + // inside the goroutine would race the chat worker's first assistant + // reply and could skip title generation. + messages, err := p.db.GetChatMessagesForPromptByChatID(ctx, chat.ID) + if err != nil { + logger.Debug(ctx, "failed to load messages for automatic title generation", + slog.Error(err), + ) + return + } + if _, ok := titleInput(chat, messages); !ok { + return + } + // Detach from the request lifetime so title generation can finish + // even after the create response is written. + titleCtx := context.WithoutCancel(ctx) + p.inflight.Go(func() { + modelOpts := modelBuildOptionsFromMessages(messages) + if modelOpts.ActiveAPIKeyID != "" { + titleCtx = aibridge.WithDelegatedAPIKeyID(titleCtx, modelOpts.ActiveAPIKeyID) + } + model, modelConfig, keys, route, _, _, _, err := p.resolveChatModel(titleCtx, chat, modelOpts) + if err != nil { + logger.Debug(titleCtx, "failed to resolve model for automatic title generation", + slog.Error(err), + ) + return + } + p.maybeGenerateChatTitle( + titleCtx, + chat, + messages, + modelConfig.Provider, + modelConfig.Model, + model, + route, + keys, + modelOpts, + &generatedChatTitle{}, + logger, + p.existingDebugService(), + ) + }) +} + // maybeGenerateChatTitle generates an AI title for the chat when // appropriate (first user message, no assistant reply yet, and the // current title is either empty or still the fallback truncation). @@ -284,10 +347,6 @@ func (p *Server) maybeGenerateChatTitle( chat.Title = title generatedTitle.Store(title) p.publishChatPubsubEvent(chat, codersdk.ChatWatchEventKindTitleChange, nil) - - // AcquireChats uses SKIP LOCKED; re-wake so a wake racing this - // UPDATE's row lock does not strand a freshly-pending chat. - p.signalWake() return } diff --git a/coderd/x/chatd/runner.go b/coderd/x/chatd/runner.go new file mode 100644 index 0000000000..fa65825293 --- /dev/null +++ b/coderd/x/chatd/runner.go @@ -0,0 +1,321 @@ +package chatd + +import ( + "context" + "errors" + "sync" + + "github.com/google/uuid" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/database" + coderdpubsub "github.com/coder/coder/v2/coderd/pubsub" +) + +type taskKind string + +const ( + taskKindGeneration taskKind = "generation" + taskKindInterrupt taskKind = "interrupt" + taskKindRequiresActionTimeout taskKind = "requires_action_timeout" + taskKindAbandon taskKind = "abandon" +) + +type taskInstanceID uuid.UUID + +type localWorkKey struct { + historyVersion int64 + status database.ChatStatus +} + +type taskIndexKey struct { + kind taskKind + key localWorkKey +} + +type taskRecord struct { + id taskInstanceID + kind taskKind + localKey localWorkKey + cancel context.CancelFunc + done <-chan struct{} +} + +type runner struct { + ctx context.Context + mgr *runnerManager + rec *runnerRecord + opts chatWorkerOptions + + lastSnapshotVersion int64 + hasAcceptedState bool + latestState runnerStateUpdate + + activeTaskID taskInstanceID + activeTaskSet bool + tasks map[taskInstanceID]*taskRecord + tasksByIndex map[taskIndexKey]taskInstanceID + localLocks *localLockSet +} + +func newRunner(ctx context.Context, mgr *runnerManager, rec *runnerRecord, opts chatWorkerOptions) *runner { + return &runner{ + ctx: ctx, + mgr: mgr, + rec: rec, + opts: opts, + tasks: make(map[taskInstanceID]*taskRecord), + tasksByIndex: make(map[taskIndexKey]taskInstanceID), + localLocks: newLocalLockSet(), + } +} + +func (r *runner) run() { + if !r.bootstrap() { + return + } + for { + select { + case state := <-r.rec.stateCh: + r.processState(state) + case <-r.ctx.Done(): + r.cancelActiveTask() + return + } + } +} + +func (r *runner) bootstrap() bool { + channel := coderdpubsub.ChatStateUpdateChannel(r.rec.key.ChatID) + unsubscribe, err := r.opts.Pubsub.SubscribeWithErr(channel, coderdpubsub.HandleChatStateUpdate( + func(ctx context.Context, payload coderdpubsub.ChatStateUpdateMessage, err error) { + if err != nil { + r.opts.Logger.Warn(ctx, "chatworker state update decode failed", slogError(err)) + return + } + r.mgr.RouteStateHint(ctx, stateUpdateFromPubsub(r.rec.key.ChatID, payload)) + }, + )) + if err != nil { + r.mgr.requestCleanup(r.ctx, r.rec.key) + return false + } + if !r.rec.setUnsubscribe(unsubscribe) { + return false + } + chat, err := r.opts.Store.GetChatByID(r.ctx, r.rec.key.ChatID) + if err != nil { + r.opts.Logger.Warn(r.ctx, "chatworker runner bootstrap failed", slogError(err)) + r.mgr.requestCleanup(r.ctx, r.rec.key) + return false + } + r.mgr.RouteStateHint(r.ctx, stateUpdateFromChat(chat)) + return true +} + +func stateUpdateFromPubsub(chatID uuid.UUID, payload coderdpubsub.ChatStateUpdateMessage) runnerStateUpdate { + return runnerStateUpdate{ + ChatID: chatID, + WorkerID: payload.WorkerID, + RunnerID: payload.RunnerID, + SnapshotVersion: payload.SnapshotVersion, + HistoryVersion: payload.HistoryVersion, + QueueVersion: payload.QueueVersion, + GenerationAttempt: payload.GenerationAttempt, + Status: database.ChatStatus(payload.Status), + Archived: payload.Archived, + } +} + +func (r *runner) processState(state runnerStateUpdate) { + if state.SnapshotVersion <= r.lastSnapshotVersion { + return + } + + r.removeFinishedTasks() + + if !uuidPtrEqual(state.WorkerID, r.rec.workerID) || !uuidPtrEqual(state.RunnerID, r.rec.key.RunnerID) { + r.acceptState(state) + r.mgr.requestCleanup(r.ctx, r.rec.key) + return + } + + changed := !r.hasAcceptedState || + r.latestState.HistoryVersion != state.HistoryVersion || + r.latestState.Status != state.Status || + r.latestState.Archived != state.Archived + if !changed { + r.acceptState(state) + return + } + if r.hasAcceptedState && r.activeTaskSet { + r.cancelActiveTask() + } + + r.spawnForState(state) + r.acceptState(state) +} + +func (r *runner) acceptState(state runnerStateUpdate) { + r.hasAcceptedState = true + r.latestState = state + r.lastSnapshotVersion = state.SnapshotVersion +} + +func (r *runner) spawnForState(state runnerStateUpdate) { + if state.Archived { + r.spawnTaskIfNeeded(taskKindAbandon, state) + return + } + switch state.Status { + case database.ChatStatusRunning: + r.spawnTaskIfNeeded(taskKindGeneration, state) + case database.ChatStatusInterrupting: + r.spawnTaskIfNeeded(taskKindInterrupt, state) + case database.ChatStatusRequiresAction: + r.spawnTaskIfNeeded(taskKindRequiresActionTimeout, state) + case database.ChatStatusWaiting, database.ChatStatusError: + r.spawnTaskIfNeeded(taskKindAbandon, state) + default: + r.spawnTaskIfNeeded(taskKindAbandon, state) + } +} + +func (r *runner) spawnTaskIfNeeded(kind taskKind, state runnerStateUpdate) { + key := localWorkKey{historyVersion: state.HistoryVersion, status: state.Status} + idx := taskIndexKey{kind: kind, key: key} + if r.activeTaskSet && r.tasksByIndex[idx] == r.activeTaskID { + return + } + + id := taskInstanceID(uuid.New()) + taskCtx, cancel := context.WithCancel(r.ctx) + done := make(chan struct{}) + record := &taskRecord{ + id: id, + kind: kind, + localKey: key, + cancel: cancel, + done: done, + } + r.tasks[id] = record + r.tasksByIndex[idx] = id + r.activeTaskID = id + r.activeTaskSet = true + + input := chatWorkerTaskStartInput{ + TaskID: uuid.UUID(id), + ChatID: r.rec.key.ChatID, + WorkerID: r.rec.workerID, + RunnerID: r.rec.key.RunnerID, + HistoryVersion: state.HistoryVersion, + GenerationAttempt: state.GenerationAttempt, + Status: state.Status, + RequiresActionDeadlineAt: state.RequiresActionDeadlineAt, + } + go r.runTask(taskCtx, kind, key, input, done) +} + +func (r *runner) runTask( + ctx context.Context, + kind taskKind, + key localWorkKey, + input chatWorkerTaskStartInput, + done chan<- struct{}, +) { + defer close(done) + err := runTaskWithRetry(ctx, r.opts.retryOptions(), kind, func(ctx context.Context) error { + unlock, ok := r.localLocks.acquire(ctx, key) + if !ok { + return errTaskExpectedExit + } + defer unlock() + if ctx.Err() != nil { + return errTaskExpectedExit + } + + switch kind { + case taskKindGeneration: + return r.opts.TaskStarter.StartGeneration(ctx, input) + case taskKindInterrupt: + return r.opts.TaskStarter.StartInterrupt(ctx, input) + case taskKindRequiresActionTimeout: + return r.opts.TaskStarter.StartRequiresActionTimeout(ctx, input) + case taskKindAbandon: + return r.opts.TaskStarter.StartAbandon(ctx, input) + default: + return errors.Join(errTaskExpectedExit, xerrors.Errorf("unknown task kind %q", kind)) + } + }) + if err != nil && ctx.Err() == nil { + r.opts.Logger.Warn(ctx, "chatworker task failed", slogError(err)) + } +} + +func (r *runner) cancelActiveTask() { + if !r.activeTaskSet { + return + } + id := r.activeTaskID + r.activeTaskSet = false + if record := r.tasks[id]; record != nil { + record.cancel() + } +} + +func (r *runner) removeFinishedTasks() { + for id, record := range r.tasks { + select { + case <-record.done: + delete(r.tasks, id) + idx := taskIndexKey{kind: record.kind, key: record.localKey} + if r.tasksByIndex[idx] == id { + delete(r.tasksByIndex, idx) + } + if r.activeTaskSet && r.activeTaskID == id { + r.activeTaskSet = false + } + default: + } + } +} + +func uuidPtrEqual(got *uuid.UUID, want uuid.UUID) bool { + return got != nil && *got == want +} + +type localLockSet struct { + mu sync.Mutex + locked map[localWorkKey]chan struct{} +} + +func newLocalLockSet() *localLockSet { + return &localLockSet{locked: make(map[localWorkKey]chan struct{})} +} + +func (l *localLockSet) acquire(ctx context.Context, key localWorkKey) (func(), bool) { + for { + l.mu.Lock() + wait, ok := l.locked[key] + if !ok { + released := make(chan struct{}) + l.locked[key] = released + l.mu.Unlock() + return func() { + l.mu.Lock() + if l.locked[key] == released { + delete(l.locked, key) + close(released) + } + l.mu.Unlock() + }, true + } + l.mu.Unlock() + + select { + case <-wait: + case <-ctx.Done(): + return nil, false + } + } +} diff --git a/coderd/x/chatd/runner_manager.go b/coderd/x/chatd/runner_manager.go new file mode 100644 index 0000000000..515cfff880 --- /dev/null +++ b/coderd/x/chatd/runner_manager.go @@ -0,0 +1,523 @@ +package chatd + +import ( + "context" + "database/sql" + "encoding/json" + "sync" + "time" + + "github.com/google/uuid" + "golang.org/x/xerrors" + + "cdr.dev/slog/v3" + "github.com/coder/coder/v2/coderd/database" + coderdpubsub "github.com/coder/coder/v2/coderd/pubsub" +) + +const shutdownCleanupTimeout = 5 * time.Second + +type runnerKey struct { + ChatID uuid.UUID + RunnerID uuid.UUID +} + +type runnerStateUpdate struct { + ChatID uuid.UUID + WorkerID *uuid.UUID + RunnerID *uuid.UUID + SnapshotVersion int64 + HistoryVersion int64 + QueueVersion int64 + GenerationAttempt int64 + Status database.ChatStatus + Archived bool + RequiresActionDeadlineAt sql.NullTime +} + +type spawnRunnerRequest struct { + ChatID uuid.UUID + WorkerID uuid.UUID + RunnerID uuid.UUID +} + +type runnerRecord struct { + key runnerKey + workerID uuid.UUID + cancel context.CancelFunc + done <-chan struct{} + stateCh chan runnerStateUpdate + + mu sync.Mutex + unsubscribe func() + cleanupStarted bool +} + +func (r *runnerRecord) setUnsubscribe(unsubscribe func()) bool { + r.mu.Lock() + if r.cleanupStarted { + r.mu.Unlock() + if unsubscribe != nil { + unsubscribe() + } + return false + } + r.unsubscribe = unsubscribe + r.mu.Unlock() + return true +} + +func (r *runnerRecord) startCleanup() { + r.mu.Lock() + if r.cleanupStarted { + r.mu.Unlock() + return + } + r.cleanupStarted = true + unsubscribe := r.unsubscribe + r.unsubscribe = nil + r.mu.Unlock() + if unsubscribe != nil { + unsubscribe() + } + r.cancel() +} + +type runnerManager struct { + server *Server + opts chatWorkerOptions + ctx context.Context + + closed bool + spawnMu sync.Mutex + + mu sync.Mutex + spawnCh chan spawnRunnerRequest + cleanupReqCh chan runnerKey + cleanupDoneCh chan runnerKey + runners map[runnerKey]*runnerRecord + runnersByChat map[uuid.UUID]map[uuid.UUID]*runnerRecord + cleaning map[runnerKey]*runnerRecord + + wg sync.WaitGroup +} + +func newRunnerManager(ctx context.Context, server *Server, opts chatWorkerOptions) *runnerManager { + return &runnerManager{ + server: server, + opts: opts, + ctx: ctx, + spawnCh: make(chan spawnRunnerRequest, opts.RunnerManagerChannelSize), + cleanupReqCh: make(chan runnerKey, opts.RunnerManagerChannelSize), + cleanupDoneCh: make(chan runnerKey, opts.RunnerManagerChannelSize), + runners: make(map[runnerKey]*runnerRecord), + runnersByChat: make(map[uuid.UUID]map[uuid.UUID]*runnerRecord), + cleaning: make(map[runnerKey]*runnerRecord), + } +} + +func (m *runnerManager) start() { + m.wg.Go(m.run) + m.wg.Go(m.databaseSyncLoop) + m.wg.Go(m.heartbeatLoop) + m.wg.Go(m.heartbeatCleanupLoop) +} + +func (m *runnerManager) wait() { + m.wg.Wait() +} + +func (m *runnerManager) idle() bool { + m.mu.Lock() + defer m.mu.Unlock() + return len(m.runners) == 0 && len(m.cleaning) == 0 +} + +func (m *runnerManager) Spawn(ctx context.Context, req spawnRunnerRequest) error { + m.spawnMu.Lock() + defer m.spawnMu.Unlock() + if m.closed { + return xerrors.New("chatworker: runner manager closed") + } + + select { + case m.spawnCh <- req: + return nil + case <-ctx.Done(): + return ctx.Err() + case <-m.ctx.Done(): + return m.ctx.Err() + } +} + +func (m *runnerManager) requestCleanup(ctx context.Context, key runnerKey) { + select { + case m.cleanupReqCh <- key: + case <-ctx.Done(): + case <-m.ctx.Done(): + } +} + +func (m *runnerManager) RouteStateHint(ctx context.Context, state runnerStateUpdate) { + m.mu.Lock() + byRunner := m.runnersByChat[state.ChatID] + targets := make([]*runnerRecord, 0, len(byRunner)) + for _, rec := range byRunner { + targets = append(targets, rec) + } + m.mu.Unlock() + + for _, rec := range targets { + select { + case rec.stateCh <- state: + case <-rec.done: + case <-ctx.Done(): + return + case <-m.ctx.Done(): + return + default: + } + } +} + +func (m *runnerManager) run() { + for { + select { + case req := <-m.spawnCh: + m.handleSpawn(req) + case key := <-m.cleanupReqCh: + m.handleCleanupRequest(key) + case key := <-m.cleanupDoneCh: + m.handleCleanupDone(key) + case <-m.ctx.Done(): + queued := m.closeAndDrainQueues() + m.cancelAll() + m.releaseOwnershipHints(queued) + return + } + } +} + +func (m *runnerManager) handleSpawn(req spawnRunnerRequest) { + key := runnerKey{ChatID: req.ChatID, RunnerID: req.RunnerID} + m.mu.Lock() + if _, ok := m.runners[key]; ok { + m.opts.Logger.Warn(m.ctx, "invalid spawn request: chat runner already spawned", slog.F("key", key)) + m.mu.Unlock() + return + } + if _, ok := m.cleaning[key]; ok { + m.opts.Logger.Warn(m.ctx, "invalid spawn request: chat runner in cleanup", slog.F("key", key)) + m.mu.Unlock() + return + } + runnerCtx, cancel := context.WithCancel(m.ctx) + done := make(chan struct{}) + rec := &runnerRecord{ + key: key, + workerID: req.WorkerID, + cancel: cancel, + done: done, + stateCh: make(chan runnerStateUpdate, m.opts.StateChannelSize), + } + m.runners[key] = rec + if m.runnersByChat[req.ChatID] == nil { + m.runnersByChat[req.ChatID] = make(map[uuid.UUID]*runnerRecord) + } + m.runnersByChat[req.ChatID][req.RunnerID] = rec + m.mu.Unlock() + + r := newRunner(runnerCtx, m, rec, m.opts) + m.wg.Go(func() { + defer close(done) + r.run() + }) +} + +func (m *runnerManager) closeAndDrainQueues() []runnerKey { + m.spawnMu.Lock() + defer m.spawnMu.Unlock() + + m.closed = true + return m.drainQueues() +} + +func (m *runnerManager) drainQueues() []runnerKey { + queued := make([]runnerKey, 0) + for { + select { + case req := <-m.spawnCh: + queued = append(queued, runnerKey{ChatID: req.ChatID, RunnerID: req.RunnerID}) + case key := <-m.cleanupReqCh: + m.handleCleanupRequest(key) + case key := <-m.cleanupDoneCh: + m.handleCleanupDone(key) + default: + return queued + } + } +} + +func (m *runnerManager) handleCleanupRequest(key runnerKey) { + m.mu.Lock() + rec, ok := m.runners[key] + if !ok { + m.mu.Unlock() + return + } + delete(m.runners, key) + if byChat := m.runnersByChat[key.ChatID]; byChat != nil { + delete(byChat, key.RunnerID) + if len(byChat) == 0 { + delete(m.runnersByChat, key.ChatID) + } + } + m.cleaning[key] = rec + m.mu.Unlock() + + rec.startCleanup() + m.registerCleanupWaiter(key, rec) +} + +func (m *runnerManager) registerCleanupWaiter(key runnerKey, rec *runnerRecord) { + m.wg.Go(func() { + <-rec.done + if m.ctx.Err() != nil { + m.mu.Lock() + delete(m.cleaning, key) + m.mu.Unlock() + return + } + select { + case m.cleanupDoneCh <- key: + case <-m.ctx.Done(): + m.mu.Lock() + delete(m.cleaning, key) + m.mu.Unlock() + } + }) +} + +func (m *runnerManager) handleCleanupDone(key runnerKey) { + m.mu.Lock() + delete(m.cleaning, key) + m.mu.Unlock() +} + +func (m *runnerManager) cancelAll() { + type cleanupTarget struct { + key runnerKey + rec *runnerRecord + } + + m.mu.Lock() + active := make([]cleanupTarget, 0, len(m.runners)) + cleaning := make([]*runnerRecord, 0, len(m.cleaning)) + for _, rec := range m.cleaning { + cleaning = append(cleaning, rec) + } + for key, rec := range m.runners { + delete(m.runners, key) + m.cleaning[key] = rec + active = append(active, cleanupTarget{key: key, rec: rec}) + } + clear(m.runnersByChat) + m.mu.Unlock() + + keys := make([]runnerKey, 0, len(cleaning)+len(active)) + for _, rec := range cleaning { + rec.startCleanup() + keys = append(keys, rec.key) + } + for _, target := range active { + target.rec.startCleanup() + m.registerCleanupWaiter(target.key, target.rec) + keys = append(keys, target.key) + } + m.releaseOwnershipHints(keys) +} + +func (m *runnerManager) releaseOwnershipHints(keys []runnerKey) { + if len(keys) == 0 { + return + } + ctx, cancel := context.WithTimeout(context.WithoutCancel(m.ctx), shutdownCleanupTimeout) + defer cancel() + + chatIDs := make([]uuid.UUID, 0, len(keys)) + runnerIDs := make([]uuid.UUID, 0, len(keys)) + uniqueChatIDs := make(map[uuid.UUID]struct{}, len(keys)) + for _, key := range keys { + chatIDs = append(chatIDs, key.ChatID) + runnerIDs = append(runnerIDs, key.RunnerID) + uniqueChatIDs[key.ChatID] = struct{}{} + } + if _, err := m.opts.Store.DeleteChatHeartbeats(ctx, database.DeleteChatHeartbeatsParams{ + ChatIds: chatIDs, + RunnerIds: runnerIDs, + }); err != nil { + m.opts.Logger.Warn(ctx, "chatworker shutdown heartbeat cleanup failed", slogError(err)) + } + + syncIDs := make([]uuid.UUID, 0, len(uniqueChatIDs)) + for id := range uniqueChatIDs { + syncIDs = append(syncIDs, id) + } + chats, err := m.opts.Store.GetChatsByIDsForRunnerSync(ctx, syncIDs) + if err != nil { + m.opts.Logger.Warn(ctx, "chatworker shutdown ownership lookup failed", slogError(err)) + } + snapshotByChat := make(map[uuid.UUID]int64, len(chats)) + for _, chat := range chats { + snapshotByChat[chat.ID] = chat.SnapshotVersion + } + for _, key := range keys { + payload, err := json.Marshal(coderdpubsub.ChatStateOwnershipMessage{ + ChatID: key.ChatID, + SnapshotVersion: snapshotByChat[key.ChatID], + }) + if err != nil { + m.opts.Logger.Warn(ctx, "chatworker shutdown ownership marshal failed", slogError(err)) + continue + } + if err := m.opts.Pubsub.Publish(coderdpubsub.ChatStateOwnershipChannel, payload); err != nil { + m.opts.Logger.Warn(ctx, "chatworker shutdown ownership publish failed", slogError(err)) + } + } +} + +func (m *runnerManager) snapshotRunnerKeys() []runnerKey { + m.mu.Lock() + defer m.mu.Unlock() + keys := make([]runnerKey, 0, len(m.runners)) + for key := range m.runners { + keys = append(keys, key) + } + return keys +} + +func (m *runnerManager) databaseSyncLoop() { + ticker := m.opts.Clock.NewTicker(m.opts.RunnerSyncInterval, "chatworker", "runner-sync") + defer ticker.Stop() + for { + select { + case <-ticker.C: + if err := m.syncOnce(m.ctx); err != nil { + m.opts.Logger.Warn(m.ctx, "chatworker runner sync failed", slogError(err)) + } + case <-m.ctx.Done(): + return + } + } +} + +func (m *runnerManager) syncOnce(ctx context.Context) error { + keys := m.snapshotRunnerKeys() + if len(keys) == 0 { + return nil + } + idsByChat := make(map[uuid.UUID]struct{}, len(keys)) + for _, key := range keys { + idsByChat[key.ChatID] = struct{}{} + } + chatIDs := make([]uuid.UUID, 0, len(idsByChat)) + for id := range idsByChat { + chatIDs = append(chatIDs, id) + } + chats, err := m.opts.Store.GetChatsByIDsForRunnerSync(ctx, chatIDs) + if err != nil { + return xerrors.Errorf("get chats for runner sync: %w", err) + } + seen := make(map[uuid.UUID]struct{}, len(chats)) + for _, chat := range chats { + seen[chat.ID] = struct{}{} + m.RouteStateHint(ctx, stateUpdateFromChat(chat)) + } + for _, key := range keys { + if _, ok := seen[key.ChatID]; !ok { + m.requestCleanup(ctx, key) + } + } + return nil +} + +func (m *runnerManager) heartbeatLoop() { + ticker := m.opts.Clock.NewTicker(m.opts.HeartbeatInterval, "chatworker", "heartbeat") + defer ticker.Stop() + for { + select { + case <-ticker.C: + if err := m.heartbeatOnce(m.ctx); err != nil { + m.opts.Logger.Warn(m.ctx, "chatworker heartbeat failed", slogError(err)) + } + case <-m.ctx.Done(): + return + } + } +} + +func (m *runnerManager) heartbeatOnce(ctx context.Context) error { + keys := m.snapshotRunnerKeys() + if len(keys) == 0 { + return nil + } + chatIDs := make([]uuid.UUID, 0, len(keys)) + runnerIDs := make([]uuid.UUID, 0, len(keys)) + for _, key := range keys { + chatIDs = append(chatIDs, key.ChatID) + runnerIDs = append(runnerIDs, key.RunnerID) + } + return m.opts.Store.UpsertChatHeartbeats(ctx, database.UpsertChatHeartbeatsParams{ + ChatIds: chatIDs, + RunnerIds: runnerIDs, + }) +} + +func (m *runnerManager) heartbeatCleanupLoop() { + ticker := m.opts.Clock.NewTicker(m.opts.HeartbeatCleanupInterval, "chatworker", "heartbeat-cleanup") + defer ticker.Stop() + for { + select { + case <-ticker.C: + if err := m.heartbeatCleanupOnce(m.ctx); err != nil { + m.opts.Logger.Warn(m.ctx, "chatworker heartbeat cleanup failed", slogError(err)) + } + case <-m.ctx.Done(): + return + } + } +} + +func (m *runnerManager) heartbeatCleanupOnce(ctx context.Context) error { + _, err := m.opts.Store.DeleteStaleChatHeartbeats(ctx, m.opts.HeartbeatStaleSeconds) + return err +} + +func stateUpdateFromChat(chat database.Chat) runnerStateUpdate { + var workerID *uuid.UUID + if chat.WorkerID.Valid { + id := chat.WorkerID.UUID + workerID = &id + } + var runnerID *uuid.UUID + if chat.RunnerID.Valid { + id := chat.RunnerID.UUID + runnerID = &id + } + return runnerStateUpdate{ + ChatID: chat.ID, + WorkerID: workerID, + RunnerID: runnerID, + SnapshotVersion: chat.SnapshotVersion, + HistoryVersion: chat.HistoryVersion, + QueueVersion: chat.QueueVersion, + GenerationAttempt: chat.GenerationAttempt, + Status: chat.Status, + Archived: chat.Archived, + RequiresActionDeadlineAt: chat.RequiresActionDeadlineAt, + } +} + +func slogError(err error) slog.Field { + return slog.Error(err) +} diff --git a/coderd/x/chatd/runner_test.go b/coderd/x/chatd/runner_test.go new file mode 100644 index 0000000000..eca1df9a26 --- /dev/null +++ b/coderd/x/chatd/runner_test.go @@ -0,0 +1,137 @@ +package chatd //nolint:testpackage // Uses unexported chatworker helpers. + +import ( + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/testutil" + "github.com/coder/quartz" +) + +func TestRunner_IgnoresDuplicateStateNotifications(t *testing.T) { + t.Parallel() + f := newWorkerTestFixture(t) + chat := f.createRunningChat(t) + starter := newBlockingTaskStarter(false) + startWorker(t, testOptions(t, f, starter)) + starter.waitCall(t, taskKindGeneration, chat.ID) + latest, err := f.db.GetChatByID(testutil.Context(t, testutil.WaitShort), chat.ID) + require.NoError(t, err) + + publishChatUpdate(t, f, latest) + publishChatUpdate(t, f, latest) + starter.assertNoCall(t) +} + +func TestRunner_CancelsActiveTaskWhenHistoryChanges(t *testing.T) { + t.Parallel() + f := newWorkerTestFixture(t) + chat := f.createRunningChat(t) + starter := newBlockingTaskStarter(false) + startWorker(t, testOptions(t, f, starter)) + first := starter.waitCall(t, taskKindGeneration, chat.ID) + + updated := commitAssistantStep(t, f, chat.ID, "first step") + require.Greater(t, updated.HistoryVersion, first.input.HistoryVersion) + requireTaskCanceled(t, first) + second := starter.waitCall(t, taskKindGeneration, chat.ID) + require.Equal(t, updated.HistoryVersion, second.input.HistoryVersion) +} + +func TestRunner_CancelsActiveTaskWhenStatusChanges(t *testing.T) { + t.Parallel() + f := newWorkerTestFixture(t) + chat := f.createRunningChat(t) + starter := newBlockingTaskStarter(false) + startWorker(t, testOptions(t, f, starter)) + first := starter.waitCall(t, taskKindGeneration, chat.ID) + + updated := interruptChat(t, f, chat.ID) + require.Equal(t, database.ChatStatusInterrupting, updated.Status) + requireTaskCanceled(t, first) + second := starter.waitCall(t, taskKindInterrupt, chat.ID) + require.Equal(t, updated.HistoryVersion, second.input.HistoryVersion) +} + +func TestRunner_CleansUpOnOwnershipTakeover(t *testing.T) { + t.Parallel() + f := newWorkerTestFixture(t) + chat := f.createRunningChat(t) + starter := newBlockingTaskStarter(false) + startWorker(t, testOptions(t, f, starter)) + first := starter.waitCall(t, taskKindGeneration, chat.ID) + + acquireChat(t, f, chat.ID, uuid.New(), uuid.New()) + requireTaskCanceled(t, first) + starter.assertNoCall(t) +} + +func TestRunner_SerializesReplacementTasksForSameHistoryAndStatus(t *testing.T) { + t.Parallel() + f := newWorkerTestFixture(t) + chat := f.createRunningChat(t) + starter := newBlockingTaskStarter(true) + defer starter.releaseAll() + startWorker(t, testOptions(t, f, starter)) + first := starter.waitCall(t, taskKindGeneration, chat.ID) + + forceExecutionStateAndPublish(t, f, chat.ID, database.ChatStatusInterrupting, false) + starter.waitCall(t, taskKindInterrupt, chat.ID) + forceExecutionStateAndPublish(t, f, chat.ID, database.ChatStatusRunning, false) + starter.assertNoCall(t) + + starter.release(t, 0) + replacement := starter.waitCall(t, taskKindGeneration, chat.ID) + require.Equal(t, first.input.HistoryVersion, replacement.input.HistoryVersion) +} + +func TestRunner_AllowsReplacementForDifferentHistoryOrStatus(t *testing.T) { + t.Parallel() + f := newWorkerTestFixture(t) + chat := f.createRunningChat(t) + starter := newBlockingTaskStarter(true) + defer starter.releaseAll() + startWorker(t, testOptions(t, f, starter)) + first := starter.waitCall(t, taskKindGeneration, chat.ID) + + updated := commitAssistantStep(t, f, chat.ID, "different history") + second := starter.waitCall(t, taskKindGeneration, chat.ID) + require.Greater(t, second.input.HistoryVersion, first.input.HistoryVersion) + require.Equal(t, updated.HistoryVersion, second.input.HistoryVersion) +} + +func TestWorker_RoutesDatabaseSyncStateToActiveRunner(t *testing.T) { + t.Parallel() + f := newWorkerTestFixture(t) + chat := f.createRunningChat(t) + clock := quartz.NewMock(t) + starter := newBlockingTaskStarter(false) + opts := testOptions(t, f, starter) + opts.Clock = clock + opts.RunnerSyncInterval = time.Minute + startWorker(t, opts) + first := starter.waitCall(t, taskKindGeneration, chat.ID) + + forceExecutionState(t, f, chat.ID, database.ChatStatusInterrupting, false) + clock.Advance(time.Minute).MustWait(testutil.Context(t, testutil.WaitLong)) + requireTaskCanceled(t, first) + starter.waitCall(t, taskKindInterrupt, chat.ID) +} + +func TestWorker_CleanupStopsRoutingAndCancelsTasks(t *testing.T) { + t.Parallel() + f := newWorkerTestFixture(t) + chat := f.createRunningChat(t) + starter := newBlockingTaskStarter(false) + startWorker(t, testOptions(t, f, starter)) + first := starter.waitCall(t, taskKindGeneration, chat.ID) + + latest := acquireChat(t, f, chat.ID, uuid.New(), uuid.New()) + requireTaskCanceled(t, first) + publishChatUpdate(t, f, latest) + starter.assertNoCall(t) +} diff --git a/coderd/x/chatd/streamcollector_internal_test.go b/coderd/x/chatd/streamcollector_internal_test.go index 81dae5f133..089ad26290 100644 --- a/coderd/x/chatd/streamcollector_internal_test.go +++ b/coderd/x/chatd/streamcollector_internal_test.go @@ -10,11 +10,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "cdr.dev/slog/v3" - "github.com/coder/coder/v2/coderd/x/chatd/chatloop" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/testutil" - "github.com/coder/quartz" ) // TestStreamStateCollector exercises the four gauges emitted by @@ -168,49 +165,3 @@ func newSubscribers(t *testing.T, n int) map[uuid.UUID]chan codersdk.ChatStreamE } return subs } - -// TestStreamStateCollector_BufferDroppedIncrementsOnCapacity pre-fills -// a buffer to capacity and asserts stream_buffer_dropped_total -// increments on each subsequent publishToStream drop. -func TestStreamStateCollector_BufferDroppedIncrementsOnCapacity(t *testing.T) { - t.Parallel() - - reg := prometheus.NewRegistry() - server := &Server{ - logger: slog.Make(), - clock: quartz.NewMock(t), - metrics: chatloop.NewMetrics(reg), - } - - chatID := uuid.New() - server.chatStreams.Store(chatID, &chatStreamState{ - buffering: true, - buffer: make([]bufferedStreamPart, maxStreamBufferSize), - }) - - partEvent := codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeMessagePart, - MessagePart: &codersdk.ChatStreamMessagePart{}, - } - - server.publishToStream(chatID, partEvent) - assert.Equal(t, float64(1), counterValue(t, reg, "coderd_chatd_stream_buffer_dropped_total")) - - server.publishToStream(chatID, partEvent) - assert.Equal(t, float64(2), counterValue(t, reg, "coderd_chatd_stream_buffer_dropped_total")) -} - -func counterValue(t *testing.T, reg *prometheus.Registry, name string) float64 { - t.Helper() - families, err := reg.Gather() - require.NoError(t, err) - for _, f := range families { - if f.GetName() != name { - continue - } - require.Len(t, f.GetMetric(), 1, "counter %q should have exactly one sample", name) - return f.GetMetric()[0].GetCounter().GetValue() - } - t.Fatalf("counter %q not registered", name) - return 0 -} diff --git a/coderd/x/chatd/subagent.go b/coderd/x/chatd/subagent.go index fdc2e56f7a..7fc2fc7f9f 100644 --- a/coderd/x/chatd/subagent.go +++ b/coderd/x/chatd/subagent.go @@ -20,9 +20,11 @@ import ( "github.com/coder/coder/v2/coderd/aibridge" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbauthz" + dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub" coderdpubsub "github.com/coder/coder/v2/coderd/pubsub" "github.com/coder/coder/v2/coderd/x/chatd/chatprompt" "github.com/coder/coder/v2/coderd/x/chatd/chatprovider" + "github.com/coder/coder/v2/coderd/x/chatd/chatstate" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/workspacesdk" ) @@ -987,147 +989,114 @@ func (p *Server) createChildSubagentChatWithOptions( // for another pool checkout. deploymentPrompt := p.resolveDeploymentSystemPrompt(ctx) - var child database.Chat - txErr := p.db.InTx(func(tx database.Store) error { - if limitErr := p.checkUsageLimit(ctx, tx, parent.OwnerID, uuid.NullUUID{UUID: parent.OrganizationID, Valid: true}); limitErr != nil { - return limitErr - } - - insertedChat, err := tx.InsertChat(ctx, database.InsertChatParams{ - OrganizationID: parent.OrganizationID, - OwnerID: parent.OwnerID, - WorkspaceID: parent.WorkspaceID, - BuildID: parent.BuildID, - AgentID: parent.AgentID, - ParentChatID: uuid.NullUUID{UUID: parent.ID, Valid: true}, - RootChatID: uuid.NullUUID{UUID: rootChatID, Valid: true}, - LastModelConfigID: modelConfigID, - Title: title, - Mode: opts.chatMode, - PlanMode: childPlanMode, - ClientType: parent.ClientType, - Status: database.ChatStatusPending, - MCPServerIDs: mcpServerIDs, - Labels: pqtype.NullRawMessage{ - RawMessage: labelsJSON, - Valid: true, - }, - DynamicTools: pqtype.NullRawMessage{}, - }) - if err != nil { - return xerrors.Errorf("insert child chat: %w", err) - } - - workspaceAwareness := workspaceDetachedNoCreateAwareness - if insertedChat.WorkspaceID.Valid { - workspaceAwareness = workspaceAttachedAwareness - } - workspaceAwarenessContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{ - codersdk.ChatMessageText(workspaceAwareness), - }) - if err != nil { - return xerrors.Errorf("marshal workspace awareness: %w", err) - } - userContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{codersdk.ChatMessageText(prompt)}) - if err != nil { - return xerrors.Errorf("marshal initial user content: %w", err) - } - - systemParams := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendChatMessage. - ChatID: insertedChat.ID, - } - if deploymentPrompt != "" { - deploymentContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{ - codersdk.ChatMessageText(deploymentPrompt), - }) - if err != nil { - return xerrors.Errorf("marshal deployment system prompt: %w", err) - } - appendChatMessage(&systemParams, newChatMessage( - database.ChatMessageRoleSystem, - deploymentContent, - database.ChatMessageVisibilityModel, - modelConfigID, - chatprompt.CurrentContentVersion, - )) - } - if childSystemPrompt != "" { - childSystemPromptContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{ - codersdk.ChatMessageText(childSystemPrompt), - }) - if err != nil { - return xerrors.Errorf("marshal child system prompt: %w", err) - } - appendChatMessage(&systemParams, newChatMessage( - database.ChatMessageRoleSystem, - childSystemPromptContent, - database.ChatMessageVisibilityModel, - modelConfigID, - chatprompt.CurrentContentVersion, - )) - } - appendChatMessage(&systemParams, newChatMessage( - database.ChatMessageRoleSystem, - workspaceAwarenessContent, - database.ChatMessageVisibilityModel, - modelConfigID, - chatprompt.CurrentContentVersion, - )) - if _, err := tx.InsertChatMessages(ctx, systemParams); err != nil { - return xerrors.Errorf("insert initial child system messages: %w", err) - } - - child = insertedChat - - // Copy persisted context before the initial child prompt so the - // child cannot be acquired until its inherited context is in - // place. signalWake runs only after commit. - copiedContextParts, err := copyParentContextMessages(ctx, p.logger, tx, parent, child) - if err != nil { - return xerrors.Errorf("copy parent context messages: %w", err) - } - if err := updateChildLastInjectedContext(ctx, p.logger, tx, child.ID, copiedContextParts); err != nil { - return xerrors.Errorf("update child injected context: %w", err) - } - - userParams := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendUserChatMessage. - ChatID: insertedChat.ID, - } - childUserMsg := newUserChatMessage( - childAPIKeyID, - userContent, - database.ChatMessageVisibilityBoth, - modelConfigID, - chatprompt.CurrentContentVersion, - ) - childUserMsg = childUserMsg.withCreatedBy(parent.OwnerID) - appendUserChatMessage(&userParams, childUserMsg) - if _, err := tx.InsertChatMessages(ctx, userParams); err != nil { - return xerrors.Errorf("insert initial child user message: %w", err) - } - - return nil - }, nil) - if txErr != nil { - return database.Chat{}, xerrors.Errorf("create child chat: %w", txErr) + if limitErr := p.checkUsageLimit(ctx, p.db, parent.OwnerID, uuid.NullUUID{UUID: parent.OrganizationID, Valid: true}); limitErr != nil { + return database.Chat{}, limitErr } + workspaceAwareness := workspaceDetachedNoCreateAwareness + if parent.WorkspaceID.Valid { + workspaceAwareness = workspaceAttachedAwareness + } + workspaceAwarenessContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{ + codersdk.ChatMessageText(workspaceAwareness), + }) + if err != nil { + return database.Chat{}, xerrors.Errorf("marshal workspace awareness: %w", err) + } + userContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{codersdk.ChatMessageText(prompt)}) + if err != nil { + return database.Chat{}, xerrors.Errorf("marshal initial user content: %w", err) + } + + initialMessages := make([]chatstate.Message, 0, 4) + if deploymentPrompt != "" { + deploymentContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{ + codersdk.ChatMessageText(deploymentPrompt), + }) + if err != nil { + return database.Chat{}, xerrors.Errorf("marshal deployment system prompt: %w", err) + } + initialMessages = append(initialMessages, systemMessage(deploymentContent, modelConfigID)) + } + if childSystemPrompt != "" { + childSystemPromptContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{ + codersdk.ChatMessageText(childSystemPrompt), + }) + if err != nil { + return database.Chat{}, xerrors.Errorf("marshal child system prompt: %w", err) + } + initialMessages = append(initialMessages, systemMessage(childSystemPromptContent, modelConfigID)) + } + initialMessages = append(initialMessages, systemMessage(workspaceAwarenessContent, modelConfigID)) + + copiedContextParts, err := copyParentContextMessages(ctx, p.logger, p.db, parent) + if err != nil { + return database.Chat{}, xerrors.Errorf("copy parent context messages: %w", err) + } + var lastInjectedContext pqtype.NullRawMessage + if len(copiedContextParts) > 0 { + filteredContent, err := chatprompt.MarshalParts(copiedContextParts) + if err != nil { + return database.Chat{}, xerrors.Errorf("marshal copied context parts: %w", err) + } + initialMessages = append(initialMessages, userMessageWithAPIKeyID( + filteredContent, + modelConfigID, + parent.OwnerID, + childAPIKeyID, + )) + lastInjectedContext, err = BuildLastInjectedContext(FilterContextPartsToLatestAgent(copiedContextParts)) + if err != nil { + return database.Chat{}, xerrors.Errorf("build inherited injected context: %w", err) + } + } + initialMessages = append(initialMessages, userMessageWithAPIKeyID(userContent, modelConfigID, parent.OwnerID, childAPIKeyID)) + + publisher := p.pubsub + if publisher == nil { + publisher = dbpubsub.NewInMemory() + } + result, err := chatstate.CreateChat(ctx, p.db, publisher, chatstate.CreateChatInput{ + OrganizationID: parent.OrganizationID, + OwnerID: parent.OwnerID, + WorkspaceID: parent.WorkspaceID, + BuildID: parent.BuildID, + AgentID: parent.AgentID, + ParentChatID: uuid.NullUUID{UUID: parent.ID, Valid: true}, + RootChatID: uuid.NullUUID{UUID: rootChatID, Valid: true}, + LastModelConfigID: modelConfigID, + Title: title, + Mode: opts.chatMode, + PlanMode: childPlanMode, + MCPServerIDs: mcpServerIDs, + Labels: pqtype.NullRawMessage{ + RawMessage: labelsJSON, + Valid: true, + }, + DynamicTools: pqtype.NullRawMessage{}, + ClientType: parent.ClientType, + InitialMessages: initialMessages, + LastInjectedContext: lastInjectedContext, + }) + if err != nil { + return database.Chat{}, xerrors.Errorf("create child chat: %w", err) + } + + child := result.Chat + p.publishChatPubsubEvent(child, codersdk.ChatWatchEventKindCreated, nil) - p.signalWake() return child, nil } // copyParentContextMessages reads persisted context-file and skill -// messages from the parent chat and inserts copies into the child -// chat. This ensures sub-agents inherit the same instruction and -// skill context as their parent without independently re-fetching -// from the agent. +// messages from the parent chat. This ensures sub-agents inherit the +// same instruction and skill context as their parent without +// independently re-fetching from the agent. func copyParentContextMessages( ctx context.Context, logger slog.Logger, store database.Store, parent database.Chat, - child database.Chat, ) ([]codersdk.ChatMessagePart, error) { parentMessages, err := store.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ ChatID: parent.ID, @@ -1137,12 +1106,7 @@ func copyParentContextMessages( return nil, xerrors.Errorf("get parent messages: %w", err) } - var ( - copiedParts []codersdk.ChatMessagePart - copiedRole database.ChatMessageRole - copiedVisibility database.ChatMessageVisibility - copiedVersion int16 - ) + var copiedParts []codersdk.ChatMessagePart for _, msg := range parentMessages { if !msg.Content.Valid { continue @@ -1161,11 +1125,6 @@ func copyParentContextMessages( if len(messageContextParts) == 0 { continue } - if copiedParts == nil { - copiedRole = msg.Role - copiedVisibility = msg.Visibility - copiedVersion = msg.ContentVersion - } copiedParts = append(copiedParts, messageContextParts...) } if len(copiedParts) == 0 { @@ -1173,69 +1132,10 @@ func copyParentContextMessages( } copiedParts = FilterContextPartsToLatestAgent(copiedParts) - filteredContent, err := chatprompt.MarshalParts(copiedParts) - if err != nil { - return nil, xerrors.Errorf("marshal filtered context parts: %w", err) - } - - msgParams := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by append[User]ChatMessage. - ChatID: child.ID, - } - if copiedRole == database.ChatMessageRoleUser { - copiedAPIKeyID, _ := aibridge.DelegatedAPIKeyIDFromContext(ctx) - appendUserChatMessage(&msgParams, newUserChatMessage( - copiedAPIKeyID, - filteredContent, - copiedVisibility, - child.LastModelConfigID, - copiedVersion, - )) - } else { - appendChatMessage(&msgParams, newChatMessage( - copiedRole, - filteredContent, - copiedVisibility, - child.LastModelConfigID, - copiedVersion, - )) - } - if _, err := store.InsertChatMessages(ctx, msgParams); err != nil { - return nil, xerrors.Errorf("insert context message: %w", err) - } return copiedParts, nil } -func updateChildLastInjectedContext( - ctx context.Context, - logger slog.Logger, - store database.Store, - chatID uuid.UUID, - parts []codersdk.ChatMessagePart, -) error { - parts = FilterContextPartsToLatestAgent(parts) - param, err := BuildLastInjectedContext(parts) - if err != nil { - logger.Warn(ctx, "failed to marshal inherited injected context", - slog.F("chat_id", chatID), - slog.Error(err), - ) - return xerrors.Errorf("marshal inherited injected context: %w", err) - } - if _, err := store.UpdateChatLastInjectedContext(ctx, database.UpdateChatLastInjectedContextParams{ - ID: chatID, - LastInjectedContext: param, - }); err != nil { - logger.Warn(ctx, "failed to update inherited injected context", - slog.F("chat_id", chatID), - slog.Error(err), - ) - return xerrors.Errorf("update inherited injected context: %w", err) - } - - return nil -} - func (p *Server) sendSubagentMessage( ctx context.Context, parentChatID uuid.UUID, @@ -1317,7 +1217,7 @@ func (p *Server) awaitSubagentCompletion( ch := make(chan struct{}, 1) notifyCh := (<-chan struct{})(ch) cancel, subErr := p.pubsub.SubscribeWithErr( - coderdpubsub.ChatStreamNotifyChannel(targetChatID), + coderdpubsub.ChatStateUpdateChannel(targetChatID), func(_ context.Context, _ []byte, _ error) { // Non-blocking send so we never stall the // pubsub dispatch goroutine. diff --git a/coderd/x/chatd/subagent_internal_test.go b/coderd/x/chatd/subagent_internal_test.go index a776d95c4e..fd31300e1b 100644 --- a/coderd/x/chatd/subagent_internal_test.go +++ b/coderd/x/chatd/subagent_internal_test.go @@ -4,7 +4,8 @@ import ( "context" "database/sql" "encoding/json" - "errors" + "net/http" + "net/http/httptest" "sync" "testing" "time" @@ -14,6 +15,7 @@ import ( "github.com/sqlc-dev/pqtype" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/xerrors" "cdr.dev/slog/v3" "cdr.dev/slog/v3/sloggers/slogtest" @@ -146,11 +148,11 @@ type subscribeFailingPubsub struct { } func (subscribeFailingPubsub) Subscribe(_ string, _ pubsub.Listener) (func(), error) { - return nil, errors.New("subscribe disabled") + return nil, xerrors.New("subscribe disabled") } func (subscribeFailingPubsub) SubscribeWithErr(_ string, _ pubsub.ListenerWithErr) (func(), error) { - return nil, errors.New("subscribe disabled") + return nil, xerrors.New("subscribe disabled") } type subagentTestLogSink struct { @@ -802,7 +804,7 @@ func TestCreateChildSubagentChatInheritsWorkspaceBinding(t *testing.T) { parentChat, err := db.GetChatByID(ctx, parent.ID) require.NoError(t, err) - child, err := server.createChildSubagentChat(ctx, parentChat, "inspect bindings", "") + child, err := server.createChildSubagentChatWithOptions(ctx, parentChat, "inspect bindings", "", childSubagentChatOptions{}) require.NoError(t, err) childChat, err := db.GetChatByID(ctx, child.ID) @@ -965,7 +967,7 @@ func TestCreateChildSubagentChatCopiesPlanMode(t *testing.T) { require.NoError(t, err) require.Equal(t, planMode, parentChat.PlanMode) - child, err := server.createChildSubagentChat(ctx, parentChat, "inspect bindings", "") + child, err := server.createChildSubagentChatWithOptions(ctx, parentChat, "inspect bindings", "", childSubagentChatOptions{}) require.NoError(t, err) childChat, err := db.GetChatByID(ctx, child.ID) @@ -2825,11 +2827,12 @@ func TestCreateChildSubagentChat_InheritsMCPServerIDs(t *testing.T) { "parent chat must have the MCP server IDs we set") // Spawn a child subagent chat. - child, err := server.createChildSubagentChat( + child, err := server.createChildSubagentChatWithOptions( ctx, parentChat, "do some work", "child-task", + childSubagentChatOptions{}, ) require.NoError(t, err) @@ -2863,11 +2866,12 @@ func TestCreateChildSubagentChat_NoMCPServersStaysEmpty(t *testing.T) { require.NoError(t, err) // Spawn a child. - child, err := server.createChildSubagentChat( + child, err := server.createChildSubagentChatWithOptions( ctx, parentChat, "do some work", "child-no-mcp", + childSubagentChatOptions{}, ) require.NoError(t, err) @@ -3359,23 +3363,6 @@ func TestAwaitSubagentCompletion(t *testing.T) { parent, child := createParentChildChats(ctx, t, server, user, org, model) - // signalWake from CreateChat triggers background processing. Wait - // for those runs to finish, then reset both chats so this test owns - // the state transition observed by the poll loop. - testutil.Eventually(ctx, t, func(ctx context.Context) bool { - parentChat, err := db.GetChatByID(ctx, parent.ID) - if err != nil { - return false - } - childChat, err := db.GetChatByID(ctx, child.ID) - if err != nil { - return false - } - return parentChat.Status != database.ChatStatusPending && - parentChat.Status != database.ChatStatusRunning && - childChat.Status != database.ChatStatusPending && - childChat.Status != database.ChatStatusRunning - }, testutil.IntervalFast) setChatStatus(ctx, t, db, parent.ID, database.ChatStatusRunning, "") setChatStatus(ctx, t, db, child.ID, database.ChatStatusRunning, "") @@ -3463,7 +3450,7 @@ func TestAwaitSubagentCompletion(t *testing.T) { probeCh := make(chan struct{}, 1) cancelProbe, err := ps.SubscribeWithErr( - coderdpubsub.ChatStreamNotifyChannel(child.ID), + coderdpubsub.ChatStateUpdateChannel(child.ID), func(_ context.Context, _ []byte, _ error) { select { case probeCh <- struct{}{}: @@ -3493,7 +3480,7 @@ func TestAwaitSubagentCompletion(t *testing.T) { assert.Equal(c, "pubsub result", report) }, testutil.WaitMedium, testutil.IntervalFast) require.NoError(t, ps.Publish( - coderdpubsub.ChatStreamNotifyChannel(child.ID), + coderdpubsub.ChatStateUpdateChannel(child.ID), []byte("done"), )) testutil.RequireReceive(ctx, t, probeCh) @@ -3566,11 +3553,44 @@ func TestAwaitSubagentCompletion(t *testing.T) { t.Run("ContextCanceled", func(t *testing.T) { t.Parallel() + + providerCalled := make(chan struct{}, 1) + providerReleased := make(chan struct{}) + providerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + select { + case providerCalled <- struct{}{}: + default: + } + + select { + case <-r.Context().Done(): + case <-providerReleased: + } + })) + t.Cleanup(func() { + close(providerReleased) + providerServer.Close() + }) + + db, ps := dbtestutil.NewDB(t) + server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{}) ctx := chatdTestContext(t) + user, org, _ := seedInternalChatDeps(t, db) + provider := dbgen.ChatProvider(t, db, database.ChatProvider{ + Provider: "openai", + DisplayName: "OpenAI", + BaseUrl: providerServer.URL, + }) + model := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{ + Provider: "openai", + Model: "gpt-4o-mini", + AIProviderID: uuid.NullUUID{UUID: provider.ID, Valid: true}, + }) parent, child := createParentChildChats(ctx, t, server, user, org, model) - setChatStatus(ctx, t, db, child.ID, database.ChatStatusRunning, "") + testutil.RequireReceive(ctx, t, providerCalled) + // Use a short-lived context instead of goroutine + sleep. shortCtx, cancel := context.WithTimeout(ctx, testutil.IntervalMedium) defer cancel() diff --git a/coderd/x/chatd/subscribe_out_of_order_internal_test.go b/coderd/x/chatd/subscribe_out_of_order_internal_test.go deleted file mode 100644 index a6bb083785..0000000000 --- a/coderd/x/chatd/subscribe_out_of_order_internal_test.go +++ /dev/null @@ -1,212 +0,0 @@ -package chatd - -import ( - "testing" - "time" - - "github.com/google/uuid" - "github.com/stretchr/testify/require" - "go.uber.org/mock/gomock" - - "github.com/coder/coder/v2/coderd/database" - "github.com/coder/coder/v2/coderd/database/dbmock" - coderdpubsub "github.com/coder/coder/v2/coderd/pubsub" - "github.com/coder/coder/v2/codersdk" - "github.com/coder/coder/v2/testutil" -) - -// TestSubscribeDeliversOutOfOrderDurableMessage tests that a -// late-arriving lower-ID durable message is delivered when a -// higher-ID was already cached and sent. -func TestSubscribeDeliversOutOfOrderDurableMessage(t *testing.T) { - t.Parallel() - - ctx := testutil.Context(t, testutil.WaitMedium) - - ctrl := gomock.NewController(t) - db := dbmock.NewMockStore(ctrl) - - chatID := uuid.New() - chat := database.Chat{ID: chatID, Status: database.ChatStatusRequiresAction} - initialUser := database.ChatMessage{ID: 3, ChatID: chatID, Role: database.ChatMessageRoleUser} - initialAssistant := database.ChatMessage{ID: 4, ChatID: chatID, Role: database.ChatMessageRoleAssistant} - - gomock.InOrder( - db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), - db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), - db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{ - ChatID: chatID, - AfterID: 0, - }).Return([]database.ChatMessage{initialUser, initialAssistant}, nil), - db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil), - ) - // Notify-driven catch-up queries return nothing so the test only - // exercises the cache delivery path. - db.EXPECT().GetChatMessagesByChatID(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() - - server := newSubscribeTestServer(t, db) - - toolResult := codersdk.ChatMessage{ID: 5, ChatID: chatID, Role: codersdk.ChatMessageRoleTool} - resumed := codersdk.ChatMessage{ID: 7, ChatID: chatID, Role: codersdk.ChatMessageRoleAssistant} - promoted := codersdk.ChatMessage{ID: 6, ChatID: chatID, Role: codersdk.ChatMessageRoleUser} - - server.cacheDurableMessage(chatID, codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeMessage, ChatID: chatID, - Message: &codersdk.ChatMessage{ID: 4, ChatID: chatID, Role: codersdk.ChatMessageRoleAssistant}, - }) - - _, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 0) - require.True(t, ok) - defer cancel() - - // Cache id=5 and id=7, but not id=6, then emit the notify for - // id=5. The merge goroutine drains [5, 7] from the cache. - server.cacheDurableMessage(chatID, codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeMessage, ChatID: chatID, Message: &toolResult, - }) - server.cacheDurableMessage(chatID, codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeMessage, ChatID: chatID, Message: &resumed, - }) - server.publishChatStreamNotify(chatID, coderdpubsub.ChatStreamNotifyMessage{AfterMessageID: 4}) - - first := testutil.RequireReceive(ctx, t, events) - require.Equal(t, codersdk.ChatStreamEventTypeMessage, first.Type) - require.NotNil(t, first.Message) - require.Equal(t, int64(5), first.Message.ID) - second := testutil.RequireReceive(ctx, t, events) - require.Equal(t, codersdk.ChatStreamEventTypeMessage, second.Type) - require.NotNil(t, second.Message) - require.Equal(t, int64(7), second.Message.ID) - - // Cache id=6 after the merge goroutine has already advanced - // lastMessageID to 7, then emit the notify for id=6. - server.cacheDurableMessage(chatID, codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeMessage, ChatID: chatID, Message: &promoted, - }) - server.publishChatStreamNotify(chatID, coderdpubsub.ChatStreamNotifyMessage{AfterMessageID: 5}) - - third := testutil.RequireReceive(ctx, t, events) - require.Equal(t, codersdk.ChatStreamEventTypeMessage, third.Type) - require.NotNil(t, third.Message) - require.Equal(t, int64(6), third.Message.ID) - - requireNoStreamEvent(t, events, 200*time.Millisecond) -} - -// TestSubscribeRespectsAfterMessageIDOnLateNotify tests that -// lookupAfter never drops below afterMessageID, preventing -// re-emission of messages the client already has via REST. -func TestSubscribeRespectsAfterMessageIDOnLateNotify(t *testing.T) { - t.Parallel() - - ctx := testutil.Context(t, testutil.WaitMedium) - - ctrl := gomock.NewController(t) - db := dbmock.NewMockStore(ctrl) - - chatID := uuid.New() - chat := database.Chat{ID: chatID, Status: database.ChatStatusRunning} - - gomock.InOrder( - db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), - db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), - db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{ - ChatID: chatID, - AfterID: 100, - }).Return(nil, nil), - db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil), - ) - db.EXPECT().GetChatMessagesByChatID(gomock.Any(), gomock.Any()).Return(nil, nil).AnyTimes() - - server := newSubscribeTestServer(t, db) - - // Seed the cache with messages the client claims to already have - // (id<=100) plus one new message (id=101). - for _, id := range []int64{96, 97, 98, 99, 100, 101} { - msg := &codersdk.ChatMessage{ID: id, ChatID: chatID, Role: codersdk.ChatMessageRoleAssistant} - server.cacheDurableMessage(chatID, codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeMessage, ChatID: chatID, Message: msg, - }) - } - - _, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 100) - require.True(t, ok) - defer cancel() - - // A stale notify with AfterMessageID=95 would naively pull - // id=96..101 back from the cache; only id=101 should reach the - // live stream because the client already has 96-100. - server.publishChatStreamNotify(chatID, coderdpubsub.ChatStreamNotifyMessage{AfterMessageID: 95}) - - ev := testutil.RequireReceive(ctx, t, events) - require.Equal(t, codersdk.ChatStreamEventTypeMessage, ev.Type) - require.NotNil(t, ev.Message) - require.Equal(t, int64(101), ev.Message.ID, - "messages at or below afterMessageID must not be re-emitted") - - requireNoStreamEvent(t, events, 200*time.Millisecond) -} - -// TestSubscribeRunsDBFallbackWhenCacheDeliversUnrelatedMessage tests -// that the DB fallback runs even when the cache delivers, so -// cross-replica messages are not dropped. -func TestSubscribeRunsDBFallbackWhenCacheDeliversUnrelatedMessage(t *testing.T) { - t.Parallel() - - ctx := testutil.Context(t, testutil.WaitMedium) - - ctrl := gomock.NewController(t) - db := dbmock.NewMockStore(ctrl) - - chatID := uuid.New() - chat := database.Chat{ID: chatID, Status: database.ChatStatusRunning} - crossReplica := database.ChatMessage{ID: 6, ChatID: chatID, Role: database.ChatMessageRoleUser} - - gomock.InOrder( - db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), - db.EXPECT().GetChatByID(gomock.Any(), chatID).Return(chat, nil), - // Snapshot: nothing above the client's afterMessageID=5 yet. - db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{ - ChatID: chatID, - AfterID: 5, - }).Return(nil, nil), - db.EXPECT().GetChatQueuedMessages(gomock.Any(), chatID).Return(nil, nil), - // Notify catchup: the cross-replica message lives only in the - // DB on this replica. - db.EXPECT().GetChatMessagesByChatID(gomock.Any(), database.GetChatMessagesByChatIDParams{ - ChatID: chatID, - AfterID: 5, - }).Return([]database.ChatMessage{crossReplica}, nil), - ) - - server := newSubscribeTestServer(t, db) - - // Cache a locally-published higher-ID message so the cache pass - // has something to deliver without covering id=6. - localOnly := codersdk.ChatMessage{ID: 8, ChatID: chatID, Role: codersdk.ChatMessageRoleAssistant} - server.cacheDurableMessage(chatID, codersdk.ChatStreamEvent{ - Type: codersdk.ChatStreamEventTypeMessage, ChatID: chatID, Message: &localOnly, - }) - - _, events, cancel, ok := server.Subscribe(ctx, chatID, nil, 5) - require.True(t, ok) - defer cancel() - - server.publishChatStreamNotify(chatID, coderdpubsub.ChatStreamNotifyMessage{AfterMessageID: 5}) - - // The cache pass delivers id=8; the DB pass must still run and - // deliver id=6. Order between them is set by cache iteration vs - // DB query, so accept either ordering. - first := testutil.RequireReceive(ctx, t, events) - require.Equal(t, codersdk.ChatStreamEventTypeMessage, first.Type) - require.NotNil(t, first.Message) - second := testutil.RequireReceive(ctx, t, events) - require.Equal(t, codersdk.ChatStreamEventTypeMessage, second.Type) - require.NotNil(t, second.Message) - - got := map[int64]bool{first.Message.ID: true, second.Message.ID: true} - require.True(t, got[6], "cross-replica DB message id=6 must be delivered") - require.True(t, got[8], "locally-cached message id=8 must be delivered") - - requireNoStreamEvent(t, events, 200*time.Millisecond) -} diff --git a/coderd/x/chatd/tasks.go b/coderd/x/chatd/tasks.go new file mode 100644 index 0000000000..7563bd5dec --- /dev/null +++ b/coderd/x/chatd/tasks.go @@ -0,0 +1,645 @@ +package chatd + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "strings" + "time" + + "cdr.dev/slog/v3" + "github.com/google/uuid" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/db2sdk" + coderdpubsub "github.com/coder/coder/v2/coderd/pubsub" + "github.com/coder/coder/v2/coderd/x/chatd/chatprompt" + "github.com/coder/coder/v2/coderd/x/chatd/chatstate" + "github.com/coder/coder/v2/coderd/x/chatd/messagepartbuffer" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/quartz" +) + +const postCommitWatchPublishTimeout = 10 * time.Second + +var ( + errTaskExpectedExit = xerrors.New("chatworker task expected exit") + errTaskRetryable = xerrors.New("chatworker task retryable error") +) + +type taskRetryableError struct { + err error +} + +func (e taskRetryableError) Error() string { + if e.err == nil { + return errTaskRetryable.Error() + } + return e.err.Error() +} + +func (e taskRetryableError) Unwrap() error { + if e.err == nil { + return errTaskRetryable + } + return errors.Join(errTaskRetryable, e.err) +} + +type retryWrapperOptions struct { + clock quartz.Clock + initialDelay time.Duration + maxDelay time.Duration +} + +func runTaskWithRetry( + ctx context.Context, + opts retryWrapperOptions, + kind taskKind, + fn func(context.Context) error, +) error { + if opts.clock == nil { + opts.clock = quartz.NewReal() + } + if opts.initialDelay <= 0 { + opts.initialDelay = defaultTaskRetryInitialBackoff + } + if opts.maxDelay <= 0 { + opts.maxDelay = defaultTaskRetryMaxBackoff + } + if opts.maxDelay < opts.initialDelay { + opts.maxDelay = opts.initialDelay + } + + delay := opts.initialDelay + for { + err := executeTaskSafely(ctx, fn) + switch { + case err == nil: + return nil + case errors.Is(err, errTaskExpectedExit): + return nil + case ctx.Err() != nil: + return nil + } + + timer := opts.clock.NewTimer(delay, "chatworker", "task-retry-"+string(kind)) + select { + case <-timer.C: + case <-ctx.Done(): + timer.Stop() + return nil + } + timer.Stop() + if delay < opts.maxDelay { + delay *= 2 + if delay > opts.maxDelay { + delay = opts.maxDelay + } + } + } +} + +func executeTaskSafely(ctx context.Context, fn func(context.Context) error) (err error) { + defer func() { + if recovered := recover(); recovered != nil { + err = xerrors.Errorf("chatworker task panic: %v", recovered) + } + }() + return fn(ctx) +} + +type interruptionOutcome struct { + Chat database.Chat + Kind runnerActionKind + WatchEventKind codersdk.ChatWatchEventKind +} + +type taskStarter struct { + server *Server + opts chatWorkerOptions + routeStateHint func(context.Context, runnerStateUpdate) + requestCleanup func(context.Context, runnerKey) + afterInterruptionOutcome func(context.Context, interruptionOutcome) error +} + +func newTaskStarter( + server *Server, + opts chatWorkerOptions, + routeStateHint func(context.Context, runnerStateUpdate), + requestCleanup func(context.Context, runnerKey), +) (*taskStarter, error) { + if opts.Store == nil { + return nil, xerrors.New("chatworker: task store is required") + } + if opts.Pubsub == nil { + return nil, xerrors.New("chatworker: task pubsub is required") + } + if opts.MessagePartBuffer == nil { + return nil, xerrors.New("chatworker: message part buffer is required") + } + if opts.Clock == nil { + opts.Clock = quartz.NewReal() + } + if opts.TaskRetryInitialBackoff <= 0 { + opts.TaskRetryInitialBackoff = defaultTaskRetryInitialBackoff + } + if opts.TaskRetryMaxBackoff <= 0 { + opts.TaskRetryMaxBackoff = defaultTaskRetryMaxBackoff + } + if opts.TaskRetryMaxBackoff < opts.TaskRetryInitialBackoff { + opts.TaskRetryMaxBackoff = opts.TaskRetryInitialBackoff + } + if routeStateHint == nil { + return nil, xerrors.New("chatworker: route state hint callback is required") + } + if requestCleanup == nil { + return nil, xerrors.New("chatworker: cleanup callback is required") + } + return &taskStarter{ + server: server, + opts: opts, + routeStateHint: routeStateHint, + requestCleanup: requestCleanup, + }, nil +} + +func (o chatWorkerOptions) retryOptions() retryWrapperOptions { + return retryWrapperOptions{ + clock: o.Clock, + initialDelay: o.TaskRetryInitialBackoff, + maxDelay: o.TaskRetryMaxBackoff, + } +} + +func (s *taskStarter) StartInterrupt(ctx context.Context, input chatWorkerTaskStartInput) error { + machine := chatstate.NewChatMachine(s.opts.Store, s.opts.Pubsub, input.ChatID, chatstate.Options{}) + var chat database.Chat + err := machine.ReadLock(ctx, func(store database.Store) error { + locked, err := store.GetChatByID(ctx, input.ChatID) + if errors.Is(err, sql.ErrNoRows) { + return errTaskExpectedExit + } + if err != nil { + return xerrors.Errorf("load locked chat: %w", err) + } + if err := verifyTaskFence(locked, input, database.ChatStatusInterrupting, taskFenceOptions{requireHistory: true}); err != nil { + return err + } + chat = locked + return nil + }) + if err != nil { + return normalizeTaskInfrastructureError(err, "lock chat for interrupt") + } + + key := messagepartbuffer.Key{ + ChatID: input.ChatID, + HistoryVersion: input.HistoryVersion, + GenerationAttempt: chat.GenerationAttempt, + } + if err := s.opts.MessagePartBuffer.CloseEpisode(key); err != nil { + if ctx.Err() != nil { + return errTaskExpectedExit + } + return taskRetryableError{err: xerrors.Errorf("close message part episode: %w", err)} + } + parts, err := s.opts.MessagePartBuffer.GetParts(key) + if errors.Is(err, messagepartbuffer.ErrEpisodeNotFound) { + parts = nil + err = nil + } + if err != nil { + if ctx.Err() != nil { + return errTaskExpectedExit + } + return taskRetryableError{err: xerrors.Errorf("get message part episode: %w", err)} + } + partialMessages, err := bufferedPartsToPartialMessages(bufferedPartsToPartialMessagesInput{ + parts: parts, + modelConfigID: chat.LastModelConfigID, + contentVersion: chatprompt.CurrentContentVersion, + logger: s.opts.Logger, + interruptedAt: s.opts.Clock.Now("chatworker", "interrupt"), + }) + if err != nil { + return xerrors.Errorf("convert buffered parts: %w", err) + } + + var committed database.Chat + err = machine.Update(ctx, func(tx *chatstate.Tx) error { + locked, err := tx.Store().GetChatByID(ctx, input.ChatID) + if errors.Is(err, sql.ErrNoRows) { + return errTaskExpectedExit + } + if err != nil { + return xerrors.Errorf("load chat: %w", err) + } + if err := verifyTaskFence(locked, input, database.ChatStatusInterrupting, taskFenceOptions{requireHistory: true}); err != nil { + return err + } + messages := partialMessages + committedCancels, err := committedPendingLocalToolCancellationMessages(ctx, tx.Store(), locked, s.opts.Clock.Now("chatworker", "interrupt"), s.opts.Logger) + if err != nil { + return err + } + if len(committedCancels) > 0 { + messages = append(append([]chatstate.Message{}, partialMessages...), committedCancels...) + } + if _, err := tx.FinishInterruption(chatstate.FinishInterruptionInput{PartialMessages: messages}); err != nil { + return err + } + committed, err = tx.Store().GetChatByID(ctx, input.ChatID) + if err != nil { + return xerrors.Errorf("load committed chat: %w", err) + } + return nil + }) + if err != nil { + if current, ok := s.committedStateAfterUpdateError(ctx, committed); ok { + return s.publishWatchAndRoute(ctx, current, codersdk.ChatWatchEventKindStatusChange) + } + return normalizeTaskTransitionError(err, "finish interruption") + } + if err := s.publishWatchAndRoute(ctx, committed, codersdk.ChatWatchEventKindStatusChange); err != nil { + return err + } + return s.runAfterInterruptionOutcome(ctx, interruptionOutcome{ + Chat: committed, + Kind: runnerActionKindFinishInterruption, + WatchEventKind: codersdk.ChatWatchEventKindStatusChange, + }) +} + +func (s *taskStarter) runAfterInterruptionOutcome(ctx context.Context, outcome interruptionOutcome) error { + afterOutcome := s.afterInterruptionOutcome + if afterOutcome == nil && s.server != nil { + afterOutcome = s.server.afterInterruptionOutcome + } + if afterOutcome == nil { + return nil + } + if err := afterOutcome(ctx, outcome); err != nil { + return taskRetryableError{err: xerrors.Errorf("interruption post-outcome side effects: %w", err)} + } + return nil +} + +func (s *taskStarter) StartRequiresActionTimeout(ctx context.Context, input chatWorkerTaskStartInput) error { + machine := chatstate.NewChatMachine(s.opts.Store, s.opts.Pubsub, input.ChatID, chatstate.Options{}) + for { + decision, err := decideRequiresActionTimeout(ctx, machine, input) + if err != nil { + return err + } + if decision.cancel { + return s.cancelRequiresAction(ctx, machine, input, decision.reason) + } + if !decision.waitUntil.Valid { + return errTaskExpectedExit + } + if err := s.waitUntil(ctx, decision.waitUntil.Time); err != nil { + return err + } + } +} + +type requiresActionTimeoutDecision struct { + cancel bool + reason string + waitUntil sql.NullTime +} + +func decideRequiresActionTimeout( + ctx context.Context, + machine *chatstate.ChatMachine, + input chatWorkerTaskStartInput, +) (requiresActionTimeoutDecision, error) { + var decision requiresActionTimeoutDecision + err := machine.ReadLock(ctx, func(store database.Store) error { + locked, err := store.GetChatByID(ctx, input.ChatID) + if errors.Is(err, sql.ErrNoRows) { + return errTaskExpectedExit + } + if err != nil { + return xerrors.Errorf("load locked chat: %w", err) + } + if err := verifyTaskFence(locked, input, database.ChatStatusRequiresAction, taskFenceOptions{requireHistory: true}); err != nil { + return err + } + if !locked.RequiresActionDeadlineAt.Valid { + decision.cancel = true + decision.reason = "Tool execution canceled because the action deadline was missing" + return nil + } + now, err := store.GetDatabaseNow(ctx) + if err != nil { + return xerrors.Errorf("get database time: %w", err) + } + if now.Before(locked.RequiresActionDeadlineAt.Time) { + decision.waitUntil = locked.RequiresActionDeadlineAt + return nil + } + decision.cancel = true + decision.reason = "Tool execution timed out" + return nil + }) + if err != nil { + return requiresActionTimeoutDecision{}, normalizeTaskInfrastructureError(err, "lock chat for requires action timeout") + } + return decision, nil +} + +func (s *taskStarter) waitUntil(ctx context.Context, deadline time.Time) error { + now := s.opts.Clock.Now("chatworker", "requires-action-timeout") + if !now.Before(deadline) { + return nil + } + timer := s.opts.Clock.NewTimer(deadline.Sub(now), "chatworker", "requires-action-timeout") + defer timer.Stop() + select { + case <-timer.C: + return nil + case <-ctx.Done(): + return errTaskExpectedExit + } +} + +func (s *taskStarter) cancelRequiresAction( + ctx context.Context, + machine *chatstate.ChatMachine, + input chatWorkerTaskStartInput, + reason string, +) error { + var committed database.Chat + err := machine.Update(ctx, func(tx *chatstate.Tx) error { + locked, err := tx.Store().GetChatByID(ctx, input.ChatID) + if errors.Is(err, sql.ErrNoRows) { + return errTaskExpectedExit + } + if err != nil { + return xerrors.Errorf("load chat: %w", err) + } + if err := verifyTaskFence(locked, input, database.ChatStatusRequiresAction, taskFenceOptions{requireHistory: true}); err != nil { + return err + } + if locked.RequiresActionDeadlineAt.Valid { + now, err := tx.Store().GetDatabaseNow(ctx) + if err != nil { + return xerrors.Errorf("get database time: %w", err) + } + if now.Before(locked.RequiresActionDeadlineAt.Time) { + return errTaskExpectedExit + } + } + if _, err := tx.CancelRequiresAction(chatstate.CancelRequiresActionInput{Reason: reason}); err != nil { + return err + } + committed, err = tx.Store().GetChatByID(ctx, input.ChatID) + if err != nil { + return xerrors.Errorf("load committed chat: %w", err) + } + return nil + }) + if err != nil { + if current, ok := s.committedStateAfterUpdateError(ctx, committed); ok { + return s.publishWatchAndRoute(ctx, current, codersdk.ChatWatchEventKindStatusChange) + } + return normalizeTaskTransitionError(err, "cancel requires action") + } + return s.publishWatchAndRoute(ctx, committed, codersdk.ChatWatchEventKindStatusChange) +} + +func (s *taskStarter) StartAbandon(ctx context.Context, input chatWorkerTaskStartInput) error { + machine := chatstate.NewChatMachine(s.opts.Store, s.opts.Pubsub, input.ChatID, chatstate.Options{}) + mismatch := false + err := machine.Update(ctx, func(tx *chatstate.Tx) error { + locked, err := tx.Store().GetChatByID(ctx, input.ChatID) + if errors.Is(err, sql.ErrNoRows) { + mismatch = true + return errTaskExpectedExit + } + if err != nil { + return xerrors.Errorf("load chat: %w", err) + } + if !ownedByTask(locked, input) { + mismatch = true + return errTaskExpectedExit + } + if err := verifyTaskFence(locked, input, input.Status, taskFenceOptions{requireHistory: true, allowArchived: true}); err != nil { + return err + } + if _, err := tx.Abandon(chatstate.AbandonInput{}); err != nil { + return err + } + return nil + }) + if err != nil { + if errors.Is(err, errTaskExpectedExit) && mismatch { + s.requestCleanup(ctx, runnerKey{ChatID: input.ChatID, RunnerID: input.RunnerID}) + return nil + } + return normalizeTaskTransitionError(err, "abandon chat") + } + s.requestCleanup(ctx, runnerKey{ChatID: input.ChatID, RunnerID: input.RunnerID}) + return nil +} + +func (s *taskStarter) committedStateAfterUpdateError(ctx context.Context, committed database.Chat) (database.Chat, bool) { + if committed.ID == uuid.Nil { + return database.Chat{}, false + } + current, err := s.opts.Store.GetChatByID(ctx, committed.ID) + if err != nil { + return database.Chat{}, false + } + if current.SnapshotVersion != committed.SnapshotVersion || + current.HistoryVersion != committed.HistoryVersion || + current.QueueVersion != committed.QueueVersion || + current.GenerationAttempt != committed.GenerationAttempt || + current.Status != committed.Status || + current.Archived != committed.Archived || + current.WorkerID != committed.WorkerID || + current.RunnerID != committed.RunnerID { + return database.Chat{}, false + } + return current, true +} + +func (s *taskStarter) publishWatchAndRoute( + ctx context.Context, + chat database.Chat, + kind codersdk.ChatWatchEventKind, +) error { + watchCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), postCommitWatchPublishTimeout) + defer cancel() + if err := s.publishWatchWithRetry(watchCtx, chat, kind); err != nil { + return err + } + s.routeStateHint(ctx, stateUpdateFromChat(chat)) + return nil +} + +func (s *taskStarter) publishWatchWithRetry( + ctx context.Context, + chat database.Chat, + kind codersdk.ChatWatchEventKind, +) error { + delay := s.opts.TaskRetryInitialBackoff + for { + if err := publishChatWatchEvent(s.opts.Pubsub, chat, kind); err == nil { + return nil + } else if ctx.Err() != nil { + return errTaskExpectedExit + } + timer := s.opts.Clock.NewTimer(delay, "chatworker", "watch-publish-retry") + select { + case <-timer.C: + case <-ctx.Done(): + timer.Stop() + return errTaskExpectedExit + } + timer.Stop() + if delay < s.opts.TaskRetryMaxBackoff { + delay *= 2 + if delay > s.opts.TaskRetryMaxBackoff { + delay = s.opts.TaskRetryMaxBackoff + } + } + } +} + +func publishChatWatchEvent(pubsub chatWorkerPubsub, chat database.Chat, kind codersdk.ChatWatchEventKind) error { + event := codersdk.ChatWatchEvent{ + Kind: kind, + Chat: db2sdk.Chat(chat, nil, nil), + } + payload, err := json.Marshal(event) + if err != nil { + return xerrors.Errorf("marshal chat watch event: %w", err) + } + if err := pubsub.Publish(coderdpubsub.ChatWatchEventChannel(chat.OwnerID), payload); err != nil { + return xerrors.Errorf("publish chat watch event: %w", err) + } + return nil +} + +type taskFenceOptions struct { + requireHistory bool + allowArchived bool +} + +func verifyTaskFence( + chat database.Chat, + input chatWorkerTaskStartInput, + status database.ChatStatus, + opts taskFenceOptions, +) error { + if !ownedByTask(chat, input) { + return errTaskExpectedExit + } + if chat.Status != status { + return errTaskExpectedExit + } + if !opts.allowArchived && chat.Archived { + return errTaskExpectedExit + } + if opts.requireHistory && chat.HistoryVersion != input.HistoryVersion { + return errTaskExpectedExit + } + return nil +} + +func ownedByTask(chat database.Chat, input chatWorkerTaskStartInput) bool { + return chat.WorkerID.Valid && chat.WorkerID.UUID == input.WorkerID && + chat.RunnerID.Valid && chat.RunnerID.UUID == input.RunnerID +} + +func normalizeTaskInfrastructureError(err error, action string) error { + if err == nil { + return nil + } + if errors.Is(err, errTaskExpectedExit) || errors.Is(err, chatstate.ErrChatNotFound) || errors.Is(err, sql.ErrNoRows) || errors.Is(err, context.Canceled) { + return errTaskExpectedExit + } + return taskRetryableError{err: xerrors.Errorf("%s: %w", action, err)} +} + +func normalizeTaskTransitionError(err error, action string) error { + if err == nil { + return nil + } + if errors.Is(err, errTaskExpectedExit) || errors.Is(err, chatstate.ErrChatNotFound) || errors.Is(err, sql.ErrNoRows) || errors.Is(err, context.Canceled) { + return errTaskExpectedExit + } + if errors.Is(err, chatstate.ErrTransitionNotAllowed) || errors.Is(err, chatstate.ErrInvalidState) { + return xerrors.Errorf("%s: %w", action, err) + } + return taskRetryableError{err: xerrors.Errorf("%s: %w", action, err)} +} + +func dynamicToolNamesFromChat(chat database.Chat) map[string]bool { + if !chat.DynamicTools.Valid || len(chat.DynamicTools.RawMessage) == 0 { + return nil + } + var tools []codersdk.DynamicTool + if err := json.Unmarshal(chat.DynamicTools.RawMessage, &tools); err != nil { + return nil + } + names := make(map[string]bool, len(tools)) + for _, tool := range tools { + name := strings.TrimSpace(tool.Name) + if name != "" { + names[name] = true + } + } + return names +} + +func committedPendingLocalToolCancellationMessages( + ctx context.Context, + store database.Store, + chat database.Chat, + interruptedAt time.Time, + logger slog.Logger, +) ([]chatstate.Message, error) { + messages, err := store.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ + ChatID: chat.ID, + AfterID: 0, + }) + if err != nil { + return nil, xerrors.Errorf("load committed messages for interruption: %w", err) + } + localCalls, _, err := unresolvedToolCallsFromHistory(messages, dynamicToolNamesFromChat(chat)) + if err != nil { + return nil, err + } + if len(localCalls) == 0 { + return nil, nil + } + result := make([]chatstate.Message, 0, len(localCalls)) + for _, call := range localCalls { + payload, err := json.Marshal(map[string]string{"error": interruptedToolResultErrorMessage}) + if err != nil { + return nil, xerrors.Errorf("marshal interrupted tool result: %w", err) + } + part := codersdk.ChatMessageToolResult(call.ToolCallID, call.ToolName, payload, true, false) + if !interruptedAt.IsZero() { + part.CreatedAt = &interruptedAt + } + content, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{part}) + if err != nil { + return nil, xerrors.Errorf("marshal interrupted tool result part: %w", err) + } + result = append(result, chatstate.Message{ + Role: database.ChatMessageRoleTool, + Content: content, + Visibility: database.ChatMessageVisibilityBoth, + ModelConfigID: uuid.NullUUID{UUID: chat.LastModelConfigID, Valid: chat.LastModelConfigID != uuid.Nil}, + ContentVersion: chatprompt.CurrentContentVersion, + }) + } + _ = logger + return result, nil +} diff --git a/coderd/x/chatd/tasks_test.go b/coderd/x/chatd/tasks_test.go new file mode 100644 index 0000000000..626854e8af --- /dev/null +++ b/coderd/x/chatd/tasks_test.go @@ -0,0 +1,1125 @@ +//nolint:testpackage // These tests exercise package-private task seams. +package chatd + +import ( + "context" + "database/sql" + "encoding/json" + "sync" + "testing" + "time" + + "github.com/google/uuid" + "github.com/sqlc-dev/pqtype" + "github.com/stretchr/testify/require" + "golang.org/x/xerrors" + + "cdr.dev/slog/v3" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbgen" + "github.com/coder/coder/v2/coderd/database/dbtestutil" + dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub" + coderdpubsub "github.com/coder/coder/v2/coderd/pubsub" + "github.com/coder/coder/v2/coderd/x/chatd/chaterror" + "github.com/coder/coder/v2/coderd/x/chatd/chatprompt" + "github.com/coder/coder/v2/coderd/x/chatd/chatretry" + "github.com/coder/coder/v2/coderd/x/chatd/chatstate" + "github.com/coder/coder/v2/coderd/x/chatd/messagepartbuffer" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/testutil" + "github.com/coder/quartz" +) + +func TestRetryWrapper_ExpectedExitsDoNotRetry(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + calls := 0 + err := runTaskWithRetry(ctx, retryWrapperOptions{ + clock: quartz.NewMock(t), + initialDelay: time.Second, + maxDelay: time.Second, + }, taskKindInterrupt, func(context.Context) error { + calls++ + return errTaskExpectedExit + }) + require.NoError(t, err) + require.Equal(t, 1, calls) +} + +func TestRetryWrapper_UnexpectedErrorsRetry(t *testing.T) { + t.Parallel() + + clock := quartz.NewMock(t) + trap := clock.Trap().NewTimer("chatworker", "task-retry-requires_action_timeout") + defer trap.Close() + ctx := testutil.Context(t, testutil.WaitLong) + calls := 0 + done := make(chan error, 1) + go func() { + done <- runTaskWithRetry(ctx, retryWrapperOptions{ + clock: clock, + initialDelay: time.Minute, + maxDelay: time.Minute, + }, taskKindRequiresActionTimeout, func(context.Context) error { + calls++ + if calls == 1 { + return xerrors.New("database unavailable") + } + return nil + }) + }() + + trap.MustWait(ctx).MustRelease(ctx) + clock.Advance(time.Minute).MustWait(ctx) + require.NoError(t, <-done) + require.Equal(t, 2, calls) +} + +func TestRetryWrapper_PanicsRetry(t *testing.T) { + t.Parallel() + + clock := quartz.NewMock(t) + trap := clock.Trap().NewTimer("chatworker", "task-retry-generation") + defer trap.Close() + ctx := testutil.Context(t, testutil.WaitLong) + calls := 0 + done := make(chan error, 1) + go func() { + done <- runTaskWithRetry(ctx, retryWrapperOptions{ + clock: clock, + initialDelay: time.Minute, + maxDelay: time.Minute, + }, taskKindGeneration, func(context.Context) error { + calls++ + if calls == 1 { + panic("database unavailable") + } + return nil + }) + }() + + trap.MustWait(ctx).MustRelease(ctx) + clock.Advance(time.Minute).MustWait(ctx) + require.NoError(t, <-done) + require.Equal(t, 2, calls) +} + +func TestInterruptTask_FinishInterruptionOnly(t *testing.T) { + t.Parallel() + + f := newTaskTestFixture(t) + chat := f.createRunningChat(t) + workerID := uuid.New() + runnerID := uuid.New() + acquired := f.acquireChat(t, chat.ID, workerID, runnerID) + buffer := messagepartbuffer.New(messagepartbuffer.Options{}) + key := messagepartbuffer.Key{ + ChatID: chat.ID, + HistoryVersion: acquired.HistoryVersion, + GenerationAttempt: acquired.GenerationAttempt, + } + require.NoError(t, buffer.CreateEpisode(key)) + require.NoError(t, buffer.AddPart(key, codersdk.ChatMessageRoleAssistant, codersdk.ChatMessageText("partial answer"))) + interrupting := f.interruptChat(t, chat.ID) + require.Equal(t, database.ChatStatusInterrupting, interrupting.Status) + recorder := newTaskSideEffectRecorder() + starter := newTestTaskStarter(t, f, buffer, recorder) + + err := starter.StartInterrupt(testutil.Context(t, testutil.WaitLong), chatWorkerTaskStartInput{ + ChatID: chat.ID, + WorkerID: workerID, + RunnerID: runnerID, + HistoryVersion: interrupting.HistoryVersion, + GenerationAttempt: interrupting.GenerationAttempt, + Status: database.ChatStatusInterrupting, + }) + require.NoError(t, err) + + latest, err := f.db.GetChatByID(testutil.Context(t, testutil.WaitShort), chat.ID) + require.NoError(t, err) + require.Equal(t, database.ChatStatusRunning, latest.Status) + recorder.requireStateHint(t, chat.ID, latest.SnapshotVersion, database.ChatStatusRunning) + recorder.requireInterruptionOutcome(t, chat.ID, database.ChatStatusRunning) + recorder.requireCleanupCount(t, 0) + f.requireWatchEvent(t, chat.ID, codersdk.ChatWatchEventKindStatusChange) + + messages, err := f.db.GetChatMessagesByChatID(testutil.Context(t, testutil.WaitShort), database.GetChatMessagesByChatIDParams{ChatID: chat.ID}) + require.NoError(t, err) + require.GreaterOrEqual(t, len(messages), 3) + parts, err := chatprompt.ParseContent(messages[len(messages)-2]) + require.NoError(t, err) + require.Equal(t, []codersdk.ChatMessagePart{codersdk.ChatMessageText("partial answer")}, parts) + require.Equal(t, database.ChatMessageRoleUser, messages[len(messages)-1].Role) +} + +func TestInterruptTask_StaleFenceExits(t *testing.T) { + t.Parallel() + + f := newTaskTestFixture(t) + chat := f.createRunningChat(t) + workerID := uuid.New() + runnerID := uuid.New() + f.acquireChat(t, chat.ID, workerID, runnerID) + interrupting := f.interruptChat(t, chat.ID) + otherWorkerID := uuid.New() + otherRunnerID := uuid.New() + f.acquireChat(t, chat.ID, otherWorkerID, otherRunnerID) + recorder := newTaskSideEffectRecorder() + starter := newTestTaskStarter(t, f, messagepartbuffer.New(messagepartbuffer.Options{}), recorder) + + err := starter.StartInterrupt(testutil.Context(t, testutil.WaitLong), chatWorkerTaskStartInput{ + ChatID: chat.ID, + WorkerID: workerID, + RunnerID: runnerID, + HistoryVersion: interrupting.HistoryVersion, + GenerationAttempt: interrupting.GenerationAttempt, + Status: database.ChatStatusInterrupting, + }) + require.ErrorIs(t, err, errTaskExpectedExit) + + latest, err := f.db.GetChatByID(testutil.Context(t, testutil.WaitShort), chat.ID) + require.NoError(t, err) + require.Equal(t, database.ChatStatusInterrupting, latest.Status) + require.Equal(t, otherWorkerID, latest.WorkerID.UUID) + require.Equal(t, otherRunnerID, latest.RunnerID.UUID) + recorder.requireStateHintCount(t, 0) + f.requireNoWatchEvents(t) +} + +func TestInterruptTask_MissingEpisodePersistsNilPartials(t *testing.T) { + t.Parallel() + + f := newTaskTestFixture(t) + chat := f.createRunningChat(t) + workerID := uuid.New() + runnerID := uuid.New() + f.acquireChat(t, chat.ID, workerID, runnerID) + interrupting := f.forceExecutionState(t, chat.ID, database.ChatStatusInterrupting, false, sql.NullTime{}) + recorder := newTaskSideEffectRecorder() + starter := newTestTaskStarter(t, f, messagepartbuffer.New(messagepartbuffer.Options{}), recorder) + + err := starter.StartInterrupt(testutil.Context(t, testutil.WaitLong), chatWorkerTaskStartInput{ + ChatID: chat.ID, + WorkerID: workerID, + RunnerID: runnerID, + HistoryVersion: interrupting.HistoryVersion, + GenerationAttempt: interrupting.GenerationAttempt, + Status: database.ChatStatusInterrupting, + }) + require.NoError(t, err) + + latest, err := f.db.GetChatByID(testutil.Context(t, testutil.WaitShort), chat.ID) + require.NoError(t, err) + require.Equal(t, database.ChatStatusWaiting, latest.Status) + recorder.requireInterruptionOutcome(t, chat.ID, database.ChatStatusWaiting) + messages, err := f.db.GetChatMessagesByChatID(testutil.Context(t, testutil.WaitShort), database.GetChatMessagesByChatIDParams{ChatID: chat.ID}) + require.NoError(t, err) + require.Len(t, messages, 1) + recorder.requireStateHint(t, chat.ID, latest.SnapshotVersion, database.ChatStatusWaiting) +} + +func TestInterruptTask_BufferedPartsBecomePartialMessages(t *testing.T) { + t.Parallel() + + f := newTaskTestFixture(t) + chat := f.createRunningChat(t) + workerID := uuid.New() + runnerID := uuid.New() + acquired := f.acquireChat(t, chat.ID, workerID, runnerID) + buffer := messagepartbuffer.New(messagepartbuffer.Options{}) + key := messagepartbuffer.Key{ChatID: chat.ID, HistoryVersion: acquired.HistoryVersion, GenerationAttempt: acquired.GenerationAttempt} + require.NoError(t, buffer.CreateEpisode(key)) + callID := "call_" + uuid.NewString() + require.NoError(t, buffer.AddPart(key, codersdk.ChatMessageRoleAssistant, codersdk.ChatMessagePart{ + Type: codersdk.ChatMessagePartTypeToolCall, + ToolCallID: callID, + ToolName: "local_tool", + Args: json.RawMessage(`{"value":1}`), + })) + interrupting := f.interruptChat(t, chat.ID) + recorder := newTaskSideEffectRecorder() + starter := newTestTaskStarter(t, f, buffer, recorder) + + err := starter.StartInterrupt(testutil.Context(t, testutil.WaitLong), chatWorkerTaskStartInput{ + ChatID: chat.ID, + WorkerID: workerID, + RunnerID: runnerID, + HistoryVersion: interrupting.HistoryVersion, + GenerationAttempt: interrupting.GenerationAttempt, + Status: database.ChatStatusInterrupting, + }) + require.NoError(t, err) + + messages, err := f.db.GetChatMessagesByChatID(testutil.Context(t, testutil.WaitShort), database.GetChatMessagesByChatIDParams{ChatID: chat.ID}) + require.NoError(t, err) + require.GreaterOrEqual(t, len(messages), 4) + assistant := messages[len(messages)-3] + tool := messages[len(messages)-2] + require.Equal(t, database.ChatMessageRoleAssistant, assistant.Role) + require.Equal(t, database.ChatMessageRoleTool, tool.Role) + toolParts, err := chatprompt.ParseContent(tool) + require.NoError(t, err) + require.Len(t, toolParts, 1) + require.Equal(t, codersdk.ChatMessagePartTypeToolResult, toolParts[0].Type) + require.Equal(t, callID, toolParts[0].ToolCallID) + require.True(t, toolParts[0].IsError) +} + +func TestRequiresActionTimeout_ExpiredCancelsOnly(t *testing.T) { + t.Parallel() + + f := newTaskTestFixture(t) + chat := f.createRequiresActionChat(t) + workerID := uuid.New() + runnerID := uuid.New() + acquired := f.acquireChat(t, chat.ID, workerID, runnerID) + expired := f.setRequiresActionDeadline(t, chat.ID, sql.NullTime{Time: time.Now().Add(-time.Minute), Valid: true}) + recorder := newTaskSideEffectRecorder() + starter := newTestTaskStarter(t, f, messagepartbuffer.New(messagepartbuffer.Options{}), recorder) + + err := starter.StartRequiresActionTimeout(testutil.Context(t, testutil.WaitLong), chatWorkerTaskStartInput{ + ChatID: chat.ID, + WorkerID: workerID, + RunnerID: runnerID, + HistoryVersion: acquired.HistoryVersion, + Status: database.ChatStatusRequiresAction, + RequiresActionDeadlineAt: expired.RequiresActionDeadlineAt, + }) + require.NoError(t, err) + + latest, err := f.db.GetChatByID(testutil.Context(t, testutil.WaitShort), chat.ID) + require.NoError(t, err) + require.Equal(t, database.ChatStatusRunning, latest.Status) + require.False(t, latest.RequiresActionDeadlineAt.Valid) + recorder.requireStateHint(t, chat.ID, latest.SnapshotVersion, database.ChatStatusRunning) + f.requireWatchEvent(t, chat.ID, codersdk.ChatWatchEventKindStatusChange) +} + +func TestRequiresActionTimeout_NullDeadlineCancelsImmediately(t *testing.T) { + t.Parallel() + + f := newTaskTestFixture(t) + chat := f.createRequiresActionChat(t) + workerID := uuid.New() + runnerID := uuid.New() + acquired := f.acquireChat(t, chat.ID, workerID, runnerID) + nullDeadline := f.setRequiresActionDeadline(t, chat.ID, sql.NullTime{}) + recorder := newTaskSideEffectRecorder() + starter := newTestTaskStarter(t, f, messagepartbuffer.New(messagepartbuffer.Options{}), recorder) + + err := starter.StartRequiresActionTimeout(testutil.Context(t, testutil.WaitLong), chatWorkerTaskStartInput{ + ChatID: chat.ID, + WorkerID: workerID, + RunnerID: runnerID, + HistoryVersion: acquired.HistoryVersion, + Status: database.ChatStatusRequiresAction, + RequiresActionDeadlineAt: nullDeadline.RequiresActionDeadlineAt, + }) + require.NoError(t, err) + + latest, err := f.db.GetChatByID(testutil.Context(t, testutil.WaitShort), chat.ID) + require.NoError(t, err) + require.Equal(t, database.ChatStatusRunning, latest.Status) + recorder.requireStateHint(t, chat.ID, latest.SnapshotVersion, database.ChatStatusRunning) +} + +func TestRequiresActionTimeout_StaleFenceExitsAfterToolResult(t *testing.T) { + t.Parallel() + + f := newTaskTestFixture(t) + chat := f.createRequiresActionChat(t) + workerID := uuid.New() + runnerID := uuid.New() + acquired := f.acquireChat(t, chat.ID, workerID, runnerID) + expired := f.setRequiresActionDeadline(t, chat.ID, sql.NullTime{Time: time.Now().Add(-time.Minute), Valid: true}) + f.forceExecutionState(t, chat.ID, database.ChatStatusRunning, false, sql.NullTime{}) + recorder := newTaskSideEffectRecorder() + starter := newTestTaskStarter(t, f, messagepartbuffer.New(messagepartbuffer.Options{}), recorder) + + err := starter.StartRequiresActionTimeout(testutil.Context(t, testutil.WaitLong), chatWorkerTaskStartInput{ + ChatID: chat.ID, + WorkerID: workerID, + RunnerID: runnerID, + HistoryVersion: acquired.HistoryVersion, + Status: database.ChatStatusRequiresAction, + RequiresActionDeadlineAt: expired.RequiresActionDeadlineAt, + }) + require.ErrorIs(t, err, errTaskExpectedExit) + + latest, err := f.db.GetChatByID(testutil.Context(t, testutil.WaitShort), chat.ID) + require.NoError(t, err) + require.Equal(t, database.ChatStatusRunning, latest.Status) + recorder.requireStateHintCount(t, 0) + f.requireNoWatchEvents(t) +} + +func TestAbandonTask_AbandonOnly(t *testing.T) { + t.Parallel() + + f := newTaskTestFixture(t) + chat := f.createRunningChat(t) + workerID := uuid.New() + runnerID := uuid.New() + acquired := f.acquireChat(t, chat.ID, workerID, runnerID) + recorder := newTaskSideEffectRecorder() + starter := newTestTaskStarter(t, f, messagepartbuffer.New(messagepartbuffer.Options{}), recorder) + + err := starter.StartAbandon(testutil.Context(t, testutil.WaitLong), chatWorkerTaskStartInput{ + ChatID: chat.ID, + WorkerID: workerID, + RunnerID: runnerID, + HistoryVersion: acquired.HistoryVersion, + Status: database.ChatStatusRunning, + }) + require.NoError(t, err) + + latest, err := f.db.GetChatByID(testutil.Context(t, testutil.WaitShort), chat.ID) + require.NoError(t, err) + require.False(t, latest.WorkerID.Valid) + require.False(t, latest.RunnerID.Valid) + recorder.requireCleanup(t, chat.ID, runnerID) + recorder.requireStateHintCount(t, 0) + f.requireNoWatchEvents(t) +} + +func TestAbandonTask_OwnershipMismatchRequestsCleanup(t *testing.T) { + t.Parallel() + + f := newTaskTestFixture(t) + chat := f.createRunningChat(t) + workerID := uuid.New() + runnerID := uuid.New() + f.acquireChat(t, chat.ID, workerID, runnerID) + otherWorkerID := uuid.New() + otherRunnerID := uuid.New() + latestOwner := f.acquireChat(t, chat.ID, otherWorkerID, otherRunnerID) + recorder := newTaskSideEffectRecorder() + starter := newTestTaskStarter(t, f, messagepartbuffer.New(messagepartbuffer.Options{}), recorder) + + err := starter.StartAbandon(testutil.Context(t, testutil.WaitLong), chatWorkerTaskStartInput{ + ChatID: chat.ID, + WorkerID: workerID, + RunnerID: runnerID, + HistoryVersion: latestOwner.HistoryVersion, + Status: database.ChatStatusRunning, + }) + require.NoError(t, err) + + latest, err := f.db.GetChatByID(testutil.Context(t, testutil.WaitShort), chat.ID) + require.NoError(t, err) + require.Equal(t, otherWorkerID, latest.WorkerID.UUID) + require.Equal(t, otherRunnerID, latest.RunnerID.UUID) + recorder.requireCleanup(t, chat.ID, runnerID) +} + +func TestAbandonTask_StaleStatusFenceExits(t *testing.T) { + t.Parallel() + + f := newTaskTestFixture(t) + chat := f.createRunningChat(t) + workerID := uuid.New() + runnerID := uuid.New() + acquired := f.acquireChat(t, chat.ID, workerID, runnerID) + f.forceExecutionState(t, chat.ID, database.ChatStatusInterrupting, false, sql.NullTime{}) + recorder := newTaskSideEffectRecorder() + starter := newTestTaskStarter(t, f, messagepartbuffer.New(messagepartbuffer.Options{}), recorder) + + err := starter.StartAbandon(testutil.Context(t, testutil.WaitLong), chatWorkerTaskStartInput{ + ChatID: chat.ID, + WorkerID: workerID, + RunnerID: runnerID, + HistoryVersion: acquired.HistoryVersion, + Status: database.ChatStatusWaiting, + }) + require.ErrorIs(t, err, errTaskExpectedExit) + + latest, err := f.db.GetChatByID(testutil.Context(t, testutil.WaitShort), chat.ID) + require.NoError(t, err) + require.True(t, latest.WorkerID.Valid) + require.True(t, latest.RunnerID.Valid) + require.Equal(t, database.ChatStatusInterrupting, latest.Status) + recorder.requireCleanupCount(t, 0) +} + +func TestGenerationTask_RecordRetryState(t *testing.T) { + t.Parallel() + + f := newTaskTestFixture(t) + chat := f.createRunningChat(t) + workerID := uuid.New() + runnerID := uuid.New() + acquired := f.acquireChat(t, chat.ID, workerID, runnerID) + recorder := newTaskSideEffectRecorder() + starter := newTestTaskStarter(t, f, messagepartbuffer.New(messagepartbuffer.Options{}), recorder) + + attempt, _, _, closeEpisode, err := starter.beginGenerationAttempt( + testutil.Context(t, testutil.WaitLong), + chatstate.NewChatMachine(f.db, f.pubsub, chat.ID, chatstate.Options{}), + chatWorkerTaskStartInput{ + ChatID: chat.ID, + WorkerID: workerID, + RunnerID: runnerID, + HistoryVersion: acquired.HistoryVersion, + Status: database.ChatStatusRunning, + }, + ) + require.NoError(t, err) + closeEpisode() + require.Equal(t, int64(1), attempt) + before, err := f.db.GetChatByID(testutil.Context(t, testutil.WaitShort), chat.ID) + require.NoError(t, err) + require.False(t, before.RetryState.Valid) + + decision, err := starter.recordGenerationRetry( + testutil.Context(t, testutil.WaitLong), + chatstate.NewChatMachine(f.db, f.pubsub, chat.ID, chatstate.Options{}), + chatWorkerTaskStartInput{ + ChatID: chat.ID, + WorkerID: workerID, + RunnerID: runnerID, + HistoryVersion: acquired.HistoryVersion, + Status: database.ChatStatusRunning, + }, + chaterror.ClassifiedError{ + Message: "OpenAI is rate limiting requests.", + Kind: codersdk.ChatErrorKindRateLimit, + Provider: "openai", + Retryable: true, + StatusCode: 429, + }, + ) + require.NoError(t, err) + require.True(t, decision.retry) + require.Equal(t, int64(1), decision.generationAttempt) + require.Equal(t, chatretry.Delay(0), decision.delay) + + latest, err := f.db.GetChatByID(testutil.Context(t, testutil.WaitShort), chat.ID) + require.NoError(t, err) + require.True(t, latest.RetryState.Valid) + require.Equal(t, latest.SnapshotVersion, latest.RetryStateVersion) + require.Greater(t, latest.RetryStateVersion, before.RetryStateVersion) + require.Equal(t, before.GenerationAttempt, latest.GenerationAttempt) + recorder.requireStateHintCount(t, 0) + + var retryPayload codersdk.ChatStreamRetry + require.NoError(t, json.Unmarshal(latest.RetryState.RawMessage, &retryPayload)) + require.Equal(t, 1, retryPayload.Attempt) + require.Equal(t, chatretry.Delay(0).Milliseconds(), retryPayload.DelayMs) + require.Equal(t, "OpenAI is rate limiting requests.", retryPayload.Error) + require.Equal(t, codersdk.ChatErrorKindRateLimit, retryPayload.Kind) + require.Equal(t, "openai", retryPayload.Provider) + require.Equal(t, 429, retryPayload.StatusCode) + require.False(t, retryPayload.RetryingAt.IsZero()) +} + +func TestGenerationTask_RecordRetryStateUsesDurableGenerationAttempt(t *testing.T) { + t.Parallel() + + f := newTaskTestFixture(t) + chat := f.createRunningChat(t) + workerID := uuid.New() + runnerID := uuid.New() + acquired := f.acquireChat(t, chat.ID, workerID, runnerID) + starter := newTestTaskStarter(t, f, messagepartbuffer.New(messagepartbuffer.Options{}), newTaskSideEffectRecorder()) + machine := chatstate.NewChatMachine(f.db, f.pubsub, chat.ID, chatstate.Options{}) + + for range 3 { + attempt, _, _, closeEpisode, err := starter.beginGenerationAttempt( + testutil.Context(t, testutil.WaitLong), + machine, + chatWorkerTaskStartInput{ + ChatID: chat.ID, + WorkerID: workerID, + RunnerID: runnerID, + HistoryVersion: acquired.HistoryVersion, + Status: database.ChatStatusRunning, + }, + ) + require.NoError(t, err) + closeEpisode() + require.Positive(t, attempt) + } + + decision, err := starter.recordGenerationRetry( + testutil.Context(t, testutil.WaitLong), + machine, + chatWorkerTaskStartInput{ + ChatID: chat.ID, + WorkerID: workerID, + RunnerID: runnerID, + HistoryVersion: acquired.HistoryVersion, + Status: database.ChatStatusRunning, + }, + chaterror.ClassifiedError{ + Message: "OpenAI is temporarily unavailable.", + Kind: codersdk.ChatErrorKindTimeout, + Provider: "openai", + Retryable: true, + }, + ) + require.NoError(t, err) + require.True(t, decision.retry) + require.Equal(t, int64(3), decision.generationAttempt) + require.Equal(t, chatretry.Delay(2), decision.delay) + + latest, err := f.db.GetChatByID(testutil.Context(t, testutil.WaitShort), chat.ID) + require.NoError(t, err) + var retryPayload codersdk.ChatStreamRetry + require.NoError(t, json.Unmarshal(latest.RetryState.RawMessage, &retryPayload)) + require.Equal(t, 3, retryPayload.Attempt) + require.Equal(t, chatretry.Delay(2).Milliseconds(), retryPayload.DelayMs) +} + +func TestGenerationTask_RecordRetryStateClearedByNextAttempt(t *testing.T) { + t.Parallel() + + f := newTaskTestFixture(t) + chat := f.createRunningChat(t) + workerID := uuid.New() + runnerID := uuid.New() + acquired := f.acquireChat(t, chat.ID, workerID, runnerID) + starter := newTestTaskStarter(t, f, messagepartbuffer.New(messagepartbuffer.Options{}), newTaskSideEffectRecorder()) + machine := chatstate.NewChatMachine(f.db, f.pubsub, chat.ID, chatstate.Options{}) + input := chatWorkerTaskStartInput{ + ChatID: chat.ID, + WorkerID: workerID, + RunnerID: runnerID, + HistoryVersion: acquired.HistoryVersion, + Status: database.ChatStatusRunning, + } + + attempt, _, _, closeEpisode, err := starter.beginGenerationAttempt(testutil.Context(t, testutil.WaitLong), machine, input) + require.NoError(t, err) + closeEpisode() + require.Equal(t, int64(1), attempt) + _, err = starter.recordGenerationRetry( + testutil.Context(t, testutil.WaitLong), + machine, + input, + chaterror.ClassifiedError{ + Message: "OpenAI is temporarily unavailable.", + Kind: codersdk.ChatErrorKindTimeout, + Provider: "openai", + Retryable: true, + }, + ) + require.NoError(t, err) + withRetry, err := f.db.GetChatByID(testutil.Context(t, testutil.WaitShort), chat.ID) + require.NoError(t, err) + require.True(t, withRetry.RetryState.Valid) + + attempt, _, _, closeEpisode, err = starter.beginGenerationAttempt(testutil.Context(t, testutil.WaitLong), machine, input) + require.NoError(t, err) + closeEpisode() + require.Equal(t, int64(2), attempt) + after, err := f.db.GetChatByID(testutil.Context(t, testutil.WaitShort), chat.ID) + require.NoError(t, err) + require.False(t, after.RetryState.Valid) + require.Equal(t, after.SnapshotVersion, after.RetryStateVersion) + require.Greater(t, after.RetryStateVersion, withRetry.RetryStateVersion) +} + +func TestGenerationTask_RecordRetryStateStaleFenceExits(t *testing.T) { + t.Parallel() + + f := newTaskTestFixture(t) + chat := f.createRunningChat(t) + workerID := uuid.New() + runnerID := uuid.New() + acquired := f.acquireChat(t, chat.ID, workerID, runnerID) + starter := newTestTaskStarter(t, f, messagepartbuffer.New(messagepartbuffer.Options{}), newTaskSideEffectRecorder()) + machine := chatstate.NewChatMachine(f.db, f.pubsub, chat.ID, chatstate.Options{}) + attempt, _, _, closeEpisode, err := starter.beginGenerationAttempt( + testutil.Context(t, testutil.WaitLong), + machine, + chatWorkerTaskStartInput{ + ChatID: chat.ID, + WorkerID: workerID, + RunnerID: runnerID, + HistoryVersion: acquired.HistoryVersion, + Status: database.ChatStatusRunning, + }, + ) + require.NoError(t, err) + closeEpisode() + require.Equal(t, int64(1), attempt) + + otherWorkerID := uuid.New() + otherRunnerID := uuid.New() + f.acquireChat(t, chat.ID, otherWorkerID, otherRunnerID) + _, err = starter.recordGenerationRetry( + testutil.Context(t, testutil.WaitLong), + machine, + chatWorkerTaskStartInput{ + ChatID: chat.ID, + WorkerID: workerID, + RunnerID: runnerID, + HistoryVersion: acquired.HistoryVersion, + Status: database.ChatStatusRunning, + }, + chaterror.ClassifiedError{ + Message: "OpenAI is temporarily unavailable.", + Kind: codersdk.ChatErrorKindTimeout, + Provider: "openai", + Retryable: true, + }, + ) + require.ErrorIs(t, err, errTaskExpectedExit) + latest, err := f.db.GetChatByID(testutil.Context(t, testutil.WaitShort), chat.ID) + require.NoError(t, err) + require.False(t, latest.RetryState.Valid) + require.Equal(t, otherWorkerID, latest.WorkerID.UUID) + require.Equal(t, otherRunnerID, latest.RunnerID.UUID) +} + +func TestRunner_StartsRealInterruptTask(t *testing.T) { + t.Parallel() + + f := newTaskTestFixture(t) + chat := f.createRunningChat(t) + worker := startRealTaskWorker(t, f, messagepartbuffer.New(messagepartbuffer.Options{})) + waitOwnedChat(t, f, chat.ID, worker.chatWorkerID()) + + interrupting := f.interruptChat(t, chat.ID) + require.Equal(t, database.ChatStatusInterrupting, interrupting.Status) + testutil.Eventually(testutil.Context(t, testutil.WaitLong), t, func(ctx context.Context) bool { + latest, err := f.db.GetChatByID(ctx, chat.ID) + return err == nil && latest.Status == database.ChatStatusRunning + }, testutil.IntervalFast) + latest, err := f.db.GetChatByID(testutil.Context(t, testutil.WaitShort), chat.ID) + require.NoError(t, err) + require.Equal(t, worker.chatWorkerID(), latest.WorkerID.UUID) + f.requireWatchEvent(t, chat.ID, codersdk.ChatWatchEventKindStatusChange) +} + +func TestRunner_StartsRealRequiresActionTimeoutTask(t *testing.T) { + t.Parallel() + + f := newTaskTestFixture(t) + chat := f.createRequiresActionChat(t) + f.setRequiresActionDeadline(t, chat.ID, sql.NullTime{Time: time.Now().Add(-time.Minute), Valid: true}) + worker := startRealTaskWorker(t, f, messagepartbuffer.New(messagepartbuffer.Options{})) + + testutil.Eventually(testutil.Context(t, testutil.WaitLong), t, func(ctx context.Context) bool { + latest, err := f.db.GetChatByID(ctx, chat.ID) + return err == nil && latest.Status == database.ChatStatusRunning && latest.WorkerID.Valid && latest.WorkerID.UUID == worker.chatWorkerID() + }, testutil.IntervalFast) + latest, err := f.db.GetChatByID(testutil.Context(t, testutil.WaitShort), chat.ID) + require.NoError(t, err) + require.True(t, latest.RunnerID.Valid) + f.requireWatchEvent(t, chat.ID, codersdk.ChatWatchEventKindStatusChange) +} + +func TestRunner_StartsRealAbandonTask(t *testing.T) { + t.Parallel() + + f := newTaskTestFixture(t) + chat := f.createRunningChat(t) + worker := startRealTaskWorker(t, f, messagepartbuffer.New(messagepartbuffer.Options{})) + waitOwnedChat(t, f, chat.ID, worker.chatWorkerID()) + + updated := f.forceExecutionState(t, chat.ID, database.ChatStatusError, false, sql.NullTime{}) + f.publishChatUpdate(t, updated) + testutil.Eventually(testutil.Context(t, testutil.WaitLong), t, func(ctx context.Context) bool { + latest, err := f.db.GetChatByID(ctx, chat.ID) + return err == nil && !latest.WorkerID.Valid && !latest.RunnerID.Valid + }, testutil.IntervalFast) +} + +type taskTestFixture struct { + db database.Store + pubsub *taskRecordingPubsub + sqlDB *sql.DB + user database.User + org database.Organization + model database.ChatModelConfig +} + +func newTaskTestFixture(t *testing.T) *taskTestFixture { + t.Helper() + db, ps, sqlDB := dbtestutil.NewDBWithSQLDB(t) + user := dbgen.User(t, db, database.User{}) + org := dbgen.Organization(t, db, database.Organization{}) + dbgen.OrganizationMember(t, db, database.OrganizationMember{UserID: user.ID, OrganizationID: org.ID}) + dbgen.ChatProvider(t, db, database.ChatProvider{ + Provider: "openai", + DisplayName: "openai", + BaseUrl: "http://example.invalid", + }) + model := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{Provider: "openai", IsDefault: true}) + return &taskTestFixture{db: db, pubsub: newTaskRecordingPubsub(ps), sqlDB: sqlDB, user: user, org: org, model: model} +} + +func (f *taskTestFixture) createRunningChat(t *testing.T) database.Chat { + t.Helper() + res, err := chatstate.CreateChat(testutil.Context(t, testutil.WaitShort), f.db, f.pubsub, chatstate.CreateChatInput{ + OrganizationID: f.org.ID, + OwnerID: f.user.ID, + LastModelConfigID: f.model.ID, + Title: "test", + ClientType: database.ChatClientTypeApi, + InitialMessages: []chatstate.Message{taskUserTextMessage(t, "hello", f.user.ID, f.model.ID)}, + }) + require.NoError(t, err) + f.pubsub.clear() + return res.Chat +} + +func (f *taskTestFixture) createRequiresActionChat(t *testing.T) database.Chat { + t.Helper() + toolName := "dynamic_" + uuid.NewString() + dynamicTools, err := json.Marshal([]codersdk.DynamicTool{{ + Name: toolName, + Description: "test tool", + InputSchema: json.RawMessage(`{"type":"object"}`), + }}) + require.NoError(t, err) + res, err := chatstate.CreateChat(testutil.Context(t, testutil.WaitShort), f.db, f.pubsub, chatstate.CreateChatInput{ + OrganizationID: f.org.ID, + OwnerID: f.user.ID, + LastModelConfigID: f.model.ID, + Title: "test", + ClientType: database.ChatClientTypeApi, + DynamicTools: pqtype.NullRawMessage{RawMessage: dynamicTools, Valid: true}, + InitialMessages: []chatstate.Message{taskUserTextMessage(t, "hello", f.user.ID, f.model.ID)}, + }) + require.NoError(t, err) + machine := chatstate.NewChatMachine(f.db, f.pubsub, res.Chat.ID, chatstate.Options{}) + require.NoError(t, machine.Update(testutil.Context(t, testutil.WaitShort), func(tx *chatstate.Tx) error { + _, err := tx.CommitStep(chatstate.CommitStepInput{Messages: []chatstate.Message{taskAssistantToolCallMessage(t, f.model.ID, toolName)}}) + return err + })) + require.NoError(t, machine.Update(testutil.Context(t, testutil.WaitShort), func(tx *chatstate.Tx) error { + _, err := tx.EnterRequiresAction(chatstate.EnterRequiresActionInput{}) + return err + })) + chat, err := f.db.GetChatByID(testutil.Context(t, testutil.WaitShort), res.Chat.ID) + require.NoError(t, err) + f.pubsub.clear() + return chat +} + +func (f *taskTestFixture) acquireChat(t *testing.T, chatID uuid.UUID, workerID uuid.UUID, runnerID uuid.UUID) database.Chat { + t.Helper() + machine := chatstate.NewChatMachine(f.db, f.pubsub, chatID, chatstate.Options{}) + require.NoError(t, machine.Update(testutil.Context(t, testutil.WaitShort), func(tx *chatstate.Tx) error { + _, err := tx.Acquire(chatstate.AcquireInput{WorkerID: workerID, RunnerID: runnerID}) + return err + })) + chat, err := f.db.GetChatByID(testutil.Context(t, testutil.WaitShort), chatID) + require.NoError(t, err) + f.pubsub.clear() + return chat +} + +func (f *taskTestFixture) interruptChat(t *testing.T, chatID uuid.UUID) database.Chat { + t.Helper() + machine := chatstate.NewChatMachine(f.db, f.pubsub, chatID, chatstate.Options{}) + require.NoError(t, machine.Update(testutil.Context(t, testutil.WaitShort), func(tx *chatstate.Tx) error { + _, err := tx.SendMessage(chatstate.SendMessageInput{ + Message: taskUserTextMessage(t, "interrupt", f.user.ID, f.model.ID), + BusyBehavior: chatstate.BusyBehaviorInterrupt, + }) + return err + })) + chat, err := f.db.GetChatByID(testutil.Context(t, testutil.WaitShort), chatID) + require.NoError(t, err) + f.pubsub.clear() + return chat +} + +func (f *taskTestFixture) forceExecutionState(t *testing.T, chatID uuid.UUID, status database.ChatStatus, archived bool, deadline sql.NullTime) database.Chat { + t.Helper() + var updated database.Chat + require.NoError(t, f.db.InTx(func(store database.Store) error { + if _, err := store.LockChatAndBumpSnapshotVersion(testutil.Context(t, testutil.WaitShort), chatID); err != nil { + return err + } + chat, err := store.GetChatByID(testutil.Context(t, testutil.WaitShort), chatID) + if err != nil { + return err + } + updated, err = store.UpdateChatExecutionState(testutil.Context(t, testutil.WaitShort), database.UpdateChatExecutionStateParams{ + ID: chat.ID, + Status: status, + Archived: archived, + WorkerID: chat.WorkerID, + RunnerID: chat.RunnerID, + LastError: chat.LastError, + RequiresActionDeadlineAt: deadline, + }) + return err + }, nil)) + f.pubsub.clear() + return updated +} + +func (f *taskTestFixture) setRequiresActionDeadline(t *testing.T, chatID uuid.UUID, deadline sql.NullTime) database.Chat { + t.Helper() + chat, err := f.db.GetChatByID(testutil.Context(t, testutil.WaitShort), chatID) + require.NoError(t, err) + return f.forceExecutionState(t, chatID, chat.Status, chat.Archived, deadline) +} + +func (f *taskTestFixture) publishChatUpdate(t *testing.T, chat database.Chat) { + t.Helper() + msg := coderdpubsub.ChatStateUpdateMessage{ + SnapshotVersion: chat.SnapshotVersion, + HistoryVersion: chat.HistoryVersion, + QueueVersion: chat.QueueVersion, + RetryStateVersion: chat.RetryStateVersion, + GenerationAttempt: chat.GenerationAttempt, + Status: string(chat.Status), + Archived: chat.Archived, + } + if chat.WorkerID.Valid { + id := chat.WorkerID.UUID + msg.WorkerID = &id + } + if chat.RunnerID.Valid { + id := chat.RunnerID.UUID + msg.RunnerID = &id + } + payload, err := json.Marshal(msg) + require.NoError(t, err) + require.NoError(t, f.pubsub.Publish(coderdpubsub.ChatStateUpdateChannel(chat.ID), payload)) +} + +func (f *taskTestFixture) requireWatchEvent(t *testing.T, chatID uuid.UUID, kind codersdk.ChatWatchEventKind) { + t.Helper() + events := f.pubsub.watchEvents(t) + for _, event := range events { + if event.Kind == kind && event.Chat.ID == chatID { + return + } + } + t.Fatalf("missing watch event kind=%s chat_id=%s events=%v", kind, chatID, events) +} + +func (f *taskTestFixture) requireNoWatchEvents(t *testing.T) { + t.Helper() + require.Empty(t, f.pubsub.watchEvents(t)) +} +func taskUserTextMessage(t *testing.T, text string, createdBy uuid.UUID, modelConfigID uuid.UUID) chatstate.Message { + t.Helper() + raw, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{codersdk.ChatMessageText(text)}) + require.NoError(t, err) + return chatstate.Message{ + Role: database.ChatMessageRoleUser, + Content: raw, + Visibility: database.ChatMessageVisibilityBoth, + ContentVersion: chatprompt.CurrentContentVersion, + CreatedBy: uuid.NullUUID{UUID: createdBy, Valid: true}, + ModelConfigID: uuid.NullUUID{UUID: modelConfigID, Valid: true}, + } +} + +func taskAssistantToolCallMessage(t *testing.T, modelConfigID uuid.UUID, toolName string) chatstate.Message { + t.Helper() + raw, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{{ + Type: codersdk.ChatMessagePartTypeToolCall, + ToolCallID: "call_" + uuid.NewString(), + ToolName: toolName, + Args: json.RawMessage(`{}`), + }}) + require.NoError(t, err) + return chatstate.Message{ + Role: database.ChatMessageRoleAssistant, + Content: raw, + Visibility: database.ChatMessageVisibilityBoth, + ContentVersion: chatprompt.CurrentContentVersion, + ModelConfigID: uuid.NullUUID{UUID: modelConfigID, Valid: true}, + } +} + +type taskPublishedEvent struct { + channel string + payload []byte +} + +type taskRecordingPubsub struct { + inner dbpubsub.Pubsub + mu sync.Mutex + sent []taskPublishedEvent +} + +func newTaskRecordingPubsub(inner dbpubsub.Pubsub) *taskRecordingPubsub { + return &taskRecordingPubsub{inner: inner} +} + +func (p *taskRecordingPubsub) Publish(channel string, payload []byte) error { + p.mu.Lock() + p.sent = append(p.sent, taskPublishedEvent{channel: channel, payload: append([]byte(nil), payload...)}) + p.mu.Unlock() + return p.inner.Publish(channel, payload) +} + +func (p *taskRecordingPubsub) SubscribeWithErr(channel string, listener dbpubsub.ListenerWithErr) (func(), error) { + return p.inner.SubscribeWithErr(channel, listener) +} + +func (p *taskRecordingPubsub) clear() { + p.mu.Lock() + p.sent = nil + p.mu.Unlock() +} + +func (p *taskRecordingPubsub) events() []taskPublishedEvent { + p.mu.Lock() + defer p.mu.Unlock() + return append([]taskPublishedEvent(nil), p.sent...) +} + +func (p *taskRecordingPubsub) watchEvents(t *testing.T) []codersdk.ChatWatchEvent { + t.Helper() + events := p.events() + out := make([]codersdk.ChatWatchEvent, 0) + for _, event := range events { + var payload codersdk.ChatWatchEvent + if err := json.Unmarshal(event.payload, &payload); err != nil { + continue + } + if event.channel != coderdpubsub.ChatWatchEventChannel(payload.Chat.OwnerID) { + continue + } + out = append(out, payload) + } + return out +} + +func startRealTaskWorker(t *testing.T, f *taskTestFixture, buffer *messagepartbuffer.Buffer) *chatWorker { + t.Helper() + worker, err := newChatWorker(nil, chatWorkerOptions{ + WorkerID: uuid.New(), + Store: f.db, + Pubsub: f.pubsub, + Logger: slog.Make(), + MessagePartBuffer: buffer, + AcquisitionInterval: time.Hour, + AcquisitionBatchSize: 10, + RunnerSyncInterval: time.Hour, + HeartbeatInterval: time.Hour, + HeartbeatCleanupInterval: time.Hour, + HeartbeatStaleSeconds: 30, + StateChannelSize: 16, + RunnerManagerChannelSize: 16, + AcquisitionWakeChannelSize: 1, + TaskRetryInitialBackoff: time.Millisecond, + TaskRetryMaxBackoff: time.Millisecond, + }) + require.NoError(t, err) + require.NoError(t, worker.Start(context.Background())) + t.Cleanup(func() { require.NoError(t, worker.Close()) }) + return worker +} + +func waitOwnedChat(t *testing.T, f *taskTestFixture, chatID uuid.UUID, workerID uuid.UUID) database.Chat { + t.Helper() + var latest database.Chat + testutil.Eventually(testutil.Context(t, testutil.WaitLong), t, func(ctx context.Context) bool { + chat, err := f.db.GetChatByID(ctx, chatID) + if err != nil { + return false + } + latest = chat + return chat.WorkerID.Valid && chat.WorkerID.UUID == workerID && chat.RunnerID.Valid + }, testutil.IntervalFast) + return latest +} + +type taskSideEffectRecorder struct { + mu sync.Mutex + hints []runnerStateUpdate + cleanups []runnerKey + interrupts []interruptionOutcome +} + +func newTaskSideEffectRecorder() *taskSideEffectRecorder { + return &taskSideEffectRecorder{} +} + +func (r *taskSideEffectRecorder) routeStateHint(_ context.Context, state runnerStateUpdate) { + r.mu.Lock() + r.hints = append(r.hints, state) + r.mu.Unlock() +} + +func (r *taskSideEffectRecorder) requestCleanup(_ context.Context, key runnerKey) { + r.mu.Lock() + r.cleanups = append(r.cleanups, key) + r.mu.Unlock() +} + +func (r *taskSideEffectRecorder) afterInterruptionOutcome(_ context.Context, outcome interruptionOutcome) error { + r.mu.Lock() + r.interrupts = append(r.interrupts, outcome) + r.mu.Unlock() + return nil +} + +func (r *taskSideEffectRecorder) requireStateHint(t *testing.T, chatID uuid.UUID, snapshot int64, status database.ChatStatus) { + t.Helper() + r.mu.Lock() + defer r.mu.Unlock() + for _, hint := range r.hints { + if hint.ChatID == chatID && hint.SnapshotVersion == snapshot && hint.Status == status { + return + } + } + t.Fatalf("missing state hint chat_id=%s snapshot=%d status=%s hints=%v", chatID, snapshot, status, r.hints) +} + +func (r *taskSideEffectRecorder) requireStateHintCount(t *testing.T, count int) { + t.Helper() + r.mu.Lock() + defer r.mu.Unlock() + require.Len(t, r.hints, count) +} + +func (r *taskSideEffectRecorder) requireCleanup(t *testing.T, chatID uuid.UUID, runnerID uuid.UUID) { + t.Helper() + r.mu.Lock() + defer r.mu.Unlock() + for _, cleanup := range r.cleanups { + if cleanup.ChatID == chatID && cleanup.RunnerID == runnerID { + return + } + } + t.Fatalf("missing cleanup chat_id=%s runner_id=%s cleanups=%v", chatID, runnerID, r.cleanups) +} + +func (r *taskSideEffectRecorder) requireCleanupCount(t *testing.T, count int) { + t.Helper() + r.mu.Lock() + defer r.mu.Unlock() + require.Len(t, r.cleanups, count) +} + +func (r *taskSideEffectRecorder) requireInterruptionOutcome(t *testing.T, chatID uuid.UUID, status database.ChatStatus) { + t.Helper() + r.mu.Lock() + defer r.mu.Unlock() + for _, outcome := range r.interrupts { + if outcome.Chat.ID == chatID && outcome.Chat.Status == status { + return + } + } + t.Fatalf("missing interruption outcome chat_id=%s status=%s outcomes=%v", chatID, status, r.interrupts) +} + +func newTestTaskStarter(t *testing.T, f *taskTestFixture, buffer *messagepartbuffer.Buffer, recorder *taskSideEffectRecorder) *taskStarter { + t.Helper() + starter, err := newTaskStarter(nil, chatWorkerOptions{ + Store: f.db, + Pubsub: f.pubsub, + Logger: slog.Make(), + Clock: quartz.NewReal(), + MessagePartBuffer: buffer, + TaskRetryInitialBackoff: time.Millisecond, + TaskRetryMaxBackoff: time.Millisecond, + }, recorder.routeStateHint, recorder.requestCleanup) + require.NoError(t, err) + starter.afterInterruptionOutcome = recorder.afterInterruptionOutcome + return starter +} diff --git a/coderd/x/chatd/testhooks.go b/coderd/x/chatd/testhooks.go index 7c7177b88b..c1356ee3d0 100644 --- a/coderd/x/chatd/testhooks.go +++ b/coderd/x/chatd/testhooks.go @@ -1,9 +1,20 @@ package chatd +import ( + "context" + "time" +) + // WaitUntilIdleForTest waits for background chat work tracked by the server to // finish without shutting the server down. Tests use this to assert final // database state only after asynchronous chat processing has completed. // Close waits for the same tracked work, but also stops the server. func WaitUntilIdleForTest(server *Server) { server.drainInflight() + if server.chatWorker == nil { + return + } + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + _ = server.chatWorker.WaitIdle(ctx) } diff --git a/coderd/x/chatd/turn_summary_internal_test.go b/coderd/x/chatd/turn_summary_internal_test.go index c38d57754f..c626da28dd 100644 --- a/coderd/x/chatd/turn_summary_internal_test.go +++ b/coderd/x/chatd/turn_summary_internal_test.go @@ -6,7 +6,6 @@ import ( "encoding/json" "sync/atomic" "testing" - "time" "charm.land/fantasy" "github.com/google/uuid" @@ -16,6 +15,8 @@ import ( "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbgen" "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/coderd/x/chatd/chatprompt" + "github.com/coder/coder/v2/coderd/x/chatd/chatstate" "github.com/coder/coder/v2/coderd/x/chatd/chattest" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/testutil" @@ -55,38 +56,63 @@ func TestUpdateLastTurnSummaryRejectsStaleWrites(t *testing.T) { }) require.NoError(t, err) - chat, err := db.InsertChat(ctx, database.InsertChatParams{ + content, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{ + codersdk.ChatMessageText("hello"), + }) + require.NoError(t, err) + created, err := chatstate.CreateChat(ctx, db, ps, chatstate.CreateChatInput{ OrganizationID: org.ID, - Status: database.ChatStatusWaiting, - ClientType: database.ChatClientTypeUi, OwnerID: owner.ID, LastModelConfigID: modelCfg.ID, Title: "summary-chat", + ClientType: database.ChatClientTypeUi, + InitialMessages: []chatstate.Message{ + { + Role: database.ChatMessageRoleUser, + Content: content, + Visibility: database.ChatMessageVisibilityBoth, + ContentVersion: chatprompt.CurrentContentVersion, + CreatedBy: uuid.NullUUID{UUID: owner.ID, Valid: true}, + ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true}, + }, + }, }) require.NoError(t, err) + chat := created.Chat logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) server := &Server{db: db, pubsub: ps} - server.updateLastTurnSummary(ctx, chat, chat.UpdatedAt, "fresh summary", logger) + server.updateLastTurnSummary(ctx, chat, chat.HistoryVersion, "fresh summary", logger) fetched, err := db.GetChatByID(ctx, chat.ID) require.NoError(t, err) require.Equal(t, sql.NullString{String: "fresh summary", Valid: true}, fetched.LastTurnSummary) - advancedUpdatedAt := chat.UpdatedAt.Add(time.Second) - _, err = db.UpdateChatStatusPreserveUpdatedAt(ctx, database.UpdateChatStatusPreserveUpdatedAtParams{ - ID: chat.ID, - Status: database.ChatStatusRunning, - UpdatedAt: advancedUpdatedAt, + assistantContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{ + codersdk.ChatMessageText("assistant response"), }) require.NoError(t, err) + machine := chatstate.NewChatMachine(db, ps, chat.ID, chatstate.Options{}) + require.NoError(t, machine.Update(ctx, func(tx *chatstate.Tx) error { + _, err := tx.CommitStep(chatstate.CommitStepInput{ + Messages: []chatstate.Message{ + { + Role: database.ChatMessageRoleAssistant, + Content: assistantContent, + Visibility: database.ChatMessageVisibilityBoth, + ContentVersion: chatprompt.CurrentContentVersion, + ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true}, + }, + }, + }) + return err + })) - server.updateLastTurnSummary(context.WithoutCancel(ctx), chat, chat.UpdatedAt, "stale summary", logger) + server.updateLastTurnSummary(context.WithoutCancel(ctx), chat, chat.HistoryVersion, "stale summary", logger) fetched, err = db.GetChatByID(ctx, chat.ID) require.NoError(t, err) require.Equal(t, sql.NullString{String: "fresh summary", Valid: true}, fetched.LastTurnSummary) - require.Equal(t, advancedUpdatedAt, fetched.UpdatedAt) } func TestPendingChatPersistsSummaryButSkipsWebPush(t *testing.T) { diff --git a/coderd/x/chatd/worker.go b/coderd/x/chatd/worker.go new file mode 100644 index 0000000000..ffec11be94 --- /dev/null +++ b/coderd/x/chatd/worker.go @@ -0,0 +1,314 @@ +package chatd + +import ( + "context" + "database/sql" + "errors" + "sync" + "time" + + "github.com/google/uuid" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/database" + coderdpubsub "github.com/coder/coder/v2/coderd/pubsub" + "github.com/coder/coder/v2/coderd/x/chatd/chatstate" +) + +// chatWorker owns chat acquisition and runner lifecycle for one process. +type chatWorker struct { + server *Server + opts chatWorkerOptions + + mu sync.Mutex + started bool + ctx context.Context + cancel context.CancelFunc + manager *runnerManager + unsubscribe func() + wakeCh chan struct{} + wg sync.WaitGroup +} + +// New constructs a chat worker. The worker is idle until Start is called. +func newChatWorker(server *Server, opts chatWorkerOptions) (*chatWorker, error) { + withDefaults, err := opts.withDefaults() + if err != nil { + return nil, err + } + return &chatWorker{server: server, opts: withDefaults}, nil +} + +// chatWorkerID returns this worker's configured worker ID. +func (w *chatWorker) chatWorkerID() uuid.UUID { + return w.opts.WorkerID +} + +// Start starts the acquisition and runner manager loops. +func (w *chatWorker) Start(ctx context.Context) error { + w.mu.Lock() + defer w.mu.Unlock() + if w.started { + return xerrors.New("chatworker: worker already started") + } + workerID := w.opts.WorkerID + workerCtx, cancel := context.WithCancel(ctx) + manager := newRunnerManager(workerCtx, w.server, w.opts) + if manager.opts.TaskStarter == nil { + starter, err := newTaskStarter(manager.server, manager.opts, manager.RouteStateHint, manager.requestCleanup) + if err != nil { + cancel() + return err + } + manager.opts.TaskStarter = starter + } + wakeCh := make(chan struct{}, w.opts.AcquisitionWakeChannelSize) + + unsubscribe, err := w.opts.Pubsub.SubscribeWithErr( + coderdpubsub.ChatStateOwnershipChannel, + coderdpubsub.HandleChatStateOwnership(func(ctx context.Context, _ coderdpubsub.ChatStateOwnershipMessage, err error) { + if err != nil { + w.opts.Logger.Warn(ctx, "chatworker ownership hint decode failed", slogError(err)) + return + } + wake(wakeCh) + }), + ) + if err != nil { + cancel() + return xerrors.Errorf("subscribe ownership hints: %w", err) + } + + w.started = true + w.ctx = workerCtx + w.cancel = cancel + w.manager = manager + w.unsubscribe = unsubscribe + w.wakeCh = wakeCh + + manager.start() + w.wg.Go(func() { + w.acquisitionLoop(workerCtx, workerID, manager, wakeCh) + }) + w.wg.Go(func() { + w.archiveLoop(workerCtx) + }) + wake(wakeCh) + return nil +} + +// Wake requests an immediate acquisition pass. +func (w *chatWorker) Wake() { + w.mu.Lock() + wakeCh := w.wakeCh + w.mu.Unlock() + if wakeCh != nil { + wake(wakeCh) + } +} + +// WaitIdle waits until the worker has no active or cleaning runners. +func (w *chatWorker) WaitIdle(ctx context.Context) error { + for { + w.mu.Lock() + manager := w.manager + w.mu.Unlock() + if manager == nil || manager.idle() { + return nil + } + timer := w.opts.Clock.NewTimer(10*time.Millisecond, "chatworker", "wait-idle") + select { + case <-timer.C: + case <-ctx.Done(): + timer.Stop() + return ctx.Err() + } + timer.Stop() + } +} + +// Close stops the worker and waits for its loops to exit. +func (w *chatWorker) Close() error { + w.mu.Lock() + if !w.started { + w.mu.Unlock() + return nil + } + cancel := w.cancel + unsubscribe := w.unsubscribe + manager := w.manager + w.started = false + w.cancel = nil + w.unsubscribe = nil + w.manager = nil + w.wakeCh = nil + w.mu.Unlock() + + if unsubscribe != nil { + unsubscribe() + } + cancel() + w.wg.Wait() + if manager != nil { + manager.wait() + } + return nil +} + +func wake(ch chan<- struct{}) { + select { + case ch <- struct{}{}: + default: + } +} + +func (w *chatWorker) acquisitionLoop( + ctx context.Context, + workerID uuid.UUID, + manager *runnerManager, + wakeCh <-chan struct{}, +) { + ticker := w.opts.Clock.NewTicker(w.opts.AcquisitionInterval, "chatworker", "acquisition") + defer ticker.Stop() + for { + select { + case <-wakeCh: + w.acquireOnce(ctx, workerID, manager) + case <-ticker.C: + w.acquireOnce(ctx, workerID, manager) + case <-ctx.Done(): + return + } + } +} + +func (w *chatWorker) acquireOnce(ctx context.Context, workerID uuid.UUID, manager *runnerManager) { + attempted := make(map[uuid.UUID]struct{}) + for { + rows, err := w.opts.Store.GetChatWorkerAcquisitionCandidates(ctx, database.GetChatWorkerAcquisitionCandidatesParams{ + StaleSeconds: w.opts.HeartbeatStaleSeconds, + LimitCount: w.opts.AcquisitionBatchSize, + }) + if err != nil { + if ctx.Err() == nil { + w.opts.Logger.Warn(ctx, "chatworker acquisition query failed", slogError(err)) + } + return + } + if len(rows) == 0 { + return + } + newRows := 0 + for _, row := range rows { + if _, ok := attempted[row.ID]; ok { + continue + } + attempted[row.ID] = struct{}{} + newRows++ + if err := w.acquireCandidateSafely(ctx, workerID, manager, row.ID); err != nil { + if ctx.Err() != nil { + return + } + w.opts.Logger.Warn(ctx, "chatworker acquisition candidate failed", slogError(err)) + } + } + if len(rows) < int(w.opts.AcquisitionBatchSize) || newRows == 0 { + return + } + } +} + +var errSkipAcquire = xerrors.New("skip acquire") + +func (w *chatWorker) acquireCandidateSafely( + ctx context.Context, + workerID uuid.UUID, + manager *runnerManager, + chatID uuid.UUID, +) (err error) { + defer func() { + if recovered := recover(); recovered != nil { + err = xerrors.Errorf("chatworker acquisition panic: %v", recovered) + } + }() + return w.acquireCandidate(ctx, workerID, manager, chatID) +} + +func (w *chatWorker) acquireCandidate( + ctx context.Context, + workerID uuid.UUID, + manager *runnerManager, + chatID uuid.UUID, +) error { + runnerID := uuid.New() + machine := chatstate.NewChatMachine(w.opts.Store, w.opts.Pubsub, chatID, chatstate.Options{}) + err := machine.Update(ctx, func(tx *chatstate.Tx) error { + store := tx.Store() + chat, err := store.GetChatByID(ctx, chatID) + if errors.Is(err, sql.ErrNoRows) { + return errSkipAcquire + } + if err != nil { + return xerrors.Errorf("load chat: %w", err) + } + queueCount, err := store.CountChatQueuedMessages(ctx, chatID) + if err != nil { + return xerrors.Errorf("count queue: %w", err) + } + if !chatstate.ClassifyExecutionState(chat, queueCount > 0, true).IsRunnable() || chat.Archived { + return errSkipAcquire + } + if chat.WorkerID.Valid && chat.RunnerID.Valid { + stale, err := store.IsChatHeartbeatStale(ctx, database.IsChatHeartbeatStaleParams{ + ChatID: chat.ID, + RunnerID: chat.RunnerID.UUID, + StaleSeconds: w.opts.HeartbeatStaleSeconds, + }) + if err != nil { + return xerrors.Errorf("check heartbeat stale: %w", err) + } + if !stale { + return errSkipAcquire + } + } + _, err = tx.Acquire(chatstate.AcquireInput{WorkerID: workerID, RunnerID: runnerID}) + return err + }) + if errors.Is(err, errSkipAcquire) || errors.Is(err, chatstate.ErrChatNotFound) { + return nil + } + if err != nil { + return err + } + if err := manager.Spawn(ctx, spawnRunnerRequest{ChatID: chatID, WorkerID: workerID, RunnerID: runnerID}); err != nil { + if errAbandon := w.abandonAcquiredChat(ctx, workerID, runnerID, chatID); errAbandon != nil { + return errors.Join(err, errAbandon) + } + return err + } + return nil +} + +func (w *chatWorker) abandonAcquiredChat(ctx context.Context, workerID uuid.UUID, runnerID uuid.UUID, chatID uuid.UUID) error { + cleanupCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), shutdownCleanupTimeout) + defer cancel() + machine := chatstate.NewChatMachine(w.opts.Store, w.opts.Pubsub, chatID, chatstate.Options{}) + err := machine.Update(cleanupCtx, func(tx *chatstate.Tx) error { + chat, err := tx.Store().GetChatByID(cleanupCtx, chatID) + if errors.Is(err, sql.ErrNoRows) { + return errSkipAcquire + } + if err != nil { + return xerrors.Errorf("load chat: %w", err) + } + if !chat.WorkerID.Valid || chat.WorkerID.UUID != workerID || !chat.RunnerID.Valid || chat.RunnerID.UUID != runnerID { + return errSkipAcquire + } + _, err = tx.Abandon(chatstate.AbandonInput{}) + return err + }) + if errors.Is(err, errSkipAcquire) || errors.Is(err, chatstate.ErrChatNotFound) { + return nil + } + return err +} diff --git a/coderd/x/chatd/worker_internal_test.go b/coderd/x/chatd/worker_internal_test.go new file mode 100644 index 0000000000..6a635d8418 --- /dev/null +++ b/coderd/x/chatd/worker_internal_test.go @@ -0,0 +1,315 @@ +package chatd //nolint:testpackage // Tests unexported chat worker internals. + +import ( + "context" + "database/sql" + "errors" + "testing" + "time" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/testutil" + "github.com/coder/quartz" +) + +func TestWorker_NewRequiresTaskStarterOrMessagePartBuffer(t *testing.T) { + t.Parallel() + f := newWorkerTestFixture(t) + _, err := newChatWorker(nil, chatWorkerOptions{WorkerID: uuid.New(), Store: f.db, Pubsub: f.pubsub}) + require.ErrorContains(t, err, "task starter or message part buffer is required") +} + +func TestWorker_NewRequiresWorkerID(t *testing.T) { + t.Parallel() + f := newWorkerTestFixture(t) + opts := testOptions(t, f, newRecordingTaskStarter()) + opts.WorkerID = uuid.Nil + _, err := newChatWorker(nil, opts) + require.ErrorContains(t, err, "worker ID is required") +} + +func TestWorker_UsesConfiguredWorkerID(t *testing.T) { + t.Parallel() + f := newWorkerTestFixture(t) + starter := newRecordingTaskStarter() + opts := testOptions(t, f, starter) + workerID := opts.WorkerID + worker, err := newChatWorker(nil, opts) + require.NoError(t, err) + require.Equal(t, workerID, worker.chatWorkerID()) + require.NoError(t, worker.Start(context.Background())) + require.Equal(t, workerID, worker.chatWorkerID()) + require.NoError(t, worker.Close()) +} + +func TestWorker_AcquiresRunnableChatFromOwnershipHint(t *testing.T) { + t.Parallel() + f := newWorkerTestFixture(t) + chat := f.createRunningChat(t) + starter := newRecordingTaskStarter() + worker := startWorker(t, testOptions(t, f, starter)) + + call := starter.waitCall(t, taskKindGeneration, chat.ID) + require.Equal(t, worker.chatWorkerID(), call.input.WorkerID) + require.Equal(t, database.ChatStatusRunning, call.input.Status) + require.NotEqual(t, uuid.Nil, call.input.RunnerID) + + latest, err := f.db.GetChatByID(testutil.Context(t, testutil.WaitShort), chat.ID) + require.NoError(t, err) + require.Equal(t, worker.chatWorkerID(), latest.WorkerID.UUID) + require.Equal(t, call.input.RunnerID, latest.RunnerID.UUID) + _, err = f.db.GetChatHeartbeat(testutil.Context(t, testutil.WaitShort), database.GetChatHeartbeatParams{ + ChatID: chat.ID, + RunnerID: call.input.RunnerID, + }) + require.NoError(t, err) +} + +func TestWorker_AcquiresRequiresActionChatFromOwnershipHint(t *testing.T) { + t.Parallel() + f := newWorkerTestFixture(t) + chat := f.createRequiresActionChat(t) + starter := newRecordingTaskStarter() + startWorker(t, testOptions(t, f, starter)) + + call := starter.waitCall(t, taskKindRequiresActionTimeout, chat.ID) + require.Equal(t, database.ChatStatusRequiresAction, call.input.Status) + require.True(t, call.input.RequiresActionDeadlineAt.Valid) +} + +func TestWorker_SkipsFreshlyOwnedChat(t *testing.T) { + t.Parallel() + f := newWorkerTestFixture(t) + chat := f.createRunningChat(t) + otherWorker := uuid.New() + otherRunner := uuid.New() + acquireChat(t, f, chat.ID, otherWorker, otherRunner) + starter := newRecordingTaskStarter() + worker := startWorker(t, testOptions(t, f, starter)) + worker.Wake() + + starter.assertNoCall(t) + latest, err := f.db.GetChatByID(testutil.Context(t, testutil.WaitShort), chat.ID) + require.NoError(t, err) + require.Equal(t, otherWorker, latest.WorkerID.UUID) + require.Equal(t, otherRunner, latest.RunnerID.UUID) +} + +func TestWorker_TwoWorkersRaceSingleOwner(t *testing.T) { + t.Parallel() + f := newWorkerTestFixture(t) + chat := f.createRunningChat(t) + firstStarter := newRecordingTaskStarter() + secondStarter := newRecordingTaskStarter() + first := startWorker(t, testOptions(t, f, firstStarter)) + second := startWorker(t, testOptions(t, f, secondStarter)) + + call := waitAnyTaskCall(t, firstStarter, secondStarter, taskKindGeneration, chat.ID) + require.Contains(t, []uuid.UUID{first.chatWorkerID(), second.chatWorkerID()}, call.input.WorkerID) + firstStarter.assertNoCall(t) + secondStarter.assertNoCall(t) + + latest, err := f.db.GetChatByID(testutil.Context(t, testutil.WaitShort), chat.ID) + require.NoError(t, err) + require.True(t, latest.WorkerID.Valid) + require.True(t, latest.RunnerID.Valid) + require.Equal(t, call.input.WorkerID, latest.WorkerID.UUID) + require.Equal(t, call.input.RunnerID, latest.RunnerID.UUID) +} + +func TestWorker_DrainsMultipleRunnableChatsOnWake(t *testing.T) { + t.Parallel() + f := newWorkerTestFixture(t) + first := f.createRunningChat(t) + second := f.createRunningChat(t) + third := f.createRunningChat(t) + starter := newRecordingTaskStarter() + opts := testOptions(t, f, starter) + opts.AcquisitionBatchSize = 1 + startWorker(t, opts) + + want := map[uuid.UUID]bool{first.ID: true, second.ID: true, third.ID: true} + for range 3 { + call := starter.waitCall(t, taskKindGeneration, uuid.Nil) + delete(want, call.input.ChatID) + } + require.Empty(t, want) +} + +func TestWorker_DoesNotAcquireIdleOrArchivedChats(t *testing.T) { + t.Parallel() + f := newWorkerTestFixture(t) + waiting := f.createRunningChat(t) + finishTurn(t, f, waiting.ID) + errorChat := f.createRunningChat(t) + forceExecutionStateAndPublish(t, f, errorChat.ID, database.ChatStatusError, false) + archived := f.createRunningChat(t) + forceExecutionStateAndPublish(t, f, archived.ID, database.ChatStatusRunning, true) + starter := newRecordingTaskStarter() + worker := startWorker(t, testOptions(t, f, starter)) + worker.Wake() + + starter.assertNoCall(t) +} + +func TestWorker_HeartbeatLoopRefreshesActiveRunnerHeartbeat(t *testing.T) { + t.Parallel() + f := newWorkerTestFixture(t) + chat := f.createRunningChat(t) + clock := quartz.NewMock(t) + heartbeatTrap := clock.Trap().NewTicker("chatworker", "heartbeat") + defer heartbeatTrap.Close() + starter := newBlockingTaskStarter(false) + opts := testOptions(t, f, starter) + opts.Clock = clock + opts.HeartbeatInterval = time.Minute + startWorker(t, opts) + heartbeatTrap.MustWait(testutil.Context(t, testutil.WaitLong)).MustRelease(testutil.Context(t, testutil.WaitLong)) + call := starter.waitCall(t, taskKindGeneration, chat.ID) + oldHeartbeat := makeHeartbeatStale(t, f, chat.ID, call.input.RunnerID) + + clock.Advance(time.Minute).MustWait(testutil.Context(t, testutil.WaitLong)) + testutil.Eventually(testutil.Context(t, testutil.WaitLong), t, func(ctx context.Context) bool { + heartbeat, err := f.db.GetChatHeartbeat(ctx, database.GetChatHeartbeatParams{ + ChatID: chat.ID, + RunnerID: call.input.RunnerID, + }) + return err == nil && heartbeat.HeartbeatAt.After(oldHeartbeat) + }, testutil.IntervalFast, "heartbeat should be refreshed") +} + +func TestWorker_HeartbeatCleanupDeletesStaleRows(t *testing.T) { + t.Parallel() + f := newWorkerTestFixture(t) + chat := f.createRunningChat(t) + finishTurn(t, f, chat.ID) + runnerID := uuid.New() + require.NoError(t, f.db.UpsertChatHeartbeat(testutil.Context(t, testutil.WaitShort), database.UpsertChatHeartbeatParams{ + ChatID: chat.ID, + RunnerID: runnerID, + })) + makeHeartbeatStale(t, f, chat.ID, runnerID) + clock := quartz.NewMock(t) + cleanupTrap := clock.Trap().NewTicker("chatworker", "heartbeat-cleanup") + defer cleanupTrap.Close() + starter := newRecordingTaskStarter() + opts := testOptions(t, f, starter) + opts.Clock = clock + opts.HeartbeatCleanupInterval = time.Minute + startWorker(t, opts) + cleanupTrap.MustWait(testutil.Context(t, testutil.WaitLong)).MustRelease(testutil.Context(t, testutil.WaitLong)) + + clock.Advance(time.Minute).MustWait(testutil.Context(t, testutil.WaitLong)) + testutil.Eventually(testutil.Context(t, testutil.WaitLong), t, func(ctx context.Context) bool { + _, err := f.db.GetChatHeartbeat(ctx, database.GetChatHeartbeatParams{ + ChatID: chat.ID, + RunnerID: runnerID, + }) + return errors.Is(err, sql.ErrNoRows) + }, testutil.IntervalFast) +} + +func TestWorker_CloseDeletesOwnedHeartbeatsAndPublishesOwnershipHints(t *testing.T) { + t.Parallel() + f := newWorkerTestFixture(t) + first := f.createRunningChat(t) + second := f.createRunningChat(t) + starter := newBlockingTaskStarter(false) + pubsub := newRecordingPubsub(f.pubsub) + opts := testOptions(t, f, starter) + opts.Pubsub = pubsub + worker := startWorker(t, opts) + callsByChat := make(map[uuid.UUID]taskCall) + for range 2 { + call := starter.waitCall(t, taskKindGeneration, uuid.Nil) + callsByChat[call.input.ChatID] = call + } + require.Contains(t, callsByChat, first.ID) + require.Contains(t, callsByChat, second.ID) + + require.NoError(t, worker.Close()) + for _, call := range callsByChat { + _, err := f.db.GetChatHeartbeat(testutil.Context(t, testutil.WaitShort), database.GetChatHeartbeatParams{ + ChatID: call.input.ChatID, + RunnerID: call.input.RunnerID, + }) + require.ErrorIs(t, err, sql.ErrNoRows) + } + + messages := pubsub.ownershipMessages(t) + seen := make(map[uuid.UUID]bool) + for _, msg := range messages { + seen[msg.ChatID] = true + require.NotZero(t, msg.SnapshotVersion) + } + require.True(t, seen[first.ID], "expected ownership hint for first runner") + require.True(t, seen[second.ID], "expected ownership hint for second runner") +} + +func TestWorker_CloseIsIdempotentAndDoesNotBlock(t *testing.T) { + t.Parallel() + f := newWorkerTestFixture(t) + chat := f.createRunningChat(t) + starter := newBlockingTaskStarter(false) + worker := startWorker(t, testOptions(t, f, starter)) + call := starter.waitCall(t, taskKindGeneration, chat.ID) + + closed := make(chan error, 1) + go func() { + if err := worker.Close(); err != nil { + closed <- err + return + } + closed <- worker.Close() + }() + select { + case err := <-closed: + require.NoError(t, err) + case <-time.After(testutil.WaitLong): + t.Fatal("worker close did not return") + } + select { + case <-call.ctx.Done(): + case <-time.After(testutil.WaitLong): + t.Fatal("active task was not canceled") + } +} + +func waitAnyTaskCall( + t *testing.T, + first *recordingTaskStarter, + second *recordingTaskStarter, + kind taskKind, + chatID uuid.UUID, +) taskCall { + t.Helper() + deadline := time.After(testutil.WaitLong) + for { + select { + case call := <-first.callCh: + if call.kind == kind && call.input.ChatID == chatID { + return call + } + case call := <-second.callCh: + if call.kind == kind && call.input.ChatID == chatID { + return call + } + case <-deadline: + t.Fatal("timed out waiting for either worker to start task") + return taskCall{} + } + } +} + +func requireTaskCanceled(t *testing.T, call taskCall) { + t.Helper() + select { + case <-call.ctx.Done(): + require.True(t, errors.Is(call.ctx.Err(), context.Canceled)) + case <-time.After(testutil.WaitLong): + t.Fatal("task context was not canceled") + } +} diff --git a/coderd/x/chatd/workspace_context_builder.go b/coderd/x/chatd/workspace_context_builder.go new file mode 100644 index 0000000000..66c975324e --- /dev/null +++ b/coderd/x/chatd/workspace_context_builder.go @@ -0,0 +1,143 @@ +package chatd + +import ( + "context" + "sync" + + "github.com/google/uuid" + + "cdr.dev/slog/v3" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/x/chatd/chatprompt" + "github.com/coder/coder/v2/coderd/x/chatd/chatstate" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/codersdk/workspacesdk" +) + +// buildWorkspaceContext fetches workspace context for the chat's +// bound workspace agent and returns durable chatstate.Message values +// for the generation action to commit. It returns an empty result +// (Messages == nil) when there is nothing safe to persist for the +// current committed metadata; that is treated as an expected exit by +// the generation action. +func (server *Server) buildWorkspaceContext( + ctx context.Context, + input workspaceContextBuildInput, +) (workspaceContextBuildResult, error) { + chat := input.Chat + if !chat.WorkspaceID.Valid || !chat.AgentID.Valid { + return workspaceContextBuildResult{}, nil + } + logger := server.logger.With( + slog.F("chat_id", chat.ID), + slog.F("owner_id", chat.OwnerID), + ) + + // Build a per-call workspace context with the latest committed + // chat snapshot so getWorkspaceAgent and getWorkspaceConn dial + // the agent we actually want to fetch context from. + currentChat := chat + var chatStateMu sync.Mutex + wsCtx := turnWorkspaceContext{ + server: server, + chatStateMu: &chatStateMu, + currentChat: ¤tChat, + loadChatSnapshot: server.db.GetChatByID, + } + defer wsCtx.close() + + parts, discoveredSkillsPart, _, expectedAgentID := server.fetchContextForBuild(ctx, chat, &wsCtx, logger) + _ = discoveredSkillsPart + // If the workspace or agent is gone, fall back to no-op so the + // generation action exits without committing stale context. + if expectedAgentID == uuid.Nil { + return workspaceContextBuildResult{}, nil + } + + hasContent := false + hasContextFilePart := false + for _, part := range parts { + if part.Type == codersdk.ChatMessagePartTypeContextFile { + hasContextFilePart = true + if part.ContextFileContent != "" { + hasContent = true + } + } + } + + agentID := uuid.NullUUID{UUID: expectedAgentID, Valid: true} + + // If we have no content but the agent is known, commit a blank + // context-file marker (sentinel) so subsequent turns skip the + // workspace-agent dial and the decision helper observes the + // attempt in committed history. This applies whether the + // workspace connection succeeded but returned no AGENTS.md, or + // the agent's context config fetch failed: in both cases we + // have a known agent and committing a sentinel breaks the + // otherwise-infinite decision loop. + if !hasContent { + if !hasContextFilePart { + parts = append([]codersdk.ChatMessagePart{{ + Type: codersdk.ChatMessagePartTypeContextFile, + ContextFileAgentID: agentID, + }}, parts...) + } + } + + content, err := chatprompt.MarshalParts(parts) + if err != nil { + return workspaceContextBuildResult{}, nil + } + + modelConfigID := chat.LastModelConfigID + msg := chatstate.Message{ + Role: database.ChatMessageRoleUser, + Content: content, + Visibility: database.ChatMessageVisibilityBoth, + ModelConfigID: uuid.NullUUID{UUID: modelConfigID, Valid: modelConfigID != uuid.Nil}, + ContentVersion: chatprompt.CurrentContentVersion, + } + + // Update the cache column so subsequent turns can read the last + // injected context without scanning messages. This is a + // best-effort write that does not mutate chat history; the + // generation action separately commits the durable message + // below. + stripped := make([]codersdk.ChatMessagePart, len(parts)) + copy(stripped, parts) + for i := range stripped { + stripped[i].StripInternal() + } + server.updateLastInjectedContext(ctx, chat.ID, stripped) + + return workspaceContextBuildResult{Messages: []chatstate.Message{msg}}, nil +} + +// fetchContextForBuild fetches workspace context parts from the +// agent, returning the parts to persist, any discovered skill +// metadata (carried in parts), and whether the workspace connection +// succeeded. expectedAgentID is the agent ID the fetch was bound to, +// or uuid.Nil if the agent could not be resolved. +func (server *Server) fetchContextForBuild( + ctx context.Context, + chat database.Chat, + wsCtx *turnWorkspaceContext, + logger slog.Logger, +) (parts []codersdk.ChatMessagePart, discoveredSkills []codersdk.ChatMessagePart, workspaceConnOK bool, expectedAgentID uuid.UUID) { + agent, agentParts, _, connOK := server.fetchWorkspaceContext( + ctx, chat, wsCtx.getWorkspaceAgent, + func(instructionCtx context.Context) (workspacesdk.AgentConn, error) { + if _, _, err := wsCtx.workspaceAgentIDForConn(instructionCtx); err != nil { + return nil, err + } + return wsCtx.getWorkspaceConn(instructionCtx) + }, + ) + if agent == nil { + // fetchWorkspaceContext returns nil for the agent when the + // chat has no valid workspace or the agent lookup fails. + logger.Debug(ctx, "workspace context build: workspace agent not resolvable") + return nil, nil, false, uuid.Nil + } + return agentParts, agentParts, connOK, agent.ID +} diff --git a/enterprise/coderd/exp_chats_test.go b/enterprise/coderd/exp_chats_test.go index d29240dd2e..83ede0531f 100644 --- a/enterprise/coderd/exp_chats_test.go +++ b/enterprise/coderd/exp_chats_test.go @@ -67,6 +67,7 @@ func createOpenAIModelConfigForTest( func TestChatStreamRelay(t *testing.T) { t.Parallel() + t.Skip("chatd refactor: remove in PR 4") t.Run("RelayMessagePartsAcrossReplicas", func(t *testing.T) { t.Parallel() diff --git a/enterprise/coderd/x/chatd/chatd_test.go b/enterprise/coderd/x/chatd/chatd_test.go index 9587c7e6b3..a4c48ad4d9 100644 --- a/enterprise/coderd/x/chatd/chatd_test.go +++ b/enterprise/coderd/x/chatd/chatd_test.go @@ -31,6 +31,8 @@ import ( "github.com/coder/quartz" ) +const skipLegacyChatStream = "chatd refactor: remove in PR 4" + func chatLastErrorMessage(raw pqtype.NullRawMessage) string { if !raw.Valid { return "" @@ -485,6 +487,7 @@ func TestSubscribeRelaySnapshotDelivered(t *testing.T) { func TestSubscribeRetryEventAcrossInstances(t *testing.T) { t.Parallel() + t.Skip(skipLegacyChatStream) db, ps := dbtestutil.NewDB(t) workerID := uuid.New() @@ -1287,6 +1290,7 @@ func TestSubscribeRelayMultipleReconnects(t *testing.T) { // the durable message ID and dropped from new buffer snapshots. func TestSubscribeRelayDialCanceledOnFastCompletion(t *testing.T) { t.Parallel() + t.Skip(skipLegacyChatStream) db, ps := dbtestutil.NewDB(t) workerID := uuid.New()