diff --git a/coderd/x/chatd/model_routing_aibridge.go b/coderd/x/chatd/model_routing_aibridge.go index 5db1a16e53..0193448286 100644 --- a/coderd/x/chatd/model_routing_aibridge.go +++ b/coderd/x/chatd/model_routing_aibridge.go @@ -27,6 +27,19 @@ const ( aibridgeDelegatedBYOKMarker = "delegated" ) +// Synthetic quickgen calls are still routed through AI Bridge, but they should +// not become promptless root cards in the user's chat session timeline. +type suppressAIBridgeSessionHeadersKey struct{} + +func contextWithoutAIBridgeSessionHeaders(ctx context.Context) context.Context { + return context.WithValue(ctx, suppressAIBridgeSessionHeadersKey{}, true) +} + +func suppressAIBridgeSessionHeadersFromContext(ctx context.Context) bool { + suppress, _ := ctx.Value(suppressAIBridgeSessionHeadersKey{}).(bool) + return suppress +} + type aiGatewayModelRoute struct { Provider database.AIProvider ModelProviderHint string @@ -76,6 +89,10 @@ type aiGatewayRoundTripper struct { func (t *aiGatewayRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { ctx := aibridge.WithDelegatedAPIKeyID(req.Context(), t.apiKeyID) cloned := req.Clone(ctx) + if suppressAIBridgeSessionHeadersFromContext(req.Context()) { + cloned.Header.Del(chatprovider.HeaderCoderChatID) + cloned.Header.Del(chatprovider.HeaderCoderSubchatID) + } for name, value := range t.providerAuth.Headers { cloned.Header.Set(name, value) } diff --git a/coderd/x/chatd/model_routing_internal_test.go b/coderd/x/chatd/model_routing_internal_test.go index 786365d9fb..2ad1ec2a0b 100644 --- a/coderd/x/chatd/model_routing_internal_test.go +++ b/coderd/x/chatd/model_routing_internal_test.go @@ -1,6 +1,7 @@ package chatd import ( + "context" "database/sql" "fmt" "io" @@ -374,6 +375,66 @@ func TestAIGatewayModelForwardsProviderAuth(t *testing.T) { }) } +func TestAIGatewayRoundTripperCanSuppressSessionHeaders(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + ctx func() context.Context + wantChatID string + wantSubchat string + }{ + { + name: "preserves session headers by default", + ctx: func() context.Context { return t.Context() }, + wantChatID: "chat-id", + wantSubchat: "subchat-id", + }, + { + name: "suppresses session headers when marked", + ctx: func() context.Context { + return contextWithoutAIBridgeSessionHeaders(t.Context()) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + seen := make(chan http.Header, 1) + rt := &aiGatewayRoundTripper{ + base: roundTripFunc(func(req *http.Request) (*http.Response, error) { + seen <- req.Header.Clone() + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{}, + Body: io.NopCloser(strings.NewReader("")), + Request: req, + }, nil + }), + apiKeyID: uuid.NewString(), + } + req, err := http.NewRequestWithContext(tt.ctx(), http.MethodPost, "http://coder-aibridge/v1/responses", nil) + require.NoError(t, err) + req.Header.Set(chatprovider.HeaderCoderOwnerID, "owner-id") + req.Header.Set(chatprovider.HeaderCoderChatID, "chat-id") + req.Header.Set(chatprovider.HeaderCoderSubchatID, "subchat-id") + req.Header.Set(chatprovider.HeaderCoderWorkspaceID, "workspace-id") + + resp, err := rt.RoundTrip(req) + require.NoError(t, err) + require.NoError(t, resp.Body.Close()) + + got := <-seen + require.Equal(t, "owner-id", got.Get(chatprovider.HeaderCoderOwnerID)) + require.Equal(t, tt.wantChatID, got.Get(chatprovider.HeaderCoderChatID)) + require.Equal(t, tt.wantSubchat, got.Get(chatprovider.HeaderCoderSubchatID)) + require.Equal(t, "workspace-id", got.Get(chatprovider.HeaderCoderWorkspaceID)) + }) + } +} + func TestActiveTurnAPIKeyIDFromMessages(t *testing.T) { t.Parallel() diff --git a/coderd/x/chatd/quickgen.go b/coderd/x/chatd/quickgen.go index 6c3f52900f..10e470a84d 100644 --- a/coderd/x/chatd/quickgen.go +++ b/coderd/x/chatd/quickgen.go @@ -495,6 +495,7 @@ func generateStructuredTitleWithUsage( return "", fantasy.Usage{}, xerrors.New("title input was empty") } + ctx = contextWithoutAIBridgeSessionHeaders(ctx) prompt := syntheticObjectGenerationPrompt(systemPrompt, userInput) var maxOutputTokens int64 = 256 @@ -922,6 +923,7 @@ func generateStructuredTurnStatusLabel( return "", xerrors.New("turn status label input was empty") } + ctx = contextWithoutAIBridgeSessionHeaders(ctx) prompt := syntheticObjectGenerationPrompt(systemPrompt, userInput) var maxOutputTokens int64 = 64 diff --git a/coderd/x/chatd/quickgen_internal_test.go b/coderd/x/chatd/quickgen_internal_test.go index e373ef670b..56d5bd08fe 100644 --- a/coderd/x/chatd/quickgen_internal_test.go +++ b/coderd/x/chatd/quickgen_internal_test.go @@ -415,8 +415,9 @@ func TestMaybeGenerateChatTitlePreservesUpdatedAt(t *testing.T) { const wantTitle = "Failed workspace logs" model := &chattest.FakeModel{ - GenerateObjectFn: func(_ context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) { + GenerateObjectFn: func(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) { require.Equal(t, "propose_title", call.SchemaName) + requireSyntheticQuickgenContext(ctx, t) requireSyntheticQuickgenPrompt(t, call.Prompt, userPrompt) return &fantasy.ObjectResponse{ Object: map[string]any{"title": wantTitle}, @@ -495,6 +496,7 @@ func Test_generateManualTitle_UsesTimeout(t *testing.T) { deadline, 2*time.Second, ) + requireSyntheticQuickgenContext(ctx, t) requireSyntheticQuickgenPrompt(t, call.Prompt, "refresh chat title") require.Equal(t, "propose_title", call.SchemaName) return &fantasy.ObjectResponse{Object: map[string]any{"title": "Refresh title"}}, nil @@ -524,7 +526,8 @@ func Test_generateManualTitle_TruncatesFirstUserInput(t *testing.T) { } model := &chattest.FakeModel{ - GenerateObjectFn: func(_ context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) { + GenerateObjectFn: func(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) { + requireSyntheticQuickgenContext(ctx, t) requireSyntheticQuickgenPrompt(t, call.Prompt, truncateRunes(longFirstUserText, maxLatestUserMessageRunes)) // The manual title system prompt also includes the latest user excerpt. systemText, ok := call.Prompt[0].Content[0].(fantasy.TextPart) @@ -764,6 +767,12 @@ func openAICompatTestModel(t *testing.T, baseURL string) fantasy.LanguageModel { return model } +func requireSyntheticQuickgenContext(ctx context.Context, t *testing.T) { + t.Helper() + + require.True(t, suppressAIBridgeSessionHeadersFromContext(ctx)) +} + func requireSyntheticQuickgenPrompt(t *testing.T, prompt fantasy.Prompt, userInput string) { t.Helper() @@ -804,8 +813,9 @@ func TestGenerateStructuredTurnStatusLabel(t *testing.T) { t.Parallel() model := &chattest.FakeModel{ - GenerateObjectFn: func(_ context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) { + GenerateObjectFn: func(ctx context.Context, call fantasy.ObjectCall) (*fantasy.ObjectResponse, error) { require.Equal(t, "propose_turn_status_label", call.SchemaName) + requireSyntheticQuickgenContext(ctx, t) requireSyntheticQuickgenPrompt(t, call.Prompt, "done") return &fantasy.ObjectResponse{ Object: map[string]any{"label": "Submitted PR"},