From 9c30cf886eb22d50ef7c92d006367281a0dbfa2d Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Thu, 28 May 2026 22:46:06 +0000 Subject: [PATCH] feat(coderd/database): queries + dbcrypt for mcp_server_user_header_values Wires the per-user MCP custom header values store from migration 000510 through the data layer: sqlc queries, dbauthz wrappers (ActionRead/UpdatePersonal mirroring ExternalAuthLink), dbcrypt envelope encryption around header_values, dbgen fakes, and dbmock + dbmetrics regeneration. Adds CustomHeadersUserKeys to InsertMCPServerConfig and UpdateMCPServerConfig so the admin-configured set of user-set header names round-trips with the existing custom_headers JSON. Subsequent commits will surface this via the SDK, HTTP handlers, runtime overlay in chatd, and the admin + user-settings UI. --- coderd/database/dbauthz/dbauthz.go | 25 +++ coderd/database/dbauthz/dbauthz_test.go | 39 +++++ coderd/database/dbgen/dbgen.go | 14 ++ coderd/database/dbmetrics/querymetrics.go | 32 ++++ coderd/database/dbmock/dbmock.go | 59 +++++++ coderd/database/modelmethods.go | 3 + coderd/database/querier.go | 4 + coderd/database/queries.sql.go | 168 +++++++++++++++++-- coderd/database/queries/mcpserverconfigs.sql | 46 +++++ coderd/mcp.go | 4 + enterprise/dbcrypt/dbcrypt.go | 47 ++++++ enterprise/dbcrypt/dbcrypt_internal_test.go | 92 ++++++++++ 12 files changed, 520 insertions(+), 13 deletions(-) diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index a1a7497153..7efd42e404 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -2114,6 +2114,13 @@ func (q *querier) DeleteMCPServerConfigByID(ctx context.Context, id uuid.UUID) e return q.db.DeleteMCPServerConfigByID(ctx, id) } +func (q *querier) DeleteMCPServerUserHeaderValues(ctx context.Context, arg database.DeleteMCPServerUserHeaderValuesParams) error { + fetch := func(ctx context.Context, arg database.DeleteMCPServerUserHeaderValuesParams) (database.McpServerUserHeaderValue, error) { + return q.db.GetMCPServerUserHeaderValues(ctx, database.GetMCPServerUserHeaderValuesParams(arg)) + } + return fetchAndExec(q.log, q.auth, policy.ActionUpdatePersonal, fetch, q.db.DeleteMCPServerUserHeaderValues)(ctx, arg) +} + func (q *querier) DeleteMCPServerUserToken(ctx context.Context, arg database.DeleteMCPServerUserTokenParams) error { if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil { return err @@ -3712,6 +3719,14 @@ func (q *querier) GetMCPServerConfigsByIDs(ctx context.Context, ids []uuid.UUID) return q.db.GetMCPServerConfigsByIDs(ctx, ids) } +func (q *querier) GetMCPServerUserHeaderValues(ctx context.Context, arg database.GetMCPServerUserHeaderValuesParams) (database.McpServerUserHeaderValue, error) { + return fetchWithAction(q.log, q.auth, policy.ActionReadPersonal, q.db.GetMCPServerUserHeaderValues)(ctx, arg) +} + +func (q *querier) GetMCPServerUserHeaderValuesByUserID(ctx context.Context, userID uuid.UUID) ([]database.McpServerUserHeaderValue, error) { + return fetchWithPostFilter(q.auth, policy.ActionReadPersonal, q.db.GetMCPServerUserHeaderValuesByUserID)(ctx, userID) +} + func (q *querier) GetMCPServerUserToken(ctx context.Context, arg database.GetMCPServerUserTokenParams) (database.MCPServerUserToken, error) { if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil { return database.MCPServerUserToken{}, err @@ -8261,6 +8276,16 @@ func (q *querier) UpsertLogoURL(ctx context.Context, value string) error { return q.db.UpsertLogoURL(ctx, value) } +func (q *querier) UpsertMCPServerUserHeaderValues(ctx context.Context, arg database.UpsertMCPServerUserHeaderValuesParams) (database.McpServerUserHeaderValue, error) { + fetch := func(ctx context.Context, arg database.UpsertMCPServerUserHeaderValuesParams) (database.McpServerUserHeaderValue, error) { + return q.db.GetMCPServerUserHeaderValues(ctx, database.GetMCPServerUserHeaderValuesParams{ + MCPServerConfigID: arg.MCPServerConfigID, + UserID: arg.UserID, + }) + } + return fetchAndQuery(q.log, q.auth, policy.ActionUpdatePersonal, fetch, q.db.UpsertMCPServerUserHeaderValues)(ctx, arg) +} + func (q *querier) UpsertMCPServerUserToken(ctx context.Context, arg database.UpsertMCPServerUserTokenParams) (database.MCPServerUserToken, error) { if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil { return database.MCPServerUserToken{}, err diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index f788fa71e2..c8d65fb049 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -1662,6 +1662,45 @@ func (s *MethodTestSuite) TestChats() { dbm.EXPECT().GetMCPServerUserTokensByUserID(gomock.Any(), userID).Return(tokens, nil).AnyTimes() check.Args(userID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(tokens) })) + s.Run("GetMCPServerUserHeaderValues", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + arg := database.GetMCPServerUserHeaderValuesParams{ + MCPServerConfigID: uuid.New(), + UserID: uuid.New(), + } + value := testutil.Fake(s.T(), faker, database.McpServerUserHeaderValue{MCPServerConfigID: arg.MCPServerConfigID, UserID: arg.UserID}) + dbm.EXPECT().GetMCPServerUserHeaderValues(gomock.Any(), arg).Return(value, nil).AnyTimes() + check.Args(arg).Asserts(value, policy.ActionReadPersonal).Returns(value) + })) + s.Run("GetMCPServerUserHeaderValuesByUserID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + userID := uuid.New() + values := []database.McpServerUserHeaderValue{testutil.Fake(s.T(), faker, database.McpServerUserHeaderValue{UserID: userID})} + dbm.EXPECT().GetMCPServerUserHeaderValuesByUserID(gomock.Any(), userID).Return(values, nil).AnyTimes() + check.Args(userID).Asserts(values[0], policy.ActionReadPersonal).Returns(values) + })) + s.Run("UpsertMCPServerUserHeaderValues", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + arg := database.UpsertMCPServerUserHeaderValuesParams{ + MCPServerConfigID: uuid.New(), + UserID: uuid.New(), + HeaderValues: `{"X-User-Token":"secret"}`, + } + value := testutil.Fake(s.T(), faker, database.McpServerUserHeaderValue{MCPServerConfigID: arg.MCPServerConfigID, UserID: arg.UserID}) + dbm.EXPECT().GetMCPServerUserHeaderValues(gomock.Any(), database.GetMCPServerUserHeaderValuesParams{ + MCPServerConfigID: arg.MCPServerConfigID, + UserID: arg.UserID, + }).Return(value, nil).AnyTimes() + dbm.EXPECT().UpsertMCPServerUserHeaderValues(gomock.Any(), arg).Return(value, nil).AnyTimes() + check.Args(arg).Asserts(value, policy.ActionUpdatePersonal).Returns(value) + })) + s.Run("DeleteMCPServerUserHeaderValues", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + arg := database.DeleteMCPServerUserHeaderValuesParams{ + MCPServerConfigID: uuid.New(), + UserID: uuid.New(), + } + value := testutil.Fake(s.T(), faker, database.McpServerUserHeaderValue{MCPServerConfigID: arg.MCPServerConfigID, UserID: arg.UserID}) + dbm.EXPECT().GetMCPServerUserHeaderValues(gomock.Any(), database.GetMCPServerUserHeaderValuesParams(arg)).Return(value, nil).AnyTimes() + dbm.EXPECT().DeleteMCPServerUserHeaderValues(gomock.Any(), arg).Return(nil).AnyTimes() + check.Args(arg).Asserts(value, policy.ActionUpdatePersonal).Returns() + })) s.Run("InsertMCPServerConfig", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { arg := database.InsertMCPServerConfigParams{ DisplayName: "Test MCP Server", diff --git a/coderd/database/dbgen/dbgen.go b/coderd/database/dbgen/dbgen.go index 416a2b7257..452eaaa2db 100644 --- a/coderd/database/dbgen/dbgen.go +++ b/coderd/database/dbgen/dbgen.go @@ -349,6 +349,7 @@ func MCPServerConfig(t testing.TB, db database.Store, seed database.MCPServerCon APIKeyValueKeyID: seed.APIKeyValueKeyID, CustomHeaders: seed.CustomHeaders, CustomHeadersKeyID: seed.CustomHeadersKeyID, + CustomHeadersUserKeys: takeFirstSlice(seed.CustomHeadersUserKeys, []string{}), ToolAllowList: takeFirstSlice(seed.ToolAllowList, []string{}), ToolDenyList: takeFirstSlice(seed.ToolDenyList, []string{}), Availability: takeFirst(seed.Availability, "default_off"), @@ -363,6 +364,19 @@ func MCPServerConfig(t testing.TB, db database.Store, seed database.MCPServerCon return cfg } +func MCPServerUserHeaderValues(t testing.TB, db database.Store, seed database.McpServerUserHeaderValue) database.McpServerUserHeaderValue { + t.Helper() + + row, err := db.UpsertMCPServerUserHeaderValues(genCtx, database.UpsertMCPServerUserHeaderValuesParams{ + MCPServerConfigID: takeFirst(seed.MCPServerConfigID, uuid.New()), + UserID: takeFirst(seed.UserID, uuid.New()), + HeaderValues: takeFirst(seed.HeaderValues, "{}"), + HeaderValuesKeyID: seed.HeaderValuesKeyID, + }) + require.NoError(t, err, "upsert MCP server user header values") + return row +} + func ConnectionLog(t testing.TB, db database.Store, seed database.UpsertConnectionLogParams) database.ConnectionLog { arg := database.UpsertConnectionLogParams{ ID: takeFirst(seed.ID, uuid.New()), diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index e7120ec588..0e7a333cc0 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -577,6 +577,14 @@ func (m queryMetricsStore) DeleteMCPServerConfigByID(ctx context.Context, id uui return r0 } +func (m queryMetricsStore) DeleteMCPServerUserHeaderValues(ctx context.Context, arg database.DeleteMCPServerUserHeaderValuesParams) error { + start := time.Now() + r0 := m.s.DeleteMCPServerUserHeaderValues(ctx, arg) + m.queryLatencies.WithLabelValues("DeleteMCPServerUserHeaderValues").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteMCPServerUserHeaderValues").Inc() + return r0 +} + func (m queryMetricsStore) DeleteMCPServerUserToken(ctx context.Context, arg database.DeleteMCPServerUserTokenParams) error { start := time.Now() r0 := m.s.DeleteMCPServerUserToken(ctx, arg) @@ -2121,6 +2129,22 @@ func (m queryMetricsStore) GetMCPServerConfigsByIDs(ctx context.Context, ids []u return r0, r1 } +func (m queryMetricsStore) GetMCPServerUserHeaderValues(ctx context.Context, arg database.GetMCPServerUserHeaderValuesParams) (database.McpServerUserHeaderValue, error) { + start := time.Now() + r0, r1 := m.s.GetMCPServerUserHeaderValues(ctx, arg) + m.queryLatencies.WithLabelValues("GetMCPServerUserHeaderValues").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetMCPServerUserHeaderValues").Inc() + return r0, r1 +} + +func (m queryMetricsStore) GetMCPServerUserHeaderValuesByUserID(ctx context.Context, userID uuid.UUID) ([]database.McpServerUserHeaderValue, error) { + start := time.Now() + r0, r1 := m.s.GetMCPServerUserHeaderValuesByUserID(ctx, userID) + m.queryLatencies.WithLabelValues("GetMCPServerUserHeaderValuesByUserID").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetMCPServerUserHeaderValuesByUserID").Inc() + return r0, r1 +} + func (m queryMetricsStore) GetMCPServerUserToken(ctx context.Context, arg database.GetMCPServerUserTokenParams) (database.MCPServerUserToken, error) { start := time.Now() r0, r1 := m.s.GetMCPServerUserToken(ctx, arg) @@ -5953,6 +5977,14 @@ func (m queryMetricsStore) UpsertLogoURL(ctx context.Context, value string) erro return r0 } +func (m queryMetricsStore) UpsertMCPServerUserHeaderValues(ctx context.Context, arg database.UpsertMCPServerUserHeaderValuesParams) (database.McpServerUserHeaderValue, error) { + start := time.Now() + r0, r1 := m.s.UpsertMCPServerUserHeaderValues(ctx, arg) + m.queryLatencies.WithLabelValues("UpsertMCPServerUserHeaderValues").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertMCPServerUserHeaderValues").Inc() + return r0, r1 +} + func (m queryMetricsStore) UpsertMCPServerUserToken(ctx context.Context, arg database.UpsertMCPServerUserTokenParams) (database.MCPServerUserToken, error) { start := time.Now() r0, r1 := m.s.UpsertMCPServerUserToken(ctx, arg) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index 0f6799e638..bb09bdb26c 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -960,6 +960,20 @@ func (mr *MockStoreMockRecorder) DeleteMCPServerConfigByID(ctx, id any) *gomock. return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteMCPServerConfigByID", reflect.TypeOf((*MockStore)(nil).DeleteMCPServerConfigByID), ctx, id) } +// DeleteMCPServerUserHeaderValues mocks base method. +func (m *MockStore) DeleteMCPServerUserHeaderValues(ctx context.Context, arg database.DeleteMCPServerUserHeaderValuesParams) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteMCPServerUserHeaderValues", ctx, arg) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteMCPServerUserHeaderValues indicates an expected call of DeleteMCPServerUserHeaderValues. +func (mr *MockStoreMockRecorder) DeleteMCPServerUserHeaderValues(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteMCPServerUserHeaderValues", reflect.TypeOf((*MockStore)(nil).DeleteMCPServerUserHeaderValues), ctx, arg) +} + // DeleteMCPServerUserToken mocks base method. func (m *MockStore) DeleteMCPServerUserToken(ctx context.Context, arg database.DeleteMCPServerUserTokenParams) error { m.ctrl.T.Helper() @@ -3945,6 +3959,36 @@ func (mr *MockStoreMockRecorder) GetMCPServerConfigsByIDs(ctx, ids any) *gomock. return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMCPServerConfigsByIDs", reflect.TypeOf((*MockStore)(nil).GetMCPServerConfigsByIDs), ctx, ids) } +// GetMCPServerUserHeaderValues mocks base method. +func (m *MockStore) GetMCPServerUserHeaderValues(ctx context.Context, arg database.GetMCPServerUserHeaderValuesParams) (database.McpServerUserHeaderValue, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetMCPServerUserHeaderValues", ctx, arg) + ret0, _ := ret[0].(database.McpServerUserHeaderValue) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetMCPServerUserHeaderValues indicates an expected call of GetMCPServerUserHeaderValues. +func (mr *MockStoreMockRecorder) GetMCPServerUserHeaderValues(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMCPServerUserHeaderValues", reflect.TypeOf((*MockStore)(nil).GetMCPServerUserHeaderValues), ctx, arg) +} + +// GetMCPServerUserHeaderValuesByUserID mocks base method. +func (m *MockStore) GetMCPServerUserHeaderValuesByUserID(ctx context.Context, userID uuid.UUID) ([]database.McpServerUserHeaderValue, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetMCPServerUserHeaderValuesByUserID", ctx, userID) + ret0, _ := ret[0].([]database.McpServerUserHeaderValue) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetMCPServerUserHeaderValuesByUserID indicates an expected call of GetMCPServerUserHeaderValuesByUserID. +func (mr *MockStoreMockRecorder) GetMCPServerUserHeaderValuesByUserID(ctx, userID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMCPServerUserHeaderValuesByUserID", reflect.TypeOf((*MockStore)(nil).GetMCPServerUserHeaderValuesByUserID), ctx, userID) +} + // GetMCPServerUserToken mocks base method. func (m *MockStore) GetMCPServerUserToken(ctx context.Context, arg database.GetMCPServerUserTokenParams) (database.MCPServerUserToken, error) { m.ctrl.T.Helper() @@ -11172,6 +11216,21 @@ func (mr *MockStoreMockRecorder) UpsertLogoURL(ctx, value any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertLogoURL", reflect.TypeOf((*MockStore)(nil).UpsertLogoURL), ctx, value) } +// UpsertMCPServerUserHeaderValues mocks base method. +func (m *MockStore) UpsertMCPServerUserHeaderValues(ctx context.Context, arg database.UpsertMCPServerUserHeaderValuesParams) (database.McpServerUserHeaderValue, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpsertMCPServerUserHeaderValues", ctx, arg) + ret0, _ := ret[0].(database.McpServerUserHeaderValue) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpsertMCPServerUserHeaderValues indicates an expected call of UpsertMCPServerUserHeaderValues. +func (mr *MockStoreMockRecorder) UpsertMCPServerUserHeaderValues(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertMCPServerUserHeaderValues", reflect.TypeOf((*MockStore)(nil).UpsertMCPServerUserHeaderValues), ctx, arg) +} + // UpsertMCPServerUserToken mocks base method. func (m *MockStore) UpsertMCPServerUserToken(ctx context.Context, arg database.UpsertMCPServerUserTokenParams) (database.MCPServerUserToken, error) { m.ctrl.T.Helper() diff --git a/coderd/database/modelmethods.go b/coderd/database/modelmethods.go index 62eb12a1d2..741228d200 100644 --- a/coderd/database/modelmethods.go +++ b/coderd/database/modelmethods.go @@ -621,6 +621,9 @@ func (u GetUsersRow) RBACObject() rbac.Object { func (u GitSSHKey) RBACObject() rbac.Object { return rbac.ResourceUserObject(u.UserID) } func (u ExternalAuthLink) RBACObject() rbac.Object { return rbac.ResourceUserObject(u.UserID) } func (u UserLink) RBACObject() rbac.Object { return rbac.ResourceUserObject(u.UserID) } +func (u McpServerUserHeaderValue) RBACObject() rbac.Object { + return rbac.ResourceUserObject(u.UserID) +} func (u ExternalAuthLink) OAuthToken() *oauth2.Token { return &oauth2.Token{ diff --git a/coderd/database/querier.go b/coderd/database/querier.go index a6c8f3e7db..312b7deddc 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -138,6 +138,7 @@ type sqlcQuerier interface { DeleteGroupMemberFromGroup(ctx context.Context, arg DeleteGroupMemberFromGroupParams) error DeleteLicense(ctx context.Context, id int32) (int32, error) DeleteMCPServerConfigByID(ctx context.Context, id uuid.UUID) error + DeleteMCPServerUserHeaderValues(ctx context.Context, arg DeleteMCPServerUserHeaderValuesParams) error DeleteMCPServerUserToken(ctx context.Context, arg DeleteMCPServerUserTokenParams) error DeleteOAuth2ProviderAppByClientID(ctx context.Context, id uuid.UUID) error DeleteOAuth2ProviderAppByID(ctx context.Context, id uuid.UUID) error @@ -515,6 +516,8 @@ type sqlcQuerier interface { GetMCPServerConfigBySlug(ctx context.Context, slug string) (MCPServerConfig, error) GetMCPServerConfigs(ctx context.Context) ([]MCPServerConfig, error) GetMCPServerConfigsByIDs(ctx context.Context, ids []uuid.UUID) ([]MCPServerConfig, error) + GetMCPServerUserHeaderValues(ctx context.Context, arg GetMCPServerUserHeaderValuesParams) (McpServerUserHeaderValue, error) + GetMCPServerUserHeaderValuesByUserID(ctx context.Context, userID uuid.UUID) ([]McpServerUserHeaderValue, error) GetMCPServerUserToken(ctx context.Context, arg GetMCPServerUserTokenParams) (MCPServerUserToken, error) GetMCPServerUserTokensByUserID(ctx context.Context, userID uuid.UUID) ([]MCPServerUserToken, error) GetNotificationMessagesByStatus(ctx context.Context, arg GetNotificationMessagesByStatusParams) ([]NotificationMessage, error) @@ -1390,6 +1393,7 @@ type sqlcQuerier interface { UpsertHealthSettings(ctx context.Context, value string) error UpsertLastUpdateCheck(ctx context.Context, value string) error UpsertLogoURL(ctx context.Context, value string) error + UpsertMCPServerUserHeaderValues(ctx context.Context, arg UpsertMCPServerUserHeaderValuesParams) (McpServerUserHeaderValue, error) UpsertMCPServerUserToken(ctx context.Context, arg UpsertMCPServerUserTokenParams) (MCPServerUserToken, error) // Insert or update notification report generator logs with recent activity. UpsertNotificationReportGeneratorLog(ctx context.Context, arg UpsertNotificationReportGeneratorLogParams) error diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index 2876d79d2d..fbf994fcae 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -15039,6 +15039,24 @@ func (q *sqlQuerier) DeleteMCPServerConfigByID(ctx context.Context, id uuid.UUID return err } +const deleteMCPServerUserHeaderValues = `-- name: DeleteMCPServerUserHeaderValues :exec +DELETE FROM + mcp_server_user_header_values +WHERE + mcp_server_config_id = $1::uuid + AND user_id = $2::uuid +` + +type DeleteMCPServerUserHeaderValuesParams struct { + MCPServerConfigID uuid.UUID `db:"mcp_server_config_id" json:"mcp_server_config_id"` + UserID uuid.UUID `db:"user_id" json:"user_id"` +} + +func (q *sqlQuerier) DeleteMCPServerUserHeaderValues(ctx context.Context, arg DeleteMCPServerUserHeaderValuesParams) error { + _, err := q.db.ExecContext(ctx, deleteMCPServerUserHeaderValues, arg.MCPServerConfigID, arg.UserID) + return err +} + const deleteMCPServerUserToken = `-- name: DeleteMCPServerUserToken :exec DELETE FROM mcp_server_user_tokens @@ -15416,6 +15434,76 @@ func (q *sqlQuerier) GetMCPServerConfigsByIDs(ctx context.Context, ids []uuid.UU return items, nil } +const getMCPServerUserHeaderValues = `-- name: GetMCPServerUserHeaderValues :one +SELECT + id, mcp_server_config_id, user_id, header_values, header_values_key_id, created_at, updated_at +FROM + mcp_server_user_header_values +WHERE + mcp_server_config_id = $1::uuid + AND user_id = $2::uuid +` + +type GetMCPServerUserHeaderValuesParams struct { + MCPServerConfigID uuid.UUID `db:"mcp_server_config_id" json:"mcp_server_config_id"` + UserID uuid.UUID `db:"user_id" json:"user_id"` +} + +func (q *sqlQuerier) GetMCPServerUserHeaderValues(ctx context.Context, arg GetMCPServerUserHeaderValuesParams) (McpServerUserHeaderValue, error) { + row := q.db.QueryRowContext(ctx, getMCPServerUserHeaderValues, arg.MCPServerConfigID, arg.UserID) + var i McpServerUserHeaderValue + err := row.Scan( + &i.ID, + &i.MCPServerConfigID, + &i.UserID, + &i.HeaderValues, + &i.HeaderValuesKeyID, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const getMCPServerUserHeaderValuesByUserID = `-- name: GetMCPServerUserHeaderValuesByUserID :many +SELECT + id, mcp_server_config_id, user_id, header_values, header_values_key_id, created_at, updated_at +FROM + mcp_server_user_header_values +WHERE + user_id = $1::uuid +` + +func (q *sqlQuerier) GetMCPServerUserHeaderValuesByUserID(ctx context.Context, userID uuid.UUID) ([]McpServerUserHeaderValue, error) { + rows, err := q.db.QueryContext(ctx, getMCPServerUserHeaderValuesByUserID, userID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []McpServerUserHeaderValue + for rows.Next() { + var i McpServerUserHeaderValue + if err := rows.Scan( + &i.ID, + &i.MCPServerConfigID, + &i.UserID, + &i.HeaderValues, + &i.HeaderValuesKeyID, + &i.CreatedAt, + &i.UpdatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getMCPServerUserToken = `-- name: GetMCPServerUserToken :one SELECT id, mcp_server_config_id, user_id, access_token, access_token_key_id, refresh_token, refresh_token_key_id, token_type, expiry, created_at, updated_at @@ -15514,6 +15602,7 @@ INSERT INTO mcp_server_configs ( api_key_value_key_id, custom_headers, custom_headers_key_id, + custom_headers_user_keys, tool_allow_list, tool_deny_list, availability, @@ -15544,13 +15633,14 @@ INSERT INTO mcp_server_configs ( $18::text, $19::text[], $20::text[], - $21::text, - $22::boolean, + $21::text[], + $22::text, $23::boolean, $24::boolean, $25::boolean, - $26::uuid, - $27::uuid + $26::boolean, + $27::uuid, + $28::uuid ) RETURNING id, display_name, slug, description, icon_url, transport, url, auth_type, oauth2_client_id, oauth2_client_secret, oauth2_client_secret_key_id, oauth2_auth_url, oauth2_token_url, oauth2_scopes, api_key_header, api_key_value, api_key_value_key_id, custom_headers, custom_headers_key_id, tool_allow_list, tool_deny_list, availability, enabled, created_by, updated_by, created_at, updated_at, model_intent, allow_in_plan_mode, forward_coder_headers, custom_headers_user_keys @@ -15575,6 +15665,7 @@ type InsertMCPServerConfigParams struct { APIKeyValueKeyID sql.NullString `db:"api_key_value_key_id" json:"api_key_value_key_id"` CustomHeaders string `db:"custom_headers" json:"custom_headers"` CustomHeadersKeyID sql.NullString `db:"custom_headers_key_id" json:"custom_headers_key_id"` + CustomHeadersUserKeys []string `db:"custom_headers_user_keys" json:"custom_headers_user_keys"` ToolAllowList []string `db:"tool_allow_list" json:"tool_allow_list"` ToolDenyList []string `db:"tool_deny_list" json:"tool_deny_list"` Availability string `db:"availability" json:"availability"` @@ -15606,6 +15697,7 @@ func (q *sqlQuerier) InsertMCPServerConfig(ctx context.Context, arg InsertMCPSer arg.APIKeyValueKeyID, arg.CustomHeaders, arg.CustomHeadersKeyID, + pq.Array(arg.CustomHeadersUserKeys), pq.Array(arg.ToolAllowList), pq.Array(arg.ToolDenyList), arg.Availability, @@ -15675,17 +15767,18 @@ SET api_key_value_key_id = $16::text, custom_headers = $17::text, custom_headers_key_id = $18::text, - tool_allow_list = $19::text[], - tool_deny_list = $20::text[], - availability = $21::text, - enabled = $22::boolean, - model_intent = $23::boolean, - allow_in_plan_mode = $24::boolean, - forward_coder_headers = $25::boolean, - updated_by = $26::uuid, + custom_headers_user_keys = $19::text[], + tool_allow_list = $20::text[], + tool_deny_list = $21::text[], + availability = $22::text, + enabled = $23::boolean, + model_intent = $24::boolean, + allow_in_plan_mode = $25::boolean, + forward_coder_headers = $26::boolean, + updated_by = $27::uuid, updated_at = NOW() WHERE - id = $27::uuid + id = $28::uuid RETURNING id, display_name, slug, description, icon_url, transport, url, auth_type, oauth2_client_id, oauth2_client_secret, oauth2_client_secret_key_id, oauth2_auth_url, oauth2_token_url, oauth2_scopes, api_key_header, api_key_value, api_key_value_key_id, custom_headers, custom_headers_key_id, tool_allow_list, tool_deny_list, availability, enabled, created_by, updated_by, created_at, updated_at, model_intent, allow_in_plan_mode, forward_coder_headers, custom_headers_user_keys ` @@ -15709,6 +15802,7 @@ type UpdateMCPServerConfigParams struct { APIKeyValueKeyID sql.NullString `db:"api_key_value_key_id" json:"api_key_value_key_id"` CustomHeaders string `db:"custom_headers" json:"custom_headers"` CustomHeadersKeyID sql.NullString `db:"custom_headers_key_id" json:"custom_headers_key_id"` + CustomHeadersUserKeys []string `db:"custom_headers_user_keys" json:"custom_headers_user_keys"` ToolAllowList []string `db:"tool_allow_list" json:"tool_allow_list"` ToolDenyList []string `db:"tool_deny_list" json:"tool_deny_list"` Availability string `db:"availability" json:"availability"` @@ -15740,6 +15834,7 @@ func (q *sqlQuerier) UpdateMCPServerConfig(ctx context.Context, arg UpdateMCPSer arg.APIKeyValueKeyID, arg.CustomHeaders, arg.CustomHeadersKeyID, + pq.Array(arg.CustomHeadersUserKeys), pq.Array(arg.ToolAllowList), pq.Array(arg.ToolDenyList), arg.Availability, @@ -15787,6 +15882,53 @@ func (q *sqlQuerier) UpdateMCPServerConfig(ctx context.Context, arg UpdateMCPSer return i, err } +const upsertMCPServerUserHeaderValues = `-- name: UpsertMCPServerUserHeaderValues :one +INSERT INTO mcp_server_user_header_values ( + mcp_server_config_id, + user_id, + header_values, + header_values_key_id +) VALUES ( + $1::uuid, + $2::uuid, + $3::text, + $4::text +) +ON CONFLICT (mcp_server_config_id, user_id) DO UPDATE SET + header_values = $3::text, + header_values_key_id = $4::text, + updated_at = NOW() +RETURNING + id, mcp_server_config_id, user_id, header_values, header_values_key_id, created_at, updated_at +` + +type UpsertMCPServerUserHeaderValuesParams struct { + MCPServerConfigID uuid.UUID `db:"mcp_server_config_id" json:"mcp_server_config_id"` + UserID uuid.UUID `db:"user_id" json:"user_id"` + HeaderValues string `db:"header_values" json:"header_values"` + HeaderValuesKeyID sql.NullString `db:"header_values_key_id" json:"header_values_key_id"` +} + +func (q *sqlQuerier) UpsertMCPServerUserHeaderValues(ctx context.Context, arg UpsertMCPServerUserHeaderValuesParams) (McpServerUserHeaderValue, error) { + row := q.db.QueryRowContext(ctx, upsertMCPServerUserHeaderValues, + arg.MCPServerConfigID, + arg.UserID, + arg.HeaderValues, + arg.HeaderValuesKeyID, + ) + var i McpServerUserHeaderValue + err := row.Scan( + &i.ID, + &i.MCPServerConfigID, + &i.UserID, + &i.HeaderValues, + &i.HeaderValuesKeyID, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + const upsertMCPServerUserToken = `-- name: UpsertMCPServerUserToken :one INSERT INTO mcp_server_user_tokens ( mcp_server_config_id, diff --git a/coderd/database/queries/mcpserverconfigs.sql b/coderd/database/queries/mcpserverconfigs.sql index 3d05a2b102..2f3b17b97e 100644 --- a/coderd/database/queries/mcpserverconfigs.sql +++ b/coderd/database/queries/mcpserverconfigs.sql @@ -73,6 +73,7 @@ INSERT INTO mcp_server_configs ( api_key_value_key_id, custom_headers, custom_headers_key_id, + custom_headers_user_keys, tool_allow_list, tool_deny_list, availability, @@ -101,6 +102,7 @@ INSERT INTO mcp_server_configs ( sqlc.narg('api_key_value_key_id')::text, @custom_headers::text, sqlc.narg('custom_headers_key_id')::text, + @custom_headers_user_keys::text[], @tool_allow_list::text[], @tool_deny_list::text[], @availability::text, @@ -136,6 +138,7 @@ SET api_key_value_key_id = sqlc.narg('api_key_value_key_id')::text, custom_headers = @custom_headers::text, custom_headers_key_id = sqlc.narg('custom_headers_key_id')::text, + custom_headers_user_keys = @custom_headers_user_keys::text[], tool_allow_list = @tool_allow_list::text[], tool_deny_list = @tool_deny_list::text[], availability = @availability::text, @@ -211,6 +214,49 @@ WHERE mcp_server_config_id = @mcp_server_config_id::uuid AND user_id = @user_id::uuid; +-- name: GetMCPServerUserHeaderValues :one +SELECT + * +FROM + mcp_server_user_header_values +WHERE + mcp_server_config_id = @mcp_server_config_id::uuid + AND user_id = @user_id::uuid; + +-- name: GetMCPServerUserHeaderValuesByUserID :many +SELECT + * +FROM + mcp_server_user_header_values +WHERE + user_id = @user_id::uuid; + +-- name: UpsertMCPServerUserHeaderValues :one +INSERT INTO mcp_server_user_header_values ( + mcp_server_config_id, + user_id, + header_values, + header_values_key_id +) VALUES ( + @mcp_server_config_id::uuid, + @user_id::uuid, + @header_values::text, + sqlc.narg('header_values_key_id')::text +) +ON CONFLICT (mcp_server_config_id, user_id) DO UPDATE SET + header_values = @header_values::text, + header_values_key_id = sqlc.narg('header_values_key_id')::text, + updated_at = NOW() +RETURNING + *; + +-- name: DeleteMCPServerUserHeaderValues :exec +DELETE FROM + mcp_server_user_header_values +WHERE + mcp_server_config_id = @mcp_server_config_id::uuid + AND user_id = @user_id::uuid; + -- name: CleanupDeletedMCPServerIDsFromChats :exec UPDATE chats SET mcp_server_ids = ( diff --git a/coderd/mcp.go b/coderd/mcp.go index 3e0a5829f7..576613af11 100644 --- a/coderd/mcp.go +++ b/coderd/mcp.go @@ -277,6 +277,7 @@ func (api *API) createMCPServerConfig(rw http.ResponseWriter, r *http.Request) { APIKeyValueKeyID: sql.NullString{}, CustomHeaders: customHeadersJSON, CustomHeadersKeyID: sql.NullString{}, + CustomHeadersUserKeys: nil, ToolAllowList: coalesceStringSlice(trimStringSlice(req.ToolAllowList)), ToolDenyList: coalesceStringSlice(trimStringSlice(req.ToolDenyList)), Availability: strings.TrimSpace(req.Availability), @@ -366,6 +367,7 @@ func (api *API) createMCPServerConfig(rw http.ResponseWriter, r *http.Request) { APIKeyValueKeyID: inserted.APIKeyValueKeyID, CustomHeaders: inserted.CustomHeaders, CustomHeadersKeyID: inserted.CustomHeadersKeyID, + CustomHeadersUserKeys: inserted.CustomHeadersUserKeys, ToolAllowList: inserted.ToolAllowList, ToolDenyList: inserted.ToolDenyList, Availability: inserted.Availability, @@ -436,6 +438,7 @@ func (api *API) createMCPServerConfig(rw http.ResponseWriter, r *http.Request) { APIKeyValueKeyID: sql.NullString{}, CustomHeaders: customHeadersJSON, CustomHeadersKeyID: sql.NullString{}, + CustomHeadersUserKeys: nil, ToolAllowList: coalesceStringSlice(trimStringSlice(req.ToolAllowList)), ToolDenyList: coalesceStringSlice(trimStringSlice(req.ToolDenyList)), Availability: strings.TrimSpace(req.Availability), @@ -785,6 +788,7 @@ func (api *API) updateMCPServerConfig(rw http.ResponseWriter, r *http.Request) { APIKeyValueKeyID: apiKeyValueKeyID, CustomHeaders: customHeaders, CustomHeadersKeyID: customHeadersKeyID, + CustomHeadersUserKeys: existing.CustomHeadersUserKeys, ToolAllowList: toolAllowList, ToolDenyList: toolDenyList, Availability: availability, diff --git a/enterprise/dbcrypt/dbcrypt.go b/enterprise/dbcrypt/dbcrypt.go index 44cdb5554e..4239bd899c 100644 --- a/enterprise/dbcrypt/dbcrypt.go +++ b/enterprise/dbcrypt/dbcrypt.go @@ -702,6 +702,12 @@ func (db *dbCrypt) decryptMCPServerUserToken(tok *database.MCPServerUserToken) e return db.decryptField(&tok.RefreshToken, tok.RefreshTokenKeyID) } +// decryptMCPServerUserHeaderValues decrypts all encrypted fields on a +// single McpServerUserHeaderValue in place. +func (db *dbCrypt) decryptMCPServerUserHeaderValues(row *database.McpServerUserHeaderValue) error { + return db.decryptField(&row.HeaderValues, row.HeaderValuesKeyID) +} + func (db *dbCrypt) GetMCPServerConfigByID(ctx context.Context, id uuid.UUID) (database.MCPServerConfig, error) { cfg, err := db.Store.GetMCPServerConfigByID(ctx, id) if err != nil { @@ -876,6 +882,47 @@ func (db *dbCrypt) UpsertMCPServerUserToken(ctx context.Context, params database return tok, nil } +func (db *dbCrypt) GetMCPServerUserHeaderValues(ctx context.Context, arg database.GetMCPServerUserHeaderValuesParams) (database.McpServerUserHeaderValue, error) { + row, err := db.Store.GetMCPServerUserHeaderValues(ctx, arg) + if err != nil { + return database.McpServerUserHeaderValue{}, err + } + if err := db.decryptMCPServerUserHeaderValues(&row); err != nil { + return database.McpServerUserHeaderValue{}, err + } + return row, nil +} + +func (db *dbCrypt) GetMCPServerUserHeaderValuesByUserID(ctx context.Context, userID uuid.UUID) ([]database.McpServerUserHeaderValue, error) { + rows, err := db.Store.GetMCPServerUserHeaderValuesByUserID(ctx, userID) + if err != nil { + return nil, err + } + for i := range rows { + if err := db.decryptMCPServerUserHeaderValues(&rows[i]); err != nil { + return nil, err + } + } + return rows, nil +} + +func (db *dbCrypt) UpsertMCPServerUserHeaderValues(ctx context.Context, params database.UpsertMCPServerUserHeaderValuesParams) (database.McpServerUserHeaderValue, error) { + if strings.TrimSpace(params.HeaderValues) == "" { + params.HeaderValuesKeyID = sql.NullString{} + } else if err := db.encryptField(¶ms.HeaderValues, ¶ms.HeaderValuesKeyID); err != nil { + return database.McpServerUserHeaderValue{}, err + } + + row, err := db.Store.UpsertMCPServerUserHeaderValues(ctx, params) + if err != nil { + return database.McpServerUserHeaderValue{}, err + } + if err := db.decryptMCPServerUserHeaderValues(&row); err != nil { + return database.McpServerUserHeaderValue{}, err + } + return row, nil +} + func (db *dbCrypt) CreateUserSecret(ctx context.Context, params database.CreateUserSecretParams) (database.UserSecret, error) { if err := db.encryptField(¶ms.Value, ¶ms.ValueKeyID); err != nil { return database.UserSecret{}, err diff --git a/enterprise/dbcrypt/dbcrypt_internal_test.go b/enterprise/dbcrypt/dbcrypt_internal_test.go index e5a433399b..ec74e92dad 100644 --- a/enterprise/dbcrypt/dbcrypt_internal_test.go +++ b/enterprise/dbcrypt/dbcrypt_internal_test.go @@ -1570,6 +1570,98 @@ func TestMCPServerUserTokens(t *testing.T) { }) } +func TestMCPServerUserHeaderValues(t *testing.T) { + t.Parallel() + ctx := context.Background() + + const headerValues = `{"X-User-Token":"super-secret-user-token"}` + + // insertConfigAndValues creates a user, an MCP server config with a + // user-set custom header, and the user-supplied values row through the + // encrypted store. + insertConfigAndValues := func( + t *testing.T, + crypt *dbCrypt, + ciphers []Cipher, + ) (database.MCPServerConfig, database.McpServerUserHeaderValue) { + t.Helper() + user := dbgen.User(t, crypt, database.User{}) + cfg := dbgen.MCPServerConfig(t, crypt, database.MCPServerConfig{ + DisplayName: "Header Values Test MCP", + AuthType: "custom_headers", + CustomHeadersUserKeys: []string{"X-User-Token"}, + CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true}, + UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true}, + }) + + row, err := crypt.UpsertMCPServerUserHeaderValues(ctx, database.UpsertMCPServerUserHeaderValuesParams{ + MCPServerConfigID: cfg.ID, + UserID: user.ID, + HeaderValues: headerValues, + }) + require.NoError(t, err) + require.Equal(t, headerValues, row.HeaderValues) + require.Equal(t, ciphers[0].HexDigest(), row.HeaderValuesKeyID.String) + return cfg, row + } + + t.Run("UpsertMCPServerUserHeaderValues", func(t *testing.T) { + t.Parallel() + db, crypt, ciphers := setup(t) + cfg, row := insertConfigAndValues(t, crypt, ciphers) + + // Verify the raw DB value is encrypted. + rawRow, err := db.GetMCPServerUserHeaderValues(ctx, database.GetMCPServerUserHeaderValuesParams{ + MCPServerConfigID: cfg.ID, + UserID: row.UserID, + }) + require.NoError(t, err) + requireEncryptedEquals(t, ciphers[0], rawRow.HeaderValues, headerValues) + }) + + t.Run("GetMCPServerUserHeaderValues", func(t *testing.T) { + t.Parallel() + db, crypt, ciphers := setup(t) + cfg, row := insertConfigAndValues(t, crypt, ciphers) + + got, err := crypt.GetMCPServerUserHeaderValues(ctx, database.GetMCPServerUserHeaderValuesParams{ + MCPServerConfigID: cfg.ID, + UserID: row.UserID, + }) + require.NoError(t, err) + require.Equal(t, headerValues, got.HeaderValues) + require.Equal(t, ciphers[0].HexDigest(), got.HeaderValuesKeyID.String) + + // Raw values must be encrypted. + rawRow, err := db.GetMCPServerUserHeaderValues(ctx, database.GetMCPServerUserHeaderValuesParams{ + MCPServerConfigID: cfg.ID, + UserID: row.UserID, + }) + require.NoError(t, err) + requireEncryptedEquals(t, ciphers[0], rawRow.HeaderValues, headerValues) + }) + + t.Run("GetMCPServerUserHeaderValuesByUserID", func(t *testing.T) { + t.Parallel() + db, crypt, ciphers := setup(t) + cfg, row := insertConfigAndValues(t, crypt, ciphers) + + rows, err := crypt.GetMCPServerUserHeaderValuesByUserID(ctx, row.UserID) + require.NoError(t, err) + require.Len(t, rows, 1) + require.Equal(t, headerValues, rows[0].HeaderValues) + require.Equal(t, ciphers[0].HexDigest(), rows[0].HeaderValuesKeyID.String) + + // Raw values must be encrypted. + rawRow, err := db.GetMCPServerUserHeaderValues(ctx, database.GetMCPServerUserHeaderValuesParams{ + MCPServerConfigID: cfg.ID, + UserID: row.UserID, + }) + require.NoError(t, err) + requireEncryptedEquals(t, ciphers[0], rawRow.HeaderValues, headerValues) + }) +} + func TestUserSecrets(t *testing.T) { t.Parallel() ctx := context.Background()