diff --git a/coderd/chatd/chatd.go b/coderd/chatd/chatd.go index 566cbd0675..ea3a2f87be 100644 --- a/coderd/chatd/chatd.go +++ b/coderd/chatd/chatd.go @@ -2022,33 +2022,8 @@ func (p *Server) processChat(ctx context.Context, chat database.Chat) { chat.Status = status p.publishChatPubsubEvent(chat, coderdpubsub.ChatEventKindStatusChange) - // Send a web push notification when the agent finishes - // processing. We only notify for terminal states (waiting - // = success, error = failure) and skip sub-agent chats - // and user-interrupted chats to avoid unnecessary - // notifications. - if p.webpushDispatcher != nil && p.webpushDispatcher.PublicKey() != "" && !chat.ParentChatID.Valid && !wasInterrupted { - if status == database.ChatStatusWaiting || status == database.ChatStatusError { - pushMsg := codersdk.WebpushMessage{ - Title: chat.Title, - Body: "Agent has finished running.", - Icon: "/favicon.ico", - Data: map[string]string{"url": fmt.Sprintf("/agents/%s", chat.ID)}, - } - if status == database.ChatStatusError { - pushMsg.Body = "Agent encountered an error." - if lastError != "" { - pushMsg.Body = lastError - } - } - if err := p.webpushDispatcher.Dispatch(cleanupCtx, chat.OwnerID, pushMsg); err != nil { - logger.Warn(cleanupCtx, "failed to send chat completion web push", - slog.F("chat_id", chat.ID), - slog.F("status", status), - slog.Error(err), - ) - } - } + if !wasInterrupted { + p.maybeSendPushNotification(cleanupCtx, chat, status, lastError, logger) } }() @@ -2977,6 +2952,90 @@ func (p *Server) recoverStaleChats(ctx context.Context) { } } +// maybeSendPushNotification sends a web push notification when an +// agent chat reaches a terminal state. For errors it dispatches +// synchronously; for successful completions it spawns a goroutine +// that generates a short LLM summary before dispatching. The caller +// is responsible for skipping interrupted chats. +func (p *Server) maybeSendPushNotification( + ctx context.Context, + chat database.Chat, + status database.ChatStatus, + lastError string, + logger slog.Logger, +) { + if p.webpushDispatcher == nil || p.webpushDispatcher.PublicKey() == "" { + return + } + if chat.ParentChatID.Valid { + return + } + + switch status { + case database.ChatStatusError: + pushBody := "Agent encountered an error." + if lastError != "" { + pushBody = lastError + } + p.dispatchPush(ctx, chat, pushBody, status, logger) + + case database.ChatStatusWaiting: + // Generate a push notification summary asynchronously + // using a cheap LLM model. This avoids blocking the + // deferred cleanup path while still providing a + // meaningful notification body. + p.inflight.Add(1) + go func() { + defer p.inflight.Done() + pushCtx := context.WithoutCancel(ctx) + pushBody := "Agent has finished running." + + msg, err := p.db.GetLastChatMessageByRole(pushCtx, database.GetLastChatMessageByRoleParams{ + ChatID: chat.ID, + Role: "assistant", + }) + if err == nil { + content, parseErr := chatprompt.ParseContent(msg.Role, msg.Content) + if parseErr == nil { + assistantText := strings.TrimSpace(contentBlocksToText(content)) + if assistantText != "" { + model, _, keys, resolveErr := p.resolveChatModel(pushCtx, chat) + if resolveErr == nil { + if summary := generatePushSummary(pushCtx, chat.Title, assistantText, model, keys, logger); summary != "" { + pushBody = summary + } + } + } + } + } + + p.dispatchPush(pushCtx, chat, pushBody, status, logger) + }() + } +} + +func (p *Server) dispatchPush( + ctx context.Context, + chat database.Chat, + body string, + status database.ChatStatus, + logger slog.Logger, +) { + pushMsg := codersdk.WebpushMessage{ + Title: chat.Title, + Body: body, + Icon: "/favicon.ico", + Data: map[string]string{"url": fmt.Sprintf("/agents/%s", chat.ID)}, + } + if err := p.webpushDispatcher.Dispatch(ctx, chat.OwnerID, pushMsg); err != nil { + logger.Warn(ctx, "failed to send chat completion web push", + slog.F("chat_id", chat.ID), + slog.F("status", status), + slog.Error(err), + ) + } +} + // Close stops the processor and waits for it to finish. func (p *Server) Close() error { p.cancel() diff --git a/coderd/chatd/chatd_test.go b/coderd/chatd/chatd_test.go index de93e366a9..028df0bef2 100644 --- a/coderd/chatd/chatd_test.go +++ b/coderd/chatd/chatd_test.go @@ -1482,6 +1482,12 @@ func (m *mockWebpushDispatcher) Dispatch(_ context.Context, userID uuid.UUID, ms return nil } +func (m *mockWebpushDispatcher) getLastMessage() codersdk.WebpushMessage { + m.mu.Lock() + defer m.mu.Unlock() + return m.lastMessage +} + func (*mockWebpushDispatcher) Test(_ context.Context, _ codersdk.WebpushSubscription) error { return nil } @@ -1674,3 +1680,66 @@ func TestCloseDuringShutdownContextCanceledShouldRetryOnNewReplica(t *testing.T) !fromDB.LastError.Valid }, testutil.WaitMedium, testutil.IntervalFast) } + +func TestSuccessfulChatSendsWebPushWithSummary(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + ctx := testutil.Context(t, testutil.WaitLong) + + const assistantText = "I have completed the task successfully and all tests are passing now." + const summaryText = "Completed task and verified all tests pass." + + openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { + if !req.Stream { + // Non-streaming calls are used for title + // generation and push summary generation. + // Return the summary text for both — the title + // result is irrelevant to this test. + return chattest.OpenAINonStreamingResponse(summaryText) + } + return chattest.OpenAIStreamingResponse( + chattest.OpenAITextChunks(assistantText)..., + ) + }) + + mockPush := &mockWebpushDispatcher{} + + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + server := chatd.New(chatd.Config{ + Logger: logger, + Database: db, + ReplicaID: uuid.New(), + Pubsub: ps, + PendingChatAcquireInterval: 10 * time.Millisecond, + InFlightChatStaleAfter: testutil.WaitSuperLong, + WebpushDispatcher: mockPush, + }) + t.Cleanup(func() { + require.NoError(t, server.Close()) + }) + + user, model := seedChatDependencies(ctx, t, db) + setOpenAIProviderBaseURL(ctx, t, db, openAIURL) + + _, err := server.CreateChat(ctx, chatd.CreateOptions{ + OwnerID: user.ID, + Title: "summary-push-test", + ModelConfigID: model.ID, + InitialUserContent: []fantasy.Content{fantasy.TextContent{Text: "do the thing"}}, + }) + require.NoError(t, err) + + // The push notification is dispatched asynchronously after the + // chat finishes, so we poll for it rather than checking + // immediately after the status transitions to waiting. + testutil.Eventually(ctx, t, func(ctx context.Context) bool { + return mockPush.dispatchCount.Load() >= 1 + }, testutil.IntervalFast) + + msg := mockPush.getLastMessage() + require.Equal(t, summaryText, msg.Body, + "push body should be the LLM-generated summary") + require.NotEqual(t, "Agent has finished running.", msg.Body, + "push body should not use the default fallback text") +} diff --git a/coderd/chatd/title.go b/coderd/chatd/quickgen.go similarity index 75% rename from coderd/chatd/title.go rename to coderd/chatd/quickgen.go index 0ab706a634..747ae6788e 100644 --- a/coderd/chatd/title.go +++ b/coderd/chatd/quickgen.go @@ -128,37 +128,11 @@ func generateTitle( model fantasy.LanguageModel, input string, ) (string, error) { - prompt := []fantasy.Message{ - { - Role: fantasy.MessageRoleSystem, - Content: []fantasy.MessagePart{ - fantasy.TextPart{Text: titleGenerationPrompt}, - }, - }, - { - Role: fantasy.MessageRoleUser, - Content: []fantasy.MessagePart{ - fantasy.TextPart{Text: input}, - }, - }, - } - - var maxOutputTokens int64 = 256 - - var response *fantasy.Response - err := chatretry.Retry(ctx, func(retryCtx context.Context) error { - var genErr error - response, genErr = model.Generate(retryCtx, fantasy.Call{ - Prompt: prompt, - MaxOutputTokens: &maxOutputTokens, - }) - return genErr - }, nil) + title, err := generateShortText(ctx, model, titleGenerationPrompt, input) if err != nil { - return "", xerrors.Errorf("generate title text: %w", err) + return "", err } - - title := normalizeTitleOutput(contentBlocksToText(response.Content)) + title = normalizeTitleOutput(title) if title == "" { return "", xerrors.New("generated title was empty") } @@ -278,3 +252,96 @@ func truncateRunes(value string, maxLen int) string { } return string(runes[:maxLen]) } + +const pushSummaryPrompt = "You are a notification assistant. Given a chat title " + + "and the agent's last message, write a single short sentence (under 100 characters) " + + "summarizing what the agent did. This will be shown as a push notification body. " + + "Return plain text only — no quotes, no emoji, no markdown." + +// generatePushSummary calls a cheap model to produce a short push +// notification body from the chat title and the last assistant +// message text. It follows the same candidate-selection strategy +// as title generation: try preferred lightweight models first, then +// fall back to the provided model. Returns "" on any failure. +func generatePushSummary( + ctx context.Context, + chatTitle string, + assistantText string, + fallbackModel fantasy.LanguageModel, + keys chatprovider.ProviderAPIKeys, + logger slog.Logger, +) string { + summaryCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + + input := "Chat title: " + chatTitle + "\n\nAgent's last message:\n" + assistantText + + candidates := make([]fantasy.LanguageModel, 0, len(preferredTitleModels)+1) + for _, c := range preferredTitleModels { + m, err := chatprovider.ModelFromConfig( + c.provider, c.model, keys, + ) + if err == nil { + candidates = append(candidates, m) + } + } + candidates = append(candidates, fallbackModel) + + for _, model := range candidates { + summary, err := generateShortText(summaryCtx, model, pushSummaryPrompt, input) + if err != nil { + logger.Debug(ctx, "push summary model candidate failed", + slog.Error(err), + ) + continue + } + if summary != "" { + return summary + } + } + return "" +} + +// generateShortText calls a model with a system prompt and user +// input, returning a cleaned-up short text response. It reuses the +// same retry logic as title generation. +func generateShortText( + ctx context.Context, + model fantasy.LanguageModel, + systemPrompt string, + userInput string, +) (string, error) { + prompt := []fantasy.Message{ + { + Role: fantasy.MessageRoleSystem, + Content: []fantasy.MessagePart{ + fantasy.TextPart{Text: systemPrompt}, + }, + }, + { + Role: fantasy.MessageRoleUser, + Content: []fantasy.MessagePart{ + fantasy.TextPart{Text: userInput}, + }, + }, + } + + var maxOutputTokens int64 = 256 + + var response *fantasy.Response + err := chatretry.Retry(ctx, func(retryCtx context.Context) error { + var genErr error + response, genErr = model.Generate(retryCtx, fantasy.Call{ + Prompt: prompt, + MaxOutputTokens: &maxOutputTokens, + }) + return genErr + }, nil) + if err != nil { + return "", xerrors.Errorf("generate short text: %w", err) + } + + text := strings.TrimSpace(contentBlocksToText(response.Content)) + text = strings.Trim(text, "\"'`") + return text, nil +} diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index 9b63649ada..a803260676 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -2831,6 +2831,15 @@ func (q *querier) GetInboxNotificationsByUserID(ctx context.Context, userID data return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetInboxNotificationsByUserID)(ctx, userID) } +func (q *querier) GetLastChatMessageByRole(ctx context.Context, arg database.GetLastChatMessageByRoleParams) (database.ChatMessage, error) { + // Authorize read on the parent chat. + _, err := q.GetChatByID(ctx, arg.ChatID) + if err != nil { + return database.ChatMessage{}, err + } + return q.db.GetLastChatMessageByRole(ctx, arg) +} + func (q *querier) GetLastUpdateCheck(ctx context.Context) (string, error) { if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil { return "", err diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 3dbb103c4c..5eb590af7b 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -488,6 +488,14 @@ func (s *MethodTestSuite) TestChats() { dbm.EXPECT().GetChatMessagesByChatID(gomock.Any(), arg).Return(msgs, nil).AnyTimes() check.Args(arg).Asserts(chat, policy.ActionRead).Returns(msgs) })) + s.Run("GetLastChatMessageByRole", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + chat := testutil.Fake(s.T(), faker, database.Chat{}) + msg := testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID}) + arg := database.GetLastChatMessageByRoleParams{ChatID: chat.ID, Role: "assistant"} + dbm.EXPECT().GetChatByID(gomock.Any(), chat.ID).Return(chat, nil).AnyTimes() + dbm.EXPECT().GetLastChatMessageByRole(gomock.Any(), arg).Return(msg, nil).AnyTimes() + check.Args(arg).Asserts(chat, policy.ActionRead).Returns(msg) + })) s.Run("GetChatMessagesForPromptByChatID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { chat := testutil.Fake(s.T(), faker, database.Chat{}) msgs := []database.ChatMessage{testutil.Fake(s.T(), faker, database.ChatMessage{ChatID: chat.ID})} diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index 876e4535df..7739254bef 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -1407,6 +1407,14 @@ func (m queryMetricsStore) GetInboxNotificationsByUserID(ctx context.Context, ar return r0, r1 } +func (m queryMetricsStore) GetLastChatMessageByRole(ctx context.Context, arg database.GetLastChatMessageByRoleParams) (database.ChatMessage, error) { + start := time.Now() + r0, r1 := m.s.GetLastChatMessageByRole(ctx, arg) + m.queryLatencies.WithLabelValues("GetLastChatMessageByRole").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetLastChatMessageByRole").Inc() + return r0, r1 +} + func (m queryMetricsStore) GetLastUpdateCheck(ctx context.Context) (string, error) { start := time.Now() r0, r1 := m.s.GetLastUpdateCheck(ctx) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index 8d81f43a43..33bc2653d6 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -2587,6 +2587,21 @@ func (mr *MockStoreMockRecorder) GetInboxNotificationsByUserID(ctx, arg any) *go return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetInboxNotificationsByUserID", reflect.TypeOf((*MockStore)(nil).GetInboxNotificationsByUserID), ctx, arg) } +// GetLastChatMessageByRole mocks base method. +func (m *MockStore) GetLastChatMessageByRole(ctx context.Context, arg database.GetLastChatMessageByRoleParams) (database.ChatMessage, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetLastChatMessageByRole", ctx, arg) + ret0, _ := ret[0].(database.ChatMessage) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetLastChatMessageByRole indicates an expected call of GetLastChatMessageByRole. +func (mr *MockStoreMockRecorder) GetLastChatMessageByRole(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetLastChatMessageByRole", reflect.TypeOf((*MockStore)(nil).GetLastChatMessageByRole), ctx, arg) +} + // GetLastUpdateCheck mocks base method. func (m *MockStore) GetLastUpdateCheck(ctx context.Context) (string, error) { m.ctrl.T.Helper() diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 5d8138694f..e29227e459 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -284,6 +284,7 @@ type sqlcQuerier interface { // param created_at_opt: The created_at timestamp to filter by. This parameter is usd for pagination - it fetches notifications created before the specified timestamp if it is not the zero value // param limit_opt: The limit of notifications to fetch. If the limit is not specified, it defaults to 25 GetInboxNotificationsByUserID(ctx context.Context, arg GetInboxNotificationsByUserIDParams) ([]InboxNotification, error) + GetLastChatMessageByRole(ctx context.Context, arg GetLastChatMessageByRoleParams) (ChatMessage, error) GetLastUpdateCheck(ctx context.Context) (string, error) GetLatestCryptoKeyByFeature(ctx context.Context, feature CryptoKeyFeature) (CryptoKey, error) GetLatestWorkspaceAppStatusByAppID(ctx context.Context, appID uuid.UUID) (WorkspaceAppStatus, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 9632fd21b1..0f6c2db116 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -3536,6 +3536,48 @@ func (q *sqlQuerier) GetChatsByOwnerID(ctx context.Context, arg GetChatsByOwnerI return items, nil } +const getLastChatMessageByRole = `-- name: GetLastChatMessageByRole :one +SELECT + id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed +FROM + chat_messages +WHERE + chat_id = $1::uuid + AND role = $2::text +ORDER BY + created_at DESC, id DESC +LIMIT + 1 +` + +type GetLastChatMessageByRoleParams struct { + ChatID uuid.UUID `db:"chat_id" json:"chat_id"` + Role string `db:"role" json:"role"` +} + +func (q *sqlQuerier) GetLastChatMessageByRole(ctx context.Context, arg GetLastChatMessageByRoleParams) (ChatMessage, error) { + row := q.db.QueryRowContext(ctx, getLastChatMessageByRole, arg.ChatID, arg.Role) + var i ChatMessage + err := row.Scan( + &i.ID, + &i.ChatID, + &i.ModelConfigID, + &i.CreatedAt, + &i.Role, + &i.Content, + &i.Visibility, + &i.InputTokens, + &i.OutputTokens, + &i.TotalTokens, + &i.ReasoningTokens, + &i.CacheCreationTokens, + &i.CacheReadTokens, + &i.ContextLimit, + &i.Compressed, + ) + return i, err +} + const getStaleChats = `-- name: GetStaleChats :many SELECT id, owner_id, workspace_id, title, status, worker_id, started_at, heartbeat_at, created_at, updated_at, parent_chat_id, root_chat_id, last_model_config_id, archived, last_error diff --git a/coderd/database/queries/chats.sql b/coderd/database/queries/chats.sql index 79257ebe58..d4d73fc136 100644 --- a/coderd/database/queries/chats.sql +++ b/coderd/database/queries/chats.sql @@ -433,5 +433,18 @@ WHERE id = ( ) RETURNING *; +-- name: GetLastChatMessageByRole :one +SELECT + * +FROM + chat_messages +WHERE + chat_id = @chat_id::uuid + AND role = @role::text +ORDER BY + created_at DESC, id DESC +LIMIT + 1; + -- name: GetChatByIDForUpdate :one SELECT * FROM chats WHERE id = @id::uuid FOR UPDATE;