diff --git a/coderd/coderd.go b/coderd/coderd.go index e5314ef5ee..e61e5fe8c5 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -1234,6 +1234,7 @@ func New(options *Options) *API { r.Get("/git", api.watchChatGit) }) r.Post("/interrupt", api.interruptChat) + r.Post("/title/regenerate", api.regenerateChatTitle) r.Get("/diff", api.getChatDiffContents) r.Route("/queue/{queuedMessage}", func(r chi.Router) { r.Delete("/", api.deleteChatQueuedMessage) diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index ba86137ab8..fba344a914 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -2614,6 +2614,14 @@ func (q *querier) GetChatMessagesByChatID(ctx context.Context, arg database.GetC return q.db.GetChatMessagesByChatID(ctx, arg) } +func (q *querier) GetChatMessagesByChatIDAscPaginated(ctx context.Context, arg database.GetChatMessagesByChatIDAscPaginatedParams) ([]database.ChatMessage, error) { + _, err := q.GetChatByID(ctx, arg.ChatID) + if err != nil { + return nil, err + } + return q.db.GetChatMessagesByChatIDAscPaginated(ctx, arg) +} + func (q *querier) GetChatMessagesByChatIDDescPaginated(ctx context.Context, arg database.GetChatMessagesByChatIDDescPaginatedParams) ([]database.ChatMessage, error) { _, err := q.GetChatByID(ctx, arg.ChatID) if err != nil { @@ -5736,6 +5744,17 @@ func (q *querier) UpdateChatLabelsByID(ctx context.Context, arg database.UpdateC return q.db.UpdateChatLabelsByID(ctx, arg) } +func (q *querier) UpdateChatLastModelConfigByID(ctx context.Context, arg database.UpdateChatLastModelConfigByIDParams) (database.Chat, error) { + chat, err := q.db.GetChatByID(ctx, arg.ID) + if err != nil { + return database.Chat{}, err + } + if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil { + return database.Chat{}, err + } + return q.db.UpdateChatLastModelConfigByID(ctx, arg) +} + func (q *querier) UpdateChatMCPServerIDs(ctx context.Context, arg database.UpdateChatMCPServerIDsParams) (database.Chat, error) { chat, err := q.db.GetChatByID(ctx, arg.ID) if err != nil { @@ -5801,6 +5820,17 @@ func (q *querier) UpdateChatStatus(ctx context.Context, arg database.UpdateChatS return q.db.UpdateChatStatus(ctx, arg) } +func (q *querier) UpdateChatStatusPreserveUpdatedAt(ctx context.Context, arg database.UpdateChatStatusPreserveUpdatedAtParams) (database.Chat, error) { + chat, err := q.db.GetChatByID(ctx, arg.ID) + if err != nil { + return database.Chat{}, err + } + if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil { + return database.Chat{}, err + } + return q.db.UpdateChatStatusPreserveUpdatedAt(ctx, arg) +} + func (q *querier) UpdateChatWorkspaceBinding(ctx context.Context, arg database.UpdateChatWorkspaceBindingParams) (database.Chat, error) { chat, err := q.db.GetChatByID(ctx, arg.ID) if err != nil { diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 42e81e3e2e..f0cf24c9c2 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -592,6 +592,14 @@ func (s *MethodTestSuite) TestChats() { dbm.EXPECT().GetChatMessagesByChatID(gomock.Any(), arg).Return(msgs, nil).AnyTimes() check.Args(arg).Asserts(chat, policy.ActionRead).Returns(msgs) })) + s.Run("GetChatMessagesByChatIDAscPaginated", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + chat := testutil.Fake(s.T(), faker, database.Chat{}) + msgs := []database.ChatMessage{testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})} + arg := database.GetChatMessagesByChatIDAscPaginatedParams{ChatID: chat.ID, AfterID: 0, LimitVal: 50} + dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes() + dbm.EXPECT().GetChatMessagesByChatIDAscPaginated(gomock.Any(), arg).Return(msgs, nil).AnyTimes() + check.Args(arg).Asserts(chat, policy.ActionRead).Returns(msgs) + })) s.Run("GetChatMessagesByChatIDDescPaginated", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { chat := testutil.Fake(s.T(), faker, database.Chat{}) msgs := []database.ChatMessage{testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})} @@ -789,6 +797,26 @@ func (s *MethodTestSuite) TestChats() { dbm.EXPECT().UpdateChatLabelsByID(gomock.Any(), arg).Return(chat, nil).AnyTimes() check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat) })) + s.Run("UpdateChatLastModelConfigByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + chat := testutil.Fake(s.T(), faker, database.Chat{}) + arg := database.UpdateChatLastModelConfigByIDParams{ + ID: chat.ID, + LastModelConfigID: uuid.New(), + } + dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes() + dbm.EXPECT().UpdateChatLastModelConfigByID(gomock.Any(), arg).Return(chat, nil).AnyTimes() + check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat) + })) + s.Run("UpdateChatStatusPreserveUpdatedAt", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + chat := testutil.Fake(s.T(), faker, database.Chat{}) + arg := database.UpdateChatStatusPreserveUpdatedAtParams{ + ID: chat.ID, + Status: database.ChatStatusRunning, + } + dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes() + dbm.EXPECT().UpdateChatStatusPreserveUpdatedAt(gomock.Any(), arg).Return(chat, nil).AnyTimes() + check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat) + })) s.Run("UpdateChatHeartbeat", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { chat := testutil.Fake(s.T(), faker, database.Chat{}) arg := database.UpdateChatHeartbeatParams{ diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index 53cd82a3da..64882c15b7 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -1144,6 +1144,14 @@ func (m queryMetricsStore) GetChatMessagesByChatID(ctx context.Context, chatID d return r0, r1 } +func (m queryMetricsStore) GetChatMessagesByChatIDAscPaginated(ctx context.Context, arg database.GetChatMessagesByChatIDAscPaginatedParams) ([]database.ChatMessage, error) { + start := time.Now() + r0, r1 := m.s.GetChatMessagesByChatIDAscPaginated(ctx, arg) + m.queryLatencies.WithLabelValues("GetChatMessagesByChatIDAscPaginated").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatMessagesByChatIDAscPaginated").Inc() + return r0, r1 +} + func (m queryMetricsStore) GetChatMessagesByChatIDDescPaginated(ctx context.Context, arg database.GetChatMessagesByChatIDDescPaginatedParams) ([]database.ChatMessage, error) { start := time.Now() r0, r1 := m.s.GetChatMessagesByChatIDDescPaginated(ctx, arg) @@ -4096,6 +4104,14 @@ func (m queryMetricsStore) UpdateChatLabelsByID(ctx context.Context, arg databas return r0, r1 } +func (m queryMetricsStore) UpdateChatLastModelConfigByID(ctx context.Context, arg database.UpdateChatLastModelConfigByIDParams) (database.Chat, error) { + start := time.Now() + r0, r1 := m.s.UpdateChatLastModelConfigByID(ctx, arg) + m.queryLatencies.WithLabelValues("UpdateChatLastModelConfigByID").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatLastModelConfigByID").Inc() + return r0, r1 +} + func (m queryMetricsStore) UpdateChatMCPServerIDs(ctx context.Context, arg database.UpdateChatMCPServerIDsParams) (database.Chat, error) { start := time.Now() r0, r1 := m.s.UpdateChatMCPServerIDs(ctx, arg) @@ -4144,6 +4160,14 @@ func (m queryMetricsStore) UpdateChatStatus(ctx context.Context, arg database.Up return r0, r1 } +func (m queryMetricsStore) UpdateChatStatusPreserveUpdatedAt(ctx context.Context, arg database.UpdateChatStatusPreserveUpdatedAtParams) (database.Chat, error) { + start := time.Now() + r0, r1 := m.s.UpdateChatStatusPreserveUpdatedAt(ctx, arg) + m.queryLatencies.WithLabelValues("UpdateChatStatusPreserveUpdatedAt").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatStatusPreserveUpdatedAt").Inc() + return r0, r1 +} + func (m queryMetricsStore) UpdateChatWorkspaceBinding(ctx context.Context, arg database.UpdateChatWorkspaceBindingParams) (database.Chat, error) { start := time.Now() r0, r1 := m.s.UpdateChatWorkspaceBinding(ctx, arg) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index ea437248c4..31b7843e74 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -2103,6 +2103,21 @@ func (mr *MockStoreMockRecorder) GetChatMessagesByChatID(ctx, arg any) *gomock.C return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatMessagesByChatID", reflect.TypeOf((*MockStore)(nil).GetChatMessagesByChatID), ctx, arg) } +// GetChatMessagesByChatIDAscPaginated mocks base method. +func (m *MockStore) GetChatMessagesByChatIDAscPaginated(ctx context.Context, arg database.GetChatMessagesByChatIDAscPaginatedParams) ([]database.ChatMessage, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetChatMessagesByChatIDAscPaginated", ctx, arg) + ret0, _ := ret[0].([]database.ChatMessage) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetChatMessagesByChatIDAscPaginated indicates an expected call of GetChatMessagesByChatIDAscPaginated. +func (mr *MockStoreMockRecorder) GetChatMessagesByChatIDAscPaginated(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatMessagesByChatIDAscPaginated", reflect.TypeOf((*MockStore)(nil).GetChatMessagesByChatIDAscPaginated), ctx, arg) +} + // GetChatMessagesByChatIDDescPaginated mocks base method. func (m *MockStore) GetChatMessagesByChatIDDescPaginated(ctx context.Context, arg database.GetChatMessagesByChatIDDescPaginatedParams) ([]database.ChatMessage, error) { m.ctrl.T.Helper() @@ -7745,6 +7760,21 @@ func (mr *MockStoreMockRecorder) UpdateChatLabelsByID(ctx, arg any) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatLabelsByID", reflect.TypeOf((*MockStore)(nil).UpdateChatLabelsByID), ctx, arg) } +// UpdateChatLastModelConfigByID mocks base method. +func (m *MockStore) UpdateChatLastModelConfigByID(ctx context.Context, arg database.UpdateChatLastModelConfigByIDParams) (database.Chat, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateChatLastModelConfigByID", ctx, arg) + ret0, _ := ret[0].(database.Chat) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpdateChatLastModelConfigByID indicates an expected call of UpdateChatLastModelConfigByID. +func (mr *MockStoreMockRecorder) UpdateChatLastModelConfigByID(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatLastModelConfigByID", reflect.TypeOf((*MockStore)(nil).UpdateChatLastModelConfigByID), ctx, arg) +} + // UpdateChatMCPServerIDs mocks base method. func (m *MockStore) UpdateChatMCPServerIDs(ctx context.Context, arg database.UpdateChatMCPServerIDsParams) (database.Chat, error) { m.ctrl.T.Helper() @@ -7834,6 +7864,21 @@ func (mr *MockStoreMockRecorder) UpdateChatStatus(ctx, arg any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatStatus", reflect.TypeOf((*MockStore)(nil).UpdateChatStatus), ctx, arg) } +// UpdateChatStatusPreserveUpdatedAt mocks base method. +func (m *MockStore) UpdateChatStatusPreserveUpdatedAt(ctx context.Context, arg database.UpdateChatStatusPreserveUpdatedAtParams) (database.Chat, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateChatStatusPreserveUpdatedAt", ctx, arg) + ret0, _ := ret[0].(database.Chat) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpdateChatStatusPreserveUpdatedAt indicates an expected call of UpdateChatStatusPreserveUpdatedAt. +func (mr *MockStoreMockRecorder) UpdateChatStatusPreserveUpdatedAt(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatStatusPreserveUpdatedAt", reflect.TypeOf((*MockStore)(nil).UpdateChatStatusPreserveUpdatedAt), ctx, arg) +} + // UpdateChatWorkspaceBinding mocks base method. func (m *MockStore) UpdateChatWorkspaceBinding(ctx context.Context, arg database.UpdateChatWorkspaceBindingParams) (database.Chat, error) { m.ctrl.T.Helper() diff --git a/coderd/database/querier.go b/coderd/database/querier.go index cfc25b077e..4d5bafea57 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -250,6 +250,7 @@ type sqlcQuerier interface { GetChatIncludeDefaultSystemPrompt(ctx context.Context) (bool, error) GetChatMessageByID(ctx context.Context, id int64) (ChatMessage, error) GetChatMessagesByChatID(ctx context.Context, arg GetChatMessagesByChatIDParams) ([]ChatMessage, error) + GetChatMessagesByChatIDAscPaginated(ctx context.Context, arg GetChatMessagesByChatIDAscPaginatedParams) ([]ChatMessage, error) GetChatMessagesByChatIDDescPaginated(ctx context.Context, arg GetChatMessagesByChatIDDescPaginatedParams) ([]ChatMessage, error) GetChatMessagesForPromptByChatID(ctx context.Context, chatID uuid.UUID) ([]ChatMessage, error) GetChatModelConfigByID(ctx context.Context, id uuid.UUID) (ChatModelConfig, error) @@ -852,12 +853,14 @@ type sqlcQuerier interface { // replicas know the worker is still alive. UpdateChatHeartbeat(ctx context.Context, arg UpdateChatHeartbeatParams) (int64, error) UpdateChatLabelsByID(ctx context.Context, arg UpdateChatLabelsByIDParams) (Chat, error) + UpdateChatLastModelConfigByID(ctx context.Context, arg UpdateChatLastModelConfigByIDParams) (Chat, error) UpdateChatMCPServerIDs(ctx context.Context, arg UpdateChatMCPServerIDsParams) (Chat, error) UpdateChatMessageByID(ctx context.Context, arg UpdateChatMessageByIDParams) (ChatMessage, error) UpdateChatModelConfig(ctx context.Context, arg UpdateChatModelConfigParams) (ChatModelConfig, error) UpdateChatPinOrder(ctx context.Context, arg UpdateChatPinOrderParams) error UpdateChatProvider(ctx context.Context, arg UpdateChatProviderParams) (ChatProvider, error) UpdateChatStatus(ctx context.Context, arg UpdateChatStatusParams) (Chat, error) + UpdateChatStatusPreserveUpdatedAt(ctx context.Context, arg UpdateChatStatusPreserveUpdatedAtParams) (Chat, error) UpdateChatWorkspaceBinding(ctx context.Context, arg UpdateChatWorkspaceBindingParams) (Chat, error) UpdateCryptoKeyDeletesAt(ctx context.Context, arg UpdateCryptoKeyDeletesAtParams) (CryptoKey, error) UpdateCustomRole(ctx context.Context, arg UpdateCustomRoleParams) (CustomRole, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 06ba299204..5649048776 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -4932,6 +4932,73 @@ func (q *sqlQuerier) GetChatMessagesByChatID(ctx context.Context, arg GetChatMes return items, nil } +const getChatMessagesByChatIDAscPaginated = `-- name: GetChatMessagesByChatIDAscPaginated :many +SELECT + id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id +FROM + chat_messages +WHERE + chat_id = $1::uuid + AND id > $2::bigint + AND visibility IN ('user', 'both') + AND deleted = false +ORDER BY + id ASC +LIMIT + COALESCE(NULLIF($3::int, 0), 50) +` + +type GetChatMessagesByChatIDAscPaginatedParams struct { + ChatID uuid.UUID `db:"chat_id" json:"chat_id"` + AfterID int64 `db:"after_id" json:"after_id"` + LimitVal int32 `db:"limit_val" json:"limit_val"` +} + +func (q *sqlQuerier) GetChatMessagesByChatIDAscPaginated(ctx context.Context, arg GetChatMessagesByChatIDAscPaginatedParams) ([]ChatMessage, error) { + rows, err := q.db.QueryContext(ctx, getChatMessagesByChatIDAscPaginated, arg.ChatID, arg.AfterID, arg.LimitVal) + if err != nil { + return nil, err + } + defer rows.Close() + var items []ChatMessage + for rows.Next() { + var i ChatMessage + if err := rows.Scan( + &i.ID, + &i.ChatID, + &i.ModelConfigID, + &i.CreatedAt, + &i.Role, + &i.Content, + &i.Visibility, + &i.InputTokens, + &i.OutputTokens, + &i.TotalTokens, + &i.ReasoningTokens, + &i.CacheCreationTokens, + &i.CacheReadTokens, + &i.ContextLimit, + &i.Compressed, + &i.CreatedBy, + &i.ContentVersion, + &i.TotalCostMicros, + &i.RuntimeMs, + &i.Deleted, + &i.ProviderResponseID, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getChatMessagesByChatIDDescPaginated = `-- name: GetChatMessagesByChatIDDescPaginated :many SELECT id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id @@ -6254,6 +6321,52 @@ func (q *sqlQuerier) UpdateChatLabelsByID(ctx context.Context, arg UpdateChatLab return i, err } +const updateChatLastModelConfigByID = `-- name: UpdateChatLastModelConfigByID :one +UPDATE + chats +SET + -- NOTE: updated_at is intentionally NOT touched here to avoid changing list ordering. + last_model_config_id = $1::uuid +WHERE + id = $2::uuid +RETURNING + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order +` + +type UpdateChatLastModelConfigByIDParams struct { + LastModelConfigID uuid.UUID `db:"last_model_config_id" json:"last_model_config_id"` + ID uuid.UUID `db:"id" json:"id"` +} + +func (q *sqlQuerier) UpdateChatLastModelConfigByID(ctx context.Context, arg UpdateChatLastModelConfigByIDParams) (Chat, error) { + row := q.db.QueryRowContext(ctx, updateChatLastModelConfigByID, arg.LastModelConfigID, arg.ID) + var i Chat + err := row.Scan( + &i.ID, + &i.OwnerID, + &i.WorkspaceID, + &i.Title, + &i.Status, + &i.WorkerID, + &i.StartedAt, + &i.HeartbeatAt, + &i.CreatedAt, + &i.UpdatedAt, + &i.ParentChatID, + &i.RootChatID, + &i.LastModelConfigID, + &i.Archived, + &i.LastError, + &i.Mode, + pq.Array(&i.MCPServerIDs), + &i.Labels, + &i.BuildID, + &i.AgentID, + &i.PinOrder, + ) + return i, err +} + const updateChatMCPServerIDs = `-- name: UpdateChatMCPServerIDs :one UPDATE chats @@ -6479,6 +6592,69 @@ func (q *sqlQuerier) UpdateChatStatus(ctx context.Context, arg UpdateChatStatusP return i, err } +const updateChatStatusPreserveUpdatedAt = `-- name: UpdateChatStatusPreserveUpdatedAt :one +UPDATE + chats +SET + status = $1::chat_status, + worker_id = $2::uuid, + started_at = $3::timestamptz, + heartbeat_at = $4::timestamptz, + last_error = $5::text, + updated_at = $6::timestamptz +WHERE + id = $7::uuid +RETURNING + id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order +` + +type UpdateChatStatusPreserveUpdatedAtParams struct { + Status ChatStatus `db:"status" json:"status"` + WorkerID uuid.NullUUID `db:"worker_id" json:"worker_id"` + StartedAt sql.NullTime `db:"started_at" json:"started_at"` + HeartbeatAt sql.NullTime `db:"heartbeat_at" json:"heartbeat_at"` + LastError sql.NullString `db:"last_error" json:"last_error"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + ID uuid.UUID `db:"id" json:"id"` +} + +func (q *sqlQuerier) UpdateChatStatusPreserveUpdatedAt(ctx context.Context, arg UpdateChatStatusPreserveUpdatedAtParams) (Chat, error) { + row := q.db.QueryRowContext(ctx, updateChatStatusPreserveUpdatedAt, + arg.Status, + arg.WorkerID, + arg.StartedAt, + arg.HeartbeatAt, + arg.LastError, + arg.UpdatedAt, + arg.ID, + ) + var i Chat + err := row.Scan( + &i.ID, + &i.OwnerID, + &i.WorkspaceID, + &i.Title, + &i.Status, + &i.WorkerID, + &i.StartedAt, + &i.HeartbeatAt, + &i.CreatedAt, + &i.UpdatedAt, + &i.ParentChatID, + &i.RootChatID, + &i.LastModelConfigID, + &i.Archived, + &i.LastError, + &i.Mode, + pq.Array(&i.MCPServerIDs), + &i.Labels, + &i.BuildID, + &i.AgentID, + &i.PinOrder, + ) + return i, err +} + const updateChatWorkspaceBinding = `-- name: UpdateChatWorkspaceBinding :one UPDATE chats SET workspace_id = $1::uuid, diff --git a/coderd/database/queries/chats.sql b/coderd/database/queries/chats.sql index 957e917f65..cd960f3c32 100644 --- a/coderd/database/queries/chats.sql +++ b/coderd/database/queries/chats.sql @@ -220,6 +220,21 @@ WHERE ORDER BY created_at ASC; +-- name: GetChatMessagesByChatIDAscPaginated :many +SELECT + * +FROM + chat_messages +WHERE + chat_id = @chat_id::uuid + AND id > @after_id::bigint + AND visibility IN ('user', 'both') + AND deleted = false +ORDER BY + id ASC +LIMIT + COALESCE(NULLIF(@limit_val::int, 0), 50); + -- name: GetChatMessagesByChatIDDescPaginated :many SELECT * @@ -466,6 +481,17 @@ WHERE RETURNING *; +-- name: UpdateChatLastModelConfigByID :one +UPDATE + chats +SET + -- NOTE: updated_at is intentionally NOT touched here to avoid changing list ordering. + last_model_config_id = @last_model_config_id::uuid +WHERE + id = @id::uuid +RETURNING + *; + -- name: UpdateChatLabelsByID :one UPDATE chats @@ -550,6 +576,21 @@ WHERE RETURNING *; +-- name: UpdateChatStatusPreserveUpdatedAt :one +UPDATE + chats +SET + status = @status::chat_status, + worker_id = sqlc.narg('worker_id')::uuid, + started_at = sqlc.narg('started_at')::timestamptz, + heartbeat_at = sqlc.narg('heartbeat_at')::timestamptz, + last_error = sqlc.narg('last_error')::text, + updated_at = @updated_at::timestamptz +WHERE + id = @id::uuid +RETURNING + *; + -- name: GetStaleChats :many -- Find chats that appear stuck (running but heartbeat has expired). -- Used for recovery after coderd crashes or long hangs. diff --git a/coderd/exp_chats.go b/coderd/exp_chats.go index 2c8d80fbf1..34511426f0 100644 --- a/coderd/exp_chats.go +++ b/coderd/exp_chats.go @@ -2105,6 +2105,50 @@ func (api *API) interruptChat(rw http.ResponseWriter, r *http.Request) { httpapi.Write(ctx, rw, http.StatusOK, db2sdk.Chat(chat, nil)) } +// EXPERIMENTAL: this endpoint is experimental and is subject to change. +// +//nolint:revive // HTTP handler writes to ResponseWriter. +func (api *API) regenerateChatTitle(rw http.ResponseWriter, r *http.Request) { + ctx := r.Context() + chat := httpmw.ChatParam(r) + + if !api.Authorize(r, policy.ActionUpdate, chat.RBACObject()) { + httpapi.ResourceNotFound(rw) + return + } + if api.chatDaemon == nil { + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Chat processor is unavailable.", + Detail: "Chat processor is not configured.", + }) + return + } + + updatedChat, err := api.chatDaemon.RegenerateChatTitle(ctx, chat) + if err != nil { + if errors.Is(err, chatd.ErrManualTitleRegenerationInProgress) { + httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{ + Message: "Title regeneration already in progress for this chat.", + }) + return + } + if maybeWriteLimitErr(ctx, rw, err) { + return + } + if httpapi.Is404Error(err) { + httpapi.ResourceNotFound(rw) + return + } + httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ + Message: "Failed to regenerate chat title.", + Detail: err.Error(), + }) + return + } + + httpapi.Write(ctx, rw, http.StatusOK, db2sdk.Chat(updatedChat, nil)) +} + // EXPERIMENTAL: this endpoint is experimental and is subject to change. // //nolint:revive // HTTP handler writes to ResponseWriter. diff --git a/coderd/exp_chats_test.go b/coderd/exp_chats_test.go index 105b3e2173..cd70d67730 100644 --- a/coderd/exp_chats_test.go +++ b/coderd/exp_chats_test.go @@ -19,6 +19,7 @@ import ( "github.com/google/uuid" "github.com/shopspring/decimal" "github.com/stretchr/testify/require" + "golang.org/x/xerrors" "github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/coderd/coderdtest/oidctest" @@ -31,6 +32,7 @@ import ( "github.com/coder/coder/v2/coderd/externalauth" coderdpubsub "github.com/coder/coder/v2/coderd/pubsub" "github.com/coder/coder/v2/coderd/rbac" + "github.com/coder/coder/v2/coderd/rbac/policy" "github.com/coder/coder/v2/coderd/util/ptr" "github.com/coder/coder/v2/coderd/x/chatd" "github.com/coder/coder/v2/coderd/x/chatd/chatprompt" @@ -3441,6 +3443,268 @@ func TestInterruptChat(t *testing.T) { }) } +func TestRegenerateChatTitle(t *testing.T) { + t.Parallel() + + t.Run("ChatNotFound", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + _ = coderdtest.CreateFirstUser(t, client.Client) + + _, err := client.RegenerateChatTitle(ctx, uuid.New()) + requireSDKError(t, err, http.StatusNotFound) + }) + + t.Run("UpdateDenied", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + clientRaw, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{ + Authorizer: &coderdtest.FakeAuthorizer{ + ConditionalReturn: func(_ context.Context, _ rbac.Subject, action policy.Action, object rbac.Object) error { + if action == policy.ActionUpdate && object.Type == rbac.ResourceChat.Type { + return xerrors.New("denied") + } + return nil + }, + }, + DeploymentValues: chatDeploymentValues(t), + }) + client := codersdk.NewExperimentalClient(clientRaw) + user := coderdtest.CreateFirstUser(t, client.Client) + modelConfig := createChatModelConfig(t, client) + + chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ + OwnerID: user.UserID, + LastModelConfigID: modelConfig.ID, + Title: "chat with update denied", + }) + require.NoError(t, err) + + _, err = client.RegenerateChatTitle(ctx, chat.ID) + requireSDKError(t, err, http.StatusNotFound) + }) + + t.Run("NotFoundForDifferentUser", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + firstUser := coderdtest.CreateFirstUser(t, client.Client) + _ = createChatModelConfig(t, client) + + createdChat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ + Content: []codersdk.ChatInputPart{ + { + Type: codersdk.ChatInputPartTypeText, + Text: "private chat", + }, + }, + }) + require.NoError(t, err) + + otherClientRaw, _ := coderdtest.CreateAnotherUser(t, client.Client, firstUser.OrganizationID) + otherClient := codersdk.NewExperimentalClient(otherClientRaw) + _, err = otherClient.RegenerateChatTitle(ctx, createdChat.ID) + requireSDKError(t, err, http.StatusNotFound) + }) + + t.Run("Unauthenticated", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client := newChatClient(t) + _ = coderdtest.CreateFirstUser(t, client.Client) + _ = createChatModelConfig(t, client) + + chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ + Content: []codersdk.ChatInputPart{{ + Type: codersdk.ChatInputPartTypeText, + Text: "chat for unauthenticated regeneration", + }}, + }) + require.NoError(t, err) + + unauthenticatedClient := codersdk.NewExperimentalClient(codersdk.New(client.URL)) + _, err = unauthenticatedClient.RegenerateChatTitle(ctx, chat.ID) + requireSDKError(t, err, http.StatusUnauthorized) + }) + + t.Run("UsageLimitExceeded", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client, db := newChatClientWithDatabase(t) + _ = coderdtest.CreateFirstUser(t, client.Client) + modelConfig := createChatModelConfig(t, client) + + chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ + Content: []codersdk.ChatInputPart{{ + Type: codersdk.ChatInputPartTypeText, + Text: "chat over usage limit", + }}, + }) + require.NoError(t, err) + + wantResetsAt := enableDailyChatUsageLimit(ctx, t, db, 100) + insertAssistantCostMessage(ctx, t, db, chat.ID, modelConfig.ID, 100) + + _, err = db.UpdateChatStatus(dbauthz.AsSystemRestricted(ctx), database.UpdateChatStatusParams{ + ID: chat.ID, + Status: database.ChatStatusCompleted, + WorkerID: uuid.NullUUID{}, + StartedAt: sql.NullTime{}, + HeartbeatAt: sql.NullTime{}, + LastError: sql.NullString{}, + }) + require.NoError(t, err) + + _, err = client.RegenerateChatTitle(ctx, chat.ID) + limitErr := codersdk.ChatUsageLimitExceededFrom(err) + require.NotNil(t, limitErr) + require.Equal(t, "Chat usage limit exceeded.", limitErr.Message) + require.Equal(t, int64(100), limitErr.SpentMicros) + require.Equal(t, int64(100), limitErr.LimitMicros) + require.True( + t, + limitErr.ResetsAt.Equal(wantResetsAt), + "expected resets_at %s, got %s", + wantResetsAt.UTC().Format(time.RFC3339), + limitErr.ResetsAt.UTC().Format(time.RFC3339), + ) + }) + + t.Run("AlreadyInProgress", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client, db := newChatClientWithDatabase(t) + user := coderdtest.CreateFirstUser(t, client.Client) + modelConfig := createChatModelConfig(t, client) + + chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ + OwnerID: user.UserID, + LastModelConfigID: modelConfig.ID, + Title: "chat with lock held", + }) + require.NoError(t, err) + + _, err = db.UpdateChatStatus(dbauthz.AsSystemRestricted(ctx), database.UpdateChatStatusParams{ + ID: chat.ID, + Status: database.ChatStatusCompleted, + WorkerID: uuid.NullUUID{UUID: uuid.MustParse("00000000-0000-0000-0000-000000000001"), Valid: true}, + StartedAt: sql.NullTime{Time: time.Now(), Valid: true}, + HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true}, + LastError: sql.NullString{}, + }) + require.NoError(t, err) + + res, err := client.Request( + ctx, + http.MethodPost, + fmt.Sprintf("/api/experimental/chats/%s/title/regenerate", chat.ID), + nil, + ) + require.NoError(t, err) + defer res.Body.Close() + require.Equal(t, http.StatusConflict, res.StatusCode) + + var resp codersdk.Response + require.NoError(t, json.NewDecoder(res.Body).Decode(&resp)) + require.Equal(t, "Title regeneration already in progress for this chat.", resp.Message) + }) + + t.Run("PendingWithoutWorker", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client, db := newChatClientWithDatabase(t) + user := coderdtest.CreateFirstUser(t, client.Client) + modelConfig := createChatModelConfig(t, client) + + chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{ + OwnerID: user.UserID, + LastModelConfigID: modelConfig.ID, + Title: "pending chat without worker", + }) + require.NoError(t, err) + + chat, err = db.UpdateChatStatus(dbauthz.AsSystemRestricted(ctx), database.UpdateChatStatusParams{ + ID: chat.ID, + Status: database.ChatStatusPending, + WorkerID: uuid.NullUUID{}, + StartedAt: sql.NullTime{}, + HeartbeatAt: sql.NullTime{}, + LastError: sql.NullString{}, + }) + require.NoError(t, err) + + before, err := db.GetChatByID(dbauthz.AsSystemRestricted(ctx), chat.ID) + require.NoError(t, err) + + res, err := client.Request( + ctx, + http.MethodPost, + fmt.Sprintf("/api/experimental/chats/%s/title/regenerate", chat.ID), + nil, + ) + require.NoError(t, err) + defer res.Body.Close() + require.Equal(t, http.StatusConflict, res.StatusCode) + + var resp codersdk.Response + require.NoError(t, json.NewDecoder(res.Body).Decode(&resp)) + require.Equal(t, "Title regeneration already in progress for this chat.", resp.Message) + + persisted, err := db.GetChatByID(dbauthz.AsSystemRestricted(ctx), chat.ID) + require.NoError(t, err) + require.Equal(t, database.ChatStatusPending, persisted.Status) + require.False(t, persisted.WorkerID.Valid) + require.True(t, persisted.UpdatedAt.Equal(before.UpdatedAt)) + }) + + t.Run("RegenerationFailure", func(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitLong) + client, db := newChatClientWithDatabase(t) + _ = coderdtest.CreateFirstUser(t, client.Client) + _ = createChatModelConfig(t, client) + + chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{ + Content: []codersdk.ChatInputPart{ + { + Type: codersdk.ChatInputPartTypeText, + Text: "test chat", + }, + }, + }) + require.NoError(t, err) + + _, err = db.UpdateChatStatus(dbauthz.AsSystemRestricted(ctx), database.UpdateChatStatusParams{ + ID: chat.ID, + Status: database.ChatStatusCompleted, + WorkerID: uuid.NullUUID{}, + StartedAt: sql.NullTime{}, + HeartbeatAt: sql.NullTime{}, + LastError: sql.NullString{}, + }) + require.NoError(t, err) + + before, err := db.GetChatByID(dbauthz.AsSystemRestricted(ctx), chat.ID) + require.NoError(t, err) + + _, err = client.RegenerateChatTitle(ctx, chat.ID) + requireSDKError(t, err, http.StatusInternalServerError) + + after, err := db.GetChatByID(dbauthz.AsSystemRestricted(ctx), chat.ID) + require.NoError(t, err) + require.True(t, after.UpdatedAt.Equal(before.UpdatedAt)) + }) +} + func TestGetChatDiffStatus(t *testing.T) { t.Parallel() diff --git a/coderd/x/chatd/chatd.go b/coderd/x/chatd/chatd.go index 2c31176e4a..fd5685fa98 100644 --- a/coderd/x/chatd/chatd.go +++ b/coderd/x/chatd/chatd.go @@ -1333,6 +1333,477 @@ func (p *Server) InterruptChat( return updatedChat } +const manualTitleMessageWindowLimit = 50 + +var ErrManualTitleRegenerationInProgress = xerrors.New( + "manual title regeneration already in progress", +) + +type manualTitleGenerationError struct { + cause error + modelConfig database.ChatModelConfig + usage fantasy.Usage +} + +func (e *manualTitleGenerationError) Error() string { + return e.cause.Error() +} + +func (e *manualTitleGenerationError) Unwrap() error { + return e.cause +} + +var manualTitleLockWorkerID = uuid.MustParse( + "00000000-0000-0000-0000-000000000001", +) + +const manualTitleLockStaleAfter = time.Minute + +func isPendingOrRunningChatStatus(status database.ChatStatus) bool { + switch status { + case database.ChatStatusPending, database.ChatStatusRunning: + return true + default: + return false + } +} + +func isFreshManualTitleLock(chat database.Chat, now time.Time) bool { + if !chat.WorkerID.Valid || chat.WorkerID.UUID != manualTitleLockWorkerID { + return false + } + leaseAt := chat.HeartbeatAt + if !leaseAt.Valid { + leaseAt = chat.StartedAt + } + return leaseAt.Valid && leaseAt.Time.After(now.Add(-manualTitleLockStaleAfter)) +} + +// updateChatStatusPreserveUpdatedAt applies internal lock transitions without +// changing chat recency, because chat list ordering uses updated_at. +func updateChatStatusPreserveUpdatedAt( + ctx context.Context, + store database.Store, + chat database.Chat, + workerID uuid.NullUUID, + startedAt sql.NullTime, + heartbeatAt sql.NullTime, +) (database.Chat, error) { + return store.UpdateChatStatusPreserveUpdatedAt( + ctx, + database.UpdateChatStatusPreserveUpdatedAtParams{ + ID: chat.ID, + Status: chat.Status, + WorkerID: workerID, + StartedAt: startedAt, + HeartbeatAt: heartbeatAt, + LastError: chat.LastError, + UpdatedAt: chat.UpdatedAt, + }, + ) +} + +func (p *Server) acquireManualTitleLock(ctx context.Context, chatID uuid.UUID) error { + now := time.Now() + return p.db.InTx(func(tx database.Store) error { + lockedChat, err := tx.GetChatByIDForUpdate(ctx, chatID) + if err != nil { + return xerrors.Errorf("lock chat for manual title regeneration: %w", err) + } + if isPendingOrRunningChatStatus(lockedChat.Status) || + isFreshManualTitleLock(lockedChat, now) { + return ErrManualTitleRegenerationInProgress + } + _, err = updateChatStatusPreserveUpdatedAt( + ctx, + tx, + lockedChat, + uuid.NullUUID{UUID: manualTitleLockWorkerID, Valid: true}, + sql.NullTime{Time: now, Valid: true}, + sql.NullTime{Time: now, Valid: true}, + ) + if err != nil { + return xerrors.Errorf("mark chat for manual title regeneration: %w", err) + } + return nil + }, database.DefaultTXOptions().WithID("chat_title_regenerate_lock")) +} + +func (p *Server) releaseManualTitleLock(ctx context.Context, chatID uuid.UUID) { + cleanupCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second) + defer cancel() + + err := p.db.InTx(func(tx database.Store) error { + lockedChat, err := tx.GetChatByIDForUpdate(cleanupCtx, chatID) + if err != nil { + return xerrors.Errorf("lock chat to release manual title regeneration: %w", err) + } + if !lockedChat.WorkerID.Valid || lockedChat.WorkerID.UUID != manualTitleLockWorkerID { + return nil + } + _, err = updateChatStatusPreserveUpdatedAt( + cleanupCtx, + tx, + lockedChat, + uuid.NullUUID{}, + sql.NullTime{}, + sql.NullTime{}, + ) + if err != nil { + return xerrors.Errorf("clear manual title regeneration marker: %w", err) + } + return nil + }, database.DefaultTXOptions().WithID("chat_title_regenerate_unlock")) + if err != nil { + p.logger.Warn(cleanupCtx, "failed to release manual title regeneration marker", + slog.F("chat_id", chatID), + slog.Error(err), + ) + } +} + +// RegenerateChatTitle regenerates a chat title from the chat's visible +// messages, persists it when it changes, and broadcasts the update. +func (p *Server) RegenerateChatTitle( + ctx context.Context, + chat database.Chat, +) (database.Chat, error) { + // Reuse chatd's scoped auth context for deployment-config lookups while + // keeping chat ownership authorization at the HTTP layer. + //nolint:gocritic // Non-admin users need chatd-scoped config reads here. + chatdCtx := dbauthz.AsChatd(ctx) + keys, err := p.resolveProviderAPIKeys(chatdCtx) + if err != nil { + return database.Chat{}, xerrors.Errorf("resolve chat providers: %w", err) + } + if err := p.acquireManualTitleLock(ctx, chat.ID); err != nil { + return database.Chat{}, err + } + defer p.releaseManualTitleLock(chatdCtx, chat.ID) + + updatedChat, err := p.regenerateChatTitleWithStore( + chatdCtx, + p.db, + chat, + keys, + ) + if err != nil { + var generationErr *manualTitleGenerationError + if errors.As(err, &generationErr) { + // Reuse chatd's scoped auth context for failure accounting while + // detaching from request cancellation so usage is still recorded. + //nolint:gocritic // Failure accounting still needs chatd-scoped config reads. + recordCtx, recordCancel := context.WithTimeout( + dbauthz.AsChatd(context.WithoutCancel(ctx)), + 5*time.Second, + ) + defer recordCancel() + if _, recordErr := recordManualTitleUsage( + recordCtx, + p.db, + chat, + generationErr.modelConfig, + generationErr.usage, + "", + ); recordErr != nil { + return database.Chat{}, errors.Join( + generationErr, + xerrors.Errorf("record manual title usage: %w", recordErr), + ) + } + return database.Chat{}, generationErr + } + return database.Chat{}, err + } + return updatedChat, nil +} + +func (p *Server) regenerateChatTitleWithStore( + ctx context.Context, + store database.Store, + chat database.Chat, + keys chatprovider.ProviderAPIKeys, +) (database.Chat, error) { + if limitErr := p.checkUsageLimit(ctx, store, chat.OwnerID); limitErr != nil { + return database.Chat{}, limitErr + } + + headMessages, err := store.GetChatMessagesByChatIDAscPaginated( + ctx, + database.GetChatMessagesByChatIDAscPaginatedParams{ + ChatID: chat.ID, + AfterID: 0, + LimitVal: manualTitleMessageWindowLimit, + }, + ) + if err != nil { + return database.Chat{}, xerrors.Errorf("get head chat messages: %w", err) + } + tailMessages, err := store.GetChatMessagesByChatIDDescPaginated( + ctx, + database.GetChatMessagesByChatIDDescPaginatedParams{ + ChatID: chat.ID, + BeforeID: 0, + LimitVal: manualTitleMessageWindowLimit, + }, + ) + if err != nil { + return database.Chat{}, xerrors.Errorf("get tail chat messages: %w", err) + } + messages := mergeManualTitleMessages(headMessages, tailMessages) + if len(messages) == 0 { + return chat, nil + } + + model, modelConfig, err := p.resolveManualTitleModel(ctx, store, chat, keys) + if err != nil { + return database.Chat{}, err + } + + title, usage, err := generateManualTitle(ctx, messages, model) + if err != nil { + wrappedErr := xerrors.Errorf("generate manual title: %w", err) + if usage == (fantasy.Usage{}) { + return database.Chat{}, wrappedErr + } + return database.Chat{}, &manualTitleGenerationError{ + cause: wrappedErr, + modelConfig: modelConfig, + usage: usage, + } + } + + recordCtx, recordCancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second) + defer recordCancel() + + updatedChat, recordErr := recordManualTitleUsage( + recordCtx, + store, + chat, + modelConfig, + usage, + title, + ) + if recordErr != nil { + if title != "" { + return database.Chat{}, xerrors.Errorf("record manual title usage and update chat title: %w", recordErr) + } + return database.Chat{}, xerrors.Errorf("record manual title usage: %w", recordErr) + } + if updatedChat.Title == chat.Title { + return updatedChat, nil + } + + p.publishChatPubsubEvent(updatedChat, coderdpubsub.ChatEventKindTitleChange, nil) + return updatedChat, nil +} + +func (p *Server) resolveManualTitleModel( + ctx context.Context, + store database.Store, + chat database.Chat, + keys chatprovider.ProviderAPIKeys, +) (fantasy.LanguageModel, database.ChatModelConfig, error) { + configs, err := store.GetEnabledChatModelConfigs(ctx) + if err != nil { + p.logger.Debug(ctx, "failed to list manual title model configs", + slog.F("chat_id", chat.ID), + slog.Error(err), + ) + return p.resolveFallbackManualTitleModel(ctx, chat, keys) + } + + config, ok := selectPreferredConfiguredShortTextModelConfig(configs) + if !ok { + return p.resolveFallbackManualTitleModel(ctx, chat, keys) + } + + model, err := chatprovider.ModelFromConfig( + config.Provider, + config.Model, + keys, + chatprovider.UserAgent(), + chatprovider.CoderHeaders(chat), + ) + if err != nil { + p.logger.Debug(ctx, "manual title preferred model unavailable", + slog.F("chat_id", chat.ID), + slog.F("provider", config.Provider), + slog.F("model", config.Model), + slog.Error(err), + ) + return p.resolveFallbackManualTitleModel(ctx, chat, keys) + } + + return model, config, nil +} + +func (p *Server) resolveFallbackManualTitleModel( + ctx context.Context, + chat database.Chat, + keys chatprovider.ProviderAPIKeys, +) (fantasy.LanguageModel, database.ChatModelConfig, error) { + config, err := p.resolveModelConfig(ctx, chat) + if err != nil { + return nil, database.ChatModelConfig{}, xerrors.Errorf( + "resolve fallback manual title model config: %w", + err, + ) + } + model, err := chatprovider.ModelFromConfig( + config.Provider, + config.Model, + keys, + chatprovider.UserAgent(), + chatprovider.CoderHeaders(chat), + ) + if err != nil { + return nil, database.ChatModelConfig{}, xerrors.Errorf( + "create fallback manual title model: %w", + err, + ) + } + return model, config, nil +} + +func mergeManualTitleMessages( + headMessages []database.ChatMessage, + tailMessagesDesc []database.ChatMessage, +) []database.ChatMessage { + merged := make([]database.ChatMessage, 0, len(headMessages)+len(tailMessagesDesc)) + seen := make(map[int64]struct{}, len(headMessages)+len(tailMessagesDesc)) + appendUnique := func(message database.ChatMessage) { + if _, ok := seen[message.ID]; ok { + return + } + seen[message.ID] = struct{}{} + merged = append(merged, message) + } + for _, message := range headMessages { + appendUnique(message) + } + for i := len(tailMessagesDesc) - 1; i >= 0; i-- { + appendUnique(tailMessagesDesc[i]) + } + return merged +} + +func fantasyUsageToChatMessageUsage(usage fantasy.Usage) codersdk.ChatMessageUsage { + var chatUsage codersdk.ChatMessageUsage + if usage.InputTokens != 0 { + chatUsage.InputTokens = ptr.Ref(usage.InputTokens) + } + if usage.OutputTokens != 0 { + chatUsage.OutputTokens = ptr.Ref(usage.OutputTokens) + } + if usage.ReasoningTokens != 0 { + chatUsage.ReasoningTokens = ptr.Ref(usage.ReasoningTokens) + } + if usage.CacheCreationTokens != 0 { + chatUsage.CacheCreationTokens = ptr.Ref(usage.CacheCreationTokens) + } + if usage.CacheReadTokens != 0 { + chatUsage.CacheReadTokens = ptr.Ref(usage.CacheReadTokens) + } + return chatUsage +} + +func recordManualTitleUsage( + ctx context.Context, + store database.Store, + chat database.Chat, + modelConfig database.ChatModelConfig, + usage fantasy.Usage, + newTitle string, +) (database.Chat, error) { + hasUsage := usage != (fantasy.Usage{}) + if !hasUsage && newTitle == "" { + return chat, nil + } + + var totalCostMicros *int64 + if hasUsage { + callConfig := codersdk.ChatModelCallConfig{} + if len(modelConfig.Options) > 0 { + if err := json.Unmarshal(modelConfig.Options, &callConfig); err != nil { + return database.Chat{}, xerrors.Errorf("parse model call config: %w", err) + } + } + totalCostMicros = chatcost.CalculateTotalCostMicros( + fantasyUsageToChatMessageUsage(usage), + callConfig.Cost, + ) + } + + // Use a valid empty JSON array for the content column. + // MarshalParts returns a null NullRawMessage for empty + // slices, which becomes an empty string that PostgreSQL + // rejects as invalid JSON. + content := "[]" + + updatedChat := chat + err := store.InTx(func(tx database.Store) error { + lockedChat, err := tx.GetChatByIDForUpdate(ctx, chat.ID) + if err != nil { + return xerrors.Errorf("lock chat for manual title usage: %w", err) + } + updatedChat = lockedChat + if hasUsage { + messages, err := tx.InsertChatMessages(ctx, database.InsertChatMessagesParams{ + ChatID: chat.ID, + CreatedBy: []uuid.UUID{chat.OwnerID}, + ModelConfigID: []uuid.UUID{modelConfig.ID}, + Role: []database.ChatMessageRole{database.ChatMessageRoleAssistant}, + Content: []string{content}, + ContentVersion: []int16{chatprompt.CurrentContentVersion}, + Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityModel}, + InputTokens: []int64{usage.InputTokens}, + OutputTokens: []int64{usage.OutputTokens}, + TotalTokens: []int64{usage.TotalTokens}, + ReasoningTokens: []int64{usage.ReasoningTokens}, + CacheCreationTokens: []int64{usage.CacheCreationTokens}, + CacheReadTokens: []int64{usage.CacheReadTokens}, + ContextLimit: []int64{modelConfig.ContextLimit}, + Compressed: []bool{false}, + TotalCostMicros: []int64{ptr.NilToDefault(totalCostMicros, 0)}, + RuntimeMs: []int64{0}, + ProviderResponseID: []string{""}, + }) + if err != nil { + return xerrors.Errorf("insert manual title usage message: %w", err) + } + if len(messages) != 1 { + return xerrors.Errorf("expected 1 manual title usage message, got %d", len(messages)) + } + if err := tx.SoftDeleteChatMessageByID(ctx, messages[0].ID); err != nil { + return xerrors.Errorf("soft delete manual title usage message: %w", err) + } + if lockedChat.LastModelConfigID != modelConfig.ID { + if _, err := tx.UpdateChatLastModelConfigByID(ctx, database.UpdateChatLastModelConfigByIDParams{ + ID: chat.ID, + LastModelConfigID: lockedChat.LastModelConfigID, + }); err != nil { + return xerrors.Errorf("restore chat model config after manual title usage: %w", err) + } + } + } + if newTitle != "" && lockedChat.Title == chat.Title && newTitle != lockedChat.Title { + updatedChat, err = tx.UpdateChatByID(ctx, database.UpdateChatByIDParams{ + ID: chat.ID, + Title: newTitle, + }) + if err != nil { + return xerrors.Errorf("update chat title: %w", err) + } + } + return nil + }, nil) + if err != nil { + return database.Chat{}, err + } + return updatedChat, nil +} + // RefreshStatus loads the latest chat status and publishes it to stream subscribers. func (p *Server) RefreshStatus(ctx context.Context, chatID uuid.UUID) error { if chatID == uuid.Nil { @@ -3475,24 +3946,7 @@ func (p *Server) runChat( } hasUsage := step.Usage != (fantasy.Usage{}) - var usageForCost codersdk.ChatMessageUsage - if hasUsage { - if step.Usage.InputTokens != 0 { - usageForCost.InputTokens = ptr.Ref(step.Usage.InputTokens) - } - if step.Usage.OutputTokens != 0 { - usageForCost.OutputTokens = ptr.Ref(step.Usage.OutputTokens) - } - if step.Usage.ReasoningTokens != 0 { - usageForCost.ReasoningTokens = ptr.Ref(step.Usage.ReasoningTokens) - } - if step.Usage.CacheCreationTokens != 0 { - usageForCost.CacheCreationTokens = ptr.Ref(step.Usage.CacheCreationTokens) - } - if step.Usage.CacheReadTokens != 0 { - usageForCost.CacheReadTokens = ptr.Ref(step.Usage.CacheReadTokens) - } - } + usageForCost := fantasyUsageToChatMessageUsage(step.Usage) totalCostMicros := chatcost.CalculateTotalCostMicros(usageForCost, callConfig.Cost) var insertedMessages []database.ChatMessage @@ -4078,10 +4532,8 @@ func (p *Server) resolveChatModel( ctx context.Context, chat database.Chat, ) (fantasy.LanguageModel, database.ChatModelConfig, chatprovider.ProviderAPIKeys, error) { - var ( - dbConfig database.ChatModelConfig - providers []database.ChatProvider - ) + var dbConfig database.ChatModelConfig + var keys chatprovider.ProviderAPIKeys var g errgroup.Group g.Go(func() error { @@ -4094,28 +4546,15 @@ func (p *Server) resolveChatModel( }) g.Go(func() error { var err error - providers, err = p.configCache.EnabledProviders(ctx) + keys, err = p.resolveProviderAPIKeys(ctx) if err != nil { - return xerrors.Errorf("get enabled chat providers: %w", err) + return xerrors.Errorf("resolve provider API keys: %w", err) } return nil }) if err := g.Wait(); err != nil { return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, err } - dbProviders := make( - []chatprovider.ConfiguredProvider, 0, len(providers), - ) - for _, provider := range providers { - dbProviders = append(dbProviders, chatprovider.ConfiguredProvider{ - Provider: provider.Provider, - APIKey: provider.APIKey, - BaseURL: provider.BaseUrl, - }) - } - keys := chatprovider.MergeProviderAPIKeys( - p.providerAPIKeys, dbProviders, - ) model, err := chatprovider.ModelFromConfig( dbConfig.Provider, dbConfig.Model, keys, chatprovider.UserAgent(), @@ -4129,6 +4568,29 @@ func (p *Server) resolveChatModel( return model, dbConfig, keys, nil } +func (p *Server) resolveProviderAPIKeys( + ctx context.Context, +) (chatprovider.ProviderAPIKeys, error) { + providers, err := p.configCache.EnabledProviders(ctx) + if err != nil { + return chatprovider.ProviderAPIKeys{}, xerrors.Errorf( + "get enabled chat providers: %w", + err, + ) + } + dbProviders := make( + []chatprovider.ConfiguredProvider, 0, len(providers), + ) + for _, provider := range providers { + dbProviders = append(dbProviders, chatprovider.ConfiguredProvider{ + Provider: provider.Provider, + APIKey: provider.APIKey, + BaseURL: provider.BaseUrl, + }) + } + return chatprovider.MergeProviderAPIKeys(p.providerAPIKeys, dbProviders), nil +} + // resolveModelConfig looks up the chat's model config by its // LastModelConfigID. If the referenced config no longer exists // (e.g. it was deleted), it falls back to the default model diff --git a/coderd/x/chatd/chatd_internal_test.go b/coderd/x/chatd/chatd_internal_test.go index 1fa72feda7..f13b9b8b50 100644 --- a/coderd/x/chatd/chatd_internal_test.go +++ b/coderd/x/chatd/chatd_internal_test.go @@ -23,6 +23,7 @@ import ( dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub" coderdpubsub "github.com/coder/coder/v2/coderd/pubsub" "github.com/coder/coder/v2/coderd/x/chatd/chaterror" + "github.com/coder/coder/v2/coderd/x/chatd/chattest" "github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk/workspacesdk" "github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock" @@ -30,6 +31,183 @@ import ( "github.com/coder/quartz" ) +func TestRegenerateChatTitle_PersistsAndBroadcasts(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + lockTx := dbmock.NewMockStore(ctrl) + usageTx := dbmock.NewMockStore(ctrl) + unlockTx := dbmock.NewMockStore(ctrl) + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + pubsub := dbpubsub.NewInMemory() + clock := quartz.NewReal() + + ownerID := uuid.New() + chatID := uuid.New() + modelConfigID := uuid.New() + userPrompt := "review pull request 23633 and fix review threads" + wantTitle := "Review PR 23633" + + chat := database.Chat{ + ID: chatID, + OwnerID: ownerID, + LastModelConfigID: modelConfigID, + Status: database.ChatStatusCompleted, + Title: fallbackChatTitle(userPrompt), + } + modelConfig := database.ChatModelConfig{ + ID: modelConfigID, + Provider: "anthropic", + Model: "claude-haiku-4-5", + ContextLimit: 8192, + } + updatedChat := chat + updatedChat.Title = wantTitle + + messageEvents := make(chan struct { + payload coderdpubsub.ChatEvent + err error + }, 1) + cancelSub, err := pubsub.SubscribeWithErr( + coderdpubsub.ChatEventChannel(ownerID), + coderdpubsub.HandleChatEvent(func(_ context.Context, payload coderdpubsub.ChatEvent, err error) { + messageEvents <- struct { + payload coderdpubsub.ChatEvent + err error + }{payload: payload, err: err} + }), + ) + require.NoError(t, err) + defer cancelSub() + + serverURL := chattest.NewAnthropic(t, func(req *chattest.AnthropicRequest) chattest.AnthropicResponse { + require.Equal(t, "claude-haiku-4-5", req.Model) + return chattest.AnthropicNonStreamingResponse(wantTitle) + }) + + server := &Server{ + db: db, + logger: logger, + pubsub: pubsub, + configCache: newChatConfigCache(context.Background(), db, clock), + } + + db.EXPECT().GetChatModelConfigByID(gomock.Any(), modelConfigID).Return(modelConfig, nil) + db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{ + Provider: "anthropic", + APIKey: "test-key", + BaseUrl: serverURL, + }}, nil) + db.EXPECT().GetChatUsageLimitConfig(gomock.Any()).Return(database.ChatUsageLimitConfig{}, sql.ErrNoRows) + db.EXPECT().GetChatMessagesByChatIDAscPaginated( + gomock.Any(), + database.GetChatMessagesByChatIDAscPaginatedParams{ + ChatID: chatID, + AfterID: 0, + LimitVal: manualTitleMessageWindowLimit, + }, + ).Return([]database.ChatMessage{ + mustChatMessage( + t, + database.ChatMessageRoleUser, + database.ChatMessageVisibilityBoth, + codersdk.ChatMessageText(userPrompt), + ), + mustChatMessage( + t, + database.ChatMessageRoleAssistant, + database.ChatMessageVisibilityBoth, + codersdk.ChatMessageText("checking the diff now"), + ), + }, nil) + db.EXPECT().GetChatMessagesByChatIDDescPaginated( + gomock.Any(), + database.GetChatMessagesByChatIDDescPaginatedParams{ + ChatID: chatID, + BeforeID: 0, + LimitVal: manualTitleMessageWindowLimit, + }, + ).Return(nil, nil) + db.EXPECT().GetEnabledChatModelConfigs(gomock.Any()).Return(nil, nil) + + gomock.InOrder( + db.EXPECT().InTx(gomock.Any(), database.DefaultTXOptions().WithID("chat_title_regenerate_lock")).DoAndReturn( + func(fn func(database.Store) error, opts *database.TxOptions) error { + require.Equal(t, "chat_title_regenerate_lock", opts.TxIdentifier) + return fn(lockTx) + }, + ), + db.EXPECT().InTx(gomock.Any(), nil).DoAndReturn( + func(fn func(database.Store) error, opts *database.TxOptions) error { + require.Nil(t, opts) + return fn(usageTx) + }, + ), + db.EXPECT().InTx(gomock.Any(), database.DefaultTXOptions().WithID("chat_title_regenerate_unlock")).DoAndReturn( + func(fn func(database.Store) error, opts *database.TxOptions) error { + require.Equal(t, "chat_title_regenerate_unlock", opts.TxIdentifier) + return fn(unlockTx) + }, + ), + ) + + lockTx.EXPECT().GetChatByIDForUpdate(gomock.Any(), chatID).Return(chat, nil) + lockTx.EXPECT().UpdateChatStatusPreserveUpdatedAt(gomock.Any(), gomock.AssignableToTypeOf(database.UpdateChatStatusPreserveUpdatedAtParams{})).DoAndReturn( + func(_ context.Context, arg database.UpdateChatStatusPreserveUpdatedAtParams) (database.Chat, error) { + require.Equal(t, chatID, arg.ID) + require.Equal(t, chat.Status, arg.Status) + require.Equal(t, uuid.NullUUID{UUID: manualTitleLockWorkerID, Valid: true}, arg.WorkerID) + require.True(t, arg.StartedAt.Valid) + require.True(t, arg.HeartbeatAt.Valid) + return chat, nil + }, + ) + + usageTx.EXPECT().GetChatByIDForUpdate(gomock.Any(), chatID).Return(chat, nil) + usageTx.EXPECT().InsertChatMessages(gomock.Any(), gomock.AssignableToTypeOf(database.InsertChatMessagesParams{})).DoAndReturn( + func(_ context.Context, arg database.InsertChatMessagesParams) ([]database.ChatMessage, error) { + require.Equal(t, []uuid.UUID{ownerID}, arg.CreatedBy) + require.Equal(t, []uuid.UUID{modelConfigID}, arg.ModelConfigID) + require.Equal(t, []string{"[]"}, arg.Content) + return []database.ChatMessage{{ID: 91}}, nil + }, + ) + usageTx.EXPECT().SoftDeleteChatMessageByID(gomock.Any(), int64(91)).Return(nil) + usageTx.EXPECT().UpdateChatByID(gomock.Any(), database.UpdateChatByIDParams{ + ID: chatID, + Title: wantTitle, + }).Return(updatedChat, nil) + + lockedChatWithMarker := updatedChat + lockedChatWithMarker.WorkerID = uuid.NullUUID{UUID: manualTitleLockWorkerID, Valid: true} + unlockTx.EXPECT().GetChatByIDForUpdate(gomock.Any(), chatID).Return(lockedChatWithMarker, nil) + unlockTx.EXPECT().UpdateChatStatusPreserveUpdatedAt(gomock.Any(), gomock.AssignableToTypeOf(database.UpdateChatStatusPreserveUpdatedAtParams{})).DoAndReturn( + func(_ context.Context, arg database.UpdateChatStatusPreserveUpdatedAtParams) (database.Chat, error) { + require.Equal(t, chatID, arg.ID) + require.False(t, arg.WorkerID.Valid) + require.False(t, arg.StartedAt.Valid) + require.False(t, arg.HeartbeatAt.Valid) + return updatedChat, nil + }, + ) + + gotChat, err := server.RegenerateChatTitle(ctx, chat) + require.NoError(t, err) + require.Equal(t, updatedChat, gotChat) + + select { + case event := <-messageEvents: + require.NoError(t, event.err) + require.Equal(t, coderdpubsub.ChatEventKindTitleChange, event.payload.Kind) + require.Equal(t, chatID, event.payload.Chat.ID) + require.Equal(t, wantTitle, event.payload.Chat.Title) + case <-time.After(time.Second): + t.Fatal("timed out waiting for title change pubsub event") + } +} + func TestRefreshChatWorkspaceSnapshot_NoReloadWhenWorkspacePresent(t *testing.T) { t.Parallel() diff --git a/coderd/x/chatd/quickgen.go b/coderd/x/chatd/quickgen.go index b23a394b55..02f1f1b8a0 100644 --- a/coderd/x/chatd/quickgen.go +++ b/coderd/x/chatd/quickgen.go @@ -2,6 +2,8 @@ package chatd import ( "context" + "fmt" + "slices" "strings" "time" @@ -36,6 +38,17 @@ const titleGenerationPrompt = "You are a title generator. Your ONLY job is to ou "Output ONLY the title — no quotes, no emoji, no markdown, no code fences, " + "no trailing punctuation, no preamble, no explanation. Sentence case." +const ( + // maxConversationContextRunes caps the conversation sample in manual + // title prompts to avoid exceeding model context windows. + maxConversationContextRunes = 6000 + // maxLatestUserMessageRunes caps the latest user message excerpt. + maxLatestUserMessageRunes = 1000 + // recentTurnWindow is the number of most recent turns included + // alongside the first user turn in manual title context. + recentTurnWindow = 3 +) + // preferredTitleModels are lightweight models used for title // generation, one per provider type. Each entry uses the // cheapest/fastest small model for that provider as identified @@ -54,6 +67,33 @@ var preferredTitleModels = []struct { {fantasyvercel.Name, "anthropic/claude-haiku-4.5"}, } +func selectPreferredConfiguredShortTextModelConfig( + configs []database.ChatModelConfig, +) (database.ChatModelConfig, bool) { + for _, preferred := range preferredTitleModels { + for _, config := range configs { + if chatprovider.NormalizeProvider(config.Provider) != preferred.provider { + continue + } + if !strings.EqualFold(strings.TrimSpace(config.Model), preferred.model) { + continue + } + return config, true + } + } + return database.ChatModelConfig{}, false +} + +func normalizeShortTextOutput(text string) string { + text = strings.TrimSpace(text) + if text == "" { + return "" + } + + text = strings.Trim(text, "\"'`") + return strings.Join(strings.Fields(text), " ") +} + // maybeGenerateChatTitle generates an AI title for the chat when // appropriate (first user message, no assistant reply yet, and the // current title is either empty or still the fallback truncation). @@ -138,7 +178,7 @@ func generateTitle( model fantasy.LanguageModel, input string, ) (string, error) { - title, err := generateShortText(ctx, model, titleGenerationPrompt, input) + title, _, err := generateShortText(ctx, model, titleGenerationPrompt, input) if err != nil { return "", err } @@ -199,13 +239,10 @@ func titleInput( } func normalizeTitleOutput(title string) string { - title = strings.TrimSpace(title) + title = normalizeShortTextOutput(title) if title == "" { return "" } - - title = strings.Trim(title, "\"'`") - title = strings.Join(strings.Fields(title), " ") return truncateRunes(title, 80) } @@ -226,7 +263,7 @@ func fallbackChatTitle(message string) string { title := strings.Join(words, " ") if truncated { - title += "…" + return truncateRunes(title, maxRunes-1) + "…" } return truncateRunes(title, maxRunes) @@ -260,6 +297,186 @@ func truncateRunes(value string, maxLen int) string { return string(runes[:maxLen]) } +// Manual title regeneration is user-initiated and can use richer +// conversation context than the automatic first-message title path +// above. These helpers keep the manual prompt-building logic private +// while reusing the shared title-generation utilities in this file. +type manualTitleTurn struct { + role string + text string +} + +func extractManualTitleTurns(messages []database.ChatMessage) []manualTitleTurn { + turns := make([]manualTitleTurn, 0, len(messages)) + for _, message := range messages { + if message.Visibility == database.ChatMessageVisibilityModel { + continue + } + + role := "" + switch message.Role { + case database.ChatMessageRoleUser: + role = string(database.ChatMessageRoleUser) + case database.ChatMessageRoleAssistant: + role = string(database.ChatMessageRoleAssistant) + default: + continue + } + + parts, err := chatprompt.ParseContent(message) + if err != nil { + continue + } + + text := strings.TrimSpace(contentBlocksToText(parts)) + if text == "" { + continue + } + + turns = append(turns, manualTitleTurn{ + role: role, + text: text, + }) + } + + return turns +} + +func selectManualTitleTurnIndexes(turns []manualTitleTurn) []int { + firstUserIndex := slices.IndexFunc(turns, func(turn manualTitleTurn) bool { + return turn.role == string(database.ChatMessageRoleUser) + }) + if firstUserIndex == -1 { + return nil + } + + windowStart := max(0, len(turns)-recentTurnWindow) + selected := make([]int, 0, recentTurnWindow+1) + if firstUserIndex < windowStart { + selected = append(selected, firstUserIndex) + } + for i := windowStart; i < len(turns); i++ { + selected = append(selected, i) + } + + return selected +} + +func buildManualTitleContext( + turns []manualTitleTurn, + selected []int, +) (conversationBlock string, latestUserMsg string) { + userCount := 0 + for _, turn := range turns { + if turn.role != string(database.ChatMessageRoleUser) { + continue + } + userCount++ + latestUserMsg = turn.text + } + + latestUserMsg = truncateRunes(latestUserMsg, maxLatestUserMessageRunes) + if userCount <= 1 || len(selected) == 0 { + return "", latestUserMsg + } + + lines := make([]string, 0, len(selected)+1) + for i, idx := range selected { + if i == 1 { + if gap := idx - selected[i-1] - 1; gap > 0 { + lines = append(lines, fmt.Sprintf("[... %d earlier turns omitted ...]", gap)) + } + } + lines = append(lines, fmt.Sprintf("[%s]: %s", turns[idx].role, turns[idx].text)) + } + + conversationBlock = strings.Join(lines, "\n") + conversationBlock = truncateRunes(conversationBlock, maxConversationContextRunes) + return conversationBlock, latestUserMsg +} + +func renderManualTitlePrompt( + conversationBlock string, + firstUserText string, + latestUserMsg string, +) string { + var prompt strings.Builder + write := func(value string) { + _, _ = prompt.WriteString(value) + } + + write("You are a title generator for an AI coding assistant conversation.\n\n") + write("The user's primary objective was:\n\n") + write(firstUserText) + write("\n") + + if conversationBlock != "" { + write("\n\nConversation sample:\n\n") + write(conversationBlock) + write("\n") + } + + if strings.TrimSpace(latestUserMsg) != strings.TrimSpace(truncateRunes(firstUserText, maxLatestUserMessageRunes)) { + write("\n\nThe user's most recent message:\n\n") + write(latestUserMsg) + write("\n\n") + write("Note: Weight the overall conversation arc more heavily than just the latest message.") + } + + write("\n\nRequirements:\n") + write("- Output a short title of 2-8 words.\n") + write("- Use verb-noun format in sentence case.\n") + write("- Preserve specific identifiers (PR numbers, repo names, file paths, function names, error messages).\n") + write("- No trailing punctuation, quotes, emoji, or markdown.\n") + write("- No temporal phrasing (\"Continue\", \"Follow up on\") or meta phrasing (\"Chat about\").\n") + write("- Output ONLY the title - nothing else.\n") + return prompt.String() +} + +func generateManualTitle( + ctx context.Context, + messages []database.ChatMessage, + fallbackModel fantasy.LanguageModel, +) (string, fantasy.Usage, error) { + turns := extractManualTitleTurns(messages) + selected := selectManualTitleTurnIndexes(turns) + + firstUserIndex := slices.IndexFunc(turns, func(turn manualTitleTurn) bool { + return turn.role == string(database.ChatMessageRoleUser) + }) + if firstUserIndex == -1 { + return "", fantasy.Usage{}, nil + } + firstUserText := truncateRunes(turns[firstUserIndex].text, maxLatestUserMessageRunes) + + conversationBlock, latestUserMsg := buildManualTitleContext(turns, selected) + systemPrompt := renderManualTitlePrompt( + conversationBlock, + firstUserText, + latestUserMsg, + ) + + titleCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + title, usage, err := generateShortText( + titleCtx, + fallbackModel, + systemPrompt, + "Generate the title.", + ) + if err != nil { + return "", fantasy.Usage{}, err + } + + title = normalizeTitleOutput(title) + if title == "" { + return "", usage, xerrors.New("generated title was empty") + } + + return title, usage, nil +} + const pushSummaryPrompt = "You are a notification assistant. Given a chat title " + "and the agent's last message, write a single short sentence (under 100 characters) " + "summarizing what the agent did. This will be shown as a push notification body. " + @@ -281,6 +498,7 @@ func generatePushSummary( summaryCtx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() + assistantText = truncateRunes(assistantText, maxConversationContextRunes) input := "Chat title: " + chat.Title + "\n\nAgent's last message:\n" + assistantText candidates := make([]fantasy.LanguageModel, 0, len(preferredTitleModels)+1) @@ -296,7 +514,7 @@ func generatePushSummary( candidates = append(candidates, fallbackModel) for _, model := range candidates { - summary, err := generateShortText(summaryCtx, model, pushSummaryPrompt, input) + summary, _, err := generateShortText(summaryCtx, model, pushSummaryPrompt, input) if err != nil { logger.Debug(ctx, "push summary model candidate failed", slog.Error(err), @@ -318,7 +536,7 @@ func generateShortText( model fantasy.LanguageModel, systemPrompt string, userInput string, -) (string, error) { +) (string, fantasy.Usage, error) { prompt := []fantasy.Message{ { Role: fantasy.MessageRoleSystem, @@ -346,7 +564,7 @@ func generateShortText( return genErr }, nil) if err != nil { - return "", xerrors.Errorf("generate short text: %w", err) + return "", fantasy.Usage{}, xerrors.Errorf("generate short text: %w", err) } responseParts := make([]codersdk.ChatMessagePart, 0, len(response.Content)) @@ -355,7 +573,6 @@ func generateShortText( responseParts = append(responseParts, p) } } - text := strings.TrimSpace(contentBlocksToText(responseParts)) - text = strings.Trim(text, "\"'`") - return text, nil + text := normalizeShortTextOutput(contentBlocksToText(responseParts)) + return text, response.Usage, nil } diff --git a/coderd/x/chatd/quickgen_test.go b/coderd/x/chatd/quickgen_test.go new file mode 100644 index 0000000000..e8dfee79a1 --- /dev/null +++ b/coderd/x/chatd/quickgen_test.go @@ -0,0 +1,584 @@ +package chatd //nolint:testpackage // Keeps internal helper tests in-package. + +import ( + "context" + "encoding/json" + "strings" + "testing" + "time" + + "charm.land/fantasy" + "github.com/sqlc-dev/pqtype" + "github.com/stretchr/testify/require" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/codersdk" +) + +func Test_extractManualTitleTurns(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + messages []database.ChatMessage + want []manualTitleTurn + }{ + { + name: "filters to visible user and assistant text turns", + messages: []database.ChatMessage{ + mustChatMessage(t, database.ChatMessageRoleUser, database.ChatMessageVisibilityBoth, + codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeText, Text: " review quickgen helpers "}, + ), + mustChatMessage(t, database.ChatMessageRoleAssistant, database.ChatMessageVisibilityBoth, + codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeText, Text: " drafted a plan "}, + ), + mustChatMessage(t, database.ChatMessageRoleSystem, database.ChatMessageVisibilityBoth, + codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeText, Text: "system prompt"}, + ), + mustChatMessage(t, database.ChatMessageRoleTool, database.ChatMessageVisibilityBoth, + codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeText, Text: "tool output"}, + ), + mustChatMessage(t, database.ChatMessageRoleUser, database.ChatMessageVisibilityModel, + codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeText, Text: "hidden model note"}, + ), + mustChatMessage(t, database.ChatMessageRoleUser, database.ChatMessageVisibilityBoth, + codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeText, Text: " "}, + ), + mustChatMessage(t, database.ChatMessageRoleAssistant, database.ChatMessageVisibilityBoth, + codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeReasoning, Text: "reasoning only"}, + ), + mustChatMessage(t, database.ChatMessageRoleUser, database.ChatMessageVisibilityBoth, + codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeFile, MediaType: "text/plain"}, + ), + }, + want: []manualTitleTurn{ + {role: "user", text: "review quickgen helpers"}, + {role: "assistant", text: "drafted a plan"}, + }, + }, + { + name: "reuses text extraction for multi-part content", + messages: []database.ChatMessage{ + mustChatMessage(t, database.ChatMessageRoleUser, database.ChatMessageVisibilityBoth, + codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeText, Text: "first chunk"}, + codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeReasoning, Text: "skip me"}, + codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeText, Text: " second chunk "}, + ), + }, + want: []manualTitleTurn{{role: "user", text: "first chunk second chunk"}}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got := extractManualTitleTurns(tt.messages) + require.Equal(t, tt.want, got) + }) + } +} + +func Test_selectManualTitleTurnIndexes(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + turns []manualTitleTurn + want []int + }{ + { + name: "single user turn", + turns: []manualTitleTurn{ + {role: "user", text: "one"}, + }, + want: []int{0}, + }, + { + name: "first user plus trailing window", + turns: []manualTitleTurn{ + {role: "user", text: "one"}, + {role: "assistant", text: "two"}, + {role: "user", text: "three"}, + {role: "assistant", text: "four"}, + {role: "user", text: "five"}, + }, + want: []int{0, 2, 3, 4}, + }, + { + name: "two turns returns both", + turns: []manualTitleTurn{ + {role: "user", text: "one"}, + {role: "assistant", text: "two"}, + }, + want: []int{0, 1}, + }, + { + name: "prepends first user when before trailing window", + turns: []manualTitleTurn{ + {role: "assistant", text: "intro"}, + {role: "assistant", text: "setup"}, + {role: "user", text: "goal"}, + {role: "assistant", text: "a"}, + {role: "assistant", text: "b"}, + {role: "assistant", text: "c"}, + }, + want: []int{2, 3, 4, 5}, + }, + { + name: "ten plus turns keeps first user and last three", + turns: []manualTitleTurn{ + {role: "assistant", text: "0"}, + {role: "assistant", text: "1"}, + {role: "user", text: "2"}, + {role: "assistant", text: "3"}, + {role: "assistant", text: "4"}, + {role: "assistant", text: "5"}, + {role: "assistant", text: "6"}, + {role: "assistant", text: "7"}, + {role: "assistant", text: "8"}, + {role: "user", text: "9"}, + {role: "assistant", text: "10"}, + {role: "user", text: "11"}, + }, + want: []int{2, 9, 10, 11}, + }, + { + name: "no user turns", + turns: []manualTitleTurn{ + {role: "assistant", text: "one"}, + {role: "assistant", text: "two"}, + }, + want: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + got := selectManualTitleTurnIndexes(tt.turns) + require.Equal(t, tt.want, got) + }) + } +} + +func Test_buildManualTitleContext(t *testing.T) { + t.Parallel() + + longConversationText := strings.Repeat("a", 3500) + longLatestUserText := strings.Repeat("z", 1200) + + tests := []struct { + name string + turns []manualTitleTurn + selected []int + wantConversation string + wantConversationEmpty bool + wantConversationHasGap bool + wantConversationRunes int + wantLatestUser string + wantLatestUserRunes int + wantLatestUserContains string + wantLatestUserNotEmpty bool + }{ + { + name: "adds gap marker when selected turns skip earlier context", + turns: []manualTitleTurn{ + {role: "user", text: "open pull request"}, + {role: "assistant", text: "checked CI"}, + {role: "user", text: "review logs"}, + {role: "assistant", text: "found flaky test"}, + {role: "user", text: "update chat title"}, + }, + selected: []int{0, 3, 4}, + wantConversationHasGap: true, + wantLatestUser: "update chat title", + }, + { + name: "omits gap marker for contiguous selection", + turns: []manualTitleTurn{ + {role: "user", text: "open pull request"}, + {role: "assistant", text: "checked CI"}, + {role: "user", text: "update chat title"}, + }, + selected: []int{0, 1, 2}, + wantConversation: "[user]: open pull request\n[assistant]: checked CI\n[user]: update chat title", + wantConversationHasGap: false, + wantLatestUser: "update chat title", + }, + { + name: "single useful user turn returns empty conversation block", + turns: []manualTitleTurn{{role: "user", text: "rename helper"}}, + selected: []int{0}, + wantConversationEmpty: true, + wantLatestUser: "rename helper", + }, + { + name: "truncates conversation block at six thousand runes", + turns: []manualTitleTurn{ + {role: "user", text: longConversationText}, + {role: "assistant", text: longConversationText}, + {role: "user", text: "latest"}, + }, + selected: []int{0, 1, 2}, + wantConversationRunes: 6000, + wantLatestUser: "latest", + }, + { + name: "truncates latest user message at one thousand runes", + turns: []manualTitleTurn{ + {role: "user", text: "first"}, + {role: "assistant", text: "reply"}, + {role: "user", text: longLatestUserText}, + }, + selected: []int{0, 1, 2}, + wantLatestUserRunes: 1000, + wantLatestUserContains: strings.Repeat("z", 1000), + wantLatestUserNotEmpty: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + conversationBlock, latestUserMsg := buildManualTitleContext(tt.turns, tt.selected) + + if tt.wantConversationEmpty { + require.Empty(t, conversationBlock) + } + if tt.wantConversation != "" { + require.Equal(t, tt.wantConversation, conversationBlock) + } + if tt.wantConversationHasGap { + require.Contains(t, conversationBlock, "[... 2 earlier turns omitted ...]") + } else if !tt.wantConversationEmpty { + require.NotContains(t, conversationBlock, "earlier turns omitted") + } + if tt.wantConversationRunes > 0 { + require.Len(t, []rune(conversationBlock), tt.wantConversationRunes) + } + if tt.wantLatestUser != "" { + require.Equal(t, tt.wantLatestUser, latestUserMsg) + } + if tt.wantLatestUserRunes > 0 { + require.Len(t, []rune(latestUserMsg), tt.wantLatestUserRunes) + } + if tt.wantLatestUserContains != "" { + require.Equal(t, tt.wantLatestUserContains, latestUserMsg) + } + if tt.wantLatestUserNotEmpty { + require.NotEmpty(t, latestUserMsg) + } + }) + } +} + +func Test_renderManualTitlePrompt(t *testing.T) { + t.Parallel() + + longFirstUserText := strings.Repeat("b", 1501) + + tests := []struct { + name string + conversationBlock string + firstUserText string + latestUserMsg string + wantConversationSample bool + wantLatestSection bool + }{ + { + name: "includes conversation sample when provided", + conversationBlock: "[user]: inspect logs\n[assistant]: found flaky test", + firstUserText: "inspect logs", + latestUserMsg: "update quickgen title", + wantConversationSample: true, + wantLatestSection: true, + }, + { + name: "omits optional sections when not needed", + conversationBlock: "", + firstUserText: "inspect logs", + latestUserMsg: "inspect logs", + wantConversationSample: false, + wantLatestSection: false, + }, + { + name: "latest section compares trimmed text", + conversationBlock: "", + firstUserText: "inspect logs", + latestUserMsg: " inspect logs ", + wantConversationSample: false, + wantLatestSection: false, + }, + { + name: "omits latest section when same message truncated", + conversationBlock: "", + firstUserText: longFirstUserText, + latestUserMsg: truncateRunes(longFirstUserText, 1000), + wantConversationSample: false, + wantLatestSection: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + prompt := renderManualTitlePrompt(tt.conversationBlock, tt.firstUserText, tt.latestUserMsg) + + require.Contains(t, prompt, "The user's primary objective was:") + require.Contains(t, prompt, "Requirements:") + require.Contains(t, prompt, "- Output a short title of 2-8 words.") + require.Contains(t, prompt, "- Output ONLY the title - nothing else.") + + if tt.wantConversationSample { + require.Contains(t, prompt, "Conversation sample:") + require.Contains(t, prompt, tt.conversationBlock) + } else { + require.NotContains(t, prompt, "Conversation sample:") + } + + if tt.wantLatestSection { + require.Contains(t, prompt, "The user's most recent message:") + require.Contains(t, prompt, "Note: Weight the overall conversation arc more heavily than just the latest message.") + require.Contains(t, prompt, strings.TrimSpace(tt.latestUserMsg)) + } else { + require.NotContains(t, prompt, "The user's most recent message:") + require.NotContains(t, prompt, "Weight the overall conversation arc more heavily") + } + }) + } +} + +func Test_generateManualTitle_UsesTimeout(t *testing.T) { + t.Parallel() + + messages := []database.ChatMessage{ + mustChatMessage( + t, + database.ChatMessageRoleUser, + database.ChatMessageVisibilityBoth, + codersdk.ChatMessageText("refresh chat title"), + ), + } + + model := &stubModel{ + generateFn: func(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) { + deadline, ok := ctx.Deadline() + require.True(t, ok, "manual title generation should set a deadline") + require.WithinDuration( + t, + time.Now().Add(30*time.Second), + deadline, + 2*time.Second, + ) + require.Len(t, call.Prompt, 2) + return &fantasy.Response{ + Content: fantasy.ResponseContent{ + fantasy.TextContent{Text: "Refresh title"}, + }, + }, nil + }, + } + + title, _, err := generateManualTitle( + context.Background(), + messages, + model, + ) + require.NoError(t, err) + require.Equal(t, "Refresh title", title) +} + +func Test_generateManualTitle_TruncatesFirstUserInput(t *testing.T) { + t.Parallel() + + longFirstUserText := strings.Repeat("a", 1500) + messages := []database.ChatMessage{ + mustChatMessage( + t, + database.ChatMessageRoleUser, + database.ChatMessageVisibilityBoth, + codersdk.ChatMessageText(longFirstUserText), + ), + } + + model := &stubModel{ + generateFn: func(_ context.Context, call fantasy.Call) (*fantasy.Response, error) { + require.Len(t, call.Prompt, 2) + systemText, ok := call.Prompt[0].Content[0].(fantasy.TextPart) + require.True(t, ok) + require.Contains(t, systemText.Text, truncateRunes(longFirstUserText, 1000)) + + userText, ok := call.Prompt[1].Content[0].(fantasy.TextPart) + require.True(t, ok) + require.Equal(t, "Generate the title.", userText.Text) + return &fantasy.Response{ + Content: fantasy.ResponseContent{ + fantasy.TextContent{Text: "Refresh title"}, + }, + }, nil + }, + } + + _, _, err := generateManualTitle( + context.Background(), + messages, + model, + ) + require.NoError(t, err) +} + +func Test_generateManualTitle_ReturnsUsageForEmptyNormalizedTitle(t *testing.T) { + t.Parallel() + + messages := []database.ChatMessage{ + mustChatMessage( + t, + database.ChatMessageRoleUser, + database.ChatMessageVisibilityBoth, + codersdk.ChatMessageText("refresh chat title"), + ), + } + + model := &stubModel{ + generateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) { + return &fantasy.Response{ + Content: fantasy.ResponseContent{ + fantasy.TextContent{Text: "\"\""}, + }, + Usage: fantasy.Usage{ + InputTokens: 11, + OutputTokens: 7, + TotalTokens: 18, + }, + }, nil + }, + } + + _, usage, err := generateManualTitle( + context.Background(), + messages, + model, + ) + require.ErrorContains(t, err, "generated title was empty") + require.Equal(t, int64(11), usage.InputTokens) + require.Equal(t, int64(7), usage.OutputTokens) + require.Equal(t, int64(18), usage.TotalTokens) +} + +func Test_selectPreferredConfiguredShortTextModelConfig(t *testing.T) { + t.Parallel() + + t.Run("chooses the highest-priority configured lightweight model", func(t *testing.T) { + t.Parallel() + + configs := []database.ChatModelConfig{ + {Provider: preferredTitleModels[2].provider, Model: preferredTitleModels[2].model}, + {Provider: preferredTitleModels[1].provider, Model: preferredTitleModels[1].model}, + {Provider: "openai", Model: "gpt-4.1"}, + } + + got, ok := selectPreferredConfiguredShortTextModelConfig(configs) + require.True(t, ok) + require.Equal(t, preferredTitleModels[1].provider, got.Provider) + require.Equal(t, preferredTitleModels[1].model, got.Model) + }) + + t.Run("returns false when no preferred lightweight model is configured", func(t *testing.T) { + t.Parallel() + + got, ok := selectPreferredConfiguredShortTextModelConfig([]database.ChatModelConfig{{ + Provider: "openai", + Model: "gpt-4.1", + }}) + require.False(t, ok) + require.Equal(t, database.ChatModelConfig{}, got) + }) +} + +func Test_generateShortText_NormalizesQuotedOutput(t *testing.T) { + t.Parallel() + + model := &stubModel{ + generateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) { + return &fantasy.Response{ + Content: fantasy.ResponseContent{ + fantasy.TextContent{Text: " \"Quoted summary\" "}, + }, + Usage: fantasy.Usage{InputTokens: 3, OutputTokens: 2, TotalTokens: 5}, + }, nil + }, + } + + text, usage, err := generateShortText(context.Background(), model, "system", "user") + require.NoError(t, err) + require.Equal(t, "Quoted summary", text) + require.Equal(t, int64(3), usage.InputTokens) + require.Equal(t, int64(2), usage.OutputTokens) + require.Equal(t, int64(5), usage.TotalTokens) +} + +type stubModel struct { + generateFn func(context.Context, fantasy.Call) (*fantasy.Response, error) +} + +func (m *stubModel) Generate( + ctx context.Context, + call fantasy.Call, +) (*fantasy.Response, error) { + return m.generateFn(ctx, call) +} + +func (*stubModel) Stream( + context.Context, + fantasy.Call, +) (fantasy.StreamResponse, error) { + return nil, xerrors.New("stream not implemented") +} + +func (*stubModel) GenerateObject( + context.Context, + fantasy.ObjectCall, +) (*fantasy.ObjectResponse, error) { + return nil, xerrors.New("generate object not implemented") +} + +func (*stubModel) StreamObject( + context.Context, + fantasy.ObjectCall, +) (fantasy.ObjectStreamResponse, error) { + return nil, xerrors.New("stream object not implemented") +} + +func (*stubModel) Provider() string { + return "test" +} + +func (*stubModel) Model() string { + return "test" +} + +func mustChatMessage( + t *testing.T, + role database.ChatMessageRole, + visibility database.ChatMessageVisibility, + parts ...codersdk.ChatMessagePart, +) database.ChatMessage { + t.Helper() + + content, err := json.Marshal(parts) + require.NoError(t, err) + + return database.ChatMessage{ + Role: role, + Visibility: visibility, + Content: pqtype.NullRawMessage{ + RawMessage: content, + Valid: len(content) > 0, + }, + } +} diff --git a/codersdk/chats.go b/codersdk/chats.go index 75223858e1..ddc06dd670 100644 --- a/codersdk/chats.go +++ b/codersdk/chats.go @@ -1870,6 +1870,21 @@ func (c *ExperimentalClient) InterruptChat(ctx context.Context, chatID uuid.UUID return chat, json.NewDecoder(res.Body).Decode(&chat) } +// RegenerateChatTitle requests the server to regenerate the chat's +// title using richer conversation context. +func (c *ExperimentalClient) RegenerateChatTitle(ctx context.Context, chatID uuid.UUID) (Chat, error) { + res, err := c.Request(ctx, http.MethodPost, fmt.Sprintf("/api/experimental/chats/%s/title/regenerate", chatID), nil) + if err != nil { + return Chat{}, err + } + defer res.Body.Close() + if res.StatusCode != http.StatusOK { + return Chat{}, readBodyAsChatUsageLimitError(res) + } + var chat Chat + return chat, json.NewDecoder(res.Body).Decode(&chat) +} + // GetChatGitChanges returns git changes for a chat. func (c *ExperimentalClient) GetChatGitChanges(ctx context.Context, chatID uuid.UUID) ([]ChatGitChange, error) { res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/experimental/chats/%s/git-changes", chatID), nil) diff --git a/site/src/api/api.ts b/site/src/api/api.ts index 456cc45228..7fff88626b 100644 --- a/site/src/api/api.ts +++ b/site/src/api/api.ts @@ -3126,6 +3126,13 @@ class ExperimentalApiMethods { await this.axios.patch(`/api/experimental/chats/${chatId}`, req); }; + regenerateChatTitle = async (chatId: string): Promise => { + const response = await this.axios.post( + `/api/experimental/chats/${chatId}/title/regenerate`, + ); + return response.data; + }; + createChatMessage = async ( chatId: string, req: TypesGen.CreateChatMessageRequest, diff --git a/site/src/api/queries/chats.test.ts b/site/src/api/queries/chats.test.ts index fece4b5120..c119d60578 100644 --- a/site/src/api/queries/chats.test.ts +++ b/site/src/api/queries/chats.test.ts @@ -22,6 +22,7 @@ import { invalidateChatListQueries, pinChat, promoteChatQueuedMessage, + regenerateChatTitle, reorderPinnedChat, unarchiveChat, unpinChat, @@ -41,6 +42,7 @@ vi.mock("api/api", () => ({ editChatMessage: vi.fn(), interruptChat: vi.fn(), promoteChatQueuedMessage: vi.fn(), + regenerateChatTitle: vi.fn(), }, }, })); @@ -556,6 +558,51 @@ describe("reorderPinnedChat", () => { }); }); +describe("regenerateChatTitle cache updates", () => { + it("preserves existing chat detail fields when the response is partial", () => { + const queryClient = createTestQueryClient(); + const chatId = "chat-1"; + const cachedChat = makeChat(chatId, { + diff_status: { + chat_id: chatId, + url: "https://example.com/pr/1", + pull_request_state: "open", + pull_request_title: "", + pull_request_draft: false, + changes_requested: false, + additions: 1, + deletions: 2, + changed_files: 3, + refreshed_at: "2025-01-01T00:00:00.000Z", + stale_at: "2025-01-01T01:00:00.000Z", + }, + }); + queryClient.setQueryData(chatKey(chatId), cachedChat); + seedInfiniteChats(queryClient, [cachedChat]); + + const mutation = regenerateChatTitle(queryClient); + const updatedChat = { + id: chatId, + title: "New title", + } satisfies Partial; + + mutation.onSuccess(updatedChat as TypesGen.Chat); + + const cachedDetail = queryClient.getQueryData( + chatKey(chatId), + ); + expect(cachedDetail).toEqual({ + ...cachedChat, + title: "New title", + }); + expect(cachedDetail?.diff_status).toEqual(cachedChat.diff_status); + expect(readInfiniteChats(queryClient)?.[0]).toMatchObject({ + id: chatId, + title: "New title", + }); + }); +}); + describe("chat cost query factories", () => { it("builds the summary query key and forwards snake_case params", async () => { const user = "user-1"; diff --git a/site/src/api/queries/chats.ts b/site/src/api/queries/chats.ts index e12d552072..fead88a404 100644 --- a/site/src/api/queries/chats.ts +++ b/site/src/api/queries/chats.ts @@ -476,6 +476,37 @@ export const reorderPinnedChat = (queryClient: QueryClient) => ({ }, }); +export const regenerateChatTitle = (queryClient: QueryClient) => ({ + mutationFn: (chatId: string) => API.experimental.regenerateChatTitle(chatId), + + onSuccess: (updatedChat: TypesGen.Chat) => { + queryClient.setQueryData( + chatKey(updatedChat.id), + (previousChat) => + previousChat ? { ...previousChat, ...updatedChat } : updatedChat, + ); + updateInfiniteChatsCache(queryClient, (chats) => + chats.map((chat) => + chat.id === updatedChat.id + ? { ...chat, title: updatedChat.title } + : chat, + ), + ); + }, + + onSettled: async ( + _data: TypesGen.Chat | undefined, + _error: unknown, + chatId: string, + ) => { + await invalidateChatListQueries(queryClient); + await queryClient.invalidateQueries({ + queryKey: chatKey(chatId), + exact: true, + }); + }, +}); + export const createChat = (queryClient: QueryClient) => ({ mutationFn: (req: TypesGen.CreateChatRequest) => API.experimental.createChat(req), diff --git a/site/src/pages/AgentsPage/AgentDetail.stories.tsx b/site/src/pages/AgentsPage/AgentDetail.stories.tsx index 44bd0be38c..8857190a0d 100644 --- a/site/src/pages/AgentsPage/AgentDetail.stories.tsx +++ b/site/src/pages/AgentsPage/AgentDetail.stories.tsx @@ -55,6 +55,9 @@ const AgentDetailLayout: FC = () => { requestUnarchiveAgent: () => {}, requestPinAgent: () => {}, requestUnpinAgent: () => {}, + onRegenerateTitle: () => {}, + isRegeneratingTitle: false, + regeneratingTitleChatId: null, isSidebarCollapsed: false, onToggleSidebarCollapsed: () => {}, onExpandSidebar: () => {}, diff --git a/site/src/pages/AgentsPage/AgentDetail.tsx b/site/src/pages/AgentsPage/AgentDetail.tsx index 03990e60eb..839e42cf08 100644 --- a/site/src/pages/AgentsPage/AgentDetail.tsx +++ b/site/src/pages/AgentsPage/AgentDetail.tsx @@ -292,6 +292,9 @@ const AgentDetail: FC = () => { requestArchiveAgent, requestArchiveAndDeleteWorkspace, requestUnarchiveAgent, + onRegenerateTitle, + isRegeneratingTitle, + regeneratingTitleChatId, isSidebarCollapsed, onToggleSidebarCollapsed, onChatReady, @@ -326,6 +329,9 @@ const AgentDetail: FC = () => { }); }; + const isRegeneratingThisChat = + isRegeneratingTitle && regeneratingTitleChatId === agentId; + const chatQuery = useQuery({ ...chat(agentId ?? ""), enabled: Boolean(agentId), @@ -480,6 +486,7 @@ const AgentDetail: FC = () => { } : undefined; const isArchived = chatRecord?.archived ?? false; + const isRegenerateTitleDisabled = isArchived || isRegeneratingTitle; const chatLastModelConfigID = chatRecord?.last_model_config_id; const sendMutation = useMutation( @@ -887,6 +894,13 @@ const AgentDetail: FC = () => { onChatReady(); }, [onChatReady, chatMessagesQuery.isSuccess, agentId]); + const handleRegenerateTitle = () => { + if (!agentId || isRegenerateTitleDisabled || !onRegenerateTitle) { + return; + } + onRegenerateTitle(agentId); + }; + if (chatQuery.isLoading || chatMessagesQuery.isLoading) { return ( { handleArchiveAndDeleteWorkspaceAction={ handleArchiveAndDeleteWorkspaceAction } + handleRegenerateTitle={handleRegenerateTitle} + isRegeneratingTitle={isRegeneratingThisChat} + isRegenerateTitleDisabled={isRegenerateTitleDisabled} urlTransform={urlTransform} scrollContainerRef={scrollContainerRef} scrollToBottomRef={scrollToBottomRef} diff --git a/site/src/pages/AgentsPage/AgentEmbedPage.tsx b/site/src/pages/AgentsPage/AgentEmbedPage.tsx index f0d259fc3d..f66082a93e 100644 --- a/site/src/pages/AgentsPage/AgentEmbedPage.tsx +++ b/site/src/pages/AgentsPage/AgentEmbedPage.tsx @@ -233,6 +233,9 @@ const AgentEmbedPage: FC = () => { requestPinAgent: () => {}, requestUnpinAgent: () => {}, requestArchiveAndDeleteWorkspace, + // Title regeneration is not supported in embed mode. + isRegeneratingTitle: false, + regeneratingTitleChatId: null, isSidebarCollapsed, onToggleSidebarCollapsed, onExpandSidebar: () => {}, diff --git a/site/src/pages/AgentsPage/AgentsPage.tsx b/site/src/pages/AgentsPage/AgentsPage.tsx index bcd47ce600..cb36584680 100644 --- a/site/src/pages/AgentsPage/AgentsPage.tsx +++ b/site/src/pages/AgentsPage/AgentsPage.tsx @@ -22,6 +22,7 @@ import { pinChat, prependToInfiniteChatsCache, readInfiniteChatsCache, + regenerateChatTitle, reorderPinnedChat, unarchiveChat, unpinChat, @@ -220,6 +221,12 @@ const AgentsPage: FC = () => { toast.error(getErrorMessage(error, "Failed to reorder pinned agents.")); }, }); + const regenerateTitleMutation = useMutation({ + ...regenerateChatTitle(queryClient), + onError: (error: unknown) => { + toast.error(getErrorMessage(error, "Failed to generate new title.")); + }, + }); const [isSidebarCollapsed, setIsSidebarCollapsed] = useState(false); const [chatErrorReasons, setChatErrorReasons] = useState< Record @@ -359,6 +366,12 @@ const AgentsPage: FC = () => { const requestReorderPinnedAgent = (chatId: string, pinOrder: number) => { reorderPinnedChatMutation.mutate({ chatId, pinOrder }); }; + const requestRegenerateTitle = (chatId: string) => { + if (regenerateTitleMutation.isPending) { + return; + } + regenerateTitleMutation.mutate(chatId); + }; const handleToggleSidebarCollapsed = () => setIsSidebarCollapsed((prev) => !prev); @@ -609,6 +622,9 @@ const AgentsPage: FC = () => { requestPinAgent={requestPinAgent} requestUnpinAgent={requestUnpinAgent} requestReorderPinnedAgent={requestReorderPinnedAgent} + onRegenerateTitle={requestRegenerateTitle} + isRegeneratingTitle={regenerateTitleMutation.isPending} + regeneratingTitleChatId={regenerateTitleMutation.variables ?? null} onToggleSidebarCollapsed={handleToggleSidebarCollapsed} isAgentsAdmin={isAgentsAdmin} hasNextPage={chatsQuery.hasNextPage} diff --git a/site/src/pages/AgentsPage/AgentsPageView.tsx b/site/src/pages/AgentsPage/AgentsPageView.tsx index f598f4bcbc..284688da8e 100644 --- a/site/src/pages/AgentsPage/AgentsPageView.tsx +++ b/site/src/pages/AgentsPage/AgentsPageView.tsx @@ -23,6 +23,9 @@ export interface AgentsOutletContext { requestPinAgent: (chatId: string) => void; requestUnpinAgent: (chatId: string) => void; requestReorderPinnedAgent?: (chatId: string, pinOrder: number) => void; + onRegenerateTitle?: (chatId: string) => void; + isRegeneratingTitle: boolean; + regeneratingTitleChatId: string | null; isSidebarCollapsed: boolean; onToggleSidebarCollapsed: () => void; onExpandSidebar: () => void; @@ -59,6 +62,9 @@ interface AgentsPageViewProps { requestPinAgent: (chatId: string) => void; requestUnpinAgent: (chatId: string) => void; requestReorderPinnedAgent?: (chatId: string, pinOrder: number) => void; + onRegenerateTitle: (chatId: string) => void; + isRegeneratingTitle: boolean; + regeneratingTitleChatId: string | null; onToggleSidebarCollapsed: () => void; isAgentsAdmin: boolean; hasNextPage: boolean | undefined; @@ -93,6 +99,9 @@ export const AgentsPageView: FC = ({ requestPinAgent, requestUnpinAgent, requestReorderPinnedAgent, + onRegenerateTitle, + isRegeneratingTitle, + regeneratingTitleChatId, onToggleSidebarCollapsed, isAgentsAdmin, hasNextPage, @@ -133,6 +142,9 @@ export const AgentsPageView: FC = ({ requestPinAgent, requestUnpinAgent, requestReorderPinnedAgent, + onRegenerateTitle, + isRegeneratingTitle, + regeneratingTitleChatId, isSidebarCollapsed, onToggleSidebarCollapsed, onExpandSidebar, @@ -166,6 +178,9 @@ export const AgentsPageView: FC = ({ onPinAgent={requestPinAgent} onUnpinAgent={requestUnpinAgent} onReorderPinnedAgent={requestReorderPinnedAgent} + onRegenerateTitle={onRegenerateTitle} + isRegeneratingTitle={isRegeneratingTitle} + regeneratingTitleChatId={regeneratingTitleChatId} onBeforeNewAgent={handleNewAgent} isCreating={isCreating} isArchiving={isArchiving} diff --git a/site/src/pages/AgentsPage/components/AgentDetail/TopBar.stories.tsx b/site/src/pages/AgentsPage/components/AgentDetail/TopBar.stories.tsx index a71ed1fce9..8f2e8222c7 100644 --- a/site/src/pages/AgentsPage/components/AgentDetail/TopBar.stories.tsx +++ b/site/src/pages/AgentsPage/components/AgentDetail/TopBar.stories.tsx @@ -1,26 +1,27 @@ import type { Meta, StoryObj } from "@storybook/react-vite"; -import { expect, userEvent, waitFor, within } from "storybook/test"; +import { expect, fn, userEvent, waitFor, within } from "storybook/test"; import { AgentDetailTopBar } from "./TopBar"; const defaultProps = { chatTitle: "Build authentication feature", panel: { showSidebarPanel: false, - onToggleSidebar: () => {}, + onToggleSidebar: fn(), }, workspace: { canOpenEditors: true, canOpenWorkspace: true, - onOpenInEditor: () => {}, - onViewWorkspace: () => {}, - onOpenTerminal: () => {}, + onOpenInEditor: fn(), + onViewWorkspace: fn(), + onOpenTerminal: fn(), sshCommand: "ssh main.my-workspace.admin.coder", }, - onArchiveAgent: () => {}, - onArchiveAndDeleteWorkspace: () => {}, - onUnarchiveAgent: () => {}, + onArchiveAgent: fn(), + onArchiveAndDeleteWorkspace: fn(), + onRegenerateTitle: fn(), + onUnarchiveAgent: fn(), isSidebarCollapsed: false, - onToggleSidebarCollapsed: () => {}, + onToggleSidebarCollapsed: fn(), } satisfies React.ComponentProps; const meta: Meta = { @@ -36,6 +37,13 @@ type Story = StoryObj; export const Default: Story = {}; +export const RegeneratingTitle: Story = { + args: { + ...Default.args, + isRegeneratingTitle: true, + }, +}; + export const WithPanelOpen: Story = { args: { panel: { @@ -227,10 +235,22 @@ export const MobileWithClosedPR: Story = { }, }; +export const GenerateTitle: Story = { + play: async ({ canvasElement }) => { + const canvas = within(canvasElement); + const trigger = canvas.getByLabelText("Open agent actions"); + await userEvent.click(trigger); + await waitFor(() => { + const body = within(document.body); + expect(body.getByText("Generate new title")).toBeInTheDocument(); + }); + }, +}; + export const ArchivedWithUnarchive: Story = { args: { isArchived: true, - onUnarchiveAgent: () => {}, + onUnarchiveAgent: fn(), }, play: async ({ canvasElement }) => { const canvas = within(canvasElement); @@ -243,6 +263,7 @@ export const ArchivedWithUnarchive: Story = { expect(body.getByText("Unarchive Agent")).toBeInTheDocument(); }); const body = within(document.body); + expect(body.queryByText("Generate new title")).not.toBeInTheDocument(); expect(body.queryByText("Archive Agent")).not.toBeInTheDocument(); expect( body.queryByText("Archive & Delete Workspace"), diff --git a/site/src/pages/AgentsPage/components/AgentDetail/TopBar.tsx b/site/src/pages/AgentsPage/components/AgentDetail/TopBar.tsx index 4f55f59d98..f3ce3240e7 100644 --- a/site/src/pages/AgentsPage/components/AgentDetail/TopBar.tsx +++ b/site/src/pages/AgentsPage/components/AgentDetail/TopBar.tsx @@ -12,10 +12,12 @@ import { PanelRightOpenIcon, TerminalIcon, Trash2Icon, + WandSparklesIcon, } from "lucide-react"; import type { FC } from "react"; import { Link } from "react-router"; import { toast } from "sonner"; +import { cn } from "utils/cn"; import type * as TypesGen from "#/api/typesGenerated"; import type { ChatDiffStatus } from "#/api/typesGenerated"; import { Button } from "#/components/Button/Button"; @@ -26,7 +28,7 @@ import { DropdownMenuSeparator, DropdownMenuTrigger, } from "#/components/DropdownMenu/DropdownMenu"; -import { cn } from "#/utils/cn"; +import { Spinner } from "#/components/Spinner/Spinner"; import { parsePullRequestUrl } from "../../utils/pullRequest"; import { useEmbedContext } from "../EmbedContext"; import { PrStateIcon } from "../GitPanel/GitPanel"; @@ -53,6 +55,9 @@ type AgentDetailTopBarProps = { onArchiveAgent: () => void; onUnarchiveAgent: () => void; onArchiveAndDeleteWorkspace: () => void; + onRegenerateTitle?: () => void; + isRegeneratingTitle?: boolean; + isRegenerateTitleDisabled?: boolean; hasWorkspace?: boolean; isArchived?: boolean; isSidebarCollapsed: boolean; @@ -68,6 +73,9 @@ export const AgentDetailTopBar: FC = ({ onArchiveAgent, onUnarchiveAgent, onArchiveAndDeleteWorkspace, + onRegenerateTitle, + isRegeneratingTitle, + isRegenerateTitleDisabled, hasWorkspace, isArchived, isSidebarCollapsed, @@ -115,7 +123,12 @@ export const AgentDetailTopBar: FC = ({ {/* Title area */}
{chatTitle && ( -
+
{parentChat && ( <>
)}
@@ -227,7 +252,23 @@ export const AgentDetailTopBar: FC = ({ View Workspace - + {!isArchived && ( + <> + + {onRegenerateTitle && ( + <> + + + Generate new title + + + + )} + + )} {isArchived ? ( diff --git a/site/src/pages/AgentsPage/components/AgentDetailView.stories.tsx b/site/src/pages/AgentsPage/components/AgentDetailView.stories.tsx index 6894dd9510..8f6d62f9c5 100644 --- a/site/src/pages/AgentsPage/components/AgentDetailView.stories.tsx +++ b/site/src/pages/AgentsPage/components/AgentDetailView.stories.tsx @@ -142,6 +142,7 @@ const StoryAgentDetailView: FC = ({ editing, ...overrides }) => { handleArchiveAgentAction: fn(), handleUnarchiveAgentAction: fn(), handleArchiveAndDeleteWorkspaceAction: fn(), + handleRegenerateTitle: fn(), scrollContainerRef: { current: null }, hasMoreMessages: false, isFetchingMoreMessages: false, diff --git a/site/src/pages/AgentsPage/components/AgentDetailView.tsx b/site/src/pages/AgentsPage/components/AgentDetailView.tsx index 8b095b738a..fd3ff693a4 100644 --- a/site/src/pages/AgentsPage/components/AgentDetailView.tsx +++ b/site/src/pages/AgentsPage/components/AgentDetailView.tsx @@ -117,6 +117,9 @@ interface AgentDetailViewProps { handleArchiveAgentAction: () => void; handleUnarchiveAgentAction: () => void; handleArchiveAndDeleteWorkspaceAction: () => void; + handleRegenerateTitle?: () => void; + isRegeneratingTitle?: boolean; + isRegenerateTitleDisabled?: boolean; // Scroll container ref. scrollContainerRef: RefObject; @@ -179,6 +182,9 @@ export const AgentDetailView: FC = ({ handleArchiveAgentAction, handleUnarchiveAgentAction, handleArchiveAndDeleteWorkspaceAction, + handleRegenerateTitle, + isRegeneratingTitle, + isRegenerateTitleDisabled, scrollContainerRef, scrollToBottomRef, hasMoreMessages, @@ -261,6 +267,11 @@ export const AgentDetailView: FC = ({ onArchiveAndDeleteWorkspace={ handleArchiveAndDeleteWorkspaceAction } + {...(handleRegenerateTitle + ? { onRegenerateTitle: handleRegenerateTitle } + : {})} + isRegeneratingTitle={isRegeneratingTitle} + isRegenerateTitleDisabled={isRegenerateTitleDisabled} hasWorkspace={hasWorkspace} isArchived={isArchived} diffStatusData={diffStatusData} @@ -435,6 +446,7 @@ export const AgentDetailLoadingView: FC = ({ }} onArchiveAgent={() => {}} onUnarchiveAgent={() => {}} + onRegenerateTitle={() => {}} onArchiveAndDeleteWorkspace={() => {}} hasWorkspace={false} isSidebarCollapsed={isSidebarCollapsed} @@ -507,6 +519,7 @@ export const AgentDetailNotFoundView: FC = ({ }} onArchiveAgent={() => {}} onUnarchiveAgent={() => {}} + onRegenerateTitle={() => {}} onArchiveAndDeleteWorkspace={() => {}} hasWorkspace={false} isSidebarCollapsed={isSidebarCollapsed} diff --git a/site/src/pages/AgentsPage/components/Sidebar/AgentsSidebar.stories.tsx b/site/src/pages/AgentsPage/components/Sidebar/AgentsSidebar.stories.tsx index 7c0ef418d5..294820d1b2 100644 --- a/site/src/pages/AgentsPage/components/Sidebar/AgentsSidebar.stories.tsx +++ b/site/src/pages/AgentsPage/components/Sidebar/AgentsSidebar.stories.tsx @@ -74,8 +74,11 @@ const meta: Meta = { onArchiveAndDeleteWorkspace: fn(), onPinAgent: fn(), onUnpinAgent: fn(), + onRegenerateTitle: fn(), onBeforeNewAgent: fn(), isCreating: false, + isRegeneratingTitle: false, + regeneratingTitleChatId: null, archivedFilter: "active" as const, onArchivedFilterChange: fn(), }, diff --git a/site/src/pages/AgentsPage/components/Sidebar/AgentsSidebar.test.tsx b/site/src/pages/AgentsPage/components/Sidebar/AgentsSidebar.test.tsx index cf883ae12b..50d2829f60 100644 --- a/site/src/pages/AgentsPage/components/Sidebar/AgentsSidebar.test.tsx +++ b/site/src/pages/AgentsPage/components/Sidebar/AgentsSidebar.test.tsx @@ -108,6 +108,9 @@ const defaultProps: React.ComponentProps = { onArchiveAndDeleteWorkspace: vi.fn(), onPinAgent: vi.fn(), onUnpinAgent: vi.fn(), + onRegenerateTitle: vi.fn(), + isRegeneratingTitle: false, + regeneratingTitleChatId: null, onBeforeNewAgent: vi.fn(), isCreating: false, archivedFilter: "active" as const, diff --git a/site/src/pages/AgentsPage/components/Sidebar/AgentsSidebar.tsx b/site/src/pages/AgentsPage/components/Sidebar/AgentsSidebar.tsx index 4e8331fe54..cc811a8c8c 100644 --- a/site/src/pages/AgentsPage/components/Sidebar/AgentsSidebar.tsx +++ b/site/src/pages/AgentsPage/components/Sidebar/AgentsSidebar.tsx @@ -72,6 +72,7 @@ import { DropdownMenu, DropdownMenuContent, DropdownMenuItem, + DropdownMenuSeparator, DropdownMenuTrigger, } from "#/components/DropdownMenu/DropdownMenu"; import { ExternalImage } from "#/components/ExternalImage/ExternalImage"; @@ -123,10 +124,13 @@ interface AgentsSidebarProps { onPinAgent: (chatId: string) => void; onUnpinAgent: (chatId: string) => void; onReorderPinnedAgent?: (chatId: string, pinOrder: number) => void; + onRegenerateTitle: (chatId: string) => void; onBeforeNewAgent?: () => void; isCreating: boolean; isArchiving?: boolean; archivingChatId?: string | null; + isRegeneratingTitle?: boolean; + regeneratingTitleChatId?: string | null; isLoading?: boolean; loadError?: unknown; onRetryLoad?: () => void; @@ -367,6 +371,8 @@ interface ChatTreeContextValue { readonly chatErrorReasons: Record; readonly isArchiving: boolean; readonly archivingChatId: string | null; + readonly isRegeneratingTitle: boolean; + readonly regeneratingTitleChatId: string | null; readonly toggleExpanded: (chatID: string) => void; readonly onArchiveAgent: (chatId: string) => void; readonly onUnarchiveAgent: (chatId: string) => void; @@ -376,6 +382,7 @@ interface ChatTreeContextValue { ) => void; readonly onPinAgent: (chatId: string) => void; readonly onUnpinAgent: (chatId: string) => void; + readonly onRegenerateTitle: (chatId: string) => void; } const ChatTreeContext = createContext(null); @@ -405,12 +412,15 @@ const ChatTreeNode: FC = ({ chat, isChildNode }) => { chatErrorReasons, isArchiving, archivingChatId, + isRegeneratingTitle, + regeneratingTitleChatId, toggleExpanded, onArchiveAgent, onUnarchiveAgent, onArchiveAndDeleteWorkspace, onPinAgent, onUnpinAgent, + onRegenerateTitle, } = useChatTree(); const chatID = chat.id; const childIDs = (chatTree.childrenById.get(chatID) ?? []).filter((childID) => @@ -448,6 +458,8 @@ const ChatTreeNode: FC = ({ chat, isChildNode }) => { }`; const workspaceId = chat.workspace_id; const isArchivingThisChat = isArchiving && archivingChatId === chat.id; + const isRegeneratingThisChat = + isRegeneratingTitle && regeneratingTitleChatId === chat.id; const isExpanded = normalizedSearch ? true : (expandedById[chatID] ?? false); return ( @@ -509,13 +521,21 @@ const ChatTreeNode: FC = ({ chat, isChildNode }) => {
{chat.title} + {isRegeneratingThisChat && ( + + Regenerating title… + + )}
{hasLinkedDiffStatus && hasLineStats && ( @@ -598,6 +618,14 @@ const ChatTreeNode: FC = ({ chat, isChildNode }) => { ) : ( <> + onRegenerateTitle(chat.id)} + > + + Generate new title + + = (props) => { onPinAgent, onUnpinAgent, onReorderPinnedAgent, + onRegenerateTitle, onBeforeNewAgent, isCreating, isArchiving = false, archivingChatId = null, + isRegeneratingTitle = false, + regeneratingTitleChatId = null, isLoading = false, loadError, onRetryLoad, @@ -910,12 +941,15 @@ export const AgentsSidebar: FC = (props) => { chatErrorReasons, isArchiving, archivingChatId, + isRegeneratingTitle, + regeneratingTitleChatId, toggleExpanded, onArchiveAgent, onUnarchiveAgent, onArchiveAndDeleteWorkspace, onPinAgent, onUnpinAgent, + onRegenerateTitle, }; const subNavTitle = "Settings";