mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
chore: support multi-org group sync with runtime configuration (#14578)
- Implement multi-org group sync - Implement runtime configuration to change sync behavior - Legacy group sync migrated to new package
This commit is contained in:
@@ -187,11 +187,6 @@ func createOIDCConfig(ctx context.Context, logger slog.Logger, vals *codersdk.De
|
|||||||
EmailField: vals.OIDC.EmailField.String(),
|
EmailField: vals.OIDC.EmailField.String(),
|
||||||
AuthURLParams: vals.OIDC.AuthURLParams.Value,
|
AuthURLParams: vals.OIDC.AuthURLParams.Value,
|
||||||
IgnoreUserInfo: vals.OIDC.IgnoreUserInfo.Value(),
|
IgnoreUserInfo: vals.OIDC.IgnoreUserInfo.Value(),
|
||||||
GroupField: vals.OIDC.GroupField.String(),
|
|
||||||
GroupFilter: vals.OIDC.GroupRegexFilter.Value(),
|
|
||||||
GroupAllowList: groupAllowList,
|
|
||||||
CreateMissingGroups: vals.OIDC.GroupAutoCreate.Value(),
|
|
||||||
GroupMapping: vals.OIDC.GroupMapping.Value,
|
|
||||||
UserRoleField: vals.OIDC.UserRoleField.String(),
|
UserRoleField: vals.OIDC.UserRoleField.String(),
|
||||||
UserRoleMapping: vals.OIDC.UserRoleMapping.Value,
|
UserRoleMapping: vals.OIDC.UserRoleMapping.Value,
|
||||||
UserRolesDefault: vals.OIDC.UserRolesDefault.GetSlice(),
|
UserRolesDefault: vals.OIDC.UserRolesDefault.GetSlice(),
|
||||||
|
|||||||
+4
-18
@@ -181,7 +181,6 @@ type Options struct {
|
|||||||
NetworkTelemetryBatchFrequency time.Duration
|
NetworkTelemetryBatchFrequency time.Duration
|
||||||
NetworkTelemetryBatchMaxSize int
|
NetworkTelemetryBatchMaxSize int
|
||||||
SwaggerEndpoint bool
|
SwaggerEndpoint bool
|
||||||
SetUserGroups func(ctx context.Context, logger slog.Logger, tx database.Store, userID uuid.UUID, orgGroupNames map[uuid.UUID][]string, createMissingGroups bool) error
|
|
||||||
SetUserSiteRoles func(ctx context.Context, logger slog.Logger, tx database.Store, userID uuid.UUID, roles []string) error
|
SetUserSiteRoles func(ctx context.Context, logger slog.Logger, tx database.Store, userID uuid.UUID, roles []string) error
|
||||||
TemplateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore]
|
TemplateScheduleStore *atomic.Pointer[schedule.TemplateScheduleStore]
|
||||||
UserQuietHoursScheduleStore *atomic.Pointer[schedule.UserQuietHoursScheduleStore]
|
UserQuietHoursScheduleStore *atomic.Pointer[schedule.UserQuietHoursScheduleStore]
|
||||||
@@ -276,13 +275,6 @@ func New(options *Options) *API {
|
|||||||
if options.Entitlements == nil {
|
if options.Entitlements == nil {
|
||||||
options.Entitlements = entitlements.New()
|
options.Entitlements = entitlements.New()
|
||||||
}
|
}
|
||||||
if options.IDPSync == nil {
|
|
||||||
options.IDPSync = idpsync.NewAGPLSync(options.Logger, idpsync.SyncSettings{
|
|
||||||
OrganizationField: options.DeploymentValues.OIDC.OrganizationField.Value(),
|
|
||||||
OrganizationMapping: options.DeploymentValues.OIDC.OrganizationMapping.Value,
|
|
||||||
OrganizationAssignDefault: options.DeploymentValues.OIDC.OrganizationAssignDefault.Value(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
if options.NewTicker == nil {
|
if options.NewTicker == nil {
|
||||||
options.NewTicker = func(duration time.Duration) (tick <-chan time.Time, done func()) {
|
options.NewTicker = func(duration time.Duration) (tick <-chan time.Time, done func()) {
|
||||||
ticker := time.NewTicker(duration)
|
ticker := time.NewTicker(duration)
|
||||||
@@ -318,6 +310,10 @@ func New(options *Options) *API {
|
|||||||
options.AccessControlStore,
|
options.AccessControlStore,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if options.IDPSync == nil {
|
||||||
|
options.IDPSync = idpsync.NewAGPLSync(options.Logger, options.RuntimeConfig, idpsync.FromDeploymentValues(options.DeploymentValues))
|
||||||
|
}
|
||||||
|
|
||||||
experiments := ReadExperiments(
|
experiments := ReadExperiments(
|
||||||
options.Logger, options.DeploymentValues.Experiments.Value(),
|
options.Logger, options.DeploymentValues.Experiments.Value(),
|
||||||
)
|
)
|
||||||
@@ -377,16 +373,6 @@ func New(options *Options) *API {
|
|||||||
if options.TracerProvider == nil {
|
if options.TracerProvider == nil {
|
||||||
options.TracerProvider = trace.NewNoopTracerProvider()
|
options.TracerProvider = trace.NewNoopTracerProvider()
|
||||||
}
|
}
|
||||||
if options.SetUserGroups == nil {
|
|
||||||
options.SetUserGroups = func(ctx context.Context, logger slog.Logger, _ database.Store, userID uuid.UUID, orgGroupNames map[uuid.UUID][]string, createMissingGroups bool) error {
|
|
||||||
logger.Warn(ctx, "attempted to assign OIDC groups without enterprise license",
|
|
||||||
slog.F("user_id", userID),
|
|
||||||
slog.F("groups", orgGroupNames),
|
|
||||||
slog.F("create_missing_groups", createMissingGroups),
|
|
||||||
)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if options.SetUserSiteRoles == nil {
|
if options.SetUserSiteRoles == nil {
|
||||||
options.SetUserSiteRoles = func(ctx context.Context, logger slog.Logger, _ database.Store, userID uuid.UUID, roles []string) error {
|
options.SetUserSiteRoles = func(ctx context.Context, logger slog.Logger, _ database.Store, userID uuid.UUID, roles []string) error {
|
||||||
logger.Warn(ctx, "attempted to assign OIDC user roles without enterprise license",
|
logger.Warn(ctx, "attempted to assign OIDC user roles without enterprise license",
|
||||||
|
|||||||
@@ -0,0 +1,25 @@
|
|||||||
|
package coderdtest
|
||||||
|
|
||||||
|
import "github.com/google/uuid"
|
||||||
|
|
||||||
|
// DeterministicUUIDGenerator allows "naming" uuids for unit tests.
|
||||||
|
// An example of where this is useful, is when a tabled test references
|
||||||
|
// a UUID that is not yet known. An alternative to this would be to
|
||||||
|
// hard code some UUID strings, but these strings are not human friendly.
|
||||||
|
type DeterministicUUIDGenerator struct {
|
||||||
|
Named map[string]uuid.UUID
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDeterministicUUIDGenerator() *DeterministicUUIDGenerator {
|
||||||
|
return &DeterministicUUIDGenerator{
|
||||||
|
Named: make(map[string]uuid.UUID),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DeterministicUUIDGenerator) ID(name string) uuid.UUID {
|
||||||
|
if v, ok := d.Named[name]; ok {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
d.Named[name] = uuid.New()
|
||||||
|
return d.Named[name]
|
||||||
|
}
|
||||||
@@ -0,0 +1,17 @@
|
|||||||
|
package coderdtest_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestDeterministicUUIDGenerator(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ids := coderdtest.NewDeterministicUUIDGenerator()
|
||||||
|
require.Equal(t, ids.ID("g1"), ids.ID("g1"))
|
||||||
|
require.NotEqual(t, ids.ID("g1"), ids.ID("g2"))
|
||||||
|
}
|
||||||
@@ -2892,6 +2892,14 @@ func (q *querier) InsertUser(ctx context.Context, arg database.InsertUserParams)
|
|||||||
return insert(q.log, q.auth, obj, q.db.InsertUser)(ctx, arg)
|
return insert(q.log, q.auth, obj, q.db.InsertUser)(ctx, arg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (q *querier) InsertUserGroupsByID(ctx context.Context, arg database.InsertUserGroupsByIDParams) ([]uuid.UUID, error) {
|
||||||
|
// This is used by OIDC sync. So only used by a system user.
|
||||||
|
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return q.db.InsertUserGroupsByID(ctx, arg)
|
||||||
|
}
|
||||||
|
|
||||||
func (q *querier) InsertUserGroupsByName(ctx context.Context, arg database.InsertUserGroupsByNameParams) error {
|
func (q *querier) InsertUserGroupsByName(ctx context.Context, arg database.InsertUserGroupsByNameParams) error {
|
||||||
// This will add the user to all named groups. This counts as updating a group.
|
// This will add the user to all named groups. This counts as updating a group.
|
||||||
// NOTE: instead of checking if the user has permission to update each group, we instead
|
// NOTE: instead of checking if the user has permission to update each group, we instead
|
||||||
@@ -3100,6 +3108,14 @@ func (q *querier) RemoveUserFromAllGroups(ctx context.Context, userID uuid.UUID)
|
|||||||
return q.db.RemoveUserFromAllGroups(ctx, userID)
|
return q.db.RemoveUserFromAllGroups(ctx, userID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (q *querier) RemoveUserFromGroups(ctx context.Context, arg database.RemoveUserFromGroupsParams) ([]uuid.UUID, error) {
|
||||||
|
// This is a system function to clear user groups in group sync.
|
||||||
|
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return q.db.RemoveUserFromGroups(ctx, arg)
|
||||||
|
}
|
||||||
|
|
||||||
func (q *querier) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error {
|
func (q *querier) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error {
|
||||||
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
|
if err := q.authorizeContext(ctx, policy.ActionUpdate, rbac.ResourceSystem); err != nil {
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -388,6 +388,17 @@ func (s *MethodTestSuite) TestGroup() {
|
|||||||
GroupNames: slice.New(g1.Name, g2.Name),
|
GroupNames: slice.New(g1.Name, g2.Name),
|
||||||
}).Asserts(rbac.ResourceGroup.InOrg(o.ID), policy.ActionUpdate).Returns()
|
}).Asserts(rbac.ResourceGroup.InOrg(o.ID), policy.ActionUpdate).Returns()
|
||||||
}))
|
}))
|
||||||
|
s.Run("InsertUserGroupsByID", s.Subtest(func(db database.Store, check *expects) {
|
||||||
|
o := dbgen.Organization(s.T(), db, database.Organization{})
|
||||||
|
u1 := dbgen.User(s.T(), db, database.User{})
|
||||||
|
g1 := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID})
|
||||||
|
g2 := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID})
|
||||||
|
_ = dbgen.GroupMember(s.T(), db, database.GroupMemberTable{GroupID: g1.ID, UserID: u1.ID})
|
||||||
|
check.Args(database.InsertUserGroupsByIDParams{
|
||||||
|
UserID: u1.ID,
|
||||||
|
GroupIds: slice.New(g1.ID, g2.ID),
|
||||||
|
}).Asserts(rbac.ResourceSystem, policy.ActionUpdate).Returns(slice.New(g1.ID, g2.ID))
|
||||||
|
}))
|
||||||
s.Run("RemoveUserFromAllGroups", s.Subtest(func(db database.Store, check *expects) {
|
s.Run("RemoveUserFromAllGroups", s.Subtest(func(db database.Store, check *expects) {
|
||||||
o := dbgen.Organization(s.T(), db, database.Organization{})
|
o := dbgen.Organization(s.T(), db, database.Organization{})
|
||||||
u1 := dbgen.User(s.T(), db, database.User{})
|
u1 := dbgen.User(s.T(), db, database.User{})
|
||||||
@@ -397,6 +408,18 @@ func (s *MethodTestSuite) TestGroup() {
|
|||||||
_ = dbgen.GroupMember(s.T(), db, database.GroupMemberTable{GroupID: g2.ID, UserID: u1.ID})
|
_ = dbgen.GroupMember(s.T(), db, database.GroupMemberTable{GroupID: g2.ID, UserID: u1.ID})
|
||||||
check.Args(u1.ID).Asserts(rbac.ResourceSystem, policy.ActionUpdate).Returns()
|
check.Args(u1.ID).Asserts(rbac.ResourceSystem, policy.ActionUpdate).Returns()
|
||||||
}))
|
}))
|
||||||
|
s.Run("RemoveUserFromGroups", s.Subtest(func(db database.Store, check *expects) {
|
||||||
|
o := dbgen.Organization(s.T(), db, database.Organization{})
|
||||||
|
u1 := dbgen.User(s.T(), db, database.User{})
|
||||||
|
g1 := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID})
|
||||||
|
g2 := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID})
|
||||||
|
_ = dbgen.GroupMember(s.T(), db, database.GroupMemberTable{GroupID: g1.ID, UserID: u1.ID})
|
||||||
|
_ = dbgen.GroupMember(s.T(), db, database.GroupMemberTable{GroupID: g2.ID, UserID: u1.ID})
|
||||||
|
check.Args(database.RemoveUserFromGroupsParams{
|
||||||
|
UserID: u1.ID,
|
||||||
|
GroupIds: []uuid.UUID{g1.ID, g2.ID},
|
||||||
|
}).Asserts(rbac.ResourceSystem, policy.ActionUpdate).Returns(slice.New(g1.ID, g2.ID))
|
||||||
|
}))
|
||||||
s.Run("UpdateGroupByID", s.Subtest(func(db database.Store, check *expects) {
|
s.Run("UpdateGroupByID", s.Subtest(func(db database.Store, check *expects) {
|
||||||
g := dbgen.Group(s.T(), db, database.Group{})
|
g := dbgen.Group(s.T(), db, database.Group{})
|
||||||
check.Args(database.UpdateGroupByIDParams{
|
check.Args(database.UpdateGroupByIDParams{
|
||||||
|
|||||||
@@ -2695,18 +2695,18 @@ func (q *FakeQuerier) GetGroups(_ context.Context, arg database.GetGroupsParams)
|
|||||||
q.mutex.RLock()
|
q.mutex.RLock()
|
||||||
defer q.mutex.RUnlock()
|
defer q.mutex.RUnlock()
|
||||||
|
|
||||||
groupIDs := make(map[uuid.UUID]struct{})
|
userGroupIDs := make(map[uuid.UUID]struct{})
|
||||||
if arg.HasMemberID != uuid.Nil {
|
if arg.HasMemberID != uuid.Nil {
|
||||||
for _, member := range q.groupMembers {
|
for _, member := range q.groupMembers {
|
||||||
if member.UserID == arg.HasMemberID {
|
if member.UserID == arg.HasMemberID {
|
||||||
groupIDs[member.GroupID] = struct{}{}
|
userGroupIDs[member.GroupID] = struct{}{}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle the everyone group
|
// Handle the everyone group
|
||||||
for _, orgMember := range q.organizationMembers {
|
for _, orgMember := range q.organizationMembers {
|
||||||
if orgMember.UserID == arg.HasMemberID {
|
if orgMember.UserID == arg.HasMemberID {
|
||||||
groupIDs[orgMember.OrganizationID] = struct{}{}
|
userGroupIDs[orgMember.OrganizationID] = struct{}{}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -2718,11 +2718,15 @@ func (q *FakeQuerier) GetGroups(_ context.Context, arg database.GetGroupsParams)
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
_, ok := groupIDs[group.ID]
|
_, ok := userGroupIDs[group.ID]
|
||||||
if arg.HasMemberID != uuid.Nil && !ok {
|
if arg.HasMemberID != uuid.Nil && !ok {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(arg.GroupNames) > 0 && !slices.Contains(arg.GroupNames, group.Name) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
orgDetails, ok := orgDetailsCache[group.ID]
|
orgDetails, ok := orgDetailsCache[group.ID]
|
||||||
if !ok {
|
if !ok {
|
||||||
for _, org := range q.organizations {
|
for _, org := range q.organizations {
|
||||||
@@ -7015,7 +7019,37 @@ func (q *FakeQuerier) InsertUser(_ context.Context, arg database.InsertUserParam
|
|||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (q *FakeQuerier) InsertUserGroupsByID(_ context.Context, arg database.InsertUserGroupsByIDParams) ([]uuid.UUID, error) {
|
||||||
|
err := validateDatabaseType(arg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
q.mutex.Lock()
|
||||||
|
defer q.mutex.Unlock()
|
||||||
|
|
||||||
|
var groupIDs []uuid.UUID
|
||||||
|
for _, group := range q.groups {
|
||||||
|
for _, groupID := range arg.GroupIds {
|
||||||
|
if group.ID == groupID {
|
||||||
|
q.groupMembers = append(q.groupMembers, database.GroupMemberTable{
|
||||||
|
UserID: arg.UserID,
|
||||||
|
GroupID: groupID,
|
||||||
|
})
|
||||||
|
groupIDs = append(groupIDs, group.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return groupIDs, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (q *FakeQuerier) InsertUserGroupsByName(_ context.Context, arg database.InsertUserGroupsByNameParams) error {
|
func (q *FakeQuerier) InsertUserGroupsByName(_ context.Context, arg database.InsertUserGroupsByNameParams) error {
|
||||||
|
err := validateDatabaseType(arg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
q.mutex.Lock()
|
q.mutex.Lock()
|
||||||
defer q.mutex.Unlock()
|
defer q.mutex.Unlock()
|
||||||
|
|
||||||
@@ -7607,6 +7641,34 @@ func (q *FakeQuerier) RemoveUserFromAllGroups(_ context.Context, userID uuid.UUI
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (q *FakeQuerier) RemoveUserFromGroups(_ context.Context, arg database.RemoveUserFromGroupsParams) ([]uuid.UUID, error) {
|
||||||
|
err := validateDatabaseType(arg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
q.mutex.Lock()
|
||||||
|
defer q.mutex.Unlock()
|
||||||
|
|
||||||
|
removed := make([]uuid.UUID, 0)
|
||||||
|
q.data.groupMembers = slices.DeleteFunc(q.data.groupMembers, func(groupMember database.GroupMemberTable) bool {
|
||||||
|
// Delete all group members that match the arguments.
|
||||||
|
if groupMember.UserID != arg.UserID {
|
||||||
|
// Not the right user, ignore.
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if !slices.Contains(arg.GroupIds, groupMember.GroupID) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
removed = append(removed, groupMember.GroupID)
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
return removed, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (q *FakeQuerier) RevokeDBCryptKey(_ context.Context, activeKeyDigest string) error {
|
func (q *FakeQuerier) RevokeDBCryptKey(_ context.Context, activeKeyDigest string) error {
|
||||||
q.mutex.Lock()
|
q.mutex.Lock()
|
||||||
defer q.mutex.Unlock()
|
defer q.mutex.Unlock()
|
||||||
|
|||||||
@@ -1789,6 +1789,13 @@ func (m metricsStore) InsertUser(ctx context.Context, arg database.InsertUserPar
|
|||||||
return user, err
|
return user, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m metricsStore) InsertUserGroupsByID(ctx context.Context, arg database.InsertUserGroupsByIDParams) ([]uuid.UUID, error) {
|
||||||
|
start := time.Now()
|
||||||
|
r0, r1 := m.s.InsertUserGroupsByID(ctx, arg)
|
||||||
|
m.queryLatencies.WithLabelValues("InsertUserGroupsByID").Observe(time.Since(start).Seconds())
|
||||||
|
return r0, r1
|
||||||
|
}
|
||||||
|
|
||||||
func (m metricsStore) InsertUserGroupsByName(ctx context.Context, arg database.InsertUserGroupsByNameParams) error {
|
func (m metricsStore) InsertUserGroupsByName(ctx context.Context, arg database.InsertUserGroupsByNameParams) error {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
err := m.s.InsertUserGroupsByName(ctx, arg)
|
err := m.s.InsertUserGroupsByName(ctx, arg)
|
||||||
@@ -1943,6 +1950,13 @@ func (m metricsStore) RemoveUserFromAllGroups(ctx context.Context, userID uuid.U
|
|||||||
return r0
|
return r0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m metricsStore) RemoveUserFromGroups(ctx context.Context, arg database.RemoveUserFromGroupsParams) ([]uuid.UUID, error) {
|
||||||
|
start := time.Now()
|
||||||
|
r0, r1 := m.s.RemoveUserFromGroups(ctx, arg)
|
||||||
|
m.queryLatencies.WithLabelValues("RemoveUserFromGroups").Observe(time.Since(start).Seconds())
|
||||||
|
return r0, r1
|
||||||
|
}
|
||||||
|
|
||||||
func (m metricsStore) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error {
|
func (m metricsStore) RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
r0 := m.s.RevokeDBCryptKey(ctx, activeKeyDigest)
|
r0 := m.s.RevokeDBCryptKey(ctx, activeKeyDigest)
|
||||||
|
|||||||
@@ -3766,6 +3766,21 @@ func (mr *MockStoreMockRecorder) InsertUser(arg0, arg1 any) *gomock.Call {
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertUser", reflect.TypeOf((*MockStore)(nil).InsertUser), arg0, arg1)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertUser", reflect.TypeOf((*MockStore)(nil).InsertUser), arg0, arg1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InsertUserGroupsByID mocks base method.
|
||||||
|
func (m *MockStore) InsertUserGroupsByID(arg0 context.Context, arg1 database.InsertUserGroupsByIDParams) ([]uuid.UUID, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "InsertUserGroupsByID", arg0, arg1)
|
||||||
|
ret0, _ := ret[0].([]uuid.UUID)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// InsertUserGroupsByID indicates an expected call of InsertUserGroupsByID.
|
||||||
|
func (mr *MockStoreMockRecorder) InsertUserGroupsByID(arg0, arg1 any) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InsertUserGroupsByID", reflect.TypeOf((*MockStore)(nil).InsertUserGroupsByID), arg0, arg1)
|
||||||
|
}
|
||||||
|
|
||||||
// InsertUserGroupsByName mocks base method.
|
// InsertUserGroupsByName mocks base method.
|
||||||
func (m *MockStore) InsertUserGroupsByName(arg0 context.Context, arg1 database.InsertUserGroupsByNameParams) error {
|
func (m *MockStore) InsertUserGroupsByName(arg0 context.Context, arg1 database.InsertUserGroupsByNameParams) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
@@ -4103,6 +4118,21 @@ func (mr *MockStoreMockRecorder) RemoveUserFromAllGroups(arg0, arg1 any) *gomock
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveUserFromAllGroups", reflect.TypeOf((*MockStore)(nil).RemoveUserFromAllGroups), arg0, arg1)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveUserFromAllGroups", reflect.TypeOf((*MockStore)(nil).RemoveUserFromAllGroups), arg0, arg1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RemoveUserFromGroups mocks base method.
|
||||||
|
func (m *MockStore) RemoveUserFromGroups(arg0 context.Context, arg1 database.RemoveUserFromGroupsParams) ([]uuid.UUID, error) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
ret := m.ctrl.Call(m, "RemoveUserFromGroups", arg0, arg1)
|
||||||
|
ret0, _ := ret[0].([]uuid.UUID)
|
||||||
|
ret1, _ := ret[1].(error)
|
||||||
|
return ret0, ret1
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveUserFromGroups indicates an expected call of RemoveUserFromGroups.
|
||||||
|
func (mr *MockStoreMockRecorder) RemoveUserFromGroups(arg0, arg1 any) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemoveUserFromGroups", reflect.TypeOf((*MockStore)(nil).RemoveUserFromGroups), arg0, arg1)
|
||||||
|
}
|
||||||
|
|
||||||
// RevokeDBCryptKey mocks base method.
|
// RevokeDBCryptKey mocks base method.
|
||||||
func (m *MockStore) RevokeDBCryptKey(arg0 context.Context, arg1 string) error {
|
func (m *MockStore) RevokeDBCryptKey(arg0 context.Context, arg1 string) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
|||||||
@@ -369,6 +369,9 @@ type sqlcQuerier interface {
|
|||||||
InsertTemplateVersionVariable(ctx context.Context, arg InsertTemplateVersionVariableParams) (TemplateVersionVariable, error)
|
InsertTemplateVersionVariable(ctx context.Context, arg InsertTemplateVersionVariableParams) (TemplateVersionVariable, error)
|
||||||
InsertTemplateVersionWorkspaceTag(ctx context.Context, arg InsertTemplateVersionWorkspaceTagParams) (TemplateVersionWorkspaceTag, error)
|
InsertTemplateVersionWorkspaceTag(ctx context.Context, arg InsertTemplateVersionWorkspaceTagParams) (TemplateVersionWorkspaceTag, error)
|
||||||
InsertUser(ctx context.Context, arg InsertUserParams) (User, error)
|
InsertUser(ctx context.Context, arg InsertUserParams) (User, error)
|
||||||
|
// InsertUserGroupsByID adds a user to all provided groups, if they exist.
|
||||||
|
// If there is a conflict, the user is already a member
|
||||||
|
InsertUserGroupsByID(ctx context.Context, arg InsertUserGroupsByIDParams) ([]uuid.UUID, error)
|
||||||
// InsertUserGroupsByName adds a user to all provided groups, if they exist.
|
// InsertUserGroupsByName adds a user to all provided groups, if they exist.
|
||||||
InsertUserGroupsByName(ctx context.Context, arg InsertUserGroupsByNameParams) error
|
InsertUserGroupsByName(ctx context.Context, arg InsertUserGroupsByNameParams) error
|
||||||
InsertUserLink(ctx context.Context, arg InsertUserLinkParams) (UserLink, error)
|
InsertUserLink(ctx context.Context, arg InsertUserLinkParams) (UserLink, error)
|
||||||
@@ -396,6 +399,7 @@ type sqlcQuerier interface {
|
|||||||
ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx context.Context, templateID uuid.UUID) error
|
ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx context.Context, templateID uuid.UUID) error
|
||||||
RegisterWorkspaceProxy(ctx context.Context, arg RegisterWorkspaceProxyParams) (WorkspaceProxy, error)
|
RegisterWorkspaceProxy(ctx context.Context, arg RegisterWorkspaceProxyParams) (WorkspaceProxy, error)
|
||||||
RemoveUserFromAllGroups(ctx context.Context, userID uuid.UUID) error
|
RemoveUserFromAllGroups(ctx context.Context, userID uuid.UUID) error
|
||||||
|
RemoveUserFromGroups(ctx context.Context, arg RemoveUserFromGroupsParams) ([]uuid.UUID, error)
|
||||||
RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error
|
RevokeDBCryptKey(ctx context.Context, activeKeyDigest string) error
|
||||||
// Non blocking lock. Returns true if the lock was acquired, false otherwise.
|
// Non blocking lock. Returns true if the lock was acquired, false otherwise.
|
||||||
//
|
//
|
||||||
|
|||||||
@@ -1446,6 +1446,56 @@ func (q *sqlQuerier) InsertGroupMember(ctx context.Context, arg InsertGroupMembe
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const insertUserGroupsByID = `-- name: InsertUserGroupsByID :many
|
||||||
|
WITH groups AS (
|
||||||
|
SELECT
|
||||||
|
id
|
||||||
|
FROM
|
||||||
|
groups
|
||||||
|
WHERE
|
||||||
|
groups.id = ANY($2 :: uuid [])
|
||||||
|
)
|
||||||
|
INSERT INTO
|
||||||
|
group_members (user_id, group_id)
|
||||||
|
SELECT
|
||||||
|
$1,
|
||||||
|
groups.id
|
||||||
|
FROM
|
||||||
|
groups
|
||||||
|
ON CONFLICT DO NOTHING
|
||||||
|
RETURNING group_id
|
||||||
|
`
|
||||||
|
|
||||||
|
type InsertUserGroupsByIDParams struct {
|
||||||
|
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||||
|
GroupIds []uuid.UUID `db:"group_ids" json:"group_ids"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// InsertUserGroupsByID adds a user to all provided groups, if they exist.
|
||||||
|
// If there is a conflict, the user is already a member
|
||||||
|
func (q *sqlQuerier) InsertUserGroupsByID(ctx context.Context, arg InsertUserGroupsByIDParams) ([]uuid.UUID, error) {
|
||||||
|
rows, err := q.db.QueryContext(ctx, insertUserGroupsByID, arg.UserID, pq.Array(arg.GroupIds))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
var items []uuid.UUID
|
||||||
|
for rows.Next() {
|
||||||
|
var group_id uuid.UUID
|
||||||
|
if err := rows.Scan(&group_id); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
items = append(items, group_id)
|
||||||
|
}
|
||||||
|
if err := rows.Close(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return items, nil
|
||||||
|
}
|
||||||
|
|
||||||
const insertUserGroupsByName = `-- name: InsertUserGroupsByName :exec
|
const insertUserGroupsByName = `-- name: InsertUserGroupsByName :exec
|
||||||
WITH groups AS (
|
WITH groups AS (
|
||||||
SELECT
|
SELECT
|
||||||
@@ -1489,6 +1539,43 @@ func (q *sqlQuerier) RemoveUserFromAllGroups(ctx context.Context, userID uuid.UU
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const removeUserFromGroups = `-- name: RemoveUserFromGroups :many
|
||||||
|
DELETE FROM
|
||||||
|
group_members
|
||||||
|
WHERE
|
||||||
|
user_id = $1 AND
|
||||||
|
group_id = ANY($2 :: uuid [])
|
||||||
|
RETURNING group_id
|
||||||
|
`
|
||||||
|
|
||||||
|
type RemoveUserFromGroupsParams struct {
|
||||||
|
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||||
|
GroupIds []uuid.UUID `db:"group_ids" json:"group_ids"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *sqlQuerier) RemoveUserFromGroups(ctx context.Context, arg RemoveUserFromGroupsParams) ([]uuid.UUID, error) {
|
||||||
|
rows, err := q.db.QueryContext(ctx, removeUserFromGroups, arg.UserID, pq.Array(arg.GroupIds))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
var items []uuid.UUID
|
||||||
|
for rows.Next() {
|
||||||
|
var group_id uuid.UUID
|
||||||
|
if err := rows.Scan(&group_id); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
items = append(items, group_id)
|
||||||
|
}
|
||||||
|
if err := rows.Close(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if err := rows.Err(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return items, nil
|
||||||
|
}
|
||||||
|
|
||||||
const deleteGroupByID = `-- name: DeleteGroupByID :exec
|
const deleteGroupByID = `-- name: DeleteGroupByID :exec
|
||||||
DELETE FROM
|
DELETE FROM
|
||||||
groups
|
groups
|
||||||
@@ -1592,11 +1679,16 @@ WHERE
|
|||||||
)
|
)
|
||||||
ELSE true
|
ELSE true
|
||||||
END
|
END
|
||||||
|
AND CASE WHEN array_length($3 :: text[], 1) > 0 THEN
|
||||||
|
groups.name = ANY($3)
|
||||||
|
ELSE true
|
||||||
|
END
|
||||||
`
|
`
|
||||||
|
|
||||||
type GetGroupsParams struct {
|
type GetGroupsParams struct {
|
||||||
OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"`
|
OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"`
|
||||||
HasMemberID uuid.UUID `db:"has_member_id" json:"has_member_id"`
|
HasMemberID uuid.UUID `db:"has_member_id" json:"has_member_id"`
|
||||||
|
GroupNames []string `db:"group_names" json:"group_names"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type GetGroupsRow struct {
|
type GetGroupsRow struct {
|
||||||
@@ -1606,7 +1698,7 @@ type GetGroupsRow struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (q *sqlQuerier) GetGroups(ctx context.Context, arg GetGroupsParams) ([]GetGroupsRow, error) {
|
func (q *sqlQuerier) GetGroups(ctx context.Context, arg GetGroupsParams) ([]GetGroupsRow, error) {
|
||||||
rows, err := q.db.QueryContext(ctx, getGroups, arg.OrganizationID, arg.HasMemberID)
|
rows, err := q.db.QueryContext(ctx, getGroups, arg.OrganizationID, arg.HasMemberID, pq.Array(arg.GroupNames))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -29,12 +29,41 @@ SELECT
|
|||||||
FROM
|
FROM
|
||||||
groups;
|
groups;
|
||||||
|
|
||||||
|
-- InsertUserGroupsByID adds a user to all provided groups, if they exist.
|
||||||
|
-- name: InsertUserGroupsByID :many
|
||||||
|
WITH groups AS (
|
||||||
|
SELECT
|
||||||
|
id
|
||||||
|
FROM
|
||||||
|
groups
|
||||||
|
WHERE
|
||||||
|
groups.id = ANY(@group_ids :: uuid [])
|
||||||
|
)
|
||||||
|
INSERT INTO
|
||||||
|
group_members (user_id, group_id)
|
||||||
|
SELECT
|
||||||
|
@user_id,
|
||||||
|
groups.id
|
||||||
|
FROM
|
||||||
|
groups
|
||||||
|
-- If there is a conflict, the user is already a member
|
||||||
|
ON CONFLICT DO NOTHING
|
||||||
|
RETURNING group_id;
|
||||||
|
|
||||||
-- name: RemoveUserFromAllGroups :exec
|
-- name: RemoveUserFromAllGroups :exec
|
||||||
DELETE FROM
|
DELETE FROM
|
||||||
group_members
|
group_members
|
||||||
WHERE
|
WHERE
|
||||||
user_id = @user_id;
|
user_id = @user_id;
|
||||||
|
|
||||||
|
-- name: RemoveUserFromGroups :many
|
||||||
|
DELETE FROM
|
||||||
|
group_members
|
||||||
|
WHERE
|
||||||
|
user_id = @user_id AND
|
||||||
|
group_id = ANY(@group_ids :: uuid [])
|
||||||
|
RETURNING group_id;
|
||||||
|
|
||||||
-- name: InsertGroupMember :exec
|
-- name: InsertGroupMember :exec
|
||||||
INSERT INTO
|
INSERT INTO
|
||||||
group_members (user_id, group_id)
|
group_members (user_id, group_id)
|
||||||
|
|||||||
@@ -52,6 +52,10 @@ WHERE
|
|||||||
)
|
)
|
||||||
ELSE true
|
ELSE true
|
||||||
END
|
END
|
||||||
|
AND CASE WHEN array_length(@group_names :: text[], 1) > 0 THEN
|
||||||
|
groups.name = ANY(@group_names)
|
||||||
|
ELSE true
|
||||||
|
END
|
||||||
;
|
;
|
||||||
|
|
||||||
-- name: InsertGroup :one
|
-- name: InsertGroup :one
|
||||||
|
|||||||
@@ -0,0 +1,416 @@
|
|||||||
|
package idpsync
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"regexp"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt/v4"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"golang.org/x/xerrors"
|
||||||
|
|
||||||
|
"cdr.dev/slog"
|
||||||
|
"github.com/coder/coder/v2/coderd/database"
|
||||||
|
"github.com/coder/coder/v2/coderd/database/db2sdk"
|
||||||
|
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||||
|
"github.com/coder/coder/v2/coderd/runtimeconfig"
|
||||||
|
"github.com/coder/coder/v2/coderd/util/slice"
|
||||||
|
)
|
||||||
|
|
||||||
|
type GroupParams struct {
|
||||||
|
// SyncEnabled if false will skip syncing the user's groups
|
||||||
|
SyncEnabled bool
|
||||||
|
MergedClaims jwt.MapClaims
|
||||||
|
}
|
||||||
|
|
||||||
|
func (AGPLIDPSync) GroupSyncEnabled() bool {
|
||||||
|
// AGPL does not support syncing groups.
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s AGPLIDPSync) GroupSyncSettings() runtimeconfig.RuntimeEntry[*GroupSyncSettings] {
|
||||||
|
return s.Group
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s AGPLIDPSync) ParseGroupClaims(_ context.Context, _ jwt.MapClaims) (GroupParams, *HTTPError) {
|
||||||
|
return GroupParams{
|
||||||
|
SyncEnabled: s.GroupSyncEnabled(),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s AGPLIDPSync) SyncGroups(ctx context.Context, db database.Store, user database.User, params GroupParams) error {
|
||||||
|
// Nothing happens if sync is not enabled
|
||||||
|
if !params.SyncEnabled {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// nolint:gocritic // all syncing is done as a system user
|
||||||
|
ctx = dbauthz.AsSystemRestricted(ctx)
|
||||||
|
|
||||||
|
// Only care about the default org for deployment settings if the
|
||||||
|
// legacy deployment settings exist.
|
||||||
|
defaultOrgID := uuid.Nil
|
||||||
|
// Default organization is configured via legacy deployment values
|
||||||
|
if s.DeploymentSyncSettings.Legacy.GroupField != "" {
|
||||||
|
defaultOrganization, err := db.GetDefaultOrganization(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return xerrors.Errorf("get default organization: %w", err)
|
||||||
|
}
|
||||||
|
defaultOrgID = defaultOrganization.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
err := db.InTx(func(tx database.Store) error {
|
||||||
|
userGroups, err := tx.GetGroups(ctx, database.GetGroupsParams{
|
||||||
|
HasMemberID: user.ID,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return xerrors.Errorf("get user groups: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Figure out which organizations the user is a member of.
|
||||||
|
// The "Everyone" group is always included, so we can infer organization
|
||||||
|
// membership via the groups the user is in.
|
||||||
|
userOrgs := make(map[uuid.UUID][]database.GetGroupsRow)
|
||||||
|
for _, g := range userGroups {
|
||||||
|
g := g
|
||||||
|
userOrgs[g.Group.OrganizationID] = append(userOrgs[g.Group.OrganizationID], g)
|
||||||
|
}
|
||||||
|
|
||||||
|
// For each org, we need to fetch the sync settings
|
||||||
|
// This loop also handles any legacy settings for the default
|
||||||
|
// organization.
|
||||||
|
orgSettings := make(map[uuid.UUID]GroupSyncSettings)
|
||||||
|
for orgID := range userOrgs {
|
||||||
|
orgResolver := s.Manager.OrganizationResolver(tx, orgID)
|
||||||
|
settings, err := s.SyncSettings.Group.Resolve(ctx, orgResolver)
|
||||||
|
if err != nil {
|
||||||
|
if !xerrors.Is(err, runtimeconfig.ErrEntryNotFound) {
|
||||||
|
return xerrors.Errorf("resolve group sync settings: %w", err)
|
||||||
|
}
|
||||||
|
// Default to not being configured
|
||||||
|
settings = &GroupSyncSettings{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Legacy deployment settings will override empty settings.
|
||||||
|
if orgID == defaultOrgID && settings.Field == "" {
|
||||||
|
settings = &GroupSyncSettings{
|
||||||
|
Field: s.Legacy.GroupField,
|
||||||
|
LegacyNameMapping: s.Legacy.GroupMapping,
|
||||||
|
RegexFilter: s.Legacy.GroupFilter,
|
||||||
|
AutoCreateMissing: s.Legacy.CreateMissingGroups,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
orgSettings[orgID] = *settings
|
||||||
|
}
|
||||||
|
|
||||||
|
// groupIDsToAdd & groupIDsToRemove are the final group differences
|
||||||
|
// needed to be applied to user. The loop below will iterate over all
|
||||||
|
// organizations the user is in, and determine the diffs.
|
||||||
|
// The diffs are applied as a batch sql query, rather than each
|
||||||
|
// organization having to execute a query.
|
||||||
|
groupIDsToAdd := make([]uuid.UUID, 0)
|
||||||
|
groupIDsToRemove := make([]uuid.UUID, 0)
|
||||||
|
// For each org, determine which groups the user should land in
|
||||||
|
for orgID, settings := range orgSettings {
|
||||||
|
if settings.Field == "" {
|
||||||
|
// No group sync enabled for this org, so do nothing.
|
||||||
|
// The user can remain in their groups for this org.
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// expectedGroups is the set of groups the IDP expects the
|
||||||
|
// user to be a member of.
|
||||||
|
expectedGroups, err := settings.ParseClaims(orgID, params.MergedClaims)
|
||||||
|
if err != nil {
|
||||||
|
s.Logger.Debug(ctx, "failed to parse claims for groups",
|
||||||
|
slog.F("organization_field", s.GroupField),
|
||||||
|
slog.F("organization_id", orgID),
|
||||||
|
slog.Error(err),
|
||||||
|
)
|
||||||
|
// Unsure where to raise this error on the UI or database.
|
||||||
|
// TODO: This error prevents group sync, but we have no way
|
||||||
|
// to raise this to an org admin. Come up with a solution to
|
||||||
|
// notify the admin and user of this issue.
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// Everyone group is always implied, so include it.
|
||||||
|
expectedGroups = append(expectedGroups, ExpectedGroup{
|
||||||
|
OrganizationID: orgID,
|
||||||
|
GroupID: &orgID,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Now we know what groups the user should be in for a given org,
|
||||||
|
// determine if we have to do any group updates to sync the user's
|
||||||
|
// state.
|
||||||
|
existingGroups := userOrgs[orgID]
|
||||||
|
existingGroupsTyped := db2sdk.List(existingGroups, func(f database.GetGroupsRow) ExpectedGroup {
|
||||||
|
return ExpectedGroup{
|
||||||
|
OrganizationID: orgID,
|
||||||
|
GroupID: &f.Group.ID,
|
||||||
|
GroupName: &f.Group.Name,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
add, remove := slice.SymmetricDifferenceFunc(existingGroupsTyped, expectedGroups, func(a, b ExpectedGroup) bool {
|
||||||
|
return a.Equal(b)
|
||||||
|
})
|
||||||
|
|
||||||
|
for _, r := range remove {
|
||||||
|
if r.GroupID == nil {
|
||||||
|
// This should never happen. All group removals come from the
|
||||||
|
// existing set, which come from the db. All groups from the
|
||||||
|
// database have IDs. This code is purely defensive.
|
||||||
|
detail := "user:" + user.Username
|
||||||
|
if r.GroupName != nil {
|
||||||
|
detail += fmt.Sprintf(" from group %s", *r.GroupName)
|
||||||
|
}
|
||||||
|
return xerrors.Errorf("removal group has nil ID, which should never happen: %s", detail)
|
||||||
|
}
|
||||||
|
groupIDsToRemove = append(groupIDsToRemove, *r.GroupID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleMissingGroups will add the new groups to the org if
|
||||||
|
// the settings specify. It will convert all group names into uuids
|
||||||
|
// for easier assignment.
|
||||||
|
// TODO: This code should be batched at the end of the for loop.
|
||||||
|
// Optimizing this is being pushed because if AutoCreate is disabled,
|
||||||
|
// this code will only add cost on the first login for each user.
|
||||||
|
// AutoCreate is usually disabled for large deployments.
|
||||||
|
// For small deployments, this is less of a problem.
|
||||||
|
assignGroups, err := settings.HandleMissingGroups(ctx, tx, orgID, add)
|
||||||
|
if err != nil {
|
||||||
|
return xerrors.Errorf("handle missing groups: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
groupIDsToAdd = append(groupIDsToAdd, assignGroups...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ApplyGroupDifference will take the total adds and removes, and apply
|
||||||
|
// them.
|
||||||
|
err = s.ApplyGroupDifference(ctx, tx, user, groupIDsToAdd, groupIDsToRemove)
|
||||||
|
if err != nil {
|
||||||
|
return xerrors.Errorf("apply group difference: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}, nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ApplyGroupDifference will add and remove the user from the specified groups.
|
||||||
|
func (s AGPLIDPSync) ApplyGroupDifference(ctx context.Context, tx database.Store, user database.User, add []uuid.UUID, removeIDs []uuid.UUID) error {
|
||||||
|
if len(removeIDs) > 0 {
|
||||||
|
removedGroupIDs, err := tx.RemoveUserFromGroups(ctx, database.RemoveUserFromGroupsParams{
|
||||||
|
UserID: user.ID,
|
||||||
|
GroupIds: removeIDs,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return xerrors.Errorf("remove user from %d groups: %w", len(removeIDs), err)
|
||||||
|
}
|
||||||
|
if len(removedGroupIDs) != len(removeIDs) {
|
||||||
|
s.Logger.Debug(ctx, "user not removed from expected number of groups",
|
||||||
|
slog.F("user_id", user.ID),
|
||||||
|
slog.F("groups_removed_count", len(removedGroupIDs)),
|
||||||
|
slog.F("expected_count", len(removeIDs)),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(add) > 0 {
|
||||||
|
add = slice.Unique(add)
|
||||||
|
// Defensive programming to only insert uniques.
|
||||||
|
assignedGroupIDs, err := tx.InsertUserGroupsByID(ctx, database.InsertUserGroupsByIDParams{
|
||||||
|
UserID: user.ID,
|
||||||
|
GroupIds: add,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return xerrors.Errorf("insert user into %d groups: %w", len(add), err)
|
||||||
|
}
|
||||||
|
if len(assignedGroupIDs) != len(add) {
|
||||||
|
s.Logger.Debug(ctx, "user not assigned to expected number of groups",
|
||||||
|
slog.F("user_id", user.ID),
|
||||||
|
slog.F("groups_assigned_count", len(assignedGroupIDs)),
|
||||||
|
slog.F("expected_count", len(add)),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type GroupSyncSettings struct {
|
||||||
|
// Field selects the claim field to be used as the created user's
|
||||||
|
// groups. If the group field is the empty string, then no group updates
|
||||||
|
// will ever come from the OIDC provider.
|
||||||
|
Field string `json:"field"`
|
||||||
|
// Mapping maps from an OIDC group --> Coder group ID
|
||||||
|
Mapping map[string][]uuid.UUID `json:"mapping"`
|
||||||
|
// RegexFilter is a regular expression that filters the groups returned by
|
||||||
|
// the OIDC provider. Any group not matched by this regex will be ignored.
|
||||||
|
// If the group filter is nil, then no group filtering will occur.
|
||||||
|
RegexFilter *regexp.Regexp `json:"regex_filter"`
|
||||||
|
// AutoCreateMissing controls whether groups returned by the OIDC provider
|
||||||
|
// are automatically created in Coder if they are missing.
|
||||||
|
AutoCreateMissing bool `json:"auto_create_missing_groups"`
|
||||||
|
// LegacyNameMapping is deprecated. It remaps an IDP group name to
|
||||||
|
// a Coder group name. Since configuration is now done at runtime,
|
||||||
|
// group IDs are used to account for group renames.
|
||||||
|
// For legacy configurations, this config option has to remain.
|
||||||
|
// Deprecated: Use Mapping instead.
|
||||||
|
LegacyNameMapping map[string]string `json:"legacy_group_name_mapping,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GroupSyncSettings) Set(v string) error {
|
||||||
|
return json.Unmarshal([]byte(v), s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *GroupSyncSettings) String() string {
|
||||||
|
return runtimeconfig.JSONString(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ExpectedGroup struct {
|
||||||
|
OrganizationID uuid.UUID
|
||||||
|
GroupID *uuid.UUID
|
||||||
|
GroupName *string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Equal compares two ExpectedGroups. The org id must be the same.
|
||||||
|
// If the group ID is set, it will be compared and take priority, ignoring the
|
||||||
|
// name value. So 2 groups with the same ID but different names will be
|
||||||
|
// considered equal.
|
||||||
|
func (a ExpectedGroup) Equal(b ExpectedGroup) bool {
|
||||||
|
// Must match
|
||||||
|
if a.OrganizationID != b.OrganizationID {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
// Only the name or the name needs to be checked, priority is given to the ID.
|
||||||
|
if a.GroupID != nil && b.GroupID != nil {
|
||||||
|
return *a.GroupID == *b.GroupID
|
||||||
|
}
|
||||||
|
if a.GroupName != nil && b.GroupName != nil {
|
||||||
|
return *a.GroupName == *b.GroupName
|
||||||
|
}
|
||||||
|
|
||||||
|
// If everything is nil, it is equal. Although a bit pointless
|
||||||
|
if a.GroupID == nil && b.GroupID == nil &&
|
||||||
|
a.GroupName == nil && b.GroupName == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseClaims will take the merged claims from the IDP and return the groups
|
||||||
|
// the user is expected to be a member of. The expected group can either be a
|
||||||
|
// name or an ID.
|
||||||
|
// It is unfortunate we cannot use exclusively names or exclusively IDs.
|
||||||
|
// When configuring though, if a group is mapped from "A" -> "UUID 1234", and
|
||||||
|
// the group "UUID 1234" is renamed, we want to maintain the mapping.
|
||||||
|
// We have to keep names because group sync supports syncing groups by name if
|
||||||
|
// the external IDP group name matches the Coder one.
|
||||||
|
func (s GroupSyncSettings) ParseClaims(orgID uuid.UUID, mergedClaims jwt.MapClaims) ([]ExpectedGroup, error) {
|
||||||
|
groupsRaw, ok := mergedClaims[s.Field]
|
||||||
|
if !ok {
|
||||||
|
return []ExpectedGroup{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
parsedGroups, err := ParseStringSliceClaim(groupsRaw)
|
||||||
|
if err != nil {
|
||||||
|
return nil, xerrors.Errorf("parse groups field, unexpected type %T: %w", groupsRaw, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
groups := make([]ExpectedGroup, 0)
|
||||||
|
for _, group := range parsedGroups {
|
||||||
|
group := group
|
||||||
|
|
||||||
|
// Legacy group mappings happen before the regex filter.
|
||||||
|
mappedGroupName, ok := s.LegacyNameMapping[group]
|
||||||
|
if ok {
|
||||||
|
group = mappedGroupName
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only allow through groups that pass the regex
|
||||||
|
if s.RegexFilter != nil {
|
||||||
|
if !s.RegexFilter.MatchString(group) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mappedGroupIDs, ok := s.Mapping[group]
|
||||||
|
if ok {
|
||||||
|
for _, gid := range mappedGroupIDs {
|
||||||
|
gid := gid
|
||||||
|
groups = append(groups, ExpectedGroup{OrganizationID: orgID, GroupID: &gid})
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
groups = append(groups, ExpectedGroup{OrganizationID: orgID, GroupName: &group})
|
||||||
|
}
|
||||||
|
|
||||||
|
return groups, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// HandleMissingGroups ensures all ExpectedGroups convert to uuids.
|
||||||
|
// Groups can be referenced by name via legacy params or IDP group names.
|
||||||
|
// These group names are converted to IDs for easier assignment.
|
||||||
|
// Missing groups are created if AutoCreate is enabled.
|
||||||
|
// TODO: Batching this would be better, as this is 1 or 2 db calls per organization.
|
||||||
|
func (s GroupSyncSettings) HandleMissingGroups(ctx context.Context, tx database.Store, orgID uuid.UUID, add []ExpectedGroup) ([]uuid.UUID, error) {
|
||||||
|
// All expected that are missing IDs means the group does not exist
|
||||||
|
// in the database, or it is a legacy mapping, and we need to do a lookup.
|
||||||
|
var missingGroups []string
|
||||||
|
addIDs := make([]uuid.UUID, 0)
|
||||||
|
|
||||||
|
for _, expected := range add {
|
||||||
|
if expected.GroupID == nil && expected.GroupName != nil {
|
||||||
|
missingGroups = append(missingGroups, *expected.GroupName)
|
||||||
|
} else if expected.GroupID != nil {
|
||||||
|
// Keep the IDs to sync the groups.
|
||||||
|
addIDs = append(addIDs, *expected.GroupID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.AutoCreateMissing && len(missingGroups) > 0 {
|
||||||
|
// Insert any missing groups. If the groups already exist, this is a noop.
|
||||||
|
_, err := tx.InsertMissingGroups(ctx, database.InsertMissingGroupsParams{
|
||||||
|
OrganizationID: orgID,
|
||||||
|
Source: database.GroupSourceOidc,
|
||||||
|
GroupNames: missingGroups,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, xerrors.Errorf("insert missing groups: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fetch any missing groups by name. If they exist, their IDs will be
|
||||||
|
// matched and returned.
|
||||||
|
if len(missingGroups) > 0 {
|
||||||
|
// Do name lookups for all groups that are missing IDs.
|
||||||
|
newGroups, err := tx.GetGroups(ctx, database.GetGroupsParams{
|
||||||
|
OrganizationID: orgID,
|
||||||
|
HasMemberID: uuid.UUID{},
|
||||||
|
GroupNames: missingGroups,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, xerrors.Errorf("get groups by names: %w", err)
|
||||||
|
}
|
||||||
|
for _, g := range newGroups {
|
||||||
|
addIDs = append(addIDs, g.Group.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return addIDs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ConvertAllowList(allowList []string) map[string]struct{} {
|
||||||
|
allowMap := make(map[string]struct{}, len(allowList))
|
||||||
|
for _, group := range allowList {
|
||||||
|
allowMap[group] = struct{}{}
|
||||||
|
}
|
||||||
|
return allowMap
|
||||||
|
}
|
||||||
@@ -0,0 +1,814 @@
|
|||||||
|
package idpsync_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"database/sql"
|
||||||
|
"regexp"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt/v4"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/exp/slices"
|
||||||
|
"golang.org/x/xerrors"
|
||||||
|
|
||||||
|
"cdr.dev/slog/sloggers/slogtest"
|
||||||
|
"github.com/coder/coder/v2/coderd/coderdtest"
|
||||||
|
"github.com/coder/coder/v2/coderd/database"
|
||||||
|
"github.com/coder/coder/v2/coderd/database/db2sdk"
|
||||||
|
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||||
|
"github.com/coder/coder/v2/coderd/database/dbgen"
|
||||||
|
"github.com/coder/coder/v2/coderd/database/dbtestutil"
|
||||||
|
"github.com/coder/coder/v2/coderd/idpsync"
|
||||||
|
"github.com/coder/coder/v2/coderd/runtimeconfig"
|
||||||
|
"github.com/coder/coder/v2/coderd/util/ptr"
|
||||||
|
"github.com/coder/coder/v2/testutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestParseGroupClaims(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
t.Run("EmptyConfig", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}),
|
||||||
|
runtimeconfig.NewManager(),
|
||||||
|
idpsync.DeploymentSyncSettings{})
|
||||||
|
|
||||||
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||||
|
|
||||||
|
params, err := s.ParseGroupClaims(ctx, jwt.MapClaims{})
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
require.False(t, params.SyncEnabled)
|
||||||
|
})
|
||||||
|
|
||||||
|
// AllowList has no effect in AGPL
|
||||||
|
t.Run("AllowList", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}),
|
||||||
|
runtimeconfig.NewManager(),
|
||||||
|
idpsync.DeploymentSyncSettings{
|
||||||
|
GroupField: "groups",
|
||||||
|
GroupAllowList: map[string]struct{}{
|
||||||
|
"foo": {},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||||
|
|
||||||
|
params, err := s.ParseGroupClaims(ctx, jwt.MapClaims{})
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.False(t, params.SyncEnabled)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGroupSyncTable(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
// Last checked, takes 30s with postgres on a fast machine.
|
||||||
|
if dbtestutil.WillUsePostgres() {
|
||||||
|
t.Skip("Skipping test because it populates a lot of db entries, which is slow on postgres.")
|
||||||
|
}
|
||||||
|
|
||||||
|
userClaims := jwt.MapClaims{
|
||||||
|
"groups": []string{
|
||||||
|
"foo", "bar", "baz",
|
||||||
|
"create-bar", "create-baz",
|
||||||
|
"legacy-bar",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
ids := coderdtest.NewDeterministicUUIDGenerator()
|
||||||
|
testCases := []orgSetupDefinition{
|
||||||
|
{
|
||||||
|
Name: "SwitchGroups",
|
||||||
|
Settings: &idpsync.GroupSyncSettings{
|
||||||
|
Field: "groups",
|
||||||
|
Mapping: map[string][]uuid.UUID{
|
||||||
|
"foo": {ids.ID("sg-foo"), ids.ID("sg-foo-2")},
|
||||||
|
"bar": {ids.ID("sg-bar")},
|
||||||
|
"baz": {ids.ID("sg-baz")},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Groups: map[uuid.UUID]bool{
|
||||||
|
uuid.New(): true,
|
||||||
|
uuid.New(): true,
|
||||||
|
// Extra groups
|
||||||
|
ids.ID("sg-foo"): false,
|
||||||
|
ids.ID("sg-foo-2"): false,
|
||||||
|
ids.ID("sg-bar"): false,
|
||||||
|
ids.ID("sg-baz"): false,
|
||||||
|
},
|
||||||
|
ExpectedGroups: []uuid.UUID{
|
||||||
|
ids.ID("sg-foo"),
|
||||||
|
ids.ID("sg-foo-2"),
|
||||||
|
ids.ID("sg-bar"),
|
||||||
|
ids.ID("sg-baz"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "StayInGroup",
|
||||||
|
Settings: &idpsync.GroupSyncSettings{
|
||||||
|
Field: "groups",
|
||||||
|
// Only match foo, so bar does not map
|
||||||
|
RegexFilter: regexp.MustCompile("^foo$"),
|
||||||
|
Mapping: map[string][]uuid.UUID{
|
||||||
|
"foo": {ids.ID("gg-foo"), uuid.New()},
|
||||||
|
"bar": {ids.ID("gg-bar")},
|
||||||
|
"baz": {ids.ID("gg-baz")},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Groups: map[uuid.UUID]bool{
|
||||||
|
ids.ID("gg-foo"): true,
|
||||||
|
ids.ID("gg-bar"): false,
|
||||||
|
},
|
||||||
|
ExpectedGroups: []uuid.UUID{
|
||||||
|
ids.ID("gg-foo"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "UserJoinsGroups",
|
||||||
|
Settings: &idpsync.GroupSyncSettings{
|
||||||
|
Field: "groups",
|
||||||
|
Mapping: map[string][]uuid.UUID{
|
||||||
|
"foo": {ids.ID("ng-foo"), uuid.New()},
|
||||||
|
"bar": {ids.ID("ng-bar"), ids.ID("ng-bar-2")},
|
||||||
|
"baz": {ids.ID("ng-baz")},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Groups: map[uuid.UUID]bool{
|
||||||
|
ids.ID("ng-foo"): false,
|
||||||
|
ids.ID("ng-bar"): false,
|
||||||
|
ids.ID("ng-bar-2"): false,
|
||||||
|
ids.ID("ng-baz"): false,
|
||||||
|
},
|
||||||
|
ExpectedGroups: []uuid.UUID{
|
||||||
|
ids.ID("ng-foo"),
|
||||||
|
ids.ID("ng-bar"),
|
||||||
|
ids.ID("ng-bar-2"),
|
||||||
|
ids.ID("ng-baz"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "CreateGroups",
|
||||||
|
Settings: &idpsync.GroupSyncSettings{
|
||||||
|
Field: "groups",
|
||||||
|
RegexFilter: regexp.MustCompile("^create"),
|
||||||
|
AutoCreateMissing: true,
|
||||||
|
},
|
||||||
|
Groups: map[uuid.UUID]bool{},
|
||||||
|
ExpectedGroupNames: []string{
|
||||||
|
"create-bar",
|
||||||
|
"create-baz",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "GroupNamesNoMapping",
|
||||||
|
Settings: &idpsync.GroupSyncSettings{
|
||||||
|
Field: "groups",
|
||||||
|
RegexFilter: regexp.MustCompile(".*"),
|
||||||
|
AutoCreateMissing: false,
|
||||||
|
},
|
||||||
|
GroupNames: map[string]bool{
|
||||||
|
"foo": false,
|
||||||
|
"bar": false,
|
||||||
|
"goob": true,
|
||||||
|
},
|
||||||
|
ExpectedGroupNames: []string{
|
||||||
|
"foo",
|
||||||
|
"bar",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "NoUser",
|
||||||
|
Settings: &idpsync.GroupSyncSettings{
|
||||||
|
Field: "groups",
|
||||||
|
Mapping: map[string][]uuid.UUID{
|
||||||
|
// Extra ID that does not map to a group
|
||||||
|
"foo": {ids.ID("ow-foo"), uuid.New()},
|
||||||
|
},
|
||||||
|
RegexFilter: nil,
|
||||||
|
AutoCreateMissing: false,
|
||||||
|
},
|
||||||
|
NotMember: true,
|
||||||
|
Groups: map[uuid.UUID]bool{
|
||||||
|
ids.ID("ow-foo"): false,
|
||||||
|
ids.ID("ow-bar"): false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "NoSettingsNoUser",
|
||||||
|
Settings: nil,
|
||||||
|
Groups: map[uuid.UUID]bool{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "LegacyMapping",
|
||||||
|
Settings: &idpsync.GroupSyncSettings{
|
||||||
|
Field: "groups",
|
||||||
|
RegexFilter: regexp.MustCompile("^legacy"),
|
||||||
|
LegacyNameMapping: map[string]string{
|
||||||
|
"create-bar": "legacy-bar",
|
||||||
|
"foo": "legacy-foo",
|
||||||
|
"bop": "legacy-bop",
|
||||||
|
},
|
||||||
|
AutoCreateMissing: true,
|
||||||
|
},
|
||||||
|
Groups: map[uuid.UUID]bool{
|
||||||
|
ids.ID("lg-foo"): true,
|
||||||
|
},
|
||||||
|
GroupNames: map[string]bool{
|
||||||
|
"legacy-foo": false,
|
||||||
|
"extra": true,
|
||||||
|
"legacy-bop": true,
|
||||||
|
},
|
||||||
|
ExpectedGroupNames: []string{
|
||||||
|
"legacy-bar",
|
||||||
|
"legacy-foo",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
tc := tc
|
||||||
|
t.Run(tc.Name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
db, _ := dbtestutil.NewDB(t)
|
||||||
|
manager := runtimeconfig.NewManager()
|
||||||
|
s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}),
|
||||||
|
manager,
|
||||||
|
idpsync.DeploymentSyncSettings{
|
||||||
|
GroupField: "groups",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx := testutil.Context(t, testutil.WaitSuperLong)
|
||||||
|
user := dbgen.User(t, db, database.User{})
|
||||||
|
orgID := uuid.New()
|
||||||
|
SetupOrganization(t, s, db, user, orgID, tc)
|
||||||
|
|
||||||
|
// Do the group sync!
|
||||||
|
err := s.SyncGroups(ctx, db, user, idpsync.GroupParams{
|
||||||
|
SyncEnabled: true,
|
||||||
|
MergedClaims: userClaims,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
tc.Assert(t, orgID, db, user)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// AllTogether runs the entire tabled test as a singular user and
|
||||||
|
// deployment. This tests all organizations being synced together.
|
||||||
|
// The reason we do them individually, is that it is much easier to
|
||||||
|
// debug a single test case.
|
||||||
|
t.Run("AllTogether", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
db, _ := dbtestutil.NewDB(t)
|
||||||
|
manager := runtimeconfig.NewManager()
|
||||||
|
s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}),
|
||||||
|
manager,
|
||||||
|
// Also sync the default org!
|
||||||
|
idpsync.DeploymentSyncSettings{
|
||||||
|
GroupField: "groups",
|
||||||
|
Legacy: idpsync.DefaultOrgLegacySettings{
|
||||||
|
GroupField: "groups",
|
||||||
|
GroupMapping: map[string]string{
|
||||||
|
"foo": "legacy-foo",
|
||||||
|
"baz": "legacy-baz",
|
||||||
|
},
|
||||||
|
GroupFilter: regexp.MustCompile("^legacy"),
|
||||||
|
CreateMissingGroups: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
ctx := testutil.Context(t, testutil.WaitSuperLong)
|
||||||
|
user := dbgen.User(t, db, database.User{})
|
||||||
|
|
||||||
|
var asserts []func(t *testing.T)
|
||||||
|
// The default org is also going to do something
|
||||||
|
def := orgSetupDefinition{
|
||||||
|
Name: "DefaultOrg",
|
||||||
|
GroupNames: map[string]bool{
|
||||||
|
"legacy-foo": false,
|
||||||
|
"legacy-baz": true,
|
||||||
|
"random": true,
|
||||||
|
},
|
||||||
|
// No settings, because they come from the deployment values
|
||||||
|
Settings: nil,
|
||||||
|
ExpectedGroups: nil,
|
||||||
|
ExpectedGroupNames: []string{"legacy-foo", "legacy-baz", "legacy-bar"},
|
||||||
|
}
|
||||||
|
|
||||||
|
//nolint:gocritic // testing
|
||||||
|
defOrg, err := db.GetDefaultOrganization(dbauthz.AsSystemRestricted(ctx))
|
||||||
|
require.NoError(t, err)
|
||||||
|
SetupOrganization(t, s, db, user, defOrg.ID, def)
|
||||||
|
asserts = append(asserts, func(t *testing.T) {
|
||||||
|
t.Run(def.Name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
def.Assert(t, defOrg.ID, db, user)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
tc := tc
|
||||||
|
|
||||||
|
orgID := uuid.New()
|
||||||
|
SetupOrganization(t, s, db, user, orgID, tc)
|
||||||
|
asserts = append(asserts, func(t *testing.T) {
|
||||||
|
t.Run(tc.Name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
tc.Assert(t, orgID, db, user)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
asserts = append(asserts, func(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
def.Assert(t, defOrg.ID, db, user)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Do the group sync!
|
||||||
|
err = s.SyncGroups(ctx, db, user, idpsync.GroupParams{
|
||||||
|
SyncEnabled: true,
|
||||||
|
MergedClaims: userClaims,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
for _, assert := range asserts {
|
||||||
|
assert(t)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSyncDisabled(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
if dbtestutil.WillUsePostgres() {
|
||||||
|
t.Skip("Skipping test because it populates a lot of db entries, which is slow on postgres.")
|
||||||
|
}
|
||||||
|
|
||||||
|
db, _ := dbtestutil.NewDB(t)
|
||||||
|
manager := runtimeconfig.NewManager()
|
||||||
|
s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}),
|
||||||
|
manager,
|
||||||
|
idpsync.DeploymentSyncSettings{},
|
||||||
|
)
|
||||||
|
|
||||||
|
ids := coderdtest.NewDeterministicUUIDGenerator()
|
||||||
|
ctx := testutil.Context(t, testutil.WaitSuperLong)
|
||||||
|
user := dbgen.User(t, db, database.User{})
|
||||||
|
orgID := uuid.New()
|
||||||
|
|
||||||
|
def := orgSetupDefinition{
|
||||||
|
Name: "SyncDisabled",
|
||||||
|
Groups: map[uuid.UUID]bool{
|
||||||
|
ids.ID("foo"): true,
|
||||||
|
ids.ID("bar"): true,
|
||||||
|
ids.ID("baz"): false,
|
||||||
|
ids.ID("bop"): false,
|
||||||
|
},
|
||||||
|
Settings: &idpsync.GroupSyncSettings{
|
||||||
|
Field: "groups",
|
||||||
|
Mapping: map[string][]uuid.UUID{
|
||||||
|
"foo": {ids.ID("foo")},
|
||||||
|
"baz": {ids.ID("baz")},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
ExpectedGroups: []uuid.UUID{
|
||||||
|
ids.ID("foo"),
|
||||||
|
ids.ID("bar"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
SetupOrganization(t, s, db, user, orgID, def)
|
||||||
|
|
||||||
|
// Do the group sync!
|
||||||
|
err := s.SyncGroups(ctx, db, user, idpsync.GroupParams{
|
||||||
|
SyncEnabled: false,
|
||||||
|
MergedClaims: jwt.MapClaims{
|
||||||
|
"groups": []string{"baz", "bop"},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
def.Assert(t, orgID, db, user)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestApplyGroupDifference is mainly testing the database functions
|
||||||
|
func TestApplyGroupDifference(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ids := coderdtest.NewDeterministicUUIDGenerator()
|
||||||
|
testCase := []struct {
|
||||||
|
Name string
|
||||||
|
Before map[uuid.UUID]bool
|
||||||
|
Add []uuid.UUID
|
||||||
|
Remove []uuid.UUID
|
||||||
|
Expect []uuid.UUID
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
Name: "Empty",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "AddFromNone",
|
||||||
|
Before: map[uuid.UUID]bool{
|
||||||
|
ids.ID("g1"): false,
|
||||||
|
},
|
||||||
|
Add: []uuid.UUID{
|
||||||
|
ids.ID("g1"),
|
||||||
|
},
|
||||||
|
Expect: []uuid.UUID{
|
||||||
|
ids.ID("g1"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "AddSome",
|
||||||
|
Before: map[uuid.UUID]bool{
|
||||||
|
ids.ID("g1"): true,
|
||||||
|
ids.ID("g2"): false,
|
||||||
|
ids.ID("g3"): false,
|
||||||
|
uuid.New(): false,
|
||||||
|
},
|
||||||
|
Add: []uuid.UUID{
|
||||||
|
ids.ID("g2"),
|
||||||
|
ids.ID("g3"),
|
||||||
|
},
|
||||||
|
Expect: []uuid.UUID{
|
||||||
|
ids.ID("g1"),
|
||||||
|
ids.ID("g2"),
|
||||||
|
ids.ID("g3"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "RemoveAll",
|
||||||
|
Before: map[uuid.UUID]bool{
|
||||||
|
uuid.New(): false,
|
||||||
|
ids.ID("g2"): true,
|
||||||
|
ids.ID("g3"): true,
|
||||||
|
},
|
||||||
|
Remove: []uuid.UUID{
|
||||||
|
ids.ID("g2"),
|
||||||
|
ids.ID("g3"),
|
||||||
|
},
|
||||||
|
Expect: []uuid.UUID{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "Mixed",
|
||||||
|
Before: map[uuid.UUID]bool{
|
||||||
|
// adds
|
||||||
|
ids.ID("a1"): true,
|
||||||
|
ids.ID("a2"): true,
|
||||||
|
ids.ID("a3"): false,
|
||||||
|
ids.ID("a4"): false,
|
||||||
|
// removes
|
||||||
|
ids.ID("r1"): true,
|
||||||
|
ids.ID("r2"): true,
|
||||||
|
ids.ID("r3"): false,
|
||||||
|
ids.ID("r4"): false,
|
||||||
|
// stable
|
||||||
|
ids.ID("s1"): true,
|
||||||
|
ids.ID("s2"): true,
|
||||||
|
// noise
|
||||||
|
uuid.New(): false,
|
||||||
|
uuid.New(): false,
|
||||||
|
},
|
||||||
|
Add: []uuid.UUID{
|
||||||
|
ids.ID("a1"), ids.ID("a2"),
|
||||||
|
ids.ID("a3"), ids.ID("a4"),
|
||||||
|
// Double up to try and confuse
|
||||||
|
ids.ID("a1"),
|
||||||
|
ids.ID("a4"),
|
||||||
|
},
|
||||||
|
Remove: []uuid.UUID{
|
||||||
|
ids.ID("r1"), ids.ID("r2"),
|
||||||
|
ids.ID("r3"), ids.ID("r4"),
|
||||||
|
// Double up to try and confuse
|
||||||
|
ids.ID("r1"),
|
||||||
|
ids.ID("r4"),
|
||||||
|
},
|
||||||
|
Expect: []uuid.UUID{
|
||||||
|
ids.ID("a1"), ids.ID("a2"), ids.ID("a3"), ids.ID("a4"),
|
||||||
|
ids.ID("s1"), ids.ID("s2"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCase {
|
||||||
|
tc := tc
|
||||||
|
t.Run(tc.Name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
mgr := runtimeconfig.NewManager()
|
||||||
|
db, _ := dbtestutil.NewDB(t)
|
||||||
|
|
||||||
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||||
|
//nolint:gocritic // testing
|
||||||
|
ctx = dbauthz.AsSystemRestricted(ctx)
|
||||||
|
|
||||||
|
org := dbgen.Organization(t, db, database.Organization{})
|
||||||
|
_, err := db.InsertAllUsersGroup(ctx, org.ID)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
user := dbgen.User(t, db, database.User{})
|
||||||
|
_ = dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
||||||
|
UserID: user.ID,
|
||||||
|
OrganizationID: org.ID,
|
||||||
|
})
|
||||||
|
|
||||||
|
for gid, in := range tc.Before {
|
||||||
|
group := dbgen.Group(t, db, database.Group{
|
||||||
|
ID: gid,
|
||||||
|
OrganizationID: org.ID,
|
||||||
|
})
|
||||||
|
if in {
|
||||||
|
_ = dbgen.GroupMember(t, db, database.GroupMemberTable{
|
||||||
|
UserID: user.ID,
|
||||||
|
GroupID: group.ID,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), mgr, idpsync.FromDeploymentValues(coderdtest.DeploymentValues(t)))
|
||||||
|
err = s.ApplyGroupDifference(context.Background(), db, user, tc.Add, tc.Remove)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
userGroups, err := db.GetGroups(ctx, database.GetGroupsParams{
|
||||||
|
HasMemberID: user.ID,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// assert
|
||||||
|
found := db2sdk.List(userGroups, func(g database.GetGroupsRow) uuid.UUID {
|
||||||
|
return g.Group.ID
|
||||||
|
})
|
||||||
|
|
||||||
|
// Add everyone group
|
||||||
|
require.ElementsMatch(t, append(tc.Expect, org.ID), found)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExpectedGroupEqual(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ids := coderdtest.NewDeterministicUUIDGenerator()
|
||||||
|
testCases := []struct {
|
||||||
|
Name string
|
||||||
|
A idpsync.ExpectedGroup
|
||||||
|
B idpsync.ExpectedGroup
|
||||||
|
Equal bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
Name: "Empty",
|
||||||
|
A: idpsync.ExpectedGroup{},
|
||||||
|
B: idpsync.ExpectedGroup{},
|
||||||
|
Equal: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "DifferentOrgs",
|
||||||
|
A: idpsync.ExpectedGroup{
|
||||||
|
OrganizationID: uuid.New(),
|
||||||
|
GroupID: ptr.Ref(ids.ID("g1")),
|
||||||
|
GroupName: nil,
|
||||||
|
},
|
||||||
|
B: idpsync.ExpectedGroup{
|
||||||
|
OrganizationID: uuid.New(),
|
||||||
|
GroupID: ptr.Ref(ids.ID("g1")),
|
||||||
|
GroupName: nil,
|
||||||
|
},
|
||||||
|
Equal: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "SameID",
|
||||||
|
A: idpsync.ExpectedGroup{
|
||||||
|
OrganizationID: ids.ID("org"),
|
||||||
|
GroupID: ptr.Ref(ids.ID("g1")),
|
||||||
|
GroupName: nil,
|
||||||
|
},
|
||||||
|
B: idpsync.ExpectedGroup{
|
||||||
|
OrganizationID: ids.ID("org"),
|
||||||
|
GroupID: ptr.Ref(ids.ID("g1")),
|
||||||
|
GroupName: nil,
|
||||||
|
},
|
||||||
|
Equal: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "DifferentIDs",
|
||||||
|
A: idpsync.ExpectedGroup{
|
||||||
|
OrganizationID: ids.ID("org"),
|
||||||
|
GroupID: ptr.Ref(uuid.New()),
|
||||||
|
GroupName: nil,
|
||||||
|
},
|
||||||
|
B: idpsync.ExpectedGroup{
|
||||||
|
OrganizationID: ids.ID("org"),
|
||||||
|
GroupID: ptr.Ref(uuid.New()),
|
||||||
|
GroupName: nil,
|
||||||
|
},
|
||||||
|
Equal: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "SameName",
|
||||||
|
A: idpsync.ExpectedGroup{
|
||||||
|
OrganizationID: ids.ID("org"),
|
||||||
|
GroupID: nil,
|
||||||
|
GroupName: ptr.Ref("foo"),
|
||||||
|
},
|
||||||
|
B: idpsync.ExpectedGroup{
|
||||||
|
OrganizationID: ids.ID("org"),
|
||||||
|
GroupID: nil,
|
||||||
|
GroupName: ptr.Ref("foo"),
|
||||||
|
},
|
||||||
|
Equal: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "DifferentName",
|
||||||
|
A: idpsync.ExpectedGroup{
|
||||||
|
OrganizationID: ids.ID("org"),
|
||||||
|
GroupID: nil,
|
||||||
|
GroupName: ptr.Ref("foo"),
|
||||||
|
},
|
||||||
|
B: idpsync.ExpectedGroup{
|
||||||
|
OrganizationID: ids.ID("org"),
|
||||||
|
GroupID: nil,
|
||||||
|
GroupName: ptr.Ref("bar"),
|
||||||
|
},
|
||||||
|
Equal: false,
|
||||||
|
},
|
||||||
|
// Edge cases
|
||||||
|
{
|
||||||
|
// A bit strange, but valid as ID takes priority.
|
||||||
|
// We assume 2 groups with the same ID are equal, even if
|
||||||
|
// their names are different. Names are mutable, IDs are not,
|
||||||
|
// so there is 0% chance they are different groups.
|
||||||
|
Name: "DifferentIDSameName",
|
||||||
|
A: idpsync.ExpectedGroup{
|
||||||
|
OrganizationID: ids.ID("org"),
|
||||||
|
GroupID: ptr.Ref(ids.ID("g1")),
|
||||||
|
GroupName: ptr.Ref("foo"),
|
||||||
|
},
|
||||||
|
B: idpsync.ExpectedGroup{
|
||||||
|
OrganizationID: ids.ID("org"),
|
||||||
|
GroupID: ptr.Ref(ids.ID("g1")),
|
||||||
|
GroupName: ptr.Ref("bar"),
|
||||||
|
},
|
||||||
|
Equal: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "MixedNils",
|
||||||
|
A: idpsync.ExpectedGroup{
|
||||||
|
OrganizationID: ids.ID("org"),
|
||||||
|
GroupID: ptr.Ref(ids.ID("g1")),
|
||||||
|
GroupName: nil,
|
||||||
|
},
|
||||||
|
B: idpsync.ExpectedGroup{
|
||||||
|
OrganizationID: ids.ID("org"),
|
||||||
|
GroupID: nil,
|
||||||
|
GroupName: ptr.Ref("bar"),
|
||||||
|
},
|
||||||
|
Equal: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "NoComparable",
|
||||||
|
A: idpsync.ExpectedGroup{
|
||||||
|
OrganizationID: ids.ID("org"),
|
||||||
|
GroupID: ptr.Ref(ids.ID("g1")),
|
||||||
|
GroupName: nil,
|
||||||
|
},
|
||||||
|
B: idpsync.ExpectedGroup{
|
||||||
|
OrganizationID: ids.ID("org"),
|
||||||
|
GroupID: nil,
|
||||||
|
GroupName: nil,
|
||||||
|
},
|
||||||
|
Equal: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
tc := tc
|
||||||
|
t.Run(tc.Name, func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
require.Equal(t, tc.Equal, tc.A.Equal(tc.B))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetupOrganization(t *testing.T, s *idpsync.AGPLIDPSync, db database.Store, user database.User, orgID uuid.UUID, def orgSetupDefinition) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
// Account that the org might be the default organization
|
||||||
|
org, err := db.GetOrganizationByID(context.Background(), orgID)
|
||||||
|
if xerrors.Is(err, sql.ErrNoRows) {
|
||||||
|
org = dbgen.Organization(t, db, database.Organization{
|
||||||
|
ID: orgID,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = db.InsertAllUsersGroup(context.Background(), org.ID)
|
||||||
|
if !database.IsUniqueViolation(err) {
|
||||||
|
require.NoError(t, err, "Everyone group for an org")
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := runtimeconfig.NewManager()
|
||||||
|
orgResolver := manager.OrganizationResolver(db, org.ID)
|
||||||
|
err = s.Group.SetRuntimeValue(context.Background(), orgResolver, def.Settings)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
if !def.NotMember {
|
||||||
|
dbgen.OrganizationMember(t, db, database.OrganizationMember{
|
||||||
|
UserID: user.ID,
|
||||||
|
OrganizationID: org.ID,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
for groupID, in := range def.Groups {
|
||||||
|
dbgen.Group(t, db, database.Group{
|
||||||
|
ID: groupID,
|
||||||
|
OrganizationID: org.ID,
|
||||||
|
})
|
||||||
|
if in {
|
||||||
|
dbgen.GroupMember(t, db, database.GroupMemberTable{
|
||||||
|
UserID: user.ID,
|
||||||
|
GroupID: groupID,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for groupName, in := range def.GroupNames {
|
||||||
|
group := dbgen.Group(t, db, database.Group{
|
||||||
|
Name: groupName,
|
||||||
|
OrganizationID: org.ID,
|
||||||
|
})
|
||||||
|
if in {
|
||||||
|
dbgen.GroupMember(t, db, database.GroupMemberTable{
|
||||||
|
UserID: user.ID,
|
||||||
|
GroupID: group.ID,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type orgSetupDefinition struct {
|
||||||
|
Name string
|
||||||
|
// True if the user is a member of the group
|
||||||
|
Groups map[uuid.UUID]bool
|
||||||
|
GroupNames map[string]bool
|
||||||
|
NotMember bool
|
||||||
|
|
||||||
|
Settings *idpsync.GroupSyncSettings
|
||||||
|
ExpectedGroups []uuid.UUID
|
||||||
|
ExpectedGroupNames []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o orgSetupDefinition) Assert(t *testing.T, orgID uuid.UUID, db database.Store, user database.User) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
members, err := db.OrganizationMembers(ctx, database.OrganizationMembersParams{
|
||||||
|
OrganizationID: orgID,
|
||||||
|
UserID: user.ID,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
if o.NotMember {
|
||||||
|
require.Len(t, members, 0, "should not be a member")
|
||||||
|
} else {
|
||||||
|
require.Len(t, members, 1, "should be a member")
|
||||||
|
}
|
||||||
|
|
||||||
|
userGroups, err := db.GetGroups(ctx, database.GetGroupsParams{
|
||||||
|
OrganizationID: orgID,
|
||||||
|
HasMemberID: user.ID,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
if o.ExpectedGroups == nil {
|
||||||
|
o.ExpectedGroups = make([]uuid.UUID, 0)
|
||||||
|
}
|
||||||
|
if len(o.ExpectedGroupNames) > 0 && len(o.ExpectedGroups) > 0 {
|
||||||
|
t.Fatal("ExpectedGroups and ExpectedGroupNames are mutually exclusive")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Everyone groups mess up our asserts
|
||||||
|
userGroups = slices.DeleteFunc(userGroups, func(row database.GetGroupsRow) bool {
|
||||||
|
return row.Group.ID == row.Group.OrganizationID
|
||||||
|
})
|
||||||
|
|
||||||
|
if len(o.ExpectedGroupNames) > 0 {
|
||||||
|
found := db2sdk.List(userGroups, func(g database.GetGroupsRow) string {
|
||||||
|
return g.Group.Name
|
||||||
|
})
|
||||||
|
require.ElementsMatch(t, o.ExpectedGroupNames, found, "user groups by name")
|
||||||
|
require.Len(t, o.ExpectedGroups, 0, "ExpectedGroups should be empty")
|
||||||
|
} else {
|
||||||
|
// Check by ID, recommended
|
||||||
|
found := db2sdk.List(userGroups, func(g database.GetGroupsRow) uuid.UUID {
|
||||||
|
return g.Group.ID
|
||||||
|
})
|
||||||
|
require.ElementsMatch(t, o.ExpectedGroups, found, "user groups")
|
||||||
|
require.Len(t, o.ExpectedGroupNames, 0, "ExpectedGroupNames should be empty")
|
||||||
|
}
|
||||||
|
}
|
||||||
+69
-15
@@ -3,6 +3,7 @@ package idpsync
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/golang-jwt/jwt/v4"
|
"github.com/golang-jwt/jwt/v4"
|
||||||
@@ -12,6 +13,7 @@ import (
|
|||||||
"cdr.dev/slog"
|
"cdr.dev/slog"
|
||||||
"github.com/coder/coder/v2/coderd/database"
|
"github.com/coder/coder/v2/coderd/database"
|
||||||
"github.com/coder/coder/v2/coderd/httpapi"
|
"github.com/coder/coder/v2/coderd/httpapi"
|
||||||
|
"github.com/coder/coder/v2/coderd/runtimeconfig"
|
||||||
"github.com/coder/coder/v2/codersdk"
|
"github.com/coder/coder/v2/codersdk"
|
||||||
"github.com/coder/coder/v2/site"
|
"github.com/coder/coder/v2/site"
|
||||||
)
|
)
|
||||||
@@ -25,21 +27,34 @@ type IDPSync interface {
|
|||||||
OrganizationSyncEnabled() bool
|
OrganizationSyncEnabled() bool
|
||||||
// ParseOrganizationClaims takes claims from an OIDC provider, and returns the
|
// ParseOrganizationClaims takes claims from an OIDC provider, and returns the
|
||||||
// organization sync params for assigning users into organizations.
|
// organization sync params for assigning users into organizations.
|
||||||
ParseOrganizationClaims(ctx context.Context, _ jwt.MapClaims) (OrganizationParams, *HTTPError)
|
ParseOrganizationClaims(ctx context.Context, mergedClaims jwt.MapClaims) (OrganizationParams, *HTTPError)
|
||||||
// SyncOrganizations assigns and removed users from organizations based on the
|
// SyncOrganizations assigns and removed users from organizations based on the
|
||||||
// provided params.
|
// provided params.
|
||||||
SyncOrganizations(ctx context.Context, tx database.Store, user database.User, params OrganizationParams) error
|
SyncOrganizations(ctx context.Context, tx database.Store, user database.User, params OrganizationParams) error
|
||||||
|
|
||||||
|
GroupSyncEnabled() bool
|
||||||
|
// ParseGroupClaims takes claims from an OIDC provider, and returns the params
|
||||||
|
// for group syncing. Most of the logic happens in SyncGroups.
|
||||||
|
ParseGroupClaims(ctx context.Context, mergedClaims jwt.MapClaims) (GroupParams, *HTTPError)
|
||||||
|
// SyncGroups assigns and removes users from groups based on the provided params.
|
||||||
|
SyncGroups(ctx context.Context, db database.Store, user database.User, params GroupParams) error
|
||||||
|
// GroupSyncSettings is exposed for the API to implement CRUD operations
|
||||||
|
// on the settings used by IDPSync. This entry is thread safe and can be
|
||||||
|
// accessed concurrently. The settings are stored in the database.
|
||||||
|
GroupSyncSettings() runtimeconfig.RuntimeEntry[*GroupSyncSettings]
|
||||||
}
|
}
|
||||||
|
|
||||||
// AGPLIDPSync is the configuration for syncing user information from an external
|
// AGPLIDPSync is the configuration for syncing user information from an external
|
||||||
// IDP. All related code to syncing user information should be in this package.
|
// IDP. All related code to syncing user information should be in this package.
|
||||||
type AGPLIDPSync struct {
|
type AGPLIDPSync struct {
|
||||||
Logger slog.Logger
|
Logger slog.Logger
|
||||||
|
Manager *runtimeconfig.Manager
|
||||||
|
|
||||||
SyncSettings
|
SyncSettings
|
||||||
}
|
}
|
||||||
|
|
||||||
type SyncSettings struct {
|
// DeploymentSyncSettings are static and are sourced from the deployment config.
|
||||||
|
type DeploymentSyncSettings struct {
|
||||||
// OrganizationField selects the claim field to be used as the created user's
|
// OrganizationField selects the claim field to be used as the created user's
|
||||||
// organizations. If the field is the empty string, then no organization updates
|
// organizations. If the field is the empty string, then no organization updates
|
||||||
// will ever come from the OIDC provider.
|
// will ever come from the OIDC provider.
|
||||||
@@ -50,23 +65,62 @@ type SyncSettings struct {
|
|||||||
// placed into the default organization. This is mostly a hack to support
|
// placed into the default organization. This is mostly a hack to support
|
||||||
// legacy deployments.
|
// legacy deployments.
|
||||||
OrganizationAssignDefault bool
|
OrganizationAssignDefault bool
|
||||||
|
|
||||||
|
// GroupField at the deployment level is used for deployment level group claim
|
||||||
|
// settings.
|
||||||
|
GroupField string
|
||||||
|
// GroupAllowList (if set) will restrict authentication to only users who
|
||||||
|
// have at least one group in this list.
|
||||||
|
// A map representation is used for easier lookup.
|
||||||
|
GroupAllowList map[string]struct{}
|
||||||
|
// Legacy deployment settings that only apply to the default org.
|
||||||
|
Legacy DefaultOrgLegacySettings
|
||||||
}
|
}
|
||||||
|
|
||||||
type OrganizationParams struct {
|
type DefaultOrgLegacySettings struct {
|
||||||
// SyncEnabled if false will skip syncing the user's organizations.
|
GroupField string
|
||||||
SyncEnabled bool
|
GroupMapping map[string]string
|
||||||
// IncludeDefault is primarily for single org deployments. It will ensure
|
GroupFilter *regexp.Regexp
|
||||||
// a user is always inserted into the default org.
|
CreateMissingGroups bool
|
||||||
IncludeDefault bool
|
|
||||||
// Organizations is the list of organizations the user should be a member of
|
|
||||||
// assuming syncing is turned on.
|
|
||||||
Organizations []uuid.UUID
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAGPLSync(logger slog.Logger, settings SyncSettings) *AGPLIDPSync {
|
func FromDeploymentValues(dv *codersdk.DeploymentValues) DeploymentSyncSettings {
|
||||||
|
if dv == nil {
|
||||||
|
panic("Developer error: DeploymentValues should not be nil")
|
||||||
|
}
|
||||||
|
return DeploymentSyncSettings{
|
||||||
|
OrganizationField: dv.OIDC.OrganizationField.Value(),
|
||||||
|
OrganizationMapping: dv.OIDC.OrganizationMapping.Value,
|
||||||
|
OrganizationAssignDefault: dv.OIDC.OrganizationAssignDefault.Value(),
|
||||||
|
|
||||||
|
// TODO: Separate group field for allow list from default org.
|
||||||
|
// Right now you cannot disable group sync from the default org and
|
||||||
|
// configure an allow list.
|
||||||
|
GroupField: dv.OIDC.GroupField.Value(),
|
||||||
|
GroupAllowList: ConvertAllowList(dv.OIDC.GroupAllowList.Value()),
|
||||||
|
Legacy: DefaultOrgLegacySettings{
|
||||||
|
GroupField: dv.OIDC.GroupField.Value(),
|
||||||
|
GroupMapping: dv.OIDC.GroupMapping.Value,
|
||||||
|
GroupFilter: dv.OIDC.GroupRegexFilter.Value(),
|
||||||
|
CreateMissingGroups: dv.OIDC.GroupAutoCreate.Value(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type SyncSettings struct {
|
||||||
|
DeploymentSyncSettings
|
||||||
|
|
||||||
|
Group runtimeconfig.RuntimeEntry[*GroupSyncSettings]
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewAGPLSync(logger slog.Logger, manager *runtimeconfig.Manager, settings DeploymentSyncSettings) *AGPLIDPSync {
|
||||||
return &AGPLIDPSync{
|
return &AGPLIDPSync{
|
||||||
Logger: logger.Named("idp-sync"),
|
Logger: logger.Named("idp-sync"),
|
||||||
SyncSettings: settings,
|
Manager: manager,
|
||||||
|
SyncSettings: SyncSettings{
|
||||||
|
DeploymentSyncSettings: settings,
|
||||||
|
Group: runtimeconfig.MustNew[*GroupSyncSettings]("group-sync-settings"),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -16,6 +16,17 @@ import (
|
|||||||
"github.com/coder/coder/v2/coderd/util/slice"
|
"github.com/coder/coder/v2/coderd/util/slice"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type OrganizationParams struct {
|
||||||
|
// SyncEnabled if false will skip syncing the user's organizations.
|
||||||
|
SyncEnabled bool
|
||||||
|
// IncludeDefault is primarily for single org deployments. It will ensure
|
||||||
|
// a user is always inserted into the default org.
|
||||||
|
IncludeDefault bool
|
||||||
|
// Organizations is the list of organizations the user should be a member of
|
||||||
|
// assuming syncing is turned on.
|
||||||
|
Organizations []uuid.UUID
|
||||||
|
}
|
||||||
|
|
||||||
func (AGPLIDPSync) OrganizationSyncEnabled() bool {
|
func (AGPLIDPSync) OrganizationSyncEnabled() bool {
|
||||||
// AGPL does not support syncing organizations.
|
// AGPL does not support syncing organizations.
|
||||||
return false
|
return false
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
|
|
||||||
"cdr.dev/slog/sloggers/slogtest"
|
"cdr.dev/slog/sloggers/slogtest"
|
||||||
"github.com/coder/coder/v2/coderd/idpsync"
|
"github.com/coder/coder/v2/coderd/idpsync"
|
||||||
|
"github.com/coder/coder/v2/coderd/runtimeconfig"
|
||||||
"github.com/coder/coder/v2/testutil"
|
"github.com/coder/coder/v2/testutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -18,11 +19,13 @@ func TestParseOrganizationClaims(t *testing.T) {
|
|||||||
t.Run("SingleOrgDeployment", func(t *testing.T) {
|
t.Run("SingleOrgDeployment", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), idpsync.SyncSettings{
|
s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}),
|
||||||
OrganizationField: "",
|
runtimeconfig.NewManager(),
|
||||||
OrganizationMapping: nil,
|
idpsync.DeploymentSyncSettings{
|
||||||
OrganizationAssignDefault: true,
|
OrganizationField: "",
|
||||||
})
|
OrganizationMapping: nil,
|
||||||
|
OrganizationAssignDefault: true,
|
||||||
|
})
|
||||||
|
|
||||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||||
|
|
||||||
@@ -38,13 +41,15 @@ func TestParseOrganizationClaims(t *testing.T) {
|
|||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
// AGPL has limited behavior
|
// AGPL has limited behavior
|
||||||
s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}), idpsync.SyncSettings{
|
s := idpsync.NewAGPLSync(slogtest.Make(t, &slogtest.Options{}),
|
||||||
OrganizationField: "orgs",
|
runtimeconfig.NewManager(),
|
||||||
OrganizationMapping: map[string][]uuid.UUID{
|
idpsync.DeploymentSyncSettings{
|
||||||
"random": {uuid.New()},
|
OrganizationField: "orgs",
|
||||||
},
|
OrganizationMapping: map[string][]uuid.UUID{
|
||||||
OrganizationAssignDefault: false,
|
"random": {uuid.New()},
|
||||||
})
|
},
|
||||||
|
OrganizationAssignDefault: false,
|
||||||
|
})
|
||||||
|
|
||||||
ctx := testutil.Context(t, testutil.WaitMedium)
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package runtimeconfig
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"golang.org/x/xerrors"
|
"golang.org/x/xerrors"
|
||||||
@@ -93,3 +94,11 @@ func (e *RuntimeEntry[T]) name() (string, error) {
|
|||||||
|
|
||||||
return e.n, nil
|
return e.n, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func JSONString(v any) string {
|
||||||
|
s, err := json.Marshal(v)
|
||||||
|
if err != nil {
|
||||||
|
return "decode failed: " + err.Error()
|
||||||
|
}
|
||||||
|
return string(s)
|
||||||
|
}
|
||||||
|
|||||||
+29
-173
@@ -8,7 +8,6 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/mail"
|
"net/mail"
|
||||||
"regexp"
|
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -20,7 +19,6 @@ import (
|
|||||||
"github.com/google/go-github/v43/github"
|
"github.com/google/go-github/v43/github"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/moby/moby/pkg/namesgenerator"
|
"github.com/moby/moby/pkg/namesgenerator"
|
||||||
"golang.org/x/exp/slices"
|
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
"golang.org/x/xerrors"
|
"golang.org/x/xerrors"
|
||||||
|
|
||||||
@@ -659,6 +657,9 @@ func (api *API) userOAuth2Github(rw http.ResponseWriter, r *http.Request) {
|
|||||||
AvatarURL: ghUser.GetAvatarURL(),
|
AvatarURL: ghUser.GetAvatarURL(),
|
||||||
Name: normName,
|
Name: normName,
|
||||||
DebugContext: OauthDebugContext{},
|
DebugContext: OauthDebugContext{},
|
||||||
|
GroupSync: idpsync.GroupParams{
|
||||||
|
SyncEnabled: false,
|
||||||
|
},
|
||||||
OrganizationSync: idpsync.OrganizationParams{
|
OrganizationSync: idpsync.OrganizationParams{
|
||||||
SyncEnabled: false,
|
SyncEnabled: false,
|
||||||
IncludeDefault: true,
|
IncludeDefault: true,
|
||||||
@@ -739,27 +740,6 @@ type OIDCConfig struct {
|
|||||||
// support the userinfo endpoint, or if the userinfo endpoint causes
|
// support the userinfo endpoint, or if the userinfo endpoint causes
|
||||||
// undesirable behavior.
|
// undesirable behavior.
|
||||||
IgnoreUserInfo bool
|
IgnoreUserInfo bool
|
||||||
|
|
||||||
// TODO: Move all idp fields into the IDPSync struct
|
|
||||||
// GroupField selects the claim field to be used as the created user's
|
|
||||||
// groups. If the group field is the empty string, then no group updates
|
|
||||||
// will ever come from the OIDC provider.
|
|
||||||
GroupField string
|
|
||||||
// CreateMissingGroups controls whether groups returned by the OIDC provider
|
|
||||||
// are automatically created in Coder if they are missing.
|
|
||||||
CreateMissingGroups bool
|
|
||||||
// GroupFilter is a regular expression that filters the groups returned by
|
|
||||||
// the OIDC provider. Any group not matched by this regex will be ignored.
|
|
||||||
// If the group filter is nil, then no group filtering will occur.
|
|
||||||
GroupFilter *regexp.Regexp
|
|
||||||
// GroupAllowList is a list of groups that are allowed to log in.
|
|
||||||
// If the list length is 0, then the allow list will not be applied and
|
|
||||||
// this feature is disabled.
|
|
||||||
GroupAllowList map[string]bool
|
|
||||||
// GroupMapping controls how groups returned by the OIDC provider get mapped
|
|
||||||
// to groups within Coder.
|
|
||||||
// map[oidcGroupName]coderGroupName
|
|
||||||
GroupMapping map[string]string
|
|
||||||
// UserRoleField selects the claim field to be used as the created user's
|
// UserRoleField selects the claim field to be used as the created user's
|
||||||
// roles. If the field is the empty string, then no role updates
|
// roles. If the field is the empty string, then no role updates
|
||||||
// will ever come from the OIDC provider.
|
// will ever come from the OIDC provider.
|
||||||
@@ -1002,11 +982,6 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ctx = slog.With(ctx, slog.F("email", email), slog.F("username", username), slog.F("name", name))
|
ctx = slog.With(ctx, slog.F("email", email), slog.F("username", username), slog.F("name", name))
|
||||||
usingGroups, groups, groupErr := api.oidcGroups(ctx, mergedClaims)
|
|
||||||
if groupErr != nil {
|
|
||||||
groupErr.Write(rw, r)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
roles, roleErr := api.oidcRoles(ctx, mergedClaims)
|
roles, roleErr := api.oidcRoles(ctx, mergedClaims)
|
||||||
if roleErr != nil {
|
if roleErr != nil {
|
||||||
@@ -1030,6 +1005,12 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
groupSync, groupSyncErr := api.IDPSync.ParseGroupClaims(ctx, mergedClaims)
|
||||||
|
if groupSyncErr != nil {
|
||||||
|
groupSyncErr.Write(rw, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// If a new user is authenticating for the first time
|
// If a new user is authenticating for the first time
|
||||||
// the audit action is 'register', not 'login'
|
// the audit action is 'register', not 'login'
|
||||||
if user.ID == uuid.Nil {
|
if user.ID == uuid.Nil {
|
||||||
@@ -1037,23 +1018,20 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
params := (&oauthLoginParams{
|
params := (&oauthLoginParams{
|
||||||
User: user,
|
User: user,
|
||||||
Link: link,
|
Link: link,
|
||||||
State: state,
|
State: state,
|
||||||
LinkedID: oidcLinkedID(idToken),
|
LinkedID: oidcLinkedID(idToken),
|
||||||
LoginType: database.LoginTypeOIDC,
|
LoginType: database.LoginTypeOIDC,
|
||||||
AllowSignups: api.OIDCConfig.AllowSignups,
|
AllowSignups: api.OIDCConfig.AllowSignups,
|
||||||
Email: email,
|
Email: email,
|
||||||
Username: username,
|
Username: username,
|
||||||
Name: name,
|
Name: name,
|
||||||
AvatarURL: picture,
|
AvatarURL: picture,
|
||||||
UsingRoles: api.OIDCConfig.RoleSyncEnabled(),
|
UsingRoles: api.OIDCConfig.RoleSyncEnabled(),
|
||||||
Roles: roles,
|
Roles: roles,
|
||||||
UsingGroups: usingGroups,
|
OrganizationSync: orgSync,
|
||||||
Groups: groups,
|
GroupSync: groupSync,
|
||||||
OrganizationSync: orgSync,
|
|
||||||
CreateMissingGroups: api.OIDCConfig.CreateMissingGroups,
|
|
||||||
GroupFilter: api.OIDCConfig.GroupFilter,
|
|
||||||
DebugContext: OauthDebugContext{
|
DebugContext: OauthDebugContext{
|
||||||
IDTokenClaims: idtokenClaims,
|
IDTokenClaims: idtokenClaims,
|
||||||
UserInfoClaims: userInfoClaims,
|
UserInfoClaims: userInfoClaims,
|
||||||
@@ -1089,79 +1067,6 @@ func (api *API) userOIDC(rw http.ResponseWriter, r *http.Request) {
|
|||||||
http.Redirect(rw, r, redirect, http.StatusTemporaryRedirect)
|
http.Redirect(rw, r, redirect, http.StatusTemporaryRedirect)
|
||||||
}
|
}
|
||||||
|
|
||||||
// oidcGroups returns the groups for the user from the OIDC claims.
|
|
||||||
func (api *API) oidcGroups(ctx context.Context, mergedClaims map[string]interface{}) (bool, []string, *idpsync.HTTPError) {
|
|
||||||
logger := api.Logger.Named(userAuthLoggerName)
|
|
||||||
usingGroups := false
|
|
||||||
var groups []string
|
|
||||||
|
|
||||||
// If the GroupField is the empty string, then groups from OIDC are not used.
|
|
||||||
// This is so we can support manual group assignment.
|
|
||||||
if api.OIDCConfig.GroupField != "" {
|
|
||||||
// If the allow list is empty, then the user is allowed to log in.
|
|
||||||
// Otherwise, they must belong to at least 1 group in the allow list.
|
|
||||||
inAllowList := len(api.OIDCConfig.GroupAllowList) == 0
|
|
||||||
|
|
||||||
usingGroups = true
|
|
||||||
groupsRaw, ok := mergedClaims[api.OIDCConfig.GroupField]
|
|
||||||
if ok {
|
|
||||||
parsedGroups, err := idpsync.ParseStringSliceClaim(groupsRaw)
|
|
||||||
if err != nil {
|
|
||||||
api.Logger.Debug(ctx, "groups field was an unknown type in oidc claims",
|
|
||||||
slog.F("type", fmt.Sprintf("%T", groupsRaw)),
|
|
||||||
slog.Error(err),
|
|
||||||
)
|
|
||||||
return false, nil, &idpsync.HTTPError{
|
|
||||||
Code: http.StatusBadRequest,
|
|
||||||
Msg: "Failed to sync groups from OIDC claims",
|
|
||||||
Detail: err.Error(),
|
|
||||||
RenderStaticPage: false,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
api.Logger.Debug(ctx, "groups returned in oidc claims",
|
|
||||||
slog.F("len", len(parsedGroups)),
|
|
||||||
slog.F("groups", parsedGroups),
|
|
||||||
)
|
|
||||||
|
|
||||||
for _, group := range parsedGroups {
|
|
||||||
if mappedGroup, ok := api.OIDCConfig.GroupMapping[group]; ok {
|
|
||||||
group = mappedGroup
|
|
||||||
}
|
|
||||||
if _, ok := api.OIDCConfig.GroupAllowList[group]; ok {
|
|
||||||
inAllowList = true
|
|
||||||
}
|
|
||||||
groups = append(groups, group)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !inAllowList {
|
|
||||||
logger.Debug(ctx, "oidc group claim not in allow list, rejecting login",
|
|
||||||
slog.F("allow_list_count", len(api.OIDCConfig.GroupAllowList)),
|
|
||||||
slog.F("user_group_count", len(groups)),
|
|
||||||
)
|
|
||||||
detail := "Ask an administrator to add one of your groups to the whitelist"
|
|
||||||
if len(groups) == 0 {
|
|
||||||
detail = "You are currently not a member of any groups! Ask an administrator to add you to an authorized group to login."
|
|
||||||
}
|
|
||||||
return usingGroups, groups, &idpsync.HTTPError{
|
|
||||||
Code: http.StatusForbidden,
|
|
||||||
Msg: "Not a member of an allowed group",
|
|
||||||
Detail: detail,
|
|
||||||
RenderStaticPage: true,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// This conditional is purely to warn the user they might have misconfigured their OIDC
|
|
||||||
// configuration.
|
|
||||||
if _, groupClaimExists := mergedClaims["groups"]; !usingGroups && groupClaimExists {
|
|
||||||
logger.Debug(ctx, "claim 'groups' was returned, but 'oidc-group-field' is not set, check your coder oidc settings")
|
|
||||||
}
|
|
||||||
|
|
||||||
return usingGroups, groups, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// oidcRoles returns the roles for the user from the OIDC claims.
|
// oidcRoles returns the roles for the user from the OIDC claims.
|
||||||
// If the function returns false, then the caller should return early.
|
// If the function returns false, then the caller should return early.
|
||||||
// All writes to the response writer are handled by this function.
|
// All writes to the response writer are handled by this function.
|
||||||
@@ -1276,14 +1181,7 @@ type oauthLoginParams struct {
|
|||||||
AvatarURL string
|
AvatarURL string
|
||||||
// OrganizationSync has the organizations that the user will be assigned to.
|
// OrganizationSync has the organizations that the user will be assigned to.
|
||||||
OrganizationSync idpsync.OrganizationParams
|
OrganizationSync idpsync.OrganizationParams
|
||||||
// Is UsingGroups is true, then the user will be assigned
|
GroupSync idpsync.GroupParams
|
||||||
// to the Groups provided.
|
|
||||||
UsingGroups bool
|
|
||||||
CreateMissingGroups bool
|
|
||||||
// These are the group names from the IDP. Internally, they will map to
|
|
||||||
// some organization groups.
|
|
||||||
Groups []string
|
|
||||||
GroupFilter *regexp.Regexp
|
|
||||||
// Is UsingRoles is true, then the user will be assigned
|
// Is UsingRoles is true, then the user will be assigned
|
||||||
// the roles provided.
|
// the roles provided.
|
||||||
UsingRoles bool
|
UsingRoles bool
|
||||||
@@ -1489,53 +1387,11 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C
|
|||||||
return xerrors.Errorf("sync organizations: %w", err)
|
return xerrors.Errorf("sync organizations: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure groups are correct.
|
// Group sync needs to occur after org sync, since a user can join an org,
|
||||||
// This places all groups into the default organization.
|
// then have their groups sync to said org.
|
||||||
// To go multi-org, we need to add a mapping feature here to know which
|
err = api.IDPSync.SyncGroups(ctx, tx, user, params.GroupSync)
|
||||||
// groups go to which orgs.
|
if err != nil {
|
||||||
if params.UsingGroups {
|
return xerrors.Errorf("sync groups: %w", err)
|
||||||
filtered := params.Groups
|
|
||||||
if params.GroupFilter != nil {
|
|
||||||
filtered = make([]string, 0, len(params.Groups))
|
|
||||||
for _, group := range params.Groups {
|
|
||||||
if params.GroupFilter.MatchString(group) {
|
|
||||||
filtered = append(filtered, group)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
//nolint:gocritic // No user present in the context.
|
|
||||||
defaultOrganization, err := tx.GetDefaultOrganization(dbauthz.AsSystemRestricted(ctx))
|
|
||||||
if err != nil {
|
|
||||||
// If there is no default org, then we can't assign groups.
|
|
||||||
// By default, we assume all groups belong to the default org.
|
|
||||||
return xerrors.Errorf("get default organization: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
//nolint:gocritic // No user present in the context.
|
|
||||||
memberships, err := tx.OrganizationMembers(dbauthz.AsSystemRestricted(ctx), database.OrganizationMembersParams{
|
|
||||||
UserID: user.ID,
|
|
||||||
OrganizationID: uuid.Nil,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return xerrors.Errorf("get organization memberships: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the user is not in the default organization, then we can't assign groups.
|
|
||||||
// A user cannot be in groups to an org they are not a member of.
|
|
||||||
if !slices.ContainsFunc(memberships, func(member database.OrganizationMembersRow) bool {
|
|
||||||
return member.OrganizationMember.OrganizationID == defaultOrganization.ID
|
|
||||||
}) {
|
|
||||||
return xerrors.Errorf("user %s is not a member of the default organization, cannot assign to groups in the org", user.ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
//nolint:gocritic
|
|
||||||
err = api.Options.SetUserGroups(dbauthz.AsSystemRestricted(ctx), logger, tx, user.ID, map[uuid.UUID][]string{
|
|
||||||
defaultOrganization.ID: filtered,
|
|
||||||
}, params.CreateMissingGroups)
|
|
||||||
if err != nil {
|
|
||||||
return xerrors.Errorf("set user groups: %w", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure roles are correct.
|
// Ensure roles are correct.
|
||||||
|
|||||||
@@ -80,13 +80,6 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
|
|||||||
if options.Entitlements == nil {
|
if options.Entitlements == nil {
|
||||||
options.Entitlements = entitlements.New()
|
options.Entitlements = entitlements.New()
|
||||||
}
|
}
|
||||||
if options.IDPSync == nil {
|
|
||||||
options.IDPSync = enidpsync.NewSync(options.Logger, options.Entitlements, idpsync.SyncSettings{
|
|
||||||
OrganizationField: options.DeploymentValues.OIDC.OrganizationField.Value(),
|
|
||||||
OrganizationMapping: options.DeploymentValues.OIDC.OrganizationMapping.Value,
|
|
||||||
OrganizationAssignDefault: options.DeploymentValues.OIDC.OrganizationAssignDefault.Value(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx, cancelFunc := context.WithCancel(ctx)
|
ctx, cancelFunc := context.WithCancel(ctx)
|
||||||
|
|
||||||
@@ -118,6 +111,11 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
options.Database = cryptDB
|
options.Database = cryptDB
|
||||||
|
|
||||||
|
if options.IDPSync == nil {
|
||||||
|
options.IDPSync = enidpsync.NewSync(options.Logger, options.RuntimeConfig, options.Entitlements, idpsync.FromDeploymentValues(options.DeploymentValues))
|
||||||
|
}
|
||||||
|
|
||||||
api := &API{
|
api := &API{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
cancel: cancelFunc,
|
cancel: cancelFunc,
|
||||||
@@ -147,7 +145,6 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
|
|||||||
}
|
}
|
||||||
return c.Subject, c.Trial, nil
|
return c.Subject, c.Trial, nil
|
||||||
}
|
}
|
||||||
api.AGPL.Options.SetUserGroups = api.setUserGroups
|
|
||||||
api.AGPL.Options.SetUserSiteRoles = api.setUserSiteRoles
|
api.AGPL.Options.SetUserSiteRoles = api.setUserSiteRoles
|
||||||
api.AGPL.SiteHandler.RegionsFetcher = func(ctx context.Context) (any, error) {
|
api.AGPL.SiteHandler.RegionsFetcher = func(ctx context.Context) (any, error) {
|
||||||
// If the user can read the workspace proxy resource, return that.
|
// If the user can read the workspace proxy resource, return that.
|
||||||
|
|||||||
@@ -2,9 +2,9 @@ package enidpsync
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"cdr.dev/slog"
|
"cdr.dev/slog"
|
||||||
|
|
||||||
"github.com/coder/coder/v2/coderd/entitlements"
|
"github.com/coder/coder/v2/coderd/entitlements"
|
||||||
"github.com/coder/coder/v2/coderd/idpsync"
|
"github.com/coder/coder/v2/coderd/idpsync"
|
||||||
|
"github.com/coder/coder/v2/coderd/runtimeconfig"
|
||||||
)
|
)
|
||||||
|
|
||||||
// EnterpriseIDPSync enabled syncing user information from an external IDP.
|
// EnterpriseIDPSync enabled syncing user information from an external IDP.
|
||||||
@@ -17,9 +17,9 @@ type EnterpriseIDPSync struct {
|
|||||||
*idpsync.AGPLIDPSync
|
*idpsync.AGPLIDPSync
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewSync(logger slog.Logger, set *entitlements.Set, settings idpsync.SyncSettings) *EnterpriseIDPSync {
|
func NewSync(logger slog.Logger, manager *runtimeconfig.Manager, set *entitlements.Set, settings idpsync.DeploymentSyncSettings) *EnterpriseIDPSync {
|
||||||
return &EnterpriseIDPSync{
|
return &EnterpriseIDPSync{
|
||||||
entitlements: set,
|
entitlements: set,
|
||||||
AGPLIDPSync: idpsync.NewAGPLSync(logger.With(slog.F("enterprise_capable", "true")), settings),
|
AGPLIDPSync: idpsync.NewAGPLSync(logger.With(slog.F("enterprise_capable", "true")), manager, settings),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,70 @@
|
|||||||
|
package enidpsync
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt/v4"
|
||||||
|
|
||||||
|
"github.com/coder/coder/v2/coderd/idpsync"
|
||||||
|
"github.com/coder/coder/v2/codersdk"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (e EnterpriseIDPSync) GroupSyncEnabled() bool {
|
||||||
|
return e.entitlements.Enabled(codersdk.FeatureTemplateRBAC)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseGroupClaims parses the user claims and handles deployment wide group behavior.
|
||||||
|
// Almost all behavior is deferred since each organization configures it's own
|
||||||
|
// group sync settings.
|
||||||
|
// GroupAllowList is implemented here to prevent login by unauthorized users.
|
||||||
|
// TODO: GroupAllowList overlaps with the default organization group sync settings.
|
||||||
|
func (e EnterpriseIDPSync) ParseGroupClaims(ctx context.Context, mergedClaims jwt.MapClaims) (idpsync.GroupParams, *idpsync.HTTPError) {
|
||||||
|
if !e.GroupSyncEnabled() {
|
||||||
|
return e.AGPLIDPSync.ParseGroupClaims(ctx, mergedClaims)
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.GroupField != "" && len(e.GroupAllowList) > 0 {
|
||||||
|
groupsRaw, ok := mergedClaims[e.GroupField]
|
||||||
|
if !ok {
|
||||||
|
return idpsync.GroupParams{}, &idpsync.HTTPError{
|
||||||
|
Code: http.StatusForbidden,
|
||||||
|
Msg: "Not a member of an allowed group",
|
||||||
|
Detail: "You have no groups in your claims!",
|
||||||
|
RenderStaticPage: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
parsedGroups, err := idpsync.ParseStringSliceClaim(groupsRaw)
|
||||||
|
if err != nil {
|
||||||
|
return idpsync.GroupParams{}, &idpsync.HTTPError{
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
Msg: "Failed read groups from claims for allow list check. Ask an administrator for help.",
|
||||||
|
Detail: err.Error(),
|
||||||
|
RenderStaticPage: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inAllowList := false
|
||||||
|
AllowListCheckLoop:
|
||||||
|
for _, group := range parsedGroups {
|
||||||
|
if _, ok := e.GroupAllowList[group]; ok {
|
||||||
|
inAllowList = true
|
||||||
|
break AllowListCheckLoop
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !inAllowList {
|
||||||
|
return idpsync.GroupParams{}, &idpsync.HTTPError{
|
||||||
|
Code: http.StatusForbidden,
|
||||||
|
Msg: "Not a member of an allowed group",
|
||||||
|
Detail: "Ask an administrator to add one of your groups to the allow list.",
|
||||||
|
RenderStaticPage: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return idpsync.GroupParams{
|
||||||
|
SyncEnabled: true,
|
||||||
|
MergedClaims: mergedClaims,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,96 @@
|
|||||||
|
package enidpsync_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt/v4"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"cdr.dev/slog/sloggers/slogtest"
|
||||||
|
"github.com/coder/coder/v2/coderd/entitlements"
|
||||||
|
"github.com/coder/coder/v2/coderd/idpsync"
|
||||||
|
"github.com/coder/coder/v2/coderd/runtimeconfig"
|
||||||
|
"github.com/coder/coder/v2/codersdk"
|
||||||
|
"github.com/coder/coder/v2/enterprise/coderd/enidpsync"
|
||||||
|
"github.com/coder/coder/v2/testutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestEnterpriseParseGroupClaims(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
entitled := entitlements.New()
|
||||||
|
entitled.Update(func(entitlements *codersdk.Entitlements) {
|
||||||
|
entitlements.Features[codersdk.FeatureTemplateRBAC] = codersdk.Feature{
|
||||||
|
Entitlement: codersdk.EntitlementEntitled,
|
||||||
|
Enabled: true,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("NoEntitlements", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
s := enidpsync.NewSync(slogtest.Make(t, &slogtest.Options{}),
|
||||||
|
runtimeconfig.NewManager(),
|
||||||
|
entitlements.New(),
|
||||||
|
idpsync.DeploymentSyncSettings{})
|
||||||
|
|
||||||
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||||
|
|
||||||
|
params, err := s.ParseGroupClaims(ctx, jwt.MapClaims{})
|
||||||
|
require.Nil(t, err)
|
||||||
|
|
||||||
|
require.False(t, params.SyncEnabled)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("NotInAllowList", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
s := enidpsync.NewSync(slogtest.Make(t, &slogtest.Options{}),
|
||||||
|
runtimeconfig.NewManager(),
|
||||||
|
entitled,
|
||||||
|
idpsync.DeploymentSyncSettings{
|
||||||
|
GroupField: "groups",
|
||||||
|
GroupAllowList: map[string]struct{}{
|
||||||
|
"foo": {},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||||
|
|
||||||
|
// Try with incorrect group
|
||||||
|
_, err := s.ParseGroupClaims(ctx, jwt.MapClaims{
|
||||||
|
"groups": []string{"bar"},
|
||||||
|
})
|
||||||
|
require.NotNil(t, err)
|
||||||
|
require.Equal(t, 403, err.Code)
|
||||||
|
|
||||||
|
// Try with no groups
|
||||||
|
_, err = s.ParseGroupClaims(ctx, jwt.MapClaims{})
|
||||||
|
require.NotNil(t, err)
|
||||||
|
require.Equal(t, 403, err.Code)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("InAllowList", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
s := enidpsync.NewSync(slogtest.Make(t, &slogtest.Options{}),
|
||||||
|
runtimeconfig.NewManager(),
|
||||||
|
entitled,
|
||||||
|
idpsync.DeploymentSyncSettings{
|
||||||
|
GroupField: "groups",
|
||||||
|
GroupAllowList: map[string]struct{}{
|
||||||
|
"foo": {},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx := testutil.Context(t, testutil.WaitMedium)
|
||||||
|
|
||||||
|
claims := jwt.MapClaims{
|
||||||
|
"groups": []string{"foo", "bar"},
|
||||||
|
}
|
||||||
|
params, err := s.ParseGroupClaims(ctx, claims)
|
||||||
|
require.Nil(t, err)
|
||||||
|
require.True(t, params.SyncEnabled)
|
||||||
|
require.Equal(t, claims, params.MergedClaims)
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -19,6 +19,7 @@ import (
|
|||||||
"github.com/coder/coder/v2/coderd/entitlements"
|
"github.com/coder/coder/v2/coderd/entitlements"
|
||||||
"github.com/coder/coder/v2/coderd/idpsync"
|
"github.com/coder/coder/v2/coderd/idpsync"
|
||||||
"github.com/coder/coder/v2/coderd/rbac"
|
"github.com/coder/coder/v2/coderd/rbac"
|
||||||
|
"github.com/coder/coder/v2/coderd/runtimeconfig"
|
||||||
"github.com/coder/coder/v2/codersdk"
|
"github.com/coder/coder/v2/codersdk"
|
||||||
"github.com/coder/coder/v2/enterprise/coderd/enidpsync"
|
"github.com/coder/coder/v2/enterprise/coderd/enidpsync"
|
||||||
"github.com/coder/coder/v2/testutil"
|
"github.com/coder/coder/v2/testutil"
|
||||||
@@ -41,7 +42,7 @@ type Expectations struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type OrganizationSyncTestCase struct {
|
type OrganizationSyncTestCase struct {
|
||||||
Settings idpsync.SyncSettings
|
Settings idpsync.DeploymentSyncSettings
|
||||||
Entitlements *entitlements.Set
|
Entitlements *entitlements.Set
|
||||||
Exps []Expectations
|
Exps []Expectations
|
||||||
}
|
}
|
||||||
@@ -89,7 +90,7 @@ func TestOrganizationSync(t *testing.T) {
|
|||||||
other := dbgen.Organization(t, db, database.Organization{})
|
other := dbgen.Organization(t, db, database.Organization{})
|
||||||
return OrganizationSyncTestCase{
|
return OrganizationSyncTestCase{
|
||||||
Entitlements: entitled,
|
Entitlements: entitled,
|
||||||
Settings: idpsync.SyncSettings{
|
Settings: idpsync.DeploymentSyncSettings{
|
||||||
OrganizationField: "",
|
OrganizationField: "",
|
||||||
OrganizationMapping: nil,
|
OrganizationMapping: nil,
|
||||||
OrganizationAssignDefault: true,
|
OrganizationAssignDefault: true,
|
||||||
@@ -142,7 +143,7 @@ func TestOrganizationSync(t *testing.T) {
|
|||||||
three := dbgen.Organization(t, db, database.Organization{})
|
three := dbgen.Organization(t, db, database.Organization{})
|
||||||
return OrganizationSyncTestCase{
|
return OrganizationSyncTestCase{
|
||||||
Entitlements: entitled,
|
Entitlements: entitled,
|
||||||
Settings: idpsync.SyncSettings{
|
Settings: idpsync.DeploymentSyncSettings{
|
||||||
OrganizationField: "organizations",
|
OrganizationField: "organizations",
|
||||||
OrganizationMapping: map[string][]uuid.UUID{
|
OrganizationMapping: map[string][]uuid.UUID{
|
||||||
"first": {one.ID},
|
"first": {one.ID},
|
||||||
@@ -236,7 +237,7 @@ func TestOrganizationSync(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create a new sync object
|
// Create a new sync object
|
||||||
sync := enidpsync.NewSync(logger, caseData.Entitlements, caseData.Settings)
|
sync := enidpsync.NewSync(logger, runtimeconfig.NewManager(), caseData.Entitlements, caseData.Settings)
|
||||||
for _, exp := range caseData.Exps {
|
for _, exp := range caseData.Exps {
|
||||||
t.Run(exp.Name, func(t *testing.T) {
|
t.Run(exp.Name, func(t *testing.T) {
|
||||||
params, httpErr := sync.ParseOrganizationClaims(ctx, exp.Claims)
|
params, httpErr := sync.ParseOrganizationClaims(ctx, exp.Claims)
|
||||||
|
|||||||
@@ -8,75 +8,9 @@ import (
|
|||||||
|
|
||||||
"cdr.dev/slog"
|
"cdr.dev/slog"
|
||||||
"github.com/coder/coder/v2/coderd/database"
|
"github.com/coder/coder/v2/coderd/database"
|
||||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
|
||||||
"github.com/coder/coder/v2/codersdk"
|
"github.com/coder/coder/v2/codersdk"
|
||||||
)
|
)
|
||||||
|
|
||||||
// nolint: revive
|
|
||||||
func (api *API) setUserGroups(ctx context.Context, logger slog.Logger, db database.Store, userID uuid.UUID, orgGroupNames map[uuid.UUID][]string, createMissingGroups bool) error {
|
|
||||||
if !api.Entitlements.Enabled(codersdk.FeatureTemplateRBAC) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return db.InTx(func(tx database.Store) error {
|
|
||||||
// When setting the user's groups, it's easier to just clear their groups and re-add them.
|
|
||||||
// This ensures that the user's groups are always in sync with the auth provider.
|
|
||||||
orgs, err := tx.GetOrganizationsByUserID(ctx, userID)
|
|
||||||
if err != nil {
|
|
||||||
return xerrors.Errorf("get user orgs: %w", err)
|
|
||||||
}
|
|
||||||
if len(orgs) != 1 {
|
|
||||||
return xerrors.Errorf("expected 1 org, got %d", len(orgs))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Delete all groups the user belongs to.
|
|
||||||
// nolint:gocritic // Requires system context to remove user from all groups.
|
|
||||||
err = tx.RemoveUserFromAllGroups(dbauthz.AsSystemRestricted(ctx), userID)
|
|
||||||
if err != nil {
|
|
||||||
return xerrors.Errorf("delete user groups: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: This could likely be improved by making these single queries.
|
|
||||||
// Either by batching or some other means. This for loop could be really
|
|
||||||
// inefficient if there are a lot of organizations. There was deployments
|
|
||||||
// on v1 with >100 orgs.
|
|
||||||
for orgID, groupNames := range orgGroupNames {
|
|
||||||
// Create the missing groups for each organization.
|
|
||||||
if createMissingGroups {
|
|
||||||
// This is the system creating these additional groups, so we use the system restricted context.
|
|
||||||
// nolint:gocritic
|
|
||||||
created, err := tx.InsertMissingGroups(dbauthz.AsSystemRestricted(ctx), database.InsertMissingGroupsParams{
|
|
||||||
OrganizationID: orgID,
|
|
||||||
GroupNames: groupNames,
|
|
||||||
Source: database.GroupSourceOidc,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return xerrors.Errorf("insert missing groups: %w", err)
|
|
||||||
}
|
|
||||||
if len(created) > 0 {
|
|
||||||
logger.Debug(ctx, "auto created missing groups",
|
|
||||||
slog.F("org_id", orgID.ID),
|
|
||||||
slog.F("created", created),
|
|
||||||
slog.F("num", len(created)),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Re-add the user to all groups returned by the auth provider.
|
|
||||||
err = tx.InsertUserGroupsByName(ctx, database.InsertUserGroupsByNameParams{
|
|
||||||
UserID: userID,
|
|
||||||
OrganizationID: orgID,
|
|
||||||
GroupNames: groupNames,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return xerrors.Errorf("insert user groups: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (api *API) setUserSiteRoles(ctx context.Context, logger slog.Logger, db database.Store, userID uuid.UUID, roles []string) error {
|
func (api *API) setUserSiteRoles(ctx context.Context, logger slog.Logger, db database.Store, userID uuid.UUID, roles []string) error {
|
||||||
if !api.Entitlements.Enabled(codersdk.FeatureUserRoleManagement) {
|
if !api.Entitlements.Enabled(codersdk.FeatureUserRoleManagement) {
|
||||||
logger.Warn(ctx, "attempted to assign OIDC user roles without enterprise entitlement, roles left unchanged",
|
logger.Warn(ctx, "attempted to assign OIDC user roles without enterprise entitlement, roles left unchanged",
|
||||||
|
|||||||
@@ -402,7 +402,9 @@ func TestUserOIDC(t *testing.T) {
|
|||||||
runner := setupOIDCTest(t, oidcTestConfig{
|
runner := setupOIDCTest(t, oidcTestConfig{
|
||||||
Config: func(cfg *coderd.OIDCConfig) {
|
Config: func(cfg *coderd.OIDCConfig) {
|
||||||
cfg.AllowSignups = true
|
cfg.AllowSignups = true
|
||||||
cfg.GroupField = groupClaim
|
},
|
||||||
|
DeploymentValues: func(dv *codersdk.DeploymentValues) {
|
||||||
|
dv.OIDC.GroupField = groupClaim
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -433,8 +435,10 @@ func TestUserOIDC(t *testing.T) {
|
|||||||
runner := setupOIDCTest(t, oidcTestConfig{
|
runner := setupOIDCTest(t, oidcTestConfig{
|
||||||
Config: func(cfg *coderd.OIDCConfig) {
|
Config: func(cfg *coderd.OIDCConfig) {
|
||||||
cfg.AllowSignups = true
|
cfg.AllowSignups = true
|
||||||
cfg.GroupField = groupClaim
|
},
|
||||||
cfg.GroupMapping = map[string]string{oidcGroupName: coderGroupName}
|
DeploymentValues: func(dv *codersdk.DeploymentValues) {
|
||||||
|
dv.OIDC.GroupField = groupClaim
|
||||||
|
dv.OIDC.GroupMapping = serpent.Struct[map[string]string]{Value: map[string]string{oidcGroupName: coderGroupName}}
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -468,7 +472,9 @@ func TestUserOIDC(t *testing.T) {
|
|||||||
runner := setupOIDCTest(t, oidcTestConfig{
|
runner := setupOIDCTest(t, oidcTestConfig{
|
||||||
Config: func(cfg *coderd.OIDCConfig) {
|
Config: func(cfg *coderd.OIDCConfig) {
|
||||||
cfg.AllowSignups = true
|
cfg.AllowSignups = true
|
||||||
cfg.GroupField = groupClaim
|
},
|
||||||
|
DeploymentValues: func(dv *codersdk.DeploymentValues) {
|
||||||
|
dv.OIDC.GroupField = groupClaim
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -502,7 +508,9 @@ func TestUserOIDC(t *testing.T) {
|
|||||||
runner := setupOIDCTest(t, oidcTestConfig{
|
runner := setupOIDCTest(t, oidcTestConfig{
|
||||||
Config: func(cfg *coderd.OIDCConfig) {
|
Config: func(cfg *coderd.OIDCConfig) {
|
||||||
cfg.AllowSignups = true
|
cfg.AllowSignups = true
|
||||||
cfg.GroupField = groupClaim
|
},
|
||||||
|
DeploymentValues: func(dv *codersdk.DeploymentValues) {
|
||||||
|
dv.OIDC.GroupField = groupClaim
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -537,7 +545,9 @@ func TestUserOIDC(t *testing.T) {
|
|||||||
runner := setupOIDCTest(t, oidcTestConfig{
|
runner := setupOIDCTest(t, oidcTestConfig{
|
||||||
Config: func(cfg *coderd.OIDCConfig) {
|
Config: func(cfg *coderd.OIDCConfig) {
|
||||||
cfg.AllowSignups = true
|
cfg.AllowSignups = true
|
||||||
cfg.GroupField = groupClaim
|
},
|
||||||
|
DeploymentValues: func(dv *codersdk.DeploymentValues) {
|
||||||
|
dv.OIDC.GroupField = groupClaim
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -559,8 +569,10 @@ func TestUserOIDC(t *testing.T) {
|
|||||||
runner := setupOIDCTest(t, oidcTestConfig{
|
runner := setupOIDCTest(t, oidcTestConfig{
|
||||||
Config: func(cfg *coderd.OIDCConfig) {
|
Config: func(cfg *coderd.OIDCConfig) {
|
||||||
cfg.AllowSignups = true
|
cfg.AllowSignups = true
|
||||||
cfg.GroupField = groupClaim
|
},
|
||||||
cfg.CreateMissingGroups = true
|
DeploymentValues: func(dv *codersdk.DeploymentValues) {
|
||||||
|
dv.OIDC.GroupField = groupClaim
|
||||||
|
dv.OIDC.GroupAutoCreate = true
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -582,8 +594,10 @@ func TestUserOIDC(t *testing.T) {
|
|||||||
runner := setupOIDCTest(t, oidcTestConfig{
|
runner := setupOIDCTest(t, oidcTestConfig{
|
||||||
Config: func(cfg *coderd.OIDCConfig) {
|
Config: func(cfg *coderd.OIDCConfig) {
|
||||||
cfg.AllowSignups = true
|
cfg.AllowSignups = true
|
||||||
cfg.GroupField = groupClaim
|
},
|
||||||
cfg.CreateMissingGroups = true
|
DeploymentValues: func(dv *codersdk.DeploymentValues) {
|
||||||
|
dv.OIDC.GroupField = groupClaim
|
||||||
|
dv.OIDC.GroupAutoCreate = true
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -606,8 +620,10 @@ func TestUserOIDC(t *testing.T) {
|
|||||||
runner := setupOIDCTest(t, oidcTestConfig{
|
runner := setupOIDCTest(t, oidcTestConfig{
|
||||||
Config: func(cfg *coderd.OIDCConfig) {
|
Config: func(cfg *coderd.OIDCConfig) {
|
||||||
cfg.AllowSignups = true
|
cfg.AllowSignups = true
|
||||||
cfg.GroupField = groupClaim
|
},
|
||||||
cfg.GroupAllowList = map[string]bool{allowedGroup: true}
|
DeploymentValues: func(dv *codersdk.DeploymentValues) {
|
||||||
|
dv.OIDC.GroupField = groupClaim
|
||||||
|
dv.OIDC.GroupAllowList = []string{allowedGroup}
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -697,6 +713,7 @@ func TestGroupSync(t *testing.T) {
|
|||||||
testCases := []struct {
|
testCases := []struct {
|
||||||
name string
|
name string
|
||||||
modCfg func(cfg *coderd.OIDCConfig)
|
modCfg func(cfg *coderd.OIDCConfig)
|
||||||
|
modDV func(dv *codersdk.DeploymentValues)
|
||||||
// initialOrgGroups is initial groups in the org
|
// initialOrgGroups is initial groups in the org
|
||||||
initialOrgGroups []string
|
initialOrgGroups []string
|
||||||
// initialUserGroups is initial groups for the user
|
// initialUserGroups is initial groups for the user
|
||||||
@@ -718,10 +735,10 @@ func TestGroupSync(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "GroupSyncDisabled",
|
name: "GroupSyncDisabled",
|
||||||
modCfg: func(cfg *coderd.OIDCConfig) {
|
modDV: func(dv *codersdk.DeploymentValues) {
|
||||||
// Disable group sync
|
// Disable group sync
|
||||||
cfg.GroupField = ""
|
dv.OIDC.GroupField = ""
|
||||||
cfg.GroupFilter = regexp.MustCompile(".*")
|
dv.OIDC.GroupRegexFilter = serpent.Regexp(*regexp.MustCompile(".*"))
|
||||||
},
|
},
|
||||||
initialOrgGroups: []string{"a", "b", "c", "d"},
|
initialOrgGroups: []string{"a", "b", "c", "d"},
|
||||||
initialUserGroups: []string{"b", "c", "d"},
|
initialUserGroups: []string{"b", "c", "d"},
|
||||||
@@ -732,10 +749,8 @@ func TestGroupSync(t *testing.T) {
|
|||||||
{
|
{
|
||||||
// From a,c,b -> b,c,d
|
// From a,c,b -> b,c,d
|
||||||
name: "ChangeUserGroups",
|
name: "ChangeUserGroups",
|
||||||
modCfg: func(cfg *coderd.OIDCConfig) {
|
modDV: func(dv *codersdk.DeploymentValues) {
|
||||||
cfg.GroupMapping = map[string]string{
|
dv.OIDC.GroupMapping = serpent.Struct[map[string]string]{Value: map[string]string{"D": "d"}}
|
||||||
"D": "d",
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
initialOrgGroups: []string{"a", "b", "c", "d"},
|
initialOrgGroups: []string{"a", "b", "c", "d"},
|
||||||
initialUserGroups: []string{"a", "b", "c"},
|
initialUserGroups: []string{"a", "b", "c"},
|
||||||
@@ -749,8 +764,8 @@ func TestGroupSync(t *testing.T) {
|
|||||||
{
|
{
|
||||||
// From a,c,b -> []
|
// From a,c,b -> []
|
||||||
name: "RemoveAllGroups",
|
name: "RemoveAllGroups",
|
||||||
modCfg: func(cfg *coderd.OIDCConfig) {
|
modDV: func(dv *codersdk.DeploymentValues) {
|
||||||
cfg.GroupFilter = regexp.MustCompile(".*")
|
dv.OIDC.GroupRegexFilter = serpent.Regexp(*regexp.MustCompile(".*"))
|
||||||
},
|
},
|
||||||
initialOrgGroups: []string{"a", "b", "c", "d"},
|
initialOrgGroups: []string{"a", "b", "c", "d"},
|
||||||
initialUserGroups: []string{"a", "b", "c"},
|
initialUserGroups: []string{"a", "b", "c"},
|
||||||
@@ -763,8 +778,8 @@ func TestGroupSync(t *testing.T) {
|
|||||||
{
|
{
|
||||||
// From a,c,b -> b,c,d,e,f
|
// From a,c,b -> b,c,d,e,f
|
||||||
name: "CreateMissingGroups",
|
name: "CreateMissingGroups",
|
||||||
modCfg: func(cfg *coderd.OIDCConfig) {
|
modDV: func(dv *codersdk.DeploymentValues) {
|
||||||
cfg.CreateMissingGroups = true
|
dv.OIDC.GroupAutoCreate = true
|
||||||
},
|
},
|
||||||
initialOrgGroups: []string{"a", "b", "c", "d"},
|
initialOrgGroups: []string{"a", "b", "c", "d"},
|
||||||
initialUserGroups: []string{"a", "b", "c"},
|
initialUserGroups: []string{"a", "b", "c"},
|
||||||
@@ -777,14 +792,11 @@ func TestGroupSync(t *testing.T) {
|
|||||||
{
|
{
|
||||||
// From a,c,b -> b,c,d,e,f
|
// From a,c,b -> b,c,d,e,f
|
||||||
name: "CreateMissingGroupsFilter",
|
name: "CreateMissingGroupsFilter",
|
||||||
modCfg: func(cfg *coderd.OIDCConfig) {
|
modDV: func(dv *codersdk.DeploymentValues) {
|
||||||
cfg.CreateMissingGroups = true
|
dv.OIDC.GroupAutoCreate = true
|
||||||
// Only single letter groups
|
// Only single letter groups
|
||||||
cfg.GroupFilter = regexp.MustCompile("^[a-z]$")
|
dv.OIDC.GroupRegexFilter = serpent.Regexp(*regexp.MustCompile("^[a-z]$"))
|
||||||
cfg.GroupMapping = map[string]string{
|
dv.OIDC.GroupMapping = serpent.Struct[map[string]string]{Value: map[string]string{"zebra": "z"}}
|
||||||
// Does not match the filter, but does after being mapped!
|
|
||||||
"zebra": "z",
|
|
||||||
}
|
|
||||||
},
|
},
|
||||||
initialOrgGroups: []string{"a", "b", "c", "d"},
|
initialOrgGroups: []string{"a", "b", "c", "d"},
|
||||||
initialUserGroups: []string{"a", "b", "c"},
|
initialUserGroups: []string{"a", "b", "c"},
|
||||||
@@ -806,8 +818,15 @@ func TestGroupSync(t *testing.T) {
|
|||||||
t.Parallel()
|
t.Parallel()
|
||||||
runner := setupOIDCTest(t, oidcTestConfig{
|
runner := setupOIDCTest(t, oidcTestConfig{
|
||||||
Config: func(cfg *coderd.OIDCConfig) {
|
Config: func(cfg *coderd.OIDCConfig) {
|
||||||
cfg.GroupField = "groups"
|
if tc.modCfg != nil {
|
||||||
tc.modCfg(cfg)
|
tc.modCfg(cfg)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
DeploymentValues: func(dv *codersdk.DeploymentValues) {
|
||||||
|
dv.OIDC.GroupField = "groups"
|
||||||
|
if tc.modDV != nil {
|
||||||
|
tc.modDV(dv)
|
||||||
|
}
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user