mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
fix(coderd): pass title API key context (#25723)
Fixes CODAGT-503 - Add failing-first coverage for manual title generation with missing message `api_key_id`, with both context fallback and fail-closed cases. - Set `aibridge.WithDelegatedAPIKeyID(ctx, apiKey.ID)` in `regenerateChatTitle` and `proposeChatTitle`. - In `generateManualTitleCandidate`, fall back to `aibridge.DelegatedAPIKeyIDFromContext(ctx)` only when `modelBuildOptionsFromMessages` yields an empty `ActiveAPIKeyID`. - Keep `modelBuildOptionsFromMessages` pure and leave automatic title generation unchanged.
This commit is contained in:
@@ -3162,6 +3162,7 @@ func (p *Server) recordManualTitleGenerationFailure(
|
||||
// generateManualTitleCandidate performs only model generation and returns the
|
||||
// candidate plus accounting metadata. Endpoint-specific commit paths are
|
||||
// responsible for recording usage and deciding whether to persist the title.
|
||||
// The context may carry the caller's delegated API key for manual title routes.
|
||||
func (p *Server) generateManualTitleCandidate(
|
||||
ctx context.Context,
|
||||
store database.Store,
|
||||
@@ -3199,6 +3200,13 @@ func (p *Server) generateManualTitleCandidate(
|
||||
return manualTitleCandidateResult{}, nil
|
||||
}
|
||||
modelOpts := modelBuildOptionsFromMessages(messages)
|
||||
// Manual title routes can run over messages that lack API key attribution.
|
||||
// Fall back to the authenticated caller's delegated key for AI Gateway routing.
|
||||
if modelOpts.ActiveAPIKeyID == "" {
|
||||
if apiKeyID, ok := aibridge.DelegatedAPIKeyIDFromContext(ctx); ok {
|
||||
modelOpts.ActiveAPIKeyID = apiKeyID
|
||||
}
|
||||
}
|
||||
|
||||
model, modelConfig, modelKeys, err := p.resolveManualTitleModel(ctx, store, chat, keys, modelOpts)
|
||||
result := manualTitleCandidateResult{
|
||||
|
||||
@@ -3,6 +3,10 @@ package chatd
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
@@ -14,6 +18,7 @@ import (
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd/aibridge"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
|
||||
@@ -562,6 +567,112 @@ func TestResolveManualTitleModel_TitleGenerationOverrideMissingCredentials(t *te
|
||||
require.Equal(t, database.ChatModelConfig{}, gotConfig)
|
||||
}
|
||||
|
||||
func TestGenerateManualTitleCandidate_ActiveAPIKeyIDFallback(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
contextAPIKeyID := uuid.NewString()
|
||||
messageAPIKeyID := uuid.NewString()
|
||||
shadowedContextAPIKeyID := uuid.NewString()
|
||||
tests := []struct {
|
||||
name string
|
||||
messageAPIKeyID string
|
||||
contextAPIKeyID string
|
||||
wantAPIKeyID string
|
||||
wantErrContains string
|
||||
}{
|
||||
{
|
||||
name: "ContextFallback",
|
||||
contextAPIKeyID: contextAPIKeyID,
|
||||
wantAPIKeyID: contextAPIKeyID,
|
||||
},
|
||||
{
|
||||
name: "MessageTakesPrecedence",
|
||||
messageAPIKeyID: messageAPIKeyID,
|
||||
contextAPIKeyID: shadowedContextAPIKeyID,
|
||||
wantAPIKeyID: messageAPIKeyID,
|
||||
},
|
||||
{
|
||||
name: "NoKeyAnywhereFailsClosed",
|
||||
wantErrContains: "AI Gateway routing requires the active turn API key ID",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
if tt.contextAPIKeyID != "" {
|
||||
ctx = aibridge.WithDelegatedAPIKeyID(ctx, tt.contextAPIKeyID)
|
||||
}
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
chat, messages := titleOverrideTestChatAndMessages(t)
|
||||
chat.OrganizationID = uuid.New()
|
||||
if tt.messageAPIKeyID != "" {
|
||||
messages[0] = withChatMessageAPIKeyID(messages[0], tt.messageAPIKeyID)
|
||||
}
|
||||
overrideConfig := titleOverrideModelConfig("gpt-4.1", true)
|
||||
providerID := uuid.New()
|
||||
overrideConfig.AIProviderID = uuid.NullUUID{UUID: providerID, Valid: true}
|
||||
provider := database.AIProvider{
|
||||
ID: providerID,
|
||||
Name: "primary-openai",
|
||||
Type: database.AiProviderTypeOpenai,
|
||||
Enabled: true,
|
||||
}
|
||||
wantTitle := "Context title"
|
||||
seenAPIKeyID := make(chan string, 1)
|
||||
factory := &aibridgeTestFactory{rt: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
apiKeyID, _ := aibridge.DelegatedAPIKeyIDFromContext(req.Context())
|
||||
seenAPIKeyID <- apiKeyID
|
||||
text := strconv.Quote(`{"title":"` + wantTitle + `"}`)
|
||||
body := `{"id":"resp_test","object":"response","created_at":0,"status":"completed","model":"gpt-4.1","output":[{"id":"msg_test","type":"message","role":"assistant","content":[{"type":"output_text","text":` + text + `}]}],"usage":{"input_tokens":1,"output_tokens":1,"total_tokens":2}}`
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader(body)),
|
||||
Request: req,
|
||||
}, nil
|
||||
})}
|
||||
|
||||
db.EXPECT().GetChatUsageLimitConfig(gomock.Any()).Return(database.ChatUsageLimitConfig{}, sql.ErrNoRows)
|
||||
db.EXPECT().GetChatMessagesByChatIDAscPaginated(gomock.Any(), database.GetChatMessagesByChatIDAscPaginatedParams{
|
||||
ChatID: chat.ID,
|
||||
AfterID: 0,
|
||||
LimitVal: manualTitleMessageWindowLimit,
|
||||
}).Return(messages, nil)
|
||||
db.EXPECT().GetChatMessagesByChatIDDescPaginated(gomock.Any(), database.GetChatMessagesByChatIDDescPaginatedParams{
|
||||
ChatID: chat.ID,
|
||||
BeforeID: 0,
|
||||
LimitVal: manualTitleMessageWindowLimit,
|
||||
}).Return(nil, nil)
|
||||
db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return(overrideConfig.ID.String(), nil)
|
||||
db.EXPECT().GetChatModelConfigByID(gomock.Any(), overrideConfig.ID).Return(overrideConfig, nil)
|
||||
db.EXPECT().GetAIProviderByID(gomock.Any(), providerID).Return(provider, nil).AnyTimes()
|
||||
db.EXPECT().GetAIProviderKeysByProviderID(gomock.Any(), providerID).Return([]database.AIProviderKey{{
|
||||
ProviderID: providerID,
|
||||
APIKey: "test-key",
|
||||
}}, nil).AnyTimes()
|
||||
|
||||
server := titleOverrideTestServer(db, logger)
|
||||
server.aiGatewayRoutingEnabled = true
|
||||
server.aibridgeTransportFactory = aibridgeTestFactoryPointer(factory)
|
||||
result, err := server.generateManualTitleCandidate(ctx, db, chat, chatprovider.ProviderAPIKeys{})
|
||||
if tt.wantErrContains != "" {
|
||||
require.ErrorContains(t, err, tt.wantErrContains)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, wantTitle, result.title)
|
||||
require.True(t, result.hasMessages)
|
||||
require.Equal(t, tt.wantAPIKeyID, result.activeAPIKeyID)
|
||||
require.Equal(t, tt.wantAPIKeyID, testutil.RequireReceive(ctx, t, seenAPIKeyID))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveManualTitleModel_TitleGenerationOverrideSetUnusable(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user