mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
This commit is contained in:
committed by
GitHub
parent
26c035d742
commit
2f60b14649
+20
-14
@@ -6468,22 +6468,14 @@ type runChatResult struct {
|
|||||||
HistoryTipMessageID int64
|
HistoryTipMessageID int64
|
||||||
}
|
}
|
||||||
|
|
||||||
func contextWithActiveTurnAPIKeyID(ctx context.Context, messages []database.ChatMessage) context.Context {
|
|
||||||
apiKeyID, ok := activeTurnAPIKeyIDFromMessages(messages)
|
|
||||||
if !ok {
|
|
||||||
return ctx
|
|
||||||
}
|
|
||||||
return aibridge.WithDelegatedAPIKeyID(ctx, apiKeyID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func activeTurnAPIKeyIDFromMessages(messages []database.ChatMessage) (string, bool) {
|
func activeTurnAPIKeyIDFromMessages(messages []database.ChatMessage) (string, bool) {
|
||||||
for i := len(messages) - 1; i >= 0; i-- {
|
for i := len(messages) - 1; i >= 0; i-- {
|
||||||
message := messages[i]
|
message := messages[i]
|
||||||
if message.Role != database.ChatMessageRoleUser {
|
if message.Role != database.ChatMessageRoleUser {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if message.Visibility != database.ChatMessageVisibilityBoth &&
|
if !isUserVisibleChatMessage(message) &&
|
||||||
message.Visibility != database.ChatMessageVisibilityUser {
|
!(message.Visibility == database.ChatMessageVisibilityModel && message.Compressed) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if !message.APIKeyID.Valid || message.APIKeyID.String == "" {
|
if !message.APIKeyID.Valid || message.APIKeyID.String == "" {
|
||||||
@@ -6494,6 +6486,11 @@ func activeTurnAPIKeyIDFromMessages(messages []database.ChatMessage) (string, bo
|
|||||||
return "", false
|
return "", false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isUserVisibleChatMessage(message database.ChatMessage) bool {
|
||||||
|
return message.Visibility == database.ChatMessageVisibilityBoth ||
|
||||||
|
message.Visibility == database.ChatMessageVisibilityUser
|
||||||
|
}
|
||||||
|
|
||||||
func allToolNames(allTools []fantasy.AgentTool) []string {
|
func allToolNames(allTools []fantasy.AgentTool) []string {
|
||||||
toolNames := make([]string, 0, len(allTools))
|
toolNames := make([]string, 0, len(allTools))
|
||||||
for _, tool := range allTools {
|
for _, tool := range allTools {
|
||||||
@@ -7124,7 +7121,9 @@ func (p *Server) runChat(
|
|||||||
return result, xerrors.Errorf("get chat messages: %w", err)
|
return result, xerrors.Errorf("get chat messages: %w", err)
|
||||||
}
|
}
|
||||||
modelOpts := modelBuildOptionsFromMessages(messages)
|
modelOpts := modelBuildOptionsFromMessages(messages)
|
||||||
ctx = contextWithActiveTurnAPIKeyID(ctx, messages)
|
if modelOpts.ActiveAPIKeyID != "" {
|
||||||
|
ctx = aibridge.WithDelegatedAPIKeyID(ctx, modelOpts.ActiveAPIKeyID)
|
||||||
|
}
|
||||||
|
|
||||||
// Load MCP server configs and user tokens in parallel with model
|
// Load MCP server configs and user tokens in parallel with model
|
||||||
// resolution. These queries have no dependencies on each other and all
|
// resolution. These queries have no dependencies on each other and all
|
||||||
@@ -7831,6 +7830,7 @@ func (p *Server) runChat(
|
|||||||
persistCtx,
|
persistCtx,
|
||||||
chat.ID,
|
chat.ID,
|
||||||
modelConfig.ID,
|
modelConfig.ID,
|
||||||
|
modelOpts.ActiveAPIKeyID,
|
||||||
compactionToolCallID,
|
compactionToolCallID,
|
||||||
result,
|
result,
|
||||||
); err != nil {
|
); err != nil {
|
||||||
@@ -8460,12 +8460,14 @@ func buildProviderTools(options *codersdk.ChatModelProviderOptions) []chatloop.P
|
|||||||
return tools
|
return tools
|
||||||
}
|
}
|
||||||
|
|
||||||
// persistChatContextSummary persists a chat context summary to the database.
|
// persistChatContextSummary is called from the chat loop's compaction
|
||||||
// This is invoked via the chat loop's compaction callback.
|
// callback. activeAPIKeyID is stamped onto the summary user message. When
|
||||||
|
// empty, it falls back to the delegated key in ctx.
|
||||||
func (p *Server) persistChatContextSummary(
|
func (p *Server) persistChatContextSummary(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
chatID uuid.UUID,
|
chatID uuid.UUID,
|
||||||
modelConfigID uuid.UUID,
|
modelConfigID uuid.UUID,
|
||||||
|
activeAPIKeyID string,
|
||||||
toolCallID string,
|
toolCallID string,
|
||||||
result chatloop.CompactionResult,
|
result chatloop.CompactionResult,
|
||||||
) error {
|
) error {
|
||||||
@@ -8514,6 +8516,11 @@ func (p *Server) persistChatContextSummary(
|
|||||||
return xerrors.Errorf("encode summary tool result: %w", err)
|
return xerrors.Errorf("encode summary tool result: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
summaryAPIKeyID := activeAPIKeyID
|
||||||
|
if summaryAPIKeyID == "" {
|
||||||
|
summaryAPIKeyID, _ = aibridge.DelegatedAPIKeyIDFromContext(ctx)
|
||||||
|
}
|
||||||
|
|
||||||
var insertedMessages []database.ChatMessage
|
var insertedMessages []database.ChatMessage
|
||||||
|
|
||||||
txErr := p.db.InTx(func(tx database.Store) error {
|
txErr := p.db.InTx(func(tx database.Store) error {
|
||||||
@@ -8522,7 +8529,6 @@ func (p *Server) persistChatContextSummary(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Hidden summary user message (not published to subscribers).
|
// Hidden summary user message (not published to subscribers).
|
||||||
summaryAPIKeyID, _ := aibridge.DelegatedAPIKeyIDFromContext(ctx)
|
|
||||||
summaryUserMsg := newUserChatMessage(
|
summaryUserMsg := newUserChatMessage(
|
||||||
summaryAPIKeyID,
|
summaryAPIKeyID,
|
||||||
systemContent,
|
systemContent,
|
||||||
|
|||||||
@@ -6651,42 +6651,63 @@ func TestPersistChatContextSummarySetsAPIKeyID(t *testing.T) {
|
|||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
})
|
})
|
||||||
|
|
||||||
ctx = aibridge.WithDelegatedAPIKeyID(ctx, apiKey.ID)
|
|
||||||
|
|
||||||
server := &Server{db: db}
|
server := &Server{db: db}
|
||||||
|
persistAndAssertSummaryKey := func(
|
||||||
|
summaryCtx context.Context,
|
||||||
|
chatID uuid.UUID,
|
||||||
|
activeAPIKeyID string,
|
||||||
|
wantAPIKeyID string,
|
||||||
|
toolCallID string,
|
||||||
|
) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
err := server.persistChatContextSummary(
|
err := server.persistChatContextSummary(
|
||||||
ctx,
|
summaryCtx,
|
||||||
chat.ID,
|
chatID,
|
||||||
modelConfig.ID,
|
modelConfig.ID,
|
||||||
"tool-call-id-1",
|
activeAPIKeyID,
|
||||||
chatloop.CompactionResult{
|
toolCallID,
|
||||||
SystemSummary: "summarized context",
|
chatloop.CompactionResult{
|
||||||
SummaryReport: "context was summarized",
|
SystemSummary: "summarized context",
|
||||||
ThresholdPercent: 70,
|
SummaryReport: "context was summarized",
|
||||||
UsagePercent: 85.0,
|
ThresholdPercent: 70,
|
||||||
ContextTokens: 8500,
|
UsagePercent: 85.0,
|
||||||
ContextLimit: 10000,
|
ContextTokens: 8500,
|
||||||
},
|
ContextLimit: 10000,
|
||||||
)
|
},
|
||||||
require.NoError(t, err)
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
msgs, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
|
msgs, err := db.GetChatMessagesForPromptByChatID(ctx, chatID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// GetChatMessagesForPromptByChatID uses a compaction boundary CTE
|
// GetChatMessagesForPromptByChatID uses a compaction boundary CTE
|
||||||
// that selects compressed=true, visibility='model'. Only the user
|
// that selects compressed=true, visibility='model'. Only the user
|
||||||
// summary qualifies; the assistant (visibility=user) and tool
|
// summary qualifies; the assistant (visibility=user) and tool
|
||||||
// result (visibility=both) are excluded by the CTE filter.
|
// result (visibility=both) are excluded by the CTE filter.
|
||||||
require.NotEmpty(t, msgs)
|
require.NotEmpty(t, msgs)
|
||||||
|
|
||||||
var foundUserSummary bool
|
var foundUserSummary bool
|
||||||
for _, msg := range msgs {
|
for _, msg := range msgs {
|
||||||
if msg.Role == database.ChatMessageRoleUser {
|
if msg.Role == database.ChatMessageRoleUser {
|
||||||
foundUserSummary = true
|
foundUserSummary = true
|
||||||
require.True(t, msg.APIKeyID.Valid, "summary user message must have APIKeyID set")
|
require.True(t, msg.APIKeyID.Valid, "summary user message must have APIKeyID set")
|
||||||
require.Equal(t, apiKey.ID, msg.APIKeyID.String, "summary user message APIKeyID must match")
|
require.Equal(t, wantAPIKeyID, msg.APIKeyID.String, "summary user message APIKeyID must match")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
require.True(t, foundUserSummary, "expected to find compressed user summary message")
|
||||||
}
|
}
|
||||||
require.True(t, foundUserSummary, "expected to find compressed user summary message")
|
|
||||||
|
persistAndAssertSummaryKey(ctx, chat.ID, apiKey.ID, apiKey.ID, "tool-call-id-1")
|
||||||
|
|
||||||
|
fallbackChat := dbgen.Chat(t, db, database.Chat{
|
||||||
|
OwnerID: user.ID,
|
||||||
|
OrganizationID: org.ID,
|
||||||
|
LastModelConfigID: modelConfig.ID,
|
||||||
|
})
|
||||||
|
fallbackKey, _ := dbgen.APIKey(t, db, database.APIKey{
|
||||||
|
UserID: user.ID,
|
||||||
|
})
|
||||||
|
fallbackCtx := aibridge.WithDelegatedAPIKeyID(ctx, fallbackKey.ID)
|
||||||
|
persistAndAssertSummaryKey(fallbackCtx, fallbackChat.ID, "", fallbackKey.ID, "tool-call-id-2")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -405,7 +405,7 @@ func TestActiveTurnAPIKeyIDFromMessages(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "SkipsModelOnlyUserMessages",
|
name: "SkipsUncompressedModelOnlyUserMessages",
|
||||||
messages: []database.ChatMessage{
|
messages: []database.ChatMessage{
|
||||||
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityBoth, APIKeyID: sqlNullString(oldKeyID)},
|
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityBoth, APIKeyID: sqlNullString(oldKeyID)},
|
||||||
{ID: 2, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, APIKeyID: sqlNullString(currentKeyID)},
|
{ID: 2, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, APIKeyID: sqlNullString(currentKeyID)},
|
||||||
@@ -413,6 +413,54 @@ func TestActiveTurnAPIKeyIDFromMessages(t *testing.T) {
|
|||||||
wantKey: oldKeyID,
|
wantKey: oldKeyID,
|
||||||
wantOK: true,
|
wantOK: true,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "CompressedSummaryFallback",
|
||||||
|
messages: []database.ChatMessage{
|
||||||
|
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, Compressed: true, APIKeyID: sqlNullString(currentKeyID)},
|
||||||
|
{ID: 2, Role: database.ChatMessageRoleAssistant, Visibility: database.ChatMessageVisibilityBoth},
|
||||||
|
},
|
||||||
|
wantKey: currentKeyID,
|
||||||
|
wantOK: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "LatestCompressedSummaryWins",
|
||||||
|
messages: []database.ChatMessage{
|
||||||
|
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, Compressed: true, APIKeyID: sqlNullString(oldKeyID)},
|
||||||
|
{ID: 2, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, Compressed: true, APIKeyID: sqlNullString(currentKeyID)},
|
||||||
|
{ID: 3, Role: database.ChatMessageRoleAssistant, Visibility: database.ChatMessageVisibilityBoth},
|
||||||
|
},
|
||||||
|
wantKey: currentKeyID,
|
||||||
|
wantOK: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "VisibleUserWinsOverCompressedSummary",
|
||||||
|
messages: []database.ChatMessage{
|
||||||
|
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, Compressed: true, APIKeyID: sqlNullString(oldKeyID)},
|
||||||
|
{ID: 2, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityBoth, APIKeyID: sqlNullString(currentKeyID)},
|
||||||
|
},
|
||||||
|
wantKey: currentKeyID,
|
||||||
|
wantOK: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "MissingVisibleUserKeyDoesNotFallBackToCompressedSummary",
|
||||||
|
messages: []database.ChatMessage{
|
||||||
|
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, Compressed: true, APIKeyID: sqlNullString(oldKeyID)},
|
||||||
|
{ID: 2, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityBoth},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "UncompressedModelOnlyUserIgnored",
|
||||||
|
messages: []database.ChatMessage{
|
||||||
|
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, APIKeyID: sqlNullString(currentKeyID)},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "CompressedSummaryMissingKeyDoesNotFallBack",
|
||||||
|
messages: []database.ChatMessage{
|
||||||
|
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityBoth, APIKeyID: sqlNullString(oldKeyID)},
|
||||||
|
{ID: 2, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, Compressed: true},
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
@@ -421,15 +469,11 @@ func TestActiveTurnAPIKeyIDFromMessages(t *testing.T) {
|
|||||||
gotKey, gotOK := activeTurnAPIKeyIDFromMessages(tt.messages)
|
gotKey, gotOK := activeTurnAPIKeyIDFromMessages(tt.messages)
|
||||||
require.Equal(t, tt.wantOK, gotOK)
|
require.Equal(t, tt.wantOK, gotOK)
|
||||||
require.Equal(t, tt.wantKey, gotKey)
|
require.Equal(t, tt.wantKey, gotKey)
|
||||||
ctx := contextWithActiveTurnAPIKeyID(t.Context(), tt.messages)
|
|
||||||
ctxKey, ctxOK := aibridge.DelegatedAPIKeyIDFromContext(ctx)
|
|
||||||
require.Equal(t, tt.wantOK, ctxOK)
|
|
||||||
require.Equal(t, tt.wantKey, ctxKey)
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestActiveTurnContextUsesPromptMessages(t *testing.T) {
|
func TestPromptMessagesForVisibleUserPreserveActiveAPIKeyID(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
db, _ := dbtestutil.NewDB(t)
|
db, _ := dbtestutil.NewDB(t)
|
||||||
@@ -477,12 +521,70 @@ func TestActiveTurnContextUsesPromptMessages(t *testing.T) {
|
|||||||
|
|
||||||
messages, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
|
messages, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
ctx = contextWithActiveTurnAPIKeyID(ctx, messages)
|
gotKey, ok := activeTurnAPIKeyIDFromMessages(messages)
|
||||||
gotKey, ok := aibridge.DelegatedAPIKeyIDFromContext(ctx)
|
|
||||||
require.True(t, ok)
|
require.True(t, ok)
|
||||||
require.Equal(t, currentKey.ID, gotKey)
|
require.Equal(t, currentKey.ID, gotKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestPromptMessagesForCompactedChatPreserveActiveAPIKeyID(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
db, _ := dbtestutil.NewDB(t)
|
||||||
|
ctx := t.Context()
|
||||||
|
user := dbgen.User(t, db, database.User{})
|
||||||
|
org := dbgen.Organization(t, db, database.Organization{})
|
||||||
|
model := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{})
|
||||||
|
chat := dbgen.Chat(t, db, database.Chat{OrganizationID: org.ID, OwnerID: user.ID, LastModelConfigID: model.ID})
|
||||||
|
key, _ := dbgen.APIKey(t, db, database.APIKey{UserID: user.ID})
|
||||||
|
|
||||||
|
visibleUser := dbgen.ChatMessage(t, db, database.ChatMessage{
|
||||||
|
ChatID: chat.ID,
|
||||||
|
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||||
|
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
|
||||||
|
Role: database.ChatMessageRoleUser,
|
||||||
|
Visibility: database.ChatMessageVisibilityBoth,
|
||||||
|
APIKeyID: sqlNullString(key.ID),
|
||||||
|
})
|
||||||
|
dbgen.ChatMessage(t, db, database.ChatMessage{
|
||||||
|
ChatID: chat.ID,
|
||||||
|
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
|
||||||
|
Role: database.ChatMessageRoleAssistant,
|
||||||
|
Visibility: database.ChatMessageVisibilityBoth,
|
||||||
|
})
|
||||||
|
compressedSummary := dbgen.ChatMessage(t, db, database.ChatMessage{
|
||||||
|
ChatID: chat.ID,
|
||||||
|
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
|
||||||
|
Role: database.ChatMessageRoleUser,
|
||||||
|
Visibility: database.ChatMessageVisibilityModel,
|
||||||
|
Compressed: true,
|
||||||
|
APIKeyID: sqlNullString(key.ID),
|
||||||
|
})
|
||||||
|
afterSummary := dbgen.ChatMessage(t, db, database.ChatMessage{
|
||||||
|
ChatID: chat.ID,
|
||||||
|
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
|
||||||
|
Role: database.ChatMessageRoleAssistant,
|
||||||
|
Visibility: database.ChatMessageVisibilityBoth,
|
||||||
|
})
|
||||||
|
|
||||||
|
messages, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
ids := make(map[int64]struct{}, len(messages))
|
||||||
|
for _, message := range messages {
|
||||||
|
ids[message.ID] = struct{}{}
|
||||||
|
}
|
||||||
|
_, hasVisibleUser := ids[visibleUser.ID]
|
||||||
|
require.False(t, hasVisibleUser)
|
||||||
|
_, hasSummary := ids[compressedSummary.ID]
|
||||||
|
require.True(t, hasSummary)
|
||||||
|
_, hasAfterSummary := ids[afterSummary.ID]
|
||||||
|
require.True(t, hasAfterSummary)
|
||||||
|
|
||||||
|
gotKey, ok := activeTurnAPIKeyIDFromMessages(messages)
|
||||||
|
require.True(t, ok)
|
||||||
|
require.Equal(t, key.ID, gotKey)
|
||||||
|
}
|
||||||
|
|
||||||
func sqlNullString(value string) sql.NullString {
|
func sqlNullString(value string) sql.NullString {
|
||||||
return sql.NullString{String: value, Valid: value != ""}
|
return sql.NullString{String: value, Valid: value != ""}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user