chore: support multi-org group sync with runtime configuration (#14578)

- Implement multi-org group sync
- Implement runtime configuration to change sync behavior
- Legacy group sync migrated to new package
This commit is contained in:
Steven Masley
2024-09-11 13:43:50 -05:00
committed by GitHub
parent 7de576b596
commit 6a846cdbb8
27 changed files with 1920 additions and 341 deletions
-5
View File
@@ -187,11 +187,6 @@ func createOIDCConfig(ctx context.Context, logger slog.Logger, vals *codersdk.De
EmailField: vals.OIDC.EmailField.String(), EmailField: vals.OIDC.EmailField.String(),
AuthURLParams: vals.OIDC.AuthURLParams.Value, AuthURLParams: vals.OIDC.AuthURLParams.Value,
IgnoreUserInfo: vals.OIDC.IgnoreUserInfo.Value(), IgnoreUserInfo: vals.OIDC.IgnoreUserInfo.Value(),
GroupField: vals.OIDC.GroupField.String(),
GroupFilter: vals.OIDC.GroupRegexFilter.Value(),
GroupAllowList: groupAllowList,
CreateMissingGroups: vals.OIDC.GroupAutoCreate.Value(),
GroupMapping: vals.OIDC.GroupMapping.Value,
UserRoleField: vals.OIDC.UserRoleField.String(), UserRoleField: vals.OIDC.UserRoleField.String(),
UserRoleMapping: vals.OIDC.UserRoleMapping.Value, UserRoleMapping: vals.OIDC.UserRoleMapping.Value,
UserRolesDefault: vals.OIDC.UserRolesDefault.GetSlice(), UserRolesDefault: vals.OIDC.UserRolesDefault.GetSlice(),
+4 -18
View File
@@ -181,7 +181,6 @@ type Options struct {
NetworkTelemetryBatchFrequency time.Duration NetworkTelemetryBatchFrequency time.Duration
NetworkTelemetryBatchMaxSize int NetworkTelemetryBatchMaxSize int
SwaggerEndpoint bool SwaggerEndpoint bool
SetUserGroups func(ctx context.Context, logger slog.Logger, tx database.Store, userID uuid.UUID, orgGroupNames map[uuid.UUID][]string, createMissingGroups bool) error
SetUserSiteRoles func(ctx context.Context, logger slog.Logger, tx database.Store, userID uuid.UUID, roles []string) error SetUserSiteRoles func(ctx context.Context, logger slog.Logger, tx database.Store, userID uuid.UUID, roles []string) error
TemplateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore] TemplateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore]
UserQuietHoursScheduleStore *atomic.Pointer[schedule.UserQuietHoursScheduleStore] UserQuietHoursScheduleStore *atomic.Pointer[schedule.UserQuietHoursScheduleStore]
@@ -276,13 +275,6 @@ func New(options *Options) *API {
if options.Entitlements == nil { if options.Entitlements == nil {
options.Entitlements = entitlements.New() options.Entitlements = entitlements.New()
} }
if options.IDPSync == nil {
options.IDPSync = idpsync.NewAGPLSync(options.Logger, idpsync.SyncSettings{
OrganizationField: options.DeploymentValues.OIDC.OrganizationField.Value(),
OrganizationMapping: options.DeploymentValues.OIDC.OrganizationMapping.Value,
OrganizationAssignDefault: options.DeploymentValues.OIDC.OrganizationAssignDefault.Value(),
})
}
if options.NewTicker == nil { if options.NewTicker == nil {
options.NewTicker = func(duration time.Duration) (tick <-chan time.Time, done func()) { options.NewTicker = func(duration time.Duration) (tick <-chan time.Time, done func()) {
ticker := time.NewTicker(duration) ticker := time.NewTicker(duration)
@@ -318,6 +310,10 @@ func New(options *Options) *API {
options.AccessControlStore, options.AccessControlStore,
) )
if options.IDPSync == nil {
options.IDPSync = idpsync.NewAGPLSync(options.Logger, options.RuntimeConfig, idpsync.FromDeploymentValues(options.DeploymentValues))
}
experiments := ReadExperiments( experiments := ReadExperiments(
options.Logger, options.DeploymentValues.Experiments.Value(), options.Logger, options.DeploymentValues.Experiments.Value(),
) )
@@ -377,16 +373,6 @@ func New(options *Options) *API {
if options.TracerProvider == nil { if options.TracerProvider == nil {
options.TracerProvider = trace.NewNoopTracerProvider() options.TracerProvider = trace.NewNoopTracerProvider()
} }
if options.SetUserGroups == nil {
options.SetUserGroups = func(ctx context.Context, logger slog.Logger, _ database.Store, userID uuid.UUID, orgGroupNames map[uuid.UUID][]string, createMissingGroups bool) error {
logger.Warn(ctx, "attempted to assign OIDC groups without enterprise license",
slog.F("user_id", userID),
slog.F("groups", orgGroupNames),
slog.F("create_missing_groups", createMissingGroups),
)
return nil
}
}
if options.SetUserSiteRoles == nil { if options.SetUserSiteRoles == nil {
options.SetUserSiteRoles = func(ctx context.Context, logger slog.Logger, _ database.Store, userID uuid.UUID, roles []string) error { options.SetUserSiteRoles = func(ctx context.Context, logger slog.Logger, _ database.Store, userID uuid.UUID, roles []string) error {
logger.Warn(ctx, "attempted to assign OIDC user roles without enterprise license", logger.Warn(ctx, "attempted to assign OIDC user roles without enterprise license",
+25
View File
@@ -0,0 +1,25 @@
package coderdtest
import "github.com/google/uuid"
// DeterministicUUIDGenerator allows "naming" uuids for unit tests.
// An example of where this is useful, is when a tabled test references
// a UUID that is not yet known. An alternative to this would be to
// hard code some UUID strings, but these strings are not human friendly.
type DeterministicUUIDGenerator struct {
Named map[string]uuid.UUID
}
func NewDeterministicUUIDGenerator() *DeterministicUUIDGenerator {
return &DeterministicUUIDGenerator{
Named: make(map[string]uuid.UUID),
}
}
func (d *DeterministicUUIDGenerator) ID(name string) uuid.UUID {
if v, ok := d.Named[name]; ok {
return v
}
d.Named[name] = uuid.New()
return d.Named[name]
}
+17
View File
@@ -0,0 +1,17 @@
package coderdtest_test
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd/coderdtest"
)
func TestDeterministicUUIDGenerator(t *testing.T) {
t.Parallel()
ids := coderdtest.NewDeterministicUUIDGenerator()
require.Equal(t, ids.ID("g1"), ids.ID("g1"))
require.NotEqual(t, ids.ID("g1"), ids.ID("g2"))
}
+16
View File
@@ -2892,6 +2892,14 @@ func (q *querier) InsertUser(ctx context.Context, arg database.InsertUserParams)
return insert(q.log, q.auth, obj, q.db.InsertUser)(ctx, arg) return insert(q.log, q.auth, obj, q.db.InsertUser)(ctx, arg)
} }
func (q *querier) InsertUserGroupsByID(ctx context.Context, arg database.InsertUserGroupsByIDParams) ([]uuid.UUID, error) {
// This is used by OIDC sync. So only used by a system user.
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
return nil, err
}
return q.db.InsertUserGroupsByID(ctx, arg)
}
func (q *querier) InsertUserGroupsByName(ctx context.Context, arg database.InsertUserGroupsByNameParams) error { func (q *querier) InsertUserGroupsByName(ctx context.Context, arg database.InsertUserGroupsByNameParams) error {
// This will add the user to all named groups. This counts as updating a group. // This will add the user to all named groups. This counts as updating a group.
// NOTE: instead of checking if the user has permission to update each group, we instead // NOTE: instead of checking if the user has permission to update each group, we instead
@@ -3100,6 +3108,14 @@ func (q *querier) RemoveUserFromAllGroups(ctx context.Context, userID uuid.UUID)
return q.db.RemoveUserFromAllGroups(ctx, userID) return q.db.RemoveUserFromAllGroups(ctx, userID)
} }
func (q *querier) RemoveUserFromGroups(ctx context.Context, arg database.RemoveUserFromGroupsParams) ([]uuid.UUID, error) {
// This is a system function to clear user groups in group sync.
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
return nil, err
}
return q.db.RemoveUserFromGroups(ctx, arg)
}
func (q *querier) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error { func (q *querier) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error {
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil { if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
return err return err
+23
View File
@@ -388,6 +388,17 @@ func (s *MethodTestSuite) TestGroup() {
GroupNames: slice.New(g1.Name, g2.Name), GroupNames: slice.New(g1.Name, g2.Name),
}).Asserts(rbac.ResourceGroup.InOrg(o.ID), policy.ActionUpdate).Returns() }).Asserts(rbac.ResourceGroup.InOrg(o.ID), policy.ActionUpdate).Returns()
})) }))
s.Run("InsertUserGroupsByID", s.Subtest(func(db database.Store, check *expects) {
o := dbgen.Organization(s.T(), db, database.Organization{})
u1 := dbgen.User(s.T(), db, database.User{})
g1 := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID})
g2 := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID})
_ = dbgen.GroupMember(s.T(), db, database.GroupMemberTable{GroupID: g1.ID, UserID: u1.ID})
check.Args(database.InsertUserGroupsByIDParams{
UserID: u1.ID,
GroupIds: slice.New(g1.ID, g2.ID),
}).Asserts(rbac.ResourceSystem, policy.ActionUpdate).Returns(slice.New(g1.ID, g2.ID))
}))
s.Run("RemoveUserFromAllGroups", s.Subtest(func(db database.Store, check *expects) { s.Run("RemoveUserFromAllGroups", s.Subtest(func(db database.Store, check *expects) {
o := dbgen.Organization(s.T(), db, database.Organization{}) o := dbgen.Organization(s.T(), db, database.Organization{})
u1 := dbgen.User(s.T(), db, database.User{}) u1 := dbgen.User(s.T(), db, database.User{})
@@ -397,6 +408,18 @@ func (s *MethodTestSuite) TestGroup() {
_ = dbgen.GroupMember(s.T(), db, database.GroupMemberTable{GroupID: g2.ID, UserID: u1.ID}) _ = dbgen.GroupMember(s.T(), db, database.GroupMemberTable{GroupID: g2.ID, UserID: u1.ID})
check.Args(u1.ID).Asserts(rbac.ResourceSystem, policy.ActionUpdate).Returns() check.Args(u1.ID).Asserts(rbac.ResourceSystem, policy.ActionUpdate).Returns()
})) }))
s.Run("RemoveUserFromGroups", s.Subtest(func(db database.Store, check *expects) {
o := dbgen.Organization(s.T(), db, database.Organization{})
u1 := dbgen.User(s.T(), db, database.User{})
g1 := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID})
g2 := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID})
_ = dbgen.GroupMember(s.T(), db, database.GroupMemberTable{GroupID: g1.ID, UserID: u1.ID})
_ = dbgen.GroupMember(s.T(), db, database.GroupMemberTable{GroupID: g2.ID, UserID: u1.ID})
check.Args(database.RemoveUserFromGroupsParams{
UserID: u1.ID,
GroupIds: []uuid.UUID{g1.ID, g2.ID},
}).Asserts(rbac.ResourceSystem, policy.ActionUpdate).Returns(slice.New(g1.ID, g2.ID))
}))
s.Run("UpdateGroupByID", s.Subtest(func(db database.Store, check *expects) { s.Run("UpdateGroupByID", s.Subtest(func(db database.Store, check *expects) {
g := dbgen.Group(s.T(), db, database.Group{}) g := dbgen.Group(s.T(), db, database.Group{})
check.Args(database.UpdateGroupByIDParams{ check.Args(database.UpdateGroupByIDParams{
+66 -4
View File
@@ -2695,18 +2695,18 @@ func (q *FakeQuerier) GetGroups(_ context.Context, arg database.GetGroupsParams)
q.mutex.RLock() q.mutex.RLock()
defer q.mutex.RUnlock() defer q.mutex.RUnlock()
groupIDs := make(map[uuid.UUID]struct{}) userGroupIDs := make(map[uuid.UUID]struct{})
if arg.HasMemberID != uuid.Nil { if arg.HasMemberID != uuid.Nil {
for _, member := range q.groupMembers { for _, member := range q.groupMembers {
if member.UserID == arg.HasMemberID { if member.UserID == arg.HasMemberID {
groupIDs[member.GroupID] = struct{}{} userGroupIDs[member.GroupID] = struct{}{}
} }
} }
// Handle the everyone group // Handle the everyone group
for _, orgMember := range q.organizationMembers { for _, orgMember := range q.organizationMembers {
if orgMember.UserID == arg.HasMemberID { if orgMember.UserID == arg.HasMemberID {
groupIDs[orgMember.OrganizationID] = struct{}{} userGroupIDs[orgMember.OrganizationID] = struct{}{}
} }
} }
} }
@@ -2718,11 +2718,15 @@ func (q *FakeQuerier) GetGroups(_ context.Context, arg database.GetGroupsParams)
continue continue
} }
_, ok := groupIDs[group.ID] _, ok := userGroupIDs[group.ID]
if arg.HasMemberID != uuid.Nil && !ok { if arg.HasMemberID != uuid.Nil && !ok {
continue continue
} }
if len(arg.GroupNames) > 0 && !slices.Contains(arg.GroupNames, group.Name) {
continue
}
orgDetails, ok := orgDetailsCache[group.ID] orgDetails, ok := orgDetailsCache[group.ID]
if !ok { if !ok {
for _, org := range q.organizations { for _, org := range q.organizations {
@@ -7015,7 +7019,37 @@ func (q *FakeQuerier) InsertUser(_ context.Context, arg database.InsertUserParam
return user, nil return user, nil
} }
func (q *FakeQuerier) InsertUserGroupsByID(_ context.Context, arg database.InsertUserGroupsByIDParams) ([]uuid.UUID, error) {
err := validateDatabaseType(arg)
if err != nil {
return nil, err
}
q.mutex.Lock()
defer q.mutex.Unlock()
var groupIDs []uuid.UUID
for _, group := range q.groups {
for _, groupID := range arg.GroupIds {
if group.ID == groupID {
q.groupMembers = append(q.groupMembers, database.GroupMemberTable{
UserID: arg.UserID,
GroupID: groupID,
})
groupIDs = append(groupIDs, group.ID)
}
}
}
return groupIDs, nil
}
func (q *FakeQuerier) InsertUserGroupsByName(_ context.Context, arg database.InsertUserGroupsByNameParams) error { func (q *FakeQuerier) InsertUserGroupsByName(_ context.Context, arg database.InsertUserGroupsByNameParams) error {
err := validateDatabaseType(arg)
if err != nil {
return err
}
q.mutex.Lock() q.mutex.Lock()
defer q.mutex.Unlock() defer q.mutex.Unlock()
@@ -7607,6 +7641,34 @@ func (q *FakeQuerier) RemoveUserFromAllGroups(_ context.Context, userID uuid.UUI
return nil return nil
} }
func (q *FakeQuerier) RemoveUserFromGroups(_ context.Context, arg database.RemoveUserFromGroupsParams) ([]uuid.UUID, error) {
err := validateDatabaseType(arg)
if err != nil {
return nil, err
}
q.mutex.Lock()
defer q.mutex.Unlock()
removed := make([]uuid.UUID, 0)
q.data.groupMembers = slices.DeleteFunc(q.data.groupMembers, func(groupMember database.GroupMemberTable) bool {
// Delete all group members that match the arguments.
if groupMember.UserID != arg.UserID {
// Not the right user, ignore.
return false
}
if !slices.Contains(arg.GroupIds, groupMember.GroupID) {
return false
}
removed = append(removed, groupMember.GroupID)
return true
})
return removed, nil
}
func (q *FakeQuerier) RevokeDBCryptKey(_ context.Context, activeKeyDigest string) error { func (q *FakeQuerier) RevokeDBCryptKey(_ context.Context, activeKeyDigest string) error {
q.mutex.Lock() q.mutex.Lock()
defer q.mutex.Unlock() defer q.mutex.Unlock()
+14
View File
@@ -1789,6 +1789,13 @@ func (m metricsStore) InsertUser(ctx context.Context, arg database.InsertUserPar
return user, err return user, err
} }
func (m metricsStore) InsertUserGroupsByID(ctx context.Context, arg database.InsertUserGroupsByIDParams) ([]uuid.UUID, error) {
start := time.Now()
r0, r1 := m.s.InsertUserGroupsByID(ctx, arg)
m.queryLatencies.WithLabelValues("InsertUserGroupsByID").Observe(time.Since(start).Seconds())
return r0, r1
}
func (m metricsStore) InsertUserGroupsByName(ctx context.Context, arg database.InsertUserGroupsByNameParams) error { func (m metricsStore) InsertUserGroupsByName(ctx context.Context, arg database.InsertUserGroupsByNameParams) error {
start := time.Now() start := time.Now()
err := m.s.InsertUserGroupsByName(ctx, arg) err := m.s.InsertUserGroupsByName(ctx, arg)
@@ -1943,6 +1950,13 @@ func (m metricsStore) RemoveUserFromAllGroups(ctx context.Context, userID uuid.U
return r0 return r0
} }
func (m metricsStore) RemoveUserFromGroups(ctx context.Context, arg database.RemoveUserFromGroupsParams) ([]uuid.UUID, error) {
start := time.Now()
r0, r1 := m.s.RemoveUserFromGroups(ctx, arg)
m.queryLatencies.WithLabelValues("RemoveUserFromGroups").Observe(time.Since(start).Seconds())
return r0, r1
}
func (m metricsStore) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error { func (m metricsStore) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error {
start := time.Now() start := time.Now()
r0 := m.s.RevokeDBCryptKey(ctx, activeKeyDigest) r0 := m.s.RevokeDBCryptKey(ctx, activeKeyDigest)
+30
View File
@@ -3766,6 +3766,21 @@ func (mr *MockStoreMockRecorder) InsertUser(arg0, arg1 any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertUser", reflect.TypeOf((*MockStore)(nil).InsertUser), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertUser", reflect.TypeOf((*MockStore)(nil).InsertUser), arg0, arg1)
} }
// InsertUserGroupsByID mocks base method.
func (m *MockStore) InsertUserGroupsByID(arg0 context.Context, arg1 database.InsertUserGroupsByIDParams) ([]uuid.UUID, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "InsertUserGroupsByID", arg0, arg1)
ret0, _ := ret[0].([]uuid.UUID)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// InsertUserGroupsByID indicates an expected call of InsertUserGroupsByID.
func (mr *MockStoreMockRecorder) InsertUserGroupsByID(arg0, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertUserGroupsByID", reflect.TypeOf((*MockStore)(nil).InsertUserGroupsByID), arg0, arg1)
}
// InsertUserGroupsByName mocks base method. // InsertUserGroupsByName mocks base method.
func (m *MockStore) InsertUserGroupsByName(arg0 context.Context, arg1 database.InsertUserGroupsByNameParams) error { func (m *MockStore) InsertUserGroupsByName(arg0 context.Context, arg1 database.InsertUserGroupsByNameParams) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@@ -4103,6 +4118,21 @@ func (mr *MockStoreMockRecorder) RemoveUserFromAllGroups(arg0, arg1 any) *gomock
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveUserFromAllGroups", reflect.TypeOf((*MockStore)(nil).RemoveUserFromAllGroups), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveUserFromAllGroups", reflect.TypeOf((*MockStore)(nil).RemoveUserFromAllGroups), arg0, arg1)
} }
// RemoveUserFromGroups mocks base method.
func (m *MockStore) RemoveUserFromGroups(arg0 context.Context, arg1 database.RemoveUserFromGroupsParams) ([]uuid.UUID, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RemoveUserFromGroups", arg0, arg1)
ret0, _ := ret[0].([]uuid.UUID)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// RemoveUserFromGroups indicates an expected call of RemoveUserFromGroups.
func (mr *MockStoreMockRecorder) RemoveUserFromGroups(arg0, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveUserFromGroups", reflect.TypeOf((*MockStore)(nil).RemoveUserFromGroups), arg0, arg1)
}
// RevokeDBCryptKey mocks base method. // RevokeDBCryptKey mocks base method.
func (m *MockStore) RevokeDBCryptKey(arg0 context.Context, arg1 string) error { func (m *MockStore) RevokeDBCryptKey(arg0 context.Context, arg1 string) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
+4
View File
@@ -369,6 +369,9 @@ type sqlcQuerier interface {
InsertTemplateVersionVariable(ctx context.Context, arg InsertTemplateVersionVariableParams) (TemplateVersionVariable, error) InsertTemplateVersionVariable(ctx context.Context, arg InsertTemplateVersionVariableParams) (TemplateVersionVariable, error)
InsertTemplateVersionWorkspaceTag(ctx context.Context, arg InsertTemplateVersionWorkspaceTagParams) (TemplateVersionWorkspaceTag, error) InsertTemplateVersionWorkspaceTag(ctx context.Context, arg InsertTemplateVersionWorkspaceTagParams) (TemplateVersionWorkspaceTag, error)
InsertUser(ctx context.Context, arg InsertUserParams) (User, error) InsertUser(ctx context.Context, arg InsertUserParams) (User, error)
// InsertUserGroupsByID adds a user to all provided groups, if they exist.
// If there is a conflict, the user is already a member
InsertUserGroupsByID(ctx context.Context, arg InsertUserGroupsByIDParams) ([]uuid.UUID, error)
// InsertUserGroupsByName adds a user to all provided groups, if they exist. // InsertUserGroupsByName adds a user to all provided groups, if they exist.
InsertUserGroupsByName(ctx context.Context, arg InsertUserGroupsByNameParams) error InsertUserGroupsByName(ctx context.Context, arg InsertUserGroupsByNameParams) error
InsertUserLink(ctx context.Context, arg InsertUserLinkParams) (UserLink, error) InsertUserLink(ctx context.Context, arg InsertUserLinkParams) (UserLink, error)
@@ -396,6 +399,7 @@ type sqlcQuerier interface {
ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx context.Context, templateID uuid.UUID) error ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx context.Context, templateID uuid.UUID) error
RegisterWorkspaceProxy(ctx context.Context, arg RegisterWorkspaceProxyParams) (WorkspaceProxy, error) RegisterWorkspaceProxy(ctx context.Context, arg RegisterWorkspaceProxyParams) (WorkspaceProxy, error)
RemoveUserFromAllGroups(ctx context.Context, userID uuid.UUID) error RemoveUserFromAllGroups(ctx context.Context, userID uuid.UUID) error
RemoveUserFromGroups(ctx context.Context, arg RemoveUserFromGroupsParams) ([]uuid.UUID, error)
RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error
// Non blocking lock. Returns true if the lock was acquired, false otherwise. // Non blocking lock. Returns true if the lock was acquired, false otherwise.
// //
+93 -1
View File
@@ -1446,6 +1446,56 @@ func (q *sqlQuerier) InsertGroupMember(ctx context.Context, arg InsertGroupMembe
return err return err
} }
const insertUserGroupsByID = `-- name: InsertUserGroupsByID :many
WITH groups AS (
SELECT
id
FROM
groups
WHERE
groups.id = ANY($2 :: uuid [])
)
INSERT INTO
group_members (user_id, group_id)
SELECT
$1,
groups.id
FROM
groups
ON CONFLICT DO NOTHING
RETURNING group_id
`
type InsertUserGroupsByIDParams struct {
UserID uuid.UUID `db:"user_id" json:"user_id"`
GroupIds []uuid.UUID `db:"group_ids" json:"group_ids"`
}
// InsertUserGroupsByID adds a user to all provided groups, if they exist.
// If there is a conflict, the user is already a member
func (q *sqlQuerier) InsertUserGroupsByID(ctx context.Context, arg InsertUserGroupsByIDParams) ([]uuid.UUID, error) {
rows, err := q.db.QueryContext(ctx, insertUserGroupsByID, arg.UserID, pq.Array(arg.GroupIds))
if err != nil {
return nil, err
}
defer rows.Close()
var items []uuid.UUID
for rows.Next() {
var group_id uuid.UUID
if err := rows.Scan(&group_id); err != nil {
return nil, err
}
items = append(items, group_id)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const insertUserGroupsByName = `-- name: InsertUserGroupsByName :exec const insertUserGroupsByName = `-- name: InsertUserGroupsByName :exec
WITH groups AS ( WITH groups AS (
SELECT SELECT
@@ -1489,6 +1539,43 @@ func (q *sqlQuerier) RemoveUserFromAllGroups(ctx context.Context, userID uuid.UU
return err return err
} }
const removeUserFromGroups = `-- name: RemoveUserFromGroups :many
DELETE FROM
group_members
WHERE
user_id = $1 AND
group_id = ANY($2 :: uuid [])
RETURNING group_id
`
type RemoveUserFromGroupsParams struct {
UserID uuid.UUID `db:"user_id" json:"user_id"`
GroupIds []uuid.UUID `db:"group_ids" json:"group_ids"`
}
func (q *sqlQuerier) RemoveUserFromGroups(ctx context.Context, arg RemoveUserFromGroupsParams) ([]uuid.UUID, error) {
rows, err := q.db.QueryContext(ctx, removeUserFromGroups, arg.UserID, pq.Array(arg.GroupIds))
if err != nil {
return nil, err
}
defer rows.Close()
var items []uuid.UUID
for rows.Next() {
var group_id uuid.UUID
if err := rows.Scan(&group_id); err != nil {
return nil, err
}
items = append(items, group_id)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const deleteGroupByID = `-- name: DeleteGroupByID :exec const deleteGroupByID = `-- name: DeleteGroupByID :exec
DELETE FROM DELETE FROM
groups groups
@@ -1592,11 +1679,16 @@ WHERE
) )
ELSE true ELSE true
END END
AND CASE WHEN array_length($3 :: text[], 1) > 0 THEN
groups.name = ANY($3)
ELSE true
END
` `
type GetGroupsParams struct { type GetGroupsParams struct {
OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"` OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"`
HasMemberID uuid.UUID `db:"has_member_id" json:"has_member_id"` HasMemberID uuid.UUID `db:"has_member_id" json:"has_member_id"`
GroupNames []string `db:"group_names" json:"group_names"`
} }
type GetGroupsRow struct { type GetGroupsRow struct {
@@ -1606,7 +1698,7 @@ type GetGroupsRow struct {
} }
func (q *sqlQuerier) GetGroups(ctx context.Context, arg GetGroupsParams) ([]GetGroupsRow, error) { func (q *sqlQuerier) GetGroups(ctx context.Context, arg GetGroupsParams) ([]GetGroupsRow, error) {
rows, err := q.db.QueryContext(ctx, getGroups, arg.OrganizationID, arg.HasMemberID) rows, err := q.db.QueryContext(ctx, getGroups, arg.OrganizationID, arg.HasMemberID, pq.Array(arg.GroupNames))
if err != nil { if err != nil {
return nil, err return nil, err
} }
+29
View File
@@ -29,12 +29,41 @@ SELECT
FROM FROM
groups; groups;
-- InsertUserGroupsByID adds a user to all provided groups, if they exist.
-- name: InsertUserGroupsByID :many
WITH groups AS (
SELECT
id
FROM
groups
WHERE
groups.id = ANY(@group_ids :: uuid [])
)
INSERT INTO
group_members (user_id, group_id)
SELECT
@user_id,
groups.id
FROM
groups
-- If there is a conflict, the user is already a member
ON CONFLICT DO NOTHING
RETURNING group_id;
-- name: RemoveUserFromAllGroups :exec -- name: RemoveUserFromAllGroups :exec
DELETE FROM DELETE FROM
group_members group_members
WHERE WHERE
user_id = @user_id; user_id = @user_id;
-- name: RemoveUserFromGroups :many
DELETE FROM
group_members
WHERE
user_id = @user_id AND
group_id = ANY(@group_ids :: uuid [])
RETURNING group_id;
-- name: InsertGroupMember :exec -- name: InsertGroupMember :exec
INSERT INTO INSERT INTO
group_members (user_id, group_id) group_members (user_id, group_id)
+4
View File
@@ -52,6 +52,10 @@ WHERE
) )
ELSE true ELSE true
END END
AND CASE WHEN array_length(@group_names :: text[], 1) > 0 THEN
groups.name = ANY(@group_names)
ELSE true
END
; ;
-- name: InsertGroup :one -- name: InsertGroup :one
+416
View File
@@ -0,0 +1,416 @@
package idpsync
import (
"context"
"encoding/json"
"fmt"
"regexp"
"github.com/golang-jwt/jwt/v4"
"github.com/google/uuid"
"golang.org/x/xerrors"
"cdr.dev/slog"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/runtimeconfig"
"github.com/coder/coder/v2/coderd/util/slice"
)
type GroupParams struct {
// SyncEnabled if false will skip syncing the user's groups
SyncEnabled bool
MergedClaims jwt.MapClaims
}
func (AGPLIDPSync) GroupSyncEnabled() bool {
// AGPL does not support syncing groups.
return false
}
func (s AGPLIDPSync) GroupSyncSettings() runtimeconfig.RuntimeEntry[*GroupSyncSettings] {
return s.Group
}
func (s AGPLIDPSync) ParseGroupClaims(_ context.Context, _ jwt.MapClaims) (GroupParams, *HTTPError) {
return GroupParams{
SyncEnabled: s.GroupSyncEnabled(),
}, nil
}
func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user database.User, params GroupParams) error {
// Nothing happens if sync is not enabled
if !params.SyncEnabled {
return nil
}
// nolint:gocritic // all syncing is done as a system user
ctx = dbauthz.AsSystemRestricted(ctx)
// Only care about the default org for deployment settings if the
// legacy deployment settings exist.
defaultOrgID := uuid.Nil
// Default organization is configured via legacy deployment values
if s.DeploymentSyncSettings.Legacy.GroupField != "" {
defaultOrganization, err := db.GetDefaultOrganization(ctx)
if err != nil {
return xerrors.Errorf("get default organization: %w", err)
}
defaultOrgID = defaultOrganization.ID
}
err := db.InTx(func(tx database.Store) error {
userGroups, err := tx.GetGroups(ctx, database.GetGroupsParams{
HasMemberID: user.ID,
})
if err != nil {
return xerrors.Errorf("get user groups: %w", err)
}
// Figure out which organizations the user is a member of.
// The "Everyone" group is always included, so we can infer organization
// membership via the groups the user is in.
userOrgs := make(map[uuid.UUID][]database.GetGroupsRow)
for _, g := range userGroups {
g := g
userOrgs[g.Group.OrganizationID] = append(userOrgs[g.Group.OrganizationID], g)
}
// For each org, we need to fetch the sync settings
// This loop also handles any legacy settings for the default
// organization.
orgSettings := make(map[uuid.UUID]GroupSyncSettings)
for orgID := range userOrgs {
orgResolver := s.Manager.OrganizationResolver(tx, orgID)
settings, err := s.SyncSettings.Group.Resolve(ctx, orgResolver)
if err != nil {
if !xerrors.Is(err, runtimeconfig.ErrEntryNotFound) {
return xerrors.Errorf("resolve group sync settings: %w", err)
}
// Default to not being configured
settings = &GroupSyncSettings{}
}
// Legacy deployment settings will override empty settings.
if orgID == defaultOrgID && settings.Field == "" {
settings = &GroupSyncSettings{
Field: s.Legacy.GroupField,
LegacyNameMapping: s.Legacy.GroupMapping,
RegexFilter: s.Legacy.GroupFilter,
AutoCreateMissing: s.Legacy.CreateMissingGroups,
}
}
orgSettings[orgID] = *settings
}
// groupIDsToAdd & groupIDsToRemove are the final group differences
// needed to be applied to user. The loop below will iterate over all
// organizations the user is in, and determine the diffs.
// The diffs are applied as a batch sql query, rather than each
// organization having to execute a query.
groupIDsToAdd := make([]uuid.UUID, 0)
groupIDsToRemove := make([]uuid.UUID, 0)
// For each org, determine which groups the user should land in
for orgID, settings := range orgSettings {
if settings.Field == "" {
// No group sync enabled for this org, so do nothing.
// The user can remain in their groups for this org.
continue
}
// expectedGroups is the set of groups the IDP expects the
// user to be a member of.
expectedGroups, err := settings.ParseClaims(orgID, params.MergedClaims)
if err != nil {
s.Logger.Debug(ctx, "failed to parse claims for groups",
slog.F("organization_field", s.GroupField),
slog.F("organization_id", orgID),
slog.Error(err),
)
// Unsure where to raise this error on the UI or database.
// TODO: This error prevents group sync, but we have no way
// to raise this to an org admin. Come up with a solution to
// notify the admin and user of this issue.
continue
}
// Everyone group is always implied, so include it.
expectedGroups = append(expectedGroups, ExpectedGroup{
OrganizationID: orgID,
GroupID: &orgID,
})
// Now we know what groups the user should be in for a given org,
// determine if we have to do any group updates to sync the user's
// state.
existingGroups := userOrgs[orgID]
existingGroupsTyped := db2sdk.List(existingGroups, func(f database.GetGroupsRow) ExpectedGroup {
return ExpectedGroup{
OrganizationID: orgID,
GroupID: &f.Group.ID,
GroupName: &f.Group.Name,
}
})
add, remove := slice.SymmetricDifferenceFunc(existingGroupsTyped, expectedGroups, func(a, b ExpectedGroup) bool {
return a.Equal(b)
})
for _, r := range remove {
if r.GroupID == nil {
// This should never happen. All group removals come from the
// existing set, which come from the db. All groups from the
// database have IDs. This code is purely defensive.
detail := "user:" + user.Username
if r.GroupName != nil {
detail += fmt.Sprintf(" from group %s", *r.GroupName)
}
return xerrors.Errorf("removal group has nil ID, which should never happen: %s", detail)
}
groupIDsToRemove = append(groupIDsToRemove, *r.GroupID)
}
// HandleMissingGroups will add the new groups to the org if
// the settings specify. It will convert all group names into uuids
// for easier assignment.
// TODO: This code should be batched at the end of the for loop.
// Optimizing this is being pushed because if AutoCreate is disabled,
// this code will only add cost on the first login for each user.
// AutoCreate is usually disabled for large deployments.
// For small deployments, this is less of a problem.
assignGroups, err := settings.HandleMissingGroups(ctx, tx, orgID, add)
if err != nil {
return xerrors.Errorf("handle missing groups: %w", err)
}
groupIDsToAdd = append(groupIDsToAdd, assignGroups...)
}
// ApplyGroupDifference will take the total adds and removes, and apply
// them.
err = s.ApplyGroupDifference(ctx, tx, user, groupIDsToAdd, groupIDsToRemove)
if err != nil {
return xerrors.Errorf("apply group difference: %w", err)
}
return nil
}, nil)
if err != nil {
return err
}
return nil
}
// ApplyGroupDifference will add and remove the user from the specified groups.
func (s AGPLIDPSync) ApplyGroupDifference(ctx context.Context, tx database.Store, user database.User, add []uuid.UUID, removeIDs []uuid.UUID) error {
if len(removeIDs) > 0 {
removedGroupIDs, err := tx.RemoveUserFromGroups(ctx, database.RemoveUserFromGroupsParams{
UserID: user.ID,
GroupIds: removeIDs,
})
if err != nil {
return xerrors.Errorf("remove user from %d groups: %w", len(removeIDs), err)
}
if len(removedGroupIDs) != len(removeIDs) {
s.Logger.Debug(ctx, "user not removed from expected number of groups",
slog.F("user_id", user.ID),
slog.F("groups_removed_count", len(removedGroupIDs)),
slog.F("expected_count", len(removeIDs)),
)
}
}
if len(add) > 0 {
add = slice.Unique(add)
// Defensive programming to only insert uniques.
assignedGroupIDs, err := tx.InsertUserGroupsByID(ctx, database.InsertUserGroupsByIDParams{
UserID: user.ID,
GroupIds: add,
})
if err != nil {
return xerrors.Errorf("insert user into %d groups: %w", len(add), err)
}
if len(assignedGroupIDs) != len(add) {
s.Logger.Debug(ctx, "user not assigned to expected number of groups",
slog.F("user_id", user.ID),
slog.F("groups_assigned_count", len(assignedGroupIDs)),
slog.F("expected_count", len(add)),
)
}
}
return nil
}
type GroupSyncSettings struct {
// Field selects the claim field to be used as the created user's
// groups. If the group field is the empty string, then no group updates
// will ever come from the OIDC provider.
Field string `json:"field"`
// Mapping maps from an OIDC group --> Coder group ID
Mapping map[string][]uuid.UUID `json:"mapping"`
// RegexFilter is a regular expression that filters the groups returned by
// the OIDC provider. Any group not matched by this regex will be ignored.
// If the group filter is nil, then no group filtering will occur.
RegexFilter *regexp.Regexp `json:"regex_filter"`
// AutoCreateMissing controls whether groups returned by the OIDC provider
// are automatically created in Coder if they are missing.
AutoCreateMissing bool `json:"auto_create_missing_groups"`
// LegacyNameMapping is deprecated. It remaps an IDP group name to
// a Coder group name. Since configuration is now done at runtime,
// group IDs are used to account for group renames.
// For legacy configurations, this config option has to remain.
// Deprecated: Use Mapping instead.
LegacyNameMapping map[string]string `json:"legacy_group_name_mapping,omitempty"`
}
func (s *GroupSyncSettings) Set(v string) error {
return json.Unmarshal([]byte(v), s)
}
func (s *GroupSyncSettings) String() string {
return runtimeconfig.JSONString(s)
}
type ExpectedGroup struct {
OrganizationID uuid.UUID
GroupID *uuid.UUID
GroupName *string
}
// Equal compares two ExpectedGroups. The org id must be the same.
// If the group ID is set, it will be compared and take priority, ignoring the
// name value. So 2 groups with the same ID but different names will be
// considered equal.
func (a ExpectedGroup) Equal(b ExpectedGroup) bool {
// Must match
if a.OrganizationID != b.OrganizationID {
return false
}
// Only the name or the name needs to be checked, priority is given to the ID.
if a.GroupID != nil && b.GroupID != nil {
return *a.GroupID == *b.GroupID
}
if a.GroupName != nil && b.GroupName != nil {
return *a.GroupName == *b.GroupName
}
// If everything is nil, it is equal. Although a bit pointless
if a.GroupID == nil && b.GroupID == nil &&
a.GroupName == nil && b.GroupName == nil {
return true
}
return false
}
// ParseClaims will take the merged claims from the IDP and return the groups
// the user is expected to be a member of. The expected group can either be a
// name or an ID.
// It is unfortunate we cannot use exclusively names or exclusively IDs.
// When configuring though, if a group is mapped from "A" -> "UUID 1234", and
// the group "UUID 1234" is renamed, we want to maintain the mapping.
// We have to keep names because group sync supports syncing groups by name if
// the external IDP group name matches the Coder one.
func (s GroupSyncSettings) ParseClaims(orgID uuid.UUID, mergedClaims jwt.MapClaims) ([]ExpectedGroup, error) {
groupsRaw, ok := mergedClaims[s.Field]
if !ok {
return []ExpectedGroup{}, nil
}
parsedGroups, err := ParseStringSliceClaim(groupsRaw)
if err != nil {
return nil, xerrors.Errorf("parse groups field, unexpected type %T: %w", groupsRaw, err)
}
groups := make([]ExpectedGroup, 0)
for _, group := range parsedGroups {
group := group
// Legacy group mappings happen before the regex filter.
mappedGroupName, ok := s.LegacyNameMapping[group]
if ok {
group = mappedGroupName
}
// Only allow through groups that pass the regex
if s.RegexFilter != nil {
if !s.RegexFilter.MatchString(group) {
continue
}
}
mappedGroupIDs, ok := s.Mapping[group]
if ok {
for _, gid := range mappedGroupIDs {
gid := gid
groups = append(groups, ExpectedGroup{OrganizationID: orgID, GroupID: &gid})
}
continue
}
groups = append(groups, ExpectedGroup{OrganizationID: orgID, GroupName: &group})
}
return groups, nil
}
// HandleMissingGroups ensures all ExpectedGroups convert to uuids.
// Groups can be referenced by name via legacy params or IDP group names.
// These group names are converted to IDs for easier assignment.
// Missing groups are created if AutoCreate is enabled.
// TODO: Batching this would be better, as this is 1 or 2 db calls per organization.
func (s GroupSyncSettings) HandleMissingGroups(ctx context.Context, tx database.Store, orgID uuid.UUID, add []ExpectedGroup) ([]uuid.UUID, error) {
// All expected that are missing IDs means the group does not exist
// in the database, or it is a legacy mapping, and we need to do a lookup.
var missingGroups []string
addIDs := make([]uuid.UUID, 0)
for _, expected := range add {
if expected.GroupID == nil && expected.GroupName != nil {
missingGroups = append(missingGroups, *expected.GroupName)
} else if expected.GroupID != nil {
// Keep the IDs to sync the groups.
addIDs = append(addIDs, *expected.GroupID)
}
}
if s.AutoCreateMissing && len(missingGroups) > 0 {
// Insert any missing groups. If the groups already exist, this is a noop.
_, err := tx.InsertMissingGroups(ctx, database.InsertMissingGroupsParams{
OrganizationID: orgID,
Source: database.GroupSourceOidc,
GroupNames: missingGroups,
})
if err != nil {
return nil, xerrors.Errorf("insert missing groups: %w", err)
}
}
// Fetch any missing groups by name. If they exist, their IDs will be
// matched and returned.
if len(missingGroups) > 0 {
// Do name lookups for all groups that are missing IDs.
newGroups, err := tx.GetGroups(ctx, database.GetGroupsParams{
OrganizationID: orgID,
HasMemberID: uuid.UUID{},
GroupNames: missingGroups,
})
if err != nil {
return nil, xerrors.Errorf("get groups by names: %w", err)
}
for _, g := range newGroups {
addIDs = append(addIDs, g.Group.ID)
}
}
return addIDs, nil
}
func ConvertAllowList(allowList []string) map[string]struct{} {
allowMap := make(map[string]struct{}, len(allowList))
for _, group := range allowList {
allowMap[group] = struct{}{}
}
return allowMap
}
+814
View File
@@ -0,0 +1,814 @@
package idpsync_test
import (
"context"
"database/sql"
"regexp"
"testing"
"github.com/golang-jwt/jwt/v4"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"golang.org/x/exp/slices"
"golang.org/x/xerrors"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/coderd/database/dbgen"
"github.com/coder/coder/v2/coderd/database/dbtestutil"
"github.com/coder/coder/v2/coderd/idpsync"
"github.com/coder/coder/v2/coderd/runtimeconfig"
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/testutil"
)
func TestParseGroupClaims(t *testing.T) {
t.Parallel()
t.Run("EmptyConfig", func(t *testing.T) {
t.Parallel()
s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}),
runtimeconfig.NewManager(),
idpsync.DeploymentSyncSettings{})
ctx := testutil.Context(t, testutil.WaitMedium)
params, err := s.ParseGroupClaims(ctx, jwt.MapClaims{})
require.Nil(t, err)
require.False(t, params.SyncEnabled)
})
// AllowList has no effect in AGPL
t.Run("AllowList", func(t *testing.T) {
t.Parallel()
s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}),
runtimeconfig.NewManager(),
idpsync.DeploymentSyncSettings{
GroupField: "groups",
GroupAllowList: map[string]struct{}{
"foo": {},
},
})
ctx := testutil.Context(t, testutil.WaitMedium)
params, err := s.ParseGroupClaims(ctx, jwt.MapClaims{})
require.Nil(t, err)
require.False(t, params.SyncEnabled)
})
}
func TestGroupSyncTable(t *testing.T) {
t.Parallel()
// Last checked, takes 30s with postgres on a fast machine.
if dbtestutil.WillUsePostgres() {
t.Skip("Skipping test because it populates a lot of db entries, which is slow on postgres.")
}
userClaims := jwt.MapClaims{
"groups": []string{
"foo", "bar", "baz",
"create-bar", "create-baz",
"legacy-bar",
},
}
ids := coderdtest.NewDeterministicUUIDGenerator()
testCases := []orgSetupDefinition{
{
Name: "SwitchGroups",
Settings: &idpsync.GroupSyncSettings{
Field: "groups",
Mapping: map[string][]uuid.UUID{
"foo": {ids.ID("sg-foo"), ids.ID("sg-foo-2")},
"bar": {ids.ID("sg-bar")},
"baz": {ids.ID("sg-baz")},
},
},
Groups: map[uuid.UUID]bool{
uuid.New(): true,
uuid.New(): true,
// Extra groups
ids.ID("sg-foo"): false,
ids.ID("sg-foo-2"): false,
ids.ID("sg-bar"): false,
ids.ID("sg-baz"): false,
},
ExpectedGroups: []uuid.UUID{
ids.ID("sg-foo"),
ids.ID("sg-foo-2"),
ids.ID("sg-bar"),
ids.ID("sg-baz"),
},
},
{
Name: "StayInGroup",
Settings: &idpsync.GroupSyncSettings{
Field: "groups",
// Only match foo, so bar does not map
RegexFilter: regexp.MustCompile("^foo$"),
Mapping: map[string][]uuid.UUID{
"foo": {ids.ID("gg-foo"), uuid.New()},
"bar": {ids.ID("gg-bar")},
"baz": {ids.ID("gg-baz")},
},
},
Groups: map[uuid.UUID]bool{
ids.ID("gg-foo"): true,
ids.ID("gg-bar"): false,
},
ExpectedGroups: []uuid.UUID{
ids.ID("gg-foo"),
},
},
{
Name: "UserJoinsGroups",
Settings: &idpsync.GroupSyncSettings{
Field: "groups",
Mapping: map[string][]uuid.UUID{
"foo": {ids.ID("ng-foo"), uuid.New()},
"bar": {ids.ID("ng-bar"), ids.ID("ng-bar-2")},
"baz": {ids.ID("ng-baz")},
},
},
Groups: map[uuid.UUID]bool{
ids.ID("ng-foo"): false,
ids.ID("ng-bar"): false,
ids.ID("ng-bar-2"): false,
ids.ID("ng-baz"): false,
},
ExpectedGroups: []uuid.UUID{
ids.ID("ng-foo"),
ids.ID("ng-bar"),
ids.ID("ng-bar-2"),
ids.ID("ng-baz"),
},
},
{
Name: "CreateGroups",
Settings: &idpsync.GroupSyncSettings{
Field: "groups",
RegexFilter: regexp.MustCompile("^create"),
AutoCreateMissing: true,
},
Groups: map[uuid.UUID]bool{},
ExpectedGroupNames: []string{
"create-bar",
"create-baz",
},
},
{
Name: "GroupNamesNoMapping",
Settings: &idpsync.GroupSyncSettings{
Field: "groups",
RegexFilter: regexp.MustCompile(".*"),
AutoCreateMissing: false,
},
GroupNames: map[string]bool{
"foo": false,
"bar": false,
"goob": true,
},
ExpectedGroupNames: []string{
"foo",
"bar",
},
},
{
Name: "NoUser",
Settings: &idpsync.GroupSyncSettings{
Field: "groups",
Mapping: map[string][]uuid.UUID{
// Extra ID that does not map to a group
"foo": {ids.ID("ow-foo"), uuid.New()},
},
RegexFilter: nil,
AutoCreateMissing: false,
},
NotMember: true,
Groups: map[uuid.UUID]bool{
ids.ID("ow-foo"): false,
ids.ID("ow-bar"): false,
},
},
{
Name: "NoSettingsNoUser",
Settings: nil,
Groups: map[uuid.UUID]bool{},
},
{
Name: "LegacyMapping",
Settings: &idpsync.GroupSyncSettings{
Field: "groups",
RegexFilter: regexp.MustCompile("^legacy"),
LegacyNameMapping: map[string]string{
"create-bar": "legacy-bar",
"foo": "legacy-foo",
"bop": "legacy-bop",
},
AutoCreateMissing: true,
},
Groups: map[uuid.UUID]bool{
ids.ID("lg-foo"): true,
},
GroupNames: map[string]bool{
"legacy-foo": false,
"extra": true,
"legacy-bop": true,
},
ExpectedGroupNames: []string{
"legacy-bar",
"legacy-foo",
},
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.Name, func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
manager := runtimeconfig.NewManager()
s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}),
manager,
idpsync.DeploymentSyncSettings{
GroupField: "groups",
},
)
ctx := testutil.Context(t, testutil.WaitSuperLong)
user := dbgen.User(t, db, database.User{})
orgID := uuid.New()
SetupOrganization(t, s, db, user, orgID, tc)
// Do the group sync!
err := s.SyncGroups(ctx, db, user, idpsync.GroupParams{
SyncEnabled: true,
MergedClaims: userClaims,
})
require.NoError(t, err)
tc.Assert(t, orgID, db, user)
})
}
// AllTogether runs the entire tabled test as a singular user and
// deployment. This tests all organizations being synced together.
// The reason we do them individually, is that it is much easier to
// debug a single test case.
t.Run("AllTogether", func(t *testing.T) {
t.Parallel()
db, _ := dbtestutil.NewDB(t)
manager := runtimeconfig.NewManager()
s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}),
manager,
// Also sync the default org!
idpsync.DeploymentSyncSettings{
GroupField: "groups",
Legacy: idpsync.DefaultOrgLegacySettings{
GroupField: "groups",
GroupMapping: map[string]string{
"foo": "legacy-foo",
"baz": "legacy-baz",
},
GroupFilter: regexp.MustCompile("^legacy"),
CreateMissingGroups: true,
},
},
)
ctx := testutil.Context(t, testutil.WaitSuperLong)
user := dbgen.User(t, db, database.User{})
var asserts []func(t *testing.T)
// The default org is also going to do something
def := orgSetupDefinition{
Name: "DefaultOrg",
GroupNames: map[string]bool{
"legacy-foo": false,
"legacy-baz": true,
"random": true,
},
// No settings, because they come from the deployment values
Settings: nil,
ExpectedGroups: nil,
ExpectedGroupNames: []string{"legacy-foo", "legacy-baz", "legacy-bar"},
}
//nolint:gocritic // testing
defOrg, err := db.GetDefaultOrganization(dbauthz.AsSystemRestricted(ctx))
require.NoError(t, err)
SetupOrganization(t, s, db, user, defOrg.ID, def)
asserts = append(asserts, func(t *testing.T) {
t.Run(def.Name, func(t *testing.T) {
t.Parallel()
def.Assert(t, defOrg.ID, db, user)
})
})
for _, tc := range testCases {
tc := tc
orgID := uuid.New()
SetupOrganization(t, s, db, user, orgID, tc)
asserts = append(asserts, func(t *testing.T) {
t.Run(tc.Name, func(t *testing.T) {
t.Parallel()
tc.Assert(t, orgID, db, user)
})
})
}
asserts = append(asserts, func(t *testing.T) {
t.Helper()
def.Assert(t, defOrg.ID, db, user)
})
// Do the group sync!
err = s.SyncGroups(ctx, db, user, idpsync.GroupParams{
SyncEnabled: true,
MergedClaims: userClaims,
})
require.NoError(t, err)
for _, assert := range asserts {
assert(t)
}
})
}
func TestSyncDisabled(t *testing.T) {
t.Parallel()
if dbtestutil.WillUsePostgres() {
t.Skip("Skipping test because it populates a lot of db entries, which is slow on postgres.")
}
db, _ := dbtestutil.NewDB(t)
manager := runtimeconfig.NewManager()
s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}),
manager,
idpsync.DeploymentSyncSettings{},
)
ids := coderdtest.NewDeterministicUUIDGenerator()
ctx := testutil.Context(t, testutil.WaitSuperLong)
user := dbgen.User(t, db, database.User{})
orgID := uuid.New()
def := orgSetupDefinition{
Name: "SyncDisabled",
Groups: map[uuid.UUID]bool{
ids.ID("foo"): true,
ids.ID("bar"): true,
ids.ID("baz"): false,
ids.ID("bop"): false,
},
Settings: &idpsync.GroupSyncSettings{
Field: "groups",
Mapping: map[string][]uuid.UUID{
"foo": {ids.ID("foo")},
"baz": {ids.ID("baz")},
},
},
ExpectedGroups: []uuid.UUID{
ids.ID("foo"),
ids.ID("bar"),
},
}
SetupOrganization(t, s, db, user, orgID, def)
// Do the group sync!
err := s.SyncGroups(ctx, db, user, idpsync.GroupParams{
SyncEnabled: false,
MergedClaims: jwt.MapClaims{
"groups": []string{"baz", "bop"},
},
})
require.NoError(t, err)
def.Assert(t, orgID, db, user)
}
// TestApplyGroupDifference is mainly testing the database functions
func TestApplyGroupDifference(t *testing.T) {
t.Parallel()
ids := coderdtest.NewDeterministicUUIDGenerator()
testCase := []struct {
Name string
Before map[uuid.UUID]bool
Add []uuid.UUID
Remove []uuid.UUID
Expect []uuid.UUID
}{
{
Name: "Empty",
},
{
Name: "AddFromNone",
Before: map[uuid.UUID]bool{
ids.ID("g1"): false,
},
Add: []uuid.UUID{
ids.ID("g1"),
},
Expect: []uuid.UUID{
ids.ID("g1"),
},
},
{
Name: "AddSome",
Before: map[uuid.UUID]bool{
ids.ID("g1"): true,
ids.ID("g2"): false,
ids.ID("g3"): false,
uuid.New(): false,
},
Add: []uuid.UUID{
ids.ID("g2"),
ids.ID("g3"),
},
Expect: []uuid.UUID{
ids.ID("g1"),
ids.ID("g2"),
ids.ID("g3"),
},
},
{
Name: "RemoveAll",
Before: map[uuid.UUID]bool{
uuid.New(): false,
ids.ID("g2"): true,
ids.ID("g3"): true,
},
Remove: []uuid.UUID{
ids.ID("g2"),
ids.ID("g3"),
},
Expect: []uuid.UUID{},
},
{
Name: "Mixed",
Before: map[uuid.UUID]bool{
// adds
ids.ID("a1"): true,
ids.ID("a2"): true,
ids.ID("a3"): false,
ids.ID("a4"): false,
// removes
ids.ID("r1"): true,
ids.ID("r2"): true,
ids.ID("r3"): false,
ids.ID("r4"): false,
// stable
ids.ID("s1"): true,
ids.ID("s2"): true,
// noise
uuid.New(): false,
uuid.New(): false,
},
Add: []uuid.UUID{
ids.ID("a1"), ids.ID("a2"),
ids.ID("a3"), ids.ID("a4"),
// Double up to try and confuse
ids.ID("a1"),
ids.ID("a4"),
},
Remove: []uuid.UUID{
ids.ID("r1"), ids.ID("r2"),
ids.ID("r3"), ids.ID("r4"),
// Double up to try and confuse
ids.ID("r1"),
ids.ID("r4"),
},
Expect: []uuid.UUID{
ids.ID("a1"), ids.ID("a2"), ids.ID("a3"), ids.ID("a4"),
ids.ID("s1"), ids.ID("s2"),
},
},
}
for _, tc := range testCase {
tc := tc
t.Run(tc.Name, func(t *testing.T) {
t.Parallel()
mgr := runtimeconfig.NewManager()
db, _ := dbtestutil.NewDB(t)
ctx := testutil.Context(t, testutil.WaitMedium)
//nolint:gocritic // testing
ctx = dbauthz.AsSystemRestricted(ctx)
org := dbgen.Organization(t, db, database.Organization{})
_, err := db.InsertAllUsersGroup(ctx, org.ID)
require.NoError(t, err)
user := dbgen.User(t, db, database.User{})
_ = dbgen.OrganizationMember(t, db, database.OrganizationMember{
UserID: user.ID,
OrganizationID: org.ID,
})
for gid, in := range tc.Before {
group := dbgen.Group(t, db, database.Group{
ID: gid,
OrganizationID: org.ID,
})
if in {
_ = dbgen.GroupMember(t, db, database.GroupMemberTable{
UserID: user.ID,
GroupID: group.ID,
})
}
}
s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), mgr, idpsync.FromDeploymentValues(coderdtest.DeploymentValues(t)))
err = s.ApplyGroupDifference(context.Background(), db, user, tc.Add, tc.Remove)
require.NoError(t, err)
userGroups, err := db.GetGroups(ctx, database.GetGroupsParams{
HasMemberID: user.ID,
})
require.NoError(t, err)
// assert
found := db2sdk.List(userGroups, func(g database.GetGroupsRow) uuid.UUID {
return g.Group.ID
})
// Add everyone group
require.ElementsMatch(t, append(tc.Expect, org.ID), found)
})
}
}
func TestExpectedGroupEqual(t *testing.T) {
t.Parallel()
ids := coderdtest.NewDeterministicUUIDGenerator()
testCases := []struct {
Name string
A idpsync.ExpectedGroup
B idpsync.ExpectedGroup
Equal bool
}{
{
Name: "Empty",
A: idpsync.ExpectedGroup{},
B: idpsync.ExpectedGroup{},
Equal: true,
},
{
Name: "DifferentOrgs",
A: idpsync.ExpectedGroup{
OrganizationID: uuid.New(),
GroupID: ptr.Ref(ids.ID("g1")),
GroupName: nil,
},
B: idpsync.ExpectedGroup{
OrganizationID: uuid.New(),
GroupID: ptr.Ref(ids.ID("g1")),
GroupName: nil,
},
Equal: false,
},
{
Name: "SameID",
A: idpsync.ExpectedGroup{
OrganizationID: ids.ID("org"),
GroupID: ptr.Ref(ids.ID("g1")),
GroupName: nil,
},
B: idpsync.ExpectedGroup{
OrganizationID: ids.ID("org"),
GroupID: ptr.Ref(ids.ID("g1")),
GroupName: nil,
},
Equal: true,
},
{
Name: "DifferentIDs",
A: idpsync.ExpectedGroup{
OrganizationID: ids.ID("org"),
GroupID: ptr.Ref(uuid.New()),
GroupName: nil,
},
B: idpsync.ExpectedGroup{
OrganizationID: ids.ID("org"),
GroupID: ptr.Ref(uuid.New()),
GroupName: nil,
},
Equal: false,
},
{
Name: "SameName",
A: idpsync.ExpectedGroup{
OrganizationID: ids.ID("org"),
GroupID: nil,
GroupName: ptr.Ref("foo"),
},
B: idpsync.ExpectedGroup{
OrganizationID: ids.ID("org"),
GroupID: nil,
GroupName: ptr.Ref("foo"),
},
Equal: true,
},
{
Name: "DifferentName",
A: idpsync.ExpectedGroup{
OrganizationID: ids.ID("org"),
GroupID: nil,
GroupName: ptr.Ref("foo"),
},
B: idpsync.ExpectedGroup{
OrganizationID: ids.ID("org"),
GroupID: nil,
GroupName: ptr.Ref("bar"),
},
Equal: false,
},
// Edge cases
{
// A bit strange, but valid as ID takes priority.
// We assume 2 groups with the same ID are equal, even if
// their names are different. Names are mutable, IDs are not,
// so there is 0% chance they are different groups.
Name: "DifferentIDSameName",
A: idpsync.ExpectedGroup{
OrganizationID: ids.ID("org"),
GroupID: ptr.Ref(ids.ID("g1")),
GroupName: ptr.Ref("foo"),
},
B: idpsync.ExpectedGroup{
OrganizationID: ids.ID("org"),
GroupID: ptr.Ref(ids.ID("g1")),
GroupName: ptr.Ref("bar"),
},
Equal: true,
},
{
Name: "MixedNils",
A: idpsync.ExpectedGroup{
OrganizationID: ids.ID("org"),
GroupID: ptr.Ref(ids.ID("g1")),
GroupName: nil,
},
B: idpsync.ExpectedGroup{
OrganizationID: ids.ID("org"),
GroupID: nil,
GroupName: ptr.Ref("bar"),
},
Equal: false,
},
{
Name: "NoComparable",
A: idpsync.ExpectedGroup{
OrganizationID: ids.ID("org"),
GroupID: ptr.Ref(ids.ID("g1")),
GroupName: nil,
},
B: idpsync.ExpectedGroup{
OrganizationID: ids.ID("org"),
GroupID: nil,
GroupName: nil,
},
Equal: false,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.Name, func(t *testing.T) {
t.Parallel()
require.Equal(t, tc.Equal, tc.A.Equal(tc.B))
})
}
}
func SetupOrganization(t *testing.T, s *idpsync.AGPLIDPSync, db database.Store, user database.User, orgID uuid.UUID, def orgSetupDefinition) {
t.Helper()
// Account that the org might be the default organization
org, err := db.GetOrganizationByID(context.Background(), orgID)
if xerrors.Is(err, sql.ErrNoRows) {
org = dbgen.Organization(t, db, database.Organization{
ID: orgID,
})
}
_, err = db.InsertAllUsersGroup(context.Background(), org.ID)
if !database.IsUniqueViolation(err) {
require.NoError(t, err, "Everyone group for an org")
}
manager := runtimeconfig.NewManager()
orgResolver := manager.OrganizationResolver(db, org.ID)
err = s.Group.SetRuntimeValue(context.Background(), orgResolver, def.Settings)
require.NoError(t, err)
if !def.NotMember {
dbgen.OrganizationMember(t, db, database.OrganizationMember{
UserID: user.ID,
OrganizationID: org.ID,
})
}
for groupID, in := range def.Groups {
dbgen.Group(t, db, database.Group{
ID: groupID,
OrganizationID: org.ID,
})
if in {
dbgen.GroupMember(t, db, database.GroupMemberTable{
UserID: user.ID,
GroupID: groupID,
})
}
}
for groupName, in := range def.GroupNames {
group := dbgen.Group(t, db, database.Group{
Name: groupName,
OrganizationID: org.ID,
})
if in {
dbgen.GroupMember(t, db, database.GroupMemberTable{
UserID: user.ID,
GroupID: group.ID,
})
}
}
}
type orgSetupDefinition struct {
Name string
// True if the user is a member of the group
Groups map[uuid.UUID]bool
GroupNames map[string]bool
NotMember bool
Settings *idpsync.GroupSyncSettings
ExpectedGroups []uuid.UUID
ExpectedGroupNames []string
}
func (o orgSetupDefinition) Assert(t *testing.T, orgID uuid.UUID, db database.Store, user database.User) {
t.Helper()
ctx := context.Background()
members, err := db.OrganizationMembers(ctx, database.OrganizationMembersParams{
OrganizationID: orgID,
UserID: user.ID,
})
require.NoError(t, err)
if o.NotMember {
require.Len(t, members, 0, "should not be a member")
} else {
require.Len(t, members, 1, "should be a member")
}
userGroups, err := db.GetGroups(ctx, database.GetGroupsParams{
OrganizationID: orgID,
HasMemberID: user.ID,
})
require.NoError(t, err)
if o.ExpectedGroups == nil {
o.ExpectedGroups = make([]uuid.UUID, 0)
}
if len(o.ExpectedGroupNames) > 0 && len(o.ExpectedGroups) > 0 {
t.Fatal("ExpectedGroups and ExpectedGroupNames are mutually exclusive")
}
// Everyone groups mess up our asserts
userGroups = slices.DeleteFunc(userGroups, func(row database.GetGroupsRow) bool {
return row.Group.ID == row.Group.OrganizationID
})
if len(o.ExpectedGroupNames) > 0 {
found := db2sdk.List(userGroups, func(g database.GetGroupsRow) string {
return g.Group.Name
})
require.ElementsMatch(t, o.ExpectedGroupNames, found, "user groups by name")
require.Len(t, o.ExpectedGroups, 0, "ExpectedGroups should be empty")
} else {
// Check by ID, recommended
found := db2sdk.List(userGroups, func(g database.GetGroupsRow) uuid.UUID {
return g.Group.ID
})
require.ElementsMatch(t, o.ExpectedGroups, found, "user groups")
require.Len(t, o.ExpectedGroupNames, 0, "ExpectedGroupNames should be empty")
}
}
+69 -15
View File
@@ -3,6 +3,7 @@ package idpsync
import ( import (
"context" "context"
"net/http" "net/http"
"regexp"
"strings" "strings"
"github.com/golang-jwt/jwt/v4" "github.com/golang-jwt/jwt/v4"
@@ -12,6 +13,7 @@ import (
"cdr.dev/slog" "cdr.dev/slog"
"github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/runtimeconfig"
"github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/site" "github.com/coder/coder/v2/site"
) )
@@ -25,21 +27,34 @@ type IDPSync interface {
OrganizationSyncEnabled() bool OrganizationSyncEnabled() bool
// ParseOrganizationClaims takes claims from an OIDC provider, and returns the // ParseOrganizationClaims takes claims from an OIDC provider, and returns the
// organization sync params for assigning users into organizations. // organization sync params for assigning users into organizations.
ParseOrganizationClaims(ctx context.Context, _ jwt.MapClaims) (OrganizationParams, *HTTPError) ParseOrganizationClaims(ctx context.Context, mergedClaims jwt.MapClaims) (OrganizationParams, *HTTPError)
// SyncOrganizations assigns and removed users from organizations based on the // SyncOrganizations assigns and removed users from organizations based on the
// provided params. // provided params.
SyncOrganizations(ctx context.Context, tx database.Store, user database.User, params OrganizationParams) error SyncOrganizations(ctx context.Context, tx database.Store, user database.User, params OrganizationParams) error
GroupSyncEnabled() bool
// ParseGroupClaims takes claims from an OIDC provider, and returns the params
// for group syncing. Most of the logic happens in SyncGroups.
ParseGroupClaims(ctx context.Context, mergedClaims jwt.MapClaims) (GroupParams, *HTTPError)
// SyncGroups assigns and removes users from groups based on the provided params.
SyncGroups(ctx context.Context, db database.Store, user database.User, params GroupParams) error
// GroupSyncSettings is exposed for the API to implement CRUD operations
// on the settings used by IDPSync. This entry is thread safe and can be
// accessed concurrently. The settings are stored in the database.
GroupSyncSettings() runtimeconfig.RuntimeEntry[*GroupSyncSettings]
} }
// AGPLIDPSync is the configuration for syncing user information from an external // AGPLIDPSync is the configuration for syncing user information from an external
// IDP. All related code to syncing user information should be in this package. // IDP. All related code to syncing user information should be in this package.
type AGPLIDPSync struct { type AGPLIDPSync struct {
Logger slog.Logger Logger slog.Logger
Manager *runtimeconfig.Manager
SyncSettings SyncSettings
} }
type SyncSettings struct { // DeploymentSyncSettings are static and are sourced from the deployment config.
type DeploymentSyncSettings struct {
// OrganizationField selects the claim field to be used as the created user's // OrganizationField selects the claim field to be used as the created user's
// organizations. If the field is the empty string, then no organization updates // organizations. If the field is the empty string, then no organization updates
// will ever come from the OIDC provider. // will ever come from the OIDC provider.
@@ -50,23 +65,62 @@ type SyncSettings struct {
// placed into the default organization. This is mostly a hack to support // placed into the default organization. This is mostly a hack to support
// legacy deployments. // legacy deployments.
OrganizationAssignDefault bool OrganizationAssignDefault bool
// GroupField at the deployment level is used for deployment level group claim
// settings.
GroupField string
// GroupAllowList (if set) will restrict authentication to only users who
// have at least one group in this list.
// A map representation is used for easier lookup.
GroupAllowList map[string]struct{}
// Legacy deployment settings that only apply to the default org.
Legacy DefaultOrgLegacySettings
} }
type OrganizationParams struct { type DefaultOrgLegacySettings struct {
// SyncEnabled if false will skip syncing the user's organizations. GroupField string
SyncEnabled bool GroupMapping map[string]string
// IncludeDefault is primarily for single org deployments. It will ensure GroupFilter *regexp.Regexp
// a user is always inserted into the default org. CreateMissingGroups bool
IncludeDefault bool
// Organizations is the list of organizations the user should be a member of
// assuming syncing is turned on.
Organizations []uuid.UUID
} }
func NewAGPLSync(logger slog.Logger, settings SyncSettings) *AGPLIDPSync { func FromDeploymentValues(dv *codersdk.DeploymentValues) DeploymentSyncSettings {
if dv == nil {
panic("Developer error: DeploymentValues should not be nil")
}
return DeploymentSyncSettings{
OrganizationField: dv.OIDC.OrganizationField.Value(),
OrganizationMapping: dv.OIDC.OrganizationMapping.Value,
OrganizationAssignDefault: dv.OIDC.OrganizationAssignDefault.Value(),
// TODO: Separate group field for allow list from default org.
// Right now you cannot disable group sync from the default org and
// configure an allow list.
GroupField: dv.OIDC.GroupField.Value(),
GroupAllowList: ConvertAllowList(dv.OIDC.GroupAllowList.Value()),
Legacy: DefaultOrgLegacySettings{
GroupField: dv.OIDC.GroupField.Value(),
GroupMapping: dv.OIDC.GroupMapping.Value,
GroupFilter: dv.OIDC.GroupRegexFilter.Value(),
CreateMissingGroups: dv.OIDC.GroupAutoCreate.Value(),
},
}
}
type SyncSettings struct {
DeploymentSyncSettings
Group runtimeconfig.RuntimeEntry[*GroupSyncSettings]
}
func NewAGPLSync(logger slog.Logger, manager *runtimeconfig.Manager, settings DeploymentSyncSettings) *AGPLIDPSync {
return &AGPLIDPSync{ return &AGPLIDPSync{
Logger: logger.Named("idp-sync"), Logger: logger.Named("idp-sync"),
SyncSettings: settings, Manager: manager,
SyncSettings: SyncSettings{
DeploymentSyncSettings: settings,
Group: runtimeconfig.MustNew[*GroupSyncSettings]("group-sync-settings"),
},
} }
} }
+11
View File
@@ -16,6 +16,17 @@ import (
"github.com/coder/coder/v2/coderd/util/slice" "github.com/coder/coder/v2/coderd/util/slice"
) )
type OrganizationParams struct {
// SyncEnabled if false will skip syncing the user's organizations.
SyncEnabled bool
// IncludeDefault is primarily for single org deployments. It will ensure
// a user is always inserted into the default org.
IncludeDefault bool
// Organizations is the list of organizations the user should be a member of
// assuming syncing is turned on.
Organizations []uuid.UUID
}
func (AGPLIDPSync) OrganizationSyncEnabled() bool { func (AGPLIDPSync) OrganizationSyncEnabled() bool {
// AGPL does not support syncing organizations. // AGPL does not support syncing organizations.
return false return false
+17 -12
View File
@@ -9,6 +9,7 @@ import (
"cdr.dev/slog/sloggers/slogtest" "cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/idpsync" "github.com/coder/coder/v2/coderd/idpsync"
"github.com/coder/coder/v2/coderd/runtimeconfig"
"github.com/coder/coder/v2/testutil" "github.com/coder/coder/v2/testutil"
) )
@@ -18,11 +19,13 @@ func TestParseOrganizationClaims(t *testing.T) {
t.Run("SingleOrgDeployment", func(t *testing.T) { t.Run("SingleOrgDeployment", func(t *testing.T) {
t.Parallel() t.Parallel()
s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), idpsync.SyncSettings{ s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}),
OrganizationField: "", runtimeconfig.NewManager(),
OrganizationMapping: nil, idpsync.DeploymentSyncSettings{
OrganizationAssignDefault: true, OrganizationField: "",
}) OrganizationMapping: nil,
OrganizationAssignDefault: true,
})
ctx := testutil.Context(t, testutil.WaitMedium) ctx := testutil.Context(t, testutil.WaitMedium)
@@ -38,13 +41,15 @@ func TestParseOrganizationClaims(t *testing.T) {
t.Parallel() t.Parallel()
// AGPL has limited behavior // AGPL has limited behavior
s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), idpsync.SyncSettings{ s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}),
OrganizationField: "orgs", runtimeconfig.NewManager(),
OrganizationMapping: map[string][]uuid.UUID{ idpsync.DeploymentSyncSettings{
"random": {uuid.New()}, OrganizationField: "orgs",
}, OrganizationMapping: map[string][]uuid.UUID{
OrganizationAssignDefault: false, "random": {uuid.New()},
}) },
OrganizationAssignDefault: false,
})
ctx := testutil.Context(t, testutil.WaitMedium) ctx := testutil.Context(t, testutil.WaitMedium)
+9
View File
@@ -2,6 +2,7 @@ package runtimeconfig
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"golang.org/x/xerrors" "golang.org/x/xerrors"
@@ -93,3 +94,11 @@ func (e *RuntimeEntry[T]) name() (string, error) {
return e.n, nil return e.n, nil
} }
func JSONString(v any) string {
s, err := json.Marshal(v)
if err != nil {
return "decode failed: " + err.Error()
}
return string(s)
}
+29 -173
View File
@@ -8,7 +8,6 @@ import (
"fmt" "fmt"
"net/http" "net/http"
"net/mail" "net/mail"
"regexp"
"sort" "sort"
"strconv" "strconv"
"strings" "strings"
@@ -20,7 +19,6 @@ import (
"github.com/google/go-github/v43/github" "github.com/google/go-github/v43/github"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/moby/moby/pkg/namesgenerator" "github.com/moby/moby/pkg/namesgenerator"
"golang.org/x/exp/slices"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"golang.org/x/xerrors" "golang.org/x/xerrors"
@@ -659,6 +657,9 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) {
AvatarURL: ghUser.GetAvatarURL(), AvatarURL: ghUser.GetAvatarURL(),
Name: normName, Name: normName,
DebugContext: OauthDebugContext{}, DebugContext: OauthDebugContext{},
GroupSync: idpsync.GroupParams{
SyncEnabled: false,
},
OrganizationSync: idpsync.OrganizationParams{ OrganizationSync: idpsync.OrganizationParams{
SyncEnabled: false, SyncEnabled: false,
IncludeDefault: true, IncludeDefault: true,
@@ -739,27 +740,6 @@ type OIDCConfig struct {
// support the userinfo endpoint, or if the userinfo endpoint causes // support the userinfo endpoint, or if the userinfo endpoint causes
// undesirable behavior. // undesirable behavior.
IgnoreUserInfo bool IgnoreUserInfo bool
// TODO: Move all idp fields into the IDPSync struct
// GroupField selects the claim field to be used as the created user's
// groups. If the group field is the empty string, then no group updates
// will ever come from the OIDC provider.
GroupField string
// CreateMissingGroups controls whether groups returned by the OIDC provider
// are automatically created in Coder if they are missing.
CreateMissingGroups bool
// GroupFilter is a regular expression that filters the groups returned by
// the OIDC provider. Any group not matched by this regex will be ignored.
// If the group filter is nil, then no group filtering will occur.
GroupFilter *regexp.Regexp
// GroupAllowList is a list of groups that are allowed to log in.
// If the list length is 0, then the allow list will not be applied and
// this feature is disabled.
GroupAllowList map[string]bool
// GroupMapping controls how groups returned by the OIDC provider get mapped
// to groups within Coder.
// map[oidcGroupName]coderGroupName
GroupMapping map[string]string
// UserRoleField selects the claim field to be used as the created user's // UserRoleField selects the claim field to be used as the created user's
// roles. If the field is the empty string, then no role updates // roles. If the field is the empty string, then no role updates
// will ever come from the OIDC provider. // will ever come from the OIDC provider.
@@ -1002,11 +982,6 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
} }
ctx = slog.With(ctx, slog.F("email", email), slog.F("username", username), slog.F("name", name)) ctx = slog.With(ctx, slog.F("email", email), slog.F("username", username), slog.F("name", name))
usingGroups, groups, groupErr := api.oidcGroups(ctx, mergedClaims)
if groupErr != nil {
groupErr.Write(rw, r)
return
}
roles, roleErr := api.oidcRoles(ctx, mergedClaims) roles, roleErr := api.oidcRoles(ctx, mergedClaims)
if roleErr != nil { if roleErr != nil {
@@ -1030,6 +1005,12 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
return return
} }
groupSync, groupSyncErr := api.IDPSync.ParseGroupClaims(ctx, mergedClaims)
if groupSyncErr != nil {
groupSyncErr.Write(rw, r)
return
}
// If a new user is authenticating for the first time // If a new user is authenticating for the first time
// the audit action is 'register', not 'login' // the audit action is 'register', not 'login'
if user.ID == uuid.Nil { if user.ID == uuid.Nil {
@@ -1037,23 +1018,20 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
} }
params := (&oauthLoginParams{ params := (&oauthLoginParams{
User: user, User: user,
Link: link, Link: link,
State: state, State: state,
LinkedID: oidcLinkedID(idToken), LinkedID: oidcLinkedID(idToken),
LoginType: database.LoginTypeOIDC, LoginType: database.LoginTypeOIDC,
AllowSignups: api.OIDCConfig.AllowSignups, AllowSignups: api.OIDCConfig.AllowSignups,
Email: email, Email: email,
Username: username, Username: username,
Name: name, Name: name,
AvatarURL: picture, AvatarURL: picture,
UsingRoles: api.OIDCConfig.RoleSyncEnabled(), UsingRoles: api.OIDCConfig.RoleSyncEnabled(),
Roles: roles, Roles: roles,
UsingGroups: usingGroups, OrganizationSync: orgSync,
Groups: groups, GroupSync: groupSync,
OrganizationSync: orgSync,
CreateMissingGroups: api.OIDCConfig.CreateMissingGroups,
GroupFilter: api.OIDCConfig.GroupFilter,
DebugContext: OauthDebugContext{ DebugContext: OauthDebugContext{
IDTokenClaims: idtokenClaims, IDTokenClaims: idtokenClaims,
UserInfoClaims: userInfoClaims, UserInfoClaims: userInfoClaims,
@@ -1089,79 +1067,6 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
http.Redirect(rw, r, redirect, http.StatusTemporaryRedirect) http.Redirect(rw, r, redirect, http.StatusTemporaryRedirect)
} }
// oidcGroups returns the groups for the user from the OIDC claims.
func (api *API) oidcGroups(ctx context.Context, mergedClaims map[string]interface{}) (bool, []string, *idpsync.HTTPError) {
logger := api.Logger.Named(userAuthLoggerName)
usingGroups := false
var groups []string
// If the GroupField is the empty string, then groups from OIDC are not used.
// This is so we can support manual group assignment.
if api.OIDCConfig.GroupField != "" {
// If the allow list is empty, then the user is allowed to log in.
// Otherwise, they must belong to at least 1 group in the allow list.
inAllowList := len(api.OIDCConfig.GroupAllowList) == 0
usingGroups = true
groupsRaw, ok := mergedClaims[api.OIDCConfig.GroupField]
if ok {
parsedGroups, err := idpsync.ParseStringSliceClaim(groupsRaw)
if err != nil {
api.Logger.Debug(ctx, "groups field was an unknown type in oidc claims",
slog.F("type", fmt.Sprintf("%T", groupsRaw)),
slog.Error(err),
)
return false, nil, &idpsync.HTTPError{
Code: http.StatusBadRequest,
Msg: "Failed to sync groups from OIDC claims",
Detail: err.Error(),
RenderStaticPage: false,
}
}
api.Logger.Debug(ctx, "groups returned in oidc claims",
slog.F("len", len(parsedGroups)),
slog.F("groups", parsedGroups),
)
for _, group := range parsedGroups {
if mappedGroup, ok := api.OIDCConfig.GroupMapping[group]; ok {
group = mappedGroup
}
if _, ok := api.OIDCConfig.GroupAllowList[group]; ok {
inAllowList = true
}
groups = append(groups, group)
}
}
if !inAllowList {
logger.Debug(ctx, "oidc group claim not in allow list, rejecting login",
slog.F("allow_list_count", len(api.OIDCConfig.GroupAllowList)),
slog.F("user_group_count", len(groups)),
)
detail := "Ask an administrator to add one of your groups to the whitelist"
if len(groups) == 0 {
detail = "You are currently not a member of any groups! Ask an administrator to add you to an authorized group to login."
}
return usingGroups, groups, &idpsync.HTTPError{
Code: http.StatusForbidden,
Msg: "Not a member of an allowed group",
Detail: detail,
RenderStaticPage: true,
}
}
}
// This conditional is purely to warn the user they might have misconfigured their OIDC
// configuration.
if _, groupClaimExists := mergedClaims["groups"]; !usingGroups && groupClaimExists {
logger.Debug(ctx, "claim 'groups' was returned, but 'oidc-group-field' is not set, check your coder oidc settings")
}
return usingGroups, groups, nil
}
// oidcRoles returns the roles for the user from the OIDC claims. // oidcRoles returns the roles for the user from the OIDC claims.
// If the function returns false, then the caller should return early. // If the function returns false, then the caller should return early.
// All writes to the response writer are handled by this function. // All writes to the response writer are handled by this function.
@@ -1276,14 +1181,7 @@ type oauthLoginParams struct {
AvatarURL string AvatarURL string
// OrganizationSync has the organizations that the user will be assigned to. // OrganizationSync has the organizations that the user will be assigned to.
OrganizationSync idpsync.OrganizationParams OrganizationSync idpsync.OrganizationParams
// Is UsingGroups is true, then the user will be assigned GroupSync idpsync.GroupParams
// to the Groups provided.
UsingGroups bool
CreateMissingGroups bool
// These are the group names from the IDP. Internally, they will map to
// some organization groups.
Groups []string
GroupFilter *regexp.Regexp
// Is UsingRoles is true, then the user will be assigned // Is UsingRoles is true, then the user will be assigned
// the roles provided. // the roles provided.
UsingRoles bool UsingRoles bool
@@ -1489,53 +1387,11 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C
return xerrors.Errorf("sync organizations: %w", err) return xerrors.Errorf("sync organizations: %w", err)
} }
// Ensure groups are correct. // Group sync needs to occur after org sync, since a user can join an org,
// This places all groups into the default organization. // then have their groups sync to said org.
// To go multi-org, we need to add a mapping feature here to know which err = api.IDPSync.SyncGroups(ctx, tx, user, params.GroupSync)
// groups go to which orgs. if err != nil {
if params.UsingGroups { return xerrors.Errorf("sync groups: %w", err)
filtered := params.Groups
if params.GroupFilter != nil {
filtered = make([]string, 0, len(params.Groups))
for _, group := range params.Groups {
if params.GroupFilter.MatchString(group) {
filtered = append(filtered, group)
}
}
}
//nolint:gocritic // No user present in the context.
defaultOrganization, err := tx.GetDefaultOrganization(dbauthz.AsSystemRestricted(ctx))
if err != nil {
// If there is no default org, then we can't assign groups.
// By default, we assume all groups belong to the default org.
return xerrors.Errorf("get default organization: %w", err)
}
//nolint:gocritic // No user present in the context.
memberships, err := tx.OrganizationMembers(dbauthz.AsSystemRestricted(ctx), database.OrganizationMembersParams{
UserID: user.ID,
OrganizationID: uuid.Nil,
})
if err != nil {
return xerrors.Errorf("get organization memberships: %w", err)
}
// If the user is not in the default organization, then we can't assign groups.
// A user cannot be in groups to an org they are not a member of.
if !slices.ContainsFunc(memberships, func(member database.OrganizationMembersRow) bool {
return member.OrganizationMember.OrganizationID == defaultOrganization.ID
}) {
return xerrors.Errorf("user %s is not a member of the default organization, cannot assign to groups in the org", user.ID)
}
//nolint:gocritic
err = api.Options.SetUserGroups(dbauthz.AsSystemRestricted(ctx), logger, tx, user.ID, map[uuid.UUID][]string{
defaultOrganization.ID: filtered,
}, params.CreateMissingGroups)
if err != nil {
return xerrors.Errorf("set user groups: %w", err)
}
} }
// Ensure roles are correct. // Ensure roles are correct.
+5 -8
View File
@@ -80,13 +80,6 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
if options.Entitlements == nil { if options.Entitlements == nil {
options.Entitlements = entitlements.New() options.Entitlements = entitlements.New()
} }
if options.IDPSync == nil {
options.IDPSync = enidpsync.NewSync(options.Logger, options.Entitlements, idpsync.SyncSettings{
OrganizationField: options.DeploymentValues.OIDC.OrganizationField.Value(),
OrganizationMapping: options.DeploymentValues.OIDC.OrganizationMapping.Value,
OrganizationAssignDefault: options.DeploymentValues.OIDC.OrganizationAssignDefault.Value(),
})
}
ctx, cancelFunc := context.WithCancel(ctx) ctx, cancelFunc := context.WithCancel(ctx)
@@ -118,6 +111,11 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
} }
options.Database = cryptDB options.Database = cryptDB
if options.IDPSync == nil {
options.IDPSync = enidpsync.NewSync(options.Logger, options.RuntimeConfig, options.Entitlements, idpsync.FromDeploymentValues(options.DeploymentValues))
}
api := &API{ api := &API{
ctx: ctx, ctx: ctx,
cancel: cancelFunc, cancel: cancelFunc,
@@ -147,7 +145,6 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
} }
return c.Subject, c.Trial, nil return c.Subject, c.Trial, nil
} }
api.AGPL.Options.SetUserGroups = api.setUserGroups
api.AGPL.Options.SetUserSiteRoles = api.setUserSiteRoles api.AGPL.Options.SetUserSiteRoles = api.setUserSiteRoles
api.AGPL.SiteHandler.RegionsFetcher = func(ctx context.Context) (any, error) { api.AGPL.SiteHandler.RegionsFetcher = func(ctx context.Context) (any, error) {
// If the user can read the workspace proxy resource, return that. // If the user can read the workspace proxy resource, return that.
+3 -3
View File
@@ -2,9 +2,9 @@ package enidpsync
import ( import (
"cdr.dev/slog" "cdr.dev/slog"
"github.com/coder/coder/v2/coderd/entitlements" "github.com/coder/coder/v2/coderd/entitlements"
"github.com/coder/coder/v2/coderd/idpsync" "github.com/coder/coder/v2/coderd/idpsync"
"github.com/coder/coder/v2/coderd/runtimeconfig"
) )
// EnterpriseIDPSync enabled syncing user information from an external IDP. // EnterpriseIDPSync enabled syncing user information from an external IDP.
@@ -17,9 +17,9 @@ type EnterpriseIDPSync struct {
*idpsync.AGPLIDPSync *idpsync.AGPLIDPSync
} }
func NewSync(logger slog.Logger, set *entitlements.Set, settings idpsync.SyncSettings) *EnterpriseIDPSync { func NewSync(logger slog.Logger, manager *runtimeconfig.Manager, set *entitlements.Set, settings idpsync.DeploymentSyncSettings) *EnterpriseIDPSync {
return &EnterpriseIDPSync{ return &EnterpriseIDPSync{
entitlements: set, entitlements: set,
AGPLIDPSync: idpsync.NewAGPLSync(logger.With(slog.F("enterprise_capable", "true")), settings), AGPLIDPSync: idpsync.NewAGPLSync(logger.With(slog.F("enterprise_capable", "true")), manager, settings),
} }
} }
+70
View File
@@ -0,0 +1,70 @@
package enidpsync
import (
"context"
"net/http"
"github.com/golang-jwt/jwt/v4"
"github.com/coder/coder/v2/coderd/idpsync"
"github.com/coder/coder/v2/codersdk"
)
func (e EnterpriseIDPSync) GroupSyncEnabled() bool {
return e.entitlements.Enabled(codersdk.FeatureTemplateRBAC)
}
// ParseGroupClaims parses the user claims and handles deployment wide group behavior.
// Almost all behavior is deferred since each organization configures it's own
// group sync settings.
// GroupAllowList is implemented here to prevent login by unauthorized users.
// TODO: GroupAllowList overlaps with the default organization group sync settings.
func (e EnterpriseIDPSync) ParseGroupClaims(ctx context.Context, mergedClaims jwt.MapClaims) (idpsync.GroupParams, *idpsync.HTTPError) {
if !e.GroupSyncEnabled() {
return e.AGPLIDPSync.ParseGroupClaims(ctx, mergedClaims)
}
if e.GroupField != "" && len(e.GroupAllowList) > 0 {
groupsRaw, ok := mergedClaims[e.GroupField]
if !ok {
return idpsync.GroupParams{}, &idpsync.HTTPError{
Code: http.StatusForbidden,
Msg: "Not a member of an allowed group",
Detail: "You have no groups in your claims!",
RenderStaticPage: true,
}
}
parsedGroups, err := idpsync.ParseStringSliceClaim(groupsRaw)
if err != nil {
return idpsync.GroupParams{}, &idpsync.HTTPError{
Code: http.StatusBadRequest,
Msg: "Failed read groups from claims for allow list check. Ask an administrator for help.",
Detail: err.Error(),
RenderStaticPage: true,
}
}
inAllowList := false
AllowListCheckLoop:
for _, group := range parsedGroups {
if _, ok := e.GroupAllowList[group]; ok {
inAllowList = true
break AllowListCheckLoop
}
}
if !inAllowList {
return idpsync.GroupParams{}, &idpsync.HTTPError{
Code: http.StatusForbidden,
Msg: "Not a member of an allowed group",
Detail: "Ask an administrator to add one of your groups to the allow list.",
RenderStaticPage: true,
}
}
}
return idpsync.GroupParams{
SyncEnabled: true,
MergedClaims: mergedClaims,
}, nil
}
@@ -0,0 +1,96 @@
package enidpsync_test
import (
"testing"
"github.com/golang-jwt/jwt/v4"
"github.com/stretchr/testify/require"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/v2/coderd/entitlements"
"github.com/coder/coder/v2/coderd/idpsync"
"github.com/coder/coder/v2/coderd/runtimeconfig"
"github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/enterprise/coderd/enidpsync"
"github.com/coder/coder/v2/testutil"
)
func TestEnterpriseParseGroupClaims(t *testing.T) {
t.Parallel()
entitled := entitlements.New()
entitled.Update(func(entitlements *codersdk.Entitlements) {
entitlements.Features[codersdk.FeatureTemplateRBAC] = codersdk.Feature{
Entitlement: codersdk.EntitlementEntitled,
Enabled: true,
}
})
t.Run("NoEntitlements", func(t *testing.T) {
t.Parallel()
s := enidpsync.NewSync(slogtest.Make(t, &slogtest.Options{}),
runtimeconfig.NewManager(),
entitlements.New(),
idpsync.DeploymentSyncSettings{})
ctx := testutil.Context(t, testutil.WaitMedium)
params, err := s.ParseGroupClaims(ctx, jwt.MapClaims{})
require.Nil(t, err)
require.False(t, params.SyncEnabled)
})
t.Run("NotInAllowList", func(t *testing.T) {
t.Parallel()
s := enidpsync.NewSync(slogtest.Make(t, &slogtest.Options{}),
runtimeconfig.NewManager(),
entitled,
idpsync.DeploymentSyncSettings{
GroupField: "groups",
GroupAllowList: map[string]struct{}{
"foo": {},
},
})
ctx := testutil.Context(t, testutil.WaitMedium)
// Try with incorrect group
_, err := s.ParseGroupClaims(ctx, jwt.MapClaims{
"groups": []string{"bar"},
})
require.NotNil(t, err)
require.Equal(t, 403, err.Code)
// Try with no groups
_, err = s.ParseGroupClaims(ctx, jwt.MapClaims{})
require.NotNil(t, err)
require.Equal(t, 403, err.Code)
})
t.Run("InAllowList", func(t *testing.T) {
t.Parallel()
s := enidpsync.NewSync(slogtest.Make(t, &slogtest.Options{}),
runtimeconfig.NewManager(),
entitled,
idpsync.DeploymentSyncSettings{
GroupField: "groups",
GroupAllowList: map[string]struct{}{
"foo": {},
},
})
ctx := testutil.Context(t, testutil.WaitMedium)
claims := jwt.MapClaims{
"groups": []string{"foo", "bar"},
}
params, err := s.ParseGroupClaims(ctx, claims)
require.Nil(t, err)
require.True(t, params.SyncEnabled)
require.Equal(t, claims, params.MergedClaims)
})
}
@@ -19,6 +19,7 @@ import (
"github.com/coder/coder/v2/coderd/entitlements" "github.com/coder/coder/v2/coderd/entitlements"
"github.com/coder/coder/v2/coderd/idpsync" "github.com/coder/coder/v2/coderd/idpsync"
"github.com/coder/coder/v2/coderd/rbac" "github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/coderd/runtimeconfig"
"github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/enterprise/coderd/enidpsync" "github.com/coder/coder/v2/enterprise/coderd/enidpsync"
"github.com/coder/coder/v2/testutil" "github.com/coder/coder/v2/testutil"
@@ -41,7 +42,7 @@ type Expectations struct {
} }
type OrganizationSyncTestCase struct { type OrganizationSyncTestCase struct {
Settings idpsync.SyncSettings Settings idpsync.DeploymentSyncSettings
Entitlements *entitlements.Set Entitlements *entitlements.Set
Exps []Expectations Exps []Expectations
} }
@@ -89,7 +90,7 @@ func TestOrganizationSync(t *testing.T) {
other := dbgen.Organization(t, db, database.Organization{}) other := dbgen.Organization(t, db, database.Organization{})
return OrganizationSyncTestCase{ return OrganizationSyncTestCase{
Entitlements: entitled, Entitlements: entitled,
Settings: idpsync.SyncSettings{ Settings: idpsync.DeploymentSyncSettings{
OrganizationField: "", OrganizationField: "",
OrganizationMapping: nil, OrganizationMapping: nil,
OrganizationAssignDefault: true, OrganizationAssignDefault: true,
@@ -142,7 +143,7 @@ func TestOrganizationSync(t *testing.T) {
three := dbgen.Organization(t, db, database.Organization{}) three := dbgen.Organization(t, db, database.Organization{})
return OrganizationSyncTestCase{ return OrganizationSyncTestCase{
Entitlements: entitled, Entitlements: entitled,
Settings: idpsync.SyncSettings{ Settings: idpsync.DeploymentSyncSettings{
OrganizationField: "organizations", OrganizationField: "organizations",
OrganizationMapping: map[string][]uuid.UUID{ OrganizationMapping: map[string][]uuid.UUID{
"first": {one.ID}, "first": {one.ID},
@@ -236,7 +237,7 @@ func TestOrganizationSync(t *testing.T) {
} }
// Create a new sync object // Create a new sync object
sync := enidpsync.NewSync(logger, caseData.Entitlements, caseData.Settings) sync := enidpsync.NewSync(logger, runtimeconfig.NewManager(), caseData.Entitlements, caseData.Settings)
for _, exp := range caseData.Exps { for _, exp := range caseData.Exps {
t.Run(exp.Name, func(t *testing.T) { t.Run(exp.Name, func(t *testing.T) {
params, httpErr := sync.ParseOrganizationClaims(ctx, exp.Claims) params, httpErr := sync.ParseOrganizationClaims(ctx, exp.Claims)
-66
View File
@@ -8,75 +8,9 @@ import (
"cdr.dev/slog" "cdr.dev/slog"
"github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbauthz"
"github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk"
) )
// nolint: revive
func (api *API) setUserGroups(ctx context.Context, logger slog.Logger, db database.Store, userID uuid.UUID, orgGroupNames map[uuid.UUID][]string, createMissingGroups bool) error {
if !api.Entitlements.Enabled(codersdk.FeatureTemplateRBAC) {
return nil
}
return db.InTx(func(tx database.Store) error {
// When setting the user's groups, it's easier to just clear their groups and re-add them.
// This ensures that the user's groups are always in sync with the auth provider.
orgs, err := tx.GetOrganizationsByUserID(ctx, userID)
if err != nil {
return xerrors.Errorf("get user orgs: %w", err)
}
if len(orgs) != 1 {
return xerrors.Errorf("expected 1 org, got %d", len(orgs))
}
// Delete all groups the user belongs to.
// nolint:gocritic // Requires system context to remove user from all groups.
err = tx.RemoveUserFromAllGroups(dbauthz.AsSystemRestricted(ctx), userID)
if err != nil {
return xerrors.Errorf("delete user groups: %w", err)
}
// TODO: This could likely be improved by making these single queries.
// Either by batching or some other means. This for loop could be really
// inefficient if there are a lot of organizations. There was deployments
// on v1 with >100 orgs.
for orgID, groupNames := range orgGroupNames {
// Create the missing groups for each organization.
if createMissingGroups {
// This is the system creating these additional groups, so we use the system restricted context.
// nolint:gocritic
created, err := tx.InsertMissingGroups(dbauthz.AsSystemRestricted(ctx), database.InsertMissingGroupsParams{
OrganizationID: orgID,
GroupNames: groupNames,
Source: database.GroupSourceOidc,
})
if err != nil {
return xerrors.Errorf("insert missing groups: %w", err)
}
if len(created) > 0 {
logger.Debug(ctx, "auto created missing groups",
slog.F("org_id", orgID.ID),
slog.F("created", created),
slog.F("num", len(created)),
)
}
}
// Re-add the user to all groups returned by the auth provider.
err = tx.InsertUserGroupsByName(ctx, database.InsertUserGroupsByNameParams{
UserID: userID,
OrganizationID: orgID,
GroupNames: groupNames,
})
if err != nil {
return xerrors.Errorf("insert user groups: %w", err)
}
}
return nil
}, nil)
}
func (api *API) setUserSiteRoles(ctx context.Context, logger slog.Logger, db database.Store, userID uuid.UUID, roles []string) error { func (api *API) setUserSiteRoles(ctx context.Context, logger slog.Logger, db database.Store, userID uuid.UUID, roles []string) error {
if !api.Entitlements.Enabled(codersdk.FeatureUserRoleManagement) { if !api.Entitlements.Enabled(codersdk.FeatureUserRoleManagement) {
logger.Warn(ctx, "attempted to assign OIDC user roles without enterprise entitlement, roles left unchanged", logger.Warn(ctx, "attempted to assign OIDC user roles without enterprise entitlement, roles left unchanged",
+51 -32
View File
@@ -402,7 +402,9 @@ func TestUserOIDC(t *testing.T) {
runner := setupOIDCTest(t, oidcTestConfig{ runner := setupOIDCTest(t, oidcTestConfig{
Config: func(cfg *coderd.OIDCConfig) { Config: func(cfg *coderd.OIDCConfig) {
cfg.AllowSignups = true cfg.AllowSignups = true
cfg.GroupField = groupClaim },
DeploymentValues: func(dv *codersdk.DeploymentValues) {
dv.OIDC.GroupField = groupClaim
}, },
}) })
@@ -433,8 +435,10 @@ func TestUserOIDC(t *testing.T) {
runner := setupOIDCTest(t, oidcTestConfig{ runner := setupOIDCTest(t, oidcTestConfig{
Config: func(cfg *coderd.OIDCConfig) { Config: func(cfg *coderd.OIDCConfig) {
cfg.AllowSignups = true cfg.AllowSignups = true
cfg.GroupField = groupClaim },
cfg.GroupMapping = map[string]string{oidcGroupName: coderGroupName} DeploymentValues: func(dv *codersdk.DeploymentValues) {
dv.OIDC.GroupField = groupClaim
dv.OIDC.GroupMapping = serpent.Struct[map[string]string]{Value: map[string]string{oidcGroupName: coderGroupName}}
}, },
}) })
@@ -468,7 +472,9 @@ func TestUserOIDC(t *testing.T) {
runner := setupOIDCTest(t, oidcTestConfig{ runner := setupOIDCTest(t, oidcTestConfig{
Config: func(cfg *coderd.OIDCConfig) { Config: func(cfg *coderd.OIDCConfig) {
cfg.AllowSignups = true cfg.AllowSignups = true
cfg.GroupField = groupClaim },
DeploymentValues: func(dv *codersdk.DeploymentValues) {
dv.OIDC.GroupField = groupClaim
}, },
}) })
@@ -502,7 +508,9 @@ func TestUserOIDC(t *testing.T) {
runner := setupOIDCTest(t, oidcTestConfig{ runner := setupOIDCTest(t, oidcTestConfig{
Config: func(cfg *coderd.OIDCConfig) { Config: func(cfg *coderd.OIDCConfig) {
cfg.AllowSignups = true cfg.AllowSignups = true
cfg.GroupField = groupClaim },
DeploymentValues: func(dv *codersdk.DeploymentValues) {
dv.OIDC.GroupField = groupClaim
}, },
}) })
@@ -537,7 +545,9 @@ func TestUserOIDC(t *testing.T) {
runner := setupOIDCTest(t, oidcTestConfig{ runner := setupOIDCTest(t, oidcTestConfig{
Config: func(cfg *coderd.OIDCConfig) { Config: func(cfg *coderd.OIDCConfig) {
cfg.AllowSignups = true cfg.AllowSignups = true
cfg.GroupField = groupClaim },
DeploymentValues: func(dv *codersdk.DeploymentValues) {
dv.OIDC.GroupField = groupClaim
}, },
}) })
@@ -559,8 +569,10 @@ func TestUserOIDC(t *testing.T) {
runner := setupOIDCTest(t, oidcTestConfig{ runner := setupOIDCTest(t, oidcTestConfig{
Config: func(cfg *coderd.OIDCConfig) { Config: func(cfg *coderd.OIDCConfig) {
cfg.AllowSignups = true cfg.AllowSignups = true
cfg.GroupField = groupClaim },
cfg.CreateMissingGroups = true DeploymentValues: func(dv *codersdk.DeploymentValues) {
dv.OIDC.GroupField = groupClaim
dv.OIDC.GroupAutoCreate = true
}, },
}) })
@@ -582,8 +594,10 @@ func TestUserOIDC(t *testing.T) {
runner := setupOIDCTest(t, oidcTestConfig{ runner := setupOIDCTest(t, oidcTestConfig{
Config: func(cfg *coderd.OIDCConfig) { Config: func(cfg *coderd.OIDCConfig) {
cfg.AllowSignups = true cfg.AllowSignups = true
cfg.GroupField = groupClaim },
cfg.CreateMissingGroups = true DeploymentValues: func(dv *codersdk.DeploymentValues) {
dv.OIDC.GroupField = groupClaim
dv.OIDC.GroupAutoCreate = true
}, },
}) })
@@ -606,8 +620,10 @@ func TestUserOIDC(t *testing.T) {
runner := setupOIDCTest(t, oidcTestConfig{ runner := setupOIDCTest(t, oidcTestConfig{
Config: func(cfg *coderd.OIDCConfig) { Config: func(cfg *coderd.OIDCConfig) {
cfg.AllowSignups = true cfg.AllowSignups = true
cfg.GroupField = groupClaim },
cfg.GroupAllowList = map[string]bool{allowedGroup: true} DeploymentValues: func(dv *codersdk.DeploymentValues) {
dv.OIDC.GroupField = groupClaim
dv.OIDC.GroupAllowList = []string{allowedGroup}
}, },
}) })
@@ -697,6 +713,7 @@ func TestGroupSync(t *testing.T) {
testCases := []struct { testCases := []struct {
name string name string
modCfg func(cfg *coderd.OIDCConfig) modCfg func(cfg *coderd.OIDCConfig)
modDV func(dv *codersdk.DeploymentValues)
// initialOrgGroups is initial groups in the org // initialOrgGroups is initial groups in the org
initialOrgGroups []string initialOrgGroups []string
// initialUserGroups is initial groups for the user // initialUserGroups is initial groups for the user
@@ -718,10 +735,10 @@ func TestGroupSync(t *testing.T) {
}, },
{ {
name: "GroupSyncDisabled", name: "GroupSyncDisabled",
modCfg: func(cfg *coderd.OIDCConfig) { modDV: func(dv *codersdk.DeploymentValues) {
// Disable group sync // Disable group sync
cfg.GroupField = "" dv.OIDC.GroupField = ""
cfg.GroupFilter = regexp.MustCompile(".*") dv.OIDC.GroupRegexFilter = serpent.Regexp(*regexp.MustCompile(".*"))
}, },
initialOrgGroups: []string{"a", "b", "c", "d"}, initialOrgGroups: []string{"a", "b", "c", "d"},
initialUserGroups: []string{"b", "c", "d"}, initialUserGroups: []string{"b", "c", "d"},
@@ -732,10 +749,8 @@ func TestGroupSync(t *testing.T) {
{ {
// From a,c,b -> b,c,d // From a,c,b -> b,c,d
name: "ChangeUserGroups", name: "ChangeUserGroups",
modCfg: func(cfg *coderd.OIDCConfig) { modDV: func(dv *codersdk.DeploymentValues) {
cfg.GroupMapping = map[string]string{ dv.OIDC.GroupMapping = serpent.Struct[map[string]string]{Value: map[string]string{"D": "d"}}
"D": "d",
}
}, },
initialOrgGroups: []string{"a", "b", "c", "d"}, initialOrgGroups: []string{"a", "b", "c", "d"},
initialUserGroups: []string{"a", "b", "c"}, initialUserGroups: []string{"a", "b", "c"},
@@ -749,8 +764,8 @@ func TestGroupSync(t *testing.T) {
{ {
// From a,c,b -> [] // From a,c,b -> []
name: "RemoveAllGroups", name: "RemoveAllGroups",
modCfg: func(cfg *coderd.OIDCConfig) { modDV: func(dv *codersdk.DeploymentValues) {
cfg.GroupFilter = regexp.MustCompile(".*") dv.OIDC.GroupRegexFilter = serpent.Regexp(*regexp.MustCompile(".*"))
}, },
initialOrgGroups: []string{"a", "b", "c", "d"}, initialOrgGroups: []string{"a", "b", "c", "d"},
initialUserGroups: []string{"a", "b", "c"}, initialUserGroups: []string{"a", "b", "c"},
@@ -763,8 +778,8 @@ func TestGroupSync(t *testing.T) {
{ {
// From a,c,b -> b,c,d,e,f // From a,c,b -> b,c,d,e,f
name: "CreateMissingGroups", name: "CreateMissingGroups",
modCfg: func(cfg *coderd.OIDCConfig) { modDV: func(dv *codersdk.DeploymentValues) {
cfg.CreateMissingGroups = true dv.OIDC.GroupAutoCreate = true
}, },
initialOrgGroups: []string{"a", "b", "c", "d"}, initialOrgGroups: []string{"a", "b", "c", "d"},
initialUserGroups: []string{"a", "b", "c"}, initialUserGroups: []string{"a", "b", "c"},
@@ -777,14 +792,11 @@ func TestGroupSync(t *testing.T) {
{ {
// From a,c,b -> b,c,d,e,f // From a,c,b -> b,c,d,e,f
name: "CreateMissingGroupsFilter", name: "CreateMissingGroupsFilter",
modCfg: func(cfg *coderd.OIDCConfig) { modDV: func(dv *codersdk.DeploymentValues) {
cfg.CreateMissingGroups = true dv.OIDC.GroupAutoCreate = true
// Only single letter groups // Only single letter groups
cfg.GroupFilter = regexp.MustCompile("^[a-z]$") dv.OIDC.GroupRegexFilter = serpent.Regexp(*regexp.MustCompile("^[a-z]$"))
cfg.GroupMapping = map[string]string{ dv.OIDC.GroupMapping = serpent.Struct[map[string]string]{Value: map[string]string{"zebra": "z"}}
// Does not match the filter, but does after being mapped!
"zebra": "z",
}
}, },
initialOrgGroups: []string{"a", "b", "c", "d"}, initialOrgGroups: []string{"a", "b", "c", "d"},
initialUserGroups: []string{"a", "b", "c"}, initialUserGroups: []string{"a", "b", "c"},
@@ -806,8 +818,15 @@ func TestGroupSync(t *testing.T) {
t.Parallel() t.Parallel()
runner := setupOIDCTest(t, oidcTestConfig{ runner := setupOIDCTest(t, oidcTestConfig{
Config: func(cfg *coderd.OIDCConfig) { Config: func(cfg *coderd.OIDCConfig) {
cfg.GroupField = "groups" if tc.modCfg != nil {
tc.modCfg(cfg) tc.modCfg(cfg)
}
},
DeploymentValues: func(dv *codersdk.DeploymentValues) {
dv.OIDC.GroupField = "groups"
if tc.modDV != nil {
tc.modDV(dv)
}
}, },
}) })