fix(coderd): pass title API key context (#25723)

Fixes CODAGT-503

- Add failing-first coverage for manual title generation with missing
message `api_key_id`, with both context fallback and fail-closed cases.
- Set `aibridge.WithDelegatedAPIKeyID(ctx, apiKey.ID)` in
`regenerateChatTitle` and `proposeChatTitle`.
- In `generateManualTitleCandidate`, fall back to
`aibridge.DelegatedAPIKeyIDFromContext(ctx)` only when
`modelBuildOptionsFromMessages` yields an empty `ActiveAPIKeyID`.
- Keep `modelBuildOptionsFromMessages` pure and leave automatic title
generation unchanged.
This commit is contained in:
Cian Johnston
2026-05-27 13:20:36 +01:00
committed by GitHub
parent 10f37db35d
commit 0c27224fc2
4 changed files with 204 additions and 0 deletions
+3
View File
@@ -27,6 +27,7 @@ import (
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/agent/agentssh"
"github.com/coder/coder/v2/coderd/aibridge"
"github.com/coder/coder/v2/coderd/audit"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
@@ -3665,6 +3666,7 @@ func (api *API) regenerateChatTitle(rw http.ResponseWriter, r *http.Request) {
return
}
ctx = aibridge.WithDelegatedAPIKeyID(ctx, apiKey.ID)
updatedChat, err := api.chatDaemon.RegenerateChatTitle(ctx, chat)
if err != nil {
if errors.Is(err, chatd.ErrManualTitleRegenerationInProgress) {
@@ -3718,6 +3720,7 @@ func (api *API) proposeChatTitle(rw http.ResponseWriter, r *http.Request) {
return
}
ctx = aibridge.WithDelegatedAPIKeyID(ctx, apiKey.ID)
title, err := api.chatDaemon.ProposeChatTitle(ctx, chat)
if err != nil {
if errors.Is(err, chatd.ErrManualTitleRegenerationInProgress) {
+82
View File
@@ -7,10 +7,12 @@ import (
"encoding/json"
stderrors "errors"
"fmt"
"io"
"mime"
"net/http"
"regexp"
"slices"
"strconv"
"strings"
"sync/atomic"
"testing"
@@ -24,6 +26,7 @@ import (
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd"
"github.com/coder/coder/v2/coderd/aibridge"
"github.com/coder/coder/v2/coderd/audit"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database"
@@ -8470,6 +8473,85 @@ func TestProposeChatTitle(t *testing.T) {
})
}
func TestManualTitleEndpointsPassCallerAPIKeyToAIGateway(t *testing.T) {
t.Parallel()
for _, tt := range []struct {
name string
call func(context.Context, *codersdk.ExperimentalClient, uuid.UUID) error
}{
{
name: "RegenerateChatTitle",
call: func(ctx context.Context, client *codersdk.ExperimentalClient, chatID uuid.UUID) error {
_, err := client.RegenerateChatTitle(ctx, chatID)
return err
},
},
{
name: "ProposeChatTitle",
call: func(ctx context.Context, client *codersdk.ExperimentalClient, chatID uuid.UUID) error {
_, err := client.ProposeChatTitle(ctx, chatID)
return err
},
},
} {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
values := chatDeploymentValues(t)
require.NoError(t, values.AI.BridgeConfig.Enabled.Set("true"))
require.NoError(t, values.AI.Chat.AIGatewayRoutingEnabled.Set("true"))
client, db, api := newChatClientWithAPIAndDatabase(t, func(opts *coderdtest.Options) {
opts.DeploymentValues = values
})
firstUser := coderdtest.CreateFirstUser(t, client.Client)
modelConfig := createAdditionalChatModelConfig(t, client, "openai", "gpt-4.1")
wantAPIKeyID := strings.Split(client.SessionToken(), "-")[0]
wantTitle := "Fallback title"
seenAPIKeyID := make(chan string, 1)
stub := &stubTransportFactory{
handler: http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
apiKeyID, _ := aibridge.DelegatedAPIKeyIDFromContext(r.Context())
seenAPIKeyID <- apiKeyID
rw.Header().Set("Content-Type", "application/json")
text := strconv.Quote(`{"title":"` + wantTitle + `"}`)
_, _ = io.WriteString(rw, `{"id":"resp_test","object":"response","created_at":0,"status":"completed","model":"gpt-4.1","output":[{"id":"msg_test","type":"message","role":"assistant","content":[{"type":"output_text","text":`+text+`}]}],"usage":{"input_tokens":1,"output_tokens":1,"total_tokens":2}}`)
}),
calls: make(chan callRecord, 1),
}
var factory aibridge.TransportFactory = stub
api.AIBridgeTransportFactory.Store(&factory)
require.NoError(t, client.UpdateChatModelOverride(ctx, codersdk.ChatModelOverrideContextTitleGeneration, codersdk.UpdateChatModelOverrideRequest{
ModelConfigID: modelConfig.ID.String(),
}))
chat := dbgen.Chat(t, db, database.Chat{
OrganizationID: firstUser.OrganizationID,
OwnerID: firstUser.UserID,
LastModelConfigID: modelConfig.ID,
Title: "initial title",
Status: database.ChatStatusCompleted,
})
content, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
codersdk.ChatMessageText("manual title source"),
})
require.NoError(t, err)
_ = dbgen.ChatMessage(t, db, database.ChatMessage{
ChatID: chat.ID,
CreatedBy: uuid.NullUUID{UUID: firstUser.UserID, Valid: true},
ModelConfigID: uuid.NullUUID{UUID: modelConfig.ID, Valid: true},
Role: database.ChatMessageRoleUser,
Visibility: database.ChatMessageVisibilityBoth,
Content: content,
})
require.NoError(t, tt.call(ctx, client, chat.ID))
require.Equal(t, wantAPIKeyID, testutil.RequireReceive(ctx, t, seenAPIKeyID))
})
}
}
func TestGetChatDiffStatus(t *testing.T) {
t.Parallel()
+8
View File
@@ -3162,6 +3162,7 @@ func (p *Server) recordManualTitleGenerationFailure(
// generateManualTitleCandidate performs only model generation and returns the
// candidate plus accounting metadata. Endpoint-specific commit paths are
// responsible for recording usage and deciding whether to persist the title.
// The context may carry the caller's delegated API key for manual title routes.
func (p *Server) generateManualTitleCandidate(
ctx context.Context,
store database.Store,
@@ -3199,6 +3200,13 @@ func (p *Server) generateManualTitleCandidate(
return manualTitleCandidateResult{}, nil
}
modelOpts := modelBuildOptionsFromMessages(messages)
// Manual title routes can run over messages that lack API key attribution.
// Fall back to the authenticated caller's delegated key for AI Gateway routing.
if modelOpts.ActiveAPIKeyID == "" {
if apiKeyID, ok := aibridge.DelegatedAPIKeyIDFromContext(ctx); ok {
modelOpts.ActiveAPIKeyID = apiKeyID
}
}
model, modelConfig, modelKeys, err := p.resolveManualTitleModel(ctx, store, chat, keys, modelOpts)
result := manualTitleCandidateResult{
@@ -3,6 +3,10 @@ package chatd
import (
"context"
"database/sql"
"io"
"net/http"
"strconv"
"strings"
"sync/atomic"
"testing"
@@ -14,6 +18,7 @@ import (
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/aibridge"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
@@ -562,6 +567,112 @@ func TestResolveManualTitleModel_TitleGenerationOverrideMissingCredentials(t *te
require.Equal(t, database.ChatModelConfig{}, gotConfig)
}
func TestGenerateManualTitleCandidate_ActiveAPIKeyIDFallback(t *testing.T) {
t.Parallel()
contextAPIKeyID := uuid.NewString()
messageAPIKeyID := uuid.NewString()
shadowedContextAPIKeyID := uuid.NewString()
tests := []struct {
name string
messageAPIKeyID string
contextAPIKeyID string
wantAPIKeyID string
wantErrContains string
}{
{
name: "ContextFallback",
contextAPIKeyID: contextAPIKeyID,
wantAPIKeyID: contextAPIKeyID,
},
{
name: "MessageTakesPrecedence",
messageAPIKeyID: messageAPIKeyID,
contextAPIKeyID: shadowedContextAPIKeyID,
wantAPIKeyID: messageAPIKeyID,
},
{
name: "NoKeyAnywhereFailsClosed",
wantErrContains: "AI Gateway routing requires the active turn API key ID",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
if tt.contextAPIKeyID != "" {
ctx = aibridge.WithDelegatedAPIKeyID(ctx, tt.contextAPIKeyID)
}
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
chat, messages := titleOverrideTestChatAndMessages(t)
chat.OrganizationID = uuid.New()
if tt.messageAPIKeyID != "" {
messages[0] = withChatMessageAPIKeyID(messages[0], tt.messageAPIKeyID)
}
overrideConfig := titleOverrideModelConfig("gpt-4.1", true)
providerID := uuid.New()
overrideConfig.AIProviderID = uuid.NullUUID{UUID: providerID, Valid: true}
provider := database.AIProvider{
ID: providerID,
Name: "primary-openai",
Type: database.AiProviderTypeOpenai,
Enabled: true,
}
wantTitle := "Context title"
seenAPIKeyID := make(chan string, 1)
factory := &aibridgeTestFactory{rt: roundTripFunc(func(req *http.Request) (*http.Response, error) {
apiKeyID, _ := aibridge.DelegatedAPIKeyIDFromContext(req.Context())
seenAPIKeyID <- apiKeyID
text := strconv.Quote(`{"title":"` + wantTitle + `"}`)
body := `{"id":"resp_test","object":"response","created_at":0,"status":"completed","model":"gpt-4.1","output":[{"id":"msg_test","type":"message","role":"assistant","content":[{"type":"output_text","text":` + text + `}]}],"usage":{"input_tokens":1,"output_tokens":1,"total_tokens":2}}`
return &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(body)),
Request: req,
}, nil
})}
db.EXPECT().GetChatUsageLimitConfig(gomock.Any()).Return(database.ChatUsageLimitConfig{}, sql.ErrNoRows)
db.EXPECT().GetChatMessagesByChatIDAscPaginated(gomock.Any(), database.GetChatMessagesByChatIDAscPaginatedParams{
ChatID: chat.ID,
AfterID: 0,
LimitVal: manualTitleMessageWindowLimit,
}).Return(messages, nil)
db.EXPECT().GetChatMessagesByChatIDDescPaginated(gomock.Any(), database.GetChatMessagesByChatIDDescPaginatedParams{
ChatID: chat.ID,
BeforeID: 0,
LimitVal: manualTitleMessageWindowLimit,
}).Return(nil, nil)
db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return(overrideConfig.ID.String(), nil)
db.EXPECT().GetChatModelConfigByID(gomock.Any(), overrideConfig.ID).Return(overrideConfig, nil)
db.EXPECT().GetAIProviderByID(gomock.Any(), providerID).Return(provider, nil).AnyTimes()
db.EXPECT().GetAIProviderKeysByProviderID(gomock.Any(), providerID).Return([]database.AIProviderKey{{
ProviderID: providerID,
APIKey: "test-key",
}}, nil).AnyTimes()
server := titleOverrideTestServer(db, logger)
server.aiGatewayRoutingEnabled = true
server.aibridgeTransportFactory = aibridgeTestFactoryPointer(factory)
result, err := server.generateManualTitleCandidate(ctx, db, chat, chatprovider.ProviderAPIKeys{})
if tt.wantErrContains != "" {
require.ErrorContains(t, err, tt.wantErrContains)
return
}
require.NoError(t, err)
require.Equal(t, wantTitle, result.title)
require.True(t, result.hasMessages)
require.Equal(t, tt.wantAPIKeyID, result.activeAPIKeyID)
require.Equal(t, tt.wantAPIKeyID, testutil.RequireReceive(ctx, t, seenAPIKeyID))
})
}
}
func TestResolveManualTitleModel_TitleGenerationOverrideSetUnusable(t *testing.T) {
t.Parallel()