mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
feat: route chatd provider traffic through aibridge (#25629)
## Summary Routes chatd model calls backed by concrete AI Provider rows through the in-process aibridge transport by default, with deployment options to use direct provider routing when AI Gateway is disabled or chat AI Gateway routing is disabled. - Splits model routing into common, direct provider, and AI Gateway paths behind a single deployment-mode entry point. - Builds chatd models through explicit request, route, and options data. Active API key attribution is passed explicitly instead of being hidden inside generic model construction. - For AI Gateway BYOK routes, resolves the user's provider key in chatd, forwards it through provider-specific auth headers, and sets `X-Coder-AI-Governance-Token` to the `delegated` marker so aibridge preserves those headers while still stripping Coder-specific metadata. - Keeps central provider credentials and deployment fallback credentials out of forwarded provider auth headers, so AI Gateway central policy remains authoritative. - Redacts delegated provider auth from default string formatting to avoid accidental plaintext logging of user BYOK credentials. - Covers selected chat models, advisor overrides, title and quickgen paths, subagent overrides, computer use model selection, and an integration-style chat turn through the aibridge transport path. - Persists initiating API key IDs on chat and queued user messages, including subagent child messages, and fails closed for AI Gateway-routed model builds without an active key. - Removes unused `api_key_id` indexes while keeping the persistence columns and foreign keys. - Keeps the deployment option available through config and env parsing, but hides it from CLI help and generated docs. - Stabilizes the subagent poll fallback test so background CreateChat processing cannot win the state transition under slower CI environments. ## Tests - `go test ./coderd/x/chatd -run 'TestAIGatewayProviderAuthForUser|TestAIGatewayProviderAuthRedactsFormatting|TestResolveModelRouteForConfigAIGatewayProviderAuth|TestAIGatewayModelForwardsProviderAuth|TestProcessChat_AIGatewayRoutingUsesDelegatedAPIKey|TestAwaitSubagentCompletion' -count=1` - `go test ./coderd/aibridged -run 'TestServeHTTP_DelegatedAPIKey|TestServeHTTP_StripCoderToken' -count=1` - `git diff --check HEAD~1..HEAD` - `make lint` > Mux working on behalf of Mike.
This commit is contained in:
+5
@@ -765,6 +765,11 @@ chat:
|
||||
# opt-in settings.
|
||||
# (default: false, type: bool)
|
||||
debugLoggingEnabled: false
|
||||
# Route chat model requests through AI Gateway when both chat routing and AI
|
||||
# Gateway are enabled. Otherwise, chat calls AI providers directly. Pending chats
|
||||
# without API key metadata may need a retry or temporary direct routing.
|
||||
# (default: true, type: bool)
|
||||
aiGatewayRoutingEnabled: true
|
||||
aibridge:
|
||||
# Deprecated: use --ai-gateway-enabled or CODER_AI_GATEWAY_ENABLED instead.
|
||||
# Whether to start an in-memory aibridged instance.
|
||||
|
||||
@@ -807,6 +807,9 @@ func New(options *Options) *API {
|
||||
providerAPIKeys = *options.ChatProviderAPIKeys
|
||||
}
|
||||
|
||||
chatAIGatewayRoutingEnabled := options.DeploymentValues.AI.BridgeConfig.Enabled.Value() &&
|
||||
options.DeploymentValues.AI.Chat.AIGatewayRoutingEnabled.Value()
|
||||
|
||||
api.chatDaemon = chatd.New(chatd.Config{
|
||||
Logger: options.Logger.Named("chatd"),
|
||||
Database: options.Database,
|
||||
@@ -816,6 +819,8 @@ func New(options *Options) *API {
|
||||
ProviderAPIKeys: providerAPIKeys,
|
||||
AllowBYOK: options.DeploymentValues.AI.BridgeConfig.AllowBYOK.Value(),
|
||||
AllowBYOKSet: true,
|
||||
AIBridgeTransportFactory: &api.AIBridgeTransportFactory,
|
||||
AIGatewayRoutingEnabled: chatAIGatewayRoutingEnabled,
|
||||
AlwaysEnableDebugLogs: options.DeploymentValues.AI.Chat.DebugLoggingEnabled.Value(),
|
||||
AgentConn: api.agentProvider.AgentConn,
|
||||
AgentInactiveDisconnectTimeout: api.AgentInactiveDisconnectTimeout,
|
||||
|
||||
@@ -119,6 +119,7 @@ func ChatMessage(t testing.TB, db database.Store, seed database.ChatMessage) dat
|
||||
msgs, err := db.InsertChatMessages(genCtx, database.InsertChatMessagesParams{
|
||||
ChatID: seed.ChatID,
|
||||
CreatedBy: []uuid.UUID{seed.CreatedBy.UUID},
|
||||
APIKeyID: []string{seed.APIKeyID.String},
|
||||
ModelConfigID: []uuid.UUID{seed.ModelConfigID.UUID},
|
||||
Role: []database.ChatMessageRole{takeFirst(seed.Role, database.ChatMessageRoleUser)},
|
||||
Content: []string{content},
|
||||
|
||||
Generated
+10
-2
@@ -1554,7 +1554,8 @@ CREATE TABLE chat_messages (
|
||||
total_cost_micros bigint,
|
||||
runtime_ms bigint,
|
||||
deleted boolean DEFAULT false NOT NULL,
|
||||
provider_response_id text
|
||||
provider_response_id text,
|
||||
api_key_id text
|
||||
);
|
||||
|
||||
CREATE SEQUENCE chat_messages_id_seq
|
||||
@@ -1593,7 +1594,8 @@ CREATE TABLE chat_queued_messages (
|
||||
chat_id uuid NOT NULL,
|
||||
content jsonb NOT NULL,
|
||||
created_at timestamp with time zone DEFAULT now() NOT NULL,
|
||||
model_config_id uuid
|
||||
model_config_id uuid,
|
||||
api_key_id text
|
||||
);
|
||||
|
||||
CREATE SEQUENCE chat_queued_messages_id_seq
|
||||
@@ -4453,6 +4455,9 @@ ALTER TABLE ONLY chat_files
|
||||
ALTER TABLE ONLY chat_files
|
||||
ADD CONSTRAINT chat_files_owner_id_fkey FOREIGN KEY (owner_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
|
||||
ALTER TABLE ONLY chat_messages
|
||||
ADD CONSTRAINT chat_messages_api_key_id_fkey FOREIGN KEY (api_key_id) REFERENCES api_keys(id) ON DELETE SET NULL;
|
||||
|
||||
ALTER TABLE ONLY chat_messages
|
||||
ADD CONSTRAINT chat_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
|
||||
@@ -4468,6 +4473,9 @@ ALTER TABLE ONLY chat_model_configs
|
||||
ALTER TABLE ONLY chat_model_configs
|
||||
ADD CONSTRAINT chat_model_configs_updated_by_fkey FOREIGN KEY (updated_by) REFERENCES users(id);
|
||||
|
||||
ALTER TABLE ONLY chat_queued_messages
|
||||
ADD CONSTRAINT chat_queued_messages_api_key_id_fkey FOREIGN KEY (api_key_id) REFERENCES api_keys(id) ON DELETE SET NULL;
|
||||
|
||||
ALTER TABLE ONLY chat_queued_messages
|
||||
ADD CONSTRAINT chat_queued_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
|
||||
|
||||
@@ -21,11 +21,13 @@ const (
|
||||
ForeignKeyChatFileLinksFileID ForeignKeyConstraint = "chat_file_links_file_id_fkey" // ALTER TABLE ONLY chat_file_links ADD CONSTRAINT chat_file_links_file_id_fkey FOREIGN KEY (file_id) REFERENCES chat_files(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatFilesOrganizationID ForeignKeyConstraint = "chat_files_organization_id_fkey" // ALTER TABLE ONLY chat_files ADD CONSTRAINT chat_files_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatFilesOwnerID ForeignKeyConstraint = "chat_files_owner_id_fkey" // ALTER TABLE ONLY chat_files ADD CONSTRAINT chat_files_owner_id_fkey FOREIGN KEY (owner_id) REFERENCES users(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatMessagesAPIKeyID ForeignKeyConstraint = "chat_messages_api_key_id_fkey" // ALTER TABLE ONLY chat_messages ADD CONSTRAINT chat_messages_api_key_id_fkey FOREIGN KEY (api_key_id) REFERENCES api_keys(id) ON DELETE SET NULL;
|
||||
ForeignKeyChatMessagesChatID ForeignKeyConstraint = "chat_messages_chat_id_fkey" // ALTER TABLE ONLY chat_messages ADD CONSTRAINT chat_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatMessagesModelConfigID ForeignKeyConstraint = "chat_messages_model_config_id_fkey" // ALTER TABLE ONLY chat_messages ADD CONSTRAINT chat_messages_model_config_id_fkey FOREIGN KEY (model_config_id) REFERENCES chat_model_configs(id);
|
||||
ForeignKeyChatModelConfigsAiProviderID ForeignKeyConstraint = "chat_model_configs_ai_provider_id_fkey" // ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_ai_provider_id_fkey FOREIGN KEY (ai_provider_id) REFERENCES ai_providers(id);
|
||||
ForeignKeyChatModelConfigsCreatedBy ForeignKeyConstraint = "chat_model_configs_created_by_fkey" // ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_created_by_fkey FOREIGN KEY (created_by) REFERENCES users(id);
|
||||
ForeignKeyChatModelConfigsUpdatedBy ForeignKeyConstraint = "chat_model_configs_updated_by_fkey" // ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_updated_by_fkey FOREIGN KEY (updated_by) REFERENCES users(id);
|
||||
ForeignKeyChatQueuedMessagesAPIKeyID ForeignKeyConstraint = "chat_queued_messages_api_key_id_fkey" // ALTER TABLE ONLY chat_queued_messages ADD CONSTRAINT chat_queued_messages_api_key_id_fkey FOREIGN KEY (api_key_id) REFERENCES api_keys(id) ON DELETE SET NULL;
|
||||
ForeignKeyChatQueuedMessagesChatID ForeignKeyConstraint = "chat_queued_messages_chat_id_fkey" // ALTER TABLE ONLY chat_queued_messages ADD CONSTRAINT chat_queued_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE;
|
||||
ForeignKeyChatsAgentID ForeignKeyConstraint = "chats_agent_id_fkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_agent_id_fkey FOREIGN KEY (agent_id) REFERENCES workspace_agents(id) ON DELETE SET NULL;
|
||||
ForeignKeyChatsBuildID ForeignKeyConstraint = "chats_build_id_fkey" // ALTER TABLE ONLY chats ADD CONSTRAINT chats_build_id_fkey FOREIGN KEY (build_id) REFERENCES workspace_builds(id) ON DELETE SET NULL;
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
ALTER TABLE chat_queued_messages
|
||||
DROP COLUMN api_key_id;
|
||||
|
||||
ALTER TABLE chat_messages
|
||||
DROP COLUMN api_key_id;
|
||||
@@ -0,0 +1,8 @@
|
||||
-- Preserve chat history when API keys are deleted. Pending work whose latest
|
||||
-- user turn loses this attribution will fail closed under AI Gateway routing;
|
||||
-- operators can retry the turn or temporarily use direct routing.
|
||||
ALTER TABLE chat_messages
|
||||
ADD COLUMN api_key_id text REFERENCES api_keys(id) ON DELETE SET NULL;
|
||||
|
||||
ALTER TABLE chat_queued_messages
|
||||
ADD COLUMN api_key_id text REFERENCES api_keys(id) ON DELETE SET NULL;
|
||||
@@ -4699,6 +4699,7 @@ type ChatMessage struct {
|
||||
RuntimeMs sql.NullInt64 `db:"runtime_ms" json:"runtime_ms"`
|
||||
Deleted bool `db:"deleted" json:"deleted"`
|
||||
ProviderResponseID sql.NullString `db:"provider_response_id" json:"provider_response_id"`
|
||||
APIKeyID sql.NullString `db:"api_key_id" json:"api_key_id"`
|
||||
}
|
||||
|
||||
type ChatModelConfig struct {
|
||||
@@ -4726,6 +4727,7 @@ type ChatQueuedMessage struct {
|
||||
Content json.RawMessage `db:"content" json:"content"`
|
||||
CreatedAt time.Time `db:"created_at" json:"created_at"`
|
||||
ModelConfigID uuid.NullUUID `db:"model_config_id" json:"model_config_id"`
|
||||
APIKeyID sql.NullString `db:"api_key_id" json:"api_key_id"`
|
||||
}
|
||||
|
||||
type ChatTable struct {
|
||||
|
||||
@@ -7157,7 +7157,7 @@ func (q *sqlQuerier) GetChatDiffStatusesByChatIDs(ctx context.Context, chatIds [
|
||||
|
||||
const getChatMessageByID = `-- name: GetChatMessageByID :one
|
||||
SELECT
|
||||
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id
|
||||
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id, api_key_id
|
||||
FROM
|
||||
chat_messages
|
||||
WHERE
|
||||
@@ -7190,6 +7190,7 @@ func (q *sqlQuerier) GetChatMessageByID(ctx context.Context, id int64) (ChatMess
|
||||
&i.RuntimeMs,
|
||||
&i.Deleted,
|
||||
&i.ProviderResponseID,
|
||||
&i.APIKeyID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@@ -7279,7 +7280,7 @@ func (q *sqlQuerier) GetChatMessageSummariesPerChat(ctx context.Context, created
|
||||
|
||||
const getChatMessagesByChatID = `-- name: GetChatMessagesByChatID :many
|
||||
SELECT
|
||||
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id
|
||||
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id, api_key_id
|
||||
FROM
|
||||
chat_messages
|
||||
WHERE
|
||||
@@ -7327,6 +7328,7 @@ func (q *sqlQuerier) GetChatMessagesByChatID(ctx context.Context, arg GetChatMes
|
||||
&i.RuntimeMs,
|
||||
&i.Deleted,
|
||||
&i.ProviderResponseID,
|
||||
&i.APIKeyID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -7343,7 +7345,7 @@ func (q *sqlQuerier) GetChatMessagesByChatID(ctx context.Context, arg GetChatMes
|
||||
|
||||
const getChatMessagesByChatIDAscPaginated = `-- name: GetChatMessagesByChatIDAscPaginated :many
|
||||
SELECT
|
||||
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id
|
||||
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id, api_key_id
|
||||
FROM
|
||||
chat_messages
|
||||
WHERE
|
||||
@@ -7394,6 +7396,7 @@ func (q *sqlQuerier) GetChatMessagesByChatIDAscPaginated(ctx context.Context, ar
|
||||
&i.RuntimeMs,
|
||||
&i.Deleted,
|
||||
&i.ProviderResponseID,
|
||||
&i.APIKeyID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -7410,7 +7413,7 @@ func (q *sqlQuerier) GetChatMessagesByChatIDAscPaginated(ctx context.Context, ar
|
||||
|
||||
const getChatMessagesByChatIDDescPaginated = `-- name: GetChatMessagesByChatIDDescPaginated :many
|
||||
SELECT
|
||||
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id
|
||||
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id, api_key_id
|
||||
FROM
|
||||
chat_messages
|
||||
WHERE
|
||||
@@ -7474,6 +7477,7 @@ func (q *sqlQuerier) GetChatMessagesByChatIDDescPaginated(ctx context.Context, a
|
||||
&i.RuntimeMs,
|
||||
&i.Deleted,
|
||||
&i.ProviderResponseID,
|
||||
&i.APIKeyID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -7506,7 +7510,7 @@ WITH latest_compressed_summary AS (
|
||||
1
|
||||
)
|
||||
SELECT
|
||||
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id
|
||||
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id, api_key_id
|
||||
FROM
|
||||
chat_messages
|
||||
WHERE
|
||||
@@ -7578,6 +7582,7 @@ func (q *sqlQuerier) GetChatMessagesForPromptByChatID(ctx context.Context, chatI
|
||||
&i.RuntimeMs,
|
||||
&i.Deleted,
|
||||
&i.ProviderResponseID,
|
||||
&i.APIKeyID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -7639,7 +7644,7 @@ func (q *sqlQuerier) GetChatModelConfigsForTelemetry(ctx context.Context) ([]Get
|
||||
}
|
||||
|
||||
const getChatQueuedMessages = `-- name: GetChatQueuedMessages :many
|
||||
SELECT id, chat_id, content, created_at, model_config_id FROM chat_queued_messages
|
||||
SELECT id, chat_id, content, created_at, model_config_id, api_key_id FROM chat_queued_messages
|
||||
WHERE chat_id = $1
|
||||
ORDER BY created_at ASC, id ASC
|
||||
`
|
||||
@@ -7659,6 +7664,7 @@ func (q *sqlQuerier) GetChatQueuedMessages(ctx context.Context, chatID uuid.UUID
|
||||
&i.Content,
|
||||
&i.CreatedAt,
|
||||
&i.ModelConfigID,
|
||||
&i.APIKeyID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -8354,7 +8360,7 @@ func (q *sqlQuerier) GetChildChatsByParentIDs(ctx context.Context, arg GetChildC
|
||||
|
||||
const getLastChatMessageByRole = `-- name: GetLastChatMessageByRole :one
|
||||
SELECT
|
||||
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id
|
||||
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id, api_key_id
|
||||
FROM
|
||||
chat_messages
|
||||
WHERE
|
||||
@@ -8397,6 +8403,7 @@ func (q *sqlQuerier) GetLastChatMessageByRole(ctx context.Context, arg GetLastCh
|
||||
&i.RuntimeMs,
|
||||
&i.Deleted,
|
||||
&i.ProviderResponseID,
|
||||
&i.APIKeyID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@@ -8709,7 +8716,7 @@ WITH updated_chat AS (
|
||||
SET
|
||||
last_model_config_id = (
|
||||
SELECT val
|
||||
FROM UNNEST($3::uuid[])
|
||||
FROM UNNEST($4::uuid[])
|
||||
WITH ORDINALITY AS t(val, ord)
|
||||
WHERE val != '00000000-0000-0000-0000-000000000000'::uuid
|
||||
ORDER BY ord DESC
|
||||
@@ -8719,12 +8726,12 @@ WITH updated_chat AS (
|
||||
id = $1::uuid
|
||||
AND EXISTS (
|
||||
SELECT 1
|
||||
FROM UNNEST($3::uuid[])
|
||||
FROM UNNEST($4::uuid[])
|
||||
WHERE unnest != '00000000-0000-0000-0000-000000000000'::uuid
|
||||
)
|
||||
AND chats.last_model_config_id IS DISTINCT FROM (
|
||||
SELECT val
|
||||
FROM UNNEST($3::uuid[])
|
||||
FROM UNNEST($4::uuid[])
|
||||
WITH ORDINALITY AS t(val, ord)
|
||||
WHERE val != '00000000-0000-0000-0000-000000000000'::uuid
|
||||
ORDER BY ord DESC
|
||||
@@ -8734,6 +8741,7 @@ WITH updated_chat AS (
|
||||
INSERT INTO chat_messages (
|
||||
chat_id,
|
||||
created_by,
|
||||
api_key_id,
|
||||
model_config_id,
|
||||
role,
|
||||
content,
|
||||
@@ -8754,29 +8762,31 @@ INSERT INTO chat_messages (
|
||||
SELECT
|
||||
$1::uuid,
|
||||
NULLIF(UNNEST($2::uuid[]), '00000000-0000-0000-0000-000000000000'::uuid),
|
||||
NULLIF(UNNEST($3::uuid[]), '00000000-0000-0000-0000-000000000000'::uuid),
|
||||
UNNEST($4::chat_message_role[]),
|
||||
UNNEST($5::text[])::jsonb,
|
||||
UNNEST($6::smallint[]),
|
||||
UNNEST($7::chat_message_visibility[]),
|
||||
NULLIF(UNNEST($8::bigint[]), 0),
|
||||
NULLIF(UNNEST($3::text[]), ''),
|
||||
NULLIF(UNNEST($4::uuid[]), '00000000-0000-0000-0000-000000000000'::uuid),
|
||||
UNNEST($5::chat_message_role[]),
|
||||
UNNEST($6::text[])::jsonb,
|
||||
UNNEST($7::smallint[]),
|
||||
UNNEST($8::chat_message_visibility[]),
|
||||
NULLIF(UNNEST($9::bigint[]), 0),
|
||||
NULLIF(UNNEST($10::bigint[]), 0),
|
||||
NULLIF(UNNEST($11::bigint[]), 0),
|
||||
NULLIF(UNNEST($12::bigint[]), 0),
|
||||
NULLIF(UNNEST($13::bigint[]), 0),
|
||||
NULLIF(UNNEST($14::bigint[]), 0),
|
||||
UNNEST($15::boolean[]),
|
||||
NULLIF(UNNEST($16::bigint[]), 0),
|
||||
NULLIF(UNNEST($15::bigint[]), 0),
|
||||
UNNEST($16::boolean[]),
|
||||
NULLIF(UNNEST($17::bigint[]), 0),
|
||||
NULLIF(UNNEST($18::text[]), '')
|
||||
NULLIF(UNNEST($18::bigint[]), 0),
|
||||
NULLIF(UNNEST($19::text[]), '')
|
||||
RETURNING
|
||||
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id
|
||||
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id, api_key_id
|
||||
`
|
||||
|
||||
type InsertChatMessagesParams struct {
|
||||
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
|
||||
CreatedBy []uuid.UUID `db:"created_by" json:"created_by"`
|
||||
APIKeyID []string `db:"api_key_id" json:"api_key_id"`
|
||||
ModelConfigID []uuid.UUID `db:"model_config_id" json:"model_config_id"`
|
||||
Role []ChatMessageRole `db:"role" json:"role"`
|
||||
Content []string `db:"content" json:"content"`
|
||||
@@ -8799,6 +8809,7 @@ func (q *sqlQuerier) InsertChatMessages(ctx context.Context, arg InsertChatMessa
|
||||
rows, err := q.db.QueryContext(ctx, insertChatMessages,
|
||||
arg.ChatID,
|
||||
pq.Array(arg.CreatedBy),
|
||||
pq.Array(arg.APIKeyID),
|
||||
pq.Array(arg.ModelConfigID),
|
||||
pq.Array(arg.Role),
|
||||
pq.Array(arg.Content),
|
||||
@@ -8845,6 +8856,7 @@ func (q *sqlQuerier) InsertChatMessages(ctx context.Context, arg InsertChatMessa
|
||||
&i.RuntimeMs,
|
||||
&i.Deleted,
|
||||
&i.ProviderResponseID,
|
||||
&i.APIKeyID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -8860,23 +8872,30 @@ func (q *sqlQuerier) InsertChatMessages(ctx context.Context, arg InsertChatMessa
|
||||
}
|
||||
|
||||
const insertChatQueuedMessage = `-- name: InsertChatQueuedMessage :one
|
||||
INSERT INTO chat_queued_messages (chat_id, content, model_config_id)
|
||||
INSERT INTO chat_queued_messages (chat_id, content, model_config_id, api_key_id)
|
||||
VALUES (
|
||||
$1,
|
||||
$2,
|
||||
$3::uuid
|
||||
$3::uuid,
|
||||
$4::text
|
||||
)
|
||||
RETURNING id, chat_id, content, created_at, model_config_id
|
||||
RETURNING id, chat_id, content, created_at, model_config_id, api_key_id
|
||||
`
|
||||
|
||||
type InsertChatQueuedMessageParams struct {
|
||||
ChatID uuid.UUID `db:"chat_id" json:"chat_id"`
|
||||
Content json.RawMessage `db:"content" json:"content"`
|
||||
ModelConfigID uuid.NullUUID `db:"model_config_id" json:"model_config_id"`
|
||||
APIKeyID sql.NullString `db:"api_key_id" json:"api_key_id"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) InsertChatQueuedMessage(ctx context.Context, arg InsertChatQueuedMessageParams) (ChatQueuedMessage, error) {
|
||||
row := q.db.QueryRowContext(ctx, insertChatQueuedMessage, arg.ChatID, arg.Content, arg.ModelConfigID)
|
||||
row := q.db.QueryRowContext(ctx, insertChatQueuedMessage,
|
||||
arg.ChatID,
|
||||
arg.Content,
|
||||
arg.ModelConfigID,
|
||||
arg.APIKeyID,
|
||||
)
|
||||
var i ChatQueuedMessage
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
@@ -8884,6 +8903,7 @@ func (q *sqlQuerier) InsertChatQueuedMessage(ctx context.Context, arg InsertChat
|
||||
&i.Content,
|
||||
&i.CreatedAt,
|
||||
&i.ModelConfigID,
|
||||
&i.APIKeyID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@@ -9108,7 +9128,7 @@ WHERE id = (
|
||||
ORDER BY cqm.created_at ASC, cqm.id ASC
|
||||
LIMIT 1
|
||||
)
|
||||
RETURNING id, chat_id, content, created_at, model_config_id
|
||||
RETURNING id, chat_id, content, created_at, model_config_id, api_key_id
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) PopNextQueuedMessage(ctx context.Context, chatID uuid.UUID) (ChatQueuedMessage, error) {
|
||||
@@ -9120,6 +9140,7 @@ func (q *sqlQuerier) PopNextQueuedMessage(ctx context.Context, chatID uuid.UUID)
|
||||
&i.Content,
|
||||
&i.CreatedAt,
|
||||
&i.ModelConfigID,
|
||||
&i.APIKeyID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@@ -10145,7 +10166,7 @@ SET
|
||||
WHERE
|
||||
id = $3::bigint
|
||||
RETURNING
|
||||
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id
|
||||
id, chat_id, model_config_id, created_at, role, content, visibility, input_tokens, output_tokens, total_tokens, reasoning_tokens, cache_creation_tokens, cache_read_tokens, context_limit, compressed, created_by, content_version, total_cost_micros, runtime_ms, deleted, provider_response_id, api_key_id
|
||||
`
|
||||
|
||||
type UpdateChatMessageByIDParams struct {
|
||||
@@ -10179,6 +10200,7 @@ func (q *sqlQuerier) UpdateChatMessageByID(ctx context.Context, arg UpdateChatMe
|
||||
&i.RuntimeMs,
|
||||
&i.Deleted,
|
||||
&i.ProviderResponseID,
|
||||
&i.APIKeyID,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
@@ -763,6 +763,7 @@ WITH updated_chat AS (
|
||||
INSERT INTO chat_messages (
|
||||
chat_id,
|
||||
created_by,
|
||||
api_key_id,
|
||||
model_config_id,
|
||||
role,
|
||||
content,
|
||||
@@ -783,6 +784,7 @@ INSERT INTO chat_messages (
|
||||
SELECT
|
||||
@chat_id::uuid,
|
||||
NULLIF(UNNEST(@created_by::uuid[]), '00000000-0000-0000-0000-000000000000'::uuid),
|
||||
NULLIF(UNNEST(@api_key_id::text[]), ''),
|
||||
NULLIF(UNNEST(@model_config_id::uuid[]), '00000000-0000-0000-0000-000000000000'::uuid),
|
||||
UNNEST(@role::chat_message_role[]),
|
||||
UNNEST(@content::text[])::jsonb,
|
||||
@@ -1688,11 +1690,12 @@ RETURNING
|
||||
*;
|
||||
|
||||
-- name: InsertChatQueuedMessage :one
|
||||
INSERT INTO chat_queued_messages (chat_id, content, model_config_id)
|
||||
INSERT INTO chat_queued_messages (chat_id, content, model_config_id, api_key_id)
|
||||
VALUES (
|
||||
@chat_id,
|
||||
@content,
|
||||
sqlc.narg('model_config_id')::uuid
|
||||
sqlc.narg('model_config_id')::uuid,
|
||||
sqlc.narg('api_key_id')::text
|
||||
)
|
||||
RETURNING *;
|
||||
|
||||
|
||||
@@ -1220,6 +1220,7 @@ func (api *API) postChats(rw http.ResponseWriter, r *http.Request) {
|
||||
ClientType: clientType,
|
||||
SystemPrompt: req.SystemPrompt,
|
||||
InitialUserContent: contentBlocks,
|
||||
APIKeyID: apiKey.ID,
|
||||
MCPServerIDs: mcpServerIDs,
|
||||
Labels: labels,
|
||||
DynamicTools: dynamicToolsJSON,
|
||||
@@ -3089,6 +3090,7 @@ func (api *API) postChatMessages(rw http.ResponseWriter, r *http.Request) {
|
||||
CreatedBy: apiKey.UserID,
|
||||
Content: contentBlocks,
|
||||
ModelConfigID: modelConfigID,
|
||||
APIKeyID: apiKey.ID,
|
||||
BusyBehavior: busyBehavior,
|
||||
PlanMode: sendPlanMode,
|
||||
MCPServerIDs: req.MCPServerIDs,
|
||||
@@ -3229,6 +3231,7 @@ func (api *API) patchChatMessage(rw http.ResponseWriter, r *http.Request) {
|
||||
CreatedBy: apiKey.UserID,
|
||||
EditedMessageID: messageID,
|
||||
Content: contentBlocks,
|
||||
APIKeyID: apiKey.ID,
|
||||
ModelConfigID: editModelConfigID,
|
||||
})
|
||||
if editErr != nil {
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/aibridge"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatadvisor"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
|
||||
@@ -30,7 +31,11 @@ import (
|
||||
type advisorOverrideStubStore struct {
|
||||
database.Store
|
||||
|
||||
getEnabledChatModelConfigByID func(context.Context, uuid.UUID) (database.ChatModelConfig, error)
|
||||
getEnabledChatModelConfigByID func(context.Context, uuid.UUID) (database.ChatModelConfig, error)
|
||||
getAIProviderByID func(context.Context, uuid.UUID) (database.AIProvider, error)
|
||||
getAIProviders func(context.Context, database.GetAIProvidersParams) ([]database.AIProvider, error)
|
||||
getAIProviderKeysByProviderID func(context.Context, uuid.UUID) ([]database.AIProviderKey, error)
|
||||
getAIProviderKeysByProviderIDs func(context.Context, []uuid.UUID) ([]database.AIProviderKey, error)
|
||||
}
|
||||
|
||||
func (s *advisorOverrideStubStore) GetEnabledChatModelConfigByID(
|
||||
@@ -43,6 +48,46 @@ func (s *advisorOverrideStubStore) GetEnabledChatModelConfigByID(
|
||||
return s.getEnabledChatModelConfigByID(ctx, id)
|
||||
}
|
||||
|
||||
func (s *advisorOverrideStubStore) GetAIProviderByID(
|
||||
ctx context.Context,
|
||||
id uuid.UUID,
|
||||
) (database.AIProvider, error) {
|
||||
if s.getAIProviderByID == nil {
|
||||
return database.AIProvider{}, xerrors.New("unexpected GetAIProviderByID call")
|
||||
}
|
||||
return s.getAIProviderByID(ctx, id)
|
||||
}
|
||||
|
||||
func (s *advisorOverrideStubStore) GetAIProviders(
|
||||
ctx context.Context,
|
||||
params database.GetAIProvidersParams,
|
||||
) ([]database.AIProvider, error) {
|
||||
if s.getAIProviders == nil {
|
||||
return nil, xerrors.New("unexpected GetAIProviders call")
|
||||
}
|
||||
return s.getAIProviders(ctx, params)
|
||||
}
|
||||
|
||||
func (s *advisorOverrideStubStore) GetAIProviderKeysByProviderID(
|
||||
ctx context.Context,
|
||||
providerID uuid.UUID,
|
||||
) ([]database.AIProviderKey, error) {
|
||||
if s.getAIProviderKeysByProviderID == nil {
|
||||
return nil, xerrors.New("unexpected GetAIProviderKeysByProviderID call")
|
||||
}
|
||||
return s.getAIProviderKeysByProviderID(ctx, providerID)
|
||||
}
|
||||
|
||||
func (s *advisorOverrideStubStore) GetAIProviderKeysByProviderIDs(
|
||||
ctx context.Context,
|
||||
providerIDs []uuid.UUID,
|
||||
) ([]database.AIProviderKey, error) {
|
||||
if s.getAIProviderKeysByProviderIDs == nil {
|
||||
return nil, xerrors.New("unexpected GetAIProviderKeysByProviderIDs call")
|
||||
}
|
||||
return s.getAIProviderKeysByProviderIDs(ctx, providerIDs)
|
||||
}
|
||||
|
||||
func newAdvisorTestServer(
|
||||
ctx context.Context,
|
||||
t *testing.T,
|
||||
@@ -56,6 +101,60 @@ func newAdvisorTestServer(
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Server) resolveAdvisorModelOverrideOrFallback(
|
||||
ctx context.Context,
|
||||
chat database.Chat,
|
||||
advisorCfg codersdk.AdvisorConfig,
|
||||
fallbackModel fantasy.LanguageModel,
|
||||
fallbackCallConfig codersdk.ChatModelCallConfig,
|
||||
providerKeys chatprovider.ProviderAPIKeys,
|
||||
modelOpts modelBuildOptions,
|
||||
logger slog.Logger,
|
||||
) (fantasy.LanguageModel, codersdk.ChatModelCallConfig) {
|
||||
model, cfg, err := p.resolveAdvisorModelOverride(
|
||||
ctx,
|
||||
chat,
|
||||
advisorCfg,
|
||||
fallbackModel,
|
||||
fallbackCallConfig,
|
||||
providerKeys,
|
||||
modelOpts,
|
||||
logger,
|
||||
)
|
||||
if err != nil {
|
||||
logger.Warn(ctx, "failed to resolve advisor model override, continuing with chat model", slog.Error(err))
|
||||
return fallbackModel, fallbackCallConfig
|
||||
}
|
||||
return model, cfg
|
||||
}
|
||||
|
||||
func (p *Server) newAdvisorRuntimeOrFallback(
|
||||
ctx context.Context,
|
||||
chat database.Chat,
|
||||
advisorCfg codersdk.AdvisorConfig,
|
||||
fallbackModel fantasy.LanguageModel,
|
||||
fallbackCallConfig codersdk.ChatModelCallConfig,
|
||||
providerKeys chatprovider.ProviderAPIKeys,
|
||||
modelOpts modelBuildOptions,
|
||||
logger slog.Logger,
|
||||
) *chatadvisor.Runtime {
|
||||
rt, err := p.newAdvisorRuntime(
|
||||
ctx,
|
||||
chat,
|
||||
advisorCfg,
|
||||
fallbackModel,
|
||||
fallbackCallConfig,
|
||||
providerKeys,
|
||||
modelOpts,
|
||||
logger,
|
||||
)
|
||||
if err != nil {
|
||||
logger.Warn(ctx, "failed to create advisor runtime, continuing without advisor", slog.Error(err))
|
||||
return nil
|
||||
}
|
||||
return rt
|
||||
}
|
||||
|
||||
// TestResolveAdvisorModelOverride covers the early-return, each fallback
|
||||
// branch, and the success path. Prior tests only hit the ModelConfigID ==
|
||||
// uuid.Nil early return, so the override body never executed.
|
||||
@@ -73,13 +172,14 @@ func TestResolveAdvisorModelOverride(t *testing.T) {
|
||||
store := &advisorOverrideStubStore{}
|
||||
p := newAdvisorTestServer(ctx, t, store)
|
||||
|
||||
gotModel, gotCfg := p.resolveAdvisorModelOverride(
|
||||
gotModel, gotCfg := p.resolveAdvisorModelOverrideOrFallback(
|
||||
ctx,
|
||||
database.Chat{},
|
||||
codersdk.AdvisorConfig{},
|
||||
fallbackModel,
|
||||
fallbackCallConfig,
|
||||
chatprovider.ProviderAPIKeys{},
|
||||
modelBuildOptions{},
|
||||
logger,
|
||||
)
|
||||
require.Equal(t, fallbackModel, gotModel)
|
||||
@@ -96,13 +196,14 @@ func TestResolveAdvisorModelOverride(t *testing.T) {
|
||||
}
|
||||
p := newAdvisorTestServer(ctx, t, store)
|
||||
|
||||
gotModel, gotCfg := p.resolveAdvisorModelOverride(
|
||||
gotModel, gotCfg := p.resolveAdvisorModelOverrideOrFallback(
|
||||
ctx,
|
||||
database.Chat{},
|
||||
codersdk.AdvisorConfig{ModelConfigID: uuid.New()},
|
||||
fallbackModel,
|
||||
fallbackCallConfig,
|
||||
chatprovider.ProviderAPIKeys{OpenAI: "sk-test"},
|
||||
modelBuildOptions{},
|
||||
logger,
|
||||
)
|
||||
require.Equal(t, fallbackModel, gotModel)
|
||||
@@ -125,13 +226,14 @@ func TestResolveAdvisorModelOverride(t *testing.T) {
|
||||
}
|
||||
p := newAdvisorTestServer(ctx, t, store)
|
||||
|
||||
gotModel, gotCfg := p.resolveAdvisorModelOverride(
|
||||
gotModel, gotCfg := p.resolveAdvisorModelOverrideOrFallback(
|
||||
ctx,
|
||||
database.Chat{},
|
||||
codersdk.AdvisorConfig{ModelConfigID: uuid.New()},
|
||||
fallbackModel,
|
||||
fallbackCallConfig,
|
||||
chatprovider.ProviderAPIKeys{OpenAI: "sk-test"},
|
||||
modelBuildOptions{},
|
||||
logger,
|
||||
)
|
||||
require.Equal(t, fallbackModel, gotModel)
|
||||
@@ -158,13 +260,14 @@ func TestResolveAdvisorModelOverride(t *testing.T) {
|
||||
}
|
||||
p := newAdvisorTestServer(ctx, t, store)
|
||||
|
||||
gotModel, gotCfg := p.resolveAdvisorModelOverride(
|
||||
gotModel, gotCfg := p.resolveAdvisorModelOverrideOrFallback(
|
||||
ctx,
|
||||
database.Chat{},
|
||||
codersdk.AdvisorConfig{ModelConfigID: configID},
|
||||
fallbackModel,
|
||||
fallbackCallConfig,
|
||||
chatprovider.ProviderAPIKeys{OpenAI: "sk-test"},
|
||||
modelBuildOptions{},
|
||||
logger,
|
||||
)
|
||||
require.Equal(t, fallbackModel, gotModel)
|
||||
@@ -175,6 +278,7 @@ func TestResolveAdvisorModelOverride(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
configID := uuid.New()
|
||||
providerID := uuid.New()
|
||||
store := &advisorOverrideStubStore{
|
||||
getEnabledChatModelConfigByID: func(context.Context, uuid.UUID) (database.ChatModelConfig, error) {
|
||||
return database.ChatModelConfig{
|
||||
@@ -187,16 +291,27 @@ func TestResolveAdvisorModelOverride(t *testing.T) {
|
||||
DisplayName: "gpt-5.2",
|
||||
}, nil
|
||||
},
|
||||
getAIProviders: func(context.Context, database.GetAIProvidersParams) ([]database.AIProvider, error) {
|
||||
return []database.AIProvider{{
|
||||
ID: providerID,
|
||||
Type: database.AiProviderTypeOpenai,
|
||||
Enabled: true,
|
||||
}}, nil
|
||||
},
|
||||
getAIProviderKeysByProviderIDs: func(context.Context, []uuid.UUID) ([]database.AIProviderKey, error) {
|
||||
return nil, nil
|
||||
},
|
||||
}
|
||||
p := newAdvisorTestServer(ctx, t, store)
|
||||
|
||||
gotModel, gotCfg := p.resolveAdvisorModelOverride(
|
||||
gotModel, gotCfg := p.resolveAdvisorModelOverrideOrFallback(
|
||||
ctx,
|
||||
database.Chat{},
|
||||
codersdk.AdvisorConfig{ModelConfigID: configID},
|
||||
fallbackModel,
|
||||
fallbackCallConfig,
|
||||
chatprovider.ProviderAPIKeys{},
|
||||
modelBuildOptions{},
|
||||
logger,
|
||||
)
|
||||
require.Equal(t, fallbackModel, gotModel)
|
||||
@@ -227,13 +342,14 @@ func TestResolveAdvisorModelOverride(t *testing.T) {
|
||||
}
|
||||
p := newAdvisorTestServer(ctx, t, store)
|
||||
|
||||
gotModel, gotCfg := p.resolveAdvisorModelOverride(
|
||||
gotModel, gotCfg := p.resolveAdvisorModelOverrideOrFallback(
|
||||
ctx,
|
||||
database.Chat{},
|
||||
codersdk.AdvisorConfig{ModelConfigID: configID},
|
||||
fallbackModel,
|
||||
fallbackCallConfig,
|
||||
chatprovider.ProviderAPIKeys{OpenAI: "sk-test"},
|
||||
modelBuildOptions{},
|
||||
logger,
|
||||
)
|
||||
require.NotEqual(t, fantasy.LanguageModel(fallbackModel), gotModel,
|
||||
@@ -247,6 +363,99 @@ func TestResolveAdvisorModelOverride(t *testing.T) {
|
||||
require.NotNil(t, gotCfg.Temperature)
|
||||
require.InDelta(t, 0.42, *gotCfg.Temperature, 1e-9)
|
||||
})
|
||||
|
||||
t.Run("AIProviderIDResolvesOverrideProviderKeys", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
configID := uuid.New()
|
||||
providerID := uuid.New()
|
||||
store := &advisorOverrideStubStore{
|
||||
getEnabledChatModelConfigByID: func(context.Context, uuid.UUID) (database.ChatModelConfig, error) {
|
||||
return database.ChatModelConfig{
|
||||
ID: configID,
|
||||
Provider: "openai",
|
||||
Model: "gpt-5.2",
|
||||
Enabled: true,
|
||||
CreatedAt: time.Unix(0, 0).UTC(),
|
||||
UpdatedAt: time.Unix(0, 0).UTC(),
|
||||
DisplayName: "gpt-5.2",
|
||||
AIProviderID: uuid.NullUUID{UUID: providerID, Valid: true},
|
||||
}, nil
|
||||
},
|
||||
getAIProviderByID: func(context.Context, uuid.UUID) (database.AIProvider, error) {
|
||||
return database.AIProvider{
|
||||
ID: providerID,
|
||||
Type: database.AiProviderTypeOpenai,
|
||||
Enabled: true,
|
||||
}, nil
|
||||
},
|
||||
getAIProviderKeysByProviderID: func(context.Context, uuid.UUID) ([]database.AIProviderKey, error) {
|
||||
return []database.AIProviderKey{{
|
||||
ProviderID: providerID,
|
||||
APIKey: "sk-selected",
|
||||
}}, nil
|
||||
},
|
||||
}
|
||||
p := newAdvisorTestServer(ctx, t, store)
|
||||
|
||||
gotModel, gotCfg := p.resolveAdvisorModelOverrideOrFallback(
|
||||
ctx,
|
||||
database.Chat{},
|
||||
codersdk.AdvisorConfig{ModelConfigID: configID},
|
||||
fallbackModel,
|
||||
fallbackCallConfig,
|
||||
chatprovider.ProviderAPIKeys{},
|
||||
modelBuildOptions{},
|
||||
logger,
|
||||
)
|
||||
require.NotEqual(t, fantasy.LanguageModel(fallbackModel), gotModel)
|
||||
require.NotNil(t, gotModel)
|
||||
require.Equal(t, "openai", gotModel.Provider())
|
||||
require.Equal(t, "gpt-5.2", gotModel.Model())
|
||||
require.Equal(t, fallbackCallConfig, gotCfg)
|
||||
})
|
||||
}
|
||||
|
||||
func TestResolveAdvisorModelOverridePromotesAIBridgeErrors(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
configID := uuid.New()
|
||||
providerID := uuid.New()
|
||||
store := &advisorOverrideStubStore{
|
||||
getEnabledChatModelConfigByID: func(context.Context, uuid.UUID) (database.ChatModelConfig, error) {
|
||||
return database.ChatModelConfig{
|
||||
ID: configID,
|
||||
Provider: "openai",
|
||||
Model: "gpt-5.2",
|
||||
Enabled: true,
|
||||
DisplayName: "gpt-5.2",
|
||||
AIProviderID: uuid.NullUUID{UUID: providerID, Valid: true},
|
||||
}, nil
|
||||
},
|
||||
getAIProviderByID: func(context.Context, uuid.UUID) (database.AIProvider, error) {
|
||||
return database.AIProvider{ID: providerID, Type: database.AiProviderTypeOpenai, Name: "primary-openai", Enabled: true}, nil
|
||||
},
|
||||
getAIProviderKeysByProviderID: func(context.Context, uuid.UUID) ([]database.AIProviderKey, error) {
|
||||
return []database.AIProviderKey{{ProviderID: providerID, APIKey: "sk-selected"}}, nil
|
||||
},
|
||||
}
|
||||
p := newAdvisorTestServer(ctx, t, store)
|
||||
p.aiGatewayRoutingEnabled = true
|
||||
|
||||
ctx = aibridge.WithDelegatedAPIKeyID(ctx, uuid.NewString())
|
||||
model, _, err := p.resolveAdvisorModelOverride(
|
||||
ctx,
|
||||
database.Chat{ID: uuid.New(), OwnerID: uuid.New()},
|
||||
codersdk.AdvisorConfig{ModelConfigID: configID},
|
||||
&chattest.FakeModel{ProviderName: "stub", ModelName: "stub"},
|
||||
codersdk.ChatModelCallConfig{},
|
||||
chatprovider.ProviderAPIKeys{},
|
||||
modelBuildOptions{ActiveAPIKeyID: uuid.NewString()},
|
||||
slog.Make(),
|
||||
)
|
||||
require.ErrorContains(t, err, "AI Gateway transport factory")
|
||||
require.Nil(t, model)
|
||||
}
|
||||
|
||||
// TestStripAdvisorGuidanceBlock exercises the filter that keeps the advisor
|
||||
@@ -346,7 +555,7 @@ func TestNewAdvisorRuntime(t *testing.T) {
|
||||
store := &advisorOverrideStubStore{}
|
||||
p := newAdvisorTestServer(ctx, t, store)
|
||||
|
||||
rt := p.newAdvisorRuntime(
|
||||
rt := p.newAdvisorRuntimeOrFallback(
|
||||
ctx,
|
||||
database.Chat{},
|
||||
codersdk.AdvisorConfig{
|
||||
@@ -357,6 +566,7 @@ func TestNewAdvisorRuntime(t *testing.T) {
|
||||
fallbackModel,
|
||||
fallbackCallConfig,
|
||||
chatprovider.ProviderAPIKeys{},
|
||||
modelBuildOptions{},
|
||||
logger,
|
||||
)
|
||||
require.NotNil(t, rt, "zero max uses must default rather than bail out")
|
||||
@@ -370,7 +580,7 @@ func TestNewAdvisorRuntime(t *testing.T) {
|
||||
store := &advisorOverrideStubStore{}
|
||||
p := newAdvisorTestServer(ctx, t, store)
|
||||
|
||||
rt := p.newAdvisorRuntime(
|
||||
rt := p.newAdvisorRuntimeOrFallback(
|
||||
ctx,
|
||||
database.Chat{},
|
||||
codersdk.AdvisorConfig{
|
||||
@@ -381,6 +591,7 @@ func TestNewAdvisorRuntime(t *testing.T) {
|
||||
fallbackModel,
|
||||
fallbackCallConfig,
|
||||
chatprovider.ProviderAPIKeys{},
|
||||
modelBuildOptions{},
|
||||
logger,
|
||||
)
|
||||
require.Nil(t, rt, "negative max uses must disable the advisor")
|
||||
@@ -392,7 +603,7 @@ func TestNewAdvisorRuntime(t *testing.T) {
|
||||
store := &advisorOverrideStubStore{}
|
||||
p := newAdvisorTestServer(ctx, t, store)
|
||||
|
||||
rt := p.newAdvisorRuntime(
|
||||
rt := p.newAdvisorRuntimeOrFallback(
|
||||
ctx,
|
||||
database.Chat{},
|
||||
codersdk.AdvisorConfig{
|
||||
@@ -403,6 +614,7 @@ func TestNewAdvisorRuntime(t *testing.T) {
|
||||
fallbackModel,
|
||||
fallbackCallConfig,
|
||||
chatprovider.ProviderAPIKeys{},
|
||||
modelBuildOptions{},
|
||||
logger,
|
||||
)
|
||||
require.NotNil(t, rt,
|
||||
|
||||
+239
-156
@@ -29,6 +29,7 @@ import (
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/aibridge"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/db2sdk"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
@@ -254,6 +255,9 @@ type Server struct {
|
||||
metrics *chatloop.Metrics
|
||||
recordingSem chan struct{}
|
||||
|
||||
aibridgeTransportFactory *atomic.Pointer[aibridge.TransportFactory]
|
||||
aiGatewayRoutingEnabled bool
|
||||
|
||||
// Configuration
|
||||
pendingChatAcquireInterval time.Duration
|
||||
maxChatsPerAcquire int32
|
||||
@@ -344,10 +348,11 @@ func (p *Server) resolveAdvisorModelOverride(
|
||||
fallbackModel fantasy.LanguageModel,
|
||||
fallbackCallConfig codersdk.ChatModelCallConfig,
|
||||
providerKeys chatprovider.ProviderAPIKeys,
|
||||
modelOpts modelBuildOptions,
|
||||
logger slog.Logger,
|
||||
) (fantasy.LanguageModel, codersdk.ChatModelCallConfig) {
|
||||
) (fantasy.LanguageModel, codersdk.ChatModelCallConfig, error) {
|
||||
if advisorCfg.ModelConfigID == uuid.Nil {
|
||||
return fallbackModel, fallbackCallConfig
|
||||
return fallbackModel, fallbackCallConfig, nil
|
||||
}
|
||||
|
||||
// Re-read the override instead of using the cache so disabled models
|
||||
@@ -363,7 +368,7 @@ func (p *Server) resolveAdvisorModelOverride(
|
||||
"advisor model config is disabled or unavailable, continuing with chat model",
|
||||
slog.F("model_config_id", advisorCfg.ModelConfigID),
|
||||
)
|
||||
return fallbackModel, fallbackCallConfig
|
||||
return fallbackModel, fallbackCallConfig, nil
|
||||
}
|
||||
logger.Warn(
|
||||
ctx,
|
||||
@@ -371,7 +376,7 @@ func (p *Server) resolveAdvisorModelOverride(
|
||||
slog.F("model_config_id", advisorCfg.ModelConfigID),
|
||||
slog.Error(err),
|
||||
)
|
||||
return fallbackModel, fallbackCallConfig
|
||||
return fallbackModel, fallbackCallConfig, nil
|
||||
}
|
||||
|
||||
overrideCallConfig := codersdk.ChatModelCallConfig{}
|
||||
@@ -383,29 +388,48 @@ func (p *Server) resolveAdvisorModelOverride(
|
||||
slog.F("model_config_id", advisorCfg.ModelConfigID),
|
||||
slog.Error(err),
|
||||
)
|
||||
return fallbackModel, fallbackCallConfig
|
||||
return fallbackModel, fallbackCallConfig, nil
|
||||
}
|
||||
}
|
||||
|
||||
overrideModel, err := chatprovider.ModelFromConfig(
|
||||
overrideConfig.Provider,
|
||||
overrideConfig.Model,
|
||||
route, err := p.resolveModelRouteForConfig(
|
||||
ctx,
|
||||
chat.OwnerID,
|
||||
overrideConfig,
|
||||
providerKeys,
|
||||
chatprovider.UserAgent(),
|
||||
chatprovider.CoderHeaders(chat),
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
if p.shouldUseAIGatewayRouting() && overrideConfig.AIProviderID.Valid {
|
||||
return nil, codersdk.ChatModelCallConfig{}, xerrors.Errorf("resolve advisor override route: %w", err)
|
||||
}
|
||||
logger.Warn(
|
||||
ctx,
|
||||
"failed to resolve advisor override route, continuing with chat model",
|
||||
slog.F("model_config_id", advisorCfg.ModelConfigID),
|
||||
slog.Error(err),
|
||||
)
|
||||
return fallbackModel, fallbackCallConfig, nil
|
||||
}
|
||||
overrideModel, err := p.newModel(ctx, modelClientRequest{
|
||||
Chat: chat,
|
||||
ModelName: overrideConfig.Model,
|
||||
UserAgent: chatprovider.UserAgent(),
|
||||
ExtraHeaders: chatprovider.CoderHeaders(chat),
|
||||
}, route, modelOpts)
|
||||
if err != nil {
|
||||
if p.shouldUseAIGatewayRouting() && overrideConfig.AIProviderID.Valid {
|
||||
return nil, codersdk.ChatModelCallConfig{}, xerrors.Errorf("create advisor override model: %w", err)
|
||||
}
|
||||
logger.Warn(
|
||||
ctx,
|
||||
"failed to create advisor override model, continuing with chat model",
|
||||
slog.F("model_config_id", advisorCfg.ModelConfigID),
|
||||
slog.Error(err),
|
||||
)
|
||||
return fallbackModel, fallbackCallConfig
|
||||
return fallbackModel, fallbackCallConfig, nil
|
||||
}
|
||||
|
||||
return overrideModel, overrideCallConfig
|
||||
return overrideModel, overrideCallConfig, nil
|
||||
}
|
||||
|
||||
func (p *Server) newAdvisorRuntime(
|
||||
@@ -415,17 +439,22 @@ func (p *Server) newAdvisorRuntime(
|
||||
fallbackModel fantasy.LanguageModel,
|
||||
fallbackCallConfig codersdk.ChatModelCallConfig,
|
||||
providerKeys chatprovider.ProviderAPIKeys,
|
||||
modelOpts modelBuildOptions,
|
||||
logger slog.Logger,
|
||||
) *chatadvisor.Runtime {
|
||||
advisorModel, advisorCallConfig := p.resolveAdvisorModelOverride(
|
||||
) (*chatadvisor.Runtime, error) {
|
||||
advisorModel, advisorCallConfig, err := p.resolveAdvisorModelOverride(
|
||||
ctx,
|
||||
chat,
|
||||
advisorCfg,
|
||||
fallbackModel,
|
||||
fallbackCallConfig,
|
||||
providerKeys,
|
||||
modelOpts,
|
||||
logger,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
maxUsesPerRun := advisorCfg.MaxUsesPerRun
|
||||
switch {
|
||||
@@ -441,7 +470,7 @@ func (p *Server) newAdvisorRuntime(
|
||||
"invalid advisor max uses per run, continuing without advisor",
|
||||
slog.F("max_uses_per_run", maxUsesPerRun),
|
||||
)
|
||||
return nil
|
||||
return nil, nil //nolint:nilnil // Nil runtime with nil error means advisor is skipped for this turn.
|
||||
}
|
||||
|
||||
maxOutputTokens := advisorCfg.MaxOutputTokens
|
||||
@@ -468,9 +497,9 @@ func (p *Server) newAdvisorRuntime(
|
||||
"failed to create advisor runtime, continuing without advisor",
|
||||
slog.Error(err),
|
||||
)
|
||||
return nil
|
||||
return nil, nil //nolint:nilnil // Nil runtime with nil error means advisor is skipped for this turn.
|
||||
}
|
||||
return rt
|
||||
return rt, nil
|
||||
}
|
||||
|
||||
// cachedWorkspaceMCPTools stores workspace MCP tools discovered
|
||||
@@ -1436,6 +1465,7 @@ type CreateOptions struct {
|
||||
ClientType database.ChatClientType
|
||||
SystemPrompt string
|
||||
InitialUserContent []codersdk.ChatMessagePart
|
||||
APIKeyID string
|
||||
MCPServerIDs []uuid.UUID
|
||||
Labels database.StringMap
|
||||
DynamicTools json.RawMessage
|
||||
@@ -1460,6 +1490,7 @@ type SendMessageOptions struct {
|
||||
CreatedBy uuid.UUID
|
||||
Content []codersdk.ChatMessagePart
|
||||
ModelConfigID uuid.UUID
|
||||
APIKeyID string
|
||||
BusyBehavior SendMessageBusyBehavior
|
||||
PlanMode *database.NullChatPlanMode
|
||||
MCPServerIDs *[]uuid.UUID
|
||||
@@ -1479,6 +1510,7 @@ type EditMessageOptions struct {
|
||||
CreatedBy uuid.UUID
|
||||
EditedMessageID int64
|
||||
Content []codersdk.ChatMessagePart
|
||||
APIKeyID string
|
||||
// ModelConfigID, when non-zero, overrides the model used for
|
||||
// the replacement user message. When set to uuid.Nil the
|
||||
// original message's model is preserved.
|
||||
@@ -1647,7 +1679,7 @@ func (p *Server) CreateChat(ctx context.Context, opts CreateOptions) (database.C
|
||||
database.ChatMessageVisibilityBoth,
|
||||
opts.ModelConfigID,
|
||||
chatprompt.CurrentContentVersion,
|
||||
).withCreatedBy(opts.OwnerID))
|
||||
).withCreatedBy(opts.OwnerID).withAPIKeyID(opts.APIKeyID))
|
||||
|
||||
_, err = tx.InsertChatMessages(ctx, msgParams)
|
||||
if err != nil {
|
||||
@@ -1786,6 +1818,10 @@ func (p *Server) SendMessage(
|
||||
UUID: modelConfigID,
|
||||
Valid: modelConfigID != uuid.Nil,
|
||||
},
|
||||
APIKeyID: sql.NullString{
|
||||
String: opts.APIKeyID,
|
||||
Valid: opts.APIKeyID != "",
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("insert queued message: %w", err)
|
||||
@@ -1810,6 +1846,7 @@ func (p *Server) SendMessage(
|
||||
modelConfigID,
|
||||
content,
|
||||
opts.CreatedBy,
|
||||
opts.APIKeyID,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -2083,7 +2120,7 @@ func (p *Server) EditMessage(
|
||||
editedMsg.Visibility,
|
||||
messageModelConfigID,
|
||||
chatprompt.CurrentContentVersion,
|
||||
).withCreatedBy(opts.CreatedBy))
|
||||
).withCreatedBy(opts.CreatedBy).withAPIKeyID(opts.APIKeyID))
|
||||
newMessages, err := insertChatMessageWithStore(ctx, tx, msgParams)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("insert replacement message: %w", err)
|
||||
@@ -2416,12 +2453,14 @@ func (p *Server) PromoteQueued(
|
||||
var (
|
||||
targetContent json.RawMessage
|
||||
targetModelConfigID uuid.NullUUID
|
||||
targetAPIKeyID sql.NullString
|
||||
found bool
|
||||
)
|
||||
for _, qm := range queuedMessages {
|
||||
if qm.ID == opts.QueuedMessageID {
|
||||
targetContent = qm.Content
|
||||
targetModelConfigID = qm.ModelConfigID
|
||||
targetAPIKeyID = qm.APIKeyID
|
||||
found = true
|
||||
break
|
||||
}
|
||||
@@ -2511,6 +2550,7 @@ func (p *Server) PromoteQueued(
|
||||
Valid: len(targetContent) > 0,
|
||||
},
|
||||
opts.CreatedBy,
|
||||
targetAPIKeyID.String,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -2754,6 +2794,7 @@ func (p *Server) SubmitToolResults(
|
||||
params := database.InsertChatMessagesParams{
|
||||
ChatID: opts.ChatID,
|
||||
CreatedBy: make([]uuid.UUID, n),
|
||||
APIKeyID: make([]string, n),
|
||||
ModelConfigID: make([]uuid.UUID, n),
|
||||
Role: make([]database.ChatMessageRole, n),
|
||||
Content: make([]string, n),
|
||||
@@ -2862,16 +2903,18 @@ var ErrManualTitleRegenerationInProgress = xerrors.New(
|
||||
)
|
||||
|
||||
type manualTitleCandidateResult struct {
|
||||
title string
|
||||
modelConfig database.ChatModelConfig
|
||||
usage fantasy.Usage
|
||||
hasMessages bool
|
||||
title string
|
||||
modelConfig database.ChatModelConfig
|
||||
usage fantasy.Usage
|
||||
activeAPIKeyID string
|
||||
hasMessages bool
|
||||
}
|
||||
|
||||
type manualTitleGenerationError struct {
|
||||
cause error
|
||||
modelConfig database.ChatModelConfig
|
||||
usage fantasy.Usage
|
||||
cause error
|
||||
modelConfig database.ChatModelConfig
|
||||
usage fantasy.Usage
|
||||
activeAPIKeyID string
|
||||
}
|
||||
|
||||
func (e *manualTitleGenerationError) Error() string {
|
||||
@@ -3105,6 +3148,7 @@ func (p *Server) recordManualTitleGenerationFailure(
|
||||
chat,
|
||||
generationErr.modelConfig,
|
||||
generationErr.usage,
|
||||
generationErr.activeAPIKeyID,
|
||||
"",
|
||||
); recordErr != nil {
|
||||
return errors.Join(
|
||||
@@ -3154,11 +3198,13 @@ func (p *Server) generateManualTitleCandidate(
|
||||
if len(messages) == 0 {
|
||||
return manualTitleCandidateResult{}, nil
|
||||
}
|
||||
modelOpts := modelBuildOptionsFromMessages(messages)
|
||||
|
||||
model, modelConfig, modelKeys, err := p.resolveManualTitleModel(ctx, store, chat, keys)
|
||||
model, modelConfig, modelKeys, err := p.resolveManualTitleModel(ctx, store, chat, keys, modelOpts)
|
||||
result := manualTitleCandidateResult{
|
||||
modelConfig: modelConfig,
|
||||
hasMessages: true,
|
||||
modelConfig: modelConfig,
|
||||
activeAPIKeyID: modelOpts.ActiveAPIKeyID,
|
||||
hasMessages: true,
|
||||
}
|
||||
if err != nil {
|
||||
return result, err
|
||||
@@ -3174,6 +3220,7 @@ func (p *Server) generateManualTitleCandidate(
|
||||
chat,
|
||||
modelConfig,
|
||||
modelKeys,
|
||||
modelOpts,
|
||||
messages,
|
||||
model,
|
||||
)
|
||||
@@ -3189,9 +3236,10 @@ func (p *Server) generateManualTitleCandidate(
|
||||
return result, wrappedErr
|
||||
}
|
||||
return result, &manualTitleGenerationError{
|
||||
cause: wrappedErr,
|
||||
modelConfig: modelConfig,
|
||||
usage: usage,
|
||||
cause: wrappedErr,
|
||||
modelConfig: modelConfig,
|
||||
usage: usage,
|
||||
activeAPIKeyID: modelOpts.ActiveAPIKeyID,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3220,6 +3268,7 @@ func (p *Server) proposeChatTitleWithStore(
|
||||
chat,
|
||||
result.modelConfig,
|
||||
result.usage,
|
||||
result.activeAPIKeyID,
|
||||
"",
|
||||
); recordErr != nil {
|
||||
return "", xerrors.Errorf("record manual title usage: %w", recordErr)
|
||||
@@ -3250,6 +3299,7 @@ func (p *Server) regenerateChatTitleWithStore(
|
||||
chat,
|
||||
result.modelConfig,
|
||||
result.usage,
|
||||
result.activeAPIKeyID,
|
||||
result.title,
|
||||
)
|
||||
if recordErr != nil {
|
||||
@@ -3272,6 +3322,7 @@ func (p *Server) prepareManualTitleDebugRun(
|
||||
chat database.Chat,
|
||||
modelConfig database.ChatModelConfig,
|
||||
keys chatprovider.ProviderAPIKeys,
|
||||
modelOpts modelBuildOptions,
|
||||
messages []database.ChatMessage,
|
||||
fallbackModel fantasy.LanguageModel,
|
||||
) (context.Context, fantasy.LanguageModel, func(error)) {
|
||||
@@ -3279,15 +3330,21 @@ func (p *Server) prepareManualTitleDebugRun(
|
||||
titleModel := fallbackModel
|
||||
finishDebugRun := func(error) {}
|
||||
|
||||
httpClient := &http.Client{Transport: &chatdebug.RecordingTransport{}}
|
||||
debugModel, debugModelErr := chatprovider.ModelFromConfig(
|
||||
modelConfig.Provider,
|
||||
modelConfig.Model,
|
||||
keys,
|
||||
chatprovider.UserAgent(),
|
||||
chatprovider.CoderHeaders(chat),
|
||||
httpClient,
|
||||
)
|
||||
route, routeErr := p.resolveModelRouteForConfig(ctx, chat.OwnerID, modelConfig, keys)
|
||||
debugOpts := modelOpts
|
||||
debugOpts.RecordHTTP = true
|
||||
var debugModelErr error
|
||||
var debugModel fantasy.LanguageModel
|
||||
if routeErr != nil {
|
||||
debugModelErr = routeErr
|
||||
} else {
|
||||
debugModel, debugModelErr = p.newModel(ctx, modelClientRequest{
|
||||
Chat: chat,
|
||||
ModelName: modelConfig.Model,
|
||||
UserAgent: chatprovider.UserAgent(),
|
||||
ExtraHeaders: chatprovider.CoderHeaders(chat),
|
||||
}, route, debugOpts)
|
||||
}
|
||||
switch {
|
||||
case debugModelErr != nil:
|
||||
p.logger.Warn(ctx, "failed to create debug-aware manual title model",
|
||||
@@ -3535,11 +3592,13 @@ func (p *Server) resolveManualTitleModel(
|
||||
store database.Store,
|
||||
chat database.Chat,
|
||||
keys chatprovider.ProviderAPIKeys,
|
||||
modelOpts modelBuildOptions,
|
||||
) (fantasy.LanguageModel, database.ChatModelConfig, chatprovider.ProviderAPIKeys, error) {
|
||||
overrideConfig, overrideModel, overrideKeys, overrideSet, overrideErr := p.resolveTitleGenerationModelOverride(
|
||||
overrideConfig, overrideModel, overrideKeys, _, overrideSet, overrideErr := p.resolveTitleGenerationModelOverride(
|
||||
ctx,
|
||||
chat,
|
||||
keys,
|
||||
modelOpts,
|
||||
)
|
||||
if overrideErr != nil {
|
||||
if overrideSet {
|
||||
@@ -3562,15 +3621,15 @@ func (p *Server) resolveManualTitleModel(
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.Error(err),
|
||||
)
|
||||
return p.resolveFallbackManualTitleModel(ctx, chat, keys)
|
||||
return p.resolveFallbackManualTitleModel(ctx, chat, keys, modelOpts)
|
||||
}
|
||||
|
||||
config, ok := selectPreferredConfiguredShortTextModelConfig(configs)
|
||||
if !ok {
|
||||
return p.resolveFallbackManualTitleModel(ctx, chat, keys)
|
||||
return p.resolveFallbackManualTitleModel(ctx, chat, keys, modelOpts)
|
||||
}
|
||||
|
||||
providerHint, modelKeys, err := p.resolveModelConfigProviderHintAndKeys(ctx, chat.OwnerID, config, keys)
|
||||
route, err := p.resolveModelRouteForConfig(ctx, chat.OwnerID, config, keys)
|
||||
if err != nil {
|
||||
p.logger.Debug(ctx, "manual title preferred model unavailable",
|
||||
slog.F("chat_id", chat.ID),
|
||||
@@ -3578,33 +3637,32 @@ func (p *Server) resolveManualTitleModel(
|
||||
slog.F("model", config.Model),
|
||||
slog.Error(err),
|
||||
)
|
||||
return p.resolveFallbackManualTitleModel(ctx, chat, keys)
|
||||
return p.resolveFallbackManualTitleModel(ctx, chat, keys, modelOpts)
|
||||
}
|
||||
model, err := chatprovider.ModelFromConfig(
|
||||
providerHint,
|
||||
config.Model,
|
||||
modelKeys,
|
||||
chatprovider.UserAgent(),
|
||||
chatprovider.CoderHeaders(chat),
|
||||
nil,
|
||||
)
|
||||
model, err := p.newModel(ctx, modelClientRequest{
|
||||
Chat: chat,
|
||||
ModelName: config.Model,
|
||||
UserAgent: chatprovider.UserAgent(),
|
||||
ExtraHeaders: chatprovider.CoderHeaders(chat),
|
||||
}, route, modelOpts)
|
||||
if err != nil {
|
||||
p.logger.Debug(ctx, "manual title preferred model unavailable",
|
||||
slog.F("chat_id", chat.ID),
|
||||
slog.F("provider", providerHint),
|
||||
slog.F("provider", config.Provider),
|
||||
slog.F("model", config.Model),
|
||||
slog.Error(err),
|
||||
)
|
||||
return p.resolveFallbackManualTitleModel(ctx, chat, keys)
|
||||
return p.resolveFallbackManualTitleModel(ctx, chat, keys, modelOpts)
|
||||
}
|
||||
|
||||
return model, config, modelKeys, nil
|
||||
return model, config, route.directProviderKeys(), nil
|
||||
}
|
||||
|
||||
func (p *Server) resolveFallbackManualTitleModel(
|
||||
ctx context.Context,
|
||||
chat database.Chat,
|
||||
keys chatprovider.ProviderAPIKeys,
|
||||
modelOpts modelBuildOptions,
|
||||
) (fantasy.LanguageModel, database.ChatModelConfig, chatprovider.ProviderAPIKeys, error) {
|
||||
config, err := p.resolveModelConfig(ctx, chat)
|
||||
if err != nil {
|
||||
@@ -3613,25 +3671,23 @@ func (p *Server) resolveFallbackManualTitleModel(
|
||||
err,
|
||||
)
|
||||
}
|
||||
providerHint, modelKeys, err := p.resolveModelConfigProviderHintAndKeys(ctx, chat.OwnerID, config, keys)
|
||||
route, err := p.resolveModelRouteForConfig(ctx, chat.OwnerID, config, keys)
|
||||
if err != nil {
|
||||
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, err
|
||||
}
|
||||
model, err := chatprovider.ModelFromConfig(
|
||||
providerHint,
|
||||
config.Model,
|
||||
modelKeys,
|
||||
chatprovider.UserAgent(),
|
||||
chatprovider.CoderHeaders(chat),
|
||||
nil,
|
||||
)
|
||||
model, err := p.newModel(ctx, modelClientRequest{
|
||||
Chat: chat,
|
||||
ModelName: config.Model,
|
||||
UserAgent: chatprovider.UserAgent(),
|
||||
ExtraHeaders: chatprovider.CoderHeaders(chat),
|
||||
}, route, modelOpts)
|
||||
if err != nil {
|
||||
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, xerrors.Errorf(
|
||||
"create fallback manual title model: %w",
|
||||
err,
|
||||
)
|
||||
}
|
||||
return model, config, modelKeys, nil
|
||||
return model, config, route.directProviderKeys(), nil
|
||||
}
|
||||
|
||||
func mergeManualTitleMessages(
|
||||
@@ -3682,6 +3738,7 @@ func recordManualTitleUsage(
|
||||
chat database.Chat,
|
||||
modelConfig database.ChatModelConfig,
|
||||
usage fantasy.Usage,
|
||||
activeAPIKeyID string,
|
||||
newTitle string,
|
||||
) (database.Chat, error) {
|
||||
hasUsage := usage != (fantasy.Usage{})
|
||||
@@ -3720,6 +3777,7 @@ func recordManualTitleUsage(
|
||||
messages, err := tx.InsertChatMessages(ctx, database.InsertChatMessagesParams{
|
||||
ChatID: chat.ID,
|
||||
CreatedBy: []uuid.UUID{chat.OwnerID},
|
||||
APIKeyID: []string{activeAPIKeyID},
|
||||
ModelConfigID: []uuid.UUID{modelConfig.ID},
|
||||
Role: []database.ChatMessageRole{database.ChatMessageRoleAssistant},
|
||||
Content: []string{content},
|
||||
@@ -3845,6 +3903,7 @@ type chatMessage struct {
|
||||
visibility database.ChatMessageVisibility
|
||||
modelConfigID uuid.UUID
|
||||
createdBy uuid.UUID
|
||||
apiKeyID string
|
||||
contentVersion int16
|
||||
compressed bool
|
||||
inputTokens int64
|
||||
@@ -3880,6 +3939,11 @@ func (m chatMessage) withCreatedBy(id uuid.UUID) chatMessage {
|
||||
return m
|
||||
}
|
||||
|
||||
func (m chatMessage) withAPIKeyID(id string) chatMessage {
|
||||
m.apiKeyID = id
|
||||
return m
|
||||
}
|
||||
|
||||
func (m chatMessage) withCompressed() chatMessage {
|
||||
m.compressed = true
|
||||
return m
|
||||
@@ -3924,6 +3988,7 @@ func appendChatMessage(
|
||||
msg chatMessage,
|
||||
) {
|
||||
params.CreatedBy = append(params.CreatedBy, msg.createdBy)
|
||||
params.APIKeyID = append(params.APIKeyID, msg.apiKeyID)
|
||||
params.ModelConfigID = append(params.ModelConfigID, msg.modelConfigID)
|
||||
params.Role = append(params.Role, msg.role)
|
||||
params.Content = append(params.Content, string(msg.content.RawMessage))
|
||||
@@ -3973,6 +4038,7 @@ func insertUserMessageAndSetPending(
|
||||
modelConfigID uuid.UUID,
|
||||
content pqtype.NullRawMessage,
|
||||
createdBy uuid.UUID,
|
||||
apiKeyID string,
|
||||
) (database.ChatMessage, database.Chat, error) {
|
||||
msgParams := database.InsertChatMessagesParams{ //nolint:exhaustruct // Fields populated by appendChatMessage.
|
||||
ChatID: lockedChat.ID,
|
||||
@@ -3983,7 +4049,7 @@ func insertUserMessageAndSetPending(
|
||||
database.ChatMessageVisibilityBoth,
|
||||
modelConfigID,
|
||||
chatprompt.CurrentContentVersion,
|
||||
).withCreatedBy(createdBy))
|
||||
).withCreatedBy(createdBy).withAPIKeyID(apiKeyID))
|
||||
messages, err := insertChatMessageWithStore(ctx, store, msgParams)
|
||||
if err != nil {
|
||||
return database.ChatMessage{}, database.Chat{}, err
|
||||
@@ -4052,7 +4118,10 @@ type Config struct {
|
||||
WebpushDispatcher webpush.Dispatcher
|
||||
UsageTracker *workspacestats.UsageTracker
|
||||
Clock quartz.Clock
|
||||
PrometheusRegistry prometheus.Registerer
|
||||
AIBridgeTransportFactory *atomic.Pointer[aibridge.TransportFactory]
|
||||
AIGatewayRoutingEnabled bool
|
||||
|
||||
PrometheusRegistry prometheus.Registerer
|
||||
|
||||
// OIDCTokenSource resolves the calling user's OIDC access
|
||||
// token for MCP servers configured with auth_type=user_oidc.
|
||||
@@ -4139,6 +4208,8 @@ func New(cfg Config) *Server {
|
||||
debugSvc.SetStaleAfter(inFlightChatStaleAfter * 3)
|
||||
return debugSvc
|
||||
},
|
||||
aibridgeTransportFactory: cfg.AIBridgeTransportFactory,
|
||||
aiGatewayRoutingEnabled: cfg.AIGatewayRoutingEnabled,
|
||||
pendingChatAcquireInterval: pendingChatAcquireInterval,
|
||||
maxChatsPerAcquire: maxChatsPerAcquire,
|
||||
inFlightChatStaleAfter: inFlightChatStaleAfter,
|
||||
@@ -5803,7 +5874,7 @@ func (p *Server) tryAutoPromoteQueuedMessage(
|
||||
database.ChatMessageVisibilityBoth,
|
||||
effectiveModelConfigID,
|
||||
chatprompt.CurrentContentVersion,
|
||||
).withCreatedBy(chat.OwnerID))
|
||||
).withCreatedBy(chat.OwnerID).withAPIKeyID(nextQueued.APIKeyID.String))
|
||||
msgs, err := insertChatMessageWithStore(ctx, tx, msgParams)
|
||||
if err != nil {
|
||||
return nil, nil, false, xerrors.Errorf("insert promoted message: %w", err)
|
||||
@@ -6322,11 +6393,39 @@ type runChatResult struct {
|
||||
ProviderKeys chatprovider.ProviderAPIKeys
|
||||
PendingDynamicToolCalls []chatloop.PendingToolCall
|
||||
FallbackProvider string
|
||||
FallbackRoute resolvedModelRoute
|
||||
FallbackModel string
|
||||
ModelBuildOptions modelBuildOptions
|
||||
TriggerMessageID int64
|
||||
HistoryTipMessageID int64
|
||||
}
|
||||
|
||||
func contextWithActiveTurnAPIKeyID(ctx context.Context, messages []database.ChatMessage) context.Context {
|
||||
apiKeyID, ok := activeTurnAPIKeyIDFromMessages(messages)
|
||||
if !ok {
|
||||
return ctx
|
||||
}
|
||||
return aibridge.WithDelegatedAPIKeyID(ctx, apiKeyID)
|
||||
}
|
||||
|
||||
func activeTurnAPIKeyIDFromMessages(messages []database.ChatMessage) (string, bool) {
|
||||
for i := len(messages) - 1; i >= 0; i-- {
|
||||
message := messages[i]
|
||||
if message.Role != database.ChatMessageRoleUser {
|
||||
continue
|
||||
}
|
||||
if message.Visibility != database.ChatMessageVisibilityBoth &&
|
||||
message.Visibility != database.ChatMessageVisibilityUser {
|
||||
continue
|
||||
}
|
||||
if !message.APIKeyID.Valid || message.APIKeyID.String == "" {
|
||||
return "", false
|
||||
}
|
||||
return message.APIKeyID.String, true
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
func allToolNames(allTools []fantasy.AgentTool) []string {
|
||||
toolNames := make([]string, 0, len(allTools))
|
||||
for _, tool := range allTools {
|
||||
@@ -6948,12 +7047,20 @@ func (p *Server) runChat(
|
||||
err error
|
||||
debugEnabled bool
|
||||
debugProvider string
|
||||
modelRoute resolvedModelRoute
|
||||
debugModel string
|
||||
)
|
||||
|
||||
// Load MCP server configs and user tokens in parallel with
|
||||
// model resolution and message loading. These queries have
|
||||
// no dependencies on each other and all hit different tables.
|
||||
messages, err = p.db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
|
||||
if err != nil {
|
||||
return result, xerrors.Errorf("get chat messages: %w", err)
|
||||
}
|
||||
modelOpts := modelBuildOptionsFromMessages(messages)
|
||||
ctx = contextWithActiveTurnAPIKeyID(ctx, messages)
|
||||
|
||||
// Load MCP server configs and user tokens in parallel with model
|
||||
// resolution. These queries have no dependencies on each other and all
|
||||
// hit different tables.
|
||||
var (
|
||||
mcpConfigs []database.MCPServerConfig
|
||||
mcpTokens []database.MCPServerUserToken
|
||||
@@ -6961,7 +7068,7 @@ func (p *Server) runChat(
|
||||
var g errgroup.Group
|
||||
g.Go(func() error {
|
||||
var err error
|
||||
model, modelConfig, providerKeys, debugEnabled, debugProvider, debugModel, err = p.resolveChatModel(ctx, chat)
|
||||
model, modelConfig, providerKeys, modelRoute, debugEnabled, debugProvider, debugModel, err = p.resolveChatModel(ctx, chat, modelOpts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -6972,14 +7079,6 @@ func (p *Server) runChat(
|
||||
}
|
||||
return nil
|
||||
})
|
||||
g.Go(func() error {
|
||||
var err error
|
||||
messages, err = p.db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get chat messages: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if len(chat.MCPServerIDs) > 0 {
|
||||
g.Go(func() error {
|
||||
var err error
|
||||
@@ -7055,15 +7154,20 @@ func (p *Server) runChat(
|
||||
// registering the runtime there would inject guidance for a tool
|
||||
// that is never exposed to the model.
|
||||
if advisorCfg.Enabled && isRootChat && !isPlanModeTurn && !isExploreSubagent {
|
||||
advisorRuntime = p.newAdvisorRuntime(
|
||||
var advisorErr error
|
||||
advisorRuntime, advisorErr = p.newAdvisorRuntime(
|
||||
ctx,
|
||||
chat,
|
||||
advisorCfg,
|
||||
model,
|
||||
callConfig,
|
||||
providerKeys,
|
||||
modelOpts,
|
||||
logger,
|
||||
)
|
||||
if advisorErr != nil {
|
||||
return result, advisorErr
|
||||
}
|
||||
}
|
||||
|
||||
var advisorPromptSnapshot []fantasy.Message
|
||||
@@ -7090,7 +7194,9 @@ func (p *Server) runChat(
|
||||
result.StatusLabelModel = model
|
||||
result.ProviderKeys = providerKeys
|
||||
result.FallbackProvider = modelConfig.Provider
|
||||
result.FallbackRoute = modelRoute
|
||||
result.FallbackModel = modelConfig.Model
|
||||
result.ModelBuildOptions = modelOpts
|
||||
debugSvc := p.existingDebugService()
|
||||
// Fire title generation asynchronously so it doesn't block the
|
||||
// chat response. It uses a detached context so it can finish
|
||||
@@ -7111,7 +7217,9 @@ func (p *Server) runChat(
|
||||
modelConfig.Provider,
|
||||
modelConfig.Model,
|
||||
titleModel,
|
||||
modelRoute,
|
||||
titleProviderKeys,
|
||||
modelOpts,
|
||||
generatedTitle,
|
||||
titleLogger,
|
||||
debugSvc,
|
||||
@@ -7681,20 +7789,21 @@ func (p *Server) runChat(
|
||||
}
|
||||
|
||||
if isComputerUse {
|
||||
computerUseProviderKeys, keyErr := p.resolveUserProviderAPIKeysForProviderType(ctx, chat.OwnerID, computerUseModelProvider)
|
||||
computerUseRoute, keyErr := p.resolveModelRouteForProviderType(ctx, chat.OwnerID, computerUseModelProvider)
|
||||
if keyErr != nil {
|
||||
return result, xerrors.Errorf("resolve computer use provider API keys: %w", keyErr)
|
||||
return result, xerrors.Errorf("resolve computer use provider route: %w", keyErr)
|
||||
}
|
||||
providerKeys = computerUseProviderKeys
|
||||
providerKeys = computerUseRoute.directProviderKeys()
|
||||
|
||||
// Override model for computer use subagent.
|
||||
cuModel, cuDebugEnabled, resolvedProvider, resolvedModel, cuErr := p.resolveComputerUseModel(
|
||||
ctx,
|
||||
chat,
|
||||
providerKeys,
|
||||
computerUseRoute,
|
||||
computerUseProvider,
|
||||
computerUseModelProvider,
|
||||
computerUseModelName,
|
||||
modelOpts,
|
||||
)
|
||||
if cuErr != nil {
|
||||
return result, cuErr
|
||||
@@ -8394,45 +8503,15 @@ func (p *Server) persistChatContextSummary(
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Server) resolveModelConfigProviderHintAndKeys(
|
||||
ctx context.Context,
|
||||
ownerID uuid.UUID,
|
||||
modelConfig database.ChatModelConfig,
|
||||
fallbackKeys chatprovider.ProviderAPIKeys,
|
||||
) (string, chatprovider.ProviderAPIKeys, error) {
|
||||
providerHint := modelConfig.Provider
|
||||
if !modelConfig.AIProviderID.Valid {
|
||||
if !fallbackKeys.Empty() && userCanUseProviderKeys(fallbackKeys, providerHint) {
|
||||
return providerHint, fallbackKeys, nil
|
||||
}
|
||||
keys, err := p.resolveUserProviderAPIKeys(ctx, ownerID, uuid.Nil)
|
||||
if err != nil {
|
||||
return "", chatprovider.ProviderAPIKeys{}, xerrors.Errorf("resolve provider API keys: %w", err)
|
||||
}
|
||||
return providerHint, keys, nil
|
||||
}
|
||||
//nolint:gocritic // Manual title generation needs chatd-scoped provider reads for user-owned chats.
|
||||
provider, err := p.db.GetAIProviderByID(dbauthz.AsChatd(ctx), modelConfig.AIProviderID.UUID)
|
||||
if err != nil {
|
||||
return "", chatprovider.ProviderAPIKeys{}, xerrors.Errorf("get AI provider: %w", err)
|
||||
}
|
||||
if !provider.Enabled {
|
||||
return "", chatprovider.ProviderAPIKeys{}, xerrors.Errorf("AI provider %s is disabled", provider.ID)
|
||||
}
|
||||
providerKeys, err := p.resolveUserProviderAPIKeysForProvider(ctx, ownerID, provider)
|
||||
if err != nil {
|
||||
return "", chatprovider.ProviderAPIKeys{}, xerrors.Errorf("resolve provider API keys: %w", err)
|
||||
}
|
||||
return string(provider.Type), providerKeys, nil
|
||||
}
|
||||
|
||||
func (p *Server) resolveChatModel(
|
||||
ctx context.Context,
|
||||
chat database.Chat,
|
||||
modelOpts modelBuildOptions,
|
||||
) (
|
||||
model fantasy.LanguageModel,
|
||||
dbConfig database.ChatModelConfig,
|
||||
keys chatprovider.ProviderAPIKeys,
|
||||
route resolvedModelRoute,
|
||||
debugEnabled bool,
|
||||
resolvedProvider string,
|
||||
resolvedModel string,
|
||||
@@ -8440,57 +8519,45 @@ func (p *Server) resolveChatModel(
|
||||
) {
|
||||
dbConfig, err = p.resolveModelConfig(ctx, chat)
|
||||
if err != nil {
|
||||
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, false, "", "", xerrors.Errorf("resolve model config: %w", err)
|
||||
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, resolvedModelRoute{}, false, "", "", xerrors.Errorf("resolve model config: %w", err)
|
||||
}
|
||||
|
||||
if !dbConfig.Enabled {
|
||||
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, false, "", "", xerrors.Errorf("chat model config %s is disabled", dbConfig.ID)
|
||||
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, resolvedModelRoute{}, false, "", "", xerrors.Errorf("chat model config %s is disabled", dbConfig.ID)
|
||||
}
|
||||
|
||||
providerHint := dbConfig.Provider
|
||||
var keyErr error
|
||||
if dbConfig.AIProviderID.Valid {
|
||||
provider, err := p.db.GetAIProviderByID(ctx, dbConfig.AIProviderID.UUID)
|
||||
if err != nil {
|
||||
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, false, "", "", xerrors.Errorf("get AI provider: %w", err)
|
||||
}
|
||||
if !provider.Enabled {
|
||||
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, false, "", "", xerrors.Errorf("AI provider %s is disabled", provider.ID)
|
||||
}
|
||||
providerHint = string(provider.Type)
|
||||
keys, keyErr = p.resolveUserProviderAPIKeysForProvider(ctx, chat.OwnerID, provider)
|
||||
} else {
|
||||
keys, keyErr = p.resolveUserProviderAPIKeys(ctx, chat.OwnerID, uuid.Nil)
|
||||
}
|
||||
if keyErr != nil {
|
||||
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, false, "", "", xerrors.Errorf("resolve provider API keys: %w", keyErr)
|
||||
route, err = p.resolveModelRouteForConfig(ctx, chat.OwnerID, dbConfig, chatprovider.ProviderAPIKeys{})
|
||||
if err != nil {
|
||||
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, resolvedModelRoute{}, false, "", "", err
|
||||
}
|
||||
keys = route.directProviderKeys()
|
||||
|
||||
providerHint, err := route.providerHint()
|
||||
if err != nil {
|
||||
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, resolvedModelRoute{}, false, "", "", err
|
||||
}
|
||||
resolvedProvider, resolvedModel, err = chatprovider.ResolveModelWithProviderHint(
|
||||
dbConfig.Model,
|
||||
providerHint,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, false, "", "", xerrors.Errorf(
|
||||
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, resolvedModelRoute{}, false, "", "", xerrors.Errorf(
|
||||
"resolve model metadata: %w", err,
|
||||
)
|
||||
}
|
||||
|
||||
model, debugEnabled, err = p.newDebugAwareModelFromConfig(
|
||||
ctx,
|
||||
chat,
|
||||
providerHint,
|
||||
dbConfig.Model,
|
||||
keys,
|
||||
chatprovider.UserAgent(),
|
||||
chatprovider.CoderHeaders(chat),
|
||||
)
|
||||
model, debugEnabled, err = p.newDebugAwareModel(ctx, modelClientRequest{
|
||||
Chat: chat,
|
||||
ModelName: dbConfig.Model,
|
||||
UserAgent: chatprovider.UserAgent(),
|
||||
ExtraHeaders: chatprovider.CoderHeaders(chat),
|
||||
}, route, modelOpts)
|
||||
if err != nil {
|
||||
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, false, "", "", xerrors.Errorf(
|
||||
return nil, database.ChatModelConfig{}, chatprovider.ProviderAPIKeys{}, resolvedModelRoute{}, false, "", "", xerrors.Errorf(
|
||||
"create model: %w", err,
|
||||
)
|
||||
}
|
||||
return model, dbConfig, keys, debugEnabled, resolvedProvider, resolvedModel, nil
|
||||
return model, dbConfig, keys, route, debugEnabled, resolvedProvider, resolvedModel, nil
|
||||
}
|
||||
|
||||
func (p *Server) aiProviderConfig(ctx context.Context, provider database.AIProvider) (chatprovider.ConfiguredProvider, error) {
|
||||
@@ -8605,9 +8672,18 @@ func (p *Server) resolveUserProviderAPIKeysForProviderType(
|
||||
ownerID uuid.UUID,
|
||||
providerType string,
|
||||
) (chatprovider.ProviderAPIKeys, error) {
|
||||
keys, _, err := p.resolveUserProviderAPIKeysAndProviderForProviderType(ctx, ownerID, providerType)
|
||||
return keys, err
|
||||
}
|
||||
|
||||
func (p *Server) resolveUserProviderAPIKeysAndProviderForProviderType(
|
||||
ctx context.Context,
|
||||
ownerID uuid.UUID,
|
||||
providerType string,
|
||||
) (chatprovider.ProviderAPIKeys, *database.AIProvider, error) {
|
||||
providers, err := p.db.GetAIProviders(ctx, database.GetAIProvidersParams{})
|
||||
if err != nil {
|
||||
return chatprovider.ProviderAPIKeys{}, xerrors.Errorf("get enabled AI providers: %w", err)
|
||||
return chatprovider.ProviderAPIKeys{}, nil, xerrors.Errorf("get enabled AI providers: %w", err)
|
||||
}
|
||||
normalizedProviderType := chatprovider.NormalizeProvider(providerType)
|
||||
for _, provider := range providers {
|
||||
@@ -8616,13 +8692,17 @@ func (p *Server) resolveUserProviderAPIKeysForProviderType(
|
||||
}
|
||||
keys, err := p.resolveUserProviderAPIKeysForProvider(ctx, ownerID, provider)
|
||||
if err != nil {
|
||||
return chatprovider.ProviderAPIKeys{}, err
|
||||
return chatprovider.ProviderAPIKeys{}, nil, err
|
||||
}
|
||||
if userCanUseProviderKeys(keys, normalizedProviderType) {
|
||||
return keys, nil
|
||||
return keys, &provider, nil
|
||||
}
|
||||
}
|
||||
return p.resolveUserProviderAPIKeys(ctx, ownerID, uuid.Nil)
|
||||
keys, err := p.resolveUserProviderAPIKeys(ctx, ownerID, uuid.Nil)
|
||||
if err != nil {
|
||||
return chatprovider.ProviderAPIKeys{}, nil, err
|
||||
}
|
||||
return keys, nil, nil
|
||||
}
|
||||
|
||||
func (p *Server) resolveUserProviderAPIKeys(
|
||||
@@ -9391,6 +9471,7 @@ func insertSyntheticToolResultsTx(
|
||||
params := database.InsertChatMessagesParams{
|
||||
ChatID: chat.ID,
|
||||
CreatedBy: make([]uuid.UUID, n),
|
||||
APIKeyID: make([]string, n),
|
||||
ModelConfigID: make([]uuid.UUID, n),
|
||||
Role: make([]database.ChatMessageRole, n),
|
||||
Content: make([]string, n),
|
||||
@@ -9537,7 +9618,7 @@ func (p *Server) generateFinalTurnStatusLabel(
|
||||
return fallbackTurnStatusLabel(status)
|
||||
}
|
||||
|
||||
statusLabel := generateTurnStatusLabel(
|
||||
statusLabel := p.generateTurnStatusLabel(
|
||||
ctx,
|
||||
chat,
|
||||
status,
|
||||
@@ -9545,7 +9626,9 @@ func (p *Server) generateFinalTurnStatusLabel(
|
||||
runResult.FallbackProvider,
|
||||
runResult.FallbackModel,
|
||||
runResult.StatusLabelModel,
|
||||
runResult.FallbackRoute,
|
||||
runResult.ProviderKeys,
|
||||
runResult.ModelBuildOptions,
|
||||
logger,
|
||||
p.existingDebugService(),
|
||||
runResult.TriggerMessageID,
|
||||
|
||||
@@ -2,14 +2,11 @@ package chatd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatdebug"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
|
||||
)
|
||||
@@ -109,53 +106,38 @@ func (p *Server) scheduleDebugCleanup(
|
||||
}()
|
||||
}
|
||||
|
||||
func (p *Server) newDebugAwareModelFromConfig(
|
||||
func (p *Server) newDebugAwareModel(
|
||||
ctx context.Context,
|
||||
chat database.Chat,
|
||||
providerHint string,
|
||||
modelName string,
|
||||
providerKeys chatprovider.ProviderAPIKeys,
|
||||
userAgent string,
|
||||
extraHeaders map[string]string,
|
||||
req modelClientRequest,
|
||||
route resolvedModelRoute,
|
||||
opts modelBuildOptions,
|
||||
) (fantasy.LanguageModel, bool, error) {
|
||||
provider, resolvedModel, err := chatprovider.ResolveModelWithProviderHint(modelName, providerHint)
|
||||
providerHint, err := route.providerHint()
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
provider, resolvedModel, err := chatprovider.ResolveModelWithProviderHint(req.ModelName, providerHint)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
route = route.withProviderHint(provider)
|
||||
req.ModelName = resolvedModel
|
||||
|
||||
debugSvc := p.debugService()
|
||||
debugEnabled := debugSvc != nil && debugSvc.IsEnabled(ctx, chat.ID, chat.OwnerID)
|
||||
debugEnabled := debugSvc != nil && debugSvc.IsEnabled(ctx, req.Chat.ID, req.Chat.OwnerID)
|
||||
opts.RecordHTTP = debugEnabled
|
||||
|
||||
var httpClient *http.Client
|
||||
if debugEnabled {
|
||||
httpClient = &http.Client{Transport: &chatdebug.RecordingTransport{}}
|
||||
}
|
||||
|
||||
model, err := chatprovider.ModelFromConfig(
|
||||
provider,
|
||||
resolvedModel,
|
||||
providerKeys,
|
||||
userAgent,
|
||||
extraHeaders,
|
||||
httpClient,
|
||||
)
|
||||
model, err := p.newModel(ctx, req, route, opts)
|
||||
if err != nil {
|
||||
return nil, debugEnabled, err
|
||||
}
|
||||
if model == nil {
|
||||
return nil, debugEnabled, xerrors.Errorf(
|
||||
"create model for %s/%s returned nil",
|
||||
provider,
|
||||
resolvedModel,
|
||||
)
|
||||
}
|
||||
if !debugEnabled {
|
||||
return model, false, nil
|
||||
}
|
||||
|
||||
return chatdebug.WrapModel(model, debugSvc, chatdebug.RecorderOptions{
|
||||
ChatID: chat.ID,
|
||||
OwnerID: chat.OwnerID,
|
||||
ChatID: req.Chat.ID,
|
||||
OwnerID: req.Chat.OwnerID,
|
||||
Provider: provider,
|
||||
Model: resolvedModel,
|
||||
}), true, nil
|
||||
|
||||
@@ -152,10 +152,11 @@ func TestResolveComputerUseModel_OpenAIMissingCredentials(t *testing.T) {
|
||||
model, debugEnabled, resolvedProvider, resolvedModel, err := server.resolveComputerUseModel(
|
||||
context.Background(),
|
||||
database.Chat{ID: uuid.New(), OwnerID: uuid.New()},
|
||||
chatprovider.ProviderAPIKeys{},
|
||||
newDirectModelRoute(modelProvider, chatprovider.ProviderAPIKeys{}),
|
||||
provider,
|
||||
modelProvider,
|
||||
modelName,
|
||||
modelBuildOptions{},
|
||||
)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, model)
|
||||
@@ -167,6 +168,55 @@ func TestResolveComputerUseModel_OpenAIMissingCredentials(t *testing.T) {
|
||||
require.NotContains(t, err.Error(), "ANTHROPIC_API_KEY")
|
||||
}
|
||||
|
||||
func TestResolveUserProviderAPIKeysAndProviderForProviderTypeProviderMatch(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
ownerID := uuid.New()
|
||||
providerID := uuid.New()
|
||||
|
||||
db.EXPECT().GetAIProviders(gomock.Any(), database.GetAIProvidersParams{}).Return([]database.AIProvider{
|
||||
{ID: uuid.New(), Type: database.AiProviderTypeAnthropic, Enabled: true},
|
||||
{ID: providerID, Type: database.AiProviderTypeOpenai, Enabled: true},
|
||||
}, nil)
|
||||
db.EXPECT().GetAIProviderKeysByProviderID(gomock.Any(), providerID).Return([]database.AIProviderKey{{
|
||||
ProviderID: providerID,
|
||||
APIKey: "test-key",
|
||||
}}, nil)
|
||||
|
||||
server := &Server{db: db}
|
||||
keys, aiProvider, err := server.resolveUserProviderAPIKeysAndProviderForProviderType(
|
||||
ctx,
|
||||
ownerID,
|
||||
chattool.ComputerUseProviderOpenAI,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "test-key", keys.APIKey(chattool.ComputerUseProviderOpenAI))
|
||||
require.NotNil(t, aiProvider)
|
||||
require.Equal(t, providerID, aiProvider.ID)
|
||||
require.Equal(t, database.AiProviderTypeOpenai, aiProvider.Type)
|
||||
}
|
||||
|
||||
func TestResolveModelRouteForProviderTypeAIGatewayRequiresProvider(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
|
||||
db.EXPECT().GetAIProviders(gomock.Any(), database.GetAIProvidersParams{}).Return(nil, nil)
|
||||
|
||||
server := &Server{db: db, aiGatewayRoutingEnabled: true}
|
||||
_, err := server.resolveModelRouteForProviderType(
|
||||
ctx,
|
||||
uuid.New(),
|
||||
chattool.ComputerUseProviderOpenAI,
|
||||
)
|
||||
require.ErrorContains(t, err, "AI Gateway routing requires a usable AI provider")
|
||||
}
|
||||
|
||||
func TestAppendComputerUseProviderTool(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -731,6 +781,11 @@ func TestRenameChatTitle(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func withChatMessageAPIKeyID(message database.ChatMessage, apiKeyID string) database.ChatMessage {
|
||||
message.APIKeyID = sqlNullString(apiKeyID)
|
||||
return message
|
||||
}
|
||||
|
||||
func TestRegenerateChatTitle_PersistsAndBroadcasts(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -749,6 +804,7 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts(t *testing.T) {
|
||||
modelConfigID := uuid.New()
|
||||
workerID := uuid.New()
|
||||
userPrompt := "review pull request 23633 and fix review threads"
|
||||
activeAPIKeyID := "key-" + uuid.NewString()
|
||||
wantTitle := "Review PR 23633"
|
||||
|
||||
chat := database.Chat{
|
||||
@@ -814,12 +870,12 @@ func TestRegenerateChatTitle_PersistsAndBroadcasts(t *testing.T) {
|
||||
LimitVal: manualTitleMessageWindowLimit,
|
||||
},
|
||||
).Return([]database.ChatMessage{
|
||||
mustChatMessage(
|
||||
withChatMessageAPIKeyID(mustChatMessage(
|
||||
t,
|
||||
database.ChatMessageRoleUser,
|
||||
database.ChatMessageVisibilityBoth,
|
||||
codersdk.ChatMessageText(userPrompt),
|
||||
),
|
||||
), activeAPIKeyID),
|
||||
mustChatMessage(
|
||||
t,
|
||||
database.ChatMessageRoleAssistant,
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
@@ -32,6 +33,7 @@ import (
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/agent/agentcontextconfig"
|
||||
"github.com/coder/coder/v2/agent/agenttest"
|
||||
"github.com/coder/coder/v2/coderd/aibridge"
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/db2sdk"
|
||||
@@ -67,6 +69,88 @@ type recordedOpenAIRequest struct {
|
||||
ContentLength int64
|
||||
}
|
||||
|
||||
type chatAIGatewayRecordedRequest struct {
|
||||
ProviderName string
|
||||
Source aibridge.Source
|
||||
APIKeyID string
|
||||
Path string
|
||||
Authorization string
|
||||
XAPIKey string
|
||||
CoderToken string
|
||||
}
|
||||
|
||||
type chatAIGatewayTestFactory struct {
|
||||
target *url.URL
|
||||
transport http.RoundTripper
|
||||
mu sync.Mutex
|
||||
requests []chatAIGatewayRecordedRequest
|
||||
}
|
||||
|
||||
func newChatAIGatewayTestFactory(t testing.TB, targetBaseURL string) *chatAIGatewayTestFactory {
|
||||
t.Helper()
|
||||
|
||||
target, err := url.Parse(targetBaseURL)
|
||||
require.NoError(t, err)
|
||||
return &chatAIGatewayTestFactory{target: target, transport: http.DefaultTransport}
|
||||
}
|
||||
|
||||
func (f *chatAIGatewayTestFactory) TransportFor(providerName string, source aibridge.Source) (http.RoundTripper, error) {
|
||||
return chatAIGatewayRoundTripper{factory: f, providerName: providerName, source: source}, nil
|
||||
}
|
||||
|
||||
func (f *chatAIGatewayTestFactory) requestsSnapshot() []chatAIGatewayRecordedRequest {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
return slices.Clone(f.requests)
|
||||
}
|
||||
|
||||
type chatAIGatewayRoundTripper struct {
|
||||
factory *chatAIGatewayTestFactory
|
||||
providerName string
|
||||
source aibridge.Source
|
||||
}
|
||||
|
||||
func (t chatAIGatewayRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
apiKeyID, _ := aibridge.DelegatedAPIKeyIDFromContext(req.Context())
|
||||
t.factory.mu.Lock()
|
||||
t.factory.requests = append(t.factory.requests, chatAIGatewayRecordedRequest{
|
||||
ProviderName: t.providerName,
|
||||
Source: t.source,
|
||||
APIKeyID: apiKeyID,
|
||||
Path: req.URL.Path,
|
||||
Authorization: req.Header.Get("Authorization"),
|
||||
XAPIKey: req.Header.Get("X-Api-Key"),
|
||||
CoderToken: req.Header.Get(aibridge.HeaderCoderToken),
|
||||
})
|
||||
t.factory.mu.Unlock()
|
||||
|
||||
targetURL := *t.factory.target
|
||||
targetURL.Path = strings.TrimPrefix(req.URL.Path, "/v1")
|
||||
if targetURL.Path == "" {
|
||||
targetURL.Path = "/"
|
||||
}
|
||||
targetURL.RawQuery = req.URL.RawQuery
|
||||
|
||||
cloned := req.Clone(req.Context())
|
||||
cloned.URL = &targetURL
|
||||
cloned.Host = t.factory.target.Host
|
||||
return t.factory.transport.RoundTrip(cloned)
|
||||
}
|
||||
|
||||
func chatAIGatewayTransportFactoryPointer(factory aibridge.TransportFactory) *atomic.Pointer[aibridge.TransportFactory] {
|
||||
var ptr atomic.Pointer[aibridge.TransportFactory]
|
||||
ptr.Store(&factory)
|
||||
return &ptr
|
||||
}
|
||||
|
||||
func directChatRoutingDeploymentValues(t testing.TB) *codersdk.DeploymentValues {
|
||||
t.Helper()
|
||||
|
||||
values := coderdtest.DeploymentValues(t)
|
||||
require.NoError(t, values.AI.Chat.AIGatewayRoutingEnabled.Set("false"))
|
||||
return values
|
||||
}
|
||||
|
||||
func openAIToolName(tool chattest.OpenAITool) string {
|
||||
return cmp.Or(tool.Function.Name, tool.Name, tool.Type)
|
||||
}
|
||||
@@ -231,7 +315,7 @@ func TestSubagentChatExcludesWorkspaceProvisioningTools(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
deploymentValues := coderdtest.DeploymentValues(t)
|
||||
deploymentValues := directChatRoutingDeploymentValues(t)
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
DeploymentValues: deploymentValues,
|
||||
IncludeProvisionerDaemon: true,
|
||||
@@ -388,7 +472,7 @@ func TestPlanModeSubagentChatExcludesAskUserQuestion(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
deploymentValues := coderdtest.DeploymentValues(t)
|
||||
deploymentValues := directChatRoutingDeploymentValues(t)
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
DeploymentValues: deploymentValues,
|
||||
IncludeProvisionerDaemon: true,
|
||||
@@ -555,7 +639,7 @@ func TestExploreSubagentIsReadOnly(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
deploymentValues := coderdtest.DeploymentValues(t)
|
||||
deploymentValues := directChatRoutingDeploymentValues(t)
|
||||
client, db := coderdtest.NewWithDatabase(t, &coderdtest.Options{
|
||||
DeploymentValues: deploymentValues,
|
||||
IncludeProvisionerDaemon: true,
|
||||
@@ -1796,6 +1880,73 @@ func TestUpdateChatHeartbeatsRequiresOwnership(t *testing.T) {
|
||||
require.Equal(t, chat.ID, ids[0])
|
||||
}
|
||||
|
||||
func TestCreateChatPersistsAPIKeyIDOnInitialUserMessage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
replica := newTestServer(t, db, ps, uuid.New())
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
user, org, model := seedChatDependencies(t, db)
|
||||
apiKey, _ := dbgen.APIKey(t, db, database.APIKey{UserID: user.ID})
|
||||
|
||||
chat, err := replica.CreateChat(ctx, chatd.CreateOptions{
|
||||
OrganizationID: org.ID,
|
||||
OwnerID: user.ID,
|
||||
Title: "create-chat-api-key-id",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
||||
APIKeyID: apiKey.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{
|
||||
ChatID: chat.ID,
|
||||
AfterID: 0,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Len(t, messages, 1)
|
||||
require.Equal(t, database.ChatMessageRoleUser, messages[0].Role)
|
||||
require.True(t, messages[0].APIKeyID.Valid)
|
||||
require.Equal(t, apiKey.ID, messages[0].APIKeyID.String)
|
||||
}
|
||||
|
||||
func TestSendMessagePersistsAPIKeyIDOnUserMessage(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
replica := newTestServer(t, db, ps, uuid.New())
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
user, org, model := seedChatDependencies(t, db)
|
||||
apiKey, _ := dbgen.APIKey(t, db, database.APIKey{UserID: user.ID})
|
||||
|
||||
chat := dbgen.Chat(t, db, database.Chat{
|
||||
OrganizationID: org.ID,
|
||||
OwnerID: user.ID,
|
||||
LastModelConfigID: model.ID,
|
||||
Title: "send-message-api-key-id",
|
||||
})
|
||||
|
||||
result, err := replica.SendMessage(ctx, chatd.SendMessageOptions{
|
||||
ChatID: chat.ID,
|
||||
CreatedBy: user.ID,
|
||||
Content: []codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText("message with api key id"),
|
||||
},
|
||||
APIKeyID: apiKey.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.False(t, result.Queued)
|
||||
require.True(t, result.Message.APIKeyID.Valid)
|
||||
require.Equal(t, apiKey.ID, result.Message.APIKeyID.String)
|
||||
|
||||
stored, err := db.GetChatMessageByID(ctx, result.Message.ID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, stored.APIKeyID.Valid)
|
||||
require.Equal(t, apiKey.ID, stored.APIKeyID.String)
|
||||
}
|
||||
|
||||
func TestSendMessageQueueBehaviorQueuesWhenBusy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -2143,15 +2294,20 @@ func TestEditMessageUpdatesAndTruncatesAndClearsQueue(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
apiKey, _ := dbgen.APIKey(t, db, database.APIKey{UserID: user.ID})
|
||||
apiKeyID := apiKey.ID
|
||||
editResult, err := replica.EditMessage(ctx, chatd.EditMessageOptions{
|
||||
ChatID: chat.ID,
|
||||
EditedMessageID: editedMessageID,
|
||||
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("edited")},
|
||||
APIKeyID: apiKeyID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
// The edited message is soft-deleted and a new message is inserted,
|
||||
// so the returned message ID will differ from the original.
|
||||
require.NotEqual(t, editedMessageID, editResult.Message.ID)
|
||||
require.True(t, editResult.Message.APIKeyID.Valid)
|
||||
require.Equal(t, apiKeyID, editResult.Message.APIKeyID.String)
|
||||
require.Equal(t, database.ChatStatusPending, editResult.Chat.Status)
|
||||
require.False(t, editResult.Chat.WorkerID.Valid)
|
||||
|
||||
@@ -2166,6 +2322,8 @@ func TestEditMessageUpdatesAndTruncatesAndClearsQueue(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
require.Len(t, messages, 1)
|
||||
require.Equal(t, editResult.Message.ID, messages[0].ID)
|
||||
require.True(t, messages[0].APIKeyID.Valid)
|
||||
require.Equal(t, apiKeyID, messages[0].APIKeyID.String)
|
||||
onlyMessage := db2sdk.ChatMessage(messages[0])
|
||||
require.Len(t, onlyMessage.Content, 1)
|
||||
require.Equal(t, "edited", onlyMessage.Content[0].Text)
|
||||
@@ -2366,6 +2524,7 @@ func TestPromoteQueuedAllowsAlreadyQueuedMessageWhenUsageLimitReached(t *testing
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
user, org, model := seedChatDependencies(t, db)
|
||||
apiKey, _ := dbgen.APIKey(t, db, database.APIKey{UserID: user.ID})
|
||||
|
||||
_, err := db.UpsertChatUsageLimitConfig(ctx, database.UpsertChatUsageLimitConfigParams{
|
||||
Enabled: true,
|
||||
@@ -2400,11 +2559,14 @@ func TestPromoteQueuedAllowsAlreadyQueuedMessageWhenUsageLimitReached(t *testing
|
||||
queuedResult, err := replica.SendMessage(ctx, chatd.SendMessageOptions{
|
||||
ChatID: chat.ID,
|
||||
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText("queued")},
|
||||
APIKeyID: apiKey.ID,
|
||||
BusyBehavior: chatd.SendMessageBusyBehaviorQueue,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.True(t, queuedResult.Queued)
|
||||
require.NotNil(t, queuedResult.QueuedMessage)
|
||||
require.True(t, queuedResult.QueuedMessage.APIKeyID.Valid)
|
||||
require.Equal(t, apiKey.ID, queuedResult.QueuedMessage.APIKeyID.String)
|
||||
|
||||
assistantContent, err := chatprompt.MarshalParts([]codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText("assistant"),
|
||||
@@ -2437,6 +2599,8 @@ func TestPromoteQueuedAllowsAlreadyQueuedMessageWhenUsageLimitReached(t *testing
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, database.ChatMessageRoleUser, result.PromotedMessage.Role)
|
||||
require.True(t, result.PromotedMessage.APIKeyID.Valid)
|
||||
require.Equal(t, apiKey.ID, result.PromotedMessage.APIKeyID.String)
|
||||
|
||||
queued, err := db.GetChatQueuedMessages(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
@@ -4864,7 +5028,7 @@ func TestCreateWorkspaceTool_EndToEnd(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
deploymentValues := coderdtest.DeploymentValues(t)
|
||||
deploymentValues := directChatRoutingDeploymentValues(t)
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
DeploymentValues: deploymentValues,
|
||||
IncludeProvisionerDaemon: true,
|
||||
@@ -5029,7 +5193,7 @@ func TestStartWorkspaceTool_EndToEnd(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitSuperLong)
|
||||
deploymentValues := coderdtest.DeploymentValues(t)
|
||||
deploymentValues := directChatRoutingDeploymentValues(t)
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
DeploymentValues: deploymentValues,
|
||||
IncludeProvisionerDaemon: true,
|
||||
@@ -7252,6 +7416,100 @@ func TestProcessChat_UserProviderKey_Success(t *testing.T) {
|
||||
require.Contains(t, recordedAuthHeaders, "Bearer "+userAPIKey)
|
||||
}
|
||||
|
||||
func TestProcessChat_AIGatewayRoutingUsesDelegatedAPIKey(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
openAIURL := chattest.NewOpenAI(t, func(req *chattest.OpenAIRequest) chattest.OpenAIResponse {
|
||||
if req.Stream {
|
||||
return chattest.OpenAIStreamingResponse(
|
||||
chattest.OpenAITextChunks("hello through AI Gateway")...,
|
||||
)
|
||||
}
|
||||
return chattest.OpenAINonStreamingResponse(`{"title":"AI Gateway Chat"}`)
|
||||
})
|
||||
factory := newChatAIGatewayTestFactory(t, openAIURL)
|
||||
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
||||
UserID: user.ID,
|
||||
OrganizationID: org.ID,
|
||||
})
|
||||
provider := dbgen.AIProvider(t, db, database.AIProvider{
|
||||
Type: database.AiProviderTypeOpenai,
|
||||
Name: "primary-openai-" + uuid.NewString(),
|
||||
BaseUrl: openAIURL,
|
||||
})
|
||||
model := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{
|
||||
Provider: string(database.AiProviderTypeOpenai),
|
||||
Model: "gpt-4o-mini",
|
||||
IsDefault: true,
|
||||
AIProviderID: uuid.NullUUID{UUID: provider.ID, Valid: true},
|
||||
})
|
||||
apiKey, _ := dbgen.APIKey(t, db, database.APIKey{UserID: user.ID})
|
||||
_, err := db.UpsertUserAIProviderKey(ctx, database.UpsertUserAIProviderKeyParams{
|
||||
ID: uuid.New(),
|
||||
UserID: user.ID,
|
||||
AIProviderID: provider.ID,
|
||||
APIKey: "sk-user-aibridge",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
creator := newTestServer(t, db, ps, uuid.New())
|
||||
chat, err := creator.CreateChat(ctx, chatd.CreateOptions{
|
||||
OrganizationID: org.ID,
|
||||
OwnerID: user.ID,
|
||||
Title: "aigateway-routing",
|
||||
ModelConfigID: model.ID,
|
||||
APIKeyID: apiKey.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText("say hello"),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, events, cancel, ok := creator.Subscribe(ctx, chat.ID, nil, 0)
|
||||
require.True(t, ok)
|
||||
t.Cleanup(cancel)
|
||||
|
||||
_ = newActiveTestServer(t, db, ps, func(cfg *chatd.Config) {
|
||||
cfg.AIBridgeTransportFactory = chatAIGatewayTransportFactoryPointer(factory)
|
||||
cfg.AIGatewayRoutingEnabled = true
|
||||
cfg.AllowBYOK = true
|
||||
cfg.AllowBYOKSet = true
|
||||
})
|
||||
|
||||
terminalStatus := waitForTerminalChatStatusEvent(ctx, t, events)
|
||||
require.Equal(t, codersdk.ChatStatusWaiting, terminalStatus)
|
||||
|
||||
chatResult := waitForTerminalChat(ctx, t, db, chat.ID)
|
||||
require.Equal(t, database.ChatStatusWaiting, chatResult.Status)
|
||||
require.False(t, chatResult.LastError.Valid)
|
||||
|
||||
requests := factory.requestsSnapshot()
|
||||
require.NotEmpty(t, requests)
|
||||
require.Contains(t, requests, chatAIGatewayRecordedRequest{
|
||||
ProviderName: provider.Name,
|
||||
Source: aibridge.SourceAgents,
|
||||
APIKeyID: apiKey.ID,
|
||||
Path: "/v1/responses",
|
||||
Authorization: "Bearer sk-user-aibridge",
|
||||
CoderToken: "delegated",
|
||||
})
|
||||
for _, req := range requests {
|
||||
require.Equal(t, provider.Name, req.ProviderName)
|
||||
require.Equal(t, aibridge.SourceAgents, req.Source)
|
||||
require.Equal(t, apiKey.ID, req.APIKeyID)
|
||||
require.Equal(t, "Bearer sk-user-aibridge", req.Authorization)
|
||||
require.Empty(t, req.XAPIKey)
|
||||
require.Equal(t, "delegated", req.CoderToken)
|
||||
require.True(t, strings.HasPrefix(req.Path, "/v1/"), "unexpected aibridge path %q", req.Path)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProcessChat_UserProviderKey_MissingKeyError(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -8493,7 +8751,7 @@ func TestAgentContextFilesAndSkillsLoadedIntoChat(t *testing.T) {
|
||||
))
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitSuperLong)
|
||||
deploymentValues := coderdtest.DeploymentValues(t)
|
||||
deploymentValues := directChatRoutingDeploymentValues(t)
|
||||
client := coderdtest.New(t, &coderdtest.Options{
|
||||
DeploymentValues: deploymentValues,
|
||||
IncludeProvisionerDaemon: true,
|
||||
|
||||
@@ -60,10 +60,11 @@ func (p *Server) computerUseProviderAndModelFromConfig(
|
||||
func (p *Server) resolveComputerUseModel(
|
||||
ctx context.Context,
|
||||
chat database.Chat,
|
||||
providerKeys chatprovider.ProviderAPIKeys,
|
||||
route resolvedModelRoute,
|
||||
computerUseProvider string,
|
||||
computerUseModelProvider string,
|
||||
computerUseModelName string,
|
||||
modelOpts modelBuildOptions,
|
||||
) (
|
||||
model fantasy.LanguageModel,
|
||||
debugEnabled bool,
|
||||
@@ -84,15 +85,12 @@ func (p *Server) resolveComputerUseModel(
|
||||
)
|
||||
}
|
||||
|
||||
model, debugEnabled, err = p.newDebugAwareModelFromConfig(
|
||||
ctx,
|
||||
chat,
|
||||
computerUseModelProvider,
|
||||
computerUseModelName,
|
||||
providerKeys,
|
||||
chatprovider.UserAgent(),
|
||||
chatprovider.CoderHeaders(chat),
|
||||
)
|
||||
model, debugEnabled, err = p.newDebugAwareModel(ctx, modelClientRequest{
|
||||
Chat: chat,
|
||||
ModelName: computerUseModelName,
|
||||
UserAgent: chatprovider.UserAgent(),
|
||||
ExtraHeaders: chatprovider.CoderHeaders(chat),
|
||||
}, route, modelOpts)
|
||||
if err != nil {
|
||||
return nil, false, "", "", xerrors.Errorf(
|
||||
"resolve computer use model for provider %q model %q: %w",
|
||||
|
||||
@@ -0,0 +1,168 @@
|
||||
package chatd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
|
||||
)
|
||||
|
||||
type modelClientRequest struct {
|
||||
Chat database.Chat
|
||||
ModelName string
|
||||
UserAgent string
|
||||
ExtraHeaders map[string]string
|
||||
}
|
||||
|
||||
type modelBuildOptions struct {
|
||||
ActiveAPIKeyID string
|
||||
RecordHTTP bool
|
||||
}
|
||||
|
||||
func modelBuildOptionsFromMessages(messages []database.ChatMessage) modelBuildOptions {
|
||||
apiKeyID, _ := activeTurnAPIKeyIDFromMessages(messages)
|
||||
return modelBuildOptions{ActiveAPIKeyID: apiKeyID}
|
||||
}
|
||||
|
||||
type modelRouteKind int
|
||||
|
||||
const (
|
||||
modelRouteKindDirect modelRouteKind = iota + 1
|
||||
modelRouteKindAIGateway
|
||||
)
|
||||
|
||||
type resolvedModelRoute struct {
|
||||
kind modelRouteKind
|
||||
direct directModelRoute
|
||||
aiGateway aiGatewayModelRoute
|
||||
}
|
||||
|
||||
func newDirectModelRoute(providerHint string, keys chatprovider.ProviderAPIKeys) resolvedModelRoute {
|
||||
return resolvedModelRoute{
|
||||
kind: modelRouteKindDirect,
|
||||
direct: directModelRoute{
|
||||
ProviderHint: providerHint,
|
||||
Keys: keys,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (r resolvedModelRoute) providerHint() (string, error) {
|
||||
switch r.kind {
|
||||
case modelRouteKindDirect:
|
||||
return r.direct.ProviderHint, nil
|
||||
case modelRouteKindAIGateway:
|
||||
return r.aiGateway.ModelProviderHint, nil
|
||||
default:
|
||||
return "", xerrors.New("model route is not configured")
|
||||
}
|
||||
}
|
||||
|
||||
func (r resolvedModelRoute) withProviderHint(providerHint string) resolvedModelRoute {
|
||||
switch r.kind {
|
||||
case modelRouteKindDirect:
|
||||
r.direct.ProviderHint = providerHint
|
||||
case modelRouteKindAIGateway:
|
||||
r.aiGateway.ModelProviderHint = providerHint
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func (r resolvedModelRoute) directProviderKeys() chatprovider.ProviderAPIKeys {
|
||||
if r.kind != modelRouteKindDirect {
|
||||
return chatprovider.ProviderAPIKeys{}
|
||||
}
|
||||
return r.direct.Keys
|
||||
}
|
||||
|
||||
func (p *Server) enabledAIProviderByID(ctx context.Context, providerID uuid.UUID) (database.AIProvider, error) {
|
||||
provider, err := p.db.GetAIProviderByID(ctx, providerID)
|
||||
if err != nil {
|
||||
return database.AIProvider{}, xerrors.Errorf("get AI provider: %w", err)
|
||||
}
|
||||
if !provider.Enabled {
|
||||
return database.AIProvider{}, xerrors.Errorf("AI provider %s is disabled", provider.ID)
|
||||
}
|
||||
return provider, nil
|
||||
}
|
||||
|
||||
func (p *Server) shouldUseAIGatewayRouting() bool {
|
||||
return p.aiGatewayRoutingEnabled
|
||||
}
|
||||
|
||||
func (p *Server) resolveModelRouteForConfig(
|
||||
ctx context.Context,
|
||||
ownerID uuid.UUID,
|
||||
modelConfig database.ChatModelConfig,
|
||||
fallbackKeys chatprovider.ProviderAPIKeys,
|
||||
) (resolvedModelRoute, error) {
|
||||
if p.shouldUseAIGatewayRouting() {
|
||||
return p.resolveAIGatewayModelRouteForConfig(ctx, ownerID, modelConfig)
|
||||
}
|
||||
return p.resolveDirectModelRouteForConfig(ctx, ownerID, modelConfig, fallbackKeys)
|
||||
}
|
||||
|
||||
func (p *Server) resolveModelRouteForProviderType(
|
||||
ctx context.Context,
|
||||
ownerID uuid.UUID,
|
||||
providerType string,
|
||||
) (resolvedModelRoute, error) {
|
||||
if p.shouldUseAIGatewayRouting() {
|
||||
return p.resolveAIGatewayModelRouteForProviderType(ctx, ownerID, providerType)
|
||||
}
|
||||
return p.resolveDirectModelRouteForProviderType(ctx, ownerID, providerType)
|
||||
}
|
||||
|
||||
func (p *Server) newModel(
|
||||
ctx context.Context,
|
||||
req modelClientRequest,
|
||||
route resolvedModelRoute,
|
||||
opts modelBuildOptions,
|
||||
) (fantasy.LanguageModel, error) {
|
||||
switch route.kind {
|
||||
case modelRouteKindDirect:
|
||||
return p.newDirectModel(ctx, req, route.direct, opts)
|
||||
case modelRouteKindAIGateway:
|
||||
return p.newAIGatewayModel(ctx, req, route.aiGateway, opts)
|
||||
default:
|
||||
return nil, xerrors.New("model route is not configured")
|
||||
}
|
||||
}
|
||||
|
||||
func newLanguageModel(
|
||||
providerHint string,
|
||||
modelName string,
|
||||
providerKeys chatprovider.ProviderAPIKeys,
|
||||
userAgent string,
|
||||
extraHeaders map[string]string,
|
||||
httpClient *http.Client,
|
||||
) (fantasy.LanguageModel, error) {
|
||||
model, err := chatprovider.ModelFromConfig(
|
||||
providerHint,
|
||||
modelName,
|
||||
providerKeys,
|
||||
userAgent,
|
||||
extraHeaders,
|
||||
httpClient,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if model == nil {
|
||||
provider, resolvedModel, resolveErr := chatprovider.ResolveModelWithProviderHint(modelName, providerHint)
|
||||
if resolveErr != nil {
|
||||
return nil, resolveErr
|
||||
}
|
||||
return nil, xerrors.Errorf(
|
||||
"create model for %s/%s returned nil",
|
||||
provider,
|
||||
resolvedModel,
|
||||
)
|
||||
}
|
||||
return model, nil
|
||||
}
|
||||
@@ -0,0 +1,292 @@
|
||||
package chatd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"charm.land/fantasy"
|
||||
fantasyanthropic "charm.land/fantasy/providers/anthropic"
|
||||
fantasyopenai "charm.land/fantasy/providers/openai"
|
||||
fantasyopenaicompat "charm.land/fantasy/providers/openaicompat"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/aibridge"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatdebug"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
|
||||
)
|
||||
|
||||
const (
|
||||
aibridgeLocalBaseURL = "http://coder-aibridge"
|
||||
// aibridgePlaceholderAPIKey satisfies fantasy clients that require a
|
||||
// non-empty API key before aibridged resolves the real credential.
|
||||
aibridgePlaceholderAPIKey = "coder-aibridge"
|
||||
aibridgeDelegatedBYOKMarker = "delegated"
|
||||
)
|
||||
|
||||
type aiGatewayModelRoute struct {
|
||||
Provider database.AIProvider
|
||||
ModelProviderHint string
|
||||
ProviderAuth aiGatewayProviderAuth
|
||||
}
|
||||
|
||||
func newAIGatewayModelRoute(
|
||||
provider database.AIProvider,
|
||||
modelProviderHint string,
|
||||
auth aiGatewayProviderAuth,
|
||||
) resolvedModelRoute {
|
||||
return resolvedModelRoute{
|
||||
kind: modelRouteKindAIGateway,
|
||||
aiGateway: aiGatewayModelRoute{
|
||||
Provider: provider,
|
||||
ModelProviderHint: modelProviderHint,
|
||||
ProviderAuth: auth,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type aiGatewayProviderAuth struct {
|
||||
Headers map[string]string
|
||||
}
|
||||
|
||||
func (aiGatewayProviderAuth) String() string {
|
||||
return "aiGatewayProviderAuth{Headers:<redacted>}"
|
||||
}
|
||||
|
||||
func (a aiGatewayProviderAuth) GoString() string {
|
||||
return a.String()
|
||||
}
|
||||
|
||||
type aiGatewayRequestFormat int
|
||||
|
||||
const (
|
||||
aiGatewayRequestFormatOpenAI aiGatewayRequestFormat = iota
|
||||
aiGatewayRequestFormatAnthropic
|
||||
)
|
||||
|
||||
type aiGatewayRoundTripper struct {
|
||||
base http.RoundTripper
|
||||
apiKeyID string
|
||||
providerAuth aiGatewayProviderAuth
|
||||
}
|
||||
|
||||
func (t *aiGatewayRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
ctx := aibridge.WithDelegatedAPIKeyID(req.Context(), t.apiKeyID)
|
||||
cloned := req.Clone(ctx)
|
||||
for name, value := range t.providerAuth.Headers {
|
||||
cloned.Header.Set(name, value)
|
||||
}
|
||||
if len(t.providerAuth.Headers) > 0 {
|
||||
cloned.Header.Set(aibridge.HeaderCoderToken, aibridgeDelegatedBYOKMarker)
|
||||
}
|
||||
return t.base.RoundTrip(cloned)
|
||||
}
|
||||
|
||||
func (p *Server) newAIGatewayModel(
|
||||
_ context.Context,
|
||||
req modelClientRequest,
|
||||
route aiGatewayModelRoute,
|
||||
opts modelBuildOptions,
|
||||
) (fantasy.LanguageModel, error) {
|
||||
if route.Provider.ID == uuid.Nil {
|
||||
return nil, xerrors.New("AI Gateway routing requires a concrete AI provider")
|
||||
}
|
||||
if route.Provider.Name == "" {
|
||||
return nil, xerrors.New("AI Gateway routing requires an AI provider name")
|
||||
}
|
||||
if opts.ActiveAPIKeyID == "" {
|
||||
return nil, xerrors.New("AI Gateway routing requires the active turn API key ID")
|
||||
}
|
||||
|
||||
factoryPtr := p.aibridgeTransportFactory
|
||||
if factoryPtr == nil {
|
||||
return nil, xerrors.New("AI Gateway transport factory is not configured")
|
||||
}
|
||||
factory := factoryPtr.Load()
|
||||
if factory == nil || *factory == nil {
|
||||
return nil, xerrors.New("AI Gateway transport factory is not configured")
|
||||
}
|
||||
rt, err := (*factory).TransportFor(route.Provider.Name, aibridge.SourceAgents)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("create AI Gateway transport: %w", err)
|
||||
}
|
||||
baseRT := http.RoundTripper(&aiGatewayRoundTripper{
|
||||
base: rt,
|
||||
apiKeyID: opts.ActiveAPIKeyID,
|
||||
providerAuth: route.ProviderAuth,
|
||||
})
|
||||
if opts.RecordHTTP {
|
||||
baseRT = &chatdebug.RecordingTransport{Base: baseRT}
|
||||
}
|
||||
|
||||
config := fantasyConfigForAIBridge(route.Provider.Type)
|
||||
return newLanguageModel(
|
||||
config.ProviderHint,
|
||||
req.ModelName,
|
||||
config.Keys,
|
||||
req.UserAgent,
|
||||
req.ExtraHeaders,
|
||||
&http.Client{Transport: baseRT},
|
||||
)
|
||||
}
|
||||
|
||||
type aibridgeFantasyConfig struct {
|
||||
ProviderHint string
|
||||
Keys chatprovider.ProviderAPIKeys
|
||||
}
|
||||
|
||||
func fantasyConfigForAIBridge(providerType database.AIProviderType) aibridgeFantasyConfig {
|
||||
var fantasyProvider string
|
||||
baseURL := aibridgeLocalBaseURL + "/v1"
|
||||
switch providerType {
|
||||
case database.AiProviderTypeAnthropic, database.AiProviderTypeBedrock:
|
||||
fantasyProvider = fantasyanthropic.Name
|
||||
baseURL = aibridgeLocalBaseURL
|
||||
case database.AiProviderTypeOpenai:
|
||||
fantasyProvider = fantasyopenai.Name
|
||||
default:
|
||||
fantasyProvider = fantasyopenaicompat.Name
|
||||
}
|
||||
return aibridgeFantasyConfig{
|
||||
ProviderHint: fantasyProvider,
|
||||
Keys: chatprovider.ProviderAPIKeys{
|
||||
ByProvider: map[string]string{
|
||||
fantasyProvider: aibridgePlaceholderAPIKey,
|
||||
},
|
||||
BaseURLByProvider: map[string]string{
|
||||
fantasyProvider: baseURL,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func aiGatewayRequestFormatForProviderType(providerType database.AIProviderType) aiGatewayRequestFormat {
|
||||
switch providerType {
|
||||
case database.AiProviderTypeAnthropic, database.AiProviderTypeBedrock:
|
||||
return aiGatewayRequestFormatAnthropic
|
||||
default:
|
||||
return aiGatewayRequestFormatOpenAI
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Server) aiGatewayProviderAuthForUser(
|
||||
ctx context.Context,
|
||||
ownerID uuid.UUID,
|
||||
provider database.AIProvider,
|
||||
format aiGatewayRequestFormat,
|
||||
) (aiGatewayProviderAuth, error) {
|
||||
if !p.allowBYOK {
|
||||
return aiGatewayProviderAuth{}, nil
|
||||
}
|
||||
userKey, err := p.db.GetUserAIProviderKeyByProviderID(ctx, database.GetUserAIProviderKeyByProviderIDParams{
|
||||
UserID: ownerID,
|
||||
AIProviderID: provider.ID,
|
||||
})
|
||||
if err != nil {
|
||||
if xerrors.Is(err, sql.ErrNoRows) {
|
||||
return aiGatewayProviderAuth{}, nil
|
||||
}
|
||||
return aiGatewayProviderAuth{}, xerrors.Errorf("get user AI provider key: %w", err)
|
||||
}
|
||||
apiKey := strings.TrimSpace(userKey.APIKey)
|
||||
if apiKey == "" {
|
||||
return aiGatewayProviderAuth{}, nil
|
||||
}
|
||||
|
||||
headers := map[string]string{}
|
||||
switch format {
|
||||
case aiGatewayRequestFormatAnthropic:
|
||||
headers["X-Api-Key"] = apiKey
|
||||
default:
|
||||
headers["Authorization"] = "Bearer " + apiKey
|
||||
}
|
||||
return aiGatewayProviderAuth{Headers: headers}, nil
|
||||
}
|
||||
|
||||
func (p *Server) resolveAIGatewayRoute(
|
||||
ctx context.Context,
|
||||
ownerID uuid.UUID,
|
||||
provider database.AIProvider,
|
||||
modelProviderHint string,
|
||||
) (resolvedModelRoute, error) {
|
||||
auth, err := p.aiGatewayProviderAuthForUser(
|
||||
ctx,
|
||||
ownerID,
|
||||
provider,
|
||||
aiGatewayRequestFormatForProviderType(provider.Type),
|
||||
)
|
||||
if err != nil {
|
||||
return resolvedModelRoute{}, xerrors.Errorf("resolve AI Gateway provider auth: %w", err)
|
||||
}
|
||||
return newAIGatewayModelRoute(provider, modelProviderHint, auth), nil
|
||||
}
|
||||
|
||||
func (p *Server) resolveAIGatewayModelRouteForConfig(
|
||||
ctx context.Context,
|
||||
ownerID uuid.UUID,
|
||||
modelConfig database.ChatModelConfig,
|
||||
) (resolvedModelRoute, error) {
|
||||
provider, err := p.gatewayProviderForConfig(ctx, modelConfig)
|
||||
if err != nil {
|
||||
return resolvedModelRoute{}, err
|
||||
}
|
||||
return p.resolveAIGatewayRoute(ctx, ownerID, provider, string(provider.Type))
|
||||
}
|
||||
|
||||
func (p *Server) resolveAIGatewayModelRouteForProviderType(
|
||||
ctx context.Context,
|
||||
ownerID uuid.UUID,
|
||||
providerType string,
|
||||
) (resolvedModelRoute, error) {
|
||||
provider, err := p.aiProviderForProviderType(ctx, providerType)
|
||||
if err != nil {
|
||||
return resolvedModelRoute{}, err
|
||||
}
|
||||
return p.resolveAIGatewayRoute(
|
||||
ctx,
|
||||
ownerID,
|
||||
provider,
|
||||
chatprovider.NormalizeProvider(providerType),
|
||||
)
|
||||
}
|
||||
|
||||
func (p *Server) gatewayProviderForConfig(
|
||||
ctx context.Context,
|
||||
modelConfig database.ChatModelConfig,
|
||||
) (database.AIProvider, error) {
|
||||
if !modelConfig.AIProviderID.Valid {
|
||||
return database.AIProvider{}, xerrors.Errorf(
|
||||
"AI Gateway routing requires AI provider metadata for model config %s (%s)",
|
||||
modelConfig.ID,
|
||||
modelConfig.Model,
|
||||
)
|
||||
}
|
||||
return p.enabledAIProviderByID(ctx, modelConfig.AIProviderID.UUID)
|
||||
}
|
||||
|
||||
func (p *Server) aiProviderForProviderType(
|
||||
ctx context.Context,
|
||||
providerType string,
|
||||
) (database.AIProvider, error) {
|
||||
providers, err := p.db.GetAIProviders(ctx, database.GetAIProvidersParams{})
|
||||
if err != nil {
|
||||
return database.AIProvider{}, xerrors.Errorf("get enabled AI providers: %w", err)
|
||||
}
|
||||
normalizedProviderType := chatprovider.NormalizeProvider(providerType)
|
||||
for _, provider := range providers {
|
||||
if !provider.Enabled {
|
||||
continue
|
||||
}
|
||||
if chatprovider.NormalizeProvider(string(provider.Type)) != normalizedProviderType {
|
||||
continue
|
||||
}
|
||||
return provider, nil
|
||||
}
|
||||
return database.AIProvider{}, xerrors.Errorf(
|
||||
"AI Gateway routing requires a usable AI provider for provider type %q",
|
||||
providerType,
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,93 @@
|
||||
package chatd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatdebug"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
|
||||
)
|
||||
|
||||
type directModelRoute struct {
|
||||
ProviderHint string
|
||||
Keys chatprovider.ProviderAPIKeys
|
||||
}
|
||||
|
||||
func (*Server) newDirectModel(
|
||||
_ context.Context,
|
||||
req modelClientRequest,
|
||||
route directModelRoute,
|
||||
opts modelBuildOptions,
|
||||
) (fantasy.LanguageModel, error) {
|
||||
var httpClient *http.Client
|
||||
if opts.RecordHTTP {
|
||||
httpClient = &http.Client{Transport: &chatdebug.RecordingTransport{}}
|
||||
}
|
||||
return newLanguageModel(
|
||||
route.ProviderHint,
|
||||
req.ModelName,
|
||||
route.Keys,
|
||||
req.UserAgent,
|
||||
req.ExtraHeaders,
|
||||
httpClient,
|
||||
)
|
||||
}
|
||||
|
||||
func (p *Server) resolveDirectModelRouteForConfig(
|
||||
ctx context.Context,
|
||||
ownerID uuid.UUID,
|
||||
modelConfig database.ChatModelConfig,
|
||||
fallbackKeys chatprovider.ProviderAPIKeys,
|
||||
) (resolvedModelRoute, error) {
|
||||
providerHint, provider, err := p.directProviderHintAndProviderForConfig(ctx, modelConfig)
|
||||
if err != nil {
|
||||
return resolvedModelRoute{}, err
|
||||
}
|
||||
if provider == nil {
|
||||
if !fallbackKeys.Empty() && userCanUseProviderKeys(fallbackKeys, providerHint) {
|
||||
return newDirectModelRoute(providerHint, fallbackKeys), nil
|
||||
}
|
||||
keys, err := p.resolveUserProviderAPIKeys(ctx, ownerID, uuid.Nil)
|
||||
if err != nil {
|
||||
return resolvedModelRoute{}, xerrors.Errorf("resolve provider API keys: %w", err)
|
||||
}
|
||||
return newDirectModelRoute(providerHint, keys), nil
|
||||
}
|
||||
providerKeys, err := p.resolveUserProviderAPIKeysForProvider(ctx, ownerID, *provider)
|
||||
if err != nil {
|
||||
return resolvedModelRoute{}, xerrors.Errorf("resolve provider API keys: %w", err)
|
||||
}
|
||||
return newDirectModelRoute(providerHint, providerKeys), nil
|
||||
}
|
||||
|
||||
func (p *Server) resolveDirectModelRouteForProviderType(
|
||||
ctx context.Context,
|
||||
ownerID uuid.UUID,
|
||||
providerType string,
|
||||
) (resolvedModelRoute, error) {
|
||||
normalizedProviderType := chatprovider.NormalizeProvider(providerType)
|
||||
keys, _, err := p.resolveUserProviderAPIKeysAndProviderForProviderType(ctx, ownerID, providerType)
|
||||
if err != nil {
|
||||
return resolvedModelRoute{}, err
|
||||
}
|
||||
return newDirectModelRoute(normalizedProviderType, keys), nil
|
||||
}
|
||||
|
||||
func (p *Server) directProviderHintAndProviderForConfig(
|
||||
ctx context.Context,
|
||||
modelConfig database.ChatModelConfig,
|
||||
) (string, *database.AIProvider, error) {
|
||||
if !modelConfig.AIProviderID.Valid {
|
||||
return modelConfig.Provider, nil, nil
|
||||
}
|
||||
provider, err := p.enabledAIProviderByID(ctx, modelConfig.AIProviderID.UUID)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
return string(provider.Type), &provider, nil
|
||||
}
|
||||
@@ -0,0 +1,640 @@
|
||||
package chatd
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
|
||||
"charm.land/fantasy"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.uber.org/mock/gomock"
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/aibridge"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
"github.com/coder/coder/v2/coderd/database/dbmock"
|
||||
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chatprovider"
|
||||
"github.com/coder/coder/v2/coderd/x/chatd/chattool"
|
||||
)
|
||||
|
||||
type aibridgeTestFactory struct {
|
||||
providerName string
|
||||
source aibridge.Source
|
||||
err error
|
||||
rt http.RoundTripper
|
||||
}
|
||||
|
||||
func (f *aibridgeTestFactory) TransportFor(providerName string, source aibridge.Source) (http.RoundTripper, error) {
|
||||
f.providerName = providerName
|
||||
f.source = source
|
||||
if f.err != nil {
|
||||
return nil, f.err
|
||||
}
|
||||
return f.rt, nil
|
||||
}
|
||||
|
||||
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
return f(req)
|
||||
}
|
||||
|
||||
func aibridgeTestFactoryPointer(factory aibridge.TransportFactory) *atomic.Pointer[aibridge.TransportFactory] {
|
||||
var ptr atomic.Pointer[aibridge.TransportFactory]
|
||||
ptr.Store(&factory)
|
||||
return &ptr
|
||||
}
|
||||
|
||||
func aibridgeTestAIProvider(providerID uuid.UUID, providerName string, providerType database.AIProviderType) database.AIProvider {
|
||||
return database.AIProvider{
|
||||
ID: providerID,
|
||||
Name: providerName,
|
||||
Type: providerType,
|
||||
Enabled: true,
|
||||
}
|
||||
}
|
||||
|
||||
func aibridgeTestRoute(aiProvider database.AIProvider) resolvedModelRoute {
|
||||
return newAIGatewayModelRoute(aiProvider, string(aiProvider.Type), aiGatewayProviderAuth{})
|
||||
}
|
||||
|
||||
func aibridgeTestRequest(chat database.Chat, model string) modelClientRequest {
|
||||
return modelClientRequest{
|
||||
Chat: chat,
|
||||
ModelName: model,
|
||||
UserAgent: chatprovider.UserAgent(),
|
||||
}
|
||||
}
|
||||
|
||||
func TestAIBridgeProviderFormatMapping(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
providerType database.AIProviderType
|
||||
wantProvider string
|
||||
wantBaseURL string
|
||||
}{
|
||||
{name: "OpenAI", providerType: database.AiProviderTypeOpenai, wantProvider: "openai", wantBaseURL: "http://coder-aibridge/v1"},
|
||||
{name: "Anthropic", providerType: database.AiProviderTypeAnthropic, wantProvider: "anthropic", wantBaseURL: "http://coder-aibridge"},
|
||||
{name: "Bedrock", providerType: database.AiProviderTypeBedrock, wantProvider: "anthropic", wantBaseURL: "http://coder-aibridge"},
|
||||
{name: "Google", providerType: database.AiProviderTypeGoogle, wantProvider: "openai-compat", wantBaseURL: "http://coder-aibridge/v1"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
config := fantasyConfigForAIBridge(tt.providerType)
|
||||
require.Equal(t, tt.wantProvider, config.ProviderHint)
|
||||
require.Equal(t, tt.wantBaseURL, config.Keys.BaseURL(config.ProviderHint))
|
||||
require.Equal(t, aibridgePlaceholderAPIKey, config.Keys.APIKey(config.ProviderHint))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveModelRouteForConfigPreservesBaseURL(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := t.Context()
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
ownerID := uuid.New()
|
||||
providerID := uuid.New()
|
||||
baseURL := "https://openai.example.com/v1"
|
||||
|
||||
db.EXPECT().GetAIProviderByID(gomock.Any(), providerID).Return(database.AIProvider{
|
||||
ID: providerID,
|
||||
Type: database.AiProviderTypeOpenai,
|
||||
Name: "primary-openai",
|
||||
Enabled: true,
|
||||
BaseUrl: baseURL,
|
||||
}, nil)
|
||||
db.EXPECT().GetAIProviderKeysByProviderID(gomock.Any(), providerID).Return([]database.AIProviderKey{{
|
||||
ProviderID: providerID,
|
||||
APIKey: "provider-key",
|
||||
}}, nil)
|
||||
|
||||
server := &Server{db: db}
|
||||
route, err := server.resolveModelRouteForConfig(ctx, ownerID, database.ChatModelConfig{
|
||||
Provider: "openai",
|
||||
AIProviderID: uuid.NullUUID{UUID: providerID, Valid: true},
|
||||
}, chatprovider.ProviderAPIKeys{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, modelRouteKindDirect, route.kind)
|
||||
require.Equal(t, "openai", route.direct.ProviderHint)
|
||||
require.Equal(t, "provider-key", route.direct.Keys.APIKey("openai"))
|
||||
require.Equal(t, baseURL, route.direct.Keys.BaseURL("openai"))
|
||||
}
|
||||
|
||||
func TestAIGatewayProviderAuthForUser(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := t.Context()
|
||||
ownerID := uuid.New()
|
||||
providerID := uuid.New()
|
||||
provider := database.AIProvider{ID: providerID, Type: database.AiProviderTypeOpenai, Enabled: true}
|
||||
|
||||
t.Run("OpenAIUserKey", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
db.EXPECT().GetUserAIProviderKeyByProviderID(gomock.Any(), database.GetUserAIProviderKeyByProviderIDParams{
|
||||
UserID: ownerID,
|
||||
AIProviderID: providerID,
|
||||
}).Return(database.UserAiProviderKey{APIKey: "sk-user"}, nil)
|
||||
|
||||
server := &Server{db: db, allowBYOK: true}
|
||||
auth, err := server.aiGatewayProviderAuthForUser(ctx, ownerID, provider, aiGatewayRequestFormatOpenAI)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "Bearer sk-user", auth.Headers["Authorization"])
|
||||
require.Empty(t, auth.Headers["X-Api-Key"])
|
||||
})
|
||||
|
||||
t.Run("AnthropicUserKey", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
db.EXPECT().GetUserAIProviderKeyByProviderID(gomock.Any(), database.GetUserAIProviderKeyByProviderIDParams{
|
||||
UserID: ownerID,
|
||||
AIProviderID: providerID,
|
||||
}).Return(database.UserAiProviderKey{APIKey: "sk-user"}, nil)
|
||||
|
||||
server := &Server{db: db, allowBYOK: true}
|
||||
auth, err := server.aiGatewayProviderAuthForUser(ctx, ownerID, provider, aiGatewayRequestFormatAnthropic)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "sk-user", auth.Headers["X-Api-Key"])
|
||||
require.Empty(t, auth.Headers["Authorization"])
|
||||
})
|
||||
|
||||
t.Run("NoUserKey", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
db.EXPECT().GetUserAIProviderKeyByProviderID(gomock.Any(), database.GetUserAIProviderKeyByProviderIDParams{
|
||||
UserID: ownerID,
|
||||
AIProviderID: providerID,
|
||||
}).Return(database.UserAiProviderKey{}, sql.ErrNoRows)
|
||||
|
||||
server := &Server{db: db, allowBYOK: true}
|
||||
auth, err := server.aiGatewayProviderAuthForUser(ctx, ownerID, provider, aiGatewayRequestFormatOpenAI)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, auth.Headers)
|
||||
})
|
||||
|
||||
t.Run("BYOKDisabled", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
server := &Server{db: db, allowBYOK: false}
|
||||
auth, err := server.aiGatewayProviderAuthForUser(ctx, ownerID, provider, aiGatewayRequestFormatOpenAI)
|
||||
require.NoError(t, err)
|
||||
require.Empty(t, auth.Headers)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAIGatewayProviderAuthRedactsFormatting(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
auth := aiGatewayProviderAuth{Headers: map[string]string{
|
||||
"Authorization": "Bearer sk-user",
|
||||
"X-Api-Key": "sk-user",
|
||||
}}
|
||||
for _, formatted := range []string{
|
||||
fmt.Sprint(auth),
|
||||
fmt.Sprintf("%+v", auth),
|
||||
fmt.Sprintf("%#v", auth),
|
||||
} {
|
||||
require.NotContains(t, formatted, "sk-user")
|
||||
require.NotContains(t, formatted, "Bearer sk-user")
|
||||
require.Contains(t, formatted, "redacted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveModelRouteForConfigAIGatewayProviderAuth(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := t.Context()
|
||||
ownerID := uuid.New()
|
||||
providerID := uuid.New()
|
||||
provider := database.AIProvider{
|
||||
ID: providerID,
|
||||
Type: database.AiProviderTypeOpenai,
|
||||
Name: "primary-openai",
|
||||
Enabled: true,
|
||||
}
|
||||
modelConfig := database.ChatModelConfig{
|
||||
ID: uuid.New(),
|
||||
Model: "gpt-4",
|
||||
Provider: "openai",
|
||||
AIProviderID: uuid.NullUUID{UUID: providerID, Valid: true},
|
||||
}
|
||||
|
||||
t.Run("UserKey", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
db.EXPECT().GetAIProviderByID(gomock.Any(), providerID).Return(provider, nil)
|
||||
db.EXPECT().GetUserAIProviderKeyByProviderID(gomock.Any(), database.GetUserAIProviderKeyByProviderIDParams{
|
||||
UserID: ownerID,
|
||||
AIProviderID: providerID,
|
||||
}).Return(database.UserAiProviderKey{APIKey: "sk-user"}, nil)
|
||||
|
||||
server := &Server{db: db, aiGatewayRoutingEnabled: true, allowBYOK: true}
|
||||
route, err := server.resolveModelRouteForConfig(ctx, ownerID, modelConfig, chatprovider.ProviderAPIKeys{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, modelRouteKindAIGateway, route.kind)
|
||||
require.Equal(t, "Bearer sk-user", route.aiGateway.ProviderAuth.Headers["Authorization"])
|
||||
})
|
||||
|
||||
t.Run("CentralProviderCredentialsNotForwarded", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctrl := gomock.NewController(t)
|
||||
db := dbmock.NewMockStore(ctrl)
|
||||
db.EXPECT().GetAIProviderByID(gomock.Any(), providerID).Return(provider, nil)
|
||||
|
||||
server := &Server{db: db, aiGatewayRoutingEnabled: true, allowBYOK: false}
|
||||
route, err := server.resolveModelRouteForConfig(ctx, ownerID, modelConfig, chatprovider.ProviderAPIKeys{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, modelRouteKindAIGateway, route.kind)
|
||||
require.Empty(t, route.aiGateway.ProviderAuth.Headers)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAIGatewayModelForwardsProviderAuth(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
type seenRequest struct {
|
||||
authorization string
|
||||
xAPIKey string
|
||||
coderToken string
|
||||
apiKeyID string
|
||||
path string
|
||||
}
|
||||
newServer := func(t *testing.T, provider database.AIProvider, auth aiGatewayProviderAuth, seen chan seenRequest) (*Server, resolvedModelRoute) {
|
||||
factory := &aibridgeTestFactory{rt: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
apiKeyID, _ := aibridge.DelegatedAPIKeyIDFromContext(req.Context())
|
||||
seen <- seenRequest{
|
||||
authorization: req.Header.Get("Authorization"),
|
||||
xAPIKey: req.Header.Get("X-Api-Key"),
|
||||
coderToken: req.Header.Get(aibridge.HeaderCoderToken),
|
||||
apiKeyID: apiKeyID,
|
||||
path: req.URL.Path,
|
||||
}
|
||||
body := `{"id":"resp_test","object":"response","created_at":0,"status":"completed","model":"gpt-4","output":[{"id":"msg_test","type":"message","role":"assistant","content":[{"type":"output_text","text":"hello"}]}],"usage":{"input_tokens":1,"output_tokens":1,"total_tokens":2}}`
|
||||
if provider.Type == database.AiProviderTypeAnthropic {
|
||||
body = `{"id":"msg_test","type":"message","role":"assistant","model":"claude-haiku-4-5","content":[{"type":"text","text":"hello"}],"stop_reason":"end_turn","stop_sequence":null,"usage":{"input_tokens":1,"output_tokens":1}}`
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader(body)),
|
||||
Request: req,
|
||||
}, nil
|
||||
})}
|
||||
server := &Server{
|
||||
aiGatewayRoutingEnabled: true,
|
||||
aibridgeTransportFactory: aibridgeTestFactoryPointer(factory),
|
||||
}
|
||||
route := newAIGatewayModelRoute(provider, string(provider.Type), auth)
|
||||
return server, route
|
||||
}
|
||||
|
||||
t.Run("OpenAI", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
seen := make(chan seenRequest, 1)
|
||||
provider := aibridgeTestAIProvider(uuid.New(), "primary-openai", database.AiProviderTypeOpenai)
|
||||
server, route := newServer(t, provider, aiGatewayProviderAuth{
|
||||
Headers: map[string]string{"Authorization": "Bearer sk-user"},
|
||||
}, seen)
|
||||
apiKeyID := uuid.NewString()
|
||||
model, err := server.newModel(t.Context(), aibridgeTestRequest(database.Chat{ID: uuid.New(), OwnerID: uuid.New()}, "gpt-4"), route, modelBuildOptions{ActiveAPIKeyID: apiKeyID, RecordHTTP: true})
|
||||
require.NoError(t, err)
|
||||
_, err = model.Generate(t.Context(), fantasy.Call{Prompt: []fantasy.Message{{Role: fantasy.MessageRoleUser, Content: []fantasy.MessagePart{fantasy.TextPart{Text: "hello"}}}}})
|
||||
require.NoError(t, err)
|
||||
|
||||
got := <-seen
|
||||
require.Equal(t, "Bearer sk-user", got.authorization)
|
||||
require.Empty(t, got.xAPIKey)
|
||||
require.Equal(t, aibridgeDelegatedBYOKMarker, got.coderToken)
|
||||
require.Equal(t, apiKeyID, got.apiKeyID)
|
||||
require.Equal(t, "/v1/responses", got.path)
|
||||
})
|
||||
|
||||
t.Run("Anthropic", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
seen := make(chan seenRequest, 1)
|
||||
provider := aibridgeTestAIProvider(uuid.New(), "primary-anthropic", database.AiProviderTypeAnthropic)
|
||||
server, route := newServer(t, provider, aiGatewayProviderAuth{
|
||||
Headers: map[string]string{"X-Api-Key": "sk-user"},
|
||||
}, seen)
|
||||
apiKeyID := uuid.NewString()
|
||||
model, err := server.newModel(t.Context(), aibridgeTestRequest(database.Chat{ID: uuid.New(), OwnerID: uuid.New()}, "claude-haiku-4-5"), route, modelBuildOptions{ActiveAPIKeyID: apiKeyID})
|
||||
require.NoError(t, err)
|
||||
_, err = model.Generate(t.Context(), fantasy.Call{Prompt: []fantasy.Message{{Role: fantasy.MessageRoleUser, Content: []fantasy.MessagePart{fantasy.TextPart{Text: "hello"}}}}})
|
||||
require.NoError(t, err)
|
||||
|
||||
got := <-seen
|
||||
require.Equal(t, "sk-user", got.xAPIKey)
|
||||
require.Equal(t, aibridgeDelegatedBYOKMarker, got.coderToken)
|
||||
require.Equal(t, apiKeyID, got.apiKeyID)
|
||||
require.Equal(t, "/v1/messages", got.path)
|
||||
})
|
||||
|
||||
t.Run("NoUserKeyLeavesPlaceholderForAIBridged", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
seen := make(chan seenRequest, 1)
|
||||
provider := aibridgeTestAIProvider(uuid.New(), "primary-openai", database.AiProviderTypeOpenai)
|
||||
server, route := newServer(t, provider, aiGatewayProviderAuth{}, seen)
|
||||
apiKeyID := uuid.NewString()
|
||||
model, err := server.newModel(t.Context(), aibridgeTestRequest(database.Chat{ID: uuid.New(), OwnerID: uuid.New()}, "gpt-4"), route, modelBuildOptions{ActiveAPIKeyID: apiKeyID})
|
||||
require.NoError(t, err)
|
||||
_, err = model.Generate(t.Context(), fantasy.Call{Prompt: []fantasy.Message{{Role: fantasy.MessageRoleUser, Content: []fantasy.MessagePart{fantasy.TextPart{Text: "hello"}}}}})
|
||||
require.NoError(t, err)
|
||||
|
||||
got := <-seen
|
||||
require.Equal(t, "Bearer "+aibridgePlaceholderAPIKey, got.authorization)
|
||||
require.Empty(t, got.xAPIKey)
|
||||
require.Empty(t, got.coderToken)
|
||||
require.Equal(t, apiKeyID, got.apiKeyID)
|
||||
})
|
||||
}
|
||||
|
||||
func TestActiveTurnAPIKeyIDFromMessages(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
oldKeyID := uuid.NewString()
|
||||
currentKeyID := uuid.NewString()
|
||||
tests := []struct {
|
||||
name string
|
||||
messages []database.ChatMessage
|
||||
wantKey string
|
||||
wantOK bool
|
||||
}{
|
||||
{
|
||||
name: "CurrentUserMessage",
|
||||
messages: []database.ChatMessage{
|
||||
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityBoth, APIKeyID: sqlNullString(oldKeyID)},
|
||||
{ID: 2, Role: database.ChatMessageRoleAssistant, Visibility: database.ChatMessageVisibilityBoth},
|
||||
{ID: 3, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityBoth, APIKeyID: sqlNullString(currentKeyID)},
|
||||
},
|
||||
wantKey: currentKeyID,
|
||||
wantOK: true,
|
||||
},
|
||||
{
|
||||
name: "MissingCurrentUserAPIKeyDoesNotFallBack",
|
||||
messages: []database.ChatMessage{
|
||||
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityBoth, APIKeyID: sqlNullString(oldKeyID)},
|
||||
{ID: 2, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityBoth},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "SkipsModelOnlyUserMessages",
|
||||
messages: []database.ChatMessage{
|
||||
{ID: 1, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityBoth, APIKeyID: sqlNullString(oldKeyID)},
|
||||
{ID: 2, Role: database.ChatMessageRoleUser, Visibility: database.ChatMessageVisibilityModel, APIKeyID: sqlNullString(currentKeyID)},
|
||||
},
|
||||
wantKey: oldKeyID,
|
||||
wantOK: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
gotKey, gotOK := activeTurnAPIKeyIDFromMessages(tt.messages)
|
||||
require.Equal(t, tt.wantOK, gotOK)
|
||||
require.Equal(t, tt.wantKey, gotKey)
|
||||
ctx := contextWithActiveTurnAPIKeyID(t.Context(), tt.messages)
|
||||
ctxKey, ctxOK := aibridge.DelegatedAPIKeyIDFromContext(ctx)
|
||||
require.Equal(t, tt.wantOK, ctxOK)
|
||||
require.Equal(t, tt.wantKey, ctxKey)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestActiveTurnContextUsesPromptMessages(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
ctx := t.Context()
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
model := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{})
|
||||
chat := dbgen.Chat(t, db, database.Chat{OrganizationID: org.ID, OwnerID: user.ID, LastModelConfigID: model.ID})
|
||||
oldKey, _ := dbgen.APIKey(t, db, database.APIKey{UserID: user.ID})
|
||||
currentKey, _ := dbgen.APIKey(t, db, database.APIKey{UserID: user.ID})
|
||||
modelOnlyKey, _ := dbgen.APIKey(t, db, database.APIKey{UserID: user.ID})
|
||||
|
||||
dbgen.ChatMessage(t, db, database.ChatMessage{
|
||||
ChatID: chat.ID,
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
|
||||
Role: database.ChatMessageRoleUser,
|
||||
Visibility: database.ChatMessageVisibilityBoth,
|
||||
APIKeyID: sqlNullString(oldKey.ID),
|
||||
})
|
||||
dbgen.ChatMessage(t, db, database.ChatMessage{
|
||||
ChatID: chat.ID,
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
|
||||
Role: database.ChatMessageRoleSystem,
|
||||
Visibility: database.ChatMessageVisibilityModel,
|
||||
Compressed: true,
|
||||
})
|
||||
dbgen.ChatMessage(t, db, database.ChatMessage{
|
||||
ChatID: chat.ID,
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
|
||||
Role: database.ChatMessageRoleUser,
|
||||
Visibility: database.ChatMessageVisibilityBoth,
|
||||
APIKeyID: sqlNullString(currentKey.ID),
|
||||
})
|
||||
dbgen.ChatMessage(t, db, database.ChatMessage{
|
||||
ChatID: chat.ID,
|
||||
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
|
||||
ModelConfigID: uuid.NullUUID{UUID: model.ID, Valid: true},
|
||||
Role: database.ChatMessageRoleUser,
|
||||
Visibility: database.ChatMessageVisibilityModel,
|
||||
APIKeyID: sqlNullString(modelOnlyKey.ID),
|
||||
})
|
||||
|
||||
messages, err := db.GetChatMessagesForPromptByChatID(ctx, chat.ID)
|
||||
require.NoError(t, err)
|
||||
ctx = contextWithActiveTurnAPIKeyID(ctx, messages)
|
||||
gotKey, ok := aibridge.DelegatedAPIKeyIDFromContext(ctx)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, currentKey.ID, gotKey)
|
||||
}
|
||||
|
||||
func sqlNullString(value string) sql.NullString {
|
||||
return sql.NullString{String: value, Valid: value != ""}
|
||||
}
|
||||
|
||||
func TestAIBridgeRoutingFailClosed(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
providerID := uuid.New()
|
||||
chat := database.Chat{ID: uuid.New(), OwnerID: uuid.New()}
|
||||
aiProvider := aibridgeTestAIProvider(providerID, "primary-openai", database.AiProviderTypeOpenai)
|
||||
|
||||
t.Run("NilFactory", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
server := &Server{aiGatewayRoutingEnabled: true}
|
||||
_, err := server.newModel(t.Context(), aibridgeTestRequest(chat, "gpt-4"), aibridgeTestRoute(aiProvider), modelBuildOptions{ActiveAPIKeyID: uuid.NewString()})
|
||||
require.ErrorContains(t, err, "transport factory")
|
||||
})
|
||||
|
||||
t.Run("FactoryError", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
factory := &aibridgeTestFactory{err: xerrors.New("boom")}
|
||||
server := &Server{
|
||||
aiGatewayRoutingEnabled: true,
|
||||
aibridgeTransportFactory: aibridgeTestFactoryPointer(factory),
|
||||
}
|
||||
_, err := server.newModel(t.Context(), aibridgeTestRequest(chat, "gpt-4"), aibridgeTestRoute(aiProvider), modelBuildOptions{ActiveAPIKeyID: uuid.NewString()})
|
||||
require.ErrorContains(t, err, "boom")
|
||||
})
|
||||
|
||||
t.Run("MissingProviderName", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
server := &Server{aiGatewayRoutingEnabled: true}
|
||||
missingNameProvider := aibridgeTestAIProvider(providerID, "", database.AiProviderTypeOpenai)
|
||||
_, err := server.newModel(t.Context(), aibridgeTestRequest(chat, "gpt-4"), aibridgeTestRoute(missingNameProvider), modelBuildOptions{ActiveAPIKeyID: uuid.NewString()})
|
||||
require.ErrorContains(t, err, "AI provider name")
|
||||
})
|
||||
|
||||
t.Run("MissingAPIKeyID", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
factory := &aibridgeTestFactory{rt: roundTripFunc(func(*http.Request) (*http.Response, error) {
|
||||
t.Fatal("transport must not be used without an API key ID")
|
||||
return nil, xerrors.New("unreachable")
|
||||
})}
|
||||
server := &Server{
|
||||
aiGatewayRoutingEnabled: true,
|
||||
aibridgeTransportFactory: aibridgeTestFactoryPointer(factory),
|
||||
}
|
||||
_, err := server.newModel(t.Context(), aibridgeTestRequest(chat, "gpt-4"), aibridgeTestRoute(aiProvider), modelBuildOptions{})
|
||||
require.ErrorContains(t, err, "active turn API key ID")
|
||||
})
|
||||
|
||||
t.Run("StaticModel", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
server := &Server{aiGatewayRoutingEnabled: true}
|
||||
_, err := server.newModel(t.Context(), aibridgeTestRequest(chat, "gpt-4"), newAIGatewayModelRoute(database.AIProvider{}, "", aiGatewayProviderAuth{}), modelBuildOptions{ActiveAPIKeyID: uuid.NewString()})
|
||||
require.ErrorContains(t, err, "concrete AI provider")
|
||||
})
|
||||
}
|
||||
|
||||
func TestDirectModelBuildDoesNotRequireActiveAPIKeyID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := &Server{}
|
||||
model, err := server.newModel(t.Context(), modelClientRequest{
|
||||
Chat: database.Chat{ID: uuid.New(), OwnerID: uuid.New()},
|
||||
ModelName: "gpt-4",
|
||||
UserAgent: chatprovider.UserAgent(),
|
||||
}, newDirectModelRoute("openai", chatprovider.ProviderAPIKeys{OpenAI: "sk-test"}), modelBuildOptions{})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, model)
|
||||
}
|
||||
|
||||
func TestAIBridgeComputerUseModelUsesRoute(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
providerID := uuid.New()
|
||||
apiKeyID := uuid.NewString()
|
||||
factory := &aibridgeTestFactory{rt: roundTripFunc(func(*http.Request) (*http.Response, error) {
|
||||
t.Fatal("computer use model construction must not send a request")
|
||||
return nil, xerrors.New("unreachable")
|
||||
})}
|
||||
chat := database.Chat{ID: uuid.New(), OwnerID: uuid.New()}
|
||||
server := &Server{
|
||||
aiGatewayRoutingEnabled: true,
|
||||
aibridgeTransportFactory: aibridgeTestFactoryPointer(factory),
|
||||
}
|
||||
provider := chattool.ComputerUseProviderOpenAI
|
||||
modelProvider, modelName, ok := chattool.DefaultComputerUseModel(provider)
|
||||
require.True(t, ok)
|
||||
|
||||
ctx := aibridge.WithDelegatedAPIKeyID(t.Context(), "context-key-must-be-ignored")
|
||||
model, debugEnabled, resolvedProvider, resolvedModel, err := server.resolveComputerUseModel(
|
||||
ctx,
|
||||
chat,
|
||||
aibridgeTestRoute(aibridgeTestAIProvider(providerID, "primary-openai", database.AiProviderTypeOpenai)),
|
||||
provider,
|
||||
modelProvider,
|
||||
modelName,
|
||||
modelBuildOptions{ActiveAPIKeyID: apiKeyID},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, model)
|
||||
require.False(t, debugEnabled)
|
||||
require.Equal(t, chattool.ComputerUseProviderOpenAI, resolvedProvider)
|
||||
require.Equal(t, modelName, resolvedModel)
|
||||
require.Equal(t, "primary-openai", factory.providerName)
|
||||
require.Equal(t, aibridge.SourceAgents, factory.source)
|
||||
}
|
||||
|
||||
func TestAIBridgeDelegatedContextPropagation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
providerID := uuid.New()
|
||||
apiKeyID := uuid.NewString()
|
||||
type seenRequest struct {
|
||||
apiKeyID string
|
||||
ok bool
|
||||
path string
|
||||
}
|
||||
seen := make(chan seenRequest, 1)
|
||||
factory := &aibridgeTestFactory{rt: roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
gotAPIKeyID, ok := aibridge.DelegatedAPIKeyIDFromContext(req.Context())
|
||||
seen <- seenRequest{
|
||||
apiKeyID: gotAPIKeyID,
|
||||
ok: ok,
|
||||
path: req.URL.Path,
|
||||
}
|
||||
body := `{"id":"resp_test","object":"response","created_at":0,"status":"completed","model":"gpt-4","output":[{"id":"msg_test","type":"message","role":"assistant","content":[{"type":"output_text","text":"hello"}]}],"usage":{"input_tokens":1,"output_tokens":1,"total_tokens":2}}`
|
||||
return &http.Response{
|
||||
StatusCode: http.StatusOK,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
Body: io.NopCloser(strings.NewReader(body)),
|
||||
Request: req,
|
||||
}, nil
|
||||
})}
|
||||
chat := database.Chat{ID: uuid.New(), OwnerID: uuid.New()}
|
||||
server := &Server{
|
||||
aiGatewayRoutingEnabled: true,
|
||||
aibridgeTransportFactory: aibridgeTestFactoryPointer(factory),
|
||||
}
|
||||
|
||||
ctx := aibridge.WithDelegatedAPIKeyID(t.Context(), "context-key-must-be-ignored")
|
||||
model, err := server.newModel(ctx, aibridgeTestRequest(chat, "gpt-4"), aibridgeTestRoute(aibridgeTestAIProvider(providerID, "primary-openai", database.AiProviderTypeOpenai)), modelBuildOptions{ActiveAPIKeyID: apiKeyID, RecordHTTP: true})
|
||||
require.NoError(t, err)
|
||||
_, err = model.Generate(t.Context(), fantasy.Call{Prompt: []fantasy.Message{{
|
||||
Role: fantasy.MessageRoleUser,
|
||||
Content: []fantasy.MessagePart{fantasy.TextPart{Text: "hello"}},
|
||||
}}})
|
||||
require.NoError(t, err)
|
||||
|
||||
got := <-seen
|
||||
require.Equal(t, "primary-openai", factory.providerName)
|
||||
require.Equal(t, aibridge.SourceAgents, factory.source)
|
||||
require.True(t, got.ok)
|
||||
require.Equal(t, "/v1/responses", got.path)
|
||||
require.Equal(t, apiKeyID, got.apiKeyID)
|
||||
}
|
||||
+66
-77
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -69,9 +68,39 @@ var preferredTitleModels = []struct {
|
||||
type shortTextCandidate struct {
|
||||
provider string
|
||||
model string
|
||||
route resolvedModelRoute
|
||||
lm fantasy.LanguageModel
|
||||
}
|
||||
|
||||
func (p *Server) preferredShortTextCandidates(
|
||||
chat database.Chat,
|
||||
keys chatprovider.ProviderAPIKeys,
|
||||
) []shortTextCandidate {
|
||||
if p.shouldUseAIGatewayRouting() {
|
||||
return nil
|
||||
}
|
||||
|
||||
candidates := make([]shortTextCandidate, 0, len(preferredTitleModels)+1)
|
||||
userAgent := chatprovider.UserAgent()
|
||||
extraHeaders := chatprovider.CoderHeaders(chat)
|
||||
for _, candidate := range preferredTitleModels {
|
||||
model, err := chatprovider.ModelFromConfig(
|
||||
candidate.provider, candidate.model, keys, userAgent,
|
||||
extraHeaders,
|
||||
nil,
|
||||
)
|
||||
if err == nil {
|
||||
candidates = append(candidates, shortTextCandidate{
|
||||
provider: candidate.provider,
|
||||
model: candidate.model,
|
||||
route: newDirectModelRoute(candidate.provider, keys),
|
||||
lm: model,
|
||||
})
|
||||
}
|
||||
}
|
||||
return candidates
|
||||
}
|
||||
|
||||
func selectPreferredConfiguredShortTextModelConfig(
|
||||
configs []database.ChatModelConfig,
|
||||
) (database.ChatModelConfig, bool) {
|
||||
@@ -121,7 +150,9 @@ func (p *Server) maybeGenerateChatTitle(
|
||||
fallbackProvider string,
|
||||
fallbackModelName string,
|
||||
fallbackModel fantasy.LanguageModel,
|
||||
fallbackRoute resolvedModelRoute,
|
||||
keys chatprovider.ProviderAPIKeys,
|
||||
modelOpts modelBuildOptions,
|
||||
generatedTitle *generatedChatTitle,
|
||||
logger slog.Logger,
|
||||
debugSvc *chatdebug.Service,
|
||||
@@ -135,10 +166,11 @@ func (p *Server) maybeGenerateChatTitle(
|
||||
titleCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
overrideConfig, overrideModel, overrideKeys, overrideSet, overrideErr := p.resolveTitleGenerationModelOverride(
|
||||
overrideConfig, overrideModel, _, overrideRoute, overrideSet, overrideErr := p.resolveTitleGenerationModelOverride(
|
||||
titleCtx,
|
||||
chat,
|
||||
keys,
|
||||
modelOpts,
|
||||
)
|
||||
if overrideErr != nil {
|
||||
if overrideSet {
|
||||
@@ -161,29 +193,15 @@ func (p *Server) maybeGenerateChatTitle(
|
||||
candidates = []shortTextCandidate{{
|
||||
provider: overrideConfig.Provider,
|
||||
model: overrideConfig.Model,
|
||||
route: overrideRoute,
|
||||
lm: overrideModel,
|
||||
}}
|
||||
} else {
|
||||
// Build candidate list: preferred lightweight models first,
|
||||
// then the user's chat model as last resort.
|
||||
candidates = make([]shortTextCandidate, 0, len(preferredTitleModels)+1)
|
||||
for _, c := range preferredTitleModels {
|
||||
m, err := chatprovider.ModelFromConfig(
|
||||
c.provider, c.model, keys, chatprovider.UserAgent(),
|
||||
chatprovider.CoderHeaders(chat),
|
||||
nil,
|
||||
)
|
||||
if err == nil {
|
||||
candidates = append(candidates, shortTextCandidate{
|
||||
provider: c.provider,
|
||||
model: c.model,
|
||||
lm: m,
|
||||
})
|
||||
}
|
||||
}
|
||||
candidates = p.preferredShortTextCandidates(chat, keys)
|
||||
candidates = append(candidates, shortTextCandidate{
|
||||
provider: fallbackProvider,
|
||||
model: fallbackModelName,
|
||||
route: fallbackRoute,
|
||||
lm: fallbackModel,
|
||||
})
|
||||
}
|
||||
@@ -213,17 +231,13 @@ func (p *Server) maybeGenerateChatTitle(
|
||||
candidateCtx := titleCtx
|
||||
candidateModel := candidate.lm
|
||||
finishDebugRun := func(error) {}
|
||||
candidateKeys := keys
|
||||
if overrideSet {
|
||||
candidateKeys = overrideKeys
|
||||
}
|
||||
if debugEnabled {
|
||||
candidateCtx, candidateModel, finishDebugRun = prepareQuickgenDebugCandidate(
|
||||
candidateCtx, candidateModel, finishDebugRun = p.prepareQuickgenDebugCandidate(
|
||||
titleCtx,
|
||||
chat,
|
||||
candidateKeys,
|
||||
debugSvc,
|
||||
candidate,
|
||||
modelOpts,
|
||||
chatdebug.KindTitleGeneration,
|
||||
triggerMessageID,
|
||||
historyTipMessageID,
|
||||
@@ -293,32 +307,26 @@ func (p *Server) maybeGenerateChatTitle(
|
||||
}
|
||||
}
|
||||
|
||||
func newQuickgenDebugModel(
|
||||
func (p *Server) newQuickgenDebugModel(
|
||||
ctx context.Context,
|
||||
chat database.Chat,
|
||||
keys chatprovider.ProviderAPIKeys,
|
||||
debugSvc *chatdebug.Service,
|
||||
provider string,
|
||||
model string,
|
||||
route resolvedModelRoute,
|
||||
modelOpts modelBuildOptions,
|
||||
) (fantasy.LanguageModel, error) {
|
||||
httpClient := &http.Client{Transport: &chatdebug.RecordingTransport{}}
|
||||
debugModel, err := chatprovider.ModelFromConfig(
|
||||
provider,
|
||||
model,
|
||||
keys,
|
||||
chatprovider.UserAgent(),
|
||||
chatprovider.CoderHeaders(chat),
|
||||
httpClient,
|
||||
)
|
||||
debugOpts := modelOpts
|
||||
debugOpts.RecordHTTP = true
|
||||
debugModel, err := p.newModel(ctx, modelClientRequest{
|
||||
Chat: chat,
|
||||
ModelName: model,
|
||||
UserAgent: chatprovider.UserAgent(),
|
||||
ExtraHeaders: chatprovider.CoderHeaders(chat),
|
||||
}, route, debugOpts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if debugModel == nil {
|
||||
return nil, xerrors.Errorf(
|
||||
"create model for %s/%s returned nil",
|
||||
provider,
|
||||
model,
|
||||
)
|
||||
}
|
||||
|
||||
return chatdebug.WrapModel(debugModel, debugSvc, chatdebug.RecorderOptions{
|
||||
ChatID: chat.ID,
|
||||
@@ -328,12 +336,12 @@ func newQuickgenDebugModel(
|
||||
}), nil
|
||||
}
|
||||
|
||||
func prepareQuickgenDebugCandidate(
|
||||
func (p *Server) prepareQuickgenDebugCandidate(
|
||||
ctx context.Context,
|
||||
chat database.Chat,
|
||||
keys chatprovider.ProviderAPIKeys,
|
||||
debugSvc *chatdebug.Service,
|
||||
candidate shortTextCandidate,
|
||||
modelOpts modelBuildOptions,
|
||||
kind chatdebug.RunKind,
|
||||
triggerMessageID int64,
|
||||
historyTipMessageID int64,
|
||||
@@ -345,12 +353,14 @@ func prepareQuickgenDebugCandidate(
|
||||
return ctx, candidate.lm, finishDebugRun
|
||||
}
|
||||
|
||||
debugModel, err := newQuickgenDebugModel(
|
||||
debugModel, err := p.newQuickgenDebugModel(
|
||||
ctx,
|
||||
chat,
|
||||
keys,
|
||||
debugSvc,
|
||||
candidate.provider,
|
||||
candidate.model,
|
||||
candidate.route,
|
||||
modelOpts,
|
||||
)
|
||||
if err != nil {
|
||||
logger.Warn(ctx, "failed to build short-text debug model",
|
||||
@@ -393,18 +403,8 @@ func prepareQuickgenDebugCandidate(
|
||||
return ctx, candidate.lm, finishDebugRun
|
||||
}
|
||||
|
||||
runCtx := chatdebug.ContextWithRun(
|
||||
ctx,
|
||||
&chatdebug.RunContext{
|
||||
RunID: run.ID,
|
||||
ChatID: chat.ID,
|
||||
TriggerMessageID: triggerMessageID,
|
||||
HistoryTipMessageID: historyTipMessageID,
|
||||
Kind: kind,
|
||||
Provider: candidate.provider,
|
||||
Model: candidate.model,
|
||||
},
|
||||
)
|
||||
runContext := chatdebugRunContext(run)
|
||||
runCtx := chatdebug.ContextWithRun(ctx, &runContext)
|
||||
finishDebugRun = func(runErr error) {
|
||||
if finalizeErr := debugSvc.FinalizeRun(ctx, chatdebug.FinalizeRunParams{
|
||||
RunID: run.ID,
|
||||
@@ -824,7 +824,7 @@ const turnStatusLabelPrompt = "You write compact chat status labels for a sideba
|
||||
// message text. It follows the same candidate-selection strategy
|
||||
// as title generation: try preferred lightweight models first, then
|
||||
// fall back to the provided model. Returns "" on any failure.
|
||||
func generateTurnStatusLabel(
|
||||
func (p *Server) generateTurnStatusLabel(
|
||||
ctx context.Context,
|
||||
chat database.Chat,
|
||||
status database.ChatStatus,
|
||||
@@ -832,7 +832,9 @@ func generateTurnStatusLabel(
|
||||
fallbackProvider string,
|
||||
fallbackModelName string,
|
||||
fallbackModel fantasy.LanguageModel,
|
||||
fallbackRoute resolvedModelRoute,
|
||||
keys chatprovider.ProviderAPIKeys,
|
||||
modelOpts modelBuildOptions,
|
||||
logger slog.Logger,
|
||||
debugSvc *chatdebug.Service,
|
||||
triggerMessageID int64,
|
||||
@@ -848,24 +850,11 @@ func generateTurnStatusLabel(
|
||||
"\nChat title: " + chat.Title +
|
||||
"\n\nAgent's latest message:\n" + assistantText
|
||||
|
||||
candidates := make([]shortTextCandidate, 0, len(preferredTitleModels)+1)
|
||||
for _, c := range preferredTitleModels {
|
||||
m, err := chatprovider.ModelFromConfig(
|
||||
c.provider, c.model, keys, chatprovider.UserAgent(),
|
||||
chatprovider.CoderHeaders(chat),
|
||||
nil,
|
||||
)
|
||||
if err == nil {
|
||||
candidates = append(candidates, shortTextCandidate{
|
||||
provider: c.provider,
|
||||
model: c.model,
|
||||
lm: m,
|
||||
})
|
||||
}
|
||||
}
|
||||
candidates := p.preferredShortTextCandidates(chat, keys)
|
||||
candidates = append(candidates, shortTextCandidate{
|
||||
provider: fallbackProvider,
|
||||
model: fallbackModelName,
|
||||
route: fallbackRoute,
|
||||
lm: fallbackModel,
|
||||
})
|
||||
|
||||
@@ -876,12 +865,12 @@ func generateTurnStatusLabel(
|
||||
candidateModel := candidate.lm
|
||||
finishDebugRun := func(error) {}
|
||||
if debugEnabled {
|
||||
candidateCtx, candidateModel, finishDebugRun = prepareQuickgenDebugCandidate(
|
||||
candidateCtx, candidateModel, finishDebugRun = p.prepareQuickgenDebugCandidate(
|
||||
labelCtx,
|
||||
chat,
|
||||
keys,
|
||||
debugSvc,
|
||||
candidate,
|
||||
modelOpts,
|
||||
chatdebug.KindQuickgen,
|
||||
triggerMessageID,
|
||||
historyTipMessageID,
|
||||
|
||||
@@ -359,6 +359,16 @@ func Test_renderManualTitlePrompt(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestPreferredShortTextCandidatesNilUnderAIGateway(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
server := &Server{aiGatewayRoutingEnabled: true}
|
||||
candidates := server.preferredShortTextCandidates(database.Chat{}, chatprovider.ProviderAPIKeys{
|
||||
ByProvider: map[string]string{"openai": "test-key"},
|
||||
})
|
||||
require.Nil(t, candidates)
|
||||
}
|
||||
|
||||
func TestMaybeGenerateChatTitlePreservesUpdatedAt(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -428,7 +438,9 @@ func TestMaybeGenerateChatTitlePreservesUpdatedAt(t *testing.T) {
|
||||
"openai",
|
||||
"test-model",
|
||||
model,
|
||||
resolvedModelRoute{},
|
||||
chatprovider.ProviderAPIKeys{},
|
||||
modelBuildOptions{},
|
||||
generated,
|
||||
logger,
|
||||
nil,
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"github.com/coder/coder/v2/coderd/aibridge"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
coderdpubsub "github.com/coder/coder/v2/coderd/pubsub"
|
||||
@@ -908,6 +909,14 @@ func (p *Server) resolveExploreToolSnapshot(
|
||||
return inheritedMCPServerIDs, nil
|
||||
}
|
||||
|
||||
func (p *Server) delegatedAPIKeyIDForSubagent(ctx context.Context) (string, error) {
|
||||
apiKeyID, ok := aibridge.DelegatedAPIKeyIDFromContext(ctx)
|
||||
if !ok && p.shouldUseAIGatewayRouting() {
|
||||
return "", xerrors.New("AI Gateway routing requires the active turn API key ID for subagent messages")
|
||||
}
|
||||
return apiKeyID, nil
|
||||
}
|
||||
|
||||
func (p *Server) createChildSubagentChat(
|
||||
ctx context.Context,
|
||||
parent database.Chat,
|
||||
@@ -950,6 +959,10 @@ func (p *Server) createChildSubagentChatWithOptions(
|
||||
if modelConfigID == uuid.Nil {
|
||||
return database.Chat{}, xerrors.New("model config is required")
|
||||
}
|
||||
childAPIKeyID, err := p.delegatedAPIKeyIDForSubagent(ctx)
|
||||
if err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
|
||||
childPlanMode := parent.PlanMode
|
||||
if opts.planModeOverride != nil {
|
||||
@@ -1086,7 +1099,7 @@ func (p *Server) createChildSubagentChatWithOptions(
|
||||
database.ChatMessageVisibilityBoth,
|
||||
modelConfigID,
|
||||
chatprompt.CurrentContentVersion,
|
||||
).withCreatedBy(parent.OwnerID))
|
||||
).withCreatedBy(parent.OwnerID).withAPIKeyID(childAPIKeyID))
|
||||
if _, err := tx.InsertChatMessages(ctx, userParams); err != nil {
|
||||
return xerrors.Errorf("insert initial child user message: %w", err)
|
||||
}
|
||||
@@ -1236,10 +1249,16 @@ func (p *Server) sendSubagentMessage(
|
||||
return database.Chat{}, xerrors.Errorf("get target chat: %w", err)
|
||||
}
|
||||
|
||||
apiKeyID, err := p.delegatedAPIKeyIDForSubagent(ctx)
|
||||
if err != nil {
|
||||
return database.Chat{}, err
|
||||
}
|
||||
|
||||
sendResult, err := p.SendMessage(ctx, SendMessageOptions{
|
||||
ChatID: targetChatID,
|
||||
CreatedBy: targetChat.OwnerID,
|
||||
Content: []codersdk.ChatMessagePart{codersdk.ChatMessageText(message)},
|
||||
APIKeyID: apiKeyID,
|
||||
BusyBehavior: busyBehavior,
|
||||
})
|
||||
if err != nil {
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
|
||||
"cdr.dev/slog/v3"
|
||||
"cdr.dev/slog/v3/sloggers/slogtest"
|
||||
"github.com/coder/coder/v2/coderd/aibridge"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||
@@ -231,6 +232,163 @@ func insertInternalAIProvider(
|
||||
})
|
||||
}
|
||||
|
||||
func TestCreateChildSubagentChatPropagatesActiveTurnAPIKeyID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
dbgen.OrganizationMember(t, db, database.OrganizationMember{UserID: user.ID, OrganizationID: org.ID})
|
||||
model := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{})
|
||||
parent := dbgen.Chat(t, db, database.Chat{
|
||||
OrganizationID: org.ID,
|
||||
OwnerID: user.ID,
|
||||
LastModelConfigID: model.ID,
|
||||
})
|
||||
|
||||
apiKey, _ := dbgen.APIKey(t, db, database.APIKey{UserID: user.ID})
|
||||
ctx = aibridge.WithDelegatedAPIKeyID(ctx, apiKey.ID)
|
||||
|
||||
server := &Server{db: db, logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})}
|
||||
child, err := server.createChildSubagentChat(ctx, parent, "inspect the workspace", "")
|
||||
require.NoError(t, err)
|
||||
|
||||
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ChatID: child.ID})
|
||||
require.NoError(t, err)
|
||||
var childUserMessage database.ChatMessage
|
||||
for _, message := range messages {
|
||||
if message.Role == database.ChatMessageRoleUser {
|
||||
childUserMessage = message
|
||||
break
|
||||
}
|
||||
}
|
||||
require.NotZero(t, childUserMessage.ID)
|
||||
require.True(t, childUserMessage.APIKeyID.Valid)
|
||||
require.Equal(t, apiKey.ID, childUserMessage.APIKeyID.String)
|
||||
}
|
||||
|
||||
func TestSendSubagentMessagePropagatesActiveTurnAPIKeyID(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
|
||||
ctx := chatdTestContext(t)
|
||||
user, org, model := seedInternalChatDeps(t, db)
|
||||
apiKey, _ := dbgen.APIKey(t, db, database.APIKey{UserID: user.ID})
|
||||
|
||||
parent, err := server.CreateChat(ctx, CreateOptions{
|
||||
OrganizationID: org.ID,
|
||||
OwnerID: user.ID,
|
||||
Title: "parent-send-subagent-key",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
||||
APIKeyID: apiKey.ID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
child, err := server.CreateChat(ctx, CreateOptions{
|
||||
OrganizationID: org.ID,
|
||||
OwnerID: user.ID,
|
||||
ParentChatID: uuid.NullUUID{UUID: parent.ID, Valid: true},
|
||||
RootChatID: uuid.NullUUID{UUID: parent.ID, Valid: true},
|
||||
Title: "child-send-subagent-key",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText("do work"),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
setChatStatus(ctx, t, db, child.ID, database.ChatStatusWaiting, "")
|
||||
|
||||
ctx = aibridge.WithDelegatedAPIKeyID(ctx, apiKey.ID)
|
||||
_, err = server.sendSubagentMessage(
|
||||
ctx,
|
||||
parent.ID,
|
||||
child.ID,
|
||||
"follow up",
|
||||
SendMessageBusyBehaviorInterrupt,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
messages, err := db.GetChatMessagesByChatID(ctx, database.GetChatMessagesByChatIDParams{ChatID: child.ID})
|
||||
require.NoError(t, err)
|
||||
var latestUserMessage database.ChatMessage
|
||||
for _, message := range messages {
|
||||
if message.Role == database.ChatMessageRoleUser && message.ID > latestUserMessage.ID {
|
||||
latestUserMessage = message
|
||||
}
|
||||
}
|
||||
require.NotZero(t, latestUserMessage.ID)
|
||||
require.True(t, latestUserMessage.APIKeyID.Valid)
|
||||
require.Equal(t, apiKey.ID, latestUserMessage.APIKeyID.String)
|
||||
}
|
||||
|
||||
func TestCreateChildSubagentChatRequiresActiveTurnAPIKeyIDForAIGateway(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
db, _ := dbtestutil.NewDB(t)
|
||||
user := dbgen.User(t, db, database.User{})
|
||||
org := dbgen.Organization(t, db, database.Organization{})
|
||||
dbgen.OrganizationMember(t, db, database.OrganizationMember{UserID: user.ID, OrganizationID: org.ID})
|
||||
model := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{})
|
||||
parent := dbgen.Chat(t, db, database.Chat{
|
||||
OrganizationID: org.ID,
|
||||
OwnerID: user.ID,
|
||||
LastModelConfigID: model.ID,
|
||||
})
|
||||
|
||||
server := &Server{
|
||||
db: db,
|
||||
logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}),
|
||||
aiGatewayRoutingEnabled: true,
|
||||
}
|
||||
_, err := server.createChildSubagentChat(ctx, parent, "inspect the workspace", "")
|
||||
require.ErrorContains(t, err, "AI Gateway routing requires the active turn API key ID for subagent messages")
|
||||
}
|
||||
|
||||
func TestSendSubagentMessageRequiresActiveTurnAPIKeyIDForAIGateway(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
db, ps := dbtestutil.NewDB(t)
|
||||
server := newInternalTestServer(t, db, ps, chatprovider.ProviderAPIKeys{})
|
||||
server.aiGatewayRoutingEnabled = true
|
||||
ctx := chatdTestContext(t)
|
||||
user, org, model := seedInternalChatDeps(t, db)
|
||||
|
||||
parent, err := server.CreateChat(ctx, CreateOptions{
|
||||
OrganizationID: org.ID,
|
||||
OwnerID: user.ID,
|
||||
Title: "parent-send-subagent-missing-key",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{codersdk.ChatMessageText("hello")},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
child, err := server.CreateChat(ctx, CreateOptions{
|
||||
OrganizationID: org.ID,
|
||||
OwnerID: user.ID,
|
||||
ParentChatID: uuid.NullUUID{UUID: parent.ID, Valid: true},
|
||||
RootChatID: uuid.NullUUID{UUID: parent.ID, Valid: true},
|
||||
Title: "child-send-subagent-missing-key",
|
||||
ModelConfigID: model.ID,
|
||||
InitialUserContent: []codersdk.ChatMessagePart{
|
||||
codersdk.ChatMessageText("do work"),
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
setChatStatus(ctx, t, db, child.ID, database.ChatStatusWaiting, "")
|
||||
_, err = server.sendSubagentMessage(
|
||||
ctx,
|
||||
parent.ID,
|
||||
child.ID,
|
||||
"follow up",
|
||||
SendMessageBusyBehaviorInterrupt,
|
||||
)
|
||||
require.ErrorContains(t, err, "AI Gateway routing requires the active turn API key ID for subagent messages")
|
||||
}
|
||||
|
||||
func TestResolveUserProviderAPIKeys_AIProvider(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -351,7 +509,7 @@ func TestResolveChatModel_AIProviderDisabled(t *testing.T) {
|
||||
LastModelConfigID: modelConfig.ID,
|
||||
})
|
||||
|
||||
model, config, keys, debugEnabled, resolvedProvider, resolvedModel, err := server.resolveChatModel(ctx, chat)
|
||||
model, config, keys, _, debugEnabled, resolvedProvider, resolvedModel, err := server.resolveChatModel(ctx, chat, modelBuildOptions{})
|
||||
require.ErrorContains(t, err, "is disabled")
|
||||
require.Nil(t, model)
|
||||
require.Equal(t, database.ChatModelConfig{}, config)
|
||||
@@ -3187,6 +3345,26 @@ func TestAwaitSubagentCompletion(t *testing.T) {
|
||||
|
||||
parent, child := createParentChildChats(ctx, t, server, user, org, model)
|
||||
|
||||
// signalWake from CreateChat triggers background processing. Wait
|
||||
// for those runs to finish, then reset both chats so this test owns
|
||||
// the state transition observed by the poll loop.
|
||||
testutil.Eventually(ctx, t, func(ctx context.Context) bool {
|
||||
parentChat, err := db.GetChatByID(ctx, parent.ID)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
childChat, err := db.GetChatByID(ctx, child.ID)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return parentChat.Status != database.ChatStatusPending &&
|
||||
parentChat.Status != database.ChatStatusRunning &&
|
||||
childChat.Status != database.ChatStatusPending &&
|
||||
childChat.Status != database.ChatStatusRunning
|
||||
}, testutil.IntervalFast)
|
||||
setChatStatus(ctx, t, db, parent.ID, database.ChatStatusRunning, "")
|
||||
setChatStatus(ctx, t, db, child.ID, database.ChatStatusRunning, "")
|
||||
|
||||
// Set the trap BEFORE starting the goroutine so we
|
||||
// deterministically catch the ticker creation.
|
||||
tickTrap := mClock.Trap().NewTicker("chatd", "subagent_poll")
|
||||
|
||||
@@ -31,28 +31,18 @@ func readTitleGenerationModelOverride(
|
||||
}
|
||||
|
||||
// resolveTitleGenerationModelOverride resolves the deployment-wide title
|
||||
// generation model override. It returns four values:
|
||||
//
|
||||
// - modelConfig and model: populated only on success.
|
||||
// - overrideSet: true when the admin configured a non-empty override,
|
||||
// regardless of whether resolution succeeded. Callers MUST always check
|
||||
// err first; overrideSet alone does not imply the model is usable.
|
||||
// - err: non-nil when resolution failed. DB read failure returns
|
||||
// (zero, nil, false, err). With overrideSet=true, the override is
|
||||
// configured but unusable (deleted model, missing credentials, etc.) and
|
||||
// callers should treat this as a hard failure for explicit-override
|
||||
// semantics, not a soft fallback.
|
||||
//
|
||||
// When the override is unset or stored as malformed, the function returns
|
||||
// (zero, nil, false, nil) so callers can fall back to default behavior.
|
||||
// generation model override. overrideSet is true when an override was
|
||||
// configured; in that case any returned error is a hard failure. When
|
||||
// overrideSet is false, callers may fall back to the default title model.
|
||||
func (p *Server) resolveTitleGenerationModelOverride(
|
||||
ctx context.Context,
|
||||
chat database.Chat,
|
||||
keys chatprovider.ProviderAPIKeys,
|
||||
) (database.ChatModelConfig, fantasy.LanguageModel, chatprovider.ProviderAPIKeys, bool, error) {
|
||||
modelOpts modelBuildOptions,
|
||||
) (database.ChatModelConfig, fantasy.LanguageModel, chatprovider.ProviderAPIKeys, resolvedModelRoute, bool, error) {
|
||||
raw, err := readTitleGenerationModelOverride(ctx, p.db)
|
||||
if err != nil {
|
||||
return database.ChatModelConfig{}, nil, chatprovider.ProviderAPIKeys{}, false, xerrors.Errorf(
|
||||
return database.ChatModelConfig{}, nil, chatprovider.ProviderAPIKeys{}, resolvedModelRoute{}, false, xerrors.Errorf(
|
||||
"read title generation model override: %w",
|
||||
err,
|
||||
)
|
||||
@@ -84,43 +74,28 @@ func (p *Server) resolveTitleGenerationModelOverride(
|
||||
modelOverrideFailureModeHard,
|
||||
)
|
||||
if err != nil {
|
||||
return database.ChatModelConfig{}, nil, chatprovider.ProviderAPIKeys{}, overrideSet, err
|
||||
return database.ChatModelConfig{}, nil, chatprovider.ProviderAPIKeys{}, resolvedModelRoute{}, overrideSet, err
|
||||
}
|
||||
if !overrideSet {
|
||||
return database.ChatModelConfig{}, nil, keys, false, nil
|
||||
return database.ChatModelConfig{}, nil, keys, resolvedModelRoute{}, false, nil
|
||||
}
|
||||
|
||||
providerHint := modelConfig.Provider
|
||||
if modelConfig.AIProviderID.Valid {
|
||||
//nolint:gocritic // Title overrides need chatd-scoped provider reads for user-owned chats.
|
||||
provider, err := p.db.GetAIProviderByID(dbauthz.AsChatd(ctx), modelConfig.AIProviderID.UUID)
|
||||
if err != nil {
|
||||
return database.ChatModelConfig{}, nil, chatprovider.ProviderAPIKeys{}, true, xerrors.Errorf("get AI provider for title generation override: %w", err)
|
||||
}
|
||||
if !provider.Enabled {
|
||||
return database.ChatModelConfig{}, nil, chatprovider.ProviderAPIKeys{}, true, xerrors.Errorf("AI provider %s is disabled", modelConfig.AIProviderID.UUID)
|
||||
}
|
||||
providerHint = string(provider.Type)
|
||||
}
|
||||
model, err := chatprovider.ModelFromConfig(
|
||||
providerHint,
|
||||
modelConfig.Model,
|
||||
overrideProviderKeys,
|
||||
chatprovider.UserAgent(),
|
||||
chatprovider.CoderHeaders(chat),
|
||||
nil,
|
||||
)
|
||||
//nolint:gocritic // Title overrides need chatd-scoped provider reads for user-owned chats.
|
||||
route, err := p.resolveModelRouteForConfig(dbauthz.AsChatd(ctx), chat.OwnerID, modelConfig, overrideProviderKeys)
|
||||
if err != nil {
|
||||
return database.ChatModelConfig{}, nil, chatprovider.ProviderAPIKeys{}, true, xerrors.Errorf(
|
||||
return database.ChatModelConfig{}, nil, chatprovider.ProviderAPIKeys{}, resolvedModelRoute{}, true, err
|
||||
}
|
||||
model, err := p.newModel(ctx, modelClientRequest{
|
||||
Chat: chat,
|
||||
ModelName: modelConfig.Model,
|
||||
UserAgent: chatprovider.UserAgent(),
|
||||
ExtraHeaders: chatprovider.CoderHeaders(chat),
|
||||
}, route, modelOpts)
|
||||
if err != nil {
|
||||
return database.ChatModelConfig{}, nil, chatprovider.ProviderAPIKeys{}, resolvedModelRoute{}, true, xerrors.Errorf(
|
||||
"create title generation model override: %w",
|
||||
err,
|
||||
)
|
||||
}
|
||||
if model == nil {
|
||||
return database.ChatModelConfig{}, nil, chatprovider.ProviderAPIKeys{}, true, xerrors.Errorf(
|
||||
"create title generation model override returned nil",
|
||||
)
|
||||
}
|
||||
|
||||
return modelConfig, model, overrideProviderKeys, true, nil
|
||||
return modelConfig, model, route.directProviderKeys(), route, true, nil
|
||||
}
|
||||
|
||||
@@ -65,7 +65,9 @@ func TestMaybeGenerateChatTitle_TitleGenerationOverrideUnset(t *testing.T) {
|
||||
"openai",
|
||||
"fallback-chat-model",
|
||||
fallbackModel,
|
||||
resolvedModelRoute{},
|
||||
keys,
|
||||
modelBuildOptions{},
|
||||
generated,
|
||||
logger,
|
||||
nil,
|
||||
@@ -112,7 +114,9 @@ func TestMaybeGenerateChatTitle_TitleGenerationOverrideUnset(t *testing.T) {
|
||||
"openai",
|
||||
"fallback-chat-model",
|
||||
fallbackModel,
|
||||
resolvedModelRoute{},
|
||||
chatprovider.ProviderAPIKeys{},
|
||||
modelBuildOptions{},
|
||||
generated,
|
||||
logger,
|
||||
nil,
|
||||
@@ -160,7 +164,9 @@ func TestMaybeGenerateChatTitle_TitleGenerationOverrideReadDBError(t *testing.T)
|
||||
"openai",
|
||||
"fallback-chat-model",
|
||||
fallbackModel,
|
||||
resolvedModelRoute{},
|
||||
chatprovider.ProviderAPIKeys{},
|
||||
modelBuildOptions{},
|
||||
generated,
|
||||
logger,
|
||||
nil,
|
||||
@@ -207,7 +213,9 @@ func TestMaybeGenerateChatTitle_TitleGenerationOverrideMalformedFallsThrough(t *
|
||||
"openai",
|
||||
"fallback-chat-model",
|
||||
fallbackModel,
|
||||
resolvedModelRoute{},
|
||||
chatprovider.ProviderAPIKeys{},
|
||||
modelBuildOptions{},
|
||||
generated,
|
||||
logger,
|
||||
nil,
|
||||
@@ -257,7 +265,7 @@ func TestMaybeGenerateChatTitle_TitleGenerationOverrideSetUsable(t *testing.T) {
|
||||
db.EXPECT().GetAIProviderKeysByProviderID(gomock.Any(), providerID).Return([]database.AIProviderKey{{
|
||||
ProviderID: providerID,
|
||||
APIKey: "test-key",
|
||||
}}, nil)
|
||||
}}, nil).Times(2)
|
||||
db.EXPECT().UpdateChatTitleByID(gomock.Any(), database.UpdateChatTitleByIDParams{
|
||||
ID: chat.ID,
|
||||
Title: wantTitle,
|
||||
@@ -272,7 +280,9 @@ func TestMaybeGenerateChatTitle_TitleGenerationOverrideSetUsable(t *testing.T) {
|
||||
"openai",
|
||||
"fallback-chat-model",
|
||||
fallbackModel,
|
||||
resolvedModelRoute{},
|
||||
chatprovider.ProviderAPIKeys{},
|
||||
modelBuildOptions{},
|
||||
generated,
|
||||
logger,
|
||||
nil,
|
||||
@@ -312,7 +322,9 @@ func TestMaybeGenerateChatTitle_TitleGenerationOverrideSetUnusableSkips(t *testi
|
||||
"openai",
|
||||
"fallback-chat-model",
|
||||
fallbackModel,
|
||||
resolvedModelRoute{},
|
||||
chatprovider.ProviderAPIKeys{},
|
||||
modelBuildOptions{},
|
||||
generated,
|
||||
logger,
|
||||
nil,
|
||||
@@ -360,7 +372,9 @@ func TestMaybeGenerateChatTitle_TitleGenerationOverrideCallFailureSkipsFallback(
|
||||
"openai",
|
||||
"fallback-chat-model",
|
||||
fallbackModel,
|
||||
resolvedModelRoute{},
|
||||
keys,
|
||||
modelBuildOptions{},
|
||||
generated,
|
||||
logger,
|
||||
nil,
|
||||
@@ -398,6 +412,7 @@ func TestResolveManualTitleModel_TitleGenerationOverrideUnset(t *testing.T) {
|
||||
db,
|
||||
chat,
|
||||
chatprovider.ProviderAPIKeys{ByProvider: map[string]string{"openai": "test-key"}},
|
||||
modelBuildOptions{},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, model)
|
||||
@@ -447,6 +462,7 @@ func TestResolveManualTitleModel_TitleGenerationOverrideUnsetAIProvider(t *testi
|
||||
db,
|
||||
chat,
|
||||
chatprovider.ProviderAPIKeys{},
|
||||
modelBuildOptions{},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, model)
|
||||
@@ -481,6 +497,7 @@ func TestResolveManualTitleModel_TitleGenerationOverrideReadDBError(t *testing.T
|
||||
db,
|
||||
chat,
|
||||
chatprovider.ProviderAPIKeys{ByProvider: map[string]string{"openai": "test-key"}},
|
||||
modelBuildOptions{},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, model)
|
||||
@@ -508,6 +525,7 @@ func TestResolveManualTitleModel_TitleGenerationOverrideSetUsable(t *testing.T)
|
||||
db,
|
||||
chat,
|
||||
chatprovider.ProviderAPIKeys{ByProvider: map[string]string{"openai": "test-key"}},
|
||||
modelBuildOptions{},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, model)
|
||||
@@ -535,6 +553,7 @@ func TestResolveManualTitleModel_TitleGenerationOverrideMissingCredentials(t *te
|
||||
db,
|
||||
chat,
|
||||
chatprovider.ProviderAPIKeys{},
|
||||
modelBuildOptions{},
|
||||
)
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "resolve manual title generation model override")
|
||||
@@ -562,6 +581,7 @@ func TestResolveManualTitleModel_TitleGenerationOverrideSetUnusable(t *testing.T
|
||||
db,
|
||||
chat,
|
||||
chatprovider.ProviderAPIKeys{ByProvider: map[string]string{"openai": "test-key"}},
|
||||
modelBuildOptions{},
|
||||
)
|
||||
require.Error(t, err)
|
||||
require.ErrorContains(t, err, "resolve manual title generation model override")
|
||||
|
||||
+14
-2
@@ -4058,6 +4058,17 @@ Write out the current server config as YAML to stdout.`,
|
||||
Group: &deploymentGroupChat,
|
||||
YAML: "debugLoggingEnabled",
|
||||
},
|
||||
{
|
||||
Name: "Chat: AI Gateway Routing Enabled",
|
||||
Description: "Route chat model requests through AI Gateway when both chat routing and AI Gateway are enabled. Otherwise, chat calls AI providers directly. Pending chats without API key metadata may need a retry or temporary direct routing.",
|
||||
Flag: "chat-ai-gateway-routing-enabled",
|
||||
Env: "CODER_CHAT_AI_GATEWAY_ROUTING_ENABLED",
|
||||
Value: &c.AI.Chat.AIGatewayRoutingEnabled,
|
||||
Default: "true",
|
||||
Group: &deploymentGroupChat,
|
||||
YAML: "aiGatewayRoutingEnabled",
|
||||
Hidden: true,
|
||||
},
|
||||
// AI Bridge Options (deprecated in favor of AI Gateway options)
|
||||
{
|
||||
Name: "AI Bridge Enabled",
|
||||
@@ -4714,8 +4725,9 @@ type AIBridgeProxyConfig struct {
|
||||
}
|
||||
|
||||
type ChatConfig struct {
|
||||
AcquireBatchSize serpent.Int64 `json:"acquire_batch_size" typescript:",notnull"`
|
||||
DebugLoggingEnabled serpent.Bool `json:"debug_logging_enabled" typescript:",notnull"`
|
||||
AcquireBatchSize serpent.Int64 `json:"acquire_batch_size" typescript:",notnull"`
|
||||
DebugLoggingEnabled serpent.Bool `json:"debug_logging_enabled" typescript:",notnull"`
|
||||
AIGatewayRoutingEnabled serpent.Bool `json:"ai_gateway_routing_enabled" typescript:",notnull" swaggerignore:"true"`
|
||||
}
|
||||
|
||||
type AIConfig struct {
|
||||
|
||||
@@ -916,6 +916,15 @@ func TestRetentionConfigParsing(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatAIGatewayRoutingEnabledDefault(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
dv := codersdk.DeploymentValues{}
|
||||
opts := dv.Options()
|
||||
require.NoError(t, opts.SetDefaults())
|
||||
require.True(t, dv.AI.Chat.AIGatewayRoutingEnabled.Value())
|
||||
}
|
||||
|
||||
func TestAIBudgetConfigParsing(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -77,6 +77,9 @@ func TestChatStreamRelay(t *testing.T) {
|
||||
Options: &coderdtest.Options{
|
||||
Database: db,
|
||||
Pubsub: pubsub,
|
||||
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
|
||||
require.NoError(t, dv.AI.Chat.AIGatewayRoutingEnabled.Set("false"))
|
||||
}),
|
||||
},
|
||||
LicenseOptions: &coderdenttest.LicenseOptions{
|
||||
Features: license.Features{
|
||||
@@ -89,6 +92,9 @@ func TestChatStreamRelay(t *testing.T) {
|
||||
Options: &coderdtest.Options{
|
||||
Database: db,
|
||||
Pubsub: pubsub,
|
||||
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
|
||||
require.NoError(t, dv.AI.Chat.AIGatewayRoutingEnabled.Set("false"))
|
||||
}),
|
||||
},
|
||||
DontAddLicense: true,
|
||||
DontAddFirstUser: true,
|
||||
@@ -219,6 +225,9 @@ func TestChatStreamRelay(t *testing.T) {
|
||||
Database: db,
|
||||
Pubsub: pubsub,
|
||||
TLSCertificates: certificates,
|
||||
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
|
||||
require.NoError(t, dv.AI.Chat.AIGatewayRoutingEnabled.Set("false"))
|
||||
}),
|
||||
},
|
||||
LicenseOptions: &coderdenttest.LicenseOptions{
|
||||
Features: license.Features{
|
||||
@@ -232,6 +241,9 @@ func TestChatStreamRelay(t *testing.T) {
|
||||
Database: db,
|
||||
Pubsub: pubsub,
|
||||
TLSCertificates: certificates,
|
||||
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
|
||||
require.NoError(t, dv.AI.Chat.AIGatewayRoutingEnabled.Set("false"))
|
||||
}),
|
||||
},
|
||||
DontAddLicense: true,
|
||||
DontAddFirstUser: true,
|
||||
@@ -398,6 +410,9 @@ func TestChatStreamRelay(t *testing.T) {
|
||||
Options: &coderdtest.Options{
|
||||
Database: db,
|
||||
Pubsub: pubsub,
|
||||
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
|
||||
require.NoError(t, dv.AI.Chat.AIGatewayRoutingEnabled.Set("false"))
|
||||
}),
|
||||
},
|
||||
LicenseOptions: &coderdenttest.LicenseOptions{
|
||||
Features: license.Features{
|
||||
@@ -410,6 +425,9 @@ func TestChatStreamRelay(t *testing.T) {
|
||||
Options: &coderdtest.Options{
|
||||
Database: db,
|
||||
Pubsub: pubsub,
|
||||
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
|
||||
require.NoError(t, dv.AI.Chat.AIGatewayRoutingEnabled.Set("false"))
|
||||
}),
|
||||
},
|
||||
DontAddLicense: true,
|
||||
DontAddFirstUser: true,
|
||||
@@ -544,6 +562,7 @@ func TestChatStreamRelay(t *testing.T) {
|
||||
|
||||
db, pubsub := dbtestutil.NewDB(t)
|
||||
hostPrefixValues := coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
|
||||
require.NoError(t, dv.AI.Chat.AIGatewayRoutingEnabled.Set("false"))
|
||||
dv.HTTPCookies.EnableHostPrefix = true
|
||||
dv.HTTPCookies.Secure = true
|
||||
})
|
||||
@@ -696,6 +715,9 @@ func TestChatStreamRelay(t *testing.T) {
|
||||
Options: &coderdtest.Options{
|
||||
Database: db,
|
||||
Pubsub: pubsub,
|
||||
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
|
||||
require.NoError(t, dv.AI.Chat.AIGatewayRoutingEnabled.Set("false"))
|
||||
}),
|
||||
},
|
||||
LicenseOptions: &coderdenttest.LicenseOptions{
|
||||
Features: license.Features{
|
||||
@@ -708,6 +730,9 @@ func TestChatStreamRelay(t *testing.T) {
|
||||
Options: &coderdtest.Options{
|
||||
Database: db,
|
||||
Pubsub: pubsub,
|
||||
DeploymentValues: coderdtest.DeploymentValues(t, func(dv *codersdk.DeploymentValues) {
|
||||
require.NoError(t, dv.AI.Chat.AIGatewayRoutingEnabled.Set("false"))
|
||||
}),
|
||||
},
|
||||
DontAddLicense: true,
|
||||
DontAddFirstUser: true,
|
||||
|
||||
Generated
+1
@@ -1613,6 +1613,7 @@ export interface ChatComputerUseProviderResponse {
|
||||
export interface ChatConfig {
|
||||
readonly acquire_batch_size: number;
|
||||
readonly debug_logging_enabled: boolean;
|
||||
readonly ai_gateway_routing_enabled: boolean;
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
|
||||
Reference in New Issue
Block a user