From 40878eeba4b3d2ab11edc97fc8e5a0d1f92ba585 Mon Sep 17 00:00:00 2001 From: Michael Suchacz <203725896+ibetitsmike@users.noreply.github.com> Date: Fri, 22 May 2026 02:16:01 +0200 Subject: [PATCH] feat: add AI provider schema expansion (#25412) --- coderd/database/check_constraint.go | 1 + coderd/database/dbauthz/dbauthz.go | 79 +++++ coderd/database/dbauthz/dbauthz_test.go | 58 ++++ coderd/database/dbmetrics/querymetrics.go | 64 ++++ coderd/database/dbmock/dbmock.go | 118 +++++++ coderd/database/dump.sql | 46 +++ coderd/database/foreign_key_constraint.go | 4 + ...000503_ai_providers_schema_expand.down.sql | 46 +++ .../000503_ai_providers_schema_expand.up.sql | 72 ++++ .../000504_ai_providers_backfill.down.sql | 48 +++ .../000504_ai_providers_backfill.up.sql | 78 +++++ coderd/database/migrations/migrate_test.go | 316 ++++++++++++++++++ .../000503_ai_providers_schema_expand.up.sql | 11 + coderd/database/models.go | 14 + coderd/database/querier.go | 13 + coderd/database/queries.sql.go | 308 ++++++++++++++++- .../queries/user_ai_provider_keys.sql | 100 ++++++ coderd/database/sqlc.yaml | 1 + coderd/database/unique_constraint.go | 2 + codersdk/aiproviders.go | 21 +- enterprise/cli/server_dbcrypt_test.go | 50 +++ enterprise/dbcrypt/cliutil.go | 45 +++ enterprise/dbcrypt/dbcrypt.go | 92 +++++ enterprise/dbcrypt/dbcrypt_internal_test.go | 162 +++++++++ site/src/api/typesGenerated.ts | 6 +- 25 files changed, 1737 insertions(+), 18 deletions(-) create mode 100644 coderd/database/migrations/000503_ai_providers_schema_expand.down.sql create mode 100644 coderd/database/migrations/000503_ai_providers_schema_expand.up.sql create mode 100644 coderd/database/migrations/000504_ai_providers_backfill.down.sql create mode 100644 coderd/database/migrations/000504_ai_providers_backfill.up.sql create mode 100644 coderd/database/migrations/testdata/fixtures/000503_ai_providers_schema_expand.up.sql create mode 100644 coderd/database/queries/user_ai_provider_keys.sql diff --git a/coderd/database/check_constraint.go b/coderd/database/check_constraint.go index f2d49c326f..c5e109af38 100644 --- a/coderd/database/check_constraint.go +++ b/coderd/database/check_constraint.go @@ -44,6 +44,7 @@ const ( CheckTelemetryLockEventTypeConstraint CheckConstraint = "telemetry_lock_event_type_constraint" // telemetry_locks CheckValidationMonotonicOrder CheckConstraint = "validation_monotonic_order" // template_version_parameters CheckUsageEventTypeCheck CheckConstraint = "usage_event_type_check" // usage_events + CheckUserAiProviderKeysAPIKeyCheck CheckConstraint = "user_ai_provider_keys_api_key_check" // user_ai_provider_keys CheckUserChatProviderKeysAPIKeyCheck CheckConstraint = "user_chat_provider_keys_api_key_check" // user_chat_provider_keys CheckUserSkillsContentSize CheckConstraint = "user_skills_content_size" // user_skills CheckUserSkillsDescriptionSize CheckConstraint = "user_skills_description_size" // user_skills diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index b1464c5ca6..a1ac393df5 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -2258,6 +2258,24 @@ func (q *querier) DeleteTask(ctx context.Context, arg database.DeleteTaskParams) return q.db.DeleteTask(ctx, arg) } +func (q *querier) DeleteUserAIProviderKey(ctx context.Context, arg database.DeleteUserAIProviderKeyParams) error { + u, err := q.db.GetUserByID(ctx, arg.UserID) + if err != nil { + return err + } + if err := q.authorizeContext(ctx, policy.ActionUpdatePersonal, u); err != nil { + return err + } + return q.db.DeleteUserAIProviderKey(ctx, arg) +} + +func (q *querier) DeleteUserAIProviderKeysByProviderID(ctx context.Context, aiProviderID uuid.UUID) error { + if err := q.authorizeContext(ctx, policy.ActionDelete, rbac.ResourceAIProvider); err != nil { + return err + } + return q.db.DeleteUserAIProviderKeysByProviderID(ctx, aiProviderID) +} + func (q *querier) DeleteUserChatCompactionThreshold(ctx context.Context, arg database.DeleteUserChatCompactionThresholdParams) error { u, err := q.db.GetUserByID(ctx, arg.UserID) if err != nil { @@ -4471,6 +4489,35 @@ func (q *querier) GetUnexpiredLicenses(ctx context.Context) ([]database.License, return q.db.GetUnexpiredLicenses(ctx) } +func (q *querier) GetUserAIProviderKeyByProviderID(ctx context.Context, arg database.GetUserAIProviderKeyByProviderIDParams) (database.UserAiProviderKey, error) { + u, err := q.db.GetUserByID(ctx, arg.UserID) + if err != nil { + return database.UserAiProviderKey{}, err + } + if err := q.authorizeContext(ctx, policy.ActionReadPersonal, u); err != nil { + return database.UserAiProviderKey{}, err + } + return q.db.GetUserAIProviderKeyByProviderID(ctx, arg) +} + +func (q *querier) GetUserAIProviderKeys(ctx context.Context) ([]database.UserAiProviderKey, error) { + if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceAIProvider); err != nil { + return nil, err + } + return q.db.GetUserAIProviderKeys(ctx) +} + +func (q *querier) GetUserAIProviderKeysByUserID(ctx context.Context, userID uuid.UUID) ([]database.UserAiProviderKey, error) { + u, err := q.db.GetUserByID(ctx, userID) + if err != nil { + return nil, err + } + if err := q.authorizeContext(ctx, policy.ActionReadPersonal, u); err != nil { + return nil, err + } + return q.db.GetUserAIProviderKeysByUserID(ctx, userID) +} + func (q *querier) GetUserAISeatStates(ctx context.Context, userIDs []uuid.UUID) ([]uuid.UUID, error) { if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceAiSeat); err != nil { return nil, err @@ -6795,6 +6842,16 @@ func (q *querier) UpdateEncryptedAIProviderSettings(ctx context.Context, arg dat return q.db.UpdateEncryptedAIProviderSettings(ctx, arg) } +func (q *querier) UpdateEncryptedUserAIProviderKey(ctx context.Context, arg database.UpdateEncryptedUserAIProviderKeyParams) (database.UserAiProviderKey, error) { + // Encrypted user-owned provider keys can be rewritten on any row so + // dbcrypt rotation can move every key to a new digest. This is a + // maintenance path, not the self-service user key API. + if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceAIProvider); err != nil { + return database.UserAiProviderKey{}, err + } + return q.db.UpdateEncryptedUserAIProviderKey(ctx, arg) +} + func (q *querier) UpdateExternalAuthLink(ctx context.Context, arg database.UpdateExternalAuthLinkParams) (database.ExternalAuthLink, error) { fetch := func(ctx context.Context, arg database.UpdateExternalAuthLinkParams) (database.ExternalAuthLink, error) { return q.db.GetExternalAuthLink(ctx, database.GetExternalAuthLinkParams{UserID: arg.UserID, ProviderID: arg.ProviderID}) @@ -7288,6 +7345,17 @@ func (q *querier) UpdateUsageEventsPostPublish(ctx context.Context, arg database return q.db.UpdateUsageEventsPostPublish(ctx, arg) } +func (q *querier) UpdateUserAIProviderKey(ctx context.Context, arg database.UpdateUserAIProviderKeyParams) (database.UserAiProviderKey, error) { + u, err := q.db.GetUserByID(ctx, arg.UserID) + if err != nil { + return database.UserAiProviderKey{}, err + } + if err := q.authorizeContext(ctx, policy.ActionUpdatePersonal, u); err != nil { + return database.UserAiProviderKey{}, err + } + return q.db.UpdateUserAIProviderKey(ctx, arg) +} + func (q *querier) UpdateUserAgentChatSendShortcut(ctx context.Context, arg database.UpdateUserAgentChatSendShortcutParams) (string, error) { user, err := q.db.GetUserByID(ctx, arg.UserID) if err != nil { @@ -8230,6 +8298,17 @@ func (q *querier) UpsertTemplateUsageStats(ctx context.Context) error { return q.db.UpsertTemplateUsageStats(ctx) } +func (q *querier) UpsertUserAIProviderKey(ctx context.Context, arg database.UpsertUserAIProviderKeyParams) (database.UserAiProviderKey, error) { + u, err := q.db.GetUserByID(ctx, arg.UserID) + if err != nil { + return database.UserAiProviderKey{}, err + } + if err := q.authorizeContext(ctx, policy.ActionUpdatePersonal, u); err != nil { + return database.UserAiProviderKey{}, err + } + return q.db.UpsertUserAIProviderKey(ctx, arg) +} + func (q *querier) UpsertUserChatDebugLoggingEnabled(ctx context.Context, arg database.UpsertUserChatDebugLoggingEnabledParams) error { u, err := q.db.GetUserByID(ctx, arg.UserID) if err != nil { diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index 9004ff428e..1b24cf06b8 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -2963,6 +2963,49 @@ func (s *MethodTestSuite) TestUser() { dbm.EXPECT().UpsertUserChatProviderKey(gomock.Any(), arg).Return(key, nil).AnyTimes() check.Args(arg).Asserts(u, policy.ActionUpdatePersonal).Returns(key) })) + s.Run("GetUserAIProviderKeyByProviderID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + u := testutil.Fake(s.T(), faker, database.User{}) + arg := database.GetUserAIProviderKeyByProviderIDParams{UserID: u.ID, AIProviderID: uuid.New()} + key := testutil.Fake(s.T(), faker, database.UserAiProviderKey{UserID: u.ID, AIProviderID: arg.AIProviderID}) + dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes() + dbm.EXPECT().GetUserAIProviderKeyByProviderID(gomock.Any(), arg).Return(key, nil).AnyTimes() + check.Args(arg).Asserts(u, policy.ActionReadPersonal).Returns(key) + })) + s.Run("GetUserAIProviderKeysByUserID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + u := testutil.Fake(s.T(), faker, database.User{}) + key := testutil.Fake(s.T(), faker, database.UserAiProviderKey{UserID: u.ID}) + dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes() + dbm.EXPECT().GetUserAIProviderKeysByUserID(gomock.Any(), u.ID).Return([]database.UserAiProviderKey{key}, nil).AnyTimes() + check.Args(u.ID).Asserts(u, policy.ActionReadPersonal).Returns([]database.UserAiProviderKey{key}) + })) + s.Run("DeleteUserAIProviderKeysByProviderID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + providerID := uuid.New() + dbm.EXPECT().DeleteUserAIProviderKeysByProviderID(gomock.Any(), providerID).Return(nil).AnyTimes() + check.Args(providerID).Asserts(rbac.ResourceAIProvider, policy.ActionDelete).Returns() + })) + s.Run("DeleteUserAIProviderKey", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + u := testutil.Fake(s.T(), faker, database.User{}) + arg := database.DeleteUserAIProviderKeyParams{UserID: u.ID, AIProviderID: uuid.New()} + dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes() + dbm.EXPECT().DeleteUserAIProviderKey(gomock.Any(), arg).Return(nil).AnyTimes() + check.Args(arg).Asserts(u, policy.ActionUpdatePersonal).Returns() + })) + s.Run("UpdateUserAIProviderKey", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + u := testutil.Fake(s.T(), faker, database.User{}) + arg := database.UpdateUserAIProviderKeyParams{UserID: u.ID, AIProviderID: uuid.New(), APIKey: "updated-api-key"} + key := testutil.Fake(s.T(), faker, database.UserAiProviderKey{UserID: u.ID, AIProviderID: arg.AIProviderID, APIKey: arg.APIKey}) + dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes() + dbm.EXPECT().UpdateUserAIProviderKey(gomock.Any(), arg).Return(key, nil).AnyTimes() + check.Args(arg).Asserts(u, policy.ActionUpdatePersonal).Returns(key) + })) + s.Run("UpsertUserAIProviderKey", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + u := testutil.Fake(s.T(), faker, database.User{}) + arg := database.UpsertUserAIProviderKeyParams{UserID: u.ID, AIProviderID: uuid.New(), APIKey: "upserted-api-key"} + key := testutil.Fake(s.T(), faker, database.UserAiProviderKey{UserID: u.ID, AIProviderID: arg.AIProviderID, APIKey: arg.APIKey}) + dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes() + dbm.EXPECT().UpsertUserAIProviderKey(gomock.Any(), arg).Return(key, nil).AnyTimes() + check.Args(arg).Asserts(u, policy.ActionUpdatePersonal).Returns(key) + })) s.Run("GetUserChatDebugLoggingEnabled", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { u := testutil.Fake(s.T(), faker, database.User{}) dbm.EXPECT().GetUserByID(gomock.Any(), u.ID).Return(u, nil).AnyTimes() @@ -6557,6 +6600,21 @@ func (s *MethodTestSuite) TestAIBridge() { dbm.EXPECT().UpdateEncryptedAIProviderKey(gomock.Any(), arg).Return(key, nil).AnyTimes() check.Args(arg).Asserts(rbac.ResourceAIProvider, policy.ActionUpdate).Returns(key) })) + s.Run("GetUserAIProviderKeys", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + keyA := testutil.Fake(s.T(), faker, database.UserAiProviderKey{}) + keyB := testutil.Fake(s.T(), faker, database.UserAiProviderKey{}) + dbm.EXPECT().GetUserAIProviderKeys(gomock.Any()).Return([]database.UserAiProviderKey{keyA, keyB}, nil).AnyTimes() + check.Args().Asserts(rbac.ResourceAIProvider, policy.ActionRead).Returns([]database.UserAiProviderKey{keyA, keyB}) + })) + s.Run("UpdateEncryptedUserAIProviderKey", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + key := testutil.Fake(s.T(), faker, database.UserAiProviderKey{}) + arg := database.UpdateEncryptedUserAIProviderKeyParams{ + ID: key.ID, + APIKey: "encrypted-api-key", + } + dbm.EXPECT().UpdateEncryptedUserAIProviderKey(gomock.Any(), arg).Return(key, nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceAIProvider, policy.ActionUpdate).Returns(key) + })) } func (s *MethodTestSuite) TestTelemetry() { diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index fc07c3e6a8..1a21da8143 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -785,6 +785,22 @@ func (m queryMetricsStore) DeleteTask(ctx context.Context, arg database.DeleteTa return r0, r1 } +func (m queryMetricsStore) DeleteUserAIProviderKey(ctx context.Context, arg database.DeleteUserAIProviderKeyParams) error { + start := time.Now() + r0 := m.s.DeleteUserAIProviderKey(ctx, arg) + m.queryLatencies.WithLabelValues("DeleteUserAIProviderKey").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteUserAIProviderKey").Inc() + return r0 +} + +func (m queryMetricsStore) DeleteUserAIProviderKeysByProviderID(ctx context.Context, aiProviderID uuid.UUID) error { + start := time.Now() + r0 := m.s.DeleteUserAIProviderKeysByProviderID(ctx, aiProviderID) + m.queryLatencies.WithLabelValues("DeleteUserAIProviderKeysByProviderID").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteUserAIProviderKeysByProviderID").Inc() + return r0 +} + func (m queryMetricsStore) DeleteUserChatCompactionThreshold(ctx context.Context, arg database.DeleteUserChatCompactionThresholdParams) error { start := time.Now() r0 := m.s.DeleteUserChatCompactionThreshold(ctx, arg) @@ -2897,6 +2913,30 @@ func (m queryMetricsStore) GetUnexpiredLicenses(ctx context.Context) ([]database return r0, r1 } +func (m queryMetricsStore) GetUserAIProviderKeyByProviderID(ctx context.Context, arg database.GetUserAIProviderKeyByProviderIDParams) (database.UserAiProviderKey, error) { + start := time.Now() + r0, r1 := m.s.GetUserAIProviderKeyByProviderID(ctx, arg) + m.queryLatencies.WithLabelValues("GetUserAIProviderKeyByProviderID").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetUserAIProviderKeyByProviderID").Inc() + return r0, r1 +} + +func (m queryMetricsStore) GetUserAIProviderKeys(ctx context.Context) ([]database.UserAiProviderKey, error) { + start := time.Now() + r0, r1 := m.s.GetUserAIProviderKeys(ctx) + m.queryLatencies.WithLabelValues("GetUserAIProviderKeys").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetUserAIProviderKeys").Inc() + return r0, r1 +} + +func (m queryMetricsStore) GetUserAIProviderKeysByUserID(ctx context.Context, userID uuid.UUID) ([]database.UserAiProviderKey, error) { + start := time.Now() + r0, r1 := m.s.GetUserAIProviderKeysByUserID(ctx, userID) + m.queryLatencies.WithLabelValues("GetUserAIProviderKeysByUserID").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetUserAIProviderKeysByUserID").Inc() + return r0, r1 +} + func (m queryMetricsStore) GetUserAISeatStates(ctx context.Context, userIds []uuid.UUID) ([]uuid.UUID, error) { start := time.Now() r0, r1 := m.s.GetUserAISeatStates(ctx, userIds) @@ -4897,6 +4937,14 @@ func (m queryMetricsStore) UpdateEncryptedAIProviderSettings(ctx context.Context return r0, r1 } +func (m queryMetricsStore) UpdateEncryptedUserAIProviderKey(ctx context.Context, arg database.UpdateEncryptedUserAIProviderKeyParams) (database.UserAiProviderKey, error) { + start := time.Now() + r0, r1 := m.s.UpdateEncryptedUserAIProviderKey(ctx, arg) + m.queryLatencies.WithLabelValues("UpdateEncryptedUserAIProviderKey").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateEncryptedUserAIProviderKey").Inc() + return r0, r1 +} + func (m queryMetricsStore) UpdateExternalAuthLink(ctx context.Context, arg database.UpdateExternalAuthLinkParams) (database.ExternalAuthLink, error) { start := time.Now() r0, r1 := m.s.UpdateExternalAuthLink(ctx, arg) @@ -5225,6 +5273,14 @@ func (m queryMetricsStore) UpdateUsageEventsPostPublish(ctx context.Context, arg return r0 } +func (m queryMetricsStore) UpdateUserAIProviderKey(ctx context.Context, arg database.UpdateUserAIProviderKeyParams) (database.UserAiProviderKey, error) { + start := time.Now() + r0, r1 := m.s.UpdateUserAIProviderKey(ctx, arg) + m.queryLatencies.WithLabelValues("UpdateUserAIProviderKey").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpdateUserAIProviderKey").Inc() + return r0, r1 +} + func (m queryMetricsStore) UpdateUserAgentChatSendShortcut(ctx context.Context, arg database.UpdateUserAgentChatSendShortcutParams) (string, error) { start := time.Now() r0, r1 := m.s.UpdateUserAgentChatSendShortcut(ctx, arg) @@ -6009,6 +6065,14 @@ func (m queryMetricsStore) UpsertTemplateUsageStats(ctx context.Context) error { return r0 } +func (m queryMetricsStore) UpsertUserAIProviderKey(ctx context.Context, arg database.UpsertUserAIProviderKeyParams) (database.UserAiProviderKey, error) { + start := time.Now() + r0, r1 := m.s.UpsertUserAIProviderKey(ctx, arg) + m.queryLatencies.WithLabelValues("UpsertUserAIProviderKey").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertUserAIProviderKey").Inc() + return r0, r1 +} + func (m queryMetricsStore) UpsertUserChatDebugLoggingEnabled(ctx context.Context, arg database.UpsertUserChatDebugLoggingEnabledParams) error { start := time.Now() r0 := m.s.UpsertUserChatDebugLoggingEnabled(ctx, arg) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index 1832525d7c..92d9101f6a 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -1334,6 +1334,34 @@ func (mr *MockStoreMockRecorder) DeleteTask(ctx, arg any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteTask", reflect.TypeOf((*MockStore)(nil).DeleteTask), ctx, arg) } +// DeleteUserAIProviderKey mocks base method. +func (m *MockStore) DeleteUserAIProviderKey(ctx context.Context, arg database.DeleteUserAIProviderKeyParams) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteUserAIProviderKey", ctx, arg) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteUserAIProviderKey indicates an expected call of DeleteUserAIProviderKey. +func (mr *MockStoreMockRecorder) DeleteUserAIProviderKey(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserAIProviderKey", reflect.TypeOf((*MockStore)(nil).DeleteUserAIProviderKey), ctx, arg) +} + +// DeleteUserAIProviderKeysByProviderID mocks base method. +func (m *MockStore) DeleteUserAIProviderKeysByProviderID(ctx context.Context, aiProviderID uuid.UUID) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteUserAIProviderKeysByProviderID", ctx, aiProviderID) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteUserAIProviderKeysByProviderID indicates an expected call of DeleteUserAIProviderKeysByProviderID. +func (mr *MockStoreMockRecorder) DeleteUserAIProviderKeysByProviderID(ctx, aiProviderID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteUserAIProviderKeysByProviderID", reflect.TypeOf((*MockStore)(nil).DeleteUserAIProviderKeysByProviderID), ctx, aiProviderID) +} + // DeleteUserChatCompactionThreshold mocks base method. func (m *MockStore) DeleteUserChatCompactionThreshold(ctx context.Context, arg database.DeleteUserChatCompactionThresholdParams) error { m.ctrl.T.Helper() @@ -5431,6 +5459,51 @@ func (mr *MockStoreMockRecorder) GetUnexpiredLicenses(ctx any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUnexpiredLicenses", reflect.TypeOf((*MockStore)(nil).GetUnexpiredLicenses), ctx) } +// GetUserAIProviderKeyByProviderID mocks base method. +func (m *MockStore) GetUserAIProviderKeyByProviderID(ctx context.Context, arg database.GetUserAIProviderKeyByProviderIDParams) (database.UserAiProviderKey, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetUserAIProviderKeyByProviderID", ctx, arg) + ret0, _ := ret[0].(database.UserAiProviderKey) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetUserAIProviderKeyByProviderID indicates an expected call of GetUserAIProviderKeyByProviderID. +func (mr *MockStoreMockRecorder) GetUserAIProviderKeyByProviderID(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserAIProviderKeyByProviderID", reflect.TypeOf((*MockStore)(nil).GetUserAIProviderKeyByProviderID), ctx, arg) +} + +// GetUserAIProviderKeys mocks base method. +func (m *MockStore) GetUserAIProviderKeys(ctx context.Context) ([]database.UserAiProviderKey, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetUserAIProviderKeys", ctx) + ret0, _ := ret[0].([]database.UserAiProviderKey) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetUserAIProviderKeys indicates an expected call of GetUserAIProviderKeys. +func (mr *MockStoreMockRecorder) GetUserAIProviderKeys(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserAIProviderKeys", reflect.TypeOf((*MockStore)(nil).GetUserAIProviderKeys), ctx) +} + +// GetUserAIProviderKeysByUserID mocks base method. +func (m *MockStore) GetUserAIProviderKeysByUserID(ctx context.Context, userID uuid.UUID) ([]database.UserAiProviderKey, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetUserAIProviderKeysByUserID", ctx, userID) + ret0, _ := ret[0].([]database.UserAiProviderKey) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetUserAIProviderKeysByUserID indicates an expected call of GetUserAIProviderKeysByUserID. +func (mr *MockStoreMockRecorder) GetUserAIProviderKeysByUserID(ctx, userID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserAIProviderKeysByUserID", reflect.TypeOf((*MockStore)(nil).GetUserAIProviderKeysByUserID), ctx, userID) +} + // GetUserAISeatStates mocks base method. func (m *MockStore) GetUserAISeatStates(ctx context.Context, userIds []uuid.UUID) ([]uuid.UUID, error) { m.ctrl.T.Helper() @@ -9266,6 +9339,21 @@ func (mr *MockStoreMockRecorder) UpdateEncryptedAIProviderSettings(ctx, arg any) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateEncryptedAIProviderSettings", reflect.TypeOf((*MockStore)(nil).UpdateEncryptedAIProviderSettings), ctx, arg) } +// UpdateEncryptedUserAIProviderKey mocks base method. +func (m *MockStore) UpdateEncryptedUserAIProviderKey(ctx context.Context, arg database.UpdateEncryptedUserAIProviderKeyParams) (database.UserAiProviderKey, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateEncryptedUserAIProviderKey", ctx, arg) + ret0, _ := ret[0].(database.UserAiProviderKey) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpdateEncryptedUserAIProviderKey indicates an expected call of UpdateEncryptedUserAIProviderKey. +func (mr *MockStoreMockRecorder) UpdateEncryptedUserAIProviderKey(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateEncryptedUserAIProviderKey", reflect.TypeOf((*MockStore)(nil).UpdateEncryptedUserAIProviderKey), ctx, arg) +} + // UpdateExternalAuthLink mocks base method. func (m *MockStore) UpdateExternalAuthLink(ctx context.Context, arg database.UpdateExternalAuthLinkParams) (database.ExternalAuthLink, error) { m.ctrl.T.Helper() @@ -9857,6 +9945,21 @@ func (mr *MockStoreMockRecorder) UpdateUsageEventsPostPublish(ctx, arg any) *gom return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUsageEventsPostPublish", reflect.TypeOf((*MockStore)(nil).UpdateUsageEventsPostPublish), ctx, arg) } +// UpdateUserAIProviderKey mocks base method. +func (m *MockStore) UpdateUserAIProviderKey(ctx context.Context, arg database.UpdateUserAIProviderKeyParams) (database.UserAiProviderKey, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpdateUserAIProviderKey", ctx, arg) + ret0, _ := ret[0].(database.UserAiProviderKey) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpdateUserAIProviderKey indicates an expected call of UpdateUserAIProviderKey. +func (mr *MockStoreMockRecorder) UpdateUserAIProviderKey(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateUserAIProviderKey", reflect.TypeOf((*MockStore)(nil).UpdateUserAIProviderKey), ctx, arg) +} + // UpdateUserAgentChatSendShortcut mocks base method. func (m *MockStore) UpdateUserAgentChatSendShortcut(ctx context.Context, arg database.UpdateUserAgentChatSendShortcutParams) (string, error) { m.ctrl.T.Helper() @@ -11270,6 +11373,21 @@ func (mr *MockStoreMockRecorder) UpsertTemplateUsageStats(ctx any) *gomock.Call return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertTemplateUsageStats", reflect.TypeOf((*MockStore)(nil).UpsertTemplateUsageStats), ctx) } +// UpsertUserAIProviderKey mocks base method. +func (m *MockStore) UpsertUserAIProviderKey(ctx context.Context, arg database.UpsertUserAIProviderKeyParams) (database.UserAiProviderKey, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpsertUserAIProviderKey", ctx, arg) + ret0, _ := ret[0].(database.UserAiProviderKey) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpsertUserAIProviderKey indicates an expected call of UpsertUserAIProviderKey. +func (mr *MockStoreMockRecorder) UpsertUserAIProviderKey(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertUserAIProviderKey", reflect.TypeOf((*MockStore)(nil).UpsertUserAIProviderKey), ctx, arg) +} + // UpsertUserChatDebugLoggingEnabled mocks base method. func (m *MockStore) UpsertUserChatDebugLoggingEnabled(ctx context.Context, arg database.UpsertUserChatDebugLoggingEnabledParams) error { m.ctrl.T.Helper() diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index 89148fb938..e502f69aed 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -795,6 +795,12 @@ BEGIN DELETE FROM user_secrets WHERE user_id = OLD.id; + -- Remove their user AI provider keys. + -- user_ai_provider_keys.user_id has ON DELETE CASCADE, but soft-delete + -- does not remove the users row so the FK cascade never fires. + DELETE FROM user_ai_provider_keys + WHERE user_id = OLD.id; + -- Remove their organization memberships. -- This also triggers group membership cleanup via -- trigger_delete_group_members_on_org_member_delete. @@ -1524,6 +1530,7 @@ CREATE TABLE chat_model_configs ( context_limit bigint NOT NULL, compression_threshold integer NOT NULL, options jsonb DEFAULT '{}'::jsonb NOT NULL, + ai_provider_id uuid, CONSTRAINT chat_model_configs_compression_threshold_check CHECK (((compression_threshold >= 0) AND (compression_threshold <= 100))), CONSTRAINT chat_model_configs_context_limit_check CHECK ((context_limit > 0)) ); @@ -3016,6 +3023,23 @@ COMMENT ON TABLE usage_events_daily IS 'usage_events_daily is a daily rollup of COMMENT ON COLUMN usage_events_daily.day IS 'The date of the summed usage events, always in UTC.'; +CREATE TABLE user_ai_provider_keys ( + id uuid DEFAULT gen_random_uuid() NOT NULL, + user_id uuid NOT NULL, + ai_provider_id uuid NOT NULL, + api_key text NOT NULL, + api_key_key_id text, + created_at timestamp with time zone DEFAULT now() NOT NULL, + updated_at timestamp with time zone DEFAULT now() NOT NULL, + CONSTRAINT user_ai_provider_keys_api_key_check CHECK ((api_key <> ''::text)) +); + +COMMENT ON TABLE user_ai_provider_keys IS 'User-owned API keys associated with AI providers. These keys are used only when BYOK is enabled.'; + +COMMENT ON COLUMN user_ai_provider_keys.api_key IS 'User-owned API key used to authenticate with the upstream AI provider. Encrypted at rest via dbcrypt when api_key_key_id is set.'; + +COMMENT ON COLUMN user_ai_provider_keys.api_key_key_id IS 'The ID of the key used to encrypt the user-owned provider API key. If this is NULL, the API key is not encrypted.'; + CREATE TABLE user_chat_provider_keys ( id uuid DEFAULT gen_random_uuid() NOT NULL, user_id uuid NOT NULL, @@ -3859,6 +3883,12 @@ ALTER TABLE ONLY usage_events_daily ALTER TABLE ONLY usage_events ADD CONSTRAINT usage_events_pkey PRIMARY KEY (id); +ALTER TABLE ONLY user_ai_provider_keys + ADD CONSTRAINT user_ai_provider_keys_pkey PRIMARY KEY (id); + +ALTER TABLE ONLY user_ai_provider_keys + ADD CONSTRAINT user_ai_provider_keys_user_id_ai_provider_id_key UNIQUE (user_id, ai_provider_id); + ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_pkey PRIMARY KEY (id); @@ -4072,6 +4102,8 @@ CREATE INDEX idx_chat_messages_owner_spend ON chat_messages USING btree (chat_id CREATE INDEX idx_chat_messages_user_prompts ON chat_messages USING btree (chat_id, id DESC) WHERE ((deleted = false) AND (role = 'user'::chat_message_role) AND (visibility = ANY (ARRAY['user'::chat_message_visibility, 'both'::chat_message_visibility]))); +CREATE INDEX idx_chat_model_configs_ai_provider_id ON chat_model_configs USING btree (ai_provider_id); + CREATE INDEX idx_chat_model_configs_enabled ON chat_model_configs USING btree (enabled); CREATE INDEX idx_chat_model_configs_provider ON chat_model_configs USING btree (provider); @@ -4162,6 +4194,8 @@ CREATE INDEX idx_usage_events_ai_seats ON usage_events USING btree (event_type, CREATE INDEX idx_usage_events_select_for_publishing ON usage_events USING btree (published_at, publish_started_at, created_at); +CREATE INDEX idx_user_ai_provider_keys_ai_provider_id ON user_ai_provider_keys USING btree (ai_provider_id); + CREATE INDEX idx_user_deleted_deleted_at ON user_deleted USING btree (deleted_at); CREATE INDEX idx_user_status_changes_changed_at ON user_status_changes USING btree (changed_at); @@ -4401,6 +4435,9 @@ ALTER TABLE ONLY chat_messages ALTER TABLE ONLY chat_messages ADD CONSTRAINT chat_messages_model_config_id_fkey FOREIGN KEY (model_config_id) REFERENCES chat_model_configs(id); +ALTER TABLE ONLY chat_model_configs + ADD CONSTRAINT chat_model_configs_ai_provider_id_fkey FOREIGN KEY (ai_provider_id) REFERENCES ai_providers(id); + ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_created_by_fkey FOREIGN KEY (created_by) REFERENCES users(id); @@ -4641,6 +4678,15 @@ ALTER TABLE ONLY templates ALTER TABLE ONLY templates ADD CONSTRAINT templates_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE; +ALTER TABLE ONLY user_ai_provider_keys + ADD CONSTRAINT user_ai_provider_keys_ai_provider_id_fkey FOREIGN KEY (ai_provider_id) REFERENCES ai_providers(id) ON DELETE CASCADE; + +ALTER TABLE ONLY user_ai_provider_keys + ADD CONSTRAINT user_ai_provider_keys_api_key_key_id_fkey FOREIGN KEY (api_key_key_id) REFERENCES dbcrypt_keys(active_key_digest); + +ALTER TABLE ONLY user_ai_provider_keys + ADD CONSTRAINT user_ai_provider_keys_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE; + ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_api_key_key_id_fkey FOREIGN KEY (api_key_key_id) REFERENCES dbcrypt_keys(active_key_digest); diff --git a/coderd/database/foreign_key_constraint.go b/coderd/database/foreign_key_constraint.go index 7bbf92d9be..27eadbe88f 100644 --- a/coderd/database/foreign_key_constraint.go +++ b/coderd/database/foreign_key_constraint.go @@ -21,6 +21,7 @@ const ( ForeignKeyChatFilesOwnerID ForeignKeyConstraint = "chat_files_owner_id_fkey" // ALTER TABLE ONLY chat_files ADD CONSTRAINT chat_files_owner_id_fkey FOREIGN KEY (owner_id) REFERENCES users(id) ON DELETE CASCADE; ForeignKeyChatMessagesChatID ForeignKeyConstraint = "chat_messages_chat_id_fkey" // ALTER TABLE ONLY chat_messages ADD CONSTRAINT chat_messages_chat_id_fkey FOREIGN KEY (chat_id) REFERENCES chats(id) ON DELETE CASCADE; ForeignKeyChatMessagesModelConfigID ForeignKeyConstraint = "chat_messages_model_config_id_fkey" // ALTER TABLE ONLY chat_messages ADD CONSTRAINT chat_messages_model_config_id_fkey FOREIGN KEY (model_config_id) REFERENCES chat_model_configs(id); + ForeignKeyChatModelConfigsAiProviderID ForeignKeyConstraint = "chat_model_configs_ai_provider_id_fkey" // ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_ai_provider_id_fkey FOREIGN KEY (ai_provider_id) REFERENCES ai_providers(id); ForeignKeyChatModelConfigsCreatedBy ForeignKeyConstraint = "chat_model_configs_created_by_fkey" // ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_created_by_fkey FOREIGN KEY (created_by) REFERENCES users(id); ForeignKeyChatModelConfigsUpdatedBy ForeignKeyConstraint = "chat_model_configs_updated_by_fkey" // ALTER TABLE ONLY chat_model_configs ADD CONSTRAINT chat_model_configs_updated_by_fkey FOREIGN KEY (updated_by) REFERENCES users(id); ForeignKeyChatProvidersAPIKeyKeyID ForeignKeyConstraint = "chat_providers_api_key_key_id_fkey" // ALTER TABLE ONLY chat_providers ADD CONSTRAINT chat_providers_api_key_key_id_fkey FOREIGN KEY (api_key_key_id) REFERENCES dbcrypt_keys(active_key_digest); @@ -101,6 +102,9 @@ const ( ForeignKeyTemplateVersionsTemplateID ForeignKeyConstraint = "template_versions_template_id_fkey" // ALTER TABLE ONLY template_versions ADD CONSTRAINT template_versions_template_id_fkey FOREIGN KEY (template_id) REFERENCES templates(id) ON DELETE CASCADE; ForeignKeyTemplatesCreatedBy ForeignKeyConstraint = "templates_created_by_fkey" // ALTER TABLE ONLY templates ADD CONSTRAINT templates_created_by_fkey FOREIGN KEY (created_by) REFERENCES users(id) ON DELETE RESTRICT; ForeignKeyTemplatesOrganizationID ForeignKeyConstraint = "templates_organization_id_fkey" // ALTER TABLE ONLY templates ADD CONSTRAINT templates_organization_id_fkey FOREIGN KEY (organization_id) REFERENCES organizations(id) ON DELETE CASCADE; + ForeignKeyUserAiProviderKeysAiProviderID ForeignKeyConstraint = "user_ai_provider_keys_ai_provider_id_fkey" // ALTER TABLE ONLY user_ai_provider_keys ADD CONSTRAINT user_ai_provider_keys_ai_provider_id_fkey FOREIGN KEY (ai_provider_id) REFERENCES ai_providers(id) ON DELETE CASCADE; + ForeignKeyUserAiProviderKeysAPIKeyKeyID ForeignKeyConstraint = "user_ai_provider_keys_api_key_key_id_fkey" // ALTER TABLE ONLY user_ai_provider_keys ADD CONSTRAINT user_ai_provider_keys_api_key_key_id_fkey FOREIGN KEY (api_key_key_id) REFERENCES dbcrypt_keys(active_key_digest); + ForeignKeyUserAiProviderKeysUserID ForeignKeyConstraint = "user_ai_provider_keys_user_id_fkey" // ALTER TABLE ONLY user_ai_provider_keys ADD CONSTRAINT user_ai_provider_keys_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE; ForeignKeyUserChatProviderKeysAPIKeyKeyID ForeignKeyConstraint = "user_chat_provider_keys_api_key_key_id_fkey" // ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_api_key_key_id_fkey FOREIGN KEY (api_key_key_id) REFERENCES dbcrypt_keys(active_key_digest); ForeignKeyUserChatProviderKeysChatProviderID ForeignKeyConstraint = "user_chat_provider_keys_chat_provider_id_fkey" // ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_chat_provider_id_fkey FOREIGN KEY (chat_provider_id) REFERENCES chat_providers(id) ON DELETE CASCADE; ForeignKeyUserChatProviderKeysUserID ForeignKeyConstraint = "user_chat_provider_keys_user_id_fkey" // ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE; diff --git a/coderd/database/migrations/000503_ai_providers_schema_expand.down.sql b/coderd/database/migrations/000503_ai_providers_schema_expand.down.sql new file mode 100644 index 0000000000..3932e112a1 --- /dev/null +++ b/coderd/database/migrations/000503_ai_providers_schema_expand.down.sql @@ -0,0 +1,46 @@ +DROP INDEX IF EXISTS idx_chat_model_configs_ai_provider_id; + +ALTER TABLE chat_model_configs + DROP COLUMN IF EXISTS ai_provider_id; + +CREATE OR REPLACE FUNCTION delete_deleted_user_resources() RETURNS trigger + LANGUAGE plpgsql +AS $$ +DECLARE +BEGIN + IF (NEW.deleted) THEN + -- Remove their api_keys. + DELETE FROM api_keys + WHERE user_id = OLD.id; + + -- Remove their user_links. + -- Their login_type is preserved in the users table. + -- Matching this user back to the link can still be done by their + -- email if the account is undeleted. Although that is not a guarantee. + DELETE FROM user_links + WHERE user_id = OLD.id; + + -- Remove their user_secrets. + -- user_secrets.user_id has ON DELETE CASCADE, but soft-delete + -- does not remove the users row so the FK cascade never fires. + DELETE FROM user_secrets + WHERE user_id = OLD.id; + + -- Remove their organization memberships. + -- This also triggers group membership cleanup via + -- trigger_delete_group_members_on_org_member_delete. + DELETE FROM organization_members + WHERE user_id = OLD.id; + + -- Remove their user_skills. + -- user_skills.user_id has ON DELETE CASCADE, but soft-delete + -- does not remove the users row so the FK cascade never fires. + DELETE FROM user_skills + WHERE user_id = OLD.id; + END IF; + RETURN NEW; +END; +$$; + +DROP INDEX IF EXISTS idx_user_ai_provider_keys_ai_provider_id; +DROP TABLE IF EXISTS user_ai_provider_keys; diff --git a/coderd/database/migrations/000503_ai_providers_schema_expand.up.sql b/coderd/database/migrations/000503_ai_providers_schema_expand.up.sql new file mode 100644 index 0000000000..137d26fcfd --- /dev/null +++ b/coderd/database/migrations/000503_ai_providers_schema_expand.up.sql @@ -0,0 +1,72 @@ +CREATE TABLE user_ai_provider_keys ( + id uuid PRIMARY KEY DEFAULT gen_random_uuid(), + user_id uuid NOT NULL REFERENCES users(id) ON DELETE CASCADE, + ai_provider_id uuid NOT NULL REFERENCES ai_providers(id) ON DELETE CASCADE, + api_key text NOT NULL CHECK (api_key != ''), + api_key_key_id text REFERENCES dbcrypt_keys(active_key_digest), + created_at timestamp with time zone NOT NULL DEFAULT NOW(), + updated_at timestamp with time zone NOT NULL DEFAULT NOW(), + UNIQUE (user_id, ai_provider_id) +); + +COMMENT ON TABLE user_ai_provider_keys IS 'User-owned API keys associated with AI providers. These keys are used only when BYOK is enabled.'; + +COMMENT ON COLUMN user_ai_provider_keys.api_key IS 'User-owned API key used to authenticate with the upstream AI provider. Encrypted at rest via dbcrypt when api_key_key_id is set.'; + +COMMENT ON COLUMN user_ai_provider_keys.api_key_key_id IS 'The ID of the key used to encrypt the user-owned provider API key. If this is NULL, the API key is not encrypted.'; + +CREATE INDEX idx_user_ai_provider_keys_ai_provider_id + ON user_ai_provider_keys (ai_provider_id); + +-- user_ai_provider_keys.user_id has ON DELETE CASCADE, but user deletion +-- normally soft-deletes the users row, so the FK cascade does not fire. +CREATE OR REPLACE FUNCTION delete_deleted_user_resources() RETURNS trigger + LANGUAGE plpgsql +AS $$ +DECLARE +BEGIN + IF (NEW.deleted) THEN + -- Remove their api_keys. + DELETE FROM api_keys + WHERE user_id = OLD.id; + + -- Remove their user_links. + -- Their login_type is preserved in the users table. + -- Matching this user back to the link can still be done by their + -- email if the account is undeleted. Although that is not a guarantee. + DELETE FROM user_links + WHERE user_id = OLD.id; + + -- Remove their user_secrets. + -- user_secrets.user_id has ON DELETE CASCADE, but soft-delete + -- does not remove the users row so the FK cascade never fires. + DELETE FROM user_secrets + WHERE user_id = OLD.id; + + -- Remove their user AI provider keys. + -- user_ai_provider_keys.user_id has ON DELETE CASCADE, but soft-delete + -- does not remove the users row so the FK cascade never fires. + DELETE FROM user_ai_provider_keys + WHERE user_id = OLD.id; + + -- Remove their organization memberships. + -- This also triggers group membership cleanup via + -- trigger_delete_group_members_on_org_member_delete. + DELETE FROM organization_members + WHERE user_id = OLD.id; + + -- Remove their user_skills. + -- user_skills.user_id has ON DELETE CASCADE, but soft-delete + -- does not remove the users row so the FK cascade never fires. + DELETE FROM user_skills + WHERE user_id = OLD.id; + END IF; + RETURN NEW; +END; +$$; + +ALTER TABLE chat_model_configs + ADD COLUMN ai_provider_id uuid REFERENCES ai_providers(id); + +CREATE INDEX idx_chat_model_configs_ai_provider_id + ON chat_model_configs (ai_provider_id); diff --git a/coderd/database/migrations/000504_ai_providers_backfill.down.sql b/coderd/database/migrations/000504_ai_providers_backfill.down.sql new file mode 100644 index 0000000000..eb99d53906 --- /dev/null +++ b/coderd/database/migrations/000504_ai_providers_backfill.down.sql @@ -0,0 +1,48 @@ +WITH migrated_provider_ids AS ( + SELECT id + FROM chat_providers + UNION + SELECT id + FROM ai_providers + WHERE name LIKE 'agents-%' + AND deleted = TRUE +) +UPDATE chat_model_configs +SET ai_provider_id = NULL +WHERE ai_provider_id IN (SELECT id FROM migrated_provider_ids); + +WITH migrated_provider_ids AS ( + SELECT id + FROM chat_providers + UNION + SELECT id + FROM ai_providers + WHERE name LIKE 'agents-%' + AND deleted = TRUE +) +DELETE FROM user_ai_provider_keys +WHERE ai_provider_id IN (SELECT id FROM migrated_provider_ids); + +WITH migrated_provider_ids AS ( + SELECT id + FROM chat_providers + UNION + SELECT id + FROM ai_providers + WHERE name LIKE 'agents-%' + AND deleted = TRUE +) +DELETE FROM ai_provider_keys +WHERE provider_id IN (SELECT id FROM migrated_provider_ids); + +WITH migrated_provider_ids AS ( + SELECT id + FROM chat_providers + UNION + SELECT id + FROM ai_providers + WHERE name LIKE 'agents-%' + AND deleted = TRUE +) +DELETE FROM ai_providers +WHERE id IN (SELECT id FROM migrated_provider_ids); diff --git a/coderd/database/migrations/000504_ai_providers_backfill.up.sql b/coderd/database/migrations/000504_ai_providers_backfill.up.sql new file mode 100644 index 0000000000..176f5ddb97 --- /dev/null +++ b/coderd/database/migrations/000504_ai_providers_backfill.up.sql @@ -0,0 +1,78 @@ +-- Override any pre-existing live AI providers whose names collide with the +-- backfill below. No other process should write to ai_providers before this +-- migration, so any conflicting live row is treated as stale and soft-deleted +-- to free the name for the chat_providers row inserted below, which becomes +-- authoritative. +UPDATE ai_providers +SET deleted = TRUE, + enabled = FALSE, + updated_at = NOW() +WHERE deleted = FALSE + AND name IN ( + SELECT 'agents-' || cp.provider + FROM chat_providers cp + ); + +INSERT INTO ai_providers ( + id, + type, + name, + display_name, + enabled, + base_url, + created_at, + updated_at +) +SELECT + cp.id, + cp.provider::ai_provider_type, + 'agents-' || cp.provider, + NULLIF(cp.display_name, ''), + cp.enabled, + cp.base_url, + cp.created_at, + cp.updated_at +FROM chat_providers cp; + +INSERT INTO ai_provider_keys ( + id, + provider_id, + api_key, + api_key_key_id, + created_at, + updated_at +) +SELECT + gen_random_uuid(), + cp.id, + cp.api_key, + cp.api_key_key_id, + cp.created_at, + cp.updated_at +FROM chat_providers cp +WHERE cp.api_key != ''; + +INSERT INTO user_ai_provider_keys ( + id, + user_id, + ai_provider_id, + api_key, + api_key_key_id, + created_at, + updated_at +) +SELECT + ucpk.id, + ucpk.user_id, + ucpk.chat_provider_id, + ucpk.api_key, + ucpk.api_key_key_id, + ucpk.created_at, + ucpk.updated_at +FROM user_chat_provider_keys ucpk; + +UPDATE chat_model_configs cmc +SET ai_provider_id = cp.id +FROM chat_providers cp +WHERE cmc.provider = cp.provider + AND cmc.ai_provider_id IS NULL; diff --git a/coderd/database/migrations/migrate_test.go b/coderd/database/migrations/migrate_test.go index 8496a338eb..0b3e0c240c 100644 --- a/coderd/database/migrations/migrate_test.go +++ b/coderd/database/migrations/migrate_test.go @@ -1186,6 +1186,322 @@ func TestMigration000475AgentsAccessOrgRole(t *testing.T) { ) } +func TestMigration000504AIProvidersBackfill(t *testing.T) { + t.Parallel() + + const migrationVersion = 504 + + sqlDB := testSQLDB(t) + + next, err := migrations.Stepper(sqlDB) + require.NoError(t, err) + for { + version, more, err := next() + require.NoError(t, err) + if !more { + t.Fatalf("migration %d not found", migrationVersion) + } + if version == migrationVersion-1 { + break + } + } + + ctx := testutil.Context(t, testutil.WaitSuperLong) + now := time.Now().UTC().Truncate(time.Microsecond) + userID := uuid.New() + openAIProviderID := uuid.New() + anthropicProviderID := uuid.New() + openAIUserKeyID := uuid.New() + anthropicUserKeyID := uuid.New() + openAIModelConfigID := uuid.New() + anthropicModelConfigID := uuid.New() + + tx, err := sqlDB.BeginTx(ctx, nil) + require.NoError(t, err) + defer tx.Rollback() + + _, err = tx.ExecContext(ctx, + `INSERT INTO users (id, username, email, hashed_password, created_at, updated_at, status, rbac_roles, login_type) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)`, + userID, "ai-provider-backfill", "ai-provider-backfill@test.com", []byte{}, now, now, "active", pq.StringArray{}, "password", + ) + require.NoError(t, err) + _, err = tx.ExecContext(ctx, ` + INSERT INTO chat_providers (id, provider, display_name, api_key, enabled, base_url, created_at, updated_at) + VALUES + ($1, 'openai', 'OpenAI', 'sk-provider-openai', TRUE, 'https://api.openai.example.com/v1', $3, $3), + ($2, 'anthropic', '', '', FALSE, '', $3, $3) + `, openAIProviderID, anthropicProviderID, now) + require.NoError(t, err) + _, err = tx.ExecContext(ctx, ` + INSERT INTO user_chat_provider_keys (id, user_id, chat_provider_id, api_key, created_at, updated_at) + VALUES + ($1, $3, $4, 'sk-user-openai', $6, $6), + ($2, $3, $5, 'sk-user-anthropic', $6, $6) + `, openAIUserKeyID, anthropicUserKeyID, userID, openAIProviderID, anthropicProviderID, now) + require.NoError(t, err) + _, err = tx.ExecContext(ctx, ` + INSERT INTO chat_model_configs (id, provider, model, display_name, enabled, context_limit, compression_threshold, created_at, updated_at) + VALUES + ($1, 'openai', 'gpt-4', 'GPT 4', TRUE, 100000, 70, $3, $3), + ($2, 'anthropic', 'claude-3-5-sonnet-latest', 'Claude 3.5 Sonnet', TRUE, 200000, 70, $3, $3) + `, openAIModelConfigID, anthropicModelConfigID, now) + require.NoError(t, err) + require.NoError(t, tx.Commit()) + + var preBackfillCount int + err = sqlDB.QueryRowContext(ctx, ` + SELECT COUNT(*) + FROM ai_providers + WHERE id IN ($1, $2) + `, openAIProviderID, anthropicProviderID).Scan(&preBackfillCount) + require.NoError(t, err) + require.Zero(t, preBackfillCount, "test setup should start before the legacy chat providers are backfilled") + + var preBackfillModelConfigCount int + err = sqlDB.QueryRowContext(ctx, ` + SELECT COUNT(*) + FROM chat_model_configs + WHERE id IN ($1, $2) + AND ai_provider_id IS NOT NULL + `, openAIModelConfigID, anthropicModelConfigID).Scan(&preBackfillModelConfigCount) + require.NoError(t, err) + require.Zero(t, preBackfillModelConfigCount, "test setup should start before model configs point at AI providers") + + version, more, err := next() + require.NoError(t, err) + require.True(t, more) + require.EqualValues(t, migrationVersion, version) + + assertBackfilledProvider := func(providerID uuid.UUID, providerType, name string, displayName sql.NullString, enabled bool, baseURL string) { + t.Helper() + var provider struct { + Typ string + Name string + DisplayName sql.NullString + Enabled bool + BaseURL string + } + err = sqlDB.QueryRowContext(ctx, ` + SELECT type, name, display_name, enabled, base_url + FROM ai_providers + WHERE id = $1 + `, providerID).Scan(&provider.Typ, &provider.Name, &provider.DisplayName, &provider.Enabled, &provider.BaseURL) + require.NoError(t, err) + require.Equal(t, providerType, provider.Typ) + require.Equal(t, name, provider.Name) + require.Equal(t, displayName, provider.DisplayName) + require.Equal(t, enabled, provider.Enabled) + require.Equal(t, baseURL, provider.BaseURL) + } + assertBackfilledProvider( + openAIProviderID, + "openai", + "agents-openai", + sql.NullString{String: "OpenAI", Valid: true}, + true, + "https://api.openai.example.com/v1", + ) + assertBackfilledProvider( + anthropicProviderID, + "anthropic", + "agents-anthropic", + sql.NullString{}, + false, + "", + ) + + var providerKeyCount int + err = sqlDB.QueryRowContext(ctx, ` + SELECT COUNT(*) + FROM ai_provider_keys + WHERE provider_id = $1 AND api_key = 'sk-provider-openai' + `, openAIProviderID).Scan(&providerKeyCount) + require.NoError(t, err) + require.Equal(t, 1, providerKeyCount, "non-empty legacy provider API key should be copied") + + err = sqlDB.QueryRowContext(ctx, ` + SELECT COUNT(*) + FROM ai_provider_keys + WHERE provider_id = $1 + `, anthropicProviderID).Scan(&providerKeyCount) + require.NoError(t, err) + require.Zero(t, providerKeyCount, "empty legacy provider API key should not create an AI provider key") + + assertBackfilledUserKey := func(userKeyID, providerID uuid.UUID, apiKey string) { + t.Helper() + var userKeyCount int + err = sqlDB.QueryRowContext(ctx, ` + SELECT COUNT(*) + FROM user_ai_provider_keys + WHERE id = $1 AND user_id = $2 AND ai_provider_id = $3 AND api_key = $4 + `, userKeyID, userID, providerID, apiKey).Scan(&userKeyCount) + require.NoError(t, err) + require.Equal(t, 1, userKeyCount) + } + assertBackfilledUserKey(openAIUserKeyID, openAIProviderID, "sk-user-openai") + assertBackfilledUserKey(anthropicUserKeyID, anthropicProviderID, "sk-user-anthropic") + + assertModelConfigProviderID := func(modelConfigID, providerID uuid.UUID) { + t.Helper() + var aiProviderID sql.NullString + err = sqlDB.QueryRowContext(ctx, + `SELECT ai_provider_id::text FROM chat_model_configs WHERE id = $1`, + modelConfigID, + ).Scan(&aiProviderID) + require.NoError(t, err) + require.Equal(t, sql.NullString{String: providerID.String(), Valid: true}, aiProviderID) + } + assertModelConfigProviderID(openAIModelConfigID, openAIProviderID) + assertModelConfigProviderID(anthropicModelConfigID, anthropicProviderID) + + var legacyProviderCount int + err = sqlDB.QueryRowContext(ctx, ` + SELECT COUNT(*) + FROM chat_providers + WHERE id IN ($1, $2) + `, openAIProviderID, anthropicProviderID).Scan(&legacyProviderCount) + require.NoError(t, err) + require.Equal(t, 2, legacyProviderCount, "backfill should leave legacy rows for the rest of the stack") + + downSQL, err := os.ReadFile("000504_ai_providers_backfill.down.sql") + require.NoError(t, err) + _, err = sqlDB.ExecContext(ctx, string(downSQL)) + require.NoError(t, err) + + err = sqlDB.QueryRowContext(ctx, ` + SELECT COUNT(*) + FROM ai_providers + WHERE id IN ($1, $2) + `, openAIProviderID, anthropicProviderID).Scan(&providerKeyCount) + require.NoError(t, err) + require.Zero(t, providerKeyCount, "down migration should remove backfilled AI providers") + + err = sqlDB.QueryRowContext(ctx, ` + SELECT COUNT(*) + FROM ai_provider_keys + WHERE provider_id IN ($1, $2) + `, openAIProviderID, anthropicProviderID).Scan(&providerKeyCount) + require.NoError(t, err) + require.Zero(t, providerKeyCount, "down migration should remove backfilled provider keys") + + var userKeyCount int + err = sqlDB.QueryRowContext(ctx, ` + SELECT COUNT(*) + FROM user_ai_provider_keys + WHERE id IN ($1, $2) + `, openAIUserKeyID, anthropicUserKeyID).Scan(&userKeyCount) + require.NoError(t, err) + require.Zero(t, userKeyCount, "down migration should remove backfilled user keys") + + err = sqlDB.QueryRowContext(ctx, ` + SELECT COUNT(*) + FROM chat_model_configs + WHERE id IN ($1, $2) + AND ai_provider_id IS NOT NULL + `, openAIModelConfigID, anthropicModelConfigID).Scan(&preBackfillModelConfigCount) + require.NoError(t, err) + require.Zero(t, preBackfillModelConfigCount, "down migration should clear model config AI provider references") + + err = sqlDB.QueryRowContext(ctx, ` + SELECT COUNT(*) + FROM chat_providers + WHERE id IN ($1, $2) + `, openAIProviderID, anthropicProviderID).Scan(&legacyProviderCount) + require.NoError(t, err) + require.Equal(t, 2, legacyProviderCount, "down migration should leave the legacy source rows intact") +} + +// TestMigration000504AIProvidersBackfillOverridesNameConflict verifies that a +// pre-existing live ai_providers row whose name collides with the backfill +// (for example, agents-openai) is soft-deleted so the chat_providers-derived +// row inserted by the migration becomes authoritative. This scenario should +// not occur in practice since no other process writes to ai_providers before +// this migration runs, but the migration tolerates it rather than failing. +func TestMigration000504AIProvidersBackfillOverridesNameConflict(t *testing.T) { + t.Parallel() + + const migrationVersion = 504 + + sqlDB := testSQLDB(t) + + next, err := migrations.Stepper(sqlDB) + require.NoError(t, err) + for { + version, more, err := next() + require.NoError(t, err) + if !more { + t.Fatalf("migration %d not found", migrationVersion) + } + if version == migrationVersion-1 { + break + } + } + + ctx := testutil.Context(t, testutil.WaitSuperLong) + now := time.Now().UTC().Truncate(time.Microsecond) + chatProviderID := uuid.New() + staleProviderID := uuid.New() + + tx, err := sqlDB.BeginTx(ctx, nil) + require.NoError(t, err) + defer tx.Rollback() + + // Pre-existing live ai_providers row that collides on name. + _, err = tx.ExecContext(ctx, + `INSERT INTO ai_providers (id, type, name, display_name, enabled, base_url, created_at, updated_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8)`, + staleProviderID, "openai", "agents-openai", "Stale OpenAI", true, "https://stale.example.com/v1", now, now, + ) + require.NoError(t, err) + + // chat_providers row whose backfill will collide with the stale row above. + _, err = tx.ExecContext(ctx, + `INSERT INTO chat_providers (id, provider, display_name, api_key, enabled, base_url, created_at, updated_at) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8)`, + chatProviderID, "openai", "OpenAI", "sk-provider", true, "https://api.openai.example.com/v1", now, now, + ) + require.NoError(t, err) + require.NoError(t, tx.Commit()) + + version, more, err := next() + require.NoError(t, err) + require.True(t, more) + require.EqualValues(t, migrationVersion, version) + + // The stale row must be soft-deleted and disabled so the unique name index + // (which is partial WHERE deleted = FALSE) no longer covers it. + var stale struct { + Deleted bool + Enabled bool + } + err = sqlDB.QueryRowContext(ctx, + `SELECT deleted, enabled FROM ai_providers WHERE id = $1`, + staleProviderID, + ).Scan(&stale.Deleted, &stale.Enabled) + require.NoError(t, err) + require.True(t, stale.Deleted, "pre-existing conflicting ai_providers row should be soft-deleted") + require.False(t, stale.Enabled, "pre-existing conflicting ai_providers row should be disabled") + + // The new authoritative row must exist with the chat_providers id, the + // agents-openai name, and the chat_providers base_url. + var fresh struct { + Name string + BaseURL string + Deleted bool + Enabled bool + } + err = sqlDB.QueryRowContext(ctx, + `SELECT name, base_url, deleted, enabled FROM ai_providers WHERE id = $1`, + chatProviderID, + ).Scan(&fresh.Name, &fresh.BaseURL, &fresh.Deleted, &fresh.Enabled) + require.NoError(t, err) + require.Equal(t, "agents-openai", fresh.Name) + require.Equal(t, "https://api.openai.example.com/v1", fresh.BaseURL) + require.False(t, fresh.Deleted) + require.True(t, fresh.Enabled) +} + func TestMigration000498SoftDeleteStaleWorkspaceAgents(t *testing.T) { t.Parallel() diff --git a/coderd/database/migrations/testdata/fixtures/000503_ai_providers_schema_expand.up.sql b/coderd/database/migrations/testdata/fixtures/000503_ai_providers_schema_expand.up.sql new file mode 100644 index 0000000000..dcdf649aed --- /dev/null +++ b/coderd/database/migrations/testdata/fixtures/000503_ai_providers_schema_expand.up.sql @@ -0,0 +1,11 @@ +INSERT INTO user_ai_provider_keys ( + id, + user_id, + ai_provider_id, + api_key +) VALUES ( + '8e3c6e18-2b75-4c3f-9b35-9d1c6f4e1c01', + '30095c71-380b-457a-8995-97b8ee6e5307', + '8e3c6e18-2b75-4c3f-9b35-9d1c6f4e1a01', + 'fixture-user-openai-key' +); diff --git a/coderd/database/models.go b/coderd/database/models.go index f801be833b..169080f8fc 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -4679,6 +4679,7 @@ type ChatModelConfig struct { ContextLimit int64 `db:"context_limit" json:"context_limit"` CompressionThreshold int32 `db:"compression_threshold" json:"compression_threshold"` Options json.RawMessage `db:"options" json:"options"` + AIProviderID uuid.NullUUID `db:"ai_provider_id" json:"ai_provider_id"` } type ChatProvider struct { @@ -5692,6 +5693,19 @@ type User struct { ChatSpendLimitMicros sql.NullInt64 `db:"chat_spend_limit_micros" json:"chat_spend_limit_micros"` } +// User-owned API keys associated with AI providers. These keys are used only when BYOK is enabled. +type UserAiProviderKey struct { + ID uuid.UUID `db:"id" json:"id"` + UserID uuid.UUID `db:"user_id" json:"user_id"` + AIProviderID uuid.UUID `db:"ai_provider_id" json:"ai_provider_id"` + // User-owned API key used to authenticate with the upstream AI provider. Encrypted at rest via dbcrypt when api_key_key_id is set. + APIKey string `db:"api_key" json:"api_key"` + // The ID of the key used to encrypt the user-owned provider API key. If this is NULL, the API key is not encrypted. + ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` +} + type UserChatProviderKey struct { ID uuid.UUID `db:"id" json:"id"` UserID uuid.UUID `db:"user_id" json:"user_id"` diff --git a/coderd/database/querier.go b/coderd/database/querier.go index 32e38750ec..7f1b2d0b1b 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -193,6 +193,8 @@ type sqlcQuerier interface { DeleteTailnetPeer(ctx context.Context, arg DeleteTailnetPeerParams) (DeleteTailnetPeerRow, error) DeleteTailnetTunnel(ctx context.Context, arg DeleteTailnetTunnelParams) (DeleteTailnetTunnelRow, error) DeleteTask(ctx context.Context, arg DeleteTaskParams) (uuid.UUID, error) + DeleteUserAIProviderKey(ctx context.Context, arg DeleteUserAIProviderKeyParams) error + DeleteUserAIProviderKeysByProviderID(ctx context.Context, aiProviderID uuid.UUID) error DeleteUserChatCompactionThreshold(ctx context.Context, arg DeleteUserChatCompactionThresholdParams) error DeleteUserChatProviderKey(ctx context.Context, arg DeleteUserChatProviderKeyParams) error DeleteUserSecretByUserIDAndName(ctx context.Context, arg DeleteUserSecretByUserIDAndNameParams) (UserSecret, error) @@ -727,6 +729,11 @@ type sqlcQuerier interface { // inclusive. GetTotalUsageDCManagedAgentsV1(ctx context.Context, arg GetTotalUsageDCManagedAgentsV1Params) (int64, error) GetUnexpiredLicenses(ctx context.Context) ([]License, error) + GetUserAIProviderKeyByProviderID(ctx context.Context, arg GetUserAIProviderKeyByProviderIDParams) (UserAiProviderKey, error) + // GetUserAIProviderKeys is used by dbcrypt key rotation. Request paths should use + // user-scoped lookups instead of this bulk accessor. + GetUserAIProviderKeys(ctx context.Context) ([]UserAiProviderKey, error) + GetUserAIProviderKeysByUserID(ctx context.Context, userID uuid.UUID) ([]UserAiProviderKey, error) // Returns user IDs from the provided list that are consuming an AI seat. // Filters to active, non-deleted, non-system users to match the canonical // seat count query (GetActiveAISeatCount). @@ -1217,6 +1224,7 @@ type sqlcQuerier interface { // Used by the dbcrypt key rotation utility to re-encrypt or decrypt // rows in place. UpdateEncryptedAIProviderSettings(ctx context.Context, arg UpdateEncryptedAIProviderSettingsParams) (AIProvider, error) + UpdateEncryptedUserAIProviderKey(ctx context.Context, arg UpdateEncryptedUserAIProviderKeyParams) (UserAiProviderKey, error) UpdateExternalAuthLink(ctx context.Context, arg UpdateExternalAuthLinkParams) (ExternalAuthLink, error) // Optimistic lock: only update the row if the refresh token in the database // still matches the one we read before attempting the refresh. This prevents @@ -1265,6 +1273,7 @@ type sqlcQuerier interface { UpdateTemplateVersionFlagsByJobID(ctx context.Context, arg UpdateTemplateVersionFlagsByJobIDParams) error UpdateTemplateWorkspacesLastUsedAt(ctx context.Context, arg UpdateTemplateWorkspacesLastUsedAtParams) error UpdateUsageEventsPostPublish(ctx context.Context, arg UpdateUsageEventsPostPublishParams) error + UpdateUserAIProviderKey(ctx context.Context, arg UpdateUserAIProviderKeyParams) (UserAiProviderKey, error) UpdateUserAgentChatSendShortcut(ctx context.Context, arg UpdateUserAgentChatSendShortcutParams) (string, error) UpdateUserChatCompactionThreshold(ctx context.Context, arg UpdateUserChatCompactionThresholdParams) (UserConfig, error) UpdateUserChatCustomPrompt(ctx context.Context, arg UpdateUserChatCustomPromptParams) (UserConfig, error) @@ -1387,6 +1396,10 @@ type sqlcQuerier interface { // used to store the data, and the minutes are summed for each user and template // combination. The result is stored in the template_usage_stats table. UpsertTemplateUsageStats(ctx context.Context) error + // UpsertUserAIProviderKey preserves the original id and created_at when the + // user/provider pair already exists. On conflict, callers provide id and + // created_at for the insert path only. + UpsertUserAIProviderKey(ctx context.Context, arg UpsertUserAIProviderKeyParams) (UserAiProviderKey, error) UpsertUserChatDebugLoggingEnabled(ctx context.Context, arg UpsertUserChatDebugLoggingEnabledParams) error UpsertUserChatPersonalModelOverride(ctx context.Context, arg UpsertUserChatPersonalModelOverrideParams) error UpsertUserChatProviderKey(ctx context.Context, arg UpsertUserChatProviderKeyParams) (UserChatProviderKey, error) diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 6ff1d2dc76..b70f0a388d 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -4972,7 +4972,7 @@ func (q *sqlQuerier) DeleteChatModelConfigsByProvider(ctx context.Context, provi const getChatModelConfigByID = `-- name: GetChatModelConfigByID :one SELECT - id, provider, model, display_name, created_by, updated_by, enabled, is_default, deleted, deleted_at, created_at, updated_at, context_limit, compression_threshold, options + id, provider, model, display_name, created_by, updated_by, enabled, is_default, deleted, deleted_at, created_at, updated_at, context_limit, compression_threshold, options, ai_provider_id FROM chat_model_configs WHERE @@ -4999,13 +4999,14 @@ func (q *sqlQuerier) GetChatModelConfigByID(ctx context.Context, id uuid.UUID) ( &i.ContextLimit, &i.CompressionThreshold, &i.Options, + &i.AIProviderID, ) return i, err } const getChatModelConfigs = `-- name: GetChatModelConfigs :many SELECT - id, provider, model, display_name, created_by, updated_by, enabled, is_default, deleted, deleted_at, created_at, updated_at, context_limit, compression_threshold, options + id, provider, model, display_name, created_by, updated_by, enabled, is_default, deleted, deleted_at, created_at, updated_at, context_limit, compression_threshold, options, ai_provider_id FROM chat_model_configs WHERE @@ -5042,6 +5043,7 @@ func (q *sqlQuerier) GetChatModelConfigs(ctx context.Context) ([]ChatModelConfig &i.ContextLimit, &i.CompressionThreshold, &i.Options, + &i.AIProviderID, ); err != nil { return nil, err } @@ -5058,7 +5060,7 @@ func (q *sqlQuerier) GetChatModelConfigs(ctx context.Context) ([]ChatModelConfig const getDefaultChatModelConfig = `-- name: GetDefaultChatModelConfig :one SELECT - id, provider, model, display_name, created_by, updated_by, enabled, is_default, deleted, deleted_at, created_at, updated_at, context_limit, compression_threshold, options + id, provider, model, display_name, created_by, updated_by, enabled, is_default, deleted, deleted_at, created_at, updated_at, context_limit, compression_threshold, options, ai_provider_id FROM chat_model_configs WHERE @@ -5085,13 +5087,14 @@ func (q *sqlQuerier) GetDefaultChatModelConfig(ctx context.Context) (ChatModelCo &i.ContextLimit, &i.CompressionThreshold, &i.Options, + &i.AIProviderID, ) return i, err } const getEnabledChatModelConfigByID = `-- name: GetEnabledChatModelConfigByID :one SELECT - cmc.id, cmc.provider, cmc.model, cmc.display_name, cmc.created_by, cmc.updated_by, cmc.enabled, cmc.is_default, cmc.deleted, cmc.deleted_at, cmc.created_at, cmc.updated_at, cmc.context_limit, cmc.compression_threshold, cmc.options + cmc.id, cmc.provider, cmc.model, cmc.display_name, cmc.created_by, cmc.updated_by, cmc.enabled, cmc.is_default, cmc.deleted, cmc.deleted_at, cmc.created_at, cmc.updated_at, cmc.context_limit, cmc.compression_threshold, cmc.options, cmc.ai_provider_id FROM chat_model_configs cmc JOIN @@ -5124,13 +5127,14 @@ func (q *sqlQuerier) GetEnabledChatModelConfigByID(ctx context.Context, id uuid. &i.ContextLimit, &i.CompressionThreshold, &i.Options, + &i.AIProviderID, ) return i, err } const getEnabledChatModelConfigs = `-- name: GetEnabledChatModelConfigs :many SELECT - cmc.id, cmc.provider, cmc.model, cmc.display_name, cmc.created_by, cmc.updated_by, cmc.enabled, cmc.is_default, cmc.deleted, cmc.deleted_at, cmc.created_at, cmc.updated_at, cmc.context_limit, cmc.compression_threshold, cmc.options + cmc.id, cmc.provider, cmc.model, cmc.display_name, cmc.created_by, cmc.updated_by, cmc.enabled, cmc.is_default, cmc.deleted, cmc.deleted_at, cmc.created_at, cmc.updated_at, cmc.context_limit, cmc.compression_threshold, cmc.options, cmc.ai_provider_id FROM chat_model_configs cmc JOIN @@ -5171,6 +5175,7 @@ func (q *sqlQuerier) GetEnabledChatModelConfigs(ctx context.Context) ([]ChatMode &i.ContextLimit, &i.CompressionThreshold, &i.Options, + &i.AIProviderID, ); err != nil { return nil, err } @@ -5210,7 +5215,7 @@ INSERT INTO chat_model_configs ( $10::jsonb ) RETURNING - id, provider, model, display_name, created_by, updated_by, enabled, is_default, deleted, deleted_at, created_at, updated_at, context_limit, compression_threshold, options + id, provider, model, display_name, created_by, updated_by, enabled, is_default, deleted, deleted_at, created_at, updated_at, context_limit, compression_threshold, options, ai_provider_id ` type InsertChatModelConfigParams struct { @@ -5256,6 +5261,7 @@ func (q *sqlQuerier) InsertChatModelConfig(ctx context.Context, arg InsertChatMo &i.ContextLimit, &i.CompressionThreshold, &i.Options, + &i.AIProviderID, ) return i, err } @@ -5294,7 +5300,7 @@ WHERE id = $10::uuid AND deleted = FALSE RETURNING - id, provider, model, display_name, created_by, updated_by, enabled, is_default, deleted, deleted_at, created_at, updated_at, context_limit, compression_threshold, options + id, provider, model, display_name, created_by, updated_by, enabled, is_default, deleted, deleted_at, created_at, updated_at, context_limit, compression_threshold, options, ai_provider_id ` type UpdateChatModelConfigParams struct { @@ -5340,6 +5346,7 @@ func (q *sqlQuerier) UpdateChatModelConfig(ctx context.Context, arg UpdateChatMo &i.ContextLimit, &i.CompressionThreshold, &i.Options, + &i.AIProviderID, ) return i, err } @@ -26306,6 +26313,293 @@ func (q *sqlQuerier) UsageEventExistsByID(ctx context.Context, id string) (bool, return column_1, err } +const deleteUserAIProviderKey = `-- name: DeleteUserAIProviderKey :exec +DELETE FROM + user_ai_provider_keys +WHERE + user_id = $1::uuid + AND ai_provider_id = $2::uuid +` + +type DeleteUserAIProviderKeyParams struct { + UserID uuid.UUID `db:"user_id" json:"user_id"` + AIProviderID uuid.UUID `db:"ai_provider_id" json:"ai_provider_id"` +} + +func (q *sqlQuerier) DeleteUserAIProviderKey(ctx context.Context, arg DeleteUserAIProviderKeyParams) error { + _, err := q.db.ExecContext(ctx, deleteUserAIProviderKey, arg.UserID, arg.AIProviderID) + return err +} + +const deleteUserAIProviderKeysByProviderID = `-- name: DeleteUserAIProviderKeysByProviderID :exec +DELETE FROM + user_ai_provider_keys +WHERE + ai_provider_id = $1::uuid +` + +func (q *sqlQuerier) DeleteUserAIProviderKeysByProviderID(ctx context.Context, aiProviderID uuid.UUID) error { + _, err := q.db.ExecContext(ctx, deleteUserAIProviderKeysByProviderID, aiProviderID) + return err +} + +const getUserAIProviderKeyByProviderID = `-- name: GetUserAIProviderKeyByProviderID :one +SELECT + id, user_id, ai_provider_id, api_key, api_key_key_id, created_at, updated_at +FROM + user_ai_provider_keys +WHERE + user_id = $1::uuid + AND ai_provider_id = $2::uuid +` + +type GetUserAIProviderKeyByProviderIDParams struct { + UserID uuid.UUID `db:"user_id" json:"user_id"` + AIProviderID uuid.UUID `db:"ai_provider_id" json:"ai_provider_id"` +} + +func (q *sqlQuerier) GetUserAIProviderKeyByProviderID(ctx context.Context, arg GetUserAIProviderKeyByProviderIDParams) (UserAiProviderKey, error) { + row := q.db.QueryRowContext(ctx, getUserAIProviderKeyByProviderID, arg.UserID, arg.AIProviderID) + var i UserAiProviderKey + err := row.Scan( + &i.ID, + &i.UserID, + &i.AIProviderID, + &i.APIKey, + &i.ApiKeyKeyID, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const getUserAIProviderKeys = `-- name: GetUserAIProviderKeys :many +SELECT + id, user_id, ai_provider_id, api_key, api_key_key_id, created_at, updated_at +FROM + user_ai_provider_keys +ORDER BY + user_id ASC, + ai_provider_id ASC, + created_at ASC, + id ASC +` + +// GetUserAIProviderKeys is used by dbcrypt key rotation. Request paths should use +// user-scoped lookups instead of this bulk accessor. +func (q *sqlQuerier) GetUserAIProviderKeys(ctx context.Context) ([]UserAiProviderKey, error) { + rows, err := q.db.QueryContext(ctx, getUserAIProviderKeys) + if err != nil { + return nil, err + } + defer rows.Close() + var items []UserAiProviderKey + for rows.Next() { + var i UserAiProviderKey + if err := rows.Scan( + &i.ID, + &i.UserID, + &i.AIProviderID, + &i.APIKey, + &i.ApiKeyKeyID, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getUserAIProviderKeysByUserID = `-- name: GetUserAIProviderKeysByUserID :many +SELECT + id, user_id, ai_provider_id, api_key, api_key_key_id, created_at, updated_at +FROM + user_ai_provider_keys +WHERE + user_id = $1::uuid +ORDER BY + ai_provider_id ASC, + created_at ASC, + id ASC +` + +func (q *sqlQuerier) GetUserAIProviderKeysByUserID(ctx context.Context, userID uuid.UUID) ([]UserAiProviderKey, error) { + rows, err := q.db.QueryContext(ctx, getUserAIProviderKeysByUserID, userID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []UserAiProviderKey + for rows.Next() { + var i UserAiProviderKey + if err := rows.Scan( + &i.ID, + &i.UserID, + &i.AIProviderID, + &i.APIKey, + &i.ApiKeyKeyID, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const updateEncryptedUserAIProviderKey = `-- name: UpdateEncryptedUserAIProviderKey :one +UPDATE + user_ai_provider_keys +SET + api_key = $1::text, + api_key_key_id = $2::text, + updated_at = NOW() +WHERE + id = $3::uuid +RETURNING + id, user_id, ai_provider_id, api_key, api_key_key_id, created_at, updated_at +` + +type UpdateEncryptedUserAIProviderKeyParams struct { + APIKey string `db:"api_key" json:"api_key"` + ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"` + ID uuid.UUID `db:"id" json:"id"` +} + +func (q *sqlQuerier) UpdateEncryptedUserAIProviderKey(ctx context.Context, arg UpdateEncryptedUserAIProviderKeyParams) (UserAiProviderKey, error) { + row := q.db.QueryRowContext(ctx, updateEncryptedUserAIProviderKey, arg.APIKey, arg.ApiKeyKeyID, arg.ID) + var i UserAiProviderKey + err := row.Scan( + &i.ID, + &i.UserID, + &i.AIProviderID, + &i.APIKey, + &i.ApiKeyKeyID, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const updateUserAIProviderKey = `-- name: UpdateUserAIProviderKey :one +UPDATE + user_ai_provider_keys +SET + api_key = $1::text, + api_key_key_id = $2::text, + updated_at = NOW() +WHERE + user_id = $3::uuid + AND ai_provider_id = $4::uuid +RETURNING + id, user_id, ai_provider_id, api_key, api_key_key_id, created_at, updated_at +` + +type UpdateUserAIProviderKeyParams struct { + APIKey string `db:"api_key" json:"api_key"` + ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"` + UserID uuid.UUID `db:"user_id" json:"user_id"` + AIProviderID uuid.UUID `db:"ai_provider_id" json:"ai_provider_id"` +} + +func (q *sqlQuerier) UpdateUserAIProviderKey(ctx context.Context, arg UpdateUserAIProviderKeyParams) (UserAiProviderKey, error) { + row := q.db.QueryRowContext(ctx, updateUserAIProviderKey, + arg.APIKey, + arg.ApiKeyKeyID, + arg.UserID, + arg.AIProviderID, + ) + var i UserAiProviderKey + err := row.Scan( + &i.ID, + &i.UserID, + &i.AIProviderID, + &i.APIKey, + &i.ApiKeyKeyID, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const upsertUserAIProviderKey = `-- name: UpsertUserAIProviderKey :one +INSERT INTO user_ai_provider_keys ( + id, + user_id, + ai_provider_id, + api_key, + api_key_key_id, + created_at, + updated_at +) VALUES ( + $1::uuid, + $2::uuid, + $3::uuid, + $4::text, + $5::text, + $6::timestamptz, + $7::timestamptz +) +ON CONFLICT (user_id, ai_provider_id) DO UPDATE +SET + api_key = EXCLUDED.api_key, + api_key_key_id = EXCLUDED.api_key_key_id, + updated_at = EXCLUDED.updated_at +RETURNING + id, user_id, ai_provider_id, api_key, api_key_key_id, created_at, updated_at +` + +type UpsertUserAIProviderKeyParams struct { + ID uuid.UUID `db:"id" json:"id"` + UserID uuid.UUID `db:"user_id" json:"user_id"` + AIProviderID uuid.UUID `db:"ai_provider_id" json:"ai_provider_id"` + APIKey string `db:"api_key" json:"api_key"` + ApiKeyKeyID sql.NullString `db:"api_key_key_id" json:"api_key_key_id"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` +} + +// UpsertUserAIProviderKey preserves the original id and created_at when the +// user/provider pair already exists. On conflict, callers provide id and +// created_at for the insert path only. +func (q *sqlQuerier) UpsertUserAIProviderKey(ctx context.Context, arg UpsertUserAIProviderKeyParams) (UserAiProviderKey, error) { + row := q.db.QueryRowContext(ctx, upsertUserAIProviderKey, + arg.ID, + arg.UserID, + arg.AIProviderID, + arg.APIKey, + arg.ApiKeyKeyID, + arg.CreatedAt, + arg.UpdatedAt, + ) + var i UserAiProviderKey + err := row.Scan( + &i.ID, + &i.UserID, + &i.AIProviderID, + &i.APIKey, + &i.ApiKeyKeyID, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + const getUserLinkByLinkedID = `-- name: GetUserLinkByLinkedID :one SELECT user_links.user_id, user_links.login_type, user_links.linked_id, user_links.oauth_access_token, user_links.oauth_refresh_token, user_links.oauth_expiry, user_links.oauth_access_token_key_id, user_links.oauth_refresh_token_key_id, user_links.claims diff --git a/coderd/database/queries/user_ai_provider_keys.sql b/coderd/database/queries/user_ai_provider_keys.sql new file mode 100644 index 0000000000..ba3bbc9fc0 --- /dev/null +++ b/coderd/database/queries/user_ai_provider_keys.sql @@ -0,0 +1,100 @@ +-- name: GetUserAIProviderKeyByProviderID :one +SELECT + * +FROM + user_ai_provider_keys +WHERE + user_id = @user_id::uuid + AND ai_provider_id = @ai_provider_id::uuid; + +-- name: GetUserAIProviderKeysByUserID :many +SELECT + * +FROM + user_ai_provider_keys +WHERE + user_id = @user_id::uuid +ORDER BY + ai_provider_id ASC, + created_at ASC, + id ASC; + +-- GetUserAIProviderKeys is used by dbcrypt key rotation. Request paths should use +-- user-scoped lookups instead of this bulk accessor. +-- name: GetUserAIProviderKeys :many +SELECT + * +FROM + user_ai_provider_keys +ORDER BY + user_id ASC, + ai_provider_id ASC, + created_at ASC, + id ASC; + +-- UpsertUserAIProviderKey preserves the original id and created_at when the +-- user/provider pair already exists. On conflict, callers provide id and +-- created_at for the insert path only. +-- name: UpsertUserAIProviderKey :one +INSERT INTO user_ai_provider_keys ( + id, + user_id, + ai_provider_id, + api_key, + api_key_key_id, + created_at, + updated_at +) VALUES ( + @id::uuid, + @user_id::uuid, + @ai_provider_id::uuid, + @api_key::text, + sqlc.narg('api_key_key_id')::text, + @created_at::timestamptz, + @updated_at::timestamptz +) +ON CONFLICT (user_id, ai_provider_id) DO UPDATE +SET + api_key = EXCLUDED.api_key, + api_key_key_id = EXCLUDED.api_key_key_id, + updated_at = EXCLUDED.updated_at +RETURNING + *; + +-- name: UpdateUserAIProviderKey :one +UPDATE + user_ai_provider_keys +SET + api_key = @api_key::text, + api_key_key_id = sqlc.narg('api_key_key_id')::text, + updated_at = NOW() +WHERE + user_id = @user_id::uuid + AND ai_provider_id = @ai_provider_id::uuid +RETURNING + *; + +-- name: DeleteUserAIProviderKey :exec +DELETE FROM + user_ai_provider_keys +WHERE + user_id = @user_id::uuid + AND ai_provider_id = @ai_provider_id::uuid; + +-- name: DeleteUserAIProviderKeysByProviderID :exec +DELETE FROM + user_ai_provider_keys +WHERE + ai_provider_id = @ai_provider_id::uuid; + +-- name: UpdateEncryptedUserAIProviderKey :one +UPDATE + user_ai_provider_keys +SET + api_key = @api_key::text, + api_key_key_id = sqlc.narg('api_key_key_id')::text, + updated_at = NOW() +WHERE + id = @id::uuid +RETURNING + *; diff --git a/coderd/database/sqlc.yaml b/coderd/database/sqlc.yaml index 62143bb247..18c738c992 100644 --- a/coderd/database/sqlc.yaml +++ b/coderd/database/sqlc.yaml @@ -178,6 +178,7 @@ sql: type: "NullDecimal" package: "decimal" rename: + ai_provider_id: AIProviderID chat: ChatTable chats_expanded: Chat group_member: GroupMemberTable diff --git a/coderd/database/unique_constraint.go b/coderd/database/unique_constraint.go index 6ed84e110d..1afd078b8b 100644 --- a/coderd/database/unique_constraint.go +++ b/coderd/database/unique_constraint.go @@ -97,6 +97,8 @@ const ( UniqueTemplatesPkey UniqueConstraint = "templates_pkey" // ALTER TABLE ONLY templates ADD CONSTRAINT templates_pkey PRIMARY KEY (id); UniqueUsageEventsDailyPkey UniqueConstraint = "usage_events_daily_pkey" // ALTER TABLE ONLY usage_events_daily ADD CONSTRAINT usage_events_daily_pkey PRIMARY KEY (day, event_type); UniqueUsageEventsPkey UniqueConstraint = "usage_events_pkey" // ALTER TABLE ONLY usage_events ADD CONSTRAINT usage_events_pkey PRIMARY KEY (id); + UniqueUserAiProviderKeysPkey UniqueConstraint = "user_ai_provider_keys_pkey" // ALTER TABLE ONLY user_ai_provider_keys ADD CONSTRAINT user_ai_provider_keys_pkey PRIMARY KEY (id); + UniqueUserAiProviderKeysUserIDAiProviderIDKey UniqueConstraint = "user_ai_provider_keys_user_id_ai_provider_id_key" // ALTER TABLE ONLY user_ai_provider_keys ADD CONSTRAINT user_ai_provider_keys_user_id_ai_provider_id_key UNIQUE (user_id, ai_provider_id); UniqueUserChatProviderKeysPkey UniqueConstraint = "user_chat_provider_keys_pkey" // ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_pkey PRIMARY KEY (id); UniqueUserChatProviderKeysUserIDChatProviderIDKey UniqueConstraint = "user_chat_provider_keys_user_id_chat_provider_id_key" // ALTER TABLE ONLY user_chat_provider_keys ADD CONSTRAINT user_chat_provider_keys_user_id_chat_provider_id_key UNIQUE (user_id, chat_provider_id); UniqueUserConfigsPkey UniqueConstraint = "user_configs_pkey" // ALTER TABLE ONLY user_configs ADD CONSTRAINT user_configs_pkey PRIMARY KEY (user_id, key); diff --git a/codersdk/aiproviders.go b/codersdk/aiproviders.go index e3da8ee97e..01a89f766c 100644 --- a/codersdk/aiproviders.go +++ b/codersdk/aiproviders.go @@ -183,9 +183,9 @@ type AIProviderKey struct { } // CreateAIProviderRequest is the payload for creating a new AI -// provider. Name, Type, and BaseURL are required. APIKeys carries -// the plaintext keys for OpenAI/Anthropic providers; Bedrock -// providers authenticate via Settings and must omit APIKeys. +// provider. Name and Type are required. APIKeys carries the plaintext +// keys for OpenAI/Anthropic providers; Bedrock providers authenticate +// via Settings and must omit APIKeys. type CreateAIProviderRequest struct { Type AIProviderType `json:"type"` Name string `json:"name"` @@ -201,19 +201,24 @@ type CreateAIProviderRequest struct { func (req CreateAIProviderRequest) Validate() []ValidationError { var validations []ValidationError switch req.Type { - case AIProviderTypeOpenAI, AIProviderTypeAnthropic: + case AIProviderTypeOpenAI, + AIProviderTypeAnthropic, + AIProviderTypeAzure, + AIProviderTypeBedrock, + AIProviderTypeGoogle, + AIProviderTypeOpenAICompat, + AIProviderTypeOpenrouter, + AIProviderTypeVercel: case "": validations = append(validations, ValidationError{Field: "type", Detail: "type is required"}) default: validations = append(validations, ValidationError{ Field: "type", - Detail: fmt.Sprintf("unsupported provider type %q; expected one of: openai, anthropic", req.Type), + Detail: fmt.Sprintf("unsupported provider type %q", req.Type), }) } validations = append(validations, validateAIProviderName(req.Name)...) - if req.BaseURL == "" { - validations = append(validations, ValidationError{Field: "base_url", Detail: "base_url is required"}) - } else { + if req.BaseURL != "" { validations = append(validations, validateAIProviderBaseURL(req.BaseURL)...) } validations = append(validations, validateAIProviderAPIKeys(req.APIKeys)...) diff --git a/enterprise/cli/server_dbcrypt_test.go b/enterprise/cli/server_dbcrypt_test.go index 61579db50a..3893cfb6d3 100644 --- a/enterprise/cli/server_dbcrypt_test.go +++ b/enterprise/cli/server_dbcrypt_test.go @@ -5,6 +5,7 @@ import ( "database/sql" "encoding/base64" "testing" + "time" "github.com/google/uuid" "github.com/lib/pq" @@ -234,6 +235,25 @@ func genData(t *testing.T, db database.Store) []database.User { OAuthAccessToken: "access-" + usr.ID.String(), OAuthRefreshToken: "refresh-" + usr.ID.String(), }) + provider := dbgen.AIProvider(t, db, database.AIProvider{ + Name: "ai-provider-" + usr.ID.String(), + Settings: sql.NullString{String: "settings-" + usr.ID.String(), Valid: true}, + }) + _ = dbgen.AIProviderKey(t, db, database.AIProviderKey{ + ProviderID: provider.ID, + APIKey: "provider-key-" + usr.ID.String(), + }) + now := time.Now() + _, err := db.UpsertUserAIProviderKey(context.Background(), database.UpsertUserAIProviderKeyParams{ + ID: uuid.New(), + UserID: usr.ID, + AIProviderID: provider.ID, + APIKey: "user-ai-provider-key-" + usr.ID.String(), + CreatedAt: now, + UpdatedAt: now, + }) + require.NoError(t, err) + // Deleted users cannot have user_links or user_secrets. if !deleted { // Fun fact: our schema allows _all_ login types to have @@ -302,6 +322,36 @@ func requireEncryptedWithCipher(ctx context.Context, t *testing.T, db database.S requireEncryptedEquals(t, c, "value-"+userID.String(), s.Value) require.Equal(t, c.HexDigest(), s.ValueKeyID.String) } + + providers, err := db.GetAIProviders(ctx, database.GetAIProvidersParams{ + IncludeDeleted: true, + IncludeDisabled: true, + }) + require.NoError(t, err, "failed to get ai providers") + providerName := "ai-provider-" + userID.String() + var provider database.AIProvider + for _, p := range providers { + if p.Name == providerName { + provider = p + break + } + } + require.NotEqual(t, uuid.Nil, provider.ID, "expected ai provider for user %s", userID) + require.True(t, provider.Settings.Valid) + requireEncryptedEquals(t, c, "settings-"+userID.String(), provider.Settings.String) + require.Equal(t, c.HexDigest(), provider.SettingsKeyID.String) + + providerKeys, err := db.GetAIProviderKeysByProviderID(ctx, provider.ID) + require.NoError(t, err, "failed to get ai provider keys for provider %s", provider.ID) + require.Len(t, providerKeys, 1) + requireEncryptedEquals(t, c, "provider-key-"+userID.String(), providerKeys[0].APIKey) + require.Equal(t, c.HexDigest(), providerKeys[0].ApiKeyKeyID.String) + + userAIProviderKeys, err := db.GetUserAIProviderKeysByUserID(ctx, userID) + require.NoError(t, err, "failed to get user ai provider keys for user %s", userID) + require.Len(t, userAIProviderKeys, 1) + requireEncryptedEquals(t, c, "user-ai-provider-key-"+userID.String(), userAIProviderKeys[0].APIKey) + require.Equal(t, c.HexDigest(), userAIProviderKeys[0].ApiKeyKeyID.String) } // nullCipher is a dbcrypt.Cipher that does not encrypt or decrypt. diff --git a/enterprise/dbcrypt/cliutil.go b/enterprise/dbcrypt/cliutil.go index 4760b3309e..f573d080e6 100644 --- a/enterprise/dbcrypt/cliutil.go +++ b/enterprise/dbcrypt/cliutil.go @@ -209,6 +209,29 @@ func Rotate(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciphe log.Debug(ctx, "encrypted ai provider key", slog.F("ai_provider_key_id", apk.ID), slog.F("provider_id", apk.ProviderID), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest())) } + userAIProviderKeys, err := cryptDB.GetUserAIProviderKeys(ctx) + if err != nil { + return xerrors.Errorf("get user ai provider keys: %w", err) + } + log.Info(ctx, "encrypting user ai provider keys", slog.F("key_count", len(userAIProviderKeys))) + for idx, key := range userAIProviderKeys { + if strings.TrimSpace(key.APIKey) == "" { + continue + } + if key.ApiKeyKeyID.Valid && key.ApiKeyKeyID.String == ciphers[0].HexDigest() { + log.Debug(ctx, "skipping user ai provider key", slog.F("user_ai_provider_key_id", key.ID), slog.F("ai_provider_id", key.AIProviderID), slog.F("user_id", key.UserID), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest())) + continue + } + if _, err := cryptDB.UpdateEncryptedUserAIProviderKey(ctx, database.UpdateEncryptedUserAIProviderKeyParams{ + ID: key.ID, + APIKey: key.APIKey, + ApiKeyKeyID: sql.NullString{}, // dbcrypt will update as required + }); err != nil { + return xerrors.Errorf("update user ai provider key id=%s ai_provider_id=%s user_id=%s: %w", key.ID, key.AIProviderID, key.UserID, err) + } + log.Debug(ctx, "encrypted user ai provider key", slog.F("user_ai_provider_key_id", key.ID), slog.F("ai_provider_id", key.AIProviderID), slog.F("user_id", key.UserID), slog.F("current", idx+1), slog.F("cipher", ciphers[0].HexDigest())) + } + // Revoke old keys for _, c := range ciphers[1:] { if err := db.RevokeDBCryptKey(ctx, c.HexDigest()); err != nil { @@ -412,6 +435,26 @@ func Decrypt(ctx context.Context, log slog.Logger, sqlDB *sql.DB, ciphers []Ciph log.Debug(ctx, "decrypted ai provider key", slog.F("ai_provider_key_id", apk.ID), slog.F("provider_id", apk.ProviderID), slog.F("current", idx+1)) } + userAIProviderKeys, err := cryptDB.GetUserAIProviderKeys(ctx) + if err != nil { + return xerrors.Errorf("get user ai provider keys: %w", err) + } + log.Info(ctx, "decrypting user ai provider keys", slog.F("key_count", len(userAIProviderKeys))) + for idx, key := range userAIProviderKeys { + if !key.ApiKeyKeyID.Valid { + log.Debug(ctx, "skipping user ai provider key", slog.F("user_ai_provider_key_id", key.ID), slog.F("ai_provider_id", key.AIProviderID), slog.F("user_id", key.UserID), slog.F("current", idx+1)) + continue + } + if _, err := cryptDB.UpdateEncryptedUserAIProviderKey(ctx, database.UpdateEncryptedUserAIProviderKeyParams{ + ID: key.ID, + APIKey: key.APIKey, + ApiKeyKeyID: sql.NullString{}, // explicitly clear the key id + }); err != nil { + return xerrors.Errorf("decrypt user ai provider key id=%s ai_provider_id=%s user_id=%s: %w", key.ID, key.AIProviderID, key.UserID, err) + } + log.Debug(ctx, "decrypted user ai provider key", slog.F("user_ai_provider_key_id", key.ID), slog.F("ai_provider_id", key.AIProviderID), slog.F("user_id", key.UserID), slog.F("current", idx+1)) + } + // Revoke _all_ keys for _, c := range ciphers { if err := db.RevokeDBCryptKey(ctx, c.HexDigest()); err != nil { @@ -434,6 +477,8 @@ DELETE FROM external_auth_links OR oauth_refresh_token_key_id IS NOT NULL; DELETE FROM user_chat_provider_keys WHERE api_key_key_id IS NOT NULL; +DELETE FROM user_ai_provider_keys + WHERE api_key_key_id IS NOT NULL; DELETE FROM user_secrets WHERE value_key_id IS NOT NULL; UPDATE chat_providers diff --git a/enterprise/dbcrypt/dbcrypt.go b/enterprise/dbcrypt/dbcrypt.go index ac3256c9cd..b66ed6b3de 100644 --- a/enterprise/dbcrypt/dbcrypt.go +++ b/enterprise/dbcrypt/dbcrypt.go @@ -662,6 +662,98 @@ func (db *dbCrypt) UpdateChatProvider(ctx context.Context, params database.Updat return provider, nil } +func (db *dbCrypt) decryptUserAIProviderKey(key *database.UserAiProviderKey) error { + return db.decryptField(&key.APIKey, key.ApiKeyKeyID) +} + +func (db *dbCrypt) GetUserAIProviderKeyByProviderID(ctx context.Context, params database.GetUserAIProviderKeyByProviderIDParams) (database.UserAiProviderKey, error) { + key, err := db.Store.GetUserAIProviderKeyByProviderID(ctx, params) + if err != nil { + return database.UserAiProviderKey{}, err + } + if err := db.decryptUserAIProviderKey(&key); err != nil { + return database.UserAiProviderKey{}, err + } + return key, nil +} + +func (db *dbCrypt) GetUserAIProviderKeysByUserID(ctx context.Context, userID uuid.UUID) ([]database.UserAiProviderKey, error) { + keys, err := db.Store.GetUserAIProviderKeysByUserID(ctx, userID) + if err != nil { + return nil, err + } + for i := range keys { + if err := db.decryptUserAIProviderKey(&keys[i]); err != nil { + return nil, err + } + } + return keys, nil +} + +func (db *dbCrypt) GetUserAIProviderKeys(ctx context.Context) ([]database.UserAiProviderKey, error) { + keys, err := db.Store.GetUserAIProviderKeys(ctx) + if err != nil { + return nil, err + } + for i := range keys { + if err := db.decryptUserAIProviderKey(&keys[i]); err != nil { + return nil, err + } + } + return keys, nil +} + +func (db *dbCrypt) UpsertUserAIProviderKey(ctx context.Context, params database.UpsertUserAIProviderKeyParams) (database.UserAiProviderKey, error) { + if strings.TrimSpace(params.APIKey) == "" { + params.ApiKeyKeyID = sql.NullString{} + } else if err := db.encryptField(¶ms.APIKey, ¶ms.ApiKeyKeyID); err != nil { + return database.UserAiProviderKey{}, err + } + + key, err := db.Store.UpsertUserAIProviderKey(ctx, params) + if err != nil { + return database.UserAiProviderKey{}, err + } + if err := db.decryptUserAIProviderKey(&key); err != nil { + return database.UserAiProviderKey{}, err + } + return key, nil +} + +func (db *dbCrypt) UpdateUserAIProviderKey(ctx context.Context, params database.UpdateUserAIProviderKeyParams) (database.UserAiProviderKey, error) { + if strings.TrimSpace(params.APIKey) == "" { + params.ApiKeyKeyID = sql.NullString{} + } else if err := db.encryptField(¶ms.APIKey, ¶ms.ApiKeyKeyID); err != nil { + return database.UserAiProviderKey{}, err + } + + key, err := db.Store.UpdateUserAIProviderKey(ctx, params) + if err != nil { + return database.UserAiProviderKey{}, err + } + if err := db.decryptUserAIProviderKey(&key); err != nil { + return database.UserAiProviderKey{}, err + } + return key, nil +} + +func (db *dbCrypt) UpdateEncryptedUserAIProviderKey(ctx context.Context, params database.UpdateEncryptedUserAIProviderKeyParams) (database.UserAiProviderKey, error) { + if strings.TrimSpace(params.APIKey) == "" { + params.ApiKeyKeyID = sql.NullString{} + } else if err := db.encryptField(¶ms.APIKey, ¶ms.ApiKeyKeyID); err != nil { + return database.UserAiProviderKey{}, err + } + + key, err := db.Store.UpdateEncryptedUserAIProviderKey(ctx, params) + if err != nil { + return database.UserAiProviderKey{}, err + } + if err := db.decryptUserAIProviderKey(&key); err != nil { + return database.UserAiProviderKey{}, err + } + return key, nil +} + func (db *dbCrypt) decryptUserChatProviderKey(key *database.UserChatProviderKey) error { return db.decryptField(&key.APIKey, key.ApiKeyKeyID) } diff --git a/enterprise/dbcrypt/dbcrypt_internal_test.go b/enterprise/dbcrypt/dbcrypt_internal_test.go index fea3a4eeb6..5f70562863 100644 --- a/enterprise/dbcrypt/dbcrypt_internal_test.go +++ b/enterprise/dbcrypt/dbcrypt_internal_test.go @@ -1292,6 +1292,168 @@ func TestAIProviderKeys(t *testing.T) { }) } +func TestUserAIProviderKeys(t *testing.T) { + t.Parallel() + ctx := context.Background() + + const ( + //nolint:gosec // test credentials + initialAPIKey = "sk-initial-ai-provider-key-value" + //nolint:gosec // test credentials + updatedAPIKey = "sk-updated-ai-provider-key-value" + //nolint:gosec // test credentials + rotatedAPIKey = "sk-rotated-ai-provider-key-value" + ) + + insertProviderAndKey := func( + t *testing.T, + crypt *dbCrypt, + ciphers []Cipher, + ) (database.AIProvider, database.UserAiProviderKey) { + t.Helper() + user := dbgen.User(t, crypt, database.User{}) + provider := dbgen.AIProvider(t, crypt, database.AIProvider{}) + now := dbtime.Now() + + key, err := crypt.UpsertUserAIProviderKey(ctx, database.UpsertUserAIProviderKeyParams{ + ID: uuid.New(), + UserID: user.ID, + AIProviderID: provider.ID, + APIKey: initialAPIKey, + CreatedAt: now, + UpdatedAt: now, + }) + require.NoError(t, err) + require.Equal(t, initialAPIKey, key.APIKey) + require.Equal(t, ciphers[0].HexDigest(), key.ApiKeyKeyID.String) + return provider, key + } + + getRawUserAIProviderKey := func(t *testing.T, store database.Store, userID uuid.UUID, providerID uuid.UUID) database.UserAiProviderKey { + t.Helper() + key, err := store.GetUserAIProviderKeyByProviderID(ctx, database.GetUserAIProviderKeyByProviderIDParams{ + UserID: userID, + AIProviderID: providerID, + }) + require.NoError(t, err) + return key + } + + t.Run("UpsertUserAIProviderKeyCreatesValue", func(t *testing.T) { + t.Parallel() + db, crypt, ciphers := setup(t) + provider, key := insertProviderAndKey(t, crypt, ciphers) + + got, err := crypt.GetUserAIProviderKeyByProviderID(ctx, database.GetUserAIProviderKeyByProviderIDParams{ + UserID: key.UserID, + AIProviderID: provider.ID, + }) + require.NoError(t, err) + require.Equal(t, key.ID, got.ID) + require.Equal(t, initialAPIKey, got.APIKey) + require.Equal(t, ciphers[0].HexDigest(), got.ApiKeyKeyID.String) + + rawKey := getRawUserAIProviderKey(t, db, key.UserID, provider.ID) + require.NotEqual(t, initialAPIKey, rawKey.APIKey) + requireEncryptedEquals(t, ciphers[0], rawKey.APIKey, initialAPIKey) + }) + + t.Run("GetUserAIProviderKeysByUserID", func(t *testing.T) { + t.Parallel() + _, crypt, ciphers := setup(t) + provider, key := insertProviderAndKey(t, crypt, ciphers) + + keys, err := crypt.GetUserAIProviderKeysByUserID(ctx, key.UserID) + require.NoError(t, err) + require.Len(t, keys, 1) + require.Equal(t, key.ID, keys[0].ID) + require.Equal(t, provider.ID, keys[0].AIProviderID) + require.Equal(t, initialAPIKey, keys[0].APIKey) + require.Equal(t, ciphers[0].HexDigest(), keys[0].ApiKeyKeyID.String) + }) + + t.Run("GetUserAIProviderKeys", func(t *testing.T) { + t.Parallel() + _, crypt, ciphers := setup(t) + provider, key := insertProviderAndKey(t, crypt, ciphers) + + keys, err := crypt.GetUserAIProviderKeys(ctx) + require.NoError(t, err) + require.Len(t, keys, 1) + require.Equal(t, key.ID, keys[0].ID) + require.Equal(t, key.UserID, keys[0].UserID) + require.Equal(t, provider.ID, keys[0].AIProviderID) + require.Equal(t, initialAPIKey, keys[0].APIKey) + require.Equal(t, ciphers[0].HexDigest(), keys[0].ApiKeyKeyID.String) + }) + + t.Run("UpsertUserAIProviderKeyUpdatesValue", func(t *testing.T) { + t.Parallel() + db, crypt, ciphers := setup(t) + provider, key := insertProviderAndKey(t, crypt, ciphers) + updatedAt := key.UpdatedAt.Add(time.Minute) + + updated, err := crypt.UpsertUserAIProviderKey(ctx, database.UpsertUserAIProviderKeyParams{ + ID: uuid.New(), + UserID: key.UserID, + AIProviderID: provider.ID, + APIKey: updatedAPIKey, + CreatedAt: key.CreatedAt.Add(time.Minute), + UpdatedAt: updatedAt, + }) + require.NoError(t, err) + require.Equal(t, key.ID, updated.ID) + require.Equal(t, key.CreatedAt, updated.CreatedAt) + require.Equal(t, updatedAt, updated.UpdatedAt) + require.Equal(t, updatedAPIKey, updated.APIKey) + require.Equal(t, ciphers[0].HexDigest(), updated.ApiKeyKeyID.String) + + rawKey := getRawUserAIProviderKey(t, db, key.UserID, provider.ID) + require.NotEqual(t, updatedAPIKey, rawKey.APIKey) + requireEncryptedEquals(t, ciphers[0], rawKey.APIKey, updatedAPIKey) + }) + + t.Run("UpdateUserAIProviderKey", func(t *testing.T) { + t.Parallel() + db, crypt, ciphers := setup(t) + provider, key := insertProviderAndKey(t, crypt, ciphers) + + updated, err := crypt.UpdateUserAIProviderKey(ctx, database.UpdateUserAIProviderKeyParams{ + UserID: key.UserID, + AIProviderID: provider.ID, + APIKey: updatedAPIKey, + }) + require.NoError(t, err) + require.Equal(t, key.ID, updated.ID) + require.WithinDuration(t, dbtime.Now(), updated.UpdatedAt, time.Minute) + require.Equal(t, updatedAPIKey, updated.APIKey) + require.Equal(t, ciphers[0].HexDigest(), updated.ApiKeyKeyID.String) + + rawKey := getRawUserAIProviderKey(t, db, key.UserID, provider.ID) + require.NotEqual(t, updatedAPIKey, rawKey.APIKey) + requireEncryptedEquals(t, ciphers[0], rawKey.APIKey, updatedAPIKey) + }) + + t.Run("UpdateEncryptedUserAIProviderKey", func(t *testing.T) { + t.Parallel() + db, crypt, ciphers := setup(t) + provider, key := insertProviderAndKey(t, crypt, ciphers) + + updated, err := crypt.UpdateEncryptedUserAIProviderKey(ctx, database.UpdateEncryptedUserAIProviderKeyParams{ + ID: key.ID, + APIKey: rotatedAPIKey, + }) + require.NoError(t, err) + require.Equal(t, key.ID, updated.ID) + require.Equal(t, rotatedAPIKey, updated.APIKey) + require.Equal(t, ciphers[0].HexDigest(), updated.ApiKeyKeyID.String) + + rawKey := getRawUserAIProviderKey(t, db, key.UserID, provider.ID) + require.NotEqual(t, rotatedAPIKey, rawKey.APIKey) + requireEncryptedEquals(t, ciphers[0], rawKey.APIKey, rotatedAPIKey) + }) +} + func TestMCPServerUserTokens(t *testing.T) { t.Parallel() ctx := context.Background() diff --git a/site/src/api/typesGenerated.ts b/site/src/api/typesGenerated.ts index c66edeedf7..6bf3f9e51a 100644 --- a/site/src/api/typesGenerated.ts +++ b/site/src/api/typesGenerated.ts @@ -3205,9 +3205,9 @@ export interface ConvertLoginRequest { // From codersdk/aiproviders.go /** * CreateAIProviderRequest is the payload for creating a new AI - * provider. Name, Type, and BaseURL are required. APIKeys carries - * the plaintext keys for OpenAI/Anthropic providers; Bedrock - * providers authenticate via Settings and must omit APIKeys. + * provider. Name and Type are required. APIKeys carries the plaintext + * keys for OpenAI/Anthropic providers; Bedrock providers authenticate + * via Settings and must omit APIKeys. */ export interface CreateAIProviderRequest { readonly type: AIProviderType;