mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
fix(coderd): pass title API key context (#25723)
Fixes CODAGT-503 - Add failing-first coverage for manual title generation with missing message `api_key_id`, with both context fallback and fail-closed cases. - Set `aibridge.WithDelegatedAPIKeyID(ctx, apiKey.ID)` in `regenerateChatTitle` and `proposeChatTitle`. - In `generateManualTitleCandidate`, fall back to `aibridge.DelegatedAPIKeyIDFromContext(ctx)` only when `modelBuildOptionsFromMessages` yields an empty `ActiveAPIKeyID`. - Keep `modelBuildOptionsFromMessages` pure and leave automatic title generation unchanged.
This commit is contained in:
@@ -27,6 +27,7 @@ import (
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/agent/agentssh"
|
||||
"github.com/coder/coder/v2/coderd/aibridge"
|
||||
"github.com/coder/coder/v2/coderd/audit"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/db2sdk"
|
||||
@@ -3665,6 +3666,7 @@ func (api *API) regenerateChatTitle(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
ctx = aibridge.WithDelegatedAPIKeyID(ctx, apiKey.ID)
|
||||
updatedChat, err := api.chatDaemon.RegenerateChatTitle(ctx, chat)
|
||||
if err != nil {
|
||||
if errors.Is(err, chatd.ErrManualTitleRegenerationInProgress) {
|
||||
@@ -3718,6 +3720,7 @@ func (api *API) proposeChatTitle(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
ctx = aibridge.WithDelegatedAPIKeyID(ctx, apiKey.ID)
|
||||
title, err := api.chatDaemon.ProposeChatTitle(ctx, chat)
|
||||
if err != nil {
|
||||
if errors.Is(err, chatd.ErrManualTitleRegenerationInProgress) {
|
||||
|
||||
@@ -7,10 +7,12 @@ import (
|
||||
"encoding/json"
|
||||
stderrors "errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"mime"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
@@ -24,6 +26,7 @@ import (
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd"
|
||||
"github.com/coder/coder/v2/coderd/aibridge"
|
||||
"github.com/coder/coder/v2/coderd/audit"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
@@ -8470,6 +8473,85 @@ func TestProposeChatTitle(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestManualTitleEndpointsPassCallerAPIKeyToAIGateway(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
call func(context.Context, *codersdk.ExperimentalClient, uuid.UUID) error
|
||||
}{
|
||||
{
|
||||
name: "RegenerateChatTitle",
|
||||
call: func(ctx context.Context, client *codersdk.ExperimentalClient, chatID uuid.UUID) error {
|
||||
_, err := client.RegenerateChatTitle(ctx, chatID)
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ProposeChatTitle",
|
||||
call: func(ctx context.Context, client *codersdk.ExperimentalClient, chatID uuid.UUID) error {
|
||||
_, err := client.ProposeChatTitle(ctx, chatID)
|
||||
return err
|
||||
},
|
||||
},
|
||||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
values := chatDeploymentValues(t)
|
||||
require.NoError(t, values.AI.BridgeConfig.Enabled.Set("true"))
|
||||
require.NoError(t, values.AI.Chat.AIGatewayRoutingEnabled.Set("true"))
|
||||
client, db, api := newChatClientWithAPIAndDatabase(t, func(opts *coderdtest.Options) {
|
||||
opts.DeploymentValues = values
|
||||
})
|
||||
firstUser := coderdtest.CreateFirstUser(t, client.Client)
|
||||
modelConfig := createAdditionalChatModelConfig(t, client, "openai", "gpt-4.1")
|
||||
wantAPIKeyID := strings.Split(client.SessionToken(), "-")[0]
|
||||
wantTitle := "Fallback title"
|
||||
seenAPIKeyID := make(chan string, 1)
|
||||
stub := &stubTransportFactory{
|
||||
handler: http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
apiKeyID, _ := aibridge.DelegatedAPIKeyIDFromContext(r.Context())
|
||||
seenAPIKeyID <- apiKeyID
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
text := strconv.Quote(`{"title":"` + wantTitle + `"}`)
|
||||
_, _ = io.WriteString(rw, `{"id":"resp_test","object":"response","created_at":0,"status":"completed","model":"gpt-4.1","output":[{"id":"msg_test","type":"message","role":"assistant","content":[{"type":"output_text","text":`+text+`}]}],"usage":{"input_tokens":1,"output_tokens":1,"total_tokens":2}}`)
|
||||
}),
|
||||
calls: make(chan callRecord, 1),
|
||||
}
|
||||
var factory aibridge.TransportFactory = stub
|
||||
api.AIBridgeTransportFactory.Store(&factory)
|
||||
require.NoError(t, client.UpdateChatModelOverride(ctx, codersdk.ChatModelOverrideContextTitleGeneration, codersdk.UpdateChatModelOverrideRequest{
|
||||
ModelConfigID: modelConfig.ID.String(),
|
||||
}))
|
||||
|
||||
chat := dbgen.Chat(t, db, database.Chat{
|
||||
OrganizationID: firstUser.OrganizationID,
|
||||
OwnerID: firstUser.UserID,
|
||||
LastModelConfigID: modelConfig.ID,
|
||||
Title: "initial title",
|
||||
Status: database.ChatStatusCompleted,
|
||||
})
|
||||
content, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText("manual title source"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_ = dbgen.ChatMessage(t, db, database.ChatMessage{
|
||||
ChatID: chat.ID,
|
||||
CreatedBy: uuid.NullUUID{UUID: firstUser.UserID, Valid: true},
|
||||
ModelConfigID: uuid.NullUUID{UUID: modelConfig.ID, Valid: true},
|
||||
Role: database.ChatMessageRoleUser,
|
||||
Visibility: database.ChatMessageVisibilityBoth,
|
||||
Content: content,
|
||||
})
|
||||
|
||||
require.NoError(t, tt.call(ctx, client, chat.ID))
|
||||
require.Equal(t, wantAPIKeyID, testutil.RequireReceive(ctx, t, seenAPIKeyID))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetChatDiffStatus(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -3162,6 +3162,7 @@ func (p *Server) recordManualTitleGenerationFailure(
|
||||
// generateManualTitleCandidate performs only model generation and returns the
|
||||
// candidate plus accounting metadata. Endpoint-specific commit paths are
|
||||
// responsible for recording usage and deciding whether to persist the title.
|
||||
// The context may carry the caller's delegated API key for manual title routes.
|
||||
func (p *Server) generateManualTitleCandidate(
|
||||
ctx context.Context,
|
||||
store database.Store,
|
||||
@@ -3199,6 +3200,13 @@ func (p *Server) generateManualTitleCandidate(
|
||||
return manualTitleCandidateResult{}, nil
|
||||
}
|
||||
modelOpts := modelBuildOptionsFromMessages(messages)
|
||||
// Manual title routes can run over messages that lack API key attribution.
|
||||
// Fall back to the authenticated caller's delegated key for AI Gateway routing.
|
||||
if modelOpts.ActiveAPIKeyID == "" {
|
||||
if apiKeyID, ok := aibridge.DelegatedAPIKeyIDFromContext(ctx); ok {
|
||||
modelOpts.ActiveAPIKeyID = apiKeyID
|
||||
}
|
||||
}
|
||||
|
||||
model, modelConfig, modelKeys, err := p.resolveManualTitleModel(ctx, store, chat, keys, modelOpts)
|
||||
result := manualTitleCandidateResult{
|
||||
|
||||
@@ -3,6 +3,10 @@ package chatd
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
@@ -14,6 +18,7 @@ import (
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd/aibridge"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
|
||||
@@ -562,6 +567,112 @@ func TestResolveManualTitleModel_TitleGenerationOverrideMissingCredentials(t *te
|
||||
require.Equal(t, database.ChatModelConfig{}, gotConfig)
|
||||
}
|
||||
|
||||
func TestGenerateManualTitleCandidate_ActiveAPIKeyIDFallback(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
contextAPIKeyID := uuid.NewString()
|
||||
messageAPIKeyID := uuid.NewString()
|
||||
shadowedContextAPIKeyID := uuid.NewString()
|
||||
tests := []struct {
|
||||
name string
|
||||
messageAPIKeyID string
|
||||
contextAPIKeyID string
|
||||
wantAPIKeyID string
|
||||
wantErrContains string
|
||||
}{
|
||||
{
|
||||
name: "ContextFallback",
|
||||
contextAPIKeyID: contextAPIKeyID,
|
||||
wantAPIKeyID: contextAPIKeyID,
|
||||
},
|
||||
{
|
||||
name: "MessageTakesPrecedence",
|
||||
messageAPIKeyID: messageAPIKeyID,
|
||||
contextAPIKeyID: shadowedContextAPIKeyID,
|
||||
wantAPIKeyID: messageAPIKeyID,
|
||||
},
|
||||
{
|
||||
name: "NoKeyAnywhereFailsClosed",
|
||||
wantErrContains: "AI Gateway routing requires the active turn API key ID",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
if tt.contextAPIKeyID != "" {
|
||||
ctx = aibridge.WithDelegatedAPIKeyID(ctx, tt.contextAPIKeyID)
|
||||
}
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
|
||||
chat, messages := titleOverrideTestChatAndMessages(t)
|
||||
chat.OrganizationID = uuid.New()
|
||||
if tt.messageAPIKeyID != "" {
|
||||
messages[0] = withChatMessageAPIKeyID(messages[0], tt.messageAPIKeyID)
|
||||
}
|
||||
overrideConfig := titleOverrideModelConfig("gpt-4.1", true)
|
||||
providerID := uuid.New()
|
||||
overrideConfig.AIProviderID = uuid.NullUUID{UUID: providerID, Valid: true}
|
||||
provider := database.AIProvider{
|
||||
ID: providerID,
|
||||
Name: "primary-openai",
|
||||
Type: database.AiProviderTypeOpenai,
|
||||
Enabled: true,
|
||||
}
|
||||
wantTitle := "Context title"
|
||||
seenAPIKeyID := make(chan string, 1)
|
||||
factory := &aibridgeTestFactory{rt: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
apiKeyID, _ := aibridge.DelegatedAPIKeyIDFromContext(req.Context())
|
||||
seenAPIKeyID <- apiKeyID
|
||||
text := strconv.Quote(`{"title":"` + wantTitle + `"}`)
|
||||
body := `{"id":"resp_test","object":"response","created_at":0,"status":"completed","model":"gpt-4.1","output":[{"id":"msg_test","type":"message","role":"assistant","content":[{"type":"output_text","text":` + text + `}]}],"usage":{"input_tokens":1,"output_tokens":1,"total_tokens":2}}`
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader(body)),
|
||||
Request: req,
|
||||
}, nil
|
||||
})}
|
||||
|
||||
db.EXPECT().GetChatUsageLimitConfig(gomock.Any()).Return(database.ChatUsageLimitConfig{}, sql.ErrNoRows)
|
||||
db.EXPECT().GetChatMessagesByChatIDAscPaginated(gomock.Any(), database.GetChatMessagesByChatIDAscPaginatedParams{
|
||||
ChatID: chat.ID,
|
||||
AfterID: 0,
|
||||
LimitVal: manualTitleMessageWindowLimit,
|
||||
}).Return(messages, nil)
|
||||
db.EXPECT().GetChatMessagesByChatIDDescPaginated(gomock.Any(), database.GetChatMessagesByChatIDDescPaginatedParams{
|
||||
ChatID: chat.ID,
|
||||
BeforeID: 0,
|
||||
LimitVal: manualTitleMessageWindowLimit,
|
||||
}).Return(nil, nil)
|
||||
db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return(overrideConfig.ID.String(), nil)
|
||||
db.EXPECT().GetChatModelConfigByID(gomock.Any(), overrideConfig.ID).Return(overrideConfig, nil)
|
||||
db.EXPECT().GetAIProviderByID(gomock.Any(), providerID).Return(provider, nil).AnyTimes()
|
||||
db.EXPECT().GetAIProviderKeysByProviderID(gomock.Any(), providerID).Return([]database.AIProviderKey{{
|
||||
ProviderID: providerID,
|
||||
APIKey: "test-key",
|
||||
}}, nil).AnyTimes()
|
||||
|
||||
server := titleOverrideTestServer(db, logger)
|
||||
server.aiGatewayRoutingEnabled = true
|
||||
server.aibridgeTransportFactory = aibridgeTestFactoryPointer(factory)
|
||||
result, err := server.generateManualTitleCandidate(ctx, db, chat, chatprovider.ProviderAPIKeys{})
|
||||
if tt.wantErrContains != "" {
|
||||
require.ErrorContains(t, err, tt.wantErrContains)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, wantTitle, result.title)
|
||||
require.True(t, result.hasMessages)
|
||||
require.Equal(t, tt.wantAPIKeyID, result.activeAPIKeyID)
|
||||
require.Equal(t, tt.wantAPIKeyID, testutil.RequireReceive(ctx, t, seenAPIKeyID))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveManualTitleModel_TitleGenerationOverrideSetUnusable(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user