fix(coderd/x/chatd): preserve chat API key after compaction (#25930) (#25974)

This commit is contained in:
github-actions[bot]
2026-06-02 09:52:58 -04:00
committed by GitHub
parent 26c035d742
commit 2f60b14649
3 changed files with 182 additions and 53 deletions
+20 -14
View File
@@ -6468,22 +6468,14 @@ type runChatResult struct {
HistoryTipMessageID int64 HistoryTipMessageID int64
} }
func contextWithActiveTurnAPIKeyID(ctx context.Context, messages []database.ChatMessage) context.Context {
apiKeyID, ok := activeTurnAPIKeyIDFromMessages(messages)
if !ok {
return ctx
}
return aibridge.WithDelegatedAPIKeyID(ctx, apiKeyID)
}
func activeTurnAPIKeyIDFromMessages(messages []database.ChatMessage) (string, bool) { func activeTurnAPIKeyIDFromMessages(messages []database.ChatMessage) (string, bool) {
for i := len(messages) - 1; i >= 0; i-- { for i := len(messages) - 1; i >= 0; i-- {
message := messages[i] message := messages[i]
if message.Role != database.ChatMessageRoleUser { if message.Role != database.ChatMessageRoleUser {
continue continue
} }
if message.Visibility != database.ChatMessageVisibilityBoth && if !isUserVisibleChatMessage(message) &&
message.Visibility != database.ChatMessageVisibilityUser { !(message.Visibility == database.ChatMessageVisibilityModel && message.Compressed) {
continue continue
} }
if !message.APIKeyID.Valid || message.APIKeyID.String == "" { if !message.APIKeyID.Valid || message.APIKeyID.String == "" {
@@ -6494,6 +6486,11 @@ func activeTurnAPIKeyIDFromMessages(messages []database.ChatMessage) (string, bo
return "", false return "", false
} }
func isUserVisibleChatMessage(message database.ChatMessage) bool {
return message.Visibility == database.ChatMessageVisibilityBoth ||
message.Visibility == database.ChatMessageVisibilityUser
}
func allToolNames(allTools []fantasy.AgentTool) []string { func allToolNames(allTools []fantasy.AgentTool) []string {
toolNames := make([]string, 0, len(allTools)) toolNames := make([]string, 0, len(allTools))
for _, tool := range allTools { for _, tool := range allTools {
@@ -7124,7 +7121,9 @@ func (p *Server) runChat(
return result, xerrors.Errorf("get chat messages: %w", err) return result, xerrors.Errorf("get chat messages: %w", err)
} }
modelOpts := modelBuildOptionsFromMessages(messages) modelOpts := modelBuildOptionsFromMessages(messages)
ctx = contextWithActiveTurnAPIKeyID(ctx, messages) if modelOpts.ActiveAPIKeyID != "" {
ctx = aibridge.WithDelegatedAPIKeyID(ctx, modelOpts.ActiveAPIKeyID)
}
// Load MCP server configs and user tokens in parallel with model // Load MCP server configs and user tokens in parallel with model
// resolution. These queries have no dependencies on each other and all // resolution. These queries have no dependencies on each other and all
@@ -7831,6 +7830,7 @@ func (p *Server) runChat(
persistCtx, persistCtx,
chat.ID, chat.ID,
modelConfig.ID, modelConfig.ID,
modelOpts.ActiveAPIKeyID,
compactionToolCallID, compactionToolCallID,
result, result,
); err != nil { ); err != nil {
@@ -8460,12 +8460,14 @@ func buildProviderTools(options *codersdk.ChatModelProviderOptions) []chatloop.P
return tools return tools
} }
// persistChatContextSummary persists a chat context summary to the database. // persistChatContextSummary is called from the chat loop's compaction
// This is invoked via the chat loop's compaction callback. // callback. activeAPIKeyID is stamped onto the summary user message. When
// empty, it falls back to the delegated key in ctx.
func (p *Server) persistChatContextSummary( func (p *Server) persistChatContextSummary(
ctx context.Context, ctx context.Context,
chatID uuid.UUID, chatID uuid.UUID,
modelConfigID uuid.UUID, modelConfigID uuid.UUID,
activeAPIKeyID string,
toolCallID string, toolCallID string,
result chatloop.CompactionResult, result chatloop.CompactionResult,
) error { ) error {
@@ -8514,6 +8516,11 @@ func (p *Server) persistChatContextSummary(
return xerrors.Errorf("encode summary tool result: %w", err) return xerrors.Errorf("encode summary tool result: %w", err)
} }
summaryAPIKeyID := activeAPIKeyID
if summaryAPIKeyID == "" {
summaryAPIKeyID, _ = aibridge.DelegatedAPIKeyIDFromContext(ctx)
}
var insertedMessages []database.ChatMessage var insertedMessages []database.ChatMessage
txErr := p.db.InTx(func(tx database.Store) error { txErr := p.db.InTx(func(tx database.Store) error {
@@ -8522,7 +8529,6 @@ func (p *Server) persistChatContextSummary(
} }
// Hidden summary user message (not published to subscribers). // Hidden summary user message (not published to subscribers).
summaryAPIKeyID, _ := aibridge.DelegatedAPIKeyIDFromContext(ctx)
summaryUserMsg := newUserChatMessage( summaryUserMsg := newUserChatMessage(
summaryAPIKeyID, summaryAPIKeyID,
systemContent, systemContent,
+52 -31
View File
@@ -6651,42 +6651,63 @@ func TestPersistChatContextSummarySetsAPIKeyID(t *testing.T) {
UserID: user.ID, UserID: user.ID,
}) })
ctx = aibridge.WithDelegatedAPIKeyID(ctx, apiKey.ID)
server := &Server{db: db} server := &Server{db: db}
persistAndAssertSummaryKey := func(
summaryCtx context.Context,
chatID uuid.UUID,
activeAPIKeyID string,
wantAPIKeyID string,
toolCallID string,
) {
t.Helper()
err := server.persistChatContextSummary( err := server.persistChatContextSummary(
ctx, summaryCtx,
chat.ID, chatID,
modelConfig.ID, modelConfig.ID,
"tool-call-id-1", activeAPIKeyID,
chatloop.CompactionResult{ toolCallID,
SystemSummary: "summarized context", chatloop.CompactionResult{
SummaryReport: "context was summarized", SystemSummary: "summarized context",
ThresholdPercent: 70, SummaryReport: "context was summarized",
UsagePercent: 85.0, ThresholdPercent: 70,
ContextTokens: 8500, UsagePercent: 85.0,
ContextLimit: 10000, ContextTokens: 8500,
}, ContextLimit: 10000,
) },
require.NoError(t, err) )
require.NoError(t, err)
msgs, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID) msgs, err := db.GetChatMessagesForPromptByChatID(ctx, chatID)
require.NoError(t, err) require.NoError(t, err)
// GetChatMessagesForPromptByChatID uses a compaction boundary CTE // GetChatMessagesForPromptByChatID uses a compaction boundary CTE
// that selects compressed=true, visibility='model'. Only the user // that selects compressed=true, visibility='model'. Only the user
// summary qualifies; the assistant (visibility=user) and tool // summary qualifies; the assistant (visibility=user) and tool
// result (visibility=both) are excluded by the CTE filter. // result (visibility=both) are excluded by the CTE filter.
require.NotEmpty(t, msgs) require.NotEmpty(t, msgs)
var foundUserSummary bool var foundUserSummary bool
for _, msg := range msgs { for _, msg := range msgs {
if msg.Role == database.ChatMessageRoleUser { if msg.Role == database.ChatMessageRoleUser {
foundUserSummary = true foundUserSummary = true
require.True(t, msg.APIKeyID.Valid, "summary user message must have APIKeyID set") require.True(t, msg.APIKeyID.Valid, "summary user message must have APIKeyID set")
require.Equal(t, apiKey.ID, msg.APIKeyID.String, "summary user message APIKeyID must match") require.Equal(t, wantAPIKeyID, msg.APIKeyID.String, "summary user message APIKeyID must match")
}
} }
require.True(t, foundUserSummary, "expected to find compressed user summary message")
} }
require.True(t, foundUserSummary, "expected to find compressed user summary message")
persistAndAssertSummaryKey(ctx, chat.ID, apiKey.ID, apiKey.ID, "tool-call-id-1")
fallbackChat := dbgen.Chat(t, db, database.Chat{
OwnerID: user.ID,
OrganizationID: org.ID,
LastModelConfigID: modelConfig.ID,
})
fallbackKey, _ := dbgen.APIKey(t, db, database.APIKey{
UserID: user.ID,
})
fallbackCtx := aibridge.WithDelegatedAPIKeyID(ctx, fallbackKey.ID)
persistAndAssertSummaryKey(fallbackCtx, fallbackChat.ID, "", fallbackKey.ID, "tool-call-id-2")
} }
+110 -8
View File
@@ -405,7 +405,7 @@ func TestActiveTurnAPIKeyIDFromMessages(t *testing.T) {
}, },
}, },
{ {
name: "SkipsModelOnlyUserMessages", name: "SkipsUncompressedModelOnlyUserMessages",
messages: []database.ChatMessage{ messages: []database.ChatMessage{
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityBoth, APIKeyID: sqlNullString(oldKeyID)}, {ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityBoth, APIKeyID: sqlNullString(oldKeyID)},
{ID: 2, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, APIKeyID: sqlNullString(currentKeyID)}, {ID: 2, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, APIKeyID: sqlNullString(currentKeyID)},
@@ -413,6 +413,54 @@ func TestActiveTurnAPIKeyIDFromMessages(t *testing.T) {
wantKey: oldKeyID, wantKey: oldKeyID,
wantOK: true, wantOK: true,
}, },
{
name: "CompressedSummaryFallback",
messages: []database.ChatMessage{
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, Compressed: true, APIKeyID: sqlNullString(currentKeyID)},
{ID: 2, Role: database.ChatMessageRoleAssistant, Visibility: database.ChatMessageVisibilityBoth},
},
wantKey: currentKeyID,
wantOK: true,
},
{
name: "LatestCompressedSummaryWins",
messages: []database.ChatMessage{
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, Compressed: true, APIKeyID: sqlNullString(oldKeyID)},
{ID: 2, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, Compressed: true, APIKeyID: sqlNullString(currentKeyID)},
{ID: 3, Role: database.ChatMessageRoleAssistant, Visibility: database.ChatMessageVisibilityBoth},
},
wantKey: currentKeyID,
wantOK: true,
},
{
name: "VisibleUserWinsOverCompressedSummary",
messages: []database.ChatMessage{
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, Compressed: true, APIKeyID: sqlNullString(oldKeyID)},
{ID: 2, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityBoth, APIKeyID: sqlNullString(currentKeyID)},
},
wantKey: currentKeyID,
wantOK: true,
},
{
name: "MissingVisibleUserKeyDoesNotFallBackToCompressedSummary",
messages: []database.ChatMessage{
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, Compressed: true, APIKeyID: sqlNullString(oldKeyID)},
{ID: 2, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityBoth},
},
},
{
name: "UncompressedModelOnlyUserIgnored",
messages: []database.ChatMessage{
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, APIKeyID: sqlNullString(currentKeyID)},
},
},
{
name: "CompressedSummaryMissingKeyDoesNotFallBack",
messages: []database.ChatMessage{
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityBoth, APIKeyID: sqlNullString(oldKeyID)},
{ID: 2, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, Compressed: true},
},
},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@@ -421,15 +469,11 @@ func TestActiveTurnAPIKeyIDFromMessages(t *testing.T) {
gotKey, gotOK := activeTurnAPIKeyIDFromMessages(tt.messages) gotKey, gotOK := activeTurnAPIKeyIDFromMessages(tt.messages)
require.Equal(t, tt.wantOK, gotOK) require.Equal(t, tt.wantOK, gotOK)
require.Equal(t, tt.wantKey, gotKey) require.Equal(t, tt.wantKey, gotKey)
ctx := contextWithActiveTurnAPIKeyID(t.Context(), tt.messages)
ctxKey, ctxOK := aibridge.DelegatedAPIKeyIDFromContext(ctx)
require.Equal(t, tt.wantOK, ctxOK)
require.Equal(t, tt.wantKey, ctxKey)
}) })
} }
} }
func TestActiveTurnContextUsesPromptMessages(t *testing.T) { func TestPromptMessagesForVisibleUserPreserveActiveAPIKeyID(t *testing.T) {
t.Parallel() t.Parallel()
db, _ := dbtestutil.NewDB(t) db, _ := dbtestutil.NewDB(t)
@@ -477,12 +521,70 @@ func TestActiveTurnContextUsesPromptMessages(t *testing.T) {
messages, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID) messages, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
require.NoError(t, err) require.NoError(t, err)
ctx = contextWithActiveTurnAPIKeyID(ctx, messages) gotKey, ok := activeTurnAPIKeyIDFromMessages(messages)
gotKey, ok := aibridge.DelegatedAPIKeyIDFromContext(ctx)
require.True(t, ok) require.True(t, ok)
require.Equal(t, currentKey.ID, gotKey) require.Equal(t, currentKey.ID, gotKey)
} }
func TestPromptMessagesForCompactedChatPreserveActiveAPIKeyID(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := t.Context()
user := dbgen.User(t, db, database.User{})
org := dbgen.Organization(t, db, database.Organization{})
model := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{})
chat := dbgen.Chat(t, db, database.Chat{OrganizationID: org.ID, OwnerID: user.ID, LastModelConfigID: model.ID})
key, _ := dbgen.APIKey(t, db, database.APIKey{UserID: user.ID})
visibleUser := dbgen.ChatMessage(t, db, database.ChatMessage{
ChatID: chat.ID,
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
Role: database.ChatMessageRoleUser,
Visibility: database.ChatMessageVisibilityBoth,
APIKeyID: sqlNullString(key.ID),
})
dbgen.ChatMessage(t, db, database.ChatMessage{
ChatID: chat.ID,
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
Role: database.ChatMessageRoleAssistant,
Visibility: database.ChatMessageVisibilityBoth,
})
compressedSummary := dbgen.ChatMessage(t, db, database.ChatMessage{
ChatID: chat.ID,
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
Role: database.ChatMessageRoleUser,
Visibility: database.ChatMessageVisibilityModel,
Compressed: true,
APIKeyID: sqlNullString(key.ID),
})
afterSummary := dbgen.ChatMessage(t, db, database.ChatMessage{
ChatID: chat.ID,
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
Role: database.ChatMessageRoleAssistant,
Visibility: database.ChatMessageVisibilityBoth,
})
messages, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
require.NoError(t, err)
ids := make(map[int64]struct{}, len(messages))
for _, message := range messages {
ids[message.ID] = struct{}{}
}
_, hasVisibleUser := ids[visibleUser.ID]
require.False(t, hasVisibleUser)
_, hasSummary := ids[compressedSummary.ID]
require.True(t, hasSummary)
_, hasAfterSummary := ids[afterSummary.ID]
require.True(t, hasAfterSummary)
gotKey, ok := activeTurnAPIKeyIDFromMessages(messages)
require.True(t, ok)
require.Equal(t, key.ID, gotKey)
}
func sqlNullString(value string) sql.NullString { func sqlNullString(value string) sql.NullString {
return sql.NullString{String: value, Valid: value != ""} return sql.NullString{String: value, Valid: value != ""}
} }