mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
feat: use AI provider chat APIs (#25415)
This commit is contained in:
@@ -157,6 +157,46 @@ func TestAIProvidersCRUD(t *testing.T) {
|
||||
require.Equal(t, "no-display", created.DisplayName)
|
||||
})
|
||||
|
||||
t.Run("RequiredBaseURL", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
_ = coderdtest.CreateFirstUser(t, client)
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
|
||||
//nolint:gocritic // Owner role is the audience for this endpoint.
|
||||
_, err := client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{
|
||||
Type: codersdk.AIProviderTypeOpenAI,
|
||||
Name: "missing-base-url",
|
||||
Enabled: true,
|
||||
})
|
||||
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
|
||||
require.Equal(t, "Invalid AI provider request.", sdkErr.Message)
|
||||
require.Contains(t, sdkErr.Validations, codersdk.ValidationError{Field: "base_url", Detail: "base_url is required"})
|
||||
|
||||
created, err := client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{
|
||||
Type: codersdk.AIProviderTypeOpenAI,
|
||||
Name: "required-base-url",
|
||||
Enabled: true,
|
||||
BaseURL: "https://api.openai.com/v1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
baseURL := "https://proxy.example.com/v1"
|
||||
updated, err := client.UpdateAIProvider(ctx, created.Name, codersdk.UpdateAIProviderRequest{
|
||||
BaseURL: &baseURL,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, baseURL, updated.BaseURL)
|
||||
|
||||
baseURL = ""
|
||||
_, err = client.UpdateAIProvider(ctx, created.Name, codersdk.UpdateAIProviderRequest{
|
||||
BaseURL: &baseURL,
|
||||
})
|
||||
sdkErr = requireSDKError(t, err, http.StatusBadRequest)
|
||||
require.Equal(t, "Invalid AI provider request.", sdkErr.Message)
|
||||
require.Contains(t, sdkErr.Validations, codersdk.ValidationError{Field: "base_url", Detail: "base_url is required"})
|
||||
})
|
||||
|
||||
t.Run("DuplicateNameConflict", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client := coderdtest.New(t, nil)
|
||||
|
||||
@@ -1202,6 +1202,17 @@ func New(options *Options) *API {
|
||||
r.Delete("/", api.deleteUserSkill)
|
||||
})
|
||||
})
|
||||
r.Route("/users/{user}/ai-provider-keys", func(r chi.Router) {
|
||||
r.Use(
|
||||
apiKeyMiddleware,
|
||||
httpmw.ExtractUserParam(options.Database),
|
||||
)
|
||||
r.Get("/", api.listUserAIProviderKeyConfigs)
|
||||
r.Route("/{aiProvider}", func(r chi.Router) {
|
||||
r.Put("/", api.upsertUserAIProviderKey)
|
||||
r.Delete("/", api.deleteUserAIProviderKey)
|
||||
})
|
||||
})
|
||||
r.Route("/chats", func(r chi.Router) {
|
||||
r.Use(
|
||||
apiKeyMiddleware,
|
||||
|
||||
@@ -65,11 +65,24 @@ func CreateOpenAICompatChatModelConfig(
|
||||
BaseURL: baseURL,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
aiProviderBaseURL := baseURL
|
||||
if aiProviderBaseURL == "" {
|
||||
aiProviderBaseURL = "https://api.example.com/v1"
|
||||
}
|
||||
provider, err := client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{
|
||||
Type: codersdk.AIProviderType(TestChatProviderOpenAICompat),
|
||||
Name: "test-" + uuid.NewString(),
|
||||
BaseURL: aiProviderBaseURL,
|
||||
Enabled: true,
|
||||
APIKeys: []string{TestChatProviderAPIKey},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
contextLimit := int64(4096)
|
||||
isDefault := true
|
||||
modelConfig, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: TestChatProviderOpenAICompat,
|
||||
AIProviderID: &provider.ID,
|
||||
Model: TestChatModelOpenAICompat,
|
||||
ContextLimit: &contextLimit,
|
||||
IsDefault: &isDefault,
|
||||
|
||||
@@ -722,6 +722,24 @@ var (
|
||||
}),
|
||||
Scope: rbac.ScopeAll,
|
||||
}.WithCachedASTValue()
|
||||
|
||||
subjectAIProviderMetadataReader = rbac.Subject{
|
||||
Type: rbac.SubjectTypeAIProviderMetadataReader,
|
||||
FriendlyName: "AI Provider Metadata Reader",
|
||||
ID: uuid.Nil.String(),
|
||||
Roles: rbac.Roles([]rbac.Role{
|
||||
{
|
||||
Identifier: rbac.RoleIdentifier{Name: "ai-provider-metadata-reader"},
|
||||
DisplayName: "AI Provider Metadata Reader",
|
||||
Site: rbac.Permissions(map[string][]policy.Action{
|
||||
rbac.ResourceAIProvider.Type: {policy.ActionRead},
|
||||
}),
|
||||
User: []rbac.Permission{},
|
||||
ByOrgID: map[string]rbac.OrgPermissions{},
|
||||
},
|
||||
}),
|
||||
Scope: rbac.ScopeAll,
|
||||
}.WithCachedASTValue()
|
||||
)
|
||||
|
||||
// AsProvisionerd returns a context with an actor that has permissions required
|
||||
@@ -846,6 +864,12 @@ func AsChatd(ctx context.Context) context.Context {
|
||||
return As(ctx, subjectChatd)
|
||||
}
|
||||
|
||||
// AsAIProviderMetadataReader returns a context with an actor that can read
|
||||
// AI provider metadata and provider-key presence.
|
||||
func AsAIProviderMetadataReader(ctx context.Context) context.Context {
|
||||
return As(ctx, subjectAIProviderMetadataReader)
|
||||
}
|
||||
|
||||
var AsRemoveActor = rbac.Subject{
|
||||
ID: "remove-actor",
|
||||
}
|
||||
@@ -2546,6 +2570,13 @@ func (q *querier) GetAIProviderByID(ctx context.Context, id uuid.UUID) (database
|
||||
return q.db.GetAIProviderByID(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) GetAIProviderByIDForReferenceLock(ctx context.Context, id uuid.UUID) (database.AIProvider, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceAIProvider); err != nil {
|
||||
return database.AIProvider{}, err
|
||||
}
|
||||
return q.db.GetAIProviderByIDForReferenceLock(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) GetAIProviderByName(ctx context.Context, name string) (database.AIProvider, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceAIProvider); err != nil {
|
||||
return database.AIProvider{}, err
|
||||
@@ -2560,6 +2591,13 @@ func (q *querier) GetAIProviderKeyByID(ctx context.Context, id uuid.UUID) (datab
|
||||
return q.db.GetAIProviderKeyByID(ctx, id)
|
||||
}
|
||||
|
||||
func (q *querier) GetAIProviderKeyPresence(ctx context.Context, arg []uuid.UUID) ([]uuid.UUID, error) {
|
||||
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceAIProvider); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return q.db.GetAIProviderKeyPresence(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetAIProviderKeys(ctx context.Context, includeDeleted bool) ([]database.AIProviderKey, error) {
|
||||
// Callers pass include_deleted=TRUE only from the dbcrypt key
|
||||
// rotation utility, which needs to re-encrypt every row that holds
|
||||
|
||||
@@ -6509,6 +6509,11 @@ func (s *MethodTestSuite) TestAIBridge() {
|
||||
dbm.EXPECT().GetAIProviderByID(gomock.Any(), provider.ID).Return(provider, nil).AnyTimes()
|
||||
check.Args(provider.ID).Asserts(rbac.ResourceAIProvider, policy.ActionRead).Returns(provider)
|
||||
}))
|
||||
s.Run("GetAIProviderByIDForReferenceLock", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
provider := testutil.Fake(s.T(), faker, database.AIProvider{})
|
||||
dbm.EXPECT().GetAIProviderByIDForReferenceLock(gomock.Any(), provider.ID).Return(provider, nil).AnyTimes()
|
||||
check.Args(provider.ID).Asserts(rbac.ResourceAIProvider, policy.ActionRead).Returns(provider)
|
||||
}))
|
||||
s.Run("GetAIProviderByName", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
provider := testutil.Fake(s.T(), faker, database.AIProvider{})
|
||||
dbm.EXPECT().GetAIProviderByName(gomock.Any(), provider.Name).Return(provider, nil).AnyTimes()
|
||||
@@ -6562,6 +6567,14 @@ func (s *MethodTestSuite) TestAIBridge() {
|
||||
dbm.EXPECT().GetAIProviderKeyByID(gomock.Any(), key.ID).Return(key, nil).AnyTimes()
|
||||
check.Args(key.ID).Asserts(rbac.ResourceAIProvider, policy.ActionRead).Returns(key)
|
||||
}))
|
||||
s.Run("GetAIProviderKeyPresence", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
providerA := testutil.Fake(s.T(), faker, database.AIProvider{})
|
||||
providerB := testutil.Fake(s.T(), faker, database.AIProvider{})
|
||||
arg := []uuid.UUID{providerA.ID, providerB.ID}
|
||||
providerIDs := []uuid.UUID{providerA.ID}
|
||||
dbm.EXPECT().GetAIProviderKeyPresence(gomock.Any(), arg).Return(providerIDs, nil).AnyTimes()
|
||||
check.Args(arg).Asserts(rbac.ResourceAIProvider, policy.ActionRead).Returns(providerIDs)
|
||||
}))
|
||||
s.Run("GetAIProviderKeysByProviderID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
|
||||
provider := testutil.Fake(s.T(), faker, database.AIProvider{})
|
||||
keyA := testutil.Fake(s.T(), faker, database.AIProviderKey{ProviderID: provider.ID})
|
||||
|
||||
@@ -160,6 +160,7 @@ func ChatModelConfig(t testing.TB, db database.Store, seed database.ChatModelCon
|
||||
ContextLimit: takeFirst(seed.ContextLimit, defaultChatModelContextLimit),
|
||||
CompressionThreshold: takeFirst(seed.CompressionThreshold, defaultChatModelCompressionThreshold),
|
||||
Options: takeFirstSlice(seed.Options, json.RawMessage(`{}`)),
|
||||
AIProviderID: seed.AIProviderID,
|
||||
}
|
||||
for _, fn := range munge {
|
||||
fn(¶ms)
|
||||
|
||||
@@ -1041,6 +1041,14 @@ func (m queryMetricsStore) GetAIProviderByID(ctx context.Context, id uuid.UUID)
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetAIProviderByIDForReferenceLock(ctx context.Context, id uuid.UUID) (database.AIProvider, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetAIProviderByIDForReferenceLock(ctx, id)
|
||||
m.queryLatencies.WithLabelValues("GetAIProviderByIDForReferenceLock").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetAIProviderByIDForReferenceLock").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetAIProviderByName(ctx context.Context, name string) (database.AIProvider, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetAIProviderByName(ctx, name)
|
||||
@@ -1057,6 +1065,14 @@ func (m queryMetricsStore) GetAIProviderKeyByID(ctx context.Context, id uuid.UUI
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetAIProviderKeyPresence(ctx context.Context, arg []uuid.UUID) ([]uuid.UUID, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetAIProviderKeyPresence(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("GetAIProviderKeyPresence").Observe(time.Since(start).Seconds())
|
||||
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetAIProviderKeyPresence").Inc()
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m queryMetricsStore) GetAIProviderKeys(ctx context.Context, includeDeleted bool) ([]database.AIProviderKey, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetAIProviderKeys(ctx, includeDeleted)
|
||||
|
||||
@@ -1799,6 +1799,21 @@ func (mr *MockStoreMockRecorder) GetAIProviderByID(ctx, id any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAIProviderByID", reflect.TypeOf((*MockStore)(nil).GetAIProviderByID), ctx, id)
|
||||
}
|
||||
|
||||
// GetAIProviderByIDForReferenceLock mocks base method.
|
||||
func (m *MockStore) GetAIProviderByIDForReferenceLock(ctx context.Context, id uuid.UUID) (database.AIProvider, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAIProviderByIDForReferenceLock", ctx, id)
|
||||
ret0, _ := ret[0].(database.AIProvider)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetAIProviderByIDForReferenceLock indicates an expected call of GetAIProviderByIDForReferenceLock.
|
||||
func (mr *MockStoreMockRecorder) GetAIProviderByIDForReferenceLock(ctx, id any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAIProviderByIDForReferenceLock", reflect.TypeOf((*MockStore)(nil).GetAIProviderByIDForReferenceLock), ctx, id)
|
||||
}
|
||||
|
||||
// GetAIProviderByName mocks base method.
|
||||
func (m *MockStore) GetAIProviderByName(ctx context.Context, name string) (database.AIProvider, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -1829,6 +1844,21 @@ func (mr *MockStoreMockRecorder) GetAIProviderKeyByID(ctx, id any) *gomock.Call
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAIProviderKeyByID", reflect.TypeOf((*MockStore)(nil).GetAIProviderKeyByID), ctx, id)
|
||||
}
|
||||
|
||||
// GetAIProviderKeyPresence mocks base method.
|
||||
func (m *MockStore) GetAIProviderKeyPresence(ctx context.Context, providerIds []uuid.UUID) ([]uuid.UUID, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAIProviderKeyPresence", ctx, providerIds)
|
||||
ret0, _ := ret[0].([]uuid.UUID)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetAIProviderKeyPresence indicates an expected call of GetAIProviderKeyPresence.
|
||||
func (mr *MockStoreMockRecorder) GetAIProviderKeyPresence(ctx, providerIds any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAIProviderKeyPresence", reflect.TypeOf((*MockStore)(nil).GetAIProviderKeyPresence), ctx, providerIds)
|
||||
}
|
||||
|
||||
// GetAIProviderKeys mocks base method.
|
||||
func (m *MockStore) GetAIProviderKeys(ctx context.Context, includeDeleted bool) ([]database.AIProviderKey, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -253,8 +253,14 @@ type sqlcQuerier interface {
|
||||
GetAIBridgeUserPromptsByInterceptionID(ctx context.Context, interceptionID uuid.UUID) ([]AIBridgeUserPrompt, error)
|
||||
GetAIModelPriceByProviderModel(ctx context.Context, arg GetAIModelPriceByProviderModelParams) (AiModelPrice, error)
|
||||
GetAIProviderByID(ctx context.Context, id uuid.UUID) (AIProvider, error)
|
||||
// Lock the provider row until the model-config write completes. The
|
||||
// transaction alone does not stop a concurrent soft-delete or disable
|
||||
// between validation and writing the model config reference.
|
||||
GetAIProviderByIDForReferenceLock(ctx context.Context, id uuid.UUID) (AIProvider, error)
|
||||
GetAIProviderByName(ctx context.Context, name string) (AIProvider, error)
|
||||
GetAIProviderKeyByID(ctx context.Context, id uuid.UUID) (AIProviderKey, error)
|
||||
// Returns the provider IDs that have at least one provider-scoped key.
|
||||
GetAIProviderKeyPresence(ctx context.Context, providerIds []uuid.UUID) ([]uuid.UUID, error)
|
||||
// Returns AI provider key rows. By default, only rows whose parent
|
||||
// provider is live (deleted = FALSE) are returned, so the API list
|
||||
// handler can fetch every visible provider's keys in a single query.
|
||||
|
||||
@@ -10567,6 +10567,77 @@ func TestInsertWorkspaceAgentDevcontainers(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetEnabledChatModelConfigsUsesAIProviders(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
store, _ := dbtestutil.NewDB(t)
|
||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||
|
||||
enabledProvider := dbgen.AIProvider(t, store, database.AIProvider{
|
||||
Type: database.AiProviderTypeOpenrouter,
|
||||
Name: "openrouter-" + uuid.NewString(),
|
||||
})
|
||||
disabledProvider := dbgen.AIProvider(t, store, database.AIProvider{
|
||||
Type: database.AiProviderTypeVercel,
|
||||
Name: "vercel-" + uuid.NewString(),
|
||||
}, func(params *database.InsertAIProviderParams) {
|
||||
params.Enabled = false
|
||||
})
|
||||
enabledConfig := dbgen.ChatModelConfig(t, store, database.ChatModelConfig{
|
||||
Provider: string(enabledProvider.Type),
|
||||
Model: "openrouter-model-" + uuid.NewString(),
|
||||
AIProviderID: uuid.NullUUID{
|
||||
UUID: enabledProvider.ID,
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
disabledProviderConfig := dbgen.ChatModelConfig(t, store, database.ChatModelConfig{
|
||||
Provider: string(disabledProvider.Type),
|
||||
Model: "vercel-model-" + uuid.NewString(),
|
||||
AIProviderID: uuid.NullUUID{
|
||||
UUID: disabledProvider.ID,
|
||||
Valid: true,
|
||||
},
|
||||
})
|
||||
disabledModelConfig := dbgen.ChatModelConfig(t, store, database.ChatModelConfig{
|
||||
Provider: string(enabledProvider.Type),
|
||||
Model: "disabled-model-" + uuid.NewString(),
|
||||
AIProviderID: uuid.NullUUID{
|
||||
UUID: enabledProvider.ID,
|
||||
Valid: true,
|
||||
},
|
||||
}, func(params *database.InsertChatModelConfigParams) {
|
||||
params.Enabled = false
|
||||
})
|
||||
legacyProvider := dbgen.ChatProvider(t, store, database.ChatProvider{Provider: "google"})
|
||||
legacyConfig := dbgen.ChatModelConfig(t, store, database.ChatModelConfig{
|
||||
Provider: legacyProvider.Provider,
|
||||
Model: "google-model-" + uuid.NewString(),
|
||||
})
|
||||
|
||||
configs, err := store.GetEnabledChatModelConfigs(ctx)
|
||||
require.NoError(t, err)
|
||||
require.True(t, slices.ContainsFunc(configs, func(config database.ChatModelConfig) bool {
|
||||
return config.ID == enabledConfig.ID
|
||||
}))
|
||||
require.True(t, slices.ContainsFunc(configs, func(config database.ChatModelConfig) bool {
|
||||
return config.ID == legacyConfig.ID
|
||||
}))
|
||||
require.False(t, slices.ContainsFunc(configs, func(config database.ChatModelConfig) bool {
|
||||
return config.ID == disabledProviderConfig.ID
|
||||
}))
|
||||
require.False(t, slices.ContainsFunc(configs, func(config database.ChatModelConfig) bool {
|
||||
return config.ID == disabledModelConfig.ID
|
||||
}))
|
||||
|
||||
config, err := store.GetEnabledChatModelConfigByID(ctx, enabledConfig.ID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, enabledConfig.ID, config.ID)
|
||||
|
||||
_, err = store.GetEnabledChatModelConfigByID(ctx, disabledProviderConfig.ID)
|
||||
require.ErrorIs(t, err, sql.ErrNoRows)
|
||||
}
|
||||
|
||||
func TestInsertChatMessages(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@@ -146,6 +146,41 @@ func (q *sqlQuerier) GetAIProviderKeyByID(ctx context.Context, id uuid.UUID) (AI
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getAIProviderKeyPresence = `-- name: GetAIProviderKeyPresence :many
|
||||
SELECT DISTINCT
|
||||
provider_id
|
||||
FROM
|
||||
ai_provider_keys
|
||||
WHERE
|
||||
provider_id = ANY($1::uuid[])
|
||||
ORDER BY
|
||||
provider_id ASC
|
||||
`
|
||||
|
||||
// Returns the provider IDs that have at least one provider-scoped key.
|
||||
func (q *sqlQuerier) GetAIProviderKeyPresence(ctx context.Context, providerIds []uuid.UUID) ([]uuid.UUID, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getAIProviderKeyPresence, pq.Array(providerIds))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []uuid.UUID
|
||||
for rows.Next() {
|
||||
var provider_id uuid.UUID
|
||||
if err := rows.Scan(&provider_id); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, provider_id)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const getAIProviderKeys = `-- name: GetAIProviderKeys :many
|
||||
SELECT
|
||||
ai_provider_keys.id, ai_provider_keys.provider_id, ai_provider_keys.api_key, ai_provider_keys.api_key_key_id, ai_provider_keys.created_at, ai_provider_keys.updated_at
|
||||
@@ -371,6 +406,38 @@ func (q *sqlQuerier) GetAIProviderByID(ctx context.Context, id uuid.UUID) (AIPro
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getAIProviderByIDForReferenceLock = `-- name: GetAIProviderByIDForReferenceLock :one
|
||||
SELECT
|
||||
id, type, name, display_name, enabled, deleted, base_url, settings, settings_key_id, created_at, updated_at
|
||||
FROM
|
||||
ai_providers
|
||||
WHERE
|
||||
id = $1::uuid AND deleted = FALSE
|
||||
FOR SHARE
|
||||
`
|
||||
|
||||
// Lock the provider row until the model-config write completes. The
|
||||
// transaction alone does not stop a concurrent soft-delete or disable
|
||||
// between validation and writing the model config reference.
|
||||
func (q *sqlQuerier) GetAIProviderByIDForReferenceLock(ctx context.Context, id uuid.UUID) (AIProvider, error) {
|
||||
row := q.db.QueryRowContext(ctx, getAIProviderByIDForReferenceLock, id)
|
||||
var i AIProvider
|
||||
err := row.Scan(
|
||||
&i.ID,
|
||||
&i.Type,
|
||||
&i.Name,
|
||||
&i.DisplayName,
|
||||
&i.Enabled,
|
||||
&i.Deleted,
|
||||
&i.BaseUrl,
|
||||
&i.Settings,
|
||||
&i.SettingsKeyID,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getAIProviderByName = `-- name: GetAIProviderByName :one
|
||||
SELECT
|
||||
id, type, name, display_name, enabled, deleted, base_url, settings, settings_key_id, created_at, updated_at
|
||||
@@ -5097,13 +5164,18 @@ SELECT
|
||||
cmc.id, cmc.provider, cmc.model, cmc.display_name, cmc.created_by, cmc.updated_by, cmc.enabled, cmc.is_default, cmc.deleted, cmc.deleted_at, cmc.created_at, cmc.updated_at, cmc.context_limit, cmc.compression_threshold, cmc.options, cmc.ai_provider_id
|
||||
FROM
|
||||
chat_model_configs cmc
|
||||
JOIN
|
||||
chat_providers cp ON cp.provider = cmc.provider
|
||||
LEFT JOIN
|
||||
ai_providers ap ON ap.id = cmc.ai_provider_id
|
||||
LEFT JOIN
|
||||
chat_providers cp ON cp.provider = cmc.provider AND cmc.ai_provider_id IS NULL
|
||||
WHERE
|
||||
cmc.id = $1::uuid
|
||||
AND cmc.deleted = FALSE
|
||||
AND cmc.enabled = TRUE
|
||||
AND cp.enabled = TRUE
|
||||
AND (
|
||||
(cmc.ai_provider_id IS NOT NULL AND ap.enabled = TRUE AND ap.deleted = FALSE)
|
||||
OR (cmc.ai_provider_id IS NULL AND cp.enabled = TRUE)
|
||||
)
|
||||
`
|
||||
|
||||
// Providers can be disabled independently of their model configs.
|
||||
@@ -5137,12 +5209,17 @@ SELECT
|
||||
cmc.id, cmc.provider, cmc.model, cmc.display_name, cmc.created_by, cmc.updated_by, cmc.enabled, cmc.is_default, cmc.deleted, cmc.deleted_at, cmc.created_at, cmc.updated_at, cmc.context_limit, cmc.compression_threshold, cmc.options, cmc.ai_provider_id
|
||||
FROM
|
||||
chat_model_configs cmc
|
||||
JOIN
|
||||
chat_providers cp ON cp.provider = cmc.provider
|
||||
LEFT JOIN
|
||||
ai_providers ap ON ap.id = cmc.ai_provider_id
|
||||
LEFT JOIN
|
||||
chat_providers cp ON cp.provider = cmc.provider AND cmc.ai_provider_id IS NULL
|
||||
WHERE
|
||||
cmc.enabled = TRUE
|
||||
AND cmc.deleted = FALSE
|
||||
AND cp.enabled = TRUE
|
||||
AND (
|
||||
(cmc.ai_provider_id IS NOT NULL AND ap.enabled = TRUE AND ap.deleted = FALSE)
|
||||
OR (cmc.ai_provider_id IS NULL AND cp.enabled = TRUE)
|
||||
)
|
||||
ORDER BY
|
||||
cmc.provider ASC,
|
||||
cmc.model ASC,
|
||||
@@ -5201,7 +5278,8 @@ INSERT INTO chat_model_configs (
|
||||
is_default,
|
||||
context_limit,
|
||||
compression_threshold,
|
||||
options
|
||||
options,
|
||||
ai_provider_id
|
||||
) VALUES (
|
||||
$1::text,
|
||||
$2::text,
|
||||
@@ -5212,7 +5290,8 @@ INSERT INTO chat_model_configs (
|
||||
$7::boolean,
|
||||
$8::bigint,
|
||||
$9::integer,
|
||||
$10::jsonb
|
||||
$10::jsonb,
|
||||
$11::uuid
|
||||
)
|
||||
RETURNING
|
||||
id, provider, model, display_name, created_by, updated_by, enabled, is_default, deleted, deleted_at, created_at, updated_at, context_limit, compression_threshold, options, ai_provider_id
|
||||
@@ -5229,6 +5308,7 @@ type InsertChatModelConfigParams struct {
|
||||
ContextLimit int64 `db:"context_limit" json:"context_limit"`
|
||||
CompressionThreshold int32 `db:"compression_threshold" json:"compression_threshold"`
|
||||
Options json.RawMessage `db:"options" json:"options"`
|
||||
AIProviderID uuid.NullUUID `db:"ai_provider_id" json:"ai_provider_id"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) InsertChatModelConfig(ctx context.Context, arg InsertChatModelConfigParams) (ChatModelConfig, error) {
|
||||
@@ -5243,6 +5323,7 @@ func (q *sqlQuerier) InsertChatModelConfig(ctx context.Context, arg InsertChatMo
|
||||
arg.ContextLimit,
|
||||
arg.CompressionThreshold,
|
||||
arg.Options,
|
||||
arg.AIProviderID,
|
||||
)
|
||||
var i ChatModelConfig
|
||||
err := row.Scan(
|
||||
@@ -5295,9 +5376,10 @@ SET
|
||||
context_limit = $7::bigint,
|
||||
compression_threshold = $8::integer,
|
||||
options = $9::jsonb,
|
||||
ai_provider_id = $10::uuid,
|
||||
updated_at = NOW()
|
||||
WHERE
|
||||
id = $10::uuid
|
||||
id = $11::uuid
|
||||
AND deleted = FALSE
|
||||
RETURNING
|
||||
id, provider, model, display_name, created_by, updated_by, enabled, is_default, deleted, deleted_at, created_at, updated_at, context_limit, compression_threshold, options, ai_provider_id
|
||||
@@ -5313,6 +5395,7 @@ type UpdateChatModelConfigParams struct {
|
||||
ContextLimit int64 `db:"context_limit" json:"context_limit"`
|
||||
CompressionThreshold int32 `db:"compression_threshold" json:"compression_threshold"`
|
||||
Options json.RawMessage `db:"options" json:"options"`
|
||||
AIProviderID uuid.NullUUID `db:"ai_provider_id" json:"ai_provider_id"`
|
||||
ID uuid.UUID `db:"id" json:"id"`
|
||||
}
|
||||
|
||||
@@ -5327,6 +5410,7 @@ func (q *sqlQuerier) UpdateChatModelConfig(ctx context.Context, arg UpdateChatMo
|
||||
arg.ContextLimit,
|
||||
arg.CompressionThreshold,
|
||||
arg.Options,
|
||||
arg.AIProviderID,
|
||||
arg.ID,
|
||||
)
|
||||
var i ChatModelConfig
|
||||
|
||||
@@ -21,6 +21,17 @@ ORDER BY
|
||||
created_at ASC,
|
||||
id ASC;
|
||||
|
||||
-- name: GetAIProviderKeyPresence :many
|
||||
-- Returns the provider IDs that have at least one provider-scoped key.
|
||||
SELECT DISTINCT
|
||||
provider_id
|
||||
FROM
|
||||
ai_provider_keys
|
||||
WHERE
|
||||
provider_id = ANY(@provider_ids::uuid[])
|
||||
ORDER BY
|
||||
provider_id ASC;
|
||||
|
||||
-- name: GetAIProviderKeys :many
|
||||
-- Returns AI provider key rows. By default, only rows whose parent
|
||||
-- provider is live (deleted = FALSE) are returned, so the API list
|
||||
|
||||
@@ -6,6 +6,18 @@ FROM
|
||||
WHERE
|
||||
id = @id::uuid AND deleted = FALSE;
|
||||
|
||||
-- name: GetAIProviderByIDForReferenceLock :one
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
ai_providers
|
||||
WHERE
|
||||
id = @id::uuid AND deleted = FALSE
|
||||
-- Lock the provider row until the model-config write completes. The
|
||||
-- transaction alone does not stop a concurrent soft-delete or disable
|
||||
-- between validation and writing the model config reference.
|
||||
FOR SHARE;
|
||||
|
||||
-- name: GetAIProviderByName :one
|
||||
SELECT
|
||||
*
|
||||
|
||||
@@ -34,12 +34,17 @@ SELECT
|
||||
cmc.*
|
||||
FROM
|
||||
chat_model_configs cmc
|
||||
JOIN
|
||||
chat_providers cp ON cp.provider = cmc.provider
|
||||
LEFT JOIN
|
||||
ai_providers ap ON ap.id = cmc.ai_provider_id
|
||||
LEFT JOIN
|
||||
chat_providers cp ON cp.provider = cmc.provider AND cmc.ai_provider_id IS NULL
|
||||
WHERE
|
||||
cmc.enabled = TRUE
|
||||
AND cmc.deleted = FALSE
|
||||
AND cp.enabled = TRUE
|
||||
AND (
|
||||
(cmc.ai_provider_id IS NOT NULL AND ap.enabled = TRUE AND ap.deleted = FALSE)
|
||||
OR (cmc.ai_provider_id IS NULL AND cp.enabled = TRUE)
|
||||
)
|
||||
ORDER BY
|
||||
cmc.provider ASC,
|
||||
cmc.model ASC,
|
||||
@@ -53,13 +58,18 @@ FROM
|
||||
chat_model_configs cmc
|
||||
-- Providers can be disabled independently of their model configs.
|
||||
-- Check both to ensure the selected config is actually usable.
|
||||
JOIN
|
||||
chat_providers cp ON cp.provider = cmc.provider
|
||||
LEFT JOIN
|
||||
ai_providers ap ON ap.id = cmc.ai_provider_id
|
||||
LEFT JOIN
|
||||
chat_providers cp ON cp.provider = cmc.provider AND cmc.ai_provider_id IS NULL
|
||||
WHERE
|
||||
cmc.id = @id::uuid
|
||||
AND cmc.deleted = FALSE
|
||||
AND cmc.enabled = TRUE
|
||||
AND cp.enabled = TRUE;
|
||||
AND (
|
||||
(cmc.ai_provider_id IS NOT NULL AND ap.enabled = TRUE AND ap.deleted = FALSE)
|
||||
OR (cmc.ai_provider_id IS NULL AND cp.enabled = TRUE)
|
||||
);
|
||||
|
||||
-- name: InsertChatModelConfig :one
|
||||
INSERT INTO chat_model_configs (
|
||||
@@ -72,7 +82,8 @@ INSERT INTO chat_model_configs (
|
||||
is_default,
|
||||
context_limit,
|
||||
compression_threshold,
|
||||
options
|
||||
options,
|
||||
ai_provider_id
|
||||
) VALUES (
|
||||
@provider::text,
|
||||
@model::text,
|
||||
@@ -83,7 +94,8 @@ INSERT INTO chat_model_configs (
|
||||
@is_default::boolean,
|
||||
@context_limit::bigint,
|
||||
@compression_threshold::integer,
|
||||
@options::jsonb
|
||||
@options::jsonb,
|
||||
sqlc.narg('ai_provider_id')::uuid
|
||||
)
|
||||
RETURNING
|
||||
*;
|
||||
@@ -101,6 +113,7 @@ SET
|
||||
context_limit = @context_limit::bigint,
|
||||
compression_threshold = @compression_threshold::integer,
|
||||
options = @options::jsonb,
|
||||
ai_provider_id = sqlc.narg('ai_provider_id')::uuid,
|
||||
updated_at = NOW()
|
||||
WHERE
|
||||
id = @id::uuid
|
||||
|
||||
+265
-16
@@ -6396,6 +6396,179 @@ func convertChatMessages(messages []database.ChatMessage) []codersdk.ChatMessage
|
||||
return result
|
||||
}
|
||||
|
||||
func parseUserAIProviderID(r *http.Request) (uuid.UUID, error) {
|
||||
return uuid.Parse(chi.URLParam(r, "aiProvider"))
|
||||
}
|
||||
|
||||
func convertAIProviderSummary(provider database.AIProvider) codersdk.AIProviderSummary {
|
||||
displayName := provider.Name
|
||||
if provider.DisplayName.Valid && provider.DisplayName.String != "" {
|
||||
displayName = provider.DisplayName.String
|
||||
}
|
||||
return codersdk.AIProviderSummary{
|
||||
ID: provider.ID,
|
||||
Type: codersdk.AIProviderType(provider.Type),
|
||||
Name: provider.Name,
|
||||
DisplayName: displayName,
|
||||
Enabled: provider.Enabled,
|
||||
Deleted: provider.Deleted,
|
||||
}
|
||||
}
|
||||
|
||||
func (api *API) listUserAIProviderKeyConfigs(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
targetUser := httpmw.UserParam(r)
|
||||
//nolint:gocritic // Users can list limited provider metadata to manage their own AI provider keys.
|
||||
metadataCtx := dbauthz.AsAIProviderMetadataReader(ctx)
|
||||
providers, err := api.Database.GetAIProviders(metadataCtx, database.GetAIProvidersParams{IncludeDisabled: true})
|
||||
if err != nil {
|
||||
api.Logger.Error(ctx, "failed to list user AI provider configs", slog.Error(err), slog.F("user_id", targetUser.ID))
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to list AI providers."})
|
||||
return
|
||||
}
|
||||
keys, err := api.Database.GetUserAIProviderKeysByUserID(ctx, targetUser.ID)
|
||||
if err != nil {
|
||||
api.Logger.Error(ctx, "failed to list user AI provider keys", slog.Error(err), slog.F("user_id", targetUser.ID))
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to list user AI provider keys."})
|
||||
return
|
||||
}
|
||||
|
||||
keysByProviderID := make(map[uuid.UUID]struct{}, len(keys))
|
||||
for _, key := range keys {
|
||||
keysByProviderID[key.AIProviderID] = struct{}{}
|
||||
}
|
||||
|
||||
visibleProviders := make([]database.AIProvider, 0, len(providers))
|
||||
visibleProviderIDs := make([]uuid.UUID, 0, len(providers))
|
||||
for _, provider := range providers {
|
||||
_, hasUserKey := keysByProviderID[provider.ID]
|
||||
if !provider.Enabled && !hasUserKey {
|
||||
continue
|
||||
}
|
||||
visibleProviders = append(visibleProviders, provider)
|
||||
visibleProviderIDs = append(visibleProviderIDs, provider.ID)
|
||||
}
|
||||
|
||||
providerKeysByProviderID := make(map[uuid.UUID]struct{}, len(visibleProviderIDs))
|
||||
if len(visibleProviderIDs) > 0 {
|
||||
providerKeyIDs, err := api.Database.GetAIProviderKeyPresence(metadataCtx, visibleProviderIDs)
|
||||
if err != nil {
|
||||
api.Logger.Error(ctx, "failed to list AI provider key presence", slog.Error(err), slog.F("user_id", targetUser.ID))
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to list AI provider keys."})
|
||||
return
|
||||
}
|
||||
for _, providerID := range providerKeyIDs {
|
||||
providerKeysByProviderID[providerID] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
byokEnabled := api.DeploymentValues.AI.BridgeConfig.AllowBYOK.Value()
|
||||
configs := make([]codersdk.UserAIProviderKeyConfig, 0, len(visibleProviders))
|
||||
for _, provider := range visibleProviders {
|
||||
_, hasUserKey := keysByProviderID[provider.ID]
|
||||
_, hasProviderKey := providerKeysByProviderID[provider.ID]
|
||||
configs = append(configs, codersdk.UserAIProviderKeyConfig{
|
||||
Provider: convertAIProviderSummary(provider),
|
||||
HasUserAPIKey: hasUserKey,
|
||||
HasProviderAPIKey: hasProviderKey,
|
||||
BYOKEnabled: byokEnabled,
|
||||
})
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusOK, configs)
|
||||
}
|
||||
|
||||
func (api *API) upsertUserAIProviderKey(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
if !api.DeploymentValues.AI.BridgeConfig.AllowBYOK.Value() {
|
||||
httpapi.Write(ctx, rw, http.StatusForbidden, codersdk.Response{Message: "BYOK is disabled."})
|
||||
return
|
||||
}
|
||||
targetUser := httpmw.UserParam(r)
|
||||
providerID, err := parseUserAIProviderID(r)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{Message: "Invalid AI provider ID."})
|
||||
return
|
||||
}
|
||||
//nolint:gocritic // Users can attach their own key to an enabled provider without AI provider admin permissions.
|
||||
metadataCtx := dbauthz.AsAIProviderMetadataReader(ctx)
|
||||
provider, err := api.Database.GetAIProviderByID(metadataCtx, providerID)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{Message: "AI provider not found."})
|
||||
return
|
||||
}
|
||||
api.Logger.Error(ctx, "failed to get AI provider", slog.Error(err), slog.F("ai_provider_id", providerID))
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to get AI provider."})
|
||||
return
|
||||
}
|
||||
if !provider.Enabled {
|
||||
httpapi.Write(ctx, rw, http.StatusPreconditionFailed, codersdk.Response{Message: "AI provider is disabled."})
|
||||
return
|
||||
}
|
||||
var req codersdk.CreateUserAIProviderKeyRequest
|
||||
if !httpapi.Read(ctx, rw, r, &req) {
|
||||
return
|
||||
}
|
||||
if err := validateChatProviderAPIKeySize(req.APIKey); err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "API key too large.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if req.APIKey == "" {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{Message: "API key is required."})
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(req.APIKey) != req.APIKey {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{Message: "API key must not contain leading or trailing whitespace."})
|
||||
return
|
||||
}
|
||||
providerKeys, err := api.Database.GetAIProviderKeyPresence(metadataCtx, []uuid.UUID{providerID})
|
||||
if err != nil {
|
||||
api.Logger.Error(ctx, "failed to list AI provider key presence", slog.Error(err), slog.F("ai_provider_id", providerID))
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to list AI provider keys."})
|
||||
return
|
||||
}
|
||||
now := api.Clock.Now()
|
||||
_, err = api.Database.UpsertUserAIProviderKey(ctx, database.UpsertUserAIProviderKeyParams{
|
||||
ID: uuid.New(),
|
||||
UserID: targetUser.ID,
|
||||
AIProviderID: providerID,
|
||||
APIKey: req.APIKey,
|
||||
ApiKeyKeyID: sql.NullString{},
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
})
|
||||
if err != nil {
|
||||
api.Logger.Error(ctx, "failed to update user AI provider key", slog.Error(err), slog.F("user_id", targetUser.ID), slog.F("ai_provider_id", providerID))
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to update user AI provider key."})
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusOK, codersdk.UserAIProviderKeyConfig{
|
||||
Provider: convertAIProviderSummary(provider),
|
||||
HasUserAPIKey: true,
|
||||
HasProviderAPIKey: len(providerKeys) > 0,
|
||||
BYOKEnabled: true,
|
||||
})
|
||||
}
|
||||
|
||||
func (api *API) deleteUserAIProviderKey(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
targetUser := httpmw.UserParam(r)
|
||||
providerID, err := parseUserAIProviderID(r)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{Message: "Invalid AI provider ID."})
|
||||
return
|
||||
}
|
||||
if err := api.Database.DeleteUserAIProviderKey(ctx, database.DeleteUserAIProviderKeyParams{UserID: targetUser.ID, AIProviderID: providerID}); err != nil {
|
||||
api.Logger.Error(ctx, "failed to delete user AI provider key", slog.Error(err), slog.F("user_id", targetUser.ID), slog.F("ai_provider_id", providerID))
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{Message: "Failed to delete user AI provider key."})
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusNoContent, nil)
|
||||
}
|
||||
|
||||
func (api *API) listChatProviders(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
//nolint:gocritic // System context required to read enabled chat providers.
|
||||
@@ -6890,6 +7063,7 @@ func (api *API) listUserChatProviderConfigs(rw http.ResponseWriter, r *http.Requ
|
||||
provider,
|
||||
hasUserAPIKey,
|
||||
hasCentralAPIKeyFallback,
|
||||
api.DeploymentValues.AI.BridgeConfig.AllowBYOK.Value(),
|
||||
),
|
||||
)
|
||||
}
|
||||
@@ -6978,6 +7152,7 @@ func (api *API) upsertUserChatProviderKey(rw http.ResponseWriter, r *http.Reques
|
||||
provider,
|
||||
true,
|
||||
hasCentralAPIKeyFallback,
|
||||
api.DeploymentValues.AI.BridgeConfig.AllowBYOK.Value(),
|
||||
),
|
||||
)
|
||||
}
|
||||
@@ -7052,14 +7227,29 @@ func (api *API) createChatModelConfig(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
provider := normalizeChatProvider(req.Provider)
|
||||
if provider == "" {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid provider.",
|
||||
Detail: chatProviderValidationDetail(),
|
||||
if req.AIProviderID == nil {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{Message: "AI provider ID is required."})
|
||||
return
|
||||
}
|
||||
//nolint:gocritic // The route already authorized chat model config updates.
|
||||
aiProvider, err := api.Database.GetAIProviderByID(dbauthz.AsChatd(ctx), *req.AIProviderID)
|
||||
if err != nil {
|
||||
if httpapi.Is404Error(err) {
|
||||
httpapi.Write(ctx, rw, http.StatusPreconditionFailed, codersdk.Response{Message: "AI provider is not configured."})
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to get AI provider.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if !aiProvider.Enabled {
|
||||
httpapi.Write(ctx, rw, http.StatusPreconditionFailed, codersdk.Response{Message: "AI provider is disabled."})
|
||||
return
|
||||
}
|
||||
provider := string(aiProvider.Type)
|
||||
aiProviderID := uuid.NullUUID{UUID: aiProvider.ID, Valid: true}
|
||||
|
||||
model := strings.TrimSpace(req.Model)
|
||||
if model == "" {
|
||||
@@ -7117,15 +7307,25 @@ func (api *API) createChatModelConfig(rw http.ResponseWriter, r *http.Request) {
|
||||
ContextLimit: contextLimit,
|
||||
CompressionThreshold: compressionThreshold,
|
||||
Options: modelConfigRaw,
|
||||
AIProviderID: aiProviderID,
|
||||
CreatedBy: uuid.NullUUID{UUID: apiKey.UserID, Valid: apiKey.UserID != uuid.Nil},
|
||||
UpdatedBy: uuid.NullUUID{UUID: apiKey.UserID, Valid: apiKey.UserID != uuid.Nil},
|
||||
}
|
||||
|
||||
var inserted database.ChatModelConfig
|
||||
err := api.Database.InTx(func(tx database.Store) error {
|
||||
if err := requireChatProviderForModelConfig(ctx, tx, insertParams.Provider); err != nil {
|
||||
return err
|
||||
err = api.Database.InTx(func(tx database.Store) error {
|
||||
//nolint:gocritic // The route already authorized chat model config updates.
|
||||
lockedAIProvider, err := tx.GetAIProviderByIDForReferenceLock(dbauthz.AsChatd(ctx), insertParams.AIProviderID.UUID)
|
||||
if err != nil {
|
||||
if xerrors.Is(err, sql.ErrNoRows) {
|
||||
return errChatProviderNotConfigured
|
||||
}
|
||||
return xerrors.Errorf("get AI provider for update: %w", err)
|
||||
}
|
||||
if !lockedAIProvider.Enabled {
|
||||
return errChatProviderNotConfigured
|
||||
}
|
||||
insertParams.Provider = string(lockedAIProvider.Type)
|
||||
|
||||
insertAsDefault := isDefault
|
||||
if !insertAsDefault {
|
||||
@@ -7173,7 +7373,7 @@ func (api *API) createChatModelConfig(rw http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
return
|
||||
case xerrors.Is(err, errChatProviderNotConfigured):
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
httpapi.Write(ctx, rw, http.StatusPreconditionFailed, codersdk.Response{
|
||||
Message: "Chat provider is not configured.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
@@ -7224,15 +7424,40 @@ func (api *API) updateChatModelConfig(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
provider := existing.Provider
|
||||
if strings.TrimSpace(req.Provider) != "" {
|
||||
provider = normalizeChatProvider(req.Provider)
|
||||
if provider == "" {
|
||||
aiProviderID := existing.AIProviderID
|
||||
if req.AIProviderID != nil {
|
||||
//nolint:gocritic // The route already authorized chat model config updates.
|
||||
aiProvider, err := api.Database.GetAIProviderByID(dbauthz.AsChatd(ctx), *req.AIProviderID)
|
||||
if err != nil {
|
||||
if httpapi.Is404Error(err) {
|
||||
httpapi.Write(ctx, rw, http.StatusPreconditionFailed, codersdk.Response{Message: "AI provider is not configured."})
|
||||
return
|
||||
}
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Failed to get AI provider.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if !aiProvider.Enabled {
|
||||
httpapi.Write(ctx, rw, http.StatusPreconditionFailed, codersdk.Response{Message: "AI provider is disabled."})
|
||||
return
|
||||
}
|
||||
provider = string(aiProvider.Type)
|
||||
aiProviderID = uuid.NullUUID{UUID: aiProvider.ID, Valid: true}
|
||||
} else if strings.TrimSpace(req.Provider) != "" {
|
||||
requestedProvider := normalizeChatProvider(req.Provider)
|
||||
if requestedProvider == "" {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: "Invalid provider.",
|
||||
Detail: chatProviderValidationDetail(),
|
||||
})
|
||||
return
|
||||
}
|
||||
provider = requestedProvider
|
||||
if requestedProvider != existing.Provider {
|
||||
aiProviderID = uuid.NullUUID{}
|
||||
}
|
||||
}
|
||||
|
||||
model := existing.Model
|
||||
@@ -7299,14 +7524,30 @@ func (api *API) updateChatModelConfig(rw http.ResponseWriter, r *http.Request) {
|
||||
ContextLimit: contextLimit,
|
||||
CompressionThreshold: compressionThreshold,
|
||||
Options: modelConfigRaw,
|
||||
AIProviderID: aiProviderID,
|
||||
UpdatedBy: uuid.NullUUID{UUID: apiKey.UserID, Valid: apiKey.UserID != uuid.Nil},
|
||||
ID: existing.ID,
|
||||
}
|
||||
|
||||
var updated database.ChatModelConfig
|
||||
err = api.Database.InTx(func(tx database.Store) error {
|
||||
if err := requireChatProviderForModelConfig(ctx, tx, updateParams.Provider); err != nil {
|
||||
return err
|
||||
if updateParams.AIProviderID.Valid && req.AIProviderID != nil {
|
||||
//nolint:gocritic // The route already authorized chat model config updates.
|
||||
aiProvider, err := tx.GetAIProviderByIDForReferenceLock(dbauthz.AsChatd(ctx), updateParams.AIProviderID.UUID)
|
||||
if err != nil {
|
||||
if xerrors.Is(err, sql.ErrNoRows) {
|
||||
return errChatProviderNotConfigured
|
||||
}
|
||||
return xerrors.Errorf("get AI provider for update: %w", err)
|
||||
}
|
||||
if !aiProvider.Enabled {
|
||||
return errChatProviderNotConfigured
|
||||
}
|
||||
updateParams.Provider = string(aiProvider.Type)
|
||||
} else if !updateParams.AIProviderID.Valid {
|
||||
if err := requireChatProviderForModelConfig(ctx, tx, updateParams.Provider); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
setAsDefault := updateParams.IsDefault && !existing.IsDefault
|
||||
@@ -7357,7 +7598,7 @@ func (api *API) updateChatModelConfig(rw http.ResponseWriter, r *http.Request) {
|
||||
})
|
||||
return
|
||||
case xerrors.Is(err, errChatProviderNotConfigured):
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
httpapi.Write(ctx, rw, http.StatusPreconditionFailed, codersdk.Response{
|
||||
Message: "Chat provider is not configured.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
@@ -7487,6 +7728,7 @@ func chatModelConfigToUpdateParams(
|
||||
ContextLimit: config.ContextLimit,
|
||||
CompressionThreshold: config.CompressionThreshold,
|
||||
Options: config.Options,
|
||||
AIProviderID: config.AIProviderID,
|
||||
UpdatedBy: uuid.NullUUID{},
|
||||
ID: config.ID,
|
||||
}
|
||||
@@ -7589,6 +7831,7 @@ func convertUserChatProviderConfig(
|
||||
provider database.ChatProvider,
|
||||
hasUserAPIKey bool,
|
||||
hasCentralAPIKeyFallback bool,
|
||||
byokEnabled bool,
|
||||
) codersdk.UserChatProviderConfig {
|
||||
displayName := strings.TrimSpace(provider.DisplayName)
|
||||
if displayName == "" {
|
||||
@@ -7601,13 +7844,19 @@ func convertUserChatProviderConfig(
|
||||
DisplayName: displayName,
|
||||
HasUserAPIKey: hasUserAPIKey,
|
||||
HasCentralAPIKeyFallback: hasCentralAPIKeyFallback,
|
||||
BYOKEnabled: byokEnabled,
|
||||
}
|
||||
}
|
||||
|
||||
func convertChatModelConfig(config database.ChatModelConfig) codersdk.ChatModelConfig {
|
||||
var aiProviderID *uuid.UUID
|
||||
if config.AIProviderID.Valid {
|
||||
aiProviderID = &config.AIProviderID.UUID
|
||||
}
|
||||
return codersdk.ChatModelConfig{
|
||||
ID: config.ID,
|
||||
Provider: config.Provider,
|
||||
AIProviderID: aiProviderID,
|
||||
Model: config.Model,
|
||||
DisplayName: config.DisplayName,
|
||||
Enabled: config.Enabled,
|
||||
@@ -7787,7 +8036,7 @@ const maxChatProviderAPIKeySize = 10240 // 10 KB
|
||||
|
||||
func validateChatProviderAPIKeySize(apiKey string) error {
|
||||
if len(apiKey) > maxChatProviderAPIKeySize {
|
||||
return xerrors.Errorf("API key exceeds maximum size of %d bytes", maxChatProviderAPIKeySize)
|
||||
return xerrors.Errorf("API key exceeds maximum size of 10 KB (%d bytes)", maxChatProviderAPIKeySize)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
+444
-54
@@ -1801,16 +1801,19 @@ func TestListChatModels(t *testing.T) {
|
||||
client := newChatClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client.Client)
|
||||
|
||||
provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
Provider: "anthropic",
|
||||
providerType := database.AiProviderTypeAnthropic
|
||||
chatProvider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
Provider: string(providerType),
|
||||
CentralAPIKeyEnabled: ptr.Ref(false),
|
||||
AllowUserAPIKey: ptr.Ref(true),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
aiProvider := createAIProviderForTest(t, client, string(providerType), "")
|
||||
|
||||
contextLimit := int64(4096)
|
||||
_, err = client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: "anthropic",
|
||||
Provider: string(providerType),
|
||||
AIProviderID: &aiProvider.ID,
|
||||
Model: "claude-sonnet",
|
||||
ContextLimit: &contextLimit,
|
||||
})
|
||||
@@ -1821,7 +1824,7 @@ func TestListChatModels(t *testing.T) {
|
||||
|
||||
var anthropicProvider *codersdk.ChatModelProvider
|
||||
for i := range models.Providers {
|
||||
if models.Providers[i].Provider == "anthropic" {
|
||||
if models.Providers[i].Provider == string(providerType) {
|
||||
anthropicProvider = &models.Providers[i]
|
||||
break
|
||||
}
|
||||
@@ -1830,7 +1833,7 @@ func TestListChatModels(t *testing.T) {
|
||||
require.False(t, anthropicProvider.Available)
|
||||
require.Equal(t, codersdk.ChatModelProviderUnavailableReasonUserAPIKeyRequired, anthropicProvider.UnavailableReason)
|
||||
|
||||
_, err = client.UpsertUserChatProviderKey(ctx, provider.ID, codersdk.CreateUserChatProviderKeyRequest{
|
||||
_, err = client.UpsertUserChatProviderKey(ctx, chatProvider.ID, codersdk.CreateUserChatProviderKeyRequest{
|
||||
APIKey: "user-api-key",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
@@ -1856,7 +1859,7 @@ func TestListChatModels(t *testing.T) {
|
||||
client := newChatClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client.Client)
|
||||
|
||||
provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
chatProvider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
Provider: "google",
|
||||
APIKey: "central-api-key",
|
||||
CentralAPIKeyEnabled: ptr.Ref(true),
|
||||
@@ -1864,10 +1867,12 @@ func TestListChatModels(t *testing.T) {
|
||||
AllowCentralAPIKeyFallback: ptr.Ref(true),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
aiProvider := createAIProviderForTest(t, client, "google", "provider-api-key")
|
||||
|
||||
contextLimit := int64(4096)
|
||||
_, err = client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: "google",
|
||||
AIProviderID: &aiProvider.ID,
|
||||
Model: "gemini-1.5-pro",
|
||||
ContextLimit: &contextLimit,
|
||||
})
|
||||
@@ -1886,7 +1891,7 @@ func TestListChatModels(t *testing.T) {
|
||||
require.NotNil(t, googleProvider)
|
||||
require.True(t, googleProvider.Available)
|
||||
|
||||
_, err = client.UpsertUserChatProviderKey(ctx, provider.ID, codersdk.CreateUserChatProviderKeyRequest{
|
||||
_, err = client.UpsertUserChatProviderKey(ctx, chatProvider.ID, codersdk.CreateUserChatProviderKeyRequest{
|
||||
APIKey: "user-api-key",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
@@ -1914,15 +1919,17 @@ func TestListChatModels(t *testing.T) {
|
||||
client := newChatClientWithDeploymentValues(t, values)
|
||||
_ = coderdtest.CreateFirstUser(t, client.Client)
|
||||
|
||||
provider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
chatProvider, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
Provider: "openai",
|
||||
APIKey: "test-key",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
aiProvider := createAIProviderForTest(t, client, "openai", "test-key")
|
||||
|
||||
contextLimit := int64(4096)
|
||||
_, err = client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: "openai",
|
||||
AIProviderID: &aiProvider.ID,
|
||||
Model: "gpt-4o-mini",
|
||||
ContextLimit: &contextLimit,
|
||||
})
|
||||
@@ -1936,7 +1943,7 @@ func TestListChatModels(t *testing.T) {
|
||||
require.Equal(t, "gpt-4o-mini", models.Providers[0].Models[0].Model)
|
||||
|
||||
enabled := false
|
||||
_, err = client.UpdateChatProvider(ctx, provider.ID, codersdk.UpdateChatProviderConfigRequest{
|
||||
_, err = client.UpdateChatProvider(ctx, chatProvider.ID, codersdk.UpdateChatProviderConfigRequest{
|
||||
Enabled: &enabled,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
@@ -2261,6 +2268,186 @@ func TestWatchChats(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestUserAIProviderKeys(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
createOpenAIProvider := func(t *testing.T, client *codersdk.ExperimentalClient, name string, enabled bool, apiKeys ...string) codersdk.AIProvider {
|
||||
t.Helper()
|
||||
|
||||
provider, err := client.CreateAIProvider(testutil.Context(t, testutil.WaitLong), codersdk.CreateAIProviderRequest{
|
||||
Type: codersdk.AIProviderTypeOpenAI,
|
||||
Name: name,
|
||||
Enabled: enabled,
|
||||
BaseURL: "https://api.openai.example.com/v1",
|
||||
APIKeys: apiKeys,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return provider
|
||||
}
|
||||
|
||||
findUserAIProviderKeyConfig := func(
|
||||
t *testing.T,
|
||||
configs []codersdk.UserAIProviderKeyConfig,
|
||||
providerID uuid.UUID,
|
||||
) *codersdk.UserAIProviderKeyConfig {
|
||||
t.Helper()
|
||||
|
||||
for i := range configs {
|
||||
if configs[i].Provider.ID == providerID {
|
||||
return &configs[i]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
t.Run("SelfServiceLifecycle", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
adminClient := newChatClient(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, adminClient.Client)
|
||||
memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID)
|
||||
memberClient := codersdk.NewExperimentalClient(memberClientRaw)
|
||||
|
||||
provider := createOpenAIProvider(t, adminClient, "test-user-key-"+uuid.NewString(), true, "test-provider-api-key")
|
||||
|
||||
configs, err := memberClient.ListUserAIProviderKeyConfigs(ctx, "me")
|
||||
require.NoError(t, err)
|
||||
cfg := findUserAIProviderKeyConfig(t, configs, provider.ID)
|
||||
require.NotNil(t, cfg)
|
||||
require.False(t, cfg.HasUserAPIKey)
|
||||
require.True(t, cfg.HasProviderAPIKey)
|
||||
require.True(t, cfg.BYOKEnabled)
|
||||
|
||||
cfgValue, err := memberClient.UpsertUserAIProviderKey(ctx, "me", provider.ID, codersdk.CreateUserAIProviderKeyRequest{APIKey: "test-user-api-key"})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, provider.ID, cfgValue.Provider.ID)
|
||||
require.True(t, cfgValue.HasUserAPIKey)
|
||||
require.True(t, cfgValue.HasProviderAPIKey)
|
||||
require.True(t, cfgValue.BYOKEnabled)
|
||||
|
||||
configs, err = memberClient.ListUserAIProviderKeyConfigs(ctx, "me")
|
||||
require.NoError(t, err)
|
||||
cfg = findUserAIProviderKeyConfig(t, configs, provider.ID)
|
||||
require.NotNil(t, cfg)
|
||||
require.True(t, cfg.HasUserAPIKey)
|
||||
|
||||
cfgValue, err = memberClient.UpsertUserAIProviderKey(ctx, "me", provider.ID, codersdk.CreateUserAIProviderKeyRequest{APIKey: "replacement-user-api-key"})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, provider.ID, cfgValue.Provider.ID)
|
||||
require.True(t, cfgValue.HasUserAPIKey)
|
||||
|
||||
configs, err = memberClient.ListUserAIProviderKeyConfigs(ctx, "me")
|
||||
require.NoError(t, err)
|
||||
cfg = findUserAIProviderKeyConfig(t, configs, provider.ID)
|
||||
require.NotNil(t, cfg)
|
||||
require.True(t, cfg.HasUserAPIKey)
|
||||
|
||||
require.NoError(t, memberClient.DeleteUserAIProviderKey(ctx, "me", provider.ID))
|
||||
configs, err = memberClient.ListUserAIProviderKeyConfigs(ctx, "me")
|
||||
require.NoError(t, err)
|
||||
cfg = findUserAIProviderKeyConfig(t, configs, provider.ID)
|
||||
require.NotNil(t, cfg)
|
||||
require.False(t, cfg.HasUserAPIKey)
|
||||
})
|
||||
|
||||
t.Run("ListsDisabledProviderWithSavedUserKey", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
adminClient := newChatClient(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, adminClient.Client)
|
||||
memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID)
|
||||
memberClient := codersdk.NewExperimentalClient(memberClientRaw)
|
||||
|
||||
provider := createOpenAIProvider(t, adminClient, "test-disabled-saved-user-key-"+uuid.NewString(), true)
|
||||
_, err := memberClient.UpsertUserAIProviderKey(ctx, "me", provider.ID, codersdk.CreateUserAIProviderKeyRequest{APIKey: "test-user-api-key"})
|
||||
require.NoError(t, err)
|
||||
|
||||
enabled := false
|
||||
_, err = adminClient.UpdateAIProvider(ctx, provider.ID.String(), codersdk.UpdateAIProviderRequest{Enabled: &enabled})
|
||||
require.NoError(t, err)
|
||||
|
||||
configs, err := memberClient.ListUserAIProviderKeyConfigs(ctx, "me")
|
||||
require.NoError(t, err)
|
||||
cfg := findUserAIProviderKeyConfig(t, configs, provider.ID)
|
||||
require.NotNil(t, cfg)
|
||||
require.False(t, cfg.Provider.Enabled)
|
||||
require.True(t, cfg.HasUserAPIKey)
|
||||
})
|
||||
|
||||
t.Run("RejectsDisabledProvider", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
adminClient := newChatClient(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, adminClient.Client)
|
||||
memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID)
|
||||
memberClient := codersdk.NewExperimentalClient(memberClientRaw)
|
||||
|
||||
provider := createOpenAIProvider(t, adminClient, "test-disabled-user-key-"+uuid.NewString(), false)
|
||||
|
||||
_, err := memberClient.UpsertUserAIProviderKey(ctx, "me", provider.ID, codersdk.CreateUserAIProviderKeyRequest{APIKey: "test-user-api-key"})
|
||||
sdkErr := requireSDKError(t, err, http.StatusPreconditionFailed)
|
||||
require.Equal(t, "AI provider is disabled.", sdkErr.Message)
|
||||
})
|
||||
|
||||
t.Run("RejectsLargeAPIKey", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
adminClient := newChatClient(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, adminClient.Client)
|
||||
memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID)
|
||||
memberClient := codersdk.NewExperimentalClient(memberClientRaw)
|
||||
|
||||
provider := createOpenAIProvider(t, adminClient, "test-large-user-key-"+uuid.NewString(), true)
|
||||
|
||||
_, err := memberClient.UpsertUserAIProviderKey(ctx, "me", provider.ID, codersdk.CreateUserAIProviderKeyRequest{APIKey: strings.Repeat("x", 10241)})
|
||||
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
|
||||
require.Equal(t, "API key too large.", sdkErr.Message)
|
||||
})
|
||||
|
||||
t.Run("RejectsWhitespaceAPIKey", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
adminClient := newChatClient(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, adminClient.Client)
|
||||
memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID)
|
||||
memberClient := codersdk.NewExperimentalClient(memberClientRaw)
|
||||
|
||||
provider := createOpenAIProvider(t, adminClient, "test-whitespace-user-key-"+uuid.NewString(), true)
|
||||
|
||||
_, err := memberClient.UpsertUserAIProviderKey(ctx, "me", provider.ID, codersdk.CreateUserAIProviderKeyRequest{APIKey: " "})
|
||||
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
|
||||
require.Equal(t, "API key must not contain leading or trailing whitespace.", sdkErr.Message)
|
||||
})
|
||||
|
||||
t.Run("BYOKDisabledRejectsUpsertAndAllowsDelete", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
values := chatDeploymentValues(t)
|
||||
values.AI.BridgeConfig.AllowBYOK = serpent.Bool(false)
|
||||
client := newChatClientWithDeploymentValues(t, values)
|
||||
_ = coderdtest.CreateFirstUser(t, client.Client)
|
||||
|
||||
provider := createOpenAIProvider(t, client, "test-byok-disabled-"+uuid.NewString(), true)
|
||||
|
||||
_, err := client.UpsertUserAIProviderKey(ctx, "me", provider.ID, codersdk.CreateUserAIProviderKeyRequest{APIKey: "test-user-api-key"})
|
||||
sdkErr := requireSDKError(t, err, http.StatusForbidden)
|
||||
require.Equal(t, "BYOK is disabled.", sdkErr.Message)
|
||||
|
||||
configs, err := client.ListUserAIProviderKeyConfigs(ctx, "me")
|
||||
require.NoError(t, err)
|
||||
cfg := findUserAIProviderKeyConfig(t, configs, provider.ID)
|
||||
require.NotNil(t, cfg)
|
||||
require.False(t, cfg.BYOKEnabled)
|
||||
require.NoError(t, client.DeleteUserAIProviderKey(ctx, "me", provider.ID))
|
||||
})
|
||||
}
|
||||
|
||||
func TestListChatProviders(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -2611,7 +2798,7 @@ func TestCreateChatProvider(t *testing.T) {
|
||||
})
|
||||
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
|
||||
require.Equal(t, "API key too large.", sdkErr.Message)
|
||||
require.Equal(t, fmt.Sprintf("API key exceeds maximum size of %d bytes", chatProviderAPIKeySizeLimit), sdkErr.Detail)
|
||||
require.Equal(t, fmt.Sprintf("API key exceeds maximum size of 10 KB (%d bytes)", chatProviderAPIKeySizeLimit), sdkErr.Detail)
|
||||
})
|
||||
|
||||
t.Run("AllowsMaxSizedAPIKey", func(t *testing.T) {
|
||||
@@ -2898,7 +3085,7 @@ func TestUpdateChatProvider(t *testing.T) {
|
||||
})
|
||||
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
|
||||
require.Equal(t, "API key too large.", sdkErr.Message)
|
||||
require.Equal(t, fmt.Sprintf("API key exceeds maximum size of %d bytes", chatProviderAPIKeySizeLimit), sdkErr.Detail)
|
||||
require.Equal(t, fmt.Sprintf("API key exceeds maximum size of 10 KB (%d bytes)", chatProviderAPIKeySizeLimit), sdkErr.Detail)
|
||||
})
|
||||
|
||||
t.Run("AllowsMaxSizedAPIKey", func(t *testing.T) {
|
||||
@@ -2961,11 +3148,13 @@ func TestDeleteChatProvider(t *testing.T) {
|
||||
AllowUserAPIKey: ptr.Ref(true),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
aiProviderToDelete := createAIProviderForTest(t, client, providerToDelete.Provider, "delete-api-key")
|
||||
|
||||
deleteContextLimit := int64(4096)
|
||||
deleteIsDefault := true
|
||||
configToDelete, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: providerToDelete.Provider,
|
||||
AIProviderID: &aiProviderToDelete.ID,
|
||||
Model: "gpt-4o-delete-provider",
|
||||
ContextLimit: &deleteContextLimit,
|
||||
IsDefault: &deleteIsDefault,
|
||||
@@ -2977,10 +3166,12 @@ func TestDeleteChatProvider(t *testing.T) {
|
||||
APIKey: "keep-api-key",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
keepAIProvider := createAIProviderForTest(t, client, keepProvider.Provider, "keep-api-key")
|
||||
|
||||
keepContextLimit := int64(8192)
|
||||
keepConfig, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: keepProvider.Provider,
|
||||
AIProviderID: &keepAIProvider.ID,
|
||||
Model: "claude-keep-provider",
|
||||
ContextLimit: &keepContextLimit,
|
||||
})
|
||||
@@ -3082,11 +3273,13 @@ func TestDeleteChatProvider(t *testing.T) {
|
||||
APIKey: "only-provider-api-key",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
aiProvider := createAIProviderForTest(t, client, provider.Provider, "only-provider-api-key")
|
||||
|
||||
contextLimit := int64(4096)
|
||||
isDefault := true
|
||||
config, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: provider.Provider,
|
||||
AIProviderID: &aiProvider.ID,
|
||||
Model: "gpt-4o-only-provider",
|
||||
ContextLimit: &contextLimit,
|
||||
IsDefault: &isDefault,
|
||||
@@ -3636,7 +3829,7 @@ func TestUpsertUserChatProviderKey(t *testing.T) {
|
||||
})
|
||||
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
|
||||
require.Equal(t, "API key too large.", sdkErr.Message)
|
||||
require.Equal(t, fmt.Sprintf("API key exceeds maximum size of %d bytes", chatProviderAPIKeySizeLimit), sdkErr.Detail)
|
||||
require.Equal(t, fmt.Sprintf("API key exceeds maximum size of 10 KB (%d bytes)", chatProviderAPIKeySizeLimit), sdkErr.Detail)
|
||||
})
|
||||
|
||||
t.Run("AllowsMaxSizedAPIKey", func(t *testing.T) {
|
||||
@@ -3695,16 +3888,13 @@ func TestListChatModelConfigs(t *testing.T) {
|
||||
client := newChatClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client.Client)
|
||||
|
||||
_, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
Provider: "openai",
|
||||
APIKey: "test-api-key",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
aiProvider := createAIProviderForTest(t, client, "openai", "test-api-key")
|
||||
|
||||
contextLimit := int64(4096)
|
||||
enabled := false
|
||||
disabledConfig, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: "openai",
|
||||
AIProviderID: &aiProvider.ID,
|
||||
Model: "gpt-4o-disabled",
|
||||
DisplayName: "GPT-4o Disabled",
|
||||
Enabled: &enabled,
|
||||
@@ -3741,6 +3931,7 @@ func TestListChatModelConfigs(t *testing.T) {
|
||||
enabled := false
|
||||
_, err := adminClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: enabledConfig.Provider,
|
||||
AIProviderID: enabledConfig.AIProviderID,
|
||||
Model: "gpt-4o-disabled",
|
||||
DisplayName: "GPT-4o Disabled",
|
||||
Enabled: &enabled,
|
||||
@@ -3762,15 +3953,12 @@ func TestListChatModelConfigs(t *testing.T) {
|
||||
client, db := newChatClientWithDatabase(t)
|
||||
firstUser := coderdtest.CreateFirstUser(t, client.Client)
|
||||
|
||||
_, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
Provider: "openai",
|
||||
APIKey: "test-api-key",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
aiProvider := createAIProviderForTest(t, client, "openai", "test-api-key")
|
||||
|
||||
legacyOptions := json.RawMessage(`{"input_price_per_million_tokens":0.15,"output_price_per_million_tokens":0.6,"cache_read_price_per_million_tokens":0.03,"cache_write_price_per_million_tokens":0.3}`)
|
||||
storedConfig := dbgen.ChatModelConfig(t, db, database.ChatModelConfig{
|
||||
Provider: "openai",
|
||||
AIProviderID: uuid.NullUUID{UUID: aiProvider.ID, Valid: true},
|
||||
Model: "gpt-4o-mini-legacy",
|
||||
DisplayName: "GPT-4o Mini Legacy",
|
||||
CreatedBy: uuid.NullUUID{UUID: firstUser.UserID, Valid: true},
|
||||
@@ -3831,11 +4019,7 @@ func TestCreateChatModelConfig(t *testing.T) {
|
||||
client := newChatClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client.Client)
|
||||
|
||||
_, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
Provider: "openai",
|
||||
APIKey: "test-api-key",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
aiProvider := createAIProviderForTest(t, client, "openai", "test-api-key")
|
||||
|
||||
contextLimit := int64(4096)
|
||||
isDefault := true
|
||||
@@ -3849,6 +4033,7 @@ func TestCreateChatModelConfig(t *testing.T) {
|
||||
}
|
||||
modelConfig, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: "openai",
|
||||
AIProviderID: &aiProvider.ID,
|
||||
Model: "gpt-4o-mini",
|
||||
ContextLimit: &contextLimit,
|
||||
IsDefault: &isDefault,
|
||||
@@ -3875,15 +4060,12 @@ func TestCreateChatModelConfig(t *testing.T) {
|
||||
client := newChatClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client.Client)
|
||||
|
||||
_, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
Provider: "openai",
|
||||
APIKey: "test-api-key",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
aiProvider := createAIProviderForTest(t, client, "openai", "test-api-key")
|
||||
|
||||
contextLimit := int64(4096)
|
||||
_, err = client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
_, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: "openai",
|
||||
AIProviderID: &aiProvider.ID,
|
||||
Model: "gpt-4o-mini",
|
||||
ContextLimit: &contextLimit,
|
||||
ModelConfig: &codersdk.ChatModelCallConfig{
|
||||
@@ -3907,16 +4089,18 @@ func TestCreateChatModelConfig(t *testing.T) {
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client.Client)
|
||||
aiProvider := createAIProviderForTest(t, client, "openai", "test-api-key")
|
||||
|
||||
_, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: "openai",
|
||||
Model: "gpt-4o-mini",
|
||||
Provider: "openai",
|
||||
AIProviderID: &aiProvider.ID,
|
||||
Model: "gpt-4o-mini",
|
||||
})
|
||||
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
|
||||
require.Equal(t, "Context limit is required.", sdkErr.Message)
|
||||
})
|
||||
|
||||
t.Run("ProviderNotConfigured", func(t *testing.T) {
|
||||
t.Run("AIProviderIDRequired", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
@@ -3930,7 +4114,94 @@ func TestCreateChatModelConfig(t *testing.T) {
|
||||
ContextLimit: &contextLimit,
|
||||
})
|
||||
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
|
||||
require.Equal(t, "Chat provider is not configured.", sdkErr.Message)
|
||||
require.Equal(t, "AI provider ID is required.", sdkErr.Message)
|
||||
})
|
||||
|
||||
t.Run("ProviderNotConfigured", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client.Client)
|
||||
|
||||
contextLimit := int64(4096)
|
||||
missingProviderID := uuid.New()
|
||||
_, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: "openai",
|
||||
AIProviderID: &missingProviderID,
|
||||
Model: "gpt-4o-mini",
|
||||
ContextLimit: &contextLimit,
|
||||
})
|
||||
sdkErr := requireSDKError(t, err, http.StatusPreconditionFailed)
|
||||
require.Equal(t, "AI provider is not configured.", sdkErr.Message)
|
||||
})
|
||||
|
||||
t.Run("WithAIProviderID", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client.Client)
|
||||
provider, err := client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{
|
||||
Type: codersdk.AIProviderTypeOpenAI,
|
||||
Name: "test-model-config-provider-" + uuid.NewString(),
|
||||
Enabled: true,
|
||||
BaseURL: "https://api.openai.com/v1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
contextLimit := int64(4096)
|
||||
modelConfig, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
AIProviderID: &provider.ID,
|
||||
Model: "gpt-4o-mini",
|
||||
ContextLimit: &contextLimit,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "openai", modelConfig.Provider)
|
||||
require.NotNil(t, modelConfig.AIProviderID)
|
||||
require.Equal(t, provider.ID, *modelConfig.AIProviderID)
|
||||
})
|
||||
|
||||
t.Run("AIProviderIDNotConfigured", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client.Client)
|
||||
|
||||
missingProviderID := uuid.New()
|
||||
contextLimit := int64(4096)
|
||||
_, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
AIProviderID: &missingProviderID,
|
||||
Model: "gpt-4o-mini",
|
||||
ContextLimit: &contextLimit,
|
||||
})
|
||||
sdkErr := requireSDKError(t, err, http.StatusPreconditionFailed)
|
||||
require.Equal(t, "AI provider is not configured.", sdkErr.Message)
|
||||
})
|
||||
|
||||
t.Run("AIProviderIDDisabled", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client.Client)
|
||||
provider, err := client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{
|
||||
Type: codersdk.AIProviderTypeOpenAI,
|
||||
Name: "test-disabled-model-provider-" + uuid.NewString(),
|
||||
Enabled: false,
|
||||
BaseURL: "https://api.openai.com/v1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
contextLimit := int64(4096)
|
||||
_, err = client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
AIProviderID: &provider.ID,
|
||||
Model: "gpt-4o-mini",
|
||||
ContextLimit: &contextLimit,
|
||||
})
|
||||
sdkErr := requireSDKError(t, err, http.StatusPreconditionFailed)
|
||||
require.Equal(t, "AI provider is disabled.", sdkErr.Message)
|
||||
})
|
||||
|
||||
t.Run("ForbiddenForOrganizationMember", func(t *testing.T) {
|
||||
@@ -3942,15 +4213,12 @@ func TestCreateChatModelConfig(t *testing.T) {
|
||||
memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID)
|
||||
memberClient := codersdk.NewExperimentalClient(memberClientRaw)
|
||||
|
||||
_, err := adminClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
Provider: "openai",
|
||||
APIKey: "test-api-key",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
aiProvider := createAIProviderForTest(t, adminClient, "openai", "test-api-key")
|
||||
|
||||
contextLimit := int64(4096)
|
||||
_, err = memberClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
_, err := memberClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: "openai",
|
||||
AIProviderID: &aiProvider.ID,
|
||||
Model: "gpt-4o-mini",
|
||||
ContextLimit: &contextLimit,
|
||||
})
|
||||
@@ -4041,16 +4309,13 @@ func TestUpdateChatModelConfig(t *testing.T) {
|
||||
memberClientRaw, _ := coderdtest.CreateAnotherUser(t, adminClient.Client, firstUser.OrganizationID)
|
||||
memberClient := codersdk.NewExperimentalClient(memberClientRaw)
|
||||
|
||||
_, err := adminClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
Provider: "openai",
|
||||
APIKey: "test-api-key",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
aiProvider := createAIProviderForTest(t, adminClient, "openai", "test-api-key")
|
||||
|
||||
contextLimit := int64(4096)
|
||||
enabled := false
|
||||
modelConfig, err := adminClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: "openai",
|
||||
AIProviderID: &aiProvider.ID,
|
||||
Model: "gpt-4o-reenable",
|
||||
DisplayName: "GPT-4o Re-enable",
|
||||
Enabled: &enabled,
|
||||
@@ -4115,6 +4380,100 @@ func TestUpdateChatModelConfig(t *testing.T) {
|
||||
)
|
||||
})
|
||||
|
||||
t.Run("UpdateAIProviderID", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client.Client)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
provider, err := client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{
|
||||
Type: codersdk.AIProviderTypeAnthropic,
|
||||
Name: "test-update-model-provider-" + uuid.NewString(),
|
||||
Enabled: true,
|
||||
BaseURL: "https://api.anthropic.com",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
updated, err := client.UpdateChatModelConfig(ctx, modelConfig.ID, codersdk.UpdateChatModelConfigRequest{
|
||||
AIProviderID: &provider.ID,
|
||||
Model: "claude-3-5-sonnet-latest",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "anthropic", updated.Provider)
|
||||
require.NotNil(t, updated.AIProviderID)
|
||||
require.Equal(t, provider.ID, *updated.AIProviderID)
|
||||
})
|
||||
|
||||
t.Run("UpdateProviderPreservesAIProviderIDWhenTypeUnchanged", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client.Client)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
provider, err := client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{
|
||||
Type: codersdk.AIProviderTypeAnthropic,
|
||||
Name: "test-preserve-model-provider-" + uuid.NewString(),
|
||||
Enabled: true,
|
||||
BaseURL: "https://api.anthropic.com",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
updated, err := client.UpdateChatModelConfig(ctx, modelConfig.ID, codersdk.UpdateChatModelConfigRequest{
|
||||
AIProviderID: &provider.ID,
|
||||
Model: "claude-3-5-sonnet-latest",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, updated.AIProviderID)
|
||||
|
||||
updated, err = client.UpdateChatModelConfig(ctx, modelConfig.ID, codersdk.UpdateChatModelConfigRequest{
|
||||
Provider: "anthropic",
|
||||
Model: "claude-3-5-haiku-latest",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, updated.AIProviderID)
|
||||
require.Equal(t, provider.ID, *updated.AIProviderID)
|
||||
})
|
||||
|
||||
t.Run("UpdateAIProviderIDNotConfigured", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client.Client)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
|
||||
missingProviderID := uuid.New()
|
||||
_, err := client.UpdateChatModelConfig(ctx, modelConfig.ID, codersdk.UpdateChatModelConfigRequest{
|
||||
AIProviderID: &missingProviderID,
|
||||
})
|
||||
sdkErr := requireSDKError(t, err, http.StatusPreconditionFailed)
|
||||
require.Equal(t, "AI provider is not configured.", sdkErr.Message)
|
||||
})
|
||||
|
||||
t.Run("UpdateAIProviderIDDisabled", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
client := newChatClient(t)
|
||||
_ = coderdtest.CreateFirstUser(t, client.Client)
|
||||
modelConfig := createChatModelConfig(t, client)
|
||||
provider, err := client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{
|
||||
Type: codersdk.AIProviderTypeOpenAI,
|
||||
Name: "test-update-disabled-model-provider-" + uuid.NewString(),
|
||||
Enabled: false,
|
||||
BaseURL: "https://api.openai.com/v1",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = client.UpdateChatModelConfig(ctx, modelConfig.ID, codersdk.UpdateChatModelConfigRequest{
|
||||
AIProviderID: &provider.ID,
|
||||
})
|
||||
sdkErr := requireSDKError(t, err, http.StatusPreconditionFailed)
|
||||
require.Equal(t, "AI provider is disabled.", sdkErr.Message)
|
||||
})
|
||||
|
||||
t.Run("ProviderNotConfigured", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -4126,7 +4485,7 @@ func TestUpdateChatModelConfig(t *testing.T) {
|
||||
_, err := client.UpdateChatModelConfig(ctx, modelConfig.ID, codersdk.UpdateChatModelConfigRequest{
|
||||
Provider: "anthropic",
|
||||
})
|
||||
sdkErr := requireSDKError(t, err, http.StatusBadRequest)
|
||||
sdkErr := requireSDKError(t, err, http.StatusPreconditionFailed)
|
||||
require.Equal(t, "Chat provider is not configured.", sdkErr.Message)
|
||||
})
|
||||
|
||||
@@ -4167,16 +4526,13 @@ func TestUpdateChatModelConfig(t *testing.T) {
|
||||
_ = coderdtest.CreateFirstUser(t, client.Client)
|
||||
defaultConfig := createChatModelConfig(t, client)
|
||||
|
||||
_, err := client.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
Provider: "anthropic",
|
||||
APIKey: "candidate-api-key",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
aiProvider := createAIProviderForTest(t, client, "anthropic", "candidate-api-key")
|
||||
|
||||
contextLimit := int64(4096)
|
||||
isDefault := false
|
||||
candidateConfig, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: "anthropic",
|
||||
AIProviderID: &aiProvider.ID,
|
||||
Model: "claude-3-5-sonnet",
|
||||
ContextLimit: &contextLimit,
|
||||
IsDefault: &isDefault,
|
||||
@@ -10495,6 +10851,38 @@ func TestWatchChatGitAuthz(t *testing.T) {
|
||||
require.Equal(t, http.StatusForbidden, res.StatusCode)
|
||||
}
|
||||
|
||||
func createAIProviderForTest(
|
||||
t testing.TB,
|
||||
client *codersdk.ExperimentalClient,
|
||||
provider string,
|
||||
apiKey string,
|
||||
) codersdk.AIProvider {
|
||||
t.Helper()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
req := codersdk.CreateAIProviderRequest{
|
||||
Type: codersdk.AIProviderType(provider),
|
||||
Name: "test-" + provider + "-" + uuid.NewString(),
|
||||
BaseURL: aiProviderBaseURLForTest(provider),
|
||||
Enabled: true,
|
||||
}
|
||||
if apiKey != "" {
|
||||
req.APIKeys = []string{apiKey}
|
||||
}
|
||||
aiProvider, err := client.CreateAIProvider(ctx, req)
|
||||
require.NoError(t, err)
|
||||
return aiProvider
|
||||
}
|
||||
|
||||
func aiProviderBaseURLForTest(provider string) string {
|
||||
switch provider {
|
||||
case "anthropic", "bedrock", "google":
|
||||
return "https://api.example.com"
|
||||
default:
|
||||
return "https://api.example.com/v1"
|
||||
}
|
||||
}
|
||||
|
||||
func createChatModelConfig(t testing.TB, client *codersdk.ExperimentalClient) codersdk.ChatModelConfig {
|
||||
t.Helper()
|
||||
return coderdtest.CreateOpenAICompatChatModelConfig(t, client, "")
|
||||
@@ -10529,10 +10917,12 @@ func createAdditionalChatModelConfig(
|
||||
t.Helper()
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitLong)
|
||||
aiProvider := createAIProviderForTest(t, client, provider, "test-api-key")
|
||||
contextLimit := int64(4096)
|
||||
isDefault := false
|
||||
modelConfig, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: provider,
|
||||
AIProviderID: &aiProvider.ID,
|
||||
Model: model,
|
||||
ContextLimit: &contextLimit,
|
||||
IsDefault: &isDefault,
|
||||
|
||||
@@ -84,6 +84,7 @@ const (
|
||||
SubjectTypeBoundaryUsageTracker SubjectType = "boundary_usage_tracker"
|
||||
SubjectTypeWorkspaceBuilder SubjectType = "workspace_builder"
|
||||
SubjectTypeChatd SubjectType = "chatd"
|
||||
SubjectTypeAIProviderMetadataReader SubjectType = "ai_provider_metadata_reader"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -350,13 +350,13 @@ func (p *Server) resolveAdvisorModelOverride(
|
||||
return fallbackModel, fallbackCallConfig
|
||||
}
|
||||
|
||||
// GetEnabledChatModelConfigByID joins on chat_providers.enabled = TRUE
|
||||
// and chat_model_configs.enabled = TRUE, so it returns sql.ErrNoRows
|
||||
// the moment an admin disables either the model config or its provider.
|
||||
// Using the cached ModelConfigByID here would keep resolving an override
|
||||
// whose provider was just disabled, and an env or central fallback key
|
||||
// would let ModelFromConfig succeed, silently routing advisor prompts
|
||||
// to a provider the admin expects to be off.
|
||||
// GetEnabledChatModelConfigByID checks the model config and referenced
|
||||
// provider enabled state, so it returns sql.ErrNoRows the moment an
|
||||
// admin disables either one. Using the cached ModelConfigByID here
|
||||
// would keep resolving an override whose provider was just disabled,
|
||||
// and an available fallback key would let ModelFromConfig succeed,
|
||||
// silently routing advisor prompts to a provider the admin expects to
|
||||
// be off.
|
||||
overrideConfig, err := p.db.GetEnabledChatModelConfigByID(
|
||||
ctx,
|
||||
advisorCfg.ModelConfigID,
|
||||
|
||||
@@ -288,22 +288,7 @@ func TestSubagentChatExcludesWorkspaceProvisioningTools(t *testing.T) {
|
||||
)
|
||||
})
|
||||
|
||||
_, err := expClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
Provider: "openai-compat",
|
||||
APIKey: "test-api-key",
|
||||
BaseURL: openAIURL,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
contextLimit := int64(4096)
|
||||
isDefault := true
|
||||
_, err = expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: "openai-compat",
|
||||
Model: "gpt-4o-mini",
|
||||
ContextLimit: &contextLimit,
|
||||
IsDefault: &isDefault,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
coderdtest.CreateOpenAICompatChatModelConfig(t, expClient, openAIURL)
|
||||
|
||||
// Create a root chat whose first model call will spawn a subagent.
|
||||
chat, err := expClient.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
@@ -483,22 +468,7 @@ func TestPlanModeSubagentChatExcludesAskUserQuestion(t *testing.T) {
|
||||
)
|
||||
})
|
||||
|
||||
_, err = expClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
Provider: "openai-compat",
|
||||
APIKey: "test-api-key",
|
||||
BaseURL: openAIURL,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
contextLimit := int64(4096)
|
||||
isDefault := true
|
||||
_, err = expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: "openai-compat",
|
||||
Model: "gpt-4o-mini",
|
||||
ContextLimit: &contextLimit,
|
||||
IsDefault: &isDefault,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
coderdtest.CreateOpenAICompatChatModelConfig(t, expClient, openAIURL)
|
||||
|
||||
chat, err := expClient.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
OrganizationID: user.OrganizationID,
|
||||
@@ -638,24 +608,9 @@ func TestExploreSubagentIsReadOnly(t *testing.T) {
|
||||
)
|
||||
})
|
||||
|
||||
_, err := expClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
Provider: "openai-compat",
|
||||
APIKey: "test-api-key",
|
||||
BaseURL: openAIURL,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
coderdtest.CreateOpenAICompatChatModelConfig(t, expClient, openAIURL)
|
||||
|
||||
contextLimit := int64(4096)
|
||||
isDefault := true
|
||||
_, err = expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: "openai-compat",
|
||||
Model: "gpt-4o-mini",
|
||||
ContextLimit: &contextLimit,
|
||||
IsDefault: &isDefault,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = expClient.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
_, err := expClient.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
OrganizationID: user.OrganizationID,
|
||||
WorkspaceID: &workspace.ID,
|
||||
Content: []codersdk.ChatInputPart{
|
||||
@@ -4953,22 +4908,7 @@ func TestCreateWorkspaceTool_EndToEnd(t *testing.T) {
|
||||
)
|
||||
})
|
||||
|
||||
_, err := expClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
Provider: "openai-compat",
|
||||
APIKey: "test-api-key",
|
||||
BaseURL: openAIURL,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
contextLimit := int64(4096)
|
||||
isDefault := true
|
||||
_, err = expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: "openai-compat",
|
||||
Model: "gpt-4o-mini",
|
||||
ContextLimit: &contextLimit,
|
||||
IsDefault: &isDefault,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
coderdtest.CreateOpenAICompatChatModelConfig(t, expClient, openAIURL)
|
||||
|
||||
chat, err := expClient.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
OrganizationID: user.OrganizationID,
|
||||
@@ -5123,22 +5063,7 @@ func TestStartWorkspaceTool_EndToEnd(t *testing.T) {
|
||||
)
|
||||
})
|
||||
|
||||
_, err := expClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
Provider: "openai-compat",
|
||||
APIKey: "test-api-key",
|
||||
BaseURL: openAIURL,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
contextLimit := int64(4096)
|
||||
isDefault := true
|
||||
_, err = expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: "openai-compat",
|
||||
Model: "gpt-4o-mini",
|
||||
ContextLimit: &contextLimit,
|
||||
IsDefault: &isDefault,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
coderdtest.CreateOpenAICompatChatModelConfig(t, expClient, openAIURL)
|
||||
|
||||
// Create a chat with the stopped workspace pre-associated.
|
||||
chat, err := expClient.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
@@ -8586,22 +8511,7 @@ func TestAgentContextFilesAndSkillsLoadedIntoChat(t *testing.T) {
|
||||
)
|
||||
})
|
||||
|
||||
_, err := expClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
Provider: "openai-compat",
|
||||
APIKey: "test-api-key",
|
||||
BaseURL: openAIURL,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
contextLimit := int64(4096)
|
||||
isDefault := true
|
||||
_, err = expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: "openai-compat",
|
||||
Model: "gpt-4o-mini",
|
||||
ContextLimit: &contextLimit,
|
||||
IsDefault: &isDefault,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
coderdtest.CreateOpenAICompatChatModelConfig(t, expClient, openAIURL)
|
||||
|
||||
workspaceID := workspace.ID
|
||||
chat, err := expClient.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||
@@ -13,6 +14,52 @@ import (
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func createIntegrationAIProvider(
|
||||
ctx context.Context,
|
||||
t testing.TB,
|
||||
client *codersdk.ExperimentalClient,
|
||||
providerType codersdk.AIProviderType,
|
||||
apiKey string,
|
||||
baseURL string,
|
||||
) codersdk.AIProvider {
|
||||
t.Helper()
|
||||
if baseURL == "" {
|
||||
baseURL = defaultIntegrationAIProviderBaseURL(providerType)
|
||||
}
|
||||
provider, err := client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{
|
||||
Type: providerType,
|
||||
Name: string(providerType) + "-" + uuid.NewString(),
|
||||
DisplayName: aiProviderDisplayName(providerType),
|
||||
Enabled: true,
|
||||
BaseURL: baseURL,
|
||||
APIKeys: []string{apiKey},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return provider
|
||||
}
|
||||
|
||||
func defaultIntegrationAIProviderBaseURL(providerType codersdk.AIProviderType) string {
|
||||
switch providerType {
|
||||
case codersdk.AIProviderTypeAnthropic:
|
||||
return "https://api.anthropic.com"
|
||||
case codersdk.AIProviderTypeOpenAI:
|
||||
return "https://api.openai.com/v1"
|
||||
default:
|
||||
return "https://api.example.com"
|
||||
}
|
||||
}
|
||||
|
||||
func aiProviderDisplayName(providerType codersdk.AIProviderType) string {
|
||||
switch providerType {
|
||||
case codersdk.AIProviderTypeAnthropic:
|
||||
return "Anthropic"
|
||||
case codersdk.AIProviderTypeOpenAI:
|
||||
return "OpenAI"
|
||||
default:
|
||||
return string(providerType)
|
||||
}
|
||||
}
|
||||
|
||||
// TestAnthropicWebSearchRoundTrip is an integration test that verifies
|
||||
// provider-executed tool results (web_search) survive the full
|
||||
// persist → reconstruct → re-send cycle. It sends a query that
|
||||
@@ -43,19 +90,16 @@ func TestAnthropicWebSearchRoundTrip(t *testing.T) {
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
expClient := codersdk.NewExperimentalClient(client)
|
||||
|
||||
// Configure an Anthropic provider with the real API key.
|
||||
_, err := expClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
Provider: "anthropic",
|
||||
APIKey: apiKey,
|
||||
BaseURL: baseURL,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
provider := createIntegrationAIProvider(
|
||||
ctx, t, expClient, codersdk.AIProviderTypeAnthropic, apiKey, baseURL,
|
||||
)
|
||||
|
||||
// Create a model config that enables web_search.
|
||||
contextLimit := int64(200000)
|
||||
isDefault := true
|
||||
_, err = expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: "anthropic",
|
||||
_, err := expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: string(provider.Type),
|
||||
AIProviderID: &provider.ID,
|
||||
Model: "claude-sonnet-4-20250514",
|
||||
ContextLimit: &contextLimit,
|
||||
IsDefault: &isDefault,
|
||||
@@ -303,13 +347,9 @@ func TestOpenAIReasoningRoundTrip(t *testing.T) {
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
expClient := codersdk.NewExperimentalClient(client)
|
||||
|
||||
// Configure an OpenAI provider with the real API key.
|
||||
_, err := expClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
Provider: "openai",
|
||||
APIKey: apiKey,
|
||||
BaseURL: baseURL,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
provider := createIntegrationAIProvider(
|
||||
ctx, t, expClient, codersdk.AIProviderTypeOpenAI, apiKey, baseURL,
|
||||
)
|
||||
|
||||
// Create a model config for a reasoning model with Store: true
|
||||
// (the default). Using o4-mini because it always produces
|
||||
@@ -317,8 +357,9 @@ func TestOpenAIReasoningRoundTrip(t *testing.T) {
|
||||
contextLimit := int64(200000)
|
||||
isDefault := true
|
||||
reasoningSummary := "auto"
|
||||
_, err = expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: "openai",
|
||||
_, err := expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: string(provider.Type),
|
||||
AIProviderID: &provider.ID,
|
||||
Model: "o4-mini",
|
||||
ContextLimit: &contextLimit,
|
||||
IsDefault: &isDefault,
|
||||
@@ -457,21 +498,18 @@ func TestOpenAIReasoningRoundTripStoreFalse(t *testing.T) {
|
||||
user := coderdtest.CreateFirstUser(t, client)
|
||||
expClient := codersdk.NewExperimentalClient(client)
|
||||
|
||||
// Configure an OpenAI provider with the real API key.
|
||||
_, err := expClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
Provider: "openai",
|
||||
APIKey: apiKey,
|
||||
BaseURL: baseURL,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
provider := createIntegrationAIProvider(
|
||||
ctx, t, expClient, codersdk.AIProviderTypeOpenAI, apiKey, baseURL,
|
||||
)
|
||||
|
||||
// Create a model config for a reasoning model with Store: false.
|
||||
// Using o4-mini because it always produces reasoning items.
|
||||
contextLimit := int64(200000)
|
||||
isDefault := true
|
||||
reasoningSummary := "auto"
|
||||
_, err = expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: "openai",
|
||||
_, err := expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: string(provider.Type),
|
||||
AIProviderID: &provider.ID,
|
||||
Model: "o4-mini",
|
||||
ContextLimit: &contextLimit,
|
||||
IsDefault: &isDefault,
|
||||
|
||||
+10
-9
@@ -218,9 +218,7 @@ func (req CreateAIProviderRequest) Validate() []ValidationError {
|
||||
})
|
||||
}
|
||||
validations = append(validations, validateAIProviderName(req.Name)...)
|
||||
if req.BaseURL != "" {
|
||||
validations = append(validations, validateAIProviderBaseURL(req.BaseURL)...)
|
||||
}
|
||||
validations = append(validations, validateRequiredAIProviderBaseURL(req.BaseURL)...)
|
||||
validations = append(validations, validateAIProviderAPIKeys(req.APIKeys)...)
|
||||
if req.Settings.Bedrock != nil && req.Type != AIProviderTypeAnthropic {
|
||||
validations = append(validations, ValidationError{
|
||||
@@ -264,12 +262,8 @@ type AIProviderKeyMutation struct {
|
||||
// should reject empty patches with IsEmpty before invoking Validate.
|
||||
func (req UpdateAIProviderRequest) Validate() []ValidationError {
|
||||
var validations []ValidationError
|
||||
switch {
|
||||
case req.BaseURL == nil:
|
||||
case *req.BaseURL == "":
|
||||
validations = append(validations, ValidationError{Field: "base_url", Detail: "base_url cannot be empty"})
|
||||
default:
|
||||
validations = append(validations, validateAIProviderBaseURL(*req.BaseURL)...)
|
||||
if req.BaseURL != nil {
|
||||
validations = append(validations, validateRequiredAIProviderBaseURL(*req.BaseURL)...)
|
||||
}
|
||||
if req.APIKeys != nil {
|
||||
validations = append(validations, validateAIProviderKeyMutations(*req.APIKeys)...)
|
||||
@@ -296,6 +290,13 @@ func validateAIProviderName(name string) []ValidationError {
|
||||
return validations
|
||||
}
|
||||
|
||||
func validateRequiredAIProviderBaseURL(raw string) []ValidationError {
|
||||
if raw == "" {
|
||||
return []ValidationError{{Field: "base_url", Detail: "base_url is required"}}
|
||||
}
|
||||
return validateAIProviderBaseURL(raw)
|
||||
}
|
||||
|
||||
func validateAIProviderBaseURL(raw string) []ValidationError {
|
||||
var validations []ValidationError
|
||||
parsed, err := url.Parse(raw)
|
||||
|
||||
+75
-1
@@ -1120,6 +1120,31 @@ type UpdateChatProviderConfigRequest struct {
|
||||
AllowCentralAPIKeyFallback *bool `json:"allow_central_api_key_fallback,omitempty"`
|
||||
}
|
||||
|
||||
// AIProviderSummary is provider metadata embedded in other API responses.
|
||||
type AIProviderSummary struct {
|
||||
ID uuid.UUID `json:"id" format:"uuid"`
|
||||
Type AIProviderType `json:"type"`
|
||||
Name string `json:"name"`
|
||||
DisplayName string `json:"display_name"`
|
||||
Enabled bool `json:"enabled"`
|
||||
Deleted bool `json:"deleted"`
|
||||
}
|
||||
|
||||
// UserAIProviderKeyConfig is a provider summary from the current user's
|
||||
// perspective. It reports key presence but never returns key material.
|
||||
type UserAIProviderKeyConfig struct {
|
||||
Provider AIProviderSummary `json:"provider"`
|
||||
HasUserAPIKey bool `json:"has_user_api_key"`
|
||||
HasProviderAPIKey bool `json:"has_provider_api_key"`
|
||||
BYOKEnabled bool `json:"byok_enabled"`
|
||||
}
|
||||
|
||||
// CreateUserAIProviderKeyRequest creates or replaces a user's API key
|
||||
// for an AI provider.
|
||||
type CreateUserAIProviderKeyRequest struct {
|
||||
APIKey string `json:"api_key"`
|
||||
}
|
||||
|
||||
// UserChatProviderConfig is a summary of a provider that allows
|
||||
// user-supplied keys, as seen from the current user's perspective.
|
||||
type UserChatProviderConfig struct {
|
||||
@@ -1128,6 +1153,7 @@ type UserChatProviderConfig struct {
|
||||
DisplayName string `json:"display_name"`
|
||||
HasUserAPIKey bool `json:"has_user_api_key"`
|
||||
HasCentralAPIKeyFallback bool `json:"has_central_api_key_fallback"`
|
||||
BYOKEnabled bool `json:"byok_enabled"`
|
||||
}
|
||||
|
||||
// CreateUserChatProviderKeyRequest creates or replaces a user's API key
|
||||
@@ -1140,6 +1166,7 @@ type CreateUserChatProviderKeyRequest struct {
|
||||
type ChatModelConfig struct {
|
||||
ID uuid.UUID `json:"id" format:"uuid"`
|
||||
Provider string `json:"provider"`
|
||||
AIProviderID *uuid.UUID `json:"ai_provider_id,omitempty" format:"uuid"`
|
||||
Model string `json:"model"`
|
||||
DisplayName string `json:"display_name"`
|
||||
Enabled bool `json:"enabled"`
|
||||
@@ -1349,7 +1376,8 @@ func (c *ChatModelCallConfig) UnmarshalJSON(data []byte) error {
|
||||
|
||||
// CreateChatModelConfigRequest creates a chat model config.
|
||||
type CreateChatModelConfigRequest struct {
|
||||
Provider string `json:"provider"`
|
||||
Provider string `json:"provider,omitempty"`
|
||||
AIProviderID *uuid.UUID `json:"ai_provider_id,omitempty" format:"uuid"`
|
||||
Model string `json:"model"`
|
||||
DisplayName string `json:"display_name,omitempty"`
|
||||
Enabled *bool `json:"enabled,omitempty"`
|
||||
@@ -1362,6 +1390,7 @@ type CreateChatModelConfigRequest struct {
|
||||
// UpdateChatModelConfigRequest updates a chat model config.
|
||||
type UpdateChatModelConfigRequest struct {
|
||||
Provider string `json:"provider,omitempty"`
|
||||
AIProviderID *uuid.UUID `json:"ai_provider_id,omitempty" format:"uuid"`
|
||||
Model string `json:"model,omitempty"`
|
||||
DisplayName string `json:"display_name,omitempty"`
|
||||
Enabled *bool `json:"enabled,omitempty"`
|
||||
@@ -2116,6 +2145,51 @@ func (c *ExperimentalClient) DeleteChatProvider(ctx context.Context, providerID
|
||||
return nil
|
||||
}
|
||||
|
||||
// ListUserAIProviderKeyConfigs returns user-scoped AI provider key configs.
|
||||
func (c *ExperimentalClient) ListUserAIProviderKeyConfigs(ctx context.Context, user string) ([]UserAIProviderKeyConfig, error) {
|
||||
res, err := c.Request(ctx, http.MethodGet, userAIProviderKeysPath(user), nil)
|
||||
if err != nil {
|
||||
return nil, xerrors.Errorf("list user AI provider key configs: %w", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return nil, ReadBodyAsError(res)
|
||||
}
|
||||
var configs []UserAIProviderKeyConfig
|
||||
return configs, json.NewDecoder(res.Body).Decode(&configs)
|
||||
}
|
||||
|
||||
// UpsertUserAIProviderKey creates or replaces a user API key for an AI provider.
|
||||
func (c *ExperimentalClient) UpsertUserAIProviderKey(ctx context.Context, user string, providerID uuid.UUID, req CreateUserAIProviderKeyRequest) (UserAIProviderKeyConfig, error) {
|
||||
res, err := c.Request(ctx, http.MethodPut, fmt.Sprintf("%s/%s", userAIProviderKeysPath(user), providerID), req)
|
||||
if err != nil {
|
||||
return UserAIProviderKeyConfig{}, xerrors.Errorf("upsert user AI provider key: %w", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusOK {
|
||||
return UserAIProviderKeyConfig{}, ReadBodyAsError(res)
|
||||
}
|
||||
var config UserAIProviderKeyConfig
|
||||
return config, json.NewDecoder(res.Body).Decode(&config)
|
||||
}
|
||||
|
||||
// DeleteUserAIProviderKey deletes a user API key for an AI provider.
|
||||
func (c *ExperimentalClient) DeleteUserAIProviderKey(ctx context.Context, user string, providerID uuid.UUID) error {
|
||||
res, err := c.Request(ctx, http.MethodDelete, fmt.Sprintf("%s/%s", userAIProviderKeysPath(user), providerID), nil)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("delete user AI provider key: %w", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != http.StatusNoContent {
|
||||
return ReadBodyAsError(res)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func userAIProviderKeysPath(user string) string {
|
||||
return fmt.Sprintf("/api/experimental/users/%s/ai-provider-keys", url.PathEscape(user))
|
||||
}
|
||||
|
||||
// ListUserChatProviderConfigs returns user-scoped chat provider configs.
|
||||
func (c *ExperimentalClient) ListUserChatProviderConfigs(ctx context.Context) ([]UserChatProviderConfig, error) {
|
||||
res, err := c.Request(ctx, http.MethodGet, "/api/experimental/chats/user-provider-configs", nil)
|
||||
|
||||
@@ -24,6 +24,47 @@ import (
|
||||
"github.com/coder/websocket"
|
||||
)
|
||||
|
||||
func createOpenAIProviderForTest(
|
||||
ctx context.Context,
|
||||
t testing.TB,
|
||||
client *codersdk.ExperimentalClient,
|
||||
apiKey string,
|
||||
baseURL string,
|
||||
) codersdk.AIProvider {
|
||||
t.Helper()
|
||||
provider, err := client.CreateAIProvider(ctx, codersdk.CreateAIProviderRequest{
|
||||
Type: codersdk.AIProviderTypeOpenAI,
|
||||
Name: "openai-" + uuid.NewString(),
|
||||
DisplayName: "OpenAI",
|
||||
Enabled: true,
|
||||
BaseURL: baseURL,
|
||||
APIKeys: []string{apiKey},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return provider
|
||||
}
|
||||
|
||||
func createOpenAIModelConfigForTest(
|
||||
ctx context.Context,
|
||||
t testing.TB,
|
||||
client *codersdk.ExperimentalClient,
|
||||
apiKey string,
|
||||
baseURL string,
|
||||
) codersdk.ChatModelConfig {
|
||||
t.Helper()
|
||||
provider := createOpenAIProviderForTest(ctx, t, client, apiKey, baseURL)
|
||||
model, err := client.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: string(provider.Type),
|
||||
AIProviderID: &provider.ID,
|
||||
Model: "gpt-4",
|
||||
DisplayName: "GPT-4",
|
||||
ContextLimit: ptr.Ref(int64(1000)),
|
||||
CompressionThreshold: ptr.Ref(int32(70)),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
return model
|
||||
}
|
||||
|
||||
func TestChatStreamRelay(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
@@ -74,27 +115,11 @@ func TestChatStreamRelay(t *testing.T) {
|
||||
return chattest.OpenAINonStreamingResponse("ok")
|
||||
})
|
||||
|
||||
//nolint:gocritic // Test uses owner client to configure chat providers.
|
||||
provider, err := codersdk.NewExperimentalClient(firstClient).CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test",
|
||||
BaseURL: openai,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, codersdk.ChatProviderConfigSourceDatabase, provider.Source)
|
||||
|
||||
model, err := codersdk.NewExperimentalClient(firstClient).CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: provider.Provider,
|
||||
Model: "gpt-4",
|
||||
DisplayName: "GPT-4",
|
||||
ContextLimit: &[]int64{1000}[0],
|
||||
CompressionThreshold: &[]int32{70}[0],
|
||||
})
|
||||
require.NoError(t, err)
|
||||
expClient := codersdk.NewExperimentalClient(firstClient)
|
||||
model := createOpenAIModelConfigForTest(ctx, t, expClient, "test", openai)
|
||||
|
||||
// Create a chat on the first replica
|
||||
chat, err := codersdk.NewExperimentalClient(firstClient).CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
chat, err := expClient.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
OrganizationID: firstUser.OrganizationID,
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
@@ -264,26 +289,11 @@ func TestChatStreamRelay(t *testing.T) {
|
||||
return chattest.OpenAINonStreamingResponse("ok")
|
||||
})
|
||||
|
||||
//nolint:gocritic // Test uses owner client to configure chat providers.
|
||||
provider, err := codersdk.NewExperimentalClient(firstClient).CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test",
|
||||
BaseURL: openai,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
model, err := codersdk.NewExperimentalClient(firstClient).CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: provider.Provider,
|
||||
Model: "gpt-4",
|
||||
DisplayName: "GPT-4",
|
||||
ContextLimit: &[]int64{1000}[0],
|
||||
CompressionThreshold: &[]int32{70}[0],
|
||||
})
|
||||
require.NoError(t, err)
|
||||
expClient := codersdk.NewExperimentalClient(firstClient)
|
||||
model := createOpenAIModelConfigForTest(ctx, t, expClient, "test", openai)
|
||||
|
||||
// Create a chat on the first replica.
|
||||
chat, err := codersdk.NewExperimentalClient(firstClient).CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
chat, err := expClient.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
OrganizationID: firstUser.OrganizationID,
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
@@ -435,25 +445,10 @@ func TestChatStreamRelay(t *testing.T) {
|
||||
return chattest.OpenAINonStreamingResponse("ok")
|
||||
})
|
||||
|
||||
//nolint:gocritic // Test uses owner client to configure providers.
|
||||
provider, err := codersdk.NewExperimentalClient(firstClient).CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test",
|
||||
BaseURL: openai,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
expClient := codersdk.NewExperimentalClient(firstClient)
|
||||
model := createOpenAIModelConfigForTest(ctx, t, expClient, "test", openai)
|
||||
|
||||
model, err := codersdk.NewExperimentalClient(firstClient).CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: provider.Provider,
|
||||
Model: "gpt-4",
|
||||
DisplayName: "GPT-4",
|
||||
ContextLimit: &[]int64{1000}[0],
|
||||
CompressionThreshold: &[]int32{70}[0],
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err := codersdk.NewExperimentalClient(firstClient).CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
chat, err := expClient.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
OrganizationID: firstUser.OrganizationID,
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
@@ -607,25 +602,10 @@ func TestChatStreamRelay(t *testing.T) {
|
||||
return chattest.OpenAINonStreamingResponse("ok")
|
||||
})
|
||||
|
||||
//nolint:gocritic // Test uses owner client to configure providers.
|
||||
provider, err := codersdk.NewExperimentalClient(firstClient).CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test",
|
||||
BaseURL: openai,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
expClient := codersdk.NewExperimentalClient(firstClient)
|
||||
model := createOpenAIModelConfigForTest(ctx, t, expClient, "test", openai)
|
||||
|
||||
model, err := codersdk.NewExperimentalClient(firstClient).CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: provider.Provider,
|
||||
Model: "gpt-4",
|
||||
DisplayName: "GPT-4",
|
||||
ContextLimit: &[]int64{1000}[0],
|
||||
CompressionThreshold: &[]int32{70}[0],
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
chat, err := codersdk.NewExperimentalClient(firstClient).CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
chat, err := expClient.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
OrganizationID: firstUser.OrganizationID,
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
@@ -754,26 +734,11 @@ func TestChatStreamRelay(t *testing.T) {
|
||||
return chattest.OpenAINonStreamingResponse("ok")
|
||||
})
|
||||
|
||||
//nolint:gocritic // Test uses owner client to configure chat providers.
|
||||
provider, err := codersdk.NewExperimentalClient(firstClient).CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test",
|
||||
BaseURL: openai,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
model, err := codersdk.NewExperimentalClient(firstClient).CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: provider.Provider,
|
||||
Model: "gpt-4",
|
||||
DisplayName: "GPT-4",
|
||||
ContextLimit: &[]int64{1000}[0],
|
||||
CompressionThreshold: &[]int32{70}[0],
|
||||
})
|
||||
require.NoError(t, err)
|
||||
expClient := codersdk.NewExperimentalClient(firstClient)
|
||||
model := createOpenAIModelConfigForTest(ctx, t, expClient, "test", openai)
|
||||
|
||||
// Create a chat on the first replica.
|
||||
chat, err := codersdk.NewExperimentalClient(firstClient).CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
chat, err := expClient.CreateChat(ctx, codersdk.CreateChatRequest{
|
||||
OrganizationID: firstUser.OrganizationID,
|
||||
Content: []codersdk.ChatInputPart{{
|
||||
Type: codersdk.ChatInputPartTypeText,
|
||||
@@ -954,17 +919,7 @@ func TestChatModelConfigDefault(t *testing.T) {
|
||||
client, _ := coderdenttest.New(t, nil)
|
||||
expClient := codersdk.NewExperimentalClient(client)
|
||||
|
||||
//nolint:gocritic // Test uses owner client to configure chat providers.
|
||||
provider, err := expClient.CreateChatProvider(
|
||||
ctx,
|
||||
codersdk.CreateChatProviderConfigRequest{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test",
|
||||
BaseURL: "https://example.com",
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
provider := createOpenAIProviderForTest(ctx, t, expClient, "test", "https://example.com")
|
||||
|
||||
contextLimit := int64(1000)
|
||||
compressionThreshold := int32(70)
|
||||
@@ -974,7 +929,8 @@ func TestChatModelConfigDefault(t *testing.T) {
|
||||
firstModel, err := expClient.CreateChatModelConfig(
|
||||
ctx,
|
||||
codersdk.CreateChatModelConfigRequest{
|
||||
Provider: provider.Provider,
|
||||
Provider: string(provider.Type),
|
||||
AIProviderID: &provider.ID,
|
||||
Model: "gpt-5-a",
|
||||
DisplayName: "GPT 5 A",
|
||||
IsDefault: &trueValue,
|
||||
@@ -988,7 +944,8 @@ func TestChatModelConfigDefault(t *testing.T) {
|
||||
secondModel, err := expClient.CreateChatModelConfig(
|
||||
ctx,
|
||||
codersdk.CreateChatModelConfigRequest{
|
||||
Provider: provider.Provider,
|
||||
Provider: string(provider.Type),
|
||||
AIProviderID: &provider.ID,
|
||||
Model: "gpt-5-b",
|
||||
DisplayName: "GPT 5 B",
|
||||
IsDefault: &trueValue,
|
||||
@@ -1115,16 +1072,10 @@ func TestCreateChatNonDefaultOrg(t *testing.T) {
|
||||
})
|
||||
expClient := codersdk.NewExperimentalClient(client)
|
||||
|
||||
// Set up a chat provider and model config.
|
||||
provider, err := expClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-key",
|
||||
BaseURL: "https://example.com",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: provider.Provider,
|
||||
provider := createOpenAIProviderForTest(ctx, t, expClient, "test-key", "https://example.com")
|
||||
_, err := expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: string(provider.Type),
|
||||
AIProviderID: &provider.ID,
|
||||
Model: "gpt-4o-mini",
|
||||
DisplayName: "Test Model",
|
||||
IsDefault: ptr.Ref(true),
|
||||
@@ -1191,16 +1142,10 @@ func TestListChats_OrgAdminOnlySeesOwnChats(t *testing.T) {
|
||||
})
|
||||
expClient := codersdk.NewExperimentalClient(client)
|
||||
|
||||
// Set up a chat provider and model config.
|
||||
provider, err := expClient.CreateChatProvider(ctx, codersdk.CreateChatProviderConfigRequest{
|
||||
Provider: "openai",
|
||||
DisplayName: "OpenAI",
|
||||
APIKey: "test-key",
|
||||
BaseURL: "https://example.com",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: provider.Provider,
|
||||
provider := createOpenAIProviderForTest(ctx, t, expClient, "test-key", "https://example.com")
|
||||
_, err := expClient.CreateChatModelConfig(ctx, codersdk.CreateChatModelConfigRequest{
|
||||
Provider: string(provider.Type),
|
||||
AIProviderID: &provider.ID,
|
||||
Model: "gpt-4o-mini",
|
||||
DisplayName: "Test Model",
|
||||
IsDefault: ptr.Ref(true),
|
||||
|
||||
@@ -408,6 +408,7 @@ export type DeploymentConfig = Readonly<{
|
||||
}>;
|
||||
|
||||
const chatProviderConfigsPath = "/api/experimental/chats/providers";
|
||||
const aiProviderConfigsPath = "/api/v2/ai/providers";
|
||||
const chatModelConfigsPath = "/api/experimental/chats/model-configs";
|
||||
const userChatProviderConfigsPath =
|
||||
"/api/experimental/chats/user-provider-configs";
|
||||
@@ -415,6 +416,8 @@ const userSkillsPath = (user: string) =>
|
||||
`/api/experimental/users/${encodeURIComponent(user)}/skills`;
|
||||
const userSkillPath = (user: string, name: string) =>
|
||||
`${userSkillsPath(user)}/${encodeURIComponent(name)}`;
|
||||
const userAIProviderKeysPath = (user = "me") =>
|
||||
`/api/experimental/users/${encodeURIComponent(user)}/ai-provider-keys`;
|
||||
const mcpServerConfigsPath = "/api/experimental/mcp/servers";
|
||||
|
||||
type ChatCostDateParams = {
|
||||
@@ -3313,6 +3316,66 @@ class ExperimentalApiMethods {
|
||||
return response.data;
|
||||
};
|
||||
|
||||
listAIProviders = async (): Promise<TypesGen.AIProvider[]> => {
|
||||
const response = await this.axios.get<TypesGen.AIProvider[]>(
|
||||
aiProviderConfigsPath,
|
||||
);
|
||||
return response.data;
|
||||
};
|
||||
|
||||
createAIProvider = async (
|
||||
req: TypesGen.CreateAIProviderRequest,
|
||||
): Promise<TypesGen.AIProvider> => {
|
||||
const response = await this.axios.post<TypesGen.AIProvider>(
|
||||
aiProviderConfigsPath,
|
||||
req,
|
||||
);
|
||||
return response.data;
|
||||
};
|
||||
|
||||
updateAIProvider = async (
|
||||
providerId: string,
|
||||
req: TypesGen.UpdateAIProviderRequest,
|
||||
): Promise<TypesGen.AIProvider> => {
|
||||
const response = await this.axios.patch<TypesGen.AIProvider>(
|
||||
`${aiProviderConfigsPath}/${providerId}`,
|
||||
req,
|
||||
);
|
||||
return response.data;
|
||||
};
|
||||
|
||||
deleteAIProvider = async (providerId: string): Promise<void> => {
|
||||
await this.axios.delete(`${aiProviderConfigsPath}/${providerId}`);
|
||||
};
|
||||
|
||||
getUserAIProviderKeyConfigs = async (
|
||||
user = "me",
|
||||
): Promise<TypesGen.UserAIProviderKeyConfig[]> => {
|
||||
const response = await this.axios.get<TypesGen.UserAIProviderKeyConfig[]>(
|
||||
userAIProviderKeysPath(user),
|
||||
);
|
||||
return response.data;
|
||||
};
|
||||
|
||||
upsertUserAIProviderKey = async (
|
||||
providerId: string,
|
||||
req: TypesGen.CreateUserAIProviderKeyRequest,
|
||||
user = "me",
|
||||
): Promise<TypesGen.UserAIProviderKeyConfig> => {
|
||||
const response = await this.axios.put<TypesGen.UserAIProviderKeyConfig>(
|
||||
`${userAIProviderKeysPath(user)}/${providerId}`,
|
||||
req,
|
||||
);
|
||||
return response.data;
|
||||
};
|
||||
|
||||
deleteUserAIProviderKey = async (
|
||||
providerId: string,
|
||||
user = "me",
|
||||
): Promise<void> => {
|
||||
await this.axios.delete(`${userAIProviderKeysPath(user)}/${providerId}`);
|
||||
};
|
||||
|
||||
getChatSystemPrompt =
|
||||
async (): Promise<TypesGen.ChatSystemPromptResponse> => {
|
||||
const response = await this.axios.get<TypesGen.ChatSystemPromptResponse>(
|
||||
|
||||
@@ -10,7 +10,9 @@ import {
|
||||
type CreateChatMessageRequestWithClearablePlanMode,
|
||||
} from "#/api/api";
|
||||
import type * as TypesGen from "#/api/typesGenerated";
|
||||
import { type AIProviderType, AIProviderTypes } from "#/api/typesGenerated";
|
||||
import type { UsePaginatedQueryOptions } from "#/hooks/usePaginatedQuery";
|
||||
import { formatProviderLabel } from "#/utils/aiProviders";
|
||||
import {
|
||||
projectEditedConversationIntoCache,
|
||||
reconcileEditedMessageInCache,
|
||||
@@ -1607,10 +1609,29 @@ export const chatModels = () => ({
|
||||
|
||||
const chatProviderConfigsKey = ["chat-provider-configs"] as const;
|
||||
|
||||
const toChatProviderConfig = (
|
||||
provider: TypesGen.AIProvider,
|
||||
): TypesGen.ChatProviderConfig => ({
|
||||
id: provider.id,
|
||||
provider: provider.type,
|
||||
display_name: provider.display_name || provider.type,
|
||||
enabled: provider.enabled,
|
||||
has_api_key: provider.api_keys.length > 0,
|
||||
central_api_key_enabled: true,
|
||||
allow_user_api_key: true,
|
||||
allow_central_api_key_fallback: true,
|
||||
base_url: provider.base_url,
|
||||
source: "database",
|
||||
created_at: provider.created_at,
|
||||
updated_at: provider.updated_at,
|
||||
});
|
||||
|
||||
export const chatProviderConfigs = () => ({
|
||||
queryKey: chatProviderConfigsKey,
|
||||
queryFn: (): Promise<TypesGen.ChatProviderConfig[]> =>
|
||||
API.experimental.getChatProviderConfigs(),
|
||||
queryFn: async (): Promise<TypesGen.ChatProviderConfig[]> => {
|
||||
const providers = await API.experimental.listAIProviders();
|
||||
return providers.map(toChatProviderConfig);
|
||||
},
|
||||
});
|
||||
|
||||
const chatModelConfigsKey = ["chat-model-configs"] as const;
|
||||
@@ -1627,8 +1648,17 @@ export const userChatProviderConfigsKey = [
|
||||
|
||||
export const userChatProviderConfigs = () => ({
|
||||
queryKey: userChatProviderConfigsKey,
|
||||
queryFn: (): Promise<TypesGen.UserChatProviderConfig[]> =>
|
||||
API.experimental.getUserChatProviderConfigs(),
|
||||
queryFn: async (): Promise<TypesGen.UserChatProviderConfig[]> => {
|
||||
const configs = await API.experimental.getUserAIProviderKeyConfigs();
|
||||
return configs.map((config) => ({
|
||||
provider_id: config.provider.id,
|
||||
provider: config.provider.type,
|
||||
display_name: config.provider.display_name || config.provider.type,
|
||||
has_user_api_key: config.has_user_api_key,
|
||||
byok_enabled: config.byok_enabled,
|
||||
has_central_api_key_fallback: config.has_provider_api_key,
|
||||
}));
|
||||
},
|
||||
});
|
||||
|
||||
type UpsertUserChatProviderKeyArgs = {
|
||||
@@ -1638,7 +1668,7 @@ type UpsertUserChatProviderKeyArgs = {
|
||||
|
||||
export const upsertUserChatProviderKey = (queryClient: QueryClient) => ({
|
||||
mutationFn: ({ providerConfigId, req }: UpsertUserChatProviderKeyArgs) =>
|
||||
API.experimental.upsertUserChatProviderKey(providerConfigId, req),
|
||||
API.experimental.upsertUserAIProviderKey(providerConfigId, req),
|
||||
onSuccess: async () => {
|
||||
await Promise.all([
|
||||
queryClient.invalidateQueries({
|
||||
@@ -1651,7 +1681,7 @@ export const upsertUserChatProviderKey = (queryClient: QueryClient) => ({
|
||||
|
||||
export const deleteUserChatProviderKey = (queryClient: QueryClient) => ({
|
||||
mutationFn: (providerConfigId: string) =>
|
||||
API.experimental.deleteUserChatProviderKey(providerConfigId),
|
||||
API.experimental.deleteUserAIProviderKey(providerConfigId),
|
||||
onSuccess: async () => {
|
||||
await Promise.all([
|
||||
queryClient.invalidateQueries({
|
||||
@@ -1670,9 +1700,41 @@ const invalidateChatConfigurationQueries = async (queryClient: QueryClient) => {
|
||||
]);
|
||||
};
|
||||
|
||||
const generatedAIProviderName = (provider: string): string => {
|
||||
const suffix =
|
||||
globalThis.crypto?.randomUUID?.() ??
|
||||
`${Date.now().toString(36)}-${Math.random().toString(36).slice(2, 10)}`;
|
||||
return `${provider}-${suffix}`;
|
||||
};
|
||||
|
||||
const normalizeAIProviderType = (provider: string): AIProviderType => {
|
||||
const normalized = provider.trim().toLowerCase();
|
||||
const aliased =
|
||||
normalized === "openai-compatible" || normalized === "openai_compatible"
|
||||
? "openai-compat"
|
||||
: normalized;
|
||||
const providerType = AIProviderTypes.find(
|
||||
(candidate) => candidate === aliased,
|
||||
);
|
||||
if (!providerType) {
|
||||
throw new Error(`Unsupported AI provider type "${provider}".`);
|
||||
}
|
||||
return providerType;
|
||||
};
|
||||
|
||||
export const createChatProviderConfig = (queryClient: QueryClient) => ({
|
||||
mutationFn: (req: TypesGen.CreateChatProviderConfigRequest) =>
|
||||
API.experimental.createChatProviderConfig(req),
|
||||
mutationFn: (req: TypesGen.CreateChatProviderConfigRequest) => {
|
||||
const providerType = normalizeAIProviderType(req.provider);
|
||||
const apiKey = req.api_key;
|
||||
return API.experimental.createAIProvider({
|
||||
type: providerType,
|
||||
name: generatedAIProviderName(providerType),
|
||||
display_name: req.display_name || formatProviderLabel(providerType),
|
||||
base_url: req.base_url ?? "",
|
||||
enabled: req.enabled ?? true,
|
||||
api_keys: apiKey ? [apiKey] : undefined,
|
||||
});
|
||||
},
|
||||
onSuccess: async () => {
|
||||
await invalidateChatConfigurationQueries(queryClient);
|
||||
},
|
||||
@@ -1684,11 +1746,23 @@ type UpdateChatProviderConfigMutationArgs = {
|
||||
};
|
||||
|
||||
export const updateChatProviderConfig = (queryClient: QueryClient) => ({
|
||||
mutationFn: ({
|
||||
mutationFn: async ({
|
||||
providerConfigId,
|
||||
req,
|
||||
}: UpdateChatProviderConfigMutationArgs) =>
|
||||
API.experimental.updateChatProviderConfig(providerConfigId, req),
|
||||
}: UpdateChatProviderConfigMutationArgs) => {
|
||||
const apiKey = req.api_key;
|
||||
return API.experimental.updateAIProvider(providerConfigId, {
|
||||
display_name: req.display_name,
|
||||
base_url: req.base_url,
|
||||
enabled: req.enabled,
|
||||
api_keys:
|
||||
req.api_key === undefined
|
||||
? undefined
|
||||
: apiKey
|
||||
? [{ api_key: apiKey }]
|
||||
: [],
|
||||
});
|
||||
},
|
||||
onSuccess: async () => {
|
||||
await invalidateChatConfigurationQueries(queryClient);
|
||||
},
|
||||
@@ -1696,7 +1770,7 @@ export const updateChatProviderConfig = (queryClient: QueryClient) => ({
|
||||
|
||||
export const deleteChatProviderConfig = (queryClient: QueryClient) => ({
|
||||
mutationFn: (providerConfigId: string) =>
|
||||
API.experimental.deleteChatProviderConfig(providerConfigId),
|
||||
API.experimental.deleteAIProvider(providerConfigId),
|
||||
onSuccess: async () => {
|
||||
await invalidateChatConfigurationQueries(queryClient);
|
||||
},
|
||||
|
||||
Generated
+39
-1
@@ -445,6 +445,19 @@ export interface AIProviderSettings {}
|
||||
*/
|
||||
export const AIProviderSettingsTypeBedrock = "bedrock";
|
||||
|
||||
// From codersdk/chats.go
|
||||
/**
|
||||
* AIProviderSummary is provider metadata embedded in other API responses.
|
||||
*/
|
||||
export interface AIProviderSummary {
|
||||
readonly id: string;
|
||||
readonly type: AIProviderType;
|
||||
readonly name: string;
|
||||
readonly display_name: string;
|
||||
readonly enabled: boolean;
|
||||
readonly deleted: boolean;
|
||||
}
|
||||
|
||||
// From codersdk/aiproviders.go
|
||||
export type AIProviderType =
|
||||
| "anthropic"
|
||||
@@ -2278,6 +2291,7 @@ export interface ChatModelCallConfig {
|
||||
export interface ChatModelConfig {
|
||||
readonly id: string;
|
||||
readonly provider: string;
|
||||
readonly ai_provider_id?: string;
|
||||
readonly model: string;
|
||||
readonly display_name: string;
|
||||
readonly enabled: boolean;
|
||||
@@ -3251,7 +3265,8 @@ export interface CreateChatMessageResponse {
|
||||
* CreateChatModelConfigRequest creates a chat model config.
|
||||
*/
|
||||
export interface CreateChatModelConfigRequest {
|
||||
readonly provider: string;
|
||||
readonly provider?: string;
|
||||
readonly ai_provider_id?: string;
|
||||
readonly model: string;
|
||||
readonly display_name?: string;
|
||||
readonly enabled?: boolean;
|
||||
@@ -3582,6 +3597,15 @@ export interface CreateTokenRequest {
|
||||
readonly allow_list?: readonly APIAllowListTarget[];
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
/**
|
||||
* CreateUserAIProviderKeyRequest creates or replaces a user's API key
|
||||
* for an AI provider.
|
||||
*/
|
||||
export interface CreateUserAIProviderKeyRequest {
|
||||
readonly api_key: string;
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
/**
|
||||
* CreateUserChatProviderKeyRequest creates or replaces a user's API key
|
||||
@@ -8449,6 +8473,7 @@ export interface UpdateChatDesktopEnabledRequest {
|
||||
*/
|
||||
export interface UpdateChatModelConfigRequest {
|
||||
readonly provider?: string;
|
||||
readonly ai_provider_id?: string;
|
||||
readonly model?: string;
|
||||
readonly display_name?: string;
|
||||
readonly enabled?: boolean;
|
||||
@@ -9070,6 +9095,18 @@ export interface User extends ReducedUser {
|
||||
readonly has_ai_seat: boolean;
|
||||
}
|
||||
|
||||
// From codersdk/chats.go
|
||||
/**
|
||||
* UserAIProviderKeyConfig is a provider summary from the current user's
|
||||
* perspective. It reports key presence but never returns key material.
|
||||
*/
|
||||
export interface UserAIProviderKeyConfig {
|
||||
readonly provider: AIProviderSummary;
|
||||
readonly has_user_api_key: boolean;
|
||||
readonly has_provider_api_key: boolean;
|
||||
readonly byok_enabled: boolean;
|
||||
}
|
||||
|
||||
// From codersdk/insights.go
|
||||
/**
|
||||
* UserActivity shows the session time for a user.
|
||||
@@ -9196,6 +9233,7 @@ export interface UserChatProviderConfig {
|
||||
readonly display_name: string;
|
||||
readonly has_user_api_key: boolean;
|
||||
readonly has_central_api_key_fallback: boolean;
|
||||
readonly byok_enabled: boolean;
|
||||
}
|
||||
|
||||
// From codersdk/insights.go
|
||||
|
||||
@@ -18,6 +18,7 @@ const createProvider = (
|
||||
display_name: overrides.display_name ?? overrides.provider,
|
||||
has_user_api_key: overrides.has_user_api_key ?? false,
|
||||
has_central_api_key_fallback: overrides.has_central_api_key_fallback ?? false,
|
||||
byok_enabled: overrides.byok_enabled ?? true,
|
||||
});
|
||||
|
||||
const createModel = (
|
||||
|
||||
@@ -24,6 +24,14 @@ type ProviderStatus = {
|
||||
const getProviderStatus = (
|
||||
provider: UserChatProviderConfig,
|
||||
): ProviderStatus => {
|
||||
if (!provider.byok_enabled) {
|
||||
return {
|
||||
label: "User keys disabled",
|
||||
variant: "default",
|
||||
note: "Personal API keys are disabled by your admin.",
|
||||
};
|
||||
}
|
||||
|
||||
if (provider.has_user_api_key) {
|
||||
return {
|
||||
label: "Key saved",
|
||||
@@ -55,11 +63,13 @@ interface ProviderKeyPanelProps {
|
||||
isRemoving: boolean;
|
||||
onSave: (providerConfigId: string, apiKey: string) => void;
|
||||
onRemove: (providerConfigId: string) => void;
|
||||
hasAmbiguousProviderType: boolean;
|
||||
}
|
||||
|
||||
const ProviderKeyPanel: FC<ProviderKeyPanelProps> = ({
|
||||
provider,
|
||||
models,
|
||||
hasAmbiguousProviderType,
|
||||
isModelsLoading,
|
||||
areModelsUnavailable,
|
||||
isSaving,
|
||||
@@ -76,15 +86,26 @@ const ProviderKeyPanel: FC<ProviderKeyPanelProps> = ({
|
||||
|
||||
const status = getProviderStatus(provider);
|
||||
const enabledModels = models.filter((model) => {
|
||||
return model.enabled && model.provider === provider.provider;
|
||||
return (
|
||||
model.enabled &&
|
||||
(model.ai_provider_id === provider.provider_id ||
|
||||
(!model.ai_provider_id &&
|
||||
!hasAmbiguousProviderType &&
|
||||
model.provider === provider.provider))
|
||||
);
|
||||
});
|
||||
const trimmedApiKey = apiKey.trim();
|
||||
const hasApiKeyValue = apiKey.trim().length > 0;
|
||||
const hasAPIKeyWhitespace =
|
||||
apiKey !== API_KEY_PLACEHOLDER && apiKey.trim() !== apiKey;
|
||||
const saveDisabled =
|
||||
trimmedApiKey.length === 0 ||
|
||||
!provider.byok_enabled ||
|
||||
!hasApiKeyValue ||
|
||||
hasAPIKeyWhitespace ||
|
||||
apiKey === API_KEY_PLACEHOLDER ||
|
||||
isSaving ||
|
||||
isRemoving;
|
||||
const inputDisabled = isSaving || isRemoving;
|
||||
const inputDisabled = !provider.byok_enabled || isSaving || isRemoving;
|
||||
const removeDisabled = isSaving || isRemoving;
|
||||
const providerName = provider.display_name || provider.provider;
|
||||
|
||||
const handleApiKeyFocus = () => {
|
||||
@@ -101,7 +122,7 @@ const ProviderKeyPanel: FC<ProviderKeyPanelProps> = ({
|
||||
return;
|
||||
}
|
||||
|
||||
onSave(provider.provider_id, trimmedApiKey);
|
||||
onSave(provider.provider_id, apiKey);
|
||||
};
|
||||
|
||||
const handleRemoveKey = () => {
|
||||
@@ -136,25 +157,32 @@ const ProviderKeyPanel: FC<ProviderKeyPanelProps> = ({
|
||||
API Key
|
||||
</label>
|
||||
<div className="flex flex-col gap-3 lg:flex-row lg:items-start">
|
||||
<Input
|
||||
id={apiKeyInputId}
|
||||
name={`provider-api-key-${provider.provider_id}`}
|
||||
type="password"
|
||||
autoComplete="off"
|
||||
data-1p-ignore
|
||||
data-lpignore="true"
|
||||
data-form-type="other"
|
||||
data-bwignore
|
||||
className="h-9 font-mono text-[13px] lg:flex-1"
|
||||
placeholder="sk-..."
|
||||
value={apiKey}
|
||||
onFocus={handleApiKeyFocus}
|
||||
onChange={(event) => {
|
||||
setApiKey(event.target.value);
|
||||
setApiKeyTouched(true);
|
||||
}}
|
||||
disabled={inputDisabled}
|
||||
/>
|
||||
<div className="flex flex-col gap-1.5 lg:flex-1">
|
||||
<Input
|
||||
id={apiKeyInputId}
|
||||
name={`provider-api-key-${provider.provider_id}`}
|
||||
type="password"
|
||||
autoComplete="off"
|
||||
data-1p-ignore
|
||||
data-lpignore="true"
|
||||
data-form-type="other"
|
||||
data-bwignore
|
||||
className="h-9 font-mono text-[13px]"
|
||||
placeholder="sk-..."
|
||||
value={apiKey}
|
||||
onFocus={handleApiKeyFocus}
|
||||
onChange={(event) => {
|
||||
setApiKey(event.target.value);
|
||||
setApiKeyTouched(true);
|
||||
}}
|
||||
disabled={inputDisabled}
|
||||
/>
|
||||
{hasAPIKeyWhitespace && (
|
||||
<p className="m-0 text-xs text-content-destructive">
|
||||
API key must not contain leading or trailing whitespace.
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
<div className="flex items-center gap-2">
|
||||
<Button type="submit" size="sm" disabled={saveDisabled}>
|
||||
Save
|
||||
@@ -165,7 +193,7 @@ const ProviderKeyPanel: FC<ProviderKeyPanelProps> = ({
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={() => setIsDeleteDialogOpen(true)}
|
||||
disabled={inputDisabled}
|
||||
disabled={removeDisabled}
|
||||
>
|
||||
Remove
|
||||
</Button>
|
||||
@@ -245,6 +273,14 @@ export const AgentSettingsAPIKeysPageView: FC<
|
||||
onSave,
|
||||
onRemove,
|
||||
}) => {
|
||||
const providerTypeCounts = new Map<string, number>();
|
||||
for (const item of providerItems) {
|
||||
providerTypeCounts.set(
|
||||
item.provider.provider,
|
||||
(providerTypeCounts.get(item.provider.provider) ?? 0) + 1,
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div>
|
||||
<section className="flex flex-col gap-8">
|
||||
@@ -275,6 +311,9 @@ export const AgentSettingsAPIKeysPageView: FC<
|
||||
isRemoving={item.isRemoving}
|
||||
onSave={onSave}
|
||||
onRemove={onRemove}
|
||||
hasAmbiguousProviderType={
|
||||
(providerTypeCounts.get(item.provider.provider) ?? 0) > 1
|
||||
}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
|
||||
+123
-263
@@ -17,8 +17,6 @@ import {
|
||||
} from "./ChatModelAdminPanel";
|
||||
import { formatContextBadge, getKnownModelsForProvider } from "./knownModels";
|
||||
|
||||
// ── Helpers ────────────────────────────────────────────────────
|
||||
|
||||
const now = "2026-02-18T12:00:00.000Z";
|
||||
const nilProviderConfigID = "00000000-0000-0000-0000-000000000000";
|
||||
|
||||
@@ -88,7 +86,7 @@ const setupChatSpies = (state: {
|
||||
async (req) => {
|
||||
const created = createProviderConfig({
|
||||
id: `provider-${Date.now()}`,
|
||||
provider: req.provider,
|
||||
provider: req.provider ?? "",
|
||||
display_name: req.display_name ?? "",
|
||||
has_api_key: (req.api_key ?? "").trim().length > 0,
|
||||
central_api_key_enabled: req.central_api_key_enabled ?? true,
|
||||
@@ -99,7 +97,9 @@ const setupChatSpies = (state: {
|
||||
source: "database",
|
||||
});
|
||||
state.providerConfigs = [
|
||||
...state.providerConfigs.filter((p) => p.provider !== req.provider),
|
||||
...state.providerConfigs.filter(
|
||||
(p) => !(p.id === nilProviderConfigID && p.provider === req.provider),
|
||||
),
|
||||
created,
|
||||
];
|
||||
return created;
|
||||
@@ -152,7 +152,7 @@ const setupChatSpies = (state: {
|
||||
async (req) => {
|
||||
const created = createModelConfig({
|
||||
id: `model-${state.modelConfigs.length + 1}`,
|
||||
provider: req.provider,
|
||||
provider: req.provider ?? "",
|
||||
model: req.model,
|
||||
display_name: req.display_name || req.model,
|
||||
enabled: req.enabled ?? true,
|
||||
@@ -211,8 +211,6 @@ const setupChatSpies = (state: {
|
||||
);
|
||||
};
|
||||
|
||||
// ── Meta ───────────────────────────────────────────────────────
|
||||
|
||||
const meta: Meta<typeof ChatModelAdminPanel> = {
|
||||
title: "pages/AgentsPage/ChatModelAdminPanel",
|
||||
component: ChatModelAdminPanel,
|
||||
@@ -224,7 +222,7 @@ const meta: Meta<typeof ChatModelAdminPanel> = {
|
||||
providerConfigsError: null,
|
||||
modelConfigsError: null,
|
||||
modelCatalogError: null,
|
||||
onCreateProvider: fn(async () => ({})),
|
||||
onCreateProvider: fn(async () => ({ id: "" })),
|
||||
onUpdateProvider: fn(async () => ({})),
|
||||
onDeleteProvider: fn(async () => undefined),
|
||||
isProviderMutationPending: false,
|
||||
@@ -242,8 +240,6 @@ const meta: Meta<typeof ChatModelAdminPanel> = {
|
||||
export default meta;
|
||||
type Story = StoryObj<typeof ChatModelAdminPanel>;
|
||||
|
||||
// ── Providers section stories ──────────────────────────────────
|
||||
|
||||
export const ProviderAccordionCards: Story = {
|
||||
args: {
|
||||
section: "providers" as ChatModelAdminSection,
|
||||
@@ -268,6 +264,100 @@ export const ProviderAccordionCards: Story = {
|
||||
},
|
||||
};
|
||||
|
||||
export const AddProviderFromMenu: Story = {
|
||||
render: function AddProviderFromMenu(args) {
|
||||
const [providerConfigsData, setProviderConfigsData] = useState(
|
||||
args.providerConfigsData,
|
||||
);
|
||||
|
||||
const handleCreateProvider: ChatModelAdminPanelStoryProps["onCreateProvider"] =
|
||||
async (req) => {
|
||||
const created = createProviderConfig({
|
||||
id: `provider-${req.provider}`,
|
||||
provider: req.provider ?? "",
|
||||
display_name: req.display_name ?? "",
|
||||
has_api_key: (req.api_key ?? "").trim().length > 0,
|
||||
central_api_key_enabled: req.central_api_key_enabled ?? true,
|
||||
allow_user_api_key: req.allow_user_api_key ?? true,
|
||||
allow_central_api_key_fallback:
|
||||
req.allow_central_api_key_fallback ?? true,
|
||||
base_url: req.base_url ?? "",
|
||||
source: "database",
|
||||
});
|
||||
await args.onCreateProvider(req);
|
||||
setProviderConfigsData((current) => [...(current ?? []), created]);
|
||||
return created;
|
||||
};
|
||||
|
||||
return (
|
||||
<ChatModelAdminPanel
|
||||
{...args}
|
||||
providerConfigsData={providerConfigsData}
|
||||
onCreateProvider={handleCreateProvider}
|
||||
/>
|
||||
);
|
||||
},
|
||||
args: {
|
||||
section: "providers" as ChatModelAdminSection,
|
||||
sectionLabel: "Providers",
|
||||
sectionDescription:
|
||||
"Connect third-party LLM services like OpenAI, Anthropic, or Google.",
|
||||
providerConfigsData: [
|
||||
createProviderConfig({
|
||||
id: "provider-anthropic",
|
||||
provider: "anthropic",
|
||||
display_name: "Anthropic Migration Test",
|
||||
has_api_key: true,
|
||||
allow_user_api_key: true,
|
||||
}),
|
||||
createProviderConfig({
|
||||
id: "provider-openai",
|
||||
provider: "openai",
|
||||
display_name: "OpenAI Migration Test",
|
||||
has_api_key: true,
|
||||
allow_user_api_key: true,
|
||||
}),
|
||||
createProviderConfig({
|
||||
id: "provider-openai-compatible",
|
||||
provider: "openai-compat",
|
||||
display_name: "OpenAI Compatible Migration Test",
|
||||
has_api_key: true,
|
||||
allow_user_api_key: true,
|
||||
}),
|
||||
],
|
||||
modelCatalogData: { providers: [] },
|
||||
},
|
||||
play: async ({ canvasElement, args }) => {
|
||||
const body = within(canvasElement.ownerDocument.body);
|
||||
|
||||
await userEvent.click(
|
||||
await body.findByRole("button", { name: "Add provider" }),
|
||||
);
|
||||
await userEvent.click(
|
||||
await body.findByRole("menuitem", { name: /Google/i }),
|
||||
);
|
||||
|
||||
expect(await body.findByLabelText(/^API Key$/i)).toBeInTheDocument();
|
||||
await userEvent.click(
|
||||
await body.findByRole("button", { name: "Create provider config" }),
|
||||
);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(args.onCreateProvider).toHaveBeenCalledWith(
|
||||
expect.objectContaining({ provider: "google" }),
|
||||
);
|
||||
});
|
||||
await waitFor(() => {
|
||||
expect(body.getByRole("button", { name: "Save changes" })).toBeDisabled();
|
||||
});
|
||||
|
||||
await userEvent.click(body.getByText("Back"));
|
||||
expect(
|
||||
await body.findByRole("button", { name: "Google" }),
|
||||
).toBeInTheDocument();
|
||||
},
|
||||
};
|
||||
|
||||
export const EnvPresetProviders: Story = {
|
||||
args: {
|
||||
section: "providers" as ChatModelAdminSection,
|
||||
@@ -346,10 +436,9 @@ export const CreateAndUpdateProvider: Story = {
|
||||
|
||||
const handleCreateProvider: ChatModelAdminPanelStoryProps["onCreateProvider"] =
|
||||
async (req) => {
|
||||
const result = await args.onCreateProvider(req);
|
||||
const created = createProviderConfig({
|
||||
id: `provider-${Date.now()}`,
|
||||
provider: req.provider,
|
||||
provider: req.provider ?? "",
|
||||
display_name: req.display_name ?? "",
|
||||
has_api_key: (req.api_key ?? "").trim().length > 0,
|
||||
central_api_key_enabled: req.central_api_key_enabled ?? true,
|
||||
@@ -359,11 +448,15 @@ export const CreateAndUpdateProvider: Story = {
|
||||
base_url: req.base_url ?? "",
|
||||
source: "database",
|
||||
});
|
||||
await args.onCreateProvider(req);
|
||||
setProviderConfigsData((current) => [
|
||||
...(current ?? []).filter((p) => p.provider !== req.provider),
|
||||
...(current ?? []).filter(
|
||||
(p) =>
|
||||
!(p.id === nilProviderConfigID && p.provider === req.provider),
|
||||
),
|
||||
created,
|
||||
]);
|
||||
return result;
|
||||
return created;
|
||||
};
|
||||
|
||||
const handleUpdateProvider: ChatModelAdminPanelStoryProps["onUpdateProvider"] =
|
||||
@@ -446,24 +539,13 @@ export const CreateAndUpdateProvider: Story = {
|
||||
|
||||
await userEvent.click(await body.findByRole("button", { name: /OpenAI/i }));
|
||||
|
||||
await expect(
|
||||
await body.findByRole("switch", { name: "Central API key" }),
|
||||
).toBeChecked();
|
||||
expect(
|
||||
await body.findByRole("switch", { name: "Allow user API keys" }),
|
||||
).not.toBeChecked();
|
||||
expect(
|
||||
body.queryByRole("switch", { name: "Use central key as fallback" }),
|
||||
).not.toBeInTheDocument();
|
||||
|
||||
await userEvent.type(
|
||||
await body.findByLabelText(/^API Key$/i),
|
||||
"sk-provider-key",
|
||||
);
|
||||
await userEvent.type(
|
||||
await body.findByLabelText("Base URL"),
|
||||
"https://proxy.example.com/v1",
|
||||
);
|
||||
const createBaseURLInput = await body.findByLabelText("Base URL");
|
||||
await userEvent.clear(createBaseURLInput);
|
||||
await userEvent.type(createBaseURLInput, "https://proxy.example.com/v1");
|
||||
await userEvent.click(
|
||||
await body.findByRole("button", { name: "Create provider config" }),
|
||||
);
|
||||
@@ -479,9 +561,6 @@ export const CreateAndUpdateProvider: Story = {
|
||||
provider: "openai",
|
||||
api_key: "sk-provider-key",
|
||||
base_url: "https://proxy.example.com/v1",
|
||||
central_api_key_enabled: true,
|
||||
allow_user_api_key: false,
|
||||
allow_central_api_key_fallback: false,
|
||||
}),
|
||||
);
|
||||
|
||||
@@ -491,13 +570,6 @@ export const CreateAndUpdateProvider: Story = {
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
await userEvent.click(
|
||||
await body.findByRole("switch", { name: "Allow user API keys" }),
|
||||
);
|
||||
await userEvent.click(
|
||||
await body.findByRole("switch", { name: "Use central key as fallback" }),
|
||||
);
|
||||
|
||||
const apiKeyInput = body.getByLabelText(/^API Key$/i);
|
||||
await userEvent.clear(apiKeyInput);
|
||||
await userEvent.type(apiKeyInput, "sk-updated-provider-key");
|
||||
@@ -514,180 +586,6 @@ export const CreateAndUpdateProvider: Story = {
|
||||
expect.objectContaining({
|
||||
api_key: "sk-updated-provider-key",
|
||||
base_url: "https://internal-proxy.example.com/v2",
|
||||
allow_user_api_key: true,
|
||||
allow_central_api_key_fallback: true,
|
||||
}),
|
||||
);
|
||||
},
|
||||
};
|
||||
|
||||
export const ProviderWithUserKeysEnabled: Story = {
|
||||
args: {
|
||||
section: "providers" as ChatModelAdminSection,
|
||||
providerConfigsData: [
|
||||
createProviderConfig({
|
||||
id: "provider-openai-user-keys",
|
||||
provider: "openai",
|
||||
display_name: "OpenAI",
|
||||
has_api_key: true,
|
||||
central_api_key_enabled: true,
|
||||
allow_user_api_key: true,
|
||||
allow_central_api_key_fallback: false,
|
||||
}),
|
||||
],
|
||||
modelCatalogData: { providers: [] },
|
||||
},
|
||||
beforeEach: () => {
|
||||
setupChatSpies({
|
||||
providerConfigs: [
|
||||
createProviderConfig({
|
||||
id: "provider-openai-user-keys",
|
||||
provider: "openai",
|
||||
display_name: "OpenAI",
|
||||
has_api_key: true,
|
||||
central_api_key_enabled: true,
|
||||
allow_user_api_key: true,
|
||||
allow_central_api_key_fallback: false,
|
||||
}),
|
||||
],
|
||||
modelConfigs: [],
|
||||
modelCatalog: { providers: [] },
|
||||
});
|
||||
},
|
||||
play: async ({ canvasElement }) => {
|
||||
const body = within(canvasElement.ownerDocument.body);
|
||||
await expect(
|
||||
await body.findByText("User keys enabled"),
|
||||
).toBeInTheDocument();
|
||||
await userEvent.click(body.getByRole("button", { name: /OpenAI/i }));
|
||||
await expect(
|
||||
await body.findByRole("switch", { name: "Allow user API keys" }),
|
||||
).toBeChecked();
|
||||
await expect(
|
||||
await body.findByRole("switch", {
|
||||
name: "Use central key as fallback",
|
||||
}),
|
||||
).not.toBeChecked();
|
||||
},
|
||||
};
|
||||
|
||||
export const ProviderWithCentralFallback: Story = {
|
||||
args: {
|
||||
section: "providers" as ChatModelAdminSection,
|
||||
providerConfigsData: [
|
||||
createProviderConfig({
|
||||
id: "provider-openrouter-fallback",
|
||||
provider: "openrouter",
|
||||
display_name: "OpenRouter",
|
||||
has_api_key: true,
|
||||
central_api_key_enabled: true,
|
||||
allow_user_api_key: true,
|
||||
allow_central_api_key_fallback: true,
|
||||
}),
|
||||
],
|
||||
modelCatalogData: { providers: [] },
|
||||
},
|
||||
beforeEach: () => {
|
||||
setupChatSpies({
|
||||
providerConfigs: [
|
||||
createProviderConfig({
|
||||
id: "provider-openrouter-fallback",
|
||||
provider: "openrouter",
|
||||
display_name: "OpenRouter",
|
||||
has_api_key: true,
|
||||
central_api_key_enabled: true,
|
||||
allow_user_api_key: true,
|
||||
allow_central_api_key_fallback: true,
|
||||
}),
|
||||
],
|
||||
modelConfigs: [],
|
||||
modelCatalog: { providers: [] },
|
||||
});
|
||||
},
|
||||
play: async ({ canvasElement }) => {
|
||||
const body = within(canvasElement.ownerDocument.body);
|
||||
await userEvent.click(
|
||||
await body.findByRole("button", { name: /OpenRouter/i }),
|
||||
);
|
||||
await expect(
|
||||
await body.findByRole("switch", {
|
||||
name: "Use central key as fallback",
|
||||
}),
|
||||
).toBeChecked();
|
||||
},
|
||||
};
|
||||
|
||||
export const ProviderWithUserKeysOnly: Story = {
|
||||
args: {
|
||||
section: "providers" as ChatModelAdminSection,
|
||||
providerConfigsData: [
|
||||
createProviderConfig({
|
||||
id: "provider-google-user-only",
|
||||
provider: "google",
|
||||
display_name: "Google",
|
||||
has_api_key: false,
|
||||
central_api_key_enabled: false,
|
||||
allow_user_api_key: true,
|
||||
allow_central_api_key_fallback: false,
|
||||
}),
|
||||
],
|
||||
modelCatalogData: { providers: [] },
|
||||
},
|
||||
beforeEach: () => {
|
||||
setupChatSpies({
|
||||
providerConfigs: [
|
||||
createProviderConfig({
|
||||
id: "provider-google-user-only",
|
||||
provider: "google",
|
||||
display_name: "Google",
|
||||
has_api_key: false,
|
||||
central_api_key_enabled: false,
|
||||
allow_user_api_key: true,
|
||||
allow_central_api_key_fallback: false,
|
||||
}),
|
||||
],
|
||||
modelConfigs: [],
|
||||
modelCatalog: { providers: [] },
|
||||
});
|
||||
},
|
||||
play: async ({ canvasElement, args }) => {
|
||||
const body = within(canvasElement.ownerDocument.body);
|
||||
await userEvent.click(await body.findByRole("button", { name: /Google/i }));
|
||||
await expect(
|
||||
await body.findByRole("switch", { name: "Central API key" }),
|
||||
).not.toBeChecked();
|
||||
await expect(
|
||||
await body.findByRole("switch", { name: "Allow user API keys" }),
|
||||
).toBeChecked();
|
||||
expect(body.queryByLabelText(/^API Key$/i)).not.toBeInTheDocument();
|
||||
expect(
|
||||
body.queryByRole("switch", { name: "Use central key as fallback" }),
|
||||
).not.toBeInTheDocument();
|
||||
|
||||
const saveButton = body.getByRole("button", { name: "Save changes" });
|
||||
await userEvent.click(
|
||||
body.getByRole("switch", { name: "Central API key" }),
|
||||
);
|
||||
await expect(await body.findByLabelText(/^API Key$/i)).toBeRequired();
|
||||
expect(saveButton).toBeDisabled();
|
||||
|
||||
await userEvent.type(
|
||||
body.getByLabelText(/^API Key$/i),
|
||||
"sk-google-central-key",
|
||||
);
|
||||
await waitFor(() => {
|
||||
expect(saveButton).toBeEnabled();
|
||||
});
|
||||
await userEvent.click(saveButton);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(args.onUpdateProvider).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
expect(args.onUpdateProvider).toHaveBeenCalledWith(
|
||||
"provider-google-user-only",
|
||||
expect.objectContaining({
|
||||
api_key: "sk-google-central-key",
|
||||
central_api_key_enabled: true,
|
||||
}),
|
||||
);
|
||||
},
|
||||
@@ -777,51 +675,6 @@ export const ModelFormUserKeyOnlyProvider: Story = {
|
||||
},
|
||||
};
|
||||
|
||||
export const ProviderInvalidCredentialState: Story = {
|
||||
args: {
|
||||
section: "providers" as ChatModelAdminSection,
|
||||
providerConfigsData: [
|
||||
createProviderConfig({
|
||||
id: "provider-bedrock-invalid",
|
||||
provider: "bedrock",
|
||||
display_name: "Bedrock",
|
||||
has_api_key: false,
|
||||
central_api_key_enabled: false,
|
||||
allow_user_api_key: false,
|
||||
allow_central_api_key_fallback: false,
|
||||
}),
|
||||
],
|
||||
modelCatalogData: { providers: [] },
|
||||
},
|
||||
beforeEach: () => {
|
||||
setupChatSpies({
|
||||
providerConfigs: [
|
||||
createProviderConfig({
|
||||
id: "provider-bedrock-invalid",
|
||||
provider: "bedrock",
|
||||
display_name: "Bedrock",
|
||||
has_api_key: false,
|
||||
central_api_key_enabled: false,
|
||||
allow_user_api_key: false,
|
||||
allow_central_api_key_fallback: false,
|
||||
}),
|
||||
],
|
||||
modelConfigs: [],
|
||||
modelCatalog: { providers: [] },
|
||||
});
|
||||
},
|
||||
play: async ({ canvasElement }) => {
|
||||
const body = within(canvasElement.ownerDocument.body);
|
||||
await userEvent.click(
|
||||
await body.findByRole("button", { name: /Bedrock/i }),
|
||||
);
|
||||
await expect(
|
||||
body.findByText("At least one credential source must be enabled"),
|
||||
).resolves.toBeInTheDocument();
|
||||
expect(body.getByRole("button", { name: "Save changes" })).toBeDisabled();
|
||||
},
|
||||
};
|
||||
|
||||
export const ProviderFormBedrockAmbientCredentials: Story = {
|
||||
args: {
|
||||
section: "providers" as ChatModelAdminSection,
|
||||
@@ -843,6 +696,7 @@ export const ProviderFormBedrockAmbientCredentials: Story = {
|
||||
);
|
||||
|
||||
const apiKeyInput = await body.findByLabelText(/^API Key$/i);
|
||||
const baseURLInput = await body.findByLabelText("Base URL");
|
||||
const createButton = body.getByRole("button", {
|
||||
name: "Create provider config",
|
||||
});
|
||||
@@ -859,10 +713,17 @@ export const ProviderFormBedrockAmbientCredentials: Story = {
|
||||
).resolves.toBeInTheDocument();
|
||||
await expect(
|
||||
body.findByText(
|
||||
/Overrides the Bedrock runtime endpoint\.\s+Set AWS_REGION on\s+the Coder server to select the target region\./i,
|
||||
/Bedrock runtime endpoint\.\s+Use the AWS region for the models this provider should call\./i,
|
||||
),
|
||||
).resolves.toBeInTheDocument();
|
||||
await expect(createButton).toBeEnabled();
|
||||
await expect(createButton).toBeDisabled();
|
||||
await userEvent.type(
|
||||
baseURLInput,
|
||||
"https://bedrock-runtime.us-east-1.amazonaws.com",
|
||||
);
|
||||
await waitFor(() => {
|
||||
expect(createButton).toBeEnabled();
|
||||
});
|
||||
|
||||
await userEvent.click(createButton);
|
||||
await waitFor(() => {
|
||||
@@ -875,9 +736,7 @@ export const ProviderFormBedrockAmbientCredentials: Story = {
|
||||
>;
|
||||
expect(createRequest).toMatchObject({
|
||||
provider: "bedrock",
|
||||
central_api_key_enabled: true,
|
||||
allow_user_api_key: false,
|
||||
allow_central_api_key_fallback: false,
|
||||
base_url: "https://bedrock-runtime.us-east-1.amazonaws.com",
|
||||
});
|
||||
expect(createRequest).not.toHaveProperty("api_key");
|
||||
},
|
||||
@@ -891,6 +750,7 @@ export const ProviderFormBedrockBearerToken: Story = {
|
||||
id: "provider-bedrock-bearer",
|
||||
provider: "bedrock",
|
||||
display_name: "AWS Bedrock",
|
||||
base_url: "https://bedrock-runtime.us-east-1.amazonaws.com",
|
||||
has_api_key: true,
|
||||
central_api_key_enabled: true,
|
||||
allow_user_api_key: false,
|
||||
@@ -936,6 +796,7 @@ export const ProviderFormBedrockClearBearerToken: Story = {
|
||||
id: "provider-bedrock-clear",
|
||||
provider: "bedrock",
|
||||
display_name: "AWS Bedrock",
|
||||
base_url: "https://bedrock-runtime.us-east-1.amazonaws.com",
|
||||
has_api_key: true,
|
||||
central_api_key_enabled: true,
|
||||
allow_user_api_key: false,
|
||||
@@ -1161,7 +1022,6 @@ export const UpdateModelEnabledToggle: Story = {
|
||||
},
|
||||
};
|
||||
|
||||
// ── Per-provider model form stories ────────────────────────────
|
||||
// Each story opens the "Add model" form for a specific provider
|
||||
// so you can visually verify the schema-driven fields render.
|
||||
|
||||
|
||||
@@ -6,13 +6,18 @@ import { ErrorAlert } from "#/components/Alert/ErrorAlert";
|
||||
import { Spinner } from "#/components/Spinner/Spinner";
|
||||
import { cn } from "#/utils/cn";
|
||||
import { formatProviderLabel } from "../../utils/modelOptions";
|
||||
import { normalizeProvider, readOptionalString } from "./helpers";
|
||||
import {
|
||||
getDefaultProviderBaseURL,
|
||||
normalizeProvider,
|
||||
readOptionalString,
|
||||
} from "./helpers";
|
||||
import { ModelsSection } from "./ModelsSection";
|
||||
import { ProvidersSection } from "./ProvidersSection";
|
||||
|
||||
// ── Exported types ─────────────────────────────────────────────
|
||||
export type CreateProviderResult = { id: string };
|
||||
|
||||
export type ProviderState = {
|
||||
key: string;
|
||||
provider: string;
|
||||
label: string;
|
||||
providerConfig: TypesGen.ChatProviderConfig | undefined;
|
||||
@@ -21,14 +26,13 @@ export type ProviderState = {
|
||||
hasManagedAPIKey: boolean;
|
||||
hasCatalogAPIKey: boolean;
|
||||
hasEffectiveAPIKey: boolean;
|
||||
allowUserAPIKey: boolean;
|
||||
isEnvPreset: boolean;
|
||||
baseURL: string;
|
||||
};
|
||||
|
||||
export type ChatModelAdminSection = "providers" | "models";
|
||||
|
||||
// ── Internal helpers ───────────────────────────────────────────
|
||||
|
||||
type CatalogProvider = TypesGen.ChatModelsResponse["providers"][number];
|
||||
|
||||
const nilUUID = "00000000-0000-0000-0000-000000000000";
|
||||
@@ -78,65 +82,106 @@ const getProviderModels = (
|
||||
const getProviderBaseURL = (
|
||||
providerConfig: TypesGen.ChatProviderConfig | undefined,
|
||||
): string => {
|
||||
return readOptionalString(providerConfig?.base_url) ?? "";
|
||||
return (
|
||||
readOptionalString(providerConfig?.base_url) ??
|
||||
getDefaultProviderBaseURL(providerConfig?.provider ?? "")
|
||||
);
|
||||
};
|
||||
|
||||
// ── Hook: compute provider states from query data ──────────────
|
||||
const providerConfigStateKey = (
|
||||
providerConfig: TypesGen.ChatProviderConfig,
|
||||
): string => {
|
||||
const providerID = readOptionalString(providerConfig.id);
|
||||
if (providerID && providerID !== nilUUID) {
|
||||
return providerID;
|
||||
}
|
||||
return normalizeProvider(providerConfig.provider);
|
||||
};
|
||||
|
||||
type ProviderEntry = {
|
||||
key: string;
|
||||
provider: string;
|
||||
};
|
||||
|
||||
const useProviderStates = (
|
||||
modelConfigs: readonly TypesGen.ChatModelConfig[],
|
||||
providerConfigsData: TypesGen.ChatProviderConfig[] | null | undefined,
|
||||
catalogData: TypesGen.ChatModelsResponse | null | undefined,
|
||||
): readonly ProviderState[] => {
|
||||
const orderedProviders: string[] = [];
|
||||
const seenProviders = new Set<string>();
|
||||
const includeProvider = (providerValue: string) => {
|
||||
const normalized = normalizeProvider(providerValue);
|
||||
if (!normalized || seenProviders.has(normalized)) return;
|
||||
seenProviders.add(normalized);
|
||||
orderedProviders.push(normalized);
|
||||
const orderedEntries: ProviderEntry[] = [];
|
||||
const seenEntries = new Set<string>();
|
||||
const includeEntry = (keyValue: string, providerValue: string) => {
|
||||
const key = readOptionalString(keyValue);
|
||||
const provider = normalizeProvider(providerValue);
|
||||
if (!key || !provider || seenEntries.has(key)) return;
|
||||
seenEntries.add(key);
|
||||
orderedEntries.push({ key, provider });
|
||||
};
|
||||
|
||||
const catalogProviders = getCatalogProviders(catalogData);
|
||||
const catalogProvidersByProvider = new Map<string, CatalogProvider>();
|
||||
for (const cp of catalogProviders) {
|
||||
const normalized = normalizeProvider(cp.provider);
|
||||
if (!normalized) continue;
|
||||
includeProvider(normalized);
|
||||
catalogProvidersByProvider.set(normalized, cp);
|
||||
const provider = normalizeProvider(cp.provider);
|
||||
if (!provider) continue;
|
||||
catalogProvidersByProvider.set(provider, cp);
|
||||
}
|
||||
|
||||
const providerConfigKeysByProvider = new Map<string, string[]>();
|
||||
const providerTypesWithConfigs = new Set<string>();
|
||||
for (const pc of providerConfigsData ?? []) {
|
||||
includeProvider(pc.provider);
|
||||
const provider = normalizeProvider(pc.provider);
|
||||
if (!provider) continue;
|
||||
const key = providerConfigStateKey(pc);
|
||||
providerTypesWithConfigs.add(provider);
|
||||
providerConfigKeysByProvider.set(provider, [
|
||||
...(providerConfigKeysByProvider.get(provider) ?? []),
|
||||
key,
|
||||
]);
|
||||
includeEntry(key, provider);
|
||||
}
|
||||
const modelStateKey = (modelConfig: TypesGen.ChatModelConfig): string => {
|
||||
const aiProviderID = readOptionalString(modelConfig.ai_provider_id);
|
||||
if (aiProviderID) {
|
||||
return aiProviderID;
|
||||
}
|
||||
const provider = normalizeProvider(modelConfig.provider);
|
||||
const providerConfigKeys = providerConfigKeysByProvider.get(provider) ?? [];
|
||||
if (providerConfigKeys.length === 1) {
|
||||
return providerConfigKeys[0];
|
||||
}
|
||||
return providerConfigKeys.length === 0 ? provider : "";
|
||||
};
|
||||
|
||||
for (const cp of catalogProviders) {
|
||||
const provider = normalizeProvider(cp.provider);
|
||||
if (!provider || providerTypesWithConfigs.has(provider)) continue;
|
||||
includeEntry(provider, provider);
|
||||
}
|
||||
for (const mc of modelConfigs) {
|
||||
includeProvider(mc.provider);
|
||||
includeEntry(modelStateKey(mc), mc.provider);
|
||||
}
|
||||
|
||||
const providerConfigsByProvider = new Map<
|
||||
string,
|
||||
TypesGen.ChatProviderConfig
|
||||
>();
|
||||
const providerConfigsByKey = new Map<string, TypesGen.ChatProviderConfig>();
|
||||
for (const pc of providerConfigsData ?? []) {
|
||||
const normalized = normalizeProvider(pc.provider);
|
||||
if (!normalized) continue;
|
||||
providerConfigsByProvider.set(normalized, pc);
|
||||
const key = providerConfigStateKey(pc);
|
||||
if (!key) continue;
|
||||
providerConfigsByKey.set(key, pc);
|
||||
}
|
||||
|
||||
const modelConfigsByProvider = new Map<string, TypesGen.ChatModelConfig[]>();
|
||||
const modelConfigsByKey = new Map<string, TypesGen.ChatModelConfig[]>();
|
||||
for (const mc of modelConfigs) {
|
||||
const normalized = normalizeProvider(mc.provider);
|
||||
if (!normalized) continue;
|
||||
const existing = modelConfigsByProvider.get(normalized);
|
||||
const key = modelStateKey(mc);
|
||||
if (!key) continue;
|
||||
const existing = modelConfigsByKey.get(key);
|
||||
if (existing) {
|
||||
existing.push(mc);
|
||||
} else {
|
||||
modelConfigsByProvider.set(normalized, [mc]);
|
||||
modelConfigsByKey.set(key, [mc]);
|
||||
}
|
||||
}
|
||||
|
||||
return orderedProviders.map((provider) => {
|
||||
const providerConfigEntry = providerConfigsByProvider.get(provider);
|
||||
return orderedEntries.map(({ key, provider }) => {
|
||||
const providerConfigEntry = providerConfigsByKey.get(key);
|
||||
const providerConfigSource = getProviderConfigSource(providerConfigEntry);
|
||||
const providerConfig = isDatabaseProviderConfig(
|
||||
providerConfigEntry,
|
||||
@@ -145,9 +190,6 @@ const useProviderStates = (
|
||||
? providerConfigEntry
|
||||
: undefined;
|
||||
const catalogProvider = catalogProvidersByProvider.get(provider);
|
||||
const catalogProviderSource = readOptionalString(
|
||||
(catalogProvider as CatalogProvider & { source?: string })?.source,
|
||||
);
|
||||
const hasManagedAPIKey = hasProviderAPIKey(providerConfig);
|
||||
const hasProviderEntryAPIKey = hasProviderAPIKey(providerConfigEntry);
|
||||
const hasCatalogAPIKey = catalogProvider
|
||||
@@ -159,15 +201,14 @@ const useProviderStates = (
|
||||
const hasBedrockAmbientCredentials =
|
||||
provider === "bedrock" &&
|
||||
providerConfig?.central_api_key_enabled === true;
|
||||
const modelConfigsForProvider = modelConfigsByProvider.get(provider) ?? [];
|
||||
const modelConfigsForProvider = modelConfigsByKey.get(key) ?? [];
|
||||
const isCatalogEnvPreset =
|
||||
!providerConfig &&
|
||||
envPresetProviders.has(provider) &&
|
||||
(catalogProviderSource === "env" || hasCatalogAPIKey);
|
||||
!providerConfig && envPresetProviders.has(provider) && hasCatalogAPIKey;
|
||||
const isEnvPreset =
|
||||
providerConfigSource === "env_preset" || isCatalogEnvPreset;
|
||||
|
||||
return {
|
||||
key,
|
||||
provider,
|
||||
label,
|
||||
providerConfig,
|
||||
@@ -178,14 +219,13 @@ const useProviderStates = (
|
||||
hasEffectiveAPIKey: providerConfigEntry
|
||||
? hasProviderEntryAPIKey || hasBedrockAmbientCredentials
|
||||
: hasManagedAPIKey || hasCatalogAPIKey,
|
||||
allowUserAPIKey: providerConfigEntry?.allow_user_api_key ?? true,
|
||||
isEnvPreset,
|
||||
baseURL: getProviderBaseURL(providerConfigEntry),
|
||||
};
|
||||
});
|
||||
};
|
||||
|
||||
// ── Component ──────────────────────────────────────────────────
|
||||
|
||||
interface ChatModelAdminPanelProps {
|
||||
className?: string;
|
||||
section?: ChatModelAdminSection;
|
||||
@@ -203,7 +243,7 @@ interface ChatModelAdminPanelProps {
|
||||
// Provider mutation handlers.
|
||||
onCreateProvider: (
|
||||
req: TypesGen.CreateChatProviderConfigRequest,
|
||||
) => Promise<unknown>;
|
||||
) => Promise<CreateProviderResult>;
|
||||
onUpdateProvider: (
|
||||
providerConfigId: string,
|
||||
req: TypesGen.UpdateChatProviderConfigRequest,
|
||||
@@ -251,13 +291,10 @@ export const ChatModelAdminPanel: FC<ChatModelAdminPanelProps> = ({
|
||||
isDeletingModel,
|
||||
modelMutationError,
|
||||
}) => {
|
||||
// ── Sorted model configs ───────────────────────────────────
|
||||
const modelConfigs = (modelConfigsData ?? []).slice().sort((a, b) => {
|
||||
const cmp = a.provider.localeCompare(b.provider);
|
||||
return cmp !== 0 ? cmp : a.model.localeCompare(b.model);
|
||||
});
|
||||
|
||||
// ── Provider states ────────────────────────────────────────
|
||||
const providerStates = useProviderStates(
|
||||
modelConfigs,
|
||||
providerConfigsData,
|
||||
@@ -276,7 +313,6 @@ export const ChatModelAdminPanel: FC<ChatModelAdminPanelProps> = ({
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Content */}
|
||||
<div className="flex flex-1 flex-col gap-8">
|
||||
{section === "providers" ? (
|
||||
<ProvidersSection
|
||||
@@ -305,8 +341,6 @@ export const ChatModelAdminPanel: FC<ChatModelAdminPanelProps> = ({
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Errors — rendered at the bottom */}
|
||||
{providerConfigsError && <ErrorAlert error={providerConfigsError} />}
|
||||
{modelConfigsError && <ErrorAlert error={modelConfigsError} />}
|
||||
{modelCatalogError && <ErrorAlert error={modelCatalogError} />}
|
||||
|
||||
@@ -34,6 +34,7 @@ import { getFormHelpers } from "#/utils/formUtils";
|
||||
import { BackButton } from "../BackButton";
|
||||
import { ConfirmDeleteDialog } from "../ConfirmDeleteDialog";
|
||||
import type { ProviderState } from "./ChatModelAdminPanel";
|
||||
import { readOptionalString } from "./helpers";
|
||||
import {
|
||||
GeneralModelConfigFields,
|
||||
ModelConfigFields,
|
||||
@@ -141,6 +142,9 @@ export const ModelForm: FC<ModelFormProps> = ({
|
||||
return "add";
|
||||
})();
|
||||
|
||||
const selectedProviderType =
|
||||
selectedProviderState?.provider ?? selectedProvider;
|
||||
|
||||
const form = useFormik<ModelFormValues>({
|
||||
initialValues,
|
||||
validationSchema,
|
||||
@@ -158,7 +162,7 @@ export const ModelForm: FC<ModelFormProps> = ({
|
||||
);
|
||||
|
||||
const buildResult = buildModelConfigFromForm(
|
||||
selectedProvider,
|
||||
selectedProviderType,
|
||||
values.config,
|
||||
);
|
||||
if (Object.keys(buildResult.fieldErrors).length > 0) return;
|
||||
@@ -166,8 +170,17 @@ export const ModelForm: FC<ModelFormProps> = ({
|
||||
const trimmedDisplayName = values.displayName.trim();
|
||||
const builtModelConfig = buildResult.modelConfig;
|
||||
|
||||
const selectedProviderConfigID =
|
||||
selectedProviderState?.providerConfig?.id;
|
||||
|
||||
if (isEditing && editingModel) {
|
||||
const req: TypesGen.UpdateChatModelConfigRequest = {
|
||||
...(selectedProviderConfigID &&
|
||||
selectedProviderConfigID !==
|
||||
readOptionalString(editingModel.ai_provider_id) && {
|
||||
provider: selectedProviderState.provider,
|
||||
ai_provider_id: selectedProviderConfigID,
|
||||
}),
|
||||
...(trimmedModel !== editingModel.model && {
|
||||
model: trimmedModel,
|
||||
}),
|
||||
@@ -198,7 +211,8 @@ export const ModelForm: FC<ModelFormProps> = ({
|
||||
if (!selectedProvider || !selectedProviderState?.providerConfig) return;
|
||||
|
||||
const req: TypesGen.CreateChatModelConfigRequest = {
|
||||
provider: selectedProvider,
|
||||
provider: selectedProviderState.provider,
|
||||
ai_provider_id: selectedProviderState.providerConfig.id,
|
||||
model: trimmedModel,
|
||||
enabled: values.enabled,
|
||||
is_default: values.isDefault,
|
||||
@@ -227,7 +241,7 @@ export const ModelForm: FC<ModelFormProps> = ({
|
||||
const getFieldHelpers = getFormHelpers(form);
|
||||
|
||||
const modelConfigFormBuildResult = buildModelConfigFromForm(
|
||||
selectedProvider,
|
||||
selectedProviderType,
|
||||
form.values.config,
|
||||
);
|
||||
|
||||
@@ -249,7 +263,10 @@ export const ModelForm: FC<ModelFormProps> = ({
|
||||
<Select
|
||||
value={selectedProvider ?? ""}
|
||||
onValueChange={onSelectedProviderChange}
|
||||
disabled={isEditing || isDuplicating || providerStates.length === 0}
|
||||
disabled={
|
||||
((isEditing || isDuplicating) && selectedProviderState !== null) ||
|
||||
providerStates.length === 0
|
||||
}
|
||||
>
|
||||
<SelectTrigger
|
||||
id="providerSelect"
|
||||
@@ -259,7 +276,7 @@ export const ModelForm: FC<ModelFormProps> = ({
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
{providerStates.map((ps) => (
|
||||
<SelectItem key={ps.provider} value={ps.provider}>
|
||||
<SelectItem key={ps.key} value={ps.key}>
|
||||
<span className="flex items-center gap-2">
|
||||
<ProviderIcon provider={ps.provider} className="h-4 w-4" />
|
||||
{ps.label}
|
||||
@@ -394,7 +411,7 @@ export const ModelForm: FC<ModelFormProps> = ({
|
||||
form={form}
|
||||
modelField={modelField}
|
||||
mode={mode}
|
||||
selectedProvider={selectedProvider}
|
||||
selectedProvider={selectedProviderType}
|
||||
disabled={isSaving}
|
||||
/>
|
||||
<div className="grid gap-1.5">
|
||||
|
||||
@@ -6,6 +6,7 @@ import type { ProviderState } from "./ChatModelAdminPanel";
|
||||
import { ModelsSection } from "./ModelsSection";
|
||||
|
||||
const providerState: ProviderState = {
|
||||
key: "provider-config-id",
|
||||
provider: "openai",
|
||||
label: "OpenAI",
|
||||
providerConfig: {
|
||||
@@ -27,6 +28,7 @@ const providerState: ProviderState = {
|
||||
hasManagedAPIKey: true,
|
||||
hasCatalogAPIKey: true,
|
||||
hasEffectiveAPIKey: true,
|
||||
allowUserAPIKey: false,
|
||||
isEnvPreset: false,
|
||||
baseURL: "",
|
||||
};
|
||||
@@ -42,6 +44,7 @@ const providerStateWithoutAPIKey: ProviderState = {
|
||||
hasManagedAPIKey: false,
|
||||
hasCatalogAPIKey: false,
|
||||
hasEffectiveAPIKey: false,
|
||||
allowUserAPIKey: false,
|
||||
};
|
||||
|
||||
const baseModelConfig: TypesGen.ChatModelConfig = {
|
||||
@@ -294,6 +297,7 @@ export const SavesDuplicateAsCreateRequest: Story = {
|
||||
throw new Error("Expected create request.");
|
||||
}
|
||||
expect(createReq).toEqual({
|
||||
ai_provider_id: "provider-config-id",
|
||||
provider: "openai",
|
||||
model: "gpt-4.1-copy",
|
||||
display_name: "GPT-4.1 Copy",
|
||||
@@ -344,6 +348,7 @@ export const SavesNonDefaultDuplicateWithEditableEnabled: Story = {
|
||||
|
||||
const createModelMock = args.onCreateModel as ReturnType<typeof fn>;
|
||||
expect(createModelMock.mock.calls[0]?.[0]).toEqual({
|
||||
ai_provider_id: "provider-config-id",
|
||||
provider: "openai",
|
||||
model: "gpt-4.1-copy",
|
||||
display_name: "GPT-4.1",
|
||||
@@ -385,6 +390,7 @@ export const SavesDisabledDuplicateWithEditableEnabled: Story = {
|
||||
await waitFor(() => expect(args.onCreateModel).toHaveBeenCalledTimes(1));
|
||||
const createModelMock = args.onCreateModel as ReturnType<typeof fn>;
|
||||
expect(createModelMock.mock.calls[0]?.[0]).toEqual({
|
||||
ai_provider_id: "provider-config-id",
|
||||
provider: "openai",
|
||||
model: "gpt-4.1-disabled-copy",
|
||||
display_name: "GPT-4.1 Disabled",
|
||||
|
||||
@@ -6,7 +6,7 @@ import {
|
||||
StarIcon,
|
||||
TriangleAlertIcon,
|
||||
} from "lucide-react";
|
||||
import type { FC } from "react";
|
||||
import { type FC, useEffect, useState } from "react";
|
||||
import { Link, useLocation, useSearchParams } from "react-router";
|
||||
import type * as TypesGen from "#/api/typesGenerated";
|
||||
import { Badge } from "#/components/Badge/Badge";
|
||||
@@ -25,6 +25,7 @@ import {
|
||||
import { cn } from "#/utils/cn";
|
||||
import { SectionHeader } from "../SectionHeader";
|
||||
import type { ProviderState } from "./ChatModelAdminPanel";
|
||||
import { normalizeProvider, readOptionalString } from "./helpers";
|
||||
import { ModelForm } from "./ModelForm";
|
||||
import { ProviderIcon } from "./ProviderIcon";
|
||||
import { hasCustomPricing } from "./pricingFields";
|
||||
@@ -44,6 +45,28 @@ const clearModelViewParams = (params: URLSearchParams) => {
|
||||
}
|
||||
};
|
||||
|
||||
const modelConfigProviderKey = (
|
||||
modelConfig: TypesGen.ChatModelConfig,
|
||||
providerStates: readonly ProviderState[],
|
||||
): string => {
|
||||
const providerID = readOptionalString(modelConfig.ai_provider_id);
|
||||
if (providerID) {
|
||||
return providerID;
|
||||
}
|
||||
|
||||
const provider = normalizeProvider(modelConfig.provider);
|
||||
const providerMatches = providerStates.filter(
|
||||
(providerState) => providerState.provider === provider,
|
||||
);
|
||||
if (providerMatches.length === 1) {
|
||||
return providerMatches[0].key;
|
||||
}
|
||||
if (providerMatches.length > 1) {
|
||||
return "";
|
||||
}
|
||||
return provider;
|
||||
};
|
||||
|
||||
const canManageProviderModels = (providerState: ProviderState | undefined) => {
|
||||
return Boolean(
|
||||
providerState?.providerConfig &&
|
||||
@@ -85,6 +108,9 @@ export const ModelsSection: FC<ModelsSectionProps> = ({
|
||||
onDeleteModel,
|
||||
}) => {
|
||||
const [searchParams, setSearchParams] = useSearchParams();
|
||||
const [selectedProviderOverride, setSelectedProviderOverride] = useState<
|
||||
string | null
|
||||
>(null);
|
||||
const location = useLocation();
|
||||
|
||||
// Derive the current view from URL search params so that
|
||||
@@ -122,9 +148,27 @@ export const ModelsSection: FC<ModelsSectionProps> = ({
|
||||
state: options?.replace ? location.state : { pushed: true },
|
||||
});
|
||||
};
|
||||
const modelViewIdentity = (() => {
|
||||
switch (view.mode) {
|
||||
case "add":
|
||||
return `add:${view.provider}`;
|
||||
case "edit":
|
||||
return `edit:${view.model.id}`;
|
||||
case "duplicate":
|
||||
return `duplicate:${view.sourceModel.id}`;
|
||||
default:
|
||||
return "list";
|
||||
}
|
||||
})();
|
||||
|
||||
useEffect(() => {
|
||||
void modelViewIdentity;
|
||||
setSelectedProviderOverride(null);
|
||||
}, [modelViewIdentity]);
|
||||
|
||||
// Clear model-related search params and return to the list.
|
||||
const clearModelView = () => {
|
||||
setSelectedProviderOverride(null);
|
||||
setSearchParams((prev) => {
|
||||
const next = new URLSearchParams(prev);
|
||||
clearModelViewParams(next);
|
||||
@@ -133,6 +177,7 @@ export const ModelsSection: FC<ModelsSectionProps> = ({
|
||||
};
|
||||
|
||||
const exitModelView = () => {
|
||||
setSelectedProviderOverride(null);
|
||||
setSearchParams(
|
||||
(prev) => {
|
||||
const next = new URLSearchParams(prev);
|
||||
@@ -153,13 +198,14 @@ export const ModelsSection: FC<ModelsSectionProps> = ({
|
||||
const duplicateSourceModel =
|
||||
view.mode === "duplicate" ? view.sourceModel : undefined;
|
||||
const effectiveProvider =
|
||||
view.mode === "edit"
|
||||
? view.model.provider
|
||||
selectedProviderOverride ??
|
||||
(view.mode === "edit"
|
||||
? modelConfigProviderKey(view.model, providerStates)
|
||||
: view.mode === "duplicate"
|
||||
? view.sourceModel.provider
|
||||
: view.provider;
|
||||
? modelConfigProviderKey(view.sourceModel, providerStates)
|
||||
: view.provider);
|
||||
const effectiveProviderState =
|
||||
providerStates.find((ps) => ps.provider === effectiveProvider) ?? null;
|
||||
providerStates.find((ps) => ps.key === effectiveProvider) ?? null;
|
||||
const formKey =
|
||||
view.mode === "edit"
|
||||
? `edit:${view.model.id}`
|
||||
@@ -178,7 +224,9 @@ export const ModelsSection: FC<ModelsSectionProps> = ({
|
||||
onSelectedProviderChange={(provider) => {
|
||||
if (view.mode === "add") {
|
||||
setModelViewParam("newModel", provider, { replace: true });
|
||||
return;
|
||||
}
|
||||
setSelectedProviderOverride(provider);
|
||||
}}
|
||||
modelConfigsUnavailable={modelConfigsUnavailable}
|
||||
isSaving={isCreating || isUpdating}
|
||||
@@ -222,9 +270,9 @@ export const ModelsSection: FC<ModelsSectionProps> = ({
|
||||
<DropdownMenuContent align="end">
|
||||
{addableProviders.map((ps) => (
|
||||
<DropdownMenuItem
|
||||
key={ps.provider}
|
||||
key={ps.key}
|
||||
onClick={() => {
|
||||
setModelViewParam("newModel", ps.provider);
|
||||
setModelViewParam("newModel", ps.key);
|
||||
}}
|
||||
className="gap-2"
|
||||
>
|
||||
@@ -285,10 +333,12 @@ export const ModelsSection: FC<ModelsSectionProps> = ({
|
||||
const starUnavailable =
|
||||
isUpdating || modelConfig.is_default || !modelConfig.enabled;
|
||||
const providerState = providerStates.find(
|
||||
(ps) => ps.provider === modelConfig.provider,
|
||||
(ps) =>
|
||||
ps.key === modelConfigProviderKey(modelConfig, providerStates),
|
||||
);
|
||||
const duplicateUnavailable = Boolean(
|
||||
providerState && !canManageProviderModels(providerState),
|
||||
);
|
||||
const duplicateUnavailable =
|
||||
!canManageProviderModels(providerState);
|
||||
|
||||
return (
|
||||
<div
|
||||
|
||||
@@ -13,7 +13,6 @@ import { Alert, AlertDescription, AlertTitle } from "#/components/Alert/Alert";
|
||||
import { Button } from "#/components/Button/Button";
|
||||
import { Input } from "#/components/Input/Input";
|
||||
import { Spinner } from "#/components/Spinner/Spinner";
|
||||
import { Switch } from "#/components/Switch/Switch";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
@@ -22,10 +21,12 @@ import {
|
||||
import { formatProviderLabel } from "../../utils/modelOptions";
|
||||
import { BackButton } from "../BackButton";
|
||||
import { ConfirmDeleteDialog } from "../ConfirmDeleteDialog";
|
||||
import type { ProviderState } from "./ChatModelAdminPanel";
|
||||
import { readOptionalString } from "./helpers";
|
||||
import type {
|
||||
CreateProviderResult,
|
||||
ProviderState,
|
||||
} from "./ChatModelAdminPanel";
|
||||
import { getProviderBaseURLPlaceholder, readOptionalString } from "./helpers";
|
||||
import { ProviderIcon } from "./ProviderIcon";
|
||||
import { normalizeProviderPolicyDefaults } from "./providerPolicyDefaults";
|
||||
|
||||
// Sentinel value used to represent an existing API key that the
|
||||
// backend will not reveal. If the user has not touched the field,
|
||||
@@ -38,7 +39,7 @@ interface ProviderFormProps {
|
||||
isProviderMutationPending: boolean;
|
||||
onCreateProvider: (
|
||||
req: TypesGen.CreateChatProviderConfigRequest,
|
||||
) => Promise<unknown>;
|
||||
) => Promise<CreateProviderResult>;
|
||||
onUpdateProvider: (
|
||||
providerConfigId: string,
|
||||
req: TypesGen.UpdateChatProviderConfigRequest,
|
||||
@@ -62,27 +63,13 @@ export const ProviderForm: FC<ProviderFormProps> = ({
|
||||
const apiKeyInputId = useId();
|
||||
const baseURLInputId = useId();
|
||||
|
||||
// Providers backed by the OpenAI SDK expect /v1 in the base
|
||||
// URL, while others (e.g. Anthropic) do not.
|
||||
const baseURLPlaceholder =
|
||||
provider === "anthropic" || provider === "bedrock" || provider === "google"
|
||||
? "https://api.example.com"
|
||||
: "https://api.example.com/v1";
|
||||
|
||||
const normalizedProviderConfig = providerConfig
|
||||
? normalizeProviderPolicyDefaults(providerConfig)
|
||||
: undefined;
|
||||
const baseURLPlaceholder = getProviderBaseURLPlaceholder(provider);
|
||||
|
||||
// Initial values are snapshotted when the provider config changes
|
||||
// so we can detect dirty state.
|
||||
const [initialValues] = useState(() => ({
|
||||
displayName: readOptionalString(providerConfig?.display_name) ?? "",
|
||||
baseURL,
|
||||
centralAPIKeyEnabled:
|
||||
normalizedProviderConfig?.central_api_key_enabled ?? true,
|
||||
allowUserAPIKey: normalizedProviderConfig?.allow_user_api_key ?? false,
|
||||
allowCentralAPIKeyFallback:
|
||||
normalizedProviderConfig?.allow_central_api_key_fallback ?? false,
|
||||
}));
|
||||
|
||||
const [displayName, setDisplayName] = useState(initialValues.displayName);
|
||||
@@ -92,84 +79,53 @@ export const ProviderForm: FC<ProviderFormProps> = ({
|
||||
const [apiKeyTouched, setApiKeyTouched] = useState(false);
|
||||
const [apiKeyModified, setApiKeyModified] = useState(false);
|
||||
const [baseURLValue, setBaseURLValue] = useState(initialValues.baseURL);
|
||||
const [centralAPIKeyEnabled, setCentralAPIKeyEnabled] = useState(
|
||||
initialValues.centralAPIKeyEnabled,
|
||||
);
|
||||
const [allowUserAPIKey, setAllowUserAPIKey] = useState(
|
||||
initialValues.allowUserAPIKey,
|
||||
);
|
||||
const [allowCentralAPIKeyFallback, setAllowCentralAPIKeyFallback] = useState(
|
||||
initialValues.allowCentralAPIKeyFallback,
|
||||
);
|
||||
const [confirmingDelete, setConfirmingDelete] = useState(false);
|
||||
|
||||
const isBedrockProvider = provider === "bedrock";
|
||||
const isAPIKeyEnvManaged = isEnvPreset && !providerConfig;
|
||||
const shouldShowAPIKeyField = centralAPIKeyEnabled;
|
||||
const shouldShowFallbackToggle = centralAPIKeyEnabled && allowUserAPIKey;
|
||||
const effectiveInitialFallback =
|
||||
initialValues.centralAPIKeyEnabled &&
|
||||
initialValues.allowUserAPIKey &&
|
||||
initialValues.allowCentralAPIKeyFallback;
|
||||
const effectiveFallback =
|
||||
shouldShowFallbackToggle && allowCentralAPIKeyFallback;
|
||||
// Most providers require a stored deployment key whenever central-key
|
||||
// usage is enabled and there is no saved key yet. Bedrock can also use
|
||||
// ambient AWS credentials from the Coder server, so its API key stays
|
||||
// optional.
|
||||
const requiresAPIKey =
|
||||
!isAPIKeyEnvManaged &&
|
||||
!providerState.allowUserAPIKey &&
|
||||
!isBedrockProvider &&
|
||||
centralAPIKeyEnabled &&
|
||||
!providerState.hasManagedAPIKey;
|
||||
|
||||
const effectiveApiKey =
|
||||
apiKeyTouched && apiKey !== API_KEY_PLACEHOLDER ? apiKey.trim() : "";
|
||||
apiKeyTouched && apiKey !== API_KEY_PLACEHOLDER ? apiKey : "";
|
||||
const hasTypedAPIKey = effectiveApiKey.length > 0;
|
||||
// Clearing a saved Bedrock bearer token switches the provider back
|
||||
// to ambient AWS credentials, so updates must send an explicit
|
||||
// empty string.
|
||||
const isClearingBedrockAPIKey =
|
||||
isBedrockProvider &&
|
||||
providerState.hasManagedAPIKey &&
|
||||
apiKeyModified &&
|
||||
effectiveApiKey === "";
|
||||
const hasPendingAPIKeyChange =
|
||||
(centralAPIKeyEnabled && hasTypedAPIKey) || isClearingBedrockAPIKey;
|
||||
const shouldCreateAPIKey = centralAPIKeyEnabled && hasTypedAPIKey;
|
||||
const hasCredentialSource = centralAPIKeyEnabled || allowUserAPIKey;
|
||||
const hasAPIKeyWhitespace =
|
||||
hasTypedAPIKey && effectiveApiKey.trim() !== effectiveApiKey;
|
||||
// Clearing a saved provider-scoped key switches the provider to
|
||||
// BYOK-only behavior, or ambient AWS credentials for Bedrock.
|
||||
const isClearingAPIKey =
|
||||
providerState.hasManagedAPIKey && apiKeyModified && effectiveApiKey === "";
|
||||
const hasPendingAPIKeyChange = hasTypedAPIKey || isClearingAPIKey;
|
||||
const shouldCreateAPIKey = hasTypedAPIKey;
|
||||
const apiKeyDescription = isBedrockProvider
|
||||
? "Bearer token for Bedrock authentication. Leave empty to use ambient AWS credentials."
|
||||
: "Secret key used to authenticate requests to this provider.";
|
||||
const baseURLDescription = isBedrockProvider
|
||||
? "Optional. Overrides the Bedrock runtime endpoint. Set AWS_REGION on the Coder server to select the target region."
|
||||
: "Custom endpoint for this provider. Leave empty to use the default.";
|
||||
? "Bedrock runtime endpoint. Use the AWS region for the models this provider should call."
|
||||
: "Endpoint used to call this provider.";
|
||||
const apiKeyPlaceholder = isBedrockProvider ? "Enter bearer token" : "sk-...";
|
||||
const deleteProviderDescription = normalizedProviderConfig?.allow_user_api_key
|
||||
? "Are you sure you want to delete this provider? Any personal API " +
|
||||
"keys that users have saved for this provider will also be " +
|
||||
"permanently deleted. This action is irreversible."
|
||||
: "Are you sure you want to delete this provider? This action is irreversible.";
|
||||
// New Bedrock providers can be saved immediately with ambient AWS
|
||||
// credentials, even before any fields differ from their defaults.
|
||||
const hasNewBedrockAmbientConfiguration =
|
||||
isBedrockProvider && !providerConfig && centralAPIKeyEnabled;
|
||||
const deleteProviderDescription =
|
||||
"Are you sure you want to delete this provider? The provider will be " +
|
||||
"disabled and hidden from new model configuration. Existing model " +
|
||||
"configs that reference it remain saved but cannot run until updated.";
|
||||
const hasNewProviderConfiguration = !providerConfig;
|
||||
|
||||
const isDirty =
|
||||
displayName.trim() !== initialValues.displayName ||
|
||||
hasPendingAPIKeyChange ||
|
||||
baseURLValue.trim() !== initialValues.baseURL.trim() ||
|
||||
centralAPIKeyEnabled !== initialValues.centralAPIKeyEnabled ||
|
||||
allowUserAPIKey !== initialValues.allowUserAPIKey ||
|
||||
effectiveFallback !== effectiveInitialFallback ||
|
||||
hasNewBedrockAmbientConfiguration;
|
||||
hasNewProviderConfiguration;
|
||||
|
||||
const hasBaseURL = baseURLValue.trim().length > 0;
|
||||
const canSave =
|
||||
!providerConfigsUnavailable &&
|
||||
!isProviderMutationPending &&
|
||||
!isAPIKeyEnvManaged &&
|
||||
isDirty &&
|
||||
hasCredentialSource &&
|
||||
hasBaseURL &&
|
||||
!hasAPIKeyWhitespace &&
|
||||
(!requiresAPIKey || hasTypedAPIKey);
|
||||
const canAddModel =
|
||||
Boolean(providerConfig) &&
|
||||
@@ -177,7 +133,7 @@ export const ProviderForm: FC<ProviderFormProps> = ({
|
||||
providerConfig?.allow_user_api_key === true);
|
||||
|
||||
const handleAddModel = () => {
|
||||
const params = new URLSearchParams({ newModel: provider });
|
||||
const params = new URLSearchParams({ newModel: providerState.key });
|
||||
navigate(`/agents/settings/models?${params.toString()}`, {
|
||||
state: { pushed: true },
|
||||
});
|
||||
@@ -189,7 +145,8 @@ export const ProviderForm: FC<ProviderFormProps> = ({
|
||||
providerConfigsUnavailable ||
|
||||
isProviderMutationPending ||
|
||||
isAPIKeyEnvManaged ||
|
||||
!hasCredentialSource
|
||||
!hasBaseURL ||
|
||||
hasAPIKeyWhitespace
|
||||
) {
|
||||
return;
|
||||
}
|
||||
@@ -213,15 +170,6 @@ export const ProviderForm: FC<ProviderFormProps> = ({
|
||||
...(trimmedBaseURL !== currentBaseURL && {
|
||||
base_url: trimmedBaseURL,
|
||||
}),
|
||||
...(centralAPIKeyEnabled !== initialValues.centralAPIKeyEnabled && {
|
||||
central_api_key_enabled: centralAPIKeyEnabled,
|
||||
}),
|
||||
...(allowUserAPIKey !== initialValues.allowUserAPIKey && {
|
||||
allow_user_api_key: allowUserAPIKey,
|
||||
}),
|
||||
...(effectiveFallback !== effectiveInitialFallback && {
|
||||
allow_central_api_key_fallback: effectiveFallback,
|
||||
}),
|
||||
};
|
||||
|
||||
if (Object.keys(req).length === 0) {
|
||||
@@ -231,28 +179,21 @@ export const ProviderForm: FC<ProviderFormProps> = ({
|
||||
try {
|
||||
await onUpdateProvider(providerConfig.id, req);
|
||||
} catch {
|
||||
// Error is surfaced via the mutation's error state
|
||||
// in ChatModelAdminPanel, no toast needed.
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
const req: TypesGen.CreateChatProviderConfigRequest = {
|
||||
provider,
|
||||
base_url: trimmedBaseURL,
|
||||
...(shouldCreateAPIKey && { api_key: effectiveApiKey }),
|
||||
central_api_key_enabled: centralAPIKeyEnabled,
|
||||
allow_user_api_key: allowUserAPIKey,
|
||||
allow_central_api_key_fallback: effectiveFallback,
|
||||
...(trimmedDisplayName && {
|
||||
display_name: trimmedDisplayName,
|
||||
}),
|
||||
...(trimmedBaseURL && { base_url: trimmedBaseURL }),
|
||||
};
|
||||
|
||||
try {
|
||||
await onCreateProvider(req);
|
||||
} catch {
|
||||
// Error is surfaced via the mutation's error state
|
||||
// in ChatModelAdminPanel, no toast needed.
|
||||
return;
|
||||
}
|
||||
}
|
||||
@@ -275,9 +216,7 @@ export const ProviderForm: FC<ProviderFormProps> = ({
|
||||
|
||||
return (
|
||||
<div className="flex min-h-full flex-col">
|
||||
{/* Back */}
|
||||
<BackButton onClick={onBack} />
|
||||
{/* Provider header, editable name */}
|
||||
<div className="flex items-center gap-3">
|
||||
<ProviderIcon provider={provider} className="h-8 w-8" />
|
||||
<div className="min-w-0 flex-1">
|
||||
@@ -316,57 +255,60 @@ export const ProviderForm: FC<ProviderFormProps> = ({
|
||||
data-form-type="other"
|
||||
>
|
||||
<div className="space-y-5">
|
||||
{shouldShowAPIKeyField && (
|
||||
<ProviderField
|
||||
label="API Key"
|
||||
htmlFor={apiKeyInputId}
|
||||
required={requiresAPIKey}
|
||||
description={apiKeyDescription}
|
||||
>
|
||||
<div className="space-y-1.5">
|
||||
<Input
|
||||
id={apiKeyInputId}
|
||||
name="provider_api_token"
|
||||
type="password"
|
||||
autoComplete="off"
|
||||
data-1p-ignore
|
||||
data-lpignore="true"
|
||||
data-form-type="other"
|
||||
data-bwignore
|
||||
style={{ WebkitTextSecurity: "disc" } as CSSProperties}
|
||||
className="h-9 font-mono text-[13px]"
|
||||
placeholder={apiKeyPlaceholder}
|
||||
required={requiresAPIKey}
|
||||
value={apiKey}
|
||||
onFocus={handleApiKeyFocus}
|
||||
onChange={(event) => {
|
||||
setApiKey(event.target.value);
|
||||
setApiKeyTouched(true);
|
||||
setApiKeyModified(true);
|
||||
}}
|
||||
disabled={isDisabled}
|
||||
/>
|
||||
{isBedrockProvider &&
|
||||
providerState.hasManagedAPIKey &&
|
||||
!isDisabled &&
|
||||
(!apiKeyModified || apiKey !== "") && (
|
||||
<div className="flex justify-end">
|
||||
<button
|
||||
type="button"
|
||||
className="appearance-none border-0 bg-transparent p-0 text-xs text-content-link hover:cursor-pointer hover:underline"
|
||||
onClick={() => {
|
||||
setApiKey("");
|
||||
setApiKeyTouched(true);
|
||||
setApiKeyModified(true);
|
||||
}}
|
||||
>
|
||||
Clear stored token
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</ProviderField>
|
||||
)}
|
||||
<ProviderField
|
||||
label="API Key"
|
||||
htmlFor={apiKeyInputId}
|
||||
required={requiresAPIKey}
|
||||
description={apiKeyDescription}
|
||||
>
|
||||
<div className="space-y-1.5">
|
||||
<Input
|
||||
id={apiKeyInputId}
|
||||
name="provider_api_token"
|
||||
type="password"
|
||||
autoComplete="off"
|
||||
data-1p-ignore
|
||||
data-lpignore="true"
|
||||
data-form-type="other"
|
||||
data-bwignore
|
||||
style={{ WebkitTextSecurity: "disc" } as CSSProperties}
|
||||
className="h-9 font-mono text-[13px]"
|
||||
placeholder={apiKeyPlaceholder}
|
||||
required={requiresAPIKey}
|
||||
value={apiKey}
|
||||
onFocus={handleApiKeyFocus}
|
||||
onChange={(event) => {
|
||||
setApiKey(event.target.value);
|
||||
setApiKeyTouched(true);
|
||||
setApiKeyModified(true);
|
||||
}}
|
||||
disabled={isDisabled}
|
||||
/>
|
||||
{hasAPIKeyWhitespace && (
|
||||
<p className="m-0 text-xs text-content-destructive">
|
||||
API key must not contain leading or trailing whitespace.
|
||||
</p>
|
||||
)}
|
||||
{isBedrockProvider &&
|
||||
providerState.hasManagedAPIKey &&
|
||||
!isDisabled &&
|
||||
(!apiKeyModified || apiKey !== "") && (
|
||||
<div className="flex justify-end">
|
||||
<button
|
||||
type="button"
|
||||
className="appearance-none border-0 bg-transparent p-0 text-xs text-content-link hover:cursor-pointer hover:underline"
|
||||
onClick={() => {
|
||||
setApiKey("");
|
||||
setApiKeyTouched(true);
|
||||
setApiKeyModified(true);
|
||||
}}
|
||||
>
|
||||
Clear stored token
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</ProviderField>
|
||||
|
||||
<ProviderField
|
||||
label="Base URL"
|
||||
@@ -378,56 +320,14 @@ export const ProviderForm: FC<ProviderFormProps> = ({
|
||||
name="provider_base_url"
|
||||
className="h-9 text-[13px]"
|
||||
placeholder={baseURLPlaceholder}
|
||||
required
|
||||
autoComplete="off"
|
||||
value={baseURLValue}
|
||||
onChange={(event) => setBaseURLValue(event.target.value)}
|
||||
disabled={isDisabled}
|
||||
/>
|
||||
</ProviderField>
|
||||
|
||||
<div className="space-y-3 rounded-lg border border-solid border-border/70 bg-surface-secondary/30 p-4">
|
||||
<div className="space-y-1">
|
||||
<h3 className="m-0 text-[13px] font-semibold text-content-primary">
|
||||
Key policy
|
||||
</h3>
|
||||
<p className="m-0 text-xs text-content-secondary">
|
||||
Control which credential sources this provider can use.
|
||||
</p>
|
||||
</div>
|
||||
<div className="space-y-3">
|
||||
<ProviderToggleField
|
||||
label="Central API key"
|
||||
description="Use a deployment-managed API key for this provider"
|
||||
checked={centralAPIKeyEnabled}
|
||||
onCheckedChange={setCentralAPIKeyEnabled}
|
||||
disabled={isDisabled}
|
||||
/>
|
||||
<ProviderToggleField
|
||||
label="Allow user API keys"
|
||||
description="Let users provide their own API keys for this provider"
|
||||
checked={allowUserAPIKey}
|
||||
onCheckedChange={setAllowUserAPIKey}
|
||||
disabled={isDisabled}
|
||||
/>
|
||||
{shouldShowFallbackToggle && (
|
||||
<ProviderToggleField
|
||||
label="Use central key as fallback"
|
||||
description="When a user has not saved a personal key, fall back to the central API key"
|
||||
checked={effectiveFallback}
|
||||
onCheckedChange={setAllowCentralAPIKeyFallback}
|
||||
disabled={isDisabled}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
{!hasCredentialSource && (
|
||||
<p className="m-0 text-xs text-content-destructive">
|
||||
At least one credential source must be enabled
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Footer, pushed to bottom */}
|
||||
<div className="mt-auto pt-6">
|
||||
<hr className="mb-4 border-0 border-t border-solid border-border" />
|
||||
<div className="flex items-center justify-between">
|
||||
@@ -480,50 +380,6 @@ export const ProviderForm: FC<ProviderFormProps> = ({
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
interface ProviderToggleFieldProps {
|
||||
label: string;
|
||||
description: string;
|
||||
checked: boolean;
|
||||
onCheckedChange: (checked: boolean) => void;
|
||||
disabled?: boolean;
|
||||
}
|
||||
|
||||
const ProviderToggleField: FC<ProviderToggleFieldProps> = ({
|
||||
label,
|
||||
description,
|
||||
checked,
|
||||
onCheckedChange,
|
||||
disabled,
|
||||
}) => {
|
||||
const labelId = useId();
|
||||
const descriptionId = useId();
|
||||
|
||||
return (
|
||||
<div className="flex items-start justify-between gap-4">
|
||||
<div className="min-w-0 space-y-1">
|
||||
<p
|
||||
id={labelId}
|
||||
className="m-0 text-sm font-medium text-content-primary"
|
||||
>
|
||||
{label}
|
||||
</p>
|
||||
<p id={descriptionId} className="m-0 text-xs text-content-secondary">
|
||||
{description}
|
||||
</p>
|
||||
</div>
|
||||
<Switch
|
||||
checked={checked}
|
||||
onCheckedChange={onCheckedChange}
|
||||
disabled={disabled}
|
||||
aria-labelledby={labelId}
|
||||
aria-describedby={descriptionId}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
// Field wrapper.
|
||||
interface ProviderFieldProps {
|
||||
label: string;
|
||||
htmlFor?: string;
|
||||
|
||||
@@ -1,15 +1,69 @@
|
||||
import { CheckCircleIcon, ChevronRightIcon, CircleIcon } from "lucide-react";
|
||||
import {
|
||||
CheckCircleIcon,
|
||||
ChevronRightIcon,
|
||||
CircleIcon,
|
||||
PlusIcon,
|
||||
} from "lucide-react";
|
||||
import type { FC } from "react";
|
||||
import { useLocation, useNavigate, useSearchParams } from "react-router";
|
||||
import type * as TypesGen from "#/api/typesGenerated";
|
||||
import {
|
||||
type AIProviderType,
|
||||
AIProviderTypes,
|
||||
type CreateChatProviderConfigRequest,
|
||||
type UpdateChatProviderConfigRequest,
|
||||
} from "#/api/typesGenerated";
|
||||
import { Badge } from "#/components/Badge/Badge";
|
||||
import { Button } from "#/components/Button/Button";
|
||||
import {
|
||||
DropdownMenu,
|
||||
DropdownMenuContent,
|
||||
DropdownMenuItem,
|
||||
DropdownMenuTrigger,
|
||||
} from "#/components/DropdownMenu/DropdownMenu";
|
||||
import { cn } from "#/utils/cn";
|
||||
import { formatProviderLabel } from "../../utils/modelOptions";
|
||||
import { SectionHeader } from "../SectionHeader";
|
||||
import type { ProviderState } from "./ChatModelAdminPanel";
|
||||
import type {
|
||||
CreateProviderResult,
|
||||
ProviderState,
|
||||
} from "./ChatModelAdminPanel";
|
||||
import { getDefaultProviderBaseURL } from "./helpers";
|
||||
import { ProviderForm } from "./ProviderForm";
|
||||
import { ProviderIcon } from "./ProviderIcon";
|
||||
|
||||
type ProviderView = { mode: "list" } | { mode: "detail"; provider: string };
|
||||
type ProviderView =
|
||||
| { mode: "list" }
|
||||
| { mode: "detail"; provider: string }
|
||||
| { mode: "new"; providerType: AIProviderType };
|
||||
|
||||
const providerTypeOptions = AIProviderTypes.map((providerType) => ({
|
||||
providerType,
|
||||
label: formatProviderLabel(providerType),
|
||||
})).sort((a, b) => a.label.localeCompare(b.label));
|
||||
|
||||
const getAIProviderType = (
|
||||
value: string | null,
|
||||
): AIProviderType | undefined => {
|
||||
if (!value) {
|
||||
return undefined;
|
||||
}
|
||||
return AIProviderTypes.find((providerType) => providerType === value);
|
||||
};
|
||||
|
||||
const newProviderState = (providerType: AIProviderType): ProviderState => ({
|
||||
key: `new:${providerType}`,
|
||||
provider: providerType,
|
||||
label: formatProviderLabel(providerType),
|
||||
providerConfig: undefined,
|
||||
modelConfigs: [],
|
||||
catalogModelCount: 0,
|
||||
hasManagedAPIKey: false,
|
||||
hasCatalogAPIKey: false,
|
||||
hasEffectiveAPIKey: false,
|
||||
allowUserAPIKey: true,
|
||||
isEnvPreset: false,
|
||||
baseURL: getDefaultProviderBaseURL(providerType),
|
||||
});
|
||||
|
||||
interface ProvidersSectionProps {
|
||||
sectionLabel?: string;
|
||||
@@ -18,11 +72,11 @@ interface ProvidersSectionProps {
|
||||
providerConfigsUnavailable: boolean;
|
||||
isProviderMutationPending: boolean;
|
||||
onCreateProvider: (
|
||||
req: TypesGen.CreateChatProviderConfigRequest,
|
||||
) => Promise<unknown>;
|
||||
req: CreateChatProviderConfigRequest,
|
||||
) => Promise<CreateProviderResult>;
|
||||
onUpdateProvider: (
|
||||
providerConfigId: string,
|
||||
req: TypesGen.UpdateChatProviderConfigRequest,
|
||||
req: UpdateChatProviderConfigRequest,
|
||||
) => Promise<unknown>;
|
||||
onDeleteProvider: (providerConfigId: string) => Promise<void>;
|
||||
}
|
||||
@@ -48,11 +102,17 @@ export const ProvidersSection: FC<ProvidersSectionProps> = ({
|
||||
const view: ProviderView = (() => {
|
||||
const providerParam = searchParams.get("provider");
|
||||
if (providerParam) {
|
||||
const exists = providerStates.some((ps) => ps.provider === providerParam);
|
||||
const exists = providerStates.some((ps) => ps.key === providerParam);
|
||||
return exists
|
||||
? { mode: "detail", provider: providerParam }
|
||||
: { mode: "list" };
|
||||
}
|
||||
|
||||
const newProviderType = getAIProviderType(searchParams.get("newProvider"));
|
||||
if (newProviderType) {
|
||||
return { mode: "new", providerType: newProviderType };
|
||||
}
|
||||
|
||||
return { mode: "list" };
|
||||
})();
|
||||
|
||||
@@ -61,17 +121,31 @@ export const ProvidersSection: FC<ProvidersSectionProps> = ({
|
||||
setSearchParams((prev) => {
|
||||
const next = new URLSearchParams(prev);
|
||||
next.delete("provider");
|
||||
next.delete("newProvider");
|
||||
return next;
|
||||
});
|
||||
};
|
||||
|
||||
const openNewProviderView = (providerType: AIProviderType) => {
|
||||
setSearchParams(
|
||||
(prev) => {
|
||||
const next = new URLSearchParams(prev);
|
||||
next.delete("provider");
|
||||
next.set("newProvider", providerType);
|
||||
return next;
|
||||
},
|
||||
{ state: { pushed: true } },
|
||||
);
|
||||
};
|
||||
// Detail view.
|
||||
const detailProvider =
|
||||
view.mode === "detail"
|
||||
? providerStates.find((ps) => ps.provider === view.provider)
|
||||
: undefined;
|
||||
? providerStates.find((ps) => ps.key === view.provider)
|
||||
: view.mode === "new"
|
||||
? newProviderState(view.providerType)
|
||||
: undefined;
|
||||
|
||||
if (view.mode === "detail" && detailProvider) {
|
||||
if ((view.mode === "detail" || view.mode === "new") && detailProvider) {
|
||||
const providerFormKey = [
|
||||
detailProvider.provider,
|
||||
detailProvider.providerConfig?.id ?? "new",
|
||||
@@ -91,7 +165,21 @@ export const ProvidersSection: FC<ProvidersSectionProps> = ({
|
||||
providerState={detailProvider}
|
||||
providerConfigsUnavailable={providerConfigsUnavailable}
|
||||
isProviderMutationPending={isProviderMutationPending}
|
||||
onCreateProvider={onCreateProvider}
|
||||
onCreateProvider={async (req) => {
|
||||
const createdProvider = await onCreateProvider(req);
|
||||
if (createdProvider.id) {
|
||||
setSearchParams(
|
||||
(prev) => {
|
||||
const next = new URLSearchParams(prev);
|
||||
next.set("provider", createdProvider.id);
|
||||
next.delete("newProvider");
|
||||
return next;
|
||||
},
|
||||
{ replace: true, state: location.state },
|
||||
);
|
||||
}
|
||||
return createdProvider;
|
||||
}}
|
||||
onUpdateProvider={onUpdateProvider}
|
||||
onDeleteProvider={async (id) => {
|
||||
await onDeleteProvider(id);
|
||||
@@ -102,6 +190,7 @@ export const ProvidersSection: FC<ProvidersSectionProps> = ({
|
||||
(prev) => {
|
||||
const next = new URLSearchParams(prev);
|
||||
next.delete("provider");
|
||||
next.delete("newProvider");
|
||||
return next;
|
||||
},
|
||||
{ replace: true },
|
||||
@@ -114,66 +203,94 @@ export const ProvidersSection: FC<ProvidersSectionProps> = ({
|
||||
}
|
||||
|
||||
// List view.
|
||||
if (providerStates.length === 0) {
|
||||
return (
|
||||
<div className="rounded-lg border border-dashed border-border bg-surface-primary p-6 text-center text-[13px] text-content-secondary">
|
||||
No provider types were returned by the backend.
|
||||
</div>
|
||||
);
|
||||
}
|
||||
const addProviderAction = (
|
||||
<DropdownMenu>
|
||||
<DropdownMenuTrigger asChild>
|
||||
<Button
|
||||
variant="outline"
|
||||
size="sm"
|
||||
disabled={providerConfigsUnavailable || isProviderMutationPending}
|
||||
>
|
||||
<PlusIcon className="h-4 w-4" />
|
||||
Add provider
|
||||
</Button>
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent align="end" className="w-56">
|
||||
{providerTypeOptions.map(({ providerType, label }) => (
|
||||
<DropdownMenuItem
|
||||
key={providerType}
|
||||
className="gap-2"
|
||||
onSelect={() => openNewProviderView(providerType)}
|
||||
>
|
||||
<ProviderIcon provider={providerType} className="h-5 w-5" />
|
||||
<span>{label}</span>
|
||||
</DropdownMenuItem>
|
||||
))}
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
);
|
||||
|
||||
const header = sectionLabel ? (
|
||||
<SectionHeader
|
||||
label={sectionLabel}
|
||||
description={
|
||||
sectionDescription ?? "Configure AI providers to use with Agents."
|
||||
}
|
||||
action={addProviderAction}
|
||||
/>
|
||||
) : null;
|
||||
|
||||
return (
|
||||
<>
|
||||
{sectionLabel && (
|
||||
<SectionHeader
|
||||
label={sectionLabel}
|
||||
description={
|
||||
sectionDescription ?? "Configure AI providers to use with Agents."
|
||||
}
|
||||
/>
|
||||
)}
|
||||
<div>
|
||||
{providerStates.map((providerState, i) => (
|
||||
<button
|
||||
type="button"
|
||||
key={providerState.provider}
|
||||
aria-label={providerState.label}
|
||||
onClick={() => {
|
||||
setSearchParams(
|
||||
{ provider: providerState.provider },
|
||||
{ state: { pushed: true } },
|
||||
);
|
||||
}}
|
||||
className={cn(
|
||||
"flex w-full cursor-pointer items-center gap-3.5 border-0 bg-transparent p-0 px-3 py-3 text-left transition-colors hover:bg-surface-secondary/30",
|
||||
i > 0 && "border-0 border-t border-solid border-border/50",
|
||||
)}
|
||||
>
|
||||
<ProviderIcon
|
||||
provider={providerState.provider}
|
||||
className="h-8 w-8 shrink-0"
|
||||
/>
|
||||
<div className="min-w-0 flex-1 space-y-1">
|
||||
<div className="flex flex-wrap items-center gap-2">
|
||||
<span className="min-w-0 truncate text-[15px] font-medium text-content-primary text-left">
|
||||
{providerState.label}
|
||||
</span>
|
||||
{providerState.providerConfig?.allow_user_api_key && (
|
||||
<Badge size="xs" className="text-content-secondary">
|
||||
User keys enabled
|
||||
</Badge>
|
||||
)}
|
||||
{header}
|
||||
{providerStates.length === 0 ? (
|
||||
<div className="rounded-lg border border-dashed border-border bg-surface-primary p-6 text-center text-[13px] text-content-secondary">
|
||||
No providers have been added yet.
|
||||
</div>
|
||||
) : (
|
||||
<div>
|
||||
{providerStates.map((providerState, i) => (
|
||||
<button
|
||||
type="button"
|
||||
key={providerState.key}
|
||||
aria-label={providerState.label}
|
||||
onClick={() => {
|
||||
setSearchParams(
|
||||
{ provider: providerState.key },
|
||||
{ state: { pushed: true } },
|
||||
);
|
||||
}}
|
||||
className={cn(
|
||||
"flex w-full cursor-pointer items-center gap-3.5 border-0 bg-transparent p-0 px-3 py-3 text-left transition-colors hover:bg-surface-secondary/30",
|
||||
i > 0 && "border-0 border-t border-solid border-border/50",
|
||||
)}
|
||||
>
|
||||
<ProviderIcon
|
||||
provider={providerState.provider}
|
||||
className="h-8 w-8 shrink-0"
|
||||
/>
|
||||
<div className="min-w-0 flex-1 space-y-1">
|
||||
<div className="flex flex-wrap items-center gap-2">
|
||||
<span className="min-w-0 truncate text-[15px] font-medium text-content-primary text-left">
|
||||
{providerState.label}
|
||||
</span>
|
||||
{providerState.providerConfig?.allow_user_api_key && (
|
||||
<Badge size="xs" className="text-content-secondary">
|
||||
User keys enabled
|
||||
</Badge>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{providerState.hasEffectiveAPIKey ? (
|
||||
<CheckCircleIcon className="h-4 w-4 shrink-0 text-content-success" />
|
||||
) : (
|
||||
<CircleIcon className="h-4 w-4 shrink-0 text-content-secondary opacity-40" />
|
||||
)}
|
||||
<ChevronRightIcon className="h-5 w-5 shrink-0 text-content-secondary" />
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
{providerState.hasEffectiveAPIKey ? (
|
||||
<CheckCircleIcon className="h-4 w-4 shrink-0 text-content-success" />
|
||||
) : (
|
||||
<CircleIcon className="h-4 w-4 shrink-0 text-content-secondary opacity-40" />
|
||||
)}
|
||||
<ChevronRightIcon className="h-5 w-5 shrink-0 text-content-secondary" />
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -14,3 +14,28 @@ export function readOptionalString(value: unknown): string | undefined {
|
||||
export function normalizeProvider(provider: string): string {
|
||||
return provider.trim().toLowerCase();
|
||||
}
|
||||
|
||||
const canonicalProviderBaseURLs: Record<string, string> = {
|
||||
anthropic: "https://api.anthropic.com",
|
||||
google: "https://generativelanguage.googleapis.com/v1beta",
|
||||
openai: "https://api.openai.com/v1",
|
||||
openrouter: "https://openrouter.ai/api/v1",
|
||||
vercel: "https://ai-gateway.vercel.sh/v1",
|
||||
};
|
||||
|
||||
export function getDefaultProviderBaseURL(provider: string): string {
|
||||
return canonicalProviderBaseURLs[normalizeProvider(provider)] ?? "";
|
||||
}
|
||||
|
||||
export function getProviderBaseURLPlaceholder(provider: string): string {
|
||||
switch (normalizeProvider(provider)) {
|
||||
case "azure":
|
||||
return "https://<resource-name>.openai.azure.com";
|
||||
case "bedrock":
|
||||
return "https://bedrock-runtime.<region>.amazonaws.com";
|
||||
case "openai-compat":
|
||||
return "https://api.example.com/v1";
|
||||
default:
|
||||
return getDefaultProviderBaseURL(provider) || "https://api.example.com";
|
||||
}
|
||||
}
|
||||
|
||||
-55
@@ -1,55 +0,0 @@
|
||||
import { describe, expect, it } from "vitest";
|
||||
import {
|
||||
normalizeProviderPolicyDefaults,
|
||||
type ProviderConfigWithOptionalPolicyFields,
|
||||
} from "./providerPolicyDefaults";
|
||||
|
||||
const baseProviderConfig: ProviderConfigWithOptionalPolicyFields = {
|
||||
id: "provider-1",
|
||||
provider: "openai",
|
||||
display_name: "OpenAI",
|
||||
enabled: true,
|
||||
has_api_key: true,
|
||||
base_url: "https://api.openai.com/v1",
|
||||
source: "database",
|
||||
created_at: "2025-01-01T00:00:00Z",
|
||||
updated_at: "2025-01-01T00:00:00Z",
|
||||
};
|
||||
|
||||
describe("normalizeProviderPolicyDefaults", () => {
|
||||
it("passes through explicit policy fields unchanged", () => {
|
||||
const providerConfig: ProviderConfigWithOptionalPolicyFields = {
|
||||
...baseProviderConfig,
|
||||
central_api_key_enabled: false,
|
||||
allow_user_api_key: true,
|
||||
allow_central_api_key_fallback: true,
|
||||
};
|
||||
|
||||
expect(normalizeProviderPolicyDefaults(providerConfig)).toEqual(
|
||||
providerConfig,
|
||||
);
|
||||
});
|
||||
|
||||
it("defaults omitted policy fields to the expected values", () => {
|
||||
expect(normalizeProviderPolicyDefaults(baseProviderConfig)).toMatchObject({
|
||||
central_api_key_enabled: true,
|
||||
allow_user_api_key: false,
|
||||
allow_central_api_key_fallback: false,
|
||||
});
|
||||
});
|
||||
|
||||
it("defaults undefined policy fields to the expected values", () => {
|
||||
const providerConfig: ProviderConfigWithOptionalPolicyFields = {
|
||||
...baseProviderConfig,
|
||||
central_api_key_enabled: undefined,
|
||||
allow_user_api_key: undefined,
|
||||
allow_central_api_key_fallback: undefined,
|
||||
};
|
||||
|
||||
expect(normalizeProviderPolicyDefaults(providerConfig)).toMatchObject({
|
||||
central_api_key_enabled: true,
|
||||
allow_user_api_key: false,
|
||||
allow_central_api_key_fallback: false,
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -1,26 +0,0 @@
|
||||
import type * as TypesGen from "#/api/typesGenerated";
|
||||
|
||||
type ProviderPolicyFields = Pick<
|
||||
TypesGen.ChatProviderConfig,
|
||||
| "central_api_key_enabled"
|
||||
| "allow_user_api_key"
|
||||
| "allow_central_api_key_fallback"
|
||||
>;
|
||||
|
||||
export type ProviderConfigWithOptionalPolicyFields = Omit<
|
||||
TypesGen.ChatProviderConfig,
|
||||
keyof ProviderPolicyFields
|
||||
> &
|
||||
Partial<ProviderPolicyFields>;
|
||||
|
||||
export function normalizeProviderPolicyDefaults(
|
||||
providerConfig: ProviderConfigWithOptionalPolicyFields,
|
||||
): TypesGen.ChatProviderConfig {
|
||||
return {
|
||||
...providerConfig,
|
||||
central_api_key_enabled: providerConfig.central_api_key_enabled ?? true,
|
||||
allow_user_api_key: providerConfig.allow_user_api_key ?? false,
|
||||
allow_central_api_key_fallback:
|
||||
providerConfig.allow_central_api_key_fallback ?? false,
|
||||
};
|
||||
}
|
||||
@@ -223,6 +223,7 @@ describe("countConfiguredProviderConfigs", () => {
|
||||
|
||||
describe("formatProviderLabel", () => {
|
||||
it("formats OpenAI compatible providers", () => {
|
||||
expect(formatProviderLabel("openai-compat")).toBe("OpenAI-compatible");
|
||||
expect(formatProviderLabel("openai-compatible")).toBe("OpenAI-compatible");
|
||||
});
|
||||
});
|
||||
|
||||
@@ -220,33 +220,7 @@ export const getProviderForModelOption = (
|
||||
): string | undefined =>
|
||||
modelOptions.find((option) => option.id === selectedModel)?.provider;
|
||||
|
||||
export const formatProviderLabel = (provider: string): string => {
|
||||
const normalized = provider.trim().toLowerCase();
|
||||
switch (normalized) {
|
||||
case "openai":
|
||||
return "OpenAI";
|
||||
case "anthropic":
|
||||
return "Anthropic";
|
||||
case "azure":
|
||||
return "Azure OpenAI";
|
||||
case "bedrock":
|
||||
return "AWS Bedrock";
|
||||
case "google":
|
||||
return "Google";
|
||||
case "openai-compatible":
|
||||
case "openai_compatible":
|
||||
return "OpenAI-compatible";
|
||||
case "openrouter":
|
||||
return "OpenRouter";
|
||||
case "vercel":
|
||||
return "Vercel AI Gateway";
|
||||
default:
|
||||
if (!normalized) {
|
||||
return "Unknown";
|
||||
}
|
||||
return `${normalized[0].toUpperCase()}${normalized.slice(1)}`;
|
||||
}
|
||||
};
|
||||
export { formatProviderLabel } from "#/utils/aiProviders";
|
||||
|
||||
export const getModelSelectorPlaceholder = (
|
||||
modelOptions: readonly ModelSelectorOption[],
|
||||
|
||||
@@ -0,0 +1,28 @@
|
||||
export const formatProviderLabel = (provider: string): string => {
|
||||
const normalized = provider.trim().toLowerCase();
|
||||
switch (normalized) {
|
||||
case "openai":
|
||||
return "OpenAI";
|
||||
case "anthropic":
|
||||
return "Anthropic";
|
||||
case "azure":
|
||||
return "Azure OpenAI";
|
||||
case "bedrock":
|
||||
return "AWS Bedrock";
|
||||
case "google":
|
||||
return "Google";
|
||||
case "openai-compat":
|
||||
case "openai-compatible":
|
||||
case "openai_compatible":
|
||||
return "OpenAI-compatible";
|
||||
case "openrouter":
|
||||
return "OpenRouter";
|
||||
case "vercel":
|
||||
return "Vercel AI Gateway";
|
||||
default:
|
||||
if (!normalized) {
|
||||
return "Unknown";
|
||||
}
|
||||
return `${normalized[0].toUpperCase()}${normalized.slice(1)}`;
|
||||
}
|
||||
};
|
||||
Reference in New Issue
Block a user