feat(coderd/database): queries + dbcrypt for mcp_server_user_header_values

Wires the per-user MCP custom header values store from migration 000510
through the data layer: sqlc queries, dbauthz wrappers (ActionRead/UpdatePersonal
mirroring ExternalAuthLink), dbcrypt envelope encryption around header_values,
dbgen fakes, and dbmock + dbmetrics regeneration.

Adds CustomHeadersUserKeys to InsertMCPServerConfig and UpdateMCPServerConfig
so the admin-configured set of user-set header names round-trips with the
existing custom_headers JSON.

Subsequent commits will surface this via the SDK, HTTP handlers, runtime
overlay in chatd, and the admin + user-settings UI.
This commit is contained in:
Steven Masley
2026-05-28 22:46:06 +00:00
parent dc4ff00956
commit 9c30cf886e
12 changed files with 520 additions and 13 deletions
+25
View File
@@ -2114,6 +2114,13 @@ func (q *querier) DeleteMCPServerConfigByID(ctx context.Context, id uuid.UUID) e
return q.db.DeleteMCPServerConfigByID(ctx, id)
}
func (q *querier) DeleteMCPServerUserHeaderValues(ctx context.Context, arg database.DeleteMCPServerUserHeaderValuesParams) error {
fetch := func(ctx context.Context, arg database.DeleteMCPServerUserHeaderValuesParams) (database.McpServerUserHeaderValue, error) {
return q.db.GetMCPServerUserHeaderValues(ctx, database.GetMCPServerUserHeaderValuesParams(arg))
}
return fetchAndExec(q.log, q.auth, policy.ActionUpdatePersonal, fetch, q.db.DeleteMCPServerUserHeaderValues)(ctx, arg)
}
func (q *querier) DeleteMCPServerUserToken(ctx context.Context, arg database.DeleteMCPServerUserTokenParams) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return err
@@ -3712,6 +3719,14 @@ func (q *querier) GetMCPServerConfigsByIDs(ctx context.Context, ids []uuid.UUID)
return q.db.GetMCPServerConfigsByIDs(ctx, ids)
}
func (q *querier) GetMCPServerUserHeaderValues(ctx context.Context, arg database.GetMCPServerUserHeaderValuesParams) (database.McpServerUserHeaderValue, error) {
return fetchWithAction(q.log, q.auth, policy.ActionReadPersonal, q.db.GetMCPServerUserHeaderValues)(ctx, arg)
}
func (q *querier) GetMCPServerUserHeaderValuesByUserID(ctx context.Context, userID uuid.UUID) ([]database.McpServerUserHeaderValue, error) {
return fetchWithPostFilter(q.auth, policy.ActionReadPersonal, q.db.GetMCPServerUserHeaderValuesByUserID)(ctx, userID)
}
func (q *querier) GetMCPServerUserToken(ctx context.Context, arg database.GetMCPServerUserTokenParams) (database.MCPServerUserToken, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil {
return database.MCPServerUserToken{}, err
@@ -8261,6 +8276,16 @@ func (q *querier) UpsertLogoURL(ctx context.Context, value string) error {
return q.db.UpsertLogoURL(ctx, value)
}
func (q *querier) UpsertMCPServerUserHeaderValues(ctx context.Context, arg database.UpsertMCPServerUserHeaderValuesParams) (database.McpServerUserHeaderValue, error) {
fetch := func(ctx context.Context, arg database.UpsertMCPServerUserHeaderValuesParams) (database.McpServerUserHeaderValue, error) {
return q.db.GetMCPServerUserHeaderValues(ctx, database.GetMCPServerUserHeaderValuesParams{
MCPServerConfigID: arg.MCPServerConfigID,
UserID: arg.UserID,
})
}
return fetchAndQuery(q.log, q.auth, policy.ActionUpdatePersonal, fetch, q.db.UpsertMCPServerUserHeaderValues)(ctx, arg)
}
func (q *querier) UpsertMCPServerUserToken(ctx context.Context, arg database.UpsertMCPServerUserTokenParams) (database.MCPServerUserToken, error) {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil {
return database.MCPServerUserToken{}, err
+39
View File
@@ -1662,6 +1662,45 @@ func (s *MethodTestSuite) TestChats() {
dbm.EXPECT().GetMCPServerUserTokensByUserID(gomock.Any(), userID).Return(tokens, nil).AnyTimes()
check.Args(userID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(tokens)
}))
s.Run("GetMCPServerUserHeaderValues", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
arg := database.GetMCPServerUserHeaderValuesParams{
MCPServerConfigID: uuid.New(),
UserID: uuid.New(),
}
value := testutil.Fake(s.T(), faker, database.McpServerUserHeaderValue{MCPServerConfigID: arg.MCPServerConfigID, UserID: arg.UserID})
dbm.EXPECT().GetMCPServerUserHeaderValues(gomock.Any(), arg).Return(value, nil).AnyTimes()
check.Args(arg).Asserts(value, policy.ActionReadPersonal).Returns(value)
}))
s.Run("GetMCPServerUserHeaderValuesByUserID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
userID := uuid.New()
values := []database.McpServerUserHeaderValue{testutil.Fake(s.T(), faker, database.McpServerUserHeaderValue{UserID: userID})}
dbm.EXPECT().GetMCPServerUserHeaderValuesByUserID(gomock.Any(), userID).Return(values, nil).AnyTimes()
check.Args(userID).Asserts(values[0], policy.ActionReadPersonal).Returns(values)
}))
s.Run("UpsertMCPServerUserHeaderValues", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
arg := database.UpsertMCPServerUserHeaderValuesParams{
MCPServerConfigID: uuid.New(),
UserID: uuid.New(),
HeaderValues: `{"X-User-Token":"secret"}`,
}
value := testutil.Fake(s.T(), faker, database.McpServerUserHeaderValue{MCPServerConfigID: arg.MCPServerConfigID, UserID: arg.UserID})
dbm.EXPECT().GetMCPServerUserHeaderValues(gomock.Any(), database.GetMCPServerUserHeaderValuesParams{
MCPServerConfigID: arg.MCPServerConfigID,
UserID: arg.UserID,
}).Return(value, nil).AnyTimes()
dbm.EXPECT().UpsertMCPServerUserHeaderValues(gomock.Any(), arg).Return(value, nil).AnyTimes()
check.Args(arg).Asserts(value, policy.ActionUpdatePersonal).Returns(value)
}))
s.Run("DeleteMCPServerUserHeaderValues", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
arg := database.DeleteMCPServerUserHeaderValuesParams{
MCPServerConfigID: uuid.New(),
UserID: uuid.New(),
}
value := testutil.Fake(s.T(), faker, database.McpServerUserHeaderValue{MCPServerConfigID: arg.MCPServerConfigID, UserID: arg.UserID})
dbm.EXPECT().GetMCPServerUserHeaderValues(gomock.Any(), database.GetMCPServerUserHeaderValuesParams(arg)).Return(value, nil).AnyTimes()
dbm.EXPECT().DeleteMCPServerUserHeaderValues(gomock.Any(), arg).Return(nil).AnyTimes()
check.Args(arg).Asserts(value, policy.ActionUpdatePersonal).Returns()
}))
s.Run("InsertMCPServerConfig", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) {
arg := database.InsertMCPServerConfigParams{
DisplayName: "Test MCP Server",
+14
View File
@@ -349,6 +349,7 @@ func MCPServerConfig(t testing.TB, db database.Store, seed database.MCPServerCon
APIKeyValueKeyID: seed.APIKeyValueKeyID,
CustomHeaders: seed.CustomHeaders,
CustomHeadersKeyID: seed.CustomHeadersKeyID,
CustomHeadersUserKeys: takeFirstSlice(seed.CustomHeadersUserKeys, []string{}),
ToolAllowList: takeFirstSlice(seed.ToolAllowList, []string{}),
ToolDenyList: takeFirstSlice(seed.ToolDenyList, []string{}),
Availability: takeFirst(seed.Availability, "default_off"),
@@ -363,6 +364,19 @@ func MCPServerConfig(t testing.TB, db database.Store, seed database.MCPServerCon
return cfg
}
func MCPServerUserHeaderValues(t testing.TB, db database.Store, seed database.McpServerUserHeaderValue) database.McpServerUserHeaderValue {
t.Helper()
row, err := db.UpsertMCPServerUserHeaderValues(genCtx, database.UpsertMCPServerUserHeaderValuesParams{
MCPServerConfigID: takeFirst(seed.MCPServerConfigID, uuid.New()),
UserID: takeFirst(seed.UserID, uuid.New()),
HeaderValues: takeFirst(seed.HeaderValues, "{}"),
HeaderValuesKeyID: seed.HeaderValuesKeyID,
})
require.NoError(t, err, "upsert MCP server user header values")
return row
}
func ConnectionLog(t testing.TB, db database.Store, seed database.UpsertConnectionLogParams) database.ConnectionLog {
arg := database.UpsertConnectionLogParams{
ID: takeFirst(seed.ID, uuid.New()),
+32
View File
@@ -577,6 +577,14 @@ func (m queryMetricsStore) DeleteMCPServerConfigByID(ctx context.Context, id uui
return r0
}
func (m queryMetricsStore) DeleteMCPServerUserHeaderValues(ctx context.Context, arg database.DeleteMCPServerUserHeaderValuesParams) error {
start := time.Now()
r0 := m.s.DeleteMCPServerUserHeaderValues(ctx, arg)
m.queryLatencies.WithLabelValues("DeleteMCPServerUserHeaderValues").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteMCPServerUserHeaderValues").Inc()
return r0
}
func (m queryMetricsStore) DeleteMCPServerUserToken(ctx context.Context, arg database.DeleteMCPServerUserTokenParams) error {
start := time.Now()
r0 := m.s.DeleteMCPServerUserToken(ctx, arg)
@@ -2121,6 +2129,22 @@ func (m queryMetricsStore) GetMCPServerConfigsByIDs(ctx context.Context, ids []u
return r0, r1
}
func (m queryMetricsStore) GetMCPServerUserHeaderValues(ctx context.Context, arg database.GetMCPServerUserHeaderValuesParams) (database.McpServerUserHeaderValue, error) {
start := time.Now()
r0, r1 := m.s.GetMCPServerUserHeaderValues(ctx, arg)
m.queryLatencies.WithLabelValues("GetMCPServerUserHeaderValues").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetMCPServerUserHeaderValues").Inc()
return r0, r1
}
func (m queryMetricsStore) GetMCPServerUserHeaderValuesByUserID(ctx context.Context, userID uuid.UUID) ([]database.McpServerUserHeaderValue, error) {
start := time.Now()
r0, r1 := m.s.GetMCPServerUserHeaderValuesByUserID(ctx, userID)
m.queryLatencies.WithLabelValues("GetMCPServerUserHeaderValuesByUserID").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetMCPServerUserHeaderValuesByUserID").Inc()
return r0, r1
}
func (m queryMetricsStore) GetMCPServerUserToken(ctx context.Context, arg database.GetMCPServerUserTokenParams) (database.MCPServerUserToken, error) {
start := time.Now()
r0, r1 := m.s.GetMCPServerUserToken(ctx, arg)
@@ -5953,6 +5977,14 @@ func (m queryMetricsStore) UpsertLogoURL(ctx context.Context, value string) erro
return r0
}
func (m queryMetricsStore) UpsertMCPServerUserHeaderValues(ctx context.Context, arg database.UpsertMCPServerUserHeaderValuesParams) (database.McpServerUserHeaderValue, error) {
start := time.Now()
r0, r1 := m.s.UpsertMCPServerUserHeaderValues(ctx, arg)
m.queryLatencies.WithLabelValues("UpsertMCPServerUserHeaderValues").Observe(time.Since(start).Seconds())
m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertMCPServerUserHeaderValues").Inc()
return r0, r1
}
func (m queryMetricsStore) UpsertMCPServerUserToken(ctx context.Context, arg database.UpsertMCPServerUserTokenParams) (database.MCPServerUserToken, error) {
start := time.Now()
r0, r1 := m.s.UpsertMCPServerUserToken(ctx, arg)
+59
View File
@@ -960,6 +960,20 @@ func (mr *MockStoreMockRecorder) DeleteMCPServerConfigByID(ctx, id any) *gomock.
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteMCPServerConfigByID", reflect.TypeOf((*MockStore)(nil).DeleteMCPServerConfigByID), ctx, id)
}
// DeleteMCPServerUserHeaderValues mocks base method.
func (m *MockStore) DeleteMCPServerUserHeaderValues(ctx context.Context, arg database.DeleteMCPServerUserHeaderValuesParams) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteMCPServerUserHeaderValues", ctx, arg)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteMCPServerUserHeaderValues indicates an expected call of DeleteMCPServerUserHeaderValues.
func (mr *MockStoreMockRecorder) DeleteMCPServerUserHeaderValues(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteMCPServerUserHeaderValues", reflect.TypeOf((*MockStore)(nil).DeleteMCPServerUserHeaderValues), ctx, arg)
}
// DeleteMCPServerUserToken mocks base method.
func (m *MockStore) DeleteMCPServerUserToken(ctx context.Context, arg database.DeleteMCPServerUserTokenParams) error {
m.ctrl.T.Helper()
@@ -3945,6 +3959,36 @@ func (mr *MockStoreMockRecorder) GetMCPServerConfigsByIDs(ctx, ids any) *gomock.
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMCPServerConfigsByIDs", reflect.TypeOf((*MockStore)(nil).GetMCPServerConfigsByIDs), ctx, ids)
}
// GetMCPServerUserHeaderValues mocks base method.
func (m *MockStore) GetMCPServerUserHeaderValues(ctx context.Context, arg database.GetMCPServerUserHeaderValuesParams) (database.McpServerUserHeaderValue, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetMCPServerUserHeaderValues", ctx, arg)
ret0, _ := ret[0].(database.McpServerUserHeaderValue)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetMCPServerUserHeaderValues indicates an expected call of GetMCPServerUserHeaderValues.
func (mr *MockStoreMockRecorder) GetMCPServerUserHeaderValues(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMCPServerUserHeaderValues", reflect.TypeOf((*MockStore)(nil).GetMCPServerUserHeaderValues), ctx, arg)
}
// GetMCPServerUserHeaderValuesByUserID mocks base method.
func (m *MockStore) GetMCPServerUserHeaderValuesByUserID(ctx context.Context, userID uuid.UUID) ([]database.McpServerUserHeaderValue, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetMCPServerUserHeaderValuesByUserID", ctx, userID)
ret0, _ := ret[0].([]database.McpServerUserHeaderValue)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetMCPServerUserHeaderValuesByUserID indicates an expected call of GetMCPServerUserHeaderValuesByUserID.
func (mr *MockStoreMockRecorder) GetMCPServerUserHeaderValuesByUserID(ctx, userID any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMCPServerUserHeaderValuesByUserID", reflect.TypeOf((*MockStore)(nil).GetMCPServerUserHeaderValuesByUserID), ctx, userID)
}
// GetMCPServerUserToken mocks base method.
func (m *MockStore) GetMCPServerUserToken(ctx context.Context, arg database.GetMCPServerUserTokenParams) (database.MCPServerUserToken, error) {
m.ctrl.T.Helper()
@@ -11172,6 +11216,21 @@ func (mr *MockStoreMockRecorder) UpsertLogoURL(ctx, value any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertLogoURL", reflect.TypeOf((*MockStore)(nil).UpsertLogoURL), ctx, value)
}
// UpsertMCPServerUserHeaderValues mocks base method.
func (m *MockStore) UpsertMCPServerUserHeaderValues(ctx context.Context, arg database.UpsertMCPServerUserHeaderValuesParams) (database.McpServerUserHeaderValue, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpsertMCPServerUserHeaderValues", ctx, arg)
ret0, _ := ret[0].(database.McpServerUserHeaderValue)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// UpsertMCPServerUserHeaderValues indicates an expected call of UpsertMCPServerUserHeaderValues.
func (mr *MockStoreMockRecorder) UpsertMCPServerUserHeaderValues(ctx, arg any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertMCPServerUserHeaderValues", reflect.TypeOf((*MockStore)(nil).UpsertMCPServerUserHeaderValues), ctx, arg)
}
// UpsertMCPServerUserToken mocks base method.
func (m *MockStore) UpsertMCPServerUserToken(ctx context.Context, arg database.UpsertMCPServerUserTokenParams) (database.MCPServerUserToken, error) {
m.ctrl.T.Helper()
+3
View File
@@ -621,6 +621,9 @@ func (u GetUsersRow) RBACObject() rbac.Object {
func (u GitSSHKey) RBACObject() rbac.Object { return rbac.ResourceUserObject(u.UserID) }
func (u ExternalAuthLink) RBACObject() rbac.Object { return rbac.ResourceUserObject(u.UserID) }
func (u UserLink) RBACObject() rbac.Object { return rbac.ResourceUserObject(u.UserID) }
func (u McpServerUserHeaderValue) RBACObject() rbac.Object {
return rbac.ResourceUserObject(u.UserID)
}
func (u ExternalAuthLink) OAuthToken() *oauth2.Token {
return &oauth2.Token{
+4
View File
@@ -138,6 +138,7 @@ type sqlcQuerier interface {
DeleteGroupMemberFromGroup(ctx context.Context, arg DeleteGroupMemberFromGroupParams) error
DeleteLicense(ctx context.Context, id int32) (int32, error)
DeleteMCPServerConfigByID(ctx context.Context, id uuid.UUID) error
DeleteMCPServerUserHeaderValues(ctx context.Context, arg DeleteMCPServerUserHeaderValuesParams) error
DeleteMCPServerUserToken(ctx context.Context, arg DeleteMCPServerUserTokenParams) error
DeleteOAuth2ProviderAppByClientID(ctx context.Context, id uuid.UUID) error
DeleteOAuth2ProviderAppByID(ctx context.Context, id uuid.UUID) error
@@ -515,6 +516,8 @@ type sqlcQuerier interface {
GetMCPServerConfigBySlug(ctx context.Context, slug string) (MCPServerConfig, error)
GetMCPServerConfigs(ctx context.Context) ([]MCPServerConfig, error)
GetMCPServerConfigsByIDs(ctx context.Context, ids []uuid.UUID) ([]MCPServerConfig, error)
GetMCPServerUserHeaderValues(ctx context.Context, arg GetMCPServerUserHeaderValuesParams) (McpServerUserHeaderValue, error)
GetMCPServerUserHeaderValuesByUserID(ctx context.Context, userID uuid.UUID) ([]McpServerUserHeaderValue, error)
GetMCPServerUserToken(ctx context.Context, arg GetMCPServerUserTokenParams) (MCPServerUserToken, error)
GetMCPServerUserTokensByUserID(ctx context.Context, userID uuid.UUID) ([]MCPServerUserToken, error)
GetNotificationMessagesByStatus(ctx context.Context, arg GetNotificationMessagesByStatusParams) ([]NotificationMessage, error)
@@ -1390,6 +1393,7 @@ type sqlcQuerier interface {
UpsertHealthSettings(ctx context.Context, value string) error
UpsertLastUpdateCheck(ctx context.Context, value string) error
UpsertLogoURL(ctx context.Context, value string) error
UpsertMCPServerUserHeaderValues(ctx context.Context, arg UpsertMCPServerUserHeaderValuesParams) (McpServerUserHeaderValue, error)
UpsertMCPServerUserToken(ctx context.Context, arg UpsertMCPServerUserTokenParams) (MCPServerUserToken, error)
// Insert or update notification report generator logs with recent activity.
UpsertNotificationReportGeneratorLog(ctx context.Context, arg UpsertNotificationReportGeneratorLogParams) error
+155 -13
View File
@@ -15039,6 +15039,24 @@ func (q *sqlQuerier) DeleteMCPServerConfigByID(ctx context.Context, id uuid.UUID
return err
}
const deleteMCPServerUserHeaderValues = `-- name: DeleteMCPServerUserHeaderValues :exec
DELETE FROM
mcp_server_user_header_values
WHERE
mcp_server_config_id = $1::uuid
AND user_id = $2::uuid
`
type DeleteMCPServerUserHeaderValuesParams struct {
MCPServerConfigID uuid.UUID `db:"mcp_server_config_id" json:"mcp_server_config_id"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
}
func (q *sqlQuerier) DeleteMCPServerUserHeaderValues(ctx context.Context, arg DeleteMCPServerUserHeaderValuesParams) error {
_, err := q.db.ExecContext(ctx, deleteMCPServerUserHeaderValues, arg.MCPServerConfigID, arg.UserID)
return err
}
const deleteMCPServerUserToken = `-- name: DeleteMCPServerUserToken :exec
DELETE FROM
mcp_server_user_tokens
@@ -15416,6 +15434,76 @@ func (q *sqlQuerier) GetMCPServerConfigsByIDs(ctx context.Context, ids []uuid.UU
return items, nil
}
const getMCPServerUserHeaderValues = `-- name: GetMCPServerUserHeaderValues :one
SELECT
id, mcp_server_config_id, user_id, header_values, header_values_key_id, created_at, updated_at
FROM
mcp_server_user_header_values
WHERE
mcp_server_config_id = $1::uuid
AND user_id = $2::uuid
`
type GetMCPServerUserHeaderValuesParams struct {
MCPServerConfigID uuid.UUID `db:"mcp_server_config_id" json:"mcp_server_config_id"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
}
func (q *sqlQuerier) GetMCPServerUserHeaderValues(ctx context.Context, arg GetMCPServerUserHeaderValuesParams) (McpServerUserHeaderValue, error) {
row := q.db.QueryRowContext(ctx, getMCPServerUserHeaderValues, arg.MCPServerConfigID, arg.UserID)
var i McpServerUserHeaderValue
err := row.Scan(
&i.ID,
&i.MCPServerConfigID,
&i.UserID,
&i.HeaderValues,
&i.HeaderValuesKeyID,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const getMCPServerUserHeaderValuesByUserID = `-- name: GetMCPServerUserHeaderValuesByUserID :many
SELECT
id, mcp_server_config_id, user_id, header_values, header_values_key_id, created_at, updated_at
FROM
mcp_server_user_header_values
WHERE
user_id = $1::uuid
`
func (q *sqlQuerier) GetMCPServerUserHeaderValuesByUserID(ctx context.Context, userID uuid.UUID) ([]McpServerUserHeaderValue, error) {
rows, err := q.db.QueryContext(ctx, getMCPServerUserHeaderValuesByUserID, userID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []McpServerUserHeaderValue
for rows.Next() {
var i McpServerUserHeaderValue
if err := rows.Scan(
&i.ID,
&i.MCPServerConfigID,
&i.UserID,
&i.HeaderValues,
&i.HeaderValuesKeyID,
&i.CreatedAt,
&i.UpdatedAt,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getMCPServerUserToken = `-- name: GetMCPServerUserToken :one
SELECT
id, mcp_server_config_id, user_id, access_token, access_token_key_id, refresh_token, refresh_token_key_id, token_type, expiry, created_at, updated_at
@@ -15514,6 +15602,7 @@ INSERT INTO mcp_server_configs (
api_key_value_key_id,
custom_headers,
custom_headers_key_id,
custom_headers_user_keys,
tool_allow_list,
tool_deny_list,
availability,
@@ -15544,13 +15633,14 @@ INSERT INTO mcp_server_configs (
$18::text,
$19::text[],
$20::text[],
$21::text,
$22::boolean,
$21::text[],
$22::text,
$23::boolean,
$24::boolean,
$25::boolean,
$26::uuid,
$27::uuid
$26::boolean,
$27::uuid,
$28::uuid
)
RETURNING
id, display_name, slug, description, icon_url, transport, url, auth_type, oauth2_client_id, oauth2_client_secret, oauth2_client_secret_key_id, oauth2_auth_url, oauth2_token_url, oauth2_scopes, api_key_header, api_key_value, api_key_value_key_id, custom_headers, custom_headers_key_id, tool_allow_list, tool_deny_list, availability, enabled, created_by, updated_by, created_at, updated_at, model_intent, allow_in_plan_mode, forward_coder_headers, custom_headers_user_keys
@@ -15575,6 +15665,7 @@ type InsertMCPServerConfigParams struct {
APIKeyValueKeyID sql.NullString `db:"api_key_value_key_id" json:"api_key_value_key_id"`
CustomHeaders string `db:"custom_headers" json:"custom_headers"`
CustomHeadersKeyID sql.NullString `db:"custom_headers_key_id" json:"custom_headers_key_id"`
CustomHeadersUserKeys []string `db:"custom_headers_user_keys" json:"custom_headers_user_keys"`
ToolAllowList []string `db:"tool_allow_list" json:"tool_allow_list"`
ToolDenyList []string `db:"tool_deny_list" json:"tool_deny_list"`
Availability string `db:"availability" json:"availability"`
@@ -15606,6 +15697,7 @@ func (q *sqlQuerier) InsertMCPServerConfig(ctx context.Context, arg InsertMCPSer
arg.APIKeyValueKeyID,
arg.CustomHeaders,
arg.CustomHeadersKeyID,
pq.Array(arg.CustomHeadersUserKeys),
pq.Array(arg.ToolAllowList),
pq.Array(arg.ToolDenyList),
arg.Availability,
@@ -15675,17 +15767,18 @@ SET
api_key_value_key_id = $16::text,
custom_headers = $17::text,
custom_headers_key_id = $18::text,
tool_allow_list = $19::text[],
tool_deny_list = $20::text[],
availability = $21::text,
enabled = $22::boolean,
model_intent = $23::boolean,
allow_in_plan_mode = $24::boolean,
forward_coder_headers = $25::boolean,
updated_by = $26::uuid,
custom_headers_user_keys = $19::text[],
tool_allow_list = $20::text[],
tool_deny_list = $21::text[],
availability = $22::text,
enabled = $23::boolean,
model_intent = $24::boolean,
allow_in_plan_mode = $25::boolean,
forward_coder_headers = $26::boolean,
updated_by = $27::uuid,
updated_at = NOW()
WHERE
id = $27::uuid
id = $28::uuid
RETURNING
id, display_name, slug, description, icon_url, transport, url, auth_type, oauth2_client_id, oauth2_client_secret, oauth2_client_secret_key_id, oauth2_auth_url, oauth2_token_url, oauth2_scopes, api_key_header, api_key_value, api_key_value_key_id, custom_headers, custom_headers_key_id, tool_allow_list, tool_deny_list, availability, enabled, created_by, updated_by, created_at, updated_at, model_intent, allow_in_plan_mode, forward_coder_headers, custom_headers_user_keys
`
@@ -15709,6 +15802,7 @@ type UpdateMCPServerConfigParams struct {
APIKeyValueKeyID sql.NullString `db:"api_key_value_key_id" json:"api_key_value_key_id"`
CustomHeaders string `db:"custom_headers" json:"custom_headers"`
CustomHeadersKeyID sql.NullString `db:"custom_headers_key_id" json:"custom_headers_key_id"`
CustomHeadersUserKeys []string `db:"custom_headers_user_keys" json:"custom_headers_user_keys"`
ToolAllowList []string `db:"tool_allow_list" json:"tool_allow_list"`
ToolDenyList []string `db:"tool_deny_list" json:"tool_deny_list"`
Availability string `db:"availability" json:"availability"`
@@ -15740,6 +15834,7 @@ func (q *sqlQuerier) UpdateMCPServerConfig(ctx context.Context, arg UpdateMCPSer
arg.APIKeyValueKeyID,
arg.CustomHeaders,
arg.CustomHeadersKeyID,
pq.Array(arg.CustomHeadersUserKeys),
pq.Array(arg.ToolAllowList),
pq.Array(arg.ToolDenyList),
arg.Availability,
@@ -15787,6 +15882,53 @@ func (q *sqlQuerier) UpdateMCPServerConfig(ctx context.Context, arg UpdateMCPSer
return i, err
}
const upsertMCPServerUserHeaderValues = `-- name: UpsertMCPServerUserHeaderValues :one
INSERT INTO mcp_server_user_header_values (
mcp_server_config_id,
user_id,
header_values,
header_values_key_id
) VALUES (
$1::uuid,
$2::uuid,
$3::text,
$4::text
)
ON CONFLICT (mcp_server_config_id, user_id) DO UPDATE SET
header_values = $3::text,
header_values_key_id = $4::text,
updated_at = NOW()
RETURNING
id, mcp_server_config_id, user_id, header_values, header_values_key_id, created_at, updated_at
`
type UpsertMCPServerUserHeaderValuesParams struct {
MCPServerConfigID uuid.UUID `db:"mcp_server_config_id" json:"mcp_server_config_id"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
HeaderValues string `db:"header_values" json:"header_values"`
HeaderValuesKeyID sql.NullString `db:"header_values_key_id" json:"header_values_key_id"`
}
func (q *sqlQuerier) UpsertMCPServerUserHeaderValues(ctx context.Context, arg UpsertMCPServerUserHeaderValuesParams) (McpServerUserHeaderValue, error) {
row := q.db.QueryRowContext(ctx, upsertMCPServerUserHeaderValues,
arg.MCPServerConfigID,
arg.UserID,
arg.HeaderValues,
arg.HeaderValuesKeyID,
)
var i McpServerUserHeaderValue
err := row.Scan(
&i.ID,
&i.MCPServerConfigID,
&i.UserID,
&i.HeaderValues,
&i.HeaderValuesKeyID,
&i.CreatedAt,
&i.UpdatedAt,
)
return i, err
}
const upsertMCPServerUserToken = `-- name: UpsertMCPServerUserToken :one
INSERT INTO mcp_server_user_tokens (
mcp_server_config_id,
@@ -73,6 +73,7 @@ INSERT INTO mcp_server_configs (
api_key_value_key_id,
custom_headers,
custom_headers_key_id,
custom_headers_user_keys,
tool_allow_list,
tool_deny_list,
availability,
@@ -101,6 +102,7 @@ INSERT INTO mcp_server_configs (
sqlc.narg('api_key_value_key_id')::text,
@custom_headers::text,
sqlc.narg('custom_headers_key_id')::text,
@custom_headers_user_keys::text[],
@tool_allow_list::text[],
@tool_deny_list::text[],
@availability::text,
@@ -136,6 +138,7 @@ SET
api_key_value_key_id = sqlc.narg('api_key_value_key_id')::text,
custom_headers = @custom_headers::text,
custom_headers_key_id = sqlc.narg('custom_headers_key_id')::text,
custom_headers_user_keys = @custom_headers_user_keys::text[],
tool_allow_list = @tool_allow_list::text[],
tool_deny_list = @tool_deny_list::text[],
availability = @availability::text,
@@ -211,6 +214,49 @@ WHERE
mcp_server_config_id = @mcp_server_config_id::uuid
AND user_id = @user_id::uuid;
-- name: GetMCPServerUserHeaderValues :one
SELECT
*
FROM
mcp_server_user_header_values
WHERE
mcp_server_config_id = @mcp_server_config_id::uuid
AND user_id = @user_id::uuid;
-- name: GetMCPServerUserHeaderValuesByUserID :many
SELECT
*
FROM
mcp_server_user_header_values
WHERE
user_id = @user_id::uuid;
-- name: UpsertMCPServerUserHeaderValues :one
INSERT INTO mcp_server_user_header_values (
mcp_server_config_id,
user_id,
header_values,
header_values_key_id
) VALUES (
@mcp_server_config_id::uuid,
@user_id::uuid,
@header_values::text,
sqlc.narg('header_values_key_id')::text
)
ON CONFLICT (mcp_server_config_id, user_id) DO UPDATE SET
header_values = @header_values::text,
header_values_key_id = sqlc.narg('header_values_key_id')::text,
updated_at = NOW()
RETURNING
*;
-- name: DeleteMCPServerUserHeaderValues :exec
DELETE FROM
mcp_server_user_header_values
WHERE
mcp_server_config_id = @mcp_server_config_id::uuid
AND user_id = @user_id::uuid;
-- name: CleanupDeletedMCPServerIDsFromChats :exec
UPDATE chats
SET mcp_server_ids = (
+4
View File
@@ -277,6 +277,7 @@ func (api *API) createMCPServerConfig(rw http.ResponseWriter, r *http.Request) {
APIKeyValueKeyID: sql.NullString{},
CustomHeaders: customHeadersJSON,
CustomHeadersKeyID: sql.NullString{},
CustomHeadersUserKeys: nil,
ToolAllowList: coalesceStringSlice(trimStringSlice(req.ToolAllowList)),
ToolDenyList: coalesceStringSlice(trimStringSlice(req.ToolDenyList)),
Availability: strings.TrimSpace(req.Availability),
@@ -366,6 +367,7 @@ func (api *API) createMCPServerConfig(rw http.ResponseWriter, r *http.Request) {
APIKeyValueKeyID: inserted.APIKeyValueKeyID,
CustomHeaders: inserted.CustomHeaders,
CustomHeadersKeyID: inserted.CustomHeadersKeyID,
CustomHeadersUserKeys: inserted.CustomHeadersUserKeys,
ToolAllowList: inserted.ToolAllowList,
ToolDenyList: inserted.ToolDenyList,
Availability: inserted.Availability,
@@ -436,6 +438,7 @@ func (api *API) createMCPServerConfig(rw http.ResponseWriter, r *http.Request) {
APIKeyValueKeyID: sql.NullString{},
CustomHeaders: customHeadersJSON,
CustomHeadersKeyID: sql.NullString{},
CustomHeadersUserKeys: nil,
ToolAllowList: coalesceStringSlice(trimStringSlice(req.ToolAllowList)),
ToolDenyList: coalesceStringSlice(trimStringSlice(req.ToolDenyList)),
Availability: strings.TrimSpace(req.Availability),
@@ -785,6 +788,7 @@ func (api *API) updateMCPServerConfig(rw http.ResponseWriter, r *http.Request) {
APIKeyValueKeyID: apiKeyValueKeyID,
CustomHeaders: customHeaders,
CustomHeadersKeyID: customHeadersKeyID,
CustomHeadersUserKeys: existing.CustomHeadersUserKeys,
ToolAllowList: toolAllowList,
ToolDenyList: toolDenyList,
Availability: availability,
+47
View File
@@ -702,6 +702,12 @@ func (db *dbCrypt) decryptMCPServerUserToken(tok *database.MCPServerUserToken) e
return db.decryptField(&tok.RefreshToken, tok.RefreshTokenKeyID)
}
// decryptMCPServerUserHeaderValues decrypts all encrypted fields on a
// single McpServerUserHeaderValue in place.
func (db *dbCrypt) decryptMCPServerUserHeaderValues(row *database.McpServerUserHeaderValue) error {
return db.decryptField(&row.HeaderValues, row.HeaderValuesKeyID)
}
func (db *dbCrypt) GetMCPServerConfigByID(ctx context.Context, id uuid.UUID) (database.MCPServerConfig, error) {
cfg, err := db.Store.GetMCPServerConfigByID(ctx, id)
if err != nil {
@@ -876,6 +882,47 @@ func (db *dbCrypt) UpsertMCPServerUserToken(ctx context.Context, params database
return tok, nil
}
func (db *dbCrypt) GetMCPServerUserHeaderValues(ctx context.Context, arg database.GetMCPServerUserHeaderValuesParams) (database.McpServerUserHeaderValue, error) {
row, err := db.Store.GetMCPServerUserHeaderValues(ctx, arg)
if err != nil {
return database.McpServerUserHeaderValue{}, err
}
if err := db.decryptMCPServerUserHeaderValues(&row); err != nil {
return database.McpServerUserHeaderValue{}, err
}
return row, nil
}
func (db *dbCrypt) GetMCPServerUserHeaderValuesByUserID(ctx context.Context, userID uuid.UUID) ([]database.McpServerUserHeaderValue, error) {
rows, err := db.Store.GetMCPServerUserHeaderValuesByUserID(ctx, userID)
if err != nil {
return nil, err
}
for i := range rows {
if err := db.decryptMCPServerUserHeaderValues(&rows[i]); err != nil {
return nil, err
}
}
return rows, nil
}
func (db *dbCrypt) UpsertMCPServerUserHeaderValues(ctx context.Context, params database.UpsertMCPServerUserHeaderValuesParams) (database.McpServerUserHeaderValue, error) {
if strings.TrimSpace(params.HeaderValues) == "" {
params.HeaderValuesKeyID = sql.NullString{}
} else if err := db.encryptField(&params.HeaderValues, &params.HeaderValuesKeyID); err != nil {
return database.McpServerUserHeaderValue{}, err
}
row, err := db.Store.UpsertMCPServerUserHeaderValues(ctx, params)
if err != nil {
return database.McpServerUserHeaderValue{}, err
}
if err := db.decryptMCPServerUserHeaderValues(&row); err != nil {
return database.McpServerUserHeaderValue{}, err
}
return row, nil
}
func (db *dbCrypt) CreateUserSecret(ctx context.Context, params database.CreateUserSecretParams) (database.UserSecret, error) {
if err := db.encryptField(&params.Value, &params.ValueKeyID); err != nil {
return database.UserSecret{}, err
@@ -1570,6 +1570,98 @@ func TestMCPServerUserTokens(t *testing.T) {
})
}
func TestMCPServerUserHeaderValues(t *testing.T) {
t.Parallel()
ctx := context.Background()
const headerValues = `{"X-User-Token":"super-secret-user-token"}`
// insertConfigAndValues creates a user, an MCP server config with a
// user-set custom header, and the user-supplied values row through the
// encrypted store.
insertConfigAndValues := func(
t *testing.T,
crypt *dbCrypt,
ciphers []Cipher,
) (database.MCPServerConfig, database.McpServerUserHeaderValue) {
t.Helper()
user := dbgen.User(t, crypt, database.User{})
cfg := dbgen.MCPServerConfig(t, crypt, database.MCPServerConfig{
DisplayName: "Header Values Test MCP",
AuthType: "custom_headers",
CustomHeadersUserKeys: []string{"X-User-Token"},
CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true},
})
row, err := crypt.UpsertMCPServerUserHeaderValues(ctx, database.UpsertMCPServerUserHeaderValuesParams{
MCPServerConfigID: cfg.ID,
UserID: user.ID,
HeaderValues: headerValues,
})
require.NoError(t, err)
require.Equal(t, headerValues, row.HeaderValues)
require.Equal(t, ciphers[0].HexDigest(), row.HeaderValuesKeyID.String)
return cfg, row
}
t.Run("UpsertMCPServerUserHeaderValues", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
cfg, row := insertConfigAndValues(t, crypt, ciphers)
// Verify the raw DB value is encrypted.
rawRow, err := db.GetMCPServerUserHeaderValues(ctx, database.GetMCPServerUserHeaderValuesParams{
MCPServerConfigID: cfg.ID,
UserID: row.UserID,
})
require.NoError(t, err)
requireEncryptedEquals(t, ciphers[0], rawRow.HeaderValues, headerValues)
})
t.Run("GetMCPServerUserHeaderValues", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
cfg, row := insertConfigAndValues(t, crypt, ciphers)
got, err := crypt.GetMCPServerUserHeaderValues(ctx, database.GetMCPServerUserHeaderValuesParams{
MCPServerConfigID: cfg.ID,
UserID: row.UserID,
})
require.NoError(t, err)
require.Equal(t, headerValues, got.HeaderValues)
require.Equal(t, ciphers[0].HexDigest(), got.HeaderValuesKeyID.String)
// Raw values must be encrypted.
rawRow, err := db.GetMCPServerUserHeaderValues(ctx, database.GetMCPServerUserHeaderValuesParams{
MCPServerConfigID: cfg.ID,
UserID: row.UserID,
})
require.NoError(t, err)
requireEncryptedEquals(t, ciphers[0], rawRow.HeaderValues, headerValues)
})
t.Run("GetMCPServerUserHeaderValuesByUserID", func(t *testing.T) {
t.Parallel()
db, crypt, ciphers := setup(t)
cfg, row := insertConfigAndValues(t, crypt, ciphers)
rows, err := crypt.GetMCPServerUserHeaderValuesByUserID(ctx, row.UserID)
require.NoError(t, err)
require.Len(t, rows, 1)
require.Equal(t, headerValues, rows[0].HeaderValues)
require.Equal(t, ciphers[0].HexDigest(), rows[0].HeaderValuesKeyID.String)
// Raw values must be encrypted.
rawRow, err := db.GetMCPServerUserHeaderValues(ctx, database.GetMCPServerUserHeaderValuesParams{
MCPServerConfigID: cfg.ID,
UserID: row.UserID,
})
require.NoError(t, err)
requireEncryptedEquals(t, ciphers[0], rawRow.HeaderValues, headerValues)
})
}
func TestUserSecrets(t *testing.T) {
t.Parallel()
ctx := context.Background()