diff --git a/coderd/chatd/chatd.go b/coderd/chatd/chatd.go index faacc91b17..0475ef8989 100644 --- a/coderd/chatd/chatd.go +++ b/coderd/chatd/chatd.go @@ -2582,7 +2582,7 @@ func (p *Server) persistChatContextSummary( _, txErr := tx.InsertChatMessage(ctx, database.InsertChatMessageParams{ ChatID: chatID, ModelConfigID: uuid.NullUUID{UUID: modelConfigID, Valid: true}, - Role: string(fantasy.MessageRoleSystem), + Role: string(fantasy.MessageRoleUser), Content: pqtype.NullRawMessage{ RawMessage: systemContent, Valid: len(systemContent) > 0, diff --git a/coderd/chatd/chatloop/compaction_test.go b/coderd/chatd/chatloop/compaction_test.go index 4e3c6df7bd..254dc8b57e 100644 --- a/coderd/chatd/chatloop/compaction_test.go +++ b/coderd/chatd/chatloop/compaction_test.go @@ -600,4 +600,117 @@ func TestRun_Compaction(t *testing.T) { // Two stream calls: one before compaction, one after re-entry. require.Equal(t, 2, streamCallCount) }) + + t.Run("PostRunCompactionReEntryIncludesUserSummary", func(t *testing.T) { + t.Parallel() + + // After compaction the summary is stored as a user-role + // message. When the loop re-enters, the reloaded prompt + // must contain this user message so the LLM provider + // receives a valid prompt (providers like Anthropic + // require at least one non-system message). + + var mu sync.Mutex + var streamCallCount int + var reEntryPrompt []fantasy.Message + persistCompactionCalls := 0 + + const summaryText = "post-run compacted summary" + + model := &loopTestModel{ + provider: "fake", + streamFn: func(_ context.Context, call fantasy.Call) (fantasy.StreamResponse, error) { + mu.Lock() + step := streamCallCount + streamCallCount++ + mu.Unlock() + + switch step { + case 0: + return streamFromParts([]fantasy.StreamPart{ + {Type: fantasy.StreamPartTypeTextStart, ID: "text-1"}, + {Type: fantasy.StreamPartTypeTextDelta, ID: "text-1", Delta: "initial response"}, + {Type: fantasy.StreamPartTypeTextEnd, ID: "text-1"}, + { + Type: fantasy.StreamPartTypeFinish, + FinishReason: fantasy.FinishReasonStop, + Usage: fantasy.Usage{ + InputTokens: 80, + TotalTokens: 85, + }, + }, + }), nil + default: + mu.Lock() + reEntryPrompt = append([]fantasy.Message(nil), call.Prompt...) + mu.Unlock() + return streamFromParts([]fantasy.StreamPart{ + {Type: fantasy.StreamPartTypeTextStart, ID: "text-2"}, + {Type: fantasy.StreamPartTypeTextDelta, ID: "text-2", Delta: "continued"}, + {Type: fantasy.StreamPartTypeTextEnd, ID: "text-2"}, + { + Type: fantasy.StreamPartTypeFinish, + FinishReason: fantasy.FinishReasonStop, + Usage: fantasy.Usage{ + InputTokens: 20, + TotalTokens: 25, + }, + }, + }), nil + } + }, + generateFn: func(_ context.Context, _ fantasy.Call) (*fantasy.Response, error) { + return &fantasy.Response{ + Content: []fantasy.Content{ + fantasy.TextContent{Text: summaryText}, + }, + }, nil + }, + } + + // Simulate real post-compaction DB state: the summary is + // a user-role message (the only non-system content). + compactedMessages := []fantasy.Message{ + textMessage(fantasy.MessageRoleSystem, "system prompt"), + textMessage(fantasy.MessageRoleUser, "Summary of earlier chat context:\n\ncompacted summary"), + } + + err := Run(context.Background(), RunOptions{ + Model: model, + Messages: []fantasy.Message{ + textMessage(fantasy.MessageRoleUser, "hello"), + }, + MaxSteps: 5, + PersistStep: func(_ context.Context, _ PersistedStep) error { + return nil + }, + ContextLimitFallback: 100, + Compaction: &CompactionOptions{ + ThresholdPercent: 70, + SummaryPrompt: "summarize now", + Persist: func(_ context.Context, _ CompactionResult) error { + persistCompactionCalls++ + return nil + }, + }, + ReloadMessages: func(_ context.Context) ([]fantasy.Message, error) { + return compactedMessages, nil + }, + }) + require.NoError(t, err) + + require.GreaterOrEqual(t, persistCompactionCalls, 1) + // Re-entry happened: stream was called at least twice. + require.Equal(t, 2, streamCallCount) + // The re-entry prompt must contain the user summary. + require.NotEmpty(t, reEntryPrompt) + hasUser := false + for _, msg := range reEntryPrompt { + if msg.Role == fantasy.MessageRoleUser { + hasUser = true + break + } + } + require.True(t, hasUser, "re-entry prompt must contain a user message (the compaction summary)") + }) } diff --git a/coderd/database/querier_test.go b/coderd/database/querier_test.go index 3b2cda5983..428273bd1d 100644 --- a/coderd/database/querier_test.go +++ b/coderd/database/querier_test.go @@ -8841,3 +8841,202 @@ func TestInsertWorkspaceAgentDevcontainers(t *testing.T) { }) } } + +func TestGetChatMessagesForPromptByChatID(t *testing.T) { + t.Parallel() + + // This test exercises a complex CTE query for prompt + // reconstruction after compaction. It requires Postgres. + db, _ := dbtestutil.NewDB(t) + ctx := context.Background() + + // Helper: create a chat model config (required FK for chats). + user := dbgen.User(t, db, database.User{}) + + // A chat_providers row is required as a FK for model configs. + _, err := db.InsertChatProvider(ctx, database.InsertChatProviderParams{ + Provider: "openai", + DisplayName: "OpenAI", + APIKey: "test-key", + Enabled: true, + }) + require.NoError(t, err) + + modelCfg, err := db.InsertChatModelConfig(ctx, database.InsertChatModelConfigParams{ + Provider: "openai", + Model: "test-model", + DisplayName: "Test Model", + CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true}, + UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true}, + Enabled: true, + IsDefault: true, + ContextLimit: 128000, + CompressionThreshold: 80, + Options: json.RawMessage(`{}`), + }) + require.NoError(t, err) + + newChat := func(t *testing.T) database.Chat { + t.Helper() + chat, err := db.InsertChat(ctx, database.InsertChatParams{ + OwnerID: user.ID, + LastModelConfigID: modelCfg.ID, + Title: "test-chat-" + uuid.NewString(), + }) + require.NoError(t, err) + return chat + } + + insertMsg := func( + t *testing.T, + chatID uuid.UUID, + role string, + vis database.ChatMessageVisibility, + compressed bool, + content string, + ) database.ChatMessage { + t.Helper() + msg, err := db.InsertChatMessage(ctx, database.InsertChatMessageParams{ + ChatID: chatID, + Role: role, + Visibility: vis, + Compressed: sql.NullBool{Bool: compressed, Valid: true}, + Content: pqtype.NullRawMessage{ + RawMessage: json.RawMessage(`"` + content + `"`), + Valid: true, + }, + }) + require.NoError(t, err) + return msg + } + + msgIDs := func(msgs []database.ChatMessage) []int64 { + ids := make([]int64, len(msgs)) + for i, m := range msgs { + ids[i] = m.ID + } + return ids + } + + t.Run("NoCompaction", func(t *testing.T) { + t.Parallel() + chat := newChat(t) + + sys := insertMsg(t, chat.ID, "system", database.ChatMessageVisibilityModel, false, "system prompt") + usr := insertMsg(t, chat.ID, "user", database.ChatMessageVisibilityBoth, false, "hello") + ast := insertMsg(t, chat.ID, "assistant", database.ChatMessageVisibilityBoth, false, "hi there") + + got, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID) + require.NoError(t, err) + require.Equal(t, []int64{sys.ID, usr.ID, ast.ID}, msgIDs(got)) + }) + + t.Run("UserOnlyVisibilityExcluded", func(t *testing.T) { + t.Parallel() + chat := newChat(t) + + // Messages with visibility=user should NOT appear in the + // prompt (they are only for the UI). + insertMsg(t, chat.ID, "system", database.ChatMessageVisibilityModel, false, "system prompt") + insertMsg(t, chat.ID, "user", database.ChatMessageVisibilityUser, false, "user-only msg") + usr := insertMsg(t, chat.ID, "user", database.ChatMessageVisibilityBoth, false, "hello") + + got, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID) + require.NoError(t, err) + for _, m := range got { + require.NotEqual(t, database.ChatMessageVisibilityUser, m.Visibility, + "visibility=user messages should not appear in the prompt") + } + require.Contains(t, msgIDs(got), usr.ID) + }) + + t.Run("AfterCompaction", func(t *testing.T) { + t.Parallel() + chat := newChat(t) + + // Pre-compaction conversation. + sys := insertMsg(t, chat.ID, "system", database.ChatMessageVisibilityModel, false, "system prompt") + preUser := insertMsg(t, chat.ID, "user", database.ChatMessageVisibilityBoth, false, "old question") + preAsst := insertMsg(t, chat.ID, "assistant", database.ChatMessageVisibilityBoth, false, "old answer") + + // Compaction messages: + // 1. Summary (role=user, visibility=model, compressed=true). + summary := insertMsg(t, chat.ID, "user", database.ChatMessageVisibilityModel, true, "compaction summary") + // 2. Compressed assistant tool-call (visibility=user). + insertMsg(t, chat.ID, "assistant", database.ChatMessageVisibilityUser, true, "tool call") + // 3. Compressed tool result (visibility=both). + insertMsg(t, chat.ID, "tool", database.ChatMessageVisibilityBoth, true, "tool result") + + // Post-compaction messages. + postUser := insertMsg(t, chat.ID, "user", database.ChatMessageVisibilityBoth, false, "new question") + postAsst := insertMsg(t, chat.ID, "assistant", database.ChatMessageVisibilityBoth, false, "new answer") + + got, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID) + require.NoError(t, err) + + gotIDs := msgIDs(got) + + // Must include: system prompt, summary, post-compaction. + require.Contains(t, gotIDs, sys.ID, "system prompt must be included") + require.Contains(t, gotIDs, summary.ID, "compaction summary must be included") + require.Contains(t, gotIDs, postUser.ID, "post-compaction user msg must be included") + require.Contains(t, gotIDs, postAsst.ID, "post-compaction assistant msg must be included") + + // Must exclude: pre-compaction non-system messages. + require.NotContains(t, gotIDs, preUser.ID, "pre-compaction user msg must be excluded") + require.NotContains(t, gotIDs, preAsst.ID, "pre-compaction assistant msg must be excluded") + + // Verify ordering. + require.Equal(t, []int64{sys.ID, summary.ID, postUser.ID, postAsst.ID}, gotIDs) + }) + + t.Run("AfterCompactionSummaryIsUserRole", func(t *testing.T) { + t.Parallel() + chat := newChat(t) + + // After compaction the summary must appear as role=user so + // that LLM APIs (e.g. Anthropic) see at least one + // non-system message in the prompt. + insertMsg(t, chat.ID, "system", database.ChatMessageVisibilityModel, false, "system prompt") + summary := insertMsg(t, chat.ID, "user", database.ChatMessageVisibilityModel, true, "summary text") + newUsr := insertMsg(t, chat.ID, "user", database.ChatMessageVisibilityBoth, false, "new question") + + got, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID) + require.NoError(t, err) + + hasNonSystem := false + for _, m := range got { + if m.Role != "system" { + hasNonSystem = true + break + } + } + require.True(t, hasNonSystem, + "prompt must contain at least one non-system message after compaction") + require.Contains(t, msgIDs(got), summary.ID) + require.Contains(t, msgIDs(got), newUsr.ID) + }) + + t.Run("CompressedToolResultNotPickedAsSummary", func(t *testing.T) { + t.Parallel() + chat := newChat(t) + + // The CTE uses visibility='model' (exact match). If it + // used IN ('model','both'), the compressed tool result + // (visibility=both) would be picked as the "summary" + // instead of the actual summary. + insertMsg(t, chat.ID, "system", database.ChatMessageVisibilityModel, false, "system prompt") + summary := insertMsg(t, chat.ID, "user", database.ChatMessageVisibilityModel, true, "real summary") + compressedTool := insertMsg(t, chat.ID, "tool", database.ChatMessageVisibilityBoth, true, "tool result") + postUser := insertMsg(t, chat.ID, "user", database.ChatMessageVisibilityBoth, false, "follow-up") + + got, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID) + require.NoError(t, err) + + gotIDs := msgIDs(got) + require.Contains(t, gotIDs, summary.ID, "real summary must be included") + require.NotContains(t, gotIDs, compressedTool.ID, + "compressed tool result must not be included") + require.Contains(t, gotIDs, postUser.ID) + }) +} diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 931aa9bea6..5c051af0c7 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -3321,9 +3321,8 @@ WITH latest_compressed_summary AS ( chat_messages WHERE chat_id = $1::uuid - AND role = 'system' - AND visibility IN ('model', 'both') AND compressed = TRUE + AND visibility = 'model' ORDER BY created_at DESC, id DESC diff --git a/coderd/database/queries/chats.sql b/coderd/database/queries/chats.sql index 6a5748181a..c52c89c25c 100644 --- a/coderd/database/queries/chats.sql +++ b/coderd/database/queries/chats.sql @@ -54,9 +54,8 @@ WITH latest_compressed_summary AS ( chat_messages WHERE chat_id = @chat_id::uuid - AND role = 'system' - AND visibility IN ('model', 'both') AND compressed = TRUE + AND visibility = 'model' ORDER BY created_at DESC, id DESC