mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
This commit is contained in:
committed by
GitHub
parent
26c035d742
commit
2f60b14649
+20
-14
@@ -6468,22 +6468,14 @@ type runChatResult struct {
|
||||
HistoryTipMessageID int64
|
||||
}
|
||||
|
||||
func contextWithActiveTurnAPIKeyID(ctx context.Context, messages []database.ChatMessage) context.Context {
|
||||
apiKeyID, ok := activeTurnAPIKeyIDFromMessages(messages)
|
||||
if !ok {
|
||||
return ctx
|
||||
}
|
||||
return aibridge.WithDelegatedAPIKeyID(ctx, apiKeyID)
|
||||
}
|
||||
|
||||
func activeTurnAPIKeyIDFromMessages(messages []database.ChatMessage) (string, bool) {
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
message := messages[i]
|
||||
if message.Role != database.ChatMessageRoleUser {
|
||||
continue
|
||||
}
|
||||
if message.Visibility != database.ChatMessageVisibilityBoth &&
|
||||
message.Visibility != database.ChatMessageVisibilityUser {
|
||||
if !isUserVisibleChatMessage(message) &&
|
||||
!(message.Visibility == database.ChatMessageVisibilityModel && message.Compressed) {
|
||||
continue
|
||||
}
|
||||
if !message.APIKeyID.Valid || message.APIKeyID.String == "" {
|
||||
@@ -6494,6 +6486,11 @@ func activeTurnAPIKeyIDFromMessages(messages []database.ChatMessage) (string, bo
|
||||
return "", false
|
||||
}
|
||||
|
||||
func isUserVisibleChatMessage(message database.ChatMessage) bool {
|
||||
return message.Visibility == database.ChatMessageVisibilityBoth ||
|
||||
message.Visibility == database.ChatMessageVisibilityUser
|
||||
}
|
||||
|
||||
func allToolNames(allTools []fantasy.AgentTool) []string {
|
||||
toolNames := make([]string, 0, len(allTools))
|
||||
for _, tool := range allTools {
|
||||
@@ -7124,7 +7121,9 @@ func (p *Server) runChat(
|
||||
return result, xerrors.Errorf("get chat messages: %w", err)
|
||||
}
|
||||
modelOpts := modelBuildOptionsFromMessages(messages)
|
||||
ctx = contextWithActiveTurnAPIKeyID(ctx, messages)
|
||||
if modelOpts.ActiveAPIKeyID != "" {
|
||||
ctx = aibridge.WithDelegatedAPIKeyID(ctx, modelOpts.ActiveAPIKeyID)
|
||||
}
|
||||
|
||||
// Load MCP server configs and user tokens in parallel with model
|
||||
// resolution. These queries have no dependencies on each other and all
|
||||
@@ -7831,6 +7830,7 @@ func (p *Server) runChat(
|
||||
persistCtx,
|
||||
chat.ID,
|
||||
modelConfig.ID,
|
||||
modelOpts.ActiveAPIKeyID,
|
||||
compactionToolCallID,
|
||||
result,
|
||||
); err != nil {
|
||||
@@ -8460,12 +8460,14 @@ func buildProviderTools(options *codersdk.ChatModelProviderOptions) []chatloop.P
|
||||
return tools
|
||||
}
|
||||
|
||||
// persistChatContextSummary persists a chat context summary to the database.
|
||||
// This is invoked via the chat loop's compaction callback.
|
||||
// persistChatContextSummary is called from the chat loop's compaction
|
||||
// callback. activeAPIKeyID is stamped onto the summary user message. When
|
||||
// empty, it falls back to the delegated key in ctx.
|
||||
func (p *Server) persistChatContextSummary(
|
||||
ctx context.Context,
|
||||
chatID uuid.UUID,
|
||||
modelConfigID uuid.UUID,
|
||||
activeAPIKeyID string,
|
||||
toolCallID string,
|
||||
result chatloop.CompactionResult,
|
||||
) error {
|
||||
@@ -8514,6 +8516,11 @@ func (p *Server) persistChatContextSummary(
|
||||
return xerrors.Errorf("encode summary tool result: %w", err)
|
||||
}
|
||||
|
||||
summaryAPIKeyID := activeAPIKeyID
|
||||
if summaryAPIKeyID == "" {
|
||||
summaryAPIKeyID, _ = aibridge.DelegatedAPIKeyIDFromContext(ctx)
|
||||
}
|
||||
|
||||
var insertedMessages []database.ChatMessage
|
||||
|
||||
txErr := p.db.InTx(func(tx database.Store) error {
|
||||
@@ -8522,7 +8529,6 @@ func (p *Server) persistChatContextSummary(
|
||||
}
|
||||
|
||||
// Hidden summary user message (not published to subscribers).
|
||||
summaryAPIKeyID, _ := aibridge.DelegatedAPIKeyIDFromContext(ctx)
|
||||
summaryUserMsg := newUserChatMessage(
|
||||
summaryAPIKeyID,
|
||||
systemContent,
|
||||
|
||||
@@ -6651,42 +6651,63 @@ func TestPersistChatContextSummarySetsAPIKeyID(t *testing.T) {
|
||||
UserID: user.ID,
|
||||
})
|
||||
|
||||
ctx = aibridge.WithDelegatedAPIKeyID(ctx, apiKey.ID)
|
||||
|
||||
server := &Server{db: db}
|
||||
persistAndAssertSummaryKey := func(
|
||||
summaryCtx context.Context,
|
||||
chatID uuid.UUID,
|
||||
activeAPIKeyID string,
|
||||
wantAPIKeyID string,
|
||||
toolCallID string,
|
||||
) {
|
||||
t.Helper()
|
||||
|
||||
err := server.persistChatContextSummary(
|
||||
ctx,
|
||||
chat.ID,
|
||||
modelConfig.ID,
|
||||
"tool-call-id-1",
|
||||
chatloop.CompactionResult{
|
||||
SystemSummary: "summarized context",
|
||||
SummaryReport: "context was summarized",
|
||||
ThresholdPercent: 70,
|
||||
UsagePercent: 85.0,
|
||||
ContextTokens: 8500,
|
||||
ContextLimit: 10000,
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
err := server.persistChatContextSummary(
|
||||
summaryCtx,
|
||||
chatID,
|
||||
modelConfig.ID,
|
||||
activeAPIKeyID,
|
||||
toolCallID,
|
||||
chatloop.CompactionResult{
|
||||
SystemSummary: "summarized context",
|
||||
SummaryReport: "context was summarized",
|
||||
ThresholdPercent: 70,
|
||||
UsagePercent: 85.0,
|
||||
ContextTokens: 8500,
|
||||
ContextLimit: 10000,
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
msgs, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
msgs, err := db.GetChatMessagesForPromptByChatID(ctx, chatID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// GetChatMessagesForPromptByChatID uses a compaction boundary CTE
|
||||
// that selects compressed=true, visibility='model'. Only the user
|
||||
// summary qualifies; the assistant (visibility=user) and tool
|
||||
// result (visibility=both) are excluded by the CTE filter.
|
||||
require.NotEmpty(t, msgs)
|
||||
// GetChatMessagesForPromptByChatID uses a compaction boundary CTE
|
||||
// that selects compressed=true, visibility='model'. Only the user
|
||||
// summary qualifies; the assistant (visibility=user) and tool
|
||||
// result (visibility=both) are excluded by the CTE filter.
|
||||
require.NotEmpty(t, msgs)
|
||||
|
||||
var foundUserSummary bool
|
||||
for _, msg := range msgs {
|
||||
if msg.Role == database.ChatMessageRoleUser {
|
||||
foundUserSummary = true
|
||||
require.True(t, msg.APIKeyID.Valid, "summary user message must have APIKeyID set")
|
||||
require.Equal(t, apiKey.ID, msg.APIKeyID.String, "summary user message APIKeyID must match")
|
||||
var foundUserSummary bool
|
||||
for _, msg := range msgs {
|
||||
if msg.Role == database.ChatMessageRoleUser {
|
||||
foundUserSummary = true
|
||||
require.True(t, msg.APIKeyID.Valid, "summary user message must have APIKeyID set")
|
||||
require.Equal(t, wantAPIKeyID, msg.APIKeyID.String, "summary user message APIKeyID must match")
|
||||
}
|
||||
}
|
||||
require.True(t, foundUserSummary, "expected to find compressed user summary message")
|
||||
}
|
||||
require.True(t, foundUserSummary, "expected to find compressed user summary message")
|
||||
|
||||
persistAndAssertSummaryKey(ctx, chat.ID, apiKey.ID, apiKey.ID, "tool-call-id-1")
|
||||
|
||||
fallbackChat := dbgen.Chat(t, db, database.Chat{
|
||||
OwnerID: user.ID,
|
||||
OrganizationID: org.ID,
|
||||
LastModelConfigID: modelConfig.ID,
|
||||
})
|
||||
fallbackKey, _ := dbgen.APIKey(t, db, database.APIKey{
|
||||
UserID: user.ID,
|
||||
})
|
||||
fallbackCtx := aibridge.WithDelegatedAPIKeyID(ctx, fallbackKey.ID)
|
||||
persistAndAssertSummaryKey(fallbackCtx, fallbackChat.ID, "", fallbackKey.ID, "tool-call-id-2")
|
||||
}
|
||||
|
||||
@@ -405,7 +405,7 @@ func TestActiveTurnAPIKeyIDFromMessages(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "SkipsModelOnlyUserMessages",
|
||||
name: "SkipsUncompressedModelOnlyUserMessages",
|
||||
messages: []database.ChatMessage{
|
||||
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityBoth, APIKeyID: sqlNullString(oldKeyID)},
|
||||
{ID: 2, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, APIKeyID: sqlNullString(currentKeyID)},
|
||||
@@ -413,6 +413,54 @@ func TestActiveTurnAPIKeyIDFromMessages(t *testing.T) {
|
||||
wantKey: oldKeyID,
|
||||
wantOK: true,
|
||||
},
|
||||
{
|
||||
name: "CompressedSummaryFallback",
|
||||
messages: []database.ChatMessage{
|
||||
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, Compressed: true, APIKeyID: sqlNullString(currentKeyID)},
|
||||
{ID: 2, Role: database.ChatMessageRoleAssistant, Visibility: database.ChatMessageVisibilityBoth},
|
||||
},
|
||||
wantKey: currentKeyID,
|
||||
wantOK: true,
|
||||
},
|
||||
{
|
||||
name: "LatestCompressedSummaryWins",
|
||||
messages: []database.ChatMessage{
|
||||
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, Compressed: true, APIKeyID: sqlNullString(oldKeyID)},
|
||||
{ID: 2, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, Compressed: true, APIKeyID: sqlNullString(currentKeyID)},
|
||||
{ID: 3, Role: database.ChatMessageRoleAssistant, Visibility: database.ChatMessageVisibilityBoth},
|
||||
},
|
||||
wantKey: currentKeyID,
|
||||
wantOK: true,
|
||||
},
|
||||
{
|
||||
name: "VisibleUserWinsOverCompressedSummary",
|
||||
messages: []database.ChatMessage{
|
||||
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, Compressed: true, APIKeyID: sqlNullString(oldKeyID)},
|
||||
{ID: 2, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityBoth, APIKeyID: sqlNullString(currentKeyID)},
|
||||
},
|
||||
wantKey: currentKeyID,
|
||||
wantOK: true,
|
||||
},
|
||||
{
|
||||
name: "MissingVisibleUserKeyDoesNotFallBackToCompressedSummary",
|
||||
messages: []database.ChatMessage{
|
||||
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, Compressed: true, APIKeyID: sqlNullString(oldKeyID)},
|
||||
{ID: 2, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityBoth},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "UncompressedModelOnlyUserIgnored",
|
||||
messages: []database.ChatMessage{
|
||||
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, APIKeyID: sqlNullString(currentKeyID)},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "CompressedSummaryMissingKeyDoesNotFallBack",
|
||||
messages: []database.ChatMessage{
|
||||
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityBoth, APIKeyID: sqlNullString(oldKeyID)},
|
||||
{ID: 2, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, Compressed: true},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
@@ -421,15 +469,11 @@ func TestActiveTurnAPIKeyIDFromMessages(t *testing.T) {
|
||||
gotKey, gotOK := activeTurnAPIKeyIDFromMessages(tt.messages)
|
||||
require.Equal(t, tt.wantOK, gotOK)
|
||||
require.Equal(t, tt.wantKey, gotKey)
|
||||
ctx := contextWithActiveTurnAPIKeyID(t.Context(), tt.messages)
|
||||
ctxKey, ctxOK := aibridge.DelegatedAPIKeyIDFromContext(ctx)
|
||||
require.Equal(t, tt.wantOK, ctxOK)
|
||||
require.Equal(t, tt.wantKey, ctxKey)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestActiveTurnContextUsesPromptMessages(t *testing.T) {
|
||||
func TestPromptMessagesForVisibleUserPreserveActiveAPIKeyID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
@@ -477,12 +521,70 @@ func TestActiveTurnContextUsesPromptMessages(t *testing.T) {
|
||||
|
||||
messages, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
ctx = contextWithActiveTurnAPIKeyID(ctx, messages)
|
||||
gotKey, ok := aibridge.DelegatedAPIKeyIDFromContext(ctx)
|
||||
gotKey, ok := activeTurnAPIKeyIDFromMessages(messages)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, currentKey.ID, gotKey)
|
||||
}
|
||||
|
||||
func TestPromptMessagesForCompactedChatPreserveActiveAPIKeyID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
ctx := t.Context()
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
model := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{})
|
||||
chat := dbgen.Chat(t, db, database.Chat{OrganizationID: org.ID, OwnerID: user.ID, LastModelConfigID: model.ID})
|
||||
key, _ := dbgen.APIKey(t, db, database.APIKey{UserID: user.ID})
|
||||
|
||||
visibleUser := dbgen.ChatMessage(t, db, database.ChatMessage{
|
||||
ChatID: chat.ID,
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
|
||||
Role: database.ChatMessageRoleUser,
|
||||
Visibility: database.ChatMessageVisibilityBoth,
|
||||
APIKeyID: sqlNullString(key.ID),
|
||||
})
|
||||
dbgen.ChatMessage(t, db, database.ChatMessage{
|
||||
ChatID: chat.ID,
|
||||
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
|
||||
Role: database.ChatMessageRoleAssistant,
|
||||
Visibility: database.ChatMessageVisibilityBoth,
|
||||
})
|
||||
compressedSummary := dbgen.ChatMessage(t, db, database.ChatMessage{
|
||||
ChatID: chat.ID,
|
||||
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
|
||||
Role: database.ChatMessageRoleUser,
|
||||
Visibility: database.ChatMessageVisibilityModel,
|
||||
Compressed: true,
|
||||
APIKeyID: sqlNullString(key.ID),
|
||||
})
|
||||
afterSummary := dbgen.ChatMessage(t, db, database.ChatMessage{
|
||||
ChatID: chat.ID,
|
||||
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
|
||||
Role: database.ChatMessageRoleAssistant,
|
||||
Visibility: database.ChatMessageVisibilityBoth,
|
||||
})
|
||||
|
||||
messages, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
|
||||
ids := make(map[int64]struct{}, len(messages))
|
||||
for _, message := range messages {
|
||||
ids[message.ID] = struct{}{}
|
||||
}
|
||||
_, hasVisibleUser := ids[visibleUser.ID]
|
||||
require.False(t, hasVisibleUser)
|
||||
_, hasSummary := ids[compressedSummary.ID]
|
||||
require.True(t, hasSummary)
|
||||
_, hasAfterSummary := ids[afterSummary.ID]
|
||||
require.True(t, hasAfterSummary)
|
||||
|
||||
gotKey, ok := activeTurnAPIKeyIDFromMessages(messages)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, key.ID, gotKey)
|
||||
}
|
||||
|
||||
func sqlNullString(value string) sql.NullString {
|
||||
return sql.NullString{String: value, Valid: value != ""}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user