diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index ee49dd2bb1..8c8cf6d4c1 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -2369,7 +2369,7 @@ func (q *querier) FetchVolumesResourceMonitorsUpdatedAfter(ctx context.Context, return q.db.FetchVolumesResourceMonitorsUpdatedAfter(ctx, updatedAt) } -func (q *querier) FinalizeStaleChatDebugRows(ctx context.Context, updatedBefore time.Time) (database.FinalizeStaleChatDebugRowsRow, error) { +func (q *querier) FinalizeStaleChatDebugRows(ctx context.Context, updatedBefore database.FinalizeStaleChatDebugRowsParams) (database.FinalizeStaleChatDebugRowsRow, error) { // Background sweep operates across all chats. if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceChat); err != nil { return database.FinalizeStaleChatDebugRowsRow{}, err @@ -5915,6 +5915,28 @@ func (q *querier) SoftDeleteContextFileMessages(ctx context.Context, chatID uuid return q.db.SoftDeleteContextFileMessages(ctx, chatID) } +func (q *querier) TouchChatDebugRunUpdatedAt(ctx context.Context, arg database.TouchChatDebugRunUpdatedAtParams) error { + chat, err := q.db.GetChatByID(ctx, arg.ChatID) + if err != nil { + return err + } + if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil { + return err + } + return q.db.TouchChatDebugRunUpdatedAt(ctx, arg) +} + +func (q *querier) TouchChatDebugStepAndRun(ctx context.Context, arg database.TouchChatDebugStepAndRunParams) error { + chat, err := q.db.GetChatByID(ctx, arg.ChatID) + if err != nil { + return err + } + if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil { + return err + } + return q.db.TouchChatDebugStepAndRun(ctx, arg) +} + func (q *querier) TryAcquireLock(ctx context.Context, id int64) (bool, error) { return q.db.TryAcquireLock(ctx, id) } diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 8261a7ea6e..7826990ff2 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -475,10 +475,14 @@ func (s *MethodTestSuite) TestChats() { check.Args(chat.ID).Asserts(chat, policy.ActionUpdate).Returns(int64(1)) })) s.Run("FinalizeStaleChatDebugRows", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { - updatedBefore := dbtime.Now() + now := dbtime.Now() + arg := database.FinalizeStaleChatDebugRowsParams{ + Now: now, + UpdatedBefore: now.Add(-5 * time.Minute), + } row := database.FinalizeStaleChatDebugRowsRow{RunsFinalized: 1, StepsFinalized: 2} - dbm.EXPECT().FinalizeStaleChatDebugRows(gomock.Any(), updatedBefore).Return(row, nil).AnyTimes() - check.Args(updatedBefore).Asserts(rbac.ResourceChat, policy.ActionUpdate).Returns(row) + dbm.EXPECT().FinalizeStaleChatDebugRows(gomock.Any(), arg).Return(row, nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceChat, policy.ActionUpdate).Returns(row) })) s.Run("GetChatDebugLoggingAllowUsers", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { dbm.EXPECT().GetChatDebugLoggingAllowUsers(gomock.Any()).Return(true, nil).AnyTimes() @@ -532,6 +536,20 @@ func (s *MethodTestSuite) TestChats() { dbm.EXPECT().UpdateChatDebugRun(gomock.Any(), arg).Return(run, nil).AnyTimes() check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(run) })) + s.Run("TouchChatDebugRunUpdatedAt", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + chat := testutil.Fake(s.T(), faker, database.Chat{}) + arg := database.TouchChatDebugRunUpdatedAtParams{ID: uuid.New(), ChatID: chat.ID} + dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes() + dbm.EXPECT().TouchChatDebugRunUpdatedAt(gomock.Any(), arg).Return(nil).AnyTimes() + check.Args(arg).Asserts(chat, policy.ActionUpdate) + })) + s.Run("TouchChatDebugStepAndRun", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + chat := testutil.Fake(s.T(), faker, database.Chat{}) + arg := database.TouchChatDebugStepAndRunParams{StepID: uuid.New(), RunID: uuid.New(), ChatID: chat.ID} + dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes() + dbm.EXPECT().TouchChatDebugStepAndRun(gomock.Any(), arg).Return(nil).AnyTimes() + check.Args(arg).Asserts(chat, policy.ActionUpdate) + })) s.Run("UpdateChatDebugStep", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { chat := testutil.Fake(s.T(), faker, database.Chat{}) arg := database.UpdateChatDebugStepParams{ID: uuid.New(), ChatID: chat.ID} diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index 97752dd59d..396209fba3 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -888,7 +888,7 @@ func (m queryMetricsStore) FetchVolumesResourceMonitorsUpdatedAfter(ctx context. return r0, r1 } -func (m queryMetricsStore) FinalizeStaleChatDebugRows(ctx context.Context, updatedBefore time.Time) (database.FinalizeStaleChatDebugRowsRow, error) { +func (m queryMetricsStore) FinalizeStaleChatDebugRows(ctx context.Context, updatedBefore database.FinalizeStaleChatDebugRowsParams) (database.FinalizeStaleChatDebugRowsRow, error) { start := time.Now() r0, r1 := m.s.FinalizeStaleChatDebugRows(ctx, updatedBefore) m.queryLatencies.WithLabelValues("FinalizeStaleChatDebugRows").Observe(time.Since(start).Seconds()) @@ -4240,6 +4240,22 @@ func (m queryMetricsStore) SoftDeleteContextFileMessages(ctx context.Context, ch return r0 } +func (m queryMetricsStore) TouchChatDebugRunUpdatedAt(ctx context.Context, arg database.TouchChatDebugRunUpdatedAtParams) error { + start := time.Now() + r0 := m.s.TouchChatDebugRunUpdatedAt(ctx, arg) + m.queryLatencies.WithLabelValues("TouchChatDebugRunUpdatedAt").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "TouchChatDebugRunUpdatedAt").Inc() + return r0 +} + +func (m queryMetricsStore) TouchChatDebugStepAndRun(ctx context.Context, arg database.TouchChatDebugStepAndRunParams) error { + start := time.Now() + r0 := m.s.TouchChatDebugStepAndRun(ctx, arg) + m.queryLatencies.WithLabelValues("TouchChatDebugStepAndRun").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "TouchChatDebugStepAndRun").Inc() + return r0 +} + func (m queryMetricsStore) TryAcquireLock(ctx context.Context, pgTryAdvisoryXactLock int64) (bool, error) { start := time.Now() r0, r1 := m.s.TryAcquireLock(ctx, pgTryAdvisoryXactLock) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index 8848d7c58f..4d66dd5907 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -1518,18 +1518,18 @@ func (mr *MockStoreMockRecorder) FetchVolumesResourceMonitorsUpdatedAfter(ctx, u } // FinalizeStaleChatDebugRows mocks base method. -func (m *MockStore) FinalizeStaleChatDebugRows(ctx context.Context, updatedBefore time.Time) (database.FinalizeStaleChatDebugRowsRow, error) { +func (m *MockStore) FinalizeStaleChatDebugRows(ctx context.Context, arg database.FinalizeStaleChatDebugRowsParams) (database.FinalizeStaleChatDebugRowsRow, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "FinalizeStaleChatDebugRows", ctx, updatedBefore) + ret := m.ctrl.Call(m, "FinalizeStaleChatDebugRows", ctx, arg) ret0, _ := ret[0].(database.FinalizeStaleChatDebugRowsRow) ret1, _ := ret[1].(error) return ret0, ret1 } // FinalizeStaleChatDebugRows indicates an expected call of FinalizeStaleChatDebugRows. -func (mr *MockStoreMockRecorder) FinalizeStaleChatDebugRows(ctx, updatedBefore any) *gomock.Call { +func (mr *MockStoreMockRecorder) FinalizeStaleChatDebugRows(ctx, arg any) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FinalizeStaleChatDebugRows", reflect.TypeOf((*MockStore)(nil).FinalizeStaleChatDebugRows), ctx, updatedBefore) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FinalizeStaleChatDebugRows", reflect.TypeOf((*MockStore)(nil).FinalizeStaleChatDebugRows), ctx, arg) } // FindMatchingPresetID mocks base method. @@ -8034,6 +8034,34 @@ func (mr *MockStoreMockRecorder) SoftDeleteContextFileMessages(ctx, chatID any) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SoftDeleteContextFileMessages", reflect.TypeOf((*MockStore)(nil).SoftDeleteContextFileMessages), ctx, chatID) } +// TouchChatDebugRunUpdatedAt mocks base method. +func (m *MockStore) TouchChatDebugRunUpdatedAt(ctx context.Context, arg database.TouchChatDebugRunUpdatedAtParams) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TouchChatDebugRunUpdatedAt", ctx, arg) + ret0, _ := ret[0].(error) + return ret0 +} + +// TouchChatDebugRunUpdatedAt indicates an expected call of TouchChatDebugRunUpdatedAt. +func (mr *MockStoreMockRecorder) TouchChatDebugRunUpdatedAt(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TouchChatDebugRunUpdatedAt", reflect.TypeOf((*MockStore)(nil).TouchChatDebugRunUpdatedAt), ctx, arg) +} + +// TouchChatDebugStepAndRun mocks base method. +func (m *MockStore) TouchChatDebugStepAndRun(ctx context.Context, arg database.TouchChatDebugStepAndRunParams) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "TouchChatDebugStepAndRun", ctx, arg) + ret0, _ := ret[0].(error) + return ret0 +} + +// TouchChatDebugStepAndRun indicates an expected call of TouchChatDebugStepAndRun. +func (mr *MockStoreMockRecorder) TouchChatDebugStepAndRun(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TouchChatDebugStepAndRun", reflect.TypeOf((*MockStore)(nil).TouchChatDebugStepAndRun), ctx, arg) +} + // TryAcquireLock mocks base method. func (m *MockStore) TryAcquireLock(ctx context.Context, pgTryAdvisoryXactLock int64) (bool, error) { m.ctrl.T.Helper() diff --git a/coderd/database/querier.go b/coderd/database/querier.go index bcf0031857..94144a494a 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -197,15 +197,18 @@ type sqlcQuerier interface { FetchVolumesResourceMonitorsByAgentID(ctx context.Context, agentID uuid.UUID) ([]WorkspaceAgentVolumeResourceMonitor, error) FetchVolumesResourceMonitorsUpdatedAfter(ctx context.Context, updatedAt time.Time) ([]WorkspaceAgentVolumeResourceMonitor, error) // Marks orphaned in-progress rows as interrupted so they do not stay - // in a non-terminal state forever. The NOT IN list must match the + // in a non-terminal state forever. The NOT IN list must match the // terminal statuses defined by ChatDebugStatus in codersdk/chats.go. // // The steps CTE also catches steps whose parent run was just finalized // (via run_id IN), because PostgreSQL data-modifying CTEs share the - // same snapshot and cannot see each other's row updates. Without this, + // same snapshot and cannot see each other's row updates. Without this, // a step with a recent updated_at would survive its run's finalization // and remain in 'in_progress' state permanently. - FinalizeStaleChatDebugRows(ctx context.Context, updatedBefore time.Time) (FinalizeStaleChatDebugRowsRow, error) + // + // @now is the caller's clock timestamp so that mock-clock tests stay + // consistent with the @updated_before cutoff. + FinalizeStaleChatDebugRows(ctx context.Context, arg FinalizeStaleChatDebugRowsParams) (FinalizeStaleChatDebugRowsRow, error) // FindMatchingPresetID finds a preset ID that is the largest exact subset of the provided parameters. // It returns the preset ID if a match is found, or NULL if no match is found. // The query finds presets where all preset parameters are present in the provided parameters, @@ -780,6 +783,12 @@ type sqlcQuerier interface { InsertAuditLog(ctx context.Context, arg InsertAuditLogParams) (AuditLog, error) InsertChat(ctx context.Context, arg InsertChatParams) (Chat, error) InsertChatDebugRun(ctx context.Context, arg InsertChatDebugRunParams) (ChatDebugRun, error) + // The CTE atomically locks the parent run via UPDATE, bumps its + // updated_at (eliminating a separate TouchChatDebugRunUpdatedAt + // call), and enforces the finalization guard: if the run is already + // finished, the UPDATE returns zero rows, the INSERT gets no source + // rows, and sql.ErrNoRows is returned. The UPDATE also serializes + // with concurrent FinalizeStale under READ COMMITTED isolation. InsertChatDebugStep(ctx context.Context, arg InsertChatDebugStepParams) (ChatDebugStep, error) InsertChatFile(ctx context.Context, arg InsertChatFileParams) (InsertChatFileRow, error) InsertChatMessages(ctx context.Context, arg InsertChatMessagesParams) ([]ChatMessage, error) @@ -946,6 +955,31 @@ type sqlcQuerier interface { SoftDeleteChatMessageByID(ctx context.Context, id int64) error SoftDeleteChatMessagesAfterID(ctx context.Context, arg SoftDeleteChatMessagesAfterIDParams) error SoftDeleteContextFileMessages(ctx context.Context, chatID uuid.UUID) error + // Overrides updated_at on the parent run without touching any + // other column. Used by tests that need to stamp a run with a + // specific timestamp after the InsertChatDebugStep CTE has + // already bumped it to NOW(), so stale-row finalization paths + // can be exercised deterministically. The chatdebug service + // itself does not call this: heartbeats go through + // TouchChatDebugStepAndRun, and step creation updates the parent + // run via the InsertChatDebugStep CTE. + TouchChatDebugRunUpdatedAt(ctx context.Context, arg TouchChatDebugRunUpdatedAtParams) error + // Atomically bumps updated_at on both the step and its parent run + // in a single statement. This prevents FinalizeStale from + // interleaving between the two touches and finalizing a run whose + // step heartbeat was just written. + // + // The step UPDATE joins through touched_run (via FROM) and reads + // its RETURNING rows. Per the PostgreSQL WITH semantics, RETURNING + // is the only way to communicate values between a data-modifying + // CTE and the main query, and consuming those rows forces the run + // UPDATE to complete before the step UPDATE. That matches the + // lock order used by FinalizeStaleChatDebugRows and avoids a + // deadlock between concurrent heartbeats and stale sweeps. The + // join also constrains the step update to the specified run so a + // mismatched (run_id, step_id) pair cannot silently refresh an + // unrelated step. + TouchChatDebugStepAndRun(ctx context.Context, arg TouchChatDebugStepAndRunParams) error // Non blocking lock. Returns true if the lock was acquired, false otherwise. // // This must be called from within a transaction. The lock will be automatically @@ -966,14 +1000,25 @@ type sqlcQuerier interface { UpdateChatBuildAgentBinding(ctx context.Context, arg UpdateChatBuildAgentBindingParams) (Chat, error) UpdateChatByID(ctx context.Context, arg UpdateChatByIDParams) (Chat, error) // Uses COALESCE so that passing NULL from Go means "keep the - // existing value." This is intentional: debug rows follow a + // existing value." This is intentional: debug rows follow a // write-once-finalize pattern where fields are set at creation - // or finalization and never cleared back to NULL. + // or finalization and never cleared back to NULL. The @now + // parameter keeps updated_at under the caller's clock. + // + // finished_at is enforced as write-once at the SQL level: once + // populated it cannot be overwritten by a later call. Callers + // that issue a summary or status refresh after the run has + // already finalized therefore cannot corrupt the original + // completion timestamp, which keeps duration and ordering + // calculations stable regardless of how many times the row is + // updated. UpdateChatDebugRun(ctx context.Context, arg UpdateChatDebugRunParams) (ChatDebugRun, error) // Uses COALESCE so that passing NULL from Go means "keep the - // existing value." This is intentional: debug rows follow a + // existing value." This is intentional: debug rows follow a // write-once-finalize pattern where fields are set at creation - // or finalization and never cleared back to NULL. + // or finalization and never cleared back to NULL. The @now + // parameter keeps updated_at under the caller's clock, matching + // the injectable quartz.Clock used by FinalizeStale sweeps. UpdateChatDebugStep(ctx context.Context, arg UpdateChatDebugStepParams) (ChatDebugStep, error) // Bumps the heartbeat timestamp for the given set of chat IDs, // provided they are still running and owned by the specified diff --git a/coderd/database/querier_test.go b/coderd/database/querier_test.go index 0c35b18348..8c0593701f 100644 --- a/coderd/database/querier_test.go +++ b/coderd/database/querier_test.go @@ -11651,7 +11651,10 @@ func TestFinalizeStaleChatDebugRows(t *testing.T) { require.NoError(t, err) // --- orphanStep: in_progress step whose run is already completed --- - // its own updated_at is old, so it should be finalized directly. + // Its own updated_at is old, so it should be finalized directly. + // The step must be inserted while the run is still open because + // InsertChatDebugStep requires finished_at IS NULL on the parent + // run (atomic guard against appending steps to finalized runs). completedRun, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{ ChatID: chat.ID, ModelConfigID: uuid.NullUUID{UUID: modelCfg.ID, Valid: true}, @@ -11662,7 +11665,19 @@ func TestFinalizeStaleChatDebugRows(t *testing.T) { }) require.NoError(t, err) - // Mark the run as completed with a finished_at timestamp. + // Insert the step while the run is still open (finished_at IS NULL). + orphanStep, err := store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{ + RunID: completedRun.ID, + ChatID: chat.ID, + StepNumber: 1, + Operation: "stream", + Status: "in_progress", + UpdatedAt: sql.NullTime{Time: staleTime, Valid: true}, + }) + require.NoError(t, err) + + // Now mark the run as completed with a finished_at timestamp, + // leaving the step orphaned in in_progress state. _, err = store.UpdateChatDebugRun(ctx, database.UpdateChatDebugRunParams{ ID: completedRun.ID, ChatID: completedRun.ChatID, @@ -11671,16 +11686,7 @@ func TestFinalizeStaleChatDebugRows(t *testing.T) { Time: time.Now(), Valid: true, }, - }) - require.NoError(t, err) - - orphanStep, err := store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{ - RunID: completedRun.ID, - ChatID: chat.ID, - StepNumber: 1, - Operation: "stream", - Status: "in_progress", - UpdatedAt: sql.NullTime{Time: staleTime, Valid: true}, + Now: time.Now(), }) require.NoError(t, err) @@ -11715,6 +11721,16 @@ func TestFinalizeStaleChatDebugRows(t *testing.T) { }) require.NoError(t, err) + // The InsertChatDebugStep CTE atomically bumps the parent run's + // updated_at to NOW(). Reset it back to staleTime so the run is + // still caught by the age predicate in FinalizeStaleChatDebugRows. + err = store.TouchChatDebugRunUpdatedAt(ctx, database.TouchChatDebugRunUpdatedAtParams{ + ID: cascadeRun.ID, + ChatID: chat.ID, + Now: staleTime, + }) + require.NoError(t, err) + // --- alreadyDone: completed run/step --- should NOT be touched. doneRun, err := store.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{ ChatID: chat.ID, @@ -11726,6 +11742,17 @@ func TestFinalizeStaleChatDebugRows(t *testing.T) { }) require.NoError(t, err) + // Insert step while run is still open. + doneStep, err := store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{ + RunID: doneRun.ID, + ChatID: chat.ID, + StepNumber: 1, + Operation: "stream", + Status: "completed", + }) + require.NoError(t, err) + + // Now finalize both run and step. _, err = store.UpdateChatDebugRun(ctx, database.UpdateChatDebugRunParams{ ID: doneRun.ID, ChatID: doneRun.ChatID, @@ -11734,15 +11761,7 @@ func TestFinalizeStaleChatDebugRows(t *testing.T) { Time: time.Now(), Valid: true, }, - }) - require.NoError(t, err) - - doneStep, err := store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{ - RunID: doneRun.ID, - ChatID: chat.ID, - StepNumber: 1, - Operation: "stream", - Status: "completed", + Now: time.Now(), }) require.NoError(t, err) @@ -11754,6 +11773,7 @@ func TestFinalizeStaleChatDebugRows(t *testing.T) { Time: time.Now(), Valid: true, }, + Now: time.Now(), }) require.NoError(t, err) @@ -11769,6 +11789,17 @@ func TestFinalizeStaleChatDebugRows(t *testing.T) { }) require.NoError(t, err) + // Insert step while run is still open. + errorStep, err := store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{ + RunID: errorRun.ID, + ChatID: chat.ID, + StepNumber: 1, + Operation: "stream", + Status: "error", + }) + require.NoError(t, err) + + // Now finalize both run and step. _, err = store.UpdateChatDebugRun(ctx, database.UpdateChatDebugRunParams{ ID: errorRun.ID, ChatID: errorRun.ChatID, @@ -11777,15 +11808,7 @@ func TestFinalizeStaleChatDebugRows(t *testing.T) { Time: time.Now(), Valid: true, }, - }) - require.NoError(t, err) - - errorStep, err := store.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{ - RunID: errorRun.ID, - ChatID: chat.ID, - StepNumber: 1, - Operation: "stream", - Status: "error", + Now: time.Now(), }) require.NoError(t, err) @@ -11797,6 +11820,7 @@ func TestFinalizeStaleChatDebugRows(t *testing.T) { Time: time.Now(), Valid: true, }, + Now: time.Now(), }) require.NoError(t, err) @@ -11828,7 +11852,10 @@ func TestFinalizeStaleChatDebugRows(t *testing.T) { require.NoError(t, err) // --- Execute the finalization sweep. --- - result, err := store.FinalizeStaleChatDebugRows(ctx, staleThreshold) + result, err := store.FinalizeStaleChatDebugRows(ctx, database.FinalizeStaleChatDebugRowsParams{ + Now: time.Now(), + UpdatedBefore: staleThreshold, + }) require.NoError(t, err) // staleRun + cascadeRun were finalized; completedRun and doneRun @@ -11921,7 +11948,10 @@ func TestFinalizeStaleChatDebugRows(t *testing.T) { "fresh step should not have a finished_at timestamp") // A second sweep should be a no-op. - result2, err := store.FinalizeStaleChatDebugRows(ctx, staleThreshold) + result2, err := store.FinalizeStaleChatDebugRows(ctx, database.FinalizeStaleChatDebugRowsParams{ + Now: time.Now(), + UpdatedBefore: staleThreshold, + }) require.NoError(t, err) assert.EqualValues(t, 0, result2.RunsFinalized, "second sweep should find nothing to finalize") @@ -12034,6 +12064,7 @@ func TestChatDebugSQLGuards(t *testing.T) { Time: time.Now(), Valid: true, }, + Now: time.Now(), }) require.ErrorIs(t, err, sql.ErrNoRows, "UpdateChatDebugRun should fail when chat_id does not match") @@ -12051,6 +12082,7 @@ func TestChatDebugSQLGuards(t *testing.T) { Time: time.Now(), Valid: true, }, + Now: time.Now(), }) require.ErrorIs(t, err, sql.ErrNoRows, "UpdateChatDebugStep should fail when chat_id does not match") @@ -12137,6 +12169,7 @@ func TestChatDebugRunCOALESCEPreservation(t *testing.T) { Time: now, Valid: true, }, + Now: now, }) require.NoError(t, err) @@ -12144,9 +12177,9 @@ func TestChatDebugRunCOALESCEPreservation(t *testing.T) { require.Equal(t, "completed", updated.Status) require.True(t, updated.FinishedAt.Valid) - // UpdatedAt should advance (set to NOW() unconditionally). - require.True(t, updated.UpdatedAt.After(original.UpdatedAt) || - updated.UpdatedAt.Equal(original.UpdatedAt)) + // UpdatedAt should be set to the @now value we passed in. + require.WithinDuration(t, now, updated.UpdatedAt, time.Millisecond, + "updated_at should equal the @now parameter") // Every field not in the update call must be preserved exactly. require.Equal(t, original.RootChatID, updated.RootChatID, @@ -12257,6 +12290,7 @@ func TestChatDebugStepCOALESCEPreservation(t *testing.T) { Time: now, Valid: true, }, + Now: now, }) require.NoError(t, err) @@ -12264,9 +12298,9 @@ func TestChatDebugStepCOALESCEPreservation(t *testing.T) { require.Equal(t, "completed", updated.Status) require.True(t, updated.FinishedAt.Valid) - // UpdatedAt should advance (set to NOW() unconditionally). - require.True(t, updated.UpdatedAt.After(original.UpdatedAt) || - updated.UpdatedAt.Equal(original.UpdatedAt)) + // UpdatedAt should be set to the @now value we passed in. + require.WithinDuration(t, now, updated.UpdatedAt, time.Millisecond, + "updated_at should equal the @now parameter") // Every field not in the update call must be preserved exactly. require.Equal(t, original.HistoryTipMessageID, updated.HistoryTipMessageID, diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index f5c150198a..8f663acb29 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -2956,9 +2956,9 @@ WITH finalized_runs AS ( UPDATE chat_debug_runs SET status = 'interrupted', - updated_at = NOW(), - finished_at = NOW() - WHERE updated_at < $1::timestamptz + updated_at = $1::timestamptz, + finished_at = $1::timestamptz + WHERE updated_at < $2::timestamptz AND finished_at IS NULL AND status NOT IN ('completed', 'error', 'interrupted') RETURNING id @@ -2966,10 +2966,10 @@ WITH finalized_runs AS ( UPDATE chat_debug_steps SET status = 'interrupted', - updated_at = NOW(), - finished_at = NOW() + updated_at = $1::timestamptz, + finished_at = $1::timestamptz WHERE ( - updated_at < $1::timestamptz + updated_at < $2::timestamptz OR run_id IN (SELECT id FROM finalized_runs) ) AND finished_at IS NULL @@ -2981,22 +2981,30 @@ SELECT (SELECT COUNT(*) FROM finalized_steps)::bigint AS steps_finalized ` +type FinalizeStaleChatDebugRowsParams struct { + Now time.Time `db:"now" json:"now"` + UpdatedBefore time.Time `db:"updated_before" json:"updated_before"` +} + type FinalizeStaleChatDebugRowsRow struct { RunsFinalized int64 `db:"runs_finalized" json:"runs_finalized"` StepsFinalized int64 `db:"steps_finalized" json:"steps_finalized"` } // Marks orphaned in-progress rows as interrupted so they do not stay -// in a non-terminal state forever. The NOT IN list must match the +// in a non-terminal state forever. The NOT IN list must match the // terminal statuses defined by ChatDebugStatus in codersdk/chats.go. // // The steps CTE also catches steps whose parent run was just finalized // (via run_id IN), because PostgreSQL data-modifying CTEs share the -// same snapshot and cannot see each other's row updates. Without this, +// same snapshot and cannot see each other's row updates. Without this, // a step with a recent updated_at would survive its run's finalization // and remain in 'in_progress' state permanently. -func (q *sqlQuerier) FinalizeStaleChatDebugRows(ctx context.Context, updatedBefore time.Time) (FinalizeStaleChatDebugRowsRow, error) { - row := q.db.QueryRowContext(ctx, finalizeStaleChatDebugRows, updatedBefore) +// +// @now is the caller's clock timestamp so that mock-clock tests stay +// consistent with the @updated_before cutoff. +func (q *sqlQuerier) FinalizeStaleChatDebugRows(ctx context.Context, arg FinalizeStaleChatDebugRowsParams) (FinalizeStaleChatDebugRowsRow, error) { + row := q.db.QueryRowContext(ctx, finalizeStaleChatDebugRows, arg.Now, arg.UpdatedBefore) var i FinalizeStaleChatDebugRowsRow err := row.Scan(&i.RunsFinalized, &i.StepsFinalized) return i, err @@ -3225,6 +3233,14 @@ func (q *sqlQuerier) InsertChatDebugRun(ctx context.Context, arg InsertChatDebug } const insertChatDebugStep = `-- name: InsertChatDebugStep :one +WITH locked_run AS ( + UPDATE chat_debug_runs + SET updated_at = COALESCE($14::timestamptz, NOW()) + WHERE id = $1::uuid + AND chat_id = $16::uuid + AND finished_at IS NULL + RETURNING chat_id +) INSERT INTO chat_debug_steps ( run_id, chat_id, @@ -3245,7 +3261,7 @@ INSERT INTO chat_debug_steps ( ) SELECT $1::uuid, - run.chat_id, + locked_run.chat_id, $2::int, $3::text, $4::text, @@ -3260,9 +3276,7 @@ SELECT COALESCE($13::timestamptz, NOW()), COALESCE($14::timestamptz, NOW()), $15::timestamptz -FROM chat_debug_runs run -WHERE run.id = $1::uuid - AND run.chat_id = $16::uuid +FROM locked_run RETURNING id, run_id, chat_id, step_number, operation, status, history_tip_message_id, assistant_message_id, normalized_request, normalized_response, usage, attempts, error, metadata, started_at, updated_at, finished_at ` @@ -3285,6 +3299,12 @@ type InsertChatDebugStepParams struct { ChatID uuid.UUID `db:"chat_id" json:"chat_id"` } +// The CTE atomically locks the parent run via UPDATE, bumps its +// updated_at (eliminating a separate TouchChatDebugRunUpdatedAt +// call), and enforces the finalization guard: if the run is already +// finished, the UPDATE returns zero rows, the INSERT gets no source +// rows, and sql.ErrNoRows is returned. The UPDATE also serializes +// with concurrent FinalizeStale under READ COMMITTED isolation. func (q *sqlQuerier) InsertChatDebugStep(ctx context.Context, arg InsertChatDebugStepParams) (ChatDebugStep, error) { row := q.db.QueryRowContext(ctx, insertChatDebugStep, arg.RunID, @@ -3327,6 +3347,80 @@ func (q *sqlQuerier) InsertChatDebugStep(ctx context.Context, arg InsertChatDebu return i, err } +const touchChatDebugRunUpdatedAt = `-- name: TouchChatDebugRunUpdatedAt :exec +UPDATE chat_debug_runs +SET updated_at = $1::timestamptz +WHERE id = $2::uuid + AND chat_id = $3::uuid +` + +type TouchChatDebugRunUpdatedAtParams struct { + Now time.Time `db:"now" json:"now"` + ID uuid.UUID `db:"id" json:"id"` + ChatID uuid.UUID `db:"chat_id" json:"chat_id"` +} + +// Overrides updated_at on the parent run without touching any +// other column. Used by tests that need to stamp a run with a +// specific timestamp after the InsertChatDebugStep CTE has +// already bumped it to NOW(), so stale-row finalization paths +// can be exercised deterministically. The chatdebug service +// itself does not call this: heartbeats go through +// TouchChatDebugStepAndRun, and step creation updates the parent +// run via the InsertChatDebugStep CTE. +func (q *sqlQuerier) TouchChatDebugRunUpdatedAt(ctx context.Context, arg TouchChatDebugRunUpdatedAtParams) error { + _, err := q.db.ExecContext(ctx, touchChatDebugRunUpdatedAt, arg.Now, arg.ID, arg.ChatID) + return err +} + +const touchChatDebugStepAndRun = `-- name: TouchChatDebugStepAndRun :exec +WITH touched_run AS ( + UPDATE chat_debug_runs + SET updated_at = $1::timestamptz + WHERE id = $3::uuid + AND chat_id = $4::uuid + RETURNING id, chat_id +) +UPDATE chat_debug_steps +SET updated_at = $1::timestamptz +FROM touched_run +WHERE chat_debug_steps.id = $2::uuid + AND chat_debug_steps.run_id = touched_run.id + AND chat_debug_steps.chat_id = touched_run.chat_id +` + +type TouchChatDebugStepAndRunParams struct { + Now time.Time `db:"now" json:"now"` + StepID uuid.UUID `db:"step_id" json:"step_id"` + RunID uuid.UUID `db:"run_id" json:"run_id"` + ChatID uuid.UUID `db:"chat_id" json:"chat_id"` +} + +// Atomically bumps updated_at on both the step and its parent run +// in a single statement. This prevents FinalizeStale from +// interleaving between the two touches and finalizing a run whose +// step heartbeat was just written. +// +// The step UPDATE joins through touched_run (via FROM) and reads +// its RETURNING rows. Per the PostgreSQL WITH semantics, RETURNING +// is the only way to communicate values between a data-modifying +// CTE and the main query, and consuming those rows forces the run +// UPDATE to complete before the step UPDATE. That matches the +// lock order used by FinalizeStaleChatDebugRows and avoids a +// deadlock between concurrent heartbeats and stale sweeps. The +// join also constrains the step update to the specified run so a +// mismatched (run_id, step_id) pair cannot silently refresh an +// unrelated step. +func (q *sqlQuerier) TouchChatDebugStepAndRun(ctx context.Context, arg TouchChatDebugStepAndRunParams) error { + _, err := q.db.ExecContext(ctx, touchChatDebugStepAndRun, + arg.Now, + arg.StepID, + arg.RunID, + arg.ChatID, + ) + return err +} + const updateChatDebugRun = `-- name: UpdateChatDebugRun :one UPDATE chat_debug_runs SET @@ -3339,10 +3433,10 @@ SET provider = COALESCE($7::text, provider), model = COALESCE($8::text, model), summary = COALESCE($9::jsonb, summary), - finished_at = COALESCE($10::timestamptz, finished_at), - updated_at = NOW() -WHERE id = $11::uuid - AND chat_id = $12::uuid + finished_at = COALESCE(finished_at, $10::timestamptz), + updated_at = $11::timestamptz +WHERE id = $12::uuid + AND chat_id = $13::uuid RETURNING id, chat_id, root_chat_id, parent_chat_id, model_config_id, trigger_message_id, history_tip_message_id, kind, status, provider, model, summary, started_at, updated_at, finished_at ` @@ -3357,14 +3451,24 @@ type UpdateChatDebugRunParams struct { Model sql.NullString `db:"model" json:"model"` Summary pqtype.NullRawMessage `db:"summary" json:"summary"` FinishedAt sql.NullTime `db:"finished_at" json:"finished_at"` + Now time.Time `db:"now" json:"now"` ID uuid.UUID `db:"id" json:"id"` ChatID uuid.UUID `db:"chat_id" json:"chat_id"` } // Uses COALESCE so that passing NULL from Go means "keep the -// existing value." This is intentional: debug rows follow a +// existing value." This is intentional: debug rows follow a // write-once-finalize pattern where fields are set at creation -// or finalization and never cleared back to NULL. +// or finalization and never cleared back to NULL. The @now +// parameter keeps updated_at under the caller's clock. +// +// finished_at is enforced as write-once at the SQL level: once +// populated it cannot be overwritten by a later call. Callers +// that issue a summary or status refresh after the run has +// already finalized therefore cannot corrupt the original +// completion timestamp, which keeps duration and ordering +// calculations stable regardless of how many times the row is +// updated. func (q *sqlQuerier) UpdateChatDebugRun(ctx context.Context, arg UpdateChatDebugRunParams) (ChatDebugRun, error) { row := q.db.QueryRowContext(ctx, updateChatDebugRun, arg.RootChatID, @@ -3377,6 +3481,7 @@ func (q *sqlQuerier) UpdateChatDebugRun(ctx context.Context, arg UpdateChatDebug arg.Model, arg.Summary, arg.FinishedAt, + arg.Now, arg.ID, arg.ChatID, ) @@ -3414,9 +3519,9 @@ SET error = COALESCE($8::jsonb, error), metadata = COALESCE($9::jsonb, metadata), finished_at = COALESCE($10::timestamptz, finished_at), - updated_at = NOW() -WHERE id = $11::uuid - AND chat_id = $12::uuid + updated_at = $11::timestamptz +WHERE id = $12::uuid + AND chat_id = $13::uuid RETURNING id, run_id, chat_id, step_number, operation, status, history_tip_message_id, assistant_message_id, normalized_request, normalized_response, usage, attempts, error, metadata, started_at, updated_at, finished_at ` @@ -3431,14 +3536,17 @@ type UpdateChatDebugStepParams struct { Error pqtype.NullRawMessage `db:"error" json:"error"` Metadata pqtype.NullRawMessage `db:"metadata" json:"metadata"` FinishedAt sql.NullTime `db:"finished_at" json:"finished_at"` + Now time.Time `db:"now" json:"now"` ID uuid.UUID `db:"id" json:"id"` ChatID uuid.UUID `db:"chat_id" json:"chat_id"` } // Uses COALESCE so that passing NULL from Go means "keep the -// existing value." This is intentional: debug rows follow a +// existing value." This is intentional: debug rows follow a // write-once-finalize pattern where fields are set at creation -// or finalization and never cleared back to NULL. +// or finalization and never cleared back to NULL. The @now +// parameter keeps updated_at under the caller's clock, matching +// the injectable quartz.Clock used by FinalizeStale sweeps. func (q *sqlQuerier) UpdateChatDebugStep(ctx context.Context, arg UpdateChatDebugStepParams) (ChatDebugStep, error) { row := q.db.QueryRowContext(ctx, updateChatDebugStep, arg.Status, @@ -3451,6 +3559,7 @@ func (q *sqlQuerier) UpdateChatDebugStep(ctx context.Context, arg UpdateChatDebu arg.Error, arg.Metadata, arg.FinishedAt, + arg.Now, arg.ID, arg.ChatID, ) diff --git a/coderd/database/queries/chatdebug.sql b/coderd/database/queries/chatdebug.sql index 1fea29fa0f..a4cef61904 100644 --- a/coderd/database/queries/chatdebug.sql +++ b/coderd/database/queries/chatdebug.sql @@ -35,9 +35,18 @@ RETURNING *; -- name: UpdateChatDebugRun :one -- Uses COALESCE so that passing NULL from Go means "keep the --- existing value." This is intentional: debug rows follow a +-- existing value." This is intentional: debug rows follow a -- write-once-finalize pattern where fields are set at creation --- or finalization and never cleared back to NULL. +-- or finalization and never cleared back to NULL. The @now +-- parameter keeps updated_at under the caller's clock. +-- +-- finished_at is enforced as write-once at the SQL level: once +-- populated it cannot be overwritten by a later call. Callers +-- that issue a summary or status refresh after the run has +-- already finalized therefore cannot corrupt the original +-- completion timestamp, which keeps duration and ordering +-- calculations stable regardless of how many times the row is +-- updated. UPDATE chat_debug_runs SET root_chat_id = COALESCE(sqlc.narg('root_chat_id')::uuid, root_chat_id), @@ -49,13 +58,27 @@ SET provider = COALESCE(sqlc.narg('provider')::text, provider), model = COALESCE(sqlc.narg('model')::text, model), summary = COALESCE(sqlc.narg('summary')::jsonb, summary), - finished_at = COALESCE(sqlc.narg('finished_at')::timestamptz, finished_at), - updated_at = NOW() + finished_at = COALESCE(finished_at, sqlc.narg('finished_at')::timestamptz), + updated_at = @now::timestamptz WHERE id = @id::uuid AND chat_id = @chat_id::uuid RETURNING *; -- name: InsertChatDebugStep :one +-- The CTE atomically locks the parent run via UPDATE, bumps its +-- updated_at (eliminating a separate TouchChatDebugRunUpdatedAt +-- call), and enforces the finalization guard: if the run is already +-- finished, the UPDATE returns zero rows, the INSERT gets no source +-- rows, and sql.ErrNoRows is returned. The UPDATE also serializes +-- with concurrent FinalizeStale under READ COMMITTED isolation. +WITH locked_run AS ( + UPDATE chat_debug_runs + SET updated_at = COALESCE(sqlc.narg('updated_at')::timestamptz, NOW()) + WHERE id = @run_id::uuid + AND chat_id = @chat_id::uuid + AND finished_at IS NULL + RETURNING chat_id +) INSERT INTO chat_debug_steps ( run_id, chat_id, @@ -76,7 +99,7 @@ INSERT INTO chat_debug_steps ( ) SELECT @run_id::uuid, - run.chat_id, + locked_run.chat_id, @step_number::int, @operation::text, @status::text, @@ -91,16 +114,16 @@ SELECT COALESCE(sqlc.narg('started_at')::timestamptz, NOW()), COALESCE(sqlc.narg('updated_at')::timestamptz, NOW()), sqlc.narg('finished_at')::timestamptz -FROM chat_debug_runs run -WHERE run.id = @run_id::uuid - AND run.chat_id = @chat_id::uuid +FROM locked_run RETURNING *; -- name: UpdateChatDebugStep :one -- Uses COALESCE so that passing NULL from Go means "keep the --- existing value." This is intentional: debug rows follow a +-- existing value." This is intentional: debug rows follow a -- write-once-finalize pattern where fields are set at creation --- or finalization and never cleared back to NULL. +-- or finalization and never cleared back to NULL. The @now +-- parameter keeps updated_at under the caller's clock, matching +-- the injectable quartz.Clock used by FinalizeStale sweeps. UPDATE chat_debug_steps SET status = COALESCE(sqlc.narg('status')::text, status), @@ -113,11 +136,55 @@ SET error = COALESCE(sqlc.narg('error')::jsonb, error), metadata = COALESCE(sqlc.narg('metadata')::jsonb, metadata), finished_at = COALESCE(sqlc.narg('finished_at')::timestamptz, finished_at), - updated_at = NOW() + updated_at = @now::timestamptz WHERE id = @id::uuid AND chat_id = @chat_id::uuid RETURNING *; +-- name: TouchChatDebugRunUpdatedAt :exec +-- Overrides updated_at on the parent run without touching any +-- other column. Used by tests that need to stamp a run with a +-- specific timestamp after the InsertChatDebugStep CTE has +-- already bumped it to NOW(), so stale-row finalization paths +-- can be exercised deterministically. The chatdebug service +-- itself does not call this: heartbeats go through +-- TouchChatDebugStepAndRun, and step creation updates the parent +-- run via the InsertChatDebugStep CTE. +UPDATE chat_debug_runs +SET updated_at = @now::timestamptz +WHERE id = @id::uuid + AND chat_id = @chat_id::uuid; + +-- name: TouchChatDebugStepAndRun :exec +-- Atomically bumps updated_at on both the step and its parent run +-- in a single statement. This prevents FinalizeStale from +-- interleaving between the two touches and finalizing a run whose +-- step heartbeat was just written. +-- +-- The step UPDATE joins through touched_run (via FROM) and reads +-- its RETURNING rows. Per the PostgreSQL WITH semantics, RETURNING +-- is the only way to communicate values between a data-modifying +-- CTE and the main query, and consuming those rows forces the run +-- UPDATE to complete before the step UPDATE. That matches the +-- lock order used by FinalizeStaleChatDebugRows and avoids a +-- deadlock between concurrent heartbeats and stale sweeps. The +-- join also constrains the step update to the specified run so a +-- mismatched (run_id, step_id) pair cannot silently refresh an +-- unrelated step. +WITH touched_run AS ( + UPDATE chat_debug_runs + SET updated_at = @now::timestamptz + WHERE id = @run_id::uuid + AND chat_id = @chat_id::uuid + RETURNING id, chat_id +) +UPDATE chat_debug_steps +SET updated_at = @now::timestamptz +FROM touched_run +WHERE chat_debug_steps.id = @step_id::uuid + AND chat_debug_steps.run_id = touched_run.id + AND chat_debug_steps.chat_id = touched_run.chat_id; + -- name: GetChatDebugRunsByChatID :many -- Returns the most recent debug runs for a chat, ordered newest-first. -- Callers must supply an explicit limit to avoid unbounded result sets. @@ -168,20 +235,23 @@ WHERE chat_id = @chat_id::uuid -- name: FinalizeStaleChatDebugRows :one -- Marks orphaned in-progress rows as interrupted so they do not stay --- in a non-terminal state forever. The NOT IN list must match the +-- in a non-terminal state forever. The NOT IN list must match the -- terminal statuses defined by ChatDebugStatus in codersdk/chats.go. -- -- The steps CTE also catches steps whose parent run was just finalized -- (via run_id IN), because PostgreSQL data-modifying CTEs share the --- same snapshot and cannot see each other's row updates. Without this, +-- same snapshot and cannot see each other's row updates. Without this, -- a step with a recent updated_at would survive its run's finalization -- and remain in 'in_progress' state permanently. +-- +-- @now is the caller's clock timestamp so that mock-clock tests stay +-- consistent with the @updated_before cutoff. WITH finalized_runs AS ( UPDATE chat_debug_runs SET status = 'interrupted', - updated_at = NOW(), - finished_at = NOW() + updated_at = @now::timestamptz, + finished_at = @now::timestamptz WHERE updated_at < @updated_before::timestamptz AND finished_at IS NULL AND status NOT IN ('completed', 'error', 'interrupted') @@ -190,8 +260,8 @@ WITH finalized_runs AS ( UPDATE chat_debug_steps SET status = 'interrupted', - updated_at = NOW(), - finished_at = NOW() + updated_at = @now::timestamptz, + finished_at = @now::timestamptz WHERE ( updated_at < @updated_before::timestamptz OR run_id IN (SELECT id FROM finalized_runs) diff --git a/coderd/x/chatd/chatdebug/context_internal_test.go b/coderd/x/chatd/chatdebug/context_internal_test.go index e4d83ae14c..e109ab1749 100644 --- a/coderd/x/chatd/chatdebug/context_internal_test.go +++ b/coderd/x/chatd/chatdebug/context_internal_test.go @@ -32,9 +32,8 @@ func TestContextWithRun_CleansUpStepCounterAfterGC(t *testing.T) { t.Cleanup(func() { CleanupStepCounter(runID) }) func() { - ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID}) - handle, _ := beginStep(ctx, &Service{}, RecorderOptions{ChatID: chatID}, OperationGenerate, nil) - require.NotNil(t, handle) + _ = ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID}) + require.Equal(t, int32(1), nextStepNumber(runID)) _, ok := stepCounters.Load(runID) require.True(t, ok) }() @@ -56,17 +55,15 @@ func TestContextWithRun_MultipleInstancesSameRunID(t *testing.T) { // rc2 is the surviving instance that should keep the step counter alive. rc2 := &RunContext{RunID: runID, ChatID: chatID} - ctx2 := ContextWithRun(context.Background(), rc2) + _ = ContextWithRun(context.Background(), rc2) // Create a second RunContext with the same RunID and let it become // unreachable. Its GC cleanup must NOT delete the step counter // because rc2 is still alive. func() { rc1 := &RunContext{RunID: runID, ChatID: chatID} - ctx1 := ContextWithRun(context.Background(), rc1) - h, _ := beginStep(ctx1, &Service{}, RecorderOptions{ChatID: chatID}, OperationGenerate, nil) - require.NotNil(t, h) - require.Equal(t, int32(1), h.stepCtx.StepNumber) + _ = ContextWithRun(context.Background(), rc1) + require.Equal(t, int32(1), nextStepNumber(runID)) }() // Force GC to collect rc1. @@ -80,9 +77,11 @@ func TestContextWithRun_MultipleInstancesSameRunID(t *testing.T) { require.True(t, ok, "step counter was prematurely cleaned up while another RunContext is still alive") // Subsequent steps on the surviving context must continue numbering. - h2, _ := beginStep(ctx2, &Service{}, RecorderOptions{ChatID: chatID}, OperationGenerate, nil) - require.NotNil(t, h2) - require.Equal(t, int32(2), h2.stepCtx.StepNumber) + require.Equal(t, int32(2), nextStepNumber(runID)) + + // Keep rc2 alive past the GC cycles above so the runtime cleanup + // finalizer does not fire prematurely. + runtime.KeepAlive(rc2) } func TestContextWithRun_CleansUpStepCounterOnGCAfterCancel(t *testing.T) { @@ -96,11 +95,9 @@ func TestContextWithRun_CleansUpStepCounterOnGCAfterCancel(t *testing.T) { // context cancellation, allowing GC to trigger the cleanup. func() { ctx, cancel := context.WithCancel(context.Background()) - ctx = ContextWithRun(ctx, &RunContext{RunID: runID, ChatID: chatID}) + ContextWithRun(ctx, &RunContext{RunID: runID, ChatID: chatID}) - handle, _ := beginStep(ctx, &Service{}, RecorderOptions{ChatID: chatID}, OperationGenerate, nil) - require.NotNil(t, handle) - require.Equal(t, int32(1), handle.stepCtx.StepNumber) + require.Equal(t, int32(1), nextStepNumber(runID)) _, ok := stepCounters.Load(runID) require.True(t, ok) @@ -117,8 +114,5 @@ func TestContextWithRun_CleansUpStepCounterOnGCAfterCancel(t *testing.T) { return !ok }, testutil.WaitShort, testutil.IntervalFast) - freshCtx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID}) - freshHandle, _ := beginStep(freshCtx, &Service{}, RecorderOptions{ChatID: chatID}, OperationGenerate, nil) - require.NotNil(t, freshHandle) - require.Equal(t, int32(1), freshHandle.stepCtx.StepNumber) + require.Equal(t, int32(1), nextStepNumber(runID)) } diff --git a/coderd/x/chatd/chatdebug/model.go b/coderd/x/chatd/chatdebug/model.go index 2a33b43d88..0ac7326080 100644 --- a/coderd/x/chatd/chatdebug/model.go +++ b/coderd/x/chatd/chatdebug/model.go @@ -12,7 +12,11 @@ import ( "unicode/utf8" "charm.land/fantasy" + "github.com/google/uuid" "golang.org/x/xerrors" + + "cdr.dev/slog/v3" + stringutil "github.com/coder/coder/v2/coderd/util/strings" ) type debugModel struct { @@ -212,7 +216,13 @@ func (d *debugModel) Generate( return d.inner.Generate(ctx, call) } + // Keep the step alive during the blocking provider call so the + // stale finalizer does not mark it as interrupted. + heartbeatDone := make(chan struct{}) + launchHeartbeat(ctx, handle.svc, handle.stepCtx.StepID, handle.stepCtx.RunID, handle.stepCtx.ChatID, heartbeatDone) + resp, err := d.inner.Generate(enrichedCtx, call) + close(heartbeatDone) if err != nil { handle.finish(ctx, stepStatusForError(err), nil, nil, normalizeError(ctx, err), nil) return nil, err @@ -275,7 +285,13 @@ func (d *debugModel) GenerateObject( return d.inner.GenerateObject(ctx, call) } + // Keep the step alive during the blocking provider call so the + // stale finalizer does not mark it as interrupted. + heartbeatDone := make(chan struct{}) + launchHeartbeat(ctx, handle.svc, handle.stepCtx.StepID, handle.stepCtx.RunID, handle.stepCtx.ChatID, heartbeatDone) + resp, err := d.inner.GenerateObject(enrichedCtx, call) + close(heartbeatDone) if err != nil { handle.finish(ctx, stepStatusForError(err), nil, nil, normalizeError(ctx, err), map[string]any{"structured_output": true}) @@ -334,6 +350,53 @@ func (d *debugModel) Model() string { return d.inner.Model() } +// launchHeartbeat starts a goroutine that periodically calls TouchStep +// to keep the step and run rows alive during long-running streams. The +// goroutine also listens on the service's threshold-change channel so +// that a runtime SetStaleAfter call immediately resets the ticker +// instead of waiting for the old (possibly longer) period to elapse. +// The goroutine exits when done is closed or ctx is canceled. +func launchHeartbeat(ctx context.Context, svc *Service, stepID, runID, chatID uuid.UUID, done <-chan struct{}) { + if svc == nil { + return + } + go func() { + interval := svc.heartbeatInterval() + ticker := svc.clock.NewTicker(interval, "chatdebug", "heartbeat") + defer ticker.Stop() + thresholdCh := svc.thresholdChan() + for { + select { + case <-ctx.Done(): + return + case <-done: + return + case <-thresholdCh: + // SetStaleAfter was called; re-read the interval + // and reset the ticker immediately. + thresholdCh = svc.thresholdChan() + if newInterval := svc.heartbeatInterval(); newInterval != interval { + interval = newInterval + ticker.Reset(interval, "chatdebug", "heartbeat") + } + case <-ticker.C: + if err := svc.TouchStep(ctx, stepID, runID, chatID); err != nil { + svc.log.Debug(ctx, "heartbeat touch failed", + slog.Error(err), + slog.F("step_id", stepID), + ) + } + // Also re-read interval on every tick as a + // secondary check. + if newInterval := svc.heartbeatInterval(); newInterval != interval { + interval = newInterval + ticker.Reset(interval, "chatdebug", "heartbeat") + } + } + } + }() +} + func wrapStreamSeq( ctx context.Context, handle *stepHandle, @@ -354,6 +417,10 @@ func wrapStreamSeq( streamComplete atomic.Bool ) + // heartbeatDone is closed when the stream finalizes (either + // normally or via the safety net) to stop the heartbeat goroutine. + heartbeatDone := make(chan struct{}) + // Safety net: if the caller drops the returned iterator without // consuming it (or abandons mid-stream and the context is // canceled), finalize the step so it does not remain permanently @@ -368,10 +435,20 @@ func wrapStreamSeq( return } finalized = true + close(heartbeatDone) handle.finish(ctx, StatusInterrupted, nil, nil, nil, nil) }) + // startHeartbeat launches the heartbeat goroutine on first call. + // Deferring the start until the caller begins consuming the stream + // prevents leaked goroutines when the iterator is dropped without + // being iterated. + startHeartbeat := sync.OnceFunc(func() { + launchHeartbeat(ctx, handle.svc, handle.stepCtx.StepID, handle.stepCtx.RunID, handle.stepCtx.ChatID, heartbeatDone) + }) + return func(yield func(fantasy.StreamPart) bool) { + startHeartbeat() var ( summary streamSummary latestUsage fantasy.Usage @@ -386,7 +463,7 @@ func wrapStreamSeq( ) finalize := func(status Status) { - // Cancel the safety net since we're finalizing normally. + // Cancel the safety net and heartbeat since we're finalizing. if stop != nil { stop() } @@ -396,6 +473,7 @@ func wrapStreamSeq( return } finalized = true + close(heartbeatDone) summary.FinishReason = string(finishReason) @@ -506,6 +584,9 @@ func wrapObjectStreamSeq( finalized bool streamComplete atomic.Bool ) + + heartbeatDone := make(chan struct{}) + stop := context.AfterFunc(ctx, func() { mu.Lock() defer mu.Unlock() @@ -513,10 +594,18 @@ func wrapObjectStreamSeq( return } finalized = true + close(heartbeatDone) handle.finish(ctx, StatusInterrupted, nil, nil, nil, nil) }) + // Deferred heartbeat: start the heartbeat goroutine only when the + // caller begins consuming the stream. + startHeartbeat := sync.OnceFunc(func() { + launchHeartbeat(ctx, handle.svc, handle.stepCtx.StepID, handle.stepCtx.RunID, handle.stepCtx.ChatID, heartbeatDone) + }) + return func(yield func(fantasy.ObjectStreamPart) bool) { + startHeartbeat() var ( summary = objectStreamSummary{StructuredOutput: true} latestUsage fantasy.Usage @@ -539,6 +628,7 @@ func wrapObjectStreamSeq( return } finalized = true + close(heartbeatDone) summary.FinishReason = string(finishReason) @@ -643,7 +733,7 @@ func normalizeMessages(prompt fantasy.Prompt) []normalizedMessage { // boundText truncates s to MaxMessagePartTextLength runes, appending // an ellipsis if truncation occurs. func boundText(s string) string { - return truncateRunes(s, MaxMessagePartTextLength) + return stringutil.Truncate(s, MaxMessagePartTextLength, stringutil.TruncateWithEllipsis) } // safeMarshalJSON marshals value to JSON. On failure it returns a diff --git a/coderd/x/chatd/chatdebug/model_internal_test.go b/coderd/x/chatd/chatdebug/model_internal_test.go index f40204bb82..a3386a7058 100644 --- a/coderd/x/chatd/chatdebug/model_internal_test.go +++ b/coderd/x/chatd/chatdebug/model_internal_test.go @@ -2,11 +2,13 @@ package chatdebug import ( "context" + "encoding/json" "io" "net/http" "net/http/httptest" "strings" "testing" + "time" "charm.land/fantasy" "github.com/google/uuid" @@ -14,15 +16,131 @@ import ( "go.uber.org/mock/gomock" "golang.org/x/xerrors" + "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbmock" "github.com/coder/coder/v2/coderd/x/chatd/chattest" "github.com/coder/coder/v2/testutil" + "github.com/coder/quartz" ) type testError struct{ message string } func (e *testError) Error() string { return e.message } +func expectDebugLoggingEnabled( + t *testing.T, + db *dbmock.MockStore, + ownerID uuid.UUID, +) { + t.Helper() + + db.EXPECT().GetChatDebugLoggingAllowUsers(gomock.Any()).Return(true, nil) + db.EXPECT().GetUserChatDebugLoggingEnabled(gomock.Any(), ownerID).Return(true, nil) +} + +func expectCreateStepNumberWithRequestValidity( + t *testing.T, + db *dbmock.MockStore, + runID uuid.UUID, + chatID uuid.UUID, + stepNumber int32, + op Operation, + normalizedRequestValid bool, +) uuid.UUID { + t.Helper() + + stepID := uuid.New() + + db.EXPECT(). + InsertChatDebugStep(gomock.Any(), gomock.AssignableToTypeOf(database.InsertChatDebugStepParams{})). + DoAndReturn(func(_ context.Context, params database.InsertChatDebugStepParams) (database.ChatDebugStep, error) { + require.Equal(t, runID, params.RunID) + require.Equal(t, chatID, params.ChatID) + require.Equal(t, stepNumber, params.StepNumber) + require.Equal(t, string(op), params.Operation) + require.Equal(t, string(StatusInProgress), params.Status) + require.Equal(t, normalizedRequestValid, params.NormalizedRequest.Valid) + + return database.ChatDebugStep{ + ID: stepID, + RunID: runID, + ChatID: chatID, + StepNumber: params.StepNumber, + Operation: params.Operation, + Status: params.Status, + }, nil + }) + + // The INSERT CTE atomically bumps the parent run's updated_at, + // so no separate TouchChatDebugRunUpdatedAt call is needed. + + return stepID +} + +func expectCreateStepNumber( + t *testing.T, + db *dbmock.MockStore, + runID uuid.UUID, + chatID uuid.UUID, + stepNumber int32, + op Operation, +) uuid.UUID { + t.Helper() + + return expectCreateStepNumberWithRequestValidity( + t, + db, + runID, + chatID, + stepNumber, + op, + true, + ) +} + +func expectCreateStep( + t *testing.T, + db *dbmock.MockStore, + runID uuid.UUID, + chatID uuid.UUID, + op Operation, +) uuid.UUID { + t.Helper() + + return expectCreateStepNumber(t, db, runID, chatID, 1, op) +} + +func expectUpdateStep( + t *testing.T, + db *dbmock.MockStore, + stepID uuid.UUID, + chatID uuid.UUID, + status Status, + assertFn func(database.UpdateChatDebugStepParams), +) { + t.Helper() + + db.EXPECT(). + UpdateChatDebugStep(gomock.Any(), gomock.AssignableToTypeOf(database.UpdateChatDebugStepParams{})). + DoAndReturn(func(_ context.Context, params database.UpdateChatDebugStepParams) (database.ChatDebugStep, error) { + require.Equal(t, stepID, params.ID) + require.Equal(t, chatID, params.ChatID) + require.True(t, params.Status.Valid) + require.Equal(t, string(status), params.Status.String) + require.True(t, params.FinishedAt.Valid) + + if assertFn != nil { + assertFn(params) + } + + return database.ChatDebugStep{ + ID: stepID, + ChatID: chatID, + Status: params.Status.String, + }, nil + }) +} + func TestDebugModel_Provider(t *testing.T) { t.Parallel() @@ -48,6 +166,7 @@ func TestDebugModel_Disabled(t *testing.T) { db := dbmock.NewMockStore(ctrl) chatID := uuid.New() ownerID := uuid.New() + svc := NewService(db, testutil.Logger(t), nil) respWant := &fantasy.Response{FinishReason: fantasy.FinishReasonStop} inner := &chattest.FakeModel{ @@ -97,6 +216,30 @@ func TestDebugModel_Generate(t *testing.T) { Warnings: []fantasy.CallWarning{{Message: "warning"}}, } + expectDebugLoggingEnabled(t, db, ownerID) + stepID := expectCreateStep(t, db, runID, chatID, OperationGenerate) + expectUpdateStep(t, db, stepID, chatID, StatusCompleted, func(params database.UpdateChatDebugStepParams) { + require.True(t, params.NormalizedResponse.Valid) + require.True(t, params.Usage.Valid) + require.True(t, params.Attempts.Valid) + // Clean successes (no prior error) leave the error column + // as SQL NULL rather than sending jsonClear. + require.False(t, params.Error.Valid) + require.False(t, params.Metadata.Valid) + + // Verify actual JSON content so a broken tag or field + // rename is caught rather than only checking .Valid. + var usage fantasy.Usage + require.NoError(t, json.Unmarshal(params.Usage.RawMessage, &usage)) + require.EqualValues(t, 10, usage.InputTokens) + require.EqualValues(t, 4, usage.OutputTokens) + require.EqualValues(t, 14, usage.TotalTokens) + + var resp map[string]any + require.NoError(t, json.Unmarshal(params.NormalizedResponse.RawMessage, &resp)) + require.Equal(t, "stop", resp["finish_reason"]) + }) + svc := NewService(db, testutil.Logger(t), nil) inner := &chattest.FakeModel{ GenerateFn: func(ctx context.Context, got fantasy.Call) (*fantasy.Response, error) { @@ -149,6 +292,20 @@ func TestDebugModel_GeneratePersistsAttemptsWithoutResponseClose(t *testing.T) { })) defer server.Close() + expectDebugLoggingEnabled(t, db, ownerID) + stepID := expectCreateStep(t, db, runID, chatID, OperationGenerate) + expectUpdateStep(t, db, stepID, chatID, StatusCompleted, func(params database.UpdateChatDebugStepParams) { + require.True(t, params.Attempts.Valid) + require.True(t, params.NormalizedResponse.Valid) + require.True(t, params.Usage.Valid) + + var attempts []Attempt + require.NoError(t, json.Unmarshal(params.Attempts.RawMessage, &attempts)) + require.Len(t, attempts, 1) + require.Equal(t, attemptStatusCompleted, attempts[0].Status) + require.Equal(t, http.StatusCreated, attempts[0].ResponseStatus) + }) + svc := NewService(db, testutil.Logger(t), nil) inner := &chattest.FakeModel{ GenerateFn: func(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) { @@ -197,6 +354,21 @@ func TestDebugModel_GenerateError(t *testing.T) { runID := uuid.New() wantErr := &testError{message: "boom"} + expectDebugLoggingEnabled(t, db, ownerID) + stepID := expectCreateStep(t, db, runID, chatID, OperationGenerate) + expectUpdateStep(t, db, stepID, chatID, StatusError, func(params database.UpdateChatDebugStepParams) { + require.False(t, params.NormalizedResponse.Valid) + require.False(t, params.Usage.Valid) + require.True(t, params.Attempts.Valid) + require.True(t, params.Error.Valid) + require.False(t, params.Metadata.Valid) + + var errPayload normalizedErrorPayload + require.NoError(t, json.Unmarshal(params.Error.RawMessage, &errPayload)) + require.Equal(t, "boom", errPayload.Message) + require.Equal(t, "*chatdebug.testError", errPayload.Type) + }) + svc := NewService(db, testutil.Logger(t), nil) model := &debugModel{ inner: &chattest.FakeModel{ @@ -215,6 +387,77 @@ func TestDebugModel_GenerateError(t *testing.T) { require.ErrorIs(t, err, wantErr) } +// TestDebugModel_GenerateRetryClearsError verifies that when a Generate +// call fails and is retried on the same reused step, a successful retry +// explicitly overwrites the stored error payload with JSONB null via +// the jsonClear sentinel. Without this, COALESCE would preserve the +// stale error and AggregateRunSummary would flag the run as errored. +func TestDebugModel_GenerateRetryClearsError(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + chatID := uuid.New() + ownerID := uuid.New() + runID := uuid.New() + wantErr := &testError{message: "transient"} + + // Allow enablement check twice, once per Generate call. + db.EXPECT().GetChatDebugLoggingAllowUsers(gomock.Any()).Return(true, nil).Times(2) + db.EXPECT().GetUserChatDebugLoggingEnabled(gomock.Any(), ownerID).Return(true, nil).Times(2) + stepID := expectCreateStep(t, db, runID, chatID, OperationGenerate) + + // First finalization: error. + expectUpdateStep(t, db, stepID, chatID, StatusError, func(params database.UpdateChatDebugStepParams) { + require.True(t, params.Error.Valid, "error payload must be present on first (failed) finalization") + require.NotEqual(t, json.RawMessage("null"), params.Error.RawMessage, + "first finalization should carry the real error, not JSONB null") + }) + + // Second finalization: success with explicit error clear. + expectUpdateStep(t, db, stepID, chatID, StatusCompleted, func(params database.UpdateChatDebugStepParams) { + require.True(t, params.Error.Valid, + "error field must be Valid (JSONB null) so COALESCE overwrites the previous error") + require.JSONEq(t, "null", string(params.Error.RawMessage), + "successful retry must send JSONB null to clear the stale error") + require.True(t, params.NormalizedResponse.Valid) + require.True(t, params.Usage.Valid) + }) + + callCount := 0 + svc := NewService(db, testutil.Logger(t), nil) + model := &debugModel{ + inner: &chattest.FakeModel{ + GenerateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) { + callCount++ + if callCount == 1 { + return nil, wantErr + } + return &fantasy.Response{ + FinishReason: fantasy.FinishReasonStop, + Usage: fantasy.Usage{InputTokens: 5, OutputTokens: 2}, + }, nil + }, + }, + svc: svc, + opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID}, + } + t.Cleanup(func() { CleanupStepCounter(runID) }) + + ctx := ReuseStep(ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID})) + + // First call: fails. + resp, err := model.Generate(ctx, fantasy.Call{}) + require.Nil(t, resp) + require.ErrorIs(t, err, wantErr) + + // Second call: succeeds, reuses the same step and clears the error. + resp, err = model.Generate(ctx, fantasy.Call{}) + require.NoError(t, err) + require.NotNil(t, resp) + require.Equal(t, 2, callCount) +} + func TestStepStatusForError(t *testing.T) { t.Parallel() @@ -252,6 +495,44 @@ func TestDebugModel_Stream(t *testing.T) { {Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop, Usage: fantasy.Usage{InputTokens: 8, OutputTokens: 3, TotalTokens: 11}}, } + expectDebugLoggingEnabled(t, db, ownerID) + stepID := expectCreateStep(t, db, runID, chatID, OperationStream) + expectUpdateStep(t, db, stepID, chatID, StatusError, func(params database.UpdateChatDebugStepParams) { + require.True(t, params.NormalizedResponse.Valid) + require.True(t, params.Usage.Valid) + require.True(t, params.Attempts.Valid) + require.True(t, params.Error.Valid) + require.True(t, params.Metadata.Valid) + + // Verify usage JSON content matches the finish part. + var usage normalizedUsage + require.NoError(t, json.Unmarshal(params.Usage.RawMessage, &usage)) + require.EqualValues(t, 8, usage.InputTokens) + require.EqualValues(t, 3, usage.OutputTokens) + require.EqualValues(t, 11, usage.TotalTokens) + + // Verify the response payload captures the streamed content. + var resp normalizedResponsePayload + require.NoError(t, json.Unmarshal(params.NormalizedResponse.RawMessage, &resp)) + require.Equal(t, "stop", resp.FinishReason) + require.NotEmpty(t, resp.Content, "stream response should capture content parts") + + // Verify error payload comes from the stream error part. + var errPayload normalizedErrorPayload + require.NoError(t, json.Unmarshal(params.Error.RawMessage, &errPayload)) + require.Equal(t, "chunk failed", errPayload.Message) + + // Verify metadata contains stream_summary. + var meta map[string]any + require.NoError(t, json.Unmarshal(params.Metadata.RawMessage, &meta)) + summary, ok := meta["stream_summary"].(map[string]any) + require.True(t, ok, "metadata must contain stream_summary") + require.EqualValues(t, 1, summary["text_delta_count"]) + require.EqualValues(t, 1, summary["tool_call_count"]) + require.EqualValues(t, 1, summary["source_count"]) + require.EqualValues(t, 1, summary["error_count"]) + }) + svc := NewService(db, testutil.Logger(t), nil) model := &debugModel{ inner: &chattest.FakeModel{ @@ -299,6 +580,42 @@ func TestDebugModel_StreamObject(t *testing.T) { {Type: fantasy.ObjectStreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop, Usage: fantasy.Usage{InputTokens: 5, OutputTokens: 2, TotalTokens: 7}}, } + expectDebugLoggingEnabled(t, db, ownerID) + stepID := expectCreateStep(t, db, runID, chatID, OperationStream) + expectUpdateStep(t, db, stepID, chatID, StatusCompleted, func(params database.UpdateChatDebugStepParams) { + require.True(t, params.NormalizedResponse.Valid) + require.True(t, params.Usage.Valid) + require.True(t, params.Attempts.Valid) + // Clean successes (no prior error) leave the error column + // as SQL NULL rather than sending jsonClear. + require.False(t, params.Error.Valid) + require.True(t, params.Metadata.Valid) + + // Verify usage JSON content matches the finish part. + var usage normalizedUsage + require.NoError(t, json.Unmarshal(params.Usage.RawMessage, &usage)) + require.EqualValues(t, 5, usage.InputTokens) + require.EqualValues(t, 2, usage.OutputTokens) + require.EqualValues(t, 7, usage.TotalTokens) + + // Verify the object response payload. + var resp normalizedObjectResponsePayload + require.NoError(t, json.Unmarshal(params.NormalizedResponse.RawMessage, &resp)) + require.Equal(t, "stop", resp.FinishReason) + require.True(t, resp.StructuredOutput) + // "ob" + "ject" = 6 runes. + require.Equal(t, 6, resp.RawTextLength) + + // Verify metadata contains structured_output flag. + var meta map[string]any + require.NoError(t, json.Unmarshal(params.Metadata.RawMessage, &meta)) + require.Equal(t, true, meta["structured_output"]) + summary, ok := meta["stream_summary"].(map[string]any) + require.True(t, ok, "metadata must contain stream_summary") + require.EqualValues(t, 2, summary["text_delta_count"]) + require.EqualValues(t, 1, summary["object_part_count"]) + }) + svc := NewService(db, testutil.Logger(t), nil) model := &debugModel{ inner: &chattest.FakeModel{ @@ -337,29 +654,47 @@ func TestDebugModel_StreamObject(t *testing.T) { func TestDebugModel_StreamCompletedAfterFinish(t *testing.T) { t.Parallel() - handle := &stepHandle{ - stepCtx: &StepContext{StepID: uuid.New(), RunID: uuid.New(), ChatID: uuid.New()}, - sink: &attemptSink{}, - } - + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + chatID := uuid.New() + ownerID := uuid.New() + runID := uuid.New() parts := []fantasy.StreamPart{ {Type: fantasy.StreamPartTypeTextDelta, Delta: "hello"}, {Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop, Usage: fantasy.Usage{InputTokens: 5, OutputTokens: 1, TotalTokens: 6}}, } - seq := wrapStreamSeq(context.Background(), handle, partsToSeq(parts)) - // Consumer reads through the finish part then breaks. The wrapper - // should finalize as completed, not interrupted. + // The mock expectation for UpdateStep with StatusCompleted is the + // assertion: if the wrapper chose StatusInterrupted instead, the + // mock would reject the call. + expectDebugLoggingEnabled(t, db, ownerID) + stepID := expectCreateStep(t, db, runID, chatID, OperationStream) + expectUpdateStep(t, db, stepID, chatID, StatusCompleted, nil) + + svc := NewService(db, testutil.Logger(t), nil) + model := &debugModel{ + inner: &chattest.FakeModel{ + StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { + return partsToSeq(parts), nil + }, + }, + svc: svc, + opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID}, + } + t.Cleanup(func() { CleanupStepCounter(runID) }) + ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID}) + + seq, err := model.Stream(ctx, fantasy.Call{}) + require.NoError(t, err) + + // Consumer reads the finish part then breaks. This should still + // be considered a completed stream, not interrupted. for part := range seq { if part.Type == fantasy.StreamPartTypeFinish { break } } - - handle.mu.Lock() - status := handle.status - handle.mu.Unlock() - require.Equal(t, StatusCompleted, status) + // gomock verifies UpdateStep was called with StatusCompleted. } // TestDebugModel_StreamInterruptedBeforeFinish verifies that when a consumer @@ -368,17 +703,38 @@ func TestDebugModel_StreamCompletedAfterFinish(t *testing.T) { func TestDebugModel_StreamInterruptedBeforeFinish(t *testing.T) { t.Parallel() - handle := &stepHandle{ - stepCtx: &StepContext{StepID: uuid.New(), RunID: uuid.New(), ChatID: uuid.New()}, - sink: &attemptSink{}, - } - + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + chatID := uuid.New() + ownerID := uuid.New() + runID := uuid.New() parts := []fantasy.StreamPart{ {Type: fantasy.StreamPartTypeTextDelta, Delta: "hello"}, {Type: fantasy.StreamPartTypeTextDelta, Delta: " world"}, {Type: fantasy.StreamPartTypeFinish, FinishReason: fantasy.FinishReasonStop}, } - seq := wrapStreamSeq(context.Background(), handle, partsToSeq(parts)) + + // The mock expectation for UpdateStep with StatusInterrupted is the + // assertion: breaking before the finish part means interrupted. + expectDebugLoggingEnabled(t, db, ownerID) + stepID := expectCreateStep(t, db, runID, chatID, OperationStream) + expectUpdateStep(t, db, stepID, chatID, StatusInterrupted, nil) + + svc := NewService(db, testutil.Logger(t), nil) + model := &debugModel{ + inner: &chattest.FakeModel{ + StreamFn: func(_ context.Context, _ fantasy.Call) (fantasy.StreamResponse, error) { + return partsToSeq(parts), nil + }, + }, + svc: svc, + opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID}, + } + t.Cleanup(func() { CleanupStepCounter(runID) }) + ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID}) + + seq, err := model.Stream(ctx, fantasy.Call{}) + require.NoError(t, err) // Consumer reads the first delta then breaks before finish. count := 0 @@ -389,11 +745,7 @@ func TestDebugModel_StreamInterruptedBeforeFinish(t *testing.T) { } } require.Equal(t, 1, count) - - handle.mu.Lock() - status := handle.status - handle.mu.Unlock() - require.Equal(t, StatusInterrupted, status) + // gomock verifies UpdateStep was called with StatusInterrupted. } func TestDebugModel_StreamRejectsNilSequence(t *testing.T) { @@ -402,7 +754,23 @@ func TestDebugModel_StreamRejectsNilSequence(t *testing.T) { ctrl := gomock.NewController(t) db := dbmock.NewMockStore(ctrl) chatID := uuid.New() + ownerID := uuid.New() runID := uuid.New() + + expectDebugLoggingEnabled(t, db, ownerID) + stepID := expectCreateStep(t, db, runID, chatID, OperationStream) + expectUpdateStep(t, db, stepID, chatID, StatusError, func(params database.UpdateChatDebugStepParams) { + require.False(t, params.NormalizedResponse.Valid) + require.False(t, params.Usage.Valid) + require.True(t, params.Attempts.Valid) + require.True(t, params.Error.Valid) + require.False(t, params.Metadata.Valid) + + var errPayload normalizedErrorPayload + require.NoError(t, json.Unmarshal(params.Error.RawMessage, &errPayload)) + require.Contains(t, errPayload.Message, "nil") + }) + svc := NewService(db, testutil.Logger(t), nil) model := &debugModel{ inner: &chattest.FakeModel{ @@ -412,9 +780,10 @@ func TestDebugModel_StreamRejectsNilSequence(t *testing.T) { }, }, svc: svc, - opts: RecorderOptions{ChatID: chatID, OwnerID: uuid.New()}, + opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID}, } t.Cleanup(func() { CleanupStepCounter(runID) }) + ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID}) seq, err := model.Stream(ctx, fantasy.Call{}) @@ -428,7 +797,28 @@ func TestDebugModel_StreamObjectRejectsNilSequence(t *testing.T) { ctrl := gomock.NewController(t) db := dbmock.NewMockStore(ctrl) chatID := uuid.New() + ownerID := uuid.New() runID := uuid.New() + + expectDebugLoggingEnabled(t, db, ownerID) + stepID := expectCreateStep(t, db, runID, chatID, OperationStream) + expectUpdateStep(t, db, stepID, chatID, StatusError, func(params database.UpdateChatDebugStepParams) { + require.False(t, params.NormalizedResponse.Valid) + require.False(t, params.Usage.Valid) + require.True(t, params.Attempts.Valid) + require.True(t, params.Error.Valid) + require.True(t, params.Metadata.Valid) + + var errPayload normalizedErrorPayload + require.NoError(t, json.Unmarshal(params.Error.RawMessage, &errPayload)) + require.Contains(t, errPayload.Message, "nil") + + // Object stream always passes structured_output metadata. + var meta map[string]any + require.NoError(t, json.Unmarshal(params.Metadata.RawMessage, &meta)) + require.Equal(t, true, meta["structured_output"]) + }) + svc := NewService(db, testutil.Logger(t), nil) model := &debugModel{ inner: &chattest.FakeModel{ @@ -438,9 +828,10 @@ func TestDebugModel_StreamObjectRejectsNilSequence(t *testing.T) { }, }, svc: svc, - opts: RecorderOptions{ChatID: chatID, OwnerID: uuid.New()}, + opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID}, } t.Cleanup(func() { CleanupStepCounter(runID) }) + ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID}) seq, err := model.StreamObject(ctx, fantasy.ObjectCall{}) @@ -461,6 +852,32 @@ func TestDebugModel_StreamEarlyStop(t *testing.T) { {Type: fantasy.StreamPartTypeTextDelta, Delta: "second"}, } + expectDebugLoggingEnabled(t, db, ownerID) + stepID := expectCreateStep(t, db, runID, chatID, OperationStream) + expectUpdateStep(t, db, stepID, chatID, StatusInterrupted, func(params database.UpdateChatDebugStepParams) { + require.True(t, params.NormalizedResponse.Valid) + require.False(t, params.Usage.Valid) + require.True(t, params.Attempts.Valid) + require.False(t, params.Error.Valid) + require.True(t, params.Metadata.Valid) + + // Verify that the partial response captures the single + // consumed text delta. + var resp normalizedResponsePayload + require.NoError(t, json.Unmarshal(params.NormalizedResponse.RawMessage, &resp)) + require.NotEmpty(t, resp.Content) + // Finish reason is empty because consumer stopped before + // the finish part. + require.Empty(t, resp.FinishReason) + + // Verify stream_summary reflects partial consumption. + var meta map[string]any + require.NoError(t, json.Unmarshal(params.Metadata.RawMessage, &meta)) + summary, ok := meta["stream_summary"].(map[string]any) + require.True(t, ok, "metadata must contain stream_summary") + require.EqualValues(t, 1, summary["text_delta_count"]) + }) + svc := NewService(db, testutil.Logger(t), nil) model := &debugModel{ inner: &chattest.FakeModel{ @@ -549,6 +966,37 @@ func TestDebugModel_GenerateObject(t *testing.T) { Usage: fantasy.Usage{InputTokens: 5, OutputTokens: 3, TotalTokens: 8}, } + expectDebugLoggingEnabled(t, db, ownerID) + stepID := expectCreateStep(t, db, runID, chatID, OperationGenerate) + expectUpdateStep(t, db, stepID, chatID, StatusCompleted, func(params database.UpdateChatDebugStepParams) { + require.True(t, params.NormalizedResponse.Valid) + require.True(t, params.Usage.Valid) + require.True(t, params.Attempts.Valid) + require.False(t, params.Error.Valid) + // GenerateObject always passes structured_output metadata. + require.True(t, params.Metadata.Valid) + + // Verify usage JSON content. + var usage normalizedUsage + require.NoError(t, json.Unmarshal(params.Usage.RawMessage, &usage)) + require.EqualValues(t, 5, usage.InputTokens) + require.EqualValues(t, 3, usage.OutputTokens) + require.EqualValues(t, 8, usage.TotalTokens) + + // Verify the object response payload. + var resp normalizedObjectResponsePayload + require.NoError(t, json.Unmarshal(params.NormalizedResponse.RawMessage, &resp)) + require.Equal(t, "stop", resp.FinishReason) + require.True(t, resp.StructuredOutput) + // RawText is `{"title":"test"}` = 16 runes. + require.Equal(t, 16, resp.RawTextLength) + + // Verify metadata contains structured_output flag. + var meta map[string]any + require.NoError(t, json.Unmarshal(params.Metadata.RawMessage, &meta)) + require.Equal(t, true, meta["structured_output"]) + }) + svc := NewService(db, testutil.Logger(t), nil) inner := &chattest.FakeModel{ GenerateObjectFn: func(ctx context.Context, got fantasy.ObjectCall) (*fantasy.ObjectResponse, error) { @@ -583,9 +1031,30 @@ func TestDebugModel_GenerateObjectError(t *testing.T) { ctrl := gomock.NewController(t) db := dbmock.NewMockStore(ctrl) chatID := uuid.New() + ownerID := uuid.New() runID := uuid.New() wantErr := &testError{message: "object boom"} + expectDebugLoggingEnabled(t, db, ownerID) + stepID := expectCreateStep(t, db, runID, chatID, OperationGenerate) + expectUpdateStep(t, db, stepID, chatID, StatusError, func(params database.UpdateChatDebugStepParams) { + require.False(t, params.NormalizedResponse.Valid) + require.False(t, params.Usage.Valid) + require.True(t, params.Attempts.Valid) + require.True(t, params.Error.Valid) + // GenerateObject always passes structured_output metadata. + require.True(t, params.Metadata.Valid) + + var errPayload normalizedErrorPayload + require.NoError(t, json.Unmarshal(params.Error.RawMessage, &errPayload)) + require.Equal(t, "object boom", errPayload.Message) + require.Equal(t, "*chatdebug.testError", errPayload.Type) + + var meta map[string]any + require.NoError(t, json.Unmarshal(params.Metadata.RawMessage, &meta)) + require.Equal(t, true, meta["structured_output"]) + }) + svc := NewService(db, testutil.Logger(t), nil) model := &debugModel{ inner: &chattest.FakeModel{ @@ -594,7 +1063,7 @@ func TestDebugModel_GenerateObjectError(t *testing.T) { }, }, svc: svc, - opts: RecorderOptions{ChatID: chatID, OwnerID: uuid.New()}, + opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID}, } t.Cleanup(func() { CleanupStepCounter(runID) }) ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID}) @@ -610,8 +1079,28 @@ func TestDebugModel_GenerateObjectRejectsNilResponse(t *testing.T) { ctrl := gomock.NewController(t) db := dbmock.NewMockStore(ctrl) chatID := uuid.New() + ownerID := uuid.New() runID := uuid.New() + expectDebugLoggingEnabled(t, db, ownerID) + stepID := expectCreateStep(t, db, runID, chatID, OperationGenerate) + expectUpdateStep(t, db, stepID, chatID, StatusError, func(params database.UpdateChatDebugStepParams) { + require.False(t, params.NormalizedResponse.Valid) + require.False(t, params.Usage.Valid) + require.True(t, params.Attempts.Valid) + require.True(t, params.Error.Valid) + // GenerateObject always passes structured_output metadata. + require.True(t, params.Metadata.Valid) + + var errPayload normalizedErrorPayload + require.NoError(t, json.Unmarshal(params.Error.RawMessage, &errPayload)) + require.Contains(t, errPayload.Message, "nil") + + var meta map[string]any + require.NoError(t, json.Unmarshal(params.Metadata.RawMessage, &meta)) + require.Equal(t, true, meta["structured_output"]) + }) + svc := NewService(db, testutil.Logger(t), nil) model := &debugModel{ inner: &chattest.FakeModel{ @@ -620,7 +1109,7 @@ func TestDebugModel_GenerateObjectRejectsNilResponse(t *testing.T) { }, }, svc: svc, - opts: RecorderOptions{ChatID: chatID, OwnerID: uuid.New()}, + opts: RecorderOptions{ChatID: chatID, OwnerID: ownerID}, } t.Cleanup(func() { CleanupStepCounter(runID) }) ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID}) @@ -721,3 +1210,168 @@ func TestWrapStreamSeq_DroppedStreamFinalizedOnCtxCancel(t *testing.T) { func int64Ptr(v int64) *int64 { return &v } func float64Ptr(v float64) *float64 { return &v } + +func TestLaunchHeartbeat(t *testing.T) { + t.Parallel() + + t.Run("fires_touch_step_on_tick", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + mClock := quartz.NewMock(t) + + // Use a small stale threshold so the heartbeat interval is + // short enough to test easily (threshold/2 = 5s, clamped ≥1s). + svc := NewService(db, testutil.Logger(t), nil, + WithClock(mClock), + WithStaleThreshold(10*time.Second), + ) + + stepID := uuid.New() + runID := uuid.New() + chatID := uuid.New() + + done := make(chan struct{}) + defer close(done) + + // Trap the ticker creation so we can control it. + tickerTrap := mClock.Trap().NewTicker("chatdebug", "heartbeat") + defer tickerTrap.Close() + + ctx := testutil.Context(t, testutil.WaitShort) + + // Expect atomic TouchStep calls via TouchChatDebugStepAndRun. + touchCalled := make(chan struct{}, 5) + db.EXPECT(). + TouchChatDebugStepAndRun(gomock.Any(), gomock.Any()). + DoAndReturn(func(_ context.Context, params database.TouchChatDebugStepAndRunParams) error { + require.Equal(t, stepID, params.StepID) + require.Equal(t, runID, params.RunID) + require.Equal(t, chatID, params.ChatID) + select { + case touchCalled <- struct{}{}: + default: + } + return nil + }). + AnyTimes() + + launchHeartbeat(ctx, svc, stepID, runID, chatID, done) + + // Wait for the ticker to be created. + tickerTrap.MustWait(ctx).MustRelease(ctx) + + // Advance the clock past one heartbeat interval (5s for a + // 10s stale threshold) and verify TouchStep fires. + mClock.Advance(5 * time.Second).MustWait(ctx) + + select { + case <-touchCalled: + case <-ctx.Done(): + t.Fatal("timed out waiting for first heartbeat touch") + } + + // Advance again to verify repeated heartbeats. + mClock.Advance(5 * time.Second).MustWait(ctx) + + select { + case <-touchCalled: + case <-ctx.Done(): + t.Fatal("timed out waiting for second heartbeat touch") + } + }) + + t.Run("stops_on_done_channel", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + mClock := quartz.NewMock(t) + + svc := NewService(db, testutil.Logger(t), nil, + WithClock(mClock), + WithStaleThreshold(10*time.Second), + ) + + stepID := uuid.New() + runID := uuid.New() + chatID := uuid.New() + + done := make(chan struct{}) + + tickerTrap := mClock.Trap().NewTicker("chatdebug", "heartbeat") + defer tickerTrap.Close() + + ctx := testutil.Context(t, testutil.WaitShort) + + launchHeartbeat(ctx, svc, stepID, runID, chatID, done) + tickerTrap.MustWait(ctx).MustRelease(ctx) + + // Close done to signal the heartbeat to stop. + close(done) + + // Give the goroutine a moment to observe the close. + // No TouchStep calls should happen after done is closed. + // (gomock would fail if TouchChatDebugStepAndRun was + // called without a matching expectation.) + }) + + t.Run("nil_service_noop", func(t *testing.T) { + t.Parallel() + + done := make(chan struct{}) + defer close(done) + + ctx := testutil.Context(t, testutil.WaitShort) + + // Should not panic. + launchHeartbeat(ctx, nil, uuid.New(), uuid.New(), uuid.New(), done) + }) + + t.Run("resets_ticker_on_threshold_change", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + mClock := quartz.NewMock(t) + + svc := NewService(db, testutil.Logger(t), nil, + WithClock(mClock), + WithStaleThreshold(60*time.Second), + ) + + stepID := uuid.New() + runID := uuid.New() + chatID := uuid.New() + + done := make(chan struct{}) + defer close(done) + + tickerTrap := mClock.Trap().NewTicker("chatdebug", "heartbeat") + defer tickerTrap.Close() + resetTrap := mClock.Trap().TickerReset("chatdebug", "heartbeat") + defer resetTrap.Close() + + ctx := testutil.Context(t, testutil.WaitShort) + + launchHeartbeat(ctx, svc, stepID, runID, chatID, done) + + // Confirm the ticker was created with the original + // threshold/2 interval. + newCall := tickerTrap.MustWait(ctx) + require.Equal(t, 30*time.Second, newCall.Duration) + newCall.MustRelease(ctx) + + // Reducing the threshold must wake the heartbeat via the + // thresholdChan close and trigger a ticker reset to + // newThreshold/2 without advancing the mock clock. + svc.SetStaleAfter(10 * time.Second) + + resetCall := resetTrap.MustWait(ctx) + require.Equal(t, 5*time.Second, resetCall.Duration, + "ticker should reset to newThreshold/2 when SetStaleAfter"+ + " shrinks the threshold") + resetCall.MustRelease(ctx) + }) +} diff --git a/coderd/x/chatd/chatdebug/recorder.go b/coderd/x/chatd/chatdebug/recorder.go index c3cc4d2c2a..04f5c8a11e 100644 --- a/coderd/x/chatd/chatdebug/recorder.go +++ b/coderd/x/chatd/chatdebug/recorder.go @@ -81,6 +81,49 @@ func attemptSinkFromContext(ctx context.Context) *attemptSink { var stepCounters sync.Map // map[uuid.UUID]*atomic.Int32 +// runRefCounts tracks how many live RunContext instances reference each +// RunID. Cleanup of shared state (step counters) is deferred until the +// last RunContext for a given RunID is garbage collected. +var ( + runRefCounts sync.Map // map[uuid.UUID]*atomic.Int32 + // refCountMu serializes trackRunRef and releaseRunRef so the + // decrement-to-zero check and subsequent map deletions are + // atomic with respect to new references being added. + refCountMu sync.Mutex +) + +func trackRunRef(runID uuid.UUID) { + refCountMu.Lock() + defer refCountMu.Unlock() + val, _ := runRefCounts.LoadOrStore(runID, &atomic.Int32{}) + counter, ok := val.(*atomic.Int32) + if !ok { + panic("chatdebug: runRefCounts contains non-*atomic.Int32 value") + } + counter.Add(1) +} + +// releaseRunRef decrements the reference count for runID and cleans up +// shared state when the last reference is released. The mutex ensures +// no concurrent trackRunRef can increment between the zero check and +// the map deletions. +func releaseRunRef(runID uuid.UUID) { + refCountMu.Lock() + defer refCountMu.Unlock() + val, ok := runRefCounts.Load(runID) + if !ok { + return + } + counter, ok := val.(*atomic.Int32) + if !ok { + panic("chatdebug: runRefCounts contains non-*atomic.Int32 value") + } + if counter.Add(-1) <= 0 { + runRefCounts.Delete(runID) + stepCounters.Delete(runID) + } +} + func nextStepNumber(runID uuid.UUID) int32 { val, _ := stepCounters.LoadOrStore(runID, &atomic.Int32{}) counter, ok := val.(*atomic.Int32) @@ -129,13 +172,16 @@ type stepHandle struct { sink *attemptSink svc *Service opts RecorderOptions - once sync.Once mu sync.Mutex status Status response any usage any err any metadata any + // hadError tracks whether a prior finalization wrote an error + // payload. Used to decide whether a successful retry needs to + // explicitly clear the error field via jsonClear. + hadError bool } // beginStep validates preconditions, creates a debug step, and returns a @@ -223,11 +269,11 @@ func beginStep( return handle, enriched } -// finish updates the debug step with final status and data. -// sync.Once prevents data races when concurrent callers (e.g. -// retried stream wrappers sharing a reuse handle) both attempt -// to finalize the same step. Only the first finish call takes -// effect. +// finish updates the debug step with final status and data. A mutex +// guards the write so concurrent callers (e.g. retried stream wrappers +// sharing a reuse handle) don't race. Later retries are allowed to +// overwrite earlier failure results so the step reflects the final +// outcome, but stale callbacks cannot regress a terminal state. func (h *stepHandle) finish( ctx context.Context, status Status, @@ -240,38 +286,63 @@ func (h *stepHandle) finish( return } - h.once.Do(func() { - h.mu.Lock() - h.status = status - h.response = response - h.usage = usage - h.err = errPayload - h.metadata = metadata - h.mu.Unlock() - if h.svc == nil { - return - } + h.mu.Lock() + defer h.mu.Unlock() - updateCtx, cancel := stepFinalizeContext(ctx) - defer cancel() + // Reject stale callbacks that would regress a terminal state. + // Status priority: in_progress < interrupted < error < completed. + // A tardy safety-net writing "interrupted" cannot clobber a step + // that already reached "completed" or "error" from a real retry. + // Equal-priority updates are allowed so that retries ending in the + // same terminal class (e.g. error → error under ReuseStep) can + // still update the step with newer attempt data. + if h.status.IsTerminal() && status.Priority() < h.status.Priority() { + return + } - if _, updateErr := h.svc.UpdateStep(updateCtx, UpdateStepParams{ - ID: h.stepCtx.StepID, - ChatID: h.stepCtx.ChatID, - Status: status, - NormalizedResponse: response, - Usage: usage, - Attempts: h.sink.snapshot(), - Error: errPayload, - Metadata: metadata, - FinishedAt: time.Now(), - }); updateErr != nil { - h.svc.log.Warn(updateCtx, "failed to finalize chat debug step", - slog.Error(updateErr), - slog.F("step_id", h.stepCtx.StepID), - slog.F("chat_id", h.stepCtx.ChatID), - slog.F("status", status), - ) - } - }) + h.status = status + h.response = response + h.usage = usage + h.err = errPayload + h.metadata = metadata + if errPayload != nil { + h.hadError = true + } + if h.svc == nil { + return + } + + updateCtx, cancel := stepFinalizeContext(ctx) + defer cancel() + + // When the step completes successfully after a prior failed + // attempt, the error field must be explicitly cleared. A plain + // nil would leave the COALESCE-based SQL untouched, so we send + // jsonClear{} which serializes as a valid JSONB null. Only do + // this when a prior error was actually recorded; otherwise + // clean successes would get a spurious JSONB null that downstream + // aggregation could misread as an error. + errValue := errPayload + if errValue == nil && status == StatusCompleted && h.hadError { + errValue = jsonClear{} + } + + if _, updateErr := h.svc.UpdateStep(updateCtx, UpdateStepParams{ + ID: h.stepCtx.StepID, + ChatID: h.stepCtx.ChatID, + Status: status, + NormalizedResponse: response, + Usage: usage, + Attempts: h.sink.snapshot(), + Error: errValue, + Metadata: metadata, + FinishedAt: h.svc.clock.Now(), + }); updateErr != nil { + h.svc.log.Warn(updateCtx, "failed to finalize chat debug step", + slog.Error(updateErr), + slog.F("step_id", h.stepCtx.StepID), + slog.F("chat_id", h.stepCtx.ChatID), + slog.F("status", status), + ) + } } diff --git a/coderd/x/chatd/chatdebug/recorder_test.go b/coderd/x/chatd/chatdebug/recorder_test.go index 2907577354..ca85573ac9 100644 --- a/coderd/x/chatd/chatdebug/recorder_test.go +++ b/coderd/x/chatd/chatdebug/recorder_test.go @@ -9,8 +9,11 @@ import ( "charm.land/fantasy" "github.com/google/uuid" "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + "github.com/coder/coder/v2/coderd/database/dbmock" "github.com/coder/coder/v2/coderd/x/chatd/chattest" + "github.com/coder/coder/v2/testutil" ) func TestAttemptSink_ThreadSafe(t *testing.T) { @@ -145,11 +148,18 @@ func TestBeginStep_NilService(t *testing.T) { func TestBeginStep_FallsBackToRunChatID(t *testing.T) { t.Parallel() + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) runID := uuid.New() runChatID := uuid.New() - ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: runChatID}) + ownerID := uuid.New() + expectDebugLoggingEnabled(t, db, ownerID) + expectCreateStepNumberWithRequestValidity(t, db, runID, runChatID, 1, OperationGenerate, false) - handle, enriched := beginStep(ctx, &Service{}, RecorderOptions{}, OperationGenerate, nil) + ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: runChatID}) + svc := NewService(db, testutil.Logger(t), nil) + + handle, enriched := beginStep(ctx, svc, RecorderOptions{OwnerID: ownerID}, OperationGenerate, nil) require.NotNil(t, handle) require.Equal(t, runChatID, handle.stepCtx.ChatID) diff --git a/coderd/x/chatd/chatdebug/reuse_step_test.go b/coderd/x/chatd/chatdebug/reuse_step_test.go index 90a06b7e21..d878d86599 100644 --- a/coderd/x/chatd/chatdebug/reuse_step_test.go +++ b/coderd/x/chatd/chatdebug/reuse_step_test.go @@ -6,7 +6,9 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + "github.com/coder/coder/v2/coderd/database/dbmock" "github.com/coder/coder/v2/testutil" ) @@ -21,7 +23,21 @@ func TestBeginStepReuseStep(t *testing.T) { runID := uuid.New() t.Cleanup(func() { CleanupStepCounter(runID) }) - svc := NewService(nil, testutil.Logger(t), nil) + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + expectDebugLoggingEnabled(t, db, ownerID) + expectCreateStepNumberWithRequestValidity( + t, + db, + runID, + chatID, + 1, + OperationStream, + false, + ) + expectDebugLoggingEnabled(t, db, ownerID) + + svc := NewService(db, testutil.Logger(t), nil) ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID}) ctx = ReuseStep(ctx) opts := RecorderOptions{ChatID: chatID, OwnerID: ownerID} @@ -56,7 +72,30 @@ func TestBeginStepReuseStep(t *testing.T) { runID := uuid.New() t.Cleanup(func() { CleanupStepCounter(runID) }) - svc := NewService(nil, testutil.Logger(t), nil) + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + expectDebugLoggingEnabled(t, db, ownerID) + expectCreateStepNumberWithRequestValidity( + t, + db, + runID, + chatID, + 1, + OperationStream, + false, + ) + expectDebugLoggingEnabled(t, db, ownerID) + expectCreateStepNumberWithRequestValidity( + t, + db, + runID, + chatID, + 2, + OperationStream, + false, + ) + + svc := NewService(db, testutil.Logger(t), nil) ctx := ContextWithRun(context.Background(), &RunContext{RunID: runID, ChatID: chatID}) opts := RecorderOptions{ChatID: chatID, OwnerID: ownerID} diff --git a/coderd/x/chatd/chatdebug/service.go b/coderd/x/chatd/chatdebug/service.go new file mode 100644 index 0000000000..d2cb728e03 --- /dev/null +++ b/coderd/x/chatd/chatdebug/service.go @@ -0,0 +1,685 @@ +package chatdebug + +import ( + "bytes" + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/google/uuid" + "github.com/sqlc-dev/pqtype" + "golang.org/x/xerrors" + + "cdr.dev/slog/v3" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbauthz" + "github.com/coder/coder/v2/coderd/database/pubsub" + "github.com/coder/quartz" +) + +// DefaultStaleThreshold is the fallback stale timeout for debug rows +// when no caller-provided value is supplied. +const DefaultStaleThreshold = 5 * time.Minute + +// Service persists chat debug rows and fans out lightweight change events. +type Service struct { + db database.Store + log slog.Logger + pubsub pubsub.Pubsub + clock quartz.Clock + alwaysEnable bool + // staleAfterNanos stores the stale threshold as nanoseconds in an + // atomic.Int64 so SetStaleAfter and FinalizeStale can be called + // from concurrent goroutines without a data race. + staleAfterNanos atomic.Int64 + + // thresholdMu protects thresholdChanged. + thresholdMu sync.Mutex + // thresholdChanged is closed by SetStaleAfter to wake heartbeat + // goroutines so they can re-read the (possibly shorter) interval + // immediately instead of waiting for the old ticker to fire. + thresholdChanged chan struct{} +} + +// ServiceOption configures optional Service behavior. +type ServiceOption func(*Service) + +// WithStaleThreshold overrides the default stale-row finalization +// threshold. Callers that already have a configurable in-flight chat +// timeout (e.g. chatd's InFlightChatStaleAfter) should pass it here +// so the two sweeps stay in sync. +func WithStaleThreshold(d time.Duration) ServiceOption { + return func(s *Service) { + if d > 0 { + s.staleAfterNanos.Store(d.Nanoseconds()) + } + } +} + +// WithAlwaysEnable forces debug logging on for every chat regardless +// of the runtime admin and user opt-in settings. This is used for the +// deployment-level serpent flag. +func WithAlwaysEnable(always bool) ServiceOption { + return func(s *Service) { + s.alwaysEnable = always + } +} + +// WithClock overrides the default real clock. Tests inject +// quartz.NewMock(t) to control time-dependent behavior such as +// heartbeat tickers and FinalizeStale timestamps. +func WithClock(c quartz.Clock) ServiceOption { + return func(s *Service) { + if c != nil { + s.clock = c + } + } +} + +// CreateRunParams contains friendly inputs for creating a debug run. +type CreateRunParams struct { + ChatID uuid.UUID + RootChatID uuid.UUID + ParentChatID uuid.UUID + ModelConfigID uuid.UUID + TriggerMessageID int64 + HistoryTipMessageID int64 + Kind RunKind + Status Status + Provider string + Model string + Summary any +} + +// UpdateRunParams contains inputs for updating a debug run. +// Zero-valued fields are treated as "keep the existing value" by the +// COALESCE-based SQL query. Once a field is set it cannot be cleared +// back to NULL; this is intentional for the write-once-finalize +// lifecycle of debug rows. +type UpdateRunParams struct { + ID uuid.UUID + ChatID uuid.UUID + Status Status + Summary any + FinishedAt time.Time +} + +// CreateStepParams contains friendly inputs for creating a debug step. +type CreateStepParams struct { + RunID uuid.UUID + ChatID uuid.UUID + StepNumber int32 + Operation Operation + Status Status + HistoryTipMessageID int64 + NormalizedRequest any +} + +// UpdateStepParams contains optional inputs for updating a debug step. +// Most payload fields are typed as any and serialized through nullJSON +// because their shape varies by provider. The Attempts field uses a +// concrete slice for compile-time safety where the schema is stable. +// Zero-valued fields are treated as "keep the existing value" by the +// COALESCE-based SQL query. Once set, fields cannot be cleared back +// to NULL. This is intentional for the write-once-finalize lifecycle +// of debug rows. +type UpdateStepParams struct { + ID uuid.UUID + ChatID uuid.UUID + Status Status + AssistantMessageID int64 + NormalizedResponse any + Usage any + Attempts []Attempt + Error any + Metadata any + FinishedAt time.Time +} + +// NewService constructs a chat debug persistence service. +func NewService(db database.Store, log slog.Logger, ps pubsub.Pubsub, opts ...ServiceOption) *Service { + if db == nil { + panic("chatdebug: nil database.Store") + } + + s := &Service{ + db: db, + log: log, + pubsub: ps, + clock: quartz.NewReal(), + thresholdChanged: make(chan struct{}), + } + s.staleAfterNanos.Store(DefaultStaleThreshold.Nanoseconds()) + for _, opt := range opts { + opt(s) + } + return s +} + +// SetStaleAfter overrides the in-flight stale threshold used when +// finalizing abandoned debug rows. Zero or negative durations are +// ignored, leaving the current threshold (initial or previously +// overridden) unchanged. Active heartbeat goroutines are woken so +// they can re-read the (possibly shorter) interval immediately. +func (s *Service) SetStaleAfter(staleAfter time.Duration) { + if s == nil || staleAfter <= 0 { + return + } + s.staleAfterNanos.Store(staleAfter.Nanoseconds()) + + // Wake all heartbeat goroutines by closing the current channel + // and replacing it with a fresh one for the next update. + s.thresholdMu.Lock() + close(s.thresholdChanged) + s.thresholdChanged = make(chan struct{}) + s.thresholdMu.Unlock() +} + +// thresholdChan returns the current threshold-change notification +// channel. Heartbeat goroutines select on this to detect runtime +// stale-threshold updates. +func (s *Service) thresholdChan() <-chan struct{} { + s.thresholdMu.Lock() + defer s.thresholdMu.Unlock() + return s.thresholdChanged +} + +// staleThreshold returns the current stale timeout. +func (s *Service) staleThreshold() time.Duration { + ns := s.staleAfterNanos.Load() + d := time.Duration(ns) + if d <= 0 { + return DefaultStaleThreshold + } + return d +} + +// heartbeatInterval returns a safe ticker interval for stream heartbeats. +// It is half the stale threshold so at least one touch lands before the +// stale sweep considers the row abandoned. The result is clamped to a +// minimum of 1 ms to prevent panics from time.NewTicker(0) with +// pathologically small thresholds, while still staying well below any +// practical stale timeout. +func (s *Service) heartbeatInterval() time.Duration { + return max(s.staleThreshold()/2, time.Millisecond) +} + +func chatdContext(ctx context.Context) context.Context { + //nolint:gocritic // AsChatd provides narrowly-scoped daemon access for + // chat debug persistence reads and writes. + return dbauthz.AsChatd(ctx) +} + +// IsEnabled returns whether debug logging is enabled for the given chat. +func (s *Service) IsEnabled( + ctx context.Context, + chatID uuid.UUID, + ownerID uuid.UUID, +) bool { + if s == nil { + return false + } + if s.alwaysEnable { + return true + } + if s.db == nil { + return false + } + + authCtx := chatdContext(ctx) + + allowUsers, err := s.db.GetChatDebugLoggingAllowUsers(authCtx) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return false + } + s.log.Warn(ctx, "failed to load runtime admin chat debug logging setting", + slog.Error(err), + ) + return false + } + if !allowUsers { + return false + } + + if ownerID == uuid.Nil { + s.log.Warn(ctx, "missing chat owner for debug logging enablement check", + slog.F("chat_id", chatID), + ) + return false + } + + enabled, err := s.db.GetUserChatDebugLoggingEnabled(authCtx, ownerID) + if err == nil { + return enabled + } + if errors.Is(err, sql.ErrNoRows) { + return false + } + + s.log.Warn(ctx, "failed to load user chat debug logging setting", + slog.Error(err), + slog.F("chat_id", chatID), + slog.F("owner_id", ownerID), + ) + return false +} + +// CreateRun inserts a new debug run and emits a run update event. +func (s *Service) CreateRun( + ctx context.Context, + params CreateRunParams, +) (database.ChatDebugRun, error) { + now := s.clock.Now() + run, err := s.db.InsertChatDebugRun(chatdContext(ctx), + database.InsertChatDebugRunParams{ + ChatID: params.ChatID, + RootChatID: nullUUID(params.RootChatID), + ParentChatID: nullUUID(params.ParentChatID), + ModelConfigID: nullUUID(params.ModelConfigID), + TriggerMessageID: nullInt64(params.TriggerMessageID), + HistoryTipMessageID: nullInt64(params.HistoryTipMessageID), + Kind: string(params.Kind), + Status: string(params.Status), + Provider: nullString(params.Provider), + Model: nullString(params.Model), + Summary: s.nullJSON(ctx, params.Summary), + StartedAt: sql.NullTime{Time: now, Valid: true}, + UpdatedAt: sql.NullTime{Time: now, Valid: true}, + FinishedAt: sql.NullTime{}, + }) + if err != nil { + return database.ChatDebugRun{}, err + } + + s.publishEvent(ctx, run.ChatID, EventKindRunUpdate, run.ID, uuid.Nil) + return run, nil +} + +// UpdateRun updates an existing debug run and emits a run update event. +// When a terminal status is set without an explicit FinishedAt, the +// service auto-fills the timestamp so the row is immediately visible +// to the InsertChatDebugStep atomic guard (finished_at IS NULL). +// UpdateChatDebugRun itself enforces finished_at as write-once: once +// the column is populated, repeated auto-fills or explicit refreshes +// never overwrite the original completion timestamp, so calling this +// more than once on an already-finalized run is idempotent. +func (s *Service) UpdateRun( + ctx context.Context, + params UpdateRunParams, +) (database.ChatDebugRun, error) { + if params.Status.IsTerminal() && params.FinishedAt.IsZero() { + params.FinishedAt = s.clock.Now() + } + run, err := s.db.UpdateChatDebugRun(chatdContext(ctx), + database.UpdateChatDebugRunParams{ + RootChatID: uuid.NullUUID{}, + ParentChatID: uuid.NullUUID{}, + ModelConfigID: uuid.NullUUID{}, + TriggerMessageID: sql.NullInt64{}, + HistoryTipMessageID: sql.NullInt64{}, + Status: nullString(string(params.Status)), + Provider: sql.NullString{}, + Model: sql.NullString{}, + Summary: s.nullJSON(ctx, params.Summary), + FinishedAt: nullTime(params.FinishedAt), + Now: s.clock.Now(), + ID: params.ID, + ChatID: params.ChatID, + }) + if err != nil { + return database.ChatDebugRun{}, err + } + + s.publishEvent(ctx, run.ChatID, EventKindRunUpdate, run.ID, uuid.Nil) + return run, nil +} + +// errRunFinalized is returned by CreateStep when the parent run has +// already reached a terminal state (finished_at IS NOT NULL). This +// prevents delayed retries from appending in-progress steps to runs +// that FinalizeStale already marked as interrupted. +var errRunFinalized = xerrors.New("parent run is already finalized") + +// errRunNotFound is returned by CreateStep when the parent run cannot +// be located (missing run_id or chat_id mismatch). This surfaces +// caller-side data bugs instead of conflating them with the legitimate +// "already finalized" terminal case. +var errRunNotFound = xerrors.New("parent run not found") + +// CreateStep inserts a new debug step and emits a step update event. +// It returns errRunFinalized if the parent run has already finished, +// or errRunNotFound if the run_id/chat_id pair does not match an +// existing run. The finalization guard is enforced atomically by the +// INSERT's CTE, which issues an UPDATE on the parent run (taking a +// row lock). This prevents concurrent FinalizeStale from setting +// finished_at between the check and the INSERT. +func (s *Service) CreateStep( + ctx context.Context, + params CreateStepParams, +) (database.ChatDebugStep, error) { + now := s.clock.Now() + insert := database.InsertChatDebugStepParams{ + RunID: params.RunID, + StepNumber: params.StepNumber, + Operation: string(params.Operation), + Status: string(params.Status), + HistoryTipMessageID: nullInt64(params.HistoryTipMessageID), + AssistantMessageID: sql.NullInt64{}, + NormalizedRequest: s.nullJSON(ctx, params.NormalizedRequest), + NormalizedResponse: pqtype.NullRawMessage{}, + Usage: pqtype.NullRawMessage{}, + Attempts: pqtype.NullRawMessage{}, + Error: pqtype.NullRawMessage{}, + Metadata: pqtype.NullRawMessage{}, + StartedAt: sql.NullTime{Time: now, Valid: true}, + UpdatedAt: sql.NullTime{Time: now, Valid: true}, + FinishedAt: sql.NullTime{}, + ChatID: params.ChatID, + } + + // Cap retry attempts to prevent infinite loops under + // pathological concurrency. Each iteration performs two DB + // round-trips (insert + list), so 10 retries is generous. + const maxCreateStepRetries = 10 + + for range maxCreateStepRetries { + if err := ctx.Err(); err != nil { + return database.ChatDebugStep{}, err + } + + step, err := s.db.InsertChatDebugStep(chatdContext(ctx), insert) + if err == nil { + // The INSERT CTE atomically bumps the parent run's + // updated_at, so no separate touch call is needed. + s.publishEvent(ctx, step.ChatID, EventKindStepUpdate, step.RunID, step.ID) + return step, nil + } + // The INSERT's locked_run CTE filters on id, chat_id, and + // finished_at IS NULL, so sql.ErrNoRows can mean "run not + // found", "chat_id mismatch", or "already finalized." Look + // the run up to disambiguate instead of conflating + // caller-side data bugs with the legitimate terminal case. + if errors.Is(err, sql.ErrNoRows) { + return database.ChatDebugStep{}, s.classifyMissingRun(ctx, params) + } + if !database.IsUniqueViolation(err, database.UniqueIndexChatDebugStepsRunStep) { + return database.ChatDebugStep{}, err + } + + steps, listErr := s.db.GetChatDebugStepsByRunID(chatdContext(ctx), params.RunID) + if listErr != nil { + return database.ChatDebugStep{}, listErr + } + nextStepNumber := insert.StepNumber + 1 + for _, existing := range steps { + if existing.StepNumber >= nextStepNumber { + nextStepNumber = existing.StepNumber + 1 + } + } + insert.StepNumber = nextStepNumber + } + + return database.ChatDebugStep{}, xerrors.Errorf( + "failed to create debug step after %d attempts (run_id=%s)", + maxCreateStepRetries, params.RunID, + ) +} + +// classifyMissingRun disambiguates the sql.ErrNoRows returned by +// InsertChatDebugStep's locked_run CTE. The CTE filters on id, +// chat_id, and finished_at IS NULL, so empty RETURNING rows can mean +// the run is absent, belongs to a different chat, or has already been +// finalized. GetChatDebugRunByID is keyed only by id, which is +// sufficient to tell these cases apart. +func (s *Service) classifyMissingRun( + ctx context.Context, + params CreateStepParams, +) error { + run, err := s.db.GetChatDebugRunByID(chatdContext(ctx), params.RunID) + if errors.Is(err, sql.ErrNoRows) { + return errRunNotFound + } + if err != nil { + return xerrors.Errorf("look up parent run after failed step insert: %w", err) + } + if run.ChatID != params.ChatID { + return errRunNotFound + } + if run.FinishedAt.Valid { + return errRunFinalized + } + // The run matches the caller's (run_id, chat_id) and is still + // open, yet the INSERT returned no rows. This is unexpected + // under write-once-finalize semantics and likely indicates a + // concurrent delete or unrelated defect; surface it instead of + // silently masking it as a terminal case. + return xerrors.Errorf( + "InsertChatDebugStep returned no rows but run is still active (run_id=%s)", + params.RunID, + ) +} + +// UpdateStep updates an existing debug step and emits a step update event. +// When a terminal status is set without an explicit FinishedAt, the +// service auto-fills the timestamp so the stale sweep does not leave +// terminal rows with finished_at = NULL. +func (s *Service) UpdateStep( + ctx context.Context, + params UpdateStepParams, +) (database.ChatDebugStep, error) { + if params.Status.IsTerminal() && params.FinishedAt.IsZero() { + params.FinishedAt = s.clock.Now() + } + step, err := s.db.UpdateChatDebugStep(chatdContext(ctx), + database.UpdateChatDebugStepParams{ + Status: nullString(string(params.Status)), + HistoryTipMessageID: sql.NullInt64{}, + AssistantMessageID: nullInt64(params.AssistantMessageID), + NormalizedRequest: pqtype.NullRawMessage{}, + NormalizedResponse: s.nullJSON(ctx, params.NormalizedResponse), + Usage: s.nullJSON(ctx, params.Usage), + Attempts: s.nullJSON(ctx, params.Attempts), + Error: s.nullJSON(ctx, params.Error), + Metadata: s.nullJSON(ctx, params.Metadata), + FinishedAt: nullTime(params.FinishedAt), + Now: s.clock.Now(), + ID: params.ID, + ChatID: params.ChatID, + }) + if err != nil { + return database.ChatDebugStep{}, err + } + + s.publishEvent(ctx, step.ChatID, EventKindStepUpdate, step.RunID, step.ID) + return step, nil +} + +// TouchStep bumps the step's and its parent run's updated_at timestamps +// without changing any other fields. This prevents long-running operations +// (e.g. streaming) from being prematurely swept by FinalizeStale, which +// first marks runs stale by chat_debug_runs.updated_at and then cascades +// to steps whose run_id was just finalized. +func (s *Service) TouchStep( + ctx context.Context, + stepID uuid.UUID, + runID uuid.UUID, + chatID uuid.UUID, +) error { + // Atomically bump both the step and its parent run so + // FinalizeStale cannot interleave between the two touches. + return s.db.TouchChatDebugStepAndRun(chatdContext(ctx), + database.TouchChatDebugStepAndRunParams{ + Now: s.clock.Now(), + StepID: stepID, + RunID: runID, + ChatID: chatID, + }) +} + +// DeleteByChatID deletes all debug data for a chat and emits a delete event. +func (s *Service) DeleteByChatID( + ctx context.Context, + chatID uuid.UUID, +) (int64, error) { + deleted, err := s.db.DeleteChatDebugDataByChatID(chatdContext(ctx), chatID) + if err != nil { + return 0, err + } + + s.publishEvent(ctx, chatID, EventKindDelete, uuid.Nil, uuid.Nil) + return deleted, nil +} + +// DeleteAfterMessageID deletes debug data newer than the given message. +func (s *Service) DeleteAfterMessageID( + ctx context.Context, + chatID uuid.UUID, + messageID int64, +) (int64, error) { + deleted, err := s.db.DeleteChatDebugDataAfterMessageID( + chatdContext(ctx), + database.DeleteChatDebugDataAfterMessageIDParams{ + ChatID: chatID, + MessageID: messageID, + }, + ) + if err != nil { + return 0, err + } + + s.publishEvent(ctx, chatID, EventKindDelete, uuid.Nil, uuid.Nil) + return deleted, nil +} + +// FinalizeStale finalizes stale in-flight debug rows and emits a broadcast. +func (s *Service) FinalizeStale( + ctx context.Context, +) (database.FinalizeStaleChatDebugRowsRow, error) { + now := s.clock.Now() + result, err := s.db.FinalizeStaleChatDebugRows( + chatdContext(ctx), + database.FinalizeStaleChatDebugRowsParams{ + Now: now, + UpdatedBefore: now.Add(-s.staleThreshold()), + }, + ) + if err != nil { + return database.FinalizeStaleChatDebugRowsRow{}, err + } + + if result.RunsFinalized > 0 || result.StepsFinalized > 0 { + s.publishEvent(ctx, uuid.Nil, EventKindFinalize, uuid.Nil, uuid.Nil) + } + return result, nil +} + +func nullUUID(id uuid.UUID) uuid.NullUUID { + return uuid.NullUUID{UUID: id, Valid: id != uuid.Nil} +} + +func nullInt64(v int64) sql.NullInt64 { + return sql.NullInt64{Int64: v, Valid: v != 0} +} + +func nullString(value string) sql.NullString { + return sql.NullString{String: value, Valid: value != ""} +} + +func nullTime(value time.Time) sql.NullTime { + return sql.NullTime{Time: value, Valid: !value.IsZero()} +} + +// jsonClear is a sentinel value that tells nullJSON to emit a valid +// JSON null (JSONB 'null') instead of SQL NULL. COALESCE treats SQL +// NULL as "keep existing" but replaces with a non-NULL JSONB value, +// so passing jsonClear explicitly overwrites a previously set field. +type jsonClear struct{} + +// nullJSON marshals value to a NullRawMessage. When value is nil +// (including typed nils such as `var p *T = nil` whose interface +// representation carries a type but no value) or marshals to JSON +// "null", the result is {Valid: false}. Typed nils fall through the +// `value == nil` guard but produce `[]byte("null")` from +// json.Marshal, which the `bytes.Equal(data, []byte("null"))` check +// catches identically. This is intentional for the write-once-finalize +// pattern: combined with the COALESCE-based UPDATE queries, passing +// nil (typed or untyped) preserves the existing column value. Fields +// accumulate monotonically (request -> response -> usage -> error) and +// never need to be cleared during normal operation. The jsonClear +// sentinel exists for the sole exception (error retry clearing). +func (s *Service) nullJSON(ctx context.Context, value any) pqtype.NullRawMessage { + if value == nil { + return pqtype.NullRawMessage{} + } + // Sentinel: emit a valid JSONB null so COALESCE replaces + // any previously stored value. + if _, ok := value.(jsonClear); ok { + return pqtype.NullRawMessage{ + RawMessage: json.RawMessage("null"), + Valid: true, + } + } + + data, err := json.Marshal(value) + if err != nil { + s.log.Warn(ctx, "failed to marshal chat debug JSON", + slog.Error(err), + slog.F("value_type", fmt.Sprintf("%T", value)), + ) + return pqtype.NullRawMessage{} + } + if bytes.Equal(data, []byte("null")) { + return pqtype.NullRawMessage{} + } + + return pqtype.NullRawMessage{RawMessage: data, Valid: true} +} + +func (s *Service) publishEvent( + ctx context.Context, + chatID uuid.UUID, + kind EventKind, + runID uuid.UUID, + stepID uuid.UUID, +) { + if s.pubsub == nil { + s.log.Debug(ctx, + "chat debug pubsub unavailable; skipping event", + slog.F("kind", kind), + slog.F("chat_id", chatID), + ) + return + } + + event := DebugEvent{ + Kind: kind, + ChatID: chatID, + RunID: runID, + StepID: stepID, + } + data, err := json.Marshal(event) + if err != nil { + s.log.Warn(ctx, "failed to marshal chat debug event", + slog.Error(err), + slog.F("kind", kind), + slog.F("chat_id", chatID), + ) + return + } + + channel := PubsubChannel(chatID) + if err := s.pubsub.Publish(channel, data); err != nil { + s.log.Warn(ctx, "failed to publish chat debug event", + slog.Error(err), + slog.F("channel", channel), + slog.F("kind", kind), + slog.F("chat_id", chatID), + ) + } +} diff --git a/coderd/x/chatd/chatdebug/service_test.go b/coderd/x/chatd/chatdebug/service_test.go new file mode 100644 index 0000000000..a52e166017 --- /dev/null +++ b/coderd/x/chatd/chatdebug/service_test.go @@ -0,0 +1,1054 @@ +package chatdebug_test + +import ( + "context" + "database/sql" + "encoding/json" + "testing" + "time" + + "github.com/google/uuid" + "github.com/lib/pq" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbgen" + "github.com/coder/coder/v2/coderd/database/dbmock" + "github.com/coder/coder/v2/coderd/database/dbtestutil" + dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub" + "github.com/coder/coder/v2/coderd/x/chatd/chatdebug" + "github.com/coder/coder/v2/coderd/x/chatd/chatprompt" + "github.com/coder/coder/v2/codersdk" + "github.com/coder/coder/v2/testutil" +) + +type testFixture struct { + ctx context.Context + db database.Store + svc *chatdebug.Service + org database.Organization + owner database.User + chat database.Chat + model database.ChatModelConfig +} + +func TestService_IsEnabled(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + db, _, _ := dbtestutil.NewDBWithSQLDB(t) + _, owner, chat, model := seedChat(ctx, t, db) + require.NotEqual(t, uuid.Nil, model.ID) + + svc := chatdebug.NewService(db, testutil.Logger(t), nil) + + // Default is off until an admin allows user opt-in. + require.False(t, svc.IsEnabled(ctx, chat.ID, owner.ID)) + + err := db.UpsertChatDebugLoggingAllowUsers(ctx, true) + require.NoError(t, err) + // Allowing user opt-in is not enough on its own; the user must opt in. + require.False(t, svc.IsEnabled(ctx, chat.ID, owner.ID)) + require.False(t, svc.IsEnabled(ctx, chat.ID, uuid.Nil)) + + err = db.UpsertUserChatDebugLoggingEnabled(ctx, + database.UpsertUserChatDebugLoggingEnabledParams{ + UserID: owner.ID, + DebugLoggingEnabled: true, + }, + ) + require.NoError(t, err) + require.True(t, svc.IsEnabled(ctx, chat.ID, owner.ID)) + + err = db.UpsertUserChatDebugLoggingEnabled(ctx, + database.UpsertUserChatDebugLoggingEnabledParams{ + UserID: owner.ID, + DebugLoggingEnabled: false, + }, + ) + require.NoError(t, err) + require.False(t, svc.IsEnabled(ctx, chat.ID, owner.ID)) +} + +func TestService_IsEnabled_AlwaysEnable(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + db, _, _ := dbtestutil.NewDBWithSQLDB(t) + _, owner, chat, model := seedChat(ctx, t, db) + require.NotEqual(t, uuid.Nil, model.ID) + + svc := chatdebug.NewService(db, testutil.Logger(t), nil, chatdebug.WithAlwaysEnable(true)) + require.True(t, svc.IsEnabled(ctx, chat.ID, owner.ID)) + require.True(t, svc.IsEnabled(ctx, chat.ID, uuid.Nil)) +} + +func TestService_IsEnabled_ZeroValueService(t *testing.T) { + t.Parallel() + + var svc *chatdebug.Service + require.False(t, svc.IsEnabled(context.Background(), uuid.Nil, uuid.Nil)) + + require.False(t, (&chatdebug.Service{}).IsEnabled(context.Background(), uuid.Nil, uuid.Nil)) +} + +func TestService_CreateRun(t *testing.T) { + t.Parallel() + + fixture := newFixture(t) + rootChat := insertChat(fixture.ctx, t, fixture.db, fixture.org.ID, fixture.owner.ID, fixture.model.ID) + parentChat := insertChat(fixture.ctx, t, fixture.db, fixture.org.ID, fixture.owner.ID, fixture.model.ID) + triggerMsg := insertMessage(fixture.ctx, t, fixture.db, fixture.chat.ID, + fixture.owner.ID, fixture.model.ID, database.ChatMessageRoleUser, "trigger") + historyTipMsg := insertMessage(fixture.ctx, t, fixture.db, fixture.chat.ID, + fixture.owner.ID, fixture.model.ID, database.ChatMessageRoleAssistant, + "history-tip") + + run, err := fixture.svc.CreateRun(fixture.ctx, chatdebug.CreateRunParams{ + ChatID: fixture.chat.ID, + RootChatID: rootChat.ID, + ParentChatID: parentChat.ID, + ModelConfigID: fixture.model.ID, + TriggerMessageID: triggerMsg.ID, + HistoryTipMessageID: historyTipMsg.ID, + Kind: chatdebug.KindChatTurn, + Status: chatdebug.StatusInProgress, + Provider: fixture.model.Provider, + Model: fixture.model.Model, + Summary: map[string]any{ + "phase": "create", + "count": 1, + }, + }) + require.NoError(t, err) + assertRunMatches(t, run, fixture.chat.ID, rootChat.ID, parentChat.ID, + fixture.model.ID, triggerMsg.ID, historyTipMsg.ID, + chatdebug.KindChatTurn, chatdebug.StatusInProgress, + fixture.model.Provider, fixture.model.Model, + `{"count":1,"phase":"create"}`) + + stored, err := fixture.db.GetChatDebugRunByID(fixture.ctx, run.ID) + require.NoError(t, err) + require.Equal(t, run.ID, stored.ID) + require.JSONEq(t, string(run.Summary), string(stored.Summary)) +} + +func TestService_CreateRun_TypedNilSummaryUsesDefaultObject(t *testing.T) { + t.Parallel() + + fixture := newFixture(t) + var summary map[string]any + + run, err := fixture.svc.CreateRun(fixture.ctx, chatdebug.CreateRunParams{ + ChatID: fixture.chat.ID, + Kind: chatdebug.KindChatTurn, + Status: chatdebug.StatusInProgress, + Summary: summary, + }) + require.NoError(t, err) + require.JSONEq(t, `{}`, string(run.Summary)) +} + +func TestService_UpdateRun(t *testing.T) { + t.Parallel() + + fixture := newFixture(t) + run, err := fixture.svc.CreateRun(fixture.ctx, chatdebug.CreateRunParams{ + ChatID: fixture.chat.ID, + Kind: chatdebug.KindChatTurn, + Status: chatdebug.StatusInProgress, + Summary: map[string]any{ + "before": true, + }, + }) + require.NoError(t, err) + + finishedAt := time.Now().UTC().Round(time.Microsecond) + updated, err := fixture.svc.UpdateRun(fixture.ctx, chatdebug.UpdateRunParams{ + ID: run.ID, + ChatID: fixture.chat.ID, + Status: chatdebug.StatusCompleted, + Summary: map[string]any{"after": "done"}, + FinishedAt: finishedAt, + }) + require.NoError(t, err) + require.Equal(t, string(chatdebug.StatusCompleted), updated.Status) + require.True(t, updated.FinishedAt.Valid) + require.WithinDuration(t, finishedAt, updated.FinishedAt.Time, time.Second) + require.JSONEq(t, `{"after":"done"}`, string(updated.Summary)) + + stored, err := fixture.db.GetChatDebugRunByID(fixture.ctx, run.ID) + require.NoError(t, err) + require.Equal(t, string(chatdebug.StatusCompleted), stored.Status) + require.JSONEq(t, `{"after":"done"}`, string(stored.Summary)) + require.True(t, stored.FinishedAt.Valid) +} + +func TestService_UpdateRun_AutoFillsFinishedAtOnTerminalStatus(t *testing.T) { + t.Parallel() + + fixture := newFixture(t) + run, err := fixture.svc.CreateRun(fixture.ctx, chatdebug.CreateRunParams{ + ChatID: fixture.chat.ID, + Kind: chatdebug.KindChatTurn, + Status: chatdebug.StatusInProgress, + }) + require.NoError(t, err) + + // Pass a terminal status without FinishedAt. The service must + // auto-fill it so the run is immediately visible to the + // InsertChatDebugStep atomic guard (finished_at IS NULL). + // Truncate to microsecond precision to match Postgres timestamptz + // resolution; without this, nanosecond-precise Go timestamps can + // appear strictly after a round-tripped value in the same + // microsecond. + before := time.Now().Truncate(time.Microsecond) + updated, err := fixture.svc.UpdateRun(fixture.ctx, chatdebug.UpdateRunParams{ + ID: run.ID, + ChatID: fixture.chat.ID, + Status: chatdebug.StatusCompleted, + }) + require.NoError(t, err) + require.Equal(t, string(chatdebug.StatusCompleted), updated.Status) + require.True(t, updated.FinishedAt.Valid, + "FinishedAt must be auto-filled for terminal status") + require.False(t, updated.FinishedAt.Time.Before(before), + "auto-filled FinishedAt should not be earlier than test start") +} + +func TestService_UpdateRun_FinishedAtIsWriteOnce(t *testing.T) { + t.Parallel() + + fixture := newFixture(t) + run, err := fixture.svc.CreateRun(fixture.ctx, chatdebug.CreateRunParams{ + ChatID: fixture.chat.ID, + Kind: chatdebug.KindChatTurn, + Status: chatdebug.StatusInProgress, + }) + require.NoError(t, err) + + // First finalization stamps finished_at with an explicit value so + // the test is independent of wall-clock timing. + originalFinishedAt := time.Now().UTC(). + Truncate(time.Microsecond).Add(-time.Hour) + first, err := fixture.svc.UpdateRun(fixture.ctx, chatdebug.UpdateRunParams{ + ID: run.ID, + ChatID: fixture.chat.ID, + Status: chatdebug.StatusCompleted, + FinishedAt: originalFinishedAt, + }) + require.NoError(t, err) + require.True(t, first.FinishedAt.Valid) + require.True(t, first.FinishedAt.Time.Equal(originalFinishedAt)) + + // A later summary refresh on the already-finalized run must not + // overwrite the original completion timestamp, even though the + // service auto-fills FinishedAt with clock.Now() whenever a + // terminal status is passed. Without the SQL write-once guard, + // this second call would clobber finished_at with the current + // time and corrupt duration/ordering calculations. + second, err := fixture.svc.UpdateRun(fixture.ctx, chatdebug.UpdateRunParams{ + ID: run.ID, + ChatID: fixture.chat.ID, + Status: chatdebug.StatusCompleted, + Summary: map[string]any{"refreshed": true}, + }) + require.NoError(t, err) + require.True(t, second.FinishedAt.Valid) + require.True(t, second.FinishedAt.Time.Equal(originalFinishedAt), + "FinishedAt must be preserved across repeated terminal-status updates") + + // Even a caller that explicitly passes a new FinishedAt cannot + // overwrite the original. + override := originalFinishedAt.Add(time.Hour) + third, err := fixture.svc.UpdateRun(fixture.ctx, chatdebug.UpdateRunParams{ + ID: run.ID, + ChatID: fixture.chat.ID, + Status: chatdebug.StatusCompleted, + FinishedAt: override, + }) + require.NoError(t, err) + require.True(t, third.FinishedAt.Time.Equal(originalFinishedAt), + "explicit FinishedAt must not overwrite an already-set value") +} + +func TestService_CreateStep(t *testing.T) { + t.Parallel() + + fixture := newFixture(t) + run := createRun(t, fixture) + historyTipMsg := insertMessage(fixture.ctx, t, fixture.db, fixture.chat.ID, + fixture.owner.ID, fixture.model.ID, database.ChatMessageRoleAssistant, + "history-tip") + + step, err := fixture.svc.CreateStep(fixture.ctx, chatdebug.CreateStepParams{ + RunID: run.ID, + ChatID: fixture.chat.ID, + StepNumber: 1, + Operation: chatdebug.OperationStream, + Status: chatdebug.StatusInProgress, + HistoryTipMessageID: historyTipMsg.ID, + NormalizedRequest: map[string]any{ + "messages": []string{"hello"}, + }, + }) + require.NoError(t, err) + require.Equal(t, fixture.chat.ID, step.ChatID) + require.Equal(t, run.ID, step.RunID) + require.EqualValues(t, 1, step.StepNumber) + require.Equal(t, string(chatdebug.OperationStream), step.Operation) + require.Equal(t, string(chatdebug.StatusInProgress), step.Status) + require.True(t, step.HistoryTipMessageID.Valid) + require.Equal(t, historyTipMsg.ID, step.HistoryTipMessageID.Int64) + require.JSONEq(t, `{"messages":["hello"]}`, string(step.NormalizedRequest)) + + steps, err := fixture.db.GetChatDebugStepsByRunID(fixture.ctx, run.ID) + require.NoError(t, err) + require.Len(t, steps, 1) + require.Equal(t, step.ID, steps[0].ID) +} + +func TestService_CreateStep_RetriesDuplicateStepNumbers(t *testing.T) { + t.Parallel() + + fixture := newFixture(t) + run := createRun(t, fixture) + + first, err := fixture.svc.CreateStep(fixture.ctx, chatdebug.CreateStepParams{ + RunID: run.ID, + ChatID: fixture.chat.ID, + StepNumber: 1, + Operation: chatdebug.OperationStream, + Status: chatdebug.StatusInProgress, + }) + require.NoError(t, err) + + second, err := fixture.svc.CreateStep(fixture.ctx, chatdebug.CreateStepParams{ + RunID: run.ID, + ChatID: fixture.chat.ID, + StepNumber: 1, + Operation: chatdebug.OperationGenerate, + Status: chatdebug.StatusInProgress, + }) + require.NoError(t, err) + require.EqualValues(t, 1, first.StepNumber) + require.EqualValues(t, 2, second.StepNumber) +} + +func TestService_CreateStep_ListRetryErrorWins(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + svc := chatdebug.NewService(db, testutil.Logger(t), nil) + runID := uuid.New() + chatID := uuid.New() + listErr := xerrors.New("list chat debug steps") + + db.EXPECT().InsertChatDebugStep( + gomock.Any(), + gomock.AssignableToTypeOf(database.InsertChatDebugStepParams{}), + ).Return(database.ChatDebugStep{}, &pq.Error{ + Code: pq.ErrorCode("23505"), + Constraint: string(database.UniqueIndexChatDebugStepsRunStep), + }) + db.EXPECT().GetChatDebugStepsByRunID(gomock.Any(), runID).Return(nil, listErr) + + _, err := svc.CreateStep(context.Background(), chatdebug.CreateStepParams{ + RunID: runID, + ChatID: chatID, + StepNumber: 1, + Operation: chatdebug.OperationStream, + Status: chatdebug.StatusInProgress, + }) + require.ErrorIs(t, err, listErr) +} + +func TestService_CreateStep_RejectsFinalizedRun(t *testing.T) { + t.Parallel() + + fixture := newFixture(t) + run := createRun(t, fixture) + + // Finalize the run so it has a terminal state. + _, err := fixture.svc.UpdateRun(fixture.ctx, chatdebug.UpdateRunParams{ + ID: run.ID, + ChatID: fixture.chat.ID, + Status: chatdebug.StatusInterrupted, + FinishedAt: time.Now(), + }) + require.NoError(t, err) + + // Creating a step on the finalized run must fail. + _, err = fixture.svc.CreateStep(fixture.ctx, chatdebug.CreateStepParams{ + RunID: run.ID, + ChatID: fixture.chat.ID, + StepNumber: 1, + Operation: chatdebug.OperationStream, + Status: chatdebug.StatusInProgress, + }) + require.Error(t, err) + require.ErrorContains(t, err, "already finalized") +} + +func TestService_CreateStep_MissingRunReportsNotFound(t *testing.T) { + t.Parallel() + + fixture := newFixture(t) + + // Use a random run ID that was never inserted. The insert CTE + // returns zero rows, which must be classified as "not found" + // instead of being conflated with the already-finalized case. + _, err := fixture.svc.CreateStep(fixture.ctx, chatdebug.CreateStepParams{ + RunID: uuid.New(), + ChatID: fixture.chat.ID, + StepNumber: 1, + Operation: chatdebug.OperationStream, + Status: chatdebug.StatusInProgress, + }) + require.Error(t, err) + require.ErrorContains(t, err, "not found", + "missing parent runs must surface as not-found, not already-finalized") + require.NotContains(t, err.Error(), "already finalized") +} + +func TestService_CreateStep_ChatIDMismatchReportsNotFound(t *testing.T) { + t.Parallel() + + fixture := newFixture(t) + run := createRun(t, fixture) + + // Create a second chat under the same owner/model and try to + // attach a step to the existing run using the wrong chat_id. + // The insert's locked_run WHERE fails on chat_id, producing + // sql.ErrNoRows; classifyMissingRun must report not-found. + otherChat := insertChat(fixture.ctx, t, fixture.db, fixture.org.ID, + fixture.owner.ID, fixture.model.ID) + + _, err := fixture.svc.CreateStep(fixture.ctx, chatdebug.CreateStepParams{ + RunID: run.ID, + ChatID: otherChat.ID, + StepNumber: 1, + Operation: chatdebug.OperationStream, + Status: chatdebug.StatusInProgress, + }) + require.Error(t, err) + require.ErrorContains(t, err, "not found", + "chat_id mismatch must surface as not-found, not already-finalized") + require.NotContains(t, err.Error(), "already finalized") +} + +func TestService_UpdateStep(t *testing.T) { + t.Parallel() + + fixture := newFixture(t) + run := createRun(t, fixture) + step, err := fixture.svc.CreateStep(fixture.ctx, chatdebug.CreateStepParams{ + RunID: run.ID, + ChatID: fixture.chat.ID, + StepNumber: 1, + Operation: chatdebug.OperationStream, + Status: chatdebug.StatusInProgress, + }) + require.NoError(t, err) + + assistantMsg := insertMessage(fixture.ctx, t, fixture.db, fixture.chat.ID, + fixture.owner.ID, fixture.model.ID, database.ChatMessageRoleAssistant, + "assistant") + finishedAt := time.Now().UTC().Round(time.Microsecond) + updated, err := fixture.svc.UpdateStep(fixture.ctx, chatdebug.UpdateStepParams{ + ID: step.ID, + ChatID: fixture.chat.ID, + Status: chatdebug.StatusCompleted, + AssistantMessageID: assistantMsg.ID, + NormalizedResponse: map[string]any{"text": "done"}, + Usage: map[string]any{"input_tokens": 10, "output_tokens": 5}, + Attempts: []chatdebug.Attempt{{ + Number: 1, + ResponseStatus: 200, + DurationMs: 25, + }}, + Metadata: map[string]any{"provider": fixture.model.Provider}, + FinishedAt: finishedAt, + }) + require.NoError(t, err) + require.Equal(t, string(chatdebug.StatusCompleted), updated.Status) + require.True(t, updated.AssistantMessageID.Valid) + require.Equal(t, assistantMsg.ID, updated.AssistantMessageID.Int64) + require.True(t, updated.NormalizedResponse.Valid) + require.JSONEq(t, `{"text":"done"}`, + string(updated.NormalizedResponse.RawMessage)) + require.True(t, updated.Usage.Valid) + require.JSONEq(t, `{"input_tokens":10,"output_tokens":5}`, + string(updated.Usage.RawMessage)) + require.JSONEq(t, + `[{"number":1,"response_status":200,"duration_ms":25}]`, + string(updated.Attempts), + ) + require.JSONEq(t, `{"provider":"`+fixture.model.Provider+`"}`, + string(updated.Metadata)) + require.True(t, updated.FinishedAt.Valid) + storedSteps, err := fixture.db.GetChatDebugStepsByRunID(fixture.ctx, run.ID) + require.NoError(t, err) + require.Len(t, storedSteps, 1) + require.Equal(t, updated.ID, storedSteps[0].ID) +} + +func TestService_UpdateStep_AutoFillsFinishedAtOnTerminalStatus(t *testing.T) { + t.Parallel() + + fixture := newFixture(t) + run := createRun(t, fixture) + step, err := fixture.svc.CreateStep(fixture.ctx, chatdebug.CreateStepParams{ + RunID: run.ID, + ChatID: fixture.chat.ID, + StepNumber: 1, + Operation: chatdebug.OperationStream, + Status: chatdebug.StatusInProgress, + }) + require.NoError(t, err) + + // Pass a terminal status without FinishedAt. The service must + // auto-fill it so the stale sweep does not leave terminal rows + // with finished_at = NULL. + // Truncate to microsecond precision to match Postgres timestamptz + // resolution. + before := time.Now().Truncate(time.Microsecond) + updated, err := fixture.svc.UpdateStep(fixture.ctx, chatdebug.UpdateStepParams{ + ID: step.ID, + ChatID: fixture.chat.ID, + Status: chatdebug.StatusError, + }) + require.NoError(t, err) + require.Equal(t, string(chatdebug.StatusError), updated.Status) + require.True(t, updated.FinishedAt.Valid, + "FinishedAt must be auto-filled for terminal status") + require.False(t, updated.FinishedAt.Time.Before(before), + "auto-filled FinishedAt should not be earlier than test start") +} + +func TestService_UpdateStep_TypedNilAttemptsPreserveExistingValue(t *testing.T) { + t.Parallel() + + fixture := newFixture(t) + run := createRun(t, fixture) + step, err := fixture.svc.CreateStep(fixture.ctx, chatdebug.CreateStepParams{ + RunID: run.ID, + ChatID: fixture.chat.ID, + StepNumber: 1, + Operation: chatdebug.OperationStream, + Status: chatdebug.StatusInProgress, + }) + require.NoError(t, err) + + _, err = fixture.svc.UpdateStep(fixture.ctx, chatdebug.UpdateStepParams{ + ID: step.ID, + ChatID: fixture.chat.ID, + Status: chatdebug.StatusCompleted, + Attempts: []chatdebug.Attempt{{ + Number: 1, + }}, + }) + require.NoError(t, err) + + var typedNilAttempts []chatdebug.Attempt + updated, err := fixture.svc.UpdateStep(fixture.ctx, chatdebug.UpdateStepParams{ + ID: step.ID, + ChatID: fixture.chat.ID, + Attempts: typedNilAttempts, + }) + require.NoError(t, err) + + var attempts []map[string]any + require.NoError(t, json.Unmarshal(updated.Attempts, &attempts)) + require.Len(t, attempts, 1) + require.EqualValues(t, 1, attempts[0]["number"]) +} + +func TestService_DeleteByChatID(t *testing.T) { + t.Parallel() + + fixture := newFixture(t) + run := createRun(t, fixture) + _, err := fixture.svc.CreateStep(fixture.ctx, chatdebug.CreateStepParams{ + RunID: run.ID, + ChatID: fixture.chat.ID, + StepNumber: 1, + Operation: chatdebug.OperationGenerate, + Status: chatdebug.StatusInProgress, + }) + require.NoError(t, err) + + deleted, err := fixture.svc.DeleteByChatID(fixture.ctx, fixture.chat.ID) + require.NoError(t, err) + require.EqualValues(t, 1, deleted) + + runs, err := fixture.db.GetChatDebugRunsByChatID(fixture.ctx, database.GetChatDebugRunsByChatIDParams{ + ChatID: fixture.chat.ID, + LimitVal: 100, + }) + require.NoError(t, err) + require.Empty(t, runs) +} + +func TestService_DeleteAfterMessageID(t *testing.T) { + t.Parallel() + + fixture := newFixture(t) + low := insertMessage(fixture.ctx, t, fixture.db, fixture.chat.ID, fixture.owner.ID, + fixture.model.ID, database.ChatMessageRoleAssistant, "low") + threshold := insertMessage(fixture.ctx, t, fixture.db, fixture.chat.ID, + fixture.owner.ID, fixture.model.ID, database.ChatMessageRoleAssistant, + "threshold") + high := insertMessage(fixture.ctx, t, fixture.db, fixture.chat.ID, fixture.owner.ID, + fixture.model.ID, database.ChatMessageRoleAssistant, "high") + require.Less(t, low.ID, threshold.ID) + require.Less(t, threshold.ID, high.ID) + + runKeep := createRun(t, fixture) + stepKeep, err := fixture.svc.CreateStep(fixture.ctx, chatdebug.CreateStepParams{ + RunID: runKeep.ID, + ChatID: fixture.chat.ID, + StepNumber: 1, + Operation: chatdebug.OperationGenerate, + Status: chatdebug.StatusInProgress, + }) + require.NoError(t, err) + _, err = fixture.svc.UpdateStep(fixture.ctx, chatdebug.UpdateStepParams{ + ID: stepKeep.ID, + ChatID: fixture.chat.ID, + AssistantMessageID: low.ID, + }) + require.NoError(t, err) + + runDelete := createRun(t, fixture) + stepDelete, err := fixture.svc.CreateStep(fixture.ctx, chatdebug.CreateStepParams{ + RunID: runDelete.ID, + ChatID: fixture.chat.ID, + StepNumber: 1, + Operation: chatdebug.OperationGenerate, + Status: chatdebug.StatusInProgress, + }) + require.NoError(t, err) + _, err = fixture.svc.UpdateStep(fixture.ctx, chatdebug.UpdateStepParams{ + ID: stepDelete.ID, + ChatID: fixture.chat.ID, + AssistantMessageID: high.ID, + }) + require.NoError(t, err) + + deleted, err := fixture.svc.DeleteAfterMessageID(fixture.ctx, fixture.chat.ID, + threshold.ID) + require.NoError(t, err) + require.EqualValues(t, 1, deleted) + + runs, err := fixture.db.GetChatDebugRunsByChatID(fixture.ctx, database.GetChatDebugRunsByChatIDParams{ + ChatID: fixture.chat.ID, + LimitVal: 100, + }) + require.NoError(t, err) + require.Len(t, runs, 1) + require.Equal(t, runKeep.ID, runs[0].ID) + + steps, err := fixture.db.GetChatDebugStepsByRunID(fixture.ctx, runKeep.ID) + require.NoError(t, err) + require.Len(t, steps, 1) + require.Equal(t, stepKeep.ID, steps[0].ID) +} + +func TestService_FinalizeStale_UsesConfiguredThreshold(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + svc := chatdebug.NewService(db, testutil.Logger(t), nil) + svc.SetStaleAfter(42 * time.Second) + + db.EXPECT().FinalizeStaleChatDebugRows(gomock.Any(), gomock.Any()).DoAndReturn( + func(_ context.Context, params database.FinalizeStaleChatDebugRowsParams) (database.FinalizeStaleChatDebugRowsRow, error) { + require.WithinDuration(t, time.Now().Add(-42*time.Second), params.UpdatedBefore, 2*time.Second) + return database.FinalizeStaleChatDebugRowsRow{}, nil + }, + ) + + result, err := svc.FinalizeStale(context.Background()) + require.NoError(t, err) + require.Zero(t, result.RunsFinalized) + require.Zero(t, result.StepsFinalized) +} + +func TestService_FinalizeStale(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + db, _ := dbtestutil.NewDB(t) + _, owner, chat, model := seedChat(ctx, t, db) + require.NotEqual(t, uuid.Nil, owner.ID) + + staleTime := time.Now().Add(-10 * time.Minute).UTC().Round(time.Microsecond) + run, err := db.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{ + ChatID: chat.ID, + ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true}, + Kind: string(chatdebug.KindChatTurn), + Status: string(chatdebug.StatusInProgress), + StartedAt: sql.NullTime{Time: staleTime, Valid: true}, + UpdatedAt: sql.NullTime{Time: staleTime, Valid: true}, + }) + require.NoError(t, err) + step, err := db.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{ + RunID: run.ID, + StepNumber: 1, + Operation: string(chatdebug.OperationStream), + Status: string(chatdebug.StatusInProgress), + StartedAt: sql.NullTime{Time: staleTime, Valid: true}, + UpdatedAt: sql.NullTime{Time: staleTime, Valid: true}, + ChatID: chat.ID, + }) + require.NoError(t, err) + + svc := chatdebug.NewService(db, testutil.Logger(t), nil) + result, err := svc.FinalizeStale(ctx) + require.NoError(t, err) + require.EqualValues(t, 1, result.RunsFinalized) + require.EqualValues(t, 1, result.StepsFinalized) + + storedRun, err := db.GetChatDebugRunByID(ctx, run.ID) + require.NoError(t, err) + require.Equal(t, string(chatdebug.StatusInterrupted), storedRun.Status) + require.True(t, storedRun.FinishedAt.Valid) + + storedSteps, err := db.GetChatDebugStepsByRunID(ctx, run.ID) + require.NoError(t, err) + require.Len(t, storedSteps, 1) + require.Equal(t, step.ID, storedSteps[0].ID) + require.Equal(t, string(chatdebug.StatusInterrupted), storedSteps[0].Status) + require.True(t, storedSteps[0].FinishedAt.Valid) +} + +func TestService_FinalizeStale_BroadcastsFinalizeEvent(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + db, _ := dbtestutil.NewDB(t) + _, owner, chat, model := seedChat(ctx, t, db) + require.NotEqual(t, uuid.Nil, owner.ID) + + staleTime := time.Now().Add(-10 * time.Minute).UTC().Round(time.Microsecond) + run, err := db.InsertChatDebugRun(ctx, database.InsertChatDebugRunParams{ + ChatID: chat.ID, + ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true}, + Kind: string(chatdebug.KindChatTurn), + Status: string(chatdebug.StatusInProgress), + StartedAt: sql.NullTime{Time: staleTime, Valid: true}, + UpdatedAt: sql.NullTime{Time: staleTime, Valid: true}, + }) + require.NoError(t, err) + _, err = db.InsertChatDebugStep(ctx, database.InsertChatDebugStepParams{ + RunID: run.ID, + StepNumber: 1, + Operation: string(chatdebug.OperationStream), + Status: string(chatdebug.StatusInProgress), + StartedAt: sql.NullTime{Time: staleTime, Valid: true}, + UpdatedAt: sql.NullTime{Time: staleTime, Valid: true}, + ChatID: chat.ID, + }) + require.NoError(t, err) + + memoryPubsub := dbpubsub.NewInMemory() + svc := chatdebug.NewService(db, testutil.Logger(t), memoryPubsub) + type eventResult struct { + event chatdebug.DebugEvent + err error + } + events := make(chan eventResult, 1) + cancel, err := memoryPubsub.Subscribe(chatdebug.PubsubChannel(uuid.Nil), + func(_ context.Context, message []byte) { + var event chatdebug.DebugEvent + unmarshalErr := json.Unmarshal(message, &event) + events <- eventResult{event: event, err: unmarshalErr} + }, + ) + require.NoError(t, err) + defer cancel() + + result, err := svc.FinalizeStale(ctx) + require.NoError(t, err) + require.EqualValues(t, 1, result.RunsFinalized) + require.EqualValues(t, 1, result.StepsFinalized) + + select { + case received := <-events: + require.NoError(t, received.err) + require.Equal(t, chatdebug.EventKindFinalize, received.event.Kind) + require.Equal(t, uuid.Nil, received.event.ChatID) + require.Equal(t, uuid.Nil, received.event.RunID) + require.Equal(t, uuid.Nil, received.event.StepID) + case <-time.After(testutil.WaitShort): + t.Fatal("timed out waiting for finalize event") + } +} + +func TestService_FinalizeStale_NoChangesDoesNotBroadcast(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + db, _ := dbtestutil.NewDB(t) + _, owner, chat, _ := seedChat(ctx, t, db) + require.NotEqual(t, uuid.Nil, owner.ID) + + memoryPubsub := dbpubsub.NewInMemory() + svc := chatdebug.NewService(db, testutil.Logger(t), memoryPubsub) + events := make(chan chatdebug.DebugEvent, 1) + cancel, err := memoryPubsub.Subscribe(chatdebug.PubsubChannel(uuid.Nil), + func(_ context.Context, message []byte) { + var event chatdebug.DebugEvent + if err := json.Unmarshal(message, &event); err == nil { + events <- event + } + }, + ) + require.NoError(t, err) + defer cancel() + + result, err := svc.FinalizeStale(ctx) + require.NoError(t, err) + require.EqualValues(t, 0, result.RunsFinalized) + require.EqualValues(t, 0, result.StepsFinalized) + + select { + case event := <-events: + t.Fatalf("unexpected finalize event: %+v", event) + default: + } + + _ = chat // keep seeded chat usage explicit for test readability. +} + +func TestService_PublishesEvents(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + db, _ := dbtestutil.NewDB(t) + _, owner, chat, model := seedChat(ctx, t, db) + require.NotEqual(t, uuid.Nil, owner.ID) + + memoryPubsub := dbpubsub.NewInMemory() + svc := chatdebug.NewService(db, testutil.Logger(t), memoryPubsub) + type eventResult struct { + event chatdebug.DebugEvent + err error + } + events := make(chan eventResult, 1) + cancel, err := memoryPubsub.Subscribe(chatdebug.PubsubChannel(chat.ID), + func(_ context.Context, message []byte) { + var event chatdebug.DebugEvent + unmarshalErr := json.Unmarshal(message, &event) + events <- eventResult{event: event, err: unmarshalErr} + }, + ) + require.NoError(t, err) + defer cancel() + + run, err := svc.CreateRun(ctx, chatdebug.CreateRunParams{ + ChatID: chat.ID, + ModelConfigID: model.ID, + Kind: chatdebug.KindChatTurn, + Status: chatdebug.StatusInProgress, + }) + require.NoError(t, err) + + select { + case received := <-events: + require.NoError(t, received.err) + require.Equal(t, chatdebug.EventKindRunUpdate, received.event.Kind) + require.Equal(t, chat.ID, received.event.ChatID) + require.Equal(t, run.ID, received.event.RunID) + require.Equal(t, uuid.Nil, received.event.StepID) + case <-time.After(testutil.WaitShort): + t.Fatal("timed out waiting for debug event") + } + + select { + case received := <-events: + t.Fatalf("unexpected extra event: %+v", received.event) + default: + } +} + +func newFixture(t *testing.T) testFixture { + t.Helper() + + ctx := testutil.Context(t, testutil.WaitLong) + db, _ := dbtestutil.NewDB(t) + org, owner, chat, model := seedChat(ctx, t, db) + return testFixture{ + ctx: ctx, + db: db, + svc: chatdebug.NewService(db, testutil.Logger(t), nil), + org: org, + owner: owner, + chat: chat, + model: model, + } +} + +func seedChat( + ctx context.Context, + t *testing.T, + db database.Store, +) (database.Organization, database.User, database.Chat, database.ChatModelConfig) { + t.Helper() + + org := dbgen.Organization(t, db, database.Organization{}) + owner := dbgen.User(t, db, database.User{}) + providerName := "openai" + _, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{ + Provider: providerName, + DisplayName: "OpenAI", + APIKey: "test-key", + CreatedBy: uuid.NullUUID{UUID: owner.ID, Valid: true}, + Enabled: true, + CentralApiKeyEnabled: true, + }) + require.NoError(t, err) + + model, err := db.InsertChatModelConfig(ctx, + database.InsertChatModelConfigParams{ + Provider: providerName, + Model: "model-" + uuid.NewString(), + DisplayName: "Test Model", + CreatedBy: uuid.NullUUID{UUID: owner.ID, Valid: true}, + UpdatedBy: uuid.NullUUID{UUID: owner.ID, Valid: true}, + Enabled: true, + IsDefault: true, + ContextLimit: 128000, + CompressionThreshold: 70, + Options: json.RawMessage(`{}`), + }, + ) + require.NoError(t, err) + + chat := insertChat(ctx, t, db, org.ID, owner.ID, model.ID) + return org, owner, chat, model +} + +func insertChat( + ctx context.Context, + t *testing.T, + db database.Store, + orgID uuid.UUID, + ownerID uuid.UUID, + modelID uuid.UUID, +) database.Chat { + t.Helper() + + chat, err := db.InsertChat(ctx, database.InsertChatParams{ + OrganizationID: orgID, + Status: database.ChatStatusWaiting, + ClientType: database.ChatClientTypeUi, + OwnerID: ownerID, + LastModelConfigID: modelID, + Title: "chat-" + uuid.NewString(), + }) + require.NoError(t, err) + return chat +} + +func insertMessage( + ctx context.Context, + t *testing.T, + db database.Store, + chatID uuid.UUID, + createdBy uuid.UUID, + modelID uuid.UUID, + role database.ChatMessageRole, + text string, +) database.ChatMessage { + t.Helper() + + parts, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{ + codersdk.ChatMessageText(text), + }) + require.NoError(t, err) + + messages, err := db.InsertChatMessages(ctx, database.InsertChatMessagesParams{ + ChatID: chatID, + CreatedBy: []uuid.UUID{createdBy}, + ModelConfigID: []uuid.UUID{modelID}, + Role: []database.ChatMessageRole{role}, + Content: []string{string(parts.RawMessage)}, + ContentVersion: []int16{chatprompt.CurrentContentVersion}, + Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityBoth}, + InputTokens: []int64{0}, + OutputTokens: []int64{0}, + TotalTokens: []int64{0}, + ReasoningTokens: []int64{0}, + CacheCreationTokens: []int64{0}, + CacheReadTokens: []int64{0}, + ContextLimit: []int64{0}, + Compressed: []bool{false}, + TotalCostMicros: []int64{0}, + RuntimeMs: []int64{0}, + ProviderResponseID: []string{""}, + }) + require.NoError(t, err) + require.Len(t, messages, 1) + return messages[0] +} + +func createRun(t *testing.T, fixture testFixture) database.ChatDebugRun { + t.Helper() + + run, err := fixture.svc.CreateRun(fixture.ctx, chatdebug.CreateRunParams{ + ChatID: fixture.chat.ID, + ModelConfigID: fixture.model.ID, + Kind: chatdebug.KindChatTurn, + Status: chatdebug.StatusInProgress, + Provider: fixture.model.Provider, + Model: fixture.model.Model, + }) + require.NoError(t, err) + return run +} + +func assertRunMatches( + t *testing.T, + run database.ChatDebugRun, + chatID uuid.UUID, + rootChatID uuid.UUID, + parentChatID uuid.UUID, + modelID uuid.UUID, + triggerMessageID int64, + historyTipMessageID int64, + kind chatdebug.RunKind, + status chatdebug.Status, + provider string, + model string, + summary string, +) { + t.Helper() + + require.Equal(t, chatID, run.ChatID) + require.True(t, run.RootChatID.Valid) + require.Equal(t, rootChatID, run.RootChatID.UUID) + require.True(t, run.ParentChatID.Valid) + require.Equal(t, parentChatID, run.ParentChatID.UUID) + require.True(t, run.ModelConfigID.Valid) + require.Equal(t, modelID, run.ModelConfigID.UUID) + require.True(t, run.TriggerMessageID.Valid) + require.Equal(t, triggerMessageID, run.TriggerMessageID.Int64) + require.True(t, run.HistoryTipMessageID.Valid) + require.Equal(t, historyTipMessageID, run.HistoryTipMessageID.Int64) + require.Equal(t, string(kind), run.Kind) + require.Equal(t, string(status), run.Status) + require.True(t, run.Provider.Valid) + require.Equal(t, provider, run.Provider.String) + require.True(t, run.Model.Valid) + require.Equal(t, model, run.Model.String) + require.JSONEq(t, summary, string(run.Summary)) + require.False(t, run.StartedAt.IsZero()) + require.False(t, run.UpdatedAt.IsZero()) + require.False(t, run.FinishedAt.Valid) +} diff --git a/coderd/x/chatd/chatdebug/stubs.go b/coderd/x/chatd/chatdebug/stubs.go deleted file mode 100644 index 72dc5246d0..0000000000 --- a/coderd/x/chatd/chatdebug/stubs.go +++ /dev/null @@ -1,160 +0,0 @@ -package chatdebug - -import ( - "context" - "regexp" - "strings" - "sync" - "sync/atomic" - "time" - "unicode/utf8" - - "github.com/google/uuid" - - "cdr.dev/slog/v3" - "github.com/coder/coder/v2/coderd/database" - "github.com/coder/coder/v2/coderd/database/pubsub" -) - -// This compatibility shim forward-declares service and summary symbols -// that land in later stacked branches. Delete this file once service.go -// and summary.go are available here. - -// Service is a placeholder for the later chat debug persistence service. -type Service struct { - log slog.Logger -} - -// CreateStepParams identifies the data recorded when a debug step starts. -type CreateStepParams struct { - RunID uuid.UUID - ChatID uuid.UUID - StepNumber int32 - Operation Operation - Status Status - HistoryTipMessageID int64 - NormalizedRequest any -} - -// UpdateStepParams identifies the data recorded when a debug step finishes. -type UpdateStepParams struct { - ID uuid.UUID - ChatID uuid.UUID - Status Status - NormalizedResponse any - Usage any - Attempts []Attempt - Error any - Metadata any - FinishedAt time.Time -} - -// NewService constructs the placeholder chat debug service. -func NewService(_ database.Store, log slog.Logger, _ pubsub.Pubsub) *Service { - return &Service{log: log} -} - -// IsEnabled reports whether debug recording is enabled for a chat owner. -func (*Service) IsEnabled(context.Context, uuid.UUID, uuid.UUID) bool { - return true -} - -// CreateStep synthesizes a debug step so recorder tests can exercise the -// wrapper without requiring the later persistence service implementation. -func (*Service) CreateStep( - _ context.Context, - params CreateStepParams, -) (database.ChatDebugStep, error) { - return database.ChatDebugStep{ - ID: uuid.New(), - RunID: params.RunID, - ChatID: params.ChatID, - StepNumber: params.StepNumber, - Operation: string(params.Operation), - Status: string(params.Status), - }, nil -} - -// UpdateStep accepts final step state once recording completes. -func (*Service) UpdateStep( - _ context.Context, - params UpdateStepParams, -) (database.ChatDebugStep, error) { - return database.ChatDebugStep{ - ID: params.ID, - ChatID: params.ChatID, - Status: string(params.Status), - }, nil -} - -// runRefCounts tracks how many live RunContext instances reference each -// RunID. Cleanup of shared state (step counters) is deferred until the -// last RunContext for a given RunID is garbage collected. -var ( - runRefCounts sync.Map // map[uuid.UUID]*atomic.Int32 - // refCountMu serializes trackRunRef and releaseRunRef so the - // decrement-to-zero check and subsequent map deletions are - // atomic with respect to new references being added. - refCountMu sync.Mutex -) - -func trackRunRef(runID uuid.UUID) { - refCountMu.Lock() - defer refCountMu.Unlock() - val, _ := runRefCounts.LoadOrStore(runID, &atomic.Int32{}) - counter, ok := val.(*atomic.Int32) - if !ok { - panic("chatdebug: runRefCounts contains non-*atomic.Int32 value") - } - counter.Add(1) -} - -// releaseRunRef decrements the reference count for runID and cleans up -// shared state when the last reference is released. The mutex ensures -// no concurrent trackRunRef can increment between the zero check and -// the map deletions. -func releaseRunRef(runID uuid.UUID) { - refCountMu.Lock() - defer refCountMu.Unlock() - val, ok := runRefCounts.Load(runID) - if !ok { - return - } - counter, ok := val.(*atomic.Int32) - if !ok { - panic("chatdebug: runRefCounts contains non-*atomic.Int32 value") - } - if counter.Add(-1) <= 0 { - runRefCounts.Delete(runID) - stepCounters.Delete(runID) - } -} - -// whitespaceRun matches one or more consecutive whitespace characters. -var whitespaceRun = regexp.MustCompile(`\s+`) - -// truncateRunes truncates s to maxLen runes, appending an ellipsis -// when truncation occurs. Returns "" when maxLen <= 0. -func truncateRunes(s string, maxLen int) string { - if maxLen <= 0 { - return "" - } - if utf8.RuneCountInString(s) <= maxLen { - return s - } - if maxLen == 1 { - return "…" - } - runes := []rune(s) - return string(runes[:maxLen-1]) + "…" -} - -// TruncateLabel whitespace-normalizes and truncates text to maxLen runes. -// Returns "" if input is empty or whitespace-only. -func TruncateLabel(text string, maxLen int) string { - normalized := strings.TrimSpace(whitespaceRun.ReplaceAllString(text, " ")) - if normalized == "" { - return "" - } - return truncateRunes(normalized, maxLen) -} diff --git a/coderd/x/chatd/chatdebug/stubs_internal_test.go b/coderd/x/chatd/chatdebug/stubs_internal_test.go index 75d0aabdd4..ebef8e22a6 100644 --- a/coderd/x/chatd/chatdebug/stubs_internal_test.go +++ b/coderd/x/chatd/chatdebug/stubs_internal_test.go @@ -3,7 +3,6 @@ package chatdebug import ( "context" "testing" - "unicode/utf8" "github.com/google/uuid" "github.com/stretchr/testify/require" @@ -17,34 +16,3 @@ func TestBeginStep_SkipsNilRunID(t *testing.T) { require.Nil(t, handle) require.Equal(t, ctx, enriched) } - -func TestTruncateLabel(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - input string - maxLen int - want string - }{ - {name: "Empty", input: "", maxLen: 10, want: ""}, - {name: "WhitespaceOnly", input: " \t\n ", maxLen: 10, want: ""}, - {name: "ShortText", input: "hello world", maxLen: 20, want: "hello world"}, - {name: "ExactLength", input: "abcde", maxLen: 5, want: "abcde"}, - {name: "LongTextTruncated", input: "abcdefghij", maxLen: 5, want: "abcd…"}, - {name: "NegativeMaxLen", input: "hello", maxLen: -1, want: ""}, - {name: "ZeroMaxLen", input: "hello", maxLen: 0, want: ""}, - {name: "SingleRuneLimit", input: "hello", maxLen: 1, want: "…"}, - {name: "MultipleWhitespaceRuns", input: " hello world \t again ", maxLen: 100, want: "hello world again"}, - {name: "UnicodeRunes", input: "こんにちは世界", maxLen: 3, want: "こん…"}, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - got := TruncateLabel(tc.input, tc.maxLen) - require.Equal(t, tc.want, got) - require.LessOrEqual(t, utf8.RuneCountInString(got), max(tc.maxLen, 0)) - }) - } -} diff --git a/coderd/x/chatd/chatdebug/summary.go b/coderd/x/chatd/chatdebug/summary.go new file mode 100644 index 0000000000..9b193dfd93 --- /dev/null +++ b/coderd/x/chatd/chatdebug/summary.go @@ -0,0 +1,210 @@ +package chatdebug + +import ( + "bytes" + "context" + "encoding/json" + "regexp" + "strings" + + "charm.land/fantasy" + "github.com/google/uuid" + + "cdr.dev/slog/v3" + "github.com/coder/coder/v2/coderd/database" + stringutil "github.com/coder/coder/v2/coderd/util/strings" +) + +// whitespaceRun matches one or more consecutive whitespace characters. +var whitespaceRun = regexp.MustCompile(`\s+`) + +// TruncateLabel whitespace-normalizes and truncates text to maxLen runes. +// Returns "" if input is empty or whitespace-only. +func TruncateLabel(text string, maxLen int) string { + normalized := strings.TrimSpace(whitespaceRun.ReplaceAllString(text, " ")) + if normalized == "" { + return "" + } + return stringutil.Truncate(normalized, maxLen, stringutil.TruncateWithEllipsis) +} + +// SeedSummary builds a base summary map with a first_message label. +// Returns nil if label is empty. +func SeedSummary(label string) map[string]any { + if label == "" { + return nil + } + return map[string]any{"first_message": label} +} + +// ExtractFirstUserText extracts the plain text content from a +// fantasy.Prompt for the first user message. Used to derive +// first_message labels at run creation time. +func ExtractFirstUserText(prompt fantasy.Prompt) string { + for _, msg := range prompt { + if msg.Role != fantasy.MessageRoleUser { + continue + } + + var sb strings.Builder + for _, part := range msg.Content { + tp, ok := fantasy.AsMessagePart[fantasy.TextPart](part) + if !ok { + continue + } + _, _ = sb.WriteString(tp.Text) + } + return sb.String() + } + return "" +} + +// AggregateRunSummary reads all steps for the given run, computes token +// totals, and merges them with the run's existing summary (preserving any +// seeded first_message label). The baseSummary parameter should be the +// current run summary (may be nil). +func (s *Service) AggregateRunSummary( + ctx context.Context, + runID uuid.UUID, + baseSummary map[string]any, +) (map[string]any, error) { + if runID == uuid.Nil { + return baseSummary, nil + } + + steps, err := s.db.GetChatDebugStepsByRunID(chatdContext(ctx), runID) + if err != nil { + return nil, err + } + + // Start from a shallow copy of baseSummary to avoid mutating the + // caller's map. + // Capacity hint: baseSummary entries plus 8 derived keys + // (step_count, total_input_tokens, total_output_tokens, + // total_reasoning_tokens, total_cache_creation_tokens, + // total_cache_read_tokens, has_error, endpoint_label). + result := make(map[string]any, len(baseSummary)+8) + for k, v := range baseSummary { + result[k] = v + } + + // Clear derived fields before recomputing them so stale values from a + // previous aggregation do not survive when the new totals are zero or + // the endpoint label is unavailable. + for _, key := range []string{ + "step_count", + "total_input_tokens", + "total_output_tokens", + "total_reasoning_tokens", + "total_cache_creation_tokens", + "total_cache_read_tokens", + "endpoint_label", + "has_error", + } { + delete(result, key) + } + var ( + totalInput int64 + totalOutput int64 + totalReasoning int64 + totalCacheCreation int64 + totalCacheRead int64 + hasError bool + ) + + for _, step := range steps { + // Flag runs that hit a real error. Interrupted steps represent + // user-initiated cancellation (e.g. clicking Stop) and should + // not trigger the error indicator in the debug panel. + // A JSONB null (used by jsonClear to erase a prior error) is + // Valid but carries no meaningful content, so exclude it. + errorIsReal := step.Error.Valid && + len(step.Error.RawMessage) > 0 && + !bytes.Equal(step.Error.RawMessage, []byte("null")) + if step.Status == string(StatusError) || + (errorIsReal && step.Status != string(StatusInterrupted)) { + hasError = true + } + if !step.Usage.Valid || len(step.Usage.RawMessage) == 0 { + continue + } + + var usage fantasy.Usage + if err := json.Unmarshal(step.Usage.RawMessage, &usage); err != nil { + s.log.Warn(ctx, "skipping malformed step usage JSON", + slog.Error(err), + slog.F("run_id", runID), + slog.F("step_id", step.ID), + ) + continue + } + + totalInput += usage.InputTokens + totalOutput += usage.OutputTokens + totalReasoning += usage.ReasoningTokens + totalCacheCreation += usage.CacheCreationTokens + totalCacheRead += usage.CacheReadTokens + } + + result["step_count"] = len(steps) + result["total_input_tokens"] = totalInput + result["total_output_tokens"] = totalOutput + + // Only include reasoning/cache fields when non-zero to keep the + // summary compact for the common case. + if totalReasoning > 0 { + result["total_reasoning_tokens"] = totalReasoning + } + if totalCacheCreation > 0 { + result["total_cache_creation_tokens"] = totalCacheCreation + } + if totalCacheRead > 0 { + result["total_cache_read_tokens"] = totalCacheRead + } + + if hasError { + result["has_error"] = true + } + + // Derive endpoint_label from the first completed attempt's path + // across all steps. This gives the debug panel a meaningful + // identifier like "POST /v1/messages" for the run row. + if label := extractEndpointLabel(steps); label != "" { + result["endpoint_label"] = label + } + + return result, nil +} + +// attemptLabel is a minimal projection of Attempt used by +// extractEndpointLabel to avoid deserializing large RequestBody and +// ResponseBody fields that are not needed for label derivation. +type attemptLabel struct { + Status string `json:"status,omitempty"` + Method string `json:"method,omitempty"` + Path string `json:"path,omitempty"` +} + +// extractEndpointLabel scans steps for the first completed attempt with a +// non-empty path and returns "METHOD /path" (or just "/path"). +func extractEndpointLabel(steps []database.ChatDebugStep) string { + for _, step := range steps { + if len(step.Attempts) == 0 { + continue + } + var attempts []attemptLabel + if err := json.Unmarshal(step.Attempts, &attempts); err != nil { + continue + } + for _, a := range attempts { + if a.Status != attemptStatusCompleted || a.Path == "" { + continue + } + if a.Method != "" { + return a.Method + " " + a.Path + } + return a.Path + } + } + return "" +} diff --git a/coderd/x/chatd/chatdebug/summary_test.go b/coderd/x/chatd/chatdebug/summary_test.go new file mode 100644 index 0000000000..3c41877cd2 --- /dev/null +++ b/coderd/x/chatd/chatdebug/summary_test.go @@ -0,0 +1,516 @@ +package chatdebug_test + +import ( + "encoding/json" + "testing" + "time" + "unicode/utf8" + + "charm.land/fantasy" + "github.com/google/uuid" + "github.com/sqlc-dev/pqtype" + "github.com/stretchr/testify/require" + + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/x/chatd/chatdebug" +) + +func TestTruncateLabel(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + input string + maxLen int + want string + }{ + {name: "Empty", input: "", maxLen: 10, want: ""}, + {name: "WhitespaceOnly", input: " \t\n ", maxLen: 10, want: ""}, + {name: "ShortText", input: "hello world", maxLen: 20, want: "hello world"}, + {name: "ExactLength", input: "abcde", maxLen: 5, want: "abcde"}, + {name: "LongTextTruncated", input: "abcdefghij", maxLen: 5, want: "abcd…"}, + {name: "NegativeMaxLen", input: "hello", maxLen: -1, want: ""}, + {name: "ZeroMaxLen", input: "hello", maxLen: 0, want: ""}, + {name: "SingleRuneLimit", input: "hello", maxLen: 1, want: "…"}, + {name: "MultipleWhitespaceRuns", input: " hello world \t again ", maxLen: 100, want: "hello world again"}, + {name: "UnicodeRunes", input: "こんにちは世界", maxLen: 3, want: "こん…"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + got := chatdebug.TruncateLabel(tc.input, tc.maxLen) + require.Equal(t, tc.want, got) + require.LessOrEqual(t, utf8.RuneCountInString(got), max(tc.maxLen, 0)) + }) + } +} + +func TestSeedSummary(t *testing.T) { + t.Parallel() + + t.Run("NonEmptyLabel", func(t *testing.T) { + t.Parallel() + got := chatdebug.SeedSummary("hello world") + require.Equal(t, map[string]any{"first_message": "hello world"}, got) + }) + + t.Run("EmptyLabel", func(t *testing.T) { + t.Parallel() + got := chatdebug.SeedSummary("") + require.Nil(t, got) + }) +} + +func TestExtractFirstUserText(t *testing.T) { + t.Parallel() + + t.Run("EmptyPrompt", func(t *testing.T) { + t.Parallel() + got := chatdebug.ExtractFirstUserText(fantasy.Prompt{}) + require.Equal(t, "", got) + }) + + t.Run("NoUserMessages", func(t *testing.T) { + t.Parallel() + prompt := fantasy.Prompt{ + { + Role: fantasy.MessageRoleSystem, + Content: []fantasy.MessagePart{fantasy.TextPart{Text: "system"}}, + }, + { + Role: fantasy.MessageRoleAssistant, + Content: []fantasy.MessagePart{fantasy.TextPart{Text: "assistant"}}, + }, + } + got := chatdebug.ExtractFirstUserText(prompt) + require.Equal(t, "", got) + }) + + t.Run("FirstUserMessageMixedParts", func(t *testing.T) { + t.Parallel() + prompt := fantasy.Prompt{ + { + Role: fantasy.MessageRoleUser, + Content: []fantasy.MessagePart{ + fantasy.TextPart{Text: "hello "}, + fantasy.FilePart{Filename: "test.png"}, + fantasy.TextPart{Text: "world"}, + }, + }, + } + got := chatdebug.ExtractFirstUserText(prompt) + require.Equal(t, "hello world", got) + }) + + t.Run("MultipleUserMessagesReturnsFirst", func(t *testing.T) { + t.Parallel() + prompt := fantasy.Prompt{ + { + Role: fantasy.MessageRoleSystem, + Content: []fantasy.MessagePart{fantasy.TextPart{Text: "system"}}, + }, + { + Role: fantasy.MessageRoleUser, + Content: []fantasy.MessagePart{fantasy.TextPart{Text: "first"}}, + }, + { + Role: fantasy.MessageRoleUser, + Content: []fantasy.MessagePart{fantasy.TextPart{Text: "second"}}, + }, + } + got := chatdebug.ExtractFirstUserText(prompt) + require.Equal(t, "first", got) + }) +} + +func TestService_AggregateRunSummary(t *testing.T) { + t.Parallel() + + t.Run("NilRunID", func(t *testing.T) { + t.Parallel() + fixture := newFixture(t) + got, err := fixture.svc.AggregateRunSummary(fixture.ctx, uuid.Nil, nil) + require.NoError(t, err) + require.Nil(t, got) + }) + + t.Run("ZeroSteps", func(t *testing.T) { + t.Parallel() + fixture := newFixture(t) + run := createRun(t, fixture) + + // No steps created. Call with a base summary containing + // first_message so we can verify it is preserved. + base := map[string]any{"first_message": "hello world"} + got, err := fixture.svc.AggregateRunSummary(fixture.ctx, run.ID, base) + require.NoError(t, err) + require.Equal(t, "hello world", got["first_message"]) + require.EqualValues(t, 0, got["step_count"]) + require.EqualValues(t, int64(0), got["total_input_tokens"]) + require.EqualValues(t, int64(0), got["total_output_tokens"]) + require.NotContains(t, got, "total_reasoning_tokens") + require.NotContains(t, got, "total_cache_creation_tokens") + require.NotContains(t, got, "total_cache_read_tokens") + require.NotContains(t, got, "has_error") + require.NotContains(t, got, "endpoint_label") + }) + + t.Run("NilBaseSummary", func(t *testing.T) { + t.Parallel() + fixture := newFixture(t) + run := createRun(t, fixture) + + // Create a step with usage. + step := createTestStep(t, fixture, run.ID) + updateTestStepWithUsage(t, fixture, step.ID, 10, 5, 0, 0) + + got, err := fixture.svc.AggregateRunSummary(fixture.ctx, run.ID, nil) + require.NoError(t, err) + require.NotNil(t, got) + require.EqualValues(t, 1, got["step_count"]) + require.EqualValues(t, int64(10), got["total_input_tokens"]) + require.EqualValues(t, int64(5), got["total_output_tokens"]) + }) + + t.Run("PreservesFirstMessage", func(t *testing.T) { + t.Parallel() + fixture := newFixture(t) + run := createRun(t, fixture) + + step := createTestStep(t, fixture, run.ID) + updateTestStepWithUsage(t, fixture, step.ID, 20, 10, 0, 0) + + base := map[string]any{"first_message": "hello world"} + got, err := fixture.svc.AggregateRunSummary(fixture.ctx, run.ID, base) + require.NoError(t, err) + require.Equal(t, "hello world", got["first_message"]) + require.EqualValues(t, 1, got["step_count"]) + require.EqualValues(t, int64(20), got["total_input_tokens"]) + require.EqualValues(t, int64(10), got["total_output_tokens"]) + }) + + t.Run("ClearsStaleDerivedFields", func(t *testing.T) { + t.Parallel() + fixture := newFixture(t) + run := createRun(t, fixture) + + step := createTestStep(t, fixture, run.ID) + updateTestStepWithUsage(t, fixture, step.ID, 10, 5, 0, 0) + + base := map[string]any{ + "first_message": "hello world", + "step_count": 9, + "total_input_tokens": 999, + "total_output_tokens": 888, + "total_reasoning_tokens": 777, + "total_cache_creation_tokens": 100, + "total_cache_read_tokens": 200, + "has_error": true, + "endpoint_label": "POST /stale", + } + + got, err := fixture.svc.AggregateRunSummary(fixture.ctx, run.ID, base) + require.NoError(t, err) + require.Equal(t, "hello world", got["first_message"]) + require.EqualValues(t, 1, got["step_count"]) + require.EqualValues(t, int64(10), got["total_input_tokens"]) + require.EqualValues(t, int64(5), got["total_output_tokens"]) + // Stale reasoning tokens must be cleared because the step + // has zero reasoning tokens. + require.NotContains(t, got, "total_reasoning_tokens") + require.NotContains(t, got, "total_cache_creation_tokens") + require.NotContains(t, got, "total_cache_read_tokens") + // has_error must be cleared because the step is not in error + // status and has no error payload. + require.NotContains(t, got, "has_error") + require.NotContains(t, got, "endpoint_label") + }) + + t.Run("RecomputesHasErrorAndCompletedEndpointLabel", func(t *testing.T) { + t.Parallel() + fixture := newFixture(t) + run := createRun(t, fixture) + + step1 := createTestStep(t, fixture, run.ID) + _, err := fixture.svc.UpdateStep(fixture.ctx, chatdebug.UpdateStepParams{ + ID: step1.ID, + ChatID: fixture.chat.ID, + Status: chatdebug.StatusError, + Attempts: []chatdebug.Attempt{{ + Number: 1, + Status: "failed", + Method: "POST", + Path: "/failed", + }}, + }) + require.NoError(t, err) + + step2 := createTestStepN(t, fixture, run.ID, 2) + _, err = fixture.svc.UpdateStep(fixture.ctx, chatdebug.UpdateStepParams{ + ID: step2.ID, + ChatID: fixture.chat.ID, + Status: chatdebug.StatusCompleted, + Attempts: []chatdebug.Attempt{{ + Number: 1, + Status: "completed", + Method: "POST", + Path: "/v1/messages", + }}, + }) + require.NoError(t, err) + + got, err := fixture.svc.AggregateRunSummary(fixture.ctx, run.ID, nil) + require.NoError(t, err) + require.Equal(t, true, got["has_error"]) + require.Equal(t, "POST /v1/messages", got["endpoint_label"]) + }) + + t.Run("EndpointLabelPathOnlyWhenMethodEmpty", func(t *testing.T) { + t.Parallel() + fixture := newFixture(t) + run := createRun(t, fixture) + + step := createTestStep(t, fixture, run.ID) + _, err := fixture.svc.UpdateStep(fixture.ctx, chatdebug.UpdateStepParams{ + ID: step.ID, + ChatID: fixture.chat.ID, + Status: chatdebug.StatusCompleted, + Attempts: []chatdebug.Attempt{{ + Number: 1, + Status: "completed", + Method: "", + Path: "/v1/messages", + }}, + }) + require.NoError(t, err) + + got, err := fixture.svc.AggregateRunSummary(fixture.ctx, run.ID, nil) + require.NoError(t, err) + require.Equal(t, "/v1/messages", got["endpoint_label"], + "endpoint_label should be path-only when method is empty") + }) + + t.Run("InterruptedStepWithErrorExcludedFromHasError", func(t *testing.T) { + t.Parallel() + fixture := newFixture(t) + run := createRun(t, fixture) + + // An interrupted step with a real error payload should NOT + // trigger has_error. Interrupted means user-initiated + // cancellation (e.g. clicking Stop). + step := createTestStep(t, fixture, run.ID) + _, err := fixture.svc.UpdateStep(fixture.ctx, chatdebug.UpdateStepParams{ + ID: step.ID, + ChatID: fixture.chat.ID, + Status: chatdebug.StatusInterrupted, + Error: map[string]any{"message": "user canceled"}, + }) + require.NoError(t, err) + + got, err := fixture.svc.AggregateRunSummary(fixture.ctx, run.ID, nil) + require.NoError(t, err) + require.NotContains(t, got, "has_error", + "interrupted steps should not trigger has_error even with error payload") + }) + + t.Run("MultipleStepsSumTokens", func(t *testing.T) { + t.Parallel() + fixture := newFixture(t) + run := createRun(t, fixture) + + step1 := createTestStep(t, fixture, run.ID) + updateTestStepWithUsage(t, fixture, step1.ID, 10, 5, 2, 3) + + step2 := createTestStepN(t, fixture, run.ID, 2) + updateTestStepWithUsage(t, fixture, step2.ID, 15, 7, 1, 4) + + got, err := fixture.svc.AggregateRunSummary(fixture.ctx, run.ID, nil) + require.NoError(t, err) + require.EqualValues(t, 2, got["step_count"]) + require.EqualValues(t, int64(25), got["total_input_tokens"]) + require.EqualValues(t, int64(12), got["total_output_tokens"]) + require.EqualValues(t, int64(3), got["total_cache_creation_tokens"]) + require.EqualValues(t, int64(7), got["total_cache_read_tokens"]) + }) + + t.Run("StepWithNilUsageContributesZeroTokens", func(t *testing.T) { + t.Parallel() + fixture := newFixture(t) + run := createRun(t, fixture) + + // Step with usage. + step1 := createTestStep(t, fixture, run.ID) + updateTestStepWithUsage(t, fixture, step1.ID, 10, 5, 0, 0) + + // Step without usage (just complete it, no usage). + step2 := createTestStepN(t, fixture, run.ID, 2) + _, err := fixture.svc.UpdateStep(fixture.ctx, chatdebug.UpdateStepParams{ + ID: step2.ID, + ChatID: fixture.chat.ID, + Status: chatdebug.StatusCompleted, + }) + require.NoError(t, err) + + got, err := fixture.svc.AggregateRunSummary(fixture.ctx, run.ID, nil) + require.NoError(t, err) + // Both steps are counted even though one has no usage. + require.EqualValues(t, 2, got["step_count"]) + require.EqualValues(t, int64(10), got["total_input_tokens"]) + require.EqualValues(t, int64(5), got["total_output_tokens"]) + }) + + t.Run("ZeroCacheTotalsOmitCacheFields", func(t *testing.T) { + t.Parallel() + fixture := newFixture(t) + run := createRun(t, fixture) + + step := createTestStep(t, fixture, run.ID) + updateTestStepWithUsage(t, fixture, step.ID, 10, 5, 0, 0) + + got, err := fixture.svc.AggregateRunSummary(fixture.ctx, run.ID, nil) + require.NoError(t, err) + _, hasCacheCreation := got["total_cache_creation_tokens"] + _, hasCacheRead := got["total_cache_read_tokens"] + require.False(t, hasCacheCreation, + "cache creation tokens should be omitted when zero") + require.False(t, hasCacheRead, + "cache read tokens should be omitted when zero") + }) + + t.Run("ReasoningTokensSummedAcrossSteps", func(t *testing.T) { + t.Parallel() + fixture := newFixture(t) + run := createRun(t, fixture) + + step1 := createTestStep(t, fixture, run.ID) + updateTestStepWithFullUsage(t, fixture, step1.ID, 10, 5, 20, 0, 0) + + step2 := createTestStepN(t, fixture, run.ID, 2) + updateTestStepWithFullUsage(t, fixture, step2.ID, 15, 7, 30, 0, 0) + + got, err := fixture.svc.AggregateRunSummary(fixture.ctx, run.ID, nil) + require.NoError(t, err) + require.EqualValues(t, 2, got["step_count"]) + require.EqualValues(t, int64(25), got["total_input_tokens"]) + require.EqualValues(t, int64(12), got["total_output_tokens"]) + require.EqualValues(t, int64(50), got["total_reasoning_tokens"], + "reasoning tokens should be summed across steps") + }) + + t.Run("ZeroReasoningTokensOmitsField", func(t *testing.T) { + t.Parallel() + fixture := newFixture(t) + run := createRun(t, fixture) + + step := createTestStep(t, fixture, run.ID) + updateTestStepWithFullUsage(t, fixture, step.ID, 10, 5, 0, 0, 0) + + got, err := fixture.svc.AggregateRunSummary(fixture.ctx, run.ID, nil) + require.NoError(t, err) + _, hasReasoning := got["total_reasoning_tokens"] + require.False(t, hasReasoning, + "reasoning tokens should be omitted when zero") + }) + + t.Run("MalformedUsageJSONSkipped", func(t *testing.T) { + t.Parallel() + fixture := newFixture(t) + run := createRun(t, fixture) + + // Step 1 has valid usage and should contribute to totals. + step1 := createTestStep(t, fixture, run.ID) + updateTestStepWithUsage(t, fixture, step1.ID, 10, 5, 0, 0) + + // Step 2 is stamped with structurally-valid JSONB that cannot + // unmarshal into fantasy.Usage (string where int64 is + // expected). Write directly through the store so the jsonb + // cast succeeds while the Go unmarshal fails, exercising the + // "skipping malformed step usage JSON" log-and-continue path. + step2 := createTestStepN(t, fixture, run.ID, 2) + _, err := fixture.db.UpdateChatDebugStep(fixture.ctx, database.UpdateChatDebugStepParams{ + ID: step2.ID, + ChatID: fixture.chat.ID, + Usage: pqtype.NullRawMessage{ + RawMessage: json.RawMessage(`{"input_tokens":"not-a-number"}`), + Valid: true, + }, + Now: time.Now(), + }) + require.NoError(t, err) + + got, err := fixture.svc.AggregateRunSummary(fixture.ctx, run.ID, nil) + require.NoError(t, err, + "malformed usage JSON must be skipped, not surfaced as an error") + + // Both steps are counted, but only step1's tokens contribute. + require.EqualValues(t, 2, got["step_count"]) + require.EqualValues(t, int64(10), got["total_input_tokens"]) + require.EqualValues(t, int64(5), got["total_output_tokens"]) + }) +} + +// createTestStep is a thin helper that creates a debug step with +// step number 1 for the given run. +func createTestStep( + t *testing.T, + fixture testFixture, + runID uuid.UUID, +) database.ChatDebugStep { + t.Helper() + return createTestStepN(t, fixture, runID, 1) +} + +// createTestStepN creates a debug step with the given step number. +func createTestStepN( + t *testing.T, + fixture testFixture, + runID uuid.UUID, + stepNumber int32, +) database.ChatDebugStep { + t.Helper() + step, err := fixture.svc.CreateStep(fixture.ctx, chatdebug.CreateStepParams{ + RunID: runID, + ChatID: fixture.chat.ID, + StepNumber: stepNumber, + Operation: chatdebug.OperationGenerate, + Status: chatdebug.StatusInProgress, + }) + require.NoError(t, err) + return step +} + +// updateTestStepWithUsage completes a step and sets token usage fields. +func updateTestStepWithUsage( + t *testing.T, + fixture testFixture, + stepID uuid.UUID, + input, output, cacheCreation, cacheRead int64, +) { + t.Helper() + updateTestStepWithFullUsage(t, fixture, stepID, input, output, 0, cacheCreation, cacheRead) +} + +// updateTestStepWithFullUsage completes a step with all token usage +// fields, including reasoning tokens. +func updateTestStepWithFullUsage( + t *testing.T, + fixture testFixture, + stepID uuid.UUID, + input, output, reasoning, cacheCreation, cacheRead int64, +) { + t.Helper() + _, err := fixture.svc.UpdateStep(fixture.ctx, chatdebug.UpdateStepParams{ + ID: stepID, + ChatID: fixture.chat.ID, + Status: chatdebug.StatusCompleted, + Usage: map[string]any{ + "input_tokens": input, + "output_tokens": output, + "reasoning_tokens": reasoning, + "cache_creation_tokens": cacheCreation, + "cache_read_tokens": cacheRead, + }, + }) + require.NoError(t, err) +} diff --git a/coderd/x/chatd/chatdebug/types.go b/coderd/x/chatd/chatdebug/types.go index 115b70b2d5..0d744be26f 100644 --- a/coderd/x/chatd/chatdebug/types.go +++ b/coderd/x/chatd/chatdebug/types.go @@ -39,6 +39,29 @@ const ( StatusInterrupted Status = "interrupted" ) +// IsTerminal reports whether the status represents a final state +// that should not be overwritten by stale callbacks. +func (s Status) IsTerminal() bool { + return s.Priority() > 0 +} + +// Priority returns a numeric ordering used to prevent stale callbacks +// from regressing a step's status. Higher values win over lower ones. +func (s Status) Priority() int { + switch s { + case StatusInProgress: + return 0 + case StatusInterrupted: + return 1 + case StatusError: + return 2 + case StatusCompleted: + return 3 + default: + return 0 + } +} + // AllStatuses contains every Status value. Update this when // adding new constants above. var AllStatuses = []Status{ @@ -131,7 +154,16 @@ type DebugEvent struct { StepID uuid.UUID `json:"step_id"` } +// BroadcastPubsubChannel is the shared pubsub channel for chat-debug events +// that are not scoped to a single chat, such as stale finalization sweeps. +const BroadcastPubsubChannel = "chat_debug:broadcast" + // PubsubChannel returns the chat-scoped pubsub channel for debug events. +// Nil chat IDs use the shared broadcast channel so publishers and subscribers +// can coordinate through one discoverable helper. func PubsubChannel(chatID uuid.UUID) string { + if chatID == uuid.Nil { + return BroadcastPubsubChannel + } return "chat_debug:" + chatID.String() }