mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
fix(coderd/x/chatd): keep quickgen out of chat sessions
This commit is contained in:
@@ -27,6 +27,19 @@ const (
|
||||
aibridgeDelegatedBYOKMarker = "delegated"
|
||||
)
|
||||
|
||||
// Synthetic quickgen calls are still routed through AI Bridge, but they should
|
||||
// not become promptless root cards in the user's chat session timeline.
|
||||
type suppressAIBridgeSessionHeadersKey struct{}
|
||||
|
||||
func contextWithoutAIBridgeSessionHeaders(ctx context.Context) context.Context {
|
||||
return context.WithValue(ctx, suppressAIBridgeSessionHeadersKey{}, true)
|
||||
}
|
||||
|
||||
func suppressAIBridgeSessionHeadersFromContext(ctx context.Context) bool {
|
||||
suppress, _ := ctx.Value(suppressAIBridgeSessionHeadersKey{}).(bool)
|
||||
return suppress
|
||||
}
|
||||
|
||||
type aiGatewayModelRoute struct {
|
||||
Provider database.AIProvider
|
||||
ModelProviderHint string
|
||||
@@ -76,6 +89,10 @@ type aiGatewayRoundTripper struct {
|
||||
func (t *aiGatewayRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
ctx := aibridge.WithDelegatedAPIKeyID(req.Context(), t.apiKeyID)
|
||||
cloned := req.Clone(ctx)
|
||||
if suppressAIBridgeSessionHeadersFromContext(req.Context()) {
|
||||
cloned.Header.Del(chatprovider.HeaderCoderChatID)
|
||||
cloned.Header.Del(chatprovider.HeaderCoderSubchatID)
|
||||
}
|
||||
for name, value := range t.providerAuth.Headers {
|
||||
cloned.Header.Set(name, value)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package chatd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -374,6 +375,66 @@ func TestAIGatewayModelForwardsProviderAuth(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestAIGatewayRoundTripperCanSuppressSessionHeaders(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
ctx func() context.Context
|
||||
wantChatID string
|
||||
wantSubchat string
|
||||
}{
|
||||
{
|
||||
name: "preserves session headers by default",
|
||||
ctx: func() context.Context { return t.Context() },
|
||||
wantChatID: "chat-id",
|
||||
wantSubchat: "subchat-id",
|
||||
},
|
||||
{
|
||||
name: "suppresses session headers when marked",
|
||||
ctx: func() context.Context {
|
||||
return contextWithoutAIBridgeSessionHeaders(t.Context())
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
seen := make(chan http.Header, 1)
|
||||
rt := &aiGatewayRoundTripper{
|
||||
base: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
seen <- req.Header.Clone()
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{},
|
||||
Body: io.NopCloser(strings.NewReader("")),
|
||||
Request: req,
|
||||
}, nil
|
||||
}),
|
||||
apiKeyID: uuid.NewString(),
|
||||
}
|
||||
req, err := http.NewRequestWithContext(tt.ctx(), http.MethodPost, "http://coder-aibridge/v1/responses", nil)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set(chatprovider.HeaderCoderOwnerID, "owner-id")
|
||||
req.Header.Set(chatprovider.HeaderCoderChatID, "chat-id")
|
||||
req.Header.Set(chatprovider.HeaderCoderSubchatID, "subchat-id")
|
||||
req.Header.Set(chatprovider.HeaderCoderWorkspaceID, "workspace-id")
|
||||
|
||||
resp, err := rt.RoundTrip(req)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, resp.Body.Close())
|
||||
|
||||
got := <-seen
|
||||
require.Equal(t, "owner-id", got.Get(chatprovider.HeaderCoderOwnerID))
|
||||
require.Equal(t, tt.wantChatID, got.Get(chatprovider.HeaderCoderChatID))
|
||||
require.Equal(t, tt.wantSubchat, got.Get(chatprovider.HeaderCoderSubchatID))
|
||||
require.Equal(t, "workspace-id", got.Get(chatprovider.HeaderCoderWorkspaceID))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestActiveTurnAPIKeyIDFromMessages(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -495,6 +495,7 @@ func generateStructuredTitleWithUsage(
|
||||
return "", fantasy.Usage{}, xerrors.New("title input was empty")
|
||||
}
|
||||
|
||||
ctx = contextWithoutAIBridgeSessionHeaders(ctx)
|
||||
prompt := syntheticObjectGenerationPrompt(systemPrompt, userInput)
|
||||
|
||||
var maxOutputTokens int64 = 256
|
||||
@@ -922,6 +923,7 @@ func generateStructuredTurnStatusLabel(
|
||||
return "", xerrors.New("turn status label input was empty")
|
||||
}
|
||||
|
||||
ctx = contextWithoutAIBridgeSessionHeaders(ctx)
|
||||
prompt := syntheticObjectGenerationPrompt(systemPrompt, userInput)
|
||||
|
||||
var maxOutputTokens int64 = 64
|
||||
|
||||
@@ -415,8 +415,9 @@ func TestMaybeGenerateChatTitlePreservesUpdatedAt(t *testing.T) {
|
||||
|
||||
const wantTitle = "Failed workspace logs"
|
||||
model := &chattest.FakeModel{
|
||||
GenerateObjectFn: func(_ context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
|
||||
GenerateObjectFn: func(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
|
||||
require.Equal(t, "propose_title", call.SchemaName)
|
||||
requireSyntheticQuickgenContext(ctx, t)
|
||||
requireSyntheticQuickgenPrompt(t, call.Prompt, userPrompt)
|
||||
return &fantasy.ObjectResponse{
|
||||
Object: map[string]any{"title": wantTitle},
|
||||
@@ -495,6 +496,7 @@ func Test_generateManualTitle_UsesTimeout(t *testing.T) {
|
||||
deadline,
|
||||
2*time.Second,
|
||||
)
|
||||
requireSyntheticQuickgenContext(ctx, t)
|
||||
requireSyntheticQuickgenPrompt(t, call.Prompt, "refresh chat title")
|
||||
require.Equal(t, "propose_title", call.SchemaName)
|
||||
return &fantasy.ObjectResponse{Object: map[string]any{"title": "Refresh title"}}, nil
|
||||
@@ -524,7 +526,8 @@ func Test_generateManualTitle_TruncatesFirstUserInput(t *testing.T) {
|
||||
}
|
||||
|
||||
model := &chattest.FakeModel{
|
||||
GenerateObjectFn: func(_ context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
|
||||
GenerateObjectFn: func(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
|
||||
requireSyntheticQuickgenContext(ctx, t)
|
||||
requireSyntheticQuickgenPrompt(t, call.Prompt, truncateRunes(longFirstUserText, maxLatestUserMessageRunes))
|
||||
// The manual title system prompt also includes the latest user excerpt.
|
||||
systemText, ok := call.Prompt[0].Content[0].(fantasy.TextPart)
|
||||
@@ -764,6 +767,12 @@ func openAICompatTestModel(t *testing.T, baseURL string) fantasy.LanguageModel {
|
||||
return model
|
||||
}
|
||||
|
||||
func requireSyntheticQuickgenContext(ctx context.Context, t *testing.T) {
|
||||
t.Helper()
|
||||
|
||||
require.True(t, suppressAIBridgeSessionHeadersFromContext(ctx))
|
||||
}
|
||||
|
||||
func requireSyntheticQuickgenPrompt(t *testing.T, prompt fantasy.Prompt, userInput string) {
|
||||
t.Helper()
|
||||
|
||||
@@ -804,8 +813,9 @@ func TestGenerateStructuredTurnStatusLabel(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
model := &chattest.FakeModel{
|
||||
GenerateObjectFn: func(_ context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
|
||||
GenerateObjectFn: func(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) {
|
||||
require.Equal(t, "propose_turn_status_label", call.SchemaName)
|
||||
requireSyntheticQuickgenContext(ctx, t)
|
||||
requireSyntheticQuickgenPrompt(t, call.Prompt, "done")
|
||||
return &fantasy.ObjectResponse{
|
||||
Object: map[string]any{"label": "Submitted PR"},
|
||||
|
||||
Reference in New Issue
Block a user