mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
feat: add manual chat title regeneration (#23633)
## Summary
Adds a "Generate new title" action that lets users manually regenerate a
chat's title using richer conversation context than the automatic
first-message title path.
## Changes
### Backend
- **New endpoint:** `POST
/api/experimental/chats/{chatID}/title/regenerate` returns the updated
Chat with a regenerated title
- **Manual title algorithm:** Extracts useful user/assistant text turns
→ selects first user turn + last 3 turns → builds context with gap
markers → renders prompt with anti-recency guidance → calls lightweight
model → normalizes output
- **Helpers:** `extractManualTitleTurns`,
`selectManualTitleTurnIndexes`, `buildManualTitleContext`,
`renderManualTitlePrompt`, `generateManualTitle` — all private, with the
public `Server.RegenerateChatTitle` method
- **SDK:** `ExperimentalClient.RegenerateChatTitle(ctx, chatID) (Chat,
error)`
- Persists title via existing `UpdateChatByID` and broadcasts
`ChatEventKindTitleChange`
### Frontend
- API client method + React Query mutation with cache invalidation
- "Generate new title" menu item (with wand icon) in both TopBar and
Sidebar dropdown menus
- Loading/disabled state while regeneration is in-flight
- Error toast on failure
- Stories updated for both menus
### Tests
- `quickgen_test.go`: Table-driven tests for all 4 helper functions
(turn extraction, index selection, context building, prompt rendering)
- `exp_chats_test.go`: Handler tests (ChatNotFound,
NotFoundForDifferentUser, NoDaemon)
## Design notes
- The existing auto-title path (`maybeGenerateChatTitle`, `titleInput`)
is completely unchanged
- Manual regeneration uses richer context (first user turn + last 3
turns + gap markers) vs the auto path's single first message
- Endpoint is experimental and marked with `@x-apidocgen {"skip": true}`
This commit is contained in:
@@ -1234,6 +1234,7 @@ func New(options *Options) *API {
|
||||
r.Get("/git", api.watchChatGit)
|
||||
})
|
||||
r.Post("/interrupt", api.interruptChat)
|
||||
r.Post("/title/regenerate", api.regenerateChatTitle)
|
||||
r.Get("/diff", api.getChatDiffContents)
|
||||
r.Route("/queue/{queuedMessage}", func(r chi.Router) {
|
||||
r.Delete("/", api.deleteChatQueuedMessage)
|
||||
|
||||
@@ -2614,6 +2614,14 @@ func (q *querier) GetChatMessagesByChatID(ctx context.Context, arg database.GetC
|
||||
return q.db.GetChatMessagesByChatID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatMessagesByChatIDAscPaginated(ctx context.Context, arg database.GetChatMessagesByChatIDAscPaginatedParams) ([]database.ChatMessage, error) {
|
||||
_, err := q.GetChatByID(ctx, arg.ChatID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetChatMessagesByChatIDAscPaginated(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetChatMessagesByChatIDDescPaginated(ctx context.Context, arg database.GetChatMessagesByChatIDDescPaginatedParams) ([]database.ChatMessage, error) {
|
||||
_, err := q.GetChatByID(ctx, arg.ChatID)
|
||||
if err != nil {
|
||||
@@ -5736,6 +5744,17 @@ func (q *querier) UpdateChatLabelsByID(ctx context.Context, arg database.UpdateC
|
||||
return q.db.UpdateChatLabelsByID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatLastModelConfigByID(ctx context.Context, arg database.UpdateChatLastModelConfigByIDParams) (database.Chat, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ID)
|
||||
if err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
return q.db.UpdateChatLastModelConfigByID(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatMCPServerIDs(ctx context.Context, arg database.UpdateChatMCPServerIDsParams) (database.Chat, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ID)
|
||||
if err != nil {
|
||||
@@ -5801,6 +5820,17 @@ func (q *querier) UpdateChatStatus(ctx context.Context, arg database.UpdateChatS
|
||||
return q.db.UpdateChatStatus(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatStatusPreserveUpdatedAt(ctx context.Context, arg database.UpdateChatStatusPreserveUpdatedAtParams) (database.Chat, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ID)
|
||||
if err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, chat); err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
return q.db.UpdateChatStatusPreserveUpdatedAt(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) UpdateChatWorkspaceBinding(ctx context.Context, arg database.UpdateChatWorkspaceBindingParams) (database.Chat, error) {
|
||||
chat, err := q.db.GetChatByID(ctx, arg.ID)
|
||||
if err != nil {
|
||||
|
||||
@@ -592,6 +592,14 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().GetChatMessagesByChatID(gomock.Any(), arg).Return(msgs, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionRead).Returns(msgs)
|
||||
}))
|
||||
s.Run("GetChatMessagesByChatIDAscPaginated", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
msgs := []database.ChatMessage{testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})}
|
||||
arg := database.GetChatMessagesByChatIDAscPaginatedParams{ChatID: chat.ID, AfterID: 0, LimitVal: 50}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().GetChatMessagesByChatIDAscPaginated(gomock.Any(), arg).Return(msgs, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionRead).Returns(msgs)
|
||||
}))
|
||||
s.Run("GetChatMessagesByChatIDDescPaginated", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
msgs := []database.ChatMessage{testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})}
|
||||
@@ -789,6 +797,26 @@ func (s *MethodTestSuite) TestChats() {
|
||||
dbm.EXPECT().UpdateChatLabelsByID(gomock.Any(), arg).Return(chat, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat)
|
||||
}))
|
||||
s.Run("UpdateChatLastModelConfigByID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.UpdateChatLastModelConfigByIDParams{
|
||||
ID: chat.ID,
|
||||
LastModelConfigID: uuid.New(),
|
||||
}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateChatLastModelConfigByID(gomock.Any(), arg).Return(chat, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat)
|
||||
}))
|
||||
s.Run("UpdateChatStatusPreserveUpdatedAt", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.UpdateChatStatusPreserveUpdatedAtParams{
|
||||
ID: chat.ID,
|
||||
Status: database.ChatStatusRunning,
|
||||
}
|
||||
dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes()
|
||||
dbm.EXPECT().UpdateChatStatusPreserveUpdatedAt(gomock.Any(), arg).Return(chat, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(chat, policy.ActionUpdate).Returns(chat)
|
||||
}))
|
||||
s.Run("UpdateChatHeartbeat", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
chat := testutil.Fake(s.T(), faker, database.Chat{})
|
||||
arg := database.UpdateChatHeartbeatParams{
|
||||
|
||||
@@ -1144,6 +1144,14 @@ func (m queryMetricsStore) GetChatMessagesByChatID(ctx context.Context, chatID d
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatMessagesByChatIDAscPaginated(ctx context.Context, arg database.GetChatMessagesByChatIDAscPaginatedParams) ([]database.ChatMessage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatMessagesByChatIDAscPaginated(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("GetChatMessagesByChatIDAscPaginated").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetChatMessagesByChatIDAscPaginated").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetChatMessagesByChatIDDescPaginated(ctx context.Context, arg database.GetChatMessagesByChatIDDescPaginatedParams) ([]database.ChatMessage, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetChatMessagesByChatIDDescPaginated(ctx, arg)
|
||||
@@ -4096,6 +4104,14 @@ func (m queryMetricsStore) UpdateChatLabelsByID(ctx context.Context, arg databas
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatLastModelConfigByID(ctx context.Context, arg database.UpdateChatLastModelConfigByIDParams) (database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateChatLastModelConfigByID(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateChatLastModelConfigByID").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatLastModelConfigByID").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatMCPServerIDs(ctx context.Context, arg database.UpdateChatMCPServerIDsParams) (database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateChatMCPServerIDs(ctx, arg)
|
||||
@@ -4144,6 +4160,14 @@ func (m queryMetricsStore) UpdateChatStatus(ctx context.Context, arg database.Up
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatStatusPreserveUpdatedAt(ctx context.Context, arg database.UpdateChatStatusPreserveUpdatedAtParams) (database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateChatStatusPreserveUpdatedAt(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("UpdateChatStatusPreserveUpdatedAt").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateChatStatusPreserveUpdatedAt").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) UpdateChatWorkspaceBinding(ctx context.Context, arg database.UpdateChatWorkspaceBindingParams) (database.Chat, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.UpdateChatWorkspaceBinding(ctx, arg)
|
||||
|
||||
@@ -2103,6 +2103,21 @@ func (mr *MockStoreMockRecorder) GetChatMessagesByChatID(ctx, arg any) *gomock.C
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatMessagesByChatID", reflect.TypeOf((*MockStore)(nil).GetChatMessagesByChatID), ctx, arg)
|
||||
}
|
||||
|
||||
// GetChatMessagesByChatIDAscPaginated mocks base method.
|
||||
func (m *MockStore) GetChatMessagesByChatIDAscPaginated(ctx context.Context, arg database.GetChatMessagesByChatIDAscPaginatedParams) ([]database.ChatMessage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetChatMessagesByChatIDAscPaginated", ctx, arg)
|
||||
ret0, _ := ret[0].([]database.ChatMessage)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetChatMessagesByChatIDAscPaginated indicates an expected call of GetChatMessagesByChatIDAscPaginated.
|
||||
func (mr *MockStoreMockRecorder) GetChatMessagesByChatIDAscPaginated(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetChatMessagesByChatIDAscPaginated", reflect.TypeOf((*MockStore)(nil).GetChatMessagesByChatIDAscPaginated), ctx, arg)
|
||||
}
|
||||
|
||||
// GetChatMessagesByChatIDDescPaginated mocks base method.
|
||||
func (m *MockStore) GetChatMessagesByChatIDDescPaginated(ctx context.Context, arg database.GetChatMessagesByChatIDDescPaginatedParams) ([]database.ChatMessage, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -7745,6 +7760,21 @@ func (mr *MockStoreMockRecorder) UpdateChatLabelsByID(ctx, arg any) *gomock.Call
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatLabelsByID", reflect.TypeOf((*MockStore)(nil).UpdateChatLabelsByID), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatLastModelConfigByID mocks base method.
|
||||
func (m *MockStore) UpdateChatLastModelConfigByID(ctx context.Context, arg database.UpdateChatLastModelConfigByIDParams) (database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateChatLastModelConfigByID", ctx, arg)
|
||||
ret0, _ := ret[0].(database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateChatLastModelConfigByID indicates an expected call of UpdateChatLastModelConfigByID.
|
||||
func (mr *MockStoreMockRecorder) UpdateChatLastModelConfigByID(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatLastModelConfigByID", reflect.TypeOf((*MockStore)(nil).UpdateChatLastModelConfigByID), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatMCPServerIDs mocks base method.
|
||||
func (m *MockStore) UpdateChatMCPServerIDs(ctx context.Context, arg database.UpdateChatMCPServerIDsParams) (database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -7834,6 +7864,21 @@ func (mr *MockStoreMockRecorder) UpdateChatStatus(ctx, arg any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatStatus", reflect.TypeOf((*MockStore)(nil).UpdateChatStatus), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatStatusPreserveUpdatedAt mocks base method.
|
||||
func (m *MockStore) UpdateChatStatusPreserveUpdatedAt(ctx context.Context, arg database.UpdateChatStatusPreserveUpdatedAtParams) (database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "UpdateChatStatusPreserveUpdatedAt", ctx, arg)
|
||||
ret0, _ := ret[0].(database.Chat)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// UpdateChatStatusPreserveUpdatedAt indicates an expected call of UpdateChatStatusPreserveUpdatedAt.
|
||||
func (mr *MockStoreMockRecorder) UpdateChatStatusPreserveUpdatedAt(ctx, arg any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateChatStatusPreserveUpdatedAt", reflect.TypeOf((*MockStore)(nil).UpdateChatStatusPreserveUpdatedAt), ctx, arg)
|
||||
}
|
||||
|
||||
// UpdateChatWorkspaceBinding mocks base method.
|
||||
func (m *MockStore) UpdateChatWorkspaceBinding(ctx context.Context, arg database.UpdateChatWorkspaceBindingParams) (database.Chat, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -250,6 +250,7 @@ type sqlcQuerier interface {
|
||||
GetChatIncludeDefaultSystemPrompt(ctx context.Context) (bool, error)
|
||||
GetChatMessageByID(ctx context.Context, id int64) (ChatMessage, error)
|
||||
GetChatMessagesByChatID(ctx context.Context, arg GetChatMessagesByChatIDParams) ([]ChatMessage, error)
|
||||
GetChatMessagesByChatIDAscPaginated(ctx context.Context, arg GetChatMessagesByChatIDAscPaginatedParams) ([]ChatMessage, error)
|
||||
GetChatMessagesByChatIDDescPaginated(ctx context.Context, arg GetChatMessagesByChatIDDescPaginatedParams) ([]ChatMessage, error)
|
||||
GetChatMessagesForPromptByChatID(ctx context.Context, chatID uuid.UUID) ([]ChatMessage, error)
|
||||
GetChatModelConfigByID(ctx context.Context, id uuid.UUID) (ChatModelConfig, error)
|
||||
@@ -852,12 +853,14 @@ type sqlcQuerier interface {
|
||||
// replicas know the worker is still alive.
|
||||
UpdateChatHeartbeat(ctx context.Context, arg UpdateChatHeartbeatParams) (int64, error)
|
||||
UpdateChatLabelsByID(ctx context.Context, arg UpdateChatLabelsByIDParams) (Chat, error)
|
||||
UpdateChatLastModelConfigByID(ctx context.Context, arg UpdateChatLastModelConfigByIDParams) (Chat, error)
|
||||
UpdateChatMCPServerIDs(ctx context.Context, arg UpdateChatMCPServerIDsParams) (Chat, error)
|
||||
UpdateChatMessageByID(ctx context.Context, arg UpdateChatMessageByIDParams) (ChatMessage, error)
|
||||
UpdateChatModelConfig(ctx context.Context, arg UpdateChatModelConfigParams) (ChatModelConfig, error)
|
||||
UpdateChatPinOrder(ctx context.Context, arg UpdateChatPinOrderParams) error
|
||||
UpdateChatProvider(ctx context.Context, arg UpdateChatProviderParams) (ChatProvider, error)
|
||||
UpdateChatStatus(ctx context.Context, arg UpdateChatStatusParams) (Chat, error)
|
||||
UpdateChatStatusPreserveUpdatedAt(ctx context.Context, arg UpdateChatStatusPreserveUpdatedAtParams) (Chat, error)
|
||||
UpdateChatWorkspaceBinding(ctx context.Context, arg UpdateChatWorkspaceBindingParams) (Chat, error)
|
||||
UpdateCryptoKeyDeletesAt(ctx context.Context, arg UpdateCryptoKeyDeletesAtParams) (CryptoKey, error)
|
||||
UpdateCustomRole(ctx context.Context, arg UpdateCustomRoleParams) (CustomRole, error)
|
||||
|
||||
@@ -4932,6 +4932,73 @@ func (q *sqlQuerier) GetChatMessagesByChatID(ctx context.Context, arg GetChatMes
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const getChatMessagesByChatIDAscPaginated = `-- name: GetChatMessagesByChatIDAscPaginated :many
|
||||
SELECT
|
||||
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id
|
||||
FROM
|
||||
chat_messages
|
||||
WHERE
|
||||
chat_id = $1::uuid
|
||||
AND id > $2::bigint
|
||||
AND visibility IN ('user', 'both')
|
||||
AND deleted = false
|
||||
ORDER BY
|
||||
id ASC
|
||||
LIMIT
|
||||
COALESCE(NULLIF($3::int, 0), 50)
|
||||
`
|
||||
|
||||
type GetChatMessagesByChatIDAscPaginatedParams struct {
|
||||
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
|
||||
AfterID int64 `db:"after_id" json:"after_id"`
|
||||
LimitVal int32 `db:"limit_val" json:"limit_val"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) GetChatMessagesByChatIDAscPaginated(ctx context.Context, arg GetChatMessagesByChatIDAscPaginatedParams) ([]ChatMessage, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getChatMessagesByChatIDAscPaginated, arg.ChatID, arg.AfterID, arg.LimitVal)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []ChatMessage
|
||||
for rows.Next() {
|
||||
var i ChatMessage
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.ChatID,
|
||||
&i.ModelConfigID,
|
||||
&i.CreatedAt,
|
||||
&i.Role,
|
||||
&i.Content,
|
||||
&i.Visibility,
|
||||
&i.InputTokens,
|
||||
&i.OutputTokens,
|
||||
&i.TotalTokens,
|
||||
&i.ReasoningTokens,
|
||||
&i.CacheCreationTokens,
|
||||
&i.CacheReadTokens,
|
||||
&i.ContextLimit,
|
||||
&i.Compressed,
|
||||
&i.CreatedBy,
|
||||
&i.ContentVersion,
|
||||
&i.TotalCostMicros,
|
||||
&i.RuntimeMs,
|
||||
&i.Deleted,
|
||||
&i.ProviderResponseID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const getChatMessagesByChatIDDescPaginated = `-- name: GetChatMessagesByChatIDDescPaginated :many
|
||||
SELECT
|
||||
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id
|
||||
@@ -6254,6 +6321,52 @@ func (q *sqlQuerier) UpdateChatLabelsByID(ctx context.Context, arg UpdateChatLab
|
||||
return i, err
|
||||
}
|
||||
|
||||
const updateChatLastModelConfigByID = `-- name: UpdateChatLastModelConfigByID :one
|
||||
UPDATE
|
||||
chats
|
||||
SET
|
||||
-- NOTE: updated_at is intentionally NOT touched here to avoid changing list ordering.
|
||||
last_model_config_id = $1::uuid
|
||||
WHERE
|
||||
id = $2::uuid
|
||||
RETURNING
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order
|
||||
`
|
||||
|
||||
type UpdateChatLastModelConfigByIDParams struct {
|
||||
LastModelConfigID uuid.UUID `db:"last_model_config_id" json:"last_model_config_id"`
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) UpdateChatLastModelConfigByID(ctx context.Context, arg UpdateChatLastModelConfigByIDParams) (Chat, error) {
|
||||
row := q.db.QueryRowContext(ctx, updateChatLastModelConfigByID, arg.LastModelConfigID, arg.ID)
|
||||
var i Chat
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.OwnerID,
|
||||
&i.WorkspaceID,
|
||||
&i.Title,
|
||||
&i.Status,
|
||||
&i.WorkerID,
|
||||
&i.StartedAt,
|
||||
&i.HeartbeatAt,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.ParentChatID,
|
||||
&i.RootChatID,
|
||||
&i.LastModelConfigID,
|
||||
&i.Archived,
|
||||
&i.LastError,
|
||||
&i.Mode,
|
||||
pq.Array(&i.MCPServerIDs),
|
||||
&i.Labels,
|
||||
&i.BuildID,
|
||||
&i.AgentID,
|
||||
&i.PinOrder,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const updateChatMCPServerIDs = `-- name: UpdateChatMCPServerIDs :one
|
||||
UPDATE
|
||||
chats
|
||||
@@ -6479,6 +6592,69 @@ func (q *sqlQuerier) UpdateChatStatus(ctx context.Context, arg UpdateChatStatusP
|
||||
return i, err
|
||||
}
|
||||
|
||||
const updateChatStatusPreserveUpdatedAt = `-- name: UpdateChatStatusPreserveUpdatedAt :one
|
||||
UPDATE
|
||||
chats
|
||||
SET
|
||||
status = $1::chat_status,
|
||||
worker_id = $2::uuid,
|
||||
started_at = $3::timestamptz,
|
||||
heartbeat_at = $4::timestamptz,
|
||||
last_error = $5::text,
|
||||
updated_at = $6::timestamptz
|
||||
WHERE
|
||||
id = $7::uuid
|
||||
RETURNING
|
||||
id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error, mode, mcp_server_ids, labels, build_id, agent_id, pin_order
|
||||
`
|
||||
|
||||
type UpdateChatStatusPreserveUpdatedAtParams struct {
|
||||
Status ChatStatus `db:"status" json:"status"`
|
||||
WorkerID uuid.NullUUID `db:"worker_id" json:"worker_id"`
|
||||
StartedAt sql.NullTime `db:"started_at" json:"started_at"`
|
||||
HeartbeatAt sql.NullTime `db:"heartbeat_at" json:"heartbeat_at"`
|
||||
LastError sql.NullString `db:"last_error" json:"last_error"`
|
||||
UpdatedAt time.Time `db:"updated_at" json:"updated_at"`
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) UpdateChatStatusPreserveUpdatedAt(ctx context.Context, arg UpdateChatStatusPreserveUpdatedAtParams) (Chat, error) {
|
||||
row := q.db.QueryRowContext(ctx, updateChatStatusPreserveUpdatedAt,
|
||||
arg.Status,
|
||||
arg.WorkerID,
|
||||
arg.StartedAt,
|
||||
arg.HeartbeatAt,
|
||||
arg.LastError,
|
||||
arg.UpdatedAt,
|
||||
arg.ID,
|
||||
)
|
||||
var i Chat
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.OwnerID,
|
||||
&i.WorkspaceID,
|
||||
&i.Title,
|
||||
&i.Status,
|
||||
&i.WorkerID,
|
||||
&i.StartedAt,
|
||||
&i.HeartbeatAt,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
&i.ParentChatID,
|
||||
&i.RootChatID,
|
||||
&i.LastModelConfigID,
|
||||
&i.Archived,
|
||||
&i.LastError,
|
||||
&i.Mode,
|
||||
pq.Array(&i.MCPServerIDs),
|
||||
&i.Labels,
|
||||
&i.BuildID,
|
||||
&i.AgentID,
|
||||
&i.PinOrder,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const updateChatWorkspaceBinding = `-- name: UpdateChatWorkspaceBinding :one
|
||||
UPDATE chats SET
|
||||
workspace_id = $1::uuid,
|
||||
|
||||
@@ -220,6 +220,21 @@ WHERE
|
||||
ORDER BY
|
||||
created_at ASC;
|
||||
|
||||
-- name: GetChatMessagesByChatIDAscPaginated :many
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
chat_messages
|
||||
WHERE
|
||||
chat_id = @chat_id::uuid
|
||||
AND id > @after_id::bigint
|
||||
AND visibility IN ('user', 'both')
|
||||
AND deleted = false
|
||||
ORDER BY
|
||||
id ASC
|
||||
LIMIT
|
||||
COALESCE(NULLIF(@limit_val::int, 0), 50);
|
||||
|
||||
-- name: GetChatMessagesByChatIDDescPaginated :many
|
||||
SELECT
|
||||
*
|
||||
@@ -466,6 +481,17 @@ WHERE
|
||||
RETURNING
|
||||
*;
|
||||
|
||||
-- name: UpdateChatLastModelConfigByID :one
|
||||
UPDATE
|
||||
chats
|
||||
SET
|
||||
-- NOTE: updated_at is intentionally NOT touched here to avoid changing list ordering.
|
||||
last_model_config_id = @last_model_config_id::uuid
|
||||
WHERE
|
||||
id = @id::uuid
|
||||
RETURNING
|
||||
*;
|
||||
|
||||
-- name: UpdateChatLabelsByID :one
|
||||
UPDATE
|
||||
chats
|
||||
@@ -550,6 +576,21 @@ WHERE
|
||||
RETURNING
|
||||
*;
|
||||
|
||||
-- name: UpdateChatStatusPreserveUpdatedAt :one
|
||||
UPDATE
|
||||
chats
|
||||
SET
|
||||
status = @status::chat_status,
|
||||
worker_id = sqlc.narg('worker_id')::uuid,
|
||||
started_at = sqlc.narg('started_at')::timestamptz,
|
||||
heartbeat_at = sqlc.narg('heartbeat_at')::timestamptz,
|
||||
last_error = sqlc.narg('last_error')::text,
|
||||
updated_at = @updated_at::timestamptz
|
||||
WHERE
|
||||
id = @id::uuid
|
||||
RETURNING
|
||||
*;
|
||||
|
||||
-- name: GetStaleChats :many
|
||||
-- Find chats that appear stuck (running but heartbeat has expired).
|
||||
-- Used for recovery after coderd crashes or long hangs.
|
||||
|
||||
@@ -2105,6 +2105,50 @@ func (api *API) interruptChat(rw http.ResponseWriter, r *http.Request) {
|
||||
httpapi.Write(ctx, rw, http.StatusOK, db2sdk.Chat(chat, nil))
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
//
|
||||
//nolint:revive // HTTP handler writes to ResponseWriter.
|
||||
func (api *API) regenerateChatTitle(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
chat := httpmw.ChatParam(r)
|
||||
|
||||
if !api.Authorize(r, policy.ActionUpdate, chat.RBACObject()) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
if api.chatDaemon == nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Chat processor is unavailable.",
|
||||
Detail: "Chat processor is not configured.",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
updatedChat, err := api.chatDaemon.RegenerateChatTitle(ctx, chat)
|
||||
if err != nil {
|
||||
if errors.Is(err, chatd.ErrManualTitleRegenerationInProgress) {
|
||||
httpapi.Write(ctx, rw, http.StatusConflict, codersdk.Response{
|
||||
Message: "Title regeneration already in progress for this chat.",
|
||||
})
|
||||
return
|
||||
}
|
||||
if maybeWriteLimitErr(ctx, rw, err) {
|
||||
return
|
||||
}
|
||||
if httpapi.Is404Error(err) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to regenerate chat title.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, db2sdk.Chat(updatedChat, nil))
|
||||
}
|
||||
|
||||
// EXPERIMENTAL: this endpoint is experimental and is subject to change.
|
||||
//
|
||||
//nolint:revive // HTTP handler writes to ResponseWriter.
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/shopspring/decimal"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest/oidctest"
|
||||
@@ -31,6 +32,7 @@ import (
|
||||
"github.com/coder/coder/v2/coderd/externalauth"
|
||||
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
|
||||
"github.com/coder/coder/v2/coderd/rbac"
|
||||
"github.com/coder/coder/v2/coderd/rbac/policy"
|
||||
"github.com/coder/coder/v2/coderd/util/ptr"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprompt"
|
||||
@@ -3441,6 +3443,268 @@ func TestInterruptChat(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestRegenerateChatTitle(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("ChatNotFound", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client.Client)
|
||||
|
||||
_, err := client.RegenerateChatTitle(ctx, uuid.New())
|
||||
requireSDKError(t, err, http.StatusNotFound)
|
||||
})
|
||||
|
||||
t.Run("UpdateDenied", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
clientRaw, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{
|
||||
Authorizer: &coderdtest.FakeAuthorizer{
|
||||
ConditionalReturn: func(_ context.Context, _ rbac.Subject, action policy.Action, object rbac.Object) error {
|
||||
if action == policy.ActionUpdate && object.Type == rbac.ResourceChat.Type {
|
||||
return xerrors.New("denied")
|
||||
}
|
||||
return nil
|
||||
},
|
||||
},
|
||||
DeploymentValues: chatDeploymentValues(t),
|
||||
})
|
||||
client := codersdk.NewExperimentalClient(clientRaw)
|
||||
user := coderdtest.CreateFirstUser(t, client.Client)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
|
||||
chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
|
||||
OwnerID: user.UserID,
|
||||
LastModelConfigID: modelConfig.ID,
|
||||
Title: "chat with update denied",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.RegenerateChatTitle(ctx, chat.ID)
|
||||
requireSDKError(t, err, http.StatusNotFound)
|
||||
})
|
||||
|
||||
t.Run("NotFoundForDifferentUser", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, client.Client)
|
||||
_ = createChatModelConfig(t, client)
|
||||
|
||||
createdChat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "private chat",
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
otherClientRaw, _ := coderdtest.CreateAnotherUser(t, client.Client, firstUser.OrganizationID)
|
||||
otherClient := codersdk.NewExperimentalClient(otherClientRaw)
|
||||
_, err = otherClient.RegenerateChatTitle(ctx, createdChat.ID)
|
||||
requireSDKError(t, err, http.StatusNotFound)
|
||||
})
|
||||
|
||||
t.Run("Unauthenticated", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client.Client)
|
||||
_ = createChatModelConfig(t, client)
|
||||
|
||||
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "chat for unauthenticated regeneration",
|
||||
}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
unauthenticatedClient := codersdk.NewExperimentalClient(codersdk.New(client.URL))
|
||||
_, err = unauthenticatedClient.RegenerateChatTitle(ctx, chat.ID)
|
||||
requireSDKError(t, err, http.StatusUnauthorized)
|
||||
})
|
||||
|
||||
t.Run("UsageLimitExceeded", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client.Client)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
|
||||
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "chat over usage limit",
|
||||
}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
wantResetsAt := enableDailyChatUsageLimit(ctx, t, db, 100)
|
||||
insertAssistantCostMessage(ctx, t, db, chat.ID, modelConfig.ID, 100)
|
||||
|
||||
_, err = db.UpdateChatStatus(dbauthz.AsSystemRestricted(ctx), database.UpdateChatStatusParams{
|
||||
ID: chat.ID,
|
||||
Status: database.ChatStatusCompleted,
|
||||
WorkerID: uuid.NullUUID{},
|
||||
StartedAt: sql.NullTime{},
|
||||
HeartbeatAt: sql.NullTime{},
|
||||
LastError: sql.NullString{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.RegenerateChatTitle(ctx, chat.ID)
|
||||
limitErr := codersdk.ChatUsageLimitExceededFrom(err)
|
||||
require.NotNil(t, limitErr)
|
||||
require.Equal(t, "Chat usage limit exceeded.", limitErr.Message)
|
||||
require.Equal(t, int64(100), limitErr.SpentMicros)
|
||||
require.Equal(t, int64(100), limitErr.LimitMicros)
|
||||
require.True(
|
||||
t,
|
||||
limitErr.ResetsAt.Equal(wantResetsAt),
|
||||
"expected resets_at %s, got %s",
|
||||
wantResetsAt.UTC().Format(time.RFC3339),
|
||||
limitErr.ResetsAt.UTC().Format(time.RFC3339),
|
||||
)
|
||||
})
|
||||
|
||||
t.Run("AlreadyInProgress", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
user := coderdtest.CreateFirstUser(t, client.Client)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
|
||||
chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
|
||||
OwnerID: user.UserID,
|
||||
LastModelConfigID: modelConfig.ID,
|
||||
Title: "chat with lock held",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.UpdateChatStatus(dbauthz.AsSystemRestricted(ctx), database.UpdateChatStatusParams{
|
||||
ID: chat.ID,
|
||||
Status: database.ChatStatusCompleted,
|
||||
WorkerID: uuid.NullUUID{UUID: uuid.MustParse("00000000-0000-0000-0000-000000000001"), Valid: true},
|
||||
StartedAt: sql.NullTime{Time: time.Now(), Valid: true},
|
||||
HeartbeatAt: sql.NullTime{Time: time.Now(), Valid: true},
|
||||
LastError: sql.NullString{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
res, err := client.Request(
|
||||
ctx,
|
||||
http.MethodPost,
|
||||
fmt.Sprintf("/api/experimental/chats/%s/title/regenerate", chat.ID),
|
||||
nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusConflict, res.StatusCode)
|
||||
|
||||
var resp codersdk.Response
|
||||
require.NoError(t, json.NewDecoder(res.Body).Decode(&resp))
|
||||
require.Equal(t, "Title regeneration already in progress for this chat.", resp.Message)
|
||||
})
|
||||
|
||||
t.Run("PendingWithoutWorker", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
user := coderdtest.CreateFirstUser(t, client.Client)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
|
||||
chat, err := db.InsertChat(dbauthz.AsSystemRestricted(ctx), database.InsertChatParams{
|
||||
OwnerID: user.UserID,
|
||||
LastModelConfigID: modelConfig.ID,
|
||||
Title: "pending chat without worker",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err = db.UpdateChatStatus(dbauthz.AsSystemRestricted(ctx), database.UpdateChatStatusParams{
|
||||
ID: chat.ID,
|
||||
Status: database.ChatStatusPending,
|
||||
WorkerID: uuid.NullUUID{},
|
||||
StartedAt: sql.NullTime{},
|
||||
HeartbeatAt: sql.NullTime{},
|
||||
LastError: sql.NullString{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
before, err := db.GetChatByID(dbauthz.AsSystemRestricted(ctx), chat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
res, err := client.Request(
|
||||
ctx,
|
||||
http.MethodPost,
|
||||
fmt.Sprintf("/api/experimental/chats/%s/title/regenerate", chat.ID),
|
||||
nil,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer res.Body.Close()
|
||||
require.Equal(t, http.StatusConflict, res.StatusCode)
|
||||
|
||||
var resp codersdk.Response
|
||||
require.NoError(t, json.NewDecoder(res.Body).Decode(&resp))
|
||||
require.Equal(t, "Title regeneration already in progress for this chat.", resp.Message)
|
||||
|
||||
persisted, err := db.GetChatByID(dbauthz.AsSystemRestricted(ctx), chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, database.ChatStatusPending, persisted.Status)
|
||||
require.False(t, persisted.WorkerID.Valid)
|
||||
require.True(t, persisted.UpdatedAt.Equal(before.UpdatedAt))
|
||||
})
|
||||
|
||||
t.Run("RegenerationFailure", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client.Client)
|
||||
_ = createChatModelConfig(t, client)
|
||||
|
||||
chat, err := client.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
Content: []codersdk.ChatInputPart{
|
||||
{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
Text: "test chat",
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = db.UpdateChatStatus(dbauthz.AsSystemRestricted(ctx), database.UpdateChatStatusParams{
|
||||
ID: chat.ID,
|
||||
Status: database.ChatStatusCompleted,
|
||||
WorkerID: uuid.NullUUID{},
|
||||
StartedAt: sql.NullTime{},
|
||||
HeartbeatAt: sql.NullTime{},
|
||||
LastError: sql.NullString{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
before, err := db.GetChatByID(dbauthz.AsSystemRestricted(ctx), chat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.RegenerateChatTitle(ctx, chat.ID)
|
||||
requireSDKError(t, err, http.StatusInternalServerError)
|
||||
|
||||
after, err := db.GetChatByID(dbauthz.AsSystemRestricted(ctx), chat.ID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, after.UpdatedAt.Equal(before.UpdatedAt))
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetChatDiffStatus(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
+499
-37
@@ -1333,6 +1333,477 @@ func (p *Server) InterruptChat(
|
||||
return updatedChat
|
||||
}
|
||||
|
||||
const manualTitleMessageWindowLimit = 50
|
||||
|
||||
var ErrManualTitleRegenerationInProgress = xerrors.New(
|
||||
"manual title regeneration already in progress",
|
||||
)
|
||||
|
||||
type manualTitleGenerationError struct {
|
||||
cause error
|
||||
modelConfig database.ChatModelConfig
|
||||
usage fantasy.Usage
|
||||
}
|
||||
|
||||
func (e *manualTitleGenerationError) Error() string {
|
||||
return e.cause.Error()
|
||||
}
|
||||
|
||||
func (e *manualTitleGenerationError) Unwrap() error {
|
||||
return e.cause
|
||||
}
|
||||
|
||||
var manualTitleLockWorkerID = uuid.MustParse(
|
||||
"00000000-0000-0000-0000-000000000001",
|
||||
)
|
||||
|
||||
const manualTitleLockStaleAfter = time.Minute
|
||||
|
||||
func isPendingOrRunningChatStatus(status database.ChatStatus) bool {
|
||||
switch status {
|
||||
case database.ChatStatusPending, database.ChatStatusRunning:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func isFreshManualTitleLock(chat database.Chat, now time.Time) bool {
|
||||
if !chat.WorkerID.Valid || chat.WorkerID.UUID != manualTitleLockWorkerID {
|
||||
return false
|
||||
}
|
||||
leaseAt := chat.HeartbeatAt
|
||||
if !leaseAt.Valid {
|
||||
leaseAt = chat.StartedAt
|
||||
}
|
||||
return leaseAt.Valid && leaseAt.Time.After(now.Add(-manualTitleLockStaleAfter))
|
||||
}
|
||||
|
||||
// updateChatStatusPreserveUpdatedAt applies internal lock transitions without
|
||||
// changing chat recency, because chat list ordering uses updated_at.
|
||||
func updateChatStatusPreserveUpdatedAt(
|
||||
ctx context.Context,
|
||||
store database.Store,
|
||||
chat database.Chat,
|
||||
workerID uuid.NullUUID,
|
||||
startedAt sql.NullTime,
|
||||
heartbeatAt sql.NullTime,
|
||||
) (database.Chat, error) {
|
||||
return store.UpdateChatStatusPreserveUpdatedAt(
|
||||
ctx,
|
||||
database.UpdateChatStatusPreserveUpdatedAtParams{
|
||||
ID: chat.ID,
|
||||
Status: chat.Status,
|
||||
WorkerID: workerID,
|
||||
StartedAt: startedAt,
|
||||
HeartbeatAt: heartbeatAt,
|
||||
LastError: chat.LastError,
|
||||
UpdatedAt: chat.UpdatedAt,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
func (p *Server) acquireManualTitleLock(ctx context.Context, chatID uuid.UUID) error {
|
||||
now := time.Now()
|
||||
return p.db.InTx(func(tx database.Store) error {
|
||||
lockedChat, err := tx.GetChatByIDForUpdate(ctx, chatID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("lock chat for manual title regeneration: %w", err)
|
||||
}
|
||||
if isPendingOrRunningChatStatus(lockedChat.Status) ||
|
||||
isFreshManualTitleLock(lockedChat, now) {
|
||||
return ErrManualTitleRegenerationInProgress
|
||||
}
|
||||
_, err = updateChatStatusPreserveUpdatedAt(
|
||||
ctx,
|
||||
tx,
|
||||
lockedChat,
|
||||
uuid.NullUUID{UUID: manualTitleLockWorkerID, Valid: true},
|
||||
sql.NullTime{Time: now, Valid: true},
|
||||
sql.NullTime{Time: now, Valid: true},
|
||||
)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("mark chat for manual title regeneration: %w", err)
|
||||
}
|
||||
return nil
|
||||
}, database.DefaultTXOptions().WithID("chat_title_regenerate_lock"))
|
||||
}
|
||||
|
||||
func (p *Server) releaseManualTitleLock(ctx context.Context, chatID uuid.UUID) {
|
||||
cleanupCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := p.db.InTx(func(tx database.Store) error {
|
||||
lockedChat, err := tx.GetChatByIDForUpdate(cleanupCtx, chatID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("lock chat to release manual title regeneration: %w", err)
|
||||
}
|
||||
if !lockedChat.WorkerID.Valid || lockedChat.WorkerID.UUID != manualTitleLockWorkerID {
|
||||
return nil
|
||||
}
|
||||
_, err = updateChatStatusPreserveUpdatedAt(
|
||||
cleanupCtx,
|
||||
tx,
|
||||
lockedChat,
|
||||
uuid.NullUUID{},
|
||||
sql.NullTime{},
|
||||
sql.NullTime{},
|
||||
)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("clear manual title regeneration marker: %w", err)
|
||||
}
|
||||
return nil
|
||||
}, database.DefaultTXOptions().WithID("chat_title_regenerate_unlock"))
|
||||
if err != nil {
|
||||
p.logger.Warn(cleanupCtx, "failed to release manual title regeneration marker",
|
||||
slog.F("chat_id", chatID),
|
||||
slog.Error(err),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// RegenerateChatTitle regenerates a chat title from the chat's visible
|
||||
// messages, persists it when it changes, and broadcasts the update.
|
||||
func (p *Server) RegenerateChatTitle(
|
||||
ctx context.Context,
|
||||
chat database.Chat,
|
||||
) (database.Chat, error) {
|
||||
// Reuse chatd's scoped auth context for deployment-config lookups while
|
||||
// keeping chat ownership authorization at the HTTP layer.
|
||||
//nolint:gocritic // Non-admin users need chatd-scoped config reads here.
|
||||
chatdCtx := dbauthz.AsChatd(ctx)
|
||||
keys, err := p.resolveProviderAPIKeys(chatdCtx)
|
||||
if err != nil {
|
||||
return database.Chat{}, xerrors.Errorf("resolve chat providers: %w", err)
|
||||
}
|
||||
if err := p.acquireManualTitleLock(ctx, chat.ID); err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
defer p.releaseManualTitleLock(chatdCtx, chat.ID)
|
||||
|
||||
updatedChat, err := p.regenerateChatTitleWithStore(
|
||||
chatdCtx,
|
||||
p.db,
|
||||
chat,
|
||||
keys,
|
||||
)
|
||||
if err != nil {
|
||||
var generationErr *manualTitleGenerationError
|
||||
if errors.As(err, &generationErr) {
|
||||
// Reuse chatd's scoped auth context for failure accounting while
|
||||
// detaching from request cancellation so usage is still recorded.
|
||||
//nolint:gocritic // Failure accounting still needs chatd-scoped config reads.
|
||||
recordCtx, recordCancel := context.WithTimeout(
|
||||
dbauthz.AsChatd(context.WithoutCancel(ctx)),
|
||||
5*time.Second,
|
||||
)
|
||||
defer recordCancel()
|
||||
if _, recordErr := recordManualTitleUsage(
|
||||
recordCtx,
|
||||
p.db,
|
||||
chat,
|
||||
generationErr.modelConfig,
|
||||
generationErr.usage,
|
||||
"",
|
||||
); recordErr != nil {
|
||||
return database.Chat{}, errors.Join(
|
||||
generationErr,
|
||||
xerrors.Errorf("record manual title usage: %w", recordErr),
|
||||
)
|
||||
}
|
||||
return database.Chat{}, generationErr
|
||||
}
|
||||
return database.Chat{}, err
|
||||
}
|
||||
return updatedChat, nil
|
||||
}
|
||||
|
||||
func (p *Server) regenerateChatTitleWithStore(
|
||||
ctx context.Context,
|
||||
store database.Store,
|
||||
chat database.Chat,
|
||||
keys chatprovider.ProviderAPIKeys,
|
||||
) (database.Chat, error) {
|
||||
if limitErr := p.checkUsageLimit(ctx, store, chat.OwnerID); limitErr != nil {
|
||||
return database.Chat{}, limitErr
|
||||
}
|
||||
|
||||
headMessages, err := store.GetChatMessagesByChatIDAscPaginated(
|
||||
ctx,
|
||||
database.GetChatMessagesByChatIDAscPaginatedParams{
|
||||
ChatID: chat.ID,
|
||||
AfterID: 0,
|
||||
LimitVal: manualTitleMessageWindowLimit,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return database.Chat{}, xerrors.Errorf("get head chat messages: %w", err)
|
||||
}
|
||||
tailMessages, err := store.GetChatMessagesByChatIDDescPaginated(
|
||||
ctx,
|
||||
database.GetChatMessagesByChatIDDescPaginatedParams{
|
||||
ChatID: chat.ID,
|
||||
BeforeID: 0,
|
||||
LimitVal: manualTitleMessageWindowLimit,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return database.Chat{}, xerrors.Errorf("get tail chat messages: %w", err)
|
||||
}
|
||||
messages := mergeManualTitleMessages(headMessages, tailMessages)
|
||||
if len(messages) == 0 {
|
||||
return chat, nil
|
||||
}
|
||||
|
||||
model, modelConfig, err := p.resolveManualTitleModel(ctx, store, chat, keys)
|
||||
if err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
|
||||
title, usage, err := generateManualTitle(ctx, messages, model)
|
||||
if err != nil {
|
||||
wrappedErr := xerrors.Errorf("generate manual title: %w", err)
|
||||
if usage == (fantasy.Usage{}) {
|
||||
return database.Chat{}, wrappedErr
|
||||
}
|
||||
return database.Chat{}, &manualTitleGenerationError{
|
||||
cause: wrappedErr,
|
||||
modelConfig: modelConfig,
|
||||
usage: usage,
|
||||
}
|
||||
}
|
||||
|
||||
recordCtx, recordCancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second)
|
||||
defer recordCancel()
|
||||
|
||||
updatedChat, recordErr := recordManualTitleUsage(
|
||||
recordCtx,
|
||||
store,
|
||||
chat,
|
||||
modelConfig,
|
||||
usage,
|
||||
title,
|
||||
)
|
||||
if recordErr != nil {
|
||||
if title != "" {
|
||||
return database.Chat{}, xerrors.Errorf("record manual title usage and update chat title: %w", recordErr)
|
||||
}
|
||||
return database.Chat{}, xerrors.Errorf("record manual title usage: %w", recordErr)
|
||||
}
|
||||
if updatedChat.Title == chat.Title {
|
||||
return updatedChat, nil
|
||||
}
|
||||
|
||||
p.publishChatPubsubEvent(updatedChat, coderdpubsub.ChatEventKindTitleChange, nil)
|
||||
return updatedChat, nil
|
||||
}
|
||||
|
||||
func (p *Server) resolveManualTitleModel(
|
||||
ctx context.Context,
|
||||
store database.Store,
|
||||
chat database.Chat,
|
||||
keys chatprovider.ProviderAPIKeys,
|
||||
) (fantasy.LanguageModel, database.ChatModelConfig, error) {
|
||||
configs, err := store.GetEnabledChatModelConfigs(ctx)
|
||||
if err != nil {
|
||||
p.logger.Debug(ctx, "failed to list manual title model configs",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.Error(err),
|
||||
)
|
||||
return p.resolveFallbackManualTitleModel(ctx, chat, keys)
|
||||
}
|
||||
|
||||
config, ok := selectPreferredConfiguredShortTextModelConfig(configs)
|
||||
if !ok {
|
||||
return p.resolveFallbackManualTitleModel(ctx, chat, keys)
|
||||
}
|
||||
|
||||
model, err := chatprovider.ModelFromConfig(
|
||||
config.Provider,
|
||||
config.Model,
|
||||
keys,
|
||||
chatprovider.UserAgent(),
|
||||
chatprovider.CoderHeaders(chat),
|
||||
)
|
||||
if err != nil {
|
||||
p.logger.Debug(ctx, "manual title preferred model unavailable",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.F("provider", config.Provider),
|
||||
slog.F("model", config.Model),
|
||||
slog.Error(err),
|
||||
)
|
||||
return p.resolveFallbackManualTitleModel(ctx, chat, keys)
|
||||
}
|
||||
|
||||
return model, config, nil
|
||||
}
|
||||
|
||||
func (p *Server) resolveFallbackManualTitleModel(
|
||||
ctx context.Context,
|
||||
chat database.Chat,
|
||||
keys chatprovider.ProviderAPIKeys,
|
||||
) (fantasy.LanguageModel, database.ChatModelConfig, error) {
|
||||
config, err := p.resolveModelConfig(ctx, chat)
|
||||
if err != nil {
|
||||
return nil, database.ChatModelConfig{}, xerrors.Errorf(
|
||||
"resolve fallback manual title model config: %w",
|
||||
err,
|
||||
)
|
||||
}
|
||||
model, err := chatprovider.ModelFromConfig(
|
||||
config.Provider,
|
||||
config.Model,
|
||||
keys,
|
||||
chatprovider.UserAgent(),
|
||||
chatprovider.CoderHeaders(chat),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, database.ChatModelConfig{}, xerrors.Errorf(
|
||||
"create fallback manual title model: %w",
|
||||
err,
|
||||
)
|
||||
}
|
||||
return model, config, nil
|
||||
}
|
||||
|
||||
func mergeManualTitleMessages(
|
||||
headMessages []database.ChatMessage,
|
||||
tailMessagesDesc []database.ChatMessage,
|
||||
) []database.ChatMessage {
|
||||
merged := make([]database.ChatMessage, 0, len(headMessages)+len(tailMessagesDesc))
|
||||
seen := make(map[int64]struct{}, len(headMessages)+len(tailMessagesDesc))
|
||||
appendUnique := func(message database.ChatMessage) {
|
||||
if _, ok := seen[message.ID]; ok {
|
||||
return
|
||||
}
|
||||
seen[message.ID] = struct{}{}
|
||||
merged = append(merged, message)
|
||||
}
|
||||
for _, message := range headMessages {
|
||||
appendUnique(message)
|
||||
}
|
||||
for i := len(tailMessagesDesc) - 1; i >= 0; i-- {
|
||||
appendUnique(tailMessagesDesc[i])
|
||||
}
|
||||
return merged
|
||||
}
|
||||
|
||||
func fantasyUsageToChatMessageUsage(usage fantasy.Usage) codersdk.ChatMessageUsage {
|
||||
var chatUsage codersdk.ChatMessageUsage
|
||||
if usage.InputTokens != 0 {
|
||||
chatUsage.InputTokens = ptr.Ref(usage.InputTokens)
|
||||
}
|
||||
if usage.OutputTokens != 0 {
|
||||
chatUsage.OutputTokens = ptr.Ref(usage.OutputTokens)
|
||||
}
|
||||
if usage.ReasoningTokens != 0 {
|
||||
chatUsage.ReasoningTokens = ptr.Ref(usage.ReasoningTokens)
|
||||
}
|
||||
if usage.CacheCreationTokens != 0 {
|
||||
chatUsage.CacheCreationTokens = ptr.Ref(usage.CacheCreationTokens)
|
||||
}
|
||||
if usage.CacheReadTokens != 0 {
|
||||
chatUsage.CacheReadTokens = ptr.Ref(usage.CacheReadTokens)
|
||||
}
|
||||
return chatUsage
|
||||
}
|
||||
|
||||
func recordManualTitleUsage(
|
||||
ctx context.Context,
|
||||
store database.Store,
|
||||
chat database.Chat,
|
||||
modelConfig database.ChatModelConfig,
|
||||
usage fantasy.Usage,
|
||||
newTitle string,
|
||||
) (database.Chat, error) {
|
||||
hasUsage := usage != (fantasy.Usage{})
|
||||
if !hasUsage && newTitle == "" {
|
||||
return chat, nil
|
||||
}
|
||||
|
||||
var totalCostMicros *int64
|
||||
if hasUsage {
|
||||
callConfig := codersdk.ChatModelCallConfig{}
|
||||
if len(modelConfig.Options) > 0 {
|
||||
if err := json.Unmarshal(modelConfig.Options, &callConfig); err != nil {
|
||||
return database.Chat{}, xerrors.Errorf("parse model call config: %w", err)
|
||||
}
|
||||
}
|
||||
totalCostMicros = chatcost.CalculateTotalCostMicros(
|
||||
fantasyUsageToChatMessageUsage(usage),
|
||||
callConfig.Cost,
|
||||
)
|
||||
}
|
||||
|
||||
// Use a valid empty JSON array for the content column.
|
||||
// MarshalParts returns a null NullRawMessage for empty
|
||||
// slices, which becomes an empty string that PostgreSQL
|
||||
// rejects as invalid JSON.
|
||||
content := "[]"
|
||||
|
||||
updatedChat := chat
|
||||
err := store.InTx(func(tx database.Store) error {
|
||||
lockedChat, err := tx.GetChatByIDForUpdate(ctx, chat.ID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("lock chat for manual title usage: %w", err)
|
||||
}
|
||||
updatedChat = lockedChat
|
||||
if hasUsage {
|
||||
messages, err := tx.InsertChatMessages(ctx, database.InsertChatMessagesParams{
|
||||
ChatID: chat.ID,
|
||||
CreatedBy: []uuid.UUID{chat.OwnerID},
|
||||
ModelConfigID: []uuid.UUID{modelConfig.ID},
|
||||
Role: []database.ChatMessageRole{database.ChatMessageRoleAssistant},
|
||||
Content: []string{content},
|
||||
ContentVersion: []int16{chatprompt.CurrentContentVersion},
|
||||
Visibility: []database.ChatMessageVisibility{database.ChatMessageVisibilityModel},
|
||||
InputTokens: []int64{usage.InputTokens},
|
||||
OutputTokens: []int64{usage.OutputTokens},
|
||||
TotalTokens: []int64{usage.TotalTokens},
|
||||
ReasoningTokens: []int64{usage.ReasoningTokens},
|
||||
CacheCreationTokens: []int64{usage.CacheCreationTokens},
|
||||
CacheReadTokens: []int64{usage.CacheReadTokens},
|
||||
ContextLimit: []int64{modelConfig.ContextLimit},
|
||||
Compressed: []bool{false},
|
||||
TotalCostMicros: []int64{ptr.NilToDefault(totalCostMicros, 0)},
|
||||
RuntimeMs: []int64{0},
|
||||
ProviderResponseID: []string{""},
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("insert manual title usage message: %w", err)
|
||||
}
|
||||
if len(messages) != 1 {
|
||||
return xerrors.Errorf("expected 1 manual title usage message, got %d", len(messages))
|
||||
}
|
||||
if err := tx.SoftDeleteChatMessageByID(ctx, messages[0].ID); err != nil {
|
||||
return xerrors.Errorf("soft delete manual title usage message: %w", err)
|
||||
}
|
||||
if lockedChat.LastModelConfigID != modelConfig.ID {
|
||||
if _, err := tx.UpdateChatLastModelConfigByID(ctx, database.UpdateChatLastModelConfigByIDParams{
|
||||
ID: chat.ID,
|
||||
LastModelConfigID: lockedChat.LastModelConfigID,
|
||||
}); err != nil {
|
||||
return xerrors.Errorf("restore chat model config after manual title usage: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
if newTitle != "" && lockedChat.Title == chat.Title && newTitle != lockedChat.Title {
|
||||
updatedChat, err = tx.UpdateChatByID(ctx, database.UpdateChatByIDParams{
|
||||
ID: chat.ID,
|
||||
Title: newTitle,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("update chat title: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}, nil)
|
||||
if err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
return updatedChat, nil
|
||||
}
|
||||
|
||||
// RefreshStatus loads the latest chat status and publishes it to stream subscribers.
|
||||
func (p *Server) RefreshStatus(ctx context.Context, chatID uuid.UUID) error {
|
||||
if chatID == uuid.Nil {
|
||||
@@ -3475,24 +3946,7 @@ func (p *Server) runChat(
|
||||
}
|
||||
|
||||
hasUsage := step.Usage != (fantasy.Usage{})
|
||||
var usageForCost codersdk.ChatMessageUsage
|
||||
if hasUsage {
|
||||
if step.Usage.InputTokens != 0 {
|
||||
usageForCost.InputTokens = ptr.Ref(step.Usage.InputTokens)
|
||||
}
|
||||
if step.Usage.OutputTokens != 0 {
|
||||
usageForCost.OutputTokens = ptr.Ref(step.Usage.OutputTokens)
|
||||
}
|
||||
if step.Usage.ReasoningTokens != 0 {
|
||||
usageForCost.ReasoningTokens = ptr.Ref(step.Usage.ReasoningTokens)
|
||||
}
|
||||
if step.Usage.CacheCreationTokens != 0 {
|
||||
usageForCost.CacheCreationTokens = ptr.Ref(step.Usage.CacheCreationTokens)
|
||||
}
|
||||
if step.Usage.CacheReadTokens != 0 {
|
||||
usageForCost.CacheReadTokens = ptr.Ref(step.Usage.CacheReadTokens)
|
||||
}
|
||||
}
|
||||
usageForCost := fantasyUsageToChatMessageUsage(step.Usage)
|
||||
totalCostMicros := chatcost.CalculateTotalCostMicros(usageForCost, callConfig.Cost)
|
||||
|
||||
var insertedMessages []database.ChatMessage
|
||||
@@ -4078,10 +4532,8 @@ func (p *Server) resolveChatModel(
|
||||
ctx context.Context,
|
||||
chat database.Chat,
|
||||
) (fantasy.LanguageModel, database.ChatModelConfig, chatprovider.ProviderAPIKeys, error) {
|
||||
var (
|
||||
dbConfig database.ChatModelConfig
|
||||
providers []database.ChatProvider
|
||||
)
|
||||
var dbConfig database.ChatModelConfig
|
||||
var keys chatprovider.ProviderAPIKeys
|
||||
|
||||
var g errgroup.Group
|
||||
g.Go(func() error {
|
||||
@@ -4094,28 +4546,15 @@ func (p *Server) resolveChatModel(
|
||||
})
|
||||
g.Go(func() error {
|
||||
var err error
|
||||
providers, err = p.configCache.EnabledProviders(ctx)
|
||||
keys, err = p.resolveProviderAPIKeys(ctx)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get enabled chat providers: %w", err)
|
||||
return xerrors.Errorf("resolve provider API keys: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err := g.Wait(); err != nil {
|
||||
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, err
|
||||
}
|
||||
dbProviders := make(
|
||||
[]chatprovider.ConfiguredProvider, 0, len(providers),
|
||||
)
|
||||
for _, provider := range providers {
|
||||
dbProviders = append(dbProviders, chatprovider.ConfiguredProvider{
|
||||
Provider: provider.Provider,
|
||||
APIKey: provider.APIKey,
|
||||
BaseURL: provider.BaseUrl,
|
||||
})
|
||||
}
|
||||
keys := chatprovider.MergeProviderAPIKeys(
|
||||
p.providerAPIKeys, dbProviders,
|
||||
)
|
||||
|
||||
model, err := chatprovider.ModelFromConfig(
|
||||
dbConfig.Provider, dbConfig.Model, keys, chatprovider.UserAgent(),
|
||||
@@ -4129,6 +4568,29 @@ func (p *Server) resolveChatModel(
|
||||
return model, dbConfig, keys, nil
|
||||
}
|
||||
|
||||
func (p *Server) resolveProviderAPIKeys(
|
||||
ctx context.Context,
|
||||
) (chatprovider.ProviderAPIKeys, error) {
|
||||
providers, err := p.configCache.EnabledProviders(ctx)
|
||||
if err != nil {
|
||||
return chatprovider.ProviderAPIKeys{}, xerrors.Errorf(
|
||||
"get enabled chat providers: %w",
|
||||
err,
|
||||
)
|
||||
}
|
||||
dbProviders := make(
|
||||
[]chatprovider.ConfiguredProvider, 0, len(providers),
|
||||
)
|
||||
for _, provider := range providers {
|
||||
dbProviders = append(dbProviders, chatprovider.ConfiguredProvider{
|
||||
Provider: provider.Provider,
|
||||
APIKey: provider.APIKey,
|
||||
BaseURL: provider.BaseUrl,
|
||||
})
|
||||
}
|
||||
return chatprovider.MergeProviderAPIKeys(p.providerAPIKeys, dbProviders), nil
|
||||
}
|
||||
|
||||
// resolveModelConfig looks up the chat's model config by its
|
||||
// LastModelConfigID. If the referenced config no longer exists
|
||||
// (e.g. it was deleted), it falls back to the default model
|
||||
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
dbpubsub "github.com/coder/coder/v2/coderd/database/pubsub"
|
||||
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chaterror"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chattest"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk"
|
||||
"github.com/coder/coder/v2/codersdk/workspacesdk/agentconnmock"
|
||||
@@ -30,6 +31,183 @@ import (
|
||||
"github.com/coder/quartz"
|
||||
)
|
||||
|
||||
func TestRegenerateChatTitle_PersistsAndBroadcasts(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
lockTx := dbmock.NewMockStore(ctrl)
|
||||
usageTx := dbmock.NewMockStore(ctrl)
|
||||
unlockTx := dbmock.NewMockStore(ctrl)
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
pubsub := dbpubsub.NewInMemory()
|
||||
clock := quartz.NewReal()
|
||||
|
||||
ownerID := uuid.New()
|
||||
chatID := uuid.New()
|
||||
modelConfigID := uuid.New()
|
||||
userPrompt := "review pull request 23633 and fix review threads"
|
||||
wantTitle := "Review PR 23633"
|
||||
|
||||
chat := database.Chat{
|
||||
ID: chatID,
|
||||
OwnerID: ownerID,
|
||||
LastModelConfigID: modelConfigID,
|
||||
Status: database.ChatStatusCompleted,
|
||||
Title: fallbackChatTitle(userPrompt),
|
||||
}
|
||||
modelConfig := database.ChatModelConfig{
|
||||
ID: modelConfigID,
|
||||
Provider: "anthropic",
|
||||
Model: "claude-haiku-4-5",
|
||||
ContextLimit: 8192,
|
||||
}
|
||||
updatedChat := chat
|
||||
updatedChat.Title = wantTitle
|
||||
|
||||
messageEvents := make(chan struct {
|
||||
payload coderdpubsub.ChatEvent
|
||||
err error
|
||||
}, 1)
|
||||
cancelSub, err := pubsub.SubscribeWithErr(
|
||||
coderdpubsub.ChatEventChannel(ownerID),
|
||||
coderdpubsub.HandleChatEvent(func(_ context.Context, payload coderdpubsub.ChatEvent, err error) {
|
||||
messageEvents <- struct {
|
||||
payload coderdpubsub.ChatEvent
|
||||
err error
|
||||
}{payload: payload, err: err}
|
||||
}),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
defer cancelSub()
|
||||
|
||||
serverURL := chattest.NewAnthropic(t, func(req *chattest.AnthropicRequest) chattest.AnthropicResponse {
|
||||
require.Equal(t, "claude-haiku-4-5", req.Model)
|
||||
return chattest.AnthropicNonStreamingResponse(wantTitle)
|
||||
})
|
||||
|
||||
server := &Server{
|
||||
db: db,
|
||||
logger: logger,
|
||||
pubsub: pubsub,
|
||||
configCache: newChatConfigCache(context.Background(), db, clock),
|
||||
}
|
||||
|
||||
db.EXPECT().GetChatModelConfigByID(gomock.Any(), modelConfigID).Return(modelConfig, nil)
|
||||
db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{
|
||||
Provider: "anthropic",
|
||||
APIKey: "test-key",
|
||||
BaseUrl: serverURL,
|
||||
}}, nil)
|
||||
db.EXPECT().GetChatUsageLimitConfig(gomock.Any()).Return(database.ChatUsageLimitConfig{}, sql.ErrNoRows)
|
||||
db.EXPECT().GetChatMessagesByChatIDAscPaginated(
|
||||
gomock.Any(),
|
||||
database.GetChatMessagesByChatIDAscPaginatedParams{
|
||||
ChatID: chatID,
|
||||
AfterID: 0,
|
||||
LimitVal: manualTitleMessageWindowLimit,
|
||||
},
|
||||
).Return([]database.ChatMessage{
|
||||
mustChatMessage(
|
||||
t,
|
||||
database.ChatMessageRoleUser,
|
||||
database.ChatMessageVisibilityBoth,
|
||||
codersdk.ChatMessageText(userPrompt),
|
||||
),
|
||||
mustChatMessage(
|
||||
t,
|
||||
database.ChatMessageRoleAssistant,
|
||||
database.ChatMessageVisibilityBoth,
|
||||
codersdk.ChatMessageText("checking the diff now"),
|
||||
),
|
||||
}, nil)
|
||||
db.EXPECT().GetChatMessagesByChatIDDescPaginated(
|
||||
gomock.Any(),
|
||||
database.GetChatMessagesByChatIDDescPaginatedParams{
|
||||
ChatID: chatID,
|
||||
BeforeID: 0,
|
||||
LimitVal: manualTitleMessageWindowLimit,
|
||||
},
|
||||
).Return(nil, nil)
|
||||
db.EXPECT().GetEnabledChatModelConfigs(gomock.Any()).Return(nil, nil)
|
||||
|
||||
gomock.InOrder(
|
||||
db.EXPECT().InTx(gomock.Any(), database.DefaultTXOptions().WithID("chat_title_regenerate_lock")).DoAndReturn(
|
||||
func(fn func(database.Store) error, opts *database.TxOptions) error {
|
||||
require.Equal(t, "chat_title_regenerate_lock", opts.TxIdentifier)
|
||||
return fn(lockTx)
|
||||
},
|
||||
),
|
||||
db.EXPECT().InTx(gomock.Any(), nil).DoAndReturn(
|
||||
func(fn func(database.Store) error, opts *database.TxOptions) error {
|
||||
require.Nil(t, opts)
|
||||
return fn(usageTx)
|
||||
},
|
||||
),
|
||||
db.EXPECT().InTx(gomock.Any(), database.DefaultTXOptions().WithID("chat_title_regenerate_unlock")).DoAndReturn(
|
||||
func(fn func(database.Store) error, opts *database.TxOptions) error {
|
||||
require.Equal(t, "chat_title_regenerate_unlock", opts.TxIdentifier)
|
||||
return fn(unlockTx)
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
lockTx.EXPECT().GetChatByIDForUpdate(gomock.Any(), chatID).Return(chat, nil)
|
||||
lockTx.EXPECT().UpdateChatStatusPreserveUpdatedAt(gomock.Any(), gomock.AssignableToTypeOf(database.UpdateChatStatusPreserveUpdatedAtParams{})).DoAndReturn(
|
||||
func(_ context.Context, arg database.UpdateChatStatusPreserveUpdatedAtParams) (database.Chat, error) {
|
||||
require.Equal(t, chatID, arg.ID)
|
||||
require.Equal(t, chat.Status, arg.Status)
|
||||
require.Equal(t, uuid.NullUUID{UUID: manualTitleLockWorkerID, Valid: true}, arg.WorkerID)
|
||||
require.True(t, arg.StartedAt.Valid)
|
||||
require.True(t, arg.HeartbeatAt.Valid)
|
||||
return chat, nil
|
||||
},
|
||||
)
|
||||
|
||||
usageTx.EXPECT().GetChatByIDForUpdate(gomock.Any(), chatID).Return(chat, nil)
|
||||
usageTx.EXPECT().InsertChatMessages(gomock.Any(), gomock.AssignableToTypeOf(database.InsertChatMessagesParams{})).DoAndReturn(
|
||||
func(_ context.Context, arg database.InsertChatMessagesParams) ([]database.ChatMessage, error) {
|
||||
require.Equal(t, []uuid.UUID{ownerID}, arg.CreatedBy)
|
||||
require.Equal(t, []uuid.UUID{modelConfigID}, arg.ModelConfigID)
|
||||
require.Equal(t, []string{"[]"}, arg.Content)
|
||||
return []database.ChatMessage{{ID: 91}}, nil
|
||||
},
|
||||
)
|
||||
usageTx.EXPECT().SoftDeleteChatMessageByID(gomock.Any(), int64(91)).Return(nil)
|
||||
usageTx.EXPECT().UpdateChatByID(gomock.Any(), database.UpdateChatByIDParams{
|
||||
ID: chatID,
|
||||
Title: wantTitle,
|
||||
}).Return(updatedChat, nil)
|
||||
|
||||
lockedChatWithMarker := updatedChat
|
||||
lockedChatWithMarker.WorkerID = uuid.NullUUID{UUID: manualTitleLockWorkerID, Valid: true}
|
||||
unlockTx.EXPECT().GetChatByIDForUpdate(gomock.Any(), chatID).Return(lockedChatWithMarker, nil)
|
||||
unlockTx.EXPECT().UpdateChatStatusPreserveUpdatedAt(gomock.Any(), gomock.AssignableToTypeOf(database.UpdateChatStatusPreserveUpdatedAtParams{})).DoAndReturn(
|
||||
func(_ context.Context, arg database.UpdateChatStatusPreserveUpdatedAtParams) (database.Chat, error) {
|
||||
require.Equal(t, chatID, arg.ID)
|
||||
require.False(t, arg.WorkerID.Valid)
|
||||
require.False(t, arg.StartedAt.Valid)
|
||||
require.False(t, arg.HeartbeatAt.Valid)
|
||||
return updatedChat, nil
|
||||
},
|
||||
)
|
||||
|
||||
gotChat, err := server.RegenerateChatTitle(ctx, chat)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, updatedChat, gotChat)
|
||||
|
||||
select {
|
||||
case event := <-messageEvents:
|
||||
require.NoError(t, event.err)
|
||||
require.Equal(t, coderdpubsub.ChatEventKindTitleChange, event.payload.Kind)
|
||||
require.Equal(t, chatID, event.payload.Chat.ID)
|
||||
require.Equal(t, wantTitle, event.payload.Chat.Title)
|
||||
case <-time.After(time.Second):
|
||||
t.Fatal("timed out waiting for title change pubsub event")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRefreshChatWorkspaceSnapshot_NoReloadWhenWorkspacePresent(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
+229
-12
@@ -2,6 +2,8 @@ package chatd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -36,6 +38,17 @@ const titleGenerationPrompt = "You are a title generator. Your ONLY job is to ou
|
||||
"Output ONLY the title — no quotes, no emoji, no markdown, no code fences, " +
|
||||
"no trailing punctuation, no preamble, no explanation. Sentence case."
|
||||
|
||||
const (
|
||||
// maxConversationContextRunes caps the conversation sample in manual
|
||||
// title prompts to avoid exceeding model context windows.
|
||||
maxConversationContextRunes = 6000
|
||||
// maxLatestUserMessageRunes caps the latest user message excerpt.
|
||||
maxLatestUserMessageRunes = 1000
|
||||
// recentTurnWindow is the number of most recent turns included
|
||||
// alongside the first user turn in manual title context.
|
||||
recentTurnWindow = 3
|
||||
)
|
||||
|
||||
// preferredTitleModels are lightweight models used for title
|
||||
// generation, one per provider type. Each entry uses the
|
||||
// cheapest/fastest small model for that provider as identified
|
||||
@@ -54,6 +67,33 @@ var preferredTitleModels = []struct {
|
||||
{fantasyvercel.Name, "anthropic/claude-haiku-4.5"},
|
||||
}
|
||||
|
||||
func selectPreferredConfiguredShortTextModelConfig(
|
||||
configs []database.ChatModelConfig,
|
||||
) (database.ChatModelConfig, bool) {
|
||||
for _, preferred := range preferredTitleModels {
|
||||
for _, config := range configs {
|
||||
if chatprovider.NormalizeProvider(config.Provider) != preferred.provider {
|
||||
continue
|
||||
}
|
||||
if !strings.EqualFold(strings.TrimSpace(config.Model), preferred.model) {
|
||||
continue
|
||||
}
|
||||
return config, true
|
||||
}
|
||||
}
|
||||
return database.ChatModelConfig{}, false
|
||||
}
|
||||
|
||||
func normalizeShortTextOutput(text string) string {
|
||||
text = strings.TrimSpace(text)
|
||||
if text == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
text = strings.Trim(text, "\"'`")
|
||||
return strings.Join(strings.Fields(text), " ")
|
||||
}
|
||||
|
||||
// maybeGenerateChatTitle generates an AI title for the chat when
|
||||
// appropriate (first user message, no assistant reply yet, and the
|
||||
// current title is either empty or still the fallback truncation).
|
||||
@@ -138,7 +178,7 @@ func generateTitle(
|
||||
model fantasy.LanguageModel,
|
||||
input string,
|
||||
) (string, error) {
|
||||
title, err := generateShortText(ctx, model, titleGenerationPrompt, input)
|
||||
title, _, err := generateShortText(ctx, model, titleGenerationPrompt, input)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -199,13 +239,10 @@ func titleInput(
|
||||
}
|
||||
|
||||
func normalizeTitleOutput(title string) string {
|
||||
title = strings.TrimSpace(title)
|
||||
title = normalizeShortTextOutput(title)
|
||||
if title == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
title = strings.Trim(title, "\"'`")
|
||||
title = strings.Join(strings.Fields(title), " ")
|
||||
return truncateRunes(title, 80)
|
||||
}
|
||||
|
||||
@@ -226,7 +263,7 @@ func fallbackChatTitle(message string) string {
|
||||
|
||||
title := strings.Join(words, " ")
|
||||
if truncated {
|
||||
title += "…"
|
||||
return truncateRunes(title, maxRunes-1) + "…"
|
||||
}
|
||||
|
||||
return truncateRunes(title, maxRunes)
|
||||
@@ -260,6 +297,186 @@ func truncateRunes(value string, maxLen int) string {
|
||||
return string(runes[:maxLen])
|
||||
}
|
||||
|
||||
// Manual title regeneration is user-initiated and can use richer
|
||||
// conversation context than the automatic first-message title path
|
||||
// above. These helpers keep the manual prompt-building logic private
|
||||
// while reusing the shared title-generation utilities in this file.
|
||||
type manualTitleTurn struct {
|
||||
role string
|
||||
text string
|
||||
}
|
||||
|
||||
func extractManualTitleTurns(messages []database.ChatMessage) []manualTitleTurn {
|
||||
turns := make([]manualTitleTurn, 0, len(messages))
|
||||
for _, message := range messages {
|
||||
if message.Visibility == database.ChatMessageVisibilityModel {
|
||||
continue
|
||||
}
|
||||
|
||||
role := ""
|
||||
switch message.Role {
|
||||
case database.ChatMessageRoleUser:
|
||||
role = string(database.ChatMessageRoleUser)
|
||||
case database.ChatMessageRoleAssistant:
|
||||
role = string(database.ChatMessageRoleAssistant)
|
||||
default:
|
||||
continue
|
||||
}
|
||||
|
||||
parts, err := chatprompt.ParseContent(message)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
text := strings.TrimSpace(contentBlocksToText(parts))
|
||||
if text == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
turns = append(turns, manualTitleTurn{
|
||||
role: role,
|
||||
text: text,
|
||||
})
|
||||
}
|
||||
|
||||
return turns
|
||||
}
|
||||
|
||||
func selectManualTitleTurnIndexes(turns []manualTitleTurn) []int {
|
||||
firstUserIndex := slices.IndexFunc(turns, func(turn manualTitleTurn) bool {
|
||||
return turn.role == string(database.ChatMessageRoleUser)
|
||||
})
|
||||
if firstUserIndex == -1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
windowStart := max(0, len(turns)-recentTurnWindow)
|
||||
selected := make([]int, 0, recentTurnWindow+1)
|
||||
if firstUserIndex < windowStart {
|
||||
selected = append(selected, firstUserIndex)
|
||||
}
|
||||
for i := windowStart; i < len(turns); i++ {
|
||||
selected = append(selected, i)
|
||||
}
|
||||
|
||||
return selected
|
||||
}
|
||||
|
||||
func buildManualTitleContext(
|
||||
turns []manualTitleTurn,
|
||||
selected []int,
|
||||
) (conversationBlock string, latestUserMsg string) {
|
||||
userCount := 0
|
||||
for _, turn := range turns {
|
||||
if turn.role != string(database.ChatMessageRoleUser) {
|
||||
continue
|
||||
}
|
||||
userCount++
|
||||
latestUserMsg = turn.text
|
||||
}
|
||||
|
||||
latestUserMsg = truncateRunes(latestUserMsg, maxLatestUserMessageRunes)
|
||||
if userCount <= 1 || len(selected) == 0 {
|
||||
return "", latestUserMsg
|
||||
}
|
||||
|
||||
lines := make([]string, 0, len(selected)+1)
|
||||
for i, idx := range selected {
|
||||
if i == 1 {
|
||||
if gap := idx - selected[i-1] - 1; gap > 0 {
|
||||
lines = append(lines, fmt.Sprintf("[... %d earlier turns omitted ...]", gap))
|
||||
}
|
||||
}
|
||||
lines = append(lines, fmt.Sprintf("[%s]: %s", turns[idx].role, turns[idx].text))
|
||||
}
|
||||
|
||||
conversationBlock = strings.Join(lines, "\n")
|
||||
conversationBlock = truncateRunes(conversationBlock, maxConversationContextRunes)
|
||||
return conversationBlock, latestUserMsg
|
||||
}
|
||||
|
||||
func renderManualTitlePrompt(
|
||||
conversationBlock string,
|
||||
firstUserText string,
|
||||
latestUserMsg string,
|
||||
) string {
|
||||
var prompt strings.Builder
|
||||
write := func(value string) {
|
||||
_, _ = prompt.WriteString(value)
|
||||
}
|
||||
|
||||
write("You are a title generator for an AI coding assistant conversation.\n\n")
|
||||
write("The user's primary objective was:\n<primary_objective>\n")
|
||||
write(firstUserText)
|
||||
write("\n</primary_objective>")
|
||||
|
||||
if conversationBlock != "" {
|
||||
write("\n\nConversation sample:\n<conversation_sample>\n")
|
||||
write(conversationBlock)
|
||||
write("\n</conversation_sample>")
|
||||
}
|
||||
|
||||
if strings.TrimSpace(latestUserMsg) != strings.TrimSpace(truncateRunes(firstUserText, maxLatestUserMessageRunes)) {
|
||||
write("\n\nThe user's most recent message:\n<latest_message>\n")
|
||||
write(latestUserMsg)
|
||||
write("\n</latest_message>\n")
|
||||
write("Note: Weight the overall conversation arc more heavily than just the latest message.")
|
||||
}
|
||||
|
||||
write("\n\nRequirements:\n")
|
||||
write("- Output a short title of 2-8 words.\n")
|
||||
write("- Use verb-noun format in sentence case.\n")
|
||||
write("- Preserve specific identifiers (PR numbers, repo names, file paths, function names, error messages).\n")
|
||||
write("- No trailing punctuation, quotes, emoji, or markdown.\n")
|
||||
write("- No temporal phrasing (\"Continue\", \"Follow up on\") or meta phrasing (\"Chat about\").\n")
|
||||
write("- Output ONLY the title - nothing else.\n")
|
||||
return prompt.String()
|
||||
}
|
||||
|
||||
func generateManualTitle(
|
||||
ctx context.Context,
|
||||
messages []database.ChatMessage,
|
||||
fallbackModel fantasy.LanguageModel,
|
||||
) (string, fantasy.Usage, error) {
|
||||
turns := extractManualTitleTurns(messages)
|
||||
selected := selectManualTitleTurnIndexes(turns)
|
||||
|
||||
firstUserIndex := slices.IndexFunc(turns, func(turn manualTitleTurn) bool {
|
||||
return turn.role == string(database.ChatMessageRoleUser)
|
||||
})
|
||||
if firstUserIndex == -1 {
|
||||
return "", fantasy.Usage{}, nil
|
||||
}
|
||||
firstUserText := truncateRunes(turns[firstUserIndex].text, maxLatestUserMessageRunes)
|
||||
|
||||
conversationBlock, latestUserMsg := buildManualTitleContext(turns, selected)
|
||||
systemPrompt := renderManualTitlePrompt(
|
||||
conversationBlock,
|
||||
firstUserText,
|
||||
latestUserMsg,
|
||||
)
|
||||
|
||||
titleCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
title, usage, err := generateShortText(
|
||||
titleCtx,
|
||||
fallbackModel,
|
||||
systemPrompt,
|
||||
"Generate the title.",
|
||||
)
|
||||
if err != nil {
|
||||
return "", fantasy.Usage{}, err
|
||||
}
|
||||
|
||||
title = normalizeTitleOutput(title)
|
||||
if title == "" {
|
||||
return "", usage, xerrors.New("generated title was empty")
|
||||
}
|
||||
|
||||
return title, usage, nil
|
||||
}
|
||||
|
||||
const pushSummaryPrompt = "You are a notification assistant. Given a chat title " +
|
||||
"and the agent's last message, write a single short sentence (under 100 characters) " +
|
||||
"summarizing what the agent did. This will be shown as a push notification body. " +
|
||||
@@ -281,6 +498,7 @@ func generatePushSummary(
|
||||
summaryCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
assistantText = truncateRunes(assistantText, maxConversationContextRunes)
|
||||
input := "Chat title: " + chat.Title + "\n\nAgent's last message:\n" + assistantText
|
||||
|
||||
candidates := make([]fantasy.LanguageModel, 0, len(preferredTitleModels)+1)
|
||||
@@ -296,7 +514,7 @@ func generatePushSummary(
|
||||
candidates = append(candidates, fallbackModel)
|
||||
|
||||
for _, model := range candidates {
|
||||
summary, err := generateShortText(summaryCtx, model, pushSummaryPrompt, input)
|
||||
summary, _, err := generateShortText(summaryCtx, model, pushSummaryPrompt, input)
|
||||
if err != nil {
|
||||
logger.Debug(ctx, "push summary model candidate failed",
|
||||
slog.Error(err),
|
||||
@@ -318,7 +536,7 @@ func generateShortText(
|
||||
model fantasy.LanguageModel,
|
||||
systemPrompt string,
|
||||
userInput string,
|
||||
) (string, error) {
|
||||
) (string, fantasy.Usage, error) {
|
||||
prompt := []fantasy.Message{
|
||||
{
|
||||
Role: fantasy.MessageRoleSystem,
|
||||
@@ -346,7 +564,7 @@ func generateShortText(
|
||||
return genErr
|
||||
}, nil)
|
||||
if err != nil {
|
||||
return "", xerrors.Errorf("generate short text: %w", err)
|
||||
return "", fantasy.Usage{}, xerrors.Errorf("generate short text: %w", err)
|
||||
}
|
||||
|
||||
responseParts := make([]codersdk.ChatMessagePart, 0, len(response.Content))
|
||||
@@ -355,7 +573,6 @@ func generateShortText(
|
||||
responseParts = append(responseParts, p)
|
||||
}
|
||||
}
|
||||
text := strings.TrimSpace(contentBlocksToText(responseParts))
|
||||
text = strings.Trim(text, "\"'`")
|
||||
return text, nil
|
||||
text := normalizeShortTextOutput(contentBlocksToText(responseParts))
|
||||
return text, response.Usage, nil
|
||||
}
|
||||
|
||||
@@ -0,0 +1,584 @@
|
||||
package chatd //nolint:testpackage // Keeps internal helper tests in-package.
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/sqlc-dev/pqtype"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/codersdk"
|
||||
)
|
||||
|
||||
func Test_extractManualTitleTurns(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
messages []database.ChatMessage
|
||||
want []manualTitleTurn
|
||||
}{
|
||||
{
|
||||
name: "filters to visible user and assistant text turns",
|
||||
messages: []database.ChatMessage{
|
||||
mustChatMessage(t, database.ChatMessageRoleUser, database.ChatMessageVisibilityBoth,
|
||||
codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeText, Text: " review quickgen helpers "},
|
||||
),
|
||||
mustChatMessage(t, database.ChatMessageRoleAssistant, database.ChatMessageVisibilityBoth,
|
||||
codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeText, Text: " drafted a plan "},
|
||||
),
|
||||
mustChatMessage(t, database.ChatMessageRoleSystem, database.ChatMessageVisibilityBoth,
|
||||
codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeText, Text: "system prompt"},
|
||||
),
|
||||
mustChatMessage(t, database.ChatMessageRoleTool, database.ChatMessageVisibilityBoth,
|
||||
codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeText, Text: "tool output"},
|
||||
),
|
||||
mustChatMessage(t, database.ChatMessageRoleUser, database.ChatMessageVisibilityModel,
|
||||
codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeText, Text: "hidden model note"},
|
||||
),
|
||||
mustChatMessage(t, database.ChatMessageRoleUser, database.ChatMessageVisibilityBoth,
|
||||
codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeText, Text: " "},
|
||||
),
|
||||
mustChatMessage(t, database.ChatMessageRoleAssistant, database.ChatMessageVisibilityBoth,
|
||||
codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeReasoning, Text: "reasoning only"},
|
||||
),
|
||||
mustChatMessage(t, database.ChatMessageRoleUser, database.ChatMessageVisibilityBoth,
|
||||
codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeFile, MediaType: "text/plain"},
|
||||
),
|
||||
},
|
||||
want: []manualTitleTurn{
|
||||
{role: "user", text: "review quickgen helpers"},
|
||||
{role: "assistant", text: "drafted a plan"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "reuses text extraction for multi-part content",
|
||||
messages: []database.ChatMessage{
|
||||
mustChatMessage(t, database.ChatMessageRoleUser, database.ChatMessageVisibilityBoth,
|
||||
codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeText, Text: "first chunk"},
|
||||
codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeReasoning, Text: "skip me"},
|
||||
codersdk.ChatMessagePart{Type: codersdk.ChatMessagePartTypeText, Text: " second chunk "},
|
||||
),
|
||||
},
|
||||
want: []manualTitleTurn{{role: "user", text: "first chunk second chunk"}},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := extractManualTitleTurns(tt.messages)
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_selectManualTitleTurnIndexes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
turns []manualTitleTurn
|
||||
want []int
|
||||
}{
|
||||
{
|
||||
name: "single user turn",
|
||||
turns: []manualTitleTurn{
|
||||
{role: "user", text: "one"},
|
||||
},
|
||||
want: []int{0},
|
||||
},
|
||||
{
|
||||
name: "first user plus trailing window",
|
||||
turns: []manualTitleTurn{
|
||||
{role: "user", text: "one"},
|
||||
{role: "assistant", text: "two"},
|
||||
{role: "user", text: "three"},
|
||||
{role: "assistant", text: "four"},
|
||||
{role: "user", text: "five"},
|
||||
},
|
||||
want: []int{0, 2, 3, 4},
|
||||
},
|
||||
{
|
||||
name: "two turns returns both",
|
||||
turns: []manualTitleTurn{
|
||||
{role: "user", text: "one"},
|
||||
{role: "assistant", text: "two"},
|
||||
},
|
||||
want: []int{0, 1},
|
||||
},
|
||||
{
|
||||
name: "prepends first user when before trailing window",
|
||||
turns: []manualTitleTurn{
|
||||
{role: "assistant", text: "intro"},
|
||||
{role: "assistant", text: "setup"},
|
||||
{role: "user", text: "goal"},
|
||||
{role: "assistant", text: "a"},
|
||||
{role: "assistant", text: "b"},
|
||||
{role: "assistant", text: "c"},
|
||||
},
|
||||
want: []int{2, 3, 4, 5},
|
||||
},
|
||||
{
|
||||
name: "ten plus turns keeps first user and last three",
|
||||
turns: []manualTitleTurn{
|
||||
{role: "assistant", text: "0"},
|
||||
{role: "assistant", text: "1"},
|
||||
{role: "user", text: "2"},
|
||||
{role: "assistant", text: "3"},
|
||||
{role: "assistant", text: "4"},
|
||||
{role: "assistant", text: "5"},
|
||||
{role: "assistant", text: "6"},
|
||||
{role: "assistant", text: "7"},
|
||||
{role: "assistant", text: "8"},
|
||||
{role: "user", text: "9"},
|
||||
{role: "assistant", text: "10"},
|
||||
{role: "user", text: "11"},
|
||||
},
|
||||
want: []int{2, 9, 10, 11},
|
||||
},
|
||||
{
|
||||
name: "no user turns",
|
||||
turns: []manualTitleTurn{
|
||||
{role: "assistant", text: "one"},
|
||||
{role: "assistant", text: "two"},
|
||||
},
|
||||
want: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := selectManualTitleTurnIndexes(tt.turns)
|
||||
require.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_buildManualTitleContext(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
longConversationText := strings.Repeat("a", 3500)
|
||||
longLatestUserText := strings.Repeat("z", 1200)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
turns []manualTitleTurn
|
||||
selected []int
|
||||
wantConversation string
|
||||
wantConversationEmpty bool
|
||||
wantConversationHasGap bool
|
||||
wantConversationRunes int
|
||||
wantLatestUser string
|
||||
wantLatestUserRunes int
|
||||
wantLatestUserContains string
|
||||
wantLatestUserNotEmpty bool
|
||||
}{
|
||||
{
|
||||
name: "adds gap marker when selected turns skip earlier context",
|
||||
turns: []manualTitleTurn{
|
||||
{role: "user", text: "open pull request"},
|
||||
{role: "assistant", text: "checked CI"},
|
||||
{role: "user", text: "review logs"},
|
||||
{role: "assistant", text: "found flaky test"},
|
||||
{role: "user", text: "update chat title"},
|
||||
},
|
||||
selected: []int{0, 3, 4},
|
||||
wantConversationHasGap: true,
|
||||
wantLatestUser: "update chat title",
|
||||
},
|
||||
{
|
||||
name: "omits gap marker for contiguous selection",
|
||||
turns: []manualTitleTurn{
|
||||
{role: "user", text: "open pull request"},
|
||||
{role: "assistant", text: "checked CI"},
|
||||
{role: "user", text: "update chat title"},
|
||||
},
|
||||
selected: []int{0, 1, 2},
|
||||
wantConversation: "[user]: open pull request\n[assistant]: checked CI\n[user]: update chat title",
|
||||
wantConversationHasGap: false,
|
||||
wantLatestUser: "update chat title",
|
||||
},
|
||||
{
|
||||
name: "single useful user turn returns empty conversation block",
|
||||
turns: []manualTitleTurn{{role: "user", text: "rename helper"}},
|
||||
selected: []int{0},
|
||||
wantConversationEmpty: true,
|
||||
wantLatestUser: "rename helper",
|
||||
},
|
||||
{
|
||||
name: "truncates conversation block at six thousand runes",
|
||||
turns: []manualTitleTurn{
|
||||
{role: "user", text: longConversationText},
|
||||
{role: "assistant", text: longConversationText},
|
||||
{role: "user", text: "latest"},
|
||||
},
|
||||
selected: []int{0, 1, 2},
|
||||
wantConversationRunes: 6000,
|
||||
wantLatestUser: "latest",
|
||||
},
|
||||
{
|
||||
name: "truncates latest user message at one thousand runes",
|
||||
turns: []manualTitleTurn{
|
||||
{role: "user", text: "first"},
|
||||
{role: "assistant", text: "reply"},
|
||||
{role: "user", text: longLatestUserText},
|
||||
},
|
||||
selected: []int{0, 1, 2},
|
||||
wantLatestUserRunes: 1000,
|
||||
wantLatestUserContains: strings.Repeat("z", 1000),
|
||||
wantLatestUserNotEmpty: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
conversationBlock, latestUserMsg := buildManualTitleContext(tt.turns, tt.selected)
|
||||
|
||||
if tt.wantConversationEmpty {
|
||||
require.Empty(t, conversationBlock)
|
||||
}
|
||||
if tt.wantConversation != "" {
|
||||
require.Equal(t, tt.wantConversation, conversationBlock)
|
||||
}
|
||||
if tt.wantConversationHasGap {
|
||||
require.Contains(t, conversationBlock, "[... 2 earlier turns omitted ...]")
|
||||
} else if !tt.wantConversationEmpty {
|
||||
require.NotContains(t, conversationBlock, "earlier turns omitted")
|
||||
}
|
||||
if tt.wantConversationRunes > 0 {
|
||||
require.Len(t, []rune(conversationBlock), tt.wantConversationRunes)
|
||||
}
|
||||
if tt.wantLatestUser != "" {
|
||||
require.Equal(t, tt.wantLatestUser, latestUserMsg)
|
||||
}
|
||||
if tt.wantLatestUserRunes > 0 {
|
||||
require.Len(t, []rune(latestUserMsg), tt.wantLatestUserRunes)
|
||||
}
|
||||
if tt.wantLatestUserContains != "" {
|
||||
require.Equal(t, tt.wantLatestUserContains, latestUserMsg)
|
||||
}
|
||||
if tt.wantLatestUserNotEmpty {
|
||||
require.NotEmpty(t, latestUserMsg)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_renderManualTitlePrompt(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
longFirstUserText := strings.Repeat("b", 1501)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
conversationBlock string
|
||||
firstUserText string
|
||||
latestUserMsg string
|
||||
wantConversationSample bool
|
||||
wantLatestSection bool
|
||||
}{
|
||||
{
|
||||
name: "includes conversation sample when provided",
|
||||
conversationBlock: "[user]: inspect logs\n[assistant]: found flaky test",
|
||||
firstUserText: "inspect logs",
|
||||
latestUserMsg: "update quickgen title",
|
||||
wantConversationSample: true,
|
||||
wantLatestSection: true,
|
||||
},
|
||||
{
|
||||
name: "omits optional sections when not needed",
|
||||
conversationBlock: "",
|
||||
firstUserText: "inspect logs",
|
||||
latestUserMsg: "inspect logs",
|
||||
wantConversationSample: false,
|
||||
wantLatestSection: false,
|
||||
},
|
||||
{
|
||||
name: "latest section compares trimmed text",
|
||||
conversationBlock: "",
|
||||
firstUserText: "inspect logs",
|
||||
latestUserMsg: " inspect logs ",
|
||||
wantConversationSample: false,
|
||||
wantLatestSection: false,
|
||||
},
|
||||
{
|
||||
name: "omits latest section when same message truncated",
|
||||
conversationBlock: "",
|
||||
firstUserText: longFirstUserText,
|
||||
latestUserMsg: truncateRunes(longFirstUserText, 1000),
|
||||
wantConversationSample: false,
|
||||
wantLatestSection: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
prompt := renderManualTitlePrompt(tt.conversationBlock, tt.firstUserText, tt.latestUserMsg)
|
||||
|
||||
require.Contains(t, prompt, "The user's primary objective was:")
|
||||
require.Contains(t, prompt, "Requirements:")
|
||||
require.Contains(t, prompt, "- Output a short title of 2-8 words.")
|
||||
require.Contains(t, prompt, "- Output ONLY the title - nothing else.")
|
||||
|
||||
if tt.wantConversationSample {
|
||||
require.Contains(t, prompt, "Conversation sample:")
|
||||
require.Contains(t, prompt, tt.conversationBlock)
|
||||
} else {
|
||||
require.NotContains(t, prompt, "Conversation sample:")
|
||||
}
|
||||
|
||||
if tt.wantLatestSection {
|
||||
require.Contains(t, prompt, "The user's most recent message:")
|
||||
require.Contains(t, prompt, "Note: Weight the overall conversation arc more heavily than just the latest message.")
|
||||
require.Contains(t, prompt, strings.TrimSpace(tt.latestUserMsg))
|
||||
} else {
|
||||
require.NotContains(t, prompt, "The user's most recent message:")
|
||||
require.NotContains(t, prompt, "Weight the overall conversation arc more heavily")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_generateManualTitle_UsesTimeout(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
messages := []database.ChatMessage{
|
||||
mustChatMessage(
|
||||
t,
|
||||
database.ChatMessageRoleUser,
|
||||
database.ChatMessageVisibilityBoth,
|
||||
codersdk.ChatMessageText("refresh chat title"),
|
||||
),
|
||||
}
|
||||
|
||||
model := &stubModel{
|
||||
generateFn: func(ctx context.Context, call fantasy.Call) (*fantasy.Response, error) {
|
||||
deadline, ok := ctx.Deadline()
|
||||
require.True(t, ok, "manual title generation should set a deadline")
|
||||
require.WithinDuration(
|
||||
t,
|
||||
time.Now().Add(30*time.Second),
|
||||
deadline,
|
||||
2*time.Second,
|
||||
)
|
||||
require.Len(t, call.Prompt, 2)
|
||||
return &fantasy.Response{
|
||||
Content: fantasy.ResponseContent{
|
||||
fantasy.TextContent{Text: "Refresh title"},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
title, _, err := generateManualTitle(
|
||||
context.Background(),
|
||||
messages,
|
||||
model,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "Refresh title", title)
|
||||
}
|
||||
|
||||
func Test_generateManualTitle_TruncatesFirstUserInput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
longFirstUserText := strings.Repeat("a", 1500)
|
||||
messages := []database.ChatMessage{
|
||||
mustChatMessage(
|
||||
t,
|
||||
database.ChatMessageRoleUser,
|
||||
database.ChatMessageVisibilityBoth,
|
||||
codersdk.ChatMessageText(longFirstUserText),
|
||||
),
|
||||
}
|
||||
|
||||
model := &stubModel{
|
||||
generateFn: func(_ context.Context, call fantasy.Call) (*fantasy.Response, error) {
|
||||
require.Len(t, call.Prompt, 2)
|
||||
systemText, ok := call.Prompt[0].Content[0].(fantasy.TextPart)
|
||||
require.True(t, ok)
|
||||
require.Contains(t, systemText.Text, truncateRunes(longFirstUserText, 1000))
|
||||
|
||||
userText, ok := call.Prompt[1].Content[0].(fantasy.TextPart)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "Generate the title.", userText.Text)
|
||||
return &fantasy.Response{
|
||||
Content: fantasy.ResponseContent{
|
||||
fantasy.TextContent{Text: "Refresh title"},
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
_, _, err := generateManualTitle(
|
||||
context.Background(),
|
||||
messages,
|
||||
model,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func Test_generateManualTitle_ReturnsUsageForEmptyNormalizedTitle(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
messages := []database.ChatMessage{
|
||||
mustChatMessage(
|
||||
t,
|
||||
database.ChatMessageRoleUser,
|
||||
database.ChatMessageVisibilityBoth,
|
||||
codersdk.ChatMessageText("refresh chat title"),
|
||||
),
|
||||
}
|
||||
|
||||
model := &stubModel{
|
||||
generateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
|
||||
return &fantasy.Response{
|
||||
Content: fantasy.ResponseContent{
|
||||
fantasy.TextContent{Text: "\"\""},
|
||||
},
|
||||
Usage: fantasy.Usage{
|
||||
InputTokens: 11,
|
||||
OutputTokens: 7,
|
||||
TotalTokens: 18,
|
||||
},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
_, usage, err := generateManualTitle(
|
||||
context.Background(),
|
||||
messages,
|
||||
model,
|
||||
)
|
||||
require.ErrorContains(t, err, "generated title was empty")
|
||||
require.Equal(t, int64(11), usage.InputTokens)
|
||||
require.Equal(t, int64(7), usage.OutputTokens)
|
||||
require.Equal(t, int64(18), usage.TotalTokens)
|
||||
}
|
||||
|
||||
func Test_selectPreferredConfiguredShortTextModelConfig(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("chooses the highest-priority configured lightweight model", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
configs := []database.ChatModelConfig{
|
||||
{Provider: preferredTitleModels[2].provider, Model: preferredTitleModels[2].model},
|
||||
{Provider: preferredTitleModels[1].provider, Model: preferredTitleModels[1].model},
|
||||
{Provider: "openai", Model: "gpt-4.1"},
|
||||
}
|
||||
|
||||
got, ok := selectPreferredConfiguredShortTextModelConfig(configs)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, preferredTitleModels[1].provider, got.Provider)
|
||||
require.Equal(t, preferredTitleModels[1].model, got.Model)
|
||||
})
|
||||
|
||||
t.Run("returns false when no preferred lightweight model is configured", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got, ok := selectPreferredConfiguredShortTextModelConfig([]database.ChatModelConfig{{
|
||||
Provider: "openai",
|
||||
Model: "gpt-4.1",
|
||||
}})
|
||||
require.False(t, ok)
|
||||
require.Equal(t, database.ChatModelConfig{}, got)
|
||||
})
|
||||
}
|
||||
|
||||
func Test_generateShortText_NormalizesQuotedOutput(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
model := &stubModel{
|
||||
generateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) {
|
||||
return &fantasy.Response{
|
||||
Content: fantasy.ResponseContent{
|
||||
fantasy.TextContent{Text: " \"Quoted summary\" "},
|
||||
},
|
||||
Usage: fantasy.Usage{InputTokens: 3, OutputTokens: 2, TotalTokens: 5},
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
text, usage, err := generateShortText(context.Background(), model, "system", "user")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "Quoted summary", text)
|
||||
require.Equal(t, int64(3), usage.InputTokens)
|
||||
require.Equal(t, int64(2), usage.OutputTokens)
|
||||
require.Equal(t, int64(5), usage.TotalTokens)
|
||||
}
|
||||
|
||||
type stubModel struct {
|
||||
generateFn func(context.Context, fantasy.Call) (*fantasy.Response, error)
|
||||
}
|
||||
|
||||
func (m *stubModel) Generate(
|
||||
ctx context.Context,
|
||||
call fantasy.Call,
|
||||
) (*fantasy.Response, error) {
|
||||
return m.generateFn(ctx, call)
|
||||
}
|
||||
|
||||
func (*stubModel) Stream(
|
||||
context.Context,
|
||||
fantasy.Call,
|
||||
) (fantasy.StreamResponse, error) {
|
||||
return nil, xerrors.New("stream not implemented")
|
||||
}
|
||||
|
||||
func (*stubModel) GenerateObject(
|
||||
context.Context,
|
||||
fantasy.ObjectCall,
|
||||
) (*fantasy.ObjectResponse, error) {
|
||||
return nil, xerrors.New("generate object not implemented")
|
||||
}
|
||||
|
||||
func (*stubModel) StreamObject(
|
||||
context.Context,
|
||||
fantasy.ObjectCall,
|
||||
) (fantasy.ObjectStreamResponse, error) {
|
||||
return nil, xerrors.New("stream object not implemented")
|
||||
}
|
||||
|
||||
func (*stubModel) Provider() string {
|
||||
return "test"
|
||||
}
|
||||
|
||||
func (*stubModel) Model() string {
|
||||
return "test"
|
||||
}
|
||||
|
||||
func mustChatMessage(
|
||||
t *testing.T,
|
||||
role database.ChatMessageRole,
|
||||
visibility database.ChatMessageVisibility,
|
||||
parts ...codersdk.ChatMessagePart,
|
||||
) database.ChatMessage {
|
||||
t.Helper()
|
||||
|
||||
content, err := json.Marshal(parts)
|
||||
require.NoError(t, err)
|
||||
|
||||
return database.ChatMessage{
|
||||
Role: role,
|
||||
Visibility: visibility,
|
||||
Content: pqtype.NullRawMessage{
|
||||
RawMessage: content,
|
||||
Valid: len(content) > 0,
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -1870,6 +1870,21 @@ func (c *ExperimentalClient) InterruptChat(ctx context.Context, chatID uuid.UUID
|
||||
return chat, json.NewDecoder(res.Body).Decode(&chat)
|
||||
}
|
||||
|
||||
// RegenerateChatTitle requests the server to regenerate the chat's
|
||||
// title using richer conversation context.
|
||||
func (c *ExperimentalClient) RegenerateChatTitle(ctx context.Context, chatID uuid.UUID) (Chat, error) {
|
||||
res, err := c.Request(ctx, http.MethodPost, fmt.Sprintf("/api/experimental/chats/%s/title/regenerate", chatID), nil)
|
||||
if err != nil {
|
||||
return Chat{}, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return Chat{}, readBodyAsChatUsageLimitError(res)
|
||||
}
|
||||
var chat Chat
|
||||
return chat, json.NewDecoder(res.Body).Decode(&chat)
|
||||
}
|
||||
|
||||
// GetChatGitChanges returns git changes for a chat.
|
||||
func (c *ExperimentalClient) GetChatGitChanges(ctx context.Context, chatID uuid.UUID) ([]ChatGitChange, error) {
|
||||
res, err := c.Request(ctx, http.MethodGet, fmt.Sprintf("/api/experimental/chats/%s/git-changes", chatID), nil)
|
||||
|
||||
@@ -3126,6 +3126,13 @@ class ExperimentalApiMethods {
|
||||
await this.axios.patch(`/api/experimental/chats/${chatId}`, req);
|
||||
};
|
||||
|
||||
regenerateChatTitle = async (chatId: string): Promise<TypesGen.Chat> => {
|
||||
const response = await this.axios.post<TypesGen.Chat>(
|
||||
`/api/experimental/chats/${chatId}/title/regenerate`,
|
||||
);
|
||||
return response.data;
|
||||
};
|
||||
|
||||
createChatMessage = async (
|
||||
chatId: string,
|
||||
req: TypesGen.CreateChatMessageRequest,
|
||||
|
||||
@@ -22,6 +22,7 @@ import {
|
||||
invalidateChatListQueries,
|
||||
pinChat,
|
||||
promoteChatQueuedMessage,
|
||||
regenerateChatTitle,
|
||||
reorderPinnedChat,
|
||||
unarchiveChat,
|
||||
unpinChat,
|
||||
@@ -41,6 +42,7 @@ vi.mock("api/api", () => ({
|
||||
editChatMessage: vi.fn(),
|
||||
interruptChat: vi.fn(),
|
||||
promoteChatQueuedMessage: vi.fn(),
|
||||
regenerateChatTitle: vi.fn(),
|
||||
},
|
||||
},
|
||||
}));
|
||||
@@ -556,6 +558,51 @@ describe("reorderPinnedChat", () => {
|
||||
});
|
||||
});
|
||||
|
||||
describe("regenerateChatTitle cache updates", () => {
|
||||
it("preserves existing chat detail fields when the response is partial", () => {
|
||||
const queryClient = createTestQueryClient();
|
||||
const chatId = "chat-1";
|
||||
const cachedChat = makeChat(chatId, {
|
||||
diff_status: {
|
||||
chat_id: chatId,
|
||||
url: "https://example.com/pr/1",
|
||||
pull_request_state: "open",
|
||||
pull_request_title: "",
|
||||
pull_request_draft: false,
|
||||
changes_requested: false,
|
||||
additions: 1,
|
||||
deletions: 2,
|
||||
changed_files: 3,
|
||||
refreshed_at: "2025-01-01T00:00:00.000Z",
|
||||
stale_at: "2025-01-01T01:00:00.000Z",
|
||||
},
|
||||
});
|
||||
queryClient.setQueryData(chatKey(chatId), cachedChat);
|
||||
seedInfiniteChats(queryClient, [cachedChat]);
|
||||
|
||||
const mutation = regenerateChatTitle(queryClient);
|
||||
const updatedChat = {
|
||||
id: chatId,
|
||||
title: "New title",
|
||||
} satisfies Partial<TypesGen.Chat>;
|
||||
|
||||
mutation.onSuccess(updatedChat as TypesGen.Chat);
|
||||
|
||||
const cachedDetail = queryClient.getQueryData<TypesGen.Chat>(
|
||||
chatKey(chatId),
|
||||
);
|
||||
expect(cachedDetail).toEqual({
|
||||
...cachedChat,
|
||||
title: "New title",
|
||||
});
|
||||
expect(cachedDetail?.diff_status).toEqual(cachedChat.diff_status);
|
||||
expect(readInfiniteChats(queryClient)?.[0]).toMatchObject({
|
||||
id: chatId,
|
||||
title: "New title",
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("chat cost query factories", () => {
|
||||
it("builds the summary query key and forwards snake_case params", async () => {
|
||||
const user = "user-1";
|
||||
|
||||
@@ -476,6 +476,37 @@ export const reorderPinnedChat = (queryClient: QueryClient) => ({
|
||||
},
|
||||
});
|
||||
|
||||
export const regenerateChatTitle = (queryClient: QueryClient) => ({
|
||||
mutationFn: (chatId: string) => API.experimental.regenerateChatTitle(chatId),
|
||||
|
||||
onSuccess: (updatedChat: TypesGen.Chat) => {
|
||||
queryClient.setQueryData<TypesGen.Chat>(
|
||||
chatKey(updatedChat.id),
|
||||
(previousChat) =>
|
||||
previousChat ? { ...previousChat, ...updatedChat } : updatedChat,
|
||||
);
|
||||
updateInfiniteChatsCache(queryClient, (chats) =>
|
||||
chats.map((chat) =>
|
||||
chat.id === updatedChat.id
|
||||
? { ...chat, title: updatedChat.title }
|
||||
: chat,
|
||||
),
|
||||
);
|
||||
},
|
||||
|
||||
onSettled: async (
|
||||
_data: TypesGen.Chat | undefined,
|
||||
_error: unknown,
|
||||
chatId: string,
|
||||
) => {
|
||||
await invalidateChatListQueries(queryClient);
|
||||
await queryClient.invalidateQueries({
|
||||
queryKey: chatKey(chatId),
|
||||
exact: true,
|
||||
});
|
||||
},
|
||||
});
|
||||
|
||||
export const createChat = (queryClient: QueryClient) => ({
|
||||
mutationFn: (req: TypesGen.CreateChatRequest) =>
|
||||
API.experimental.createChat(req),
|
||||
|
||||
@@ -55,6 +55,9 @@ const AgentDetailLayout: FC = () => {
|
||||
requestUnarchiveAgent: () => {},
|
||||
requestPinAgent: () => {},
|
||||
requestUnpinAgent: () => {},
|
||||
onRegenerateTitle: () => {},
|
||||
isRegeneratingTitle: false,
|
||||
regeneratingTitleChatId: null,
|
||||
isSidebarCollapsed: false,
|
||||
onToggleSidebarCollapsed: () => {},
|
||||
onExpandSidebar: () => {},
|
||||
|
||||
@@ -292,6 +292,9 @@ const AgentDetail: FC = () => {
|
||||
requestArchiveAgent,
|
||||
requestArchiveAndDeleteWorkspace,
|
||||
requestUnarchiveAgent,
|
||||
onRegenerateTitle,
|
||||
isRegeneratingTitle,
|
||||
regeneratingTitleChatId,
|
||||
isSidebarCollapsed,
|
||||
onToggleSidebarCollapsed,
|
||||
onChatReady,
|
||||
@@ -326,6 +329,9 @@ const AgentDetail: FC = () => {
|
||||
});
|
||||
};
|
||||
|
||||
const isRegeneratingThisChat =
|
||||
isRegeneratingTitle && regeneratingTitleChatId === agentId;
|
||||
|
||||
const chatQuery = useQuery({
|
||||
...chat(agentId ?? ""),
|
||||
enabled: Boolean(agentId),
|
||||
@@ -480,6 +486,7 @@ const AgentDetail: FC = () => {
|
||||
}
|
||||
: undefined;
|
||||
const isArchived = chatRecord?.archived ?? false;
|
||||
const isRegenerateTitleDisabled = isArchived || isRegeneratingTitle;
|
||||
const chatLastModelConfigID = chatRecord?.last_model_config_id;
|
||||
|
||||
const sendMutation = useMutation(
|
||||
@@ -887,6 +894,13 @@ const AgentDetail: FC = () => {
|
||||
onChatReady();
|
||||
}, [onChatReady, chatMessagesQuery.isSuccess, agentId]);
|
||||
|
||||
const handleRegenerateTitle = () => {
|
||||
if (!agentId || isRegenerateTitleDisabled || !onRegenerateTitle) {
|
||||
return;
|
||||
}
|
||||
onRegenerateTitle(agentId);
|
||||
};
|
||||
|
||||
if (chatQuery.isLoading || chatMessagesQuery.isLoading) {
|
||||
return (
|
||||
<AgentDetailLoadingView
|
||||
@@ -958,6 +972,9 @@ const AgentDetail: FC = () => {
|
||||
handleArchiveAndDeleteWorkspaceAction={
|
||||
handleArchiveAndDeleteWorkspaceAction
|
||||
}
|
||||
handleRegenerateTitle={handleRegenerateTitle}
|
||||
isRegeneratingTitle={isRegeneratingThisChat}
|
||||
isRegenerateTitleDisabled={isRegenerateTitleDisabled}
|
||||
urlTransform={urlTransform}
|
||||
scrollContainerRef={scrollContainerRef}
|
||||
scrollToBottomRef={scrollToBottomRef}
|
||||
|
||||
@@ -233,6 +233,9 @@ const AgentEmbedPage: FC = () => {
|
||||
requestPinAgent: () => {},
|
||||
requestUnpinAgent: () => {},
|
||||
requestArchiveAndDeleteWorkspace,
|
||||
// Title regeneration is not supported in embed mode.
|
||||
isRegeneratingTitle: false,
|
||||
regeneratingTitleChatId: null,
|
||||
isSidebarCollapsed,
|
||||
onToggleSidebarCollapsed,
|
||||
onExpandSidebar: () => {},
|
||||
|
||||
@@ -22,6 +22,7 @@ import {
|
||||
pinChat,
|
||||
prependToInfiniteChatsCache,
|
||||
readInfiniteChatsCache,
|
||||
regenerateChatTitle,
|
||||
reorderPinnedChat,
|
||||
unarchiveChat,
|
||||
unpinChat,
|
||||
@@ -220,6 +221,12 @@ const AgentsPage: FC = () => {
|
||||
toast.error(getErrorMessage(error, "Failed to reorder pinned agents."));
|
||||
},
|
||||
});
|
||||
const regenerateTitleMutation = useMutation({
|
||||
...regenerateChatTitle(queryClient),
|
||||
onError: (error: unknown) => {
|
||||
toast.error(getErrorMessage(error, "Failed to generate new title."));
|
||||
},
|
||||
});
|
||||
const [isSidebarCollapsed, setIsSidebarCollapsed] = useState(false);
|
||||
const [chatErrorReasons, setChatErrorReasons] = useState<
|
||||
Record<string, ChatDetailError>
|
||||
@@ -359,6 +366,12 @@ const AgentsPage: FC = () => {
|
||||
const requestReorderPinnedAgent = (chatId: string, pinOrder: number) => {
|
||||
reorderPinnedChatMutation.mutate({ chatId, pinOrder });
|
||||
};
|
||||
const requestRegenerateTitle = (chatId: string) => {
|
||||
if (regenerateTitleMutation.isPending) {
|
||||
return;
|
||||
}
|
||||
regenerateTitleMutation.mutate(chatId);
|
||||
};
|
||||
const handleToggleSidebarCollapsed = () =>
|
||||
setIsSidebarCollapsed((prev) => !prev);
|
||||
|
||||
@@ -609,6 +622,9 @@ const AgentsPage: FC = () => {
|
||||
requestPinAgent={requestPinAgent}
|
||||
requestUnpinAgent={requestUnpinAgent}
|
||||
requestReorderPinnedAgent={requestReorderPinnedAgent}
|
||||
onRegenerateTitle={requestRegenerateTitle}
|
||||
isRegeneratingTitle={regenerateTitleMutation.isPending}
|
||||
regeneratingTitleChatId={regenerateTitleMutation.variables ?? null}
|
||||
onToggleSidebarCollapsed={handleToggleSidebarCollapsed}
|
||||
isAgentsAdmin={isAgentsAdmin}
|
||||
hasNextPage={chatsQuery.hasNextPage}
|
||||
|
||||
@@ -23,6 +23,9 @@ export interface AgentsOutletContext {
|
||||
requestPinAgent: (chatId: string) => void;
|
||||
requestUnpinAgent: (chatId: string) => void;
|
||||
requestReorderPinnedAgent?: (chatId: string, pinOrder: number) => void;
|
||||
onRegenerateTitle?: (chatId: string) => void;
|
||||
isRegeneratingTitle: boolean;
|
||||
regeneratingTitleChatId: string | null;
|
||||
isSidebarCollapsed: boolean;
|
||||
onToggleSidebarCollapsed: () => void;
|
||||
onExpandSidebar: () => void;
|
||||
@@ -59,6 +62,9 @@ interface AgentsPageViewProps {
|
||||
requestPinAgent: (chatId: string) => void;
|
||||
requestUnpinAgent: (chatId: string) => void;
|
||||
requestReorderPinnedAgent?: (chatId: string, pinOrder: number) => void;
|
||||
onRegenerateTitle: (chatId: string) => void;
|
||||
isRegeneratingTitle: boolean;
|
||||
regeneratingTitleChatId: string | null;
|
||||
onToggleSidebarCollapsed: () => void;
|
||||
isAgentsAdmin: boolean;
|
||||
hasNextPage: boolean | undefined;
|
||||
@@ -93,6 +99,9 @@ export const AgentsPageView: FC<AgentsPageViewProps> = ({
|
||||
requestPinAgent,
|
||||
requestUnpinAgent,
|
||||
requestReorderPinnedAgent,
|
||||
onRegenerateTitle,
|
||||
isRegeneratingTitle,
|
||||
regeneratingTitleChatId,
|
||||
onToggleSidebarCollapsed,
|
||||
isAgentsAdmin,
|
||||
hasNextPage,
|
||||
@@ -133,6 +142,9 @@ export const AgentsPageView: FC<AgentsPageViewProps> = ({
|
||||
requestPinAgent,
|
||||
requestUnpinAgent,
|
||||
requestReorderPinnedAgent,
|
||||
onRegenerateTitle,
|
||||
isRegeneratingTitle,
|
||||
regeneratingTitleChatId,
|
||||
isSidebarCollapsed,
|
||||
onToggleSidebarCollapsed,
|
||||
onExpandSidebar,
|
||||
@@ -166,6 +178,9 @@ export const AgentsPageView: FC<AgentsPageViewProps> = ({
|
||||
onPinAgent={requestPinAgent}
|
||||
onUnpinAgent={requestUnpinAgent}
|
||||
onReorderPinnedAgent={requestReorderPinnedAgent}
|
||||
onRegenerateTitle={onRegenerateTitle}
|
||||
isRegeneratingTitle={isRegeneratingTitle}
|
||||
regeneratingTitleChatId={regeneratingTitleChatId}
|
||||
onBeforeNewAgent={handleNewAgent}
|
||||
isCreating={isCreating}
|
||||
isArchiving={isArchiving}
|
||||
|
||||
@@ -1,26 +1,27 @@
|
||||
import type { Meta, StoryObj } from "@storybook/react-vite";
|
||||
import { expect, userEvent, waitFor, within } from "storybook/test";
|
||||
import { expect, fn, userEvent, waitFor, within } from "storybook/test";
|
||||
import { AgentDetailTopBar } from "./TopBar";
|
||||
|
||||
const defaultProps = {
|
||||
chatTitle: "Build authentication feature",
|
||||
panel: {
|
||||
showSidebarPanel: false,
|
||||
onToggleSidebar: () => {},
|
||||
onToggleSidebar: fn(),
|
||||
},
|
||||
workspace: {
|
||||
canOpenEditors: true,
|
||||
canOpenWorkspace: true,
|
||||
onOpenInEditor: () => {},
|
||||
onViewWorkspace: () => {},
|
||||
onOpenTerminal: () => {},
|
||||
onOpenInEditor: fn(),
|
||||
onViewWorkspace: fn(),
|
||||
onOpenTerminal: fn(),
|
||||
sshCommand: "ssh main.my-workspace.admin.coder",
|
||||
},
|
||||
onArchiveAgent: () => {},
|
||||
onArchiveAndDeleteWorkspace: () => {},
|
||||
onUnarchiveAgent: () => {},
|
||||
onArchiveAgent: fn(),
|
||||
onArchiveAndDeleteWorkspace: fn(),
|
||||
onRegenerateTitle: fn(),
|
||||
onUnarchiveAgent: fn(),
|
||||
isSidebarCollapsed: false,
|
||||
onToggleSidebarCollapsed: () => {},
|
||||
onToggleSidebarCollapsed: fn(),
|
||||
} satisfies React.ComponentProps<typeof AgentDetailTopBar>;
|
||||
|
||||
const meta: Meta<typeof AgentDetailTopBar> = {
|
||||
@@ -36,6 +37,13 @@ type Story = StoryObj<typeof AgentDetailTopBar>;
|
||||
|
||||
export const Default: Story = {};
|
||||
|
||||
export const RegeneratingTitle: Story = {
|
||||
args: {
|
||||
...Default.args,
|
||||
isRegeneratingTitle: true,
|
||||
},
|
||||
};
|
||||
|
||||
export const WithPanelOpen: Story = {
|
||||
args: {
|
||||
panel: {
|
||||
@@ -227,10 +235,22 @@ export const MobileWithClosedPR: Story = {
|
||||
},
|
||||
};
|
||||
|
||||
export const GenerateTitle: Story = {
|
||||
play: async ({ canvasElement }) => {
|
||||
const canvas = within(canvasElement);
|
||||
const trigger = canvas.getByLabelText("Open agent actions");
|
||||
await userEvent.click(trigger);
|
||||
await waitFor(() => {
|
||||
const body = within(document.body);
|
||||
expect(body.getByText("Generate new title")).toBeInTheDocument();
|
||||
});
|
||||
},
|
||||
};
|
||||
|
||||
export const ArchivedWithUnarchive: Story = {
|
||||
args: {
|
||||
isArchived: true,
|
||||
onUnarchiveAgent: () => {},
|
||||
onUnarchiveAgent: fn(),
|
||||
},
|
||||
play: async ({ canvasElement }) => {
|
||||
const canvas = within(canvasElement);
|
||||
@@ -243,6 +263,7 @@ export const ArchivedWithUnarchive: Story = {
|
||||
expect(body.getByText("Unarchive Agent")).toBeInTheDocument();
|
||||
});
|
||||
const body = within(document.body);
|
||||
expect(body.queryByText("Generate new title")).not.toBeInTheDocument();
|
||||
expect(body.queryByText("Archive Agent")).not.toBeInTheDocument();
|
||||
expect(
|
||||
body.queryByText("Archive & Delete Workspace"),
|
||||
|
||||
@@ -12,10 +12,12 @@ import {
|
||||
PanelRightOpenIcon,
|
||||
TerminalIcon,
|
||||
Trash2Icon,
|
||||
WandSparklesIcon,
|
||||
} from "lucide-react";
|
||||
import type { FC } from "react";
|
||||
import { Link } from "react-router";
|
||||
import { toast } from "sonner";
|
||||
import { cn } from "utils/cn";
|
||||
import type * as TypesGen from "#/api/typesGenerated";
|
||||
import type { ChatDiffStatus } from "#/api/typesGenerated";
|
||||
import { Button } from "#/components/Button/Button";
|
||||
@@ -26,7 +28,7 @@ import {
|
||||
DropdownMenuSeparator,
|
||||
DropdownMenuTrigger,
|
||||
} from "#/components/DropdownMenu/DropdownMenu";
|
||||
import { cn } from "#/utils/cn";
|
||||
import { Spinner } from "#/components/Spinner/Spinner";
|
||||
import { parsePullRequestUrl } from "../../utils/pullRequest";
|
||||
import { useEmbedContext } from "../EmbedContext";
|
||||
import { PrStateIcon } from "../GitPanel/GitPanel";
|
||||
@@ -53,6 +55,9 @@ type AgentDetailTopBarProps = {
|
||||
onArchiveAgent: () => void;
|
||||
onUnarchiveAgent: () => void;
|
||||
onArchiveAndDeleteWorkspace: () => void;
|
||||
onRegenerateTitle?: () => void;
|
||||
isRegeneratingTitle?: boolean;
|
||||
isRegenerateTitleDisabled?: boolean;
|
||||
hasWorkspace?: boolean;
|
||||
isArchived?: boolean;
|
||||
isSidebarCollapsed: boolean;
|
||||
@@ -68,6 +73,9 @@ export const AgentDetailTopBar: FC<AgentDetailTopBarProps> = ({
|
||||
onArchiveAgent,
|
||||
onUnarchiveAgent,
|
||||
onArchiveAndDeleteWorkspace,
|
||||
onRegenerateTitle,
|
||||
isRegeneratingTitle,
|
||||
isRegenerateTitleDisabled,
|
||||
hasWorkspace,
|
||||
isArchived,
|
||||
isSidebarCollapsed,
|
||||
@@ -115,7 +123,12 @@ export const AgentDetailTopBar: FC<AgentDetailTopBarProps> = ({
|
||||
{/* Title area */}
|
||||
<div className="flex min-w-0 flex-1 items-center">
|
||||
{chatTitle && (
|
||||
<div className="flex min-w-0 items-center gap-1.5">
|
||||
<div
|
||||
role="status"
|
||||
aria-live="polite"
|
||||
aria-busy={isRegeneratingTitle}
|
||||
className="flex min-w-0 items-center gap-1.5"
|
||||
>
|
||||
{parentChat && (
|
||||
<>
|
||||
<Button
|
||||
@@ -131,9 +144,21 @@ export const AgentDetailTopBar: FC<AgentDetailTopBarProps> = ({
|
||||
<ChevronRightIcon className="h-3.5 w-3.5 shrink-0 text-content-secondary/70 -ml-0.5" />
|
||||
</>
|
||||
)}
|
||||
<span className="truncate text-sm text-content-primary">
|
||||
<span
|
||||
className={cn(
|
||||
"truncate text-sm text-content-primary",
|
||||
isRegeneratingTitle && "animate-pulse",
|
||||
)}
|
||||
>
|
||||
{chatTitle}
|
||||
</span>
|
||||
{isRegeneratingTitle && (
|
||||
<Spinner
|
||||
aria-label="Regenerating title"
|
||||
className="h-3.5 w-3.5 shrink-0 text-content-secondary"
|
||||
loading
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
@@ -227,7 +252,23 @@ export const AgentDetailTopBar: FC<AgentDetailTopBarProps> = ({
|
||||
<MonitorIcon className="h-3.5 w-3.5" />
|
||||
View Workspace
|
||||
</DropdownMenuItem>
|
||||
<DropdownMenuSeparator />
|
||||
{!isArchived && (
|
||||
<>
|
||||
<DropdownMenuSeparator />
|
||||
{onRegenerateTitle && (
|
||||
<>
|
||||
<DropdownMenuItem
|
||||
disabled={isRegenerateTitleDisabled}
|
||||
onSelect={onRegenerateTitle}
|
||||
>
|
||||
<WandSparklesIcon className="h-3.5 w-3.5" />
|
||||
Generate new title
|
||||
</DropdownMenuItem>
|
||||
<DropdownMenuSeparator />
|
||||
</>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
{isArchived ? (
|
||||
<DropdownMenuItem onSelect={onUnarchiveAgent}>
|
||||
<ArchiveRestoreIcon className="h-3.5 w-3.5" />
|
||||
|
||||
@@ -142,6 +142,7 @@ const StoryAgentDetailView: FC<StoryProps> = ({ editing, ...overrides }) => {
|
||||
handleArchiveAgentAction: fn(),
|
||||
handleUnarchiveAgentAction: fn(),
|
||||
handleArchiveAndDeleteWorkspaceAction: fn(),
|
||||
handleRegenerateTitle: fn(),
|
||||
scrollContainerRef: { current: null },
|
||||
hasMoreMessages: false,
|
||||
isFetchingMoreMessages: false,
|
||||
|
||||
@@ -117,6 +117,9 @@ interface AgentDetailViewProps {
|
||||
handleArchiveAgentAction: () => void;
|
||||
handleUnarchiveAgentAction: () => void;
|
||||
handleArchiveAndDeleteWorkspaceAction: () => void;
|
||||
handleRegenerateTitle?: () => void;
|
||||
isRegeneratingTitle?: boolean;
|
||||
isRegenerateTitleDisabled?: boolean;
|
||||
|
||||
// Scroll container ref.
|
||||
scrollContainerRef: RefObject<HTMLDivElement | null>;
|
||||
@@ -179,6 +182,9 @@ export const AgentDetailView: FC<AgentDetailViewProps> = ({
|
||||
handleArchiveAgentAction,
|
||||
handleUnarchiveAgentAction,
|
||||
handleArchiveAndDeleteWorkspaceAction,
|
||||
handleRegenerateTitle,
|
||||
isRegeneratingTitle,
|
||||
isRegenerateTitleDisabled,
|
||||
scrollContainerRef,
|
||||
scrollToBottomRef,
|
||||
hasMoreMessages,
|
||||
@@ -261,6 +267,11 @@ export const AgentDetailView: FC<AgentDetailViewProps> = ({
|
||||
onArchiveAndDeleteWorkspace={
|
||||
handleArchiveAndDeleteWorkspaceAction
|
||||
}
|
||||
{...(handleRegenerateTitle
|
||||
? { onRegenerateTitle: handleRegenerateTitle }
|
||||
: {})}
|
||||
isRegeneratingTitle={isRegeneratingTitle}
|
||||
isRegenerateTitleDisabled={isRegenerateTitleDisabled}
|
||||
hasWorkspace={hasWorkspace}
|
||||
isArchived={isArchived}
|
||||
diffStatusData={diffStatusData}
|
||||
@@ -435,6 +446,7 @@ export const AgentDetailLoadingView: FC<AgentDetailLoadingViewProps> = ({
|
||||
}}
|
||||
onArchiveAgent={() => {}}
|
||||
onUnarchiveAgent={() => {}}
|
||||
onRegenerateTitle={() => {}}
|
||||
onArchiveAndDeleteWorkspace={() => {}}
|
||||
hasWorkspace={false}
|
||||
isSidebarCollapsed={isSidebarCollapsed}
|
||||
@@ -507,6 +519,7 @@ export const AgentDetailNotFoundView: FC<AgentDetailNotFoundViewProps> = ({
|
||||
}}
|
||||
onArchiveAgent={() => {}}
|
||||
onUnarchiveAgent={() => {}}
|
||||
onRegenerateTitle={() => {}}
|
||||
onArchiveAndDeleteWorkspace={() => {}}
|
||||
hasWorkspace={false}
|
||||
isSidebarCollapsed={isSidebarCollapsed}
|
||||
|
||||
@@ -74,8 +74,11 @@ const meta: Meta<typeof AgentsSidebar> = {
|
||||
onArchiveAndDeleteWorkspace: fn(),
|
||||
onPinAgent: fn(),
|
||||
onUnpinAgent: fn(),
|
||||
onRegenerateTitle: fn(),
|
||||
onBeforeNewAgent: fn(),
|
||||
isCreating: false,
|
||||
isRegeneratingTitle: false,
|
||||
regeneratingTitleChatId: null,
|
||||
archivedFilter: "active" as const,
|
||||
onArchivedFilterChange: fn(),
|
||||
},
|
||||
|
||||
@@ -108,6 +108,9 @@ const defaultProps: React.ComponentProps<typeof AgentsSidebar> = {
|
||||
onArchiveAndDeleteWorkspace: vi.fn(),
|
||||
onPinAgent: vi.fn(),
|
||||
onUnpinAgent: vi.fn(),
|
||||
onRegenerateTitle: vi.fn(),
|
||||
isRegeneratingTitle: false,
|
||||
regeneratingTitleChatId: null,
|
||||
onBeforeNewAgent: vi.fn(),
|
||||
isCreating: false,
|
||||
archivedFilter: "active" as const,
|
||||
|
||||
@@ -72,6 +72,7 @@ import {
|
||||
DropdownMenu,
|
||||
DropdownMenuContent,
|
||||
DropdownMenuItem,
|
||||
DropdownMenuSeparator,
|
||||
DropdownMenuTrigger,
|
||||
} from "#/components/DropdownMenu/DropdownMenu";
|
||||
import { ExternalImage } from "#/components/ExternalImage/ExternalImage";
|
||||
@@ -123,10 +124,13 @@ interface AgentsSidebarProps {
|
||||
onPinAgent: (chatId: string) => void;
|
||||
onUnpinAgent: (chatId: string) => void;
|
||||
onReorderPinnedAgent?: (chatId: string, pinOrder: number) => void;
|
||||
onRegenerateTitle: (chatId: string) => void;
|
||||
onBeforeNewAgent?: () => void;
|
||||
isCreating: boolean;
|
||||
isArchiving?: boolean;
|
||||
archivingChatId?: string | null;
|
||||
isRegeneratingTitle?: boolean;
|
||||
regeneratingTitleChatId?: string | null;
|
||||
isLoading?: boolean;
|
||||
loadError?: unknown;
|
||||
onRetryLoad?: () => void;
|
||||
@@ -367,6 +371,8 @@ interface ChatTreeContextValue {
|
||||
readonly chatErrorReasons: Record<string, string>;
|
||||
readonly isArchiving: boolean;
|
||||
readonly archivingChatId: string | null;
|
||||
readonly isRegeneratingTitle: boolean;
|
||||
readonly regeneratingTitleChatId: string | null;
|
||||
readonly toggleExpanded: (chatID: string) => void;
|
||||
readonly onArchiveAgent: (chatId: string) => void;
|
||||
readonly onUnarchiveAgent: (chatId: string) => void;
|
||||
@@ -376,6 +382,7 @@ interface ChatTreeContextValue {
|
||||
) => void;
|
||||
readonly onPinAgent: (chatId: string) => void;
|
||||
readonly onUnpinAgent: (chatId: string) => void;
|
||||
readonly onRegenerateTitle: (chatId: string) => void;
|
||||
}
|
||||
|
||||
const ChatTreeContext = createContext<ChatTreeContextValue | null>(null);
|
||||
@@ -405,12 +412,15 @@ const ChatTreeNode: FC<ChatTreeNodeProps> = ({ chat, isChildNode }) => {
|
||||
chatErrorReasons,
|
||||
isArchiving,
|
||||
archivingChatId,
|
||||
isRegeneratingTitle,
|
||||
regeneratingTitleChatId,
|
||||
toggleExpanded,
|
||||
onArchiveAgent,
|
||||
onUnarchiveAgent,
|
||||
onArchiveAndDeleteWorkspace,
|
||||
onPinAgent,
|
||||
onUnpinAgent,
|
||||
onRegenerateTitle,
|
||||
} = useChatTree();
|
||||
const chatID = chat.id;
|
||||
const childIDs = (chatTree.childrenById.get(chatID) ?? []).filter((childID) =>
|
||||
@@ -448,6 +458,8 @@ const ChatTreeNode: FC<ChatTreeNodeProps> = ({ chat, isChildNode }) => {
|
||||
}`;
|
||||
const workspaceId = chat.workspace_id;
|
||||
const isArchivingThisChat = isArchiving && archivingChatId === chat.id;
|
||||
const isRegeneratingThisChat =
|
||||
isRegeneratingTitle && regeneratingTitleChatId === chat.id;
|
||||
const isExpanded = normalizedSearch ? true : (expandedById[chatID] ?? false);
|
||||
|
||||
return (
|
||||
@@ -509,13 +521,21 @@ const ChatTreeNode: FC<ChatTreeNodeProps> = ({ chat, isChildNode }) => {
|
||||
<div className="min-w-0 flex-1 overflow-hidden text-left">
|
||||
<div className="flex min-w-0 items-center gap-1.5 overflow-hidden">
|
||||
<span
|
||||
aria-busy={isRegeneratingThisChat}
|
||||
className={cn(
|
||||
"block flex-1 truncate text-[13px] text-content-primary",
|
||||
isActive && "font-medium",
|
||||
// Pulse-only in sidebar (no spinner) — space-constrained card layout.
|
||||
isRegeneratingThisChat && "animate-pulse",
|
||||
)}
|
||||
>
|
||||
{chat.title}
|
||||
</span>
|
||||
{isRegeneratingThisChat && (
|
||||
<span className="sr-only" role="status">
|
||||
Regenerating title…
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
<div className="flex min-w-0 items-center gap-1.5">
|
||||
{hasLinkedDiffStatus && hasLineStats && (
|
||||
@@ -598,6 +618,14 @@ const ChatTreeNode: FC<ChatTreeNodeProps> = ({ chat, isChildNode }) => {
|
||||
</DropdownMenuItem>
|
||||
) : (
|
||||
<>
|
||||
<DropdownMenuItem
|
||||
disabled={isRegeneratingTitle}
|
||||
onSelect={() => onRegenerateTitle(chat.id)}
|
||||
>
|
||||
<WandSparklesIcon className="h-3.5 w-3.5" />
|
||||
Generate new title
|
||||
</DropdownMenuItem>
|
||||
<DropdownMenuSeparator />
|
||||
<DropdownMenuItem
|
||||
className="text-content-destructive focus:text-content-destructive"
|
||||
disabled={isArchiving}
|
||||
@@ -710,10 +738,13 @@ export const AgentsSidebar: FC<AgentsSidebarProps> = (props) => {
|
||||
onPinAgent,
|
||||
onUnpinAgent,
|
||||
onReorderPinnedAgent,
|
||||
onRegenerateTitle,
|
||||
onBeforeNewAgent,
|
||||
isCreating,
|
||||
isArchiving = false,
|
||||
archivingChatId = null,
|
||||
isRegeneratingTitle = false,
|
||||
regeneratingTitleChatId = null,
|
||||
isLoading = false,
|
||||
loadError,
|
||||
onRetryLoad,
|
||||
@@ -910,12 +941,15 @@ export const AgentsSidebar: FC<AgentsSidebarProps> = (props) => {
|
||||
chatErrorReasons,
|
||||
isArchiving,
|
||||
archivingChatId,
|
||||
isRegeneratingTitle,
|
||||
regeneratingTitleChatId,
|
||||
toggleExpanded,
|
||||
onArchiveAgent,
|
||||
onUnarchiveAgent,
|
||||
onArchiveAndDeleteWorkspace,
|
||||
onPinAgent,
|
||||
onUnpinAgent,
|
||||
onRegenerateTitle,
|
||||
};
|
||||
|
||||
const subNavTitle = "Settings";
|
||||
|
||||
Reference in New Issue
Block a user