diff --git a/coderd/database/dbauthz/dbauthz.go b/coderd/database/dbauthz/dbauthz.go index a1a7497153..909d1f6400 100644 --- a/coderd/database/dbauthz/dbauthz.go +++ b/coderd/database/dbauthz/dbauthz.go @@ -2114,6 +2114,24 @@ func (q *querier) DeleteMCPServerConfigByID(ctx context.Context, id uuid.UUID) e return q.db.DeleteMCPServerConfigByID(ctx, id) } +func (q *querier) DeleteMCPServerUserHeaderValues(ctx context.Context, arg database.DeleteMCPServerUserHeaderValuesParams) error { + fetch := func(ctx context.Context, arg database.DeleteMCPServerUserHeaderValuesParams) (database.McpServerUserHeaderValue, error) { + return q.db.GetMCPServerUserHeaderValues(ctx, database.GetMCPServerUserHeaderValuesParams(arg)) + } + return fetchAndExec(q.log, q.auth, policy.ActionUpdatePersonal, fetch, q.db.DeleteMCPServerUserHeaderValues)(ctx, arg) +} + +func (q *querier) DeleteMCPServerUserHeaderValuesByConfigID(ctx context.Context, mcpServerConfigID uuid.UUID) error { + // Admin-only operation. Called from the admin MCP server config + // update path when auth_type or custom_headers_user_keys changes, + // so stale per-user header values do not silently reactivate when + // the key set is restored. + if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil { + return err + } + return q.db.DeleteMCPServerUserHeaderValuesByConfigID(ctx, mcpServerConfigID) +} + func (q *querier) DeleteMCPServerUserToken(ctx context.Context, arg database.DeleteMCPServerUserTokenParams) error { if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil { return err @@ -3712,6 +3730,14 @@ func (q *querier) GetMCPServerConfigsByIDs(ctx context.Context, ids []uuid.UUID) return q.db.GetMCPServerConfigsByIDs(ctx, ids) } +func (q *querier) GetMCPServerUserHeaderValues(ctx context.Context, arg database.GetMCPServerUserHeaderValuesParams) (database.McpServerUserHeaderValue, error) { + return fetchWithAction(q.log, q.auth, policy.ActionReadPersonal, q.db.GetMCPServerUserHeaderValues)(ctx, arg) +} + +func (q *querier) GetMCPServerUserHeaderValuesByUserID(ctx context.Context, userID uuid.UUID) ([]database.McpServerUserHeaderValue, error) { + return fetchWithPostFilter(q.auth, policy.ActionReadPersonal, q.db.GetMCPServerUserHeaderValuesByUserID)(ctx, userID) +} + func (q *querier) GetMCPServerUserToken(ctx context.Context, arg database.GetMCPServerUserTokenParams) (database.MCPServerUserToken, error) { if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceDeploymentConfig); err != nil { return database.MCPServerUserToken{}, err @@ -8261,6 +8287,10 @@ func (q *querier) UpsertLogoURL(ctx context.Context, value string) error { return q.db.UpsertLogoURL(ctx, value) } +func (q *querier) UpsertMCPServerUserHeaderValues(ctx context.Context, arg database.UpsertMCPServerUserHeaderValuesParams) (database.McpServerUserHeaderValue, error) { + return insertWithAction(q.log, q.auth, rbac.ResourceUser.WithID(arg.UserID).WithOwner(arg.UserID.String()), policy.ActionUpdatePersonal, q.db.UpsertMCPServerUserHeaderValues)(ctx, arg) +} + func (q *querier) UpsertMCPServerUserToken(ctx context.Context, arg database.UpsertMCPServerUserTokenParams) (database.MCPServerUserToken, error) { if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceDeploymentConfig); err != nil { return database.MCPServerUserToken{}, err diff --git a/coderd/database/dbauthz/dbauthz_test.go b/coderd/database/dbauthz/dbauthz_test.go index f788fa71e2..9b2917ae6c 100644 --- a/coderd/database/dbauthz/dbauthz_test.go +++ b/coderd/database/dbauthz/dbauthz_test.go @@ -1662,6 +1662,46 @@ func (s *MethodTestSuite) TestChats() { dbm.EXPECT().GetMCPServerUserTokensByUserID(gomock.Any(), userID).Return(tokens, nil).AnyTimes() check.Args(userID).Asserts(rbac.ResourceDeploymentConfig, policy.ActionRead).Returns(tokens) })) + s.Run("GetMCPServerUserHeaderValues", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + arg := database.GetMCPServerUserHeaderValuesParams{ + MCPServerConfigID: uuid.New(), + UserID: uuid.New(), + } + value := testutil.Fake(s.T(), faker, database.McpServerUserHeaderValue{MCPServerConfigID: arg.MCPServerConfigID, UserID: arg.UserID}) + dbm.EXPECT().GetMCPServerUserHeaderValues(gomock.Any(), arg).Return(value, nil).AnyTimes() + check.Args(arg).Asserts(value, policy.ActionReadPersonal).Returns(value) + })) + s.Run("GetMCPServerUserHeaderValuesByUserID", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + userID := uuid.New() + values := []database.McpServerUserHeaderValue{testutil.Fake(s.T(), faker, database.McpServerUserHeaderValue{UserID: userID})} + dbm.EXPECT().GetMCPServerUserHeaderValuesByUserID(gomock.Any(), userID).Return(values, nil).AnyTimes() + check.Args(userID).Asserts(values[0], policy.ActionReadPersonal).Returns(values) + })) + s.Run("UpsertMCPServerUserHeaderValues", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + arg := database.UpsertMCPServerUserHeaderValuesParams{ + MCPServerConfigID: uuid.New(), + UserID: uuid.New(), + HeaderValues: `{"X-User-Token":"secret"}`, + } + value := testutil.Fake(s.T(), faker, database.McpServerUserHeaderValue{MCPServerConfigID: arg.MCPServerConfigID, UserID: arg.UserID}) + dbm.EXPECT().UpsertMCPServerUserHeaderValues(gomock.Any(), arg).Return(value, nil).AnyTimes() + check.Args(arg).Asserts(rbac.ResourceUser.WithID(arg.UserID).WithOwner(arg.UserID.String()), policy.ActionUpdatePersonal).Returns(value) + })) + s.Run("DeleteMCPServerUserHeaderValues", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { + arg := database.DeleteMCPServerUserHeaderValuesParams{ + MCPServerConfigID: uuid.New(), + UserID: uuid.New(), + } + value := testutil.Fake(s.T(), faker, database.McpServerUserHeaderValue{MCPServerConfigID: arg.MCPServerConfigID, UserID: arg.UserID}) + dbm.EXPECT().GetMCPServerUserHeaderValues(gomock.Any(), database.GetMCPServerUserHeaderValuesParams(arg)).Return(value, nil).AnyTimes() + dbm.EXPECT().DeleteMCPServerUserHeaderValues(gomock.Any(), arg).Return(nil).AnyTimes() + check.Args(arg).Asserts(value, policy.ActionUpdatePersonal).Returns() + })) + s.Run("DeleteMCPServerUserHeaderValuesByConfigID", s.Mocked(func(dbm *dbmock.MockStore, _ *gofakeit.Faker, check *expects) { + id := uuid.New() + dbm.EXPECT().DeleteMCPServerUserHeaderValuesByConfigID(gomock.Any(), id).Return(nil).AnyTimes() + check.Args(id).Asserts(rbac.ResourceDeploymentConfig, policy.ActionUpdate).Returns() + })) s.Run("InsertMCPServerConfig", s.Mocked(func(dbm *dbmock.MockStore, faker *gofakeit.Faker, check *expects) { arg := database.InsertMCPServerConfigParams{ DisplayName: "Test MCP Server", diff --git a/coderd/database/dbgen/dbgen.go b/coderd/database/dbgen/dbgen.go index 416a2b7257..096234cd8c 100644 --- a/coderd/database/dbgen/dbgen.go +++ b/coderd/database/dbgen/dbgen.go @@ -331,38 +331,53 @@ func MCPServerConfig(t testing.TB, db database.Store, seed database.MCPServerCon } cfg, err := db.InsertMCPServerConfig(genCtx, database.InsertMCPServerConfigParams{ - DisplayName: takeFirst(seed.DisplayName, "Test MCP Server"), - Slug: takeFirst(seed.Slug, testutil.GetRandomName(t)), - Description: seed.Description, - IconURL: seed.IconURL, - Transport: takeFirst(seed.Transport, "streamable_http"), - Url: takeFirst(seed.Url, "https://mcp.example.com"), - AuthType: takeFirst(seed.AuthType, "none"), - OAuth2ClientID: seed.OAuth2ClientID, - OAuth2ClientSecret: seed.OAuth2ClientSecret, - OAuth2ClientSecretKeyID: seed.OAuth2ClientSecretKeyID, - OAuth2AuthURL: seed.OAuth2AuthURL, - OAuth2TokenURL: seed.OAuth2TokenURL, - OAuth2Scopes: seed.OAuth2Scopes, - APIKeyHeader: seed.APIKeyHeader, - APIKeyValue: seed.APIKeyValue, - APIKeyValueKeyID: seed.APIKeyValueKeyID, - CustomHeaders: seed.CustomHeaders, - CustomHeadersKeyID: seed.CustomHeadersKeyID, - ToolAllowList: takeFirstSlice(seed.ToolAllowList, []string{}), - ToolDenyList: takeFirstSlice(seed.ToolDenyList, []string{}), - Availability: takeFirst(seed.Availability, "default_off"), - Enabled: takeFirst(seed.Enabled, true), - ModelIntent: seed.ModelIntent, - AllowInPlanMode: seed.AllowInPlanMode, - ForwardCoderHeaders: seed.ForwardCoderHeaders, - CreatedBy: createdBy, - UpdatedBy: updatedBy, + DisplayName: takeFirst(seed.DisplayName, "Test MCP Server"), + Slug: takeFirst(seed.Slug, testutil.GetRandomName(t)), + Description: seed.Description, + IconURL: seed.IconURL, + Transport: takeFirst(seed.Transport, "streamable_http"), + Url: takeFirst(seed.Url, "https://mcp.example.com"), + AuthType: takeFirst(seed.AuthType, "none"), + OAuth2ClientID: seed.OAuth2ClientID, + OAuth2ClientSecret: seed.OAuth2ClientSecret, + OAuth2ClientSecretKeyID: seed.OAuth2ClientSecretKeyID, + OAuth2AuthURL: seed.OAuth2AuthURL, + OAuth2TokenURL: seed.OAuth2TokenURL, + OAuth2Scopes: seed.OAuth2Scopes, + APIKeyHeader: seed.APIKeyHeader, + APIKeyValue: seed.APIKeyValue, + APIKeyValueKeyID: seed.APIKeyValueKeyID, + CustomHeaders: seed.CustomHeaders, + CustomHeadersKeyID: seed.CustomHeadersKeyID, + CustomHeadersUserKeys: takeFirstSlice(seed.CustomHeadersUserKeys, []string{}), + CustomHeadersUserKeyDescriptions: takeFirstRawMessage(seed.CustomHeadersUserKeyDescriptions, json.RawMessage("{}")), + ToolAllowList: takeFirstSlice(seed.ToolAllowList, []string{}), + ToolDenyList: takeFirstSlice(seed.ToolDenyList, []string{}), + Availability: takeFirst(seed.Availability, "default_off"), + Enabled: takeFirst(seed.Enabled, true), + ModelIntent: seed.ModelIntent, + AllowInPlanMode: seed.AllowInPlanMode, + ForwardCoderHeaders: seed.ForwardCoderHeaders, + CreatedBy: createdBy, + UpdatedBy: updatedBy, }) require.NoError(t, err, "insert MCP server config") return cfg } +func MCPServerUserHeaderValues(t testing.TB, db database.Store, seed database.McpServerUserHeaderValue) database.McpServerUserHeaderValue { + t.Helper() + + row, err := db.UpsertMCPServerUserHeaderValues(genCtx, database.UpsertMCPServerUserHeaderValuesParams{ + MCPServerConfigID: takeFirst(seed.MCPServerConfigID, uuid.New()), + UserID: takeFirst(seed.UserID, uuid.New()), + HeaderValues: takeFirst(seed.HeaderValues, "{}"), + HeaderValuesKeyID: seed.HeaderValuesKeyID, + }) + require.NoError(t, err, "upsert MCP server user header values") + return row +} + func ConnectionLog(t testing.TB, db database.Store, seed database.UpsertConnectionLogParams) database.ConnectionLog { arg := database.UpsertConnectionLogParams{ ID: takeFirst(seed.ID, uuid.New()), @@ -2157,6 +2172,20 @@ func takeFirstSlice[T any](values ...[]T) []T { }) } +// takeFirstRawMessage returns the first json.RawMessage that is not +// empty and not a JSON null/object literal that signals absence. Use +// this for NOT NULL JSONB columns whose Go zero value would otherwise +// produce a SQL NULL. +func takeFirstRawMessage(values ...json.RawMessage) json.RawMessage { + for _, v := range values { + if len(v) == 0 { + continue + } + return v + } + return nil +} + func takeFirstMap[T, E comparable](values ...map[T]E) map[T]E { return takeFirstF(values, func(v map[T]E) bool { return v != nil diff --git a/coderd/database/dbmetrics/querymetrics.go b/coderd/database/dbmetrics/querymetrics.go index e7120ec588..f980a20cb7 100644 --- a/coderd/database/dbmetrics/querymetrics.go +++ b/coderd/database/dbmetrics/querymetrics.go @@ -577,6 +577,22 @@ func (m queryMetricsStore) DeleteMCPServerConfigByID(ctx context.Context, id uui return r0 } +func (m queryMetricsStore) DeleteMCPServerUserHeaderValues(ctx context.Context, arg database.DeleteMCPServerUserHeaderValuesParams) error { + start := time.Now() + r0 := m.s.DeleteMCPServerUserHeaderValues(ctx, arg) + m.queryLatencies.WithLabelValues("DeleteMCPServerUserHeaderValues").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteMCPServerUserHeaderValues").Inc() + return r0 +} + +func (m queryMetricsStore) DeleteMCPServerUserHeaderValuesByConfigID(ctx context.Context, mcpServerConfigID uuid.UUID) error { + start := time.Now() + r0 := m.s.DeleteMCPServerUserHeaderValuesByConfigID(ctx, mcpServerConfigID) + m.queryLatencies.WithLabelValues("DeleteMCPServerUserHeaderValuesByConfigID").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "DeleteMCPServerUserHeaderValuesByConfigID").Inc() + return r0 +} + func (m queryMetricsStore) DeleteMCPServerUserToken(ctx context.Context, arg database.DeleteMCPServerUserTokenParams) error { start := time.Now() r0 := m.s.DeleteMCPServerUserToken(ctx, arg) @@ -2121,6 +2137,22 @@ func (m queryMetricsStore) GetMCPServerConfigsByIDs(ctx context.Context, ids []u return r0, r1 } +func (m queryMetricsStore) GetMCPServerUserHeaderValues(ctx context.Context, arg database.GetMCPServerUserHeaderValuesParams) (database.McpServerUserHeaderValue, error) { + start := time.Now() + r0, r1 := m.s.GetMCPServerUserHeaderValues(ctx, arg) + m.queryLatencies.WithLabelValues("GetMCPServerUserHeaderValues").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetMCPServerUserHeaderValues").Inc() + return r0, r1 +} + +func (m queryMetricsStore) GetMCPServerUserHeaderValuesByUserID(ctx context.Context, userID uuid.UUID) ([]database.McpServerUserHeaderValue, error) { + start := time.Now() + r0, r1 := m.s.GetMCPServerUserHeaderValuesByUserID(ctx, userID) + m.queryLatencies.WithLabelValues("GetMCPServerUserHeaderValuesByUserID").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "GetMCPServerUserHeaderValuesByUserID").Inc() + return r0, r1 +} + func (m queryMetricsStore) GetMCPServerUserToken(ctx context.Context, arg database.GetMCPServerUserTokenParams) (database.MCPServerUserToken, error) { start := time.Now() r0, r1 := m.s.GetMCPServerUserToken(ctx, arg) @@ -5953,6 +5985,14 @@ func (m queryMetricsStore) UpsertLogoURL(ctx context.Context, value string) erro return r0 } +func (m queryMetricsStore) UpsertMCPServerUserHeaderValues(ctx context.Context, arg database.UpsertMCPServerUserHeaderValuesParams) (database.McpServerUserHeaderValue, error) { + start := time.Now() + r0, r1 := m.s.UpsertMCPServerUserHeaderValues(ctx, arg) + m.queryLatencies.WithLabelValues("UpsertMCPServerUserHeaderValues").Observe(time.Since(start).Seconds()) + m.queryCounts.WithLabelValues(httpmw.ExtractHTTPRoute(ctx), httpmw.ExtractHTTPMethod(ctx), "UpsertMCPServerUserHeaderValues").Inc() + return r0, r1 +} + func (m queryMetricsStore) UpsertMCPServerUserToken(ctx context.Context, arg database.UpsertMCPServerUserTokenParams) (database.MCPServerUserToken, error) { start := time.Now() r0, r1 := m.s.UpsertMCPServerUserToken(ctx, arg) diff --git a/coderd/database/dbmock/dbmock.go b/coderd/database/dbmock/dbmock.go index 0f6799e638..b981e243b6 100644 --- a/coderd/database/dbmock/dbmock.go +++ b/coderd/database/dbmock/dbmock.go @@ -960,6 +960,34 @@ func (mr *MockStoreMockRecorder) DeleteMCPServerConfigByID(ctx, id any) *gomock. return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteMCPServerConfigByID", reflect.TypeOf((*MockStore)(nil).DeleteMCPServerConfigByID), ctx, id) } +// DeleteMCPServerUserHeaderValues mocks base method. +func (m *MockStore) DeleteMCPServerUserHeaderValues(ctx context.Context, arg database.DeleteMCPServerUserHeaderValuesParams) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteMCPServerUserHeaderValues", ctx, arg) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteMCPServerUserHeaderValues indicates an expected call of DeleteMCPServerUserHeaderValues. +func (mr *MockStoreMockRecorder) DeleteMCPServerUserHeaderValues(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteMCPServerUserHeaderValues", reflect.TypeOf((*MockStore)(nil).DeleteMCPServerUserHeaderValues), ctx, arg) +} + +// DeleteMCPServerUserHeaderValuesByConfigID mocks base method. +func (m *MockStore) DeleteMCPServerUserHeaderValuesByConfigID(ctx context.Context, mcpServerConfigID uuid.UUID) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DeleteMCPServerUserHeaderValuesByConfigID", ctx, mcpServerConfigID) + ret0, _ := ret[0].(error) + return ret0 +} + +// DeleteMCPServerUserHeaderValuesByConfigID indicates an expected call of DeleteMCPServerUserHeaderValuesByConfigID. +func (mr *MockStoreMockRecorder) DeleteMCPServerUserHeaderValuesByConfigID(ctx, mcpServerConfigID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteMCPServerUserHeaderValuesByConfigID", reflect.TypeOf((*MockStore)(nil).DeleteMCPServerUserHeaderValuesByConfigID), ctx, mcpServerConfigID) +} + // DeleteMCPServerUserToken mocks base method. func (m *MockStore) DeleteMCPServerUserToken(ctx context.Context, arg database.DeleteMCPServerUserTokenParams) error { m.ctrl.T.Helper() @@ -3945,6 +3973,36 @@ func (mr *MockStoreMockRecorder) GetMCPServerConfigsByIDs(ctx, ids any) *gomock. return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMCPServerConfigsByIDs", reflect.TypeOf((*MockStore)(nil).GetMCPServerConfigsByIDs), ctx, ids) } +// GetMCPServerUserHeaderValues mocks base method. +func (m *MockStore) GetMCPServerUserHeaderValues(ctx context.Context, arg database.GetMCPServerUserHeaderValuesParams) (database.McpServerUserHeaderValue, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetMCPServerUserHeaderValues", ctx, arg) + ret0, _ := ret[0].(database.McpServerUserHeaderValue) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetMCPServerUserHeaderValues indicates an expected call of GetMCPServerUserHeaderValues. +func (mr *MockStoreMockRecorder) GetMCPServerUserHeaderValues(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMCPServerUserHeaderValues", reflect.TypeOf((*MockStore)(nil).GetMCPServerUserHeaderValues), ctx, arg) +} + +// GetMCPServerUserHeaderValuesByUserID mocks base method. +func (m *MockStore) GetMCPServerUserHeaderValuesByUserID(ctx context.Context, userID uuid.UUID) ([]database.McpServerUserHeaderValue, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetMCPServerUserHeaderValuesByUserID", ctx, userID) + ret0, _ := ret[0].([]database.McpServerUserHeaderValue) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// GetMCPServerUserHeaderValuesByUserID indicates an expected call of GetMCPServerUserHeaderValuesByUserID. +func (mr *MockStoreMockRecorder) GetMCPServerUserHeaderValuesByUserID(ctx, userID any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMCPServerUserHeaderValuesByUserID", reflect.TypeOf((*MockStore)(nil).GetMCPServerUserHeaderValuesByUserID), ctx, userID) +} + // GetMCPServerUserToken mocks base method. func (m *MockStore) GetMCPServerUserToken(ctx context.Context, arg database.GetMCPServerUserTokenParams) (database.MCPServerUserToken, error) { m.ctrl.T.Helper() @@ -11172,6 +11230,21 @@ func (mr *MockStoreMockRecorder) UpsertLogoURL(ctx, value any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertLogoURL", reflect.TypeOf((*MockStore)(nil).UpsertLogoURL), ctx, value) } +// UpsertMCPServerUserHeaderValues mocks base method. +func (m *MockStore) UpsertMCPServerUserHeaderValues(ctx context.Context, arg database.UpsertMCPServerUserHeaderValuesParams) (database.McpServerUserHeaderValue, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "UpsertMCPServerUserHeaderValues", ctx, arg) + ret0, _ := ret[0].(database.McpServerUserHeaderValue) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// UpsertMCPServerUserHeaderValues indicates an expected call of UpsertMCPServerUserHeaderValues. +func (mr *MockStoreMockRecorder) UpsertMCPServerUserHeaderValues(ctx, arg any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpsertMCPServerUserHeaderValues", reflect.TypeOf((*MockStore)(nil).UpsertMCPServerUserHeaderValues), ctx, arg) +} + // UpsertMCPServerUserToken mocks base method. func (m *MockStore) UpsertMCPServerUserToken(ctx context.Context, arg database.UpsertMCPServerUserTokenParams) (database.MCPServerUserToken, error) { m.ctrl.T.Helper() diff --git a/coderd/database/dump.sql b/coderd/database/dump.sql index 82aa376d34..fc91df010a 100644 --- a/coderd/database/dump.sql +++ b/coderd/database/dump.sql @@ -2142,11 +2142,23 @@ CREATE TABLE mcp_server_configs ( model_intent boolean DEFAULT false NOT NULL, allow_in_plan_mode boolean DEFAULT false NOT NULL, forward_coder_headers boolean DEFAULT false NOT NULL, + custom_headers_user_keys text[] DEFAULT '{}'::text[] NOT NULL, + custom_headers_user_key_descriptions jsonb DEFAULT '{}'::jsonb NOT NULL, CONSTRAINT mcp_server_configs_auth_type_check CHECK ((auth_type = ANY (ARRAY['none'::text, 'oauth2'::text, 'api_key'::text, 'custom_headers'::text, 'user_oidc'::text]))), CONSTRAINT mcp_server_configs_availability_check CHECK ((availability = ANY (ARRAY['force_on'::text, 'default_on'::text, 'default_off'::text]))), CONSTRAINT mcp_server_configs_transport_check CHECK ((transport = ANY (ARRAY['streamable_http'::text, 'sse'::text]))) ); +CREATE TABLE mcp_server_user_header_values ( + id uuid DEFAULT gen_random_uuid() NOT NULL, + mcp_server_config_id uuid NOT NULL, + user_id uuid NOT NULL, + header_values text DEFAULT '{}'::text NOT NULL, + header_values_key_id text, + created_at timestamp with time zone DEFAULT now() NOT NULL, + updated_at timestamp with time zone DEFAULT now() NOT NULL +); + CREATE TABLE mcp_server_user_tokens ( id uuid DEFAULT gen_random_uuid() NOT NULL, mcp_server_config_id uuid NOT NULL, @@ -3895,6 +3907,12 @@ ALTER TABLE ONLY mcp_server_configs ALTER TABLE ONLY mcp_server_configs ADD CONSTRAINT mcp_server_configs_slug_key UNIQUE (slug); +ALTER TABLE ONLY mcp_server_user_header_values + ADD CONSTRAINT mcp_server_user_header_values_mcp_server_config_id_user_id_key UNIQUE (mcp_server_config_id, user_id); + +ALTER TABLE ONLY mcp_server_user_header_values + ADD CONSTRAINT mcp_server_user_header_values_pkey PRIMARY KEY (id); + ALTER TABLE ONLY mcp_server_user_tokens ADD CONSTRAINT mcp_server_user_tokens_mcp_server_config_id_user_id_key UNIQUE (mcp_server_config_id, user_id); @@ -4309,6 +4327,8 @@ CREATE INDEX idx_mcp_server_configs_enabled ON mcp_server_configs USING btree (e CREATE INDEX idx_mcp_server_configs_forced ON mcp_server_configs USING btree (enabled, availability) WHERE ((enabled = true) AND (availability = 'force_on'::text)); +CREATE INDEX idx_mcp_server_user_header_values_user_id ON mcp_server_user_header_values USING btree (user_id); + CREATE INDEX idx_mcp_server_user_tokens_user_id ON mcp_server_user_tokens USING btree (user_id); CREATE INDEX idx_notification_messages_status ON notification_messages USING btree (status); @@ -4713,6 +4733,15 @@ ALTER TABLE ONLY mcp_server_configs ALTER TABLE ONLY mcp_server_configs ADD CONSTRAINT mcp_server_configs_updated_by_fkey FOREIGN KEY (updated_by) REFERENCES users(id) ON DELETE SET NULL; +ALTER TABLE ONLY mcp_server_user_header_values + ADD CONSTRAINT mcp_server_user_header_values_header_values_key_id_fkey FOREIGN KEY (header_values_key_id) REFERENCES dbcrypt_keys(active_key_digest); + +ALTER TABLE ONLY mcp_server_user_header_values + ADD CONSTRAINT mcp_server_user_header_values_mcp_server_config_id_fkey FOREIGN KEY (mcp_server_config_id) REFERENCES mcp_server_configs(id) ON DELETE CASCADE; + +ALTER TABLE ONLY mcp_server_user_header_values + ADD CONSTRAINT mcp_server_user_header_values_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE; + ALTER TABLE ONLY mcp_server_user_tokens ADD CONSTRAINT mcp_server_user_tokens_access_token_key_id_fkey FOREIGN KEY (access_token_key_id) REFERENCES dbcrypt_keys(active_key_digest); diff --git a/coderd/database/foreign_key_constraint.go b/coderd/database/foreign_key_constraint.go index 5eeb24587a..41353c9baa 100644 --- a/coderd/database/foreign_key_constraint.go +++ b/coderd/database/foreign_key_constraint.go @@ -60,6 +60,9 @@ const ( ForeignKeyMcpServerConfigsCustomHeadersKeyID ForeignKeyConstraint = "mcp_server_configs_custom_headers_key_id_fkey" // ALTER TABLE ONLY mcp_server_configs ADD CONSTRAINT mcp_server_configs_custom_headers_key_id_fkey FOREIGN KEY (custom_headers_key_id) REFERENCES dbcrypt_keys(active_key_digest); ForeignKeyMcpServerConfigsOauth2ClientSecretKeyID ForeignKeyConstraint = "mcp_server_configs_oauth2_client_secret_key_id_fkey" // ALTER TABLE ONLY mcp_server_configs ADD CONSTRAINT mcp_server_configs_oauth2_client_secret_key_id_fkey FOREIGN KEY (oauth2_client_secret_key_id) REFERENCES dbcrypt_keys(active_key_digest); ForeignKeyMcpServerConfigsUpdatedBy ForeignKeyConstraint = "mcp_server_configs_updated_by_fkey" // ALTER TABLE ONLY mcp_server_configs ADD CONSTRAINT mcp_server_configs_updated_by_fkey FOREIGN KEY (updated_by) REFERENCES users(id) ON DELETE SET NULL; + ForeignKeyMcpServerUserHeaderValuesHeaderValuesKeyID ForeignKeyConstraint = "mcp_server_user_header_values_header_values_key_id_fkey" // ALTER TABLE ONLY mcp_server_user_header_values ADD CONSTRAINT mcp_server_user_header_values_header_values_key_id_fkey FOREIGN KEY (header_values_key_id) REFERENCES dbcrypt_keys(active_key_digest); + ForeignKeyMcpServerUserHeaderValuesMcpServerConfigID ForeignKeyConstraint = "mcp_server_user_header_values_mcp_server_config_id_fkey" // ALTER TABLE ONLY mcp_server_user_header_values ADD CONSTRAINT mcp_server_user_header_values_mcp_server_config_id_fkey FOREIGN KEY (mcp_server_config_id) REFERENCES mcp_server_configs(id) ON DELETE CASCADE; + ForeignKeyMcpServerUserHeaderValuesUserID ForeignKeyConstraint = "mcp_server_user_header_values_user_id_fkey" // ALTER TABLE ONLY mcp_server_user_header_values ADD CONSTRAINT mcp_server_user_header_values_user_id_fkey FOREIGN KEY (user_id) REFERENCES users(id) ON DELETE CASCADE; ForeignKeyMcpServerUserTokensAccessTokenKeyID ForeignKeyConstraint = "mcp_server_user_tokens_access_token_key_id_fkey" // ALTER TABLE ONLY mcp_server_user_tokens ADD CONSTRAINT mcp_server_user_tokens_access_token_key_id_fkey FOREIGN KEY (access_token_key_id) REFERENCES dbcrypt_keys(active_key_digest); ForeignKeyMcpServerUserTokensMcpServerConfigID ForeignKeyConstraint = "mcp_server_user_tokens_mcp_server_config_id_fkey" // ALTER TABLE ONLY mcp_server_user_tokens ADD CONSTRAINT mcp_server_user_tokens_mcp_server_config_id_fkey FOREIGN KEY (mcp_server_config_id) REFERENCES mcp_server_configs(id) ON DELETE CASCADE; ForeignKeyMcpServerUserTokensRefreshTokenKeyID ForeignKeyConstraint = "mcp_server_user_tokens_refresh_token_key_id_fkey" // ALTER TABLE ONLY mcp_server_user_tokens ADD CONSTRAINT mcp_server_user_tokens_refresh_token_key_id_fkey FOREIGN KEY (refresh_token_key_id) REFERENCES dbcrypt_keys(active_key_digest); diff --git a/coderd/database/migrations/000514_mcp_server_custom_headers_user_keys.down.sql b/coderd/database/migrations/000514_mcp_server_custom_headers_user_keys.down.sql new file mode 100644 index 0000000000..7e5bb1aab1 --- /dev/null +++ b/coderd/database/migrations/000514_mcp_server_custom_headers_user_keys.down.sql @@ -0,0 +1,6 @@ +DROP INDEX IF EXISTS idx_mcp_server_user_header_values_user_id; +DROP TABLE IF EXISTS mcp_server_user_header_values; + +ALTER TABLE mcp_server_configs + DROP COLUMN IF EXISTS custom_headers_user_keys, + DROP COLUMN IF EXISTS custom_headers_user_key_descriptions; diff --git a/coderd/database/migrations/000514_mcp_server_custom_headers_user_keys.up.sql b/coderd/database/migrations/000514_mcp_server_custom_headers_user_keys.up.sql new file mode 100644 index 0000000000..ad3dd1c388 --- /dev/null +++ b/coderd/database/migrations/000514_mcp_server_custom_headers_user_keys.up.sql @@ -0,0 +1,26 @@ +ALTER TABLE mcp_server_configs + ADD COLUMN custom_headers_user_keys TEXT[] NOT NULL DEFAULT '{}', + -- Optional admin-supplied helper text per user-set custom header key. + -- Shown to end users in the settings UI when they fill in their value. + -- Keys must be a subset of custom_headers_user_keys (case-insensitive). + ADD COLUMN custom_headers_user_key_descriptions JSONB NOT NULL DEFAULT '{}'::jsonb; + +CREATE TABLE mcp_server_user_header_values ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + mcp_server_config_id UUID NOT NULL REFERENCES mcp_server_configs(id) ON DELETE CASCADE, + user_id UUID NOT NULL REFERENCES users(id) ON DELETE CASCADE, + + -- JSON object {header: value} of values supplied by the user for the + -- headers listed in mcp_server_configs.custom_headers_user_keys. Stored + -- encrypted at rest via dbcrypt (the key id is header_values_key_id). + header_values TEXT NOT NULL DEFAULT '{}', + header_values_key_id TEXT REFERENCES dbcrypt_keys(active_key_digest), + + created_at TIMESTAMPTZ NOT NULL DEFAULT now(), + updated_at TIMESTAMPTZ NOT NULL DEFAULT now(), + + UNIQUE (mcp_server_config_id, user_id) +); + +CREATE INDEX idx_mcp_server_user_header_values_user_id + ON mcp_server_user_header_values(user_id); diff --git a/coderd/database/migrations/testdata/fixtures/000514_mcp_server_custom_headers_user_keys.up.sql b/coderd/database/migrations/testdata/fixtures/000514_mcp_server_custom_headers_user_keys.up.sql new file mode 100644 index 0000000000..e69fb9b605 --- /dev/null +++ b/coderd/database/migrations/testdata/fixtures/000514_mcp_server_custom_headers_user_keys.up.sql @@ -0,0 +1,53 @@ +-- Migration 514 adds custom_headers_user_keys to mcp_server_configs and +-- creates mcp_server_user_header_values. Insert a fixture row exercising +-- the user-set header values flow. + +INSERT INTO mcp_server_configs ( + id, + display_name, + slug, + url, + transport, + auth_type, + custom_headers, + custom_headers_user_keys, + custom_headers_user_key_descriptions, + availability, + enabled, + created_by, + updated_by, + created_at, + updated_at +) VALUES ( + 'c3d4e5f6-a7b8-9012-cdef-123456789012', + 'Fixture User-Set Headers MCP Server', + 'fixture-user-set-headers-mcp-server', + 'https://mcp.example.com/streamable', + 'streamable_http', + 'custom_headers', + '{"X-Org-ID":"acme"}', + ARRAY['X-User-Token'], + '{"X-User-Token":"Personal access token for the upstream MCP server."}'::jsonb, + 'default_off', + TRUE, + '30095c71-380b-457a-8995-97b8ee6e5307', -- admin@coder.com + '30095c71-380b-457a-8995-97b8ee6e5307', -- admin@coder.com + '2024-01-01 00:00:00+00', + '2024-01-01 00:00:00+00' +); + +INSERT INTO mcp_server_user_header_values ( + id, + mcp_server_config_id, + user_id, + header_values, + created_at, + updated_at +) VALUES ( + 'd4e5f6a7-b8c9-0123-defa-234567890123', + 'c3d4e5f6-a7b8-9012-cdef-123456789012', + '30095c71-380b-457a-8995-97b8ee6e5307', -- admin@coder.com + '{"X-User-Token":"user-supplied-token"}', + '2024-01-01 00:00:00+00', + '2024-01-01 00:00:00+00' +); diff --git a/coderd/database/modelmethods.go b/coderd/database/modelmethods.go index 62eb12a1d2..741228d200 100644 --- a/coderd/database/modelmethods.go +++ b/coderd/database/modelmethods.go @@ -621,6 +621,9 @@ func (u GetUsersRow) RBACObject() rbac.Object { func (u GitSSHKey) RBACObject() rbac.Object { return rbac.ResourceUserObject(u.UserID) } func (u ExternalAuthLink) RBACObject() rbac.Object { return rbac.ResourceUserObject(u.UserID) } func (u UserLink) RBACObject() rbac.Object { return rbac.ResourceUserObject(u.UserID) } +func (u McpServerUserHeaderValue) RBACObject() rbac.Object { + return rbac.ResourceUserObject(u.UserID) +} func (u ExternalAuthLink) OAuthToken() *oauth2.Token { return &oauth2.Token{ diff --git a/coderd/database/models.go b/coderd/database/models.go index ebfaa7a051..a0a282cb64 100644 --- a/coderd/database/models.go +++ b/coderd/database/models.go @@ -4971,36 +4971,38 @@ type License struct { } type MCPServerConfig struct { - ID uuid.UUID `db:"id" json:"id"` - DisplayName string `db:"display_name" json:"display_name"` - Slug string `db:"slug" json:"slug"` - Description string `db:"description" json:"description"` - IconURL string `db:"icon_url" json:"icon_url"` - Transport string `db:"transport" json:"transport"` - Url string `db:"url" json:"url"` - AuthType string `db:"auth_type" json:"auth_type"` - OAuth2ClientID string `db:"oauth2_client_id" json:"oauth2_client_id"` - OAuth2ClientSecret string `db:"oauth2_client_secret" json:"oauth2_client_secret"` - OAuth2ClientSecretKeyID sql.NullString `db:"oauth2_client_secret_key_id" json:"oauth2_client_secret_key_id"` - OAuth2AuthURL string `db:"oauth2_auth_url" json:"oauth2_auth_url"` - OAuth2TokenURL string `db:"oauth2_token_url" json:"oauth2_token_url"` - OAuth2Scopes string `db:"oauth2_scopes" json:"oauth2_scopes"` - APIKeyHeader string `db:"api_key_header" json:"api_key_header"` - APIKeyValue string `db:"api_key_value" json:"api_key_value"` - APIKeyValueKeyID sql.NullString `db:"api_key_value_key_id" json:"api_key_value_key_id"` - CustomHeaders string `db:"custom_headers" json:"custom_headers"` - CustomHeadersKeyID sql.NullString `db:"custom_headers_key_id" json:"custom_headers_key_id"` - ToolAllowList []string `db:"tool_allow_list" json:"tool_allow_list"` - ToolDenyList []string `db:"tool_deny_list" json:"tool_deny_list"` - Availability string `db:"availability" json:"availability"` - Enabled bool `db:"enabled" json:"enabled"` - CreatedBy uuid.NullUUID `db:"created_by" json:"created_by"` - UpdatedBy uuid.NullUUID `db:"updated_by" json:"updated_by"` - CreatedAt time.Time `db:"created_at" json:"created_at"` - UpdatedAt time.Time `db:"updated_at" json:"updated_at"` - ModelIntent bool `db:"model_intent" json:"model_intent"` - AllowInPlanMode bool `db:"allow_in_plan_mode" json:"allow_in_plan_mode"` - ForwardCoderHeaders bool `db:"forward_coder_headers" json:"forward_coder_headers"` + ID uuid.UUID `db:"id" json:"id"` + DisplayName string `db:"display_name" json:"display_name"` + Slug string `db:"slug" json:"slug"` + Description string `db:"description" json:"description"` + IconURL string `db:"icon_url" json:"icon_url"` + Transport string `db:"transport" json:"transport"` + Url string `db:"url" json:"url"` + AuthType string `db:"auth_type" json:"auth_type"` + OAuth2ClientID string `db:"oauth2_client_id" json:"oauth2_client_id"` + OAuth2ClientSecret string `db:"oauth2_client_secret" json:"oauth2_client_secret"` + OAuth2ClientSecretKeyID sql.NullString `db:"oauth2_client_secret_key_id" json:"oauth2_client_secret_key_id"` + OAuth2AuthURL string `db:"oauth2_auth_url" json:"oauth2_auth_url"` + OAuth2TokenURL string `db:"oauth2_token_url" json:"oauth2_token_url"` + OAuth2Scopes string `db:"oauth2_scopes" json:"oauth2_scopes"` + APIKeyHeader string `db:"api_key_header" json:"api_key_header"` + APIKeyValue string `db:"api_key_value" json:"api_key_value"` + APIKeyValueKeyID sql.NullString `db:"api_key_value_key_id" json:"api_key_value_key_id"` + CustomHeaders string `db:"custom_headers" json:"custom_headers"` + CustomHeadersKeyID sql.NullString `db:"custom_headers_key_id" json:"custom_headers_key_id"` + ToolAllowList []string `db:"tool_allow_list" json:"tool_allow_list"` + ToolDenyList []string `db:"tool_deny_list" json:"tool_deny_list"` + Availability string `db:"availability" json:"availability"` + Enabled bool `db:"enabled" json:"enabled"` + CreatedBy uuid.NullUUID `db:"created_by" json:"created_by"` + UpdatedBy uuid.NullUUID `db:"updated_by" json:"updated_by"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` + ModelIntent bool `db:"model_intent" json:"model_intent"` + AllowInPlanMode bool `db:"allow_in_plan_mode" json:"allow_in_plan_mode"` + ForwardCoderHeaders bool `db:"forward_coder_headers" json:"forward_coder_headers"` + CustomHeadersUserKeys []string `db:"custom_headers_user_keys" json:"custom_headers_user_keys"` + CustomHeadersUserKeyDescriptions json.RawMessage `db:"custom_headers_user_key_descriptions" json:"custom_headers_user_key_descriptions"` } type MCPServerUserToken struct { @@ -5017,6 +5019,16 @@ type MCPServerUserToken struct { UpdatedAt time.Time `db:"updated_at" json:"updated_at"` } +type McpServerUserHeaderValue struct { + ID uuid.UUID `db:"id" json:"id"` + MCPServerConfigID uuid.UUID `db:"mcp_server_config_id" json:"mcp_server_config_id"` + UserID uuid.UUID `db:"user_id" json:"user_id"` + HeaderValues string `db:"header_values" json:"header_values"` + HeaderValuesKeyID sql.NullString `db:"header_values_key_id" json:"header_values_key_id"` + CreatedAt time.Time `db:"created_at" json:"created_at"` + UpdatedAt time.Time `db:"updated_at" json:"updated_at"` +} + type NotificationMessage struct { ID uuid.UUID `db:"id" json:"id"` NotificationTemplateID uuid.UUID `db:"notification_template_id" json:"notification_template_id"` diff --git a/coderd/database/querier.go b/coderd/database/querier.go index a6c8f3e7db..6ba5435732 100644 --- a/coderd/database/querier.go +++ b/coderd/database/querier.go @@ -138,6 +138,12 @@ type sqlcQuerier interface { DeleteGroupMemberFromGroup(ctx context.Context, arg DeleteGroupMemberFromGroupParams) error DeleteLicense(ctx context.Context, id int32) (int32, error) DeleteMCPServerConfigByID(ctx context.Context, id uuid.UUID) error + DeleteMCPServerUserHeaderValues(ctx context.Context, arg DeleteMCPServerUserHeaderValuesParams) error + // Deletes every user's stored header values for the given MCP server + // config. Use when the admin changes auth_type away from custom_headers + // or alters custom_headers_user_keys so stale credentials do not + // silently reactivate when the key set is restored. + DeleteMCPServerUserHeaderValuesByConfigID(ctx context.Context, mcpServerConfigID uuid.UUID) error DeleteMCPServerUserToken(ctx context.Context, arg DeleteMCPServerUserTokenParams) error DeleteOAuth2ProviderAppByClientID(ctx context.Context, id uuid.UUID) error DeleteOAuth2ProviderAppByID(ctx context.Context, id uuid.UUID) error @@ -515,6 +521,8 @@ type sqlcQuerier interface { GetMCPServerConfigBySlug(ctx context.Context, slug string) (MCPServerConfig, error) GetMCPServerConfigs(ctx context.Context) ([]MCPServerConfig, error) GetMCPServerConfigsByIDs(ctx context.Context, ids []uuid.UUID) ([]MCPServerConfig, error) + GetMCPServerUserHeaderValues(ctx context.Context, arg GetMCPServerUserHeaderValuesParams) (McpServerUserHeaderValue, error) + GetMCPServerUserHeaderValuesByUserID(ctx context.Context, userID uuid.UUID) ([]McpServerUserHeaderValue, error) GetMCPServerUserToken(ctx context.Context, arg GetMCPServerUserTokenParams) (MCPServerUserToken, error) GetMCPServerUserTokensByUserID(ctx context.Context, userID uuid.UUID) ([]MCPServerUserToken, error) GetNotificationMessagesByStatus(ctx context.Context, arg GetNotificationMessagesByStatusParams) ([]NotificationMessage, error) @@ -1390,6 +1398,7 @@ type sqlcQuerier interface { UpsertHealthSettings(ctx context.Context, value string) error UpsertLastUpdateCheck(ctx context.Context, value string) error UpsertLogoURL(ctx context.Context, value string) error + UpsertMCPServerUserHeaderValues(ctx context.Context, arg UpsertMCPServerUserHeaderValuesParams) (McpServerUserHeaderValue, error) UpsertMCPServerUserToken(ctx context.Context, arg UpsertMCPServerUserTokenParams) (MCPServerUserToken, error) // Insert or update notification report generator logs with recent activity. UpsertNotificationReportGeneratorLog(ctx context.Context, arg UpsertNotificationReportGeneratorLogParams) error diff --git a/coderd/database/queries.sql.go b/coderd/database/queries.sql.go index dc646121dc..b53cd07702 100644 --- a/coderd/database/queries.sql.go +++ b/coderd/database/queries.sql.go @@ -15039,6 +15039,40 @@ func (q *sqlQuerier) DeleteMCPServerConfigByID(ctx context.Context, id uuid.UUID return err } +const deleteMCPServerUserHeaderValues = `-- name: DeleteMCPServerUserHeaderValues :exec +DELETE FROM + mcp_server_user_header_values +WHERE + mcp_server_config_id = $1::uuid + AND user_id = $2::uuid +` + +type DeleteMCPServerUserHeaderValuesParams struct { + MCPServerConfigID uuid.UUID `db:"mcp_server_config_id" json:"mcp_server_config_id"` + UserID uuid.UUID `db:"user_id" json:"user_id"` +} + +func (q *sqlQuerier) DeleteMCPServerUserHeaderValues(ctx context.Context, arg DeleteMCPServerUserHeaderValuesParams) error { + _, err := q.db.ExecContext(ctx, deleteMCPServerUserHeaderValues, arg.MCPServerConfigID, arg.UserID) + return err +} + +const deleteMCPServerUserHeaderValuesByConfigID = `-- name: DeleteMCPServerUserHeaderValuesByConfigID :exec +DELETE FROM + mcp_server_user_header_values +WHERE + mcp_server_config_id = $1::uuid +` + +// Deletes every user's stored header values for the given MCP server +// config. Use when the admin changes auth_type away from custom_headers +// or alters custom_headers_user_keys so stale credentials do not +// silently reactivate when the key set is restored. +func (q *sqlQuerier) DeleteMCPServerUserHeaderValuesByConfigID(ctx context.Context, mcpServerConfigID uuid.UUID) error { + _, err := q.db.ExecContext(ctx, deleteMCPServerUserHeaderValuesByConfigID, mcpServerConfigID) + return err +} + const deleteMCPServerUserToken = `-- name: DeleteMCPServerUserToken :exec DELETE FROM mcp_server_user_tokens @@ -15059,7 +15093,7 @@ func (q *sqlQuerier) DeleteMCPServerUserToken(ctx context.Context, arg DeleteMCP const getEnabledMCPServerConfigs = `-- name: GetEnabledMCPServerConfigs :many SELECT - id, display_name, slug, description, icon_url, transport, url, auth_type, oauth2_client_id, oauth2_client_secret, oauth2_client_secret_key_id, oauth2_auth_url, oauth2_token_url, oauth2_scopes, api_key_header, api_key_value, api_key_value_key_id, custom_headers, custom_headers_key_id, tool_allow_list, tool_deny_list, availability, enabled, created_by, updated_by, created_at, updated_at, model_intent, allow_in_plan_mode, forward_coder_headers + id, display_name, slug, description, icon_url, transport, url, auth_type, oauth2_client_id, oauth2_client_secret, oauth2_client_secret_key_id, oauth2_auth_url, oauth2_token_url, oauth2_scopes, api_key_header, api_key_value, api_key_value_key_id, custom_headers, custom_headers_key_id, tool_allow_list, tool_deny_list, availability, enabled, created_by, updated_by, created_at, updated_at, model_intent, allow_in_plan_mode, forward_coder_headers, custom_headers_user_keys, custom_headers_user_key_descriptions FROM mcp_server_configs WHERE @@ -15108,6 +15142,8 @@ func (q *sqlQuerier) GetEnabledMCPServerConfigs(ctx context.Context) ([]MCPServe &i.ModelIntent, &i.AllowInPlanMode, &i.ForwardCoderHeaders, + pq.Array(&i.CustomHeadersUserKeys), + &i.CustomHeadersUserKeyDescriptions, ); err != nil { return nil, err } @@ -15124,7 +15160,7 @@ func (q *sqlQuerier) GetEnabledMCPServerConfigs(ctx context.Context) ([]MCPServe const getForcedMCPServerConfigs = `-- name: GetForcedMCPServerConfigs :many SELECT - id, display_name, slug, description, icon_url, transport, url, auth_type, oauth2_client_id, oauth2_client_secret, oauth2_client_secret_key_id, oauth2_auth_url, oauth2_token_url, oauth2_scopes, api_key_header, api_key_value, api_key_value_key_id, custom_headers, custom_headers_key_id, tool_allow_list, tool_deny_list, availability, enabled, created_by, updated_by, created_at, updated_at, model_intent, allow_in_plan_mode, forward_coder_headers + id, display_name, slug, description, icon_url, transport, url, auth_type, oauth2_client_id, oauth2_client_secret, oauth2_client_secret_key_id, oauth2_auth_url, oauth2_token_url, oauth2_scopes, api_key_header, api_key_value, api_key_value_key_id, custom_headers, custom_headers_key_id, tool_allow_list, tool_deny_list, availability, enabled, created_by, updated_by, created_at, updated_at, model_intent, allow_in_plan_mode, forward_coder_headers, custom_headers_user_keys, custom_headers_user_key_descriptions FROM mcp_server_configs WHERE @@ -15174,6 +15210,8 @@ func (q *sqlQuerier) GetForcedMCPServerConfigs(ctx context.Context) ([]MCPServer &i.ModelIntent, &i.AllowInPlanMode, &i.ForwardCoderHeaders, + pq.Array(&i.CustomHeadersUserKeys), + &i.CustomHeadersUserKeyDescriptions, ); err != nil { return nil, err } @@ -15190,7 +15228,7 @@ func (q *sqlQuerier) GetForcedMCPServerConfigs(ctx context.Context) ([]MCPServer const getMCPServerConfigByID = `-- name: GetMCPServerConfigByID :one SELECT - id, display_name, slug, description, icon_url, transport, url, auth_type, oauth2_client_id, oauth2_client_secret, oauth2_client_secret_key_id, oauth2_auth_url, oauth2_token_url, oauth2_scopes, api_key_header, api_key_value, api_key_value_key_id, custom_headers, custom_headers_key_id, tool_allow_list, tool_deny_list, availability, enabled, created_by, updated_by, created_at, updated_at, model_intent, allow_in_plan_mode, forward_coder_headers + id, display_name, slug, description, icon_url, transport, url, auth_type, oauth2_client_id, oauth2_client_secret, oauth2_client_secret_key_id, oauth2_auth_url, oauth2_token_url, oauth2_scopes, api_key_header, api_key_value, api_key_value_key_id, custom_headers, custom_headers_key_id, tool_allow_list, tool_deny_list, availability, enabled, created_by, updated_by, created_at, updated_at, model_intent, allow_in_plan_mode, forward_coder_headers, custom_headers_user_keys, custom_headers_user_key_descriptions FROM mcp_server_configs WHERE @@ -15231,13 +15269,15 @@ func (q *sqlQuerier) GetMCPServerConfigByID(ctx context.Context, id uuid.UUID) ( &i.ModelIntent, &i.AllowInPlanMode, &i.ForwardCoderHeaders, + pq.Array(&i.CustomHeadersUserKeys), + &i.CustomHeadersUserKeyDescriptions, ) return i, err } const getMCPServerConfigBySlug = `-- name: GetMCPServerConfigBySlug :one SELECT - id, display_name, slug, description, icon_url, transport, url, auth_type, oauth2_client_id, oauth2_client_secret, oauth2_client_secret_key_id, oauth2_auth_url, oauth2_token_url, oauth2_scopes, api_key_header, api_key_value, api_key_value_key_id, custom_headers, custom_headers_key_id, tool_allow_list, tool_deny_list, availability, enabled, created_by, updated_by, created_at, updated_at, model_intent, allow_in_plan_mode, forward_coder_headers + id, display_name, slug, description, icon_url, transport, url, auth_type, oauth2_client_id, oauth2_client_secret, oauth2_client_secret_key_id, oauth2_auth_url, oauth2_token_url, oauth2_scopes, api_key_header, api_key_value, api_key_value_key_id, custom_headers, custom_headers_key_id, tool_allow_list, tool_deny_list, availability, enabled, created_by, updated_by, created_at, updated_at, model_intent, allow_in_plan_mode, forward_coder_headers, custom_headers_user_keys, custom_headers_user_key_descriptions FROM mcp_server_configs WHERE @@ -15278,13 +15318,15 @@ func (q *sqlQuerier) GetMCPServerConfigBySlug(ctx context.Context, slug string) &i.ModelIntent, &i.AllowInPlanMode, &i.ForwardCoderHeaders, + pq.Array(&i.CustomHeadersUserKeys), + &i.CustomHeadersUserKeyDescriptions, ) return i, err } const getMCPServerConfigs = `-- name: GetMCPServerConfigs :many SELECT - id, display_name, slug, description, icon_url, transport, url, auth_type, oauth2_client_id, oauth2_client_secret, oauth2_client_secret_key_id, oauth2_auth_url, oauth2_token_url, oauth2_scopes, api_key_header, api_key_value, api_key_value_key_id, custom_headers, custom_headers_key_id, tool_allow_list, tool_deny_list, availability, enabled, created_by, updated_by, created_at, updated_at, model_intent, allow_in_plan_mode, forward_coder_headers + id, display_name, slug, description, icon_url, transport, url, auth_type, oauth2_client_id, oauth2_client_secret, oauth2_client_secret_key_id, oauth2_auth_url, oauth2_token_url, oauth2_scopes, api_key_header, api_key_value, api_key_value_key_id, custom_headers, custom_headers_key_id, tool_allow_list, tool_deny_list, availability, enabled, created_by, updated_by, created_at, updated_at, model_intent, allow_in_plan_mode, forward_coder_headers, custom_headers_user_keys, custom_headers_user_key_descriptions FROM mcp_server_configs ORDER BY @@ -15331,6 +15373,8 @@ func (q *sqlQuerier) GetMCPServerConfigs(ctx context.Context) ([]MCPServerConfig &i.ModelIntent, &i.AllowInPlanMode, &i.ForwardCoderHeaders, + pq.Array(&i.CustomHeadersUserKeys), + &i.CustomHeadersUserKeyDescriptions, ); err != nil { return nil, err } @@ -15347,7 +15391,7 @@ func (q *sqlQuerier) GetMCPServerConfigs(ctx context.Context) ([]MCPServerConfig const getMCPServerConfigsByIDs = `-- name: GetMCPServerConfigsByIDs :many SELECT - id, display_name, slug, description, icon_url, transport, url, auth_type, oauth2_client_id, oauth2_client_secret, oauth2_client_secret_key_id, oauth2_auth_url, oauth2_token_url, oauth2_scopes, api_key_header, api_key_value, api_key_value_key_id, custom_headers, custom_headers_key_id, tool_allow_list, tool_deny_list, availability, enabled, created_by, updated_by, created_at, updated_at, model_intent, allow_in_plan_mode, forward_coder_headers + id, display_name, slug, description, icon_url, transport, url, auth_type, oauth2_client_id, oauth2_client_secret, oauth2_client_secret_key_id, oauth2_auth_url, oauth2_token_url, oauth2_scopes, api_key_header, api_key_value, api_key_value_key_id, custom_headers, custom_headers_key_id, tool_allow_list, tool_deny_list, availability, enabled, created_by, updated_by, created_at, updated_at, model_intent, allow_in_plan_mode, forward_coder_headers, custom_headers_user_keys, custom_headers_user_key_descriptions FROM mcp_server_configs WHERE @@ -15396,6 +15440,78 @@ func (q *sqlQuerier) GetMCPServerConfigsByIDs(ctx context.Context, ids []uuid.UU &i.ModelIntent, &i.AllowInPlanMode, &i.ForwardCoderHeaders, + pq.Array(&i.CustomHeadersUserKeys), + &i.CustomHeadersUserKeyDescriptions, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getMCPServerUserHeaderValues = `-- name: GetMCPServerUserHeaderValues :one +SELECT + id, mcp_server_config_id, user_id, header_values, header_values_key_id, created_at, updated_at +FROM + mcp_server_user_header_values +WHERE + mcp_server_config_id = $1::uuid + AND user_id = $2::uuid +` + +type GetMCPServerUserHeaderValuesParams struct { + MCPServerConfigID uuid.UUID `db:"mcp_server_config_id" json:"mcp_server_config_id"` + UserID uuid.UUID `db:"user_id" json:"user_id"` +} + +func (q *sqlQuerier) GetMCPServerUserHeaderValues(ctx context.Context, arg GetMCPServerUserHeaderValuesParams) (McpServerUserHeaderValue, error) { + row := q.db.QueryRowContext(ctx, getMCPServerUserHeaderValues, arg.MCPServerConfigID, arg.UserID) + var i McpServerUserHeaderValue + err := row.Scan( + &i.ID, + &i.MCPServerConfigID, + &i.UserID, + &i.HeaderValues, + &i.HeaderValuesKeyID, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + +const getMCPServerUserHeaderValuesByUserID = `-- name: GetMCPServerUserHeaderValuesByUserID :many +SELECT + id, mcp_server_config_id, user_id, header_values, header_values_key_id, created_at, updated_at +FROM + mcp_server_user_header_values +WHERE + user_id = $1::uuid +` + +func (q *sqlQuerier) GetMCPServerUserHeaderValuesByUserID(ctx context.Context, userID uuid.UUID) ([]McpServerUserHeaderValue, error) { + rows, err := q.db.QueryContext(ctx, getMCPServerUserHeaderValuesByUserID, userID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []McpServerUserHeaderValue + for rows.Next() { + var i McpServerUserHeaderValue + if err := rows.Scan( + &i.ID, + &i.MCPServerConfigID, + &i.UserID, + &i.HeaderValues, + &i.HeaderValuesKeyID, + &i.CreatedAt, + &i.UpdatedAt, ); err != nil { return nil, err } @@ -15508,6 +15624,8 @@ INSERT INTO mcp_server_configs ( api_key_value_key_id, custom_headers, custom_headers_key_id, + custom_headers_user_keys, + custom_headers_user_key_descriptions, tool_allow_list, tool_deny_list, availability, @@ -15537,47 +15655,51 @@ INSERT INTO mcp_server_configs ( $17::text, $18::text, $19::text[], - $20::text[], - $21::text, - $22::boolean, - $23::boolean, + $20::jsonb, + $21::text[], + $22::text[], + $23::text, $24::boolean, $25::boolean, - $26::uuid, - $27::uuid + $26::boolean, + $27::boolean, + $28::uuid, + $29::uuid ) RETURNING - id, display_name, slug, description, icon_url, transport, url, auth_type, oauth2_client_id, oauth2_client_secret, oauth2_client_secret_key_id, oauth2_auth_url, oauth2_token_url, oauth2_scopes, api_key_header, api_key_value, api_key_value_key_id, custom_headers, custom_headers_key_id, tool_allow_list, tool_deny_list, availability, enabled, created_by, updated_by, created_at, updated_at, model_intent, allow_in_plan_mode, forward_coder_headers + id, display_name, slug, description, icon_url, transport, url, auth_type, oauth2_client_id, oauth2_client_secret, oauth2_client_secret_key_id, oauth2_auth_url, oauth2_token_url, oauth2_scopes, api_key_header, api_key_value, api_key_value_key_id, custom_headers, custom_headers_key_id, tool_allow_list, tool_deny_list, availability, enabled, created_by, updated_by, created_at, updated_at, model_intent, allow_in_plan_mode, forward_coder_headers, custom_headers_user_keys, custom_headers_user_key_descriptions ` type InsertMCPServerConfigParams struct { - DisplayName string `db:"display_name" json:"display_name"` - Slug string `db:"slug" json:"slug"` - Description string `db:"description" json:"description"` - IconURL string `db:"icon_url" json:"icon_url"` - Transport string `db:"transport" json:"transport"` - Url string `db:"url" json:"url"` - AuthType string `db:"auth_type" json:"auth_type"` - OAuth2ClientID string `db:"oauth2_client_id" json:"oauth2_client_id"` - OAuth2ClientSecret string `db:"oauth2_client_secret" json:"oauth2_client_secret"` - OAuth2ClientSecretKeyID sql.NullString `db:"oauth2_client_secret_key_id" json:"oauth2_client_secret_key_id"` - OAuth2AuthURL string `db:"oauth2_auth_url" json:"oauth2_auth_url"` - OAuth2TokenURL string `db:"oauth2_token_url" json:"oauth2_token_url"` - OAuth2Scopes string `db:"oauth2_scopes" json:"oauth2_scopes"` - APIKeyHeader string `db:"api_key_header" json:"api_key_header"` - APIKeyValue string `db:"api_key_value" json:"api_key_value"` - APIKeyValueKeyID sql.NullString `db:"api_key_value_key_id" json:"api_key_value_key_id"` - CustomHeaders string `db:"custom_headers" json:"custom_headers"` - CustomHeadersKeyID sql.NullString `db:"custom_headers_key_id" json:"custom_headers_key_id"` - ToolAllowList []string `db:"tool_allow_list" json:"tool_allow_list"` - ToolDenyList []string `db:"tool_deny_list" json:"tool_deny_list"` - Availability string `db:"availability" json:"availability"` - Enabled bool `db:"enabled" json:"enabled"` - ModelIntent bool `db:"model_intent" json:"model_intent"` - AllowInPlanMode bool `db:"allow_in_plan_mode" json:"allow_in_plan_mode"` - ForwardCoderHeaders bool `db:"forward_coder_headers" json:"forward_coder_headers"` - CreatedBy uuid.UUID `db:"created_by" json:"created_by"` - UpdatedBy uuid.UUID `db:"updated_by" json:"updated_by"` + DisplayName string `db:"display_name" json:"display_name"` + Slug string `db:"slug" json:"slug"` + Description string `db:"description" json:"description"` + IconURL string `db:"icon_url" json:"icon_url"` + Transport string `db:"transport" json:"transport"` + Url string `db:"url" json:"url"` + AuthType string `db:"auth_type" json:"auth_type"` + OAuth2ClientID string `db:"oauth2_client_id" json:"oauth2_client_id"` + OAuth2ClientSecret string `db:"oauth2_client_secret" json:"oauth2_client_secret"` + OAuth2ClientSecretKeyID sql.NullString `db:"oauth2_client_secret_key_id" json:"oauth2_client_secret_key_id"` + OAuth2AuthURL string `db:"oauth2_auth_url" json:"oauth2_auth_url"` + OAuth2TokenURL string `db:"oauth2_token_url" json:"oauth2_token_url"` + OAuth2Scopes string `db:"oauth2_scopes" json:"oauth2_scopes"` + APIKeyHeader string `db:"api_key_header" json:"api_key_header"` + APIKeyValue string `db:"api_key_value" json:"api_key_value"` + APIKeyValueKeyID sql.NullString `db:"api_key_value_key_id" json:"api_key_value_key_id"` + CustomHeaders string `db:"custom_headers" json:"custom_headers"` + CustomHeadersKeyID sql.NullString `db:"custom_headers_key_id" json:"custom_headers_key_id"` + CustomHeadersUserKeys []string `db:"custom_headers_user_keys" json:"custom_headers_user_keys"` + CustomHeadersUserKeyDescriptions json.RawMessage `db:"custom_headers_user_key_descriptions" json:"custom_headers_user_key_descriptions"` + ToolAllowList []string `db:"tool_allow_list" json:"tool_allow_list"` + ToolDenyList []string `db:"tool_deny_list" json:"tool_deny_list"` + Availability string `db:"availability" json:"availability"` + Enabled bool `db:"enabled" json:"enabled"` + ModelIntent bool `db:"model_intent" json:"model_intent"` + AllowInPlanMode bool `db:"allow_in_plan_mode" json:"allow_in_plan_mode"` + ForwardCoderHeaders bool `db:"forward_coder_headers" json:"forward_coder_headers"` + CreatedBy uuid.UUID `db:"created_by" json:"created_by"` + UpdatedBy uuid.UUID `db:"updated_by" json:"updated_by"` } func (q *sqlQuerier) InsertMCPServerConfig(ctx context.Context, arg InsertMCPServerConfigParams) (MCPServerConfig, error) { @@ -15600,6 +15722,8 @@ func (q *sqlQuerier) InsertMCPServerConfig(ctx context.Context, arg InsertMCPSer arg.APIKeyValueKeyID, arg.CustomHeaders, arg.CustomHeadersKeyID, + pq.Array(arg.CustomHeadersUserKeys), + arg.CustomHeadersUserKeyDescriptions, pq.Array(arg.ToolAllowList), pq.Array(arg.ToolDenyList), arg.Availability, @@ -15642,6 +15766,8 @@ func (q *sqlQuerier) InsertMCPServerConfig(ctx context.Context, arg InsertMCPSer &i.ModelIntent, &i.AllowInPlanMode, &i.ForwardCoderHeaders, + pq.Array(&i.CustomHeadersUserKeys), + &i.CustomHeadersUserKeyDescriptions, ) return i, err } @@ -15668,49 +15794,53 @@ SET api_key_value_key_id = $16::text, custom_headers = $17::text, custom_headers_key_id = $18::text, - tool_allow_list = $19::text[], - tool_deny_list = $20::text[], - availability = $21::text, - enabled = $22::boolean, - model_intent = $23::boolean, - allow_in_plan_mode = $24::boolean, - forward_coder_headers = $25::boolean, - updated_by = $26::uuid, + custom_headers_user_keys = $19::text[], + custom_headers_user_key_descriptions = $20::jsonb, + tool_allow_list = $21::text[], + tool_deny_list = $22::text[], + availability = $23::text, + enabled = $24::boolean, + model_intent = $25::boolean, + allow_in_plan_mode = $26::boolean, + forward_coder_headers = $27::boolean, + updated_by = $28::uuid, updated_at = NOW() WHERE - id = $27::uuid + id = $29::uuid RETURNING - id, display_name, slug, description, icon_url, transport, url, auth_type, oauth2_client_id, oauth2_client_secret, oauth2_client_secret_key_id, oauth2_auth_url, oauth2_token_url, oauth2_scopes, api_key_header, api_key_value, api_key_value_key_id, custom_headers, custom_headers_key_id, tool_allow_list, tool_deny_list, availability, enabled, created_by, updated_by, created_at, updated_at, model_intent, allow_in_plan_mode, forward_coder_headers + id, display_name, slug, description, icon_url, transport, url, auth_type, oauth2_client_id, oauth2_client_secret, oauth2_client_secret_key_id, oauth2_auth_url, oauth2_token_url, oauth2_scopes, api_key_header, api_key_value, api_key_value_key_id, custom_headers, custom_headers_key_id, tool_allow_list, tool_deny_list, availability, enabled, created_by, updated_by, created_at, updated_at, model_intent, allow_in_plan_mode, forward_coder_headers, custom_headers_user_keys, custom_headers_user_key_descriptions ` type UpdateMCPServerConfigParams struct { - DisplayName string `db:"display_name" json:"display_name"` - Slug string `db:"slug" json:"slug"` - Description string `db:"description" json:"description"` - IconURL string `db:"icon_url" json:"icon_url"` - Transport string `db:"transport" json:"transport"` - Url string `db:"url" json:"url"` - AuthType string `db:"auth_type" json:"auth_type"` - OAuth2ClientID string `db:"oauth2_client_id" json:"oauth2_client_id"` - OAuth2ClientSecret string `db:"oauth2_client_secret" json:"oauth2_client_secret"` - OAuth2ClientSecretKeyID sql.NullString `db:"oauth2_client_secret_key_id" json:"oauth2_client_secret_key_id"` - OAuth2AuthURL string `db:"oauth2_auth_url" json:"oauth2_auth_url"` - OAuth2TokenURL string `db:"oauth2_token_url" json:"oauth2_token_url"` - OAuth2Scopes string `db:"oauth2_scopes" json:"oauth2_scopes"` - APIKeyHeader string `db:"api_key_header" json:"api_key_header"` - APIKeyValue string `db:"api_key_value" json:"api_key_value"` - APIKeyValueKeyID sql.NullString `db:"api_key_value_key_id" json:"api_key_value_key_id"` - CustomHeaders string `db:"custom_headers" json:"custom_headers"` - CustomHeadersKeyID sql.NullString `db:"custom_headers_key_id" json:"custom_headers_key_id"` - ToolAllowList []string `db:"tool_allow_list" json:"tool_allow_list"` - ToolDenyList []string `db:"tool_deny_list" json:"tool_deny_list"` - Availability string `db:"availability" json:"availability"` - Enabled bool `db:"enabled" json:"enabled"` - ModelIntent bool `db:"model_intent" json:"model_intent"` - AllowInPlanMode bool `db:"allow_in_plan_mode" json:"allow_in_plan_mode"` - ForwardCoderHeaders bool `db:"forward_coder_headers" json:"forward_coder_headers"` - UpdatedBy uuid.UUID `db:"updated_by" json:"updated_by"` - ID uuid.UUID `db:"id" json:"id"` + DisplayName string `db:"display_name" json:"display_name"` + Slug string `db:"slug" json:"slug"` + Description string `db:"description" json:"description"` + IconURL string `db:"icon_url" json:"icon_url"` + Transport string `db:"transport" json:"transport"` + Url string `db:"url" json:"url"` + AuthType string `db:"auth_type" json:"auth_type"` + OAuth2ClientID string `db:"oauth2_client_id" json:"oauth2_client_id"` + OAuth2ClientSecret string `db:"oauth2_client_secret" json:"oauth2_client_secret"` + OAuth2ClientSecretKeyID sql.NullString `db:"oauth2_client_secret_key_id" json:"oauth2_client_secret_key_id"` + OAuth2AuthURL string `db:"oauth2_auth_url" json:"oauth2_auth_url"` + OAuth2TokenURL string `db:"oauth2_token_url" json:"oauth2_token_url"` + OAuth2Scopes string `db:"oauth2_scopes" json:"oauth2_scopes"` + APIKeyHeader string `db:"api_key_header" json:"api_key_header"` + APIKeyValue string `db:"api_key_value" json:"api_key_value"` + APIKeyValueKeyID sql.NullString `db:"api_key_value_key_id" json:"api_key_value_key_id"` + CustomHeaders string `db:"custom_headers" json:"custom_headers"` + CustomHeadersKeyID sql.NullString `db:"custom_headers_key_id" json:"custom_headers_key_id"` + CustomHeadersUserKeys []string `db:"custom_headers_user_keys" json:"custom_headers_user_keys"` + CustomHeadersUserKeyDescriptions json.RawMessage `db:"custom_headers_user_key_descriptions" json:"custom_headers_user_key_descriptions"` + ToolAllowList []string `db:"tool_allow_list" json:"tool_allow_list"` + ToolDenyList []string `db:"tool_deny_list" json:"tool_deny_list"` + Availability string `db:"availability" json:"availability"` + Enabled bool `db:"enabled" json:"enabled"` + ModelIntent bool `db:"model_intent" json:"model_intent"` + AllowInPlanMode bool `db:"allow_in_plan_mode" json:"allow_in_plan_mode"` + ForwardCoderHeaders bool `db:"forward_coder_headers" json:"forward_coder_headers"` + UpdatedBy uuid.UUID `db:"updated_by" json:"updated_by"` + ID uuid.UUID `db:"id" json:"id"` } func (q *sqlQuerier) UpdateMCPServerConfig(ctx context.Context, arg UpdateMCPServerConfigParams) (MCPServerConfig, error) { @@ -15733,6 +15863,8 @@ func (q *sqlQuerier) UpdateMCPServerConfig(ctx context.Context, arg UpdateMCPSer arg.APIKeyValueKeyID, arg.CustomHeaders, arg.CustomHeadersKeyID, + pq.Array(arg.CustomHeadersUserKeys), + arg.CustomHeadersUserKeyDescriptions, pq.Array(arg.ToolAllowList), pq.Array(arg.ToolDenyList), arg.Availability, @@ -15775,6 +15907,55 @@ func (q *sqlQuerier) UpdateMCPServerConfig(ctx context.Context, arg UpdateMCPSer &i.ModelIntent, &i.AllowInPlanMode, &i.ForwardCoderHeaders, + pq.Array(&i.CustomHeadersUserKeys), + &i.CustomHeadersUserKeyDescriptions, + ) + return i, err +} + +const upsertMCPServerUserHeaderValues = `-- name: UpsertMCPServerUserHeaderValues :one +INSERT INTO mcp_server_user_header_values ( + mcp_server_config_id, + user_id, + header_values, + header_values_key_id +) VALUES ( + $1::uuid, + $2::uuid, + $3::text, + $4::text +) +ON CONFLICT (mcp_server_config_id, user_id) DO UPDATE SET + header_values = $3::text, + header_values_key_id = $4::text, + updated_at = NOW() +RETURNING + id, mcp_server_config_id, user_id, header_values, header_values_key_id, created_at, updated_at +` + +type UpsertMCPServerUserHeaderValuesParams struct { + MCPServerConfigID uuid.UUID `db:"mcp_server_config_id" json:"mcp_server_config_id"` + UserID uuid.UUID `db:"user_id" json:"user_id"` + HeaderValues string `db:"header_values" json:"header_values"` + HeaderValuesKeyID sql.NullString `db:"header_values_key_id" json:"header_values_key_id"` +} + +func (q *sqlQuerier) UpsertMCPServerUserHeaderValues(ctx context.Context, arg UpsertMCPServerUserHeaderValuesParams) (McpServerUserHeaderValue, error) { + row := q.db.QueryRowContext(ctx, upsertMCPServerUserHeaderValues, + arg.MCPServerConfigID, + arg.UserID, + arg.HeaderValues, + arg.HeaderValuesKeyID, + ) + var i McpServerUserHeaderValue + err := row.Scan( + &i.ID, + &i.MCPServerConfigID, + &i.UserID, + &i.HeaderValues, + &i.HeaderValuesKeyID, + &i.CreatedAt, + &i.UpdatedAt, ) return i, err } diff --git a/coderd/database/queries/mcpserverconfigs.sql b/coderd/database/queries/mcpserverconfigs.sql index 3d05a2b102..57a290e731 100644 --- a/coderd/database/queries/mcpserverconfigs.sql +++ b/coderd/database/queries/mcpserverconfigs.sql @@ -73,6 +73,8 @@ INSERT INTO mcp_server_configs ( api_key_value_key_id, custom_headers, custom_headers_key_id, + custom_headers_user_keys, + custom_headers_user_key_descriptions, tool_allow_list, tool_deny_list, availability, @@ -101,6 +103,8 @@ INSERT INTO mcp_server_configs ( sqlc.narg('api_key_value_key_id')::text, @custom_headers::text, sqlc.narg('custom_headers_key_id')::text, + @custom_headers_user_keys::text[], + @custom_headers_user_key_descriptions::jsonb, @tool_allow_list::text[], @tool_deny_list::text[], @availability::text, @@ -136,6 +140,8 @@ SET api_key_value_key_id = sqlc.narg('api_key_value_key_id')::text, custom_headers = @custom_headers::text, custom_headers_key_id = sqlc.narg('custom_headers_key_id')::text, + custom_headers_user_keys = @custom_headers_user_keys::text[], + custom_headers_user_key_descriptions = @custom_headers_user_key_descriptions::jsonb, tool_allow_list = @tool_allow_list::text[], tool_deny_list = @tool_deny_list::text[], availability = @availability::text, @@ -211,6 +217,59 @@ WHERE mcp_server_config_id = @mcp_server_config_id::uuid AND user_id = @user_id::uuid; +-- name: GetMCPServerUserHeaderValues :one +SELECT + * +FROM + mcp_server_user_header_values +WHERE + mcp_server_config_id = @mcp_server_config_id::uuid + AND user_id = @user_id::uuid; + +-- name: GetMCPServerUserHeaderValuesByUserID :many +SELECT + * +FROM + mcp_server_user_header_values +WHERE + user_id = @user_id::uuid; + +-- name: UpsertMCPServerUserHeaderValues :one +INSERT INTO mcp_server_user_header_values ( + mcp_server_config_id, + user_id, + header_values, + header_values_key_id +) VALUES ( + @mcp_server_config_id::uuid, + @user_id::uuid, + @header_values::text, + sqlc.narg('header_values_key_id')::text +) +ON CONFLICT (mcp_server_config_id, user_id) DO UPDATE SET + header_values = @header_values::text, + header_values_key_id = sqlc.narg('header_values_key_id')::text, + updated_at = NOW() +RETURNING + *; + +-- name: DeleteMCPServerUserHeaderValues :exec +DELETE FROM + mcp_server_user_header_values +WHERE + mcp_server_config_id = @mcp_server_config_id::uuid + AND user_id = @user_id::uuid; + +-- name: DeleteMCPServerUserHeaderValuesByConfigID :exec +-- Deletes every user's stored header values for the given MCP server +-- config. Use when the admin changes auth_type away from custom_headers +-- or alters custom_headers_user_keys so stale credentials do not +-- silently reactivate when the key set is restored. +DELETE FROM + mcp_server_user_header_values +WHERE + mcp_server_config_id = @mcp_server_config_id::uuid; + -- name: CleanupDeletedMCPServerIDsFromChats :exec UPDATE chats SET mcp_server_ids = ( diff --git a/coderd/database/unique_constraint.go b/coderd/database/unique_constraint.go index 3d5e5dabcf..a83e6064c7 100644 --- a/coderd/database/unique_constraint.go +++ b/coderd/database/unique_constraint.go @@ -51,6 +51,8 @@ const ( UniqueLicensesPkey UniqueConstraint = "licenses_pkey" // ALTER TABLE ONLY licenses ADD CONSTRAINT licenses_pkey PRIMARY KEY (id); UniqueMcpServerConfigsPkey UniqueConstraint = "mcp_server_configs_pkey" // ALTER TABLE ONLY mcp_server_configs ADD CONSTRAINT mcp_server_configs_pkey PRIMARY KEY (id); UniqueMcpServerConfigsSlugKey UniqueConstraint = "mcp_server_configs_slug_key" // ALTER TABLE ONLY mcp_server_configs ADD CONSTRAINT mcp_server_configs_slug_key UNIQUE (slug); + UniqueMcpServerUserHeaderValuesMcpServerConfigIDUserIDKey UniqueConstraint = "mcp_server_user_header_values_mcp_server_config_id_user_id_key" // ALTER TABLE ONLY mcp_server_user_header_values ADD CONSTRAINT mcp_server_user_header_values_mcp_server_config_id_user_id_key UNIQUE (mcp_server_config_id, user_id); + UniqueMcpServerUserHeaderValuesPkey UniqueConstraint = "mcp_server_user_header_values_pkey" // ALTER TABLE ONLY mcp_server_user_header_values ADD CONSTRAINT mcp_server_user_header_values_pkey PRIMARY KEY (id); UniqueMcpServerUserTokensMcpServerConfigIDUserIDKey UniqueConstraint = "mcp_server_user_tokens_mcp_server_config_id_user_id_key" // ALTER TABLE ONLY mcp_server_user_tokens ADD CONSTRAINT mcp_server_user_tokens_mcp_server_config_id_user_id_key UNIQUE (mcp_server_config_id, user_id); UniqueMcpServerUserTokensPkey UniqueConstraint = "mcp_server_user_tokens_pkey" // ALTER TABLE ONLY mcp_server_user_tokens ADD CONSTRAINT mcp_server_user_tokens_pkey PRIMARY KEY (id); UniqueNotificationMessagesPkey UniqueConstraint = "notification_messages_pkey" // ALTER TABLE ONLY notification_messages ADD CONSTRAINT notification_messages_pkey PRIMARY KEY (id); diff --git a/coderd/mcp.go b/coderd/mcp.go index 3e0a5829f7..c2beef973c 100644 --- a/coderd/mcp.go +++ b/coderd/mcp.go @@ -259,33 +259,35 @@ func (api *API) createMCPServerConfig(rw http.ResponseWriter, r *http.Request) { } inserted, err := api.Database.InsertMCPServerConfig(ctx, database.InsertMCPServerConfigParams{ - DisplayName: strings.TrimSpace(req.DisplayName), - Slug: strings.TrimSpace(req.Slug), - Description: strings.TrimSpace(req.Description), - IconURL: strings.TrimSpace(req.IconURL), - Transport: strings.TrimSpace(req.Transport), - Url: strings.TrimSpace(req.URL), - AuthType: strings.TrimSpace(req.AuthType), - OAuth2ClientID: "", - OAuth2ClientSecret: "", - OAuth2ClientSecretKeyID: sql.NullString{}, - OAuth2AuthURL: "", - OAuth2TokenURL: "", - OAuth2Scopes: "", - APIKeyHeader: strings.TrimSpace(req.APIKeyHeader), - APIKeyValue: strings.TrimSpace(req.APIKeyValue), - APIKeyValueKeyID: sql.NullString{}, - CustomHeaders: customHeadersJSON, - CustomHeadersKeyID: sql.NullString{}, - ToolAllowList: coalesceStringSlice(trimStringSlice(req.ToolAllowList)), - ToolDenyList: coalesceStringSlice(trimStringSlice(req.ToolDenyList)), - Availability: strings.TrimSpace(req.Availability), - Enabled: req.Enabled, - ModelIntent: req.ModelIntent, - AllowInPlanMode: req.AllowInPlanMode, - ForwardCoderHeaders: req.ForwardCoderHeaders, - CreatedBy: apiKey.UserID, - UpdatedBy: apiKey.UserID, + DisplayName: strings.TrimSpace(req.DisplayName), + Slug: strings.TrimSpace(req.Slug), + Description: strings.TrimSpace(req.Description), + IconURL: strings.TrimSpace(req.IconURL), + Transport: strings.TrimSpace(req.Transport), + Url: strings.TrimSpace(req.URL), + AuthType: strings.TrimSpace(req.AuthType), + OAuth2ClientID: "", + OAuth2ClientSecret: "", + OAuth2ClientSecretKeyID: sql.NullString{}, + OAuth2AuthURL: "", + OAuth2TokenURL: "", + OAuth2Scopes: "", + APIKeyHeader: strings.TrimSpace(req.APIKeyHeader), + APIKeyValue: strings.TrimSpace(req.APIKeyValue), + APIKeyValueKeyID: sql.NullString{}, + CustomHeaders: customHeadersJSON, + CustomHeadersKeyID: sql.NullString{}, + CustomHeadersUserKeys: nil, + CustomHeadersUserKeyDescriptions: nil, + ToolAllowList: coalesceStringSlice(trimStringSlice(req.ToolAllowList)), + ToolDenyList: coalesceStringSlice(trimStringSlice(req.ToolDenyList)), + Availability: strings.TrimSpace(req.Availability), + Enabled: req.Enabled, + ModelIntent: req.ModelIntent, + AllowInPlanMode: req.AllowInPlanMode, + ForwardCoderHeaders: req.ForwardCoderHeaders, + CreatedBy: apiKey.UserID, + UpdatedBy: apiKey.UserID, }) if err != nil { switch { @@ -347,33 +349,35 @@ func (api *API) createMCPServerConfig(rw http.ResponseWriter, r *http.Request) { // Update the record with discovered OAuth2 credentials. updated, err := api.Database.UpdateMCPServerConfig(ctx, database.UpdateMCPServerConfigParams{ - ID: inserted.ID, - DisplayName: inserted.DisplayName, - Slug: inserted.Slug, - Description: inserted.Description, - IconURL: inserted.IconURL, - Transport: inserted.Transport, - Url: inserted.Url, - AuthType: inserted.AuthType, - OAuth2ClientID: result.clientID, - OAuth2ClientSecret: result.clientSecret, - OAuth2ClientSecretKeyID: sql.NullString{}, - OAuth2AuthURL: result.authURL, - OAuth2TokenURL: result.tokenURL, - OAuth2Scopes: oauth2Scopes, - APIKeyHeader: inserted.APIKeyHeader, - APIKeyValue: inserted.APIKeyValue, - APIKeyValueKeyID: inserted.APIKeyValueKeyID, - CustomHeaders: inserted.CustomHeaders, - CustomHeadersKeyID: inserted.CustomHeadersKeyID, - ToolAllowList: inserted.ToolAllowList, - ToolDenyList: inserted.ToolDenyList, - Availability: inserted.Availability, - Enabled: inserted.Enabled, - ModelIntent: inserted.ModelIntent, - AllowInPlanMode: inserted.AllowInPlanMode, - ForwardCoderHeaders: inserted.ForwardCoderHeaders, - UpdatedBy: apiKey.UserID, + ID: inserted.ID, + DisplayName: inserted.DisplayName, + Slug: inserted.Slug, + Description: inserted.Description, + IconURL: inserted.IconURL, + Transport: inserted.Transport, + Url: inserted.Url, + AuthType: inserted.AuthType, + OAuth2ClientID: result.clientID, + OAuth2ClientSecret: result.clientSecret, + OAuth2ClientSecretKeyID: sql.NullString{}, + OAuth2AuthURL: result.authURL, + OAuth2TokenURL: result.tokenURL, + OAuth2Scopes: oauth2Scopes, + APIKeyHeader: inserted.APIKeyHeader, + APIKeyValue: inserted.APIKeyValue, + APIKeyValueKeyID: inserted.APIKeyValueKeyID, + CustomHeaders: inserted.CustomHeaders, + CustomHeadersKeyID: inserted.CustomHeadersKeyID, + CustomHeadersUserKeys: inserted.CustomHeadersUserKeys, + CustomHeadersUserKeyDescriptions: inserted.CustomHeadersUserKeyDescriptions, + ToolAllowList: inserted.ToolAllowList, + ToolDenyList: inserted.ToolDenyList, + Availability: inserted.Availability, + Enabled: inserted.Enabled, + ModelIntent: inserted.ModelIntent, + AllowInPlanMode: inserted.AllowInPlanMode, + ForwardCoderHeaders: inserted.ForwardCoderHeaders, + UpdatedBy: apiKey.UserID, }) if err != nil { httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ @@ -418,33 +422,35 @@ func (api *API) createMCPServerConfig(rw http.ResponseWriter, r *http.Request) { } inserted, err := api.Database.InsertMCPServerConfig(ctx, database.InsertMCPServerConfigParams{ - DisplayName: strings.TrimSpace(req.DisplayName), - Slug: strings.TrimSpace(req.Slug), - Description: strings.TrimSpace(req.Description), - IconURL: strings.TrimSpace(req.IconURL), - Transport: strings.TrimSpace(req.Transport), - Url: strings.TrimSpace(req.URL), - AuthType: strings.TrimSpace(req.AuthType), - OAuth2ClientID: strings.TrimSpace(req.OAuth2ClientID), - OAuth2ClientSecret: strings.TrimSpace(req.OAuth2ClientSecret), - OAuth2ClientSecretKeyID: sql.NullString{}, - OAuth2AuthURL: strings.TrimSpace(req.OAuth2AuthURL), - OAuth2TokenURL: strings.TrimSpace(req.OAuth2TokenURL), - OAuth2Scopes: strings.TrimSpace(req.OAuth2Scopes), - APIKeyHeader: strings.TrimSpace(req.APIKeyHeader), - APIKeyValue: strings.TrimSpace(req.APIKeyValue), - APIKeyValueKeyID: sql.NullString{}, - CustomHeaders: customHeadersJSON, - CustomHeadersKeyID: sql.NullString{}, - ToolAllowList: coalesceStringSlice(trimStringSlice(req.ToolAllowList)), - ToolDenyList: coalesceStringSlice(trimStringSlice(req.ToolDenyList)), - Availability: strings.TrimSpace(req.Availability), - Enabled: req.Enabled, - ModelIntent: req.ModelIntent, - AllowInPlanMode: req.AllowInPlanMode, - ForwardCoderHeaders: req.ForwardCoderHeaders, - CreatedBy: apiKey.UserID, - UpdatedBy: apiKey.UserID, + DisplayName: strings.TrimSpace(req.DisplayName), + Slug: strings.TrimSpace(req.Slug), + Description: strings.TrimSpace(req.Description), + IconURL: strings.TrimSpace(req.IconURL), + Transport: strings.TrimSpace(req.Transport), + Url: strings.TrimSpace(req.URL), + AuthType: strings.TrimSpace(req.AuthType), + OAuth2ClientID: strings.TrimSpace(req.OAuth2ClientID), + OAuth2ClientSecret: strings.TrimSpace(req.OAuth2ClientSecret), + OAuth2ClientSecretKeyID: sql.NullString{}, + OAuth2AuthURL: strings.TrimSpace(req.OAuth2AuthURL), + OAuth2TokenURL: strings.TrimSpace(req.OAuth2TokenURL), + OAuth2Scopes: strings.TrimSpace(req.OAuth2Scopes), + APIKeyHeader: strings.TrimSpace(req.APIKeyHeader), + APIKeyValue: strings.TrimSpace(req.APIKeyValue), + APIKeyValueKeyID: sql.NullString{}, + CustomHeaders: customHeadersJSON, + CustomHeadersKeyID: sql.NullString{}, + CustomHeadersUserKeys: nil, + CustomHeadersUserKeyDescriptions: nil, + ToolAllowList: coalesceStringSlice(trimStringSlice(req.ToolAllowList)), + ToolDenyList: coalesceStringSlice(trimStringSlice(req.ToolDenyList)), + Availability: strings.TrimSpace(req.Availability), + Enabled: req.Enabled, + ModelIntent: req.ModelIntent, + AllowInPlanMode: req.AllowInPlanMode, + ForwardCoderHeaders: req.ForwardCoderHeaders, + CreatedBy: apiKey.UserID, + UpdatedBy: apiKey.UserID, }) if err != nil { switch { @@ -767,33 +773,35 @@ func (api *API) updateMCPServerConfig(rw http.ResponseWriter, r *http.Request) { } updated, err = tx.UpdateMCPServerConfig(ctx, database.UpdateMCPServerConfigParams{ - DisplayName: displayName, - Slug: slug, - Description: description, - IconURL: iconURL, - Transport: transport, - Url: serverURL, - AuthType: authType, - OAuth2ClientID: oauth2ClientID, - OAuth2ClientSecret: oauth2ClientSecret, - OAuth2ClientSecretKeyID: oauth2ClientSecretKeyID, - OAuth2AuthURL: oauth2AuthURL, - OAuth2TokenURL: oauth2TokenURL, - OAuth2Scopes: oauth2Scopes, - APIKeyHeader: apiKeyHeader, - APIKeyValue: apiKeyValue, - APIKeyValueKeyID: apiKeyValueKeyID, - CustomHeaders: customHeaders, - CustomHeadersKeyID: customHeadersKeyID, - ToolAllowList: toolAllowList, - ToolDenyList: toolDenyList, - Availability: availability, - Enabled: enabled, - ModelIntent: modelIntent, - AllowInPlanMode: allowInPlanMode, - ForwardCoderHeaders: forwardCoderHeaders, - UpdatedBy: apiKey.UserID, - ID: existing.ID, + DisplayName: displayName, + Slug: slug, + Description: description, + IconURL: iconURL, + Transport: transport, + Url: serverURL, + AuthType: authType, + OAuth2ClientID: oauth2ClientID, + OAuth2ClientSecret: oauth2ClientSecret, + OAuth2ClientSecretKeyID: oauth2ClientSecretKeyID, + OAuth2AuthURL: oauth2AuthURL, + OAuth2TokenURL: oauth2TokenURL, + OAuth2Scopes: oauth2Scopes, + APIKeyHeader: apiKeyHeader, + APIKeyValue: apiKeyValue, + APIKeyValueKeyID: apiKeyValueKeyID, + CustomHeaders: customHeaders, + CustomHeadersKeyID: customHeadersKeyID, + CustomHeadersUserKeys: existing.CustomHeadersUserKeys, + CustomHeadersUserKeyDescriptions: existing.CustomHeadersUserKeyDescriptions, + ToolAllowList: toolAllowList, + ToolDenyList: toolDenyList, + Availability: availability, + Enabled: enabled, + ModelIntent: modelIntent, + AllowInPlanMode: allowInPlanMode, + ForwardCoderHeaders: forwardCoderHeaders, + UpdatedBy: apiKey.UserID, + ID: existing.ID, }) return err }, nil) diff --git a/enterprise/dbcrypt/dbcrypt.go b/enterprise/dbcrypt/dbcrypt.go index 44cdb5554e..4239bd899c 100644 --- a/enterprise/dbcrypt/dbcrypt.go +++ b/enterprise/dbcrypt/dbcrypt.go @@ -702,6 +702,12 @@ func (db *dbCrypt) decryptMCPServerUserToken(tok *database.MCPServerUserToken) e return db.decryptField(&tok.RefreshToken, tok.RefreshTokenKeyID) } +// decryptMCPServerUserHeaderValues decrypts all encrypted fields on a +// single McpServerUserHeaderValue in place. +func (db *dbCrypt) decryptMCPServerUserHeaderValues(row *database.McpServerUserHeaderValue) error { + return db.decryptField(&row.HeaderValues, row.HeaderValuesKeyID) +} + func (db *dbCrypt) GetMCPServerConfigByID(ctx context.Context, id uuid.UUID) (database.MCPServerConfig, error) { cfg, err := db.Store.GetMCPServerConfigByID(ctx, id) if err != nil { @@ -876,6 +882,47 @@ func (db *dbCrypt) UpsertMCPServerUserToken(ctx context.Context, params database return tok, nil } +func (db *dbCrypt) GetMCPServerUserHeaderValues(ctx context.Context, arg database.GetMCPServerUserHeaderValuesParams) (database.McpServerUserHeaderValue, error) { + row, err := db.Store.GetMCPServerUserHeaderValues(ctx, arg) + if err != nil { + return database.McpServerUserHeaderValue{}, err + } + if err := db.decryptMCPServerUserHeaderValues(&row); err != nil { + return database.McpServerUserHeaderValue{}, err + } + return row, nil +} + +func (db *dbCrypt) GetMCPServerUserHeaderValuesByUserID(ctx context.Context, userID uuid.UUID) ([]database.McpServerUserHeaderValue, error) { + rows, err := db.Store.GetMCPServerUserHeaderValuesByUserID(ctx, userID) + if err != nil { + return nil, err + } + for i := range rows { + if err := db.decryptMCPServerUserHeaderValues(&rows[i]); err != nil { + return nil, err + } + } + return rows, nil +} + +func (db *dbCrypt) UpsertMCPServerUserHeaderValues(ctx context.Context, params database.UpsertMCPServerUserHeaderValuesParams) (database.McpServerUserHeaderValue, error) { + if strings.TrimSpace(params.HeaderValues) == "" { + params.HeaderValuesKeyID = sql.NullString{} + } else if err := db.encryptField(¶ms.HeaderValues, ¶ms.HeaderValuesKeyID); err != nil { + return database.McpServerUserHeaderValue{}, err + } + + row, err := db.Store.UpsertMCPServerUserHeaderValues(ctx, params) + if err != nil { + return database.McpServerUserHeaderValue{}, err + } + if err := db.decryptMCPServerUserHeaderValues(&row); err != nil { + return database.McpServerUserHeaderValue{}, err + } + return row, nil +} + func (db *dbCrypt) CreateUserSecret(ctx context.Context, params database.CreateUserSecretParams) (database.UserSecret, error) { if err := db.encryptField(¶ms.Value, ¶ms.ValueKeyID); err != nil { return database.UserSecret{}, err diff --git a/enterprise/dbcrypt/dbcrypt_internal_test.go b/enterprise/dbcrypt/dbcrypt_internal_test.go index e5a433399b..7e01905a92 100644 --- a/enterprise/dbcrypt/dbcrypt_internal_test.go +++ b/enterprise/dbcrypt/dbcrypt_internal_test.go @@ -1032,22 +1032,24 @@ func TestMCPServerConfigs(t *testing.T) { newHeaders = `{"X-New":"new-value"}` ) updated, err := crypt.UpdateMCPServerConfig(ctx, database.UpdateMCPServerConfigParams{ - ID: cfg.ID, - DisplayName: cfg.DisplayName, - Slug: cfg.Slug, - Description: cfg.Description, - Url: cfg.Url, - Transport: cfg.Transport, - AuthType: cfg.AuthType, - OAuth2ClientID: cfg.OAuth2ClientID, - OAuth2ClientSecret: newSecret, - APIKeyValue: newAPIKey, - CustomHeaders: newHeaders, - ToolAllowList: cfg.ToolAllowList, - ToolDenyList: cfg.ToolDenyList, - Availability: cfg.Availability, - Enabled: cfg.Enabled, - UpdatedBy: cfg.CreatedBy.UUID, + ID: cfg.ID, + DisplayName: cfg.DisplayName, + Slug: cfg.Slug, + Description: cfg.Description, + Url: cfg.Url, + Transport: cfg.Transport, + AuthType: cfg.AuthType, + OAuth2ClientID: cfg.OAuth2ClientID, + OAuth2ClientSecret: newSecret, + APIKeyValue: newAPIKey, + CustomHeaders: newHeaders, + CustomHeadersUserKeys: cfg.CustomHeadersUserKeys, + CustomHeadersUserKeyDescriptions: cfg.CustomHeadersUserKeyDescriptions, + ToolAllowList: cfg.ToolAllowList, + ToolDenyList: cfg.ToolDenyList, + Availability: cfg.Availability, + Enabled: cfg.Enabled, + UpdatedBy: cfg.CreatedBy.UUID, }) require.NoError(t, err) requireMCPServerConfigDecrypted(t, updated, ciphers, newSecret, newAPIKey, newHeaders) @@ -1570,6 +1572,98 @@ func TestMCPServerUserTokens(t *testing.T) { }) } +func TestMCPServerUserHeaderValues(t *testing.T) { + t.Parallel() + ctx := context.Background() + + const headerValues = `{"X-User-Token":"super-secret-user-token"}` + + // insertConfigAndValues creates a user, an MCP server config with a + // user-set custom header, and the user-supplied values row through the + // encrypted store. + insertConfigAndValues := func( + t *testing.T, + crypt *dbCrypt, + ciphers []Cipher, + ) (database.MCPServerConfig, database.McpServerUserHeaderValue) { + t.Helper() + user := dbgen.User(t, crypt, database.User{}) + cfg := dbgen.MCPServerConfig(t, crypt, database.MCPServerConfig{ + DisplayName: "Header Values Test MCP", + AuthType: "custom_headers", + CustomHeadersUserKeys: []string{"X-User-Token"}, + CreatedBy: uuid.NullUUID{UUID: user.ID, Valid: true}, + UpdatedBy: uuid.NullUUID{UUID: user.ID, Valid: true}, + }) + + row, err := crypt.UpsertMCPServerUserHeaderValues(ctx, database.UpsertMCPServerUserHeaderValuesParams{ + MCPServerConfigID: cfg.ID, + UserID: user.ID, + HeaderValues: headerValues, + }) + require.NoError(t, err) + require.Equal(t, headerValues, row.HeaderValues) + require.Equal(t, ciphers[0].HexDigest(), row.HeaderValuesKeyID.String) + return cfg, row + } + + t.Run("UpsertMCPServerUserHeaderValues", func(t *testing.T) { + t.Parallel() + db, crypt, ciphers := setup(t) + cfg, row := insertConfigAndValues(t, crypt, ciphers) + + // Verify the raw DB value is encrypted. + rawRow, err := db.GetMCPServerUserHeaderValues(ctx, database.GetMCPServerUserHeaderValuesParams{ + MCPServerConfigID: cfg.ID, + UserID: row.UserID, + }) + require.NoError(t, err) + requireEncryptedEquals(t, ciphers[0], rawRow.HeaderValues, headerValues) + }) + + t.Run("GetMCPServerUserHeaderValues", func(t *testing.T) { + t.Parallel() + db, crypt, ciphers := setup(t) + cfg, row := insertConfigAndValues(t, crypt, ciphers) + + got, err := crypt.GetMCPServerUserHeaderValues(ctx, database.GetMCPServerUserHeaderValuesParams{ + MCPServerConfigID: cfg.ID, + UserID: row.UserID, + }) + require.NoError(t, err) + require.Equal(t, headerValues, got.HeaderValues) + require.Equal(t, ciphers[0].HexDigest(), got.HeaderValuesKeyID.String) + + // Raw values must be encrypted. + rawRow, err := db.GetMCPServerUserHeaderValues(ctx, database.GetMCPServerUserHeaderValuesParams{ + MCPServerConfigID: cfg.ID, + UserID: row.UserID, + }) + require.NoError(t, err) + requireEncryptedEquals(t, ciphers[0], rawRow.HeaderValues, headerValues) + }) + + t.Run("GetMCPServerUserHeaderValuesByUserID", func(t *testing.T) { + t.Parallel() + db, crypt, ciphers := setup(t) + cfg, row := insertConfigAndValues(t, crypt, ciphers) + + rows, err := crypt.GetMCPServerUserHeaderValuesByUserID(ctx, row.UserID) + require.NoError(t, err) + require.Len(t, rows, 1) + require.Equal(t, headerValues, rows[0].HeaderValues) + require.Equal(t, ciphers[0].HexDigest(), rows[0].HeaderValuesKeyID.String) + + // Raw values must be encrypted. + rawRow, err := db.GetMCPServerUserHeaderValues(ctx, database.GetMCPServerUserHeaderValuesParams{ + MCPServerConfigID: cfg.ID, + UserID: row.UserID, + }) + require.NoError(t, err) + requireEncryptedEquals(t, ciphers[0], rawRow.HeaderValues, headerValues) + }) +} + func TestUserSecrets(t *testing.T) { t.Parallel() ctx := context.Background()