feat: route chatd provider traffic through aibridge (#25629)

## Summary

Routes chatd model calls backed by concrete AI Provider rows through the
in-process aibridge transport by default, with deployment options to use
direct provider routing when AI Gateway is disabled or chat AI Gateway
routing is disabled.

- Splits model routing into common, direct provider, and AI Gateway
paths behind a single deployment-mode entry point.
- Builds chatd models through explicit request, route, and options data.
Active API key attribution is passed explicitly instead of being hidden
inside generic model construction.
- For AI Gateway BYOK routes, resolves the user's provider key in chatd,
forwards it through provider-specific auth headers, and sets
`X-Coder-AI-Governance-Token` to the `delegated` marker so aibridge
preserves those headers while still stripping Coder-specific metadata.
- Keeps central provider credentials and deployment fallback credentials
out of forwarded provider auth headers, so AI Gateway central policy
remains authoritative.
- Redacts delegated provider auth from default string formatting to
avoid accidental plaintext logging of user BYOK credentials.
- Covers selected chat models, advisor overrides, title and quickgen
paths, subagent overrides, computer use model selection, and an
integration-style chat turn through the aibridge transport path.
- Persists initiating API key IDs on chat and queued user messages,
including subagent child messages, and fails closed for AI
Gateway-routed model builds without an active key.
- Removes unused `api_key_id` indexes while keeping the persistence
columns and foreign keys.
- Keeps the deployment option available through config and env parsing,
but hides it from CLI help and generated docs.
- Stabilizes the subagent poll fallback test so background CreateChat
processing cannot win the state transition under slower CI environments.

## Tests

- `go test ./coderd/x/chatd -run
'TestAIGatewayProviderAuthForUser|TestAIGatewayProviderAuthRedactsFormatting|TestResolveModelRouteForConfigAIGatewayProviderAuth|TestAIGatewayModelForwardsProviderAuth|TestProcessChat_AIGatewayRoutingUsesDelegatedAPIKey|TestAwaitSubagentCompletion'
-count=1`
- `go test ./coderd/aibridged -run
'TestServeHTTP_DelegatedAPIKey|TestServeHTTP_StripCoderToken' -count=1`
- `git diff --check HEAD~1..HEAD`
- `make lint`

> Mux working on behalf of Mike.
This commit is contained in:
Michael Suchacz
2026-05-26 21:31:52 +02:00
committed by GitHub
parent a56c88a0cc
commit 8b1705eb65
31 changed files with 2463 additions and 377 deletions
+5
View File
@@ -765,6 +765,11 @@ chat:
# opt-in settings.
# (default: false, type: bool)
debugLoggingEnabled: false
# Route chat model requests through AI Gateway when both chat routing and AI
# Gateway are enabled. Otherwise, chat calls AI providers directly. Pending chats
# without API key metadata may need a retry or temporary direct routing.
# (default: true, type: bool)
aiGatewayRoutingEnabled: true
aibridge:
# Deprecated: use --ai-gateway-enabled or CODER_AI_GATEWAY_ENABLED instead.
# Whether to start an in-memory aibridged instance.
+5
View File
@@ -807,6 +807,9 @@ func New(options *Options) *API {
providerAPIKeys = *options.ChatProviderAPIKeys
}
chatAIGatewayRoutingEnabled := options.DeploymentValues.AI.BridgeConfig.Enabled.Value() &&
options.DeploymentValues.AI.Chat.AIGatewayRoutingEnabled.Value()
api.chatDaemon = chatd.New(chatd.Config{
Logger: options.Logger.Named("chatd"),
Database: options.Database,
@@ -816,6 +819,8 @@ func New(options *Options) *API {
ProviderAPIKeys: providerAPIKeys,
AllowBYOK: options.DeploymentValues.AI.BridgeConfig.AllowBYOK.Value(),
AllowBYOKSet: true,
AIBridgeTransportFactory: &api.AIBridgeTransportFactory,
AIGatewayRoutingEnabled: chatAIGatewayRoutingEnabled,
AlwaysEnableDebugLogs: options.DeploymentValues.AI.Chat.DebugLoggingEnabled.Value(),
AgentConn: api.agentProvider.AgentConn,
AgentInactiveDisconnectTimeout: api.AgentInactiveDisconnectTimeout,
+1
View File
@@ -119,6 +119,7 @@ func ChatMessage(t testing.TB, db database.Store, seed database.ChatMessage) dat
msgs, err := db.InsertChatMessages(genCtx, database.InsertChatMessagesParams{
ChatID: seed.ChatID,
CreatedBy: []uuid.UUID{seed.CreatedBy.UUID},
APIKeyID: []string{seed.APIKeyID.String},
ModelConfigID: []uuid.UUID{seed.ModelConfigID.UUID},
Role: []database.ChatMessageRole{takeFirst(seed.Role, database.ChatMessageRoleUser)},
Content: []string{content},
+10 -2
View File
@@ -1554,7 +1554,8 @@ CREATE TABLE chat_messages (
total_cost_micros bigint,
runtime_ms bigint,
deleted boolean DEFAULT false NOT NULL,
provider_response_id text
provider_response_id text,
api_key_id text
);
CREATE SEQUENCE chat_messages_id_seq
@@ -1593,7 +1594,8 @@ CREATE TABLE chat_queued_messages (
chat_id uuid NOT NULL,
content jsonb NOT NULL,
created_at timestamp with time zone DEFAULT now() NOT NULL,
model_config_id uuid
model_config_id uuid,
api_key_id text
);
CREATE SEQUENCE chat_queued_messages_id_seq
@@ -4453,6 +4455,9 @@ ALTER TABLE ONLY chat_files
ALTER TABLE ONLY chat_files
ADD CONSTRAINT chat_files_owner_id_fkey FOREIGN KEY (owner_id) REFERENCES users(id) ON DELETE CASCADE;
ALTER TABLE ONLY chat_messages
ADD CONSTRAINT chat_messages_api_key_id_fkey FOREIGN KEY (api_key_id) REFERENCES api_keys(id) ON DELETE SET NULL;
ALTER TABLE ONLY chat_messages
ADD CONSTRAINT chat_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
@@ -4468,6 +4473,9 @@ ALTER TABLE ONLY chat_model_configs
ALTER TABLE ONLY chat_model_configs
ADD CONSTRAINT chat_model_configs_updated_by_fkey FOREIGN KEY (updated_by) REFERENCES users(id);
ALTER TABLE ONLY chat_queued_messages
ADD CONSTRAINT chat_queued_messages_api_key_id_fkey FOREIGN KEY (api_key_id) REFERENCES api_keys(id) ON DELETE SET NULL;
ALTER TABLE ONLY chat_queued_messages
ADD CONSTRAINT chat_queued_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
@@ -21,11 +21,13 @@ const (
ForeignKeyChatFileLinksFileID ForeignKeyConstraint = "chat_file_links_file_id_fkey" // ALTER TABLE ONLY chat_file_links ADD CONSTRAINT chat_file_links_file_id_fkey FOREIGN KEY (file_id) REFERENCES chat_files(id) ON DELETE CASCADE;
ForeignKeyChatFilesOrganizationID ForeignKeyConstraint = "chat_files_organization_id_fkey" // ALTER TABLE ONLY chat_files ADD CONSTRAINT chat_files_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE;
ForeignKeyChatFilesOwnerID ForeignKeyConstraint = "chat_files_owner_id_fkey" // ALTER TABLE ONLY chat_files ADD CONSTRAINT chat_files_owner_id_fkey FOREIGN KEY (owner_id) REFERENCES users(id) ON DELETE CASCADE;
ForeignKeyChatMessagesAPIKeyID ForeignKeyConstraint = "chat_messages_api_key_id_fkey" // ALTER TABLE ONLY chat_messages ADD CONSTRAINT chat_messages_api_key_id_fkey FOREIGN KEY (api_key_id) REFERENCES api_keys(id) ON DELETE SET NULL;
ForeignKeyChatMessagesChatID ForeignKeyConstraint = "chat_messages_chat_id_fkey" // ALTER TABLE ONLY chat_messages ADD CONSTRAINT chat_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
ForeignKeyChatMessagesModelConfigID ForeignKeyConstraint = "chat_messages_model_config_id_fkey" // ALTER TABLE ONLY chat_messages ADD CONSTRAINT chat_messages_model_config_id_fkey FOREIGN KEY (model_config_id) REFERENCES chat_model_configs(id);
ForeignKeyChatModelConfigsAiProviderID ForeignKeyConstraint = "chat_model_configs_ai_provider_id_fkey" // ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_ai_provider_id_fkey FOREIGN KEY (ai_provider_id) REFERENCES ai_providers(id);
ForeignKeyChatModelConfigsCreatedBy ForeignKeyConstraint = "chat_model_configs_created_by_fkey" // ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_created_by_fkey FOREIGN KEY (created_by) REFERENCES users(id);
ForeignKeyChatModelConfigsUpdatedBy ForeignKeyConstraint = "chat_model_configs_updated_by_fkey" // ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_updated_by_fkey FOREIGN KEY (updated_by) REFERENCES users(id);
ForeignKeyChatQueuedMessagesAPIKeyID ForeignKeyConstraint = "chat_queued_messages_api_key_id_fkey" // ALTER TABLE ONLY chat_queued_messages ADD CONSTRAINT chat_queued_messages_api_key_id_fkey FOREIGN KEY (api_key_id) REFERENCES api_keys(id) ON DELETE SET NULL;
ForeignKeyChatQueuedMessagesChatID ForeignKeyConstraint = "chat_queued_messages_chat_id_fkey" // ALTER TABLE ONLY chat_queued_messages ADD CONSTRAINT chat_queued_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
ForeignKeyChatsAgentID ForeignKeyConstraint = "chats_agent_id_fkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_agent_id_fkey FOREIGN KEY (agent_id) REFERENCES workspace_agents(id) ON DELETE SET NULL;
ForeignKeyChatsBuildID ForeignKeyConstraint = "chats_build_id_fkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_build_id_fkey FOREIGN KEY (build_id) REFERENCES workspace_builds(id) ON DELETE SET NULL;
@@ -0,0 +1,5 @@
ALTER TABLE chat_queued_messages
DROP COLUMN api_key_id;
ALTER TABLE chat_messages
DROP COLUMN api_key_id;
@@ -0,0 +1,8 @@
-- Preserve chat history when API keys are deleted. Pending work whose latest
-- user turn loses this attribution will fail closed under AI Gateway routing;
-- operators can retry the turn or temporarily use direct routing.
ALTER TABLE chat_messages
ADD COLUMN api_key_id text REFERENCES api_keys(id) ON DELETE SET NULL;
ALTER TABLE chat_queued_messages
ADD COLUMN api_key_id text REFERENCES api_keys(id) ON DELETE SET NULL;
+2
View File
@@ -4699,6 +4699,7 @@ type ChatMessage struct {
RuntimeMs sql.NullInt64 `db:"runtime_ms" json:"runtime_ms"`
Deleted bool `db:"deleted" json:"deleted"`
ProviderResponseID sql.NullString `db:"provider_response_id" json:"provider_response_id"`
APIKeyID sql.NullString `db:"api_key_id" json:"api_key_id"`
}
type ChatModelConfig struct {
@@ -4726,6 +4727,7 @@ type ChatQueuedMessage struct {
Content json.RawMessage `db:"content" json:"content"`
CreatedAt time.Time `db:"created_at" json:"created_at"`
ModelConfigID uuid.NullUUID `db:"model_config_id" json:"model_config_id"`
APIKeyID sql.NullString `db:"api_key_id" json:"api_key_id"`
}
type ChatTable struct {
+48 -26
View File
@@ -7157,7 +7157,7 @@ func (q *sqlQuerier) GetChatDiffStatusesByChatIDs(ctx context.Context, chatIds [
const getChatMessageByID = `-- name: GetChatMessageByID :one
SELECT
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id, api_key_id
FROM
chat_messages
WHERE
@@ -7190,6 +7190,7 @@ func (q *sqlQuerier) GetChatMessageByID(ctx context.Context, id int64) (ChatMess
&i.RuntimeMs,
&i.Deleted,
&i.ProviderResponseID,
&i.APIKeyID,
)
return i, err
}
@@ -7279,7 +7280,7 @@ func (q *sqlQuerier) GetChatMessageSummariesPerChat(ctx context.Context, created
const getChatMessagesByChatID = `-- name: GetChatMessagesByChatID :many
SELECT
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id, api_key_id
FROM
chat_messages
WHERE
@@ -7327,6 +7328,7 @@ func (q *sqlQuerier) GetChatMessagesByChatID(ctx context.Context, arg GetChatMes
&i.RuntimeMs,
&i.Deleted,
&i.ProviderResponseID,
&i.APIKeyID,
); err != nil {
return nil, err
}
@@ -7343,7 +7345,7 @@ func (q *sqlQuerier) GetChatMessagesByChatID(ctx context.Context, arg GetChatMes
const getChatMessagesByChatIDAscPaginated = `-- name: GetChatMessagesByChatIDAscPaginated :many
SELECT
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id, api_key_id
FROM
chat_messages
WHERE
@@ -7394,6 +7396,7 @@ func (q *sqlQuerier) GetChatMessagesByChatIDAscPaginated(ctx context.Context, ar
&i.RuntimeMs,
&i.Deleted,
&i.ProviderResponseID,
&i.APIKeyID,
); err != nil {
return nil, err
}
@@ -7410,7 +7413,7 @@ func (q *sqlQuerier) GetChatMessagesByChatIDAscPaginated(ctx context.Context, ar
const getChatMessagesByChatIDDescPaginated = `-- name: GetChatMessagesByChatIDDescPaginated :many
SELECT
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id, api_key_id
FROM
chat_messages
WHERE
@@ -7474,6 +7477,7 @@ func (q *sqlQuerier) GetChatMessagesByChatIDDescPaginated(ctx context.Context, a
&i.RuntimeMs,
&i.Deleted,
&i.ProviderResponseID,
&i.APIKeyID,
); err != nil {
return nil, err
}
@@ -7506,7 +7510,7 @@ WITH latest_compressed_summary AS (
1
)
SELECT
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id, api_key_id
FROM
chat_messages
WHERE
@@ -7578,6 +7582,7 @@ func (q *sqlQuerier) GetChatMessagesForPromptByChatID(ctx context.Context, chatI
&i.RuntimeMs,
&i.Deleted,
&i.ProviderResponseID,
&i.APIKeyID,
); err != nil {
return nil, err
}
@@ -7639,7 +7644,7 @@ func (q *sqlQuerier) GetChatModelConfigsForTelemetry(ctx context.Context) ([]Get
}
const getChatQueuedMessages = `-- name: GetChatQueuedMessages :many
SELECT id, chat_id, content, created_at, model_config_id FROM chat_queued_messages
SELECT id, chat_id, content, created_at, model_config_id, api_key_id FROM chat_queued_messages
WHERE chat_id = $1
ORDER BY created_at ASC, id ASC
`
@@ -7659,6 +7664,7 @@ func (q *sqlQuerier) GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID
&i.Content,
&i.CreatedAt,
&i.ModelConfigID,
&i.APIKeyID,
); err != nil {
return nil, err
}
@@ -8354,7 +8360,7 @@ func (q *sqlQuerier) GetChildChatsByParentIDs(ctx context.Context, arg GetChildC
const getLastChatMessageByRole = `-- name: GetLastChatMessageByRole :one
SELECT
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id, api_key_id
FROM
chat_messages
WHERE
@@ -8397,6 +8403,7 @@ func (q *sqlQuerier) GetLastChatMessageByRole(ctx context.Context, arg GetLastCh
&i.RuntimeMs,
&i.Deleted,
&i.ProviderResponseID,
&i.APIKeyID,
)
return i, err
}
@@ -8709,7 +8716,7 @@ WITH updated_chat AS (
SET
last_model_config_id = (
SELECT val
FROM UNNEST($3::uuid[])
FROM UNNEST($4::uuid[])
WITH ORDINALITY AS t(val, ord)
WHERE val != '00000000-0000-0000-0000-000000000000'::uuid
ORDER BY ord DESC
@@ -8719,12 +8726,12 @@ WITH updated_chat AS (
id = $1::uuid
AND EXISTS (
SELECT 1
FROM UNNEST($3::uuid[])
FROM UNNEST($4::uuid[])
WHERE unnest != '00000000-0000-0000-0000-000000000000'::uuid
)
AND chats.last_model_config_id IS DISTINCT FROM (
SELECT val
FROM UNNEST($3::uuid[])
FROM UNNEST($4::uuid[])
WITH ORDINALITY AS t(val, ord)
WHERE val != '00000000-0000-0000-0000-000000000000'::uuid
ORDER BY ord DESC
@@ -8734,6 +8741,7 @@ WITH updated_chat AS (
INSERT INTO chat_messages (
chat_id,
created_by,
api_key_id,
model_config_id,
role,
content,
@@ -8754,29 +8762,31 @@ INSERT INTO chat_messages (
SELECT
$1::uuid,
NULLIF(UNNEST($2::uuid[]), '00000000-0000-0000-0000-000000000000'::uuid),
NULLIF(UNNEST($3::uuid[]), '00000000-0000-0000-0000-000000000000'::uuid),
UNNEST($4::chat_message_role[]),
UNNEST($5::text[])::jsonb,
UNNEST($6::smallint[]),
UNNEST($7::chat_message_visibility[]),
NULLIF(UNNEST($8::bigint[]), 0),
NULLIF(UNNEST($3::text[]), ''),
NULLIF(UNNEST($4::uuid[]), '00000000-0000-0000-0000-000000000000'::uuid),
UNNEST($5::chat_message_role[]),
UNNEST($6::text[])::jsonb,
UNNEST($7::smallint[]),
UNNEST($8::chat_message_visibility[]),
NULLIF(UNNEST($9::bigint[]), 0),
NULLIF(UNNEST($10::bigint[]), 0),
NULLIF(UNNEST($11::bigint[]), 0),
NULLIF(UNNEST($12::bigint[]), 0),
NULLIF(UNNEST($13::bigint[]), 0),
NULLIF(UNNEST($14::bigint[]), 0),
UNNEST($15::boolean[]),
NULLIF(UNNEST($16::bigint[]), 0),
NULLIF(UNNEST($15::bigint[]), 0),
UNNEST($16::boolean[]),
NULLIF(UNNEST($17::bigint[]), 0),
NULLIF(UNNEST($18::text[]), '')
NULLIF(UNNEST($18::bigint[]), 0),
NULLIF(UNNEST($19::text[]), '')
RETURNING
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id, api_key_id
`
type InsertChatMessagesParams struct {
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
CreatedBy []uuid.UUID `db:"created_by" json:"created_by"`
APIKeyID []string `db:"api_key_id" json:"api_key_id"`
ModelConfigID []uuid.UUID `db:"model_config_id" json:"model_config_id"`
Role []ChatMessageRole `db:"role" json:"role"`
Content []string `db:"content" json:"content"`
@@ -8799,6 +8809,7 @@ func (q *sqlQuerier) InsertChatMessages(ctx context.Context, arg InsertChatMessa
rows, err := q.db.QueryContext(ctx, insertChatMessages,
arg.ChatID,
pq.Array(arg.CreatedBy),
pq.Array(arg.APIKeyID),
pq.Array(arg.ModelConfigID),
pq.Array(arg.Role),
pq.Array(arg.Content),
@@ -8845,6 +8856,7 @@ func (q *sqlQuerier) InsertChatMessages(ctx context.Context, arg InsertChatMessa
&i.RuntimeMs,
&i.Deleted,
&i.ProviderResponseID,
&i.APIKeyID,
); err != nil {
return nil, err
}
@@ -8860,23 +8872,30 @@ func (q *sqlQuerier) InsertChatMessages(ctx context.Context, arg InsertChatMessa
}
const insertChatQueuedMessage = `-- name: InsertChatQueuedMessage :one
INSERT INTO chat_queued_messages (chat_id, content, model_config_id)
INSERT INTO chat_queued_messages (chat_id, content, model_config_id, api_key_id)
VALUES (
$1,
$2,
$3::uuid
$3::uuid,
$4::text
)
RETURNING id, chat_id, content, created_at, model_config_id
RETURNING id, chat_id, content, created_at, model_config_id, api_key_id
`
type InsertChatQueuedMessageParams struct {
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
Content json.RawMessage `db:"content" json:"content"`
ModelConfigID uuid.NullUUID `db:"model_config_id" json:"model_config_id"`
APIKeyID sql.NullString `db:"api_key_id" json:"api_key_id"`
}
func (q *sqlQuerier) InsertChatQueuedMessage(ctx context.Context, arg InsertChatQueuedMessageParams) (ChatQueuedMessage, error) {
row := q.db.QueryRowContext(ctx, insertChatQueuedMessage, arg.ChatID, arg.Content, arg.ModelConfigID)
row := q.db.QueryRowContext(ctx, insertChatQueuedMessage,
arg.ChatID,
arg.Content,
arg.ModelConfigID,
arg.APIKeyID,
)
var i ChatQueuedMessage
err := row.Scan(
&i.ID,
@@ -8884,6 +8903,7 @@ func (q *sqlQuerier) InsertChatQueuedMessage(ctx context.Context, arg InsertChat
&i.Content,
&i.CreatedAt,
&i.ModelConfigID,
&i.APIKeyID,
)
return i, err
}
@@ -9108,7 +9128,7 @@ WHERE id = (
ORDER BY cqm.created_at ASC, cqm.id ASC
LIMIT 1
)
RETURNING id, chat_id, content, created_at, model_config_id
RETURNING id, chat_id, content, created_at, model_config_id, api_key_id
`
func (q *sqlQuerier) PopNextQueuedMessage(ctx context.Context, chatID uuid.UUID) (ChatQueuedMessage, error) {
@@ -9120,6 +9140,7 @@ func (q *sqlQuerier) PopNextQueuedMessage(ctx context.Context, chatID uuid.UUID)
&i.Content,
&i.CreatedAt,
&i.ModelConfigID,
&i.APIKeyID,
)
return i, err
}
@@ -10145,7 +10166,7 @@ SET
WHERE
id = $3::bigint
RETURNING
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id, api_key_id
`
type UpdateChatMessageByIDParams struct {
@@ -10179,6 +10200,7 @@ func (q *sqlQuerier) UpdateChatMessageByID(ctx context.Context, arg UpdateChatMe
&i.RuntimeMs,
&i.Deleted,
&i.ProviderResponseID,
&i.APIKeyID,
)
return i, err
}
+5 -2
View File
@@ -763,6 +763,7 @@ WITH updated_chat AS (
INSERT INTO chat_messages (
chat_id,
created_by,
api_key_id,
model_config_id,
role,
content,
@@ -783,6 +784,7 @@ INSERT INTO chat_messages (
SELECT
@chat_id::uuid,
NULLIF(UNNEST(@created_by::uuid[]), '00000000-0000-0000-0000-000000000000'::uuid),
NULLIF(UNNEST(@api_key_id::text[]), ''),
NULLIF(UNNEST(@model_config_id::uuid[]), '00000000-0000-0000-0000-000000000000'::uuid),
UNNEST(@role::chat_message_role[]),
UNNEST(@content::text[])::jsonb,
@@ -1688,11 +1690,12 @@ RETURNING
*;
-- name: InsertChatQueuedMessage :one
INSERT INTO chat_queued_messages (chat_id, content, model_config_id)
INSERT INTO chat_queued_messages (chat_id, content, model_config_id, api_key_id)
VALUES (
@chat_id,
@content,
sqlc.narg('model_config_id')::uuid
sqlc.narg('model_config_id')::uuid,
sqlc.narg('api_key_id')::text
)
RETURNING *;
+3
View File
@@ -1220,6 +1220,7 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) {
ClientType: clientType,
SystemPrompt: req.SystemPrompt,
InitialUserContent: contentBlocks,
APIKeyID: apiKey.ID,
MCPServerIDs: mcpServerIDs,
Labels: labels,
DynamicTools: dynamicToolsJSON,
@@ -3089,6 +3090,7 @@ func (api *API) postChatMessages(rw http.ResponseWriter, r *http.Request) {
CreatedBy: apiKey.UserID,
Content: contentBlocks,
ModelConfigID: modelConfigID,
APIKeyID: apiKey.ID,
BusyBehavior: busyBehavior,
PlanMode: sendPlanMode,
MCPServerIDs: req.MCPServerIDs,
@@ -3229,6 +3231,7 @@ func (api *API) patchChatMessage(rw http.ResponseWriter, r *http.Request) {
CreatedBy: apiKey.UserID,
EditedMessageID: messageID,
Content: contentBlocks,
APIKeyID: apiKey.ID,
ModelConfigID: editModelConfigID,
})
if editErr != nil {
+222 -10
View File
@@ -13,6 +13,7 @@ import (
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/aibridge"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/x/chatd/chatadvisor"
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
@@ -30,7 +31,11 @@ import (
type advisorOverrideStubStore struct {
database.Store
getEnabledChatModelConfigByID func(context.Context, uuid.UUID) (database.ChatModelConfig, error)
getEnabledChatModelConfigByID func(context.Context, uuid.UUID) (database.ChatModelConfig, error)
getAIProviderByID func(context.Context, uuid.UUID) (database.AIProvider, error)
getAIProviders func(context.Context, database.GetAIProvidersParams) ([]database.AIProvider, error)
getAIProviderKeysByProviderID func(context.Context, uuid.UUID) ([]database.AIProviderKey, error)
getAIProviderKeysByProviderIDs func(context.Context, []uuid.UUID) ([]database.AIProviderKey, error)
}
func (s *advisorOverrideStubStore) GetEnabledChatModelConfigByID(
@@ -43,6 +48,46 @@ func (s *advisorOverrideStubStore) GetEnabledChatModelConfigByID(
return s.getEnabledChatModelConfigByID(ctx, id)
}
func (s *advisorOverrideStubStore) GetAIProviderByID(
ctx context.Context,
id uuid.UUID,
) (database.AIProvider, error) {
if s.getAIProviderByID == nil {
return database.AIProvider{}, xerrors.New("unexpected GetAIProviderByID call")
}
return s.getAIProviderByID(ctx, id)
}
func (s *advisorOverrideStubStore) GetAIProviders(
ctx context.Context,
params database.GetAIProvidersParams,
) ([]database.AIProvider, error) {
if s.getAIProviders == nil {
return nil, xerrors.New("unexpected GetAIProviders call")
}
return s.getAIProviders(ctx, params)
}
func (s *advisorOverrideStubStore) GetAIProviderKeysByProviderID(
ctx context.Context,
providerID uuid.UUID,
) ([]database.AIProviderKey, error) {
if s.getAIProviderKeysByProviderID == nil {
return nil, xerrors.New("unexpected GetAIProviderKeysByProviderID call")
}
return s.getAIProviderKeysByProviderID(ctx, providerID)
}
func (s *advisorOverrideStubStore) GetAIProviderKeysByProviderIDs(
ctx context.Context,
providerIDs []uuid.UUID,
) ([]database.AIProviderKey, error) {
if s.getAIProviderKeysByProviderIDs == nil {
return nil, xerrors.New("unexpected GetAIProviderKeysByProviderIDs call")
}
return s.getAIProviderKeysByProviderIDs(ctx, providerIDs)
}
func newAdvisorTestServer(
ctx context.Context,
t *testing.T,
@@ -56,6 +101,60 @@ func newAdvisorTestServer(
}
}
func (p *Server) resolveAdvisorModelOverrideOrFallback(
ctx context.Context,
chat database.Chat,
advisorCfg codersdk.AdvisorConfig,
fallbackModel fantasy.LanguageModel,
fallbackCallConfig codersdk.ChatModelCallConfig,
providerKeys chatprovider.ProviderAPIKeys,
modelOpts modelBuildOptions,
logger slog.Logger,
) (fantasy.LanguageModel, codersdk.ChatModelCallConfig) {
model, cfg, err := p.resolveAdvisorModelOverride(
ctx,
chat,
advisorCfg,
fallbackModel,
fallbackCallConfig,
providerKeys,
modelOpts,
logger,
)
if err != nil {
logger.Warn(ctx, "failed to resolve advisor model override, continuing with chat model", slog.Error(err))
return fallbackModel, fallbackCallConfig
}
return model, cfg
}
func (p *Server) newAdvisorRuntimeOrFallback(
ctx context.Context,
chat database.Chat,
advisorCfg codersdk.AdvisorConfig,
fallbackModel fantasy.LanguageModel,
fallbackCallConfig codersdk.ChatModelCallConfig,
providerKeys chatprovider.ProviderAPIKeys,
modelOpts modelBuildOptions,
logger slog.Logger,
) *chatadvisor.Runtime {
rt, err := p.newAdvisorRuntime(
ctx,
chat,
advisorCfg,
fallbackModel,
fallbackCallConfig,
providerKeys,
modelOpts,
logger,
)
if err != nil {
logger.Warn(ctx, "failed to create advisor runtime, continuing without advisor", slog.Error(err))
return nil
}
return rt
}
// TestResolveAdvisorModelOverride covers the early-return, each fallback
// branch, and the success path. Prior tests only hit the ModelConfigID ==
// uuid.Nil early return, so the override body never executed.
@@ -73,13 +172,14 @@ func TestResolveAdvisorModelOverride(t *testing.T) {
store := &advisorOverrideStubStore{}
p := newAdvisorTestServer(ctx, t, store)
gotModel, gotCfg := p.resolveAdvisorModelOverride(
gotModel, gotCfg := p.resolveAdvisorModelOverrideOrFallback(
ctx,
database.Chat{},
codersdk.AdvisorConfig{},
fallbackModel,
fallbackCallConfig,
chatprovider.ProviderAPIKeys{},
modelBuildOptions{},
logger,
)
require.Equal(t, fallbackModel, gotModel)
@@ -96,13 +196,14 @@ func TestResolveAdvisorModelOverride(t *testing.T) {
}
p := newAdvisorTestServer(ctx, t, store)
gotModel, gotCfg := p.resolveAdvisorModelOverride(
gotModel, gotCfg := p.resolveAdvisorModelOverrideOrFallback(
ctx,
database.Chat{},
codersdk.AdvisorConfig{ModelConfigID: uuid.New()},
fallbackModel,
fallbackCallConfig,
chatprovider.ProviderAPIKeys{OpenAI: "sk-test"},
modelBuildOptions{},
logger,
)
require.Equal(t, fallbackModel, gotModel)
@@ -125,13 +226,14 @@ func TestResolveAdvisorModelOverride(t *testing.T) {
}
p := newAdvisorTestServer(ctx, t, store)
gotModel, gotCfg := p.resolveAdvisorModelOverride(
gotModel, gotCfg := p.resolveAdvisorModelOverrideOrFallback(
ctx,
database.Chat{},
codersdk.AdvisorConfig{ModelConfigID: uuid.New()},
fallbackModel,
fallbackCallConfig,
chatprovider.ProviderAPIKeys{OpenAI: "sk-test"},
modelBuildOptions{},
logger,
)
require.Equal(t, fallbackModel, gotModel)
@@ -158,13 +260,14 @@ func TestResolveAdvisorModelOverride(t *testing.T) {
}
p := newAdvisorTestServer(ctx, t, store)
gotModel, gotCfg := p.resolveAdvisorModelOverride(
gotModel, gotCfg := p.resolveAdvisorModelOverrideOrFallback(
ctx,
database.Chat{},
codersdk.AdvisorConfig{ModelConfigID: configID},
fallbackModel,
fallbackCallConfig,
chatprovider.ProviderAPIKeys{OpenAI: "sk-test"},
modelBuildOptions{},
logger,
)
require.Equal(t, fallbackModel, gotModel)
@@ -175,6 +278,7 @@ func TestResolveAdvisorModelOverride(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
configID := uuid.New()
providerID := uuid.New()
store := &advisorOverrideStubStore{
getEnabledChatModelConfigByID: func(context.Context, uuid.UUID) (database.ChatModelConfig, error) {
return database.ChatModelConfig{
@@ -187,16 +291,27 @@ func TestResolveAdvisorModelOverride(t *testing.T) {
DisplayName: "gpt-5.2",
}, nil
},
getAIProviders: func(context.Context, database.GetAIProvidersParams) ([]database.AIProvider, error) {
return []database.AIProvider{{
ID: providerID,
Type: database.AiProviderTypeOpenai,
Enabled: true,
}}, nil
},
getAIProviderKeysByProviderIDs: func(context.Context, []uuid.UUID) ([]database.AIProviderKey, error) {
return nil, nil
},
}
p := newAdvisorTestServer(ctx, t, store)
gotModel, gotCfg := p.resolveAdvisorModelOverride(
gotModel, gotCfg := p.resolveAdvisorModelOverrideOrFallback(
ctx,
database.Chat{},
codersdk.AdvisorConfig{ModelConfigID: configID},
fallbackModel,
fallbackCallConfig,
chatprovider.ProviderAPIKeys{},
modelBuildOptions{},
logger,
)
require.Equal(t, fallbackModel, gotModel)
@@ -227,13 +342,14 @@ func TestResolveAdvisorModelOverride(t *testing.T) {
}
p := newAdvisorTestServer(ctx, t, store)
gotModel, gotCfg := p.resolveAdvisorModelOverride(
gotModel, gotCfg := p.resolveAdvisorModelOverrideOrFallback(
ctx,
database.Chat{},
codersdk.AdvisorConfig{ModelConfigID: configID},
fallbackModel,
fallbackCallConfig,
chatprovider.ProviderAPIKeys{OpenAI: "sk-test"},
modelBuildOptions{},
logger,
)
require.NotEqual(t, fantasy.LanguageModel(fallbackModel), gotModel,
@@ -247,6 +363,99 @@ func TestResolveAdvisorModelOverride(t *testing.T) {
require.NotNil(t, gotCfg.Temperature)
require.InDelta(t, 0.42, *gotCfg.Temperature, 1e-9)
})
t.Run("AIProviderIDResolvesOverrideProviderKeys", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
configID := uuid.New()
providerID := uuid.New()
store := &advisorOverrideStubStore{
getEnabledChatModelConfigByID: func(context.Context, uuid.UUID) (database.ChatModelConfig, error) {
return database.ChatModelConfig{
ID: configID,
Provider: "openai",
Model: "gpt-5.2",
Enabled: true,
CreatedAt: time.Unix(0, 0).UTC(),
UpdatedAt: time.Unix(0, 0).UTC(),
DisplayName: "gpt-5.2",
AIProviderID: uuid.NullUUID{UUID: providerID, Valid: true},
}, nil
},
getAIProviderByID: func(context.Context, uuid.UUID) (database.AIProvider, error) {
return database.AIProvider{
ID: providerID,
Type: database.AiProviderTypeOpenai,
Enabled: true,
}, nil
},
getAIProviderKeysByProviderID: func(context.Context, uuid.UUID) ([]database.AIProviderKey, error) {
return []database.AIProviderKey{{
ProviderID: providerID,
APIKey: "sk-selected",
}}, nil
},
}
p := newAdvisorTestServer(ctx, t, store)
gotModel, gotCfg := p.resolveAdvisorModelOverrideOrFallback(
ctx,
database.Chat{},
codersdk.AdvisorConfig{ModelConfigID: configID},
fallbackModel,
fallbackCallConfig,
chatprovider.ProviderAPIKeys{},
modelBuildOptions{},
logger,
)
require.NotEqual(t, fantasy.LanguageModel(fallbackModel), gotModel)
require.NotNil(t, gotModel)
require.Equal(t, "openai", gotModel.Provider())
require.Equal(t, "gpt-5.2", gotModel.Model())
require.Equal(t, fallbackCallConfig, gotCfg)
})
}
func TestResolveAdvisorModelOverridePromotesAIBridgeErrors(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
configID := uuid.New()
providerID := uuid.New()
store := &advisorOverrideStubStore{
getEnabledChatModelConfigByID: func(context.Context, uuid.UUID) (database.ChatModelConfig, error) {
return database.ChatModelConfig{
ID: configID,
Provider: "openai",
Model: "gpt-5.2",
Enabled: true,
DisplayName: "gpt-5.2",
AIProviderID: uuid.NullUUID{UUID: providerID, Valid: true},
}, nil
},
getAIProviderByID: func(context.Context, uuid.UUID) (database.AIProvider, error) {
return database.AIProvider{ID: providerID, Type: database.AiProviderTypeOpenai, Name: "primary-openai", Enabled: true}, nil
},
getAIProviderKeysByProviderID: func(context.Context, uuid.UUID) ([]database.AIProviderKey, error) {
return []database.AIProviderKey{{ProviderID: providerID, APIKey: "sk-selected"}}, nil
},
}
p := newAdvisorTestServer(ctx, t, store)
p.aiGatewayRoutingEnabled = true
ctx = aibridge.WithDelegatedAPIKeyID(ctx, uuid.NewString())
model, _, err := p.resolveAdvisorModelOverride(
ctx,
database.Chat{ID: uuid.New(), OwnerID: uuid.New()},
codersdk.AdvisorConfig{ModelConfigID: configID},
&chattest.FakeModel{ProviderName: "stub", ModelName: "stub"},
codersdk.ChatModelCallConfig{},
chatprovider.ProviderAPIKeys{},
modelBuildOptions{ActiveAPIKeyID: uuid.NewString()},
slog.Make(),
)
require.ErrorContains(t, err, "AI Gateway transport factory")
require.Nil(t, model)
}
// TestStripAdvisorGuidanceBlock exercises the filter that keeps the advisor
@@ -346,7 +555,7 @@ func TestNewAdvisorRuntime(t *testing.T) {
store := &advisorOverrideStubStore{}
p := newAdvisorTestServer(ctx, t, store)
rt := p.newAdvisorRuntime(
rt := p.newAdvisorRuntimeOrFallback(
ctx,
database.Chat{},
codersdk.AdvisorConfig{
@@ -357,6 +566,7 @@ func TestNewAdvisorRuntime(t *testing.T) {
fallbackModel,
fallbackCallConfig,
chatprovider.ProviderAPIKeys{},
modelBuildOptions{},
logger,
)
require.NotNil(t, rt, "zero max uses must default rather than bail out")
@@ -370,7 +580,7 @@ func TestNewAdvisorRuntime(t *testing.T) {
store := &advisorOverrideStubStore{}
p := newAdvisorTestServer(ctx, t, store)
rt := p.newAdvisorRuntime(
rt := p.newAdvisorRuntimeOrFallback(
ctx,
database.Chat{},
codersdk.AdvisorConfig{
@@ -381,6 +591,7 @@ func TestNewAdvisorRuntime(t *testing.T) {
fallbackModel,
fallbackCallConfig,
chatprovider.ProviderAPIKeys{},
modelBuildOptions{},
logger,
)
require.Nil(t, rt, "negative max uses must disable the advisor")
@@ -392,7 +603,7 @@ func TestNewAdvisorRuntime(t *testing.T) {
store := &advisorOverrideStubStore{}
p := newAdvisorTestServer(ctx, t, store)
rt := p.newAdvisorRuntime(
rt := p.newAdvisorRuntimeOrFallback(
ctx,
database.Chat{},
codersdk.AdvisorConfig{
@@ -403,6 +614,7 @@ func TestNewAdvisorRuntime(t *testing.T) {
fallbackModel,
fallbackCallConfig,
chatprovider.ProviderAPIKeys{},
modelBuildOptions{},
logger,
)
require.NotNil(t, rt,
+239 -156
View File
@@ -29,6 +29,7 @@ import (
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/aibridge"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/coderd/database/dbauthz"
@@ -254,6 +255,9 @@ type Server struct {
metrics *chatloop.Metrics
recordingSem chan struct{}
aibridgeTransportFactory *atomic.Pointer[aibridge.TransportFactory]
aiGatewayRoutingEnabled bool
// Configuration
pendingChatAcquireInterval time.Duration
maxChatsPerAcquire int32
@@ -344,10 +348,11 @@ func (p *Server) resolveAdvisorModelOverride(
fallbackModel fantasy.LanguageModel,
fallbackCallConfig codersdk.ChatModelCallConfig,
providerKeys chatprovider.ProviderAPIKeys,
modelOpts modelBuildOptions,
logger slog.Logger,
) (fantasy.LanguageModel, codersdk.ChatModelCallConfig) {
) (fantasy.LanguageModel, codersdk.ChatModelCallConfig, error) {
if advisorCfg.ModelConfigID == uuid.Nil {
return fallbackModel, fallbackCallConfig
return fallbackModel, fallbackCallConfig, nil
}
// Re-read the override instead of using the cache so disabled models
@@ -363,7 +368,7 @@ func (p *Server) resolveAdvisorModelOverride(
"advisor model config is disabled or unavailable, continuing with chat model",
slog.F("model_config_id", advisorCfg.ModelConfigID),
)
return fallbackModel, fallbackCallConfig
return fallbackModel, fallbackCallConfig, nil
}
logger.Warn(
ctx,
@@ -371,7 +376,7 @@ func (p *Server) resolveAdvisorModelOverride(
slog.F("model_config_id", advisorCfg.ModelConfigID),
slog.Error(err),
)
return fallbackModel, fallbackCallConfig
return fallbackModel, fallbackCallConfig, nil
}
overrideCallConfig := codersdk.ChatModelCallConfig{}
@@ -383,29 +388,48 @@ func (p *Server) resolveAdvisorModelOverride(
slog.F("model_config_id", advisorCfg.ModelConfigID),
slog.Error(err),
)
return fallbackModel, fallbackCallConfig
return fallbackModel, fallbackCallConfig, nil
}
}
overrideModel, err := chatprovider.ModelFromConfig(
overrideConfig.Provider,
overrideConfig.Model,
route, err := p.resolveModelRouteForConfig(
ctx,
chat.OwnerID,
overrideConfig,
providerKeys,
chatprovider.UserAgent(),
chatprovider.CoderHeaders(chat),
nil,
)
if err != nil {
if p.shouldUseAIGatewayRouting() && overrideConfig.AIProviderID.Valid {
return nil, codersdk.ChatModelCallConfig{}, xerrors.Errorf("resolve advisor override route: %w", err)
}
logger.Warn(
ctx,
"failed to resolve advisor override route, continuing with chat model",
slog.F("model_config_id", advisorCfg.ModelConfigID),
slog.Error(err),
)
return fallbackModel, fallbackCallConfig, nil
}
overrideModel, err := p.newModel(ctx, modelClientRequest{
Chat: chat,
ModelName: overrideConfig.Model,
UserAgent: chatprovider.UserAgent(),
ExtraHeaders: chatprovider.CoderHeaders(chat),
}, route, modelOpts)
if err != nil {
if p.shouldUseAIGatewayRouting() && overrideConfig.AIProviderID.Valid {
return nil, codersdk.ChatModelCallConfig{}, xerrors.Errorf("create advisor override model: %w", err)
}
logger.Warn(
ctx,
"failed to create advisor override model, continuing with chat model",
slog.F("model_config_id", advisorCfg.ModelConfigID),
slog.Error(err),
)
return fallbackModel, fallbackCallConfig
return fallbackModel, fallbackCallConfig, nil
}
return overrideModel, overrideCallConfig
return overrideModel, overrideCallConfig, nil
}
func (p *Server) newAdvisorRuntime(
@@ -415,17 +439,22 @@ func (p *Server) newAdvisorRuntime(
fallbackModel fantasy.LanguageModel,
fallbackCallConfig codersdk.ChatModelCallConfig,
providerKeys chatprovider.ProviderAPIKeys,
modelOpts modelBuildOptions,
logger slog.Logger,
) *chatadvisor.Runtime {
advisorModel, advisorCallConfig := p.resolveAdvisorModelOverride(
) (*chatadvisor.Runtime, error) {
advisorModel, advisorCallConfig, err := p.resolveAdvisorModelOverride(
ctx,
chat,
advisorCfg,
fallbackModel,
fallbackCallConfig,
providerKeys,
modelOpts,
logger,
)
if err != nil {
return nil, err
}
maxUsesPerRun := advisorCfg.MaxUsesPerRun
switch {
@@ -441,7 +470,7 @@ func (p *Server) newAdvisorRuntime(
"invalid advisor max uses per run, continuing without advisor",
slog.F("max_uses_per_run", maxUsesPerRun),
)
return nil
return nil, nil //nolint:nilnil // Nil runtime with nil error means advisor is skipped for this turn.
}
maxOutputTokens := advisorCfg.MaxOutputTokens
@@ -468,9 +497,9 @@ func (p *Server) newAdvisorRuntime(
"failed to create advisor runtime, continuing without advisor",
slog.Error(err),
)
return nil
return nil, nil //nolint:nilnil // Nil runtime with nil error means advisor is skipped for this turn.
}
return rt
return rt, nil
}
// cachedWorkspaceMCPTools stores workspace MCP tools discovered
@@ -1436,6 +1465,7 @@ type CreateOptions struct {
ClientType database.ChatClientType
SystemPrompt string
InitialUserContent []codersdk.ChatMessagePart
APIKeyID string
MCPServerIDs []uuid.UUID
Labels database.StringMap
DynamicTools json.RawMessage
@@ -1460,6 +1490,7 @@ type SendMessageOptions struct {
CreatedBy uuid.UUID
Content []codersdk.ChatMessagePart
ModelConfigID uuid.UUID
APIKeyID string
BusyBehavior SendMessageBusyBehavior
PlanMode *database.NullChatPlanMode
MCPServerIDs *[]uuid.UUID
@@ -1479,6 +1510,7 @@ type EditMessageOptions struct {
CreatedBy uuid.UUID
EditedMessageID int64
Content []codersdk.ChatMessagePart
APIKeyID string
// ModelConfigID, when non-zero, overrides the model used for
// the replacement user message. When set to uuid.Nil the
// original message's model is preserved.
@@ -1647,7 +1679,7 @@ func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.C
database.ChatMessageVisibilityBoth,
opts.ModelConfigID,
chatprompt.CurrentContentVersion,
).withCreatedBy(opts.OwnerID))
).withCreatedBy(opts.OwnerID).withAPIKeyID(opts.APIKeyID))
_, err = tx.InsertChatMessages(ctx, msgParams)
if err != nil {
@@ -1786,6 +1818,10 @@ func (p *Server) SendMessage(
UUID: modelConfigID,
Valid: modelConfigID != uuid.Nil,
},
APIKeyID: sql.NullString{
String: opts.APIKeyID,
Valid: opts.APIKeyID != "",
},
})
if err != nil {
return xerrors.Errorf("insert queued message: %w", err)
@@ -1810,6 +1846,7 @@ func (p *Server) SendMessage(
modelConfigID,
content,
opts.CreatedBy,
opts.APIKeyID,
)
if err != nil {
return err
@@ -2083,7 +2120,7 @@ func (p *Server) EditMessage(
editedMsg.Visibility,
messageModelConfigID,
chatprompt.CurrentContentVersion,
).withCreatedBy(opts.CreatedBy))
).withCreatedBy(opts.CreatedBy).withAPIKeyID(opts.APIKeyID))
newMessages, err := insertChatMessageWithStore(ctx, tx, msgParams)
if err != nil {
return xerrors.Errorf("insert replacement message: %w", err)
@@ -2416,12 +2453,14 @@ func (p *Server) PromoteQueued(
var (
targetContent json.RawMessage
targetModelConfigID uuid.NullUUID
targetAPIKeyID sql.NullString
found bool
)
for _, qm := range queuedMessages {
if qm.ID == opts.QueuedMessageID {
targetContent = qm.Content
targetModelConfigID = qm.ModelConfigID
targetAPIKeyID = qm.APIKeyID
found = true
break
}
@@ -2511,6 +2550,7 @@ func (p *Server) PromoteQueued(
Valid: len(targetContent) > 0,
},
opts.CreatedBy,
targetAPIKeyID.String,
)
if err != nil {
return err
@@ -2754,6 +2794,7 @@ func (p *Server) SubmitToolResults(
params := database.InsertChatMessagesParams{
ChatID: opts.ChatID,
CreatedBy: make([]uuid.UUID, n),
APIKeyID: make([]string, n),
ModelConfigID: make([]uuid.UUID, n),
Role: make([]database.ChatMessageRole, n),
Content: make([]string, n),
@@ -2862,16 +2903,18 @@ var ErrManualTitleRegenerationInProgress = xerrors.New(
)
type manualTitleCandidateResult struct {
title string
modelConfig database.ChatModelConfig
usage fantasy.Usage
hasMessages bool
title string
modelConfig database.ChatModelConfig
usage fantasy.Usage
activeAPIKeyID string
hasMessages bool
}
type manualTitleGenerationError struct {
cause error
modelConfig database.ChatModelConfig
usage fantasy.Usage
cause error
modelConfig database.ChatModelConfig
usage fantasy.Usage
activeAPIKeyID string
}
func (e *manualTitleGenerationError) Error() string {
@@ -3105,6 +3148,7 @@ func (p *Server) recordManualTitleGenerationFailure(
chat,
generationErr.modelConfig,
generationErr.usage,
generationErr.activeAPIKeyID,
"",
); recordErr != nil {
return errors.Join(
@@ -3154,11 +3198,13 @@ func (p *Server) generateManualTitleCandidate(
if len(messages) == 0 {
return manualTitleCandidateResult{}, nil
}
modelOpts := modelBuildOptionsFromMessages(messages)
model, modelConfig, modelKeys, err := p.resolveManualTitleModel(ctx, store, chat, keys)
model, modelConfig, modelKeys, err := p.resolveManualTitleModel(ctx, store, chat, keys, modelOpts)
result := manualTitleCandidateResult{
modelConfig: modelConfig,
hasMessages: true,
modelConfig: modelConfig,
activeAPIKeyID: modelOpts.ActiveAPIKeyID,
hasMessages: true,
}
if err != nil {
return result, err
@@ -3174,6 +3220,7 @@ func (p *Server) generateManualTitleCandidate(
chat,
modelConfig,
modelKeys,
modelOpts,
messages,
model,
)
@@ -3189,9 +3236,10 @@ func (p *Server) generateManualTitleCandidate(
return result, wrappedErr
}
return result, &manualTitleGenerationError{
cause: wrappedErr,
modelConfig: modelConfig,
usage: usage,
cause: wrappedErr,
modelConfig: modelConfig,
usage: usage,
activeAPIKeyID: modelOpts.ActiveAPIKeyID,
}
}
@@ -3220,6 +3268,7 @@ func (p *Server) proposeChatTitleWithStore(
chat,
result.modelConfig,
result.usage,
result.activeAPIKeyID,
"",
); recordErr != nil {
return "", xerrors.Errorf("record manual title usage: %w", recordErr)
@@ -3250,6 +3299,7 @@ func (p *Server) regenerateChatTitleWithStore(
chat,
result.modelConfig,
result.usage,
result.activeAPIKeyID,
result.title,
)
if recordErr != nil {
@@ -3272,6 +3322,7 @@ func (p *Server) prepareManualTitleDebugRun(
chat database.Chat,
modelConfig database.ChatModelConfig,
keys chatprovider.ProviderAPIKeys,
modelOpts modelBuildOptions,
messages []database.ChatMessage,
fallbackModel fantasy.LanguageModel,
) (context.Context, fantasy.LanguageModel, func(error)) {
@@ -3279,15 +3330,21 @@ func (p *Server) prepareManualTitleDebugRun(
titleModel := fallbackModel
finishDebugRun := func(error) {}
httpClient := &http.Client{Transport: &chatdebug.RecordingTransport{}}
debugModel, debugModelErr := chatprovider.ModelFromConfig(
modelConfig.Provider,
modelConfig.Model,
keys,
chatprovider.UserAgent(),
chatprovider.CoderHeaders(chat),
httpClient,
)
route, routeErr := p.resolveModelRouteForConfig(ctx, chat.OwnerID, modelConfig, keys)
debugOpts := modelOpts
debugOpts.RecordHTTP = true
var debugModelErr error
var debugModel fantasy.LanguageModel
if routeErr != nil {
debugModelErr = routeErr
} else {
debugModel, debugModelErr = p.newModel(ctx, modelClientRequest{
Chat: chat,
ModelName: modelConfig.Model,
UserAgent: chatprovider.UserAgent(),
ExtraHeaders: chatprovider.CoderHeaders(chat),
}, route, debugOpts)
}
switch {
case debugModelErr != nil:
p.logger.Warn(ctx, "failed to create debug-aware manual title model",
@@ -3535,11 +3592,13 @@ func (p *Server) resolveManualTitleModel(
store database.Store,
chat database.Chat,
keys chatprovider.ProviderAPIKeys,
modelOpts modelBuildOptions,
) (fantasy.LanguageModel, database.ChatModelConfig, chatprovider.ProviderAPIKeys, error) {
overrideConfig, overrideModel, overrideKeys, overrideSet, overrideErr := p.resolveTitleGenerationModelOverride(
overrideConfig, overrideModel, overrideKeys, _, overrideSet, overrideErr := p.resolveTitleGenerationModelOverride(
ctx,
chat,
keys,
modelOpts,
)
if overrideErr != nil {
if overrideSet {
@@ -3562,15 +3621,15 @@ func (p *Server) resolveManualTitleModel(
slog.F("chat_id", chat.ID),
slog.Error(err),
)
return p.resolveFallbackManualTitleModel(ctx, chat, keys)
return p.resolveFallbackManualTitleModel(ctx, chat, keys, modelOpts)
}
config, ok := selectPreferredConfiguredShortTextModelConfig(configs)
if !ok {
return p.resolveFallbackManualTitleModel(ctx, chat, keys)
return p.resolveFallbackManualTitleModel(ctx, chat, keys, modelOpts)
}
providerHint, modelKeys, err := p.resolveModelConfigProviderHintAndKeys(ctx, chat.OwnerID, config, keys)
route, err := p.resolveModelRouteForConfig(ctx, chat.OwnerID, config, keys)
if err != nil {
p.logger.Debug(ctx, "manual title preferred model unavailable",
slog.F("chat_id", chat.ID),
@@ -3578,33 +3637,32 @@ func (p *Server) resolveManualTitleModel(
slog.F("model", config.Model),
slog.Error(err),
)
return p.resolveFallbackManualTitleModel(ctx, chat, keys)
return p.resolveFallbackManualTitleModel(ctx, chat, keys, modelOpts)
}
model, err := chatprovider.ModelFromConfig(
providerHint,
config.Model,
modelKeys,
chatprovider.UserAgent(),
chatprovider.CoderHeaders(chat),
nil,
)
model, err := p.newModel(ctx, modelClientRequest{
Chat: chat,
ModelName: config.Model,
UserAgent: chatprovider.UserAgent(),
ExtraHeaders: chatprovider.CoderHeaders(chat),
}, route, modelOpts)
if err != nil {
p.logger.Debug(ctx, "manual title preferred model unavailable",
slog.F("chat_id", chat.ID),
slog.F("provider", providerHint),
slog.F("provider", config.Provider),
slog.F("model", config.Model),
slog.Error(err),
)
return p.resolveFallbackManualTitleModel(ctx, chat, keys)
return p.resolveFallbackManualTitleModel(ctx, chat, keys, modelOpts)
}
return model, config, modelKeys, nil
return model, config, route.directProviderKeys(), nil
}
func (p *Server) resolveFallbackManualTitleModel(
ctx context.Context,
chat database.Chat,
keys chatprovider.ProviderAPIKeys,
modelOpts modelBuildOptions,
) (fantasy.LanguageModel, database.ChatModelConfig, chatprovider.ProviderAPIKeys, error) {
config, err := p.resolveModelConfig(ctx, chat)
if err != nil {
@@ -3613,25 +3671,23 @@ func (p *Server) resolveFallbackManualTitleModel(
err,
)
}
providerHint, modelKeys, err := p.resolveModelConfigProviderHintAndKeys(ctx, chat.OwnerID, config, keys)
route, err := p.resolveModelRouteForConfig(ctx, chat.OwnerID, config, keys)
if err != nil {
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, err
}
model, err := chatprovider.ModelFromConfig(
providerHint,
config.Model,
modelKeys,
chatprovider.UserAgent(),
chatprovider.CoderHeaders(chat),
nil,
)
model, err := p.newModel(ctx, modelClientRequest{
Chat: chat,
ModelName: config.Model,
UserAgent: chatprovider.UserAgent(),
ExtraHeaders: chatprovider.CoderHeaders(chat),
}, route, modelOpts)
if err != nil {
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, xerrors.Errorf(
"create fallback manual title model: %w",
err,
)
}
return model, config, modelKeys, nil
return model, config, route.directProviderKeys(), nil
}
func mergeManualTitleMessages(
@@ -3682,6 +3738,7 @@ func recordManualTitleUsage(
chat database.Chat,
modelConfig database.ChatModelConfig,
usage fantasy.Usage,
activeAPIKeyID string,
newTitle string,
) (database.Chat, error) {
hasUsage := usage != (fantasy.Usage{})
@@ -3720,6 +3777,7 @@ func recordManualTitleUsage(
messages, err := tx.InsertChatMessages(ctx, database.InsertChatMessagesParams{
ChatID: chat.ID,
CreatedBy: []uuid.UUID{chat.OwnerID},
APIKeyID: []string{activeAPIKeyID},
ModelConfigID: []uuid.UUID{modelConfig.ID},
Role: []database.ChatMessageRole{database.ChatMessageRoleAssistant},
Content: []string{content},
@@ -3845,6 +3903,7 @@ type chatMessage struct {
visibility database.ChatMessageVisibility
modelConfigID uuid.UUID
createdBy uuid.UUID
apiKeyID string
contentVersion int16
compressed bool
inputTokens int64
@@ -3880,6 +3939,11 @@ func (m chatMessage) withCreatedBy(id uuid.UUID) chatMessage {
return m
}
func (m chatMessage) withAPIKeyID(id string) chatMessage {
m.apiKeyID = id
return m
}
func (m chatMessage) withCompressed() chatMessage {
m.compressed = true
return m
@@ -3924,6 +3988,7 @@ func appendChatMessage(
msg chatMessage,
) {
params.CreatedBy = append(params.CreatedBy, msg.createdBy)
params.APIKeyID = append(params.APIKeyID, msg.apiKeyID)
params.ModelConfigID = append(params.ModelConfigID, msg.modelConfigID)
params.Role = append(params.Role, msg.role)
params.Content = append(params.Content, string(msg.content.RawMessage))
@@ -3973,6 +4038,7 @@ func insertUserMessageAndSetPending(
modelConfigID uuid.UUID,
content pqtype.NullRawMessage,
createdBy uuid.UUID,
apiKeyID string,
) (database.ChatMessage, database.Chat, error) {
msgParams := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendChatMessage.
ChatID: lockedChat.ID,
@@ -3983,7 +4049,7 @@ func insertUserMessageAndSetPending(
database.ChatMessageVisibilityBoth,
modelConfigID,
chatprompt.CurrentContentVersion,
).withCreatedBy(createdBy))
).withCreatedBy(createdBy).withAPIKeyID(apiKeyID))
messages, err := insertChatMessageWithStore(ctx, store, msgParams)
if err != nil {
return database.ChatMessage{}, database.Chat{}, err
@@ -4052,7 +4118,10 @@ type Config struct {
WebpushDispatcher webpush.Dispatcher
UsageTracker *workspacestats.UsageTracker
Clock quartz.Clock
PrometheusRegistry prometheus.Registerer
AIBridgeTransportFactory *atomic.Pointer[aibridge.TransportFactory]
AIGatewayRoutingEnabled bool
PrometheusRegistry prometheus.Registerer
// OIDCTokenSource resolves the calling user's OIDC access
// token for MCP servers configured with auth_type=user_oidc.
@@ -4139,6 +4208,8 @@ func New(cfg Config) *Server {
debugSvc.SetStaleAfter(inFlightChatStaleAfter * 3)
return debugSvc
},
aibridgeTransportFactory: cfg.AIBridgeTransportFactory,
aiGatewayRoutingEnabled: cfg.AIGatewayRoutingEnabled,
pendingChatAcquireInterval: pendingChatAcquireInterval,
maxChatsPerAcquire: maxChatsPerAcquire,
inFlightChatStaleAfter: inFlightChatStaleAfter,
@@ -5803,7 +5874,7 @@ func (p *Server) tryAutoPromoteQueuedMessage(
database.ChatMessageVisibilityBoth,
effectiveModelConfigID,
chatprompt.CurrentContentVersion,
).withCreatedBy(chat.OwnerID))
).withCreatedBy(chat.OwnerID).withAPIKeyID(nextQueued.APIKeyID.String))
msgs, err := insertChatMessageWithStore(ctx, tx, msgParams)
if err != nil {
return nil, nil, false, xerrors.Errorf("insert promoted message: %w", err)
@@ -6322,11 +6393,39 @@ type runChatResult struct {
ProviderKeys chatprovider.ProviderAPIKeys
PendingDynamicToolCalls []chatloop.PendingToolCall
FallbackProvider string
FallbackRoute resolvedModelRoute
FallbackModel string
ModelBuildOptions modelBuildOptions
TriggerMessageID int64
HistoryTipMessageID int64
}
func contextWithActiveTurnAPIKeyID(ctx context.Context, messages []database.ChatMessage) context.Context {
apiKeyID, ok := activeTurnAPIKeyIDFromMessages(messages)
if !ok {
return ctx
}
return aibridge.WithDelegatedAPIKeyID(ctx, apiKeyID)
}
func activeTurnAPIKeyIDFromMessages(messages []database.ChatMessage) (string, bool) {
for i := len(messages) - 1; i >= 0; i-- {
message := messages[i]
if message.Role != database.ChatMessageRoleUser {
continue
}
if message.Visibility != database.ChatMessageVisibilityBoth &&
message.Visibility != database.ChatMessageVisibilityUser {
continue
}
if !message.APIKeyID.Valid || message.APIKeyID.String == "" {
return "", false
}
return message.APIKeyID.String, true
}
return "", false
}
func allToolNames(allTools []fantasy.AgentTool) []string {
toolNames := make([]string, 0, len(allTools))
for _, tool := range allTools {
@@ -6948,12 +7047,20 @@ func (p *Server) runChat(
err error
debugEnabled bool
debugProvider string
modelRoute resolvedModelRoute
debugModel string
)
// Load MCP server configs and user tokens in parallel with
// model resolution and message loading. These queries have
// no dependencies on each other and all hit different tables.
messages, err = p.db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
if err != nil {
return result, xerrors.Errorf("get chat messages: %w", err)
}
modelOpts := modelBuildOptionsFromMessages(messages)
ctx = contextWithActiveTurnAPIKeyID(ctx, messages)
// Load MCP server configs and user tokens in parallel with model
// resolution. These queries have no dependencies on each other and all
// hit different tables.
var (
mcpConfigs []database.MCPServerConfig
mcpTokens []database.MCPServerUserToken
@@ -6961,7 +7068,7 @@ func (p *Server) runChat(
var g errgroup.Group
g.Go(func() error {
var err error
model, modelConfig, providerKeys, debugEnabled, debugProvider, debugModel, err = p.resolveChatModel(ctx, chat)
model, modelConfig, providerKeys, modelRoute, debugEnabled, debugProvider, debugModel, err = p.resolveChatModel(ctx, chat, modelOpts)
if err != nil {
return err
}
@@ -6972,14 +7079,6 @@ func (p *Server) runChat(
}
return nil
})
g.Go(func() error {
var err error
messages, err = p.db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
if err != nil {
return xerrors.Errorf("get chat messages: %w", err)
}
return nil
})
if len(chat.MCPServerIDs) > 0 {
g.Go(func() error {
var err error
@@ -7055,15 +7154,20 @@ func (p *Server) runChat(
// registering the runtime there would inject guidance for a tool
// that is never exposed to the model.
if advisorCfg.Enabled && isRootChat && !isPlanModeTurn && !isExploreSubagent {
advisorRuntime = p.newAdvisorRuntime(
var advisorErr error
advisorRuntime, advisorErr = p.newAdvisorRuntime(
ctx,
chat,
advisorCfg,
model,
callConfig,
providerKeys,
modelOpts,
logger,
)
if advisorErr != nil {
return result, advisorErr
}
}
var advisorPromptSnapshot []fantasy.Message
@@ -7090,7 +7194,9 @@ func (p *Server) runChat(
result.StatusLabelModel = model
result.ProviderKeys = providerKeys
result.FallbackProvider = modelConfig.Provider
result.FallbackRoute = modelRoute
result.FallbackModel = modelConfig.Model
result.ModelBuildOptions = modelOpts
debugSvc := p.existingDebugService()
// Fire title generation asynchronously so it doesn't block the
// chat response. It uses a detached context so it can finish
@@ -7111,7 +7217,9 @@ func (p *Server) runChat(
modelConfig.Provider,
modelConfig.Model,
titleModel,
modelRoute,
titleProviderKeys,
modelOpts,
generatedTitle,
titleLogger,
debugSvc,
@@ -7681,20 +7789,21 @@ func (p *Server) runChat(
}
if isComputerUse {
computerUseProviderKeys, keyErr := p.resolveUserProviderAPIKeysForProviderType(ctx, chat.OwnerID, computerUseModelProvider)
computerUseRoute, keyErr := p.resolveModelRouteForProviderType(ctx, chat.OwnerID, computerUseModelProvider)
if keyErr != nil {
return result, xerrors.Errorf("resolve computer use provider API keys: %w", keyErr)
return result, xerrors.Errorf("resolve computer use provider route: %w", keyErr)
}
providerKeys = computerUseProviderKeys
providerKeys = computerUseRoute.directProviderKeys()
// Override model for computer use subagent.
cuModel, cuDebugEnabled, resolvedProvider, resolvedModel, cuErr := p.resolveComputerUseModel(
ctx,
chat,
providerKeys,
computerUseRoute,
computerUseProvider,
computerUseModelProvider,
computerUseModelName,
modelOpts,
)
if cuErr != nil {
return result, cuErr
@@ -8394,45 +8503,15 @@ func (p *Server) persistChatContextSummary(
return nil
}
func (p *Server) resolveModelConfigProviderHintAndKeys(
ctx context.Context,
ownerID uuid.UUID,
modelConfig database.ChatModelConfig,
fallbackKeys chatprovider.ProviderAPIKeys,
) (string, chatprovider.ProviderAPIKeys, error) {
providerHint := modelConfig.Provider
if !modelConfig.AIProviderID.Valid {
if !fallbackKeys.Empty() && userCanUseProviderKeys(fallbackKeys, providerHint) {
return providerHint, fallbackKeys, nil
}
keys, err := p.resolveUserProviderAPIKeys(ctx, ownerID, uuid.Nil)
if err != nil {
return "", chatprovider.ProviderAPIKeys{}, xerrors.Errorf("resolve provider API keys: %w", err)
}
return providerHint, keys, nil
}
//nolint:gocritic // Manual title generation needs chatd-scoped provider reads for user-owned chats.
provider, err := p.db.GetAIProviderByID(dbauthz.AsChatd(ctx), modelConfig.AIProviderID.UUID)
if err != nil {
return "", chatprovider.ProviderAPIKeys{}, xerrors.Errorf("get AI provider: %w", err)
}
if !provider.Enabled {
return "", chatprovider.ProviderAPIKeys{}, xerrors.Errorf("AI provider %s is disabled", provider.ID)
}
providerKeys, err := p.resolveUserProviderAPIKeysForProvider(ctx, ownerID, provider)
if err != nil {
return "", chatprovider.ProviderAPIKeys{}, xerrors.Errorf("resolve provider API keys: %w", err)
}
return string(provider.Type), providerKeys, nil
}
func (p *Server) resolveChatModel(
ctx context.Context,
chat database.Chat,
modelOpts modelBuildOptions,
) (
model fantasy.LanguageModel,
dbConfig database.ChatModelConfig,
keys chatprovider.ProviderAPIKeys,
route resolvedModelRoute,
debugEnabled bool,
resolvedProvider string,
resolvedModel string,
@@ -8440,57 +8519,45 @@ func (p *Server) resolveChatModel(
) {
dbConfig, err = p.resolveModelConfig(ctx, chat)
if err != nil {
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, false, "", "", xerrors.Errorf("resolve model config: %w", err)
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, resolvedModelRoute{}, false, "", "", xerrors.Errorf("resolve model config: %w", err)
}
if !dbConfig.Enabled {
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, false, "", "", xerrors.Errorf("chat model config %s is disabled", dbConfig.ID)
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, resolvedModelRoute{}, false, "", "", xerrors.Errorf("chat model config %s is disabled", dbConfig.ID)
}
providerHint := dbConfig.Provider
var keyErr error
if dbConfig.AIProviderID.Valid {
provider, err := p.db.GetAIProviderByID(ctx, dbConfig.AIProviderID.UUID)
if err != nil {
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, false, "", "", xerrors.Errorf("get AI provider: %w", err)
}
if !provider.Enabled {
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, false, "", "", xerrors.Errorf("AI provider %s is disabled", provider.ID)
}
providerHint = string(provider.Type)
keys, keyErr = p.resolveUserProviderAPIKeysForProvider(ctx, chat.OwnerID, provider)
} else {
keys, keyErr = p.resolveUserProviderAPIKeys(ctx, chat.OwnerID, uuid.Nil)
}
if keyErr != nil {
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, false, "", "", xerrors.Errorf("resolve provider API keys: %w", keyErr)
route, err = p.resolveModelRouteForConfig(ctx, chat.OwnerID, dbConfig, chatprovider.ProviderAPIKeys{})
if err != nil {
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, resolvedModelRoute{}, false, "", "", err
}
keys = route.directProviderKeys()
providerHint, err := route.providerHint()
if err != nil {
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, resolvedModelRoute{}, false, "", "", err
}
resolvedProvider, resolvedModel, err = chatprovider.ResolveModelWithProviderHint(
dbConfig.Model,
providerHint,
)
if err != nil {
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, false, "", "", xerrors.Errorf(
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, resolvedModelRoute{}, false, "", "", xerrors.Errorf(
"resolve model metadata: %w", err,
)
}
model, debugEnabled, err = p.newDebugAwareModelFromConfig(
ctx,
chat,
providerHint,
dbConfig.Model,
keys,
chatprovider.UserAgent(),
chatprovider.CoderHeaders(chat),
)
model, debugEnabled, err = p.newDebugAwareModel(ctx, modelClientRequest{
Chat: chat,
ModelName: dbConfig.Model,
UserAgent: chatprovider.UserAgent(),
ExtraHeaders: chatprovider.CoderHeaders(chat),
}, route, modelOpts)
if err != nil {
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, false, "", "", xerrors.Errorf(
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, resolvedModelRoute{}, false, "", "", xerrors.Errorf(
"create model: %w", err,
)
}
return model, dbConfig, keys, debugEnabled, resolvedProvider, resolvedModel, nil
return model, dbConfig, keys, route, debugEnabled, resolvedProvider, resolvedModel, nil
}
func (p *Server) aiProviderConfig(ctx context.Context, provider database.AIProvider) (chatprovider.ConfiguredProvider, error) {
@@ -8605,9 +8672,18 @@ func (p *Server) resolveUserProviderAPIKeysForProviderType(
ownerID uuid.UUID,
providerType string,
) (chatprovider.ProviderAPIKeys, error) {
keys, _, err := p.resolveUserProviderAPIKeysAndProviderForProviderType(ctx, ownerID, providerType)
return keys, err
}
func (p *Server) resolveUserProviderAPIKeysAndProviderForProviderType(
ctx context.Context,
ownerID uuid.UUID,
providerType string,
) (chatprovider.ProviderAPIKeys, *database.AIProvider, error) {
providers, err := p.db.GetAIProviders(ctx, database.GetAIProvidersParams{})
if err != nil {
return chatprovider.ProviderAPIKeys{}, xerrors.Errorf("get enabled AI providers: %w", err)
return chatprovider.ProviderAPIKeys{}, nil, xerrors.Errorf("get enabled AI providers: %w", err)
}
normalizedProviderType := chatprovider.NormalizeProvider(providerType)
for _, provider := range providers {
@@ -8616,13 +8692,17 @@ func (p *Server) resolveUserProviderAPIKeysForProviderType(
}
keys, err := p.resolveUserProviderAPIKeysForProvider(ctx, ownerID, provider)
if err != nil {
return chatprovider.ProviderAPIKeys{}, err
return chatprovider.ProviderAPIKeys{}, nil, err
}
if userCanUseProviderKeys(keys, normalizedProviderType) {
return keys, nil
return keys, &provider, nil
}
}
return p.resolveUserProviderAPIKeys(ctx, ownerID, uuid.Nil)
keys, err := p.resolveUserProviderAPIKeys(ctx, ownerID, uuid.Nil)
if err != nil {
return chatprovider.ProviderAPIKeys{}, nil, err
}
return keys, nil, nil
}
func (p *Server) resolveUserProviderAPIKeys(
@@ -9391,6 +9471,7 @@ func insertSyntheticToolResultsTx(
params := database.InsertChatMessagesParams{
ChatID: chat.ID,
CreatedBy: make([]uuid.UUID, n),
APIKeyID: make([]string, n),
ModelConfigID: make([]uuid.UUID, n),
Role: make([]database.ChatMessageRole, n),
Content: make([]string, n),
@@ -9537,7 +9618,7 @@ func (p *Server) generateFinalTurnStatusLabel(
return fallbackTurnStatusLabel(status)
}
statusLabel := generateTurnStatusLabel(
statusLabel := p.generateTurnStatusLabel(
ctx,
chat,
status,
@@ -9545,7 +9626,9 @@ func (p *Server) generateFinalTurnStatusLabel(
runResult.FallbackProvider,
runResult.FallbackModel,
runResult.StatusLabelModel,
runResult.FallbackRoute,
runResult.ProviderKeys,
runResult.ModelBuildOptions,
logger,
p.existingDebugService(),
runResult.TriggerMessageID,
+16 -34
View File
@@ -2,14 +2,11 @@ package chatd
import (
"context"
"net/http"
"time"
"charm.land/fantasy"
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/x/chatd/chatdebug"
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
)
@@ -109,53 +106,38 @@ func (p *Server) scheduleDebugCleanup(
}()
}
func (p *Server) newDebugAwareModelFromConfig(
func (p *Server) newDebugAwareModel(
ctx context.Context,
chat database.Chat,
providerHint string,
modelName string,
providerKeys chatprovider.ProviderAPIKeys,
userAgent string,
extraHeaders map[string]string,
req modelClientRequest,
route resolvedModelRoute,
opts modelBuildOptions,
) (fantasy.LanguageModel, bool, error) {
provider, resolvedModel, err := chatprovider.ResolveModelWithProviderHint(modelName, providerHint)
providerHint, err := route.providerHint()
if err != nil {
return nil, false, err
}
provider, resolvedModel, err := chatprovider.ResolveModelWithProviderHint(req.ModelName, providerHint)
if err != nil {
return nil, false, err
}
route = route.withProviderHint(provider)
req.ModelName = resolvedModel
debugSvc := p.debugService()
debugEnabled := debugSvc != nil && debugSvc.IsEnabled(ctx, chat.ID, chat.OwnerID)
debugEnabled := debugSvc != nil && debugSvc.IsEnabled(ctx, req.Chat.ID, req.Chat.OwnerID)
opts.RecordHTTP = debugEnabled
var httpClient *http.Client
if debugEnabled {
httpClient = &http.Client{Transport: &chatdebug.RecordingTransport{}}
}
model, err := chatprovider.ModelFromConfig(
provider,
resolvedModel,
providerKeys,
userAgent,
extraHeaders,
httpClient,
)
model, err := p.newModel(ctx, req, route, opts)
if err != nil {
return nil, debugEnabled, err
}
if model == nil {
return nil, debugEnabled, xerrors.Errorf(
"create model for %s/%s returned nil",
provider,
resolvedModel,
)
}
if !debugEnabled {
return model, false, nil
}
return chatdebug.WrapModel(model, debugSvc, chatdebug.RecorderOptions{
ChatID: chat.ID,
OwnerID: chat.OwnerID,
ChatID: req.Chat.ID,
OwnerID: req.Chat.OwnerID,
Provider: provider,
Model: resolvedModel,
}), true, nil
+59 -3
View File
@@ -152,10 +152,11 @@ func TestResolveComputerUseModel_OpenAIMissingCredentials(t *testing.T) {
model, debugEnabled, resolvedProvider, resolvedModel, err := server.resolveComputerUseModel(
context.Background(),
database.Chat{ID: uuid.New(), OwnerID: uuid.New()},
chatprovider.ProviderAPIKeys{},
newDirectModelRoute(modelProvider, chatprovider.ProviderAPIKeys{}),
provider,
modelProvider,
modelName,
modelBuildOptions{},
)
require.Error(t, err)
require.Nil(t, model)
@@ -167,6 +168,55 @@ func TestResolveComputerUseModel_OpenAIMissingCredentials(t *testing.T) {
require.NotContains(t, err.Error(), "ANTHROPIC_API_KEY")
}
func TestResolveUserProviderAPIKeysAndProviderForProviderTypeProviderMatch(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
ownerID := uuid.New()
providerID := uuid.New()
db.EXPECT().GetAIProviders(gomock.Any(), database.GetAIProvidersParams{}).Return([]database.AIProvider{
{ID: uuid.New(), Type: database.AiProviderTypeAnthropic, Enabled: true},
{ID: providerID, Type: database.AiProviderTypeOpenai, Enabled: true},
}, nil)
db.EXPECT().GetAIProviderKeysByProviderID(gomock.Any(), providerID).Return([]database.AIProviderKey{{
ProviderID: providerID,
APIKey: "test-key",
}}, nil)
server := &Server{db: db}
keys, aiProvider, err := server.resolveUserProviderAPIKeysAndProviderForProviderType(
ctx,
ownerID,
chattool.ComputerUseProviderOpenAI,
)
require.NoError(t, err)
require.Equal(t, "test-key", keys.APIKey(chattool.ComputerUseProviderOpenAI))
require.NotNil(t, aiProvider)
require.Equal(t, providerID, aiProvider.ID)
require.Equal(t, database.AiProviderTypeOpenai, aiProvider.Type)
}
func TestResolveModelRouteForProviderTypeAIGatewayRequiresProvider(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
db.EXPECT().GetAIProviders(gomock.Any(), database.GetAIProvidersParams{}).Return(nil, nil)
server := &Server{db: db, aiGatewayRoutingEnabled: true}
_, err := server.resolveModelRouteForProviderType(
ctx,
uuid.New(),
chattool.ComputerUseProviderOpenAI,
)
require.ErrorContains(t, err, "AI Gateway routing requires a usable AI provider")
}
func TestAppendComputerUseProviderTool(t *testing.T) {
t.Parallel()
@@ -731,6 +781,11 @@ func TestRenameChatTitle(t *testing.T) {
})
}
func withChatMessageAPIKeyID(message database.ChatMessage, apiKeyID string) database.ChatMessage {
message.APIKeyID = sqlNullString(apiKeyID)
return message
}
func TestRegenerateChatTitle_PersistsAndBroadcasts(t *testing.T) {
t.Parallel()
@@ -749,6 +804,7 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts(t *testing.T) {
modelConfigID := uuid.New()
workerID := uuid.New()
userPrompt := "review pull request 23633 and fix review threads"
activeAPIKeyID := "key-" + uuid.NewString()
wantTitle := "Review PR 23633"
chat := database.Chat{
@@ -814,12 +870,12 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts(t *testing.T) {
LimitVal: manualTitleMessageWindowLimit,
},
).Return([]database.ChatMessage{
mustChatMessage(
withChatMessageAPIKeyID(mustChatMessage(
t,
database.ChatMessageRoleUser,
database.ChatMessageVisibilityBoth,
codersdk.ChatMessageText(userPrompt),
),
), activeAPIKeyID),
mustChatMessage(
t,
database.ChatMessageRoleAssistant,
+264 -6
View File
@@ -11,6 +11,7 @@ import (
"io"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"slices"
@@ -32,6 +33,7 @@ import (
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/agent/agentcontextconfig"
"github.com/coder/coder/v2/agent/agenttest"
"github.com/coder/coder/v2/coderd/aibridge"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
@@ -67,6 +69,88 @@ type recordedOpenAIRequest struct {
ContentLength int64
}
type chatAIGatewayRecordedRequest struct {
ProviderName string
Source aibridge.Source
APIKeyID string
Path string
Authorization string
XAPIKey string
CoderToken string
}
type chatAIGatewayTestFactory struct {
target *url.URL
transport http.RoundTripper
mu sync.Mutex
requests []chatAIGatewayRecordedRequest
}
func newChatAIGatewayTestFactory(t testing.TB, targetBaseURL string) *chatAIGatewayTestFactory {
t.Helper()
target, err := url.Parse(targetBaseURL)
require.NoError(t, err)
return &chatAIGatewayTestFactory{target: target, transport: http.DefaultTransport}
}
func (f *chatAIGatewayTestFactory) TransportFor(providerName string, source aibridge.Source) (http.RoundTripper, error) {
return chatAIGatewayRoundTripper{factory: f, providerName: providerName, source: source}, nil
}
func (f *chatAIGatewayTestFactory) requestsSnapshot() []chatAIGatewayRecordedRequest {
f.mu.Lock()
defer f.mu.Unlock()
return slices.Clone(f.requests)
}
type chatAIGatewayRoundTripper struct {
factory *chatAIGatewayTestFactory
providerName string
source aibridge.Source
}
func (t chatAIGatewayRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
apiKeyID, _ := aibridge.DelegatedAPIKeyIDFromContext(req.Context())
t.factory.mu.Lock()
t.factory.requests = append(t.factory.requests, chatAIGatewayRecordedRequest{
ProviderName: t.providerName,
Source: t.source,
APIKeyID: apiKeyID,
Path: req.URL.Path,
Authorization: req.Header.Get("Authorization"),
XAPIKey: req.Header.Get("X-Api-Key"),
CoderToken: req.Header.Get(aibridge.HeaderCoderToken),
})
t.factory.mu.Unlock()
targetURL := *t.factory.target
targetURL.Path = strings.TrimPrefix(req.URL.Path, "/v1")
if targetURL.Path == "" {
targetURL.Path = "/"
}
targetURL.RawQuery = req.URL.RawQuery
cloned := req.Clone(req.Context())
cloned.URL = &targetURL
cloned.Host = t.factory.target.Host
return t.factory.transport.RoundTrip(cloned)
}
func chatAIGatewayTransportFactoryPointer(factory aibridge.TransportFactory) *atomic.Pointer[aibridge.TransportFactory] {
var ptr atomic.Pointer[aibridge.TransportFactory]
ptr.Store(&factory)
return &ptr
}
func directChatRoutingDeploymentValues(t testing.TB) *codersdk.DeploymentValues {
t.Helper()
values := coderdtest.DeploymentValues(t)
require.NoError(t, values.AI.Chat.AIGatewayRoutingEnabled.Set("false"))
return values
}
func openAIToolName(tool chattest.OpenAITool) string {
return cmp.Or(tool.Function.Name, tool.Name, tool.Type)
}
@@ -231,7 +315,7 @@ func TestSubagentChatExcludesWorkspaceProvisioningTools(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
deploymentValues := coderdtest.DeploymentValues(t)
deploymentValues := directChatRoutingDeploymentValues(t)
client := coderdtest.New(t, &coderdtest.Options{
DeploymentValues: deploymentValues,
IncludeProvisionerDaemon: true,
@@ -388,7 +472,7 @@ func TestPlanModeSubagentChatExcludesAskUserQuestion(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
deploymentValues := coderdtest.DeploymentValues(t)
deploymentValues := directChatRoutingDeploymentValues(t)
client := coderdtest.New(t, &coderdtest.Options{
DeploymentValues: deploymentValues,
IncludeProvisionerDaemon: true,
@@ -555,7 +639,7 @@ func TestExploreSubagentIsReadOnly(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
deploymentValues := coderdtest.DeploymentValues(t)
deploymentValues := directChatRoutingDeploymentValues(t)
client, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{
DeploymentValues: deploymentValues,
IncludeProvisionerDaemon: true,
@@ -1796,6 +1880,73 @@ func TestUpdateChatHeartbeatsRequiresOwnership(t *testing.T) {
require.Equal(t, chat.ID, ids[0])
}
func TestCreateChatPersistsAPIKeyIDOnInitialUserMessage(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
replica := newTestServer(t, db, ps, uuid.New())
ctx := testutil.Context(t, testutil.WaitLong)
user, org, model := seedChatDependencies(t, db)
apiKey, _ := dbgen.APIKey(t, db, database.APIKey{UserID: user.ID})
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
OrganizationID: org.ID,
OwnerID: user.ID,
Title: "create-chat-api-key-id",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
APIKeyID: apiKey.ID,
})
require.NoError(t, err)
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
ChatID: chat.ID,
AfterID: 0,
})
require.NoError(t, err)
require.Len(t, messages, 1)
require.Equal(t, database.ChatMessageRoleUser, messages[0].Role)
require.True(t, messages[0].APIKeyID.Valid)
require.Equal(t, apiKey.ID, messages[0].APIKeyID.String)
}
func TestSendMessagePersistsAPIKeyIDOnUserMessage(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
replica := newTestServer(t, db, ps, uuid.New())
ctx := testutil.Context(t, testutil.WaitLong)
user, org, model := seedChatDependencies(t, db)
apiKey, _ := dbgen.APIKey(t, db, database.APIKey{UserID: user.ID})
chat := dbgen.Chat(t, db, database.Chat{
OrganizationID: org.ID,
OwnerID: user.ID,
LastModelConfigID: model.ID,
Title: "send-message-api-key-id",
})
result, err := replica.SendMessage(ctx, chatd.SendMessageOptions{
ChatID: chat.ID,
CreatedBy: user.ID,
Content: []codersdk.ChatMessagePart{
codersdk.ChatMessageText("message with api key id"),
},
APIKeyID: apiKey.ID,
})
require.NoError(t, err)
require.False(t, result.Queued)
require.True(t, result.Message.APIKeyID.Valid)
require.Equal(t, apiKey.ID, result.Message.APIKeyID.String)
stored, err := db.GetChatMessageByID(ctx, result.Message.ID)
require.NoError(t, err)
require.True(t, stored.APIKeyID.Valid)
require.Equal(t, apiKey.ID, stored.APIKeyID.String)
}
func TestSendMessageQueueBehaviorQueuesWhenBusy(t *testing.T) {
t.Parallel()
@@ -2143,15 +2294,20 @@ func TestEditMessageUpdatesAndTruncatesAndClearsQueue(t *testing.T) {
})
require.NoError(t, err)
apiKey, _ := dbgen.APIKey(t, db, database.APIKey{UserID: user.ID})
apiKeyID := apiKey.ID
editResult, err := replica.EditMessage(ctx, chatd.EditMessageOptions{
ChatID: chat.ID,
EditedMessageID: editedMessageID,
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("edited")},
APIKeyID: apiKeyID,
})
require.NoError(t, err)
// The edited message is soft-deleted and a new message is inserted,
// so the returned message ID will differ from the original.
require.NotEqual(t, editedMessageID, editResult.Message.ID)
require.True(t, editResult.Message.APIKeyID.Valid)
require.Equal(t, apiKeyID, editResult.Message.APIKeyID.String)
require.Equal(t, database.ChatStatusPending, editResult.Chat.Status)
require.False(t, editResult.Chat.WorkerID.Valid)
@@ -2166,6 +2322,8 @@ func TestEditMessageUpdatesAndTruncatesAndClearsQueue(t *testing.T) {
require.NoError(t, err)
require.Len(t, messages, 1)
require.Equal(t, editResult.Message.ID, messages[0].ID)
require.True(t, messages[0].APIKeyID.Valid)
require.Equal(t, apiKeyID, messages[0].APIKeyID.String)
onlyMessage := db2sdk.ChatMessage(messages[0])
require.Len(t, onlyMessage.Content, 1)
require.Equal(t, "edited", onlyMessage.Content[0].Text)
@@ -2366,6 +2524,7 @@ func TestPromoteQueuedAllowsAlreadyQueuedMessageWhenUsageLimitReached(t *testing
ctx := testutil.Context(t, testutil.WaitLong)
user, org, model := seedChatDependencies(t, db)
apiKey, _ := dbgen.APIKey(t, db, database.APIKey{UserID: user.ID})
_, err := db.UpsertChatUsageLimitConfig(ctx, database.UpsertChatUsageLimitConfigParams{
Enabled: true,
@@ -2400,11 +2559,14 @@ func TestPromoteQueuedAllowsAlreadyQueuedMessageWhenUsageLimitReached(t *testing
queuedResult, err := replica.SendMessage(ctx, chatd.SendMessageOptions{
ChatID: chat.ID,
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("queued")},
APIKeyID: apiKey.ID,
BusyBehavior: chatd.SendMessageBusyBehaviorQueue,
})
require.NoError(t, err)
require.True(t, queuedResult.Queued)
require.NotNil(t, queuedResult.QueuedMessage)
require.True(t, queuedResult.QueuedMessage.APIKeyID.Valid)
require.Equal(t, apiKey.ID, queuedResult.QueuedMessage.APIKeyID.String)
assistantContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
codersdk.ChatMessageText("assistant"),
@@ -2437,6 +2599,8 @@ func TestPromoteQueuedAllowsAlreadyQueuedMessageWhenUsageLimitReached(t *testing
})
require.NoError(t, err)
require.Equal(t, database.ChatMessageRoleUser, result.PromotedMessage.Role)
require.True(t, result.PromotedMessage.APIKeyID.Valid)
require.Equal(t, apiKey.ID, result.PromotedMessage.APIKeyID.String)
queued, err := db.GetChatQueuedMessages(ctx, chat.ID)
require.NoError(t, err)
@@ -4864,7 +5028,7 @@ func TestCreateWorkspaceTool_EndToEnd(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
deploymentValues := coderdtest.DeploymentValues(t)
deploymentValues := directChatRoutingDeploymentValues(t)
client := coderdtest.New(t, &coderdtest.Options{
DeploymentValues: deploymentValues,
IncludeProvisionerDaemon: true,
@@ -5029,7 +5193,7 @@ func TestStartWorkspaceTool_EndToEnd(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitSuperLong)
deploymentValues := coderdtest.DeploymentValues(t)
deploymentValues := directChatRoutingDeploymentValues(t)
client := coderdtest.New(t, &coderdtest.Options{
DeploymentValues: deploymentValues,
IncludeProvisionerDaemon: true,
@@ -7252,6 +7416,100 @@ func TestProcessChat_UserProviderKey_Success(t *testing.T) {
require.Contains(t, recordedAuthHeaders, "Bearer "+userAPIKey)
}
func TestProcessChat_AIGatewayRoutingUsesDelegatedAPIKey(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitLong)
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
if req.Stream {
return chattest.OpenAIStreamingResponse(
chattest.OpenAITextChunks("hello through AI Gateway")...,
)
}
return chattest.OpenAINonStreamingResponse(`{"title":"AI Gateway Chat"}`)
})
factory := newChatAIGatewayTestFactory(t, openAIURL)
user := dbgen.User(t, db, database.User{})
org := dbgen.Organization(t, db, database.Organization{})
dbgen.OrganizationMember(t, db, database.OrganizationMember{
UserID: user.ID,
OrganizationID: org.ID,
})
provider := dbgen.AIProvider(t, db, database.AIProvider{
Type: database.AiProviderTypeOpenai,
Name: "primary-openai-" + uuid.NewString(),
BaseUrl: openAIURL,
})
model := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{
Provider: string(database.AiProviderTypeOpenai),
Model: "gpt-4o-mini",
IsDefault: true,
AIProviderID: uuid.NullUUID{UUID: provider.ID, Valid: true},
})
apiKey, _ := dbgen.APIKey(t, db, database.APIKey{UserID: user.ID})
_, err := db.UpsertUserAIProviderKey(ctx, database.UpsertUserAIProviderKeyParams{
ID: uuid.New(),
UserID: user.ID,
AIProviderID: provider.ID,
APIKey: "sk-user-aibridge",
})
require.NoError(t, err)
creator := newTestServer(t, db, ps, uuid.New())
chat, err := creator.CreateChat(ctx, chatd.CreateOptions{
OrganizationID: org.ID,
OwnerID: user.ID,
Title: "aigateway-routing",
ModelConfigID: model.ID,
APIKeyID: apiKey.ID,
InitialUserContent: []codersdk.ChatMessagePart{
codersdk.ChatMessageText("say hello"),
},
})
require.NoError(t, err)
_, events, cancel, ok := creator.Subscribe(ctx, chat.ID, nil, 0)
require.True(t, ok)
t.Cleanup(cancel)
_ = newActiveTestServer(t, db, ps, func(cfg *chatd.Config) {
cfg.AIBridgeTransportFactory = chatAIGatewayTransportFactoryPointer(factory)
cfg.AIGatewayRoutingEnabled = true
cfg.AllowBYOK = true
cfg.AllowBYOKSet = true
})
terminalStatus := waitForTerminalChatStatusEvent(ctx, t, events)
require.Equal(t, codersdk.ChatStatusWaiting, terminalStatus)
chatResult := waitForTerminalChat(ctx, t, db, chat.ID)
require.Equal(t, database.ChatStatusWaiting, chatResult.Status)
require.False(t, chatResult.LastError.Valid)
requests := factory.requestsSnapshot()
require.NotEmpty(t, requests)
require.Contains(t, requests, chatAIGatewayRecordedRequest{
ProviderName: provider.Name,
Source: aibridge.SourceAgents,
APIKeyID: apiKey.ID,
Path: "/v1/responses",
Authorization: "Bearer sk-user-aibridge",
CoderToken: "delegated",
})
for _, req := range requests {
require.Equal(t, provider.Name, req.ProviderName)
require.Equal(t, aibridge.SourceAgents, req.Source)
require.Equal(t, apiKey.ID, req.APIKeyID)
require.Equal(t, "Bearer sk-user-aibridge", req.Authorization)
require.Empty(t, req.XAPIKey)
require.Equal(t, "delegated", req.CoderToken)
require.True(t, strings.HasPrefix(req.Path, "/v1/"), "unexpected aibridge path %q", req.Path)
}
}
func TestProcessChat_UserProviderKey_MissingKeyError(t *testing.T) {
t.Parallel()
@@ -8493,7 +8751,7 @@ func TestAgentContextFilesAndSkillsLoadedIntoChat(t *testing.T) {
))
ctx := testutil.Context(t, testutil.WaitSuperLong)
deploymentValues := coderdtest.DeploymentValues(t)
deploymentValues := directChatRoutingDeploymentValues(t)
client := coderdtest.New(t, &coderdtest.Options{
DeploymentValues: deploymentValues,
IncludeProvisionerDaemon: true,
+8 -10
View File
@@ -60,10 +60,11 @@ func (p *Server) computerUseProviderAndModelFromConfig(
func (p *Server) resolveComputerUseModel(
ctx context.Context,
chat database.Chat,
providerKeys chatprovider.ProviderAPIKeys,
route resolvedModelRoute,
computerUseProvider string,
computerUseModelProvider string,
computerUseModelName string,
modelOpts modelBuildOptions,
) (
model fantasy.LanguageModel,
debugEnabled bool,
@@ -84,15 +85,12 @@ func (p *Server) resolveComputerUseModel(
)
}
model, debugEnabled, err = p.newDebugAwareModelFromConfig(
ctx,
chat,
computerUseModelProvider,
computerUseModelName,
providerKeys,
chatprovider.UserAgent(),
chatprovider.CoderHeaders(chat),
)
model, debugEnabled, err = p.newDebugAwareModel(ctx, modelClientRequest{
Chat: chat,
ModelName: computerUseModelName,
UserAgent: chatprovider.UserAgent(),
ExtraHeaders: chatprovider.CoderHeaders(chat),
}, route, modelOpts)
if err != nil {
return nil, false, "", "", xerrors.Errorf(
"resolve computer use model for provider %q model %q: %w",
+168
View File
@@ -0,0 +1,168 @@
package chatd
import (
"context"
"net/http"
"charm.land/fantasy"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
)
type modelClientRequest struct {
Chat database.Chat
ModelName string
UserAgent string
ExtraHeaders map[string]string
}
type modelBuildOptions struct {
ActiveAPIKeyID string
RecordHTTP bool
}
func modelBuildOptionsFromMessages(messages []database.ChatMessage) modelBuildOptions {
apiKeyID, _ := activeTurnAPIKeyIDFromMessages(messages)
return modelBuildOptions{ActiveAPIKeyID: apiKeyID}
}
type modelRouteKind int
const (
modelRouteKindDirect modelRouteKind = iota + 1
modelRouteKindAIGateway
)
type resolvedModelRoute struct {
kind modelRouteKind
direct directModelRoute
aiGateway aiGatewayModelRoute
}
func newDirectModelRoute(providerHint string, keys chatprovider.ProviderAPIKeys) resolvedModelRoute {
return resolvedModelRoute{
kind: modelRouteKindDirect,
direct: directModelRoute{
ProviderHint: providerHint,
Keys: keys,
},
}
}
func (r resolvedModelRoute) providerHint() (string, error) {
switch r.kind {
case modelRouteKindDirect:
return r.direct.ProviderHint, nil
case modelRouteKindAIGateway:
return r.aiGateway.ModelProviderHint, nil
default:
return "", xerrors.New("model route is not configured")
}
}
func (r resolvedModelRoute) withProviderHint(providerHint string) resolvedModelRoute {
switch r.kind {
case modelRouteKindDirect:
r.direct.ProviderHint = providerHint
case modelRouteKindAIGateway:
r.aiGateway.ModelProviderHint = providerHint
}
return r
}
func (r resolvedModelRoute) directProviderKeys() chatprovider.ProviderAPIKeys {
if r.kind != modelRouteKindDirect {
return chatprovider.ProviderAPIKeys{}
}
return r.direct.Keys
}
func (p *Server) enabledAIProviderByID(ctx context.Context, providerID uuid.UUID) (database.AIProvider, error) {
provider, err := p.db.GetAIProviderByID(ctx, providerID)
if err != nil {
return database.AIProvider{}, xerrors.Errorf("get AI provider: %w", err)
}
if !provider.Enabled {
return database.AIProvider{}, xerrors.Errorf("AI provider %s is disabled", provider.ID)
}
return provider, nil
}
func (p *Server) shouldUseAIGatewayRouting() bool {
return p.aiGatewayRoutingEnabled
}
func (p *Server) resolveModelRouteForConfig(
ctx context.Context,
ownerID uuid.UUID,
modelConfig database.ChatModelConfig,
fallbackKeys chatprovider.ProviderAPIKeys,
) (resolvedModelRoute, error) {
if p.shouldUseAIGatewayRouting() {
return p.resolveAIGatewayModelRouteForConfig(ctx, ownerID, modelConfig)
}
return p.resolveDirectModelRouteForConfig(ctx, ownerID, modelConfig, fallbackKeys)
}
func (p *Server) resolveModelRouteForProviderType(
ctx context.Context,
ownerID uuid.UUID,
providerType string,
) (resolvedModelRoute, error) {
if p.shouldUseAIGatewayRouting() {
return p.resolveAIGatewayModelRouteForProviderType(ctx, ownerID, providerType)
}
return p.resolveDirectModelRouteForProviderType(ctx, ownerID, providerType)
}
func (p *Server) newModel(
ctx context.Context,
req modelClientRequest,
route resolvedModelRoute,
opts modelBuildOptions,
) (fantasy.LanguageModel, error) {
switch route.kind {
case modelRouteKindDirect:
return p.newDirectModel(ctx, req, route.direct, opts)
case modelRouteKindAIGateway:
return p.newAIGatewayModel(ctx, req, route.aiGateway, opts)
default:
return nil, xerrors.New("model route is not configured")
}
}
func newLanguageModel(
providerHint string,
modelName string,
providerKeys chatprovider.ProviderAPIKeys,
userAgent string,
extraHeaders map[string]string,
httpClient *http.Client,
) (fantasy.LanguageModel, error) {
model, err := chatprovider.ModelFromConfig(
providerHint,
modelName,
providerKeys,
userAgent,
extraHeaders,
httpClient,
)
if err != nil {
return nil, err
}
if model == nil {
provider, resolvedModel, resolveErr := chatprovider.ResolveModelWithProviderHint(modelName, providerHint)
if resolveErr != nil {
return nil, resolveErr
}
return nil, xerrors.Errorf(
"create model for %s/%s returned nil",
provider,
resolvedModel,
)
}
return model, nil
}
+292
View File
@@ -0,0 +1,292 @@
package chatd
import (
"context"
"database/sql"
"net/http"
"strings"
"charm.land/fantasy"
fantasyanthropic "charm.land/fantasy/providers/anthropic"
fantasyopenai "charm.land/fantasy/providers/openai"
fantasyopenaicompat "charm.land/fantasy/providers/openaicompat"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/aibridge"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/x/chatd/chatdebug"
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
)
const (
aibridgeLocalBaseURL = "http://coder-aibridge"
// aibridgePlaceholderAPIKey satisfies fantasy clients that require a
// non-empty API key before aibridged resolves the real credential.
aibridgePlaceholderAPIKey = "coder-aibridge"
aibridgeDelegatedBYOKMarker = "delegated"
)
type aiGatewayModelRoute struct {
Provider database.AIProvider
ModelProviderHint string
ProviderAuth aiGatewayProviderAuth
}
func newAIGatewayModelRoute(
provider database.AIProvider,
modelProviderHint string,
auth aiGatewayProviderAuth,
) resolvedModelRoute {
return resolvedModelRoute{
kind: modelRouteKindAIGateway,
aiGateway: aiGatewayModelRoute{
Provider: provider,
ModelProviderHint: modelProviderHint,
ProviderAuth: auth,
},
}
}
type aiGatewayProviderAuth struct {
Headers map[string]string
}
func (aiGatewayProviderAuth) String() string {
return "aiGatewayProviderAuth{Headers:<redacted>}"
}
func (a aiGatewayProviderAuth) GoString() string {
return a.String()
}
type aiGatewayRequestFormat int
const (
aiGatewayRequestFormatOpenAI aiGatewayRequestFormat = iota
aiGatewayRequestFormatAnthropic
)
type aiGatewayRoundTripper struct {
base http.RoundTripper
apiKeyID string
providerAuth aiGatewayProviderAuth
}
func (t *aiGatewayRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
ctx := aibridge.WithDelegatedAPIKeyID(req.Context(), t.apiKeyID)
cloned := req.Clone(ctx)
for name, value := range t.providerAuth.Headers {
cloned.Header.Set(name, value)
}
if len(t.providerAuth.Headers) > 0 {
cloned.Header.Set(aibridge.HeaderCoderToken, aibridgeDelegatedBYOKMarker)
}
return t.base.RoundTrip(cloned)
}
func (p *Server) newAIGatewayModel(
_ context.Context,
req modelClientRequest,
route aiGatewayModelRoute,
opts modelBuildOptions,
) (fantasy.LanguageModel, error) {
if route.Provider.ID == uuid.Nil {
return nil, xerrors.New("AI Gateway routing requires a concrete AI provider")
}
if route.Provider.Name == "" {
return nil, xerrors.New("AI Gateway routing requires an AI provider name")
}
if opts.ActiveAPIKeyID == "" {
return nil, xerrors.New("AI Gateway routing requires the active turn API key ID")
}
factoryPtr := p.aibridgeTransportFactory
if factoryPtr == nil {
return nil, xerrors.New("AI Gateway transport factory is not configured")
}
factory := factoryPtr.Load()
if factory == nil || *factory == nil {
return nil, xerrors.New("AI Gateway transport factory is not configured")
}
rt, err := (*factory).TransportFor(route.Provider.Name, aibridge.SourceAgents)
if err != nil {
return nil, xerrors.Errorf("create AI Gateway transport: %w", err)
}
baseRT := http.RoundTripper(&aiGatewayRoundTripper{
base: rt,
apiKeyID: opts.ActiveAPIKeyID,
providerAuth: route.ProviderAuth,
})
if opts.RecordHTTP {
baseRT = &chatdebug.RecordingTransport{Base: baseRT}
}
config := fantasyConfigForAIBridge(route.Provider.Type)
return newLanguageModel(
config.ProviderHint,
req.ModelName,
config.Keys,
req.UserAgent,
req.ExtraHeaders,
&http.Client{Transport: baseRT},
)
}
type aibridgeFantasyConfig struct {
ProviderHint string
Keys chatprovider.ProviderAPIKeys
}
func fantasyConfigForAIBridge(providerType database.AIProviderType) aibridgeFantasyConfig {
var fantasyProvider string
baseURL := aibridgeLocalBaseURL + "/v1"
switch providerType {
case database.AiProviderTypeAnthropic, database.AiProviderTypeBedrock:
fantasyProvider = fantasyanthropic.Name
baseURL = aibridgeLocalBaseURL
case database.AiProviderTypeOpenai:
fantasyProvider = fantasyopenai.Name
default:
fantasyProvider = fantasyopenaicompat.Name
}
return aibridgeFantasyConfig{
ProviderHint: fantasyProvider,
Keys: chatprovider.ProviderAPIKeys{
ByProvider: map[string]string{
fantasyProvider: aibridgePlaceholderAPIKey,
},
BaseURLByProvider: map[string]string{
fantasyProvider: baseURL,
},
},
}
}
func aiGatewayRequestFormatForProviderType(providerType database.AIProviderType) aiGatewayRequestFormat {
switch providerType {
case database.AiProviderTypeAnthropic, database.AiProviderTypeBedrock:
return aiGatewayRequestFormatAnthropic
default:
return aiGatewayRequestFormatOpenAI
}
}
func (p *Server) aiGatewayProviderAuthForUser(
ctx context.Context,
ownerID uuid.UUID,
provider database.AIProvider,
format aiGatewayRequestFormat,
) (aiGatewayProviderAuth, error) {
if !p.allowBYOK {
return aiGatewayProviderAuth{}, nil
}
userKey, err := p.db.GetUserAIProviderKeyByProviderID(ctx, database.GetUserAIProviderKeyByProviderIDParams{
UserID: ownerID,
AIProviderID: provider.ID,
})
if err != nil {
if xerrors.Is(err, sql.ErrNoRows) {
return aiGatewayProviderAuth{}, nil
}
return aiGatewayProviderAuth{}, xerrors.Errorf("get user AI provider key: %w", err)
}
apiKey := strings.TrimSpace(userKey.APIKey)
if apiKey == "" {
return aiGatewayProviderAuth{}, nil
}
headers := map[string]string{}
switch format {
case aiGatewayRequestFormatAnthropic:
headers["X-Api-Key"] = apiKey
default:
headers["Authorization"] = "Bearer " + apiKey
}
return aiGatewayProviderAuth{Headers: headers}, nil
}
func (p *Server) resolveAIGatewayRoute(
ctx context.Context,
ownerID uuid.UUID,
provider database.AIProvider,
modelProviderHint string,
) (resolvedModelRoute, error) {
auth, err := p.aiGatewayProviderAuthForUser(
ctx,
ownerID,
provider,
aiGatewayRequestFormatForProviderType(provider.Type),
)
if err != nil {
return resolvedModelRoute{}, xerrors.Errorf("resolve AI Gateway provider auth: %w", err)
}
return newAIGatewayModelRoute(provider, modelProviderHint, auth), nil
}
func (p *Server) resolveAIGatewayModelRouteForConfig(
ctx context.Context,
ownerID uuid.UUID,
modelConfig database.ChatModelConfig,
) (resolvedModelRoute, error) {
provider, err := p.gatewayProviderForConfig(ctx, modelConfig)
if err != nil {
return resolvedModelRoute{}, err
}
return p.resolveAIGatewayRoute(ctx, ownerID, provider, string(provider.Type))
}
func (p *Server) resolveAIGatewayModelRouteForProviderType(
ctx context.Context,
ownerID uuid.UUID,
providerType string,
) (resolvedModelRoute, error) {
provider, err := p.aiProviderForProviderType(ctx, providerType)
if err != nil {
return resolvedModelRoute{}, err
}
return p.resolveAIGatewayRoute(
ctx,
ownerID,
provider,
chatprovider.NormalizeProvider(providerType),
)
}
func (p *Server) gatewayProviderForConfig(
ctx context.Context,
modelConfig database.ChatModelConfig,
) (database.AIProvider, error) {
if !modelConfig.AIProviderID.Valid {
return database.AIProvider{}, xerrors.Errorf(
"AI Gateway routing requires AI provider metadata for model config %s (%s)",
modelConfig.ID,
modelConfig.Model,
)
}
return p.enabledAIProviderByID(ctx, modelConfig.AIProviderID.UUID)
}
func (p *Server) aiProviderForProviderType(
ctx context.Context,
providerType string,
) (database.AIProvider, error) {
providers, err := p.db.GetAIProviders(ctx, database.GetAIProvidersParams{})
if err != nil {
return database.AIProvider{}, xerrors.Errorf("get enabled AI providers: %w", err)
}
normalizedProviderType := chatprovider.NormalizeProvider(providerType)
for _, provider := range providers {
if !provider.Enabled {
continue
}
if chatprovider.NormalizeProvider(string(provider.Type)) != normalizedProviderType {
continue
}
return provider, nil
}
return database.AIProvider{}, xerrors.Errorf(
"AI Gateway routing requires a usable AI provider for provider type %q",
providerType,
)
}
+93
View File
@@ -0,0 +1,93 @@
package chatd
import (
"context"
"net/http"
"charm.land/fantasy"
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/x/chatd/chatdebug"
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
)
type directModelRoute struct {
ProviderHint string
Keys chatprovider.ProviderAPIKeys
}
func (*Server) newDirectModel(
_ context.Context,
req modelClientRequest,
route directModelRoute,
opts modelBuildOptions,
) (fantasy.LanguageModel, error) {
var httpClient *http.Client
if opts.RecordHTTP {
httpClient = &http.Client{Transport: &chatdebug.RecordingTransport{}}
}
return newLanguageModel(
route.ProviderHint,
req.ModelName,
route.Keys,
req.UserAgent,
req.ExtraHeaders,
httpClient,
)
}
func (p *Server) resolveDirectModelRouteForConfig(
ctx context.Context,
ownerID uuid.UUID,
modelConfig database.ChatModelConfig,
fallbackKeys chatprovider.ProviderAPIKeys,
) (resolvedModelRoute, error) {
providerHint, provider, err := p.directProviderHintAndProviderForConfig(ctx, modelConfig)
if err != nil {
return resolvedModelRoute{}, err
}
if provider == nil {
if !fallbackKeys.Empty() && userCanUseProviderKeys(fallbackKeys, providerHint) {
return newDirectModelRoute(providerHint, fallbackKeys), nil
}
keys, err := p.resolveUserProviderAPIKeys(ctx, ownerID, uuid.Nil)
if err != nil {
return resolvedModelRoute{}, xerrors.Errorf("resolve provider API keys: %w", err)
}
return newDirectModelRoute(providerHint, keys), nil
}
providerKeys, err := p.resolveUserProviderAPIKeysForProvider(ctx, ownerID, *provider)
if err != nil {
return resolvedModelRoute{}, xerrors.Errorf("resolve provider API keys: %w", err)
}
return newDirectModelRoute(providerHint, providerKeys), nil
}
func (p *Server) resolveDirectModelRouteForProviderType(
ctx context.Context,
ownerID uuid.UUID,
providerType string,
) (resolvedModelRoute, error) {
normalizedProviderType := chatprovider.NormalizeProvider(providerType)
keys, _, err := p.resolveUserProviderAPIKeysAndProviderForProviderType(ctx, ownerID, providerType)
if err != nil {
return resolvedModelRoute{}, err
}
return newDirectModelRoute(normalizedProviderType, keys), nil
}
func (p *Server) directProviderHintAndProviderForConfig(
ctx context.Context,
modelConfig database.ChatModelConfig,
) (string, *database.AIProvider, error) {
if !modelConfig.AIProviderID.Valid {
return modelConfig.Provider, nil, nil
}
provider, err := p.enabledAIProviderByID(ctx, modelConfig.AIProviderID.UUID)
if err != nil {
return "", nil, err
}
return string(provider.Type), &provider, nil
}
@@ -0,0 +1,640 @@
package chatd
import (
"database/sql"
"fmt"
"io"
"net/http"
"strings"
"sync/atomic"
"testing"
"charm.land/fantasy"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"go.uber.org/mock/gomock"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd/aibridge"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbmock"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
"github.com/coder/coder/v2/coderd/x/chatd/chattool"
)
type aibridgeTestFactory struct {
providerName string
source aibridge.Source
err error
rt http.RoundTripper
}
func (f *aibridgeTestFactory) TransportFor(providerName string, source aibridge.Source) (http.RoundTripper, error) {
f.providerName = providerName
f.source = source
if f.err != nil {
return nil, f.err
}
return f.rt, nil
}
type roundTripFunc func(*http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}
func aibridgeTestFactoryPointer(factory aibridge.TransportFactory) *atomic.Pointer[aibridge.TransportFactory] {
var ptr atomic.Pointer[aibridge.TransportFactory]
ptr.Store(&factory)
return &ptr
}
func aibridgeTestAIProvider(providerID uuid.UUID, providerName string, providerType database.AIProviderType) database.AIProvider {
return database.AIProvider{
ID: providerID,
Name: providerName,
Type: providerType,
Enabled: true,
}
}
func aibridgeTestRoute(aiProvider database.AIProvider) resolvedModelRoute {
return newAIGatewayModelRoute(aiProvider, string(aiProvider.Type), aiGatewayProviderAuth{})
}
func aibridgeTestRequest(chat database.Chat, model string) modelClientRequest {
return modelClientRequest{
Chat: chat,
ModelName: model,
UserAgent: chatprovider.UserAgent(),
}
}
func TestAIBridgeProviderFormatMapping(t *testing.T) {
t.Parallel()
tests := []struct {
name string
providerType database.AIProviderType
wantProvider string
wantBaseURL string
}{
{name: "OpenAI", providerType: database.AiProviderTypeOpenai, wantProvider: "openai", wantBaseURL: "http://coder-aibridge/v1"},
{name: "Anthropic", providerType: database.AiProviderTypeAnthropic, wantProvider: "anthropic", wantBaseURL: "http://coder-aibridge"},
{name: "Bedrock", providerType: database.AiProviderTypeBedrock, wantProvider: "anthropic", wantBaseURL: "http://coder-aibridge"},
{name: "Google", providerType: database.AiProviderTypeGoogle, wantProvider: "openai-compat", wantBaseURL: "http://coder-aibridge/v1"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
config := fantasyConfigForAIBridge(tt.providerType)
require.Equal(t, tt.wantProvider, config.ProviderHint)
require.Equal(t, tt.wantBaseURL, config.Keys.BaseURL(config.ProviderHint))
require.Equal(t, aibridgePlaceholderAPIKey, config.Keys.APIKey(config.ProviderHint))
})
}
}
func TestResolveModelRouteForConfigPreservesBaseURL(t *testing.T) {
t.Parallel()
ctx := t.Context()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
ownerID := uuid.New()
providerID := uuid.New()
baseURL := "https://openai.example.com/v1"
db.EXPECT().GetAIProviderByID(gomock.Any(), providerID).Return(database.AIProvider{
ID: providerID,
Type: database.AiProviderTypeOpenai,
Name: "primary-openai",
Enabled: true,
BaseUrl: baseURL,
}, nil)
db.EXPECT().GetAIProviderKeysByProviderID(gomock.Any(), providerID).Return([]database.AIProviderKey{{
ProviderID: providerID,
APIKey: "provider-key",
}}, nil)
server := &Server{db: db}
route, err := server.resolveModelRouteForConfig(ctx, ownerID, database.ChatModelConfig{
Provider: "openai",
AIProviderID: uuid.NullUUID{UUID: providerID, Valid: true},
}, chatprovider.ProviderAPIKeys{})
require.NoError(t, err)
require.Equal(t, modelRouteKindDirect, route.kind)
require.Equal(t, "openai", route.direct.ProviderHint)
require.Equal(t, "provider-key", route.direct.Keys.APIKey("openai"))
require.Equal(t, baseURL, route.direct.Keys.BaseURL("openai"))
}
func TestAIGatewayProviderAuthForUser(t *testing.T) {
t.Parallel()
ctx := t.Context()
ownerID := uuid.New()
providerID := uuid.New()
provider := database.AIProvider{ID: providerID, Type: database.AiProviderTypeOpenai, Enabled: true}
t.Run("OpenAIUserKey", func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
db.EXPECT().GetUserAIProviderKeyByProviderID(gomock.Any(), database.GetUserAIProviderKeyByProviderIDParams{
UserID: ownerID,
AIProviderID: providerID,
}).Return(database.UserAiProviderKey{APIKey: "sk-user"}, nil)
server := &Server{db: db, allowBYOK: true}
auth, err := server.aiGatewayProviderAuthForUser(ctx, ownerID, provider, aiGatewayRequestFormatOpenAI)
require.NoError(t, err)
require.Equal(t, "Bearer sk-user", auth.Headers["Authorization"])
require.Empty(t, auth.Headers["X-Api-Key"])
})
t.Run("AnthropicUserKey", func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
db.EXPECT().GetUserAIProviderKeyByProviderID(gomock.Any(), database.GetUserAIProviderKeyByProviderIDParams{
UserID: ownerID,
AIProviderID: providerID,
}).Return(database.UserAiProviderKey{APIKey: "sk-user"}, nil)
server := &Server{db: db, allowBYOK: true}
auth, err := server.aiGatewayProviderAuthForUser(ctx, ownerID, provider, aiGatewayRequestFormatAnthropic)
require.NoError(t, err)
require.Equal(t, "sk-user", auth.Headers["X-Api-Key"])
require.Empty(t, auth.Headers["Authorization"])
})
t.Run("NoUserKey", func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
db.EXPECT().GetUserAIProviderKeyByProviderID(gomock.Any(), database.GetUserAIProviderKeyByProviderIDParams{
UserID: ownerID,
AIProviderID: providerID,
}).Return(database.UserAiProviderKey{}, sql.ErrNoRows)
server := &Server{db: db, allowBYOK: true}
auth, err := server.aiGatewayProviderAuthForUser(ctx, ownerID, provider, aiGatewayRequestFormatOpenAI)
require.NoError(t, err)
require.Empty(t, auth.Headers)
})
t.Run("BYOKDisabled", func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
server := &Server{db: db, allowBYOK: false}
auth, err := server.aiGatewayProviderAuthForUser(ctx, ownerID, provider, aiGatewayRequestFormatOpenAI)
require.NoError(t, err)
require.Empty(t, auth.Headers)
})
}
func TestAIGatewayProviderAuthRedactsFormatting(t *testing.T) {
t.Parallel()
auth := aiGatewayProviderAuth{Headers: map[string]string{
"Authorization": "Bearer sk-user",
"X-Api-Key": "sk-user",
}}
for _, formatted := range []string{
fmt.Sprint(auth),
fmt.Sprintf("%+v", auth),
fmt.Sprintf("%#v", auth),
} {
require.NotContains(t, formatted, "sk-user")
require.NotContains(t, formatted, "Bearer sk-user")
require.Contains(t, formatted, "redacted")
}
}
func TestResolveModelRouteForConfigAIGatewayProviderAuth(t *testing.T) {
t.Parallel()
ctx := t.Context()
ownerID := uuid.New()
providerID := uuid.New()
provider := database.AIProvider{
ID: providerID,
Type: database.AiProviderTypeOpenai,
Name: "primary-openai",
Enabled: true,
}
modelConfig := database.ChatModelConfig{
ID: uuid.New(),
Model: "gpt-4",
Provider: "openai",
AIProviderID: uuid.NullUUID{UUID: providerID, Valid: true},
}
t.Run("UserKey", func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
db.EXPECT().GetAIProviderByID(gomock.Any(), providerID).Return(provider, nil)
db.EXPECT().GetUserAIProviderKeyByProviderID(gomock.Any(), database.GetUserAIProviderKeyByProviderIDParams{
UserID: ownerID,
AIProviderID: providerID,
}).Return(database.UserAiProviderKey{APIKey: "sk-user"}, nil)
server := &Server{db: db, aiGatewayRoutingEnabled: true, allowBYOK: true}
route, err := server.resolveModelRouteForConfig(ctx, ownerID, modelConfig, chatprovider.ProviderAPIKeys{})
require.NoError(t, err)
require.Equal(t, modelRouteKindAIGateway, route.kind)
require.Equal(t, "Bearer sk-user", route.aiGateway.ProviderAuth.Headers["Authorization"])
})
t.Run("CentralProviderCredentialsNotForwarded", func(t *testing.T) {
t.Parallel()
ctrl := gomock.NewController(t)
db := dbmock.NewMockStore(ctrl)
db.EXPECT().GetAIProviderByID(gomock.Any(), providerID).Return(provider, nil)
server := &Server{db: db, aiGatewayRoutingEnabled: true, allowBYOK: false}
route, err := server.resolveModelRouteForConfig(ctx, ownerID, modelConfig, chatprovider.ProviderAPIKeys{})
require.NoError(t, err)
require.Equal(t, modelRouteKindAIGateway, route.kind)
require.Empty(t, route.aiGateway.ProviderAuth.Headers)
})
}
func TestAIGatewayModelForwardsProviderAuth(t *testing.T) {
t.Parallel()
type seenRequest struct {
authorization string
xAPIKey string
coderToken string
apiKeyID string
path string
}
newServer := func(t *testing.T, provider database.AIProvider, auth aiGatewayProviderAuth, seen chan seenRequest) (*Server, resolvedModelRoute) {
factory := &aibridgeTestFactory{rt: roundTripFunc(func(req *http.Request) (*http.Response, error) {
apiKeyID, _ := aibridge.DelegatedAPIKeyIDFromContext(req.Context())
seen <- seenRequest{
authorization: req.Header.Get("Authorization"),
xAPIKey: req.Header.Get("X-Api-Key"),
coderToken: req.Header.Get(aibridge.HeaderCoderToken),
apiKeyID: apiKeyID,
path: req.URL.Path,
}
body := `{"id":"resp_test","object":"response","created_at":0,"status":"completed","model":"gpt-4","output":[{"id":"msg_test","type":"message","role":"assistant","content":[{"type":"output_text","text":"hello"}]}],"usage":{"input_tokens":1,"output_tokens":1,"total_tokens":2}}`
if provider.Type == database.AiProviderTypeAnthropic {
body = `{"id":"msg_test","type":"message","role":"assistant","model":"claude-haiku-4-5","content":[{"type":"text","text":"hello"}],"stop_reason":"end_turn","stop_sequence":null,"usage":{"input_tokens":1,"output_tokens":1}}`
}
return &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(body)),
Request: req,
}, nil
})}
server := &Server{
aiGatewayRoutingEnabled: true,
aibridgeTransportFactory: aibridgeTestFactoryPointer(factory),
}
route := newAIGatewayModelRoute(provider, string(provider.Type), auth)
return server, route
}
t.Run("OpenAI", func(t *testing.T) {
t.Parallel()
seen := make(chan seenRequest, 1)
provider := aibridgeTestAIProvider(uuid.New(), "primary-openai", database.AiProviderTypeOpenai)
server, route := newServer(t, provider, aiGatewayProviderAuth{
Headers: map[string]string{"Authorization": "Bearer sk-user"},
}, seen)
apiKeyID := uuid.NewString()
model, err := server.newModel(t.Context(), aibridgeTestRequest(database.Chat{ID: uuid.New(), OwnerID: uuid.New()}, "gpt-4"), route, modelBuildOptions{ActiveAPIKeyID: apiKeyID, RecordHTTP: true})
require.NoError(t, err)
_, err = model.Generate(t.Context(), fantasy.Call{Prompt: []fantasy.Message{{Role: fantasy.MessageRoleUser, Content: []fantasy.MessagePart{fantasy.TextPart{Text: "hello"}}}}})
require.NoError(t, err)
got := <-seen
require.Equal(t, "Bearer sk-user", got.authorization)
require.Empty(t, got.xAPIKey)
require.Equal(t, aibridgeDelegatedBYOKMarker, got.coderToken)
require.Equal(t, apiKeyID, got.apiKeyID)
require.Equal(t, "/v1/responses", got.path)
})
t.Run("Anthropic", func(t *testing.T) {
t.Parallel()
seen := make(chan seenRequest, 1)
provider := aibridgeTestAIProvider(uuid.New(), "primary-anthropic", database.AiProviderTypeAnthropic)
server, route := newServer(t, provider, aiGatewayProviderAuth{
Headers: map[string]string{"X-Api-Key": "sk-user"},
}, seen)
apiKeyID := uuid.NewString()
model, err := server.newModel(t.Context(), aibridgeTestRequest(database.Chat{ID: uuid.New(), OwnerID: uuid.New()}, "claude-haiku-4-5"), route, modelBuildOptions{ActiveAPIKeyID: apiKeyID})
require.NoError(t, err)
_, err = model.Generate(t.Context(), fantasy.Call{Prompt: []fantasy.Message{{Role: fantasy.MessageRoleUser, Content: []fantasy.MessagePart{fantasy.TextPart{Text: "hello"}}}}})
require.NoError(t, err)
got := <-seen
require.Equal(t, "sk-user", got.xAPIKey)
require.Equal(t, aibridgeDelegatedBYOKMarker, got.coderToken)
require.Equal(t, apiKeyID, got.apiKeyID)
require.Equal(t, "/v1/messages", got.path)
})
t.Run("NoUserKeyLeavesPlaceholderForAIBridged", func(t *testing.T) {
t.Parallel()
seen := make(chan seenRequest, 1)
provider := aibridgeTestAIProvider(uuid.New(), "primary-openai", database.AiProviderTypeOpenai)
server, route := newServer(t, provider, aiGatewayProviderAuth{}, seen)
apiKeyID := uuid.NewString()
model, err := server.newModel(t.Context(), aibridgeTestRequest(database.Chat{ID: uuid.New(), OwnerID: uuid.New()}, "gpt-4"), route, modelBuildOptions{ActiveAPIKeyID: apiKeyID})
require.NoError(t, err)
_, err = model.Generate(t.Context(), fantasy.Call{Prompt: []fantasy.Message{{Role: fantasy.MessageRoleUser, Content: []fantasy.MessagePart{fantasy.TextPart{Text: "hello"}}}}})
require.NoError(t, err)
got := <-seen
require.Equal(t, "Bearer "+aibridgePlaceholderAPIKey, got.authorization)
require.Empty(t, got.xAPIKey)
require.Empty(t, got.coderToken)
require.Equal(t, apiKeyID, got.apiKeyID)
})
}
func TestActiveTurnAPIKeyIDFromMessages(t *testing.T) {
t.Parallel()
oldKeyID := uuid.NewString()
currentKeyID := uuid.NewString()
tests := []struct {
name string
messages []database.ChatMessage
wantKey string
wantOK bool
}{
{
name: "CurrentUserMessage",
messages: []database.ChatMessage{
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityBoth, APIKeyID: sqlNullString(oldKeyID)},
{ID: 2, Role: database.ChatMessageRoleAssistant, Visibility: database.ChatMessageVisibilityBoth},
{ID: 3, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityBoth, APIKeyID: sqlNullString(currentKeyID)},
},
wantKey: currentKeyID,
wantOK: true,
},
{
name: "MissingCurrentUserAPIKeyDoesNotFallBack",
messages: []database.ChatMessage{
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityBoth, APIKeyID: sqlNullString(oldKeyID)},
{ID: 2, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityBoth},
},
},
{
name: "SkipsModelOnlyUserMessages",
messages: []database.ChatMessage{
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityBoth, APIKeyID: sqlNullString(oldKeyID)},
{ID: 2, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, APIKeyID: sqlNullString(currentKeyID)},
},
wantKey: oldKeyID,
wantOK: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
gotKey, gotOK := activeTurnAPIKeyIDFromMessages(tt.messages)
require.Equal(t, tt.wantOK, gotOK)
require.Equal(t, tt.wantKey, gotKey)
ctx := contextWithActiveTurnAPIKeyID(t.Context(), tt.messages)
ctxKey, ctxOK := aibridge.DelegatedAPIKeyIDFromContext(ctx)
require.Equal(t, tt.wantOK, ctxOK)
require.Equal(t, tt.wantKey, ctxKey)
})
}
}
func TestActiveTurnContextUsesPromptMessages(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
ctx := t.Context()
user := dbgen.User(t, db, database.User{})
org := dbgen.Organization(t, db, database.Organization{})
model := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{})
chat := dbgen.Chat(t, db, database.Chat{OrganizationID: org.ID, OwnerID: user.ID, LastModelConfigID: model.ID})
oldKey, _ := dbgen.APIKey(t, db, database.APIKey{UserID: user.ID})
currentKey, _ := dbgen.APIKey(t, db, database.APIKey{UserID: user.ID})
modelOnlyKey, _ := dbgen.APIKey(t, db, database.APIKey{UserID: user.ID})
dbgen.ChatMessage(t, db, database.ChatMessage{
ChatID: chat.ID,
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
Role: database.ChatMessageRoleUser,
Visibility: database.ChatMessageVisibilityBoth,
APIKeyID: sqlNullString(oldKey.ID),
})
dbgen.ChatMessage(t, db, database.ChatMessage{
ChatID: chat.ID,
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
Role: database.ChatMessageRoleSystem,
Visibility: database.ChatMessageVisibilityModel,
Compressed: true,
})
dbgen.ChatMessage(t, db, database.ChatMessage{
ChatID: chat.ID,
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
Role: database.ChatMessageRoleUser,
Visibility: database.ChatMessageVisibilityBoth,
APIKeyID: sqlNullString(currentKey.ID),
})
dbgen.ChatMessage(t, db, database.ChatMessage{
ChatID: chat.ID,
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
Role: database.ChatMessageRoleUser,
Visibility: database.ChatMessageVisibilityModel,
APIKeyID: sqlNullString(modelOnlyKey.ID),
})
messages, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
require.NoError(t, err)
ctx = contextWithActiveTurnAPIKeyID(ctx, messages)
gotKey, ok := aibridge.DelegatedAPIKeyIDFromContext(ctx)
require.True(t, ok)
require.Equal(t, currentKey.ID, gotKey)
}
func sqlNullString(value string) sql.NullString {
return sql.NullString{String: value, Valid: value != ""}
}
func TestAIBridgeRoutingFailClosed(t *testing.T) {
t.Parallel()
providerID := uuid.New()
chat := database.Chat{ID: uuid.New(), OwnerID: uuid.New()}
aiProvider := aibridgeTestAIProvider(providerID, "primary-openai", database.AiProviderTypeOpenai)
t.Run("NilFactory", func(t *testing.T) {
t.Parallel()
server := &Server{aiGatewayRoutingEnabled: true}
_, err := server.newModel(t.Context(), aibridgeTestRequest(chat, "gpt-4"), aibridgeTestRoute(aiProvider), modelBuildOptions{ActiveAPIKeyID: uuid.NewString()})
require.ErrorContains(t, err, "transport factory")
})
t.Run("FactoryError", func(t *testing.T) {
t.Parallel()
factory := &aibridgeTestFactory{err: xerrors.New("boom")}
server := &Server{
aiGatewayRoutingEnabled: true,
aibridgeTransportFactory: aibridgeTestFactoryPointer(factory),
}
_, err := server.newModel(t.Context(), aibridgeTestRequest(chat, "gpt-4"), aibridgeTestRoute(aiProvider), modelBuildOptions{ActiveAPIKeyID: uuid.NewString()})
require.ErrorContains(t, err, "boom")
})
t.Run("MissingProviderName", func(t *testing.T) {
t.Parallel()
server := &Server{aiGatewayRoutingEnabled: true}
missingNameProvider := aibridgeTestAIProvider(providerID, "", database.AiProviderTypeOpenai)
_, err := server.newModel(t.Context(), aibridgeTestRequest(chat, "gpt-4"), aibridgeTestRoute(missingNameProvider), modelBuildOptions{ActiveAPIKeyID: uuid.NewString()})
require.ErrorContains(t, err, "AI provider name")
})
t.Run("MissingAPIKeyID", func(t *testing.T) {
t.Parallel()
factory := &aibridgeTestFactory{rt: roundTripFunc(func(*http.Request) (*http.Response, error) {
t.Fatal("transport must not be used without an API key ID")
return nil, xerrors.New("unreachable")
})}
server := &Server{
aiGatewayRoutingEnabled: true,
aibridgeTransportFactory: aibridgeTestFactoryPointer(factory),
}
_, err := server.newModel(t.Context(), aibridgeTestRequest(chat, "gpt-4"), aibridgeTestRoute(aiProvider), modelBuildOptions{})
require.ErrorContains(t, err, "active turn API key ID")
})
t.Run("StaticModel", func(t *testing.T) {
t.Parallel()
server := &Server{aiGatewayRoutingEnabled: true}
_, err := server.newModel(t.Context(), aibridgeTestRequest(chat, "gpt-4"), newAIGatewayModelRoute(database.AIProvider{}, "", aiGatewayProviderAuth{}), modelBuildOptions{ActiveAPIKeyID: uuid.NewString()})
require.ErrorContains(t, err, "concrete AI provider")
})
}
func TestDirectModelBuildDoesNotRequireActiveAPIKeyID(t *testing.T) {
t.Parallel()
server := &Server{}
model, err := server.newModel(t.Context(), modelClientRequest{
Chat: database.Chat{ID: uuid.New(), OwnerID: uuid.New()},
ModelName: "gpt-4",
UserAgent: chatprovider.UserAgent(),
}, newDirectModelRoute("openai", chatprovider.ProviderAPIKeys{OpenAI: "sk-test"}), modelBuildOptions{})
require.NoError(t, err)
require.NotNil(t, model)
}
func TestAIBridgeComputerUseModelUsesRoute(t *testing.T) {
t.Parallel()
providerID := uuid.New()
apiKeyID := uuid.NewString()
factory := &aibridgeTestFactory{rt: roundTripFunc(func(*http.Request) (*http.Response, error) {
t.Fatal("computer use model construction must not send a request")
return nil, xerrors.New("unreachable")
})}
chat := database.Chat{ID: uuid.New(), OwnerID: uuid.New()}
server := &Server{
aiGatewayRoutingEnabled: true,
aibridgeTransportFactory: aibridgeTestFactoryPointer(factory),
}
provider := chattool.ComputerUseProviderOpenAI
modelProvider, modelName, ok := chattool.DefaultComputerUseModel(provider)
require.True(t, ok)
ctx := aibridge.WithDelegatedAPIKeyID(t.Context(), "context-key-must-be-ignored")
model, debugEnabled, resolvedProvider, resolvedModel, err := server.resolveComputerUseModel(
ctx,
chat,
aibridgeTestRoute(aibridgeTestAIProvider(providerID, "primary-openai", database.AiProviderTypeOpenai)),
provider,
modelProvider,
modelName,
modelBuildOptions{ActiveAPIKeyID: apiKeyID},
)
require.NoError(t, err)
require.NotNil(t, model)
require.False(t, debugEnabled)
require.Equal(t, chattool.ComputerUseProviderOpenAI, resolvedProvider)
require.Equal(t, modelName, resolvedModel)
require.Equal(t, "primary-openai", factory.providerName)
require.Equal(t, aibridge.SourceAgents, factory.source)
}
func TestAIBridgeDelegatedContextPropagation(t *testing.T) {
t.Parallel()
providerID := uuid.New()
apiKeyID := uuid.NewString()
type seenRequest struct {
apiKeyID string
ok bool
path string
}
seen := make(chan seenRequest, 1)
factory := &aibridgeTestFactory{rt: roundTripFunc(func(req *http.Request) (*http.Response, error) {
gotAPIKeyID, ok := aibridge.DelegatedAPIKeyIDFromContext(req.Context())
seen <- seenRequest{
apiKeyID: gotAPIKeyID,
ok: ok,
path: req.URL.Path,
}
body := `{"id":"resp_test","object":"response","created_at":0,"status":"completed","model":"gpt-4","output":[{"id":"msg_test","type":"message","role":"assistant","content":[{"type":"output_text","text":"hello"}]}],"usage":{"input_tokens":1,"output_tokens":1,"total_tokens":2}}`
return &http.Response{
StatusCode: http.StatusOK,
Header: http.Header{"Content-Type": []string{"application/json"}},
Body: io.NopCloser(strings.NewReader(body)),
Request: req,
}, nil
})}
chat := database.Chat{ID: uuid.New(), OwnerID: uuid.New()}
server := &Server{
aiGatewayRoutingEnabled: true,
aibridgeTransportFactory: aibridgeTestFactoryPointer(factory),
}
ctx := aibridge.WithDelegatedAPIKeyID(t.Context(), "context-key-must-be-ignored")
model, err := server.newModel(ctx, aibridgeTestRequest(chat, "gpt-4"), aibridgeTestRoute(aibridgeTestAIProvider(providerID, "primary-openai", database.AiProviderTypeOpenai)), modelBuildOptions{ActiveAPIKeyID: apiKeyID, RecordHTTP: true})
require.NoError(t, err)
_, err = model.Generate(t.Context(), fantasy.Call{Prompt: []fantasy.Message{{
Role: fantasy.MessageRoleUser,
Content: []fantasy.MessagePart{fantasy.TextPart{Text: "hello"}},
}}})
require.NoError(t, err)
got := <-seen
require.Equal(t, "primary-openai", factory.providerName)
require.Equal(t, aibridge.SourceAgents, factory.source)
require.True(t, got.ok)
require.Equal(t, "/v1/responses", got.path)
require.Equal(t, apiKeyID, got.apiKeyID)
}
+66 -77
View File
@@ -4,7 +4,6 @@ import (
"context"
"errors"
"fmt"
"net/http"
"slices"
"strings"
"time"
@@ -69,9 +68,39 @@ var preferredTitleModels = []struct {
type shortTextCandidate struct {
provider string
model string
route resolvedModelRoute
lm fantasy.LanguageModel
}
func (p *Server) preferredShortTextCandidates(
chat database.Chat,
keys chatprovider.ProviderAPIKeys,
) []shortTextCandidate {
if p.shouldUseAIGatewayRouting() {
return nil
}
candidates := make([]shortTextCandidate, 0, len(preferredTitleModels)+1)
userAgent := chatprovider.UserAgent()
extraHeaders := chatprovider.CoderHeaders(chat)
for _, candidate := range preferredTitleModels {
model, err := chatprovider.ModelFromConfig(
candidate.provider, candidate.model, keys, userAgent,
extraHeaders,
nil,
)
if err == nil {
candidates = append(candidates, shortTextCandidate{
provider: candidate.provider,
model: candidate.model,
route: newDirectModelRoute(candidate.provider, keys),
lm: model,
})
}
}
return candidates
}
func selectPreferredConfiguredShortTextModelConfig(
configs []database.ChatModelConfig,
) (database.ChatModelConfig, bool) {
@@ -121,7 +150,9 @@ func (p *Server) maybeGenerateChatTitle(
fallbackProvider string,
fallbackModelName string,
fallbackModel fantasy.LanguageModel,
fallbackRoute resolvedModelRoute,
keys chatprovider.ProviderAPIKeys,
modelOpts modelBuildOptions,
generatedTitle *generatedChatTitle,
logger slog.Logger,
debugSvc *chatdebug.Service,
@@ -135,10 +166,11 @@ func (p *Server) maybeGenerateChatTitle(
titleCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
overrideConfig, overrideModel, overrideKeys, overrideSet, overrideErr := p.resolveTitleGenerationModelOverride(
overrideConfig, overrideModel, _, overrideRoute, overrideSet, overrideErr := p.resolveTitleGenerationModelOverride(
titleCtx,
chat,
keys,
modelOpts,
)
if overrideErr != nil {
if overrideSet {
@@ -161,29 +193,15 @@ func (p *Server) maybeGenerateChatTitle(
candidates = []shortTextCandidate{{
provider: overrideConfig.Provider,
model: overrideConfig.Model,
route: overrideRoute,
lm: overrideModel,
}}
} else {
// Build candidate list: preferred lightweight models first,
// then the user's chat model as last resort.
candidates = make([]shortTextCandidate, 0, len(preferredTitleModels)+1)
for _, c := range preferredTitleModels {
m, err := chatprovider.ModelFromConfig(
c.provider, c.model, keys, chatprovider.UserAgent(),
chatprovider.CoderHeaders(chat),
nil,
)
if err == nil {
candidates = append(candidates, shortTextCandidate{
provider: c.provider,
model: c.model,
lm: m,
})
}
}
candidates = p.preferredShortTextCandidates(chat, keys)
candidates = append(candidates, shortTextCandidate{
provider: fallbackProvider,
model: fallbackModelName,
route: fallbackRoute,
lm: fallbackModel,
})
}
@@ -213,17 +231,13 @@ func (p *Server) maybeGenerateChatTitle(
candidateCtx := titleCtx
candidateModel := candidate.lm
finishDebugRun := func(error) {}
candidateKeys := keys
if overrideSet {
candidateKeys = overrideKeys
}
if debugEnabled {
candidateCtx, candidateModel, finishDebugRun = prepareQuickgenDebugCandidate(
candidateCtx, candidateModel, finishDebugRun = p.prepareQuickgenDebugCandidate(
titleCtx,
chat,
candidateKeys,
debugSvc,
candidate,
modelOpts,
chatdebug.KindTitleGeneration,
triggerMessageID,
historyTipMessageID,
@@ -293,32 +307,26 @@ func (p *Server) maybeGenerateChatTitle(
}
}
func newQuickgenDebugModel(
func (p *Server) newQuickgenDebugModel(
ctx context.Context,
chat database.Chat,
keys chatprovider.ProviderAPIKeys,
debugSvc *chatdebug.Service,
provider string,
model string,
route resolvedModelRoute,
modelOpts modelBuildOptions,
) (fantasy.LanguageModel, error) {
httpClient := &http.Client{Transport: &chatdebug.RecordingTransport{}}
debugModel, err := chatprovider.ModelFromConfig(
provider,
model,
keys,
chatprovider.UserAgent(),
chatprovider.CoderHeaders(chat),
httpClient,
)
debugOpts := modelOpts
debugOpts.RecordHTTP = true
debugModel, err := p.newModel(ctx, modelClientRequest{
Chat: chat,
ModelName: model,
UserAgent: chatprovider.UserAgent(),
ExtraHeaders: chatprovider.CoderHeaders(chat),
}, route, debugOpts)
if err != nil {
return nil, err
}
if debugModel == nil {
return nil, xerrors.Errorf(
"create model for %s/%s returned nil",
provider,
model,
)
}
return chatdebug.WrapModel(debugModel, debugSvc, chatdebug.RecorderOptions{
ChatID: chat.ID,
@@ -328,12 +336,12 @@ func newQuickgenDebugModel(
}), nil
}
func prepareQuickgenDebugCandidate(
func (p *Server) prepareQuickgenDebugCandidate(
ctx context.Context,
chat database.Chat,
keys chatprovider.ProviderAPIKeys,
debugSvc *chatdebug.Service,
candidate shortTextCandidate,
modelOpts modelBuildOptions,
kind chatdebug.RunKind,
triggerMessageID int64,
historyTipMessageID int64,
@@ -345,12 +353,14 @@ func prepareQuickgenDebugCandidate(
return ctx, candidate.lm, finishDebugRun
}
debugModel, err := newQuickgenDebugModel(
debugModel, err := p.newQuickgenDebugModel(
ctx,
chat,
keys,
debugSvc,
candidate.provider,
candidate.model,
candidate.route,
modelOpts,
)
if err != nil {
logger.Warn(ctx, "failed to build short-text debug model",
@@ -393,18 +403,8 @@ func prepareQuickgenDebugCandidate(
return ctx, candidate.lm, finishDebugRun
}
runCtx := chatdebug.ContextWithRun(
ctx,
&chatdebug.RunContext{
RunID: run.ID,
ChatID: chat.ID,
TriggerMessageID: triggerMessageID,
HistoryTipMessageID: historyTipMessageID,
Kind: kind,
Provider: candidate.provider,
Model: candidate.model,
},
)
runContext := chatdebugRunContext(run)
runCtx := chatdebug.ContextWithRun(ctx, &runContext)
finishDebugRun = func(runErr error) {
if finalizeErr := debugSvc.FinalizeRun(ctx, chatdebug.FinalizeRunParams{
RunID: run.ID,
@@ -824,7 +824,7 @@ const turnStatusLabelPrompt = "You write compact chat status labels for a sideba
// message text. It follows the same candidate-selection strategy
// as title generation: try preferred lightweight models first, then
// fall back to the provided model. Returns "" on any failure.
func generateTurnStatusLabel(
func (p *Server) generateTurnStatusLabel(
ctx context.Context,
chat database.Chat,
status database.ChatStatus,
@@ -832,7 +832,9 @@ func generateTurnStatusLabel(
fallbackProvider string,
fallbackModelName string,
fallbackModel fantasy.LanguageModel,
fallbackRoute resolvedModelRoute,
keys chatprovider.ProviderAPIKeys,
modelOpts modelBuildOptions,
logger slog.Logger,
debugSvc *chatdebug.Service,
triggerMessageID int64,
@@ -848,24 +850,11 @@ func generateTurnStatusLabel(
"\nChat title: " + chat.Title +
"\n\nAgent's latest message:\n" + assistantText
candidates := make([]shortTextCandidate, 0, len(preferredTitleModels)+1)
for _, c := range preferredTitleModels {
m, err := chatprovider.ModelFromConfig(
c.provider, c.model, keys, chatprovider.UserAgent(),
chatprovider.CoderHeaders(chat),
nil,
)
if err == nil {
candidates = append(candidates, shortTextCandidate{
provider: c.provider,
model: c.model,
lm: m,
})
}
}
candidates := p.preferredShortTextCandidates(chat, keys)
candidates = append(candidates, shortTextCandidate{
provider: fallbackProvider,
model: fallbackModelName,
route: fallbackRoute,
lm: fallbackModel,
})
@@ -876,12 +865,12 @@ func generateTurnStatusLabel(
candidateModel := candidate.lm
finishDebugRun := func(error) {}
if debugEnabled {
candidateCtx, candidateModel, finishDebugRun = prepareQuickgenDebugCandidate(
candidateCtx, candidateModel, finishDebugRun = p.prepareQuickgenDebugCandidate(
labelCtx,
chat,
keys,
debugSvc,
candidate,
modelOpts,
chatdebug.KindQuickgen,
triggerMessageID,
historyTipMessageID,
+12
View File
@@ -359,6 +359,16 @@ func Test_renderManualTitlePrompt(t *testing.T) {
}
}
func TestPreferredShortTextCandidatesNilUnderAIGateway(t *testing.T) {
t.Parallel()
server := &Server{aiGatewayRoutingEnabled: true}
candidates := server.preferredShortTextCandidates(database.Chat{}, chatprovider.ProviderAPIKeys{
ByProvider: map[string]string{"openai": "test-key"},
})
require.Nil(t, candidates)
}
func TestMaybeGenerateChatTitlePreservesUpdatedAt(t *testing.T) {
t.Parallel()
@@ -428,7 +438,9 @@ func TestMaybeGenerateChatTitlePreservesUpdatedAt(t *testing.T) {
"openai",
"test-model",
model,
resolvedModelRoute{},
chatprovider.ProviderAPIKeys{},
modelBuildOptions{},
generated,
logger,
nil,
+20 -1
View File
@@ -17,6 +17,7 @@ import (
"golang.org/x/xerrors"
"cdr.dev/slog/v3"
"github.com/coder/coder/v2/coderd/aibridge"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
@@ -908,6 +909,14 @@ func (p *Server) resolveExploreToolSnapshot(
return inheritedMCPServerIDs, nil
}
func (p *Server) delegatedAPIKeyIDForSubagent(ctx context.Context) (string, error) {
apiKeyID, ok := aibridge.DelegatedAPIKeyIDFromContext(ctx)
if !ok && p.shouldUseAIGatewayRouting() {
return "", xerrors.New("AI Gateway routing requires the active turn API key ID for subagent messages")
}
return apiKeyID, nil
}
func (p *Server) createChildSubagentChat(
ctx context.Context,
parent database.Chat,
@@ -950,6 +959,10 @@ func (p *Server) createChildSubagentChatWithOptions(
if modelConfigID == uuid.Nil {
return database.Chat{}, xerrors.New("model config is required")
}
childAPIKeyID, err := p.delegatedAPIKeyIDForSubagent(ctx)
if err != nil {
return database.Chat{}, err
}
childPlanMode := parent.PlanMode
if opts.planModeOverride != nil {
@@ -1086,7 +1099,7 @@ func (p *Server) createChildSubagentChatWithOptions(
database.ChatMessageVisibilityBoth,
modelConfigID,
chatprompt.CurrentContentVersion,
).withCreatedBy(parent.OwnerID))
).withCreatedBy(parent.OwnerID).withAPIKeyID(childAPIKeyID))
if _, err := tx.InsertChatMessages(ctx, userParams); err != nil {
return xerrors.Errorf("insert initial child user message: %w", err)
}
@@ -1236,10 +1249,16 @@ func (p *Server) sendSubagentMessage(
return database.Chat{}, xerrors.Errorf("get target chat: %w", err)
}
apiKeyID, err := p.delegatedAPIKeyIDForSubagent(ctx)
if err != nil {
return database.Chat{}, err
}
sendResult, err := p.SendMessage(ctx, SendMessageOptions{
ChatID: targetChatID,
CreatedBy: targetChat.OwnerID,
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText(message)},
APIKeyID: apiKeyID,
BusyBehavior: busyBehavior,
})
if err != nil {
+179 -1
View File
@@ -16,6 +16,7 @@ import (
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/aibridge"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbgen"
@@ -231,6 +232,163 @@ func insertInternalAIProvider(
})
}
func TestCreateChildSubagentChatPropagatesActiveTurnAPIKeyID(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
db, _ := dbtestutil.NewDB(t)
user := dbgen.User(t, db, database.User{})
org := dbgen.Organization(t, db, database.Organization{})
dbgen.OrganizationMember(t, db, database.OrganizationMember{UserID: user.ID, OrganizationID: org.ID})
model := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{})
parent := dbgen.Chat(t, db, database.Chat{
OrganizationID: org.ID,
OwnerID: user.ID,
LastModelConfigID: model.ID,
})
apiKey, _ := dbgen.APIKey(t, db, database.APIKey{UserID: user.ID})
ctx = aibridge.WithDelegatedAPIKeyID(ctx, apiKey.ID)
server := &Server{db: db, logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})}
child, err := server.createChildSubagentChat(ctx, parent, "inspect the workspace", "")
require.NoError(t, err)
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ChatID: child.ID})
require.NoError(t, err)
var childUserMessage database.ChatMessage
for _, message := range messages {
if message.Role == database.ChatMessageRoleUser {
childUserMessage = message
break
}
}
require.NotZero(t, childUserMessage.ID)
require.True(t, childUserMessage.APIKeyID.Valid)
require.Equal(t, apiKey.ID, childUserMessage.APIKeyID.String)
}
func TestSendSubagentMessagePropagatesActiveTurnAPIKeyID(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
ctx := chatdTestContext(t)
user, org, model := seedInternalChatDeps(t, db)
apiKey, _ := dbgen.APIKey(t, db, database.APIKey{UserID: user.ID})
parent, err := server.CreateChat(ctx, CreateOptions{
OrganizationID: org.ID,
OwnerID: user.ID,
Title: "parent-send-subagent-key",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
APIKeyID: apiKey.ID,
})
require.NoError(t, err)
child, err := server.CreateChat(ctx, CreateOptions{
OrganizationID: org.ID,
OwnerID: user.ID,
ParentChatID: uuid.NullUUID{UUID: parent.ID, Valid: true},
RootChatID: uuid.NullUUID{UUID: parent.ID, Valid: true},
Title: "child-send-subagent-key",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{
codersdk.ChatMessageText("do work"),
},
})
require.NoError(t, err)
setChatStatus(ctx, t, db, child.ID, database.ChatStatusWaiting, "")
ctx = aibridge.WithDelegatedAPIKeyID(ctx, apiKey.ID)
_, err = server.sendSubagentMessage(
ctx,
parent.ID,
child.ID,
"follow up",
SendMessageBusyBehaviorInterrupt,
)
require.NoError(t, err)
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ChatID: child.ID})
require.NoError(t, err)
var latestUserMessage database.ChatMessage
for _, message := range messages {
if message.Role == database.ChatMessageRoleUser && message.ID > latestUserMessage.ID {
latestUserMessage = message
}
}
require.NotZero(t, latestUserMessage.ID)
require.True(t, latestUserMessage.APIKeyID.Valid)
require.Equal(t, apiKey.ID, latestUserMessage.APIKeyID.String)
}
func TestCreateChildSubagentChatRequiresActiveTurnAPIKeyIDForAIGateway(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitShort)
db, _ := dbtestutil.NewDB(t)
user := dbgen.User(t, db, database.User{})
org := dbgen.Organization(t, db, database.Organization{})
dbgen.OrganizationMember(t, db, database.OrganizationMember{UserID: user.ID, OrganizationID: org.ID})
model := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{})
parent := dbgen.Chat(t, db, database.Chat{
OrganizationID: org.ID,
OwnerID: user.ID,
LastModelConfigID: model.ID,
})
server := &Server{
db: db,
logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
aiGatewayRoutingEnabled: true,
}
_, err := server.createChildSubagentChat(ctx, parent, "inspect the workspace", "")
require.ErrorContains(t, err, "AI Gateway routing requires the active turn API key ID for subagent messages")
}
func TestSendSubagentMessageRequiresActiveTurnAPIKeyIDForAIGateway(t *testing.T) {
t.Parallel()
db, ps := dbtestutil.NewDB(t)
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
server.aiGatewayRoutingEnabled = true
ctx := chatdTestContext(t)
user, org, model := seedInternalChatDeps(t, db)
parent, err := server.CreateChat(ctx, CreateOptions{
OrganizationID: org.ID,
OwnerID: user.ID,
Title: "parent-send-subagent-missing-key",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
})
require.NoError(t, err)
child, err := server.CreateChat(ctx, CreateOptions{
OrganizationID: org.ID,
OwnerID: user.ID,
ParentChatID: uuid.NullUUID{UUID: parent.ID, Valid: true},
RootChatID: uuid.NullUUID{UUID: parent.ID, Valid: true},
Title: "child-send-subagent-missing-key",
ModelConfigID: model.ID,
InitialUserContent: []codersdk.ChatMessagePart{
codersdk.ChatMessageText("do work"),
},
})
require.NoError(t, err)
setChatStatus(ctx, t, db, child.ID, database.ChatStatusWaiting, "")
_, err = server.sendSubagentMessage(
ctx,
parent.ID,
child.ID,
"follow up",
SendMessageBusyBehaviorInterrupt,
)
require.ErrorContains(t, err, "AI Gateway routing requires the active turn API key ID for subagent messages")
}
func TestResolveUserProviderAPIKeys_AIProvider(t *testing.T) {
t.Parallel()
@@ -351,7 +509,7 @@ func TestResolveChatModel_AIProviderDisabled(t *testing.T) {
LastModelConfigID: modelConfig.ID,
})
model, config, keys, debugEnabled, resolvedProvider, resolvedModel, err := server.resolveChatModel(ctx, chat)
model, config, keys, _, debugEnabled, resolvedProvider, resolvedModel, err := server.resolveChatModel(ctx, chat, modelBuildOptions{})
require.ErrorContains(t, err, "is disabled")
require.Nil(t, model)
require.Equal(t, database.ChatModelConfig{}, config)
@@ -3187,6 +3345,26 @@ func TestAwaitSubagentCompletion(t *testing.T) {
parent, child := createParentChildChats(ctx, t, server, user, org, model)
// signalWake from CreateChat triggers background processing. Wait
// for those runs to finish, then reset both chats so this test owns
// the state transition observed by the poll loop.
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
parentChat, err := db.GetChatByID(ctx, parent.ID)
if err != nil {
return false
}
childChat, err := db.GetChatByID(ctx, child.ID)
if err != nil {
return false
}
return parentChat.Status != database.ChatStatusPending &&
parentChat.Status != database.ChatStatusRunning &&
childChat.Status != database.ChatStatusPending &&
childChat.Status != database.ChatStatusRunning
}, testutil.IntervalFast)
setChatStatus(ctx, t, db, parent.ID, database.ChatStatusRunning, "")
setChatStatus(ctx, t, db, child.ID, database.ChatStatusRunning, "")
// Set the trap BEFORE starting the goroutine so we
// deterministically catch the ticker creation.
tickTrap := mClock.Trap().NewTicker("chatd", "subagent_poll")
+21 -46
View File
@@ -31,28 +31,18 @@ func readTitleGenerationModelOverride(
}
// resolveTitleGenerationModelOverride resolves the deployment-wide title
// generation model override. It returns four values:
//
// - modelConfig and model: populated only on success.
// - overrideSet: true when the admin configured a non-empty override,
// regardless of whether resolution succeeded. Callers MUST always check
// err first; overrideSet alone does not imply the model is usable.
// - err: non-nil when resolution failed. DB read failure returns
// (zero, nil, false, err). With overrideSet=true, the override is
// configured but unusable (deleted model, missing credentials, etc.) and
// callers should treat this as a hard failure for explicit-override
// semantics, not a soft fallback.
//
// When the override is unset or stored as malformed, the function returns
// (zero, nil, false, nil) so callers can fall back to default behavior.
// generation model override. overrideSet is true when an override was
// configured; in that case any returned error is a hard failure. When
// overrideSet is false, callers may fall back to the default title model.
func (p *Server) resolveTitleGenerationModelOverride(
ctx context.Context,
chat database.Chat,
keys chatprovider.ProviderAPIKeys,
) (database.ChatModelConfig, fantasy.LanguageModel, chatprovider.ProviderAPIKeys, bool, error) {
modelOpts modelBuildOptions,
) (database.ChatModelConfig, fantasy.LanguageModel, chatprovider.ProviderAPIKeys, resolvedModelRoute, bool, error) {
raw, err := readTitleGenerationModelOverride(ctx, p.db)
if err != nil {
return database.ChatModelConfig{}, nil, chatprovider.ProviderAPIKeys{}, false, xerrors.Errorf(
return database.ChatModelConfig{}, nil, chatprovider.ProviderAPIKeys{}, resolvedModelRoute{}, false, xerrors.Errorf(
"read title generation model override: %w",
err,
)
@@ -84,43 +74,28 @@ func (p *Server) resolveTitleGenerationModelOverride(
modelOverrideFailureModeHard,
)
if err != nil {
return database.ChatModelConfig{}, nil, chatprovider.ProviderAPIKeys{}, overrideSet, err
return database.ChatModelConfig{}, nil, chatprovider.ProviderAPIKeys{}, resolvedModelRoute{}, overrideSet, err
}
if !overrideSet {
return database.ChatModelConfig{}, nil, keys, false, nil
return database.ChatModelConfig{}, nil, keys, resolvedModelRoute{}, false, nil
}
providerHint := modelConfig.Provider
if modelConfig.AIProviderID.Valid {
//nolint:gocritic // Title overrides need chatd-scoped provider reads for user-owned chats.
provider, err := p.db.GetAIProviderByID(dbauthz.AsChatd(ctx), modelConfig.AIProviderID.UUID)
if err != nil {
return database.ChatModelConfig{}, nil, chatprovider.ProviderAPIKeys{}, true, xerrors.Errorf("get AI provider for title generation override: %w", err)
}
if !provider.Enabled {
return database.ChatModelConfig{}, nil, chatprovider.ProviderAPIKeys{}, true, xerrors.Errorf("AI provider %s is disabled", modelConfig.AIProviderID.UUID)
}
providerHint = string(provider.Type)
}
model, err := chatprovider.ModelFromConfig(
providerHint,
modelConfig.Model,
overrideProviderKeys,
chatprovider.UserAgent(),
chatprovider.CoderHeaders(chat),
nil,
)
//nolint:gocritic // Title overrides need chatd-scoped provider reads for user-owned chats.
route, err := p.resolveModelRouteForConfig(dbauthz.AsChatd(ctx), chat.OwnerID, modelConfig, overrideProviderKeys)
if err != nil {
return database.ChatModelConfig{}, nil, chatprovider.ProviderAPIKeys{}, true, xerrors.Errorf(
return database.ChatModelConfig{}, nil, chatprovider.ProviderAPIKeys{}, resolvedModelRoute{}, true, err
}
model, err := p.newModel(ctx, modelClientRequest{
Chat: chat,
ModelName: modelConfig.Model,
UserAgent: chatprovider.UserAgent(),
ExtraHeaders: chatprovider.CoderHeaders(chat),
}, route, modelOpts)
if err != nil {
return database.ChatModelConfig{}, nil, chatprovider.ProviderAPIKeys{}, resolvedModelRoute{}, true, xerrors.Errorf(
"create title generation model override: %w",
err,
)
}
if model == nil {
return database.ChatModelConfig{}, nil, chatprovider.ProviderAPIKeys{}, true, xerrors.Errorf(
"create title generation model override returned nil",
)
}
return modelConfig, model, overrideProviderKeys, true, nil
return modelConfig, model, route.directProviderKeys(), route, true, nil
}
+21 -1
View File
@@ -65,7 +65,9 @@ func TestMaybeGenerateChatTitle_TitleGenerationOverrideUnset(t *testing.T) {
"openai",
"fallback-chat-model",
fallbackModel,
resolvedModelRoute{},
keys,
modelBuildOptions{},
generated,
logger,
nil,
@@ -112,7 +114,9 @@ func TestMaybeGenerateChatTitle_TitleGenerationOverrideUnset(t *testing.T) {
"openai",
"fallback-chat-model",
fallbackModel,
resolvedModelRoute{},
chatprovider.ProviderAPIKeys{},
modelBuildOptions{},
generated,
logger,
nil,
@@ -160,7 +164,9 @@ func TestMaybeGenerateChatTitle_TitleGenerationOverrideReadDBError(t *testing.T)
"openai",
"fallback-chat-model",
fallbackModel,
resolvedModelRoute{},
chatprovider.ProviderAPIKeys{},
modelBuildOptions{},
generated,
logger,
nil,
@@ -207,7 +213,9 @@ func TestMaybeGenerateChatTitle_TitleGenerationOverrideMalformedFallsThrough(t *
"openai",
"fallback-chat-model",
fallbackModel,
resolvedModelRoute{},
chatprovider.ProviderAPIKeys{},
modelBuildOptions{},
generated,
logger,
nil,
@@ -257,7 +265,7 @@ func TestMaybeGenerateChatTitle_TitleGenerationOverrideSetUsable(t *testing.T) {
db.EXPECT().GetAIProviderKeysByProviderID(gomock.Any(), providerID).Return([]database.AIProviderKey{{
ProviderID: providerID,
APIKey: "test-key",
}}, nil)
}}, nil).Times(2)
db.EXPECT().UpdateChatTitleByID(gomock.Any(), database.UpdateChatTitleByIDParams{
ID: chat.ID,
Title: wantTitle,
@@ -272,7 +280,9 @@ func TestMaybeGenerateChatTitle_TitleGenerationOverrideSetUsable(t *testing.T) {
"openai",
"fallback-chat-model",
fallbackModel,
resolvedModelRoute{},
chatprovider.ProviderAPIKeys{},
modelBuildOptions{},
generated,
logger,
nil,
@@ -312,7 +322,9 @@ func TestMaybeGenerateChatTitle_TitleGenerationOverrideSetUnusableSkips(t *testi
"openai",
"fallback-chat-model",
fallbackModel,
resolvedModelRoute{},
chatprovider.ProviderAPIKeys{},
modelBuildOptions{},
generated,
logger,
nil,
@@ -360,7 +372,9 @@ func TestMaybeGenerateChatTitle_TitleGenerationOverrideCallFailureSkipsFallback(
"openai",
"fallback-chat-model",
fallbackModel,
resolvedModelRoute{},
keys,
modelBuildOptions{},
generated,
logger,
nil,
@@ -398,6 +412,7 @@ func TestResolveManualTitleModel_TitleGenerationOverrideUnset(t *testing.T) {
db,
chat,
chatprovider.ProviderAPIKeys{ByProvider: map[string]string{"openai": "test-key"}},
modelBuildOptions{},
)
require.NoError(t, err)
require.NotNil(t, model)
@@ -447,6 +462,7 @@ func TestResolveManualTitleModel_TitleGenerationOverrideUnsetAIProvider(t *testi
db,
chat,
chatprovider.ProviderAPIKeys{},
modelBuildOptions{},
)
require.NoError(t, err)
require.NotNil(t, model)
@@ -481,6 +497,7 @@ func TestResolveManualTitleModel_TitleGenerationOverrideReadDBError(t *testing.T
db,
chat,
chatprovider.ProviderAPIKeys{ByProvider: map[string]string{"openai": "test-key"}},
modelBuildOptions{},
)
require.NoError(t, err)
require.NotNil(t, model)
@@ -508,6 +525,7 @@ func TestResolveManualTitleModel_TitleGenerationOverrideSetUsable(t *testing.T)
db,
chat,
chatprovider.ProviderAPIKeys{ByProvider: map[string]string{"openai": "test-key"}},
modelBuildOptions{},
)
require.NoError(t, err)
require.NotNil(t, model)
@@ -535,6 +553,7 @@ func TestResolveManualTitleModel_TitleGenerationOverrideMissingCredentials(t *te
db,
chat,
chatprovider.ProviderAPIKeys{},
modelBuildOptions{},
)
require.Error(t, err)
require.ErrorContains(t, err, "resolve manual title generation model override")
@@ -562,6 +581,7 @@ func TestResolveManualTitleModel_TitleGenerationOverrideSetUnusable(t *testing.T
db,
chat,
chatprovider.ProviderAPIKeys{ByProvider: map[string]string{"openai": "test-key"}},
modelBuildOptions{},
)
require.Error(t, err)
require.ErrorContains(t, err, "resolve manual title generation model override")
+14 -2
View File
@@ -4058,6 +4058,17 @@ Write out the current server config as YAML to stdout.`,
Group: &deploymentGroupChat,
YAML: "debugLoggingEnabled",
},
{
Name: "Chat: AI Gateway Routing Enabled",
Description: "Route chat model requests through AI Gateway when both chat routing and AI Gateway are enabled. Otherwise, chat calls AI providers directly. Pending chats without API key metadata may need a retry or temporary direct routing.",
Flag: "chat-ai-gateway-routing-enabled",
Env: "CODER_CHAT_AI_GATEWAY_ROUTING_ENABLED",
Value: &c.AI.Chat.AIGatewayRoutingEnabled,
Default: "true",
Group: &deploymentGroupChat,
YAML: "aiGatewayRoutingEnabled",
Hidden: true,
},
// AI Bridge Options (deprecated in favor of AI Gateway options)
{
Name: "AI Bridge Enabled",
@@ -4714,8 +4725,9 @@ type AIBridgeProxyConfig struct {
}
type ChatConfig struct {
AcquireBatchSize serpent.Int64 `json:"acquire_batch_size" typescript:",notnull"`
DebugLoggingEnabled serpent.Bool `json:"debug_logging_enabled" typescript:",notnull"`
AcquireBatchSize serpent.Int64 `json:"acquire_batch_size" typescript:",notnull"`
DebugLoggingEnabled serpent.Bool `json:"debug_logging_enabled" typescript:",notnull"`
AIGatewayRoutingEnabled serpent.Bool `json:"ai_gateway_routing_enabled" typescript:",notnull" swaggerignore:"true"`
}
type AIConfig struct {
+9
View File
@@ -916,6 +916,15 @@ func TestRetentionConfigParsing(t *testing.T) {
}
}
func TestChatAIGatewayRoutingEnabledDefault(t *testing.T) {
t.Parallel()
dv := codersdk.DeploymentValues{}
opts := dv.Options()
require.NoError(t, opts.SetDefaults())
require.True(t, dv.AI.Chat.AIGatewayRoutingEnabled.Value())
}
func TestAIBudgetConfigParsing(t *testing.T) {
t.Parallel()
+25
View File
@@ -77,6 +77,9 @@ func TestChatStreamRelay(t *testing.T) {
Options: &coderdtest.Options{
Database: db,
Pubsub: pubsub,
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
require.NoError(t, dv.AI.Chat.AIGatewayRoutingEnabled.Set("false"))
}),
},
LicenseOptions: &coderdenttest.LicenseOptions{
Features: license.Features{
@@ -89,6 +92,9 @@ func TestChatStreamRelay(t *testing.T) {
Options: &coderdtest.Options{
Database: db,
Pubsub: pubsub,
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
require.NoError(t, dv.AI.Chat.AIGatewayRoutingEnabled.Set("false"))
}),
},
DontAddLicense: true,
DontAddFirstUser: true,
@@ -219,6 +225,9 @@ func TestChatStreamRelay(t *testing.T) {
Database: db,
Pubsub: pubsub,
TLSCertificates: certificates,
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
require.NoError(t, dv.AI.Chat.AIGatewayRoutingEnabled.Set("false"))
}),
},
LicenseOptions: &coderdenttest.LicenseOptions{
Features: license.Features{
@@ -232,6 +241,9 @@ func TestChatStreamRelay(t *testing.T) {
Database: db,
Pubsub: pubsub,
TLSCertificates: certificates,
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
require.NoError(t, dv.AI.Chat.AIGatewayRoutingEnabled.Set("false"))
}),
},
DontAddLicense: true,
DontAddFirstUser: true,
@@ -398,6 +410,9 @@ func TestChatStreamRelay(t *testing.T) {
Options: &coderdtest.Options{
Database: db,
Pubsub: pubsub,
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
require.NoError(t, dv.AI.Chat.AIGatewayRoutingEnabled.Set("false"))
}),
},
LicenseOptions: &coderdenttest.LicenseOptions{
Features: license.Features{
@@ -410,6 +425,9 @@ func TestChatStreamRelay(t *testing.T) {
Options: &coderdtest.Options{
Database: db,
Pubsub: pubsub,
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
require.NoError(t, dv.AI.Chat.AIGatewayRoutingEnabled.Set("false"))
}),
},
DontAddLicense: true,
DontAddFirstUser: true,
@@ -544,6 +562,7 @@ func TestChatStreamRelay(t *testing.T) {
db, pubsub := dbtestutil.NewDB(t)
hostPrefixValues := coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
require.NoError(t, dv.AI.Chat.AIGatewayRoutingEnabled.Set("false"))
dv.HTTPCookies.EnableHostPrefix = true
dv.HTTPCookies.Secure = true
})
@@ -696,6 +715,9 @@ func TestChatStreamRelay(t *testing.T) {
Options: &coderdtest.Options{
Database: db,
Pubsub: pubsub,
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
require.NoError(t, dv.AI.Chat.AIGatewayRoutingEnabled.Set("false"))
}),
},
LicenseOptions: &coderdenttest.LicenseOptions{
Features: license.Features{
@@ -708,6 +730,9 @@ func TestChatStreamRelay(t *testing.T) {
Options: &coderdtest.Options{
Database: db,
Pubsub: pubsub,
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
require.NoError(t, dv.AI.Chat.AIGatewayRoutingEnabled.Set("false"))
}),
},
DontAddLicense: true,
DontAddFirstUser: true,
+1
View File
@@ -1613,6 +1613,7 @@ export interface ChatComputerUseProviderResponse {
export interface ChatConfig {
readonly acquire_batch_size: number;
readonly debug_logging_enabled: boolean;
readonly ai_gateway_routing_enabled: boolean;
}
// From codersdk/chats.go