From 5968c3dac7efccf3ac593a69f63a4c227295bd5a Mon Sep 17 00:00:00 2001 From: Michael Suchacz <203725896+ibetitsmike@users.noreply.github.com> Date: Fri, 22 May 2026 02:17:09 +0200 Subject: [PATCH] feat: use AI provider keys at runtime (#25414) --- coderd/coderd.go | 2 + coderd/database/dbauthz/dbauthz.go | 1 + coderd/database/dbgen/dbgen.go | 20 ++ coderd/x/chatd/chatd.go | 271 ++++++++++++++++---- coderd/x/chatd/chatd_internal_test.go | 38 ++- coderd/x/chatd/chatprovider/chatprovider.go | 8 + coderd/x/chatd/quickgen.go | 8 +- coderd/x/chatd/subagent.go | 56 ++-- coderd/x/chatd/subagent_catalog.go | 6 +- coderd/x/chatd/subagent_internal_test.go | 136 +++++++++- coderd/x/chatd/title_override.go | 48 +++- coderd/x/chatd/title_override_test.go | 77 +++++- 12 files changed, 567 insertions(+), 104 deletions(-) diff --git a/coderd/coderd.go b/coderd/coderd.go index 9f1d85bca7..875fd3affd 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -813,6 +813,8 @@ func New(options *Options) *API { SubscribeFn: options.ChatSubscribeFn, MaxChatsPerAcquire: int32(maxChatsPerAcquire), //nolint:gosec // maxChatsPerAcquire is clamped to int32 range above. ProviderAPIKeys: providerAPIKeys, + AllowBYOK: options.DeploymentValues.AI.BridgeConfig.AllowBYOK.Value(), + AllowBYOKSet: true, AlwaysEnableDebugLogs: options.DeploymentValues.AI.Chat.DebugLoggingEnabled.Value(), AgentConn: api.agentProvider.AgentConn, AgentInactiveDisconnectTimeout: api.AgentInactiveDisconnectTimeout, diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index a1ac393df5..b0d5fc17fa 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -710,6 +710,7 @@ var ( Identifier: rbac.RoleIdentifier{Name: "chatd"}, DisplayName: "Chat Daemon", Site: rbac.Permissions(map[string][]policy.Action{ + rbac.ResourceAIProvider.Type: {policy.ActionRead}, rbac.ResourceChat.Type: {policy.ActionCreate, policy.ActionRead, policy.ActionUpdate, policy.ActionDelete}, rbac.ResourceWorkspace.Type: {policy.ActionRead, policy.ActionUpdate}, rbac.ResourceDeploymentConfig.Type: {policy.ActionRead}, diff --git a/coderd/database/dbgen/dbgen.go b/coderd/database/dbgen/dbgen.go index a132249ba5..e037327b47 100644 --- a/coderd/database/dbgen/dbgen.go +++ b/coderd/database/dbgen/dbgen.go @@ -225,6 +225,26 @@ func AIProviderKey(t testing.TB, db database.Store, seed database.AIProviderKey, return key } +// AIProviderWithOptionalKey inserts an AI provider and, when apiKey is not +// empty, inserts a provider-scoped key for it. +func AIProviderWithOptionalKey( + t testing.TB, + db database.Store, + seed database.AIProvider, + apiKey string, + munge ...func(*database.InsertAIProviderParams), +) database.AIProvider { + t.Helper() + provider := AIProvider(t, db, seed, munge...) + if apiKey != "" { + AIProviderKey(t, db, database.AIProviderKey{ + ProviderID: provider.ID, + APIKey: apiKey, + }) + } + return provider +} + func ChatProvider(t testing.TB, db database.Store, seed database.ChatProvider, munge ...func(*database.InsertChatProviderParams)) database.ChatProvider { t.Helper() params := database.InsertChatProviderParams{ diff --git a/coderd/x/chatd/chatd.go b/coderd/x/chatd/chatd.go index 80468b03f5..2172e54731 100644 --- a/coderd/x/chatd/chatd.go +++ b/coderd/x/chatd/chatd.go @@ -230,6 +230,7 @@ type Server struct { pubsub pubsub.Pubsub webpushDispatcher webpush.Dispatcher providerAPIKeys chatprovider.ProviderAPIKeys + allowBYOK bool oidcTokenSource mcpclient.UserOIDCTokenSource debugSvc *chatdebug.Service debugSvcFactory func() *chatdebug.Service @@ -3009,9 +3010,9 @@ func (p *Server) RegenerateChatTitle( // keeping chat ownership authorization at the HTTP layer. //nolint:gocritic // Non-admin users need chatd-scoped config reads here. chatdCtx := dbauthz.AsChatd(ctx) - keys, err := p.resolveUserProviderAPIKeys(chatdCtx, chat.OwnerID) + keys, err := p.resolveUserProviderAPIKeys(chatdCtx, chat.OwnerID, uuid.Nil) if err != nil { - return database.Chat{}, xerrors.Errorf("resolve chat providers: %w", err) + keys = chatprovider.ProviderAPIKeys{} } if err := p.acquireManualTitleLock(ctx, chat.ID); err != nil { return database.Chat{}, err @@ -3073,9 +3074,9 @@ func (p *Server) ProposeChatTitle( ) (string, error) { //nolint:gocritic // Non-admin users need chatd-scoped config reads here. chatdCtx := dbauthz.AsChatd(ctx) - keys, err := p.resolveUserProviderAPIKeys(chatdCtx, chat.OwnerID) + keys, err := p.resolveUserProviderAPIKeys(chatdCtx, chat.OwnerID, uuid.Nil) if err != nil { - return "", xerrors.Errorf("resolve chat providers: %w", err) + keys = chatprovider.ProviderAPIKeys{} } if err := p.acquireManualTitleLock(ctx, chat.ID); err != nil { return "", err @@ -3161,7 +3162,7 @@ func (p *Server) generateManualTitleCandidate( return manualTitleCandidateResult{}, nil } - model, modelConfig, err := p.resolveManualTitleModel(ctx, store, chat, keys) + model, modelConfig, modelKeys, err := p.resolveManualTitleModel(ctx, store, chat, keys) result := manualTitleCandidateResult{ modelConfig: modelConfig, hasMessages: true, @@ -3179,7 +3180,7 @@ func (p *Server) generateManualTitleCandidate( debugSvc, chat, modelConfig, - keys, + modelKeys, messages, model, ) @@ -3541,15 +3542,15 @@ func (p *Server) resolveManualTitleModel( store database.Store, chat database.Chat, keys chatprovider.ProviderAPIKeys, -) (fantasy.LanguageModel, database.ChatModelConfig, error) { - overrideConfig, overrideModel, overrideSet, overrideErr := p.resolveTitleGenerationModelOverride( +) (fantasy.LanguageModel, database.ChatModelConfig, chatprovider.ProviderAPIKeys, error) { + overrideConfig, overrideModel, overrideKeys, overrideSet, overrideErr := p.resolveTitleGenerationModelOverride( ctx, chat, keys, ) if overrideErr != nil { if overrideSet { - return nil, database.ChatModelConfig{}, xerrors.Errorf( + return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, xerrors.Errorf( "resolve manual title generation model override: %w", overrideErr, ) @@ -3559,7 +3560,7 @@ func (p *Server) resolveManualTitleModel( slog.Error(overrideErr), ) } else if overrideSet { - return overrideModel, overrideConfig, nil + return overrideModel, overrideConfig, overrideKeys, nil } configs, err := store.GetEnabledChatModelConfigs(ctx) @@ -3576,14 +3577,7 @@ func (p *Server) resolveManualTitleModel( return p.resolveFallbackManualTitleModel(ctx, chat, keys) } - model, err := chatprovider.ModelFromConfig( - config.Provider, - config.Model, - keys, - chatprovider.UserAgent(), - chatprovider.CoderHeaders(chat), - nil, - ) + providerHint, modelKeys, err := p.resolveModelConfigProviderHintAndKeys(ctx, chat.OwnerID, config, keys) if err != nil { p.logger.Debug(ctx, "manual title preferred model unavailable", slog.F("chat_id", chat.ID), @@ -3593,37 +3587,58 @@ func (p *Server) resolveManualTitleModel( ) return p.resolveFallbackManualTitleModel(ctx, chat, keys) } + model, err := chatprovider.ModelFromConfig( + providerHint, + config.Model, + modelKeys, + chatprovider.UserAgent(), + chatprovider.CoderHeaders(chat), + nil, + ) + if err != nil { + p.logger.Debug(ctx, "manual title preferred model unavailable", + slog.F("chat_id", chat.ID), + slog.F("provider", providerHint), + slog.F("model", config.Model), + slog.Error(err), + ) + return p.resolveFallbackManualTitleModel(ctx, chat, keys) + } - return model, config, nil + return model, config, modelKeys, nil } func (p *Server) resolveFallbackManualTitleModel( ctx context.Context, chat database.Chat, keys chatprovider.ProviderAPIKeys, -) (fantasy.LanguageModel, database.ChatModelConfig, error) { +) (fantasy.LanguageModel, database.ChatModelConfig, chatprovider.ProviderAPIKeys, error) { config, err := p.resolveModelConfig(ctx, chat) if err != nil { - return nil, database.ChatModelConfig{}, xerrors.Errorf( + return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, xerrors.Errorf( "resolve fallback manual title model config: %w", err, ) } + providerHint, modelKeys, err := p.resolveModelConfigProviderHintAndKeys(ctx, chat.OwnerID, config, keys) + if err != nil { + return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, err + } model, err := chatprovider.ModelFromConfig( - config.Provider, + providerHint, config.Model, - keys, + modelKeys, chatprovider.UserAgent(), chatprovider.CoderHeaders(chat), nil, ) if err != nil { - return nil, database.ChatModelConfig{}, xerrors.Errorf( + return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, xerrors.Errorf( "create fallback manual title model: %w", err, ) } - return model, config, nil + return model, config, modelKeys, nil } func mergeManualTitleMessages( @@ -4038,6 +4053,8 @@ type Config struct { StopWorkspace chattool.StopWorkspaceFn Pubsub pubsub.Pubsub ProviderAPIKeys chatprovider.ProviderAPIKeys + AllowBYOK bool + AllowBYOKSet bool AlwaysEnableDebugLogs bool WebpushDispatcher webpush.Dispatcher UsageTracker *workspacestats.UsageTracker @@ -4092,6 +4109,11 @@ func New(cfg Config) *Server { workerID = uuid.New() } + allowBYOK := true + if cfg.AllowBYOKSet { + allowBYOK = cfg.AllowBYOK + } + p := &Server{ cancel: cancel, db: cfg.Database, @@ -4108,6 +4130,7 @@ func New(cfg Config) *Server { pubsub: cfg.Pubsub, webpushDispatcher: cfg.WebpushDispatcher, providerAPIKeys: cfg.ProviderAPIKeys, + allowBYOK: allowBYOK, oidcTokenSource: cfg.OIDCTokenSource, debugSvcFactory: func() *chatdebug.Service { debugSvc := chatdebug.NewService( @@ -7079,10 +7102,11 @@ func (p *Server) runChat( // Fire title generation asynchronously so it doesn't block the // chat response. It uses a detached context so it can finish // even after the chat processing context is canceled. - // Snapshot model, logger, and ctx before launch; all three get - // reassigned below (model = cuModel, logger = logger.With(...), - // ctx = runCtx) and the goroutine captures by reference. + // Snapshot model, provider keys, logger, and ctx before launch; all four get + // reassigned below (model = cuModel, providerKeys = computerUseProviderKeys, + // logger = logger.With(...), ctx = runCtx) and the goroutine captures by reference. titleModel := model + titleProviderKeys := providerKeys titleLogger := logger titleCtx := context.WithoutCancel(ctx) p.inflight.Add(1) @@ -7095,7 +7119,7 @@ func (p *Server) runChat( modelConfig.Provider, modelConfig.Model, titleModel, - providerKeys, + titleProviderKeys, generatedTitle, titleLogger, debugSvc, @@ -7665,6 +7689,12 @@ func (p *Server) runChat( } if isComputerUse { + computerUseProviderKeys, keyErr := p.resolveUserProviderAPIKeysForProviderType(ctx, chat.OwnerID, computerUseModelProvider) + if keyErr != nil { + return result, xerrors.Errorf("resolve computer use provider API keys: %w", keyErr) + } + providerKeys = computerUseProviderKeys + // Override model for computer use subagent. cuModel, cuDebugEnabled, resolvedProvider, resolvedModel, cuErr := p.resolveComputerUseModel( ctx, @@ -8372,6 +8402,38 @@ func (p *Server) persistChatContextSummary( return nil } +func (p *Server) resolveModelConfigProviderHintAndKeys( + ctx context.Context, + ownerID uuid.UUID, + modelConfig database.ChatModelConfig, + fallbackKeys chatprovider.ProviderAPIKeys, +) (string, chatprovider.ProviderAPIKeys, error) { + providerHint := modelConfig.Provider + if !modelConfig.AIProviderID.Valid { + if !fallbackKeys.Empty() && userCanUseProviderKeys(fallbackKeys, providerHint) { + return providerHint, fallbackKeys, nil + } + keys, err := p.resolveUserProviderAPIKeys(ctx, ownerID, uuid.Nil) + if err != nil { + return "", chatprovider.ProviderAPIKeys{}, xerrors.Errorf("resolve provider API keys: %w", err) + } + return providerHint, keys, nil + } + //nolint:gocritic // Manual title generation needs chatd-scoped provider reads for user-owned chats. + provider, err := p.db.GetAIProviderByID(dbauthz.AsChatd(ctx), modelConfig.AIProviderID.UUID) + if err != nil { + return "", chatprovider.ProviderAPIKeys{}, xerrors.Errorf("get AI provider: %w", err) + } + if !provider.Enabled { + return "", chatprovider.ProviderAPIKeys{}, xerrors.Errorf("AI provider %s is disabled", provider.ID) + } + providerKeys, err := p.resolveUserProviderAPIKeysForProvider(ctx, ownerID, provider) + if err != nil { + return "", chatprovider.ProviderAPIKeys{}, xerrors.Errorf("resolve provider API keys: %w", err) + } + return string(provider.Type), providerKeys, nil +} + func (p *Server) resolveChatModel( ctx context.Context, chat database.Chat, @@ -8384,30 +8446,37 @@ func (p *Server) resolveChatModel( resolvedModel string, err error, ) { - var g errgroup.Group - g.Go(func() error { - var err error - dbConfig, err = p.resolveModelConfig(ctx, chat) + dbConfig, err = p.resolveModelConfig(ctx, chat) + if err != nil { + return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, false, "", "", xerrors.Errorf("resolve model config: %w", err) + } + + if !dbConfig.Enabled { + return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, false, "", "", xerrors.Errorf("chat model config %s is disabled", dbConfig.ID) + } + + providerHint := dbConfig.Provider + var keyErr error + if dbConfig.AIProviderID.Valid { + provider, err := p.db.GetAIProviderByID(ctx, dbConfig.AIProviderID.UUID) if err != nil { - return xerrors.Errorf("resolve model config: %w", err) + return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, false, "", "", xerrors.Errorf("get AI provider: %w", err) } - return nil - }) - g.Go(func() error { - var err error - keys, err = p.resolveUserProviderAPIKeys(ctx, chat.OwnerID) - if err != nil { - return xerrors.Errorf("resolve provider API keys: %w", err) + if !provider.Enabled { + return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, false, "", "", xerrors.Errorf("AI provider %s is disabled", provider.ID) } - return nil - }) - if err := g.Wait(); err != nil { - return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, false, "", "", err + providerHint = string(provider.Type) + keys, keyErr = p.resolveUserProviderAPIKeysForProvider(ctx, chat.OwnerID, provider) + } else { + keys, keyErr = p.resolveUserProviderAPIKeys(ctx, chat.OwnerID, uuid.Nil) + } + if keyErr != nil { + return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, false, "", "", xerrors.Errorf("resolve provider API keys: %w", keyErr) } resolvedProvider, resolvedModel, err = chatprovider.ResolveModelWithProviderHint( dbConfig.Model, - dbConfig.Provider, + providerHint, ) if err != nil { return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, false, "", "", xerrors.Errorf( @@ -8418,7 +8487,7 @@ func (p *Server) resolveChatModel( model, debugEnabled, err = p.newDebugAwareModelFromConfig( ctx, chat, - dbConfig.Provider, + providerHint, dbConfig.Model, keys, chatprovider.UserAgent(), @@ -8432,10 +8501,113 @@ func (p *Server) resolveChatModel( return model, dbConfig, keys, debugEnabled, resolvedProvider, resolvedModel, nil } +func (p *Server) aiProviderConfig(ctx context.Context, provider database.AIProvider) (chatprovider.ConfiguredProvider, error) { + if !provider.Enabled { + return chatprovider.ConfiguredProvider{}, xerrors.Errorf("AI provider %s is disabled", provider.ID) + } + keys, err := p.db.GetAIProviderKeysByProviderID(ctx, provider.ID) + if err != nil { + return chatprovider.ConfiguredProvider{}, xerrors.Errorf("get AI provider keys: %w", err) + } + apiKey := "" + // GetAIProviderKeysByProviderID orders keys oldest first. chatd consumes + // one provider-scoped key because runtime provider config has one API key slot. + for _, key := range keys { + if key.APIKey != "" { + apiKey = key.APIKey + break + } + } + return chatprovider.ConfiguredProvider{ + ProviderID: provider.ID, + Provider: string(provider.Type), + APIKey: apiKey, + BaseURL: provider.BaseUrl, + CentralAPIKeyEnabled: true, + AllowUserAPIKey: p.allowBYOK, + AllowCentralAPIKeyFallback: true, + }, nil +} + +func (p *Server) resolveUserProviderAPIKeysForProvider( + ctx context.Context, + ownerID uuid.UUID, + provider database.AIProvider, +) (chatprovider.ProviderAPIKeys, error) { + configuredProvider, err := p.aiProviderConfig(ctx, provider) + if err != nil { + return chatprovider.ProviderAPIKeys{}, err + } + userKeys := []chatprovider.UserProviderKey{} + if p.allowBYOK { + userKey, err := p.db.GetUserAIProviderKeyByProviderID(ctx, database.GetUserAIProviderKeyByProviderIDParams{ + UserID: ownerID, + AIProviderID: provider.ID, + }) + if err != nil && !xerrors.Is(err, sql.ErrNoRows) { + return chatprovider.ProviderAPIKeys{}, xerrors.Errorf("get user AI provider key: %w", err) + } + if err == nil { + userKeys = append(userKeys, chatprovider.UserProviderKey{ + ChatProviderID: userKey.AIProviderID, + APIKey: userKey.APIKey, + }) + } + } + keys, _ := chatprovider.ResolveUserProviderKeys( + chatprovider.ProviderAPIKeys{}, + []chatprovider.ConfiguredProvider{configuredProvider}, + userKeys, + ) + enabledProviders := map[string]struct{}{} + normalizedProvider := chatprovider.NormalizeProvider(configuredProvider.Provider) + if normalizedProvider != "" { + enabledProviders[normalizedProvider] = struct{}{} + } + chatprovider.PruneDisabledProviderKeys(&keys, enabledProviders) + return keys, nil +} + +func (p *Server) resolveUserProviderAPIKeysForProviderType( + ctx context.Context, + ownerID uuid.UUID, + providerType string, +) (chatprovider.ProviderAPIKeys, error) { + providers, err := p.db.GetAIProviders(ctx, database.GetAIProvidersParams{}) + if err != nil { + return chatprovider.ProviderAPIKeys{}, xerrors.Errorf("get enabled AI providers: %w", err) + } + normalizedProviderType := chatprovider.NormalizeProvider(providerType) + for _, provider := range providers { + if chatprovider.NormalizeProvider(string(provider.Type)) != normalizedProviderType { + continue + } + keys, err := p.resolveUserProviderAPIKeysForProvider(ctx, ownerID, provider) + if err != nil { + return chatprovider.ProviderAPIKeys{}, err + } + if userCanUseProviderKeys(keys, normalizedProviderType) { + return keys, nil + } + } + return p.resolveUserProviderAPIKeys(ctx, ownerID, uuid.Nil) +} + func (p *Server) resolveUserProviderAPIKeys( ctx context.Context, ownerID uuid.UUID, + selectedAIProviderID uuid.UUID, ) (chatprovider.ProviderAPIKeys, error) { + var configuredProviders []chatprovider.ConfiguredProvider + userKeys := []chatprovider.UserProviderKey{} + + if selectedAIProviderID != uuid.Nil { + provider, err := p.db.GetAIProviderByID(ctx, selectedAIProviderID) + if err != nil { + return chatprovider.ProviderAPIKeys{}, xerrors.Errorf("get AI provider: %w", err) + } + return p.resolveUserProviderAPIKeysForProvider(ctx, ownerID, provider) + } providers, err := p.configCache.EnabledProviders(ctx) if err != nil { return chatprovider.ProviderAPIKeys{}, xerrors.Errorf( @@ -8443,7 +8615,7 @@ func (p *Server) resolveUserProviderAPIKeys( err, ) } - configuredProviders := make( + configuredProviders = make( []chatprovider.ConfiguredProvider, 0, len(providers), ) for _, provider := range providers { @@ -8466,8 +8638,6 @@ func (p *Server) resolveUserProviderAPIKeys( break } } - - userKeys := []chatprovider.UserProviderKey{} if allowAnyUserAPIKey { userKeyRows, err := p.db.GetUserChatProviderKeys(ctx, ownerID) if err != nil { @@ -8484,6 +8654,7 @@ func (p *Server) resolveUserProviderAPIKeys( }) } } + keys, _ := chatprovider.ResolveUserProviderKeys( p.providerAPIKeys, configuredProviders, diff --git a/coderd/x/chatd/chatd_internal_test.go b/coderd/x/chatd/chatd_internal_test.go index 6136fad2cf..cbdd7a21dc 100644 --- a/coderd/x/chatd/chatd_internal_test.go +++ b/coderd/x/chatd/chatd_internal_test.go @@ -1116,7 +1116,7 @@ func TestResolveUserProviderAPIKeys_StripsDisabledFallbackKeys(t *testing.T) { AllowCentralApiKeyFallback: true, }}, nil) - keys, err := server.resolveUserProviderAPIKeys(ctx, ownerID) + keys, err := server.resolveUserProviderAPIKeys(ctx, ownerID, uuid.Nil) require.NoError(t, err) require.Empty(t, keys.OpenAI) require.Empty(t, keys.APIKey("openai")) @@ -1128,6 +1128,40 @@ func TestResolveUserProviderAPIKeys_StripsDisabledFallbackKeys(t *testing.T) { require.Equal(t, map[string]string{"anthropic": "https://anthropic.example.com"}, keys.BaseURLByProvider) } +func TestResolveUserProviderAPIKeys_SelectedAIProviderDoesNotUseDeploymentFallback(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + ownerID := uuid.New() + providerID := uuid.New() + + server := &Server{ + db: db, + providerAPIKeys: chatprovider.ProviderAPIKeys{ + OpenAI: "openai-deployment-key", + ByProvider: map[string]string{ + "openai": "openai-deployment-key", + }, + }, + } + + db.EXPECT().GetAIProviderByID(gomock.Any(), providerID).Return(database.AIProvider{ + ID: providerID, + Type: database.AiProviderTypeOpenai, + Name: "agents-openai", + Enabled: true, + }, nil) + db.EXPECT().GetAIProviderKeysByProviderID(gomock.Any(), providerID).Return(nil, nil) + + keys, err := server.resolveUserProviderAPIKeys(ctx, ownerID, providerID) + require.NoError(t, err) + require.Empty(t, keys.OpenAI) + require.Empty(t, keys.APIKey("openai")) + require.False(t, keys.HasProvider("openai")) +} + func TestResolveUserProviderAPIKeys_SkipsUserKeyLookupWhenNoProviderAllowsUserKeys(t *testing.T) { t.Parallel() @@ -1156,7 +1190,7 @@ func TestResolveUserProviderAPIKeys_SkipsUserKeyLookupWhenNoProviderAllowsUserKe CentralApiKeyEnabled: true, }}, nil) - keys, err := server.resolveUserProviderAPIKeys(ctx, ownerID) + keys, err := server.resolveUserProviderAPIKeys(ctx, ownerID, uuid.Nil) require.NoError(t, err) require.Equal(t, "openai-deployment-key", keys.OpenAI) require.Equal(t, "openai-deployment-key", keys.APIKey("openai")) diff --git a/coderd/x/chatd/chatprovider/chatprovider.go b/coderd/x/chatd/chatprovider/chatprovider.go index 2a43bcdab1..030d4f82b1 100644 --- a/coderd/x/chatd/chatprovider/chatprovider.go +++ b/coderd/x/chatd/chatprovider/chatprovider.go @@ -104,6 +104,14 @@ type ProviderAPIKeys struct { BaseURLByProvider map[string]string } +// Empty reports whether no provider keys or base URL overrides are set. +func (k ProviderAPIKeys) Empty() bool { + return k.OpenAI == "" && + k.Anthropic == "" && + len(k.ByProvider) == 0 && + len(k.BaseURLByProvider) == 0 +} + // UserProviderKey is a user-supplied API key for a specific provider. type UserProviderKey struct { ChatProviderID uuid.UUID diff --git a/coderd/x/chatd/quickgen.go b/coderd/x/chatd/quickgen.go index 0acfd7a941..a4e69bcbad 100644 --- a/coderd/x/chatd/quickgen.go +++ b/coderd/x/chatd/quickgen.go @@ -135,7 +135,7 @@ func (p *Server) maybeGenerateChatTitle( titleCtx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() - overrideConfig, overrideModel, overrideSet, overrideErr := p.resolveTitleGenerationModelOverride( + overrideConfig, overrideModel, overrideKeys, overrideSet, overrideErr := p.resolveTitleGenerationModelOverride( titleCtx, chat, keys, @@ -213,11 +213,15 @@ func (p *Server) maybeGenerateChatTitle( candidateCtx := titleCtx candidateModel := candidate.lm finishDebugRun := func(error) {} + candidateKeys := keys + if overrideSet { + candidateKeys = overrideKeys + } if debugEnabled { candidateCtx, candidateModel, finishDebugRun = prepareQuickgenDebugCandidate( titleCtx, chat, - keys, + candidateKeys, debugSvc, candidate, chatdebug.KindTitleGeneration, diff --git a/coderd/x/chatd/subagent.go b/coderd/x/chatd/subagent.go index dd113240d1..1f557f5eee 100644 --- a/coderd/x/chatd/subagent.go +++ b/coderd/x/chatd/subagent.go @@ -38,6 +38,7 @@ type modelOverrideConfigResolver func( type modelOverrideProviderKeysResolver func( context.Context, uuid.UUID, + uuid.UUID, ) (chatprovider.ProviderAPIKeys, error) const ( @@ -77,30 +78,6 @@ type closeAgentArgs struct { ChatID string `json:"chat_id"` } -// providerConfigured reports whether a provider has an API key from -// static configuration or from the database provider configuration. -func (p *Server) providerConfigured(ctx context.Context, provider string) (bool, error) { - normalizedProvider := chatprovider.NormalizeProvider(provider) - if normalizedProvider == "" { - return false, nil - } - if p.providerAPIKeys.APIKey(normalizedProvider) != "" { - return true, nil - } - - dbProviders, err := p.configCache.EnabledProviders(ctx) - if err != nil { - return false, xerrors.Errorf("list enabled chat providers: %w", err) - } - for _, prov := range dbProviders { - if chatprovider.NormalizeProvider(prov.Provider) == normalizedProvider && - strings.TrimSpace(prov.APIKey) != "" { - return true, nil - } - } - return false, nil -} - func (p *Server) isDesktopEnabled(ctx context.Context) bool { enabled, err := p.db.GetChatDesktopEnabled(ctx) if err != nil { @@ -292,7 +269,7 @@ func (p *Server) resolveConfiguredModelOverride( return database.ChatModelConfig{}, false, nil } - providerKeys, err := resolveProviderKeys(ctx, ownerID) + providerKeys, err := resolveProviderKeys(ctx, ownerID, modelConfigAIProviderID(modelConfig)) if err != nil { return database.ChatModelConfig{}, false, xerrors.Errorf( "resolve provider API keys: %w", @@ -425,7 +402,7 @@ func (p *Server) resolvePersonalModelOverride( } return database.ChatModelConfig{}, false, nil } - providerKeys, err := p.resolveUserProviderAPIKeys(ctx, ownerID) + providerKeys, err := p.resolveUserProviderAPIKeys(ctx, ownerID, modelConfigAIProviderID(modelConfig)) if err != nil { return database.ChatModelConfig{}, false, xerrors.Errorf( "resolve provider API keys: %w", @@ -499,6 +476,13 @@ func (p *Server) resolveSubagentModelConfigID( return modelConfig.ID, nil } +func modelConfigAIProviderID(modelConfig database.ChatModelConfig) uuid.UUID { + if !modelConfig.AIProviderID.Valid { + return uuid.Nil + } + return modelConfig.AIProviderID.UUID +} + func (p *Server) resolveModelConfigAndNormalizedProvider( ctx context.Context, modelConfigID uuid.UUID, @@ -510,6 +494,26 @@ func (p *Server) resolveModelConfigAndNormalizedProvider( if err != nil { return database.ChatModelConfig{}, "", err } + if !modelConfig.Enabled { + return database.ChatModelConfig{}, "", sql.ErrNoRows + } + if modelConfig.AIProviderID.Valid { + provider, err := p.db.GetAIProviderByID(ctx, modelConfig.AIProviderID.UUID) + if err != nil { + return database.ChatModelConfig{}, "", err + } + if !provider.Enabled { + return database.ChatModelConfig{}, "", sql.ErrNoRows + } + providerName := chatprovider.NormalizeProvider(string(provider.Type)) + if providerName == "" { + return database.ChatModelConfig{}, "", errInvalidModelOverrideMetadata + } + if _, _, err := chatprovider.ResolveModelWithProviderHint(modelConfig.Model, providerName); err != nil { + return database.ChatModelConfig{}, "", errInvalidModelOverrideMetadata + } + return modelConfig, providerName, nil + } modelConfig, providerName, err := validateModelConfigAndResolveProvider(modelConfig) if err != nil { return database.ChatModelConfig{}, "", err diff --git a/coderd/x/chatd/subagent_catalog.go b/coderd/x/chatd/subagent_catalog.go index 1d34ad1528..e567631271 100644 --- a/coderd/x/chatd/subagent_catalog.go +++ b/coderd/x/chatd/subagent_catalog.go @@ -127,16 +127,16 @@ func allSubagentDefinitions() []subagentDefinition { } return "" }, - buildOptions: func(ctx context.Context, p *Server, _ database.Chat, _ database.Chat, _ uuid.UUID, prompt string) (childSubagentChatOptions, error) { + buildOptions: func(ctx context.Context, p *Server, currentChat database.Chat, _ database.Chat, _ uuid.UUID, prompt string) (childSubagentChatOptions, error) { provider, _, _, err := p.computerUseProviderAndModelFromConfig(ctx) if err != nil { return childSubagentChatOptions{}, err } - configured, err := p.providerConfigured(ctx, provider) + providerKeys, err := p.resolveUserProviderAPIKeysForProviderType(ctx, currentChat.OwnerID, provider) if err != nil { return childSubagentChatOptions{}, err } - if !configured { + if !userCanUseProviderKeys(providerKeys, provider) { return childSubagentChatOptions{}, xerrors.Errorf( `API key for computer-use provider %q is not configured`, provider, diff --git a/coderd/x/chatd/subagent_internal_test.go b/coderd/x/chatd/subagent_internal_test.go index fffea3bc78..5a0b7fdc89 100644 --- a/coderd/x/chatd/subagent_internal_test.go +++ b/coderd/x/chatd/subagent_internal_test.go @@ -213,6 +213,137 @@ func insertEnabledAnthropicProvider( }) } +func insertInternalAIProvider( + t *testing.T, + db database.Store, + providerType database.AIProviderType, + apiKey string, + enabled bool, +) database.AIProvider { + t.Helper() + return dbgen.AIProviderWithOptionalKey(t, db, database.AIProvider{ + Type: providerType, + }, apiKey, func(params *database.InsertAIProviderParams) { + params.Enabled = enabled + }) +} + +func TestResolveUserProviderAPIKeys_AIProvider(t *testing.T) { + t.Parallel() + + t.Run("UserKeyWinsWhenBYOKEnabled", func(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{}) + ctx := chatdTestContext(t) + user, _, _ := seedInternalChatDeps(t, db) + provider := insertInternalAIProvider(t, db, database.AiProviderTypeOpenai, "provider-api-key", true) + now := time.Now() + _, err := db.UpsertUserAIProviderKey(ctx, database.UpsertUserAIProviderKeyParams{ + ID: uuid.New(), + UserID: user.ID, + AIProviderID: provider.ID, + APIKey: "user-api-key", + CreatedAt: now, + UpdatedAt: now, + }) + require.NoError(t, err) + + keys, err := server.resolveUserProviderAPIKeys(ctx, user.ID, provider.ID) + require.NoError(t, err) + require.Equal(t, "user-api-key", keys.APIKey("openai")) + require.Equal(t, "https://api.example.com/", keys.BaseURL("openai")) + }) + + t.Run("ProviderKeyUsedWhenBYOKDisabled", func(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{}) + server.allowBYOK = false + ctx := chatdTestContext(t) + user, _, _ := seedInternalChatDeps(t, db) + provider := insertInternalAIProvider(t, db, database.AiProviderTypeOpenai, "provider-api-key", true) + now := time.Now() + _, err := db.UpsertUserAIProviderKey(ctx, database.UpsertUserAIProviderKeyParams{ + ID: uuid.New(), + UserID: user.ID, + AIProviderID: provider.ID, + APIKey: "user-api-key", + CreatedAt: now, + UpdatedAt: now, + }) + require.NoError(t, err) + + keys, err := server.resolveUserProviderAPIKeys(ctx, user.ID, provider.ID) + require.NoError(t, err) + require.Equal(t, "provider-api-key", keys.APIKey("openai")) + }) + + t.Run("ProviderTypeUsesAIProvider", func(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{}) + ctx := chatdTestContext(t) + user, _, _ := seedInternalChatDeps(t, db) + insertInternalAIProvider(t, db, database.AiProviderTypeAzure, "provider-api-key", true) + + keys, err := server.resolveUserProviderAPIKeysForProviderType(ctx, user.ID, "azure") + require.NoError(t, err) + require.Equal(t, "provider-api-key", keys.APIKey("azure")) + }) + + t.Run("BedrockUsesAmbientAuth", func(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{}) + ctx := chatdTestContext(t) + user, _, _ := seedInternalChatDeps(t, db) + provider := insertInternalAIProvider(t, db, database.AiProviderTypeBedrock, "", true) + + keys, err := server.resolveUserProviderAPIKeys(ctx, user.ID, provider.ID) + require.NoError(t, err) + require.True(t, keys.HasProvider("bedrock")) + require.Empty(t, keys.APIKey("bedrock")) + }) +} + +func TestResolveChatModel_AIProviderDisabled(t *testing.T) { + t.Parallel() + + ctx := chatdTestContext(t) + db, ps, sqlDB := dbtestutil.NewDBWithSQLDB(t) + user, org, _ := seedInternalChatDeps(t, db) + provider := insertInternalAIProvider(t, db, database.AiProviderTypeOpenai, "provider-api-key", false) + modelConfig := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{ + Provider: "openai", + Model: "gpt-4o-mini", + }) + _, err := sqlDB.ExecContext(ctx, "UPDATE chat_model_configs SET ai_provider_id = $1 WHERE id = $2", provider.ID, modelConfig.ID) + require.NoError(t, err) + loadedModelConfig, err := db.GetChatModelConfigByID(ctx, modelConfig.ID) + require.NoError(t, err) + require.True(t, loadedModelConfig.AIProviderID.Valid) + server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{}) + chat := dbgen.Chat(t, db, database.Chat{ + OrganizationID: org.ID, + OwnerID: user.ID, + LastModelConfigID: modelConfig.ID, + }) + + model, config, keys, debugEnabled, resolvedProvider, resolvedModel, err := server.resolveChatModel(ctx, chat) + require.ErrorContains(t, err, "is disabled") + require.Nil(t, model) + require.Equal(t, database.ChatModelConfig{}, config) + require.Equal(t, chatprovider.ProviderAPIKeys{}, keys) + require.False(t, debugEnabled) + require.Empty(t, resolvedProvider) + require.Empty(t, resolvedModel) +} + func TestResolveUserProviderAPIKeys_PreservesAnthropicKeyFromDBProvider(t *testing.T) { t.Parallel() @@ -226,7 +357,7 @@ func TestResolveUserProviderAPIKeys_PreservesAnthropicKeyFromDBProvider(t *testi user, _, _ := seedInternalChatDeps(t, db) insertEnabledAnthropicProvider(t, db, user.ID) - keys, err := server.resolveUserProviderAPIKeys(ctx, user.ID) + keys, err := server.resolveUserProviderAPIKeys(ctx, user.ID, uuid.Nil) require.NoError(t, err) require.Equal(t, "test-anthropic-key", keys.Anthropic) require.Equal(t, "test-anthropic-key", keys.APIKey("anthropic")) @@ -244,7 +375,7 @@ func TestResolveUserProviderAPIKeys_PreservesAnthropicKeyFromDBProvider(t *testi ctx := chatdTestContext(t) user, _, _ := seedInternalChatDeps(t, db) - keys, err := server.resolveUserProviderAPIKeys(ctx, user.ID) + keys, err := server.resolveUserProviderAPIKeys(ctx, user.ID, uuid.Nil) require.NoError(t, err) require.Empty(t, keys.Anthropic) require.Empty(t, keys.APIKey("anthropic")) @@ -1065,6 +1196,7 @@ func TestResolveConfiguredModelOverride_AcceptsAmbientCredentialsProvider( func( _ context.Context, resolvedOwnerID uuid.UUID, + _ uuid.UUID, ) (chatprovider.ProviderAPIKeys, error) { require.Equal(t, ownerID, resolvedOwnerID) return chatprovider.ProviderAPIKeys{ diff --git a/coderd/x/chatd/title_override.go b/coderd/x/chatd/title_override.go index b01bc1613b..9214b44254 100644 --- a/coderd/x/chatd/title_override.go +++ b/coderd/x/chatd/title_override.go @@ -49,52 +49,78 @@ func (p *Server) resolveTitleGenerationModelOverride( ctx context.Context, chat database.Chat, keys chatprovider.ProviderAPIKeys, -) (database.ChatModelConfig, fantasy.LanguageModel, bool, error) { +) (database.ChatModelConfig, fantasy.LanguageModel, chatprovider.ProviderAPIKeys, bool, error) { raw, err := readTitleGenerationModelOverride(ctx, p.db) if err != nil { - return database.ChatModelConfig{}, nil, false, xerrors.Errorf( + return database.ChatModelConfig{}, nil, chatprovider.ProviderAPIKeys{}, false, xerrors.Errorf( "read title generation model override: %w", err, ) } + overrideProviderKeys := keys modelConfig, overrideSet, err := p.resolveConfiguredModelOverride( ctx, titleGenerationOverrideContext, raw, chat.OwnerID, p.resolveModelConfigAndNormalizedProvider, - func(context.Context, uuid.UUID) (chatprovider.ProviderAPIKeys, error) { - return keys, nil + func(ctx context.Context, ownerID uuid.UUID, aiProviderID uuid.UUID) (chatprovider.ProviderAPIKeys, error) { + if aiProviderID == uuid.Nil { + resolvedProviderKeys, err := p.resolveUserProviderAPIKeys(ctx, ownerID, uuid.Nil) + if err != nil || resolvedProviderKeys.Empty() { + resolvedProviderKeys = keys + } + overrideProviderKeys = resolvedProviderKeys + return resolvedProviderKeys, nil + } + resolvedProviderKeys, err := p.resolveUserProviderAPIKeys(ctx, ownerID, aiProviderID) + if err != nil { + return chatprovider.ProviderAPIKeys{}, err + } + overrideProviderKeys = resolvedProviderKeys + return resolvedProviderKeys, nil }, modelOverrideFailureModeHard, ) if err != nil { - return database.ChatModelConfig{}, nil, overrideSet, err + return database.ChatModelConfig{}, nil, chatprovider.ProviderAPIKeys{}, overrideSet, err } if !overrideSet { - return database.ChatModelConfig{}, nil, false, nil + return database.ChatModelConfig{}, nil, keys, false, nil } + providerHint := modelConfig.Provider + if modelConfig.AIProviderID.Valid { + //nolint:gocritic // Title overrides need chatd-scoped provider reads for user-owned chats. + provider, err := p.db.GetAIProviderByID(dbauthz.AsChatd(ctx), modelConfig.AIProviderID.UUID) + if err != nil { + return database.ChatModelConfig{}, nil, chatprovider.ProviderAPIKeys{}, true, xerrors.Errorf("get AI provider for title generation override: %w", err) + } + if !provider.Enabled { + return database.ChatModelConfig{}, nil, chatprovider.ProviderAPIKeys{}, true, xerrors.Errorf("AI provider %s is disabled", modelConfig.AIProviderID.UUID) + } + providerHint = string(provider.Type) + } model, err := chatprovider.ModelFromConfig( - modelConfig.Provider, + providerHint, modelConfig.Model, - keys, + overrideProviderKeys, chatprovider.UserAgent(), chatprovider.CoderHeaders(chat), nil, ) if err != nil { - return database.ChatModelConfig{}, nil, true, xerrors.Errorf( + return database.ChatModelConfig{}, nil, chatprovider.ProviderAPIKeys{}, true, xerrors.Errorf( "create title generation model override: %w", err, ) } if model == nil { - return database.ChatModelConfig{}, nil, true, xerrors.Errorf( + return database.ChatModelConfig{}, nil, chatprovider.ProviderAPIKeys{}, true, xerrors.Errorf( "create title generation model override returned nil", ) } - return modelConfig, model, true, nil + return modelConfig, model, overrideProviderKeys, true, nil } diff --git a/coderd/x/chatd/title_override_test.go b/coderd/x/chatd/title_override_test.go index 4fc0b2badc..de2227af96 100644 --- a/coderd/x/chatd/title_override_test.go +++ b/coderd/x/chatd/title_override_test.go @@ -228,6 +228,8 @@ func TestMaybeGenerateChatTitle_TitleGenerationOverrideSetUsable(t *testing.T) { logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) chat, messages := titleOverrideTestChatAndMessages(t) overrideConfig := titleOverrideModelConfig("gpt-4.1", true) + providerID := uuid.New() + overrideConfig.AIProviderID = uuid.NullUUID{UUID: providerID, Valid: true} wantTitle := "Override title" var requestCount atomic.Int32 @@ -236,7 +238,12 @@ func TestMaybeGenerateChatTitle_TitleGenerationOverrideSetUsable(t *testing.T) { require.Equal(t, overrideConfig.Model, req.Model) return chattest.OpenAINonStreamingResponse(`{"title":"` + wantTitle + `"}`) }) - keys := titleOverrideOpenAIKeys(serverURL) + provider := database.AIProvider{ + ID: providerID, + Type: database.AiProviderTypeOpenai, + Enabled: true, + BaseUrl: serverURL, + } fallbackModel := &chattest.FakeModel{ GenerateObjectFn: func(context.Context, fantasy.ObjectCall) (*fantasy.ObjectResponse, error) { t.Fatal("fallback model should not be called when override is usable") @@ -246,7 +253,11 @@ func TestMaybeGenerateChatTitle_TitleGenerationOverrideSetUsable(t *testing.T) { db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return(overrideConfig.ID.String(), nil) db.EXPECT().GetChatModelConfigByID(gomock.Any(), overrideConfig.ID).Return(overrideConfig, nil) - db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{Provider: "openai"}}, nil) + db.EXPECT().GetAIProviderByID(gomock.Any(), providerID).Return(provider, nil).AnyTimes() + db.EXPECT().GetAIProviderKeysByProviderID(gomock.Any(), providerID).Return([]database.AIProviderKey{{ + ProviderID: providerID, + APIKey: "test-key", + }}, nil) db.EXPECT().UpdateChatTitleByID(gomock.Any(), database.UpdateChatTitleByIDParams{ ID: chat.ID, Title: wantTitle, @@ -261,7 +272,7 @@ func TestMaybeGenerateChatTitle_TitleGenerationOverrideSetUsable(t *testing.T) { "openai", "fallback-chat-model", fallbackModel, - keys, + chatprovider.ProviderAPIKeys{}, generated, logger, nil, @@ -381,7 +392,7 @@ func TestResolveManualTitleModel_TitleGenerationOverrideUnset(t *testing.T) { }, nil) server := titleOverrideTestServer(db, logger) - model, gotConfig, err := server.resolveManualTitleModel( + model, gotConfig, _, err := server.resolveManualTitleModel( ctx, db, chat, @@ -392,6 +403,56 @@ func TestResolveManualTitleModel_TitleGenerationOverrideUnset(t *testing.T) { require.Equal(t, preferredConfig, gotConfig) } +func TestResolveManualTitleModel_TitleGenerationOverrideUnsetAIProvider(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}) + chat, _ := titleOverrideTestChatAndMessages(t) + providerID := uuid.New() + preferredConfig := database.ChatModelConfig{ + ID: uuid.New(), + Provider: preferredTitleModels[1].provider, + AIProviderID: uuid.NullUUID{UUID: providerID, Valid: true}, + Model: preferredTitleModels[1].model, + Enabled: true, + } + serverURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { + t.Fatal("model construction should not call the provider") + return chattest.OpenAIResponse{} + }) + provider := database.AIProvider{ + ID: providerID, + Type: database.AiProviderTypeOpenai, + Enabled: true, + BaseUrl: serverURL, + } + + db.EXPECT().GetChatTitleGenerationModelOverride(gomock.Any()).Return("", nil) + db.EXPECT().GetEnabledChatModelConfigs(gomock.Any()).Return([]database.ChatModelConfig{ + preferredConfig, + }, nil) + db.EXPECT().GetAIProviderByID(gomock.Any(), providerID).Return(provider, nil) + db.EXPECT().GetAIProviderKeysByProviderID(gomock.Any(), providerID).Return([]database.AIProviderKey{{ + ProviderID: providerID, + APIKey: "test-key", + }}, nil) + + server := titleOverrideTestServer(db, logger) + model, gotConfig, gotKeys, err := server.resolveManualTitleModel( + ctx, + db, + chat, + chatprovider.ProviderAPIKeys{}, + ) + require.NoError(t, err) + require.NotNil(t, model) + require.Equal(t, preferredConfig, gotConfig) + require.Equal(t, "test-key", gotKeys.APIKey("openai")) +} + func TestResolveManualTitleModel_TitleGenerationOverrideReadDBError(t *testing.T) { t.Parallel() @@ -414,7 +475,7 @@ func TestResolveManualTitleModel_TitleGenerationOverrideReadDBError(t *testing.T }, nil) server := titleOverrideTestServer(db, logger) - model, gotConfig, err := server.resolveManualTitleModel( + model, gotConfig, _, err := server.resolveManualTitleModel( ctx, db, chat, @@ -440,7 +501,7 @@ func TestResolveManualTitleModel_TitleGenerationOverrideSetUsable(t *testing.T) db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{Provider: "openai"}}, nil) server := titleOverrideTestServer(db, logger) - model, gotConfig, err := server.resolveManualTitleModel( + model, gotConfig, _, err := server.resolveManualTitleModel( ctx, db, chat, @@ -466,7 +527,7 @@ func TestResolveManualTitleModel_TitleGenerationOverrideMissingCredentials(t *te db.EXPECT().GetEnabledChatProviders(gomock.Any()).Return([]database.ChatProvider{{Provider: "openai"}}, nil) server := titleOverrideTestServer(db, logger) - model, gotConfig, err := server.resolveManualTitleModel( + model, gotConfig, _, err := server.resolveManualTitleModel( ctx, db, chat, @@ -493,7 +554,7 @@ func TestResolveManualTitleModel_TitleGenerationOverrideSetUnusable(t *testing.T db.EXPECT().GetChatModelConfigByID(gomock.Any(), overrideConfig.ID).Return(overrideConfig, nil) server := titleOverrideTestServer(db, logger) - model, gotConfig, err := server.resolveManualTitleModel( + model, gotConfig, _, err := server.resolveManualTitleModel( ctx, db, chat,