feat: use AI provider chat APIs (#25415)

This commit is contained in:
Michael Suchacz
2026-05-22 07:53:23 +02:00
committed by GitHub
parent 10efde3e6c
commit 06526a5822
41 changed files with 2195 additions and 1126 deletions
+40
View File
@@ -157,6 +157,46 @@ func TestAIProvidersCRUD(t *testing.T) {
require.Equal(t, "no-display", created.DisplayName)
})
t.Run("RequiredBaseURL", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
ctx := testutil.Context(t, testutil.WaitLong)
//nolint:gocritic // Owner role is the audience for this endpoint.
_, err := client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{
Type: codersdk.AIProviderTypeOpenAI,
Name: "missing-base-url",
Enabled: true,
})
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
require.Equal(t, "Invalid AI provider request.", sdkErr.Message)
require.Contains(t, sdkErr.Validations, codersdk.ValidationError{Field: "base_url", Detail: "base_url is required"})
created, err := client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{
Type: codersdk.AIProviderTypeOpenAI,
Name: "required-base-url",
Enabled: true,
BaseURL: "https://api.openai.com/v1",
})
require.NoError(t, err)
baseURL := "https://proxy.example.com/v1"
updated, err := client.UpdateAIProvider(ctx, created.Name, codersdk.UpdateAIProviderRequest{
BaseURL: &baseURL,
})
require.NoError(t, err)
require.Equal(t, baseURL, updated.BaseURL)
baseURL = ""
_, err = client.UpdateAIProvider(ctx, created.Name, codersdk.UpdateAIProviderRequest{
BaseURL: &baseURL,
})
sdkErr = requireSDKError(t, err, http.StatusBadRequest)
require.Equal(t, "Invalid AI provider request.", sdkErr.Message)
require.Contains(t, sdkErr.Validations, codersdk.ValidationError{Field: "base_url", Detail: "base_url is required"})
})
t.Run("DuplicateNameConflict", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
+11
View File
@@ -1202,6 +1202,17 @@ func New(options *Options) *API {
r.Delete("/", api.deleteUserSkill)
})
})
r.Route("/users/{user}/ai-provider-keys", func(r chi.Router) {
r.Use(
apiKeyMiddleware,
httpmw.ExtractUserParam(options.Database),
)
r.Get("/", api.listUserAIProviderKeyConfigs)
r.Route("/{aiProvider}", func(r chi.Router) {
r.Put("/", api.upsertUserAIProviderKey)
r.Delete("/", api.deleteUserAIProviderKey)
})
})
r.Route("/chats", func(r chi.Router) {
r.Use(
apiKeyMiddleware,
+13
View File
@@ -65,11 +65,24 @@ func CreateOpenAICompatChatModelConfig(
BaseURL: baseURL,
})
require.NoError(t, err)
aiProviderBaseURL := baseURL
if aiProviderBaseURL == "" {
aiProviderBaseURL = "https://api.example.com/v1"
}
provider, err := client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{
Type: codersdk.AIProviderType(TestChatProviderOpenAICompat),
Name: "test-" + uuid.NewString(),
BaseURL: aiProviderBaseURL,
Enabled: true,
APIKeys: []string{TestChatProviderAPIKey},
})
require.NoError(t, err)
contextLimit := int64(4096)
isDefault := true
modelConfig, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: TestChatProviderOpenAICompat,
AIProviderID: &provider.ID,
Model: TestChatModelOpenAICompat,
ContextLimit: &contextLimit,
IsDefault: &isDefault,
+38
View File
@@ -722,6 +722,24 @@ var (
}),
Scope: rbac.ScopeAll,
}.WithCachedASTValue()
subjectAIProviderMetadataReader = rbac.Subject{
Type: rbac.SubjectTypeAIProviderMetadataReader,
FriendlyName: "AI Provider Metadata Reader",
ID: uuid.Nil.String(),
Roles: rbac.Roles([]rbac.Role{
{
Identifier: rbac.RoleIdentifier{Name: "ai-provider-metadata-reader"},
DisplayName: "AI Provider Metadata Reader",
Site: rbac.Permissions(map[string][]policy.Action{
rbac.ResourceAIProvider.Type: {policy.ActionRead},
}),
User: []rbac.Permission{},
ByOrgID: map[string]rbac.OrgPermissions{},
},
}),
Scope: rbac.ScopeAll,
}.WithCachedASTValue()
)
// AsProvisionerd returns a context with an actor that has permissions required
@@ -846,6 +864,12 @@ func AsChatd(ctx context.Context) context.Context {
return As(ctx, subjectChatd)
}
// AsAIProviderMetadataReader returns a context with an actor that can read
// AI provider metadata and provider-key presence.
func AsAIProviderMetadataReader(ctx context.Context) context.Context {
return As(ctx, subjectAIProviderMetadataReader)
}
var AsRemoveActor = rbac.Subject{
ID: "remove-actor",
}
@@ -2546,6 +2570,13 @@ func (q *querier) GetAIProviderByID(ctx context.Context, id uuid.UUID) (database
return q.db.GetAIProviderByID(ctx, id)
}
func (q *querier) GetAIProviderByIDForReferenceLock(ctx context.Context, id uuid.UUID) (database.AIProvider, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceAIProvider); err != nil {
return database.AIProvider{}, err
}
return q.db.GetAIProviderByIDForReferenceLock(ctx, id)
}
func (q *querier) GetAIProviderByName(ctx context.Context, name string) (database.AIProvider, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceAIProvider); err != nil {
return database.AIProvider{}, err
@@ -2560,6 +2591,13 @@ func (q *querier) GetAIProviderKeyByID(ctx context.Context, id uuid.UUID) (datab
return q.db.GetAIProviderKeyByID(ctx, id)
}
func (q *querier) GetAIProviderKeyPresence(ctx context.Context, arg []uuid.UUID) ([]uuid.UUID, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceAIProvider); err != nil {
return nil, err
}
return q.db.GetAIProviderKeyPresence(ctx, arg)
}
func (q *querier) GetAIProviderKeys(ctx context.Context, includeDeleted bool) ([]database.AIProviderKey, error) {
// Callers pass include_deleted=TRUE only from the dbcrypt key
// rotation utility, which needs to re-encrypt every row that holds
+13
View File
@@ -6509,6 +6509,11 @@ func (s *MethodTestSuite) TestAIBridge() {
dbm.EXPECT().GetAIProviderByID(gomock.Any(), provider.ID).Return(provider, nil).AnyTimes()
check.Args(provider.ID).Asserts(rbac.ResourceAIProvider, policy.ActionRead).Returns(provider)
}))
s.Run("GetAIProviderByIDForReferenceLock", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
provider := testutil.Fake(s.T(), faker, database.AIProvider{})
dbm.EXPECT().GetAIProviderByIDForReferenceLock(gomock.Any(), provider.ID).Return(provider, nil).AnyTimes()
check.Args(provider.ID).Asserts(rbac.ResourceAIProvider, policy.ActionRead).Returns(provider)
}))
s.Run("GetAIProviderByName", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
provider := testutil.Fake(s.T(), faker, database.AIProvider{})
dbm.EXPECT().GetAIProviderByName(gomock.Any(), provider.Name).Return(provider, nil).AnyTimes()
@@ -6562,6 +6567,14 @@ func (s *MethodTestSuite) TestAIBridge() {
dbm.EXPECT().GetAIProviderKeyByID(gomock.Any(), key.ID).Return(key, nil).AnyTimes()
check.Args(key.ID).Asserts(rbac.ResourceAIProvider, policy.ActionRead).Returns(key)
}))
s.Run("GetAIProviderKeyPresence", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
providerA := testutil.Fake(s.T(), faker, database.AIProvider{})
providerB := testutil.Fake(s.T(), faker, database.AIProvider{})
arg := []uuid.UUID{providerA.ID, providerB.ID}
providerIDs := []uuid.UUID{providerA.ID}
dbm.EXPECT().GetAIProviderKeyPresence(gomock.Any(), arg).Return(providerIDs, nil).AnyTimes()
check.Args(arg).Asserts(rbac.ResourceAIProvider, policy.ActionRead).Returns(providerIDs)
}))
s.Run("GetAIProviderKeysByProviderID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
provider := testutil.Fake(s.T(), faker, database.AIProvider{})
keyA := testutil.Fake(s.T(), faker, database.AIProviderKey{ProviderID: provider.ID})
+1
View File
@@ -160,6 +160,7 @@ func ChatModelConfig(t testing.TB, db database.Store, seed database.ChatModelCon
ContextLimit: takeFirst(seed.ContextLimit, defaultChatModelContextLimit),
CompressionThreshold: takeFirst(seed.CompressionThreshold, defaultChatModelCompressionThreshold),
Options: takeFirstSlice(seed.Options, json.RawMessage(`{}`)),
AIProviderID: seed.AIProviderID,
}
for _, fn := range munge {
fn(&params)
+16
View File
@@ -1041,6 +1041,14 @@ func (m queryMetricsStore) GetAIProviderByID(ctx context.Context, id uuid.UUID)
return r0, r1
}
func (m queryMetricsStore) GetAIProviderByIDForReferenceLock(ctx context.Context, id uuid.UUID) (database.AIProvider, error) {
start := time.Now()
r0, r1 := m.s.GetAIProviderByIDForReferenceLock(ctx, id)
m.queryLatencies.WithLabelValues("GetAIProviderByIDForReferenceLock").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetAIProviderByIDForReferenceLock").Inc()
return r0, r1
}
func (m queryMetricsStore) GetAIProviderByName(ctx context.Context, name string) (database.AIProvider, error) {
start := time.Now()
r0, r1 := m.s.GetAIProviderByName(ctx, name)
@@ -1057,6 +1065,14 @@ func (m queryMetricsStore) GetAIProviderKeyByID(ctx context.Context, id uuid.UUI
return r0, r1
}
func (m queryMetricsStore) GetAIProviderKeyPresence(ctx context.Context, arg []uuid.UUID) ([]uuid.UUID, error) {
start := time.Now()
r0, r1 := m.s.GetAIProviderKeyPresence(ctx, arg)
m.queryLatencies.WithLabelValues("GetAIProviderKeyPresence").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetAIProviderKeyPresence").Inc()
return r0, r1
}
func (m queryMetricsStore) GetAIProviderKeys(ctx context.Context, includeDeleted bool) ([]database.AIProviderKey, error) {
start := time.Now()
r0, r1 := m.s.GetAIProviderKeys(ctx, includeDeleted)
+30
View File
@@ -1799,6 +1799,21 @@ func (mr *MockStoreMockRecorder) GetAIProviderByID(ctx, id any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAIProviderByID", reflect.TypeOf((*MockStore)(nil).GetAIProviderByID), ctx, id)
}
// GetAIProviderByIDForReferenceLock mocks base method.
func (m *MockStore) GetAIProviderByIDForReferenceLock(ctx context.Context, id uuid.UUID) (database.AIProvider, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAIProviderByIDForReferenceLock", ctx, id)
ret0, _ := ret[0].(database.AIProvider)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetAIProviderByIDForReferenceLock indicates an expected call of GetAIProviderByIDForReferenceLock.
func (mr *MockStoreMockRecorder) GetAIProviderByIDForReferenceLock(ctx, id any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAIProviderByIDForReferenceLock", reflect.TypeOf((*MockStore)(nil).GetAIProviderByIDForReferenceLock), ctx, id)
}
// GetAIProviderByName mocks base method.
func (m *MockStore) GetAIProviderByName(ctx context.Context, name string) (database.AIProvider, error) {
m.ctrl.T.Helper()
@@ -1829,6 +1844,21 @@ func (mr *MockStoreMockRecorder) GetAIProviderKeyByID(ctx, id any) *gomock.Call
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAIProviderKeyByID", reflect.TypeOf((*MockStore)(nil).GetAIProviderKeyByID), ctx, id)
}
// GetAIProviderKeyPresence mocks base method.
func (m *MockStore) GetAIProviderKeyPresence(ctx context.Context, providerIds []uuid.UUID) ([]uuid.UUID, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAIProviderKeyPresence", ctx, providerIds)
ret0, _ := ret[0].([]uuid.UUID)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetAIProviderKeyPresence indicates an expected call of GetAIProviderKeyPresence.
func (mr *MockStoreMockRecorder) GetAIProviderKeyPresence(ctx, providerIds any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAIProviderKeyPresence", reflect.TypeOf((*MockStore)(nil).GetAIProviderKeyPresence), ctx, providerIds)
}
// GetAIProviderKeys mocks base method.
func (m *MockStore) GetAIProviderKeys(ctx context.Context, includeDeleted bool) ([]database.AIProviderKey, error) {
m.ctrl.T.Helper()
+6
View File
@@ -253,8 +253,14 @@ type sqlcQuerier interface {
GetAIBridgeUserPromptsByInterceptionID(ctx context.Context, interceptionID uuid.UUID) ([]AIBridgeUserPrompt, error)
GetAIModelPriceByProviderModel(ctx context.Context, arg GetAIModelPriceByProviderModelParams) (AiModelPrice, error)
GetAIProviderByID(ctx context.Context, id uuid.UUID) (AIProvider, error)
// Lock the provider row until the model-config write completes. The
// transaction alone does not stop a concurrent soft-delete or disable
// between validation and writing the model config reference.
GetAIProviderByIDForReferenceLock(ctx context.Context, id uuid.UUID) (AIProvider, error)
GetAIProviderByName(ctx context.Context, name string) (AIProvider, error)
GetAIProviderKeyByID(ctx context.Context, id uuid.UUID) (AIProviderKey, error)
// Returns the provider IDs that have at least one provider-scoped key.
GetAIProviderKeyPresence(ctx context.Context, providerIds []uuid.UUID) ([]uuid.UUID, error)
// Returns AI provider key rows. By default, only rows whose parent
// provider is live (deleted = FALSE) are returned, so the API list
// handler can fetch every visible provider's keys in a single query.
+71
View File
@@ -10567,6 +10567,77 @@ func TestInsertWorkspaceAgentDevcontainers(t *testing.T) {
}
}
func TestGetEnabledChatModelConfigsUsesAIProviders(t *testing.T) {
t.Parallel()
store, _ := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitMedium)
enabledProvider := dbgen.AIProvider(t, store, database.AIProvider{
Type: database.AiProviderTypeOpenrouter,
Name: "openrouter-" + uuid.NewString(),
})
disabledProvider := dbgen.AIProvider(t, store, database.AIProvider{
Type: database.AiProviderTypeVercel,
Name: "vercel-" + uuid.NewString(),
}, func(params *database.InsertAIProviderParams) {
params.Enabled = false
})
enabledConfig := dbgen.ChatModelConfig(t, store, database.ChatModelConfig{
Provider: string(enabledProvider.Type),
Model: "openrouter-model-" + uuid.NewString(),
AIProviderID: uuid.NullUUID{
UUID: enabledProvider.ID,
Valid: true,
},
})
disabledProviderConfig := dbgen.ChatModelConfig(t, store, database.ChatModelConfig{
Provider: string(disabledProvider.Type),
Model: "vercel-model-" + uuid.NewString(),
AIProviderID: uuid.NullUUID{
UUID: disabledProvider.ID,
Valid: true,
},
})
disabledModelConfig := dbgen.ChatModelConfig(t, store, database.ChatModelConfig{
Provider: string(enabledProvider.Type),
Model: "disabled-model-" + uuid.NewString(),
AIProviderID: uuid.NullUUID{
UUID: enabledProvider.ID,
Valid: true,
},
}, func(params *database.InsertChatModelConfigParams) {
params.Enabled = false
})
legacyProvider := dbgen.ChatProvider(t, store, database.ChatProvider{Provider: "google"})
legacyConfig := dbgen.ChatModelConfig(t, store, database.ChatModelConfig{
Provider: legacyProvider.Provider,
Model: "google-model-" + uuid.NewString(),
})
configs, err := store.GetEnabledChatModelConfigs(ctx)
require.NoError(t, err)
require.True(t, slices.ContainsFunc(configs, func(config database.ChatModelConfig) bool {
return config.ID == enabledConfig.ID
}))
require.True(t, slices.ContainsFunc(configs, func(config database.ChatModelConfig) bool {
return config.ID == legacyConfig.ID
}))
require.False(t, slices.ContainsFunc(configs, func(config database.ChatModelConfig) bool {
return config.ID == disabledProviderConfig.ID
}))
require.False(t, slices.ContainsFunc(configs, func(config database.ChatModelConfig) bool {
return config.ID == disabledModelConfig.ID
}))
config, err := store.GetEnabledChatModelConfigByID(ctx, enabledConfig.ID)
require.NoError(t, err)
require.Equal(t, enabledConfig.ID, config.ID)
_, err = store.GetEnabledChatModelConfigByID(ctx, disabledProviderConfig.ID)
require.ErrorIs(t, err, sql.ErrNoRows)
}
func TestInsertChatMessages(t *testing.T) {
t.Parallel()
+93 -9
View File
@@ -146,6 +146,41 @@ func (q *sqlQuerier) GetAIProviderKeyByID(ctx context.Context, id uuid.UUID) (AI
return i, err
}
const getAIProviderKeyPresence = `-- name: GetAIProviderKeyPresence :many
SELECT DISTINCT
provider_id
FROM
ai_provider_keys
WHERE
provider_id = ANY($1::uuid[])
ORDER BY
provider_id ASC
`
// Returns the provider IDs that have at least one provider-scoped key.
func (q *sqlQuerier) GetAIProviderKeyPresence(ctx context.Context, providerIds []uuid.UUID) ([]uuid.UUID, error) {
rows, err := q.db.QueryContext(ctx, getAIProviderKeyPresence, pq.Array(providerIds))
if err != nil {
return nil, err
}
defer rows.Close()
var items []uuid.UUID
for rows.Next() {
var provider_id uuid.UUID
if err := rows.Scan(&provider_id); err != nil {
return nil, err
}
items = append(items, provider_id)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getAIProviderKeys = `-- name: GetAIProviderKeys :many
SELECT
ai_provider_keys.id, ai_provider_keys.provider_id, ai_provider_keys.api_key, ai_provider_keys.api_key_key_id, ai_provider_keys.created_at, ai_provider_keys.updated_at
@@ -371,6 +406,38 @@ func (q *sqlQuerier) GetAIProviderByID(ctx context.Context, id uuid.UUID) (AIPro
return i, err
}
const getAIProviderByIDForReferenceLock = `-- name: GetAIProviderByIDForReferenceLock :one
SELECT
id, type, name, display_name, enabled, deleted, base_url, settings, settings_key_id, created_at, updated_at
FROM
ai_providers
WHERE
id = $1::uuid AND deleted = FALSE
FOR SHARE
`
// Lock the provider row until the model-config write completes. The
// transaction alone does not stop a concurrent soft-delete or disable
// between validation and writing the model config reference.
func (q *sqlQuerier) GetAIProviderByIDForReferenceLock(ctx context.Context, id uuid.UUID) (AIProvider, error) {
row := q.db.QueryRowContext(ctx, getAIProviderByIDForReferenceLock, id)
var i AIProvider
err := row.Scan(
&i.ID,
&i.Type,
&i.Name,
&i.DisplayName,
&i.Enabled,
&i.Deleted,
&i.BaseUrl,
&i.Settings,
&i.SettingsKeyID,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const getAIProviderByName = `-- name: GetAIProviderByName :one
SELECT
id, type, name, display_name, enabled, deleted, base_url, settings, settings_key_id, created_at, updated_at
@@ -5097,13 +5164,18 @@ SELECT
cmc.id, cmc.provider, cmc.model, cmc.display_name, cmc.created_by, cmc.updated_by, cmc.enabled, cmc.is_default, cmc.deleted, cmc.deleted_at, cmc.created_at, cmc.updated_at, cmc.context_limit, cmc.compression_threshold, cmc.options, cmc.ai_provider_id
FROM
chat_model_configs cmc
JOIN
chat_providers cp ON cp.provider = cmc.provider
LEFT JOIN
ai_providers ap ON ap.id = cmc.ai_provider_id
LEFT JOIN
chat_providers cp ON cp.provider = cmc.provider AND cmc.ai_provider_id IS NULL
WHERE
cmc.id = $1::uuid
AND cmc.deleted = FALSE
AND cmc.enabled = TRUE
AND cp.enabled = TRUE
AND (
(cmc.ai_provider_id IS NOT NULL AND ap.enabled = TRUE AND ap.deleted = FALSE)
OR (cmc.ai_provider_id IS NULL AND cp.enabled = TRUE)
)
`
// Providers can be disabled independently of their model configs.
@@ -5137,12 +5209,17 @@ SELECT
cmc.id, cmc.provider, cmc.model, cmc.display_name, cmc.created_by, cmc.updated_by, cmc.enabled, cmc.is_default, cmc.deleted, cmc.deleted_at, cmc.created_at, cmc.updated_at, cmc.context_limit, cmc.compression_threshold, cmc.options, cmc.ai_provider_id
FROM
chat_model_configs cmc
JOIN
chat_providers cp ON cp.provider = cmc.provider
LEFT JOIN
ai_providers ap ON ap.id = cmc.ai_provider_id
LEFT JOIN
chat_providers cp ON cp.provider = cmc.provider AND cmc.ai_provider_id IS NULL
WHERE
cmc.enabled = TRUE
AND cmc.deleted = FALSE
AND cp.enabled = TRUE
AND (
(cmc.ai_provider_id IS NOT NULL AND ap.enabled = TRUE AND ap.deleted = FALSE)
OR (cmc.ai_provider_id IS NULL AND cp.enabled = TRUE)
)
ORDER BY
cmc.provider ASC,
cmc.model ASC,
@@ -5201,7 +5278,8 @@ INSERT INTO chat_model_configs (
is_default,
context_limit,
compression_threshold,
options
options,
ai_provider_id
) VALUES (
$1::text,
$2::text,
@@ -5212,7 +5290,8 @@ INSERT INTO chat_model_configs (
$7::boolean,
$8::bigint,
$9::integer,
$10::jsonb
$10::jsonb,
$11::uuid
)
RETURNING
id, provider, model, display_name, created_by, updated_by, enabled, is_default, deleted, deleted_at, created_at, updated_at, context_limit, compression_threshold, options, ai_provider_id
@@ -5229,6 +5308,7 @@ type InsertChatModelConfigParams struct {
ContextLimit int64 `db:"context_limit" json:"context_limit"`
CompressionThreshold int32 `db:"compression_threshold" json:"compression_threshold"`
Options json.RawMessage `db:"options" json:"options"`
AIProviderID uuid.NullUUID `db:"ai_provider_id" json:"ai_provider_id"`
}
func (q *sqlQuerier) InsertChatModelConfig(ctx context.Context, arg InsertChatModelConfigParams) (ChatModelConfig, error) {
@@ -5243,6 +5323,7 @@ func (q *sqlQuerier) InsertChatModelConfig(ctx context.Context, arg InsertChatMo
arg.ContextLimit,
arg.CompressionThreshold,
arg.Options,
arg.AIProviderID,
)
var i ChatModelConfig
err := row.Scan(
@@ -5295,9 +5376,10 @@ SET
context_limit = $7::bigint,
compression_threshold = $8::integer,
options = $9::jsonb,
ai_provider_id = $10::uuid,
updated_at = NOW()
WHERE
id = $10::uuid
id = $11::uuid
AND deleted = FALSE
RETURNING
id, provider, model, display_name, created_by, updated_by, enabled, is_default, deleted, deleted_at, created_at, updated_at, context_limit, compression_threshold, options, ai_provider_id
@@ -5313,6 +5395,7 @@ type UpdateChatModelConfigParams struct {
ContextLimit int64 `db:"context_limit" json:"context_limit"`
CompressionThreshold int32 `db:"compression_threshold" json:"compression_threshold"`
Options json.RawMessage `db:"options" json:"options"`
AIProviderID uuid.NullUUID `db:"ai_provider_id" json:"ai_provider_id"`
ID uuid.UUID `db:"id" json:"id"`
}
@@ -5327,6 +5410,7 @@ func (q *sqlQuerier) UpdateChatModelConfig(ctx context.Context, arg UpdateChatMo
arg.ContextLimit,
arg.CompressionThreshold,
arg.Options,
arg.AIProviderID,
arg.ID,
)
var i ChatModelConfig
@@ -21,6 +21,17 @@ ORDER BY
created_at ASC,
id ASC;
-- name: GetAIProviderKeyPresence :many
-- Returns the provider IDs that have at least one provider-scoped key.
SELECT DISTINCT
provider_id
FROM
ai_provider_keys
WHERE
provider_id = ANY(@provider_ids::uuid[])
ORDER BY
provider_id ASC;
-- name: GetAIProviderKeys :many
-- Returns AI provider key rows. By default, only rows whose parent
-- provider is live (deleted = FALSE) are returned, so the API list
+12
View File
@@ -6,6 +6,18 @@ FROM
WHERE
id = @id::uuid AND deleted = FALSE;
-- name: GetAIProviderByIDForReferenceLock :one
SELECT
*
FROM
ai_providers
WHERE
id = @id::uuid AND deleted = FALSE
-- Lock the provider row until the model-config write completes. The
-- transaction alone does not stop a concurrent soft-delete or disable
-- between validation and writing the model config reference.
FOR SHARE;
-- name: GetAIProviderByName :one
SELECT
*
+21 -8
View File
@@ -34,12 +34,17 @@ SELECT
cmc.*
FROM
chat_model_configs cmc
JOIN
chat_providers cp ON cp.provider = cmc.provider
LEFT JOIN
ai_providers ap ON ap.id = cmc.ai_provider_id
LEFT JOIN
chat_providers cp ON cp.provider = cmc.provider AND cmc.ai_provider_id IS NULL
WHERE
cmc.enabled = TRUE
AND cmc.deleted = FALSE
AND cp.enabled = TRUE
AND (
(cmc.ai_provider_id IS NOT NULL AND ap.enabled = TRUE AND ap.deleted = FALSE)
OR (cmc.ai_provider_id IS NULL AND cp.enabled = TRUE)
)
ORDER BY
cmc.provider ASC,
cmc.model ASC,
@@ -53,13 +58,18 @@ FROM
chat_model_configs cmc
-- Providers can be disabled independently of their model configs.
-- Check both to ensure the selected config is actually usable.
JOIN
chat_providers cp ON cp.provider = cmc.provider
LEFT JOIN
ai_providers ap ON ap.id = cmc.ai_provider_id
LEFT JOIN
chat_providers cp ON cp.provider = cmc.provider AND cmc.ai_provider_id IS NULL
WHERE
cmc.id = @id::uuid
AND cmc.deleted = FALSE
AND cmc.enabled = TRUE
AND cp.enabled = TRUE;
AND (
(cmc.ai_provider_id IS NOT NULL AND ap.enabled = TRUE AND ap.deleted = FALSE)
OR (cmc.ai_provider_id IS NULL AND cp.enabled = TRUE)
);
-- name: InsertChatModelConfig :one
INSERT INTO chat_model_configs (
@@ -72,7 +82,8 @@ INSERT INTO chat_model_configs (
is_default,
context_limit,
compression_threshold,
options
options,
ai_provider_id
) VALUES (
@provider::text,
@model::text,
@@ -83,7 +94,8 @@ INSERT INTO chat_model_configs (
@is_default::boolean,
@context_limit::bigint,
@compression_threshold::integer,
@options::jsonb
@options::jsonb,
sqlc.narg('ai_provider_id')::uuid
)
RETURNING
*;
@@ -101,6 +113,7 @@ SET
context_limit = @context_limit::bigint,
compression_threshold = @compression_threshold::integer,
options = @options::jsonb,
ai_provider_id = sqlc.narg('ai_provider_id')::uuid,
updated_at = NOW()
WHERE
id = @id::uuid
+265 -16
View File
@@ -6396,6 +6396,179 @@ func convertChatMessages(messages []database.ChatMessage) []codersdk.ChatMessage
return result
}
func parseUserAIProviderID(r *http.Request) (uuid.UUID, error) {
return uuid.Parse(chi.URLParam(r, "aiProvider"))
}
func convertAIProviderSummary(provider database.AIProvider) codersdk.AIProviderSummary {
displayName := provider.Name
if provider.DisplayName.Valid && provider.DisplayName.String != "" {
displayName = provider.DisplayName.String
}
return codersdk.AIProviderSummary{
ID: provider.ID,
Type: codersdk.AIProviderType(provider.Type),
Name: provider.Name,
DisplayName: displayName,
Enabled: provider.Enabled,
Deleted: provider.Deleted,
}
}
func (api *API) listUserAIProviderKeyConfigs(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
targetUser := httpmw.UserParam(r)
//nolint:gocritic // Users can list limited provider metadata to manage their own AI provider keys.
metadataCtx := dbauthz.AsAIProviderMetadataReader(ctx)
providers, err := api.Database.GetAIProviders(metadataCtx, database.GetAIProvidersParams{IncludeDisabled: true})
if err != nil {
api.Logger.Error(ctx, "failed to list user AI provider configs", slog.Error(err), slog.F("user_id", targetUser.ID))
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to list AI providers."})
return
}
keys, err := api.Database.GetUserAIProviderKeysByUserID(ctx, targetUser.ID)
if err != nil {
api.Logger.Error(ctx, "failed to list user AI provider keys", slog.Error(err), slog.F("user_id", targetUser.ID))
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to list user AI provider keys."})
return
}
keysByProviderID := make(map[uuid.UUID]struct{}, len(keys))
for _, key := range keys {
keysByProviderID[key.AIProviderID] = struct{}{}
}
visibleProviders := make([]database.AIProvider, 0, len(providers))
visibleProviderIDs := make([]uuid.UUID, 0, len(providers))
for _, provider := range providers {
_, hasUserKey := keysByProviderID[provider.ID]
if !provider.Enabled && !hasUserKey {
continue
}
visibleProviders = append(visibleProviders, provider)
visibleProviderIDs = append(visibleProviderIDs, provider.ID)
}
providerKeysByProviderID := make(map[uuid.UUID]struct{}, len(visibleProviderIDs))
if len(visibleProviderIDs) > 0 {
providerKeyIDs, err := api.Database.GetAIProviderKeyPresence(metadataCtx, visibleProviderIDs)
if err != nil {
api.Logger.Error(ctx, "failed to list AI provider key presence", slog.Error(err), slog.F("user_id", targetUser.ID))
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to list AI provider keys."})
return
}
for _, providerID := range providerKeyIDs {
providerKeysByProviderID[providerID] = struct{}{}
}
}
byokEnabled := api.DeploymentValues.AI.BridgeConfig.AllowBYOK.Value()
configs := make([]codersdk.UserAIProviderKeyConfig, 0, len(visibleProviders))
for _, provider := range visibleProviders {
_, hasUserKey := keysByProviderID[provider.ID]
_, hasProviderKey := providerKeysByProviderID[provider.ID]
configs = append(configs, codersdk.UserAIProviderKeyConfig{
Provider: convertAIProviderSummary(provider),
HasUserAPIKey: hasUserKey,
HasProviderAPIKey: hasProviderKey,
BYOKEnabled: byokEnabled,
})
}
httpapi.Write(ctx, rw, http.StatusOK, configs)
}
func (api *API) upsertUserAIProviderKey(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if !api.DeploymentValues.AI.BridgeConfig.AllowBYOK.Value() {
httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{Message: "BYOK is disabled."})
return
}
targetUser := httpmw.UserParam(r)
providerID, err := parseUserAIProviderID(r)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{Message: "Invalid AI provider ID."})
return
}
//nolint:gocritic // Users can attach their own key to an enabled provider without AI provider admin permissions.
metadataCtx := dbauthz.AsAIProviderMetadataReader(ctx)
provider, err := api.Database.GetAIProviderByID(metadataCtx, providerID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{Message: "AI provider not found."})
return
}
api.Logger.Error(ctx, "failed to get AI provider", slog.Error(err), slog.F("ai_provider_id", providerID))
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to get AI provider."})
return
}
if !provider.Enabled {
httpapi.Write(ctx, rw, http.StatusPreconditionFailed, codersdk.Response{Message: "AI provider is disabled."})
return
}
var req codersdk.CreateUserAIProviderKeyRequest
if !httpapi.Read(ctx, rw, r, &req) {
return
}
if err := validateChatProviderAPIKeySize(req.APIKey); err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "API key too large.",
Detail: err.Error(),
})
return
}
if req.APIKey == "" {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{Message: "API key is required."})
return
}
if strings.TrimSpace(req.APIKey) != req.APIKey {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{Message: "API key must not contain leading or trailing whitespace."})
return
}
providerKeys, err := api.Database.GetAIProviderKeyPresence(metadataCtx, []uuid.UUID{providerID})
if err != nil {
api.Logger.Error(ctx, "failed to list AI provider key presence", slog.Error(err), slog.F("ai_provider_id", providerID))
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to list AI provider keys."})
return
}
now := api.Clock.Now()
_, err = api.Database.UpsertUserAIProviderKey(ctx, database.UpsertUserAIProviderKeyParams{
ID: uuid.New(),
UserID: targetUser.ID,
AIProviderID: providerID,
APIKey: req.APIKey,
ApiKeyKeyID: sql.NullString{},
CreatedAt: now,
UpdatedAt: now,
})
if err != nil {
api.Logger.Error(ctx, "failed to update user AI provider key", slog.Error(err), slog.F("user_id", targetUser.ID), slog.F("ai_provider_id", providerID))
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to update user AI provider key."})
return
}
httpapi.Write(ctx, rw, http.StatusOK, codersdk.UserAIProviderKeyConfig{
Provider: convertAIProviderSummary(provider),
HasUserAPIKey: true,
HasProviderAPIKey: len(providerKeys) > 0,
BYOKEnabled: true,
})
}
func (api *API) deleteUserAIProviderKey(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
targetUser := httpmw.UserParam(r)
providerID, err := parseUserAIProviderID(r)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{Message: "Invalid AI provider ID."})
return
}
if err := api.Database.DeleteUserAIProviderKey(ctx, database.DeleteUserAIProviderKeyParams{UserID: targetUser.ID, AIProviderID: providerID}); err != nil {
api.Logger.Error(ctx, "failed to delete user AI provider key", slog.Error(err), slog.F("user_id", targetUser.ID), slog.F("ai_provider_id", providerID))
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to delete user AI provider key."})
return
}
httpapi.Write(ctx, rw, http.StatusNoContent, nil)
}
func (api *API) listChatProviders(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
//nolint:gocritic // System context required to read enabled chat providers.
@@ -6890,6 +7063,7 @@ func (api *API) listUserChatProviderConfigs(rw http.ResponseWriter, r *http.Requ
provider,
hasUserAPIKey,
hasCentralAPIKeyFallback,
api.DeploymentValues.AI.BridgeConfig.AllowBYOK.Value(),
),
)
}
@@ -6978,6 +7152,7 @@ func (api *API) upsertUserChatProviderKey(rw http.ResponseWriter, r *http.Reques
provider,
true,
hasCentralAPIKeyFallback,
api.DeploymentValues.AI.BridgeConfig.AllowBYOK.Value(),
),
)
}
@@ -7052,14 +7227,29 @@ func (api *API) createChatModelConfig(rw http.ResponseWriter, r *http.Request) {
return
}
provider := normalizeChatProvider(req.Provider)
if provider == "" {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid provider.",
Detail: chatProviderValidationDetail(),
if req.AIProviderID == nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{Message: "AI provider ID is required."})
return
}
//nolint:gocritic // The route already authorized chat model config updates.
aiProvider, err := api.Database.GetAIProviderByID(dbauthz.AsChatd(ctx), *req.AIProviderID)
if err != nil {
if httpapi.Is404Error(err) {
httpapi.Write(ctx, rw, http.StatusPreconditionFailed, codersdk.Response{Message: "AI provider is not configured."})
return
}
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to get AI provider.",
Detail: err.Error(),
})
return
}
if !aiProvider.Enabled {
httpapi.Write(ctx, rw, http.StatusPreconditionFailed, codersdk.Response{Message: "AI provider is disabled."})
return
}
provider := string(aiProvider.Type)
aiProviderID := uuid.NullUUID{UUID: aiProvider.ID, Valid: true}
model := strings.TrimSpace(req.Model)
if model == "" {
@@ -7117,15 +7307,25 @@ func (api *API) createChatModelConfig(rw http.ResponseWriter, r *http.Request) {
ContextLimit: contextLimit,
CompressionThreshold: compressionThreshold,
Options: modelConfigRaw,
AIProviderID: aiProviderID,
CreatedBy: uuid.NullUUID{UUID: apiKey.UserID, Valid: apiKey.UserID != uuid.Nil},
UpdatedBy: uuid.NullUUID{UUID: apiKey.UserID, Valid: apiKey.UserID != uuid.Nil},
}
var inserted database.ChatModelConfig
err := api.Database.InTx(func(tx database.Store) error {
if err := requireChatProviderForModelConfig(ctx, tx, insertParams.Provider); err != nil {
return err
err = api.Database.InTx(func(tx database.Store) error {
//nolint:gocritic // The route already authorized chat model config updates.
lockedAIProvider, err := tx.GetAIProviderByIDForReferenceLock(dbauthz.AsChatd(ctx), insertParams.AIProviderID.UUID)
if err != nil {
if xerrors.Is(err, sql.ErrNoRows) {
return errChatProviderNotConfigured
}
return xerrors.Errorf("get AI provider for update: %w", err)
}
if !lockedAIProvider.Enabled {
return errChatProviderNotConfigured
}
insertParams.Provider = string(lockedAIProvider.Type)
insertAsDefault := isDefault
if !insertAsDefault {
@@ -7173,7 +7373,7 @@ func (api *API) createChatModelConfig(rw http.ResponseWriter, r *http.Request) {
})
return
case xerrors.Is(err, errChatProviderNotConfigured):
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
httpapi.Write(ctx, rw, http.StatusPreconditionFailed, codersdk.Response{
Message: "Chat provider is not configured.",
Detail: err.Error(),
})
@@ -7224,15 +7424,40 @@ func (api *API) updateChatModelConfig(rw http.ResponseWriter, r *http.Request) {
}
provider := existing.Provider
if strings.TrimSpace(req.Provider) != "" {
provider = normalizeChatProvider(req.Provider)
if provider == "" {
aiProviderID := existing.AIProviderID
if req.AIProviderID != nil {
//nolint:gocritic // The route already authorized chat model config updates.
aiProvider, err := api.Database.GetAIProviderByID(dbauthz.AsChatd(ctx), *req.AIProviderID)
if err != nil {
if httpapi.Is404Error(err) {
httpapi.Write(ctx, rw, http.StatusPreconditionFailed, codersdk.Response{Message: "AI provider is not configured."})
return
}
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Failed to get AI provider.",
Detail: err.Error(),
})
return
}
if !aiProvider.Enabled {
httpapi.Write(ctx, rw, http.StatusPreconditionFailed, codersdk.Response{Message: "AI provider is disabled."})
return
}
provider = string(aiProvider.Type)
aiProviderID = uuid.NullUUID{UUID: aiProvider.ID, Valid: true}
} else if strings.TrimSpace(req.Provider) != "" {
requestedProvider := normalizeChatProvider(req.Provider)
if requestedProvider == "" {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Invalid provider.",
Detail: chatProviderValidationDetail(),
})
return
}
provider = requestedProvider
if requestedProvider != existing.Provider {
aiProviderID = uuid.NullUUID{}
}
}
model := existing.Model
@@ -7299,14 +7524,30 @@ func (api *API) updateChatModelConfig(rw http.ResponseWriter, r *http.Request) {
ContextLimit: contextLimit,
CompressionThreshold: compressionThreshold,
Options: modelConfigRaw,
AIProviderID: aiProviderID,
UpdatedBy: uuid.NullUUID{UUID: apiKey.UserID, Valid: apiKey.UserID != uuid.Nil},
ID: existing.ID,
}
var updated database.ChatModelConfig
err = api.Database.InTx(func(tx database.Store) error {
if err := requireChatProviderForModelConfig(ctx, tx, updateParams.Provider); err != nil {
return err
if updateParams.AIProviderID.Valid && req.AIProviderID != nil {
//nolint:gocritic // The route already authorized chat model config updates.
aiProvider, err := tx.GetAIProviderByIDForReferenceLock(dbauthz.AsChatd(ctx), updateParams.AIProviderID.UUID)
if err != nil {
if xerrors.Is(err, sql.ErrNoRows) {
return errChatProviderNotConfigured
}
return xerrors.Errorf("get AI provider for update: %w", err)
}
if !aiProvider.Enabled {
return errChatProviderNotConfigured
}
updateParams.Provider = string(aiProvider.Type)
} else if !updateParams.AIProviderID.Valid {
if err := requireChatProviderForModelConfig(ctx, tx, updateParams.Provider); err != nil {
return err
}
}
setAsDefault := updateParams.IsDefault && !existing.IsDefault
@@ -7357,7 +7598,7 @@ func (api *API) updateChatModelConfig(rw http.ResponseWriter, r *http.Request) {
})
return
case xerrors.Is(err, errChatProviderNotConfigured):
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
httpapi.Write(ctx, rw, http.StatusPreconditionFailed, codersdk.Response{
Message: "Chat provider is not configured.",
Detail: err.Error(),
})
@@ -7487,6 +7728,7 @@ func chatModelConfigToUpdateParams(
ContextLimit: config.ContextLimit,
CompressionThreshold: config.CompressionThreshold,
Options: config.Options,
AIProviderID: config.AIProviderID,
UpdatedBy: uuid.NullUUID{},
ID: config.ID,
}
@@ -7589,6 +7831,7 @@ func convertUserChatProviderConfig(
provider database.ChatProvider,
hasUserAPIKey bool,
hasCentralAPIKeyFallback bool,
byokEnabled bool,
) codersdk.UserChatProviderConfig {
displayName := strings.TrimSpace(provider.DisplayName)
if displayName == "" {
@@ -7601,13 +7844,19 @@ func convertUserChatProviderConfig(
DisplayName: displayName,
HasUserAPIKey: hasUserAPIKey,
HasCentralAPIKeyFallback: hasCentralAPIKeyFallback,
BYOKEnabled: byokEnabled,
}
}
func convertChatModelConfig(config database.ChatModelConfig) codersdk.ChatModelConfig {
var aiProviderID *uuid.UUID
if config.AIProviderID.Valid {
aiProviderID = &config.AIProviderID.UUID
}
return codersdk.ChatModelConfig{
ID: config.ID,
Provider: config.Provider,
AIProviderID: aiProviderID,
Model: config.Model,
DisplayName: config.DisplayName,
Enabled: config.Enabled,
@@ -7787,7 +8036,7 @@ const maxChatProviderAPIKeySize = 10240 // 10 KB
func validateChatProviderAPIKeySize(apiKey string) error {
if len(apiKey) > maxChatProviderAPIKeySize {
return xerrors.Errorf("API key exceeds maximum size of %d bytes", maxChatProviderAPIKeySize)
return xerrors.Errorf("API key exceeds maximum size of 10 KB (%d bytes)", maxChatProviderAPIKeySize)
}
return nil
}
+444 -54
View File
@@ -1801,16 +1801,19 @@ func TestListChatModels(t *testing.T) {
client := newChatClient(t)
_ = coderdtest.CreateFirstUser(t, client.Client)
provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
Provider: "anthropic",
providerType := database.AiProviderTypeAnthropic
chatProvider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
Provider: string(providerType),
CentralAPIKeyEnabled: ptr.Ref(false),
AllowUserAPIKey: ptr.Ref(true),
})
require.NoError(t, err)
aiProvider := createAIProviderForTest(t, client, string(providerType), "")
contextLimit := int64(4096)
_, err = client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: "anthropic",
Provider: string(providerType),
AIProviderID: &aiProvider.ID,
Model: "claude-sonnet",
ContextLimit: &contextLimit,
})
@@ -1821,7 +1824,7 @@ func TestListChatModels(t *testing.T) {
var anthropicProvider *codersdk.ChatModelProvider
for i := range models.Providers {
if models.Providers[i].Provider == "anthropic" {
if models.Providers[i].Provider == string(providerType) {
anthropicProvider = &models.Providers[i]
break
}
@@ -1830,7 +1833,7 @@ func TestListChatModels(t *testing.T) {
require.False(t, anthropicProvider.Available)
require.Equal(t, codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired, anthropicProvider.UnavailableReason)
_, err = client.UpsertUserChatProviderKey(ctx, provider.ID, codersdk.CreateUserChatProviderKeyRequest{
_, err = client.UpsertUserChatProviderKey(ctx, chatProvider.ID, codersdk.CreateUserChatProviderKeyRequest{
APIKey: "user-api-key",
})
require.NoError(t, err)
@@ -1856,7 +1859,7 @@ func TestListChatModels(t *testing.T) {
client := newChatClient(t)
_ = coderdtest.CreateFirstUser(t, client.Client)
provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
chatProvider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
Provider: "google",
APIKey: "central-api-key",
CentralAPIKeyEnabled: ptr.Ref(true),
@@ -1864,10 +1867,12 @@ func TestListChatModels(t *testing.T) {
AllowCentralAPIKeyFallback: ptr.Ref(true),
})
require.NoError(t, err)
aiProvider := createAIProviderForTest(t, client, "google", "provider-api-key")
contextLimit := int64(4096)
_, err = client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: "google",
AIProviderID: &aiProvider.ID,
Model: "gemini-1.5-pro",
ContextLimit: &contextLimit,
})
@@ -1886,7 +1891,7 @@ func TestListChatModels(t *testing.T) {
require.NotNil(t, googleProvider)
require.True(t, googleProvider.Available)
_, err = client.UpsertUserChatProviderKey(ctx, provider.ID, codersdk.CreateUserChatProviderKeyRequest{
_, err = client.UpsertUserChatProviderKey(ctx, chatProvider.ID, codersdk.CreateUserChatProviderKeyRequest{
APIKey: "user-api-key",
})
require.NoError(t, err)
@@ -1914,15 +1919,17 @@ func TestListChatModels(t *testing.T) {
client := newChatClientWithDeploymentValues(t, values)
_ = coderdtest.CreateFirstUser(t, client.Client)
provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
chatProvider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
Provider: "openai",
APIKey: "test-key",
})
require.NoError(t, err)
aiProvider := createAIProviderForTest(t, client, "openai", "test-key")
contextLimit := int64(4096)
_, err = client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: "openai",
AIProviderID: &aiProvider.ID,
Model: "gpt-4o-mini",
ContextLimit: &contextLimit,
})
@@ -1936,7 +1943,7 @@ func TestListChatModels(t *testing.T) {
require.Equal(t, "gpt-4o-mini", models.Providers[0].Models[0].Model)
enabled := false
_, err = client.UpdateChatProvider(ctx, provider.ID, codersdk.UpdateChatProviderConfigRequest{
_, err = client.UpdateChatProvider(ctx, chatProvider.ID, codersdk.UpdateChatProviderConfigRequest{
Enabled: &enabled,
})
require.NoError(t, err)
@@ -2261,6 +2268,186 @@ func TestWatchChats(t *testing.T) {
})
}
func TestUserAIProviderKeys(t *testing.T) {
t.Parallel()
createOpenAIProvider := func(t *testing.T, client *codersdk.ExperimentalClient, name string, enabled bool, apiKeys ...string) codersdk.AIProvider {
t.Helper()
provider, err := client.CreateAIProvider(testutil.Context(t, testutil.WaitLong), codersdk.CreateAIProviderRequest{
Type: codersdk.AIProviderTypeOpenAI,
Name: name,
Enabled: enabled,
BaseURL: "https://api.openai.example.com/v1",
APIKeys: apiKeys,
})
require.NoError(t, err)
return provider
}
findUserAIProviderKeyConfig := func(
t *testing.T,
configs []codersdk.UserAIProviderKeyConfig,
providerID uuid.UUID,
) *codersdk.UserAIProviderKeyConfig {
t.Helper()
for i := range configs {
if configs[i].Provider.ID == providerID {
return &configs[i]
}
}
return nil
}
t.Run("SelfServiceLifecycle", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
adminClient := newChatClient(t)
firstUser := coderdtest.CreateFirstUser(t, adminClient.Client)
memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID)
memberClient := codersdk.NewExperimentalClient(memberClientRaw)
provider := createOpenAIProvider(t, adminClient, "test-user-key-"+uuid.NewString(), true, "test-provider-api-key")
configs, err := memberClient.ListUserAIProviderKeyConfigs(ctx, "me")
require.NoError(t, err)
cfg := findUserAIProviderKeyConfig(t, configs, provider.ID)
require.NotNil(t, cfg)
require.False(t, cfg.HasUserAPIKey)
require.True(t, cfg.HasProviderAPIKey)
require.True(t, cfg.BYOKEnabled)
cfgValue, err := memberClient.UpsertUserAIProviderKey(ctx, "me", provider.ID, codersdk.CreateUserAIProviderKeyRequest{APIKey: "test-user-api-key"})
require.NoError(t, err)
require.Equal(t, provider.ID, cfgValue.Provider.ID)
require.True(t, cfgValue.HasUserAPIKey)
require.True(t, cfgValue.HasProviderAPIKey)
require.True(t, cfgValue.BYOKEnabled)
configs, err = memberClient.ListUserAIProviderKeyConfigs(ctx, "me")
require.NoError(t, err)
cfg = findUserAIProviderKeyConfig(t, configs, provider.ID)
require.NotNil(t, cfg)
require.True(t, cfg.HasUserAPIKey)
cfgValue, err = memberClient.UpsertUserAIProviderKey(ctx, "me", provider.ID, codersdk.CreateUserAIProviderKeyRequest{APIKey: "replacement-user-api-key"})
require.NoError(t, err)
require.Equal(t, provider.ID, cfgValue.Provider.ID)
require.True(t, cfgValue.HasUserAPIKey)
configs, err = memberClient.ListUserAIProviderKeyConfigs(ctx, "me")
require.NoError(t, err)
cfg = findUserAIProviderKeyConfig(t, configs, provider.ID)
require.NotNil(t, cfg)
require.True(t, cfg.HasUserAPIKey)
require.NoError(t, memberClient.DeleteUserAIProviderKey(ctx, "me", provider.ID))
configs, err = memberClient.ListUserAIProviderKeyConfigs(ctx, "me")
require.NoError(t, err)
cfg = findUserAIProviderKeyConfig(t, configs, provider.ID)
require.NotNil(t, cfg)
require.False(t, cfg.HasUserAPIKey)
})
t.Run("ListsDisabledProviderWithSavedUserKey", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
adminClient := newChatClient(t)
firstUser := coderdtest.CreateFirstUser(t, adminClient.Client)
memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID)
memberClient := codersdk.NewExperimentalClient(memberClientRaw)
provider := createOpenAIProvider(t, adminClient, "test-disabled-saved-user-key-"+uuid.NewString(), true)
_, err := memberClient.UpsertUserAIProviderKey(ctx, "me", provider.ID, codersdk.CreateUserAIProviderKeyRequest{APIKey: "test-user-api-key"})
require.NoError(t, err)
enabled := false
_, err = adminClient.UpdateAIProvider(ctx, provider.ID.String(), codersdk.UpdateAIProviderRequest{Enabled: &enabled})
require.NoError(t, err)
configs, err := memberClient.ListUserAIProviderKeyConfigs(ctx, "me")
require.NoError(t, err)
cfg := findUserAIProviderKeyConfig(t, configs, provider.ID)
require.NotNil(t, cfg)
require.False(t, cfg.Provider.Enabled)
require.True(t, cfg.HasUserAPIKey)
})
t.Run("RejectsDisabledProvider", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
adminClient := newChatClient(t)
firstUser := coderdtest.CreateFirstUser(t, adminClient.Client)
memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID)
memberClient := codersdk.NewExperimentalClient(memberClientRaw)
provider := createOpenAIProvider(t, adminClient, "test-disabled-user-key-"+uuid.NewString(), false)
_, err := memberClient.UpsertUserAIProviderKey(ctx, "me", provider.ID, codersdk.CreateUserAIProviderKeyRequest{APIKey: "test-user-api-key"})
sdkErr := requireSDKError(t, err, http.StatusPreconditionFailed)
require.Equal(t, "AI provider is disabled.", sdkErr.Message)
})
t.Run("RejectsLargeAPIKey", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
adminClient := newChatClient(t)
firstUser := coderdtest.CreateFirstUser(t, adminClient.Client)
memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID)
memberClient := codersdk.NewExperimentalClient(memberClientRaw)
provider := createOpenAIProvider(t, adminClient, "test-large-user-key-"+uuid.NewString(), true)
_, err := memberClient.UpsertUserAIProviderKey(ctx, "me", provider.ID, codersdk.CreateUserAIProviderKeyRequest{APIKey: strings.Repeat("x", 10241)})
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
require.Equal(t, "API key too large.", sdkErr.Message)
})
t.Run("RejectsWhitespaceAPIKey", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
adminClient := newChatClient(t)
firstUser := coderdtest.CreateFirstUser(t, adminClient.Client)
memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID)
memberClient := codersdk.NewExperimentalClient(memberClientRaw)
provider := createOpenAIProvider(t, adminClient, "test-whitespace-user-key-"+uuid.NewString(), true)
_, err := memberClient.UpsertUserAIProviderKey(ctx, "me", provider.ID, codersdk.CreateUserAIProviderKeyRequest{APIKey: " "})
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
require.Equal(t, "API key must not contain leading or trailing whitespace.", sdkErr.Message)
})
t.Run("BYOKDisabledRejectsUpsertAndAllowsDelete", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
values := chatDeploymentValues(t)
values.AI.BridgeConfig.AllowBYOK = serpent.Bool(false)
client := newChatClientWithDeploymentValues(t, values)
_ = coderdtest.CreateFirstUser(t, client.Client)
provider := createOpenAIProvider(t, client, "test-byok-disabled-"+uuid.NewString(), true)
_, err := client.UpsertUserAIProviderKey(ctx, "me", provider.ID, codersdk.CreateUserAIProviderKeyRequest{APIKey: "test-user-api-key"})
sdkErr := requireSDKError(t, err, http.StatusForbidden)
require.Equal(t, "BYOK is disabled.", sdkErr.Message)
configs, err := client.ListUserAIProviderKeyConfigs(ctx, "me")
require.NoError(t, err)
cfg := findUserAIProviderKeyConfig(t, configs, provider.ID)
require.NotNil(t, cfg)
require.False(t, cfg.BYOKEnabled)
require.NoError(t, client.DeleteUserAIProviderKey(ctx, "me", provider.ID))
})
}
func TestListChatProviders(t *testing.T) {
t.Parallel()
@@ -2611,7 +2798,7 @@ func TestCreateChatProvider(t *testing.T) {
})
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
require.Equal(t, "API key too large.", sdkErr.Message)
require.Equal(t, fmt.Sprintf("API key exceeds maximum size of %d bytes", chatProviderAPIKeySizeLimit), sdkErr.Detail)
require.Equal(t, fmt.Sprintf("API key exceeds maximum size of 10 KB (%d bytes)", chatProviderAPIKeySizeLimit), sdkErr.Detail)
})
t.Run("AllowsMaxSizedAPIKey", func(t *testing.T) {
@@ -2898,7 +3085,7 @@ func TestUpdateChatProvider(t *testing.T) {
})
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
require.Equal(t, "API key too large.", sdkErr.Message)
require.Equal(t, fmt.Sprintf("API key exceeds maximum size of %d bytes", chatProviderAPIKeySizeLimit), sdkErr.Detail)
require.Equal(t, fmt.Sprintf("API key exceeds maximum size of 10 KB (%d bytes)", chatProviderAPIKeySizeLimit), sdkErr.Detail)
})
t.Run("AllowsMaxSizedAPIKey", func(t *testing.T) {
@@ -2961,11 +3148,13 @@ func TestDeleteChatProvider(t *testing.T) {
AllowUserAPIKey: ptr.Ref(true),
})
require.NoError(t, err)
aiProviderToDelete := createAIProviderForTest(t, client, providerToDelete.Provider, "delete-api-key")
deleteContextLimit := int64(4096)
deleteIsDefault := true
configToDelete, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: providerToDelete.Provider,
AIProviderID: &aiProviderToDelete.ID,
Model: "gpt-4o-delete-provider",
ContextLimit: &deleteContextLimit,
IsDefault: &deleteIsDefault,
@@ -2977,10 +3166,12 @@ func TestDeleteChatProvider(t *testing.T) {
APIKey: "keep-api-key",
})
require.NoError(t, err)
keepAIProvider := createAIProviderForTest(t, client, keepProvider.Provider, "keep-api-key")
keepContextLimit := int64(8192)
keepConfig, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: keepProvider.Provider,
AIProviderID: &keepAIProvider.ID,
Model: "claude-keep-provider",
ContextLimit: &keepContextLimit,
})
@@ -3082,11 +3273,13 @@ func TestDeleteChatProvider(t *testing.T) {
APIKey: "only-provider-api-key",
})
require.NoError(t, err)
aiProvider := createAIProviderForTest(t, client, provider.Provider, "only-provider-api-key")
contextLimit := int64(4096)
isDefault := true
config, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: provider.Provider,
AIProviderID: &aiProvider.ID,
Model: "gpt-4o-only-provider",
ContextLimit: &contextLimit,
IsDefault: &isDefault,
@@ -3636,7 +3829,7 @@ func TestUpsertUserChatProviderKey(t *testing.T) {
})
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
require.Equal(t, "API key too large.", sdkErr.Message)
require.Equal(t, fmt.Sprintf("API key exceeds maximum size of %d bytes", chatProviderAPIKeySizeLimit), sdkErr.Detail)
require.Equal(t, fmt.Sprintf("API key exceeds maximum size of 10 KB (%d bytes)", chatProviderAPIKeySizeLimit), sdkErr.Detail)
})
t.Run("AllowsMaxSizedAPIKey", func(t *testing.T) {
@@ -3695,16 +3888,13 @@ func TestListChatModelConfigs(t *testing.T) {
client := newChatClient(t)
_ = coderdtest.CreateFirstUser(t, client.Client)
_, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
Provider: "openai",
APIKey: "test-api-key",
})
require.NoError(t, err)
aiProvider := createAIProviderForTest(t, client, "openai", "test-api-key")
contextLimit := int64(4096)
enabled := false
disabledConfig, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: "openai",
AIProviderID: &aiProvider.ID,
Model: "gpt-4o-disabled",
DisplayName: "GPT-4o Disabled",
Enabled: &enabled,
@@ -3741,6 +3931,7 @@ func TestListChatModelConfigs(t *testing.T) {
enabled := false
_, err := adminClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: enabledConfig.Provider,
AIProviderID: enabledConfig.AIProviderID,
Model: "gpt-4o-disabled",
DisplayName: "GPT-4o Disabled",
Enabled: &enabled,
@@ -3762,15 +3953,12 @@ func TestListChatModelConfigs(t *testing.T) {
client, db := newChatClientWithDatabase(t)
firstUser := coderdtest.CreateFirstUser(t, client.Client)
_, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
Provider: "openai",
APIKey: "test-api-key",
})
require.NoError(t, err)
aiProvider := createAIProviderForTest(t, client, "openai", "test-api-key")
legacyOptions := json.RawMessage(`{"input_price_per_million_tokens":0.15,"output_price_per_million_tokens":0.6,"cache_read_price_per_million_tokens":0.03,"cache_write_price_per_million_tokens":0.3}`)
storedConfig := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{
Provider: "openai",
AIProviderID: uuid.NullUUID{UUID: aiProvider.ID, Valid: true},
Model: "gpt-4o-mini-legacy",
DisplayName: "GPT-4o Mini Legacy",
CreatedBy: uuid.NullUUID{UUID: firstUser.UserID, Valid: true},
@@ -3831,11 +4019,7 @@ func TestCreateChatModelConfig(t *testing.T) {
client := newChatClient(t)
_ = coderdtest.CreateFirstUser(t, client.Client)
_, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
Provider: "openai",
APIKey: "test-api-key",
})
require.NoError(t, err)
aiProvider := createAIProviderForTest(t, client, "openai", "test-api-key")
contextLimit := int64(4096)
isDefault := true
@@ -3849,6 +4033,7 @@ func TestCreateChatModelConfig(t *testing.T) {
}
modelConfig, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: "openai",
AIProviderID: &aiProvider.ID,
Model: "gpt-4o-mini",
ContextLimit: &contextLimit,
IsDefault: &isDefault,
@@ -3875,15 +4060,12 @@ func TestCreateChatModelConfig(t *testing.T) {
client := newChatClient(t)
_ = coderdtest.CreateFirstUser(t, client.Client)
_, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
Provider: "openai",
APIKey: "test-api-key",
})
require.NoError(t, err)
aiProvider := createAIProviderForTest(t, client, "openai", "test-api-key")
contextLimit := int64(4096)
_, err = client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
_, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: "openai",
AIProviderID: &aiProvider.ID,
Model: "gpt-4o-mini",
ContextLimit: &contextLimit,
ModelConfig: &codersdk.ChatModelCallConfig{
@@ -3907,16 +4089,18 @@ func TestCreateChatModelConfig(t *testing.T) {
ctx := testutil.Context(t, testutil.WaitLong)
client := newChatClient(t)
_ = coderdtest.CreateFirstUser(t, client.Client)
aiProvider := createAIProviderForTest(t, client, "openai", "test-api-key")
_, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: "openai",
Model: "gpt-4o-mini",
Provider: "openai",
AIProviderID: &aiProvider.ID,
Model: "gpt-4o-mini",
})
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
require.Equal(t, "Context limit is required.", sdkErr.Message)
})
t.Run("ProviderNotConfigured", func(t *testing.T) {
t.Run("AIProviderIDRequired", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
@@ -3930,7 +4114,94 @@ func TestCreateChatModelConfig(t *testing.T) {
ContextLimit: &contextLimit,
})
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
require.Equal(t, "Chat provider is not configured.", sdkErr.Message)
require.Equal(t, "AI provider ID is required.", sdkErr.Message)
})
t.Run("ProviderNotConfigured", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client := newChatClient(t)
_ = coderdtest.CreateFirstUser(t, client.Client)
contextLimit := int64(4096)
missingProviderID := uuid.New()
_, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: "openai",
AIProviderID: &missingProviderID,
Model: "gpt-4o-mini",
ContextLimit: &contextLimit,
})
sdkErr := requireSDKError(t, err, http.StatusPreconditionFailed)
require.Equal(t, "AI provider is not configured.", sdkErr.Message)
})
t.Run("WithAIProviderID", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client := newChatClient(t)
_ = coderdtest.CreateFirstUser(t, client.Client)
provider, err := client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{
Type: codersdk.AIProviderTypeOpenAI,
Name: "test-model-config-provider-" + uuid.NewString(),
Enabled: true,
BaseURL: "https://api.openai.com/v1",
})
require.NoError(t, err)
contextLimit := int64(4096)
modelConfig, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
AIProviderID: &provider.ID,
Model: "gpt-4o-mini",
ContextLimit: &contextLimit,
})
require.NoError(t, err)
require.Equal(t, "openai", modelConfig.Provider)
require.NotNil(t, modelConfig.AIProviderID)
require.Equal(t, provider.ID, *modelConfig.AIProviderID)
})
t.Run("AIProviderIDNotConfigured", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client := newChatClient(t)
_ = coderdtest.CreateFirstUser(t, client.Client)
missingProviderID := uuid.New()
contextLimit := int64(4096)
_, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
AIProviderID: &missingProviderID,
Model: "gpt-4o-mini",
ContextLimit: &contextLimit,
})
sdkErr := requireSDKError(t, err, http.StatusPreconditionFailed)
require.Equal(t, "AI provider is not configured.", sdkErr.Message)
})
t.Run("AIProviderIDDisabled", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client := newChatClient(t)
_ = coderdtest.CreateFirstUser(t, client.Client)
provider, err := client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{
Type: codersdk.AIProviderTypeOpenAI,
Name: "test-disabled-model-provider-" + uuid.NewString(),
Enabled: false,
BaseURL: "https://api.openai.com/v1",
})
require.NoError(t, err)
contextLimit := int64(4096)
_, err = client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
AIProviderID: &provider.ID,
Model: "gpt-4o-mini",
ContextLimit: &contextLimit,
})
sdkErr := requireSDKError(t, err, http.StatusPreconditionFailed)
require.Equal(t, "AI provider is disabled.", sdkErr.Message)
})
t.Run("ForbiddenForOrganizationMember", func(t *testing.T) {
@@ -3942,15 +4213,12 @@ func TestCreateChatModelConfig(t *testing.T) {
memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID)
memberClient := codersdk.NewExperimentalClient(memberClientRaw)
_, err := adminClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
Provider: "openai",
APIKey: "test-api-key",
})
require.NoError(t, err)
aiProvider := createAIProviderForTest(t, adminClient, "openai", "test-api-key")
contextLimit := int64(4096)
_, err = memberClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
_, err := memberClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: "openai",
AIProviderID: &aiProvider.ID,
Model: "gpt-4o-mini",
ContextLimit: &contextLimit,
})
@@ -4041,16 +4309,13 @@ func TestUpdateChatModelConfig(t *testing.T) {
memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID)
memberClient := codersdk.NewExperimentalClient(memberClientRaw)
_, err := adminClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
Provider: "openai",
APIKey: "test-api-key",
})
require.NoError(t, err)
aiProvider := createAIProviderForTest(t, adminClient, "openai", "test-api-key")
contextLimit := int64(4096)
enabled := false
modelConfig, err := adminClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: "openai",
AIProviderID: &aiProvider.ID,
Model: "gpt-4o-reenable",
DisplayName: "GPT-4o Re-enable",
Enabled: &enabled,
@@ -4115,6 +4380,100 @@ func TestUpdateChatModelConfig(t *testing.T) {
)
})
t.Run("UpdateAIProviderID", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client := newChatClient(t)
_ = coderdtest.CreateFirstUser(t, client.Client)
modelConfig := createChatModelConfig(t, client)
provider, err := client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{
Type: codersdk.AIProviderTypeAnthropic,
Name: "test-update-model-provider-" + uuid.NewString(),
Enabled: true,
BaseURL: "https://api.anthropic.com",
})
require.NoError(t, err)
updated, err := client.UpdateChatModelConfig(ctx, modelConfig.ID, codersdk.UpdateChatModelConfigRequest{
AIProviderID: &provider.ID,
Model: "claude-3-5-sonnet-latest",
})
require.NoError(t, err)
require.Equal(t, "anthropic", updated.Provider)
require.NotNil(t, updated.AIProviderID)
require.Equal(t, provider.ID, *updated.AIProviderID)
})
t.Run("UpdateProviderPreservesAIProviderIDWhenTypeUnchanged", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client := newChatClient(t)
_ = coderdtest.CreateFirstUser(t, client.Client)
modelConfig := createChatModelConfig(t, client)
provider, err := client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{
Type: codersdk.AIProviderTypeAnthropic,
Name: "test-preserve-model-provider-" + uuid.NewString(),
Enabled: true,
BaseURL: "https://api.anthropic.com",
})
require.NoError(t, err)
updated, err := client.UpdateChatModelConfig(ctx, modelConfig.ID, codersdk.UpdateChatModelConfigRequest{
AIProviderID: &provider.ID,
Model: "claude-3-5-sonnet-latest",
})
require.NoError(t, err)
require.NotNil(t, updated.AIProviderID)
updated, err = client.UpdateChatModelConfig(ctx, modelConfig.ID, codersdk.UpdateChatModelConfigRequest{
Provider: "anthropic",
Model: "claude-3-5-haiku-latest",
})
require.NoError(t, err)
require.NotNil(t, updated.AIProviderID)
require.Equal(t, provider.ID, *updated.AIProviderID)
})
t.Run("UpdateAIProviderIDNotConfigured", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client := newChatClient(t)
_ = coderdtest.CreateFirstUser(t, client.Client)
modelConfig := createChatModelConfig(t, client)
missingProviderID := uuid.New()
_, err := client.UpdateChatModelConfig(ctx, modelConfig.ID, codersdk.UpdateChatModelConfigRequest{
AIProviderID: &missingProviderID,
})
sdkErr := requireSDKError(t, err, http.StatusPreconditionFailed)
require.Equal(t, "AI provider is not configured.", sdkErr.Message)
})
t.Run("UpdateAIProviderIDDisabled", func(t *testing.T) {
t.Parallel()
ctx := testutil.Context(t, testutil.WaitLong)
client := newChatClient(t)
_ = coderdtest.CreateFirstUser(t, client.Client)
modelConfig := createChatModelConfig(t, client)
provider, err := client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{
Type: codersdk.AIProviderTypeOpenAI,
Name: "test-update-disabled-model-provider-" + uuid.NewString(),
Enabled: false,
BaseURL: "https://api.openai.com/v1",
})
require.NoError(t, err)
_, err = client.UpdateChatModelConfig(ctx, modelConfig.ID, codersdk.UpdateChatModelConfigRequest{
AIProviderID: &provider.ID,
})
sdkErr := requireSDKError(t, err, http.StatusPreconditionFailed)
require.Equal(t, "AI provider is disabled.", sdkErr.Message)
})
t.Run("ProviderNotConfigured", func(t *testing.T) {
t.Parallel()
@@ -4126,7 +4485,7 @@ func TestUpdateChatModelConfig(t *testing.T) {
_, err := client.UpdateChatModelConfig(ctx, modelConfig.ID, codersdk.UpdateChatModelConfigRequest{
Provider: "anthropic",
})
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
sdkErr := requireSDKError(t, err, http.StatusPreconditionFailed)
require.Equal(t, "Chat provider is not configured.", sdkErr.Message)
})
@@ -4167,16 +4526,13 @@ func TestUpdateChatModelConfig(t *testing.T) {
_ = coderdtest.CreateFirstUser(t, client.Client)
defaultConfig := createChatModelConfig(t, client)
_, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
Provider: "anthropic",
APIKey: "candidate-api-key",
})
require.NoError(t, err)
aiProvider := createAIProviderForTest(t, client, "anthropic", "candidate-api-key")
contextLimit := int64(4096)
isDefault := false
candidateConfig, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: "anthropic",
AIProviderID: &aiProvider.ID,
Model: "claude-3-5-sonnet",
ContextLimit: &contextLimit,
IsDefault: &isDefault,
@@ -10495,6 +10851,38 @@ func TestWatchChatGitAuthz(t *testing.T) {
require.Equal(t, http.StatusForbidden, res.StatusCode)
}
func createAIProviderForTest(
t testing.TB,
client *codersdk.ExperimentalClient,
provider string,
apiKey string,
) codersdk.AIProvider {
t.Helper()
ctx := testutil.Context(t, testutil.WaitLong)
req := codersdk.CreateAIProviderRequest{
Type: codersdk.AIProviderType(provider),
Name: "test-" + provider + "-" + uuid.NewString(),
BaseURL: aiProviderBaseURLForTest(provider),
Enabled: true,
}
if apiKey != "" {
req.APIKeys = []string{apiKey}
}
aiProvider, err := client.CreateAIProvider(ctx, req)
require.NoError(t, err)
return aiProvider
}
func aiProviderBaseURLForTest(provider string) string {
switch provider {
case "anthropic", "bedrock", "google":
return "https://api.example.com"
default:
return "https://api.example.com/v1"
}
}
func createChatModelConfig(t testing.TB, client *codersdk.ExperimentalClient) codersdk.ChatModelConfig {
t.Helper()
return coderdtest.CreateOpenAICompatChatModelConfig(t, client, "")
@@ -10529,10 +10917,12 @@ func createAdditionalChatModelConfig(
t.Helper()
ctx := testutil.Context(t, testutil.WaitLong)
aiProvider := createAIProviderForTest(t, client, provider, "test-api-key")
contextLimit := int64(4096)
isDefault := false
modelConfig, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: provider,
AIProviderID: &aiProvider.ID,
Model: model,
ContextLimit: &contextLimit,
IsDefault: &isDefault,
+1
View File
@@ -84,6 +84,7 @@ const (
SubjectTypeBoundaryUsageTracker SubjectType = "boundary_usage_tracker"
SubjectTypeWorkspaceBuilder SubjectType = "workspace_builder"
SubjectTypeChatd SubjectType = "chatd"
SubjectTypeAIProviderMetadataReader SubjectType = "ai_provider_metadata_reader"
)
const (
+7 -7
View File
@@ -350,13 +350,13 @@ func (p *Server) resolveAdvisorModelOverride(
return fallbackModel, fallbackCallConfig
}
// GetEnabledChatModelConfigByID joins on chat_providers.enabled = TRUE
// and chat_model_configs.enabled = TRUE, so it returns sql.ErrNoRows
// the moment an admin disables either the model config or its provider.
// Using the cached ModelConfigByID here would keep resolving an override
// whose provider was just disabled, and an env or central fallback key
// would let ModelFromConfig succeed, silently routing advisor prompts
// to a provider the admin expects to be off.
// GetEnabledChatModelConfigByID checks the model config and referenced
// provider enabled state, so it returns sql.ErrNoRows the moment an
// admin disables either one. Using the cached ModelConfigByID here
// would keep resolving an override whose provider was just disabled,
// and an available fallback key would let ModelFromConfig succeed,
// silently routing advisor prompts to a provider the admin expects to
// be off.
overrideConfig, err := p.db.GetEnabledChatModelConfigByID(
ctx,
advisorCfg.ModelConfigID,
+7 -97
View File
@@ -288,22 +288,7 @@ func TestSubagentChatExcludesWorkspaceProvisioningTools(t *testing.T) {
)
})
_, err := expClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
Provider: "openai-compat",
APIKey: "test-api-key",
BaseURL: openAIURL,
})
require.NoError(t, err)
contextLimit := int64(4096)
isDefault := true
_, err = expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: "openai-compat",
Model: "gpt-4o-mini",
ContextLimit: &contextLimit,
IsDefault: &isDefault,
})
require.NoError(t, err)
coderdtest.CreateOpenAICompatChatModelConfig(t, expClient, openAIURL)
// Create a root chat whose first model call will spawn a subagent.
chat, err := expClient.CreateChat(ctx, codersdk.CreateChatRequest{
@@ -483,22 +468,7 @@ func TestPlanModeSubagentChatExcludesAskUserQuestion(t *testing.T) {
)
})
_, err = expClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
Provider: "openai-compat",
APIKey: "test-api-key",
BaseURL: openAIURL,
})
require.NoError(t, err)
contextLimit := int64(4096)
isDefault := true
_, err = expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: "openai-compat",
Model: "gpt-4o-mini",
ContextLimit: &contextLimit,
IsDefault: &isDefault,
})
require.NoError(t, err)
coderdtest.CreateOpenAICompatChatModelConfig(t, expClient, openAIURL)
chat, err := expClient.CreateChat(ctx, codersdk.CreateChatRequest{
OrganizationID: user.OrganizationID,
@@ -638,24 +608,9 @@ func TestExploreSubagentIsReadOnly(t *testing.T) {
)
})
_, err := expClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
Provider: "openai-compat",
APIKey: "test-api-key",
BaseURL: openAIURL,
})
require.NoError(t, err)
coderdtest.CreateOpenAICompatChatModelConfig(t, expClient, openAIURL)
contextLimit := int64(4096)
isDefault := true
_, err = expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: "openai-compat",
Model: "gpt-4o-mini",
ContextLimit: &contextLimit,
IsDefault: &isDefault,
})
require.NoError(t, err)
_, err = expClient.CreateChat(ctx, codersdk.CreateChatRequest{
_, err := expClient.CreateChat(ctx, codersdk.CreateChatRequest{
OrganizationID: user.OrganizationID,
WorkspaceID: &workspace.ID,
Content: []codersdk.ChatInputPart{
@@ -4953,22 +4908,7 @@ func TestCreateWorkspaceTool_EndToEnd(t *testing.T) {
)
})
_, err := expClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
Provider: "openai-compat",
APIKey: "test-api-key",
BaseURL: openAIURL,
})
require.NoError(t, err)
contextLimit := int64(4096)
isDefault := true
_, err = expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: "openai-compat",
Model: "gpt-4o-mini",
ContextLimit: &contextLimit,
IsDefault: &isDefault,
})
require.NoError(t, err)
coderdtest.CreateOpenAICompatChatModelConfig(t, expClient, openAIURL)
chat, err := expClient.CreateChat(ctx, codersdk.CreateChatRequest{
OrganizationID: user.OrganizationID,
@@ -5123,22 +5063,7 @@ func TestStartWorkspaceTool_EndToEnd(t *testing.T) {
)
})
_, err := expClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
Provider: "openai-compat",
APIKey: "test-api-key",
BaseURL: openAIURL,
})
require.NoError(t, err)
contextLimit := int64(4096)
isDefault := true
_, err = expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: "openai-compat",
Model: "gpt-4o-mini",
ContextLimit: &contextLimit,
IsDefault: &isDefault,
})
require.NoError(t, err)
coderdtest.CreateOpenAICompatChatModelConfig(t, expClient, openAIURL)
// Create a chat with the stopped workspace pre-associated.
chat, err := expClient.CreateChat(ctx, codersdk.CreateChatRequest{
@@ -8586,22 +8511,7 @@ func TestAgentContextFilesAndSkillsLoadedIntoChat(t *testing.T) {
)
})
_, err := expClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
Provider: "openai-compat",
APIKey: "test-api-key",
BaseURL: openAIURL,
})
require.NoError(t, err)
contextLimit := int64(4096)
isDefault := true
_, err = expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: "openai-compat",
Model: "gpt-4o-mini",
ContextLimit: &contextLimit,
IsDefault: &isDefault,
})
require.NoError(t, err)
coderdtest.CreateOpenAICompatChatModelConfig(t, expClient, openAIURL)
workspaceID := workspace.ID
chat, err := expClient.CreateChat(ctx, codersdk.CreateChatRequest{
+65 -27
View File
@@ -5,6 +5,7 @@ import (
"os"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/coderdtest"
@@ -13,6 +14,52 @@ import (
"github.com/coder/coder/v2/testutil"
)
func createIntegrationAIProvider(
ctx context.Context,
t testing.TB,
client *codersdk.ExperimentalClient,
providerType codersdk.AIProviderType,
apiKey string,
baseURL string,
) codersdk.AIProvider {
t.Helper()
if baseURL == "" {
baseURL = defaultIntegrationAIProviderBaseURL(providerType)
}
provider, err := client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{
Type: providerType,
Name: string(providerType) + "-" + uuid.NewString(),
DisplayName: aiProviderDisplayName(providerType),
Enabled: true,
BaseURL: baseURL,
APIKeys: []string{apiKey},
})
require.NoError(t, err)
return provider
}
func defaultIntegrationAIProviderBaseURL(providerType codersdk.AIProviderType) string {
switch providerType {
case codersdk.AIProviderTypeAnthropic:
return "https://api.anthropic.com"
case codersdk.AIProviderTypeOpenAI:
return "https://api.openai.com/v1"
default:
return "https://api.example.com"
}
}
func aiProviderDisplayName(providerType codersdk.AIProviderType) string {
switch providerType {
case codersdk.AIProviderTypeAnthropic:
return "Anthropic"
case codersdk.AIProviderTypeOpenAI:
return "OpenAI"
default:
return string(providerType)
}
}
// TestAnthropicWebSearchRoundTrip is an integration test that verifies
// provider-executed tool results (web_search) survive the full
// persist → reconstruct → re-send cycle. It sends a query that
@@ -43,19 +90,16 @@ func TestAnthropicWebSearchRoundTrip(t *testing.T) {
user := coderdtest.CreateFirstUser(t, client)
expClient := codersdk.NewExperimentalClient(client)
// Configure an Anthropic provider with the real API key.
_, err := expClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
Provider: "anthropic",
APIKey: apiKey,
BaseURL: baseURL,
})
require.NoError(t, err)
provider := createIntegrationAIProvider(
ctx, t, expClient, codersdk.AIProviderTypeAnthropic, apiKey, baseURL,
)
// Create a model config that enables web_search.
contextLimit := int64(200000)
isDefault := true
_, err = expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: "anthropic",
_, err := expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: string(provider.Type),
AIProviderID: &provider.ID,
Model: "claude-sonnet-4-20250514",
ContextLimit: &contextLimit,
IsDefault: &isDefault,
@@ -303,13 +347,9 @@ func TestOpenAIReasoningRoundTrip(t *testing.T) {
user := coderdtest.CreateFirstUser(t, client)
expClient := codersdk.NewExperimentalClient(client)
// Configure an OpenAI provider with the real API key.
_, err := expClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
Provider: "openai",
APIKey: apiKey,
BaseURL: baseURL,
})
require.NoError(t, err)
provider := createIntegrationAIProvider(
ctx, t, expClient, codersdk.AIProviderTypeOpenAI, apiKey, baseURL,
)
// Create a model config for a reasoning model with Store: true
// (the default). Using o4-mini because it always produces
@@ -317,8 +357,9 @@ func TestOpenAIReasoningRoundTrip(t *testing.T) {
contextLimit := int64(200000)
isDefault := true
reasoningSummary := "auto"
_, err = expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: "openai",
_, err := expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: string(provider.Type),
AIProviderID: &provider.ID,
Model: "o4-mini",
ContextLimit: &contextLimit,
IsDefault: &isDefault,
@@ -457,21 +498,18 @@ func TestOpenAIReasoningRoundTripStoreFalse(t *testing.T) {
user := coderdtest.CreateFirstUser(t, client)
expClient := codersdk.NewExperimentalClient(client)
// Configure an OpenAI provider with the real API key.
_, err := expClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
Provider: "openai",
APIKey: apiKey,
BaseURL: baseURL,
})
require.NoError(t, err)
provider := createIntegrationAIProvider(
ctx, t, expClient, codersdk.AIProviderTypeOpenAI, apiKey, baseURL,
)
// Create a model config for a reasoning model with Store: false.
// Using o4-mini because it always produces reasoning items.
contextLimit := int64(200000)
isDefault := true
reasoningSummary := "auto"
_, err = expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: "openai",
_, err := expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: string(provider.Type),
AIProviderID: &provider.ID,
Model: "o4-mini",
ContextLimit: &contextLimit,
IsDefault: &isDefault,
+10 -9
View File
@@ -218,9 +218,7 @@ func (req CreateAIProviderRequest) Validate() []ValidationError {
})
}
validations = append(validations, validateAIProviderName(req.Name)...)
if req.BaseURL != "" {
validations = append(validations, validateAIProviderBaseURL(req.BaseURL)...)
}
validations = append(validations, validateRequiredAIProviderBaseURL(req.BaseURL)...)
validations = append(validations, validateAIProviderAPIKeys(req.APIKeys)...)
if req.Settings.Bedrock != nil && req.Type != AIProviderTypeAnthropic {
validations = append(validations, ValidationError{
@@ -264,12 +262,8 @@ type AIProviderKeyMutation struct {
// should reject empty patches with IsEmpty before invoking Validate.
func (req UpdateAIProviderRequest) Validate() []ValidationError {
var validations []ValidationError
switch {
case req.BaseURL == nil:
case *req.BaseURL == "":
validations = append(validations, ValidationError{Field: "base_url", Detail: "base_url cannot be empty"})
default:
validations = append(validations, validateAIProviderBaseURL(*req.BaseURL)...)
if req.BaseURL != nil {
validations = append(validations, validateRequiredAIProviderBaseURL(*req.BaseURL)...)
}
if req.APIKeys != nil {
validations = append(validations, validateAIProviderKeyMutations(*req.APIKeys)...)
@@ -296,6 +290,13 @@ func validateAIProviderName(name string) []ValidationError {
return validations
}
func validateRequiredAIProviderBaseURL(raw string) []ValidationError {
if raw == "" {
return []ValidationError{{Field: "base_url", Detail: "base_url is required"}}
}
return validateAIProviderBaseURL(raw)
}
func validateAIProviderBaseURL(raw string) []ValidationError {
var validations []ValidationError
parsed, err := url.Parse(raw)
+75 -1
View File
@@ -1120,6 +1120,31 @@ type UpdateChatProviderConfigRequest struct {
AllowCentralAPIKeyFallback *bool `json:"allow_central_api_key_fallback,omitempty"`
}
// AIProviderSummary is provider metadata embedded in other API responses.
type AIProviderSummary struct {
ID uuid.UUID `json:"id" format:"uuid"`
Type AIProviderType `json:"type"`
Name string `json:"name"`
DisplayName string `json:"display_name"`
Enabled bool `json:"enabled"`
Deleted bool `json:"deleted"`
}
// UserAIProviderKeyConfig is a provider summary from the current user's
// perspective. It reports key presence but never returns key material.
type UserAIProviderKeyConfig struct {
Provider AIProviderSummary `json:"provider"`
HasUserAPIKey bool `json:"has_user_api_key"`
HasProviderAPIKey bool `json:"has_provider_api_key"`
BYOKEnabled bool `json:"byok_enabled"`
}
// CreateUserAIProviderKeyRequest creates or replaces a user's API key
// for an AI provider.
type CreateUserAIProviderKeyRequest struct {
APIKey string `json:"api_key"`
}
// UserChatProviderConfig is a summary of a provider that allows
// user-supplied keys, as seen from the current user's perspective.
type UserChatProviderConfig struct {
@@ -1128,6 +1153,7 @@ type UserChatProviderConfig struct {
DisplayName string `json:"display_name"`
HasUserAPIKey bool `json:"has_user_api_key"`
HasCentralAPIKeyFallback bool `json:"has_central_api_key_fallback"`
BYOKEnabled bool `json:"byok_enabled"`
}
// CreateUserChatProviderKeyRequest creates or replaces a user's API key
@@ -1140,6 +1166,7 @@ type CreateUserChatProviderKeyRequest struct {
type ChatModelConfig struct {
ID uuid.UUID `json:"id" format:"uuid"`
Provider string `json:"provider"`
AIProviderID *uuid.UUID `json:"ai_provider_id,omitempty" format:"uuid"`
Model string `json:"model"`
DisplayName string `json:"display_name"`
Enabled bool `json:"enabled"`
@@ -1349,7 +1376,8 @@ func (c *ChatModelCallConfig) UnmarshalJSON(data []byte) error {
// CreateChatModelConfigRequest creates a chat model config.
type CreateChatModelConfigRequest struct {
Provider string `json:"provider"`
Provider string `json:"provider,omitempty"`
AIProviderID *uuid.UUID `json:"ai_provider_id,omitempty" format:"uuid"`
Model string `json:"model"`
DisplayName string `json:"display_name,omitempty"`
Enabled *bool `json:"enabled,omitempty"`
@@ -1362,6 +1390,7 @@ type CreateChatModelConfigRequest struct {
// UpdateChatModelConfigRequest updates a chat model config.
type UpdateChatModelConfigRequest struct {
Provider string `json:"provider,omitempty"`
AIProviderID *uuid.UUID `json:"ai_provider_id,omitempty" format:"uuid"`
Model string `json:"model,omitempty"`
DisplayName string `json:"display_name,omitempty"`
Enabled *bool `json:"enabled,omitempty"`
@@ -2116,6 +2145,51 @@ func (c *ExperimentalClient) DeleteChatProvider(ctx context.Context, providerID
return nil
}
// ListUserAIProviderKeyConfigs returns user-scoped AI provider key configs.
func (c *ExperimentalClient) ListUserAIProviderKeyConfigs(ctx context.Context, user string) ([]UserAIProviderKeyConfig, error) {
res, err := c.Request(ctx, http.MethodGet, userAIProviderKeysPath(user), nil)
if err != nil {
return nil, xerrors.Errorf("list user AI provider key configs: %w", err)
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return nil, ReadBodyAsError(res)
}
var configs []UserAIProviderKeyConfig
return configs, json.NewDecoder(res.Body).Decode(&configs)
}
// UpsertUserAIProviderKey creates or replaces a user API key for an AI provider.
func (c *ExperimentalClient) UpsertUserAIProviderKey(ctx context.Context, user string, providerID uuid.UUID, req CreateUserAIProviderKeyRequest) (UserAIProviderKeyConfig, error) {
res, err := c.Request(ctx, http.MethodPut, fmt.Sprintf("%s/%s", userAIProviderKeysPath(user), providerID), req)
if err != nil {
return UserAIProviderKeyConfig{}, xerrors.Errorf("upsert user AI provider key: %w", err)
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return UserAIProviderKeyConfig{}, ReadBodyAsError(res)
}
var config UserAIProviderKeyConfig
return config, json.NewDecoder(res.Body).Decode(&config)
}
// DeleteUserAIProviderKey deletes a user API key for an AI provider.
func (c *ExperimentalClient) DeleteUserAIProviderKey(ctx context.Context, user string, providerID uuid.UUID) error {
res, err := c.Request(ctx, http.MethodDelete, fmt.Sprintf("%s/%s", userAIProviderKeysPath(user), providerID), nil)
if err != nil {
return xerrors.Errorf("delete user AI provider key: %w", err)
}
defer res.Body.Close()
if res.StatusCode != http.StatusNoContent {
return ReadBodyAsError(res)
}
return nil
}
func userAIProviderKeysPath(user string) string {
return fmt.Sprintf("/api/experimental/users/%s/ai-provider-keys", url.PathEscape(user))
}
// ListUserChatProviderConfigs returns user-scoped chat provider configs.
func (c *ExperimentalClient) ListUserChatProviderConfigs(ctx context.Context) ([]UserChatProviderConfig, error) {
res, err := c.Request(ctx, http.MethodGet, "/api/experimental/chats/user-provider-configs", nil)
+69 -124
View File
@@ -24,6 +24,47 @@ import (
"github.com/coder/websocket"
)
func createOpenAIProviderForTest(
ctx context.Context,
t testing.TB,
client *codersdk.ExperimentalClient,
apiKey string,
baseURL string,
) codersdk.AIProvider {
t.Helper()
provider, err := client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{
Type: codersdk.AIProviderTypeOpenAI,
Name: "openai-" + uuid.NewString(),
DisplayName: "OpenAI",
Enabled: true,
BaseURL: baseURL,
APIKeys: []string{apiKey},
})
require.NoError(t, err)
return provider
}
func createOpenAIModelConfigForTest(
ctx context.Context,
t testing.TB,
client *codersdk.ExperimentalClient,
apiKey string,
baseURL string,
) codersdk.ChatModelConfig {
t.Helper()
provider := createOpenAIProviderForTest(ctx, t, client, apiKey, baseURL)
model, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: string(provider.Type),
AIProviderID: &provider.ID,
Model: "gpt-4",
DisplayName: "GPT-4",
ContextLimit: ptr.Ref(int64(1000)),
CompressionThreshold: ptr.Ref(int32(70)),
})
require.NoError(t, err)
return model
}
func TestChatStreamRelay(t *testing.T) {
t.Parallel()
@@ -74,27 +115,11 @@ func TestChatStreamRelay(t *testing.T) {
return chattest.OpenAINonStreamingResponse("ok")
})
//nolint:gocritic // Test uses owner client to configure chat providers.
provider, err := codersdk.NewExperimentalClient(firstClient).CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test",
BaseURL: openai,
})
require.NoError(t, err)
require.Equal(t, codersdk.ChatProviderConfigSourceDatabase, provider.Source)
model, err := codersdk.NewExperimentalClient(firstClient).CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: provider.Provider,
Model: "gpt-4",
DisplayName: "GPT-4",
ContextLimit: &[]int64{1000}[0],
CompressionThreshold: &[]int32{70}[0],
})
require.NoError(t, err)
expClient := codersdk.NewExperimentalClient(firstClient)
model := createOpenAIModelConfigForTest(ctx, t, expClient, "test", openai)
// Create a chat on the first replica
chat, err := codersdk.NewExperimentalClient(firstClient).CreateChat(ctx, codersdk.CreateChatRequest{
chat, err := expClient.CreateChat(ctx, codersdk.CreateChatRequest{
OrganizationID: firstUser.OrganizationID,
Content: []codersdk.ChatInputPart{{
Type: codersdk.ChatInputPartTypeText,
@@ -264,26 +289,11 @@ func TestChatStreamRelay(t *testing.T) {
return chattest.OpenAINonStreamingResponse("ok")
})
//nolint:gocritic // Test uses owner client to configure chat providers.
provider, err := codersdk.NewExperimentalClient(firstClient).CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test",
BaseURL: openai,
})
require.NoError(t, err)
model, err := codersdk.NewExperimentalClient(firstClient).CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: provider.Provider,
Model: "gpt-4",
DisplayName: "GPT-4",
ContextLimit: &[]int64{1000}[0],
CompressionThreshold: &[]int32{70}[0],
})
require.NoError(t, err)
expClient := codersdk.NewExperimentalClient(firstClient)
model := createOpenAIModelConfigForTest(ctx, t, expClient, "test", openai)
// Create a chat on the first replica.
chat, err := codersdk.NewExperimentalClient(firstClient).CreateChat(ctx, codersdk.CreateChatRequest{
chat, err := expClient.CreateChat(ctx, codersdk.CreateChatRequest{
OrganizationID: firstUser.OrganizationID,
Content: []codersdk.ChatInputPart{{
Type: codersdk.ChatInputPartTypeText,
@@ -435,25 +445,10 @@ func TestChatStreamRelay(t *testing.T) {
return chattest.OpenAINonStreamingResponse("ok")
})
//nolint:gocritic // Test uses owner client to configure providers.
provider, err := codersdk.NewExperimentalClient(firstClient).CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test",
BaseURL: openai,
})
require.NoError(t, err)
expClient := codersdk.NewExperimentalClient(firstClient)
model := createOpenAIModelConfigForTest(ctx, t, expClient, "test", openai)
model, err := codersdk.NewExperimentalClient(firstClient).CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: provider.Provider,
Model: "gpt-4",
DisplayName: "GPT-4",
ContextLimit: &[]int64{1000}[0],
CompressionThreshold: &[]int32{70}[0],
})
require.NoError(t, err)
chat, err := codersdk.NewExperimentalClient(firstClient).CreateChat(ctx, codersdk.CreateChatRequest{
chat, err := expClient.CreateChat(ctx, codersdk.CreateChatRequest{
OrganizationID: firstUser.OrganizationID,
Content: []codersdk.ChatInputPart{{
Type: codersdk.ChatInputPartTypeText,
@@ -607,25 +602,10 @@ func TestChatStreamRelay(t *testing.T) {
return chattest.OpenAINonStreamingResponse("ok")
})
//nolint:gocritic // Test uses owner client to configure providers.
provider, err := codersdk.NewExperimentalClient(firstClient).CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test",
BaseURL: openai,
})
require.NoError(t, err)
expClient := codersdk.NewExperimentalClient(firstClient)
model := createOpenAIModelConfigForTest(ctx, t, expClient, "test", openai)
model, err := codersdk.NewExperimentalClient(firstClient).CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: provider.Provider,
Model: "gpt-4",
DisplayName: "GPT-4",
ContextLimit: &[]int64{1000}[0],
CompressionThreshold: &[]int32{70}[0],
})
require.NoError(t, err)
chat, err := codersdk.NewExperimentalClient(firstClient).CreateChat(ctx, codersdk.CreateChatRequest{
chat, err := expClient.CreateChat(ctx, codersdk.CreateChatRequest{
OrganizationID: firstUser.OrganizationID,
Content: []codersdk.ChatInputPart{{
Type: codersdk.ChatInputPartTypeText,
@@ -754,26 +734,11 @@ func TestChatStreamRelay(t *testing.T) {
return chattest.OpenAINonStreamingResponse("ok")
})
//nolint:gocritic // Test uses owner client to configure chat providers.
provider, err := codersdk.NewExperimentalClient(firstClient).CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test",
BaseURL: openai,
})
require.NoError(t, err)
model, err := codersdk.NewExperimentalClient(firstClient).CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: provider.Provider,
Model: "gpt-4",
DisplayName: "GPT-4",
ContextLimit: &[]int64{1000}[0],
CompressionThreshold: &[]int32{70}[0],
})
require.NoError(t, err)
expClient := codersdk.NewExperimentalClient(firstClient)
model := createOpenAIModelConfigForTest(ctx, t, expClient, "test", openai)
// Create a chat on the first replica.
chat, err := codersdk.NewExperimentalClient(firstClient).CreateChat(ctx, codersdk.CreateChatRequest{
chat, err := expClient.CreateChat(ctx, codersdk.CreateChatRequest{
OrganizationID: firstUser.OrganizationID,
Content: []codersdk.ChatInputPart{{
Type: codersdk.ChatInputPartTypeText,
@@ -954,17 +919,7 @@ func TestChatModelConfigDefault(t *testing.T) {
client, _ := coderdenttest.New(t, nil)
expClient := codersdk.NewExperimentalClient(client)
//nolint:gocritic // Test uses owner client to configure chat providers.
provider, err := expClient.CreateChatProvider(
ctx,
codersdk.CreateChatProviderConfigRequest{
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test",
BaseURL: "https://example.com",
},
)
require.NoError(t, err)
provider := createOpenAIProviderForTest(ctx, t, expClient, "test", "https://example.com")
contextLimit := int64(1000)
compressionThreshold := int32(70)
@@ -974,7 +929,8 @@ func TestChatModelConfigDefault(t *testing.T) {
firstModel, err := expClient.CreateChatModelConfig(
ctx,
codersdk.CreateChatModelConfigRequest{
Provider: provider.Provider,
Provider: string(provider.Type),
AIProviderID: &provider.ID,
Model: "gpt-5-a",
DisplayName: "GPT 5 A",
IsDefault: &trueValue,
@@ -988,7 +944,8 @@ func TestChatModelConfigDefault(t *testing.T) {
secondModel, err := expClient.CreateChatModelConfig(
ctx,
codersdk.CreateChatModelConfigRequest{
Provider: provider.Provider,
Provider: string(provider.Type),
AIProviderID: &provider.ID,
Model: "gpt-5-b",
DisplayName: "GPT 5 B",
IsDefault: &trueValue,
@@ -1115,16 +1072,10 @@ func TestCreateChatNonDefaultOrg(t *testing.T) {
})
expClient := codersdk.NewExperimentalClient(client)
// Set up a chat provider and model config.
provider, err := expClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test-key",
BaseURL: "https://example.com",
})
require.NoError(t, err)
_, err = expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: provider.Provider,
provider := createOpenAIProviderForTest(ctx, t, expClient, "test-key", "https://example.com")
_, err := expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: string(provider.Type),
AIProviderID: &provider.ID,
Model: "gpt-4o-mini",
DisplayName: "Test Model",
IsDefault: ptr.Ref(true),
@@ -1191,16 +1142,10 @@ func TestListChats_OrgAdminOnlySeesOwnChats(t *testing.T) {
})
expClient := codersdk.NewExperimentalClient(client)
// Set up a chat provider and model config.
provider, err := expClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
Provider: "openai",
DisplayName: "OpenAI",
APIKey: "test-key",
BaseURL: "https://example.com",
})
require.NoError(t, err)
_, err = expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: provider.Provider,
provider := createOpenAIProviderForTest(ctx, t, expClient, "test-key", "https://example.com")
_, err := expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
Provider: string(provider.Type),
AIProviderID: &provider.ID,
Model: "gpt-4o-mini",
DisplayName: "Test Model",
IsDefault: ptr.Ref(true),
+63
View File
@@ -408,6 +408,7 @@ export type DeploymentConfig = Readonly<{
}>;
const chatProviderConfigsPath = "/api/experimental/chats/providers";
const aiProviderConfigsPath = "/api/v2/ai/providers";
const chatModelConfigsPath = "/api/experimental/chats/model-configs";
const userChatProviderConfigsPath =
"/api/experimental/chats/user-provider-configs";
@@ -415,6 +416,8 @@ const userSkillsPath = (user: string) =>
`/api/experimental/users/${encodeURIComponent(user)}/skills`;
const userSkillPath = (user: string, name: string) =>
`${userSkillsPath(user)}/${encodeURIComponent(name)}`;
const userAIProviderKeysPath = (user = "me") =>
`/api/experimental/users/${encodeURIComponent(user)}/ai-provider-keys`;
const mcpServerConfigsPath = "/api/experimental/mcp/servers";
type ChatCostDateParams = {
@@ -3313,6 +3316,66 @@ class ExperimentalApiMethods {
return response.data;
};
listAIProviders = async (): Promise<TypesGen.AIProvider[]> => {
const response = await this.axios.get<TypesGen.AIProvider[]>(
aiProviderConfigsPath,
);
return response.data;
};
createAIProvider = async (
req: TypesGen.CreateAIProviderRequest,
): Promise<TypesGen.AIProvider> => {
const response = await this.axios.post<TypesGen.AIProvider>(
aiProviderConfigsPath,
req,
);
return response.data;
};
updateAIProvider = async (
providerId: string,
req: TypesGen.UpdateAIProviderRequest,
): Promise<TypesGen.AIProvider> => {
const response = await this.axios.patch<TypesGen.AIProvider>(
`${aiProviderConfigsPath}/${providerId}`,
req,
);
return response.data;
};
deleteAIProvider = async (providerId: string): Promise<void> => {
await this.axios.delete(`${aiProviderConfigsPath}/${providerId}`);
};
getUserAIProviderKeyConfigs = async (
user = "me",
): Promise<TypesGen.UserAIProviderKeyConfig[]> => {
const response = await this.axios.get<TypesGen.UserAIProviderKeyConfig[]>(
userAIProviderKeysPath(user),
);
return response.data;
};
upsertUserAIProviderKey = async (
providerId: string,
req: TypesGen.CreateUserAIProviderKeyRequest,
user = "me",
): Promise<TypesGen.UserAIProviderKeyConfig> => {
const response = await this.axios.put<TypesGen.UserAIProviderKeyConfig>(
`${userAIProviderKeysPath(user)}/${providerId}`,
req,
);
return response.data;
};
deleteUserAIProviderKey = async (
providerId: string,
user = "me",
): Promise<void> => {
await this.axios.delete(`${userAIProviderKeysPath(user)}/${providerId}`);
};
getChatSystemPrompt =
async (): Promise<TypesGen.ChatSystemPromptResponse> => {
const response = await this.axios.get<TypesGen.ChatSystemPromptResponse>(
+86 -12
View File
@@ -10,7 +10,9 @@ import {
type CreateChatMessageRequestWithClearablePlanMode,
} from "#/api/api";
import type * as TypesGen from "#/api/typesGenerated";
import { type AIProviderType, AIProviderTypes } from "#/api/typesGenerated";
import type { UsePaginatedQueryOptions } from "#/hooks/usePaginatedQuery";
import { formatProviderLabel } from "#/utils/aiProviders";
import {
projectEditedConversationIntoCache,
reconcileEditedMessageInCache,
@@ -1607,10 +1609,29 @@ export const chatModels = () => ({
const chatProviderConfigsKey = ["chat-provider-configs"] as const;
const toChatProviderConfig = (
provider: TypesGen.AIProvider,
): TypesGen.ChatProviderConfig => ({
id: provider.id,
provider: provider.type,
display_name: provider.display_name || provider.type,
enabled: provider.enabled,
has_api_key: provider.api_keys.length > 0,
central_api_key_enabled: true,
allow_user_api_key: true,
allow_central_api_key_fallback: true,
base_url: provider.base_url,
source: "database",
created_at: provider.created_at,
updated_at: provider.updated_at,
});
export const chatProviderConfigs = () => ({
queryKey: chatProviderConfigsKey,
queryFn: (): Promise<TypesGen.ChatProviderConfig[]> =>
API.experimental.getChatProviderConfigs(),
queryFn: async (): Promise<TypesGen.ChatProviderConfig[]> => {
const providers = await API.experimental.listAIProviders();
return providers.map(toChatProviderConfig);
},
});
const chatModelConfigsKey = ["chat-model-configs"] as const;
@@ -1627,8 +1648,17 @@ export const userChatProviderConfigsKey = [
export const userChatProviderConfigs = () => ({
queryKey: userChatProviderConfigsKey,
queryFn: (): Promise<TypesGen.UserChatProviderConfig[]> =>
API.experimental.getUserChatProviderConfigs(),
queryFn: async (): Promise<TypesGen.UserChatProviderConfig[]> => {
const configs = await API.experimental.getUserAIProviderKeyConfigs();
return configs.map((config) => ({
provider_id: config.provider.id,
provider: config.provider.type,
display_name: config.provider.display_name || config.provider.type,
has_user_api_key: config.has_user_api_key,
byok_enabled: config.byok_enabled,
has_central_api_key_fallback: config.has_provider_api_key,
}));
},
});
type UpsertUserChatProviderKeyArgs = {
@@ -1638,7 +1668,7 @@ type UpsertUserChatProviderKeyArgs = {
export const upsertUserChatProviderKey = (queryClient: QueryClient) => ({
mutationFn: ({ providerConfigId, req }: UpsertUserChatProviderKeyArgs) =>
API.experimental.upsertUserChatProviderKey(providerConfigId, req),
API.experimental.upsertUserAIProviderKey(providerConfigId, req),
onSuccess: async () => {
await Promise.all([
queryClient.invalidateQueries({
@@ -1651,7 +1681,7 @@ export const upsertUserChatProviderKey = (queryClient: QueryClient) => ({
export const deleteUserChatProviderKey = (queryClient: QueryClient) => ({
mutationFn: (providerConfigId: string) =>
API.experimental.deleteUserChatProviderKey(providerConfigId),
API.experimental.deleteUserAIProviderKey(providerConfigId),
onSuccess: async () => {
await Promise.all([
queryClient.invalidateQueries({
@@ -1670,9 +1700,41 @@ const invalidateChatConfigurationQueries = async (queryClient: QueryClient) => {
]);
};
const generatedAIProviderName = (provider: string): string => {
const suffix =
globalThis.crypto?.randomUUID?.() ??
`${Date.now().toString(36)}-${Math.random().toString(36).slice(2, 10)}`;
return `${provider}-${suffix}`;
};
const normalizeAIProviderType = (provider: string): AIProviderType => {
const normalized = provider.trim().toLowerCase();
const aliased =
normalized === "openai-compatible" || normalized === "openai_compatible"
? "openai-compat"
: normalized;
const providerType = AIProviderTypes.find(
(candidate) => candidate === aliased,
);
if (!providerType) {
throw new Error(`Unsupported AI provider type "${provider}".`);
}
return providerType;
};
export const createChatProviderConfig = (queryClient: QueryClient) => ({
mutationFn: (req: TypesGen.CreateChatProviderConfigRequest) =>
API.experimental.createChatProviderConfig(req),
mutationFn: (req: TypesGen.CreateChatProviderConfigRequest) => {
const providerType = normalizeAIProviderType(req.provider);
const apiKey = req.api_key;
return API.experimental.createAIProvider({
type: providerType,
name: generatedAIProviderName(providerType),
display_name: req.display_name || formatProviderLabel(providerType),
base_url: req.base_url ?? "",
enabled: req.enabled ?? true,
api_keys: apiKey ? [apiKey] : undefined,
});
},
onSuccess: async () => {
await invalidateChatConfigurationQueries(queryClient);
},
@@ -1684,11 +1746,23 @@ type UpdateChatProviderConfigMutationArgs = {
};
export const updateChatProviderConfig = (queryClient: QueryClient) => ({
mutationFn: ({
mutationFn: async ({
providerConfigId,
req,
}: UpdateChatProviderConfigMutationArgs) =>
API.experimental.updateChatProviderConfig(providerConfigId, req),
}: UpdateChatProviderConfigMutationArgs) => {
const apiKey = req.api_key;
return API.experimental.updateAIProvider(providerConfigId, {
display_name: req.display_name,
base_url: req.base_url,
enabled: req.enabled,
api_keys:
req.api_key === undefined
? undefined
: apiKey
? [{ api_key: apiKey }]
: [],
});
},
onSuccess: async () => {
await invalidateChatConfigurationQueries(queryClient);
},
@@ -1696,7 +1770,7 @@ export const updateChatProviderConfig = (queryClient: QueryClient) => ({
export const deleteChatProviderConfig = (queryClient: QueryClient) => ({
mutationFn: (providerConfigId: string) =>
API.experimental.deleteChatProviderConfig(providerConfigId),
API.experimental.deleteAIProvider(providerConfigId),
onSuccess: async () => {
await invalidateChatConfigurationQueries(queryClient);
},
+39 -1
View File
@@ -445,6 +445,19 @@ export interface AIProviderSettings {}
*/
export const AIProviderSettingsTypeBedrock = "bedrock";
// From codersdk/chats.go
/**
* AIProviderSummary is provider metadata embedded in other API responses.
*/
export interface AIProviderSummary {
readonly id: string;
readonly type: AIProviderType;
readonly name: string;
readonly display_name: string;
readonly enabled: boolean;
readonly deleted: boolean;
}
// From codersdk/aiproviders.go
export type AIProviderType =
| "anthropic"
@@ -2278,6 +2291,7 @@ export interface ChatModelCallConfig {
export interface ChatModelConfig {
readonly id: string;
readonly provider: string;
readonly ai_provider_id?: string;
readonly model: string;
readonly display_name: string;
readonly enabled: boolean;
@@ -3251,7 +3265,8 @@ export interface CreateChatMessageResponse {
* CreateChatModelConfigRequest creates a chat model config.
*/
export interface CreateChatModelConfigRequest {
readonly provider: string;
readonly provider?: string;
readonly ai_provider_id?: string;
readonly model: string;
readonly display_name?: string;
readonly enabled?: boolean;
@@ -3582,6 +3597,15 @@ export interface CreateTokenRequest {
readonly allow_list?: readonly APIAllowListTarget[];
}
// From codersdk/chats.go
/**
* CreateUserAIProviderKeyRequest creates or replaces a user's API key
* for an AI provider.
*/
export interface CreateUserAIProviderKeyRequest {
readonly api_key: string;
}
// From codersdk/chats.go
/**
* CreateUserChatProviderKeyRequest creates or replaces a user's API key
@@ -8449,6 +8473,7 @@ export interface UpdateChatDesktopEnabledRequest {
*/
export interface UpdateChatModelConfigRequest {
readonly provider?: string;
readonly ai_provider_id?: string;
readonly model?: string;
readonly display_name?: string;
readonly enabled?: boolean;
@@ -9070,6 +9095,18 @@ export interface User extends ReducedUser {
readonly has_ai_seat: boolean;
}
// From codersdk/chats.go
/**
* UserAIProviderKeyConfig is a provider summary from the current user's
* perspective. It reports key presence but never returns key material.
*/
export interface UserAIProviderKeyConfig {
readonly provider: AIProviderSummary;
readonly has_user_api_key: boolean;
readonly has_provider_api_key: boolean;
readonly byok_enabled: boolean;
}
// From codersdk/insights.go
/**
* UserActivity shows the session time for a user.
@@ -9196,6 +9233,7 @@ export interface UserChatProviderConfig {
readonly display_name: string;
readonly has_user_api_key: boolean;
readonly has_central_api_key_fallback: boolean;
readonly byok_enabled: boolean;
}
// From codersdk/insights.go
@@ -18,6 +18,7 @@ const createProvider = (
display_name: overrides.display_name ?? overrides.provider,
has_user_api_key: overrides.has_user_api_key ?? false,
has_central_api_key_fallback: overrides.has_central_api_key_fallback ?? false,
byok_enabled: overrides.byok_enabled ?? true,
});
const createModel = (
@@ -24,6 +24,14 @@ type ProviderStatus = {
const getProviderStatus = (
provider: UserChatProviderConfig,
): ProviderStatus => {
if (!provider.byok_enabled) {
return {
label: "User keys disabled",
variant: "default",
note: "Personal API keys are disabled by your admin.",
};
}
if (provider.has_user_api_key) {
return {
label: "Key saved",
@@ -55,11 +63,13 @@ interface ProviderKeyPanelProps {
isRemoving: boolean;
onSave: (providerConfigId: string, apiKey: string) => void;
onRemove: (providerConfigId: string) => void;
hasAmbiguousProviderType: boolean;
}
const ProviderKeyPanel: FC<ProviderKeyPanelProps> = ({
provider,
models,
hasAmbiguousProviderType,
isModelsLoading,
areModelsUnavailable,
isSaving,
@@ -76,15 +86,26 @@ const ProviderKeyPanel: FC<ProviderKeyPanelProps> = ({
const status = getProviderStatus(provider);
const enabledModels = models.filter((model) => {
return model.enabled && model.provider === provider.provider;
return (
model.enabled &&
(model.ai_provider_id === provider.provider_id ||
(!model.ai_provider_id &&
!hasAmbiguousProviderType &&
model.provider === provider.provider))
);
});
const trimmedApiKey = apiKey.trim();
const hasApiKeyValue = apiKey.trim().length > 0;
const hasAPIKeyWhitespace =
apiKey !== API_KEY_PLACEHOLDER && apiKey.trim() !== apiKey;
const saveDisabled =
trimmedApiKey.length === 0 ||
!provider.byok_enabled ||
!hasApiKeyValue ||
hasAPIKeyWhitespace ||
apiKey === API_KEY_PLACEHOLDER ||
isSaving ||
isRemoving;
const inputDisabled = isSaving || isRemoving;
const inputDisabled = !provider.byok_enabled || isSaving || isRemoving;
const removeDisabled = isSaving || isRemoving;
const providerName = provider.display_name || provider.provider;
const handleApiKeyFocus = () => {
@@ -101,7 +122,7 @@ const ProviderKeyPanel: FC<ProviderKeyPanelProps> = ({
return;
}
onSave(provider.provider_id, trimmedApiKey);
onSave(provider.provider_id, apiKey);
};
const handleRemoveKey = () => {
@@ -136,25 +157,32 @@ const ProviderKeyPanel: FC<ProviderKeyPanelProps> = ({
API Key
</label>
<div className="flex flex-col gap-3 lg:flex-row lg:items-start">
<Input
id={apiKeyInputId}
name={`provider-api-key-${provider.provider_id}`}
type="password"
autoComplete="off"
data-1p-ignore
data-lpignore="true"
data-form-type="other"
data-bwignore
className="h-9 font-mono text-[13px] lg:flex-1"
placeholder="sk-..."
value={apiKey}
onFocus={handleApiKeyFocus}
onChange={(event) => {
setApiKey(event.target.value);
setApiKeyTouched(true);
}}
disabled={inputDisabled}
/>
<div className="flex flex-col gap-1.5 lg:flex-1">
<Input
id={apiKeyInputId}
name={`provider-api-key-${provider.provider_id}`}
type="password"
autoComplete="off"
data-1p-ignore
data-lpignore="true"
data-form-type="other"
data-bwignore
className="h-9 font-mono text-[13px]"
placeholder="sk-..."
value={apiKey}
onFocus={handleApiKeyFocus}
onChange={(event) => {
setApiKey(event.target.value);
setApiKeyTouched(true);
}}
disabled={inputDisabled}
/>
{hasAPIKeyWhitespace && (
<p className="m-0 text-xs text-content-destructive">
API key must not contain leading or trailing whitespace.
</p>
)}
</div>
<div className="flex items-center gap-2">
<Button type="submit" size="sm" disabled={saveDisabled}>
Save
@@ -165,7 +193,7 @@ const ProviderKeyPanel: FC<ProviderKeyPanelProps> = ({
variant="outline"
size="sm"
onClick={() => setIsDeleteDialogOpen(true)}
disabled={inputDisabled}
disabled={removeDisabled}
>
Remove
</Button>
@@ -245,6 +273,14 @@ export const AgentSettingsAPIKeysPageView: FC<
onSave,
onRemove,
}) => {
const providerTypeCounts = new Map<string, number>();
for (const item of providerItems) {
providerTypeCounts.set(
item.provider.provider,
(providerTypeCounts.get(item.provider.provider) ?? 0) + 1,
);
}
return (
<div>
<section className="flex flex-col gap-8">
@@ -275,6 +311,9 @@ export const AgentSettingsAPIKeysPageView: FC<
isRemoving={item.isRemoving}
onSave={onSave}
onRemove={onRemove}
hasAmbiguousProviderType={
(providerTypeCounts.get(item.provider.provider) ?? 0) > 1
}
/>
))}
</div>
@@ -17,8 +17,6 @@ import {
} from "./ChatModelAdminPanel";
import { formatContextBadge, getKnownModelsForProvider } from "./knownModels";
// ── Helpers ────────────────────────────────────────────────────
const now = "2026-02-18T12:00:00.000Z";
const nilProviderConfigID = "00000000-0000-0000-0000-000000000000";
@@ -88,7 +86,7 @@ const setupChatSpies = (state: {
async (req) => {
const created = createProviderConfig({
id: `provider-${Date.now()}`,
provider: req.provider,
provider: req.provider ?? "",
display_name: req.display_name ?? "",
has_api_key: (req.api_key ?? "").trim().length > 0,
central_api_key_enabled: req.central_api_key_enabled ?? true,
@@ -99,7 +97,9 @@ const setupChatSpies = (state: {
source: "database",
});
state.providerConfigs = [
...state.providerConfigs.filter((p) => p.provider !== req.provider),
...state.providerConfigs.filter(
(p) => !(p.id === nilProviderConfigID && p.provider === req.provider),
),
created,
];
return created;
@@ -152,7 +152,7 @@ const setupChatSpies = (state: {
async (req) => {
const created = createModelConfig({
id: `model-${state.modelConfigs.length + 1}`,
provider: req.provider,
provider: req.provider ?? "",
model: req.model,
display_name: req.display_name || req.model,
enabled: req.enabled ?? true,
@@ -211,8 +211,6 @@ const setupChatSpies = (state: {
);
};
// ── Meta ───────────────────────────────────────────────────────
const meta: Meta<typeof ChatModelAdminPanel> = {
title: "pages/AgentsPage/ChatModelAdminPanel",
component: ChatModelAdminPanel,
@@ -224,7 +222,7 @@ const meta: Meta<typeof ChatModelAdminPanel> = {
providerConfigsError: null,
modelConfigsError: null,
modelCatalogError: null,
onCreateProvider: fn(async () => ({})),
onCreateProvider: fn(async () => ({ id: "" })),
onUpdateProvider: fn(async () => ({})),
onDeleteProvider: fn(async () => undefined),
isProviderMutationPending: false,
@@ -242,8 +240,6 @@ const meta: Meta<typeof ChatModelAdminPanel> = {
export default meta;
type Story = StoryObj<typeof ChatModelAdminPanel>;
// ── Providers section stories ──────────────────────────────────
export const ProviderAccordionCards: Story = {
args: {
section: "providers" as ChatModelAdminSection,
@@ -268,6 +264,100 @@ export const ProviderAccordionCards: Story = {
},
};
export const AddProviderFromMenu: Story = {
render: function AddProviderFromMenu(args) {
const [providerConfigsData, setProviderConfigsData] = useState(
args.providerConfigsData,
);
const handleCreateProvider: ChatModelAdminPanelStoryProps["onCreateProvider"] =
async (req) => {
const created = createProviderConfig({
id: `provider-${req.provider}`,
provider: req.provider ?? "",
display_name: req.display_name ?? "",
has_api_key: (req.api_key ?? "").trim().length > 0,
central_api_key_enabled: req.central_api_key_enabled ?? true,
allow_user_api_key: req.allow_user_api_key ?? true,
allow_central_api_key_fallback:
req.allow_central_api_key_fallback ?? true,
base_url: req.base_url ?? "",
source: "database",
});
await args.onCreateProvider(req);
setProviderConfigsData((current) => [...(current ?? []), created]);
return created;
};
return (
<ChatModelAdminPanel
{...args}
providerConfigsData={providerConfigsData}
onCreateProvider={handleCreateProvider}
/>
);
},
args: {
section: "providers" as ChatModelAdminSection,
sectionLabel: "Providers",
sectionDescription:
"Connect third-party LLM services like OpenAI, Anthropic, or Google.",
providerConfigsData: [
createProviderConfig({
id: "provider-anthropic",
provider: "anthropic",
display_name: "Anthropic Migration Test",
has_api_key: true,
allow_user_api_key: true,
}),
createProviderConfig({
id: "provider-openai",
provider: "openai",
display_name: "OpenAI Migration Test",
has_api_key: true,
allow_user_api_key: true,
}),
createProviderConfig({
id: "provider-openai-compatible",
provider: "openai-compat",
display_name: "OpenAI Compatible Migration Test",
has_api_key: true,
allow_user_api_key: true,
}),
],
modelCatalogData: { providers: [] },
},
play: async ({ canvasElement, args }) => {
const body = within(canvasElement.ownerDocument.body);
await userEvent.click(
await body.findByRole("button", { name: "Add provider" }),
);
await userEvent.click(
await body.findByRole("menuitem", { name: /Google/i }),
);
expect(await body.findByLabelText(/^API Key$/i)).toBeInTheDocument();
await userEvent.click(
await body.findByRole("button", { name: "Create provider config" }),
);
await waitFor(() => {
expect(args.onCreateProvider).toHaveBeenCalledWith(
expect.objectContaining({ provider: "google" }),
);
});
await waitFor(() => {
expect(body.getByRole("button", { name: "Save changes" })).toBeDisabled();
});
await userEvent.click(body.getByText("Back"));
expect(
await body.findByRole("button", { name: "Google" }),
).toBeInTheDocument();
},
};
export const EnvPresetProviders: Story = {
args: {
section: "providers" as ChatModelAdminSection,
@@ -346,10 +436,9 @@ export const CreateAndUpdateProvider: Story = {
const handleCreateProvider: ChatModelAdminPanelStoryProps["onCreateProvider"] =
async (req) => {
const result = await args.onCreateProvider(req);
const created = createProviderConfig({
id: `provider-${Date.now()}`,
provider: req.provider,
provider: req.provider ?? "",
display_name: req.display_name ?? "",
has_api_key: (req.api_key ?? "").trim().length > 0,
central_api_key_enabled: req.central_api_key_enabled ?? true,
@@ -359,11 +448,15 @@ export const CreateAndUpdateProvider: Story = {
base_url: req.base_url ?? "",
source: "database",
});
await args.onCreateProvider(req);
setProviderConfigsData((current) => [
...(current ?? []).filter((p) => p.provider !== req.provider),
...(current ?? []).filter(
(p) =>
!(p.id === nilProviderConfigID && p.provider === req.provider),
),
created,
]);
return result;
return created;
};
const handleUpdateProvider: ChatModelAdminPanelStoryProps["onUpdateProvider"] =
@@ -446,24 +539,13 @@ export const CreateAndUpdateProvider: Story = {
await userEvent.click(await body.findByRole("button", { name: /OpenAI/i }));
await expect(
await body.findByRole("switch", { name: "Central API key" }),
).toBeChecked();
expect(
await body.findByRole("switch", { name: "Allow user API keys" }),
).not.toBeChecked();
expect(
body.queryByRole("switch", { name: "Use central key as fallback" }),
).not.toBeInTheDocument();
await userEvent.type(
await body.findByLabelText(/^API Key$/i),
"sk-provider-key",
);
await userEvent.type(
await body.findByLabelText("Base URL"),
"https://proxy.example.com/v1",
);
const createBaseURLInput = await body.findByLabelText("Base URL");
await userEvent.clear(createBaseURLInput);
await userEvent.type(createBaseURLInput, "https://proxy.example.com/v1");
await userEvent.click(
await body.findByRole("button", { name: "Create provider config" }),
);
@@ -479,9 +561,6 @@ export const CreateAndUpdateProvider: Story = {
provider: "openai",
api_key: "sk-provider-key",
base_url: "https://proxy.example.com/v1",
central_api_key_enabled: true,
allow_user_api_key: false,
allow_central_api_key_fallback: false,
}),
);
@@ -491,13 +570,6 @@ export const CreateAndUpdateProvider: Story = {
).toBeInTheDocument();
});
await userEvent.click(
await body.findByRole("switch", { name: "Allow user API keys" }),
);
await userEvent.click(
await body.findByRole("switch", { name: "Use central key as fallback" }),
);
const apiKeyInput = body.getByLabelText(/^API Key$/i);
await userEvent.clear(apiKeyInput);
await userEvent.type(apiKeyInput, "sk-updated-provider-key");
@@ -514,180 +586,6 @@ export const CreateAndUpdateProvider: Story = {
expect.objectContaining({
api_key: "sk-updated-provider-key",
base_url: "https://internal-proxy.example.com/v2",
allow_user_api_key: true,
allow_central_api_key_fallback: true,
}),
);
},
};
export const ProviderWithUserKeysEnabled: Story = {
args: {
section: "providers" as ChatModelAdminSection,
providerConfigsData: [
createProviderConfig({
id: "provider-openai-user-keys",
provider: "openai",
display_name: "OpenAI",
has_api_key: true,
central_api_key_enabled: true,
allow_user_api_key: true,
allow_central_api_key_fallback: false,
}),
],
modelCatalogData: { providers: [] },
},
beforeEach: () => {
setupChatSpies({
providerConfigs: [
createProviderConfig({
id: "provider-openai-user-keys",
provider: "openai",
display_name: "OpenAI",
has_api_key: true,
central_api_key_enabled: true,
allow_user_api_key: true,
allow_central_api_key_fallback: false,
}),
],
modelConfigs: [],
modelCatalog: { providers: [] },
});
},
play: async ({ canvasElement }) => {
const body = within(canvasElement.ownerDocument.body);
await expect(
await body.findByText("User keys enabled"),
).toBeInTheDocument();
await userEvent.click(body.getByRole("button", { name: /OpenAI/i }));
await expect(
await body.findByRole("switch", { name: "Allow user API keys" }),
).toBeChecked();
await expect(
await body.findByRole("switch", {
name: "Use central key as fallback",
}),
).not.toBeChecked();
},
};
export const ProviderWithCentralFallback: Story = {
args: {
section: "providers" as ChatModelAdminSection,
providerConfigsData: [
createProviderConfig({
id: "provider-openrouter-fallback",
provider: "openrouter",
display_name: "OpenRouter",
has_api_key: true,
central_api_key_enabled: true,
allow_user_api_key: true,
allow_central_api_key_fallback: true,
}),
],
modelCatalogData: { providers: [] },
},
beforeEach: () => {
setupChatSpies({
providerConfigs: [
createProviderConfig({
id: "provider-openrouter-fallback",
provider: "openrouter",
display_name: "OpenRouter",
has_api_key: true,
central_api_key_enabled: true,
allow_user_api_key: true,
allow_central_api_key_fallback: true,
}),
],
modelConfigs: [],
modelCatalog: { providers: [] },
});
},
play: async ({ canvasElement }) => {
const body = within(canvasElement.ownerDocument.body);
await userEvent.click(
await body.findByRole("button", { name: /OpenRouter/i }),
);
await expect(
await body.findByRole("switch", {
name: "Use central key as fallback",
}),
).toBeChecked();
},
};
export const ProviderWithUserKeysOnly: Story = {
args: {
section: "providers" as ChatModelAdminSection,
providerConfigsData: [
createProviderConfig({
id: "provider-google-user-only",
provider: "google",
display_name: "Google",
has_api_key: false,
central_api_key_enabled: false,
allow_user_api_key: true,
allow_central_api_key_fallback: false,
}),
],
modelCatalogData: { providers: [] },
},
beforeEach: () => {
setupChatSpies({
providerConfigs: [
createProviderConfig({
id: "provider-google-user-only",
provider: "google",
display_name: "Google",
has_api_key: false,
central_api_key_enabled: false,
allow_user_api_key: true,
allow_central_api_key_fallback: false,
}),
],
modelConfigs: [],
modelCatalog: { providers: [] },
});
},
play: async ({ canvasElement, args }) => {
const body = within(canvasElement.ownerDocument.body);
await userEvent.click(await body.findByRole("button", { name: /Google/i }));
await expect(
await body.findByRole("switch", { name: "Central API key" }),
).not.toBeChecked();
await expect(
await body.findByRole("switch", { name: "Allow user API keys" }),
).toBeChecked();
expect(body.queryByLabelText(/^API Key$/i)).not.toBeInTheDocument();
expect(
body.queryByRole("switch", { name: "Use central key as fallback" }),
).not.toBeInTheDocument();
const saveButton = body.getByRole("button", { name: "Save changes" });
await userEvent.click(
body.getByRole("switch", { name: "Central API key" }),
);
await expect(await body.findByLabelText(/^API Key$/i)).toBeRequired();
expect(saveButton).toBeDisabled();
await userEvent.type(
body.getByLabelText(/^API Key$/i),
"sk-google-central-key",
);
await waitFor(() => {
expect(saveButton).toBeEnabled();
});
await userEvent.click(saveButton);
await waitFor(() => {
expect(args.onUpdateProvider).toHaveBeenCalledTimes(1);
});
expect(args.onUpdateProvider).toHaveBeenCalledWith(
"provider-google-user-only",
expect.objectContaining({
api_key: "sk-google-central-key",
central_api_key_enabled: true,
}),
);
},
@@ -777,51 +675,6 @@ export const ModelFormUserKeyOnlyProvider: Story = {
},
};
export const ProviderInvalidCredentialState: Story = {
args: {
section: "providers" as ChatModelAdminSection,
providerConfigsData: [
createProviderConfig({
id: "provider-bedrock-invalid",
provider: "bedrock",
display_name: "Bedrock",
has_api_key: false,
central_api_key_enabled: false,
allow_user_api_key: false,
allow_central_api_key_fallback: false,
}),
],
modelCatalogData: { providers: [] },
},
beforeEach: () => {
setupChatSpies({
providerConfigs: [
createProviderConfig({
id: "provider-bedrock-invalid",
provider: "bedrock",
display_name: "Bedrock",
has_api_key: false,
central_api_key_enabled: false,
allow_user_api_key: false,
allow_central_api_key_fallback: false,
}),
],
modelConfigs: [],
modelCatalog: { providers: [] },
});
},
play: async ({ canvasElement }) => {
const body = within(canvasElement.ownerDocument.body);
await userEvent.click(
await body.findByRole("button", { name: /Bedrock/i }),
);
await expect(
body.findByText("At least one credential source must be enabled"),
).resolves.toBeInTheDocument();
expect(body.getByRole("button", { name: "Save changes" })).toBeDisabled();
},
};
export const ProviderFormBedrockAmbientCredentials: Story = {
args: {
section: "providers" as ChatModelAdminSection,
@@ -843,6 +696,7 @@ export const ProviderFormBedrockAmbientCredentials: Story = {
);
const apiKeyInput = await body.findByLabelText(/^API Key$/i);
const baseURLInput = await body.findByLabelText("Base URL");
const createButton = body.getByRole("button", {
name: "Create provider config",
});
@@ -859,10 +713,17 @@ export const ProviderFormBedrockAmbientCredentials: Story = {
).resolves.toBeInTheDocument();
await expect(
body.findByText(
/Overrides the Bedrock runtime endpoint\.\s+Set AWS_REGION on\s+the Coder server to select the target region\./i,
/Bedrock runtime endpoint\.\s+Use the AWS region for the models this provider should call\./i,
),
).resolves.toBeInTheDocument();
await expect(createButton).toBeEnabled();
await expect(createButton).toBeDisabled();
await userEvent.type(
baseURLInput,
"https://bedrock-runtime.us-east-1.amazonaws.com",
);
await waitFor(() => {
expect(createButton).toBeEnabled();
});
await userEvent.click(createButton);
await waitFor(() => {
@@ -875,9 +736,7 @@ export const ProviderFormBedrockAmbientCredentials: Story = {
>;
expect(createRequest).toMatchObject({
provider: "bedrock",
central_api_key_enabled: true,
allow_user_api_key: false,
allow_central_api_key_fallback: false,
base_url: "https://bedrock-runtime.us-east-1.amazonaws.com",
});
expect(createRequest).not.toHaveProperty("api_key");
},
@@ -891,6 +750,7 @@ export const ProviderFormBedrockBearerToken: Story = {
id: "provider-bedrock-bearer",
provider: "bedrock",
display_name: "AWS Bedrock",
base_url: "https://bedrock-runtime.us-east-1.amazonaws.com",
has_api_key: true,
central_api_key_enabled: true,
allow_user_api_key: false,
@@ -936,6 +796,7 @@ export const ProviderFormBedrockClearBearerToken: Story = {
id: "provider-bedrock-clear",
provider: "bedrock",
display_name: "AWS Bedrock",
base_url: "https://bedrock-runtime.us-east-1.amazonaws.com",
has_api_key: true,
central_api_key_enabled: true,
allow_user_api_key: false,
@@ -1161,7 +1022,6 @@ export const UpdateModelEnabledToggle: Story = {
},
};
// ── Per-provider model form stories ────────────────────────────
// Each story opens the "Add model" form for a specific provider
// so you can visually verify the schema-driven fields render.
@@ -6,13 +6,18 @@ import { ErrorAlert } from "#/components/Alert/ErrorAlert";
import { Spinner } from "#/components/Spinner/Spinner";
import { cn } from "#/utils/cn";
import { formatProviderLabel } from "../../utils/modelOptions";
import { normalizeProvider, readOptionalString } from "./helpers";
import {
getDefaultProviderBaseURL,
normalizeProvider,
readOptionalString,
} from "./helpers";
import { ModelsSection } from "./ModelsSection";
import { ProvidersSection } from "./ProvidersSection";
// ── Exported types ─────────────────────────────────────────────
export type CreateProviderResult = { id: string };
export type ProviderState = {
key: string;
provider: string;
label: string;
providerConfig: TypesGen.ChatProviderConfig | undefined;
@@ -21,14 +26,13 @@ export type ProviderState = {
hasManagedAPIKey: boolean;
hasCatalogAPIKey: boolean;
hasEffectiveAPIKey: boolean;
allowUserAPIKey: boolean;
isEnvPreset: boolean;
baseURL: string;
};
export type ChatModelAdminSection = "providers" | "models";
// ── Internal helpers ───────────────────────────────────────────
type CatalogProvider = TypesGen.ChatModelsResponse["providers"][number];
const nilUUID = "00000000-0000-0000-0000-000000000000";
@@ -78,65 +82,106 @@ const getProviderModels = (
const getProviderBaseURL = (
providerConfig: TypesGen.ChatProviderConfig | undefined,
): string => {
return readOptionalString(providerConfig?.base_url) ?? "";
return (
readOptionalString(providerConfig?.base_url) ??
getDefaultProviderBaseURL(providerConfig?.provider ?? "")
);
};
// ── Hook: compute provider states from query data ──────────────
const providerConfigStateKey = (
providerConfig: TypesGen.ChatProviderConfig,
): string => {
const providerID = readOptionalString(providerConfig.id);
if (providerID && providerID !== nilUUID) {
return providerID;
}
return normalizeProvider(providerConfig.provider);
};
type ProviderEntry = {
key: string;
provider: string;
};
const useProviderStates = (
modelConfigs: readonly TypesGen.ChatModelConfig[],
providerConfigsData: TypesGen.ChatProviderConfig[] | null | undefined,
catalogData: TypesGen.ChatModelsResponse | null | undefined,
): readonly ProviderState[] => {
const orderedProviders: string[] = [];
const seenProviders = new Set<string>();
const includeProvider = (providerValue: string) => {
const normalized = normalizeProvider(providerValue);
if (!normalized || seenProviders.has(normalized)) return;
seenProviders.add(normalized);
orderedProviders.push(normalized);
const orderedEntries: ProviderEntry[] = [];
const seenEntries = new Set<string>();
const includeEntry = (keyValue: string, providerValue: string) => {
const key = readOptionalString(keyValue);
const provider = normalizeProvider(providerValue);
if (!key || !provider || seenEntries.has(key)) return;
seenEntries.add(key);
orderedEntries.push({ key, provider });
};
const catalogProviders = getCatalogProviders(catalogData);
const catalogProvidersByProvider = new Map<string, CatalogProvider>();
for (const cp of catalogProviders) {
const normalized = normalizeProvider(cp.provider);
if (!normalized) continue;
includeProvider(normalized);
catalogProvidersByProvider.set(normalized, cp);
const provider = normalizeProvider(cp.provider);
if (!provider) continue;
catalogProvidersByProvider.set(provider, cp);
}
const providerConfigKeysByProvider = new Map<string, string[]>();
const providerTypesWithConfigs = new Set<string>();
for (const pc of providerConfigsData ?? []) {
includeProvider(pc.provider);
const provider = normalizeProvider(pc.provider);
if (!provider) continue;
const key = providerConfigStateKey(pc);
providerTypesWithConfigs.add(provider);
providerConfigKeysByProvider.set(provider, [
...(providerConfigKeysByProvider.get(provider) ?? []),
key,
]);
includeEntry(key, provider);
}
const modelStateKey = (modelConfig: TypesGen.ChatModelConfig): string => {
const aiProviderID = readOptionalString(modelConfig.ai_provider_id);
if (aiProviderID) {
return aiProviderID;
}
const provider = normalizeProvider(modelConfig.provider);
const providerConfigKeys = providerConfigKeysByProvider.get(provider) ?? [];
if (providerConfigKeys.length === 1) {
return providerConfigKeys[0];
}
return providerConfigKeys.length === 0 ? provider : "";
};
for (const cp of catalogProviders) {
const provider = normalizeProvider(cp.provider);
if (!provider || providerTypesWithConfigs.has(provider)) continue;
includeEntry(provider, provider);
}
for (const mc of modelConfigs) {
includeProvider(mc.provider);
includeEntry(modelStateKey(mc), mc.provider);
}
const providerConfigsByProvider = new Map<
string,
TypesGen.ChatProviderConfig
>();
const providerConfigsByKey = new Map<string, TypesGen.ChatProviderConfig>();
for (const pc of providerConfigsData ?? []) {
const normalized = normalizeProvider(pc.provider);
if (!normalized) continue;
providerConfigsByProvider.set(normalized, pc);
const key = providerConfigStateKey(pc);
if (!key) continue;
providerConfigsByKey.set(key, pc);
}
const modelConfigsByProvider = new Map<string, TypesGen.ChatModelConfig[]>();
const modelConfigsByKey = new Map<string, TypesGen.ChatModelConfig[]>();
for (const mc of modelConfigs) {
const normalized = normalizeProvider(mc.provider);
if (!normalized) continue;
const existing = modelConfigsByProvider.get(normalized);
const key = modelStateKey(mc);
if (!key) continue;
const existing = modelConfigsByKey.get(key);
if (existing) {
existing.push(mc);
} else {
modelConfigsByProvider.set(normalized, [mc]);
modelConfigsByKey.set(key, [mc]);
}
}
return orderedProviders.map((provider) => {
const providerConfigEntry = providerConfigsByProvider.get(provider);
return orderedEntries.map(({ key, provider }) => {
const providerConfigEntry = providerConfigsByKey.get(key);
const providerConfigSource = getProviderConfigSource(providerConfigEntry);
const providerConfig = isDatabaseProviderConfig(
providerConfigEntry,
@@ -145,9 +190,6 @@ const useProviderStates = (
? providerConfigEntry
: undefined;
const catalogProvider = catalogProvidersByProvider.get(provider);
const catalogProviderSource = readOptionalString(
(catalogProvider as CatalogProvider & { source?: string })?.source,
);
const hasManagedAPIKey = hasProviderAPIKey(providerConfig);
const hasProviderEntryAPIKey = hasProviderAPIKey(providerConfigEntry);
const hasCatalogAPIKey = catalogProvider
@@ -159,15 +201,14 @@ const useProviderStates = (
const hasBedrockAmbientCredentials =
provider === "bedrock" &&
providerConfig?.central_api_key_enabled === true;
const modelConfigsForProvider = modelConfigsByProvider.get(provider) ?? [];
const modelConfigsForProvider = modelConfigsByKey.get(key) ?? [];
const isCatalogEnvPreset =
!providerConfig &&
envPresetProviders.has(provider) &&
(catalogProviderSource === "env" || hasCatalogAPIKey);
!providerConfig && envPresetProviders.has(provider) && hasCatalogAPIKey;
const isEnvPreset =
providerConfigSource === "env_preset" || isCatalogEnvPreset;
return {
key,
provider,
label,
providerConfig,
@@ -178,14 +219,13 @@ const useProviderStates = (
hasEffectiveAPIKey: providerConfigEntry
? hasProviderEntryAPIKey || hasBedrockAmbientCredentials
: hasManagedAPIKey || hasCatalogAPIKey,
allowUserAPIKey: providerConfigEntry?.allow_user_api_key ?? true,
isEnvPreset,
baseURL: getProviderBaseURL(providerConfigEntry),
};
});
};
// ── Component ──────────────────────────────────────────────────
interface ChatModelAdminPanelProps {
className?: string;
section?: ChatModelAdminSection;
@@ -203,7 +243,7 @@ interface ChatModelAdminPanelProps {
// Provider mutation handlers.
onCreateProvider: (
req: TypesGen.CreateChatProviderConfigRequest,
) => Promise<unknown>;
) => Promise<CreateProviderResult>;
onUpdateProvider: (
providerConfigId: string,
req: TypesGen.UpdateChatProviderConfigRequest,
@@ -251,13 +291,10 @@ export const ChatModelAdminPanel: FC<ChatModelAdminPanelProps> = ({
isDeletingModel,
modelMutationError,
}) => {
// ── Sorted model configs ───────────────────────────────────
const modelConfigs = (modelConfigsData ?? []).slice().sort((a, b) => {
const cmp = a.provider.localeCompare(b.provider);
return cmp !== 0 ? cmp : a.model.localeCompare(b.model);
});
// ── Provider states ────────────────────────────────────────
const providerStates = useProviderStates(
modelConfigs,
providerConfigsData,
@@ -276,7 +313,6 @@ export const ChatModelAdminPanel: FC<ChatModelAdminPanelProps> = ({
</div>
)}
{/* Content */}
<div className="flex flex-1 flex-col gap-8">
{section === "providers" ? (
<ProvidersSection
@@ -305,8 +341,6 @@ export const ChatModelAdminPanel: FC<ChatModelAdminPanelProps> = ({
/>
)}
</div>
{/* Errors — rendered at the bottom */}
{providerConfigsError && <ErrorAlert error={providerConfigsError} />}
{modelConfigsError && <ErrorAlert error={modelConfigsError} />}
{modelCatalogError && <ErrorAlert error={modelCatalogError} />}
@@ -34,6 +34,7 @@ import { getFormHelpers } from "#/utils/formUtils";
import { BackButton } from "../BackButton";
import { ConfirmDeleteDialog } from "../ConfirmDeleteDialog";
import type { ProviderState } from "./ChatModelAdminPanel";
import { readOptionalString } from "./helpers";
import {
GeneralModelConfigFields,
ModelConfigFields,
@@ -141,6 +142,9 @@ export const ModelForm: FC<ModelFormProps> = ({
return "add";
})();
const selectedProviderType =
selectedProviderState?.provider ?? selectedProvider;
const form = useFormik<ModelFormValues>({
initialValues,
validationSchema,
@@ -158,7 +162,7 @@ export const ModelForm: FC<ModelFormProps> = ({
);
const buildResult = buildModelConfigFromForm(
selectedProvider,
selectedProviderType,
values.config,
);
if (Object.keys(buildResult.fieldErrors).length > 0) return;
@@ -166,8 +170,17 @@ export const ModelForm: FC<ModelFormProps> = ({
const trimmedDisplayName = values.displayName.trim();
const builtModelConfig = buildResult.modelConfig;
const selectedProviderConfigID =
selectedProviderState?.providerConfig?.id;
if (isEditing && editingModel) {
const req: TypesGen.UpdateChatModelConfigRequest = {
...(selectedProviderConfigID &&
selectedProviderConfigID !==
readOptionalString(editingModel.ai_provider_id) && {
provider: selectedProviderState.provider,
ai_provider_id: selectedProviderConfigID,
}),
...(trimmedModel !== editingModel.model && {
model: trimmedModel,
}),
@@ -198,7 +211,8 @@ export const ModelForm: FC<ModelFormProps> = ({
if (!selectedProvider || !selectedProviderState?.providerConfig) return;
const req: TypesGen.CreateChatModelConfigRequest = {
provider: selectedProvider,
provider: selectedProviderState.provider,
ai_provider_id: selectedProviderState.providerConfig.id,
model: trimmedModel,
enabled: values.enabled,
is_default: values.isDefault,
@@ -227,7 +241,7 @@ export const ModelForm: FC<ModelFormProps> = ({
const getFieldHelpers = getFormHelpers(form);
const modelConfigFormBuildResult = buildModelConfigFromForm(
selectedProvider,
selectedProviderType,
form.values.config,
);
@@ -249,7 +263,10 @@ export const ModelForm: FC<ModelFormProps> = ({
<Select
value={selectedProvider ?? ""}
onValueChange={onSelectedProviderChange}
disabled={isEditing || isDuplicating || providerStates.length === 0}
disabled={
((isEditing || isDuplicating) && selectedProviderState !== null) ||
providerStates.length === 0
}
>
<SelectTrigger
id="providerSelect"
@@ -259,7 +276,7 @@ export const ModelForm: FC<ModelFormProps> = ({
</SelectTrigger>
<SelectContent>
{providerStates.map((ps) => (
<SelectItem key={ps.provider} value={ps.provider}>
<SelectItem key={ps.key} value={ps.key}>
<span className="flex items-center gap-2">
<ProviderIcon provider={ps.provider} className="h-4 w-4" />
{ps.label}
@@ -394,7 +411,7 @@ export const ModelForm: FC<ModelFormProps> = ({
form={form}
modelField={modelField}
mode={mode}
selectedProvider={selectedProvider}
selectedProvider={selectedProviderType}
disabled={isSaving}
/>
<div className="grid gap-1.5">
@@ -6,6 +6,7 @@ import type { ProviderState } from "./ChatModelAdminPanel";
import { ModelsSection } from "./ModelsSection";
const providerState: ProviderState = {
key: "provider-config-id",
provider: "openai",
label: "OpenAI",
providerConfig: {
@@ -27,6 +28,7 @@ const providerState: ProviderState = {
hasManagedAPIKey: true,
hasCatalogAPIKey: true,
hasEffectiveAPIKey: true,
allowUserAPIKey: false,
isEnvPreset: false,
baseURL: "",
};
@@ -42,6 +44,7 @@ const providerStateWithoutAPIKey: ProviderState = {
hasManagedAPIKey: false,
hasCatalogAPIKey: false,
hasEffectiveAPIKey: false,
allowUserAPIKey: false,
};
const baseModelConfig: TypesGen.ChatModelConfig = {
@@ -294,6 +297,7 @@ export const SavesDuplicateAsCreateRequest: Story = {
throw new Error("Expected create request.");
}
expect(createReq).toEqual({
ai_provider_id: "provider-config-id",
provider: "openai",
model: "gpt-4.1-copy",
display_name: "GPT-4.1 Copy",
@@ -344,6 +348,7 @@ export const SavesNonDefaultDuplicateWithEditableEnabled: Story = {
const createModelMock = args.onCreateModel as ReturnType<typeof fn>;
expect(createModelMock.mock.calls[0]?.[0]).toEqual({
ai_provider_id: "provider-config-id",
provider: "openai",
model: "gpt-4.1-copy",
display_name: "GPT-4.1",
@@ -385,6 +390,7 @@ export const SavesDisabledDuplicateWithEditableEnabled: Story = {
await waitFor(() => expect(args.onCreateModel).toHaveBeenCalledTimes(1));
const createModelMock = args.onCreateModel as ReturnType<typeof fn>;
expect(createModelMock.mock.calls[0]?.[0]).toEqual({
ai_provider_id: "provider-config-id",
provider: "openai",
model: "gpt-4.1-disabled-copy",
display_name: "GPT-4.1 Disabled",
@@ -6,7 +6,7 @@ import {
StarIcon,
TriangleAlertIcon,
} from "lucide-react";
import type { FC } from "react";
import { type FC, useEffect, useState } from "react";
import { Link, useLocation, useSearchParams } from "react-router";
import type * as TypesGen from "#/api/typesGenerated";
import { Badge } from "#/components/Badge/Badge";
@@ -25,6 +25,7 @@ import {
import { cn } from "#/utils/cn";
import { SectionHeader } from "../SectionHeader";
import type { ProviderState } from "./ChatModelAdminPanel";
import { normalizeProvider, readOptionalString } from "./helpers";
import { ModelForm } from "./ModelForm";
import { ProviderIcon } from "./ProviderIcon";
import { hasCustomPricing } from "./pricingFields";
@@ -44,6 +45,28 @@ const clearModelViewParams = (params: URLSearchParams) => {
}
};
const modelConfigProviderKey = (
modelConfig: TypesGen.ChatModelConfig,
providerStates: readonly ProviderState[],
): string => {
const providerID = readOptionalString(modelConfig.ai_provider_id);
if (providerID) {
return providerID;
}
const provider = normalizeProvider(modelConfig.provider);
const providerMatches = providerStates.filter(
(providerState) => providerState.provider === provider,
);
if (providerMatches.length === 1) {
return providerMatches[0].key;
}
if (providerMatches.length > 1) {
return "";
}
return provider;
};
const canManageProviderModels = (providerState: ProviderState | undefined) => {
return Boolean(
providerState?.providerConfig &&
@@ -85,6 +108,9 @@ export const ModelsSection: FC<ModelsSectionProps> = ({
onDeleteModel,
}) => {
const [searchParams, setSearchParams] = useSearchParams();
const [selectedProviderOverride, setSelectedProviderOverride] = useState<
string | null
>(null);
const location = useLocation();
// Derive the current view from URL search params so that
@@ -122,9 +148,27 @@ export const ModelsSection: FC<ModelsSectionProps> = ({
state: options?.replace ? location.state : { pushed: true },
});
};
const modelViewIdentity = (() => {
switch (view.mode) {
case "add":
return `add:${view.provider}`;
case "edit":
return `edit:${view.model.id}`;
case "duplicate":
return `duplicate:${view.sourceModel.id}`;
default:
return "list";
}
})();
useEffect(() => {
void modelViewIdentity;
setSelectedProviderOverride(null);
}, [modelViewIdentity]);
// Clear model-related search params and return to the list.
const clearModelView = () => {
setSelectedProviderOverride(null);
setSearchParams((prev) => {
const next = new URLSearchParams(prev);
clearModelViewParams(next);
@@ -133,6 +177,7 @@ export const ModelsSection: FC<ModelsSectionProps> = ({
};
const exitModelView = () => {
setSelectedProviderOverride(null);
setSearchParams(
(prev) => {
const next = new URLSearchParams(prev);
@@ -153,13 +198,14 @@ export const ModelsSection: FC<ModelsSectionProps> = ({
const duplicateSourceModel =
view.mode === "duplicate" ? view.sourceModel : undefined;
const effectiveProvider =
view.mode === "edit"
? view.model.provider
selectedProviderOverride ??
(view.mode === "edit"
? modelConfigProviderKey(view.model, providerStates)
: view.mode === "duplicate"
? view.sourceModel.provider
: view.provider;
? modelConfigProviderKey(view.sourceModel, providerStates)
: view.provider);
const effectiveProviderState =
providerStates.find((ps) => ps.provider === effectiveProvider) ?? null;
providerStates.find((ps) => ps.key === effectiveProvider) ?? null;
const formKey =
view.mode === "edit"
? `edit:${view.model.id}`
@@ -178,7 +224,9 @@ export const ModelsSection: FC<ModelsSectionProps> = ({
onSelectedProviderChange={(provider) => {
if (view.mode === "add") {
setModelViewParam("newModel", provider, { replace: true });
return;
}
setSelectedProviderOverride(provider);
}}
modelConfigsUnavailable={modelConfigsUnavailable}
isSaving={isCreating || isUpdating}
@@ -222,9 +270,9 @@ export const ModelsSection: FC<ModelsSectionProps> = ({
<DropdownMenuContent align="end">
{addableProviders.map((ps) => (
<DropdownMenuItem
key={ps.provider}
key={ps.key}
onClick={() => {
setModelViewParam("newModel", ps.provider);
setModelViewParam("newModel", ps.key);
}}
className="gap-2"
>
@@ -285,10 +333,12 @@ export const ModelsSection: FC<ModelsSectionProps> = ({
const starUnavailable =
isUpdating || modelConfig.is_default || !modelConfig.enabled;
const providerState = providerStates.find(
(ps) => ps.provider === modelConfig.provider,
(ps) =>
ps.key === modelConfigProviderKey(modelConfig, providerStates),
);
const duplicateUnavailable = Boolean(
providerState && !canManageProviderModels(providerState),
);
const duplicateUnavailable =
!canManageProviderModels(providerState);
return (
<div
@@ -13,7 +13,6 @@ import { Alert, AlertDescription, AlertTitle } from "#/components/Alert/Alert";
import { Button } from "#/components/Button/Button";
import { Input } from "#/components/Input/Input";
import { Spinner } from "#/components/Spinner/Spinner";
import { Switch } from "#/components/Switch/Switch";
import {
Tooltip,
TooltipContent,
@@ -22,10 +21,12 @@ import {
import { formatProviderLabel } from "../../utils/modelOptions";
import { BackButton } from "../BackButton";
import { ConfirmDeleteDialog } from "../ConfirmDeleteDialog";
import type { ProviderState } from "./ChatModelAdminPanel";
import { readOptionalString } from "./helpers";
import type {
CreateProviderResult,
ProviderState,
} from "./ChatModelAdminPanel";
import { getProviderBaseURLPlaceholder, readOptionalString } from "./helpers";
import { ProviderIcon } from "./ProviderIcon";
import { normalizeProviderPolicyDefaults } from "./providerPolicyDefaults";
// Sentinel value used to represent an existing API key that the
// backend will not reveal. If the user has not touched the field,
@@ -38,7 +39,7 @@ interface ProviderFormProps {
isProviderMutationPending: boolean;
onCreateProvider: (
req: TypesGen.CreateChatProviderConfigRequest,
) => Promise<unknown>;
) => Promise<CreateProviderResult>;
onUpdateProvider: (
providerConfigId: string,
req: TypesGen.UpdateChatProviderConfigRequest,
@@ -62,27 +63,13 @@ export const ProviderForm: FC<ProviderFormProps> = ({
const apiKeyInputId = useId();
const baseURLInputId = useId();
// Providers backed by the OpenAI SDK expect /v1 in the base
// URL, while others (e.g. Anthropic) do not.
const baseURLPlaceholder =
provider === "anthropic" || provider === "bedrock" || provider === "google"
? "https://api.example.com"
: "https://api.example.com/v1";
const normalizedProviderConfig = providerConfig
? normalizeProviderPolicyDefaults(providerConfig)
: undefined;
const baseURLPlaceholder = getProviderBaseURLPlaceholder(provider);
// Initial values are snapshotted when the provider config changes
// so we can detect dirty state.
const [initialValues] = useState(() => ({
displayName: readOptionalString(providerConfig?.display_name) ?? "",
baseURL,
centralAPIKeyEnabled:
normalizedProviderConfig?.central_api_key_enabled ?? true,
allowUserAPIKey: normalizedProviderConfig?.allow_user_api_key ?? false,
allowCentralAPIKeyFallback:
normalizedProviderConfig?.allow_central_api_key_fallback ?? false,
}));
const [displayName, setDisplayName] = useState(initialValues.displayName);
@@ -92,84 +79,53 @@ export const ProviderForm: FC<ProviderFormProps> = ({
const [apiKeyTouched, setApiKeyTouched] = useState(false);
const [apiKeyModified, setApiKeyModified] = useState(false);
const [baseURLValue, setBaseURLValue] = useState(initialValues.baseURL);
const [centralAPIKeyEnabled, setCentralAPIKeyEnabled] = useState(
initialValues.centralAPIKeyEnabled,
);
const [allowUserAPIKey, setAllowUserAPIKey] = useState(
initialValues.allowUserAPIKey,
);
const [allowCentralAPIKeyFallback, setAllowCentralAPIKeyFallback] = useState(
initialValues.allowCentralAPIKeyFallback,
);
const [confirmingDelete, setConfirmingDelete] = useState(false);
const isBedrockProvider = provider === "bedrock";
const isAPIKeyEnvManaged = isEnvPreset && !providerConfig;
const shouldShowAPIKeyField = centralAPIKeyEnabled;
const shouldShowFallbackToggle = centralAPIKeyEnabled && allowUserAPIKey;
const effectiveInitialFallback =
initialValues.centralAPIKeyEnabled &&
initialValues.allowUserAPIKey &&
initialValues.allowCentralAPIKeyFallback;
const effectiveFallback =
shouldShowFallbackToggle && allowCentralAPIKeyFallback;
// Most providers require a stored deployment key whenever central-key
// usage is enabled and there is no saved key yet. Bedrock can also use
// ambient AWS credentials from the Coder server, so its API key stays
// optional.
const requiresAPIKey =
!isAPIKeyEnvManaged &&
!providerState.allowUserAPIKey &&
!isBedrockProvider &&
centralAPIKeyEnabled &&
!providerState.hasManagedAPIKey;
const effectiveApiKey =
apiKeyTouched && apiKey !== API_KEY_PLACEHOLDER ? apiKey.trim() : "";
apiKeyTouched && apiKey !== API_KEY_PLACEHOLDER ? apiKey : "";
const hasTypedAPIKey = effectiveApiKey.length > 0;
// Clearing a saved Bedrock bearer token switches the provider back
// to ambient AWS credentials, so updates must send an explicit
// empty string.
const isClearingBedrockAPIKey =
isBedrockProvider &&
providerState.hasManagedAPIKey &&
apiKeyModified &&
effectiveApiKey === "";
const hasPendingAPIKeyChange =
(centralAPIKeyEnabled && hasTypedAPIKey) || isClearingBedrockAPIKey;
const shouldCreateAPIKey = centralAPIKeyEnabled && hasTypedAPIKey;
const hasCredentialSource = centralAPIKeyEnabled || allowUserAPIKey;
const hasAPIKeyWhitespace =
hasTypedAPIKey && effectiveApiKey.trim() !== effectiveApiKey;
// Clearing a saved provider-scoped key switches the provider to
// BYOK-only behavior, or ambient AWS credentials for Bedrock.
const isClearingAPIKey =
providerState.hasManagedAPIKey && apiKeyModified && effectiveApiKey === "";
const hasPendingAPIKeyChange = hasTypedAPIKey || isClearingAPIKey;
const shouldCreateAPIKey = hasTypedAPIKey;
const apiKeyDescription = isBedrockProvider
? "Bearer token for Bedrock authentication. Leave empty to use ambient AWS credentials."
: "Secret key used to authenticate requests to this provider.";
const baseURLDescription = isBedrockProvider
? "Optional. Overrides the Bedrock runtime endpoint. Set AWS_REGION on the Coder server to select the target region."
: "Custom endpoint for this provider. Leave empty to use the default.";
? "Bedrock runtime endpoint. Use the AWS region for the models this provider should call."
: "Endpoint used to call this provider.";
const apiKeyPlaceholder = isBedrockProvider ? "Enter bearer token" : "sk-...";
const deleteProviderDescription = normalizedProviderConfig?.allow_user_api_key
? "Are you sure you want to delete this provider? Any personal API " +
"keys that users have saved for this provider will also be " +
"permanently deleted. This action is irreversible."
: "Are you sure you want to delete this provider? This action is irreversible.";
// New Bedrock providers can be saved immediately with ambient AWS
// credentials, even before any fields differ from their defaults.
const hasNewBedrockAmbientConfiguration =
isBedrockProvider && !providerConfig && centralAPIKeyEnabled;
const deleteProviderDescription =
"Are you sure you want to delete this provider? The provider will be " +
"disabled and hidden from new model configuration. Existing model " +
"configs that reference it remain saved but cannot run until updated.";
const hasNewProviderConfiguration = !providerConfig;
const isDirty =
displayName.trim() !== initialValues.displayName ||
hasPendingAPIKeyChange ||
baseURLValue.trim() !== initialValues.baseURL.trim() ||
centralAPIKeyEnabled !== initialValues.centralAPIKeyEnabled ||
allowUserAPIKey !== initialValues.allowUserAPIKey ||
effectiveFallback !== effectiveInitialFallback ||
hasNewBedrockAmbientConfiguration;
hasNewProviderConfiguration;
const hasBaseURL = baseURLValue.trim().length > 0;
const canSave =
!providerConfigsUnavailable &&
!isProviderMutationPending &&
!isAPIKeyEnvManaged &&
isDirty &&
hasCredentialSource &&
hasBaseURL &&
!hasAPIKeyWhitespace &&
(!requiresAPIKey || hasTypedAPIKey);
const canAddModel =
Boolean(providerConfig) &&
@@ -177,7 +133,7 @@ export const ProviderForm: FC<ProviderFormProps> = ({
providerConfig?.allow_user_api_key === true);
const handleAddModel = () => {
const params = new URLSearchParams({ newModel: provider });
const params = new URLSearchParams({ newModel: providerState.key });
navigate(`/agents/settings/models?${params.toString()}`, {
state: { pushed: true },
});
@@ -189,7 +145,8 @@ export const ProviderForm: FC<ProviderFormProps> = ({
providerConfigsUnavailable ||
isProviderMutationPending ||
isAPIKeyEnvManaged ||
!hasCredentialSource
!hasBaseURL ||
hasAPIKeyWhitespace
) {
return;
}
@@ -213,15 +170,6 @@ export const ProviderForm: FC<ProviderFormProps> = ({
...(trimmedBaseURL !== currentBaseURL && {
base_url: trimmedBaseURL,
}),
...(centralAPIKeyEnabled !== initialValues.centralAPIKeyEnabled && {
central_api_key_enabled: centralAPIKeyEnabled,
}),
...(allowUserAPIKey !== initialValues.allowUserAPIKey && {
allow_user_api_key: allowUserAPIKey,
}),
...(effectiveFallback !== effectiveInitialFallback && {
allow_central_api_key_fallback: effectiveFallback,
}),
};
if (Object.keys(req).length === 0) {
@@ -231,28 +179,21 @@ export const ProviderForm: FC<ProviderFormProps> = ({
try {
await onUpdateProvider(providerConfig.id, req);
} catch {
// Error is surfaced via the mutation's error state
// in ChatModelAdminPanel, no toast needed.
return;
}
} else {
const req: TypesGen.CreateChatProviderConfigRequest = {
provider,
base_url: trimmedBaseURL,
...(shouldCreateAPIKey && { api_key: effectiveApiKey }),
central_api_key_enabled: centralAPIKeyEnabled,
allow_user_api_key: allowUserAPIKey,
allow_central_api_key_fallback: effectiveFallback,
...(trimmedDisplayName && {
display_name: trimmedDisplayName,
}),
...(trimmedBaseURL && { base_url: trimmedBaseURL }),
};
try {
await onCreateProvider(req);
} catch {
// Error is surfaced via the mutation's error state
// in ChatModelAdminPanel, no toast needed.
return;
}
}
@@ -275,9 +216,7 @@ export const ProviderForm: FC<ProviderFormProps> = ({
return (
<div className="flex min-h-full flex-col">
{/* Back */}
<BackButton onClick={onBack} />
{/* Provider header, editable name */}
<div className="flex items-center gap-3">
<ProviderIcon provider={provider} className="h-8 w-8" />
<div className="min-w-0 flex-1">
@@ -316,57 +255,60 @@ export const ProviderForm: FC<ProviderFormProps> = ({
data-form-type="other"
>
<div className="space-y-5">
{shouldShowAPIKeyField && (
<ProviderField
label="API Key"
htmlFor={apiKeyInputId}
required={requiresAPIKey}
description={apiKeyDescription}
>
<div className="space-y-1.5">
<Input
id={apiKeyInputId}
name="provider_api_token"
type="password"
autoComplete="off"
data-1p-ignore
data-lpignore="true"
data-form-type="other"
data-bwignore
style={{ WebkitTextSecurity: "disc" } as CSSProperties}
className="h-9 font-mono text-[13px]"
placeholder={apiKeyPlaceholder}
required={requiresAPIKey}
value={apiKey}
onFocus={handleApiKeyFocus}
onChange={(event) => {
setApiKey(event.target.value);
setApiKeyTouched(true);
setApiKeyModified(true);
}}
disabled={isDisabled}
/>
{isBedrockProvider &&
providerState.hasManagedAPIKey &&
!isDisabled &&
(!apiKeyModified || apiKey !== "") && (
<div className="flex justify-end">
<button
type="button"
className="appearance-none border-0 bg-transparent p-0 text-xs text-content-link hover:cursor-pointer hover:underline"
onClick={() => {
setApiKey("");
setApiKeyTouched(true);
setApiKeyModified(true);
}}
>
Clear stored token
</button>
</div>
)}
</div>
</ProviderField>
)}
<ProviderField
label="API Key"
htmlFor={apiKeyInputId}
required={requiresAPIKey}
description={apiKeyDescription}
>
<div className="space-y-1.5">
<Input
id={apiKeyInputId}
name="provider_api_token"
type="password"
autoComplete="off"
data-1p-ignore
data-lpignore="true"
data-form-type="other"
data-bwignore
style={{ WebkitTextSecurity: "disc" } as CSSProperties}
className="h-9 font-mono text-[13px]"
placeholder={apiKeyPlaceholder}
required={requiresAPIKey}
value={apiKey}
onFocus={handleApiKeyFocus}
onChange={(event) => {
setApiKey(event.target.value);
setApiKeyTouched(true);
setApiKeyModified(true);
}}
disabled={isDisabled}
/>
{hasAPIKeyWhitespace && (
<p className="m-0 text-xs text-content-destructive">
API key must not contain leading or trailing whitespace.
</p>
)}
{isBedrockProvider &&
providerState.hasManagedAPIKey &&
!isDisabled &&
(!apiKeyModified || apiKey !== "") && (
<div className="flex justify-end">
<button
type="button"
className="appearance-none border-0 bg-transparent p-0 text-xs text-content-link hover:cursor-pointer hover:underline"
onClick={() => {
setApiKey("");
setApiKeyTouched(true);
setApiKeyModified(true);
}}
>
Clear stored token
</button>
</div>
)}
</div>
</ProviderField>
<ProviderField
label="Base URL"
@@ -378,56 +320,14 @@ export const ProviderForm: FC<ProviderFormProps> = ({
name="provider_base_url"
className="h-9 text-[13px]"
placeholder={baseURLPlaceholder}
required
autoComplete="off"
value={baseURLValue}
onChange={(event) => setBaseURLValue(event.target.value)}
disabled={isDisabled}
/>
</ProviderField>
<div className="space-y-3 rounded-lg border border-solid border-border/70 bg-surface-secondary/30 p-4">
<div className="space-y-1">
<h3 className="m-0 text-[13px] font-semibold text-content-primary">
Key policy
</h3>
<p className="m-0 text-xs text-content-secondary">
Control which credential sources this provider can use.
</p>
</div>
<div className="space-y-3">
<ProviderToggleField
label="Central API key"
description="Use a deployment-managed API key for this provider"
checked={centralAPIKeyEnabled}
onCheckedChange={setCentralAPIKeyEnabled}
disabled={isDisabled}
/>
<ProviderToggleField
label="Allow user API keys"
description="Let users provide their own API keys for this provider"
checked={allowUserAPIKey}
onCheckedChange={setAllowUserAPIKey}
disabled={isDisabled}
/>
{shouldShowFallbackToggle && (
<ProviderToggleField
label="Use central key as fallback"
description="When a user has not saved a personal key, fall back to the central API key"
checked={effectiveFallback}
onCheckedChange={setAllowCentralAPIKeyFallback}
disabled={isDisabled}
/>
)}
</div>
{!hasCredentialSource && (
<p className="m-0 text-xs text-content-destructive">
At least one credential source must be enabled
</p>
)}
</div>
</div>
{/* Footer, pushed to bottom */}
<div className="mt-auto pt-6">
<hr className="mb-4 border-0 border-t border-solid border-border" />
<div className="flex items-center justify-between">
@@ -480,50 +380,6 @@ export const ProviderForm: FC<ProviderFormProps> = ({
</div>
);
};
interface ProviderToggleFieldProps {
label: string;
description: string;
checked: boolean;
onCheckedChange: (checked: boolean) => void;
disabled?: boolean;
}
const ProviderToggleField: FC<ProviderToggleFieldProps> = ({
label,
description,
checked,
onCheckedChange,
disabled,
}) => {
const labelId = useId();
const descriptionId = useId();
return (
<div className="flex items-start justify-between gap-4">
<div className="min-w-0 space-y-1">
<p
id={labelId}
className="m-0 text-sm font-medium text-content-primary"
>
{label}
</p>
<p id={descriptionId} className="m-0 text-xs text-content-secondary">
{description}
</p>
</div>
<Switch
checked={checked}
onCheckedChange={onCheckedChange}
disabled={disabled}
aria-labelledby={labelId}
aria-describedby={descriptionId}
/>
</div>
);
};
// Field wrapper.
interface ProviderFieldProps {
label: string;
htmlFor?: string;
@@ -1,15 +1,69 @@
import { CheckCircleIcon, ChevronRightIcon, CircleIcon } from "lucide-react";
import {
CheckCircleIcon,
ChevronRightIcon,
CircleIcon,
PlusIcon,
} from "lucide-react";
import type { FC } from "react";
import { useLocation, useNavigate, useSearchParams } from "react-router";
import type * as TypesGen from "#/api/typesGenerated";
import {
type AIProviderType,
AIProviderTypes,
type CreateChatProviderConfigRequest,
type UpdateChatProviderConfigRequest,
} from "#/api/typesGenerated";
import { Badge } from "#/components/Badge/Badge";
import { Button } from "#/components/Button/Button";
import {
DropdownMenu,
DropdownMenuContent,
DropdownMenuItem,
DropdownMenuTrigger,
} from "#/components/DropdownMenu/DropdownMenu";
import { cn } from "#/utils/cn";
import { formatProviderLabel } from "../../utils/modelOptions";
import { SectionHeader } from "../SectionHeader";
import type { ProviderState } from "./ChatModelAdminPanel";
import type {
CreateProviderResult,
ProviderState,
} from "./ChatModelAdminPanel";
import { getDefaultProviderBaseURL } from "./helpers";
import { ProviderForm } from "./ProviderForm";
import { ProviderIcon } from "./ProviderIcon";
type ProviderView = { mode: "list" } | { mode: "detail"; provider: string };
type ProviderView =
| { mode: "list" }
| { mode: "detail"; provider: string }
| { mode: "new"; providerType: AIProviderType };
const providerTypeOptions = AIProviderTypes.map((providerType) => ({
providerType,
label: formatProviderLabel(providerType),
})).sort((a, b) => a.label.localeCompare(b.label));
const getAIProviderType = (
value: string | null,
): AIProviderType | undefined => {
if (!value) {
return undefined;
}
return AIProviderTypes.find((providerType) => providerType === value);
};
const newProviderState = (providerType: AIProviderType): ProviderState => ({
key: `new:${providerType}`,
provider: providerType,
label: formatProviderLabel(providerType),
providerConfig: undefined,
modelConfigs: [],
catalogModelCount: 0,
hasManagedAPIKey: false,
hasCatalogAPIKey: false,
hasEffectiveAPIKey: false,
allowUserAPIKey: true,
isEnvPreset: false,
baseURL: getDefaultProviderBaseURL(providerType),
});
interface ProvidersSectionProps {
sectionLabel?: string;
@@ -18,11 +72,11 @@ interface ProvidersSectionProps {
providerConfigsUnavailable: boolean;
isProviderMutationPending: boolean;
onCreateProvider: (
req: TypesGen.CreateChatProviderConfigRequest,
) => Promise<unknown>;
req: CreateChatProviderConfigRequest,
) => Promise<CreateProviderResult>;
onUpdateProvider: (
providerConfigId: string,
req: TypesGen.UpdateChatProviderConfigRequest,
req: UpdateChatProviderConfigRequest,
) => Promise<unknown>;
onDeleteProvider: (providerConfigId: string) => Promise<void>;
}
@@ -48,11 +102,17 @@ export const ProvidersSection: FC<ProvidersSectionProps> = ({
const view: ProviderView = (() => {
const providerParam = searchParams.get("provider");
if (providerParam) {
const exists = providerStates.some((ps) => ps.provider === providerParam);
const exists = providerStates.some((ps) => ps.key === providerParam);
return exists
? { mode: "detail", provider: providerParam }
: { mode: "list" };
}
const newProviderType = getAIProviderType(searchParams.get("newProvider"));
if (newProviderType) {
return { mode: "new", providerType: newProviderType };
}
return { mode: "list" };
})();
@@ -61,17 +121,31 @@ export const ProvidersSection: FC<ProvidersSectionProps> = ({
setSearchParams((prev) => {
const next = new URLSearchParams(prev);
next.delete("provider");
next.delete("newProvider");
return next;
});
};
const openNewProviderView = (providerType: AIProviderType) => {
setSearchParams(
(prev) => {
const next = new URLSearchParams(prev);
next.delete("provider");
next.set("newProvider", providerType);
return next;
},
{ state: { pushed: true } },
);
};
// Detail view.
const detailProvider =
view.mode === "detail"
? providerStates.find((ps) => ps.provider === view.provider)
: undefined;
? providerStates.find((ps) => ps.key === view.provider)
: view.mode === "new"
? newProviderState(view.providerType)
: undefined;
if (view.mode === "detail" && detailProvider) {
if ((view.mode === "detail" || view.mode === "new") && detailProvider) {
const providerFormKey = [
detailProvider.provider,
detailProvider.providerConfig?.id ?? "new",
@@ -91,7 +165,21 @@ export const ProvidersSection: FC<ProvidersSectionProps> = ({
providerState={detailProvider}
providerConfigsUnavailable={providerConfigsUnavailable}
isProviderMutationPending={isProviderMutationPending}
onCreateProvider={onCreateProvider}
onCreateProvider={async (req) => {
const createdProvider = await onCreateProvider(req);
if (createdProvider.id) {
setSearchParams(
(prev) => {
const next = new URLSearchParams(prev);
next.set("provider", createdProvider.id);
next.delete("newProvider");
return next;
},
{ replace: true, state: location.state },
);
}
return createdProvider;
}}
onUpdateProvider={onUpdateProvider}
onDeleteProvider={async (id) => {
await onDeleteProvider(id);
@@ -102,6 +190,7 @@ export const ProvidersSection: FC<ProvidersSectionProps> = ({
(prev) => {
const next = new URLSearchParams(prev);
next.delete("provider");
next.delete("newProvider");
return next;
},
{ replace: true },
@@ -114,66 +203,94 @@ export const ProvidersSection: FC<ProvidersSectionProps> = ({
}
// List view.
if (providerStates.length === 0) {
return (
<div className="rounded-lg border border-dashed border-border bg-surface-primary p-6 text-center text-[13px] text-content-secondary">
No provider types were returned by the backend.
</div>
);
}
const addProviderAction = (
<DropdownMenu>
<DropdownMenuTrigger asChild>
<Button
variant="outline"
size="sm"
disabled={providerConfigsUnavailable || isProviderMutationPending}
>
<PlusIcon className="h-4 w-4" />
Add provider
</Button>
</DropdownMenuTrigger>
<DropdownMenuContent align="end" className="w-56">
{providerTypeOptions.map(({ providerType, label }) => (
<DropdownMenuItem
key={providerType}
className="gap-2"
onSelect={() => openNewProviderView(providerType)}
>
<ProviderIcon provider={providerType} className="h-5 w-5" />
<span>{label}</span>
</DropdownMenuItem>
))}
</DropdownMenuContent>
</DropdownMenu>
);
const header = sectionLabel ? (
<SectionHeader
label={sectionLabel}
description={
sectionDescription ?? "Configure AI providers to use with Agents."
}
action={addProviderAction}
/>
) : null;
return (
<>
{sectionLabel && (
<SectionHeader
label={sectionLabel}
description={
sectionDescription ?? "Configure AI providers to use with Agents."
}
/>
)}
<div>
{providerStates.map((providerState, i) => (
<button
type="button"
key={providerState.provider}
aria-label={providerState.label}
onClick={() => {
setSearchParams(
{ provider: providerState.provider },
{ state: { pushed: true } },
);
}}
className={cn(
"flex w-full cursor-pointer items-center gap-3.5 border-0 bg-transparent p-0 px-3 py-3 text-left transition-colors hover:bg-surface-secondary/30",
i > 0 && "border-0 border-t border-solid border-border/50",
)}
>
<ProviderIcon
provider={providerState.provider}
className="h-8 w-8 shrink-0"
/>
<div className="min-w-0 flex-1 space-y-1">
<div className="flex flex-wrap items-center gap-2">
<span className="min-w-0 truncate text-[15px] font-medium text-content-primary text-left">
{providerState.label}
</span>
{providerState.providerConfig?.allow_user_api_key && (
<Badge size="xs" className="text-content-secondary">
User keys enabled
</Badge>
)}
{header}
{providerStates.length === 0 ? (
<div className="rounded-lg border border-dashed border-border bg-surface-primary p-6 text-center text-[13px] text-content-secondary">
No providers have been added yet.
</div>
) : (
<div>
{providerStates.map((providerState, i) => (
<button
type="button"
key={providerState.key}
aria-label={providerState.label}
onClick={() => {
setSearchParams(
{ provider: providerState.key },
{ state: { pushed: true } },
);
}}
className={cn(
"flex w-full cursor-pointer items-center gap-3.5 border-0 bg-transparent p-0 px-3 py-3 text-left transition-colors hover:bg-surface-secondary/30",
i > 0 && "border-0 border-t border-solid border-border/50",
)}
>
<ProviderIcon
provider={providerState.provider}
className="h-8 w-8 shrink-0"
/>
<div className="min-w-0 flex-1 space-y-1">
<div className="flex flex-wrap items-center gap-2">
<span className="min-w-0 truncate text-[15px] font-medium text-content-primary text-left">
{providerState.label}
</span>
{providerState.providerConfig?.allow_user_api_key && (
<Badge size="xs" className="text-content-secondary">
User keys enabled
</Badge>
)}
</div>
</div>
</div>
{providerState.hasEffectiveAPIKey ? (
<CheckCircleIcon className="h-4 w-4 shrink-0 text-content-success" />
) : (
<CircleIcon className="h-4 w-4 shrink-0 text-content-secondary opacity-40" />
)}
<ChevronRightIcon className="h-5 w-5 shrink-0 text-content-secondary" />
</button>
))}
</div>
{providerState.hasEffectiveAPIKey ? (
<CheckCircleIcon className="h-4 w-4 shrink-0 text-content-success" />
) : (
<CircleIcon className="h-4 w-4 shrink-0 text-content-secondary opacity-40" />
)}
<ChevronRightIcon className="h-5 w-5 shrink-0 text-content-secondary" />
</button>
))}
</div>
)}
</>
);
};
@@ -14,3 +14,28 @@ export function readOptionalString(value: unknown): string | undefined {
export function normalizeProvider(provider: string): string {
return provider.trim().toLowerCase();
}
const canonicalProviderBaseURLs: Record<string, string> = {
anthropic: "https://api.anthropic.com",
google: "https://generativelanguage.googleapis.com/v1beta",
openai: "https://api.openai.com/v1",
openrouter: "https://openrouter.ai/api/v1",
vercel: "https://ai-gateway.vercel.sh/v1",
};
export function getDefaultProviderBaseURL(provider: string): string {
return canonicalProviderBaseURLs[normalizeProvider(provider)] ?? "";
}
export function getProviderBaseURLPlaceholder(provider: string): string {
switch (normalizeProvider(provider)) {
case "azure":
return "https://<resource-name>.openai.azure.com";
case "bedrock":
return "https://bedrock-runtime.<region>.amazonaws.com";
case "openai-compat":
return "https://api.example.com/v1";
default:
return getDefaultProviderBaseURL(provider) || "https://api.example.com";
}
}
@@ -1,55 +0,0 @@
import { describe, expect, it } from "vitest";
import {
normalizeProviderPolicyDefaults,
type ProviderConfigWithOptionalPolicyFields,
} from "./providerPolicyDefaults";
const baseProviderConfig: ProviderConfigWithOptionalPolicyFields = {
id: "provider-1",
provider: "openai",
display_name: "OpenAI",
enabled: true,
has_api_key: true,
base_url: "https://api.openai.com/v1",
source: "database",
created_at: "2025-01-01T00:00:00Z",
updated_at: "2025-01-01T00:00:00Z",
};
describe("normalizeProviderPolicyDefaults", () => {
it("passes through explicit policy fields unchanged", () => {
const providerConfig: ProviderConfigWithOptionalPolicyFields = {
...baseProviderConfig,
central_api_key_enabled: false,
allow_user_api_key: true,
allow_central_api_key_fallback: true,
};
expect(normalizeProviderPolicyDefaults(providerConfig)).toEqual(
providerConfig,
);
});
it("defaults omitted policy fields to the expected values", () => {
expect(normalizeProviderPolicyDefaults(baseProviderConfig)).toMatchObject({
central_api_key_enabled: true,
allow_user_api_key: false,
allow_central_api_key_fallback: false,
});
});
it("defaults undefined policy fields to the expected values", () => {
const providerConfig: ProviderConfigWithOptionalPolicyFields = {
...baseProviderConfig,
central_api_key_enabled: undefined,
allow_user_api_key: undefined,
allow_central_api_key_fallback: undefined,
};
expect(normalizeProviderPolicyDefaults(providerConfig)).toMatchObject({
central_api_key_enabled: true,
allow_user_api_key: false,
allow_central_api_key_fallback: false,
});
});
});
@@ -1,26 +0,0 @@
import type * as TypesGen from "#/api/typesGenerated";
type ProviderPolicyFields = Pick<
TypesGen.ChatProviderConfig,
| "central_api_key_enabled"
| "allow_user_api_key"
| "allow_central_api_key_fallback"
>;
export type ProviderConfigWithOptionalPolicyFields = Omit<
TypesGen.ChatProviderConfig,
keyof ProviderPolicyFields
> &
Partial<ProviderPolicyFields>;
export function normalizeProviderPolicyDefaults(
providerConfig: ProviderConfigWithOptionalPolicyFields,
): TypesGen.ChatProviderConfig {
return {
...providerConfig,
central_api_key_enabled: providerConfig.central_api_key_enabled ?? true,
allow_user_api_key: providerConfig.allow_user_api_key ?? false,
allow_central_api_key_fallback:
providerConfig.allow_central_api_key_fallback ?? false,
};
}
@@ -223,6 +223,7 @@ describe("countConfiguredProviderConfigs", () => {
describe("formatProviderLabel", () => {
it("formats OpenAI compatible providers", () => {
expect(formatProviderLabel("openai-compat")).toBe("OpenAI-compatible");
expect(formatProviderLabel("openai-compatible")).toBe("OpenAI-compatible");
});
});
@@ -220,33 +220,7 @@ export const getProviderForModelOption = (
): string | undefined =>
modelOptions.find((option) => option.id === selectedModel)?.provider;
export const formatProviderLabel = (provider: string): string => {
const normalized = provider.trim().toLowerCase();
switch (normalized) {
case "openai":
return "OpenAI";
case "anthropic":
return "Anthropic";
case "azure":
return "Azure OpenAI";
case "bedrock":
return "AWS Bedrock";
case "google":
return "Google";
case "openai-compatible":
case "openai_compatible":
return "OpenAI-compatible";
case "openrouter":
return "OpenRouter";
case "vercel":
return "Vercel AI Gateway";
default:
if (!normalized) {
return "Unknown";
}
return `${normalized[0].toUpperCase()}${normalized.slice(1)}`;
}
};
export { formatProviderLabel } from "#/utils/aiProviders";
export const getModelSelectorPlaceholder = (
modelOptions: readonly ModelSelectorOption[],
+28
View File
@@ -0,0 +1,28 @@
export const formatProviderLabel = (provider: string): string => {
const normalized = provider.trim().toLowerCase();
switch (normalized) {
case "openai":
return "OpenAI";
case "anthropic":
return "Anthropic";
case "azure":
return "Azure OpenAI";
case "bedrock":
return "AWS Bedrock";
case "google":
return "Google";
case "openai-compat":
case "openai-compatible":
case "openai_compatible":
return "OpenAI-compatible";
case "openrouter":
return "OpenRouter";
case "vercel":
return "Vercel AI Gateway";
default:
if (!normalized) {
return "Unknown";
}
return `${normalized[0].toUpperCase()}${normalized.slice(1)}`;
}
};