fix(coderd/x/chatd): preserve chat API key after compaction (#25930) (#25974)

This commit is contained in:
github-actions[bot]
2026-06-02 09:52:58 -04:00
committed by GitHub
parent 26c035d742
commit 2f60b14649
3 changed files with 182 additions and 53 deletions
+20 -14
View File
@@ -6468,22 +6468,14 @@ type runChatResult struct {
HistoryTipMessageID int64
}
func contextWithActiveTurnAPIKeyID(ctx context.Context, messages []database.ChatMessage) context.Context {
apiKeyID, ok := activeTurnAPIKeyIDFromMessages(messages)
if !ok {
return ctx
}
return aibridge.WithDelegatedAPIKeyID(ctx, apiKeyID)
}
func activeTurnAPIKeyIDFromMessages(messages []database.ChatMessage) (string, bool) {
for i := len(messages) - 1; i >= 0; i-- {
message := messages[i]
if message.Role != database.ChatMessageRoleUser {
continue
}
if message.Visibility != database.ChatMessageVisibilityBoth &&
message.Visibility != database.ChatMessageVisibilityUser {
if !isUserVisibleChatMessage(message) &&
!(message.Visibility == database.ChatMessageVisibilityModel && message.Compressed) {
continue
}
if !message.APIKeyID.Valid || message.APIKeyID.String == "" {
@@ -6494,6 +6486,11 @@ func activeTurnAPIKeyIDFromMessages(messages []database.ChatMessage) (string, bo
return "", false
}
func isUserVisibleChatMessage(message database.ChatMessage) bool {
return message.Visibility == database.ChatMessageVisibilityBoth ||
message.Visibility == database.ChatMessageVisibilityUser
}
func allToolNames(allTools []fantasy.AgentTool) []string {
toolNames := make([]string, 0, len(allTools))
for _, tool := range allTools {
@@ -7124,7 +7121,9 @@ func (p *Server) runChat(
return result, xerrors.Errorf("get chat messages: %w", err)
}
modelOpts := modelBuildOptionsFromMessages(messages)
ctx = contextWithActiveTurnAPIKeyID(ctx, messages)
if modelOpts.ActiveAPIKeyID != "" {
ctx = aibridge.WithDelegatedAPIKeyID(ctx, modelOpts.ActiveAPIKeyID)
}
// Load MCP server configs and user tokens in parallel with model
// resolution. These queries have no dependencies on each other and all
@@ -7831,6 +7830,7 @@ func (p *Server) runChat(
persistCtx,
chat.ID,
modelConfig.ID,
modelOpts.ActiveAPIKeyID,
compactionToolCallID,
result,
); err != nil {
@@ -8460,12 +8460,14 @@ func buildProviderTools(options *codersdk.ChatModelProviderOptions) []chatloop.P
return tools
}
// persistChatContextSummary persists a chat context summary to the database.
// This is invoked via the chat loop's compaction callback.
// persistChatContextSummary is called from the chat loop's compaction
// callback. activeAPIKeyID is stamped onto the summary user message. When
// empty, it falls back to the delegated key in ctx.
func (p *Server) persistChatContextSummary(
ctx context.Context,
chatID uuid.UUID,
modelConfigID uuid.UUID,
activeAPIKeyID string,
toolCallID string,
result chatloop.CompactionResult,
) error {
@@ -8514,6 +8516,11 @@ func (p *Server) persistChatContextSummary(
return xerrors.Errorf("encode summary tool result: %w", err)
}
summaryAPIKeyID := activeAPIKeyID
if summaryAPIKeyID == "" {
summaryAPIKeyID, _ = aibridge.DelegatedAPIKeyIDFromContext(ctx)
}
var insertedMessages []database.ChatMessage
txErr := p.db.InTx(func(tx database.Store) error {
@@ -8522,7 +8529,6 @@ func (p *Server) persistChatContextSummary(
}
// Hidden summary user message (not published to subscribers).
summaryAPIKeyID, _ := aibridge.DelegatedAPIKeyIDFromContext(ctx)
summaryUserMsg := newUserChatMessage(
summaryAPIKeyID,
systemContent,
+52 -31
View File
@@ -6651,42 +6651,63 @@ func TestPersistChatContextSummarySetsAPIKeyID(t *testing.T) {
UserID: user.ID,
})
ctx = aibridge.WithDelegatedAPIKeyID(ctx, apiKey.ID)
server := &Server{db: db}
persistAndAssertSummaryKey := func(
summaryCtx context.Context,
chatID uuid.UUID,
activeAPIKeyID string,
wantAPIKeyID string,
toolCallID string,
) {
t.Helper()
err := server.persistChatContextSummary(
ctx,
chat.ID,
modelConfig.ID,
"tool-call-id-1",
chatloop.CompactionResult{
SystemSummary: "summarized context",
SummaryReport: "context was summarized",
ThresholdPercent: 70,
UsagePercent: 85.0,
ContextTokens: 8500,
ContextLimit: 10000,
},
)
require.NoError(t, err)
err := server.persistChatContextSummary(
summaryCtx,
chatID,
modelConfig.ID,
activeAPIKeyID,
toolCallID,
chatloop.CompactionResult{
SystemSummary: "summarized context",
SummaryReport: "context was summarized",
ThresholdPercent: 70,
UsagePercent: 85.0,
ContextTokens: 8500,
ContextLimit: 10000,
},
)
require.NoError(t, err)
msgs, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
require.NoError(t, err)
msgs, err := db.GetChatMessagesForPromptByChatID(ctx, chatID)
require.NoError(t, err)
// GetChatMessagesForPromptByChatID uses a compaction boundary CTE
// that selects compressed=true, visibility='model'. Only the user
// summary qualifies; the assistant (visibility=user) and tool
// result (visibility=both) are excluded by the CTE filter.
require.NotEmpty(t, msgs)
// GetChatMessagesForPromptByChatID uses a compaction boundary CTE
// that selects compressed=true, visibility='model'. Only the user
// summary qualifies; the assistant (visibility=user) and tool
// result (visibility=both) are excluded by the CTE filter.
require.NotEmpty(t, msgs)
var foundUserSummary bool
for _, msg := range msgs {
if msg.Role == database.ChatMessageRoleUser {
foundUserSummary = true
require.True(t, msg.APIKeyID.Valid, "summary user message must have APIKeyID set")
require.Equal(t, apiKey.ID, msg.APIKeyID.String, "summary user message APIKeyID must match")
var foundUserSummary bool
for _, msg := range msgs {
if msg.Role == database.ChatMessageRoleUser {
foundUserSummary = true
require.True(t, msg.APIKeyID.Valid, "summary user message must have APIKeyID set")
require.Equal(t, wantAPIKeyID, msg.APIKeyID.String, "summary user message APIKeyID must match")
}
}
require.True(t, foundUserSummary, "expected to find compressed user summary message")
}
require.True(t, foundUserSummary, "expected to find compressed user summary message")
persistAndAssertSummaryKey(ctx, chat.ID, apiKey.ID, apiKey.ID, "tool-call-id-1")
fallbackChat := dbgen.Chat(t, db, database.Chat{
OwnerID: user.ID,
OrganizationID: org.ID,
LastModelConfigID: modelConfig.ID,
})
fallbackKey, _ := dbgen.APIKey(t, db, database.APIKey{
UserID: user.ID,
})
fallbackCtx := aibridge.WithDelegatedAPIKeyID(ctx, fallbackKey.ID)
persistAndAssertSummaryKey(fallbackCtx, fallbackChat.ID, "", fallbackKey.ID, "tool-call-id-2")
}
+110 -8
View File
@@ -405,7 +405,7 @@ func TestActiveTurnAPIKeyIDFromMessages(t *testing.T) {
},
},
{
name: "SkipsModelOnlyUserMessages",
name: "SkipsUncompressedModelOnlyUserMessages",
messages: []database.ChatMessage{
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityBoth, APIKeyID: sqlNullString(oldKeyID)},
{ID: 2, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, APIKeyID: sqlNullString(currentKeyID)},
@@ -413,6 +413,54 @@ func TestActiveTurnAPIKeyIDFromMessages(t *testing.T) {
wantKey: oldKeyID,
wantOK: true,
},
{
name: "CompressedSummaryFallback",
messages: []database.ChatMessage{
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, Compressed: true, APIKeyID: sqlNullString(currentKeyID)},
{ID: 2, Role: database.ChatMessageRoleAssistant, Visibility: database.ChatMessageVisibilityBoth},
},
wantKey: currentKeyID,
wantOK: true,
},
{
name: "LatestCompressedSummaryWins",
messages: []database.ChatMessage{
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, Compressed: true, APIKeyID: sqlNullString(oldKeyID)},
{ID: 2, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, Compressed: true, APIKeyID: sqlNullString(currentKeyID)},
{ID: 3, Role: database.ChatMessageRoleAssistant, Visibility: database.ChatMessageVisibilityBoth},
},
wantKey: currentKeyID,
wantOK: true,
},
{
name: "VisibleUserWinsOverCompressedSummary",
messages: []database.ChatMessage{
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, Compressed: true, APIKeyID: sqlNullString(oldKeyID)},
{ID: 2, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityBoth, APIKeyID: sqlNullString(currentKeyID)},
},
wantKey: currentKeyID,
wantOK: true,
},
{
name: "MissingVisibleUserKeyDoesNotFallBackToCompressedSummary",
messages: []database.ChatMessage{
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, Compressed: true, APIKeyID: sqlNullString(oldKeyID)},
{ID: 2, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityBoth},
},
},
{
name: "UncompressedModelOnlyUserIgnored",
messages: []database.ChatMessage{
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, APIKeyID: sqlNullString(currentKeyID)},
},
},
{
name: "CompressedSummaryMissingKeyDoesNotFallBack",
messages: []database.ChatMessage{
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityBoth, APIKeyID: sqlNullString(oldKeyID)},
{ID: 2, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, Compressed: true},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@@ -421,15 +469,11 @@ func TestActiveTurnAPIKeyIDFromMessages(t *testing.T) {
gotKey, gotOK := activeTurnAPIKeyIDFromMessages(tt.messages)
require.Equal(t, tt.wantOK, gotOK)
require.Equal(t, tt.wantKey, gotKey)
ctx := contextWithActiveTurnAPIKeyID(t.Context(), tt.messages)
ctxKey, ctxOK := aibridge.DelegatedAPIKeyIDFromContext(ctx)
require.Equal(t, tt.wantOK, ctxOK)
require.Equal(t, tt.wantKey, ctxKey)
})
}
}
func TestActiveTurnContextUsesPromptMessages(t *testing.T) {
func TestPromptMessagesForVisibleUserPreserveActiveAPIKeyID(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
@@ -477,12 +521,70 @@ func TestActiveTurnContextUsesPromptMessages(t *testing.T) {
messages, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
require.NoError(t, err)
ctx = contextWithActiveTurnAPIKeyID(ctx, messages)
gotKey, ok := aibridge.DelegatedAPIKeyIDFromContext(ctx)
gotKey, ok := activeTurnAPIKeyIDFromMessages(messages)
require.True(t, ok)
require.Equal(t, currentKey.ID, gotKey)
}
func TestPromptMessagesForCompactedChatPreserveActiveAPIKeyID(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := t.Context()
user := dbgen.User(t, db, database.User{})
org := dbgen.Organization(t, db, database.Organization{})
model := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{})
chat := dbgen.Chat(t, db, database.Chat{OrganizationID: org.ID, OwnerID: user.ID, LastModelConfigID: model.ID})
key, _ := dbgen.APIKey(t, db, database.APIKey{UserID: user.ID})
visibleUser := dbgen.ChatMessage(t, db, database.ChatMessage{
ChatID: chat.ID,
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
Role: database.ChatMessageRoleUser,
Visibility: database.ChatMessageVisibilityBoth,
APIKeyID: sqlNullString(key.ID),
})
dbgen.ChatMessage(t, db, database.ChatMessage{
ChatID: chat.ID,
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
Role: database.ChatMessageRoleAssistant,
Visibility: database.ChatMessageVisibilityBoth,
})
compressedSummary := dbgen.ChatMessage(t, db, database.ChatMessage{
ChatID: chat.ID,
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
Role: database.ChatMessageRoleUser,
Visibility: database.ChatMessageVisibilityModel,
Compressed: true,
APIKeyID: sqlNullString(key.ID),
})
afterSummary := dbgen.ChatMessage(t, db, database.ChatMessage{
ChatID: chat.ID,
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
Role: database.ChatMessageRoleAssistant,
Visibility: database.ChatMessageVisibilityBoth,
})
messages, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
require.NoError(t, err)
ids := make(map[int64]struct{}, len(messages))
for _, message := range messages {
ids[message.ID] = struct{}{}
}
_, hasVisibleUser := ids[visibleUser.ID]
require.False(t, hasVisibleUser)
_, hasSummary := ids[compressedSummary.ID]
require.True(t, hasSummary)
_, hasAfterSummary := ids[afterSummary.ID]
require.True(t, hasAfterSummary)
gotKey, ok := activeTurnAPIKeyIDFromMessages(messages)
require.True(t, ok)
require.Equal(t, key.ID, gotKey)
}
func sqlNullString(value string) sql.NullString {
return sql.NullString{String: value, Valid: value != ""}
}