feat: add manual chat title regeneration (#23633)

## Summary

Adds a "Generate new title" action that lets users manually regenerate a
chat's title using richer conversation context than the automatic
first-message title path.

## Changes

### Backend
- **New endpoint:** `POST
/api/experimental/chats/{chatID}/title/regenerate` returns the updated
Chat with a regenerated title
- **Manual title algorithm:** Extracts useful user/assistant text turns
→ selects first user turn + last 3 turns → builds context with gap
markers → renders prompt with anti-recency guidance → calls lightweight
model → normalizes output
- **Helpers:** `extractManualTitleTurns`,
`selectManualTitleTurnIndexes`, `buildManualTitleContext`,
`renderManualTitlePrompt`, `generateManualTitle` — all private, with the
public `Server.RegenerateChatTitle` method
- **SDK:** `ExperimentalClient.RegenerateChatTitle(ctx, chatID) (Chat,
error)`
- Persists title via existing `UpdateChatByID` and broadcasts
`ChatEventKindTitleChange`

### Frontend
- API client method + React Query mutation with cache invalidation
- "Generate new title" menu item (with wand icon) in both TopBar and
Sidebar dropdown menus
- Loading/disabled state while regeneration is in-flight
- Error toast on failure
- Stories updated for both menus

### Tests
- `quickgen_test.go`: Table-driven tests for all 4 helper functions
(turn extraction, index selection, context building, prompt rendering)
- `exp_chats_test.go`: Handler tests (ChatNotFound,
NotFoundForDifferentUser, NoDaemon)

## Design notes
- The existing auto-title path (`maybeGenerateChatTitle`, `titleInput`)
is completely unchanged
- Manual regeneration uses richer context (first user turn + last 3
turns + gap markers) vs the auto path's single first message
- Endpoint is experimental and marked with `@x-apidocgen {"skip": true}`
This commit is contained in:
Michael Suchacz
2026-03-27 01:47:19 +01:00
committed by GitHub
parent f35f2a28e6
commit 2312e5c428
30 changed files with 2430 additions and 63 deletions
+1
View File
@@ -1234,6 +1234,7 @@ func New(options *Options) *API {
r.Get("/git", api.watchChatGit)
})
r.Post("/interrupt", api.interruptChat)
r.Post("/title/regenerate", api.regenerateChatTitle)
r.Get("/diff", api.getChatDiffContents)
r.Route("/queue/{queuedMessage}", func(r chi.Router) {
r.Delete("/", api.deleteChatQueuedMessage)
+30
View File
@@ -2614,6 +2614,14 @@ func (q *querier) GetChatMessagesByChatID(ctx context.Context, arg database.GetC
return q.db.GetChatMessagesByChatID(ctx, arg)
}
func (q *querier) GetChatMessagesByChatIDAscPaginated(ctx context.Context, arg database.GetChatMessagesByChatIDAscPaginatedParams) ([]database.ChatMessage, error) {
_, err := q.GetChatByID(ctx, arg.ChatID)
if err != nil {
return nil, err
}
return q.db.GetChatMessagesByChatIDAscPaginated(ctx, arg)
}
func (q *querier) GetChatMessagesByChatIDDescPaginated(ctx context.Context, arg database.GetChatMessagesByChatIDDescPaginatedParams) ([]database.ChatMessage, error) {
_, err := q.GetChatByID(ctx, arg.ChatID)
if err != nil {
@@ -5736,6 +5744,17 @@ func (q *querier) UpdateChatLabelsByID(ctx context.Context, arg database.UpdateC
return q.db.UpdateChatLabelsByID(ctx, arg)
}
func (q *querier) UpdateChatLastModelConfigByID(ctx context.Context, arg database.UpdateChatLastModelConfigByIDParams) (database.Chat, error) {
chat, err := q.db.GetChatByID(ctx, arg.ID)
if err != nil {
return database.Chat{}, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return database.Chat{}, err
}
return q.db.UpdateChatLastModelConfigByID(ctx, arg)
}
func (q *querier) UpdateChatMCPServerIDs(ctx context.Context, arg database.UpdateChatMCPServerIDsParams) (database.Chat, error) {
chat, err := q.db.GetChatByID(ctx, arg.ID)
if err != nil {
@@ -5801,6 +5820,17 @@ func (q *querier) UpdateChatStatus(ctx context.Context, arg database.UpdateChatS
return q.db.UpdateChatStatus(ctx, arg)
}
func (q *querier) UpdateChatStatusPreserveUpdatedAt(ctx context.Context, arg database.UpdateChatStatusPreserveUpdatedAtParams) (database.Chat, error) {
chat, err := q.db.GetChatByID(ctx, arg.ID)
if err != nil {
return database.Chat{}, err
}
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
return database.Chat{}, err
}
return q.db.UpdateChatStatusPreserveUpdatedAt(ctx, arg)
}
func (q *querier) UpdateChatWorkspaceBinding(ctx context.Context, arg database.UpdateChatWorkspaceBindingParams) (database.Chat, error) {
chat, err := q.db.GetChatByID(ctx, arg.ID)
if err != nil {
+28
View File
@@ -592,6 +592,14 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().GetChatMessagesByChatID(gomock.Any(), arg).Return(msgs, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionRead).Returns(msgs)
}))
s.Run("GetChatMessagesByChatIDAscPaginated", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
msgs := []database.ChatMessage{testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})}
arg := database.GetChatMessagesByChatIDAscPaginatedParams{ChatID: chat.ID, AfterID: 0, LimitVal: 50}
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().GetChatMessagesByChatIDAscPaginated(gomock.Any(), arg).Return(msgs, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionRead).Returns(msgs)
}))
s.Run("GetChatMessagesByChatIDDescPaginated", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
msgs := []database.ChatMessage{testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})}
@@ -789,6 +797,26 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().UpdateChatLabelsByID(gomock.Any(), arg).Return(chat, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat)
}))
s.Run("UpdateChatLastModelConfigByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
arg := database.UpdateChatLastModelConfigByIDParams{
ID: chat.ID,
LastModelConfigID: uuid.New(),
}
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().UpdateChatLastModelConfigByID(gomock.Any(), arg).Return(chat, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat)
}))
s.Run("UpdateChatStatusPreserveUpdatedAt", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
arg := database.UpdateChatStatusPreserveUpdatedAtParams{
ID: chat.ID,
Status: database.ChatStatusRunning,
}
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
dbm.EXPECT().UpdateChatStatusPreserveUpdatedAt(gomock.Any(), arg).Return(chat, nil).AnyTimes()
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat)
}))
s.Run("UpdateChatHeartbeat", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
chat := testutil.Fake(s.T(), faker, database.Chat{})
arg := database.UpdateChatHeartbeatParams{
+24
View File
@@ -1144,6 +1144,14 @@ func (m queryMetricsStore) GetChatMessagesByChatID(ctx context.Context, chatID d
return r0, r1
}
func (m queryMetricsStore) GetChatMessagesByChatIDAscPaginated(ctx context.Context, arg database.GetChatMessagesByChatIDAscPaginatedParams) ([]database.ChatMessage, error) {
start := time.Now()
r0, r1 := m.s.GetChatMessagesByChatIDAscPaginated(ctx, arg)
m.queryLatencies.WithLabelValues("GetChatMessagesByChatIDAscPaginated").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatMessagesByChatIDAscPaginated").Inc()
return r0, r1
}
func (m queryMetricsStore) GetChatMessagesByChatIDDescPaginated(ctx context.Context, arg database.GetChatMessagesByChatIDDescPaginatedParams) ([]database.ChatMessage, error) {
start := time.Now()
r0, r1 := m.s.GetChatMessagesByChatIDDescPaginated(ctx, arg)
@@ -4096,6 +4104,14 @@ func (m queryMetricsStore) UpdateChatLabelsByID(ctx context.Context, arg databas
return r0, r1
}
func (m queryMetricsStore) UpdateChatLastModelConfigByID(ctx context.Context, arg database.UpdateChatLastModelConfigByIDParams) (database.Chat, error) {
start := time.Now()
r0, r1 := m.s.UpdateChatLastModelConfigByID(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateChatLastModelConfigByID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatLastModelConfigByID").Inc()
return r0, r1
}
func (m queryMetricsStore) UpdateChatMCPServerIDs(ctx context.Context, arg database.UpdateChatMCPServerIDsParams) (database.Chat, error) {
start := time.Now()
r0, r1 := m.s.UpdateChatMCPServerIDs(ctx, arg)
@@ -4144,6 +4160,14 @@ func (m queryMetricsStore) UpdateChatStatus(ctx context.Context, arg database.Up
return r0, r1
}
func (m queryMetricsStore) UpdateChatStatusPreserveUpdatedAt(ctx context.Context, arg database.UpdateChatStatusPreserveUpdatedAtParams) (database.Chat, error) {
start := time.Now()
r0, r1 := m.s.UpdateChatStatusPreserveUpdatedAt(ctx, arg)
m.queryLatencies.WithLabelValues("UpdateChatStatusPreserveUpdatedAt").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatStatusPreserveUpdatedAt").Inc()
return r0, r1
}
func (m queryMetricsStore) UpdateChatWorkspaceBinding(ctx context.Context, arg database.UpdateChatWorkspaceBindingParams) (database.Chat, error) {
start := time.Now()
r0, r1 := m.s.UpdateChatWorkspaceBinding(ctx, arg)
+45
View File
@@ -2103,6 +2103,21 @@ func (mr *MockStoreMockRecorder) GetChatMessagesByChatID(ctx, arg any) *gomock.C
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatMessagesByChatID", reflect.TypeOf((*MockStore)(nil).GetChatMessagesByChatID), ctx, arg)
}
// GetChatMessagesByChatIDAscPaginated mocks base method.
func (m *MockStore) GetChatMessagesByChatIDAscPaginated(ctx context.Context, arg database.GetChatMessagesByChatIDAscPaginatedParams) ([]database.ChatMessage, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetChatMessagesByChatIDAscPaginated", ctx, arg)
ret0, _ := ret[0].([]database.ChatMessage)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetChatMessagesByChatIDAscPaginated indicates an expected call of GetChatMessagesByChatIDAscPaginated.
func (mr *MockStoreMockRecorder) GetChatMessagesByChatIDAscPaginated(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatMessagesByChatIDAscPaginated", reflect.TypeOf((*MockStore)(nil).GetChatMessagesByChatIDAscPaginated), ctx, arg)
}
// GetChatMessagesByChatIDDescPaginated mocks base method.
func (m *MockStore) GetChatMessagesByChatIDDescPaginated(ctx context.Context, arg database.GetChatMessagesByChatIDDescPaginatedParams) ([]database.ChatMessage, error) {
m.ctrl.T.Helper()
@@ -7745,6 +7760,21 @@ func (mr *MockStoreMockRecorder) UpdateChatLabelsByID(ctx, arg any) *gomock.Call
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatLabelsByID", reflect.TypeOf((*MockStore)(nil).UpdateChatLabelsByID), ctx, arg)
}
// UpdateChatLastModelConfigByID mocks base method.
func (m *MockStore) UpdateChatLastModelConfigByID(ctx context.Context, arg database.UpdateChatLastModelConfigByIDParams) (database.Chat, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateChatLastModelConfigByID", ctx, arg)
ret0, _ := ret[0].(database.Chat)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateChatLastModelConfigByID indicates an expected call of UpdateChatLastModelConfigByID.
func (mr *MockStoreMockRecorder) UpdateChatLastModelConfigByID(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatLastModelConfigByID", reflect.TypeOf((*MockStore)(nil).UpdateChatLastModelConfigByID), ctx, arg)
}
// UpdateChatMCPServerIDs mocks base method.
func (m *MockStore) UpdateChatMCPServerIDs(ctx context.Context, arg database.UpdateChatMCPServerIDsParams) (database.Chat, error) {
m.ctrl.T.Helper()
@@ -7834,6 +7864,21 @@ func (mr *MockStoreMockRecorder) UpdateChatStatus(ctx, arg any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatStatus", reflect.TypeOf((*MockStore)(nil).UpdateChatStatus), ctx, arg)
}
// UpdateChatStatusPreserveUpdatedAt mocks base method.
func (m *MockStore) UpdateChatStatusPreserveUpdatedAt(ctx context.Context, arg database.UpdateChatStatusPreserveUpdatedAtParams) (database.Chat, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateChatStatusPreserveUpdatedAt", ctx, arg)
ret0, _ := ret[0].(database.Chat)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpdateChatStatusPreserveUpdatedAt indicates an expected call of UpdateChatStatusPreserveUpdatedAt.
func (mr *MockStoreMockRecorder) UpdateChatStatusPreserveUpdatedAt(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatStatusPreserveUpdatedAt", reflect.TypeOf((*MockStore)(nil).UpdateChatStatusPreserveUpdatedAt), ctx, arg)
}
// UpdateChatWorkspaceBinding mocks base method.
func (m *MockStore) UpdateChatWorkspaceBinding(ctx context.Context, arg database.UpdateChatWorkspaceBindingParams) (database.Chat, error) {
m.ctrl.T.Helper()
+3
View File
@@ -250,6 +250,7 @@ type sqlcQuerier interface {
GetChatIncludeDefaultSystemPrompt(ctx context.Context) (bool, error)
GetChatMessageByID(ctx context.Context, id int64) (ChatMessage, error)
GetChatMessagesByChatID(ctx context.Context, arg GetChatMessagesByChatIDParams) ([]ChatMessage, error)
GetChatMessagesByChatIDAscPaginated(ctx context.Context, arg GetChatMessagesByChatIDAscPaginatedParams) ([]ChatMessage, error)
GetChatMessagesByChatIDDescPaginated(ctx context.Context, arg GetChatMessagesByChatIDDescPaginatedParams) ([]ChatMessage, error)
GetChatMessagesForPromptByChatID(ctx context.Context, chatID uuid.UUID) ([]ChatMessage, error)
GetChatModelConfigByID(ctx context.Context, id uuid.UUID) (ChatModelConfig, error)
@@ -852,12 +853,14 @@ type sqlcQuerier interface {
// replicas know the worker is still alive.
UpdateChatHeartbeat(ctx context.Context, arg UpdateChatHeartbeatParams) (int64, error)
UpdateChatLabelsByID(ctx context.Context, arg UpdateChatLabelsByIDParams) (Chat, error)
UpdateChatLastModelConfigByID(ctx context.Context, arg UpdateChatLastModelConfigByIDParams) (Chat, error)
UpdateChatMCPServerIDs(ctx context.Context, arg UpdateChatMCPServerIDsParams) (Chat, error)
UpdateChatMessageByID(ctx context.Context, arg UpdateChatMessageByIDParams) (ChatMessage, error)
UpdateChatModelConfig(ctx context.Context, arg UpdateChatModelConfigParams) (ChatModelConfig, error)
UpdateChatPinOrder(ctx context.Context, arg UpdateChatPinOrderParams) error
UpdateChatProvider(ctx context.Context, arg UpdateChatProviderParams) (ChatProvider, error)
UpdateChatStatus(ctx context.Context, arg UpdateChatStatusParams) (Chat, error)
UpdateChatStatusPreserveUpdatedAt(ctx context.Context, arg UpdateChatStatusPreserveUpdatedAtParams) (Chat, error)
UpdateChatWorkspaceBinding(ctx context.Context, arg UpdateChatWorkspaceBindingParams) (Chat, error)
UpdateCryptoKeyDeletesAt(ctx context.Context, arg UpdateCryptoKeyDeletesAtParams) (CryptoKey, error)
UpdateCustomRole(ctx context.Context, arg UpdateCustomRoleParams) (CustomRole, error)
+176
View File
@@ -4932,6 +4932,73 @@ func (q *sqlQuerier) GetChatMessagesByChatID(ctx context.Context, arg GetChatMes
return items, nil
}
const getChatMessagesByChatIDAscPaginated = `-- name: GetChatMessagesByChatIDAscPaginated :many
SELECT
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id
FROM
chat_messages
WHERE
chat_id = $1::uuid
AND id > $2::bigint
AND visibility IN ('user', 'both')
AND deleted = false
ORDER BY
id ASC
LIMIT
COALESCE(NULLIF($3::int, 0), 50)
`
type GetChatMessagesByChatIDAscPaginatedParams struct {
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
AfterID int64 `db:"after_id" json:"after_id"`
LimitVal int32 `db:"limit_val" json:"limit_val"`
}
func (q *sqlQuerier) GetChatMessagesByChatIDAscPaginated(ctx context.Context, arg GetChatMessagesByChatIDAscPaginatedParams) ([]ChatMessage, error) {
rows, err := q.db.QueryContext(ctx, getChatMessagesByChatIDAscPaginated, arg.ChatID, arg.AfterID, arg.LimitVal)
if err != nil {
return nil, err
}
defer rows.Close()
var items []ChatMessage
for rows.Next() {
var i ChatMessage
if err := rows.Scan(
&i.ID,
&i.ChatID,
&i.ModelConfigID,
&i.CreatedAt,
&i.Role,
&i.Content,
&i.Visibility,
&i.InputTokens,
&i.OutputTokens,
&i.TotalTokens,
&i.ReasoningTokens,
&i.CacheCreationTokens,
&i.CacheReadTokens,
&i.ContextLimit,
&i.Compressed,
&i.CreatedBy,
&i.ContentVersion,
&i.TotalCostMicros,
&i.RuntimeMs,
&i.Deleted,
&i.ProviderResponseID,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getChatMessagesByChatIDDescPaginated = `-- name: GetChatMessagesByChatIDDescPaginated :many
SELECT
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id
@@ -6254,6 +6321,52 @@ func (q *sqlQuerier) UpdateChatLabelsByID(ctx context.Context, arg UpdateChatLab
return i, err
}
const updateChatLastModelConfigByID = `-- name: UpdateChatLastModelConfigByID :one
UPDATE
chats
SET
-- NOTE: updated_at is intentionally NOT touched here to avoid changing list ordering.
last_model_config_id = $1::uuid
WHERE
id = $2::uuid
RETURNING
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order
`
type UpdateChatLastModelConfigByIDParams struct {
LastModelConfigID uuid.UUID `db:"last_model_config_id" json:"last_model_config_id"`
ID uuid.UUID `db:"id" json:"id"`
}
func (q *sqlQuerier) UpdateChatLastModelConfigByID(ctx context.Context, arg UpdateChatLastModelConfigByIDParams) (Chat, error) {
row := q.db.QueryRowContext(ctx, updateChatLastModelConfigByID, arg.LastModelConfigID, arg.ID)
var i Chat
err := row.Scan(
&i.ID,
&i.OwnerID,
&i.WorkspaceID,
&i.Title,
&i.Status,
&i.WorkerID,
&i.StartedAt,
&i.HeartbeatAt,
&i.CreatedAt,
&i.UpdatedAt,
&i.ParentChatID,
&i.RootChatID,
&i.LastModelConfigID,
&i.Archived,
&i.LastError,
&i.Mode,
pq.Array(&i.MCPServerIDs),
&i.Labels,
&i.BuildID,
&i.AgentID,
&i.PinOrder,
)
return i, err
}
const updateChatMCPServerIDs = `-- name: UpdateChatMCPServerIDs :one
UPDATE
chats
@@ -6479,6 +6592,69 @@ func (q *sqlQuerier) UpdateChatStatus(ctx context.Context, arg UpdateChatStatusP
return i, err
}
const updateChatStatusPreserveUpdatedAt = `-- name: UpdateChatStatusPreserveUpdatedAt :one
UPDATE
chats
SET
status = $1::chat_status,
worker_id = $2::uuid,
started_at = $3::timestamptz,
heartbeat_at = $4::timestamptz,
last_error = $5::text,
updated_at = $6::timestamptz
WHERE
id = $7::uuid
RETURNING
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order
`
type UpdateChatStatusPreserveUpdatedAtParams struct {
Status ChatStatus `db:"status" json:"status"`
WorkerID uuid.NullUUID `db:"worker_id" json:"worker_id"`
StartedAt sql.NullTime `db:"started_at" json:"started_at"`
HeartbeatAt sql.NullTime `db:"heartbeat_at" json:"heartbeat_at"`
LastError sql.NullString `db:"last_error" json:"last_error"`
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
ID uuid.UUID `db:"id" json:"id"`
}
func (q *sqlQuerier) UpdateChatStatusPreserveUpdatedAt(ctx context.Context, arg UpdateChatStatusPreserveUpdatedAtParams) (Chat, error) {
row := q.db.QueryRowContext(ctx, updateChatStatusPreserveUpdatedAt,
arg.Status,
arg.WorkerID,
arg.StartedAt,
arg.HeartbeatAt,
arg.LastError,
arg.UpdatedAt,
arg.ID,
)
var i Chat
err := row.Scan(
&i.ID,
&i.OwnerID,
&i.WorkspaceID,
&i.Title,
&i.Status,
&i.WorkerID,
&i.StartedAt,
&i.HeartbeatAt,
&i.CreatedAt,
&i.UpdatedAt,
&i.ParentChatID,
&i.RootChatID,
&i.LastModelConfigID,
&i.Archived,
&i.LastError,
&i.Mode,
pq.Array(&i.MCPServerIDs),
&i.Labels,
&i.BuildID,
&i.AgentID,
&i.PinOrder,
)
return i, err
}
const updateChatWorkspaceBinding = `-- name: UpdateChatWorkspaceBinding :one
UPDATE chats SET
workspace_id = $1::uuid,
+41
View File
@@ -220,6 +220,21 @@ WHERE
ORDER BY
created_at ASC;
-- name: GetChatMessagesByChatIDAscPaginated :many
SELECT
*
FROM
chat_messages
WHERE
chat_id = @chat_id::uuid
AND id > @after_id::bigint
AND visibility IN ('user', 'both')
AND deleted = false
ORDER BY
id ASC
LIMIT
COALESCE(NULLIF(@limit_val::int, 0), 50);
-- name: GetChatMessagesByChatIDDescPaginated :many
SELECT
*
@@ -466,6 +481,17 @@ WHERE
RETURNING
*;
-- name: UpdateChatLastModelConfigByID :one
UPDATE
chats
SET
-- NOTE: updated_at is intentionally NOT touched here to avoid changing list ordering.
last_model_config_id = @last_model_config_id::uuid
WHERE
id = @id::uuid
RETURNING
*;
-- name: UpdateChatLabelsByID :one
UPDATE
chats
@@ -550,6 +576,21 @@ WHERE
RETURNING
*;
-- name: UpdateChatStatusPreserveUpdatedAt :one
UPDATE
chats
SET
status = @status::chat_status,
worker_id = sqlc.narg('worker_id')::uuid,
started_at = sqlc.narg('started_at')::timestamptz,
heartbeat_at = sqlc.narg('heartbeat_at')::timestamptz,
last_error = sqlc.narg('last_error')::text,
updated_at = @updated_at::timestamptz
WHERE
id = @id::uuid
RETURNING
*;
-- name: GetStaleChats :many
-- Find chats that appear stuck (running but heartbeat has expired).
-- Used for recovery after coderd crashes or long hangs.
+44
View File
@@ -2105,6 +2105,50 @@ func (api *API) interruptChat(rw http.ResponseWriter, r *http.Request) {
httpapi.Write(ctx, rw, http.StatusOK, db2sdk.Chat(chat, nil))
}
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
//
//nolint:revive // HTTP handler writes to ResponseWriter.
func (api *API) regenerateChatTitle(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
chat := httpmw.ChatParam(r)
if !api.Authorize(r, policy.ActionUpdate, chat.RBACObject()) {
httpapi.ResourceNotFound(rw)
return
}
if api.chatDaemon == nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Chat processor is unavailable.",
Detail: "Chat processor is not configured.",
})
return
}
updatedChat, err := api.chatDaemon.RegenerateChatTitle(ctx, chat)
if err != nil {
if errors.Is(err, chatd.ErrManualTitleRegenerationInProgress) {
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
Message: "Title regeneration already in progress for this chat.",
})
return
}
if maybeWriteLimitErr(ctx, rw, err) {
return
}
if httpapi.Is404Error(err) {
httpapi.ResourceNotFound(rw)
return
}
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to regenerate chat title.",
Detail: err.Error(),
})
return
}
httpapi.Write(ctx, rw, http.StatusOK, db2sdk.Chat(updatedChat, nil))
}
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
//
//nolint:revive // HTTP handler writes to ResponseWriter.
+264
View File
@@ -19,6 +19,7 @@ import (
"github.com/google/uuid"
"github.com/shopspring/decimal"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
@@ -31,6 +32,7 @@ import (
"github.com/coder/coder/v2/coderd/externalauth"
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/coderd/rbac/policy"
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/coderd/x/chatd"
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
@@ -3441,6 +3443,268 @@ func TestInterruptChat(t *testing.T) {
})
}
func TestRegenerateChatTitle(t *testing.T) {
t.Parallel()
t.Run("ChatNotFound", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client := newChatClient(t)
_ = coderdtest.CreateFirstUser(t, client.Client)
_, err := client.RegenerateChatTitle(ctx, uuid.New())
requireSDKError(t, err, http.StatusNotFound)
})
t.Run("UpdateDenied", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
clientRaw, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{
Authorizer: &coderdtest.FakeAuthorizer{
ConditionalReturn: func(_ context.Context, _ rbac.Subject, action policy.Action, object rbac.Object) error {
if action == policy.ActionUpdate && object.Type == rbac.ResourceChat.Type {
return xerrors.New("denied")
}
return nil
},
},
DeploymentValues: chatDeploymentValues(t),
})
client := codersdk.NewExperimentalClient(clientRaw)
user := coderdtest.CreateFirstUser(t, client.Client)
modelConfig := createChatModelConfig(t, client)
chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
OwnerID: user.UserID,
LastModelConfigID: modelConfig.ID,
Title: "chat with update denied",
})
require.NoError(t, err)
_, err = client.RegenerateChatTitle(ctx, chat.ID)
requireSDKError(t, err, http.StatusNotFound)
})
t.Run("NotFoundForDifferentUser", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client := newChatClient(t)
firstUser := coderdtest.CreateFirstUser(t, client.Client)
_ = createChatModelConfig(t, client)
createdChat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
Content: []codersdk.ChatInputPart{
{
Type: codersdk.ChatInputPartTypeText,
Text: "private chat",
},
},
})
require.NoError(t, err)
otherClientRaw, _ := coderdtest.CreateAnotherUser(t, client.Client, firstUser.OrganizationID)
otherClient := codersdk.NewExperimentalClient(otherClientRaw)
_, err = otherClient.RegenerateChatTitle(ctx, createdChat.ID)
requireSDKError(t, err, http.StatusNotFound)
})
t.Run("Unauthenticated", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client := newChatClient(t)
_ = coderdtest.CreateFirstUser(t, client.Client)
_ = createChatModelConfig(t, client)
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
Content: []codersdk.ChatInputPart{{
Type: codersdk.ChatInputPartTypeText,
Text: "chat for unauthenticated regeneration",
}},
})
require.NoError(t, err)
unauthenticatedClient := codersdk.NewExperimentalClient(codersdk.New(client.URL))
_, err = unauthenticatedClient.RegenerateChatTitle(ctx, chat.ID)
requireSDKError(t, err, http.StatusUnauthorized)
})
t.Run("UsageLimitExceeded", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, db := newChatClientWithDatabase(t)
_ = coderdtest.CreateFirstUser(t, client.Client)
modelConfig := createChatModelConfig(t, client)
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
Content: []codersdk.ChatInputPart{{
Type: codersdk.ChatInputPartTypeText,
Text: "chat over usage limit",
}},
})
require.NoError(t, err)
wantResetsAt := enableDailyChatUsageLimit(ctx, t, db, 100)
insertAssistantCostMessage(ctx, t, db, chat.ID, modelConfig.ID, 100)
_, err = db.UpdateChatStatus(dbauthz.AsSystemRestricted(ctx), database.UpdateChatStatusParams{
ID: chat.ID,
Status: database.ChatStatusCompleted,
WorkerID: uuid.NullUUID{},
StartedAt: sql.NullTime{},
HeartbeatAt: sql.NullTime{},
LastError: sql.NullString{},
})
require.NoError(t, err)
_, err = client.RegenerateChatTitle(ctx, chat.ID)
limitErr := codersdk.ChatUsageLimitExceededFrom(err)
require.NotNil(t, limitErr)
require.Equal(t, "Chat usage limit exceeded.", limitErr.Message)
require.Equal(t, int64(100), limitErr.SpentMicros)
require.Equal(t, int64(100), limitErr.LimitMicros)
require.True(
t,
limitErr.ResetsAt.Equal(wantResetsAt),
"expected resets_at %s, got %s",
wantResetsAt.UTC().Format(time.RFC3339),
limitErr.ResetsAt.UTC().Format(time.RFC3339),
)
})
t.Run("AlreadyInProgress", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, db := newChatClientWithDatabase(t)
user := coderdtest.CreateFirstUser(t, client.Client)
modelConfig := createChatModelConfig(t, client)
chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
OwnerID: user.UserID,
LastModelConfigID: modelConfig.ID,
Title: "chat with lock held",
})
require.NoError(t, err)
_, err = db.UpdateChatStatus(dbauthz.AsSystemRestricted(ctx), database.UpdateChatStatusParams{
ID: chat.ID,
Status: database.ChatStatusCompleted,
WorkerID: uuid.NullUUID{UUID: uuid.MustParse("00000000-0000-0000-0000-000000000001"), Valid: true},
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
LastError: sql.NullString{},
})
require.NoError(t, err)
res, err := client.Request(
ctx,
http.MethodPost,
fmt.Sprintf("/api/experimental/chats/%s/title/regenerate", chat.ID),
nil,
)
require.NoError(t, err)
defer res.Body.Close()
require.Equal(t, http.StatusConflict, res.StatusCode)
var resp codersdk.Response
require.NoError(t, json.NewDecoder(res.Body).Decode(&resp))
require.Equal(t, "Title regeneration already in progress for this chat.", resp.Message)
})
t.Run("PendingWithoutWorker", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, db := newChatClientWithDatabase(t)
user := coderdtest.CreateFirstUser(t, client.Client)
modelConfig := createChatModelConfig(t, client)
chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
OwnerID: user.UserID,
LastModelConfigID: modelConfig.ID,
Title: "pending chat without worker",
})
require.NoError(t, err)
chat, err = db.UpdateChatStatus(dbauthz.AsSystemRestricted(ctx), database.UpdateChatStatusParams{
ID: chat.ID,
Status: database.ChatStatusPending,
WorkerID: uuid.NullUUID{},
StartedAt: sql.NullTime{},
HeartbeatAt: sql.NullTime{},
LastError: sql.NullString{},
})
require.NoError(t, err)
before, err := db.GetChatByID(dbauthz.AsSystemRestricted(ctx), chat.ID)
require.NoError(t, err)
res, err := client.Request(
ctx,
http.MethodPost,
fmt.Sprintf("/api/experimental/chats/%s/title/regenerate", chat.ID),
nil,
)
require.NoError(t, err)
defer res.Body.Close()
require.Equal(t, http.StatusConflict, res.StatusCode)
var resp codersdk.Response
require.NoError(t, json.NewDecoder(res.Body).Decode(&resp))
require.Equal(t, "Title regeneration already in progress for this chat.", resp.Message)
persisted, err := db.GetChatByID(dbauthz.AsSystemRestricted(ctx), chat.ID)
require.NoError(t, err)
require.Equal(t, database.ChatStatusPending, persisted.Status)
require.False(t, persisted.WorkerID.Valid)
require.True(t, persisted.UpdatedAt.Equal(before.UpdatedAt))
})
t.Run("RegenerationFailure", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client, db := newChatClientWithDatabase(t)
_ = coderdtest.CreateFirstUser(t, client.Client)
_ = createChatModelConfig(t, client)
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
Content: []codersdk.ChatInputPart{
{
Type: codersdk.ChatInputPartTypeText,
Text: "test chat",
},
},
})
require.NoError(t, err)
_, err = db.UpdateChatStatus(dbauthz.AsSystemRestricted(ctx), database.UpdateChatStatusParams{
ID: chat.ID,
Status: database.ChatStatusCompleted,
WorkerID: uuid.NullUUID{},
StartedAt: sql.NullTime{},
HeartbeatAt: sql.NullTime{},
LastError: sql.NullString{},
})
require.NoError(t, err)
before, err := db.GetChatByID(dbauthz.AsSystemRestricted(ctx), chat.ID)
require.NoError(t, err)
_, err = client.RegenerateChatTitle(ctx, chat.ID)
requireSDKError(t, err, http.StatusInternalServerError)
after, err := db.GetChatByID(dbauthz.AsSystemRestricted(ctx), chat.ID)
require.NoError(t, err)
require.True(t, after.UpdatedAt.Equal(before.UpdatedAt))
})
}
func TestGetChatDiffStatus(t *testing.T) {
t.Parallel()
+499 -37
View File
@@ -1333,6 +1333,477 @@ func (p *Server) InterruptChat(
return updatedChat
}
const manualTitleMessageWindowLimit = 50
var ErrManualTitleRegenerationInProgress = xerrors.New(
"manual title regeneration already in progress",
)
type manualTitleGenerationError struct {
cause error
modelConfig database.ChatModelConfig
usage fantasy.Usage
}
func (e *manualTitleGenerationError) Error() string {
return e.cause.Error()
}
func (e *manualTitleGenerationError) Unwrap() error {
return e.cause
}
var manualTitleLockWorkerID = uuid.MustParse(
"00000000-0000-0000-0000-000000000001",
)
const manualTitleLockStaleAfter = time.Minute
func isPendingOrRunningChatStatus(status database.ChatStatus) bool {
switch status {
case database.ChatStatusPending, database.ChatStatusRunning:
return true
default:
return false
}
}
func isFreshManualTitleLock(chat database.Chat, now time.Time) bool {
if !chat.WorkerID.Valid || chat.WorkerID.UUID != manualTitleLockWorkerID {
return false
}
leaseAt := chat.HeartbeatAt
if !leaseAt.Valid {
leaseAt = chat.StartedAt
}
return leaseAt.Valid && leaseAt.Time.After(now.Add(-manualTitleLockStaleAfter))
}
// updateChatStatusPreserveUpdatedAt applies internal lock transitions without
// changing chat recency, because chat list ordering uses updated_at.
func updateChatStatusPreserveUpdatedAt(
ctx context.Context,
store database.Store,
chat database.Chat,
workerID uuid.NullUUID,
startedAt sql.NullTime,
heartbeatAt sql.NullTime,
) (database.Chat, error) {
return store.UpdateChatStatusPreserveUpdatedAt(
ctx,
database.UpdateChatStatusPreserveUpdatedAtParams{
ID: chat.ID,
Status: chat.Status,
WorkerID: workerID,
StartedAt: startedAt,
HeartbeatAt: heartbeatAt,
LastError: chat.LastError,
UpdatedAt: chat.UpdatedAt,
},
)
}
func (p *Server) acquireManualTitleLock(ctx context.Context, chatID uuid.UUID) error {
now := time.Now()
return p.db.InTx(func(tx database.Store) error {
lockedChat, err := tx.GetChatByIDForUpdate(ctx, chatID)
if err != nil {
return xerrors.Errorf("lock chat for manual title regeneration: %w", err)
}
if isPendingOrRunningChatStatus(lockedChat.Status) ||
isFreshManualTitleLock(lockedChat, now) {
return ErrManualTitleRegenerationInProgress
}
_, err = updateChatStatusPreserveUpdatedAt(
ctx,
tx,
lockedChat,
uuid.NullUUID{UUID: manualTitleLockWorkerID, Valid: true},
sql.NullTime{Time: now, Valid: true},
sql.NullTime{Time: now, Valid: true},
)
if err != nil {
return xerrors.Errorf("mark chat for manual title regeneration: %w", err)
}
return nil
}, database.DefaultTXOptions().WithID("chat_title_regenerate_lock"))
}
func (p *Server) releaseManualTitleLock(ctx context.Context, chatID uuid.UUID) {
cleanupCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second)
defer cancel()
err := p.db.InTx(func(tx database.Store) error {
lockedChat, err := tx.GetChatByIDForUpdate(cleanupCtx, chatID)
if err != nil {
return xerrors.Errorf("lock chat to release manual title regeneration: %w", err)
}
if !lockedChat.WorkerID.Valid || lockedChat.WorkerID.UUID != manualTitleLockWorkerID {
return nil
}
_, err = updateChatStatusPreserveUpdatedAt(
cleanupCtx,
tx,
lockedChat,
uuid.NullUUID{},
sql.NullTime{},
sql.NullTime{},
)
if err != nil {
return xerrors.Errorf("clear manual title regeneration marker: %w", err)
}
return nil
}, database.DefaultTXOptions().WithID("chat_title_regenerate_unlock"))
if err != nil {
p.logger.Warn(cleanupCtx, "failed to release manual title regeneration marker",
slog.F("chat_id", chatID),
slog.Error(err),
)
}
}
// RegenerateChatTitle regenerates a chat title from the chat's visible
// messages, persists it when it changes, and broadcasts the update.
func (p *Server) RegenerateChatTitle(
ctx context.Context,
chat database.Chat,
) (database.Chat, error) {
// Reuse chatd's scoped auth context for deployment-config lookups while
// keeping chat ownership authorization at the HTTP layer.
//nolint:gocritic // Non-admin users need chatd-scoped config reads here.
chatdCtx := dbauthz.AsChatd(ctx)
keys, err := p.resolveProviderAPIKeys(chatdCtx)
if err != nil {
return database.Chat{}, xerrors.Errorf("resolve chat providers: %w", err)
}
if err := p.acquireManualTitleLock(ctx, chat.ID); err != nil {
return database.Chat{}, err
}
defer p.releaseManualTitleLock(chatdCtx, chat.ID)
updatedChat, err := p.regenerateChatTitleWithStore(
chatdCtx,
p.db,
chat,
keys,
)
if err != nil {
var generationErr *manualTitleGenerationError
if errors.As(err, &generationErr) {
// Reuse chatd's scoped auth context for failure accounting while
// detaching from request cancellation so usage is still recorded.
//nolint:gocritic // Failure accounting still needs chatd-scoped config reads.
recordCtx, recordCancel := context.WithTimeout(
dbauthz.AsChatd(context.WithoutCancel(ctx)),
5*time.Second,
)
defer recordCancel()
if _, recordErr := recordManualTitleUsage(
recordCtx,
p.db,
chat,
generationErr.modelConfig,
generationErr.usage,
"",
); recordErr != nil {
return database.Chat{}, errors.Join(
generationErr,
xerrors.Errorf("record manual title usage: %w", recordErr),
)
}
return database.Chat{}, generationErr
}
return database.Chat{}, err
}
return updatedChat, nil
}
func (p *Server) regenerateChatTitleWithStore(
ctx context.Context,
store database.Store,
chat database.Chat,
keys chatprovider.ProviderAPIKeys,
) (database.Chat, error) {
if limitErr := p.checkUsageLimit(ctx, store, chat.OwnerID); limitErr != nil {
return database.Chat{}, limitErr
}
headMessages, err := store.GetChatMessagesByChatIDAscPaginated(
ctx,
database.GetChatMessagesByChatIDAscPaginatedParams{
ChatID: chat.ID,
AfterID: 0,
LimitVal: manualTitleMessageWindowLimit,
},
)
if err != nil {
return database.Chat{}, xerrors.Errorf("get head chat messages: %w", err)
}
tailMessages, err := store.GetChatMessagesByChatIDDescPaginated(
ctx,
database.GetChatMessagesByChatIDDescPaginatedParams{
ChatID: chat.ID,
BeforeID: 0,
LimitVal: manualTitleMessageWindowLimit,
},
)
if err != nil {
return database.Chat{}, xerrors.Errorf("get tail chat messages: %w", err)
}
messages := mergeManualTitleMessages(headMessages, tailMessages)
if len(messages) == 0 {
return chat, nil
}
model, modelConfig, err := p.resolveManualTitleModel(ctx, store, chat, keys)
if err != nil {
return database.Chat{}, err
}
title, usage, err := generateManualTitle(ctx, messages, model)
if err != nil {
wrappedErr := xerrors.Errorf("generate manual title: %w", err)
if usage == (fantasy.Usage{}) {
return database.Chat{}, wrappedErr
}
return database.Chat{}, &manualTitleGenerationError{
cause: wrappedErr,
modelConfig: modelConfig,
usage: usage,
}
}
recordCtx, recordCancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second)
defer recordCancel()
updatedChat, recordErr := recordManualTitleUsage(
recordCtx,
store,
chat,
modelConfig,
usage,
title,
)
if recordErr != nil {
if title != "" {
return database.Chat{}, xerrors.Errorf("record manual title usage and update chat title: %w", recordErr)
}
return database.Chat{}, xerrors.Errorf("record manual title usage: %w", recordErr)
}
if updatedChat.Title == chat.Title {
return updatedChat, nil
}
p.publishChatPubsubEvent(updatedChat, coderdpubsub.ChatEventKindTitleChange, nil)
return updatedChat, nil
}
func (p *Server) resolveManualTitleModel(
ctx context.Context,
store database.Store,
chat database.Chat,
keys chatprovider.ProviderAPIKeys,
) (fantasy.LanguageModel, database.ChatModelConfig, error) {
configs, err := store.GetEnabledChatModelConfigs(ctx)
if err != nil {
p.logger.Debug(ctx, "failed to list manual title model configs",
slog.F("chat_id", chat.ID),
slog.Error(err),
)
return p.resolveFallbackManualTitleModel(ctx, chat, keys)
}
config, ok := selectPreferredConfiguredShortTextModelConfig(configs)
if !ok {
return p.resolveFallbackManualTitleModel(ctx, chat, keys)
}
model, err := chatprovider.ModelFromConfig(
config.Provider,
config.Model,
keys,
chatprovider.UserAgent(),
chatprovider.CoderHeaders(chat),
)
if err != nil {
p.logger.Debug(ctx, "manual title preferred model unavailable",
slog.F("chat_id", chat.ID),
slog.F("provider", config.Provider),
slog.F("model", config.Model),
slog.Error(err),
)
return p.resolveFallbackManualTitleModel(ctx, chat, keys)
}
return model, config, nil
}
func (p *Server) resolveFallbackManualTitleModel(
ctx context.Context,
chat database.Chat,
keys chatprovider.ProviderAPIKeys,
) (fantasy.LanguageModel, database.ChatModelConfig, error) {
config, err := p.resolveModelConfig(ctx, chat)
if err != nil {
return nil, database.ChatModelConfig{}, xerrors.Errorf(
"resolve fallback manual title model config: %w",
err,
)
}
model, err := chatprovider.ModelFromConfig(
config.Provider,
config.Model,
keys,
chatprovider.UserAgent(),
chatprovider.CoderHeaders(chat),
)
if err != nil {
return nil, database.ChatModelConfig{}, xerrors.Errorf(
"create fallback manual title model: %w",
err,
)
}
return model, config, nil
}
func mergeManualTitleMessages(
headMessages []database.ChatMessage,
tailMessagesDesc []database.ChatMessage,
) []database.ChatMessage {
merged := make([]database.ChatMessage, 0, len(headMessages)+len(tailMessagesDesc))
seen := make(map[int64]struct{}, len(headMessages)+len(tailMessagesDesc))
appendUnique := func(message database.ChatMessage) {
if _, ok := seen[message.ID]; ok {
return
}
seen[message.ID] = struct{}{}
merged = append(merged, message)
}
for _, message := range headMessages {
appendUnique(message)
}
for i := len(tailMessagesDesc) - 1; i >= 0; i-- {
appendUnique(tailMessagesDesc[i])
}
return merged
}
func fantasyUsageToChatMessageUsage(usage fantasy.Usage) codersdk.ChatMessageUsage {
var chatUsage codersdk.ChatMessageUsage
if usage.InputTokens != 0 {
chatUsage.InputTokens = ptr.Ref(usage.InputTokens)
}
if usage.OutputTokens != 0 {
chatUsage.OutputTokens = ptr.Ref(usage.OutputTokens)
}
if usage.ReasoningTokens != 0 {
chatUsage.ReasoningTokens = ptr.Ref(usage.ReasoningTokens)
}
if usage.CacheCreationTokens != 0 {
chatUsage.CacheCreationTokens = ptr.Ref(usage.CacheCreationTokens)
}
if usage.CacheReadTokens != 0 {
chatUsage.CacheReadTokens = ptr.Ref(usage.CacheReadTokens)
}
return chatUsage
}
func recordManualTitleUsage(
ctx context.Context,
store database.Store,
chat database.Chat,
modelConfig database.ChatModelConfig,
usage fantasy.Usage,
newTitle string,
) (database.Chat, error) {
hasUsage := usage != (fantasy.Usage{})
if !hasUsage && newTitle == "" {
return chat, nil
}
var totalCostMicros *int64
if hasUsage {
callConfig := codersdk.ChatModelCallConfig{}
if len(modelConfig.Options) > 0 {
if err := json.Unmarshal(modelConfig.Options, &callConfig); err != nil {
return database.Chat{}, xerrors.Errorf("parse model call config: %w", err)
}
}
totalCostMicros = chatcost.CalculateTotalCostMicros(
fantasyUsageToChatMessageUsage(usage),
callConfig.Cost,
)
}
// Use a valid empty JSON array for the content column.
// MarshalParts returns a null NullRawMessage for empty
// slices, which becomes an empty string that PostgreSQL
// rejects as invalid JSON.
content := "[]"
updatedChat := chat
err := store.InTx(func(tx database.Store) error {
lockedChat, err := tx.GetChatByIDForUpdate(ctx, chat.ID)
if err != nil {
return xerrors.Errorf("lock chat for manual title usage: %w", err)
}
updatedChat = lockedChat
if hasUsage {
messages, err := tx.InsertChatMessages(ctx, database.InsertChatMessagesParams{
ChatID: chat.ID,
CreatedBy: []uuid.UUID{chat.OwnerID},
ModelConfigID: []uuid.UUID{modelConfig.ID},
Role: []database.ChatMessageRole{database.ChatMessageRoleAssistant},
Content: []string{content},
ContentVersion: []int16{chatprompt.CurrentContentVersion},
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityModel},
InputTokens: []int64{usage.InputTokens},
OutputTokens: []int64{usage.OutputTokens},
TotalTokens: []int64{usage.TotalTokens},
ReasoningTokens: []int64{usage.ReasoningTokens},
CacheCreationTokens: []int64{usage.CacheCreationTokens},
CacheReadTokens: []int64{usage.CacheReadTokens},
ContextLimit: []int64{modelConfig.ContextLimit},
Compressed: []bool{false},
TotalCostMicros: []int64{ptr.NilToDefault(totalCostMicros, 0)},
RuntimeMs: []int64{0},
ProviderResponseID: []string{""},
})
if err != nil {
return xerrors.Errorf("insert manual title usage message: %w", err)
}
if len(messages) != 1 {
return xerrors.Errorf("expected 1 manual title usage message, got %d", len(messages))
}
if err := tx.SoftDeleteChatMessageByID(ctx, messages[0].ID); err != nil {
return xerrors.Errorf("soft delete manual title usage message: %w", err)
}
if lockedChat.LastModelConfigID != modelConfig.ID {
if _, err := tx.UpdateChatLastModelConfigByID(ctx, database.UpdateChatLastModelConfigByIDParams{
ID: chat.ID,
LastModelConfigID: lockedChat.LastModelConfigID,
}); err != nil {
return xerrors.Errorf("restore chat model config after manual title usage: %w", err)
}
}
}
if newTitle != "" && lockedChat.Title == chat.Title && newTitle != lockedChat.Title {
updatedChat, err = tx.UpdateChatByID(ctx, database.UpdateChatByIDParams{
ID: chat.ID,
Title: newTitle,
})
if err != nil {
return xerrors.Errorf("update chat title: %w", err)
}
}
return nil
}, nil)
if err != nil {
return database.Chat{}, err
}
return updatedChat, nil
}
// RefreshStatus loads the latest chat status and publishes it to stream subscribers.
func (p *Server) RefreshStatus(ctx context.Context, chatID uuid.UUID) error {
if chatID == uuid.Nil {
@@ -3475,24 +3946,7 @@ func (p *Server) runChat(
}
hasUsage := step.Usage != (fantasy.Usage{})
var usageForCost codersdk.ChatMessageUsage
if hasUsage {
if step.Usage.InputTokens != 0 {
usageForCost.InputTokens = ptr.Ref(step.Usage.InputTokens)
}
if step.Usage.OutputTokens != 0 {
usageForCost.OutputTokens = ptr.Ref(step.Usage.OutputTokens)
}
if step.Usage.ReasoningTokens != 0 {
usageForCost.ReasoningTokens = ptr.Ref(step.Usage.ReasoningTokens)
}
if step.Usage.CacheCreationTokens != 0 {
usageForCost.CacheCreationTokens = ptr.Ref(step.Usage.CacheCreationTokens)
}
if step.Usage.CacheReadTokens != 0 {
usageForCost.CacheReadTokens = ptr.Ref(step.Usage.CacheReadTokens)
}
}
usageForCost := fantasyUsageToChatMessageUsage(step.Usage)
totalCostMicros := chatcost.CalculateTotalCostMicros(usageForCost, callConfig.Cost)
var insertedMessages []database.ChatMessage
@@ -4078,10 +4532,8 @@ func (p *Server) resolveChatModel(
ctx context.Context,
chat database.Chat,
) (fantasy.LanguageModel, database.ChatModelConfig, chatprovider.ProviderAPIKeys, error) {
var (
dbConfig database.ChatModelConfig
providers []database.ChatProvider
)
var dbConfig database.ChatModelConfig
var keys chatprovider.ProviderAPIKeys
var g errgroup.Group
g.Go(func() error {
@@ -4094,28 +4546,15 @@ func (p *Server) resolveChatModel(
})
g.Go(func() error {
var err error
providers, err = p.configCache.EnabledProviders(ctx)
keys, err = p.resolveProviderAPIKeys(ctx)
if err != nil {
return xerrors.Errorf("get enabled chat providers: %w", err)
return xerrors.Errorf("resolve provider API keys: %w", err)
}
return nil
})
if err := g.Wait(); err != nil {
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, err
}
dbProviders := make(
[]chatprovider.ConfiguredProvider, 0, len(providers),
)
for _, provider := range providers {
dbProviders = append(dbProviders, chatprovider.ConfiguredProvider{
Provider: provider.Provider,
APIKey: provider.APIKey,
BaseURL: provider.BaseUrl,
})
}
keys := chatprovider.MergeProviderAPIKeys(
p.providerAPIKeys, dbProviders,
)
model, err := chatprovider.ModelFromConfig(
dbConfig.Provider, dbConfig.Model, keys, chatprovider.UserAgent(),
@@ -4129,6 +4568,29 @@ func (p *Server) resolveChatModel(
return model, dbConfig, keys, nil
}
func (p *Server) resolveProviderAPIKeys(
ctx context.Context,
) (chatprovider.ProviderAPIKeys, error) {
providers, err := p.configCache.EnabledProviders(ctx)
if err != nil {
return chatprovider.ProviderAPIKeys{}, xerrors.Errorf(
"get enabled chat providers: %w",
err,
)
}
dbProviders := make(
[]chatprovider.ConfiguredProvider, 0, len(providers),
)
for _, provider := range providers {
dbProviders = append(dbProviders, chatprovider.ConfiguredProvider{
Provider: provider.Provider,
APIKey: provider.APIKey,
BaseURL: provider.BaseUrl,
})
}
return chatprovider.MergeProviderAPIKeys(p.providerAPIKeys, dbProviders), nil
}
// resolveModelConfig looks up the chat's model config by its
// LastModelConfigID. If the referenced config no longer exists
// (e.g. it was deleted), it falls back to the default model
+178
View File
@@ -23,6 +23,7 @@ import (
dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub"
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
"github.com/coder/coder/v2/coderd/x/chatd/chaterror"
"github.com/coder/coder/v2/coderd/x/chatd/chattest"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/codersdk/workspacesdk"
"github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock"
@@ -30,6 +31,183 @@ import (
"github.com/coder/quartz"
)
func TestRegenerateChatTitle_PersistsAndBroadcasts(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
lockTx := dbmock.NewMockStore(ctrl)
usageTx := dbmock.NewMockStore(ctrl)
unlockTx := dbmock.NewMockStore(ctrl)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
pubsub := dbpubsub.NewInMemory()
clock := quartz.NewReal()
ownerID := uuid.New()
chatID := uuid.New()
modelConfigID := uuid.New()
userPrompt := "review pull request 23633 and fix review threads"
wantTitle := "Review PR 23633"
chat := database.Chat{
ID: chatID,
OwnerID: ownerID,
LastModelConfigID: modelConfigID,
Status: database.ChatStatusCompleted,
Title: fallbackChatTitle(userPrompt),
}
modelConfig := database.ChatModelConfig{
ID: modelConfigID,
Provider: "anthropic",
Model: "claude-haiku-4-5",
ContextLimit: 8192,
}
updatedChat := chat
updatedChat.Title = wantTitle
messageEvents := make(chan struct {
payload coderdpubsub.ChatEvent
err error
}, 1)
cancelSub, err := pubsub.SubscribeWithErr(
coderdpubsub.ChatEventChannel(ownerID),
coderdpubsub.HandleChatEvent(func(_ context.Context, payload coderdpubsub.ChatEvent, err error) {
messageEvents <- struct {
payload coderdpubsub.ChatEvent
err error
}{payload: payload, err: err}
}),
)
require.NoError(t, err)
defer cancelSub()
serverURL := chattest.NewAnthropic(t, func(req *chattest.AnthropicRequest) chattest.AnthropicResponse {
require.Equal(t, "claude-haiku-4-5", req.Model)
return chattest.AnthropicNonStreamingResponse(wantTitle)
})
server := &Server{
db: db,
logger: logger,
pubsub: pubsub,
configCache: newChatConfigCache(context.Background(), db, clock),
}
db.EXPECT().GetChatModelConfigByID(gomock.Any(), modelConfigID).Return(modelConfig, nil)
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{
Provider: "anthropic",
APIKey: "test-key",
BaseUrl: serverURL,
}}, nil)
db.EXPECT().GetChatUsageLimitConfig(gomock.Any()).Return(database.ChatUsageLimitConfig{}, sql.ErrNoRows)
db.EXPECT().GetChatMessagesByChatIDAscPaginated(
gomock.Any(),
database.GetChatMessagesByChatIDAscPaginatedParams{
ChatID: chatID,
AfterID: 0,
LimitVal: manualTitleMessageWindowLimit,
},
).Return([]database.ChatMessage{
mustChatMessage(
t,
database.ChatMessageRoleUser,
database.ChatMessageVisibilityBoth,
codersdk.ChatMessageText(userPrompt),
),
mustChatMessage(
t,
database.ChatMessageRoleAssistant,
database.ChatMessageVisibilityBoth,
codersdk.ChatMessageText("checking the diff now"),
),
}, nil)
db.EXPECT().GetChatMessagesByChatIDDescPaginated(
gomock.Any(),
database.GetChatMessagesByChatIDDescPaginatedParams{
ChatID: chatID,
BeforeID: 0,
LimitVal: manualTitleMessageWindowLimit,
},
).Return(nil, nil)
db.EXPECT().GetEnabledChatModelConfigs(gomock.Any()).Return(nil, nil)
gomock.InOrder(
db.EXPECT().InTx(gomock.Any(), database.DefaultTXOptions().WithID("chat_title_regenerate_lock")).DoAndReturn(
func(fn func(database.Store) error, opts *database.TxOptions) error {
require.Equal(t, "chat_title_regenerate_lock", opts.TxIdentifier)
return fn(lockTx)
},
),
db.EXPECT().InTx(gomock.Any(), nil).DoAndReturn(
func(fn func(database.Store) error, opts *database.TxOptions) error {
require.Nil(t, opts)
return fn(usageTx)
},
),
db.EXPECT().InTx(gomock.Any(), database.DefaultTXOptions().WithID("chat_title_regenerate_unlock")).DoAndReturn(
func(fn func(database.Store) error, opts *database.TxOptions) error {
require.Equal(t, "chat_title_regenerate_unlock", opts.TxIdentifier)
return fn(unlockTx)
},
),
)
lockTx.EXPECT().GetChatByIDForUpdate(gomock.Any(), chatID).Return(chat, nil)
lockTx.EXPECT().UpdateChatStatusPreserveUpdatedAt(gomock.Any(), gomock.AssignableToTypeOf(database.UpdateChatStatusPreserveUpdatedAtParams{})).DoAndReturn(
func(_ context.Context, arg database.UpdateChatStatusPreserveUpdatedAtParams) (database.Chat, error) {
require.Equal(t, chatID, arg.ID)
require.Equal(t, chat.Status, arg.Status)
require.Equal(t, uuid.NullUUID{UUID: manualTitleLockWorkerID, Valid: true}, arg.WorkerID)
require.True(t, arg.StartedAt.Valid)
require.True(t, arg.HeartbeatAt.Valid)
return chat, nil
},
)
usageTx.EXPECT().GetChatByIDForUpdate(gomock.Any(), chatID).Return(chat, nil)
usageTx.EXPECT().InsertChatMessages(gomock.Any(), gomock.AssignableToTypeOf(database.InsertChatMessagesParams{})).DoAndReturn(
func(_ context.Context, arg database.InsertChatMessagesParams) ([]database.ChatMessage, error) {
require.Equal(t, []uuid.UUID{ownerID}, arg.CreatedBy)
require.Equal(t, []uuid.UUID{modelConfigID}, arg.ModelConfigID)
require.Equal(t, []string{"[]"}, arg.Content)
return []database.ChatMessage{{ID: 91}}, nil
},
)
usageTx.EXPECT().SoftDeleteChatMessageByID(gomock.Any(), int64(91)).Return(nil)
usageTx.EXPECT().UpdateChatByID(gomock.Any(), database.UpdateChatByIDParams{
ID: chatID,
Title: wantTitle,
}).Return(updatedChat, nil)
lockedChatWithMarker := updatedChat
lockedChatWithMarker.WorkerID = uuid.NullUUID{UUID: manualTitleLockWorkerID, Valid: true}
unlockTx.EXPECT().GetChatByIDForUpdate(gomock.Any(), chatID).Return(lockedChatWithMarker, nil)
unlockTx.EXPECT().UpdateChatStatusPreserveUpdatedAt(gomock.Any(), gomock.AssignableToTypeOf(database.UpdateChatStatusPreserveUpdatedAtParams{})).DoAndReturn(
func(_ context.Context, arg database.UpdateChatStatusPreserveUpdatedAtParams) (database.Chat, error) {
require.Equal(t, chatID, arg.ID)
require.False(t, arg.WorkerID.Valid)
require.False(t, arg.StartedAt.Valid)
require.False(t, arg.HeartbeatAt.Valid)
return updatedChat, nil
},
)
gotChat, err := server.RegenerateChatTitle(ctx, chat)
require.NoError(t, err)
require.Equal(t, updatedChat, gotChat)
select {
case event := <-messageEvents:
require.NoError(t, event.err)
require.Equal(t, coderdpubsub.ChatEventKindTitleChange, event.payload.Kind)
require.Equal(t, chatID, event.payload.Chat.ID)
require.Equal(t, wantTitle, event.payload.Chat.Title)
case <-time.After(time.Second):
t.Fatal("timed out waiting for title change pubsub event")
}
}
func TestRefreshChatWorkspaceSnapshot_NoReloadWhenWorkspacePresent(t *testing.T) {
t.Parallel()
+229 -12
View File
@@ -2,6 +2,8 @@ package chatd
import (
"context"
"fmt"
"slices"
"strings"
"time"
@@ -36,6 +38,17 @@ const titleGenerationPrompt = "You are a title generator. Your ONLY job is to ou
"Output ONLY the title — no quotes, no emoji, no markdown, no code fences, " +
"no trailing punctuation, no preamble, no explanation. Sentence case."
const (
// maxConversationContextRunes caps the conversation sample in manual
// title prompts to avoid exceeding model context windows.
maxConversationContextRunes = 6000
// maxLatestUserMessageRunes caps the latest user message excerpt.
maxLatestUserMessageRunes = 1000
// recentTurnWindow is the number of most recent turns included
// alongside the first user turn in manual title context.
recentTurnWindow = 3
)
// preferredTitleModels are lightweight models used for title
// generation, one per provider type. Each entry uses the
// cheapest/fastest small model for that provider as identified
@@ -54,6 +67,33 @@ var preferredTitleModels = []struct {
{fantasyvercel.Name, "anthropic/claude-haiku-4.5"},
}
func selectPreferredConfiguredShortTextModelConfig(
configs []database.ChatModelConfig,
) (database.ChatModelConfig, bool) {
for _, preferred := range preferredTitleModels {
for _, config := range configs {
if chatprovider.NormalizeProvider(config.Provider) != preferred.provider {
continue
}
if !strings.EqualFold(strings.TrimSpace(config.Model), preferred.model) {
continue
}
return config, true
}
}
return database.ChatModelConfig{}, false
}
func normalizeShortTextOutput(text string) string {
text = strings.TrimSpace(text)
if text == "" {
return ""
}
text = strings.Trim(text, "\"'`")
return strings.Join(strings.Fields(text), " ")
}
// maybeGenerateChatTitle generates an AI title for the chat when
// appropriate (first user message, no assistant reply yet, and the
// current title is either empty or still the fallback truncation).
@@ -138,7 +178,7 @@ func generateTitle(
model fantasy.LanguageModel,
input string,
) (string, error) {
title, err := generateShortText(ctx, model, titleGenerationPrompt, input)
title, _, err := generateShortText(ctx, model, titleGenerationPrompt, input)
if err != nil {
return "", err
}
@@ -199,13 +239,10 @@ func titleInput(
}
func normalizeTitleOutput(title string) string {
title = strings.TrimSpace(title)
title = normalizeShortTextOutput(title)
if title == "" {
return ""
}
title = strings.Trim(title, "\"'`")
title = strings.Join(strings.Fields(title), " ")
return truncateRunes(title, 80)
}
@@ -226,7 +263,7 @@ func fallbackChatTitle(message string) string {
title := strings.Join(words, " ")
if truncated {
title += "…"
return truncateRunes(title, maxRunes-1) + "…"
}
return truncateRunes(title, maxRunes)
@@ -260,6 +297,186 @@ func truncateRunes(value string, maxLen int) string {
return string(runes[:maxLen])
}
// Manual title regeneration is user-initiated and can use richer
// conversation context than the automatic first-message title path
// above. These helpers keep the manual prompt-building logic private
// while reusing the shared title-generation utilities in this file.
type manualTitleTurn struct {
role string
text string
}
func extractManualTitleTurns(messages []database.ChatMessage) []manualTitleTurn {
turns := make([]manualTitleTurn, 0, len(messages))
for _, message := range messages {
if message.Visibility == database.ChatMessageVisibilityModel {
continue
}
role := ""
switch message.Role {
case database.ChatMessageRoleUser:
role = string(database.ChatMessageRoleUser)
case database.ChatMessageRoleAssistant:
role = string(database.ChatMessageRoleAssistant)
default:
continue
}
parts, err := chatprompt.ParseContent(message)
if err != nil {
continue
}
text := strings.TrimSpace(contentBlocksToText(parts))
if text == "" {
continue
}
turns = append(turns, manualTitleTurn{
role: role,
text: text,
})
}
return turns
}
func selectManualTitleTurnIndexes(turns []manualTitleTurn) []int {
firstUserIndex := slices.IndexFunc(turns, func(turn manualTitleTurn) bool {
return turn.role == string(database.ChatMessageRoleUser)
})
if firstUserIndex == -1 {
return nil
}
windowStart := max(0, len(turns)-recentTurnWindow)
selected := make([]int, 0, recentTurnWindow+1)
if firstUserIndex < windowStart {
selected = append(selected, firstUserIndex)
}
for i := windowStart; i < len(turns); i++ {
selected = append(selected, i)
}
return selected
}
func buildManualTitleContext(
turns []manualTitleTurn,
selected []int,
) (conversationBlock string, latestUserMsg string) {
userCount := 0
for _, turn := range turns {
if turn.role != string(database.ChatMessageRoleUser) {
continue
}
userCount++
latestUserMsg = turn.text
}
latestUserMsg = truncateRunes(latestUserMsg, maxLatestUserMessageRunes)
if userCount <= 1 || len(selected) == 0 {
return "", latestUserMsg
}
lines := make([]string, 0, len(selected)+1)
for i, idx := range selected {
if i == 1 {
if gap := idx - selected[i-1] - 1; gap > 0 {
lines = append(lines, fmt.Sprintf("[... %d earlier turns omitted ...]", gap))
}
}
lines = append(lines, fmt.Sprintf("[%s]: %s", turns[idx].role, turns[idx].text))
}
conversationBlock = strings.Join(lines, "\n")
conversationBlock = truncateRunes(conversationBlock, maxConversationContextRunes)
return conversationBlock, latestUserMsg
}
func renderManualTitlePrompt(
conversationBlock string,
firstUserText string,
latestUserMsg string,
) string {
var prompt strings.Builder
write := func(value string) {
_, _ = prompt.WriteString(value)
}
write("You are a title generator for an AI coding assistant conversation.\n\n")
write("The user's primary objective was:\n<primary_objective>\n")
write(firstUserText)
write("\n</primary_objective>")
if conversationBlock != "" {
write("\n\nConversation sample:\n<conversation_sample>\n")
write(conversationBlock)
write("\n</conversation_sample>")
}
if strings.TrimSpace(latestUserMsg) != strings.TrimSpace(truncateRunes(firstUserText, maxLatestUserMessageRunes)) {
write("\n\nThe user's most recent message:\n<latest_message>\n")
write(latestUserMsg)
write("\n</latest_message>\n")
write("Note: Weight the overall conversation arc more heavily than just the latest message.")
}
write("\n\nRequirements:\n")
write("- Output a short title of 2-8 words.\n")
write("- Use verb-noun format in sentence case.\n")
write("- Preserve specific identifiers (PR numbers, repo names, file paths, function names, error messages).\n")
write("- No trailing punctuation, quotes, emoji, or markdown.\n")
write("- No temporal phrasing (\"Continue\", \"Follow up on\") or meta phrasing (\"Chat about\").\n")
write("- Output ONLY the title - nothing else.\n")
return prompt.String()
}
func generateManualTitle(
ctx context.Context,
messages []database.ChatMessage,
fallbackModel fantasy.LanguageModel,
) (string, fantasy.Usage, error) {
turns := extractManualTitleTurns(messages)
selected := selectManualTitleTurnIndexes(turns)
firstUserIndex := slices.IndexFunc(turns, func(turn manualTitleTurn) bool {
return turn.role == string(database.ChatMessageRoleUser)
})
if firstUserIndex == -1 {
return "", fantasy.Usage{}, nil
}
firstUserText := truncateRunes(turns[firstUserIndex].text, maxLatestUserMessageRunes)
conversationBlock, latestUserMsg := buildManualTitleContext(turns, selected)
systemPrompt := renderManualTitlePrompt(
conversationBlock,
firstUserText,
latestUserMsg,
)
titleCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
title, usage, err := generateShortText(
titleCtx,
fallbackModel,
systemPrompt,
"Generate the title.",
)
if err != nil {
return "", fantasy.Usage{}, err
}
title = normalizeTitleOutput(title)
if title == "" {
return "", usage, xerrors.New("generated title was empty")
}
return title, usage, nil
}
const pushSummaryPrompt = "You are a notification assistant. Given a chat title " +
"and the agent's last message, write a single short sentence (under 100 characters) " +
"summarizing what the agent did. This will be shown as a push notification body. " +
@@ -281,6 +498,7 @@ func generatePushSummary(
summaryCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
assistantText = truncateRunes(assistantText, maxConversationContextRunes)
input := "Chat title: " + chat.Title + "\n\nAgent's last message:\n" + assistantText
candidates := make([]fantasy.LanguageModel, 0, len(preferredTitleModels)+1)
@@ -296,7 +514,7 @@ func generatePushSummary(
candidates = append(candidates, fallbackModel)
for _, model := range candidates {
summary, err := generateShortText(summaryCtx, model, pushSummaryPrompt, input)
summary, _, err := generateShortText(summaryCtx, model, pushSummaryPrompt, input)
if err != nil {
logger.Debug(ctx, "push summary model candidate failed",
slog.Error(err),
@@ -318,7 +536,7 @@ func generateShortText(
model fantasy.LanguageModel,
systemPrompt string,
userInput string,
) (string, error) {
) (string, fantasy.Usage, error) {
prompt := []fantasy.Message{
{
Role: fantasy.MessageRoleSystem,
@@ -346,7 +564,7 @@ func generateShortText(
return genErr
}, nil)
if err != nil {
return "", xerrors.Errorf("generate short text: %w", err)
return "", fantasy.Usage{}, xerrors.Errorf("generate short text: %w", err)
}
responseParts := make([]codersdk.ChatMessagePart, 0, len(response.Content))
@@ -355,7 +573,6 @@ func generateShortText(
responseParts = append(responseParts, p)
}
}
text := strings.TrimSpace(contentBlocksToText(responseParts))
text = strings.Trim(text, "\"'`")
return text, nil
text := normalizeShortTextOutput(contentBlocksToText(responseParts))
return text, response.Usage, nil
}
+584
View File
@@ -0,0 +1,584 @@
package chatd //nolint:testpackage // Keeps internal helper tests in-package.
import (
"context"
"encoding/json"
"strings"
"testing"
"time"
"charm.land/fantasy"
"github.com/sqlc-dev/pqtype"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/codersdk"
)
func Test_extractManualTitleTurns(t *testing.T) {
t.Parallel()
tests := []struct {
name string
messages []database.ChatMessage
want []manualTitleTurn
}{
{
name: "filters to visible user and assistant text turns",
messages: []database.ChatMessage{
mustChatMessage(t, database.ChatMessageRoleUser, database.ChatMessageVisibilityBoth,
codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeText, Text: " review quickgen helpers "},
),
mustChatMessage(t, database.ChatMessageRoleAssistant, database.ChatMessageVisibilityBoth,
codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeText, Text: " drafted a plan "},
),
mustChatMessage(t, database.ChatMessageRoleSystem, database.ChatMessageVisibilityBoth,
codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeText, Text: "system prompt"},
),
mustChatMessage(t, database.ChatMessageRoleTool, database.ChatMessageVisibilityBoth,
codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeText, Text: "tool output"},
),
mustChatMessage(t, database.ChatMessageRoleUser, database.ChatMessageVisibilityModel,
codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeText, Text: "hidden model note"},
),
mustChatMessage(t, database.ChatMessageRoleUser, database.ChatMessageVisibilityBoth,
codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeText, Text: " "},
),
mustChatMessage(t, database.ChatMessageRoleAssistant, database.ChatMessageVisibilityBoth,
codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeReasoning, Text: "reasoning only"},
),
mustChatMessage(t, database.ChatMessageRoleUser, database.ChatMessageVisibilityBoth,
codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeFile, MediaType: "text/plain"},
),
},
want: []manualTitleTurn{
{role: "user", text: "review quickgen helpers"},
{role: "assistant", text: "drafted a plan"},
},
},
{
name: "reuses text extraction for multi-part content",
messages: []database.ChatMessage{
mustChatMessage(t, database.ChatMessageRoleUser, database.ChatMessageVisibilityBoth,
codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeText, Text: "first chunk"},
codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeReasoning, Text: "skip me"},
codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeText, Text: " second chunk "},
),
},
want: []manualTitleTurn{{role: "user", text: "first chunk second chunk"}},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := extractManualTitleTurns(tt.messages)
require.Equal(t, tt.want, got)
})
}
}
func Test_selectManualTitleTurnIndexes(t *testing.T) {
t.Parallel()
tests := []struct {
name string
turns []manualTitleTurn
want []int
}{
{
name: "single user turn",
turns: []manualTitleTurn{
{role: "user", text: "one"},
},
want: []int{0},
},
{
name: "first user plus trailing window",
turns: []manualTitleTurn{
{role: "user", text: "one"},
{role: "assistant", text: "two"},
{role: "user", text: "three"},
{role: "assistant", text: "four"},
{role: "user", text: "five"},
},
want: []int{0, 2, 3, 4},
},
{
name: "two turns returns both",
turns: []manualTitleTurn{
{role: "user", text: "one"},
{role: "assistant", text: "two"},
},
want: []int{0, 1},
},
{
name: "prepends first user when before trailing window",
turns: []manualTitleTurn{
{role: "assistant", text: "intro"},
{role: "assistant", text: "setup"},
{role: "user", text: "goal"},
{role: "assistant", text: "a"},
{role: "assistant", text: "b"},
{role: "assistant", text: "c"},
},
want: []int{2, 3, 4, 5},
},
{
name: "ten plus turns keeps first user and last three",
turns: []manualTitleTurn{
{role: "assistant", text: "0"},
{role: "assistant", text: "1"},
{role: "user", text: "2"},
{role: "assistant", text: "3"},
{role: "assistant", text: "4"},
{role: "assistant", text: "5"},
{role: "assistant", text: "6"},
{role: "assistant", text: "7"},
{role: "assistant", text: "8"},
{role: "user", text: "9"},
{role: "assistant", text: "10"},
{role: "user", text: "11"},
},
want: []int{2, 9, 10, 11},
},
{
name: "no user turns",
turns: []manualTitleTurn{
{role: "assistant", text: "one"},
{role: "assistant", text: "two"},
},
want: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := selectManualTitleTurnIndexes(tt.turns)
require.Equal(t, tt.want, got)
})
}
}
func Test_buildManualTitleContext(t *testing.T) {
t.Parallel()
longConversationText := strings.Repeat("a", 3500)
longLatestUserText := strings.Repeat("z", 1200)
tests := []struct {
name string
turns []manualTitleTurn
selected []int
wantConversation string
wantConversationEmpty bool
wantConversationHasGap bool
wantConversationRunes int
wantLatestUser string
wantLatestUserRunes int
wantLatestUserContains string
wantLatestUserNotEmpty bool
}{
{
name: "adds gap marker when selected turns skip earlier context",
turns: []manualTitleTurn{
{role: "user", text: "open pull request"},
{role: "assistant", text: "checked CI"},
{role: "user", text: "review logs"},
{role: "assistant", text: "found flaky test"},
{role: "user", text: "update chat title"},
},
selected: []int{0, 3, 4},
wantConversationHasGap: true,
wantLatestUser: "update chat title",
},
{
name: "omits gap marker for contiguous selection",
turns: []manualTitleTurn{
{role: "user", text: "open pull request"},
{role: "assistant", text: "checked CI"},
{role: "user", text: "update chat title"},
},
selected: []int{0, 1, 2},
wantConversation: "[user]: open pull request\n[assistant]: checked CI\n[user]: update chat title",
wantConversationHasGap: false,
wantLatestUser: "update chat title",
},
{
name: "single useful user turn returns empty conversation block",
turns: []manualTitleTurn{{role: "user", text: "rename helper"}},
selected: []int{0},
wantConversationEmpty: true,
wantLatestUser: "rename helper",
},
{
name: "truncates conversation block at six thousand runes",
turns: []manualTitleTurn{
{role: "user", text: longConversationText},
{role: "assistant", text: longConversationText},
{role: "user", text: "latest"},
},
selected: []int{0, 1, 2},
wantConversationRunes: 6000,
wantLatestUser: "latest",
},
{
name: "truncates latest user message at one thousand runes",
turns: []manualTitleTurn{
{role: "user", text: "first"},
{role: "assistant", text: "reply"},
{role: "user", text: longLatestUserText},
},
selected: []int{0, 1, 2},
wantLatestUserRunes: 1000,
wantLatestUserContains: strings.Repeat("z", 1000),
wantLatestUserNotEmpty: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
conversationBlock, latestUserMsg := buildManualTitleContext(tt.turns, tt.selected)
if tt.wantConversationEmpty {
require.Empty(t, conversationBlock)
}
if tt.wantConversation != "" {
require.Equal(t, tt.wantConversation, conversationBlock)
}
if tt.wantConversationHasGap {
require.Contains(t, conversationBlock, "[... 2 earlier turns omitted ...]")
} else if !tt.wantConversationEmpty {
require.NotContains(t, conversationBlock, "earlier turns omitted")
}
if tt.wantConversationRunes > 0 {
require.Len(t, []rune(conversationBlock), tt.wantConversationRunes)
}
if tt.wantLatestUser != "" {
require.Equal(t, tt.wantLatestUser, latestUserMsg)
}
if tt.wantLatestUserRunes > 0 {
require.Len(t, []rune(latestUserMsg), tt.wantLatestUserRunes)
}
if tt.wantLatestUserContains != "" {
require.Equal(t, tt.wantLatestUserContains, latestUserMsg)
}
if tt.wantLatestUserNotEmpty {
require.NotEmpty(t, latestUserMsg)
}
})
}
}
func Test_renderManualTitlePrompt(t *testing.T) {
t.Parallel()
longFirstUserText := strings.Repeat("b", 1501)
tests := []struct {
name string
conversationBlock string
firstUserText string
latestUserMsg string
wantConversationSample bool
wantLatestSection bool
}{
{
name: "includes conversation sample when provided",
conversationBlock: "[user]: inspect logs\n[assistant]: found flaky test",
firstUserText: "inspect logs",
latestUserMsg: "update quickgen title",
wantConversationSample: true,
wantLatestSection: true,
},
{
name: "omits optional sections when not needed",
conversationBlock: "",
firstUserText: "inspect logs",
latestUserMsg: "inspect logs",
wantConversationSample: false,
wantLatestSection: false,
},
{
name: "latest section compares trimmed text",
conversationBlock: "",
firstUserText: "inspect logs",
latestUserMsg: " inspect logs ",
wantConversationSample: false,
wantLatestSection: false,
},
{
name: "omits latest section when same message truncated",
conversationBlock: "",
firstUserText: longFirstUserText,
latestUserMsg: truncateRunes(longFirstUserText, 1000),
wantConversationSample: false,
wantLatestSection: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
prompt := renderManualTitlePrompt(tt.conversationBlock, tt.firstUserText, tt.latestUserMsg)
require.Contains(t, prompt, "The user's primary objective was:")
require.Contains(t, prompt, "Requirements:")
require.Contains(t, prompt, "- Output a short title of 2-8 words.")
require.Contains(t, prompt, "- Output ONLY the title - nothing else.")
if tt.wantConversationSample {
require.Contains(t, prompt, "Conversation sample:")
require.Contains(t, prompt, tt.conversationBlock)
} else {
require.NotContains(t, prompt, "Conversation sample:")
}
if tt.wantLatestSection {
require.Contains(t, prompt, "The user's most recent message:")
require.Contains(t, prompt, "Note: Weight the overall conversation arc more heavily than just the latest message.")
require.Contains(t, prompt, strings.TrimSpace(tt.latestUserMsg))
} else {
require.NotContains(t, prompt, "The user's most recent message:")
require.NotContains(t, prompt, "Weight the overall conversation arc more heavily")
}
})
}
}
func Test_generateManualTitle_UsesTimeout(t *testing.T) {
t.Parallel()
messages := []database.ChatMessage{
mustChatMessage(
t,
database.ChatMessageRoleUser,
database.ChatMessageVisibilityBoth,
codersdk.ChatMessageText("refresh chat title"),
),
}
model := &stubModel{
generateFn: func(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
deadline, ok := ctx.Deadline()
require.True(t, ok, "manual title generation should set a deadline")
require.WithinDuration(
t,
time.Now().Add(30*time.Second),
deadline,
2*time.Second,
)
require.Len(t, call.Prompt, 2)
return &fantasy.Response{
Content: fantasy.ResponseContent{
fantasy.TextContent{Text: "Refresh title"},
},
}, nil
},
}
title, _, err := generateManualTitle(
context.Background(),
messages,
model,
)
require.NoError(t, err)
require.Equal(t, "Refresh title", title)
}
func Test_generateManualTitle_TruncatesFirstUserInput(t *testing.T) {
t.Parallel()
longFirstUserText := strings.Repeat("a", 1500)
messages := []database.ChatMessage{
mustChatMessage(
t,
database.ChatMessageRoleUser,
database.ChatMessageVisibilityBoth,
codersdk.ChatMessageText(longFirstUserText),
),
}
model := &stubModel{
generateFn: func(_ context.Context, call fantasy.Call) (*fantasy.Response, error) {
require.Len(t, call.Prompt, 2)
systemText, ok := call.Prompt[0].Content[0].(fantasy.TextPart)
require.True(t, ok)
require.Contains(t, systemText.Text, truncateRunes(longFirstUserText, 1000))
userText, ok := call.Prompt[1].Content[0].(fantasy.TextPart)
require.True(t, ok)
require.Equal(t, "Generate the title.", userText.Text)
return &fantasy.Response{
Content: fantasy.ResponseContent{
fantasy.TextContent{Text: "Refresh title"},
},
}, nil
},
}
_, _, err := generateManualTitle(
context.Background(),
messages,
model,
)
require.NoError(t, err)
}
func Test_generateManualTitle_ReturnsUsageForEmptyNormalizedTitle(t *testing.T) {
t.Parallel()
messages := []database.ChatMessage{
mustChatMessage(
t,
database.ChatMessageRoleUser,
database.ChatMessageVisibilityBoth,
codersdk.ChatMessageText("refresh chat title"),
),
}
model := &stubModel{
generateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
return &fantasy.Response{
Content: fantasy.ResponseContent{
fantasy.TextContent{Text: "\"\""},
},
Usage: fantasy.Usage{
InputTokens: 11,
OutputTokens: 7,
TotalTokens: 18,
},
}, nil
},
}
_, usage, err := generateManualTitle(
context.Background(),
messages,
model,
)
require.ErrorContains(t, err, "generated title was empty")
require.Equal(t, int64(11), usage.InputTokens)
require.Equal(t, int64(7), usage.OutputTokens)
require.Equal(t, int64(18), usage.TotalTokens)
}
func Test_selectPreferredConfiguredShortTextModelConfig(t *testing.T) {
t.Parallel()
t.Run("chooses the highest-priority configured lightweight model", func(t *testing.T) {
t.Parallel()
configs := []database.ChatModelConfig{
{Provider: preferredTitleModels[2].provider, Model: preferredTitleModels[2].model},
{Provider: preferredTitleModels[1].provider, Model: preferredTitleModels[1].model},
{Provider: "openai", Model: "gpt-4.1"},
}
got, ok := selectPreferredConfiguredShortTextModelConfig(configs)
require.True(t, ok)
require.Equal(t, preferredTitleModels[1].provider, got.Provider)
require.Equal(t, preferredTitleModels[1].model, got.Model)
})
t.Run("returns false when no preferred lightweight model is configured", func(t *testing.T) {
t.Parallel()
got, ok := selectPreferredConfiguredShortTextModelConfig([]database.ChatModelConfig{{
Provider: "openai",
Model: "gpt-4.1",
}})
require.False(t, ok)
require.Equal(t, database.ChatModelConfig{}, got)
})
}
func Test_generateShortText_NormalizesQuotedOutput(t *testing.T) {
t.Parallel()
model := &stubModel{
generateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
return &fantasy.Response{
Content: fantasy.ResponseContent{
fantasy.TextContent{Text: " \"Quoted summary\" "},
},
Usage: fantasy.Usage{InputTokens: 3, OutputTokens: 2, TotalTokens: 5},
}, nil
},
}
text, usage, err := generateShortText(context.Background(), model, "system", "user")
require.NoError(t, err)
require.Equal(t, "Quoted summary", text)
require.Equal(t, int64(3), usage.InputTokens)
require.Equal(t, int64(2), usage.OutputTokens)
require.Equal(t, int64(5), usage.TotalTokens)
}
type stubModel struct {
generateFn func(context.Context, fantasy.Call) (*fantasy.Response, error)
}
func (m *stubModel) Generate(
ctx context.Context,
call fantasy.Call,
) (*fantasy.Response, error) {
return m.generateFn(ctx, call)
}
func (*stubModel) Stream(
context.Context,
fantasy.Call,
) (fantasy.StreamResponse, error) {
return nil, xerrors.New("stream not implemented")
}
func (*stubModel) GenerateObject(
context.Context,
fantasy.ObjectCall,
) (*fantasy.ObjectResponse, error) {
return nil, xerrors.New("generate object not implemented")
}
func (*stubModel) StreamObject(
context.Context,
fantasy.ObjectCall,
) (fantasy.ObjectStreamResponse, error) {
return nil, xerrors.New("stream object not implemented")
}
func (*stubModel) Provider() string {
return "test"
}
func (*stubModel) Model() string {
return "test"
}
func mustChatMessage(
t *testing.T,
role database.ChatMessageRole,
visibility database.ChatMessageVisibility,
parts ...codersdk.ChatMessagePart,
) database.ChatMessage {
t.Helper()
content, err := json.Marshal(parts)
require.NoError(t, err)
return database.ChatMessage{
Role: role,
Visibility: visibility,
Content: pqtype.NullRawMessage{
RawMessage: content,
Valid: len(content) > 0,
},
}
}
+15
View File
@@ -1870,6 +1870,21 @@ func (c *ExperimentalClient) InterruptChat(ctx context.Context, chatID uuid.UUID
return chat, json.NewDecoder(res.Body).Decode(&chat)
}
// RegenerateChatTitle requests the server to regenerate the chat's
// title using richer conversation context.
func (c *ExperimentalClient) RegenerateChatTitle(ctx context.Context, chatID uuid.UUID) (Chat, error) {
res, err := c.Request(ctx, http.MethodPost, fmt.Sprintf("/api/experimental/chats/%s/title/regenerate", chatID), nil)
if err != nil {
return Chat{}, err
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return Chat{}, readBodyAsChatUsageLimitError(res)
}
var chat Chat
return chat, json.NewDecoder(res.Body).Decode(&chat)
}
// GetChatGitChanges returns git changes for a chat.
func (c *ExperimentalClient) GetChatGitChanges(ctx context.Context, chatID uuid.UUID) ([]ChatGitChange, error) {
res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/experimental/chats/%s/git-changes", chatID), nil)
+7
View File
@@ -3126,6 +3126,13 @@ class ExperimentalApiMethods {
await this.axios.patch(`/api/experimental/chats/${chatId}`, req);
};
regenerateChatTitle = async (chatId: string): Promise<TypesGen.Chat> => {
const response = await this.axios.post<TypesGen.Chat>(
`/api/experimental/chats/${chatId}/title/regenerate`,
);
return response.data;
};
createChatMessage = async (
chatId: string,
req: TypesGen.CreateChatMessageRequest,
+47
View File
@@ -22,6 +22,7 @@ import {
invalidateChatListQueries,
pinChat,
promoteChatQueuedMessage,
regenerateChatTitle,
reorderPinnedChat,
unarchiveChat,
unpinChat,
@@ -41,6 +42,7 @@ vi.mock("api/api", () => ({
editChatMessage: vi.fn(),
interruptChat: vi.fn(),
promoteChatQueuedMessage: vi.fn(),
regenerateChatTitle: vi.fn(),
},
},
}));
@@ -556,6 +558,51 @@ describe("reorderPinnedChat", () => {
});
});
describe("regenerateChatTitle cache updates", () => {
it("preserves existing chat detail fields when the response is partial", () => {
const queryClient = createTestQueryClient();
const chatId = "chat-1";
const cachedChat = makeChat(chatId, {
diff_status: {
chat_id: chatId,
url: "https://example.com/pr/1",
pull_request_state: "open",
pull_request_title: "",
pull_request_draft: false,
changes_requested: false,
additions: 1,
deletions: 2,
changed_files: 3,
refreshed_at: "2025-01-01T00:00:00.000Z",
stale_at: "2025-01-01T01:00:00.000Z",
},
});
queryClient.setQueryData(chatKey(chatId), cachedChat);
seedInfiniteChats(queryClient, [cachedChat]);
const mutation = regenerateChatTitle(queryClient);
const updatedChat = {
id: chatId,
title: "New title",
} satisfies Partial<TypesGen.Chat>;
mutation.onSuccess(updatedChat as TypesGen.Chat);
const cachedDetail = queryClient.getQueryData<TypesGen.Chat>(
chatKey(chatId),
);
expect(cachedDetail).toEqual({
...cachedChat,
title: "New title",
});
expect(cachedDetail?.diff_status).toEqual(cachedChat.diff_status);
expect(readInfiniteChats(queryClient)?.[0]).toMatchObject({
id: chatId,
title: "New title",
});
});
});
describe("chat cost query factories", () => {
it("builds the summary query key and forwards snake_case params", async () => {
const user = "user-1";
+31
View File
@@ -476,6 +476,37 @@ export const reorderPinnedChat = (queryClient: QueryClient) => ({
},
});
export const regenerateChatTitle = (queryClient: QueryClient) => ({
mutationFn: (chatId: string) => API.experimental.regenerateChatTitle(chatId),
onSuccess: (updatedChat: TypesGen.Chat) => {
queryClient.setQueryData<TypesGen.Chat>(
chatKey(updatedChat.id),
(previousChat) =>
previousChat ? { ...previousChat, ...updatedChat } : updatedChat,
);
updateInfiniteChatsCache(queryClient, (chats) =>
chats.map((chat) =>
chat.id === updatedChat.id
? { ...chat, title: updatedChat.title }
: chat,
),
);
},
onSettled: async (
_data: TypesGen.Chat | undefined,
_error: unknown,
chatId: string,
) => {
await invalidateChatListQueries(queryClient);
await queryClient.invalidateQueries({
queryKey: chatKey(chatId),
exact: true,
});
},
});
export const createChat = (queryClient: QueryClient) => ({
mutationFn: (req: TypesGen.CreateChatRequest) =>
API.experimental.createChat(req),
@@ -55,6 +55,9 @@ const AgentDetailLayout: FC = () => {
requestUnarchiveAgent: () => {},
requestPinAgent: () => {},
requestUnpinAgent: () => {},
onRegenerateTitle: () => {},
isRegeneratingTitle: false,
regeneratingTitleChatId: null,
isSidebarCollapsed: false,
onToggleSidebarCollapsed: () => {},
onExpandSidebar: () => {},
+17
View File
@@ -292,6 +292,9 @@ const AgentDetail: FC = () => {
requestArchiveAgent,
requestArchiveAndDeleteWorkspace,
requestUnarchiveAgent,
onRegenerateTitle,
isRegeneratingTitle,
regeneratingTitleChatId,
isSidebarCollapsed,
onToggleSidebarCollapsed,
onChatReady,
@@ -326,6 +329,9 @@ const AgentDetail: FC = () => {
});
};
const isRegeneratingThisChat =
isRegeneratingTitle && regeneratingTitleChatId === agentId;
const chatQuery = useQuery({
...chat(agentId ?? ""),
enabled: Boolean(agentId),
@@ -480,6 +486,7 @@ const AgentDetail: FC = () => {
}
: undefined;
const isArchived = chatRecord?.archived ?? false;
const isRegenerateTitleDisabled = isArchived || isRegeneratingTitle;
const chatLastModelConfigID = chatRecord?.last_model_config_id;
const sendMutation = useMutation(
@@ -887,6 +894,13 @@ const AgentDetail: FC = () => {
onChatReady();
}, [onChatReady, chatMessagesQuery.isSuccess, agentId]);
const handleRegenerateTitle = () => {
if (!agentId || isRegenerateTitleDisabled || !onRegenerateTitle) {
return;
}
onRegenerateTitle(agentId);
};
if (chatQuery.isLoading || chatMessagesQuery.isLoading) {
return (
<AgentDetailLoadingView
@@ -958,6 +972,9 @@ const AgentDetail: FC = () => {
handleArchiveAndDeleteWorkspaceAction={
handleArchiveAndDeleteWorkspaceAction
}
handleRegenerateTitle={handleRegenerateTitle}
isRegeneratingTitle={isRegeneratingThisChat}
isRegenerateTitleDisabled={isRegenerateTitleDisabled}
urlTransform={urlTransform}
scrollContainerRef={scrollContainerRef}
scrollToBottomRef={scrollToBottomRef}
@@ -233,6 +233,9 @@ const AgentEmbedPage: FC = () => {
requestPinAgent: () => {},
requestUnpinAgent: () => {},
requestArchiveAndDeleteWorkspace,
// Title regeneration is not supported in embed mode.
isRegeneratingTitle: false,
regeneratingTitleChatId: null,
isSidebarCollapsed,
onToggleSidebarCollapsed,
onExpandSidebar: () => {},
+16
View File
@@ -22,6 +22,7 @@ import {
pinChat,
prependToInfiniteChatsCache,
readInfiniteChatsCache,
regenerateChatTitle,
reorderPinnedChat,
unarchiveChat,
unpinChat,
@@ -220,6 +221,12 @@ const AgentsPage: FC = () => {
toast.error(getErrorMessage(error, "Failed to reorder pinned agents."));
},
});
const regenerateTitleMutation = useMutation({
...regenerateChatTitle(queryClient),
onError: (error: unknown) => {
toast.error(getErrorMessage(error, "Failed to generate new title."));
},
});
const [isSidebarCollapsed, setIsSidebarCollapsed] = useState(false);
const [chatErrorReasons, setChatErrorReasons] = useState<
Record<string, ChatDetailError>
@@ -359,6 +366,12 @@ const AgentsPage: FC = () => {
const requestReorderPinnedAgent = (chatId: string, pinOrder: number) => {
reorderPinnedChatMutation.mutate({ chatId, pinOrder });
};
const requestRegenerateTitle = (chatId: string) => {
if (regenerateTitleMutation.isPending) {
return;
}
regenerateTitleMutation.mutate(chatId);
};
const handleToggleSidebarCollapsed = () =>
setIsSidebarCollapsed((prev) => !prev);
@@ -609,6 +622,9 @@ const AgentsPage: FC = () => {
requestPinAgent={requestPinAgent}
requestUnpinAgent={requestUnpinAgent}
requestReorderPinnedAgent={requestReorderPinnedAgent}
onRegenerateTitle={requestRegenerateTitle}
isRegeneratingTitle={regenerateTitleMutation.isPending}
regeneratingTitleChatId={regenerateTitleMutation.variables ?? null}
onToggleSidebarCollapsed={handleToggleSidebarCollapsed}
isAgentsAdmin={isAgentsAdmin}
hasNextPage={chatsQuery.hasNextPage}
@@ -23,6 +23,9 @@ export interface AgentsOutletContext {
requestPinAgent: (chatId: string) => void;
requestUnpinAgent: (chatId: string) => void;
requestReorderPinnedAgent?: (chatId: string, pinOrder: number) => void;
onRegenerateTitle?: (chatId: string) => void;
isRegeneratingTitle: boolean;
regeneratingTitleChatId: string | null;
isSidebarCollapsed: boolean;
onToggleSidebarCollapsed: () => void;
onExpandSidebar: () => void;
@@ -59,6 +62,9 @@ interface AgentsPageViewProps {
requestPinAgent: (chatId: string) => void;
requestUnpinAgent: (chatId: string) => void;
requestReorderPinnedAgent?: (chatId: string, pinOrder: number) => void;
onRegenerateTitle: (chatId: string) => void;
isRegeneratingTitle: boolean;
regeneratingTitleChatId: string | null;
onToggleSidebarCollapsed: () => void;
isAgentsAdmin: boolean;
hasNextPage: boolean | undefined;
@@ -93,6 +99,9 @@ export const AgentsPageView: FC<AgentsPageViewProps> = ({
requestPinAgent,
requestUnpinAgent,
requestReorderPinnedAgent,
onRegenerateTitle,
isRegeneratingTitle,
regeneratingTitleChatId,
onToggleSidebarCollapsed,
isAgentsAdmin,
hasNextPage,
@@ -133,6 +142,9 @@ export const AgentsPageView: FC<AgentsPageViewProps> = ({
requestPinAgent,
requestUnpinAgent,
requestReorderPinnedAgent,
onRegenerateTitle,
isRegeneratingTitle,
regeneratingTitleChatId,
isSidebarCollapsed,
onToggleSidebarCollapsed,
onExpandSidebar,
@@ -166,6 +178,9 @@ export const AgentsPageView: FC<AgentsPageViewProps> = ({
onPinAgent={requestPinAgent}
onUnpinAgent={requestUnpinAgent}
onReorderPinnedAgent={requestReorderPinnedAgent}
onRegenerateTitle={onRegenerateTitle}
isRegeneratingTitle={isRegeneratingTitle}
regeneratingTitleChatId={regeneratingTitleChatId}
onBeforeNewAgent={handleNewAgent}
isCreating={isCreating}
isArchiving={isArchiving}
@@ -1,26 +1,27 @@
import type { Meta, StoryObj } from "@storybook/react-vite";
import { expect, userEvent, waitFor, within } from "storybook/test";
import { expect, fn, userEvent, waitFor, within } from "storybook/test";
import { AgentDetailTopBar } from "./TopBar";
const defaultProps = {
chatTitle: "Build authentication feature",
panel: {
showSidebarPanel: false,
onToggleSidebar: () => {},
onToggleSidebar: fn(),
},
workspace: {
canOpenEditors: true,
canOpenWorkspace: true,
onOpenInEditor: () => {},
onViewWorkspace: () => {},
onOpenTerminal: () => {},
onOpenInEditor: fn(),
onViewWorkspace: fn(),
onOpenTerminal: fn(),
sshCommand: "ssh main.my-workspace.admin.coder",
},
onArchiveAgent: () => {},
onArchiveAndDeleteWorkspace: () => {},
onUnarchiveAgent: () => {},
onArchiveAgent: fn(),
onArchiveAndDeleteWorkspace: fn(),
onRegenerateTitle: fn(),
onUnarchiveAgent: fn(),
isSidebarCollapsed: false,
onToggleSidebarCollapsed: () => {},
onToggleSidebarCollapsed: fn(),
} satisfies React.ComponentProps<typeof AgentDetailTopBar>;
const meta: Meta<typeof AgentDetailTopBar> = {
@@ -36,6 +37,13 @@ type Story = StoryObj<typeof AgentDetailTopBar>;
export const Default: Story = {};
export const RegeneratingTitle: Story = {
args: {
...Default.args,
isRegeneratingTitle: true,
},
};
export const WithPanelOpen: Story = {
args: {
panel: {
@@ -227,10 +235,22 @@ export const MobileWithClosedPR: Story = {
},
};
export const GenerateTitle: Story = {
play: async ({ canvasElement }) => {
const canvas = within(canvasElement);
const trigger = canvas.getByLabelText("Open agent actions");
await userEvent.click(trigger);
await waitFor(() => {
const body = within(document.body);
expect(body.getByText("Generate new title")).toBeInTheDocument();
});
},
};
export const ArchivedWithUnarchive: Story = {
args: {
isArchived: true,
onUnarchiveAgent: () => {},
onUnarchiveAgent: fn(),
},
play: async ({ canvasElement }) => {
const canvas = within(canvasElement);
@@ -243,6 +263,7 @@ export const ArchivedWithUnarchive: Story = {
expect(body.getByText("Unarchive Agent")).toBeInTheDocument();
});
const body = within(document.body);
expect(body.queryByText("Generate new title")).not.toBeInTheDocument();
expect(body.queryByText("Archive Agent")).not.toBeInTheDocument();
expect(
body.queryByText("Archive & Delete Workspace"),
@@ -12,10 +12,12 @@ import {
PanelRightOpenIcon,
TerminalIcon,
Trash2Icon,
WandSparklesIcon,
} from "lucide-react";
import type { FC } from "react";
import { Link } from "react-router";
import { toast } from "sonner";
import { cn } from "utils/cn";
import type * as TypesGen from "#/api/typesGenerated";
import type { ChatDiffStatus } from "#/api/typesGenerated";
import { Button } from "#/components/Button/Button";
@@ -26,7 +28,7 @@ import {
DropdownMenuSeparator,
DropdownMenuTrigger,
} from "#/components/DropdownMenu/DropdownMenu";
import { cn } from "#/utils/cn";
import { Spinner } from "#/components/Spinner/Spinner";
import { parsePullRequestUrl } from "../../utils/pullRequest";
import { useEmbedContext } from "../EmbedContext";
import { PrStateIcon } from "../GitPanel/GitPanel";
@@ -53,6 +55,9 @@ type AgentDetailTopBarProps = {
onArchiveAgent: () => void;
onUnarchiveAgent: () => void;
onArchiveAndDeleteWorkspace: () => void;
onRegenerateTitle?: () => void;
isRegeneratingTitle?: boolean;
isRegenerateTitleDisabled?: boolean;
hasWorkspace?: boolean;
isArchived?: boolean;
isSidebarCollapsed: boolean;
@@ -68,6 +73,9 @@ export const AgentDetailTopBar: FC<AgentDetailTopBarProps> = ({
onArchiveAgent,
onUnarchiveAgent,
onArchiveAndDeleteWorkspace,
onRegenerateTitle,
isRegeneratingTitle,
isRegenerateTitleDisabled,
hasWorkspace,
isArchived,
isSidebarCollapsed,
@@ -115,7 +123,12 @@ export const AgentDetailTopBar: FC<AgentDetailTopBarProps> = ({
{/* Title area */}
<div className="flex min-w-0 flex-1 items-center">
{chatTitle && (
<div className="flex min-w-0 items-center gap-1.5">
<div
role="status"
aria-live="polite"
aria-busy={isRegeneratingTitle}
className="flex min-w-0 items-center gap-1.5"
>
{parentChat && (
<>
<Button
@@ -131,9 +144,21 @@ export const AgentDetailTopBar: FC<AgentDetailTopBarProps> = ({
<ChevronRightIcon className="h-3.5 w-3.5 shrink-0 text-content-secondary/70 -ml-0.5" />
</>
)}
<span className="truncate text-sm text-content-primary">
<span
className={cn(
"truncate text-sm text-content-primary",
isRegeneratingTitle && "animate-pulse",
)}
>
{chatTitle}
</span>
{isRegeneratingTitle && (
<Spinner
aria-label="Regenerating title"
className="h-3.5 w-3.5 shrink-0 text-content-secondary"
loading
/>
)}
</div>
)}
</div>
@@ -227,7 +252,23 @@ export const AgentDetailTopBar: FC<AgentDetailTopBarProps> = ({
<MonitorIcon className="h-3.5 w-3.5" />
View Workspace
</DropdownMenuItem>
<DropdownMenuSeparator />
{!isArchived && (
<>
<DropdownMenuSeparator />
{onRegenerateTitle && (
<>
<DropdownMenuItem
disabled={isRegenerateTitleDisabled}
onSelect={onRegenerateTitle}
>
<WandSparklesIcon className="h-3.5 w-3.5" />
Generate new title
</DropdownMenuItem>
<DropdownMenuSeparator />
</>
)}
</>
)}
{isArchived ? (
<DropdownMenuItem onSelect={onUnarchiveAgent}>
<ArchiveRestoreIcon className="h-3.5 w-3.5" />
@@ -142,6 +142,7 @@ const StoryAgentDetailView: FC<StoryProps> = ({ editing, ...overrides }) => {
handleArchiveAgentAction: fn(),
handleUnarchiveAgentAction: fn(),
handleArchiveAndDeleteWorkspaceAction: fn(),
handleRegenerateTitle: fn(),
scrollContainerRef: { current: null },
hasMoreMessages: false,
isFetchingMoreMessages: false,
@@ -117,6 +117,9 @@ interface AgentDetailViewProps {
handleArchiveAgentAction: () => void;
handleUnarchiveAgentAction: () => void;
handleArchiveAndDeleteWorkspaceAction: () => void;
handleRegenerateTitle?: () => void;
isRegeneratingTitle?: boolean;
isRegenerateTitleDisabled?: boolean;
// Scroll container ref.
scrollContainerRef: RefObject<HTMLDivElement | null>;
@@ -179,6 +182,9 @@ export const AgentDetailView: FC<AgentDetailViewProps> = ({
handleArchiveAgentAction,
handleUnarchiveAgentAction,
handleArchiveAndDeleteWorkspaceAction,
handleRegenerateTitle,
isRegeneratingTitle,
isRegenerateTitleDisabled,
scrollContainerRef,
scrollToBottomRef,
hasMoreMessages,
@@ -261,6 +267,11 @@ export const AgentDetailView: FC<AgentDetailViewProps> = ({
onArchiveAndDeleteWorkspace={
handleArchiveAndDeleteWorkspaceAction
}
{...(handleRegenerateTitle
? { onRegenerateTitle: handleRegenerateTitle }
: {})}
isRegeneratingTitle={isRegeneratingTitle}
isRegenerateTitleDisabled={isRegenerateTitleDisabled}
hasWorkspace={hasWorkspace}
isArchived={isArchived}
diffStatusData={diffStatusData}
@@ -435,6 +446,7 @@ export const AgentDetailLoadingView: FC<AgentDetailLoadingViewProps> = ({
}}
onArchiveAgent={() => {}}
onUnarchiveAgent={() => {}}
onRegenerateTitle={() => {}}
onArchiveAndDeleteWorkspace={() => {}}
hasWorkspace={false}
isSidebarCollapsed={isSidebarCollapsed}
@@ -507,6 +519,7 @@ export const AgentDetailNotFoundView: FC<AgentDetailNotFoundViewProps> = ({
}}
onArchiveAgent={() => {}}
onUnarchiveAgent={() => {}}
onRegenerateTitle={() => {}}
onArchiveAndDeleteWorkspace={() => {}}
hasWorkspace={false}
isSidebarCollapsed={isSidebarCollapsed}
@@ -74,8 +74,11 @@ const meta: Meta<typeof AgentsSidebar> = {
onArchiveAndDeleteWorkspace: fn(),
onPinAgent: fn(),
onUnpinAgent: fn(),
onRegenerateTitle: fn(),
onBeforeNewAgent: fn(),
isCreating: false,
isRegeneratingTitle: false,
regeneratingTitleChatId: null,
archivedFilter: "active" as const,
onArchivedFilterChange: fn(),
},
@@ -108,6 +108,9 @@ const defaultProps: React.ComponentProps<typeof AgentsSidebar> = {
onArchiveAndDeleteWorkspace: vi.fn(),
onPinAgent: vi.fn(),
onUnpinAgent: vi.fn(),
onRegenerateTitle: vi.fn(),
isRegeneratingTitle: false,
regeneratingTitleChatId: null,
onBeforeNewAgent: vi.fn(),
isCreating: false,
archivedFilter: "active" as const,
@@ -72,6 +72,7 @@ import {
DropdownMenu,
DropdownMenuContent,
DropdownMenuItem,
DropdownMenuSeparator,
DropdownMenuTrigger,
} from "#/components/DropdownMenu/DropdownMenu";
import { ExternalImage } from "#/components/ExternalImage/ExternalImage";
@@ -123,10 +124,13 @@ interface AgentsSidebarProps {
onPinAgent: (chatId: string) => void;
onUnpinAgent: (chatId: string) => void;
onReorderPinnedAgent?: (chatId: string, pinOrder: number) => void;
onRegenerateTitle: (chatId: string) => void;
onBeforeNewAgent?: () => void;
isCreating: boolean;
isArchiving?: boolean;
archivingChatId?: string | null;
isRegeneratingTitle?: boolean;
regeneratingTitleChatId?: string | null;
isLoading?: boolean;
loadError?: unknown;
onRetryLoad?: () => void;
@@ -367,6 +371,8 @@ interface ChatTreeContextValue {
readonly chatErrorReasons: Record<string, string>;
readonly isArchiving: boolean;
readonly archivingChatId: string | null;
readonly isRegeneratingTitle: boolean;
readonly regeneratingTitleChatId: string | null;
readonly toggleExpanded: (chatID: string) => void;
readonly onArchiveAgent: (chatId: string) => void;
readonly onUnarchiveAgent: (chatId: string) => void;
@@ -376,6 +382,7 @@ interface ChatTreeContextValue {
) => void;
readonly onPinAgent: (chatId: string) => void;
readonly onUnpinAgent: (chatId: string) => void;
readonly onRegenerateTitle: (chatId: string) => void;
}
const ChatTreeContext = createContext<ChatTreeContextValue | null>(null);
@@ -405,12 +412,15 @@ const ChatTreeNode: FC<ChatTreeNodeProps> = ({ chat, isChildNode }) => {
chatErrorReasons,
isArchiving,
archivingChatId,
isRegeneratingTitle,
regeneratingTitleChatId,
toggleExpanded,
onArchiveAgent,
onUnarchiveAgent,
onArchiveAndDeleteWorkspace,
onPinAgent,
onUnpinAgent,
onRegenerateTitle,
} = useChatTree();
const chatID = chat.id;
const childIDs = (chatTree.childrenById.get(chatID) ?? []).filter((childID) =>
@@ -448,6 +458,8 @@ const ChatTreeNode: FC<ChatTreeNodeProps> = ({ chat, isChildNode }) => {
}`;
const workspaceId = chat.workspace_id;
const isArchivingThisChat = isArchiving && archivingChatId === chat.id;
const isRegeneratingThisChat =
isRegeneratingTitle && regeneratingTitleChatId === chat.id;
const isExpanded = normalizedSearch ? true : (expandedById[chatID] ?? false);
return (
@@ -509,13 +521,21 @@ const ChatTreeNode: FC<ChatTreeNodeProps> = ({ chat, isChildNode }) => {
<div className="min-w-0 flex-1 overflow-hidden text-left">
<div className="flex min-w-0 items-center gap-1.5 overflow-hidden">
<span
aria-busy={isRegeneratingThisChat}
className={cn(
"block flex-1 truncate text-[13px] text-content-primary",
isActive && "font-medium",
// Pulse-only in sidebar (no spinner) — space-constrained card layout.
isRegeneratingThisChat && "animate-pulse",
)}
>
{chat.title}
</span>
{isRegeneratingThisChat && (
<span className="sr-only" role="status">
Regenerating title
</span>
)}
</div>
<div className="flex min-w-0 items-center gap-1.5">
{hasLinkedDiffStatus && hasLineStats && (
@@ -598,6 +618,14 @@ const ChatTreeNode: FC<ChatTreeNodeProps> = ({ chat, isChildNode }) => {
</DropdownMenuItem>
) : (
<>
<DropdownMenuItem
disabled={isRegeneratingTitle}
onSelect={() => onRegenerateTitle(chat.id)}
>
<WandSparklesIcon className="h-3.5 w-3.5" />
Generate new title
</DropdownMenuItem>
<DropdownMenuSeparator />
<DropdownMenuItem
className="text-content-destructive focus:text-content-destructive"
disabled={isArchiving}
@@ -710,10 +738,13 @@ export const AgentsSidebar: FC<AgentsSidebarProps> = (props) => {
onPinAgent,
onUnpinAgent,
onReorderPinnedAgent,
onRegenerateTitle,
onBeforeNewAgent,
isCreating,
isArchiving = false,
archivingChatId = null,
isRegeneratingTitle = false,
regeneratingTitleChatId = null,
isLoading = false,
loadError,
onRetryLoad,
@@ -910,12 +941,15 @@ export const AgentsSidebar: FC<AgentsSidebarProps> = (props) => {
chatErrorReasons,
isArchiving,
archivingChatId,
isRegeneratingTitle,
regeneratingTitleChatId,
toggleExpanded,
onArchiveAgent,
onUnarchiveAgent,
onArchiveAndDeleteWorkspace,
onPinAgent,
onUnpinAgent,
onRegenerateTitle,
};
const subNavTitle = "Settings";