From 8b1705eb656bd225a868bbf4a12ef89ecc48ac64 Mon Sep 17 00:00:00 2001 From: Michael Suchacz <203725896+ibetitsmike@users.noreply.github.com> Date: Tue, 26 May 2026 21:31:52 +0200 Subject: [PATCH] feat: route chatd provider traffic through aibridge (#25629) ## Summary Routes chatd model calls backed by concrete AI Provider rows through the in-process aibridge transport by default, with deployment options to use direct provider routing when AI Gateway is disabled or chat AI Gateway routing is disabled. - Splits model routing into common, direct provider, and AI Gateway paths behind a single deployment-mode entry point. - Builds chatd models through explicit request, route, and options data. Active API key attribution is passed explicitly instead of being hidden inside generic model construction. - For AI Gateway BYOK routes, resolves the user's provider key in chatd, forwards it through provider-specific auth headers, and sets `X-Coder-AI-Governance-Token` to the `delegated` marker so aibridge preserves those headers while still stripping Coder-specific metadata. - Keeps central provider credentials and deployment fallback credentials out of forwarded provider auth headers, so AI Gateway central policy remains authoritative. - Redacts delegated provider auth from default string formatting to avoid accidental plaintext logging of user BYOK credentials. - Covers selected chat models, advisor overrides, title and quickgen paths, subagent overrides, computer use model selection, and an integration-style chat turn through the aibridge transport path. - Persists initiating API key IDs on chat and queued user messages, including subagent child messages, and fails closed for AI Gateway-routed model builds without an active key. - Removes unused `api_key_id` indexes while keeping the persistence columns and foreign keys. - Keeps the deployment option available through config and env parsing, but hides it from CLI help and generated docs. - Stabilizes the subagent poll fallback test so background CreateChat processing cannot win the state transition under slower CI environments. ## Tests - `go test ./coderd/x/chatd -run 'TestAIGatewayProviderAuthForUser|TestAIGatewayProviderAuthRedactsFormatting|TestResolveModelRouteForConfigAIGatewayProviderAuth|TestAIGatewayModelForwardsProviderAuth|TestProcessChat_AIGatewayRoutingUsesDelegatedAPIKey|TestAwaitSubagentCompletion' -count=1` - `go test ./coderd/aibridged -run 'TestServeHTTP_DelegatedAPIKey|TestServeHTTP_StripCoderToken' -count=1` - `git diff --check HEAD~1..HEAD` - `make lint` > Mux working on behalf of Mike. --- cli/testdata/server-config.yaml.golden | 5 + coderd/coderd.go | 5 + coderd/database/dbgen/dbgen.go | 1 + coderd/database/dump.sql | 12 +- coderd/database/foreign_key_constraint.go | 2 + .../000508_chat_turn_api_key_id.down.sql | 5 + .../000508_chat_turn_api_key_id.up.sql | 8 + coderd/database/models.go | 2 + coderd/database/queries.sql.go | 74 +- coderd/database/queries/chats.sql | 7 +- coderd/exp_chats.go | 3 + coderd/x/chatd/advisor_internal_test.go | 232 ++++++- coderd/x/chatd/chatd.go | 395 ++++++----- coderd/x/chatd/chatd_debug.go | 50 +- coderd/x/chatd/chatd_internal_test.go | 62 +- coderd/x/chatd/chatd_test.go | 270 +++++++- coderd/x/chatd/computer_use.go | 18 +- coderd/x/chatd/model_routing.go | 168 +++++ coderd/x/chatd/model_routing_aibridge.go | 292 ++++++++ coderd/x/chatd/model_routing_direct.go | 93 +++ coderd/x/chatd/model_routing_internal_test.go | 640 ++++++++++++++++++ coderd/x/chatd/quickgen.go | 143 ++-- coderd/x/chatd/quickgen_internal_test.go | 12 + coderd/x/chatd/subagent.go | 21 +- coderd/x/chatd/subagent_internal_test.go | 180 ++++- coderd/x/chatd/title_override.go | 67 +- .../x/chatd/title_override_internal_test.go | 22 +- codersdk/deployment.go | 16 +- codersdk/deployment_test.go | 9 + enterprise/coderd/exp_chats_test.go | 25 + site/src/api/typesGenerated.ts | 1 + 31 files changed, 2463 insertions(+), 377 deletions(-) create mode 100644 coderd/database/migrations/000508_chat_turn_api_key_id.down.sql create mode 100644 coderd/database/migrations/000508_chat_turn_api_key_id.up.sql create mode 100644 coderd/x/chatd/model_routing.go create mode 100644 coderd/x/chatd/model_routing_aibridge.go create mode 100644 coderd/x/chatd/model_routing_direct.go create mode 100644 coderd/x/chatd/model_routing_internal_test.go diff --git a/cli/testdata/server-config.yaml.golden b/cli/testdata/server-config.yaml.golden index 3e73b7ca3b..5bc02e0ae6 100644 --- a/cli/testdata/server-config.yaml.golden +++ b/cli/testdata/server-config.yaml.golden @@ -765,6 +765,11 @@ chat: # opt-in settings. # (default: false, type: bool) debugLoggingEnabled: false + # Route chat model requests through AI Gateway when both chat routing and AI + # Gateway are enabled. Otherwise, chat calls AI providers directly. Pending chats + # without API key metadata may need a retry or temporary direct routing. + # (default: true, type: bool) + aiGatewayRoutingEnabled: true aibridge: # Deprecated: use --ai-gateway-enabled or CODER_AI_GATEWAY_ENABLED instead. # Whether to start an in-memory aibridged instance. diff --git a/coderd/coderd.go b/coderd/coderd.go index 91d95c5ef2..c87adc5647 100644 --- a/coderd/coderd.go +++ b/coderd/coderd.go @@ -807,6 +807,9 @@ func New(options *Options) *API { providerAPIKeys = *options.ChatProviderAPIKeys } + chatAIGatewayRoutingEnabled := options.DeploymentValues.AI.BridgeConfig.Enabled.Value() && + options.DeploymentValues.AI.Chat.AIGatewayRoutingEnabled.Value() + api.chatDaemon = chatd.New(chatd.Config{ Logger: options.Logger.Named("chatd"), Database: options.Database, @@ -816,6 +819,8 @@ func New(options *Options) *API { ProviderAPIKeys: providerAPIKeys, AllowBYOK: options.DeploymentValues.AI.BridgeConfig.AllowBYOK.Value(), AllowBYOKSet: true, + AIBridgeTransportFactory: &api.AIBridgeTransportFactory, + AIGatewayRoutingEnabled: chatAIGatewayRoutingEnabled, AlwaysEnableDebugLogs: options.DeploymentValues.AI.Chat.DebugLoggingEnabled.Value(), AgentConn: api.agentProvider.AgentConn, AgentInactiveDisconnectTimeout: api.AgentInactiveDisconnectTimeout, diff --git a/coderd/database/dbgen/dbgen.go b/coderd/database/dbgen/dbgen.go index 995294bebc..834bad6274 100644 --- a/coderd/database/dbgen/dbgen.go +++ b/coderd/database/dbgen/dbgen.go @@ -119,6 +119,7 @@ func ChatMessage(t testing.TB, db database.Store, seed database.ChatMessage) dat msgs, err := db.InsertChatMessages(genCtx, database.InsertChatMessagesParams{ ChatID: seed.ChatID, CreatedBy: []uuid.UUID{seed.CreatedBy.UUID}, + APIKeyID: []string{seed.APIKeyID.String}, ModelConfigID: []uuid.UUID{seed.ModelConfigID.UUID}, Role: []database.ChatMessageRole{takeFirst(seed.Role, database.ChatMessageRoleUser)}, Content: []string{content}, diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index 7fe8b7b80f..53ae5ff53a 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -1554,7 +1554,8 @@ CREATE TABLE chat_messages ( total_cost_micros bigint, runtime_ms bigint, deleted boolean DEFAULT false NOT NULL, - provider_response_id text + provider_response_id text, + api_key_id text ); CREATE SEQUENCE chat_messages_id_seq @@ -1593,7 +1594,8 @@ CREATE TABLE chat_queued_messages ( chat_id uuid NOT NULL, content jsonb NOT NULL, created_at timestamp with time zone DEFAULT now() NOT NULL, - model_config_id uuid + model_config_id uuid, + api_key_id text ); CREATE SEQUENCE chat_queued_messages_id_seq @@ -4453,6 +4455,9 @@ ALTER TABLE ONLY chat_files ALTER TABLE ONLY chat_files ADD CONSTRAINT chat_files_owner_id_fkey FOREIGN KEY (owner_id) REFERENCES users(id) ON DELETE CASCADE; +ALTER TABLE ONLY chat_messages + ADD CONSTRAINT chat_messages_api_key_id_fkey FOREIGN KEY (api_key_id) REFERENCES api_keys(id) ON DELETE SET NULL; + ALTER TABLE ONLY chat_messages ADD CONSTRAINT chat_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE; @@ -4468,6 +4473,9 @@ ALTER TABLE ONLY chat_model_configs ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_updated_by_fkey FOREIGN KEY (updated_by) REFERENCES users(id); +ALTER TABLE ONLY chat_queued_messages + ADD CONSTRAINT chat_queued_messages_api_key_id_fkey FOREIGN KEY (api_key_id) REFERENCES api_keys(id) ON DELETE SET NULL; + ALTER TABLE ONLY chat_queued_messages ADD CONSTRAINT chat_queued_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE; diff --git a/coderd/database/foreign_key_constraint.go b/coderd/database/foreign_key_constraint.go index 47dba3d673..624f3229b6 100644 --- a/coderd/database/foreign_key_constraint.go +++ b/coderd/database/foreign_key_constraint.go @@ -21,11 +21,13 @@ const ( ForeignKeyChatFileLinksFileID ForeignKeyConstraint = "chat_file_links_file_id_fkey" // ALTER TABLE ONLY chat_file_links ADD CONSTRAINT chat_file_links_file_id_fkey FOREIGN KEY (file_id) REFERENCES chat_files(id) ON DELETE CASCADE; ForeignKeyChatFilesOrganizationID ForeignKeyConstraint = "chat_files_organization_id_fkey" // ALTER TABLE ONLY chat_files ADD CONSTRAINT chat_files_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE; ForeignKeyChatFilesOwnerID ForeignKeyConstraint = "chat_files_owner_id_fkey" // ALTER TABLE ONLY chat_files ADD CONSTRAINT chat_files_owner_id_fkey FOREIGN KEY (owner_id) REFERENCES users(id) ON DELETE CASCADE; + ForeignKeyChatMessagesAPIKeyID ForeignKeyConstraint = "chat_messages_api_key_id_fkey" // ALTER TABLE ONLY chat_messages ADD CONSTRAINT chat_messages_api_key_id_fkey FOREIGN KEY (api_key_id) REFERENCES api_keys(id) ON DELETE SET NULL; ForeignKeyChatMessagesChatID ForeignKeyConstraint = "chat_messages_chat_id_fkey" // ALTER TABLE ONLY chat_messages ADD CONSTRAINT chat_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE; ForeignKeyChatMessagesModelConfigID ForeignKeyConstraint = "chat_messages_model_config_id_fkey" // ALTER TABLE ONLY chat_messages ADD CONSTRAINT chat_messages_model_config_id_fkey FOREIGN KEY (model_config_id) REFERENCES chat_model_configs(id); ForeignKeyChatModelConfigsAiProviderID ForeignKeyConstraint = "chat_model_configs_ai_provider_id_fkey" // ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_ai_provider_id_fkey FOREIGN KEY (ai_provider_id) REFERENCES ai_providers(id); ForeignKeyChatModelConfigsCreatedBy ForeignKeyConstraint = "chat_model_configs_created_by_fkey" // ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_created_by_fkey FOREIGN KEY (created_by) REFERENCES users(id); ForeignKeyChatModelConfigsUpdatedBy ForeignKeyConstraint = "chat_model_configs_updated_by_fkey" // ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_updated_by_fkey FOREIGN KEY (updated_by) REFERENCES users(id); + ForeignKeyChatQueuedMessagesAPIKeyID ForeignKeyConstraint = "chat_queued_messages_api_key_id_fkey" // ALTER TABLE ONLY chat_queued_messages ADD CONSTRAINT chat_queued_messages_api_key_id_fkey FOREIGN KEY (api_key_id) REFERENCES api_keys(id) ON DELETE SET NULL; ForeignKeyChatQueuedMessagesChatID ForeignKeyConstraint = "chat_queued_messages_chat_id_fkey" // ALTER TABLE ONLY chat_queued_messages ADD CONSTRAINT chat_queued_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE; ForeignKeyChatsAgentID ForeignKeyConstraint = "chats_agent_id_fkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_agent_id_fkey FOREIGN KEY (agent_id) REFERENCES workspace_agents(id) ON DELETE SET NULL; ForeignKeyChatsBuildID ForeignKeyConstraint = "chats_build_id_fkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_build_id_fkey FOREIGN KEY (build_id) REFERENCES workspace_builds(id) ON DELETE SET NULL; diff --git a/coderd/database/migrations/000508_chat_turn_api_key_id.down.sql b/coderd/database/migrations/000508_chat_turn_api_key_id.down.sql new file mode 100644 index 0000000000..4a8ad23b10 --- /dev/null +++ b/coderd/database/migrations/000508_chat_turn_api_key_id.down.sql @@ -0,0 +1,5 @@ +ALTER TABLE chat_queued_messages +DROP COLUMN api_key_id; + +ALTER TABLE chat_messages +DROP COLUMN api_key_id; diff --git a/coderd/database/migrations/000508_chat_turn_api_key_id.up.sql b/coderd/database/migrations/000508_chat_turn_api_key_id.up.sql new file mode 100644 index 0000000000..24a83810a5 --- /dev/null +++ b/coderd/database/migrations/000508_chat_turn_api_key_id.up.sql @@ -0,0 +1,8 @@ +-- Preserve chat history when API keys are deleted. Pending work whose latest +-- user turn loses this attribution will fail closed under AI Gateway routing; +-- operators can retry the turn or temporarily use direct routing. +ALTER TABLE chat_messages +ADD COLUMN api_key_id text REFERENCES api_keys(id) ON DELETE SET NULL; + +ALTER TABLE chat_queued_messages +ADD COLUMN api_key_id text REFERENCES api_keys(id) ON DELETE SET NULL; diff --git a/coderd/database/models.go b/coderd/database/models.go index 593d89e4d1..940904385a 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -4699,6 +4699,7 @@ type ChatMessage struct { RuntimeMs sql.NullInt64 `db:"runtime_ms" json:"runtime_ms"` Deleted bool `db:"deleted" json:"deleted"` ProviderResponseID sql.NullString `db:"provider_response_id" json:"provider_response_id"` + APIKeyID sql.NullString `db:"api_key_id" json:"api_key_id"` } type ChatModelConfig struct { @@ -4726,6 +4727,7 @@ type ChatQueuedMessage struct { Content json.RawMessage `db:"content" json:"content"` CreatedAt time.Time `db:"created_at" json:"created_at"` ModelConfigID uuid.NullUUID `db:"model_config_id" json:"model_config_id"` + APIKeyID sql.NullString `db:"api_key_id" json:"api_key_id"` } type ChatTable struct { diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index be452f5d3a..2ba4b923de 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -7157,7 +7157,7 @@ func (q *sqlQuerier) GetChatDiffStatusesByChatIDs(ctx context.Context, chatIds [ const getChatMessageByID = `-- name: GetChatMessageByID :one SELECT - id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id + id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id, api_key_id FROM chat_messages WHERE @@ -7190,6 +7190,7 @@ func (q *sqlQuerier) GetChatMessageByID(ctx context.Context, id int64) (ChatMess &i.RuntimeMs, &i.Deleted, &i.ProviderResponseID, + &i.APIKeyID, ) return i, err } @@ -7279,7 +7280,7 @@ func (q *sqlQuerier) GetChatMessageSummariesPerChat(ctx context.Context, created const getChatMessagesByChatID = `-- name: GetChatMessagesByChatID :many SELECT - id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id + id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id, api_key_id FROM chat_messages WHERE @@ -7327,6 +7328,7 @@ func (q *sqlQuerier) GetChatMessagesByChatID(ctx context.Context, arg GetChatMes &i.RuntimeMs, &i.Deleted, &i.ProviderResponseID, + &i.APIKeyID, ); err != nil { return nil, err } @@ -7343,7 +7345,7 @@ func (q *sqlQuerier) GetChatMessagesByChatID(ctx context.Context, arg GetChatMes const getChatMessagesByChatIDAscPaginated = `-- name: GetChatMessagesByChatIDAscPaginated :many SELECT - id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id + id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id, api_key_id FROM chat_messages WHERE @@ -7394,6 +7396,7 @@ func (q *sqlQuerier) GetChatMessagesByChatIDAscPaginated(ctx context.Context, ar &i.RuntimeMs, &i.Deleted, &i.ProviderResponseID, + &i.APIKeyID, ); err != nil { return nil, err } @@ -7410,7 +7413,7 @@ func (q *sqlQuerier) GetChatMessagesByChatIDAscPaginated(ctx context.Context, ar const getChatMessagesByChatIDDescPaginated = `-- name: GetChatMessagesByChatIDDescPaginated :many SELECT - id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id + id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id, api_key_id FROM chat_messages WHERE @@ -7474,6 +7477,7 @@ func (q *sqlQuerier) GetChatMessagesByChatIDDescPaginated(ctx context.Context, a &i.RuntimeMs, &i.Deleted, &i.ProviderResponseID, + &i.APIKeyID, ); err != nil { return nil, err } @@ -7506,7 +7510,7 @@ WITH latest_compressed_summary AS ( 1 ) SELECT - id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id + id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id, api_key_id FROM chat_messages WHERE @@ -7578,6 +7582,7 @@ func (q *sqlQuerier) GetChatMessagesForPromptByChatID(ctx context.Context, chatI &i.RuntimeMs, &i.Deleted, &i.ProviderResponseID, + &i.APIKeyID, ); err != nil { return nil, err } @@ -7639,7 +7644,7 @@ func (q *sqlQuerier) GetChatModelConfigsForTelemetry(ctx context.Context) ([]Get } const getChatQueuedMessages = `-- name: GetChatQueuedMessages :many -SELECT id, chat_id, content, created_at, model_config_id FROM chat_queued_messages +SELECT id, chat_id, content, created_at, model_config_id, api_key_id FROM chat_queued_messages WHERE chat_id = $1 ORDER BY created_at ASC, id ASC ` @@ -7659,6 +7664,7 @@ func (q *sqlQuerier) GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID &i.Content, &i.CreatedAt, &i.ModelConfigID, + &i.APIKeyID, ); err != nil { return nil, err } @@ -8354,7 +8360,7 @@ func (q *sqlQuerier) GetChildChatsByParentIDs(ctx context.Context, arg GetChildC const getLastChatMessageByRole = `-- name: GetLastChatMessageByRole :one SELECT - id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id + id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id, api_key_id FROM chat_messages WHERE @@ -8397,6 +8403,7 @@ func (q *sqlQuerier) GetLastChatMessageByRole(ctx context.Context, arg GetLastCh &i.RuntimeMs, &i.Deleted, &i.ProviderResponseID, + &i.APIKeyID, ) return i, err } @@ -8709,7 +8716,7 @@ WITH updated_chat AS ( SET last_model_config_id = ( SELECT val - FROM UNNEST($3::uuid[]) + FROM UNNEST($4::uuid[]) WITH ORDINALITY AS t(val, ord) WHERE val != '00000000-0000-0000-0000-000000000000'::uuid ORDER BY ord DESC @@ -8719,12 +8726,12 @@ WITH updated_chat AS ( id = $1::uuid AND EXISTS ( SELECT 1 - FROM UNNEST($3::uuid[]) + FROM UNNEST($4::uuid[]) WHERE unnest != '00000000-0000-0000-0000-000000000000'::uuid ) AND chats.last_model_config_id IS DISTINCT FROM ( SELECT val - FROM UNNEST($3::uuid[]) + FROM UNNEST($4::uuid[]) WITH ORDINALITY AS t(val, ord) WHERE val != '00000000-0000-0000-0000-000000000000'::uuid ORDER BY ord DESC @@ -8734,6 +8741,7 @@ WITH updated_chat AS ( INSERT INTO chat_messages ( chat_id, created_by, + api_key_id, model_config_id, role, content, @@ -8754,29 +8762,31 @@ INSERT INTO chat_messages ( SELECT $1::uuid, NULLIF(UNNEST($2::uuid[]), '00000000-0000-0000-0000-000000000000'::uuid), - NULLIF(UNNEST($3::uuid[]), '00000000-0000-0000-0000-000000000000'::uuid), - UNNEST($4::chat_message_role[]), - UNNEST($5::text[])::jsonb, - UNNEST($6::smallint[]), - UNNEST($7::chat_message_visibility[]), - NULLIF(UNNEST($8::bigint[]), 0), + NULLIF(UNNEST($3::text[]), ''), + NULLIF(UNNEST($4::uuid[]), '00000000-0000-0000-0000-000000000000'::uuid), + UNNEST($5::chat_message_role[]), + UNNEST($6::text[])::jsonb, + UNNEST($7::smallint[]), + UNNEST($8::chat_message_visibility[]), NULLIF(UNNEST($9::bigint[]), 0), NULLIF(UNNEST($10::bigint[]), 0), NULLIF(UNNEST($11::bigint[]), 0), NULLIF(UNNEST($12::bigint[]), 0), NULLIF(UNNEST($13::bigint[]), 0), NULLIF(UNNEST($14::bigint[]), 0), - UNNEST($15::boolean[]), - NULLIF(UNNEST($16::bigint[]), 0), + NULLIF(UNNEST($15::bigint[]), 0), + UNNEST($16::boolean[]), NULLIF(UNNEST($17::bigint[]), 0), - NULLIF(UNNEST($18::text[]), '') + NULLIF(UNNEST($18::bigint[]), 0), + NULLIF(UNNEST($19::text[]), '') RETURNING - id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id + id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id, api_key_id ` type InsertChatMessagesParams struct { ChatID uuid.UUID `db:"chat_id" json:"chat_id"` CreatedBy []uuid.UUID `db:"created_by" json:"created_by"` + APIKeyID []string `db:"api_key_id" json:"api_key_id"` ModelConfigID []uuid.UUID `db:"model_config_id" json:"model_config_id"` Role []ChatMessageRole `db:"role" json:"role"` Content []string `db:"content" json:"content"` @@ -8799,6 +8809,7 @@ func (q *sqlQuerier) InsertChatMessages(ctx context.Context, arg InsertChatMessa rows, err := q.db.QueryContext(ctx, insertChatMessages, arg.ChatID, pq.Array(arg.CreatedBy), + pq.Array(arg.APIKeyID), pq.Array(arg.ModelConfigID), pq.Array(arg.Role), pq.Array(arg.Content), @@ -8845,6 +8856,7 @@ func (q *sqlQuerier) InsertChatMessages(ctx context.Context, arg InsertChatMessa &i.RuntimeMs, &i.Deleted, &i.ProviderResponseID, + &i.APIKeyID, ); err != nil { return nil, err } @@ -8860,23 +8872,30 @@ func (q *sqlQuerier) InsertChatMessages(ctx context.Context, arg InsertChatMessa } const insertChatQueuedMessage = `-- name: InsertChatQueuedMessage :one -INSERT INTO chat_queued_messages (chat_id, content, model_config_id) +INSERT INTO chat_queued_messages (chat_id, content, model_config_id, api_key_id) VALUES ( $1, $2, - $3::uuid + $3::uuid, + $4::text ) -RETURNING id, chat_id, content, created_at, model_config_id +RETURNING id, chat_id, content, created_at, model_config_id, api_key_id ` type InsertChatQueuedMessageParams struct { ChatID uuid.UUID `db:"chat_id" json:"chat_id"` Content json.RawMessage `db:"content" json:"content"` ModelConfigID uuid.NullUUID `db:"model_config_id" json:"model_config_id"` + APIKeyID sql.NullString `db:"api_key_id" json:"api_key_id"` } func (q *sqlQuerier) InsertChatQueuedMessage(ctx context.Context, arg InsertChatQueuedMessageParams) (ChatQueuedMessage, error) { - row := q.db.QueryRowContext(ctx, insertChatQueuedMessage, arg.ChatID, arg.Content, arg.ModelConfigID) + row := q.db.QueryRowContext(ctx, insertChatQueuedMessage, + arg.ChatID, + arg.Content, + arg.ModelConfigID, + arg.APIKeyID, + ) var i ChatQueuedMessage err := row.Scan( &i.ID, @@ -8884,6 +8903,7 @@ func (q *sqlQuerier) InsertChatQueuedMessage(ctx context.Context, arg InsertChat &i.Content, &i.CreatedAt, &i.ModelConfigID, + &i.APIKeyID, ) return i, err } @@ -9108,7 +9128,7 @@ WHERE id = ( ORDER BY cqm.created_at ASC, cqm.id ASC LIMIT 1 ) -RETURNING id, chat_id, content, created_at, model_config_id +RETURNING id, chat_id, content, created_at, model_config_id, api_key_id ` func (q *sqlQuerier) PopNextQueuedMessage(ctx context.Context, chatID uuid.UUID) (ChatQueuedMessage, error) { @@ -9120,6 +9140,7 @@ func (q *sqlQuerier) PopNextQueuedMessage(ctx context.Context, chatID uuid.UUID) &i.Content, &i.CreatedAt, &i.ModelConfigID, + &i.APIKeyID, ) return i, err } @@ -10145,7 +10166,7 @@ SET WHERE id = $3::bigint RETURNING - id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id + id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id, api_key_id ` type UpdateChatMessageByIDParams struct { @@ -10179,6 +10200,7 @@ func (q *sqlQuerier) UpdateChatMessageByID(ctx context.Context, arg UpdateChatMe &i.RuntimeMs, &i.Deleted, &i.ProviderResponseID, + &i.APIKeyID, ) return i, err } diff --git a/coderd/database/queries/chats.sql b/coderd/database/queries/chats.sql index 963eec817b..c8b6502cf5 100644 --- a/coderd/database/queries/chats.sql +++ b/coderd/database/queries/chats.sql @@ -763,6 +763,7 @@ WITH updated_chat AS ( INSERT INTO chat_messages ( chat_id, created_by, + api_key_id, model_config_id, role, content, @@ -783,6 +784,7 @@ INSERT INTO chat_messages ( SELECT @chat_id::uuid, NULLIF(UNNEST(@created_by::uuid[]), '00000000-0000-0000-0000-000000000000'::uuid), + NULLIF(UNNEST(@api_key_id::text[]), ''), NULLIF(UNNEST(@model_config_id::uuid[]), '00000000-0000-0000-0000-000000000000'::uuid), UNNEST(@role::chat_message_role[]), UNNEST(@content::text[])::jsonb, @@ -1688,11 +1690,12 @@ RETURNING *; -- name: InsertChatQueuedMessage :one -INSERT INTO chat_queued_messages (chat_id, content, model_config_id) +INSERT INTO chat_queued_messages (chat_id, content, model_config_id, api_key_id) VALUES ( @chat_id, @content, - sqlc.narg('model_config_id')::uuid + sqlc.narg('model_config_id')::uuid, + sqlc.narg('api_key_id')::text ) RETURNING *; diff --git a/coderd/exp_chats.go b/coderd/exp_chats.go index 3c701be52c..c01f6c29f9 100644 --- a/coderd/exp_chats.go +++ b/coderd/exp_chats.go @@ -1220,6 +1220,7 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) { ClientType: clientType, SystemPrompt: req.SystemPrompt, InitialUserContent: contentBlocks, + APIKeyID: apiKey.ID, MCPServerIDs: mcpServerIDs, Labels: labels, DynamicTools: dynamicToolsJSON, @@ -3089,6 +3090,7 @@ func (api *API) postChatMessages(rw http.ResponseWriter, r *http.Request) { CreatedBy: apiKey.UserID, Content: contentBlocks, ModelConfigID: modelConfigID, + APIKeyID: apiKey.ID, BusyBehavior: busyBehavior, PlanMode: sendPlanMode, MCPServerIDs: req.MCPServerIDs, @@ -3229,6 +3231,7 @@ func (api *API) patchChatMessage(rw http.ResponseWriter, r *http.Request) { CreatedBy: apiKey.UserID, EditedMessageID: messageID, Content: contentBlocks, + APIKeyID: apiKey.ID, ModelConfigID: editModelConfigID, }) if editErr != nil { diff --git a/coderd/x/chatd/advisor_internal_test.go b/coderd/x/chatd/advisor_internal_test.go index 90a6b00bf3..e8b9dc1841 100644 --- a/coderd/x/chatd/advisor_internal_test.go +++ b/coderd/x/chatd/advisor_internal_test.go @@ -13,6 +13,7 @@ import ( "golang.org/x/xerrors" "cdr.dev/slog/v3" + "github.com/coder/coder/v2/coderd/aibridge" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/x/chatd/chatadvisor" "github.com/coder/coder/v2/coderd/x/chatd/chatprovider" @@ -30,7 +31,11 @@ import ( type advisorOverrideStubStore struct { database.Store - getEnabledChatModelConfigByID func(context.Context, uuid.UUID) (database.ChatModelConfig, error) + getEnabledChatModelConfigByID func(context.Context, uuid.UUID) (database.ChatModelConfig, error) + getAIProviderByID func(context.Context, uuid.UUID) (database.AIProvider, error) + getAIProviders func(context.Context, database.GetAIProvidersParams) ([]database.AIProvider, error) + getAIProviderKeysByProviderID func(context.Context, uuid.UUID) ([]database.AIProviderKey, error) + getAIProviderKeysByProviderIDs func(context.Context, []uuid.UUID) ([]database.AIProviderKey, error) } func (s *advisorOverrideStubStore) GetEnabledChatModelConfigByID( @@ -43,6 +48,46 @@ func (s *advisorOverrideStubStore) GetEnabledChatModelConfigByID( return s.getEnabledChatModelConfigByID(ctx, id) } +func (s *advisorOverrideStubStore) GetAIProviderByID( + ctx context.Context, + id uuid.UUID, +) (database.AIProvider, error) { + if s.getAIProviderByID == nil { + return database.AIProvider{}, xerrors.New("unexpected GetAIProviderByID call") + } + return s.getAIProviderByID(ctx, id) +} + +func (s *advisorOverrideStubStore) GetAIProviders( + ctx context.Context, + params database.GetAIProvidersParams, +) ([]database.AIProvider, error) { + if s.getAIProviders == nil { + return nil, xerrors.New("unexpected GetAIProviders call") + } + return s.getAIProviders(ctx, params) +} + +func (s *advisorOverrideStubStore) GetAIProviderKeysByProviderID( + ctx context.Context, + providerID uuid.UUID, +) ([]database.AIProviderKey, error) { + if s.getAIProviderKeysByProviderID == nil { + return nil, xerrors.New("unexpected GetAIProviderKeysByProviderID call") + } + return s.getAIProviderKeysByProviderID(ctx, providerID) +} + +func (s *advisorOverrideStubStore) GetAIProviderKeysByProviderIDs( + ctx context.Context, + providerIDs []uuid.UUID, +) ([]database.AIProviderKey, error) { + if s.getAIProviderKeysByProviderIDs == nil { + return nil, xerrors.New("unexpected GetAIProviderKeysByProviderIDs call") + } + return s.getAIProviderKeysByProviderIDs(ctx, providerIDs) +} + func newAdvisorTestServer( ctx context.Context, t *testing.T, @@ -56,6 +101,60 @@ func newAdvisorTestServer( } } +func (p *Server) resolveAdvisorModelOverrideOrFallback( + ctx context.Context, + chat database.Chat, + advisorCfg codersdk.AdvisorConfig, + fallbackModel fantasy.LanguageModel, + fallbackCallConfig codersdk.ChatModelCallConfig, + providerKeys chatprovider.ProviderAPIKeys, + modelOpts modelBuildOptions, + logger slog.Logger, +) (fantasy.LanguageModel, codersdk.ChatModelCallConfig) { + model, cfg, err := p.resolveAdvisorModelOverride( + ctx, + chat, + advisorCfg, + fallbackModel, + fallbackCallConfig, + providerKeys, + modelOpts, + logger, + ) + if err != nil { + logger.Warn(ctx, "failed to resolve advisor model override, continuing with chat model", slog.Error(err)) + return fallbackModel, fallbackCallConfig + } + return model, cfg +} + +func (p *Server) newAdvisorRuntimeOrFallback( + ctx context.Context, + chat database.Chat, + advisorCfg codersdk.AdvisorConfig, + fallbackModel fantasy.LanguageModel, + fallbackCallConfig codersdk.ChatModelCallConfig, + providerKeys chatprovider.ProviderAPIKeys, + modelOpts modelBuildOptions, + logger slog.Logger, +) *chatadvisor.Runtime { + rt, err := p.newAdvisorRuntime( + ctx, + chat, + advisorCfg, + fallbackModel, + fallbackCallConfig, + providerKeys, + modelOpts, + logger, + ) + if err != nil { + logger.Warn(ctx, "failed to create advisor runtime, continuing without advisor", slog.Error(err)) + return nil + } + return rt +} + // TestResolveAdvisorModelOverride covers the early-return, each fallback // branch, and the success path. Prior tests only hit the ModelConfigID == // uuid.Nil early return, so the override body never executed. @@ -73,13 +172,14 @@ func TestResolveAdvisorModelOverride(t *testing.T) { store := &advisorOverrideStubStore{} p := newAdvisorTestServer(ctx, t, store) - gotModel, gotCfg := p.resolveAdvisorModelOverride( + gotModel, gotCfg := p.resolveAdvisorModelOverrideOrFallback( ctx, database.Chat{}, codersdk.AdvisorConfig{}, fallbackModel, fallbackCallConfig, chatprovider.ProviderAPIKeys{}, + modelBuildOptions{}, logger, ) require.Equal(t, fallbackModel, gotModel) @@ -96,13 +196,14 @@ func TestResolveAdvisorModelOverride(t *testing.T) { } p := newAdvisorTestServer(ctx, t, store) - gotModel, gotCfg := p.resolveAdvisorModelOverride( + gotModel, gotCfg := p.resolveAdvisorModelOverrideOrFallback( ctx, database.Chat{}, codersdk.AdvisorConfig{ModelConfigID: uuid.New()}, fallbackModel, fallbackCallConfig, chatprovider.ProviderAPIKeys{OpenAI: "sk-test"}, + modelBuildOptions{}, logger, ) require.Equal(t, fallbackModel, gotModel) @@ -125,13 +226,14 @@ func TestResolveAdvisorModelOverride(t *testing.T) { } p := newAdvisorTestServer(ctx, t, store) - gotModel, gotCfg := p.resolveAdvisorModelOverride( + gotModel, gotCfg := p.resolveAdvisorModelOverrideOrFallback( ctx, database.Chat{}, codersdk.AdvisorConfig{ModelConfigID: uuid.New()}, fallbackModel, fallbackCallConfig, chatprovider.ProviderAPIKeys{OpenAI: "sk-test"}, + modelBuildOptions{}, logger, ) require.Equal(t, fallbackModel, gotModel) @@ -158,13 +260,14 @@ func TestResolveAdvisorModelOverride(t *testing.T) { } p := newAdvisorTestServer(ctx, t, store) - gotModel, gotCfg := p.resolveAdvisorModelOverride( + gotModel, gotCfg := p.resolveAdvisorModelOverrideOrFallback( ctx, database.Chat{}, codersdk.AdvisorConfig{ModelConfigID: configID}, fallbackModel, fallbackCallConfig, chatprovider.ProviderAPIKeys{OpenAI: "sk-test"}, + modelBuildOptions{}, logger, ) require.Equal(t, fallbackModel, gotModel) @@ -175,6 +278,7 @@ func TestResolveAdvisorModelOverride(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitShort) configID := uuid.New() + providerID := uuid.New() store := &advisorOverrideStubStore{ getEnabledChatModelConfigByID: func(context.Context, uuid.UUID) (database.ChatModelConfig, error) { return database.ChatModelConfig{ @@ -187,16 +291,27 @@ func TestResolveAdvisorModelOverride(t *testing.T) { DisplayName: "gpt-5.2", }, nil }, + getAIProviders: func(context.Context, database.GetAIProvidersParams) ([]database.AIProvider, error) { + return []database.AIProvider{{ + ID: providerID, + Type: database.AiProviderTypeOpenai, + Enabled: true, + }}, nil + }, + getAIProviderKeysByProviderIDs: func(context.Context, []uuid.UUID) ([]database.AIProviderKey, error) { + return nil, nil + }, } p := newAdvisorTestServer(ctx, t, store) - gotModel, gotCfg := p.resolveAdvisorModelOverride( + gotModel, gotCfg := p.resolveAdvisorModelOverrideOrFallback( ctx, database.Chat{}, codersdk.AdvisorConfig{ModelConfigID: configID}, fallbackModel, fallbackCallConfig, chatprovider.ProviderAPIKeys{}, + modelBuildOptions{}, logger, ) require.Equal(t, fallbackModel, gotModel) @@ -227,13 +342,14 @@ func TestResolveAdvisorModelOverride(t *testing.T) { } p := newAdvisorTestServer(ctx, t, store) - gotModel, gotCfg := p.resolveAdvisorModelOverride( + gotModel, gotCfg := p.resolveAdvisorModelOverrideOrFallback( ctx, database.Chat{}, codersdk.AdvisorConfig{ModelConfigID: configID}, fallbackModel, fallbackCallConfig, chatprovider.ProviderAPIKeys{OpenAI: "sk-test"}, + modelBuildOptions{}, logger, ) require.NotEqual(t, fantasy.LanguageModel(fallbackModel), gotModel, @@ -247,6 +363,99 @@ func TestResolveAdvisorModelOverride(t *testing.T) { require.NotNil(t, gotCfg.Temperature) require.InDelta(t, 0.42, *gotCfg.Temperature, 1e-9) }) + + t.Run("AIProviderIDResolvesOverrideProviderKeys", func(t *testing.T) { + t.Parallel() + ctx := testutil.Context(t, testutil.WaitShort) + configID := uuid.New() + providerID := uuid.New() + store := &advisorOverrideStubStore{ + getEnabledChatModelConfigByID: func(context.Context, uuid.UUID) (database.ChatModelConfig, error) { + return database.ChatModelConfig{ + ID: configID, + Provider: "openai", + Model: "gpt-5.2", + Enabled: true, + CreatedAt: time.Unix(0, 0).UTC(), + UpdatedAt: time.Unix(0, 0).UTC(), + DisplayName: "gpt-5.2", + AIProviderID: uuid.NullUUID{UUID: providerID, Valid: true}, + }, nil + }, + getAIProviderByID: func(context.Context, uuid.UUID) (database.AIProvider, error) { + return database.AIProvider{ + ID: providerID, + Type: database.AiProviderTypeOpenai, + Enabled: true, + }, nil + }, + getAIProviderKeysByProviderID: func(context.Context, uuid.UUID) ([]database.AIProviderKey, error) { + return []database.AIProviderKey{{ + ProviderID: providerID, + APIKey: "sk-selected", + }}, nil + }, + } + p := newAdvisorTestServer(ctx, t, store) + + gotModel, gotCfg := p.resolveAdvisorModelOverrideOrFallback( + ctx, + database.Chat{}, + codersdk.AdvisorConfig{ModelConfigID: configID}, + fallbackModel, + fallbackCallConfig, + chatprovider.ProviderAPIKeys{}, + modelBuildOptions{}, + logger, + ) + require.NotEqual(t, fantasy.LanguageModel(fallbackModel), gotModel) + require.NotNil(t, gotModel) + require.Equal(t, "openai", gotModel.Provider()) + require.Equal(t, "gpt-5.2", gotModel.Model()) + require.Equal(t, fallbackCallConfig, gotCfg) + }) +} + +func TestResolveAdvisorModelOverridePromotesAIBridgeErrors(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + configID := uuid.New() + providerID := uuid.New() + store := &advisorOverrideStubStore{ + getEnabledChatModelConfigByID: func(context.Context, uuid.UUID) (database.ChatModelConfig, error) { + return database.ChatModelConfig{ + ID: configID, + Provider: "openai", + Model: "gpt-5.2", + Enabled: true, + DisplayName: "gpt-5.2", + AIProviderID: uuid.NullUUID{UUID: providerID, Valid: true}, + }, nil + }, + getAIProviderByID: func(context.Context, uuid.UUID) (database.AIProvider, error) { + return database.AIProvider{ID: providerID, Type: database.AiProviderTypeOpenai, Name: "primary-openai", Enabled: true}, nil + }, + getAIProviderKeysByProviderID: func(context.Context, uuid.UUID) ([]database.AIProviderKey, error) { + return []database.AIProviderKey{{ProviderID: providerID, APIKey: "sk-selected"}}, nil + }, + } + p := newAdvisorTestServer(ctx, t, store) + p.aiGatewayRoutingEnabled = true + + ctx = aibridge.WithDelegatedAPIKeyID(ctx, uuid.NewString()) + model, _, err := p.resolveAdvisorModelOverride( + ctx, + database.Chat{ID: uuid.New(), OwnerID: uuid.New()}, + codersdk.AdvisorConfig{ModelConfigID: configID}, + &chattest.FakeModel{ProviderName: "stub", ModelName: "stub"}, + codersdk.ChatModelCallConfig{}, + chatprovider.ProviderAPIKeys{}, + modelBuildOptions{ActiveAPIKeyID: uuid.NewString()}, + slog.Make(), + ) + require.ErrorContains(t, err, "AI Gateway transport factory") + require.Nil(t, model) } // TestStripAdvisorGuidanceBlock exercises the filter that keeps the advisor @@ -346,7 +555,7 @@ func TestNewAdvisorRuntime(t *testing.T) { store := &advisorOverrideStubStore{} p := newAdvisorTestServer(ctx, t, store) - rt := p.newAdvisorRuntime( + rt := p.newAdvisorRuntimeOrFallback( ctx, database.Chat{}, codersdk.AdvisorConfig{ @@ -357,6 +566,7 @@ func TestNewAdvisorRuntime(t *testing.T) { fallbackModel, fallbackCallConfig, chatprovider.ProviderAPIKeys{}, + modelBuildOptions{}, logger, ) require.NotNil(t, rt, "zero max uses must default rather than bail out") @@ -370,7 +580,7 @@ func TestNewAdvisorRuntime(t *testing.T) { store := &advisorOverrideStubStore{} p := newAdvisorTestServer(ctx, t, store) - rt := p.newAdvisorRuntime( + rt := p.newAdvisorRuntimeOrFallback( ctx, database.Chat{}, codersdk.AdvisorConfig{ @@ -381,6 +591,7 @@ func TestNewAdvisorRuntime(t *testing.T) { fallbackModel, fallbackCallConfig, chatprovider.ProviderAPIKeys{}, + modelBuildOptions{}, logger, ) require.Nil(t, rt, "negative max uses must disable the advisor") @@ -392,7 +603,7 @@ func TestNewAdvisorRuntime(t *testing.T) { store := &advisorOverrideStubStore{} p := newAdvisorTestServer(ctx, t, store) - rt := p.newAdvisorRuntime( + rt := p.newAdvisorRuntimeOrFallback( ctx, database.Chat{}, codersdk.AdvisorConfig{ @@ -403,6 +614,7 @@ func TestNewAdvisorRuntime(t *testing.T) { fallbackModel, fallbackCallConfig, chatprovider.ProviderAPIKeys{}, + modelBuildOptions{}, logger, ) require.NotNil(t, rt, diff --git a/coderd/x/chatd/chatd.go b/coderd/x/chatd/chatd.go index 991108a02b..0658ede670 100644 --- a/coderd/x/chatd/chatd.go +++ b/coderd/x/chatd/chatd.go @@ -29,6 +29,7 @@ import ( "golang.org/x/xerrors" "cdr.dev/slog/v3" + "github.com/coder/coder/v2/coderd/aibridge" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/db2sdk" "github.com/coder/coder/v2/coderd/database/dbauthz" @@ -254,6 +255,9 @@ type Server struct { metrics *chatloop.Metrics recordingSem chan struct{} + aibridgeTransportFactory *atomic.Pointer[aibridge.TransportFactory] + aiGatewayRoutingEnabled bool + // Configuration pendingChatAcquireInterval time.Duration maxChatsPerAcquire int32 @@ -344,10 +348,11 @@ func (p *Server) resolveAdvisorModelOverride( fallbackModel fantasy.LanguageModel, fallbackCallConfig codersdk.ChatModelCallConfig, providerKeys chatprovider.ProviderAPIKeys, + modelOpts modelBuildOptions, logger slog.Logger, -) (fantasy.LanguageModel, codersdk.ChatModelCallConfig) { +) (fantasy.LanguageModel, codersdk.ChatModelCallConfig, error) { if advisorCfg.ModelConfigID == uuid.Nil { - return fallbackModel, fallbackCallConfig + return fallbackModel, fallbackCallConfig, nil } // Re-read the override instead of using the cache so disabled models @@ -363,7 +368,7 @@ func (p *Server) resolveAdvisorModelOverride( "advisor model config is disabled or unavailable, continuing with chat model", slog.F("model_config_id", advisorCfg.ModelConfigID), ) - return fallbackModel, fallbackCallConfig + return fallbackModel, fallbackCallConfig, nil } logger.Warn( ctx, @@ -371,7 +376,7 @@ func (p *Server) resolveAdvisorModelOverride( slog.F("model_config_id", advisorCfg.ModelConfigID), slog.Error(err), ) - return fallbackModel, fallbackCallConfig + return fallbackModel, fallbackCallConfig, nil } overrideCallConfig := codersdk.ChatModelCallConfig{} @@ -383,29 +388,48 @@ func (p *Server) resolveAdvisorModelOverride( slog.F("model_config_id", advisorCfg.ModelConfigID), slog.Error(err), ) - return fallbackModel, fallbackCallConfig + return fallbackModel, fallbackCallConfig, nil } } - overrideModel, err := chatprovider.ModelFromConfig( - overrideConfig.Provider, - overrideConfig.Model, + route, err := p.resolveModelRouteForConfig( + ctx, + chat.OwnerID, + overrideConfig, providerKeys, - chatprovider.UserAgent(), - chatprovider.CoderHeaders(chat), - nil, ) if err != nil { + if p.shouldUseAIGatewayRouting() && overrideConfig.AIProviderID.Valid { + return nil, codersdk.ChatModelCallConfig{}, xerrors.Errorf("resolve advisor override route: %w", err) + } + logger.Warn( + ctx, + "failed to resolve advisor override route, continuing with chat model", + slog.F("model_config_id", advisorCfg.ModelConfigID), + slog.Error(err), + ) + return fallbackModel, fallbackCallConfig, nil + } + overrideModel, err := p.newModel(ctx, modelClientRequest{ + Chat: chat, + ModelName: overrideConfig.Model, + UserAgent: chatprovider.UserAgent(), + ExtraHeaders: chatprovider.CoderHeaders(chat), + }, route, modelOpts) + if err != nil { + if p.shouldUseAIGatewayRouting() && overrideConfig.AIProviderID.Valid { + return nil, codersdk.ChatModelCallConfig{}, xerrors.Errorf("create advisor override model: %w", err) + } logger.Warn( ctx, "failed to create advisor override model, continuing with chat model", slog.F("model_config_id", advisorCfg.ModelConfigID), slog.Error(err), ) - return fallbackModel, fallbackCallConfig + return fallbackModel, fallbackCallConfig, nil } - return overrideModel, overrideCallConfig + return overrideModel, overrideCallConfig, nil } func (p *Server) newAdvisorRuntime( @@ -415,17 +439,22 @@ func (p *Server) newAdvisorRuntime( fallbackModel fantasy.LanguageModel, fallbackCallConfig codersdk.ChatModelCallConfig, providerKeys chatprovider.ProviderAPIKeys, + modelOpts modelBuildOptions, logger slog.Logger, -) *chatadvisor.Runtime { - advisorModel, advisorCallConfig := p.resolveAdvisorModelOverride( +) (*chatadvisor.Runtime, error) { + advisorModel, advisorCallConfig, err := p.resolveAdvisorModelOverride( ctx, chat, advisorCfg, fallbackModel, fallbackCallConfig, providerKeys, + modelOpts, logger, ) + if err != nil { + return nil, err + } maxUsesPerRun := advisorCfg.MaxUsesPerRun switch { @@ -441,7 +470,7 @@ func (p *Server) newAdvisorRuntime( "invalid advisor max uses per run, continuing without advisor", slog.F("max_uses_per_run", maxUsesPerRun), ) - return nil + return nil, nil //nolint:nilnil // Nil runtime with nil error means advisor is skipped for this turn. } maxOutputTokens := advisorCfg.MaxOutputTokens @@ -468,9 +497,9 @@ func (p *Server) newAdvisorRuntime( "failed to create advisor runtime, continuing without advisor", slog.Error(err), ) - return nil + return nil, nil //nolint:nilnil // Nil runtime with nil error means advisor is skipped for this turn. } - return rt + return rt, nil } // cachedWorkspaceMCPTools stores workspace MCP tools discovered @@ -1436,6 +1465,7 @@ type CreateOptions struct { ClientType database.ChatClientType SystemPrompt string InitialUserContent []codersdk.ChatMessagePart + APIKeyID string MCPServerIDs []uuid.UUID Labels database.StringMap DynamicTools json.RawMessage @@ -1460,6 +1490,7 @@ type SendMessageOptions struct { CreatedBy uuid.UUID Content []codersdk.ChatMessagePart ModelConfigID uuid.UUID + APIKeyID string BusyBehavior SendMessageBusyBehavior PlanMode *database.NullChatPlanMode MCPServerIDs *[]uuid.UUID @@ -1479,6 +1510,7 @@ type EditMessageOptions struct { CreatedBy uuid.UUID EditedMessageID int64 Content []codersdk.ChatMessagePart + APIKeyID string // ModelConfigID, when non-zero, overrides the model used for // the replacement user message. When set to uuid.Nil the // original message's model is preserved. @@ -1647,7 +1679,7 @@ func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.C database.ChatMessageVisibilityBoth, opts.ModelConfigID, chatprompt.CurrentContentVersion, - ).withCreatedBy(opts.OwnerID)) + ).withCreatedBy(opts.OwnerID).withAPIKeyID(opts.APIKeyID)) _, err = tx.InsertChatMessages(ctx, msgParams) if err != nil { @@ -1786,6 +1818,10 @@ func (p *Server) SendMessage( UUID: modelConfigID, Valid: modelConfigID != uuid.Nil, }, + APIKeyID: sql.NullString{ + String: opts.APIKeyID, + Valid: opts.APIKeyID != "", + }, }) if err != nil { return xerrors.Errorf("insert queued message: %w", err) @@ -1810,6 +1846,7 @@ func (p *Server) SendMessage( modelConfigID, content, opts.CreatedBy, + opts.APIKeyID, ) if err != nil { return err @@ -2083,7 +2120,7 @@ func (p *Server) EditMessage( editedMsg.Visibility, messageModelConfigID, chatprompt.CurrentContentVersion, - ).withCreatedBy(opts.CreatedBy)) + ).withCreatedBy(opts.CreatedBy).withAPIKeyID(opts.APIKeyID)) newMessages, err := insertChatMessageWithStore(ctx, tx, msgParams) if err != nil { return xerrors.Errorf("insert replacement message: %w", err) @@ -2416,12 +2453,14 @@ func (p *Server) PromoteQueued( var ( targetContent json.RawMessage targetModelConfigID uuid.NullUUID + targetAPIKeyID sql.NullString found bool ) for _, qm := range queuedMessages { if qm.ID == opts.QueuedMessageID { targetContent = qm.Content targetModelConfigID = qm.ModelConfigID + targetAPIKeyID = qm.APIKeyID found = true break } @@ -2511,6 +2550,7 @@ func (p *Server) PromoteQueued( Valid: len(targetContent) > 0, }, opts.CreatedBy, + targetAPIKeyID.String, ) if err != nil { return err @@ -2754,6 +2794,7 @@ func (p *Server) SubmitToolResults( params := database.InsertChatMessagesParams{ ChatID: opts.ChatID, CreatedBy: make([]uuid.UUID, n), + APIKeyID: make([]string, n), ModelConfigID: make([]uuid.UUID, n), Role: make([]database.ChatMessageRole, n), Content: make([]string, n), @@ -2862,16 +2903,18 @@ var ErrManualTitleRegenerationInProgress = xerrors.New( ) type manualTitleCandidateResult struct { - title string - modelConfig database.ChatModelConfig - usage fantasy.Usage - hasMessages bool + title string + modelConfig database.ChatModelConfig + usage fantasy.Usage + activeAPIKeyID string + hasMessages bool } type manualTitleGenerationError struct { - cause error - modelConfig database.ChatModelConfig - usage fantasy.Usage + cause error + modelConfig database.ChatModelConfig + usage fantasy.Usage + activeAPIKeyID string } func (e *manualTitleGenerationError) Error() string { @@ -3105,6 +3148,7 @@ func (p *Server) recordManualTitleGenerationFailure( chat, generationErr.modelConfig, generationErr.usage, + generationErr.activeAPIKeyID, "", ); recordErr != nil { return errors.Join( @@ -3154,11 +3198,13 @@ func (p *Server) generateManualTitleCandidate( if len(messages) == 0 { return manualTitleCandidateResult{}, nil } + modelOpts := modelBuildOptionsFromMessages(messages) - model, modelConfig, modelKeys, err := p.resolveManualTitleModel(ctx, store, chat, keys) + model, modelConfig, modelKeys, err := p.resolveManualTitleModel(ctx, store, chat, keys, modelOpts) result := manualTitleCandidateResult{ - modelConfig: modelConfig, - hasMessages: true, + modelConfig: modelConfig, + activeAPIKeyID: modelOpts.ActiveAPIKeyID, + hasMessages: true, } if err != nil { return result, err @@ -3174,6 +3220,7 @@ func (p *Server) generateManualTitleCandidate( chat, modelConfig, modelKeys, + modelOpts, messages, model, ) @@ -3189,9 +3236,10 @@ func (p *Server) generateManualTitleCandidate( return result, wrappedErr } return result, &manualTitleGenerationError{ - cause: wrappedErr, - modelConfig: modelConfig, - usage: usage, + cause: wrappedErr, + modelConfig: modelConfig, + usage: usage, + activeAPIKeyID: modelOpts.ActiveAPIKeyID, } } @@ -3220,6 +3268,7 @@ func (p *Server) proposeChatTitleWithStore( chat, result.modelConfig, result.usage, + result.activeAPIKeyID, "", ); recordErr != nil { return "", xerrors.Errorf("record manual title usage: %w", recordErr) @@ -3250,6 +3299,7 @@ func (p *Server) regenerateChatTitleWithStore( chat, result.modelConfig, result.usage, + result.activeAPIKeyID, result.title, ) if recordErr != nil { @@ -3272,6 +3322,7 @@ func (p *Server) prepareManualTitleDebugRun( chat database.Chat, modelConfig database.ChatModelConfig, keys chatprovider.ProviderAPIKeys, + modelOpts modelBuildOptions, messages []database.ChatMessage, fallbackModel fantasy.LanguageModel, ) (context.Context, fantasy.LanguageModel, func(error)) { @@ -3279,15 +3330,21 @@ func (p *Server) prepareManualTitleDebugRun( titleModel := fallbackModel finishDebugRun := func(error) {} - httpClient := &http.Client{Transport: &chatdebug.RecordingTransport{}} - debugModel, debugModelErr := chatprovider.ModelFromConfig( - modelConfig.Provider, - modelConfig.Model, - keys, - chatprovider.UserAgent(), - chatprovider.CoderHeaders(chat), - httpClient, - ) + route, routeErr := p.resolveModelRouteForConfig(ctx, chat.OwnerID, modelConfig, keys) + debugOpts := modelOpts + debugOpts.RecordHTTP = true + var debugModelErr error + var debugModel fantasy.LanguageModel + if routeErr != nil { + debugModelErr = routeErr + } else { + debugModel, debugModelErr = p.newModel(ctx, modelClientRequest{ + Chat: chat, + ModelName: modelConfig.Model, + UserAgent: chatprovider.UserAgent(), + ExtraHeaders: chatprovider.CoderHeaders(chat), + }, route, debugOpts) + } switch { case debugModelErr != nil: p.logger.Warn(ctx, "failed to create debug-aware manual title model", @@ -3535,11 +3592,13 @@ func (p *Server) resolveManualTitleModel( store database.Store, chat database.Chat, keys chatprovider.ProviderAPIKeys, + modelOpts modelBuildOptions, ) (fantasy.LanguageModel, database.ChatModelConfig, chatprovider.ProviderAPIKeys, error) { - overrideConfig, overrideModel, overrideKeys, overrideSet, overrideErr := p.resolveTitleGenerationModelOverride( + overrideConfig, overrideModel, overrideKeys, _, overrideSet, overrideErr := p.resolveTitleGenerationModelOverride( ctx, chat, keys, + modelOpts, ) if overrideErr != nil { if overrideSet { @@ -3562,15 +3621,15 @@ func (p *Server) resolveManualTitleModel( slog.F("chat_id", chat.ID), slog.Error(err), ) - return p.resolveFallbackManualTitleModel(ctx, chat, keys) + return p.resolveFallbackManualTitleModel(ctx, chat, keys, modelOpts) } config, ok := selectPreferredConfiguredShortTextModelConfig(configs) if !ok { - return p.resolveFallbackManualTitleModel(ctx, chat, keys) + return p.resolveFallbackManualTitleModel(ctx, chat, keys, modelOpts) } - providerHint, modelKeys, err := p.resolveModelConfigProviderHintAndKeys(ctx, chat.OwnerID, config, keys) + route, err := p.resolveModelRouteForConfig(ctx, chat.OwnerID, config, keys) if err != nil { p.logger.Debug(ctx, "manual title preferred model unavailable", slog.F("chat_id", chat.ID), @@ -3578,33 +3637,32 @@ func (p *Server) resolveManualTitleModel( slog.F("model", config.Model), slog.Error(err), ) - return p.resolveFallbackManualTitleModel(ctx, chat, keys) + return p.resolveFallbackManualTitleModel(ctx, chat, keys, modelOpts) } - model, err := chatprovider.ModelFromConfig( - providerHint, - config.Model, - modelKeys, - chatprovider.UserAgent(), - chatprovider.CoderHeaders(chat), - nil, - ) + model, err := p.newModel(ctx, modelClientRequest{ + Chat: chat, + ModelName: config.Model, + UserAgent: chatprovider.UserAgent(), + ExtraHeaders: chatprovider.CoderHeaders(chat), + }, route, modelOpts) if err != nil { p.logger.Debug(ctx, "manual title preferred model unavailable", slog.F("chat_id", chat.ID), - slog.F("provider", providerHint), + slog.F("provider", config.Provider), slog.F("model", config.Model), slog.Error(err), ) - return p.resolveFallbackManualTitleModel(ctx, chat, keys) + return p.resolveFallbackManualTitleModel(ctx, chat, keys, modelOpts) } - return model, config, modelKeys, nil + return model, config, route.directProviderKeys(), nil } func (p *Server) resolveFallbackManualTitleModel( ctx context.Context, chat database.Chat, keys chatprovider.ProviderAPIKeys, + modelOpts modelBuildOptions, ) (fantasy.LanguageModel, database.ChatModelConfig, chatprovider.ProviderAPIKeys, error) { config, err := p.resolveModelConfig(ctx, chat) if err != nil { @@ -3613,25 +3671,23 @@ func (p *Server) resolveFallbackManualTitleModel( err, ) } - providerHint, modelKeys, err := p.resolveModelConfigProviderHintAndKeys(ctx, chat.OwnerID, config, keys) + route, err := p.resolveModelRouteForConfig(ctx, chat.OwnerID, config, keys) if err != nil { return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, err } - model, err := chatprovider.ModelFromConfig( - providerHint, - config.Model, - modelKeys, - chatprovider.UserAgent(), - chatprovider.CoderHeaders(chat), - nil, - ) + model, err := p.newModel(ctx, modelClientRequest{ + Chat: chat, + ModelName: config.Model, + UserAgent: chatprovider.UserAgent(), + ExtraHeaders: chatprovider.CoderHeaders(chat), + }, route, modelOpts) if err != nil { return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, xerrors.Errorf( "create fallback manual title model: %w", err, ) } - return model, config, modelKeys, nil + return model, config, route.directProviderKeys(), nil } func mergeManualTitleMessages( @@ -3682,6 +3738,7 @@ func recordManualTitleUsage( chat database.Chat, modelConfig database.ChatModelConfig, usage fantasy.Usage, + activeAPIKeyID string, newTitle string, ) (database.Chat, error) { hasUsage := usage != (fantasy.Usage{}) @@ -3720,6 +3777,7 @@ func recordManualTitleUsage( messages, err := tx.InsertChatMessages(ctx, database.InsertChatMessagesParams{ ChatID: chat.ID, CreatedBy: []uuid.UUID{chat.OwnerID}, + APIKeyID: []string{activeAPIKeyID}, ModelConfigID: []uuid.UUID{modelConfig.ID}, Role: []database.ChatMessageRole{database.ChatMessageRoleAssistant}, Content: []string{content}, @@ -3845,6 +3903,7 @@ type chatMessage struct { visibility database.ChatMessageVisibility modelConfigID uuid.UUID createdBy uuid.UUID + apiKeyID string contentVersion int16 compressed bool inputTokens int64 @@ -3880,6 +3939,11 @@ func (m chatMessage) withCreatedBy(id uuid.UUID) chatMessage { return m } +func (m chatMessage) withAPIKeyID(id string) chatMessage { + m.apiKeyID = id + return m +} + func (m chatMessage) withCompressed() chatMessage { m.compressed = true return m @@ -3924,6 +3988,7 @@ func appendChatMessage( msg chatMessage, ) { params.CreatedBy = append(params.CreatedBy, msg.createdBy) + params.APIKeyID = append(params.APIKeyID, msg.apiKeyID) params.ModelConfigID = append(params.ModelConfigID, msg.modelConfigID) params.Role = append(params.Role, msg.role) params.Content = append(params.Content, string(msg.content.RawMessage)) @@ -3973,6 +4038,7 @@ func insertUserMessageAndSetPending( modelConfigID uuid.UUID, content pqtype.NullRawMessage, createdBy uuid.UUID, + apiKeyID string, ) (database.ChatMessage, database.Chat, error) { msgParams := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendChatMessage. ChatID: lockedChat.ID, @@ -3983,7 +4049,7 @@ func insertUserMessageAndSetPending( database.ChatMessageVisibilityBoth, modelConfigID, chatprompt.CurrentContentVersion, - ).withCreatedBy(createdBy)) + ).withCreatedBy(createdBy).withAPIKeyID(apiKeyID)) messages, err := insertChatMessageWithStore(ctx, store, msgParams) if err != nil { return database.ChatMessage{}, database.Chat{}, err @@ -4052,7 +4118,10 @@ type Config struct { WebpushDispatcher webpush.Dispatcher UsageTracker *workspacestats.UsageTracker Clock quartz.Clock - PrometheusRegistry prometheus.Registerer + AIBridgeTransportFactory *atomic.Pointer[aibridge.TransportFactory] + AIGatewayRoutingEnabled bool + + PrometheusRegistry prometheus.Registerer // OIDCTokenSource resolves the calling user's OIDC access // token for MCP servers configured with auth_type=user_oidc. @@ -4139,6 +4208,8 @@ func New(cfg Config) *Server { debugSvc.SetStaleAfter(inFlightChatStaleAfter * 3) return debugSvc }, + aibridgeTransportFactory: cfg.AIBridgeTransportFactory, + aiGatewayRoutingEnabled: cfg.AIGatewayRoutingEnabled, pendingChatAcquireInterval: pendingChatAcquireInterval, maxChatsPerAcquire: maxChatsPerAcquire, inFlightChatStaleAfter: inFlightChatStaleAfter, @@ -5803,7 +5874,7 @@ func (p *Server) tryAutoPromoteQueuedMessage( database.ChatMessageVisibilityBoth, effectiveModelConfigID, chatprompt.CurrentContentVersion, - ).withCreatedBy(chat.OwnerID)) + ).withCreatedBy(chat.OwnerID).withAPIKeyID(nextQueued.APIKeyID.String)) msgs, err := insertChatMessageWithStore(ctx, tx, msgParams) if err != nil { return nil, nil, false, xerrors.Errorf("insert promoted message: %w", err) @@ -6322,11 +6393,39 @@ type runChatResult struct { ProviderKeys chatprovider.ProviderAPIKeys PendingDynamicToolCalls []chatloop.PendingToolCall FallbackProvider string + FallbackRoute resolvedModelRoute FallbackModel string + ModelBuildOptions modelBuildOptions TriggerMessageID int64 HistoryTipMessageID int64 } +func contextWithActiveTurnAPIKeyID(ctx context.Context, messages []database.ChatMessage) context.Context { + apiKeyID, ok := activeTurnAPIKeyIDFromMessages(messages) + if !ok { + return ctx + } + return aibridge.WithDelegatedAPIKeyID(ctx, apiKeyID) +} + +func activeTurnAPIKeyIDFromMessages(messages []database.ChatMessage) (string, bool) { + for i := len(messages) - 1; i >= 0; i-- { + message := messages[i] + if message.Role != database.ChatMessageRoleUser { + continue + } + if message.Visibility != database.ChatMessageVisibilityBoth && + message.Visibility != database.ChatMessageVisibilityUser { + continue + } + if !message.APIKeyID.Valid || message.APIKeyID.String == "" { + return "", false + } + return message.APIKeyID.String, true + } + return "", false +} + func allToolNames(allTools []fantasy.AgentTool) []string { toolNames := make([]string, 0, len(allTools)) for _, tool := range allTools { @@ -6948,12 +7047,20 @@ func (p *Server) runChat( err error debugEnabled bool debugProvider string + modelRoute resolvedModelRoute debugModel string ) - // Load MCP server configs and user tokens in parallel with - // model resolution and message loading. These queries have - // no dependencies on each other and all hit different tables. + messages, err = p.db.GetChatMessagesForPromptByChatID(ctx, chat.ID) + if err != nil { + return result, xerrors.Errorf("get chat messages: %w", err) + } + modelOpts := modelBuildOptionsFromMessages(messages) + ctx = contextWithActiveTurnAPIKeyID(ctx, messages) + + // Load MCP server configs and user tokens in parallel with model + // resolution. These queries have no dependencies on each other and all + // hit different tables. var ( mcpConfigs []database.MCPServerConfig mcpTokens []database.MCPServerUserToken @@ -6961,7 +7068,7 @@ func (p *Server) runChat( var g errgroup.Group g.Go(func() error { var err error - model, modelConfig, providerKeys, debugEnabled, debugProvider, debugModel, err = p.resolveChatModel(ctx, chat) + model, modelConfig, providerKeys, modelRoute, debugEnabled, debugProvider, debugModel, err = p.resolveChatModel(ctx, chat, modelOpts) if err != nil { return err } @@ -6972,14 +7079,6 @@ func (p *Server) runChat( } return nil }) - g.Go(func() error { - var err error - messages, err = p.db.GetChatMessagesForPromptByChatID(ctx, chat.ID) - if err != nil { - return xerrors.Errorf("get chat messages: %w", err) - } - return nil - }) if len(chat.MCPServerIDs) > 0 { g.Go(func() error { var err error @@ -7055,15 +7154,20 @@ func (p *Server) runChat( // registering the runtime there would inject guidance for a tool // that is never exposed to the model. if advisorCfg.Enabled && isRootChat && !isPlanModeTurn && !isExploreSubagent { - advisorRuntime = p.newAdvisorRuntime( + var advisorErr error + advisorRuntime, advisorErr = p.newAdvisorRuntime( ctx, chat, advisorCfg, model, callConfig, providerKeys, + modelOpts, logger, ) + if advisorErr != nil { + return result, advisorErr + } } var advisorPromptSnapshot []fantasy.Message @@ -7090,7 +7194,9 @@ func (p *Server) runChat( result.StatusLabelModel = model result.ProviderKeys = providerKeys result.FallbackProvider = modelConfig.Provider + result.FallbackRoute = modelRoute result.FallbackModel = modelConfig.Model + result.ModelBuildOptions = modelOpts debugSvc := p.existingDebugService() // Fire title generation asynchronously so it doesn't block the // chat response. It uses a detached context so it can finish @@ -7111,7 +7217,9 @@ func (p *Server) runChat( modelConfig.Provider, modelConfig.Model, titleModel, + modelRoute, titleProviderKeys, + modelOpts, generatedTitle, titleLogger, debugSvc, @@ -7681,20 +7789,21 @@ func (p *Server) runChat( } if isComputerUse { - computerUseProviderKeys, keyErr := p.resolveUserProviderAPIKeysForProviderType(ctx, chat.OwnerID, computerUseModelProvider) + computerUseRoute, keyErr := p.resolveModelRouteForProviderType(ctx, chat.OwnerID, computerUseModelProvider) if keyErr != nil { - return result, xerrors.Errorf("resolve computer use provider API keys: %w", keyErr) + return result, xerrors.Errorf("resolve computer use provider route: %w", keyErr) } - providerKeys = computerUseProviderKeys + providerKeys = computerUseRoute.directProviderKeys() // Override model for computer use subagent. cuModel, cuDebugEnabled, resolvedProvider, resolvedModel, cuErr := p.resolveComputerUseModel( ctx, chat, - providerKeys, + computerUseRoute, computerUseProvider, computerUseModelProvider, computerUseModelName, + modelOpts, ) if cuErr != nil { return result, cuErr @@ -8394,45 +8503,15 @@ func (p *Server) persistChatContextSummary( return nil } -func (p *Server) resolveModelConfigProviderHintAndKeys( - ctx context.Context, - ownerID uuid.UUID, - modelConfig database.ChatModelConfig, - fallbackKeys chatprovider.ProviderAPIKeys, -) (string, chatprovider.ProviderAPIKeys, error) { - providerHint := modelConfig.Provider - if !modelConfig.AIProviderID.Valid { - if !fallbackKeys.Empty() && userCanUseProviderKeys(fallbackKeys, providerHint) { - return providerHint, fallbackKeys, nil - } - keys, err := p.resolveUserProviderAPIKeys(ctx, ownerID, uuid.Nil) - if err != nil { - return "", chatprovider.ProviderAPIKeys{}, xerrors.Errorf("resolve provider API keys: %w", err) - } - return providerHint, keys, nil - } - //nolint:gocritic // Manual title generation needs chatd-scoped provider reads for user-owned chats. - provider, err := p.db.GetAIProviderByID(dbauthz.AsChatd(ctx), modelConfig.AIProviderID.UUID) - if err != nil { - return "", chatprovider.ProviderAPIKeys{}, xerrors.Errorf("get AI provider: %w", err) - } - if !provider.Enabled { - return "", chatprovider.ProviderAPIKeys{}, xerrors.Errorf("AI provider %s is disabled", provider.ID) - } - providerKeys, err := p.resolveUserProviderAPIKeysForProvider(ctx, ownerID, provider) - if err != nil { - return "", chatprovider.ProviderAPIKeys{}, xerrors.Errorf("resolve provider API keys: %w", err) - } - return string(provider.Type), providerKeys, nil -} - func (p *Server) resolveChatModel( ctx context.Context, chat database.Chat, + modelOpts modelBuildOptions, ) ( model fantasy.LanguageModel, dbConfig database.ChatModelConfig, keys chatprovider.ProviderAPIKeys, + route resolvedModelRoute, debugEnabled bool, resolvedProvider string, resolvedModel string, @@ -8440,57 +8519,45 @@ func (p *Server) resolveChatModel( ) { dbConfig, err = p.resolveModelConfig(ctx, chat) if err != nil { - return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, false, "", "", xerrors.Errorf("resolve model config: %w", err) + return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, resolvedModelRoute{}, false, "", "", xerrors.Errorf("resolve model config: %w", err) } if !dbConfig.Enabled { - return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, false, "", "", xerrors.Errorf("chat model config %s is disabled", dbConfig.ID) + return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, resolvedModelRoute{}, false, "", "", xerrors.Errorf("chat model config %s is disabled", dbConfig.ID) } - providerHint := dbConfig.Provider - var keyErr error - if dbConfig.AIProviderID.Valid { - provider, err := p.db.GetAIProviderByID(ctx, dbConfig.AIProviderID.UUID) - if err != nil { - return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, false, "", "", xerrors.Errorf("get AI provider: %w", err) - } - if !provider.Enabled { - return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, false, "", "", xerrors.Errorf("AI provider %s is disabled", provider.ID) - } - providerHint = string(provider.Type) - keys, keyErr = p.resolveUserProviderAPIKeysForProvider(ctx, chat.OwnerID, provider) - } else { - keys, keyErr = p.resolveUserProviderAPIKeys(ctx, chat.OwnerID, uuid.Nil) - } - if keyErr != nil { - return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, false, "", "", xerrors.Errorf("resolve provider API keys: %w", keyErr) + route, err = p.resolveModelRouteForConfig(ctx, chat.OwnerID, dbConfig, chatprovider.ProviderAPIKeys{}) + if err != nil { + return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, resolvedModelRoute{}, false, "", "", err } + keys = route.directProviderKeys() + providerHint, err := route.providerHint() + if err != nil { + return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, resolvedModelRoute{}, false, "", "", err + } resolvedProvider, resolvedModel, err = chatprovider.ResolveModelWithProviderHint( dbConfig.Model, providerHint, ) if err != nil { - return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, false, "", "", xerrors.Errorf( + return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, resolvedModelRoute{}, false, "", "", xerrors.Errorf( "resolve model metadata: %w", err, ) } - model, debugEnabled, err = p.newDebugAwareModelFromConfig( - ctx, - chat, - providerHint, - dbConfig.Model, - keys, - chatprovider.UserAgent(), - chatprovider.CoderHeaders(chat), - ) + model, debugEnabled, err = p.newDebugAwareModel(ctx, modelClientRequest{ + Chat: chat, + ModelName: dbConfig.Model, + UserAgent: chatprovider.UserAgent(), + ExtraHeaders: chatprovider.CoderHeaders(chat), + }, route, modelOpts) if err != nil { - return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, false, "", "", xerrors.Errorf( + return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, resolvedModelRoute{}, false, "", "", xerrors.Errorf( "create model: %w", err, ) } - return model, dbConfig, keys, debugEnabled, resolvedProvider, resolvedModel, nil + return model, dbConfig, keys, route, debugEnabled, resolvedProvider, resolvedModel, nil } func (p *Server) aiProviderConfig(ctx context.Context, provider database.AIProvider) (chatprovider.ConfiguredProvider, error) { @@ -8605,9 +8672,18 @@ func (p *Server) resolveUserProviderAPIKeysForProviderType( ownerID uuid.UUID, providerType string, ) (chatprovider.ProviderAPIKeys, error) { + keys, _, err := p.resolveUserProviderAPIKeysAndProviderForProviderType(ctx, ownerID, providerType) + return keys, err +} + +func (p *Server) resolveUserProviderAPIKeysAndProviderForProviderType( + ctx context.Context, + ownerID uuid.UUID, + providerType string, +) (chatprovider.ProviderAPIKeys, *database.AIProvider, error) { providers, err := p.db.GetAIProviders(ctx, database.GetAIProvidersParams{}) if err != nil { - return chatprovider.ProviderAPIKeys{}, xerrors.Errorf("get enabled AI providers: %w", err) + return chatprovider.ProviderAPIKeys{}, nil, xerrors.Errorf("get enabled AI providers: %w", err) } normalizedProviderType := chatprovider.NormalizeProvider(providerType) for _, provider := range providers { @@ -8616,13 +8692,17 @@ func (p *Server) resolveUserProviderAPIKeysForProviderType( } keys, err := p.resolveUserProviderAPIKeysForProvider(ctx, ownerID, provider) if err != nil { - return chatprovider.ProviderAPIKeys{}, err + return chatprovider.ProviderAPIKeys{}, nil, err } if userCanUseProviderKeys(keys, normalizedProviderType) { - return keys, nil + return keys, &provider, nil } } - return p.resolveUserProviderAPIKeys(ctx, ownerID, uuid.Nil) + keys, err := p.resolveUserProviderAPIKeys(ctx, ownerID, uuid.Nil) + if err != nil { + return chatprovider.ProviderAPIKeys{}, nil, err + } + return keys, nil, nil } func (p *Server) resolveUserProviderAPIKeys( @@ -9391,6 +9471,7 @@ func insertSyntheticToolResultsTx( params := database.InsertChatMessagesParams{ ChatID: chat.ID, CreatedBy: make([]uuid.UUID, n), + APIKeyID: make([]string, n), ModelConfigID: make([]uuid.UUID, n), Role: make([]database.ChatMessageRole, n), Content: make([]string, n), @@ -9537,7 +9618,7 @@ func (p *Server) generateFinalTurnStatusLabel( return fallbackTurnStatusLabel(status) } - statusLabel := generateTurnStatusLabel( + statusLabel := p.generateTurnStatusLabel( ctx, chat, status, @@ -9545,7 +9626,9 @@ func (p *Server) generateFinalTurnStatusLabel( runResult.FallbackProvider, runResult.FallbackModel, runResult.StatusLabelModel, + runResult.FallbackRoute, runResult.ProviderKeys, + runResult.ModelBuildOptions, logger, p.existingDebugService(), runResult.TriggerMessageID, diff --git a/coderd/x/chatd/chatd_debug.go b/coderd/x/chatd/chatd_debug.go index 3a803c9afa..79dc419418 100644 --- a/coderd/x/chatd/chatd_debug.go +++ b/coderd/x/chatd/chatd_debug.go @@ -2,14 +2,11 @@ package chatd import ( "context" - "net/http" "time" "charm.land/fantasy" - "golang.org/x/xerrors" "cdr.dev/slog/v3" - "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/x/chatd/chatdebug" "github.com/coder/coder/v2/coderd/x/chatd/chatprovider" ) @@ -109,53 +106,38 @@ func (p *Server) scheduleDebugCleanup( }() } -func (p *Server) newDebugAwareModelFromConfig( +func (p *Server) newDebugAwareModel( ctx context.Context, - chat database.Chat, - providerHint string, - modelName string, - providerKeys chatprovider.ProviderAPIKeys, - userAgent string, - extraHeaders map[string]string, + req modelClientRequest, + route resolvedModelRoute, + opts modelBuildOptions, ) (fantasy.LanguageModel, bool, error) { - provider, resolvedModel, err := chatprovider.ResolveModelWithProviderHint(modelName, providerHint) + providerHint, err := route.providerHint() if err != nil { return nil, false, err } + provider, resolvedModel, err := chatprovider.ResolveModelWithProviderHint(req.ModelName, providerHint) + if err != nil { + return nil, false, err + } + route = route.withProviderHint(provider) + req.ModelName = resolvedModel debugSvc := p.debugService() - debugEnabled := debugSvc != nil && debugSvc.IsEnabled(ctx, chat.ID, chat.OwnerID) + debugEnabled := debugSvc != nil && debugSvc.IsEnabled(ctx, req.Chat.ID, req.Chat.OwnerID) + opts.RecordHTTP = debugEnabled - var httpClient *http.Client - if debugEnabled { - httpClient = &http.Client{Transport: &chatdebug.RecordingTransport{}} - } - - model, err := chatprovider.ModelFromConfig( - provider, - resolvedModel, - providerKeys, - userAgent, - extraHeaders, - httpClient, - ) + model, err := p.newModel(ctx, req, route, opts) if err != nil { return nil, debugEnabled, err } - if model == nil { - return nil, debugEnabled, xerrors.Errorf( - "create model for %s/%s returned nil", - provider, - resolvedModel, - ) - } if !debugEnabled { return model, false, nil } return chatdebug.WrapModel(model, debugSvc, chatdebug.RecorderOptions{ - ChatID: chat.ID, - OwnerID: chat.OwnerID, + ChatID: req.Chat.ID, + OwnerID: req.Chat.OwnerID, Provider: provider, Model: resolvedModel, }), true, nil diff --git a/coderd/x/chatd/chatd_internal_test.go b/coderd/x/chatd/chatd_internal_test.go index b4d891bd06..b28c0cdb0c 100644 --- a/coderd/x/chatd/chatd_internal_test.go +++ b/coderd/x/chatd/chatd_internal_test.go @@ -152,10 +152,11 @@ func TestResolveComputerUseModel_OpenAIMissingCredentials(t *testing.T) { model, debugEnabled, resolvedProvider, resolvedModel, err := server.resolveComputerUseModel( context.Background(), database.Chat{ID: uuid.New(), OwnerID: uuid.New()}, - chatprovider.ProviderAPIKeys{}, + newDirectModelRoute(modelProvider, chatprovider.ProviderAPIKeys{}), provider, modelProvider, modelName, + modelBuildOptions{}, ) require.Error(t, err) require.Nil(t, model) @@ -167,6 +168,55 @@ func TestResolveComputerUseModel_OpenAIMissingCredentials(t *testing.T) { require.NotContains(t, err.Error(), "ANTHROPIC_API_KEY") } +func TestResolveUserProviderAPIKeysAndProviderForProviderTypeProviderMatch(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + ownerID := uuid.New() + providerID := uuid.New() + + db.EXPECT().GetAIProviders(gomock.Any(), database.GetAIProvidersParams{}).Return([]database.AIProvider{ + {ID: uuid.New(), Type: database.AiProviderTypeAnthropic, Enabled: true}, + {ID: providerID, Type: database.AiProviderTypeOpenai, Enabled: true}, + }, nil) + db.EXPECT().GetAIProviderKeysByProviderID(gomock.Any(), providerID).Return([]database.AIProviderKey{{ + ProviderID: providerID, + APIKey: "test-key", + }}, nil) + + server := &Server{db: db} + keys, aiProvider, err := server.resolveUserProviderAPIKeysAndProviderForProviderType( + ctx, + ownerID, + chattool.ComputerUseProviderOpenAI, + ) + require.NoError(t, err) + require.Equal(t, "test-key", keys.APIKey(chattool.ComputerUseProviderOpenAI)) + require.NotNil(t, aiProvider) + require.Equal(t, providerID, aiProvider.ID) + require.Equal(t, database.AiProviderTypeOpenai, aiProvider.Type) +} + +func TestResolveModelRouteForProviderTypeAIGatewayRequiresProvider(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + + db.EXPECT().GetAIProviders(gomock.Any(), database.GetAIProvidersParams{}).Return(nil, nil) + + server := &Server{db: db, aiGatewayRoutingEnabled: true} + _, err := server.resolveModelRouteForProviderType( + ctx, + uuid.New(), + chattool.ComputerUseProviderOpenAI, + ) + require.ErrorContains(t, err, "AI Gateway routing requires a usable AI provider") +} + func TestAppendComputerUseProviderTool(t *testing.T) { t.Parallel() @@ -731,6 +781,11 @@ func TestRenameChatTitle(t *testing.T) { }) } +func withChatMessageAPIKeyID(message database.ChatMessage, apiKeyID string) database.ChatMessage { + message.APIKeyID = sqlNullString(apiKeyID) + return message +} + func TestRegenerateChatTitle_PersistsAndBroadcasts(t *testing.T) { t.Parallel() @@ -749,6 +804,7 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts(t *testing.T) { modelConfigID := uuid.New() workerID := uuid.New() userPrompt := "review pull request 23633 and fix review threads" + activeAPIKeyID := "key-" + uuid.NewString() wantTitle := "Review PR 23633" chat := database.Chat{ @@ -814,12 +870,12 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts(t *testing.T) { LimitVal: manualTitleMessageWindowLimit, }, ).Return([]database.ChatMessage{ - mustChatMessage( + withChatMessageAPIKeyID(mustChatMessage( t, database.ChatMessageRoleUser, database.ChatMessageVisibilityBoth, codersdk.ChatMessageText(userPrompt), - ), + ), activeAPIKeyID), mustChatMessage( t, database.ChatMessageRoleAssistant, diff --git a/coderd/x/chatd/chatd_test.go b/coderd/x/chatd/chatd_test.go index d8be67eaae..9e67c10230 100644 --- a/coderd/x/chatd/chatd_test.go +++ b/coderd/x/chatd/chatd_test.go @@ -11,6 +11,7 @@ import ( "io" "net/http" "net/http/httptest" + "net/url" "os" "path/filepath" "slices" @@ -32,6 +33,7 @@ import ( "cdr.dev/slog/v3/sloggers/slogtest" "github.com/coder/coder/v2/agent/agentcontextconfig" "github.com/coder/coder/v2/agent/agenttest" + "github.com/coder/coder/v2/coderd/aibridge" "github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/db2sdk" @@ -67,6 +69,88 @@ type recordedOpenAIRequest struct { ContentLength int64 } +type chatAIGatewayRecordedRequest struct { + ProviderName string + Source aibridge.Source + APIKeyID string + Path string + Authorization string + XAPIKey string + CoderToken string +} + +type chatAIGatewayTestFactory struct { + target *url.URL + transport http.RoundTripper + mu sync.Mutex + requests []chatAIGatewayRecordedRequest +} + +func newChatAIGatewayTestFactory(t testing.TB, targetBaseURL string) *chatAIGatewayTestFactory { + t.Helper() + + target, err := url.Parse(targetBaseURL) + require.NoError(t, err) + return &chatAIGatewayTestFactory{target: target, transport: http.DefaultTransport} +} + +func (f *chatAIGatewayTestFactory) TransportFor(providerName string, source aibridge.Source) (http.RoundTripper, error) { + return chatAIGatewayRoundTripper{factory: f, providerName: providerName, source: source}, nil +} + +func (f *chatAIGatewayTestFactory) requestsSnapshot() []chatAIGatewayRecordedRequest { + f.mu.Lock() + defer f.mu.Unlock() + return slices.Clone(f.requests) +} + +type chatAIGatewayRoundTripper struct { + factory *chatAIGatewayTestFactory + providerName string + source aibridge.Source +} + +func (t chatAIGatewayRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + apiKeyID, _ := aibridge.DelegatedAPIKeyIDFromContext(req.Context()) + t.factory.mu.Lock() + t.factory.requests = append(t.factory.requests, chatAIGatewayRecordedRequest{ + ProviderName: t.providerName, + Source: t.source, + APIKeyID: apiKeyID, + Path: req.URL.Path, + Authorization: req.Header.Get("Authorization"), + XAPIKey: req.Header.Get("X-Api-Key"), + CoderToken: req.Header.Get(aibridge.HeaderCoderToken), + }) + t.factory.mu.Unlock() + + targetURL := *t.factory.target + targetURL.Path = strings.TrimPrefix(req.URL.Path, "/v1") + if targetURL.Path == "" { + targetURL.Path = "/" + } + targetURL.RawQuery = req.URL.RawQuery + + cloned := req.Clone(req.Context()) + cloned.URL = &targetURL + cloned.Host = t.factory.target.Host + return t.factory.transport.RoundTrip(cloned) +} + +func chatAIGatewayTransportFactoryPointer(factory aibridge.TransportFactory) *atomic.Pointer[aibridge.TransportFactory] { + var ptr atomic.Pointer[aibridge.TransportFactory] + ptr.Store(&factory) + return &ptr +} + +func directChatRoutingDeploymentValues(t testing.TB) *codersdk.DeploymentValues { + t.Helper() + + values := coderdtest.DeploymentValues(t) + require.NoError(t, values.AI.Chat.AIGatewayRoutingEnabled.Set("false")) + return values +} + func openAIToolName(tool chattest.OpenAITool) string { return cmp.Or(tool.Function.Name, tool.Name, tool.Type) } @@ -231,7 +315,7 @@ func TestSubagentChatExcludesWorkspaceProvisioningTools(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) - deploymentValues := coderdtest.DeploymentValues(t) + deploymentValues := directChatRoutingDeploymentValues(t) client := coderdtest.New(t, &coderdtest.Options{ DeploymentValues: deploymentValues, IncludeProvisionerDaemon: true, @@ -388,7 +472,7 @@ func TestPlanModeSubagentChatExcludesAskUserQuestion(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) - deploymentValues := coderdtest.DeploymentValues(t) + deploymentValues := directChatRoutingDeploymentValues(t) client := coderdtest.New(t, &coderdtest.Options{ DeploymentValues: deploymentValues, IncludeProvisionerDaemon: true, @@ -555,7 +639,7 @@ func TestExploreSubagentIsReadOnly(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) - deploymentValues := coderdtest.DeploymentValues(t) + deploymentValues := directChatRoutingDeploymentValues(t) client, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{ DeploymentValues: deploymentValues, IncludeProvisionerDaemon: true, @@ -1796,6 +1880,73 @@ func TestUpdateChatHeartbeatsRequiresOwnership(t *testing.T) { require.Equal(t, chat.ID, ids[0]) } +func TestCreateChatPersistsAPIKeyIDOnInitialUserMessage(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + replica := newTestServer(t, db, ps, uuid.New()) + + ctx := testutil.Context(t, testutil.WaitLong) + user, org, model := seedChatDependencies(t, db) + apiKey, _ := dbgen.APIKey(t, db, database.APIKey{UserID: user.ID}) + + chat, err := replica.CreateChat(ctx, chatd.CreateOptions{ + OrganizationID: org.ID, + OwnerID: user.ID, + Title: "create-chat-api-key-id", + ModelConfigID: model.ID, + InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")}, + APIKeyID: apiKey.ID, + }) + require.NoError(t, err) + + messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ + ChatID: chat.ID, + AfterID: 0, + }) + require.NoError(t, err) + require.Len(t, messages, 1) + require.Equal(t, database.ChatMessageRoleUser, messages[0].Role) + require.True(t, messages[0].APIKeyID.Valid) + require.Equal(t, apiKey.ID, messages[0].APIKeyID.String) +} + +func TestSendMessagePersistsAPIKeyIDOnUserMessage(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + replica := newTestServer(t, db, ps, uuid.New()) + + ctx := testutil.Context(t, testutil.WaitLong) + user, org, model := seedChatDependencies(t, db) + apiKey, _ := dbgen.APIKey(t, db, database.APIKey{UserID: user.ID}) + + chat := dbgen.Chat(t, db, database.Chat{ + OrganizationID: org.ID, + OwnerID: user.ID, + LastModelConfigID: model.ID, + Title: "send-message-api-key-id", + }) + + result, err := replica.SendMessage(ctx, chatd.SendMessageOptions{ + ChatID: chat.ID, + CreatedBy: user.ID, + Content: []codersdk.ChatMessagePart{ + codersdk.ChatMessageText("message with api key id"), + }, + APIKeyID: apiKey.ID, + }) + require.NoError(t, err) + require.False(t, result.Queued) + require.True(t, result.Message.APIKeyID.Valid) + require.Equal(t, apiKey.ID, result.Message.APIKeyID.String) + + stored, err := db.GetChatMessageByID(ctx, result.Message.ID) + require.NoError(t, err) + require.True(t, stored.APIKeyID.Valid) + require.Equal(t, apiKey.ID, stored.APIKeyID.String) +} + func TestSendMessageQueueBehaviorQueuesWhenBusy(t *testing.T) { t.Parallel() @@ -2143,15 +2294,20 @@ func TestEditMessageUpdatesAndTruncatesAndClearsQueue(t *testing.T) { }) require.NoError(t, err) + apiKey, _ := dbgen.APIKey(t, db, database.APIKey{UserID: user.ID}) + apiKeyID := apiKey.ID editResult, err := replica.EditMessage(ctx, chatd.EditMessageOptions{ ChatID: chat.ID, EditedMessageID: editedMessageID, Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("edited")}, + APIKeyID: apiKeyID, }) require.NoError(t, err) // The edited message is soft-deleted and a new message is inserted, // so the returned message ID will differ from the original. require.NotEqual(t, editedMessageID, editResult.Message.ID) + require.True(t, editResult.Message.APIKeyID.Valid) + require.Equal(t, apiKeyID, editResult.Message.APIKeyID.String) require.Equal(t, database.ChatStatusPending, editResult.Chat.Status) require.False(t, editResult.Chat.WorkerID.Valid) @@ -2166,6 +2322,8 @@ func TestEditMessageUpdatesAndTruncatesAndClearsQueue(t *testing.T) { require.NoError(t, err) require.Len(t, messages, 1) require.Equal(t, editResult.Message.ID, messages[0].ID) + require.True(t, messages[0].APIKeyID.Valid) + require.Equal(t, apiKeyID, messages[0].APIKeyID.String) onlyMessage := db2sdk.ChatMessage(messages[0]) require.Len(t, onlyMessage.Content, 1) require.Equal(t, "edited", onlyMessage.Content[0].Text) @@ -2366,6 +2524,7 @@ func TestPromoteQueuedAllowsAlreadyQueuedMessageWhenUsageLimitReached(t *testing ctx := testutil.Context(t, testutil.WaitLong) user, org, model := seedChatDependencies(t, db) + apiKey, _ := dbgen.APIKey(t, db, database.APIKey{UserID: user.ID}) _, err := db.UpsertChatUsageLimitConfig(ctx, database.UpsertChatUsageLimitConfigParams{ Enabled: true, @@ -2400,11 +2559,14 @@ func TestPromoteQueuedAllowsAlreadyQueuedMessageWhenUsageLimitReached(t *testing queuedResult, err := replica.SendMessage(ctx, chatd.SendMessageOptions{ ChatID: chat.ID, Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("queued")}, + APIKeyID: apiKey.ID, BusyBehavior: chatd.SendMessageBusyBehaviorQueue, }) require.NoError(t, err) require.True(t, queuedResult.Queued) require.NotNil(t, queuedResult.QueuedMessage) + require.True(t, queuedResult.QueuedMessage.APIKeyID.Valid) + require.Equal(t, apiKey.ID, queuedResult.QueuedMessage.APIKeyID.String) assistantContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{ codersdk.ChatMessageText("assistant"), @@ -2437,6 +2599,8 @@ func TestPromoteQueuedAllowsAlreadyQueuedMessageWhenUsageLimitReached(t *testing }) require.NoError(t, err) require.Equal(t, database.ChatMessageRoleUser, result.PromotedMessage.Role) + require.True(t, result.PromotedMessage.APIKeyID.Valid) + require.Equal(t, apiKey.ID, result.PromotedMessage.APIKeyID.String) queued, err := db.GetChatQueuedMessages(ctx, chat.ID) require.NoError(t, err) @@ -4864,7 +5028,7 @@ func TestCreateWorkspaceTool_EndToEnd(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitLong) - deploymentValues := coderdtest.DeploymentValues(t) + deploymentValues := directChatRoutingDeploymentValues(t) client := coderdtest.New(t, &coderdtest.Options{ DeploymentValues: deploymentValues, IncludeProvisionerDaemon: true, @@ -5029,7 +5193,7 @@ func TestStartWorkspaceTool_EndToEnd(t *testing.T) { t.Parallel() ctx := testutil.Context(t, testutil.WaitSuperLong) - deploymentValues := coderdtest.DeploymentValues(t) + deploymentValues := directChatRoutingDeploymentValues(t) client := coderdtest.New(t, &coderdtest.Options{ DeploymentValues: deploymentValues, IncludeProvisionerDaemon: true, @@ -7252,6 +7416,100 @@ func TestProcessChat_UserProviderKey_Success(t *testing.T) { require.Contains(t, recordedAuthHeaders, "Bearer "+userAPIKey) } +func TestProcessChat_AIGatewayRoutingUsesDelegatedAPIKey(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + ctx := testutil.Context(t, testutil.WaitLong) + + openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse { + if req.Stream { + return chattest.OpenAIStreamingResponse( + chattest.OpenAITextChunks("hello through AI Gateway")..., + ) + } + return chattest.OpenAINonStreamingResponse(`{"title":"AI Gateway Chat"}`) + }) + factory := newChatAIGatewayTestFactory(t, openAIURL) + + user := dbgen.User(t, db, database.User{}) + org := dbgen.Organization(t, db, database.Organization{}) + dbgen.OrganizationMember(t, db, database.OrganizationMember{ + UserID: user.ID, + OrganizationID: org.ID, + }) + provider := dbgen.AIProvider(t, db, database.AIProvider{ + Type: database.AiProviderTypeOpenai, + Name: "primary-openai-" + uuid.NewString(), + BaseUrl: openAIURL, + }) + model := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{ + Provider: string(database.AiProviderTypeOpenai), + Model: "gpt-4o-mini", + IsDefault: true, + AIProviderID: uuid.NullUUID{UUID: provider.ID, Valid: true}, + }) + apiKey, _ := dbgen.APIKey(t, db, database.APIKey{UserID: user.ID}) + _, err := db.UpsertUserAIProviderKey(ctx, database.UpsertUserAIProviderKeyParams{ + ID: uuid.New(), + UserID: user.ID, + AIProviderID: provider.ID, + APIKey: "sk-user-aibridge", + }) + require.NoError(t, err) + + creator := newTestServer(t, db, ps, uuid.New()) + chat, err := creator.CreateChat(ctx, chatd.CreateOptions{ + OrganizationID: org.ID, + OwnerID: user.ID, + Title: "aigateway-routing", + ModelConfigID: model.ID, + APIKeyID: apiKey.ID, + InitialUserContent: []codersdk.ChatMessagePart{ + codersdk.ChatMessageText("say hello"), + }, + }) + require.NoError(t, err) + + _, events, cancel, ok := creator.Subscribe(ctx, chat.ID, nil, 0) + require.True(t, ok) + t.Cleanup(cancel) + + _ = newActiveTestServer(t, db, ps, func(cfg *chatd.Config) { + cfg.AIBridgeTransportFactory = chatAIGatewayTransportFactoryPointer(factory) + cfg.AIGatewayRoutingEnabled = true + cfg.AllowBYOK = true + cfg.AllowBYOKSet = true + }) + + terminalStatus := waitForTerminalChatStatusEvent(ctx, t, events) + require.Equal(t, codersdk.ChatStatusWaiting, terminalStatus) + + chatResult := waitForTerminalChat(ctx, t, db, chat.ID) + require.Equal(t, database.ChatStatusWaiting, chatResult.Status) + require.False(t, chatResult.LastError.Valid) + + requests := factory.requestsSnapshot() + require.NotEmpty(t, requests) + require.Contains(t, requests, chatAIGatewayRecordedRequest{ + ProviderName: provider.Name, + Source: aibridge.SourceAgents, + APIKeyID: apiKey.ID, + Path: "/v1/responses", + Authorization: "Bearer sk-user-aibridge", + CoderToken: "delegated", + }) + for _, req := range requests { + require.Equal(t, provider.Name, req.ProviderName) + require.Equal(t, aibridge.SourceAgents, req.Source) + require.Equal(t, apiKey.ID, req.APIKeyID) + require.Equal(t, "Bearer sk-user-aibridge", req.Authorization) + require.Empty(t, req.XAPIKey) + require.Equal(t, "delegated", req.CoderToken) + require.True(t, strings.HasPrefix(req.Path, "/v1/"), "unexpected aibridge path %q", req.Path) + } +} + func TestProcessChat_UserProviderKey_MissingKeyError(t *testing.T) { t.Parallel() @@ -8493,7 +8751,7 @@ func TestAgentContextFilesAndSkillsLoadedIntoChat(t *testing.T) { )) ctx := testutil.Context(t, testutil.WaitSuperLong) - deploymentValues := coderdtest.DeploymentValues(t) + deploymentValues := directChatRoutingDeploymentValues(t) client := coderdtest.New(t, &coderdtest.Options{ DeploymentValues: deploymentValues, IncludeProvisionerDaemon: true, diff --git a/coderd/x/chatd/computer_use.go b/coderd/x/chatd/computer_use.go index d41214f558..05bbcd4285 100644 --- a/coderd/x/chatd/computer_use.go +++ b/coderd/x/chatd/computer_use.go @@ -60,10 +60,11 @@ func (p *Server) computerUseProviderAndModelFromConfig( func (p *Server) resolveComputerUseModel( ctx context.Context, chat database.Chat, - providerKeys chatprovider.ProviderAPIKeys, + route resolvedModelRoute, computerUseProvider string, computerUseModelProvider string, computerUseModelName string, + modelOpts modelBuildOptions, ) ( model fantasy.LanguageModel, debugEnabled bool, @@ -84,15 +85,12 @@ func (p *Server) resolveComputerUseModel( ) } - model, debugEnabled, err = p.newDebugAwareModelFromConfig( - ctx, - chat, - computerUseModelProvider, - computerUseModelName, - providerKeys, - chatprovider.UserAgent(), - chatprovider.CoderHeaders(chat), - ) + model, debugEnabled, err = p.newDebugAwareModel(ctx, modelClientRequest{ + Chat: chat, + ModelName: computerUseModelName, + UserAgent: chatprovider.UserAgent(), + ExtraHeaders: chatprovider.CoderHeaders(chat), + }, route, modelOpts) if err != nil { return nil, false, "", "", xerrors.Errorf( "resolve computer use model for provider %q model %q: %w", diff --git a/coderd/x/chatd/model_routing.go b/coderd/x/chatd/model_routing.go new file mode 100644 index 0000000000..c5fa7129db --- /dev/null +++ b/coderd/x/chatd/model_routing.go @@ -0,0 +1,168 @@ +package chatd + +import ( + "context" + "net/http" + + "charm.land/fantasy" + "github.com/google/uuid" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/x/chatd/chatprovider" +) + +type modelClientRequest struct { + Chat database.Chat + ModelName string + UserAgent string + ExtraHeaders map[string]string +} + +type modelBuildOptions struct { + ActiveAPIKeyID string + RecordHTTP bool +} + +func modelBuildOptionsFromMessages(messages []database.ChatMessage) modelBuildOptions { + apiKeyID, _ := activeTurnAPIKeyIDFromMessages(messages) + return modelBuildOptions{ActiveAPIKeyID: apiKeyID} +} + +type modelRouteKind int + +const ( + modelRouteKindDirect modelRouteKind = iota + 1 + modelRouteKindAIGateway +) + +type resolvedModelRoute struct { + kind modelRouteKind + direct directModelRoute + aiGateway aiGatewayModelRoute +} + +func newDirectModelRoute(providerHint string, keys chatprovider.ProviderAPIKeys) resolvedModelRoute { + return resolvedModelRoute{ + kind: modelRouteKindDirect, + direct: directModelRoute{ + ProviderHint: providerHint, + Keys: keys, + }, + } +} + +func (r resolvedModelRoute) providerHint() (string, error) { + switch r.kind { + case modelRouteKindDirect: + return r.direct.ProviderHint, nil + case modelRouteKindAIGateway: + return r.aiGateway.ModelProviderHint, nil + default: + return "", xerrors.New("model route is not configured") + } +} + +func (r resolvedModelRoute) withProviderHint(providerHint string) resolvedModelRoute { + switch r.kind { + case modelRouteKindDirect: + r.direct.ProviderHint = providerHint + case modelRouteKindAIGateway: + r.aiGateway.ModelProviderHint = providerHint + } + return r +} + +func (r resolvedModelRoute) directProviderKeys() chatprovider.ProviderAPIKeys { + if r.kind != modelRouteKindDirect { + return chatprovider.ProviderAPIKeys{} + } + return r.direct.Keys +} + +func (p *Server) enabledAIProviderByID(ctx context.Context, providerID uuid.UUID) (database.AIProvider, error) { + provider, err := p.db.GetAIProviderByID(ctx, providerID) + if err != nil { + return database.AIProvider{}, xerrors.Errorf("get AI provider: %w", err) + } + if !provider.Enabled { + return database.AIProvider{}, xerrors.Errorf("AI provider %s is disabled", provider.ID) + } + return provider, nil +} + +func (p *Server) shouldUseAIGatewayRouting() bool { + return p.aiGatewayRoutingEnabled +} + +func (p *Server) resolveModelRouteForConfig( + ctx context.Context, + ownerID uuid.UUID, + modelConfig database.ChatModelConfig, + fallbackKeys chatprovider.ProviderAPIKeys, +) (resolvedModelRoute, error) { + if p.shouldUseAIGatewayRouting() { + return p.resolveAIGatewayModelRouteForConfig(ctx, ownerID, modelConfig) + } + return p.resolveDirectModelRouteForConfig(ctx, ownerID, modelConfig, fallbackKeys) +} + +func (p *Server) resolveModelRouteForProviderType( + ctx context.Context, + ownerID uuid.UUID, + providerType string, +) (resolvedModelRoute, error) { + if p.shouldUseAIGatewayRouting() { + return p.resolveAIGatewayModelRouteForProviderType(ctx, ownerID, providerType) + } + return p.resolveDirectModelRouteForProviderType(ctx, ownerID, providerType) +} + +func (p *Server) newModel( + ctx context.Context, + req modelClientRequest, + route resolvedModelRoute, + opts modelBuildOptions, +) (fantasy.LanguageModel, error) { + switch route.kind { + case modelRouteKindDirect: + return p.newDirectModel(ctx, req, route.direct, opts) + case modelRouteKindAIGateway: + return p.newAIGatewayModel(ctx, req, route.aiGateway, opts) + default: + return nil, xerrors.New("model route is not configured") + } +} + +func newLanguageModel( + providerHint string, + modelName string, + providerKeys chatprovider.ProviderAPIKeys, + userAgent string, + extraHeaders map[string]string, + httpClient *http.Client, +) (fantasy.LanguageModel, error) { + model, err := chatprovider.ModelFromConfig( + providerHint, + modelName, + providerKeys, + userAgent, + extraHeaders, + httpClient, + ) + if err != nil { + return nil, err + } + if model == nil { + provider, resolvedModel, resolveErr := chatprovider.ResolveModelWithProviderHint(modelName, providerHint) + if resolveErr != nil { + return nil, resolveErr + } + return nil, xerrors.Errorf( + "create model for %s/%s returned nil", + provider, + resolvedModel, + ) + } + return model, nil +} diff --git a/coderd/x/chatd/model_routing_aibridge.go b/coderd/x/chatd/model_routing_aibridge.go new file mode 100644 index 0000000000..5db1a16e53 --- /dev/null +++ b/coderd/x/chatd/model_routing_aibridge.go @@ -0,0 +1,292 @@ +package chatd + +import ( + "context" + "database/sql" + "net/http" + "strings" + + "charm.land/fantasy" + fantasyanthropic "charm.land/fantasy/providers/anthropic" + fantasyopenai "charm.land/fantasy/providers/openai" + fantasyopenaicompat "charm.land/fantasy/providers/openaicompat" + "github.com/google/uuid" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/aibridge" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/x/chatd/chatdebug" + "github.com/coder/coder/v2/coderd/x/chatd/chatprovider" +) + +const ( + aibridgeLocalBaseURL = "http://coder-aibridge" + // aibridgePlaceholderAPIKey satisfies fantasy clients that require a + // non-empty API key before aibridged resolves the real credential. + aibridgePlaceholderAPIKey = "coder-aibridge" + aibridgeDelegatedBYOKMarker = "delegated" +) + +type aiGatewayModelRoute struct { + Provider database.AIProvider + ModelProviderHint string + ProviderAuth aiGatewayProviderAuth +} + +func newAIGatewayModelRoute( + provider database.AIProvider, + modelProviderHint string, + auth aiGatewayProviderAuth, +) resolvedModelRoute { + return resolvedModelRoute{ + kind: modelRouteKindAIGateway, + aiGateway: aiGatewayModelRoute{ + Provider: provider, + ModelProviderHint: modelProviderHint, + ProviderAuth: auth, + }, + } +} + +type aiGatewayProviderAuth struct { + Headers map[string]string +} + +func (aiGatewayProviderAuth) String() string { + return "aiGatewayProviderAuth{Headers:}" +} + +func (a aiGatewayProviderAuth) GoString() string { + return a.String() +} + +type aiGatewayRequestFormat int + +const ( + aiGatewayRequestFormatOpenAI aiGatewayRequestFormat = iota + aiGatewayRequestFormatAnthropic +) + +type aiGatewayRoundTripper struct { + base http.RoundTripper + apiKeyID string + providerAuth aiGatewayProviderAuth +} + +func (t *aiGatewayRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + ctx := aibridge.WithDelegatedAPIKeyID(req.Context(), t.apiKeyID) + cloned := req.Clone(ctx) + for name, value := range t.providerAuth.Headers { + cloned.Header.Set(name, value) + } + if len(t.providerAuth.Headers) > 0 { + cloned.Header.Set(aibridge.HeaderCoderToken, aibridgeDelegatedBYOKMarker) + } + return t.base.RoundTrip(cloned) +} + +func (p *Server) newAIGatewayModel( + _ context.Context, + req modelClientRequest, + route aiGatewayModelRoute, + opts modelBuildOptions, +) (fantasy.LanguageModel, error) { + if route.Provider.ID == uuid.Nil { + return nil, xerrors.New("AI Gateway routing requires a concrete AI provider") + } + if route.Provider.Name == "" { + return nil, xerrors.New("AI Gateway routing requires an AI provider name") + } + if opts.ActiveAPIKeyID == "" { + return nil, xerrors.New("AI Gateway routing requires the active turn API key ID") + } + + factoryPtr := p.aibridgeTransportFactory + if factoryPtr == nil { + return nil, xerrors.New("AI Gateway transport factory is not configured") + } + factory := factoryPtr.Load() + if factory == nil || *factory == nil { + return nil, xerrors.New("AI Gateway transport factory is not configured") + } + rt, err := (*factory).TransportFor(route.Provider.Name, aibridge.SourceAgents) + if err != nil { + return nil, xerrors.Errorf("create AI Gateway transport: %w", err) + } + baseRT := http.RoundTripper(&aiGatewayRoundTripper{ + base: rt, + apiKeyID: opts.ActiveAPIKeyID, + providerAuth: route.ProviderAuth, + }) + if opts.RecordHTTP { + baseRT = &chatdebug.RecordingTransport{Base: baseRT} + } + + config := fantasyConfigForAIBridge(route.Provider.Type) + return newLanguageModel( + config.ProviderHint, + req.ModelName, + config.Keys, + req.UserAgent, + req.ExtraHeaders, + &http.Client{Transport: baseRT}, + ) +} + +type aibridgeFantasyConfig struct { + ProviderHint string + Keys chatprovider.ProviderAPIKeys +} + +func fantasyConfigForAIBridge(providerType database.AIProviderType) aibridgeFantasyConfig { + var fantasyProvider string + baseURL := aibridgeLocalBaseURL + "/v1" + switch providerType { + case database.AiProviderTypeAnthropic, database.AiProviderTypeBedrock: + fantasyProvider = fantasyanthropic.Name + baseURL = aibridgeLocalBaseURL + case database.AiProviderTypeOpenai: + fantasyProvider = fantasyopenai.Name + default: + fantasyProvider = fantasyopenaicompat.Name + } + return aibridgeFantasyConfig{ + ProviderHint: fantasyProvider, + Keys: chatprovider.ProviderAPIKeys{ + ByProvider: map[string]string{ + fantasyProvider: aibridgePlaceholderAPIKey, + }, + BaseURLByProvider: map[string]string{ + fantasyProvider: baseURL, + }, + }, + } +} + +func aiGatewayRequestFormatForProviderType(providerType database.AIProviderType) aiGatewayRequestFormat { + switch providerType { + case database.AiProviderTypeAnthropic, database.AiProviderTypeBedrock: + return aiGatewayRequestFormatAnthropic + default: + return aiGatewayRequestFormatOpenAI + } +} + +func (p *Server) aiGatewayProviderAuthForUser( + ctx context.Context, + ownerID uuid.UUID, + provider database.AIProvider, + format aiGatewayRequestFormat, +) (aiGatewayProviderAuth, error) { + if !p.allowBYOK { + return aiGatewayProviderAuth{}, nil + } + userKey, err := p.db.GetUserAIProviderKeyByProviderID(ctx, database.GetUserAIProviderKeyByProviderIDParams{ + UserID: ownerID, + AIProviderID: provider.ID, + }) + if err != nil { + if xerrors.Is(err, sql.ErrNoRows) { + return aiGatewayProviderAuth{}, nil + } + return aiGatewayProviderAuth{}, xerrors.Errorf("get user AI provider key: %w", err) + } + apiKey := strings.TrimSpace(userKey.APIKey) + if apiKey == "" { + return aiGatewayProviderAuth{}, nil + } + + headers := map[string]string{} + switch format { + case aiGatewayRequestFormatAnthropic: + headers["X-Api-Key"] = apiKey + default: + headers["Authorization"] = "Bearer " + apiKey + } + return aiGatewayProviderAuth{Headers: headers}, nil +} + +func (p *Server) resolveAIGatewayRoute( + ctx context.Context, + ownerID uuid.UUID, + provider database.AIProvider, + modelProviderHint string, +) (resolvedModelRoute, error) { + auth, err := p.aiGatewayProviderAuthForUser( + ctx, + ownerID, + provider, + aiGatewayRequestFormatForProviderType(provider.Type), + ) + if err != nil { + return resolvedModelRoute{}, xerrors.Errorf("resolve AI Gateway provider auth: %w", err) + } + return newAIGatewayModelRoute(provider, modelProviderHint, auth), nil +} + +func (p *Server) resolveAIGatewayModelRouteForConfig( + ctx context.Context, + ownerID uuid.UUID, + modelConfig database.ChatModelConfig, +) (resolvedModelRoute, error) { + provider, err := p.gatewayProviderForConfig(ctx, modelConfig) + if err != nil { + return resolvedModelRoute{}, err + } + return p.resolveAIGatewayRoute(ctx, ownerID, provider, string(provider.Type)) +} + +func (p *Server) resolveAIGatewayModelRouteForProviderType( + ctx context.Context, + ownerID uuid.UUID, + providerType string, +) (resolvedModelRoute, error) { + provider, err := p.aiProviderForProviderType(ctx, providerType) + if err != nil { + return resolvedModelRoute{}, err + } + return p.resolveAIGatewayRoute( + ctx, + ownerID, + provider, + chatprovider.NormalizeProvider(providerType), + ) +} + +func (p *Server) gatewayProviderForConfig( + ctx context.Context, + modelConfig database.ChatModelConfig, +) (database.AIProvider, error) { + if !modelConfig.AIProviderID.Valid { + return database.AIProvider{}, xerrors.Errorf( + "AI Gateway routing requires AI provider metadata for model config %s (%s)", + modelConfig.ID, + modelConfig.Model, + ) + } + return p.enabledAIProviderByID(ctx, modelConfig.AIProviderID.UUID) +} + +func (p *Server) aiProviderForProviderType( + ctx context.Context, + providerType string, +) (database.AIProvider, error) { + providers, err := p.db.GetAIProviders(ctx, database.GetAIProvidersParams{}) + if err != nil { + return database.AIProvider{}, xerrors.Errorf("get enabled AI providers: %w", err) + } + normalizedProviderType := chatprovider.NormalizeProvider(providerType) + for _, provider := range providers { + if !provider.Enabled { + continue + } + if chatprovider.NormalizeProvider(string(provider.Type)) != normalizedProviderType { + continue + } + return provider, nil + } + return database.AIProvider{}, xerrors.Errorf( + "AI Gateway routing requires a usable AI provider for provider type %q", + providerType, + ) +} diff --git a/coderd/x/chatd/model_routing_direct.go b/coderd/x/chatd/model_routing_direct.go new file mode 100644 index 0000000000..8173aa75c9 --- /dev/null +++ b/coderd/x/chatd/model_routing_direct.go @@ -0,0 +1,93 @@ +package chatd + +import ( + "context" + "net/http" + + "charm.land/fantasy" + "github.com/google/uuid" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/x/chatd/chatdebug" + "github.com/coder/coder/v2/coderd/x/chatd/chatprovider" +) + +type directModelRoute struct { + ProviderHint string + Keys chatprovider.ProviderAPIKeys +} + +func (*Server) newDirectModel( + _ context.Context, + req modelClientRequest, + route directModelRoute, + opts modelBuildOptions, +) (fantasy.LanguageModel, error) { + var httpClient *http.Client + if opts.RecordHTTP { + httpClient = &http.Client{Transport: &chatdebug.RecordingTransport{}} + } + return newLanguageModel( + route.ProviderHint, + req.ModelName, + route.Keys, + req.UserAgent, + req.ExtraHeaders, + httpClient, + ) +} + +func (p *Server) resolveDirectModelRouteForConfig( + ctx context.Context, + ownerID uuid.UUID, + modelConfig database.ChatModelConfig, + fallbackKeys chatprovider.ProviderAPIKeys, +) (resolvedModelRoute, error) { + providerHint, provider, err := p.directProviderHintAndProviderForConfig(ctx, modelConfig) + if err != nil { + return resolvedModelRoute{}, err + } + if provider == nil { + if !fallbackKeys.Empty() && userCanUseProviderKeys(fallbackKeys, providerHint) { + return newDirectModelRoute(providerHint, fallbackKeys), nil + } + keys, err := p.resolveUserProviderAPIKeys(ctx, ownerID, uuid.Nil) + if err != nil { + return resolvedModelRoute{}, xerrors.Errorf("resolve provider API keys: %w", err) + } + return newDirectModelRoute(providerHint, keys), nil + } + providerKeys, err := p.resolveUserProviderAPIKeysForProvider(ctx, ownerID, *provider) + if err != nil { + return resolvedModelRoute{}, xerrors.Errorf("resolve provider API keys: %w", err) + } + return newDirectModelRoute(providerHint, providerKeys), nil +} + +func (p *Server) resolveDirectModelRouteForProviderType( + ctx context.Context, + ownerID uuid.UUID, + providerType string, +) (resolvedModelRoute, error) { + normalizedProviderType := chatprovider.NormalizeProvider(providerType) + keys, _, err := p.resolveUserProviderAPIKeysAndProviderForProviderType(ctx, ownerID, providerType) + if err != nil { + return resolvedModelRoute{}, err + } + return newDirectModelRoute(normalizedProviderType, keys), nil +} + +func (p *Server) directProviderHintAndProviderForConfig( + ctx context.Context, + modelConfig database.ChatModelConfig, +) (string, *database.AIProvider, error) { + if !modelConfig.AIProviderID.Valid { + return modelConfig.Provider, nil, nil + } + provider, err := p.enabledAIProviderByID(ctx, modelConfig.AIProviderID.UUID) + if err != nil { + return "", nil, err + } + return string(provider.Type), &provider, nil +} diff --git a/coderd/x/chatd/model_routing_internal_test.go b/coderd/x/chatd/model_routing_internal_test.go new file mode 100644 index 0000000000..786365d9fb --- /dev/null +++ b/coderd/x/chatd/model_routing_internal_test.go @@ -0,0 +1,640 @@ +package chatd + +import ( + "database/sql" + "fmt" + "io" + "net/http" + "strings" + "sync/atomic" + "testing" + + "charm.land/fantasy" + "github.com/google/uuid" + "github.com/stretchr/testify/require" + "go.uber.org/mock/gomock" + "golang.org/x/xerrors" + + "github.com/coder/coder/v2/coderd/aibridge" + "github.com/coder/coder/v2/coderd/database" + "github.com/coder/coder/v2/coderd/database/dbgen" + "github.com/coder/coder/v2/coderd/database/dbmock" + "github.com/coder/coder/v2/coderd/database/dbtestutil" + "github.com/coder/coder/v2/coderd/x/chatd/chatprovider" + "github.com/coder/coder/v2/coderd/x/chatd/chattool" +) + +type aibridgeTestFactory struct { + providerName string + source aibridge.Source + err error + rt http.RoundTripper +} + +func (f *aibridgeTestFactory) TransportFor(providerName string, source aibridge.Source) (http.RoundTripper, error) { + f.providerName = providerName + f.source = source + if f.err != nil { + return nil, f.err + } + return f.rt, nil +} + +type roundTripFunc func(*http.Request) (*http.Response, error) + +func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +func aibridgeTestFactoryPointer(factory aibridge.TransportFactory) *atomic.Pointer[aibridge.TransportFactory] { + var ptr atomic.Pointer[aibridge.TransportFactory] + ptr.Store(&factory) + return &ptr +} + +func aibridgeTestAIProvider(providerID uuid.UUID, providerName string, providerType database.AIProviderType) database.AIProvider { + return database.AIProvider{ + ID: providerID, + Name: providerName, + Type: providerType, + Enabled: true, + } +} + +func aibridgeTestRoute(aiProvider database.AIProvider) resolvedModelRoute { + return newAIGatewayModelRoute(aiProvider, string(aiProvider.Type), aiGatewayProviderAuth{}) +} + +func aibridgeTestRequest(chat database.Chat, model string) modelClientRequest { + return modelClientRequest{ + Chat: chat, + ModelName: model, + UserAgent: chatprovider.UserAgent(), + } +} + +func TestAIBridgeProviderFormatMapping(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + providerType database.AIProviderType + wantProvider string + wantBaseURL string + }{ + {name: "OpenAI", providerType: database.AiProviderTypeOpenai, wantProvider: "openai", wantBaseURL: "http://coder-aibridge/v1"}, + {name: "Anthropic", providerType: database.AiProviderTypeAnthropic, wantProvider: "anthropic", wantBaseURL: "http://coder-aibridge"}, + {name: "Bedrock", providerType: database.AiProviderTypeBedrock, wantProvider: "anthropic", wantBaseURL: "http://coder-aibridge"}, + {name: "Google", providerType: database.AiProviderTypeGoogle, wantProvider: "openai-compat", wantBaseURL: "http://coder-aibridge/v1"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + config := fantasyConfigForAIBridge(tt.providerType) + require.Equal(t, tt.wantProvider, config.ProviderHint) + require.Equal(t, tt.wantBaseURL, config.Keys.BaseURL(config.ProviderHint)) + require.Equal(t, aibridgePlaceholderAPIKey, config.Keys.APIKey(config.ProviderHint)) + }) + } +} + +func TestResolveModelRouteForConfigPreservesBaseURL(t *testing.T) { + t.Parallel() + + ctx := t.Context() + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + ownerID := uuid.New() + providerID := uuid.New() + baseURL := "https://openai.example.com/v1" + + db.EXPECT().GetAIProviderByID(gomock.Any(), providerID).Return(database.AIProvider{ + ID: providerID, + Type: database.AiProviderTypeOpenai, + Name: "primary-openai", + Enabled: true, + BaseUrl: baseURL, + }, nil) + db.EXPECT().GetAIProviderKeysByProviderID(gomock.Any(), providerID).Return([]database.AIProviderKey{{ + ProviderID: providerID, + APIKey: "provider-key", + }}, nil) + + server := &Server{db: db} + route, err := server.resolveModelRouteForConfig(ctx, ownerID, database.ChatModelConfig{ + Provider: "openai", + AIProviderID: uuid.NullUUID{UUID: providerID, Valid: true}, + }, chatprovider.ProviderAPIKeys{}) + require.NoError(t, err) + require.Equal(t, modelRouteKindDirect, route.kind) + require.Equal(t, "openai", route.direct.ProviderHint) + require.Equal(t, "provider-key", route.direct.Keys.APIKey("openai")) + require.Equal(t, baseURL, route.direct.Keys.BaseURL("openai")) +} + +func TestAIGatewayProviderAuthForUser(t *testing.T) { + t.Parallel() + + ctx := t.Context() + ownerID := uuid.New() + providerID := uuid.New() + provider := database.AIProvider{ID: providerID, Type: database.AiProviderTypeOpenai, Enabled: true} + + t.Run("OpenAIUserKey", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + db.EXPECT().GetUserAIProviderKeyByProviderID(gomock.Any(), database.GetUserAIProviderKeyByProviderIDParams{ + UserID: ownerID, + AIProviderID: providerID, + }).Return(database.UserAiProviderKey{APIKey: "sk-user"}, nil) + + server := &Server{db: db, allowBYOK: true} + auth, err := server.aiGatewayProviderAuthForUser(ctx, ownerID, provider, aiGatewayRequestFormatOpenAI) + require.NoError(t, err) + require.Equal(t, "Bearer sk-user", auth.Headers["Authorization"]) + require.Empty(t, auth.Headers["X-Api-Key"]) + }) + + t.Run("AnthropicUserKey", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + db.EXPECT().GetUserAIProviderKeyByProviderID(gomock.Any(), database.GetUserAIProviderKeyByProviderIDParams{ + UserID: ownerID, + AIProviderID: providerID, + }).Return(database.UserAiProviderKey{APIKey: "sk-user"}, nil) + + server := &Server{db: db, allowBYOK: true} + auth, err := server.aiGatewayProviderAuthForUser(ctx, ownerID, provider, aiGatewayRequestFormatAnthropic) + require.NoError(t, err) + require.Equal(t, "sk-user", auth.Headers["X-Api-Key"]) + require.Empty(t, auth.Headers["Authorization"]) + }) + + t.Run("NoUserKey", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + db.EXPECT().GetUserAIProviderKeyByProviderID(gomock.Any(), database.GetUserAIProviderKeyByProviderIDParams{ + UserID: ownerID, + AIProviderID: providerID, + }).Return(database.UserAiProviderKey{}, sql.ErrNoRows) + + server := &Server{db: db, allowBYOK: true} + auth, err := server.aiGatewayProviderAuthForUser(ctx, ownerID, provider, aiGatewayRequestFormatOpenAI) + require.NoError(t, err) + require.Empty(t, auth.Headers) + }) + + t.Run("BYOKDisabled", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + server := &Server{db: db, allowBYOK: false} + auth, err := server.aiGatewayProviderAuthForUser(ctx, ownerID, provider, aiGatewayRequestFormatOpenAI) + require.NoError(t, err) + require.Empty(t, auth.Headers) + }) +} + +func TestAIGatewayProviderAuthRedactsFormatting(t *testing.T) { + t.Parallel() + + auth := aiGatewayProviderAuth{Headers: map[string]string{ + "Authorization": "Bearer sk-user", + "X-Api-Key": "sk-user", + }} + for _, formatted := range []string{ + fmt.Sprint(auth), + fmt.Sprintf("%+v", auth), + fmt.Sprintf("%#v", auth), + } { + require.NotContains(t, formatted, "sk-user") + require.NotContains(t, formatted, "Bearer sk-user") + require.Contains(t, formatted, "redacted") + } +} + +func TestResolveModelRouteForConfigAIGatewayProviderAuth(t *testing.T) { + t.Parallel() + + ctx := t.Context() + ownerID := uuid.New() + providerID := uuid.New() + provider := database.AIProvider{ + ID: providerID, + Type: database.AiProviderTypeOpenai, + Name: "primary-openai", + Enabled: true, + } + modelConfig := database.ChatModelConfig{ + ID: uuid.New(), + Model: "gpt-4", + Provider: "openai", + AIProviderID: uuid.NullUUID{UUID: providerID, Valid: true}, + } + + t.Run("UserKey", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + db.EXPECT().GetAIProviderByID(gomock.Any(), providerID).Return(provider, nil) + db.EXPECT().GetUserAIProviderKeyByProviderID(gomock.Any(), database.GetUserAIProviderKeyByProviderIDParams{ + UserID: ownerID, + AIProviderID: providerID, + }).Return(database.UserAiProviderKey{APIKey: "sk-user"}, nil) + + server := &Server{db: db, aiGatewayRoutingEnabled: true, allowBYOK: true} + route, err := server.resolveModelRouteForConfig(ctx, ownerID, modelConfig, chatprovider.ProviderAPIKeys{}) + require.NoError(t, err) + require.Equal(t, modelRouteKindAIGateway, route.kind) + require.Equal(t, "Bearer sk-user", route.aiGateway.ProviderAuth.Headers["Authorization"]) + }) + + t.Run("CentralProviderCredentialsNotForwarded", func(t *testing.T) { + t.Parallel() + + ctrl := gomock.NewController(t) + db := dbmock.NewMockStore(ctrl) + db.EXPECT().GetAIProviderByID(gomock.Any(), providerID).Return(provider, nil) + + server := &Server{db: db, aiGatewayRoutingEnabled: true, allowBYOK: false} + route, err := server.resolveModelRouteForConfig(ctx, ownerID, modelConfig, chatprovider.ProviderAPIKeys{}) + require.NoError(t, err) + require.Equal(t, modelRouteKindAIGateway, route.kind) + require.Empty(t, route.aiGateway.ProviderAuth.Headers) + }) +} + +func TestAIGatewayModelForwardsProviderAuth(t *testing.T) { + t.Parallel() + + type seenRequest struct { + authorization string + xAPIKey string + coderToken string + apiKeyID string + path string + } + newServer := func(t *testing.T, provider database.AIProvider, auth aiGatewayProviderAuth, seen chan seenRequest) (*Server, resolvedModelRoute) { + factory := &aibridgeTestFactory{rt: roundTripFunc(func(req *http.Request) (*http.Response, error) { + apiKeyID, _ := aibridge.DelegatedAPIKeyIDFromContext(req.Context()) + seen <- seenRequest{ + authorization: req.Header.Get("Authorization"), + xAPIKey: req.Header.Get("X-Api-Key"), + coderToken: req.Header.Get(aibridge.HeaderCoderToken), + apiKeyID: apiKeyID, + path: req.URL.Path, + } + body := `{"id":"resp_test","object":"response","created_at":0,"status":"completed","model":"gpt-4","output":[{"id":"msg_test","type":"message","role":"assistant","content":[{"type":"output_text","text":"hello"}]}],"usage":{"input_tokens":1,"output_tokens":1,"total_tokens":2}}` + if provider.Type == database.AiProviderTypeAnthropic { + body = `{"id":"msg_test","type":"message","role":"assistant","model":"claude-haiku-4-5","content":[{"type":"text","text":"hello"}],"stop_reason":"end_turn","stop_sequence":null,"usage":{"input_tokens":1,"output_tokens":1}}` + } + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(body)), + Request: req, + }, nil + })} + server := &Server{ + aiGatewayRoutingEnabled: true, + aibridgeTransportFactory: aibridgeTestFactoryPointer(factory), + } + route := newAIGatewayModelRoute(provider, string(provider.Type), auth) + return server, route + } + + t.Run("OpenAI", func(t *testing.T) { + t.Parallel() + + seen := make(chan seenRequest, 1) + provider := aibridgeTestAIProvider(uuid.New(), "primary-openai", database.AiProviderTypeOpenai) + server, route := newServer(t, provider, aiGatewayProviderAuth{ + Headers: map[string]string{"Authorization": "Bearer sk-user"}, + }, seen) + apiKeyID := uuid.NewString() + model, err := server.newModel(t.Context(), aibridgeTestRequest(database.Chat{ID: uuid.New(), OwnerID: uuid.New()}, "gpt-4"), route, modelBuildOptions{ActiveAPIKeyID: apiKeyID, RecordHTTP: true}) + require.NoError(t, err) + _, err = model.Generate(t.Context(), fantasy.Call{Prompt: []fantasy.Message{{Role: fantasy.MessageRoleUser, Content: []fantasy.MessagePart{fantasy.TextPart{Text: "hello"}}}}}) + require.NoError(t, err) + + got := <-seen + require.Equal(t, "Bearer sk-user", got.authorization) + require.Empty(t, got.xAPIKey) + require.Equal(t, aibridgeDelegatedBYOKMarker, got.coderToken) + require.Equal(t, apiKeyID, got.apiKeyID) + require.Equal(t, "/v1/responses", got.path) + }) + + t.Run("Anthropic", func(t *testing.T) { + t.Parallel() + + seen := make(chan seenRequest, 1) + provider := aibridgeTestAIProvider(uuid.New(), "primary-anthropic", database.AiProviderTypeAnthropic) + server, route := newServer(t, provider, aiGatewayProviderAuth{ + Headers: map[string]string{"X-Api-Key": "sk-user"}, + }, seen) + apiKeyID := uuid.NewString() + model, err := server.newModel(t.Context(), aibridgeTestRequest(database.Chat{ID: uuid.New(), OwnerID: uuid.New()}, "claude-haiku-4-5"), route, modelBuildOptions{ActiveAPIKeyID: apiKeyID}) + require.NoError(t, err) + _, err = model.Generate(t.Context(), fantasy.Call{Prompt: []fantasy.Message{{Role: fantasy.MessageRoleUser, Content: []fantasy.MessagePart{fantasy.TextPart{Text: "hello"}}}}}) + require.NoError(t, err) + + got := <-seen + require.Equal(t, "sk-user", got.xAPIKey) + require.Equal(t, aibridgeDelegatedBYOKMarker, got.coderToken) + require.Equal(t, apiKeyID, got.apiKeyID) + require.Equal(t, "/v1/messages", got.path) + }) + + t.Run("NoUserKeyLeavesPlaceholderForAIBridged", func(t *testing.T) { + t.Parallel() + + seen := make(chan seenRequest, 1) + provider := aibridgeTestAIProvider(uuid.New(), "primary-openai", database.AiProviderTypeOpenai) + server, route := newServer(t, provider, aiGatewayProviderAuth{}, seen) + apiKeyID := uuid.NewString() + model, err := server.newModel(t.Context(), aibridgeTestRequest(database.Chat{ID: uuid.New(), OwnerID: uuid.New()}, "gpt-4"), route, modelBuildOptions{ActiveAPIKeyID: apiKeyID}) + require.NoError(t, err) + _, err = model.Generate(t.Context(), fantasy.Call{Prompt: []fantasy.Message{{Role: fantasy.MessageRoleUser, Content: []fantasy.MessagePart{fantasy.TextPart{Text: "hello"}}}}}) + require.NoError(t, err) + + got := <-seen + require.Equal(t, "Bearer "+aibridgePlaceholderAPIKey, got.authorization) + require.Empty(t, got.xAPIKey) + require.Empty(t, got.coderToken) + require.Equal(t, apiKeyID, got.apiKeyID) + }) +} + +func TestActiveTurnAPIKeyIDFromMessages(t *testing.T) { + t.Parallel() + + oldKeyID := uuid.NewString() + currentKeyID := uuid.NewString() + tests := []struct { + name string + messages []database.ChatMessage + wantKey string + wantOK bool + }{ + { + name: "CurrentUserMessage", + messages: []database.ChatMessage{ + {ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityBoth, APIKeyID: sqlNullString(oldKeyID)}, + {ID: 2, Role: database.ChatMessageRoleAssistant, Visibility: database.ChatMessageVisibilityBoth}, + {ID: 3, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityBoth, APIKeyID: sqlNullString(currentKeyID)}, + }, + wantKey: currentKeyID, + wantOK: true, + }, + { + name: "MissingCurrentUserAPIKeyDoesNotFallBack", + messages: []database.ChatMessage{ + {ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityBoth, APIKeyID: sqlNullString(oldKeyID)}, + {ID: 2, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityBoth}, + }, + }, + { + name: "SkipsModelOnlyUserMessages", + messages: []database.ChatMessage{ + {ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityBoth, APIKeyID: sqlNullString(oldKeyID)}, + {ID: 2, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, APIKeyID: sqlNullString(currentKeyID)}, + }, + wantKey: oldKeyID, + wantOK: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + gotKey, gotOK := activeTurnAPIKeyIDFromMessages(tt.messages) + require.Equal(t, tt.wantOK, gotOK) + require.Equal(t, tt.wantKey, gotKey) + ctx := contextWithActiveTurnAPIKeyID(t.Context(), tt.messages) + ctxKey, ctxOK := aibridge.DelegatedAPIKeyIDFromContext(ctx) + require.Equal(t, tt.wantOK, ctxOK) + require.Equal(t, tt.wantKey, ctxKey) + }) + } +} + +func TestActiveTurnContextUsesPromptMessages(t *testing.T) { + t.Parallel() + + db, _ := dbtestutil.NewDB(t) + ctx := t.Context() + user := dbgen.User(t, db, database.User{}) + org := dbgen.Organization(t, db, database.Organization{}) + model := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{}) + chat := dbgen.Chat(t, db, database.Chat{OrganizationID: org.ID, OwnerID: user.ID, LastModelConfigID: model.ID}) + oldKey, _ := dbgen.APIKey(t, db, database.APIKey{UserID: user.ID}) + currentKey, _ := dbgen.APIKey(t, db, database.APIKey{UserID: user.ID}) + modelOnlyKey, _ := dbgen.APIKey(t, db, database.APIKey{UserID: user.ID}) + + dbgen.ChatMessage(t, db, database.ChatMessage{ + ChatID: chat.ID, + CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true}, + ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true}, + Role: database.ChatMessageRoleUser, + Visibility: database.ChatMessageVisibilityBoth, + APIKeyID: sqlNullString(oldKey.ID), + }) + dbgen.ChatMessage(t, db, database.ChatMessage{ + ChatID: chat.ID, + CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true}, + ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true}, + Role: database.ChatMessageRoleSystem, + Visibility: database.ChatMessageVisibilityModel, + Compressed: true, + }) + dbgen.ChatMessage(t, db, database.ChatMessage{ + ChatID: chat.ID, + CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true}, + ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true}, + Role: database.ChatMessageRoleUser, + Visibility: database.ChatMessageVisibilityBoth, + APIKeyID: sqlNullString(currentKey.ID), + }) + dbgen.ChatMessage(t, db, database.ChatMessage{ + ChatID: chat.ID, + CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true}, + ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true}, + Role: database.ChatMessageRoleUser, + Visibility: database.ChatMessageVisibilityModel, + APIKeyID: sqlNullString(modelOnlyKey.ID), + }) + + messages, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID) + require.NoError(t, err) + ctx = contextWithActiveTurnAPIKeyID(ctx, messages) + gotKey, ok := aibridge.DelegatedAPIKeyIDFromContext(ctx) + require.True(t, ok) + require.Equal(t, currentKey.ID, gotKey) +} + +func sqlNullString(value string) sql.NullString { + return sql.NullString{String: value, Valid: value != ""} +} + +func TestAIBridgeRoutingFailClosed(t *testing.T) { + t.Parallel() + + providerID := uuid.New() + chat := database.Chat{ID: uuid.New(), OwnerID: uuid.New()} + aiProvider := aibridgeTestAIProvider(providerID, "primary-openai", database.AiProviderTypeOpenai) + + t.Run("NilFactory", func(t *testing.T) { + t.Parallel() + server := &Server{aiGatewayRoutingEnabled: true} + _, err := server.newModel(t.Context(), aibridgeTestRequest(chat, "gpt-4"), aibridgeTestRoute(aiProvider), modelBuildOptions{ActiveAPIKeyID: uuid.NewString()}) + require.ErrorContains(t, err, "transport factory") + }) + + t.Run("FactoryError", func(t *testing.T) { + t.Parallel() + factory := &aibridgeTestFactory{err: xerrors.New("boom")} + server := &Server{ + aiGatewayRoutingEnabled: true, + aibridgeTransportFactory: aibridgeTestFactoryPointer(factory), + } + _, err := server.newModel(t.Context(), aibridgeTestRequest(chat, "gpt-4"), aibridgeTestRoute(aiProvider), modelBuildOptions{ActiveAPIKeyID: uuid.NewString()}) + require.ErrorContains(t, err, "boom") + }) + + t.Run("MissingProviderName", func(t *testing.T) { + t.Parallel() + server := &Server{aiGatewayRoutingEnabled: true} + missingNameProvider := aibridgeTestAIProvider(providerID, "", database.AiProviderTypeOpenai) + _, err := server.newModel(t.Context(), aibridgeTestRequest(chat, "gpt-4"), aibridgeTestRoute(missingNameProvider), modelBuildOptions{ActiveAPIKeyID: uuid.NewString()}) + require.ErrorContains(t, err, "AI provider name") + }) + + t.Run("MissingAPIKeyID", func(t *testing.T) { + t.Parallel() + factory := &aibridgeTestFactory{rt: roundTripFunc(func(*http.Request) (*http.Response, error) { + t.Fatal("transport must not be used without an API key ID") + return nil, xerrors.New("unreachable") + })} + server := &Server{ + aiGatewayRoutingEnabled: true, + aibridgeTransportFactory: aibridgeTestFactoryPointer(factory), + } + _, err := server.newModel(t.Context(), aibridgeTestRequest(chat, "gpt-4"), aibridgeTestRoute(aiProvider), modelBuildOptions{}) + require.ErrorContains(t, err, "active turn API key ID") + }) + + t.Run("StaticModel", func(t *testing.T) { + t.Parallel() + server := &Server{aiGatewayRoutingEnabled: true} + _, err := server.newModel(t.Context(), aibridgeTestRequest(chat, "gpt-4"), newAIGatewayModelRoute(database.AIProvider{}, "", aiGatewayProviderAuth{}), modelBuildOptions{ActiveAPIKeyID: uuid.NewString()}) + require.ErrorContains(t, err, "concrete AI provider") + }) +} + +func TestDirectModelBuildDoesNotRequireActiveAPIKeyID(t *testing.T) { + t.Parallel() + + server := &Server{} + model, err := server.newModel(t.Context(), modelClientRequest{ + Chat: database.Chat{ID: uuid.New(), OwnerID: uuid.New()}, + ModelName: "gpt-4", + UserAgent: chatprovider.UserAgent(), + }, newDirectModelRoute("openai", chatprovider.ProviderAPIKeys{OpenAI: "sk-test"}), modelBuildOptions{}) + require.NoError(t, err) + require.NotNil(t, model) +} + +func TestAIBridgeComputerUseModelUsesRoute(t *testing.T) { + t.Parallel() + + providerID := uuid.New() + apiKeyID := uuid.NewString() + factory := &aibridgeTestFactory{rt: roundTripFunc(func(*http.Request) (*http.Response, error) { + t.Fatal("computer use model construction must not send a request") + return nil, xerrors.New("unreachable") + })} + chat := database.Chat{ID: uuid.New(), OwnerID: uuid.New()} + server := &Server{ + aiGatewayRoutingEnabled: true, + aibridgeTransportFactory: aibridgeTestFactoryPointer(factory), + } + provider := chattool.ComputerUseProviderOpenAI + modelProvider, modelName, ok := chattool.DefaultComputerUseModel(provider) + require.True(t, ok) + + ctx := aibridge.WithDelegatedAPIKeyID(t.Context(), "context-key-must-be-ignored") + model, debugEnabled, resolvedProvider, resolvedModel, err := server.resolveComputerUseModel( + ctx, + chat, + aibridgeTestRoute(aibridgeTestAIProvider(providerID, "primary-openai", database.AiProviderTypeOpenai)), + provider, + modelProvider, + modelName, + modelBuildOptions{ActiveAPIKeyID: apiKeyID}, + ) + require.NoError(t, err) + require.NotNil(t, model) + require.False(t, debugEnabled) + require.Equal(t, chattool.ComputerUseProviderOpenAI, resolvedProvider) + require.Equal(t, modelName, resolvedModel) + require.Equal(t, "primary-openai", factory.providerName) + require.Equal(t, aibridge.SourceAgents, factory.source) +} + +func TestAIBridgeDelegatedContextPropagation(t *testing.T) { + t.Parallel() + + providerID := uuid.New() + apiKeyID := uuid.NewString() + type seenRequest struct { + apiKeyID string + ok bool + path string + } + seen := make(chan seenRequest, 1) + factory := &aibridgeTestFactory{rt: roundTripFunc(func(req *http.Request) (*http.Response, error) { + gotAPIKeyID, ok := aibridge.DelegatedAPIKeyIDFromContext(req.Context()) + seen <- seenRequest{ + apiKeyID: gotAPIKeyID, + ok: ok, + path: req.URL.Path, + } + body := `{"id":"resp_test","object":"response","created_at":0,"status":"completed","model":"gpt-4","output":[{"id":"msg_test","type":"message","role":"assistant","content":[{"type":"output_text","text":"hello"}]}],"usage":{"input_tokens":1,"output_tokens":1,"total_tokens":2}}` + return &http.Response{ + StatusCode: http.StatusOK, + Header: http.Header{"Content-Type": []string{"application/json"}}, + Body: io.NopCloser(strings.NewReader(body)), + Request: req, + }, nil + })} + chat := database.Chat{ID: uuid.New(), OwnerID: uuid.New()} + server := &Server{ + aiGatewayRoutingEnabled: true, + aibridgeTransportFactory: aibridgeTestFactoryPointer(factory), + } + + ctx := aibridge.WithDelegatedAPIKeyID(t.Context(), "context-key-must-be-ignored") + model, err := server.newModel(ctx, aibridgeTestRequest(chat, "gpt-4"), aibridgeTestRoute(aibridgeTestAIProvider(providerID, "primary-openai", database.AiProviderTypeOpenai)), modelBuildOptions{ActiveAPIKeyID: apiKeyID, RecordHTTP: true}) + require.NoError(t, err) + _, err = model.Generate(t.Context(), fantasy.Call{Prompt: []fantasy.Message{{ + Role: fantasy.MessageRoleUser, + Content: []fantasy.MessagePart{fantasy.TextPart{Text: "hello"}}, + }}}) + require.NoError(t, err) + + got := <-seen + require.Equal(t, "primary-openai", factory.providerName) + require.Equal(t, aibridge.SourceAgents, factory.source) + require.True(t, got.ok) + require.Equal(t, "/v1/responses", got.path) + require.Equal(t, apiKeyID, got.apiKeyID) +} diff --git a/coderd/x/chatd/quickgen.go b/coderd/x/chatd/quickgen.go index a4e69bcbad..774e02d107 100644 --- a/coderd/x/chatd/quickgen.go +++ b/coderd/x/chatd/quickgen.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "net/http" "slices" "strings" "time" @@ -69,9 +68,39 @@ var preferredTitleModels = []struct { type shortTextCandidate struct { provider string model string + route resolvedModelRoute lm fantasy.LanguageModel } +func (p *Server) preferredShortTextCandidates( + chat database.Chat, + keys chatprovider.ProviderAPIKeys, +) []shortTextCandidate { + if p.shouldUseAIGatewayRouting() { + return nil + } + + candidates := make([]shortTextCandidate, 0, len(preferredTitleModels)+1) + userAgent := chatprovider.UserAgent() + extraHeaders := chatprovider.CoderHeaders(chat) + for _, candidate := range preferredTitleModels { + model, err := chatprovider.ModelFromConfig( + candidate.provider, candidate.model, keys, userAgent, + extraHeaders, + nil, + ) + if err == nil { + candidates = append(candidates, shortTextCandidate{ + provider: candidate.provider, + model: candidate.model, + route: newDirectModelRoute(candidate.provider, keys), + lm: model, + }) + } + } + return candidates +} + func selectPreferredConfiguredShortTextModelConfig( configs []database.ChatModelConfig, ) (database.ChatModelConfig, bool) { @@ -121,7 +150,9 @@ func (p *Server) maybeGenerateChatTitle( fallbackProvider string, fallbackModelName string, fallbackModel fantasy.LanguageModel, + fallbackRoute resolvedModelRoute, keys chatprovider.ProviderAPIKeys, + modelOpts modelBuildOptions, generatedTitle *generatedChatTitle, logger slog.Logger, debugSvc *chatdebug.Service, @@ -135,10 +166,11 @@ func (p *Server) maybeGenerateChatTitle( titleCtx, cancel := context.WithTimeout(ctx, 30*time.Second) defer cancel() - overrideConfig, overrideModel, overrideKeys, overrideSet, overrideErr := p.resolveTitleGenerationModelOverride( + overrideConfig, overrideModel, _, overrideRoute, overrideSet, overrideErr := p.resolveTitleGenerationModelOverride( titleCtx, chat, keys, + modelOpts, ) if overrideErr != nil { if overrideSet { @@ -161,29 +193,15 @@ func (p *Server) maybeGenerateChatTitle( candidates = []shortTextCandidate{{ provider: overrideConfig.Provider, model: overrideConfig.Model, + route: overrideRoute, lm: overrideModel, }} } else { - // Build candidate list: preferred lightweight models first, - // then the user's chat model as last resort. - candidates = make([]shortTextCandidate, 0, len(preferredTitleModels)+1) - for _, c := range preferredTitleModels { - m, err := chatprovider.ModelFromConfig( - c.provider, c.model, keys, chatprovider.UserAgent(), - chatprovider.CoderHeaders(chat), - nil, - ) - if err == nil { - candidates = append(candidates, shortTextCandidate{ - provider: c.provider, - model: c.model, - lm: m, - }) - } - } + candidates = p.preferredShortTextCandidates(chat, keys) candidates = append(candidates, shortTextCandidate{ provider: fallbackProvider, model: fallbackModelName, + route: fallbackRoute, lm: fallbackModel, }) } @@ -213,17 +231,13 @@ func (p *Server) maybeGenerateChatTitle( candidateCtx := titleCtx candidateModel := candidate.lm finishDebugRun := func(error) {} - candidateKeys := keys - if overrideSet { - candidateKeys = overrideKeys - } if debugEnabled { - candidateCtx, candidateModel, finishDebugRun = prepareQuickgenDebugCandidate( + candidateCtx, candidateModel, finishDebugRun = p.prepareQuickgenDebugCandidate( titleCtx, chat, - candidateKeys, debugSvc, candidate, + modelOpts, chatdebug.KindTitleGeneration, triggerMessageID, historyTipMessageID, @@ -293,32 +307,26 @@ func (p *Server) maybeGenerateChatTitle( } } -func newQuickgenDebugModel( +func (p *Server) newQuickgenDebugModel( + ctx context.Context, chat database.Chat, - keys chatprovider.ProviderAPIKeys, debugSvc *chatdebug.Service, provider string, model string, + route resolvedModelRoute, + modelOpts modelBuildOptions, ) (fantasy.LanguageModel, error) { - httpClient := &http.Client{Transport: &chatdebug.RecordingTransport{}} - debugModel, err := chatprovider.ModelFromConfig( - provider, - model, - keys, - chatprovider.UserAgent(), - chatprovider.CoderHeaders(chat), - httpClient, - ) + debugOpts := modelOpts + debugOpts.RecordHTTP = true + debugModel, err := p.newModel(ctx, modelClientRequest{ + Chat: chat, + ModelName: model, + UserAgent: chatprovider.UserAgent(), + ExtraHeaders: chatprovider.CoderHeaders(chat), + }, route, debugOpts) if err != nil { return nil, err } - if debugModel == nil { - return nil, xerrors.Errorf( - "create model for %s/%s returned nil", - provider, - model, - ) - } return chatdebug.WrapModel(debugModel, debugSvc, chatdebug.RecorderOptions{ ChatID: chat.ID, @@ -328,12 +336,12 @@ func newQuickgenDebugModel( }), nil } -func prepareQuickgenDebugCandidate( +func (p *Server) prepareQuickgenDebugCandidate( ctx context.Context, chat database.Chat, - keys chatprovider.ProviderAPIKeys, debugSvc *chatdebug.Service, candidate shortTextCandidate, + modelOpts modelBuildOptions, kind chatdebug.RunKind, triggerMessageID int64, historyTipMessageID int64, @@ -345,12 +353,14 @@ func prepareQuickgenDebugCandidate( return ctx, candidate.lm, finishDebugRun } - debugModel, err := newQuickgenDebugModel( + debugModel, err := p.newQuickgenDebugModel( + ctx, chat, - keys, debugSvc, candidate.provider, candidate.model, + candidate.route, + modelOpts, ) if err != nil { logger.Warn(ctx, "failed to build short-text debug model", @@ -393,18 +403,8 @@ func prepareQuickgenDebugCandidate( return ctx, candidate.lm, finishDebugRun } - runCtx := chatdebug.ContextWithRun( - ctx, - &chatdebug.RunContext{ - RunID: run.ID, - ChatID: chat.ID, - TriggerMessageID: triggerMessageID, - HistoryTipMessageID: historyTipMessageID, - Kind: kind, - Provider: candidate.provider, - Model: candidate.model, - }, - ) + runContext := chatdebugRunContext(run) + runCtx := chatdebug.ContextWithRun(ctx, &runContext) finishDebugRun = func(runErr error) { if finalizeErr := debugSvc.FinalizeRun(ctx, chatdebug.FinalizeRunParams{ RunID: run.ID, @@ -824,7 +824,7 @@ const turnStatusLabelPrompt = "You write compact chat status labels for a sideba // message text. It follows the same candidate-selection strategy // as title generation: try preferred lightweight models first, then // fall back to the provided model. Returns "" on any failure. -func generateTurnStatusLabel( +func (p *Server) generateTurnStatusLabel( ctx context.Context, chat database.Chat, status database.ChatStatus, @@ -832,7 +832,9 @@ func generateTurnStatusLabel( fallbackProvider string, fallbackModelName string, fallbackModel fantasy.LanguageModel, + fallbackRoute resolvedModelRoute, keys chatprovider.ProviderAPIKeys, + modelOpts modelBuildOptions, logger slog.Logger, debugSvc *chatdebug.Service, triggerMessageID int64, @@ -848,24 +850,11 @@ func generateTurnStatusLabel( "\nChat title: " + chat.Title + "\n\nAgent's latest message:\n" + assistantText - candidates := make([]shortTextCandidate, 0, len(preferredTitleModels)+1) - for _, c := range preferredTitleModels { - m, err := chatprovider.ModelFromConfig( - c.provider, c.model, keys, chatprovider.UserAgent(), - chatprovider.CoderHeaders(chat), - nil, - ) - if err == nil { - candidates = append(candidates, shortTextCandidate{ - provider: c.provider, - model: c.model, - lm: m, - }) - } - } + candidates := p.preferredShortTextCandidates(chat, keys) candidates = append(candidates, shortTextCandidate{ provider: fallbackProvider, model: fallbackModelName, + route: fallbackRoute, lm: fallbackModel, }) @@ -876,12 +865,12 @@ func generateTurnStatusLabel( candidateModel := candidate.lm finishDebugRun := func(error) {} if debugEnabled { - candidateCtx, candidateModel, finishDebugRun = prepareQuickgenDebugCandidate( + candidateCtx, candidateModel, finishDebugRun = p.prepareQuickgenDebugCandidate( labelCtx, chat, - keys, debugSvc, candidate, + modelOpts, chatdebug.KindQuickgen, triggerMessageID, historyTipMessageID, diff --git a/coderd/x/chatd/quickgen_internal_test.go b/coderd/x/chatd/quickgen_internal_test.go index 6be464980b..09fc8001ab 100644 --- a/coderd/x/chatd/quickgen_internal_test.go +++ b/coderd/x/chatd/quickgen_internal_test.go @@ -359,6 +359,16 @@ func Test_renderManualTitlePrompt(t *testing.T) { } } +func TestPreferredShortTextCandidatesNilUnderAIGateway(t *testing.T) { + t.Parallel() + + server := &Server{aiGatewayRoutingEnabled: true} + candidates := server.preferredShortTextCandidates(database.Chat{}, chatprovider.ProviderAPIKeys{ + ByProvider: map[string]string{"openai": "test-key"}, + }) + require.Nil(t, candidates) +} + func TestMaybeGenerateChatTitlePreservesUpdatedAt(t *testing.T) { t.Parallel() @@ -428,7 +438,9 @@ func TestMaybeGenerateChatTitlePreservesUpdatedAt(t *testing.T) { "openai", "test-model", model, + resolvedModelRoute{}, chatprovider.ProviderAPIKeys{}, + modelBuildOptions{}, generated, logger, nil, diff --git a/coderd/x/chatd/subagent.go b/coderd/x/chatd/subagent.go index cc3e35f78c..8984ac86cd 100644 --- a/coderd/x/chatd/subagent.go +++ b/coderd/x/chatd/subagent.go @@ -17,6 +17,7 @@ import ( "golang.org/x/xerrors" "cdr.dev/slog/v3" + "github.com/coder/coder/v2/coderd/aibridge" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbauthz" coderdpubsub "github.com/coder/coder/v2/coderd/pubsub" @@ -908,6 +909,14 @@ func (p *Server) resolveExploreToolSnapshot( return inheritedMCPServerIDs, nil } +func (p *Server) delegatedAPIKeyIDForSubagent(ctx context.Context) (string, error) { + apiKeyID, ok := aibridge.DelegatedAPIKeyIDFromContext(ctx) + if !ok && p.shouldUseAIGatewayRouting() { + return "", xerrors.New("AI Gateway routing requires the active turn API key ID for subagent messages") + } + return apiKeyID, nil +} + func (p *Server) createChildSubagentChat( ctx context.Context, parent database.Chat, @@ -950,6 +959,10 @@ func (p *Server) createChildSubagentChatWithOptions( if modelConfigID == uuid.Nil { return database.Chat{}, xerrors.New("model config is required") } + childAPIKeyID, err := p.delegatedAPIKeyIDForSubagent(ctx) + if err != nil { + return database.Chat{}, err + } childPlanMode := parent.PlanMode if opts.planModeOverride != nil { @@ -1086,7 +1099,7 @@ func (p *Server) createChildSubagentChatWithOptions( database.ChatMessageVisibilityBoth, modelConfigID, chatprompt.CurrentContentVersion, - ).withCreatedBy(parent.OwnerID)) + ).withCreatedBy(parent.OwnerID).withAPIKeyID(childAPIKeyID)) if _, err := tx.InsertChatMessages(ctx, userParams); err != nil { return xerrors.Errorf("insert initial child user message: %w", err) } @@ -1236,10 +1249,16 @@ func (p *Server) sendSubagentMessage( return database.Chat{}, xerrors.Errorf("get target chat: %w", err) } + apiKeyID, err := p.delegatedAPIKeyIDForSubagent(ctx) + if err != nil { + return database.Chat{}, err + } + sendResult, err := p.SendMessage(ctx, SendMessageOptions{ ChatID: targetChatID, CreatedBy: targetChat.OwnerID, Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText(message)}, + APIKeyID: apiKeyID, BusyBehavior: busyBehavior, }) if err != nil { diff --git a/coderd/x/chatd/subagent_internal_test.go b/coderd/x/chatd/subagent_internal_test.go index 55254db1a2..ce860f1249 100644 --- a/coderd/x/chatd/subagent_internal_test.go +++ b/coderd/x/chatd/subagent_internal_test.go @@ -16,6 +16,7 @@ import ( "cdr.dev/slog/v3" "cdr.dev/slog/v3/sloggers/slogtest" + "github.com/coder/coder/v2/coderd/aibridge" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database/dbauthz" "github.com/coder/coder/v2/coderd/database/dbgen" @@ -231,6 +232,163 @@ func insertInternalAIProvider( }) } +func TestCreateChildSubagentChatPropagatesActiveTurnAPIKeyID(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + db, _ := dbtestutil.NewDB(t) + user := dbgen.User(t, db, database.User{}) + org := dbgen.Organization(t, db, database.Organization{}) + dbgen.OrganizationMember(t, db, database.OrganizationMember{UserID: user.ID, OrganizationID: org.ID}) + model := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{}) + parent := dbgen.Chat(t, db, database.Chat{ + OrganizationID: org.ID, + OwnerID: user.ID, + LastModelConfigID: model.ID, + }) + + apiKey, _ := dbgen.APIKey(t, db, database.APIKey{UserID: user.ID}) + ctx = aibridge.WithDelegatedAPIKeyID(ctx, apiKey.ID) + + server := &Server{db: db, logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})} + child, err := server.createChildSubagentChat(ctx, parent, "inspect the workspace", "") + require.NoError(t, err) + + messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ChatID: child.ID}) + require.NoError(t, err) + var childUserMessage database.ChatMessage + for _, message := range messages { + if message.Role == database.ChatMessageRoleUser { + childUserMessage = message + break + } + } + require.NotZero(t, childUserMessage.ID) + require.True(t, childUserMessage.APIKeyID.Valid) + require.Equal(t, apiKey.ID, childUserMessage.APIKeyID.String) +} + +func TestSendSubagentMessagePropagatesActiveTurnAPIKeyID(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{}) + ctx := chatdTestContext(t) + user, org, model := seedInternalChatDeps(t, db) + apiKey, _ := dbgen.APIKey(t, db, database.APIKey{UserID: user.ID}) + + parent, err := server.CreateChat(ctx, CreateOptions{ + OrganizationID: org.ID, + OwnerID: user.ID, + Title: "parent-send-subagent-key", + ModelConfigID: model.ID, + InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")}, + APIKeyID: apiKey.ID, + }) + require.NoError(t, err) + child, err := server.CreateChat(ctx, CreateOptions{ + OrganizationID: org.ID, + OwnerID: user.ID, + ParentChatID: uuid.NullUUID{UUID: parent.ID, Valid: true}, + RootChatID: uuid.NullUUID{UUID: parent.ID, Valid: true}, + Title: "child-send-subagent-key", + ModelConfigID: model.ID, + InitialUserContent: []codersdk.ChatMessagePart{ + codersdk.ChatMessageText("do work"), + }, + }) + require.NoError(t, err) + + setChatStatus(ctx, t, db, child.ID, database.ChatStatusWaiting, "") + + ctx = aibridge.WithDelegatedAPIKeyID(ctx, apiKey.ID) + _, err = server.sendSubagentMessage( + ctx, + parent.ID, + child.ID, + "follow up", + SendMessageBusyBehaviorInterrupt, + ) + require.NoError(t, err) + + messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ChatID: child.ID}) + require.NoError(t, err) + var latestUserMessage database.ChatMessage + for _, message := range messages { + if message.Role == database.ChatMessageRoleUser && message.ID > latestUserMessage.ID { + latestUserMessage = message + } + } + require.NotZero(t, latestUserMessage.ID) + require.True(t, latestUserMessage.APIKeyID.Valid) + require.Equal(t, apiKey.ID, latestUserMessage.APIKeyID.String) +} + +func TestCreateChildSubagentChatRequiresActiveTurnAPIKeyIDForAIGateway(t *testing.T) { + t.Parallel() + + ctx := testutil.Context(t, testutil.WaitShort) + db, _ := dbtestutil.NewDB(t) + user := dbgen.User(t, db, database.User{}) + org := dbgen.Organization(t, db, database.Organization{}) + dbgen.OrganizationMember(t, db, database.OrganizationMember{UserID: user.ID, OrganizationID: org.ID}) + model := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{}) + parent := dbgen.Chat(t, db, database.Chat{ + OrganizationID: org.ID, + OwnerID: user.ID, + LastModelConfigID: model.ID, + }) + + server := &Server{ + db: db, + logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}), + aiGatewayRoutingEnabled: true, + } + _, err := server.createChildSubagentChat(ctx, parent, "inspect the workspace", "") + require.ErrorContains(t, err, "AI Gateway routing requires the active turn API key ID for subagent messages") +} + +func TestSendSubagentMessageRequiresActiveTurnAPIKeyIDForAIGateway(t *testing.T) { + t.Parallel() + + db, ps := dbtestutil.NewDB(t) + server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{}) + server.aiGatewayRoutingEnabled = true + ctx := chatdTestContext(t) + user, org, model := seedInternalChatDeps(t, db) + + parent, err := server.CreateChat(ctx, CreateOptions{ + OrganizationID: org.ID, + OwnerID: user.ID, + Title: "parent-send-subagent-missing-key", + ModelConfigID: model.ID, + InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")}, + }) + require.NoError(t, err) + child, err := server.CreateChat(ctx, CreateOptions{ + OrganizationID: org.ID, + OwnerID: user.ID, + ParentChatID: uuid.NullUUID{UUID: parent.ID, Valid: true}, + RootChatID: uuid.NullUUID{UUID: parent.ID, Valid: true}, + Title: "child-send-subagent-missing-key", + ModelConfigID: model.ID, + InitialUserContent: []codersdk.ChatMessagePart{ + codersdk.ChatMessageText("do work"), + }, + }) + require.NoError(t, err) + + setChatStatus(ctx, t, db, child.ID, database.ChatStatusWaiting, "") + _, err = server.sendSubagentMessage( + ctx, + parent.ID, + child.ID, + "follow up", + SendMessageBusyBehaviorInterrupt, + ) + require.ErrorContains(t, err, "AI Gateway routing requires the active turn API key ID for subagent messages") +} + func TestResolveUserProviderAPIKeys_AIProvider(t *testing.T) { t.Parallel() @@ -351,7 +509,7 @@ func TestResolveChatModel_AIProviderDisabled(t *testing.T) { LastModelConfigID: modelConfig.ID, }) - model, config, keys, debugEnabled, resolvedProvider, resolvedModel, err := server.resolveChatModel(ctx, chat) + model, config, keys, _, debugEnabled, resolvedProvider, resolvedModel, err := server.resolveChatModel(ctx, chat, modelBuildOptions{}) require.ErrorContains(t, err, "is disabled") require.Nil(t, model) require.Equal(t, database.ChatModelConfig{}, config) @@ -3187,6 +3345,26 @@ func TestAwaitSubagentCompletion(t *testing.T) { parent, child := createParentChildChats(ctx, t, server, user, org, model) + // signalWake from CreateChat triggers background processing. Wait + // for those runs to finish, then reset both chats so this test owns + // the state transition observed by the poll loop. + testutil.Eventually(ctx, t, func(ctx context.Context) bool { + parentChat, err := db.GetChatByID(ctx, parent.ID) + if err != nil { + return false + } + childChat, err := db.GetChatByID(ctx, child.ID) + if err != nil { + return false + } + return parentChat.Status != database.ChatStatusPending && + parentChat.Status != database.ChatStatusRunning && + childChat.Status != database.ChatStatusPending && + childChat.Status != database.ChatStatusRunning + }, testutil.IntervalFast) + setChatStatus(ctx, t, db, parent.ID, database.ChatStatusRunning, "") + setChatStatus(ctx, t, db, child.ID, database.ChatStatusRunning, "") + // Set the trap BEFORE starting the goroutine so we // deterministically catch the ticker creation. tickTrap := mClock.Trap().NewTicker("chatd", "subagent_poll") diff --git a/coderd/x/chatd/title_override.go b/coderd/x/chatd/title_override.go index 9214b44254..9840a3b471 100644 --- a/coderd/x/chatd/title_override.go +++ b/coderd/x/chatd/title_override.go @@ -31,28 +31,18 @@ func readTitleGenerationModelOverride( } // resolveTitleGenerationModelOverride resolves the deployment-wide title -// generation model override. It returns four values: -// -// - modelConfig and model: populated only on success. -// - overrideSet: true when the admin configured a non-empty override, -// regardless of whether resolution succeeded. Callers MUST always check -// err first; overrideSet alone does not imply the model is usable. -// - err: non-nil when resolution failed. DB read failure returns -// (zero, nil, false, err). With overrideSet=true, the override is -// configured but unusable (deleted model, missing credentials, etc.) and -// callers should treat this as a hard failure for explicit-override -// semantics, not a soft fallback. -// -// When the override is unset or stored as malformed, the function returns -// (zero, nil, false, nil) so callers can fall back to default behavior. +// generation model override. overrideSet is true when an override was +// configured; in that case any returned error is a hard failure. When +// overrideSet is false, callers may fall back to the default title model. func (p *Server) resolveTitleGenerationModelOverride( ctx context.Context, chat database.Chat, keys chatprovider.ProviderAPIKeys, -) (database.ChatModelConfig, fantasy.LanguageModel, chatprovider.ProviderAPIKeys, bool, error) { + modelOpts modelBuildOptions, +) (database.ChatModelConfig, fantasy.LanguageModel, chatprovider.ProviderAPIKeys, resolvedModelRoute, bool, error) { raw, err := readTitleGenerationModelOverride(ctx, p.db) if err != nil { - return database.ChatModelConfig{}, nil, chatprovider.ProviderAPIKeys{}, false, xerrors.Errorf( + return database.ChatModelConfig{}, nil, chatprovider.ProviderAPIKeys{}, resolvedModelRoute{}, false, xerrors.Errorf( "read title generation model override: %w", err, ) @@ -84,43 +74,28 @@ func (p *Server) resolveTitleGenerationModelOverride( modelOverrideFailureModeHard, ) if err != nil { - return database.ChatModelConfig{}, nil, chatprovider.ProviderAPIKeys{}, overrideSet, err + return database.ChatModelConfig{}, nil, chatprovider.ProviderAPIKeys{}, resolvedModelRoute{}, overrideSet, err } if !overrideSet { - return database.ChatModelConfig{}, nil, keys, false, nil + return database.ChatModelConfig{}, nil, keys, resolvedModelRoute{}, false, nil } - providerHint := modelConfig.Provider - if modelConfig.AIProviderID.Valid { - //nolint:gocritic // Title overrides need chatd-scoped provider reads for user-owned chats. - provider, err := p.db.GetAIProviderByID(dbauthz.AsChatd(ctx), modelConfig.AIProviderID.UUID) - if err != nil { - return database.ChatModelConfig{}, nil, chatprovider.ProviderAPIKeys{}, true, xerrors.Errorf("get AI provider for title generation override: %w", err) - } - if !provider.Enabled { - return database.ChatModelConfig{}, nil, chatprovider.ProviderAPIKeys{}, true, xerrors.Errorf("AI provider %s is disabled", modelConfig.AIProviderID.UUID) - } - providerHint = string(provider.Type) - } - model, err := chatprovider.ModelFromConfig( - providerHint, - modelConfig.Model, - overrideProviderKeys, - chatprovider.UserAgent(), - chatprovider.CoderHeaders(chat), - nil, - ) + //nolint:gocritic // Title overrides need chatd-scoped provider reads for user-owned chats. + route, err := p.resolveModelRouteForConfig(dbauthz.AsChatd(ctx), chat.OwnerID, modelConfig, overrideProviderKeys) if err != nil { - return database.ChatModelConfig{}, nil, chatprovider.ProviderAPIKeys{}, true, xerrors.Errorf( + return database.ChatModelConfig{}, nil, chatprovider.ProviderAPIKeys{}, resolvedModelRoute{}, true, err + } + model, err := p.newModel(ctx, modelClientRequest{ + Chat: chat, + ModelName: modelConfig.Model, + UserAgent: chatprovider.UserAgent(), + ExtraHeaders: chatprovider.CoderHeaders(chat), + }, route, modelOpts) + if err != nil { + return database.ChatModelConfig{}, nil, chatprovider.ProviderAPIKeys{}, resolvedModelRoute{}, true, xerrors.Errorf( "create title generation model override: %w", err, ) } - if model == nil { - return database.ChatModelConfig{}, nil, chatprovider.ProviderAPIKeys{}, true, xerrors.Errorf( - "create title generation model override returned nil", - ) - } - - return modelConfig, model, overrideProviderKeys, true, nil + return modelConfig, model, route.directProviderKeys(), route, true, nil } diff --git a/coderd/x/chatd/title_override_internal_test.go b/coderd/x/chatd/title_override_internal_test.go index 7352198474..9439b92fce 100644 --- a/coderd/x/chatd/title_override_internal_test.go +++ b/coderd/x/chatd/title_override_internal_test.go @@ -65,7 +65,9 @@ func TestMaybeGenerateChatTitle_TitleGenerationOverrideUnset(t *testing.T) { "openai", "fallback-chat-model", fallbackModel, + resolvedModelRoute{}, keys, + modelBuildOptions{}, generated, logger, nil, @@ -112,7 +114,9 @@ func TestMaybeGenerateChatTitle_TitleGenerationOverrideUnset(t *testing.T) { "openai", "fallback-chat-model", fallbackModel, + resolvedModelRoute{}, chatprovider.ProviderAPIKeys{}, + modelBuildOptions{}, generated, logger, nil, @@ -160,7 +164,9 @@ func TestMaybeGenerateChatTitle_TitleGenerationOverrideReadDBError(t *testing.T) "openai", "fallback-chat-model", fallbackModel, + resolvedModelRoute{}, chatprovider.ProviderAPIKeys{}, + modelBuildOptions{}, generated, logger, nil, @@ -207,7 +213,9 @@ func TestMaybeGenerateChatTitle_TitleGenerationOverrideMalformedFallsThrough(t * "openai", "fallback-chat-model", fallbackModel, + resolvedModelRoute{}, chatprovider.ProviderAPIKeys{}, + modelBuildOptions{}, generated, logger, nil, @@ -257,7 +265,7 @@ func TestMaybeGenerateChatTitle_TitleGenerationOverrideSetUsable(t *testing.T) { db.EXPECT().GetAIProviderKeysByProviderID(gomock.Any(), providerID).Return([]database.AIProviderKey{{ ProviderID: providerID, APIKey: "test-key", - }}, nil) + }}, nil).Times(2) db.EXPECT().UpdateChatTitleByID(gomock.Any(), database.UpdateChatTitleByIDParams{ ID: chat.ID, Title: wantTitle, @@ -272,7 +280,9 @@ func TestMaybeGenerateChatTitle_TitleGenerationOverrideSetUsable(t *testing.T) { "openai", "fallback-chat-model", fallbackModel, + resolvedModelRoute{}, chatprovider.ProviderAPIKeys{}, + modelBuildOptions{}, generated, logger, nil, @@ -312,7 +322,9 @@ func TestMaybeGenerateChatTitle_TitleGenerationOverrideSetUnusableSkips(t *testi "openai", "fallback-chat-model", fallbackModel, + resolvedModelRoute{}, chatprovider.ProviderAPIKeys{}, + modelBuildOptions{}, generated, logger, nil, @@ -360,7 +372,9 @@ func TestMaybeGenerateChatTitle_TitleGenerationOverrideCallFailureSkipsFallback( "openai", "fallback-chat-model", fallbackModel, + resolvedModelRoute{}, keys, + modelBuildOptions{}, generated, logger, nil, @@ -398,6 +412,7 @@ func TestResolveManualTitleModel_TitleGenerationOverrideUnset(t *testing.T) { db, chat, chatprovider.ProviderAPIKeys{ByProvider: map[string]string{"openai": "test-key"}}, + modelBuildOptions{}, ) require.NoError(t, err) require.NotNil(t, model) @@ -447,6 +462,7 @@ func TestResolveManualTitleModel_TitleGenerationOverrideUnsetAIProvider(t *testi db, chat, chatprovider.ProviderAPIKeys{}, + modelBuildOptions{}, ) require.NoError(t, err) require.NotNil(t, model) @@ -481,6 +497,7 @@ func TestResolveManualTitleModel_TitleGenerationOverrideReadDBError(t *testing.T db, chat, chatprovider.ProviderAPIKeys{ByProvider: map[string]string{"openai": "test-key"}}, + modelBuildOptions{}, ) require.NoError(t, err) require.NotNil(t, model) @@ -508,6 +525,7 @@ func TestResolveManualTitleModel_TitleGenerationOverrideSetUsable(t *testing.T) db, chat, chatprovider.ProviderAPIKeys{ByProvider: map[string]string{"openai": "test-key"}}, + modelBuildOptions{}, ) require.NoError(t, err) require.NotNil(t, model) @@ -535,6 +553,7 @@ func TestResolveManualTitleModel_TitleGenerationOverrideMissingCredentials(t *te db, chat, chatprovider.ProviderAPIKeys{}, + modelBuildOptions{}, ) require.Error(t, err) require.ErrorContains(t, err, "resolve manual title generation model override") @@ -562,6 +581,7 @@ func TestResolveManualTitleModel_TitleGenerationOverrideSetUnusable(t *testing.T db, chat, chatprovider.ProviderAPIKeys{ByProvider: map[string]string{"openai": "test-key"}}, + modelBuildOptions{}, ) require.Error(t, err) require.ErrorContains(t, err, "resolve manual title generation model override") diff --git a/codersdk/deployment.go b/codersdk/deployment.go index dab861590f..f9cdb8a8fc 100644 --- a/codersdk/deployment.go +++ b/codersdk/deployment.go @@ -4058,6 +4058,17 @@ Write out the current server config as YAML to stdout.`, Group: &deploymentGroupChat, YAML: "debugLoggingEnabled", }, + { + Name: "Chat: AI Gateway Routing Enabled", + Description: "Route chat model requests through AI Gateway when both chat routing and AI Gateway are enabled. Otherwise, chat calls AI providers directly. Pending chats without API key metadata may need a retry or temporary direct routing.", + Flag: "chat-ai-gateway-routing-enabled", + Env: "CODER_CHAT_AI_GATEWAY_ROUTING_ENABLED", + Value: &c.AI.Chat.AIGatewayRoutingEnabled, + Default: "true", + Group: &deploymentGroupChat, + YAML: "aiGatewayRoutingEnabled", + Hidden: true, + }, // AI Bridge Options (deprecated in favor of AI Gateway options) { Name: "AI Bridge Enabled", @@ -4714,8 +4725,9 @@ type AIBridgeProxyConfig struct { } type ChatConfig struct { - AcquireBatchSize serpent.Int64 `json:"acquire_batch_size" typescript:",notnull"` - DebugLoggingEnabled serpent.Bool `json:"debug_logging_enabled" typescript:",notnull"` + AcquireBatchSize serpent.Int64 `json:"acquire_batch_size" typescript:",notnull"` + DebugLoggingEnabled serpent.Bool `json:"debug_logging_enabled" typescript:",notnull"` + AIGatewayRoutingEnabled serpent.Bool `json:"ai_gateway_routing_enabled" typescript:",notnull" swaggerignore:"true"` } type AIConfig struct { diff --git a/codersdk/deployment_test.go b/codersdk/deployment_test.go index 25eeca630e..0287c0daa5 100644 --- a/codersdk/deployment_test.go +++ b/codersdk/deployment_test.go @@ -916,6 +916,15 @@ func TestRetentionConfigParsing(t *testing.T) { } } +func TestChatAIGatewayRoutingEnabledDefault(t *testing.T) { + t.Parallel() + + dv := codersdk.DeploymentValues{} + opts := dv.Options() + require.NoError(t, opts.SetDefaults()) + require.True(t, dv.AI.Chat.AIGatewayRoutingEnabled.Value()) +} + func TestAIBudgetConfigParsing(t *testing.T) { t.Parallel() diff --git a/enterprise/coderd/exp_chats_test.go b/enterprise/coderd/exp_chats_test.go index 89f315ae0b..d29240dd2e 100644 --- a/enterprise/coderd/exp_chats_test.go +++ b/enterprise/coderd/exp_chats_test.go @@ -77,6 +77,9 @@ func TestChatStreamRelay(t *testing.T) { Options: &coderdtest.Options{ Database: db, Pubsub: pubsub, + DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) { + require.NoError(t, dv.AI.Chat.AIGatewayRoutingEnabled.Set("false")) + }), }, LicenseOptions: &coderdenttest.LicenseOptions{ Features: license.Features{ @@ -89,6 +92,9 @@ func TestChatStreamRelay(t *testing.T) { Options: &coderdtest.Options{ Database: db, Pubsub: pubsub, + DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) { + require.NoError(t, dv.AI.Chat.AIGatewayRoutingEnabled.Set("false")) + }), }, DontAddLicense: true, DontAddFirstUser: true, @@ -219,6 +225,9 @@ func TestChatStreamRelay(t *testing.T) { Database: db, Pubsub: pubsub, TLSCertificates: certificates, + DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) { + require.NoError(t, dv.AI.Chat.AIGatewayRoutingEnabled.Set("false")) + }), }, LicenseOptions: &coderdenttest.LicenseOptions{ Features: license.Features{ @@ -232,6 +241,9 @@ func TestChatStreamRelay(t *testing.T) { Database: db, Pubsub: pubsub, TLSCertificates: certificates, + DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) { + require.NoError(t, dv.AI.Chat.AIGatewayRoutingEnabled.Set("false")) + }), }, DontAddLicense: true, DontAddFirstUser: true, @@ -398,6 +410,9 @@ func TestChatStreamRelay(t *testing.T) { Options: &coderdtest.Options{ Database: db, Pubsub: pubsub, + DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) { + require.NoError(t, dv.AI.Chat.AIGatewayRoutingEnabled.Set("false")) + }), }, LicenseOptions: &coderdenttest.LicenseOptions{ Features: license.Features{ @@ -410,6 +425,9 @@ func TestChatStreamRelay(t *testing.T) { Options: &coderdtest.Options{ Database: db, Pubsub: pubsub, + DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) { + require.NoError(t, dv.AI.Chat.AIGatewayRoutingEnabled.Set("false")) + }), }, DontAddLicense: true, DontAddFirstUser: true, @@ -544,6 +562,7 @@ func TestChatStreamRelay(t *testing.T) { db, pubsub := dbtestutil.NewDB(t) hostPrefixValues := coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) { + require.NoError(t, dv.AI.Chat.AIGatewayRoutingEnabled.Set("false")) dv.HTTPCookies.EnableHostPrefix = true dv.HTTPCookies.Secure = true }) @@ -696,6 +715,9 @@ func TestChatStreamRelay(t *testing.T) { Options: &coderdtest.Options{ Database: db, Pubsub: pubsub, + DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) { + require.NoError(t, dv.AI.Chat.AIGatewayRoutingEnabled.Set("false")) + }), }, LicenseOptions: &coderdenttest.LicenseOptions{ Features: license.Features{ @@ -708,6 +730,9 @@ func TestChatStreamRelay(t *testing.T) { Options: &coderdtest.Options{ Database: db, Pubsub: pubsub, + DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) { + require.NoError(t, dv.AI.Chat.AIGatewayRoutingEnabled.Set("false")) + }), }, DontAddLicense: true, DontAddFirstUser: true, diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index 29a28130a1..31369e3e3e 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -1613,6 +1613,7 @@ export interface ChatComputerUseProviderResponse { export interface ChatConfig { readonly acquire_batch_size: number; readonly debug_logging_enabled: boolean; + readonly ai_gateway_routing_enabled: boolean; } // From codersdk/chats.go