chore: add /groups endpoint to filter by organization and/or member (#14260)

* chore: merge get groups sql queries into 1

* Add endpoint for fetching groups with filters
* remove 2 ways to customizing a fake authorizer
This commit is contained in:
Steven Masley
2024-08-15 13:40:15 -05:00
committed by GitHub
parent 83ccdaa755
commit 7b09d98238
24 changed files with 539 additions and 289 deletions
+44
View File
@@ -1033,6 +1033,50 @@ const docTemplate = `{
}
}
},
"/groups": {
"get": {
"security": [
{
"CoderSessionToken": []
}
],
"produces": [
"application/json"
],
"tags": [
"Enterprise"
],
"summary": "Get groups",
"operationId": "get-groups",
"parameters": [
{
"type": "string",
"description": "Organization ID or name",
"name": "organization",
"in": "query",
"required": true
},
{
"type": "string",
"description": "User ID or name",
"name": "has_member",
"in": "query",
"required": true
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"type": "array",
"items": {
"$ref": "#/definitions/codersdk.Group"
}
}
}
}
}
},
"/groups/{group}": {
"get": {
"security": [
+40
View File
@@ -891,6 +891,46 @@
}
}
},
"/groups": {
"get": {
"security": [
{
"CoderSessionToken": []
}
],
"produces": ["application/json"],
"tags": ["Enterprise"],
"summary": "Get groups",
"operationId": "get-groups",
"parameters": [
{
"type": "string",
"description": "Organization ID or name",
"name": "organization",
"in": "query",
"required": true
},
{
"type": "string",
"description": "User ID or name",
"name": "has_member",
"in": "query",
"required": true
}
],
"responses": {
"200": {
"description": "OK",
"schema": {
"type": "array",
"items": {
"$ref": "#/definitions/codersdk.Group"
}
}
}
}
}
},
"/groups/{group}": {
"get": {
"security": [
+17 -5
View File
@@ -353,16 +353,28 @@ func (s *PreparedRecorder) CompileToSQL(ctx context.Context, cfg regosql.Convert
return s.prepped.CompileToSQL(ctx, cfg)
}
// FakeAuthorizer is an Authorizer that always returns the same error.
// FakeAuthorizer is an Authorizer that will return an error based on the
// "ConditionalReturn" function. By default, **no error** is returned.
// Meaning 'FakeAuthorizer' by default will never return "unauthorized".
type FakeAuthorizer struct {
// AlwaysReturn is the error that will be returned by Authorize.
AlwaysReturn error
ConditionalReturn func(context.Context, rbac.Subject, policy.Action, rbac.Object) error
}
var _ rbac.Authorizer = (*FakeAuthorizer)(nil)
func (d *FakeAuthorizer) Authorize(_ context.Context, _ rbac.Subject, _ policy.Action, _ rbac.Object) error {
return d.AlwaysReturn
// AlwaysReturn is the error that will be returned by Authorize.
func (d *FakeAuthorizer) AlwaysReturn(err error) *FakeAuthorizer {
d.ConditionalReturn = func(_ context.Context, _ rbac.Subject, _ policy.Action, _ rbac.Object) error {
return err
}
return d
}
func (d *FakeAuthorizer) Authorize(ctx context.Context, subject rbac.Subject, action policy.Action, object rbac.Object) error {
if d.ConditionalReturn != nil {
return d.ConditionalReturn(ctx, subject, action, object)
}
return nil
}
func (d *FakeAuthorizer) Prepare(_ context.Context, subject rbac.Subject, action policy.Action, _ string) (rbac.PreparedAuthorized, error) {
+8 -11
View File
@@ -1491,19 +1491,16 @@ func (q *querier) GetGroupMembersCountByGroupID(ctx context.Context, groupID uui
return memberCount, nil
}
func (q *querier) GetGroups(ctx context.Context) ([]database.Group, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err != nil {
return nil, err
func (q *querier) GetGroups(ctx context.Context, arg database.GetGroupsParams) ([]database.Group, error) {
if err := q.authorizeContext(ctx, policy.ActionRead, rbac.ResourceSystem); err == nil {
// Optimize this query for system users as it is used in telemetry.
// Calling authz on all groups in a deployment for telemetry jobs is
// excessive. Most user calls should have some filtering applied to reduce
// the size of the set.
return q.db.GetGroups(ctx, arg)
}
return q.db.GetGroups(ctx)
}
func (q *querier) GetGroupsByOrganizationAndUserID(ctx context.Context, arg database.GetGroupsByOrganizationAndUserIDParams) ([]database.Group, error) {
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetGroupsByOrganizationAndUserID)(ctx, arg)
}
func (q *querier) GetGroupsByOrganizationID(ctx context.Context, organizationID uuid.UUID) ([]database.Group, error) {
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetGroupsByOrganizationID)(ctx, organizationID)
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetGroups)(ctx, arg)
}
func (q *querier) GetHealthSettings(ctx context.Context) (string, error) {
+19 -12
View File
@@ -81,7 +81,7 @@ func TestInTX(t *testing.T) {
db := dbmem.New()
q := dbauthz.New(db, &coderdtest.RecordingAuthorizer{
Wrapped: &coderdtest.FakeAuthorizer{AlwaysReturn: xerrors.New("custom error")},
Wrapped: (&coderdtest.FakeAuthorizer{}).AlwaysReturn(xerrors.New("custom error")),
}, slog.Make(), coderdtest.AccessControlStorePointer())
actor := rbac.Subject{
ID: uuid.NewString(),
@@ -110,7 +110,7 @@ func TestNew(t *testing.T) {
db = dbmem.New()
exp = dbgen.Workspace(t, db, database.Workspace{})
rec = &coderdtest.RecordingAuthorizer{
Wrapped: &coderdtest.FakeAuthorizer{AlwaysReturn: nil},
Wrapped: &coderdtest.FakeAuthorizer{},
}
subj = rbac.Subject{}
ctx = dbauthz.As(context.Background(), rbac.Subject{})
@@ -135,7 +135,7 @@ func TestNew(t *testing.T) {
func TestDBAuthzRecursive(t *testing.T) {
t.Parallel()
q := dbauthz.New(dbmem.New(), &coderdtest.RecordingAuthorizer{
Wrapped: &coderdtest.FakeAuthorizer{AlwaysReturn: nil},
Wrapped: &coderdtest.FakeAuthorizer{},
}, slog.Make(), coderdtest.AccessControlStorePointer())
actor := rbac.Subject{
ID: uuid.NewString(),
@@ -342,18 +342,21 @@ func (s *MethodTestSuite) TestGroup() {
dbgen.GroupMember(s.T(), db, database.GroupMemberTable{GroupID: g.ID, UserID: u.ID})
check.Asserts(rbac.ResourceSystem, policy.ActionRead)
}))
s.Run("GetGroups", s.Subtest(func(db database.Store, check *expects) {
s.Run("System/GetGroups", s.Subtest(func(db database.Store, check *expects) {
_ = dbgen.Group(s.T(), db, database.Group{})
check.Asserts(rbac.ResourceSystem, policy.ActionRead)
check.Args(database.GetGroupsParams{}).
Asserts(rbac.ResourceSystem, policy.ActionRead)
}))
s.Run("GetGroupsByOrganizationAndUserID", s.Subtest(func(db database.Store, check *expects) {
s.Run("GetGroups", s.Subtest(func(db database.Store, check *expects) {
g := dbgen.Group(s.T(), db, database.Group{})
u := dbgen.User(s.T(), db, database.User{})
gm := dbgen.GroupMember(s.T(), db, database.GroupMemberTable{GroupID: g.ID, UserID: u.ID})
check.Args(database.GetGroupsByOrganizationAndUserIDParams{
check.Args(database.GetGroupsParams{
OrganizationID: g.OrganizationID,
UserID: gm.UserID,
}).Asserts(g, policy.ActionRead)
HasMemberID: gm.UserID,
}).Asserts(rbac.ResourceSystem, policy.ActionRead, g, policy.ActionRead).
// Fail the system resource skip
FailSystemObjectChecks()
}))
s.Run("InsertAllUsersGroup", s.Subtest(func(db database.Store, check *expects) {
o := dbgen.Organization(s.T(), db, database.Organization{})
@@ -597,12 +600,16 @@ func (s *MethodTestSuite) TestLicense() {
}
func (s *MethodTestSuite) TestOrganization() {
s.Run("GetGroupsByOrganizationID", s.Subtest(func(db database.Store, check *expects) {
s.Run("ByOrganization/GetGroups", s.Subtest(func(db database.Store, check *expects) {
o := dbgen.Organization(s.T(), db, database.Organization{})
a := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID})
b := dbgen.Group(s.T(), db, database.Group{OrganizationID: o.ID})
check.Args(o.ID).Asserts(a, policy.ActionRead, b, policy.ActionRead).
Returns([]database.Group{a, b})
check.Args(database.GetGroupsParams{
OrganizationID: o.ID,
}).Asserts(rbac.ResourceSystem, policy.ActionRead, a, policy.ActionRead, b, policy.ActionRead).
Returns([]database.Group{a, b}).
// Fail the system check shortcut
FailSystemObjectChecks()
}))
s.Run("GetOrganizationByID", s.Subtest(func(db database.Store, check *expects) {
o := dbgen.Organization(s.T(), db, database.Organization{})
+27 -7
View File
@@ -114,9 +114,7 @@ func (s *MethodTestSuite) Subtest(testCaseF func(db database.Store, check *expec
s.methodAccounting[methodName]++
db := dbmem.New()
fakeAuthorizer := &coderdtest.FakeAuthorizer{
AlwaysReturn: nil,
}
fakeAuthorizer := &coderdtest.FakeAuthorizer{}
rec := &coderdtest.RecordingAuthorizer{
Wrapped: fakeAuthorizer,
}
@@ -174,7 +172,11 @@ func (s *MethodTestSuite) Subtest(testCaseF func(db database.Store, check *expec
// Always run
s.Run("Success", func() {
rec.Reset()
fakeAuthorizer.AlwaysReturn = nil
if testCase.successAuthorizer != nil {
fakeAuthorizer.ConditionalReturn = testCase.successAuthorizer
} else {
fakeAuthorizer.AlwaysReturn(nil)
}
outputs, err := callMethod(ctx)
if testCase.err == nil {
@@ -232,7 +234,7 @@ func (s *MethodTestSuite) NoActorErrorTest(callMethod func(ctx context.Context)
// Asserts that the error returned is a NotAuthorizedError.
func (s *MethodTestSuite) NotAuthorizedErrorTest(ctx context.Context, az *coderdtest.FakeAuthorizer, testCase expects, callMethod func(ctx context.Context) ([]reflect.Value, error)) {
s.Run("NotAuthorized", func() {
az.AlwaysReturn = rbac.ForbiddenWithInternal(xerrors.New("Always fail authz"), rbac.Subject{}, "", rbac.Object{}, nil)
az.AlwaysReturn(rbac.ForbiddenWithInternal(xerrors.New("Always fail authz"), rbac.Subject{}, "", rbac.Object{}, nil))
// If we have assertions, that means the method should FAIL
// if RBAC will disallow the request. The returned error should
@@ -257,8 +259,8 @@ func (s *MethodTestSuite) NotAuthorizedErrorTest(ctx context.Context, az *coderd
// Pass in a canceled context
ctx, cancel := context.WithCancel(ctx)
cancel()
az.AlwaysReturn = rbac.ForbiddenWithInternal(&topdown.Error{Code: topdown.CancelErr},
rbac.Subject{}, "", rbac.Object{}, nil)
az.AlwaysReturn(rbac.ForbiddenWithInternal(&topdown.Error{Code: topdown.CancelErr},
rbac.Subject{}, "", rbac.Object{}, nil))
// If we have assertions, that means the method should FAIL
// if RBAC will disallow the request. The returned error should
@@ -324,6 +326,7 @@ type expects struct {
// instead.
notAuthorizedExpect string
cancelledCtxExpect string
successAuthorizer func(ctx context.Context, subject rbac.Subject, action policy.Action, obj rbac.Object) error
}
// Asserts is required. Asserts the RBAC authorize calls that should be made.
@@ -354,6 +357,23 @@ func (m *expects) Errors(err error) *expects {
return m
}
func (m *expects) FailSystemObjectChecks() *expects {
return m.WithSuccessAuthorizer(func(ctx context.Context, subject rbac.Subject, action policy.Action, obj rbac.Object) error {
if obj.Type == rbac.ResourceSystem.Type {
return xerrors.Errorf("hard coded system authz failed")
}
return nil
})
}
// WithSuccessAuthorizer is helpful when an optimization authz check is made
// to skip some RBAC checks. This check in testing would prevent the ability
// to assert the more nuanced RBAC checks.
func (m *expects) WithSuccessAuthorizer(f func(ctx context.Context, subject rbac.Subject, action policy.Action, obj rbac.Object) error) *expects {
m.successAuthorizer = f
return m
}
func (m *expects) WithNotAuthorized(contains string) *expects {
m.notAuthorizedExpect = contains
return m
+27 -32
View File
@@ -2599,16 +2599,7 @@ func (q *FakeQuerier) GetGroupMembersCountByGroupID(ctx context.Context, groupID
return int64(len(users)), nil
}
func (q *FakeQuerier) GetGroups(_ context.Context) ([]database.Group, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
out := make([]database.Group, len(q.groups))
copy(out, q.groups)
return out, nil
}
func (q *FakeQuerier) GetGroupsByOrganizationAndUserID(_ context.Context, arg database.GetGroupsByOrganizationAndUserIDParams) ([]database.Group, error) {
func (q *FakeQuerier) GetGroups(_ context.Context, arg database.GetGroupsParams) ([]database.Group, error) {
err := validateDatabaseType(arg)
if err != nil {
return nil, err
@@ -2616,34 +2607,38 @@ func (q *FakeQuerier) GetGroupsByOrganizationAndUserID(_ context.Context, arg da
q.mutex.RLock()
defer q.mutex.RUnlock()
var groupIDs []uuid.UUID
for _, member := range q.groupMembers {
if member.UserID == arg.UserID {
groupIDs = append(groupIDs, member.GroupID)
groupIDs := make(map[uuid.UUID]struct{})
if arg.HasMemberID != uuid.Nil {
for _, member := range q.groupMembers {
if member.UserID == arg.HasMemberID {
groupIDs[member.GroupID] = struct{}{}
}
}
// Handle the everyone group
for _, orgMember := range q.organizationMembers {
if orgMember.UserID == arg.HasMemberID {
groupIDs[orgMember.OrganizationID] = struct{}{}
}
}
}
groups := []database.Group{}
filtered := make([]database.Group, 0)
for _, group := range q.groups {
if slices.Contains(groupIDs, group.ID) && group.OrganizationID == arg.OrganizationID {
groups = append(groups, group)
if arg.OrganizationID != uuid.Nil && group.OrganizationID != arg.OrganizationID {
continue
}
_, ok := groupIDs[group.ID]
if arg.HasMemberID != uuid.Nil && !ok {
continue
}
filtered = append(filtered, group)
}
return groups, nil
}
func (q *FakeQuerier) GetGroupsByOrganizationID(_ context.Context, id uuid.UUID) ([]database.Group, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
groups := make([]database.Group, 0, len(q.groups))
for _, group := range q.groups {
if group.OrganizationID == id {
groups = append(groups, group)
}
}
return groups, nil
return filtered, nil
}
func (q *FakeQuerier) GetHealthSettings(_ context.Context) (string, error) {
+2 -16
View File
@@ -662,27 +662,13 @@ func (m metricsStore) GetGroupMembersCountByGroupID(ctx context.Context, groupID
return r0, r1
}
func (m metricsStore) GetGroups(ctx context.Context) ([]database.Group, error) {
func (m metricsStore) GetGroups(ctx context.Context, arg database.GetGroupsParams) ([]database.Group, error) {
start := time.Now()
r0, r1 := m.s.GetGroups(ctx)
r0, r1 := m.s.GetGroups(ctx, arg)
m.queryLatencies.WithLabelValues("GetGroups").Observe(time.Since(start).Seconds())
return r0, r1
}
func (m metricsStore) GetGroupsByOrganizationAndUserID(ctx context.Context, arg database.GetGroupsByOrganizationAndUserIDParams) ([]database.Group, error) {
start := time.Now()
r0, r1 := m.s.GetGroupsByOrganizationAndUserID(ctx, arg)
m.queryLatencies.WithLabelValues("GetGroupsByOrganizationAndUserID").Observe(time.Since(start).Seconds())
return r0, r1
}
func (m metricsStore) GetGroupsByOrganizationID(ctx context.Context, organizationID uuid.UUID) ([]database.Group, error) {
start := time.Now()
groups, err := m.s.GetGroupsByOrganizationID(ctx, organizationID)
m.queryLatencies.WithLabelValues("GetGroupsByOrganizationID").Observe(time.Since(start).Seconds())
return groups, err
}
func (m metricsStore) GetHealthSettings(ctx context.Context) (string, error) {
start := time.Now()
r0, r1 := m.s.GetHealthSettings(ctx)
+4 -34
View File
@@ -1315,48 +1315,18 @@ func (mr *MockStoreMockRecorder) GetGroupMembersCountByGroupID(arg0, arg1 any) *
}
// GetGroups mocks base method.
func (m *MockStore) GetGroups(arg0 context.Context) ([]database.Group, error) {
func (m *MockStore) GetGroups(arg0 context.Context, arg1 database.GetGroupsParams) ([]database.Group, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetGroups", arg0)
ret := m.ctrl.Call(m, "GetGroups", arg0, arg1)
ret0, _ := ret[0].([]database.Group)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetGroups indicates an expected call of GetGroups.
func (mr *MockStoreMockRecorder) GetGroups(arg0 any) *gomock.Call {
func (mr *MockStoreMockRecorder) GetGroups(arg0, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroups", reflect.TypeOf((*MockStore)(nil).GetGroups), arg0)
}
// GetGroupsByOrganizationAndUserID mocks base method.
func (m *MockStore) GetGroupsByOrganizationAndUserID(arg0 context.Context, arg1 database.GetGroupsByOrganizationAndUserIDParams) ([]database.Group, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetGroupsByOrganizationAndUserID", arg0, arg1)
ret0, _ := ret[0].([]database.Group)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetGroupsByOrganizationAndUserID indicates an expected call of GetGroupsByOrganizationAndUserID.
func (mr *MockStoreMockRecorder) GetGroupsByOrganizationAndUserID(arg0, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupsByOrganizationAndUserID", reflect.TypeOf((*MockStore)(nil).GetGroupsByOrganizationAndUserID), arg0, arg1)
}
// GetGroupsByOrganizationID mocks base method.
func (m *MockStore) GetGroupsByOrganizationID(arg0 context.Context, arg1 uuid.UUID) ([]database.Group, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetGroupsByOrganizationID", arg0, arg1)
ret0, _ := ret[0].([]database.Group)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetGroupsByOrganizationID indicates an expected call of GetGroupsByOrganizationID.
func (mr *MockStoreMockRecorder) GetGroupsByOrganizationID(arg0, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroupsByOrganizationID", reflect.TypeOf((*MockStore)(nil).GetGroupsByOrganizationID), arg0, arg1)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetGroups", reflect.TypeOf((*MockStore)(nil).GetGroups), arg0, arg1)
}
// GetHealthSettings mocks base method.
+1 -3
View File
@@ -151,9 +151,7 @@ type sqlcQuerier interface {
// count even if the caller does not have read access to ResourceGroupMember.
// They only need ResourceGroup read access.
GetGroupMembersCountByGroupID(ctx context.Context, groupID uuid.UUID) (int64, error)
GetGroups(ctx context.Context) ([]Group, error)
GetGroupsByOrganizationAndUserID(ctx context.Context, arg GetGroupsByOrganizationAndUserIDParams) ([]Group, error)
GetGroupsByOrganizationID(ctx context.Context, organizationID uuid.UUID) ([]Group, error)
GetGroups(ctx context.Context, arg GetGroupsParams) ([]Group, error)
GetHealthSettings(ctx context.Context) (string, error)
GetHungProvisionerJobs(ctx context.Context, updatedAt time.Time) ([]ProvisionerJob, error)
GetJFrogXrayScanByWorkspaceAndAgentID(ctx context.Context, arg GetJFrogXrayScanByWorkspaceAndAgentIDParams) (JfrogXrayScan, error)
+33 -96
View File
@@ -1561,105 +1561,42 @@ func (q *sqlQuerier) GetGroupByOrgAndName(ctx context.Context, arg GetGroupByOrg
}
const getGroups = `-- name: GetGroups :many
SELECT id, name, organization_id, avatar_url, quota_allowance, display_name, source FROM groups
`
func (q *sqlQuerier) GetGroups(ctx context.Context) ([]Group, error) {
rows, err := q.db.QueryContext(ctx, getGroups)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Group
for rows.Next() {
var i Group
if err := rows.Scan(
&i.ID,
&i.Name,
&i.OrganizationID,
&i.AvatarURL,
&i.QuotaAllowance,
&i.DisplayName,
&i.Source,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getGroupsByOrganizationAndUserID = `-- name: GetGroupsByOrganizationAndUserID :many
SELECT
groups.id, groups.name, groups.organization_id, groups.avatar_url, groups.quota_allowance, groups.display_name, groups.source
id, name, organization_id, avatar_url, quota_allowance, display_name, source
FROM
groups
WHERE
groups.id IN (
SELECT
group_id
FROM
group_members_expanded gme
WHERE
gme.user_id = $1
AND
gme.organization_id = $2
)
true
AND CASE
WHEN $1:: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
groups.organization_id = $1
ELSE true
END
AND CASE
-- Filter to only include groups a user is a member of
WHEN $2::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
EXISTS (
SELECT
1
FROM
-- this view handles the 'everyone' group in orgs.
group_members_expanded
WHERE
group_members_expanded.group_id = groups.id
AND
group_members_expanded.user_id = $2
)
ELSE true
END
`
type GetGroupsByOrganizationAndUserIDParams struct {
UserID uuid.UUID `db:"user_id" json:"user_id"`
type GetGroupsParams struct {
OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"`
HasMemberID uuid.UUID `db:"has_member_id" json:"has_member_id"`
}
func (q *sqlQuerier) GetGroupsByOrganizationAndUserID(ctx context.Context, arg GetGroupsByOrganizationAndUserIDParams) ([]Group, error) {
rows, err := q.db.QueryContext(ctx, getGroupsByOrganizationAndUserID, arg.UserID, arg.OrganizationID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Group
for rows.Next() {
var i Group
if err := rows.Scan(
&i.ID,
&i.Name,
&i.OrganizationID,
&i.AvatarURL,
&i.QuotaAllowance,
&i.DisplayName,
&i.Source,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const getGroupsByOrganizationID = `-- name: GetGroupsByOrganizationID :many
SELECT
id, name, organization_id, avatar_url, quota_allowance, display_name, source
FROM
groups
WHERE
organization_id = $1
`
func (q *sqlQuerier) GetGroupsByOrganizationID(ctx context.Context, organizationID uuid.UUID) ([]Group, error) {
rows, err := q.db.QueryContext(ctx, getGroupsByOrganizationID, organizationID)
func (q *sqlQuerier) GetGroups(ctx context.Context, arg GetGroupsParams) ([]Group, error) {
rows, err := q.db.QueryContext(ctx, getGroups, arg.OrganizationID, arg.HasMemberID)
if err != nil {
return nil, err
}
@@ -1766,15 +1703,15 @@ INSERT INTO groups (
id,
name,
organization_id,
source
source
)
SELECT
gen_random_uuid(),
group_name,
$1,
$2
gen_random_uuid(),
group_name,
$1,
$2
FROM
UNNEST($3 :: text[]) AS group_name
UNNEST($3 :: text[]) AS group_name
ON CONFLICT DO NOTHING
RETURNING id, name, organization_id, avatar_url, quota_allowance, display_name, source
`
+31 -29
View File
@@ -1,6 +1,3 @@
-- name: GetGroups :many
SELECT * FROM groups;
-- name: GetGroupByID :one
SELECT
*
@@ -23,30 +20,35 @@ AND
LIMIT
1;
-- name: GetGroupsByOrganizationID :many
-- name: GetGroups :many
SELECT
*
FROM
groups
WHERE
organization_id = $1;
-- name: GetGroupsByOrganizationAndUserID :many
SELECT
groups.*
*
FROM
groups
WHERE
groups.id IN (
SELECT
group_id
FROM
group_members_expanded gme
WHERE
gme.user_id = @user_id
AND
gme.organization_id = @organization_id
);
true
AND CASE
WHEN @organization_id:: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
groups.organization_id = @organization_id
ELSE true
END
AND CASE
-- Filter to only include groups a user is a member of
WHEN @has_member_id::uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
EXISTS (
SELECT
1
FROM
-- this view handles the 'everyone' group in orgs.
group_members_expanded
WHERE
group_members_expanded.group_id = groups.id
AND
group_members_expanded.user_id = @has_member_id
)
ELSE true
END
;
-- name: InsertGroup :one
INSERT INTO groups (
@@ -68,15 +70,15 @@ INSERT INTO groups (
id,
name,
organization_id,
source
source
)
SELECT
gen_random_uuid(),
group_name,
@organization_id,
@source
gen_random_uuid(),
group_name,
@organization_id,
@source
FROM
UNNEST(@group_names :: text[]) AS group_name
UNNEST(@group_names :: text[]) AS group_name
-- If the name conflicts, do nothing.
ON CONFLICT DO NOTHING
RETURNING *;
+13
View File
@@ -144,6 +144,19 @@ func (p *QueryParamParser) RequiredNotEmpty(queryParam ...string) *QueryParamPar
return p
}
// UUIDorName will parse a string as a UUID, if it fails, it uses the "fetchByName"
// function to return a UUID based on the value as a string.
// This is useful when fetching something like an organization by ID or by name.
func (p *QueryParamParser) UUIDorName(vals url.Values, def uuid.UUID, queryParam string, fetchByName func(name string) (uuid.UUID, error)) uuid.UUID {
return ParseCustom(p, vals, def, queryParam, func(v string) (uuid.UUID, error) {
id, err := uuid.Parse(v)
if err == nil {
return id, nil
}
return fetchByName(v)
})
}
func (p *QueryParamParser) UUIDorMe(vals url.Values, def uuid.UUID, me uuid.UUID, queryParam string) uuid.UUID {
return ParseCustom(p, vals, def, queryParam, func(v string) (uuid.UUID, error) {
if v == "me" {
+1 -1
View File
@@ -100,7 +100,7 @@ func TestWorkspaceAgentParam(t *testing.T) {
t.Run("NotAuthorized", func(t *testing.T) {
t.Parallel()
db := dbmem.New()
fakeAuthz := &coderdtest.FakeAuthorizer{AlwaysReturn: xerrors.Errorf("constant failure")}
fakeAuthz := (&coderdtest.FakeAuthorizer{}).AlwaysReturn(xerrors.Errorf("constant failure"))
dbFail := dbauthz.New(db, fakeAuthz, slog.Make(), coderdtest.AccessControlStorePointer())
rtr := chi.NewRouter()
@@ -481,8 +481,8 @@ func (s *server) acquireProtoJob(ctx context.Context, job database.ProvisionerJo
ownerSSHPublicKey = ownerSSHKey.PublicKey
ownerSSHPrivateKey = ownerSSHKey.PrivateKey
}
ownerGroups, err := s.Database.GetGroupsByOrganizationAndUserID(ctx, database.GetGroupsByOrganizationAndUserIDParams{
UserID: owner.ID,
ownerGroups, err := s.Database.GetGroups(ctx, database.GetGroupsParams{
HasMemberID: owner.ID,
OrganizationID: s.OrganizationID,
})
if err != nil {
+4 -4
View File
@@ -288,7 +288,7 @@ func benchmarkSetup(orgs []uuid.UUID, users []uuid.UUID, size int, opts ...func(
// BenchmarkCacher benchmarks the performance of the cacher.
func BenchmarkCacher(b *testing.B) {
ctx := context.Background()
authz := rbac.Cacher(&coderdtest.FakeAuthorizer{AlwaysReturn: nil})
authz := rbac.Cacher(&coderdtest.FakeAuthorizer{})
rats := []int{1, 10, 100}
@@ -322,7 +322,7 @@ func TestCache(t *testing.T) {
ctx := context.Background()
rec := &coderdtest.RecordingAuthorizer{
Wrapped: &coderdtest.FakeAuthorizer{AlwaysReturn: nil},
Wrapped: &coderdtest.FakeAuthorizer{},
}
subj, obj, action := coderdtest.RandomRBACSubject(), coderdtest.RandomRBACObject(), coderdtest.RandomRBACAction()
@@ -340,7 +340,7 @@ func TestCache(t *testing.T) {
ctx := context.Background()
rec := &coderdtest.RecordingAuthorizer{
Wrapped: &coderdtest.FakeAuthorizer{AlwaysReturn: nil},
Wrapped: &coderdtest.FakeAuthorizer{},
}
authz := rbac.Cacher(rec)
subj, obj, action := coderdtest.RandomRBACSubject(), coderdtest.RandomRBACObject(), coderdtest.RandomRBACAction()
@@ -400,7 +400,7 @@ func TestCache(t *testing.T) {
ctx := context.Background()
rec := &coderdtest.RecordingAuthorizer{
Wrapped: &coderdtest.FakeAuthorizer{AlwaysReturn: nil},
Wrapped: &coderdtest.FakeAuthorizer{},
}
authz := rbac.Cacher(rec)
subj1, obj1, action1 := coderdtest.RandomRBACSubject(), coderdtest.RandomRBACObject(), coderdtest.RandomRBACAction()
+1 -1
View File
@@ -367,7 +367,7 @@ func (r *remoteReporter) createSnapshot() (*Snapshot, error) {
return nil
})
eg.Go(func() error {
groups, err := r.options.Database.GetGroups(ctx)
groups, err := r.options.Database.GetGroups(ctx, database.GetGroupsParams{})
if err != nil {
return xerrors.Errorf("get groups: %w", err)
}
+23 -1
View File
@@ -5,6 +5,7 @@ import (
"encoding/json"
"fmt"
"net/http"
"net/url"
"github.com/google/uuid"
"golang.org/x/xerrors"
@@ -60,9 +61,30 @@ func (c *Client) CreateGroup(ctx context.Context, orgID uuid.UUID, req CreateGro
return resp, json.NewDecoder(res.Body).Decode(&resp)
}
// GroupsByOrganization
// Deprecated: use Groups with GroupArguments instead.
func (c *Client) GroupsByOrganization(ctx context.Context, orgID uuid.UUID) ([]Group, error) {
return c.Groups(ctx, GroupArguments{Organization: orgID.String()})
}
type GroupArguments struct {
// Organization can be an org UUID or name
Organization string
// HasMember can be a user uuid or username
HasMember string
}
func (c *Client) Groups(ctx context.Context, args GroupArguments) ([]Group, error) {
qp := url.Values{}
if args.Organization != "" {
qp.Set("organization", args.Organization)
}
if args.HasMember != "" {
qp.Set("has_member", args.HasMember)
}
res, err := c.Request(ctx, http.MethodGet,
fmt.Sprintf("/api/v2/organizations/%s/groups", orgID.String()),
fmt.Sprintf("/api/v2/groups?%s", qp.Encode()),
nil,
)
if err != nil {
+105
View File
@@ -173,6 +173,111 @@ curl -X GET http://coder-server:8080/api/v2/entitlements \
To perform this operation, you must be authenticated. [Learn more](authentication.md).
## Get groups
### Code samples
```shell
# Example request using curl
curl -X GET http://coder-server:8080/api/v2/groups?organization=string&has_member=string \
-H 'Accept: application/json' \
-H 'Coder-Session-Token: API_KEY'
```
`GET /groups`
### Parameters
| Name | In | Type | Required | Description |
| -------------- | ----- | ------ | -------- | ----------------------- |
| `organization` | query | string | true | Organization ID or name |
| `has_member` | query | string | true | User ID or name |
### Example responses
> 200 Response
```json
[
{
"avatar_url": "string",
"display_name": "string",
"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08",
"members": [
{
"avatar_url": "http://example.com",
"created_at": "2019-08-24T14:15:22Z",
"email": "user@example.com",
"id": "497f6eca-6276-4993-bfeb-53cbbbba6f08",
"last_seen_at": "2019-08-24T14:15:22Z",
"login_type": "",
"name": "string",
"status": "active",
"theme_preference": "string",
"updated_at": "2019-08-24T14:15:22Z",
"username": "string"
}
],
"name": "string",
"organization_id": "7c60d51f-b44e-4682-87d6-449835ea4de6",
"quota_allowance": 0,
"source": "user",
"total_member_count": 0
}
]
```
### Responses
| Status | Meaning | Description | Schema |
| ------ | ------------------------------------------------------- | ----------- | --------------------------------------------------- |
| 200 | [OK](https://tools.ietf.org/html/rfc7231#section-6.3.1) | OK | array of [codersdk.Group](schemas.md#codersdkgroup) |
<h3 id="get-groups-responseschema">Response Schema</h3>
Status Code **200**
| Name | Type | Required | Restrictions | Description |
| ---------------------- | ------------------------------------------------------ | -------- | ------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `[array item]` | array | false | | |
| `» avatar_url` | string | false | | |
| `» display_name` | string | false | | |
| `» id` | string(uuid) | false | | |
| `» members` | array | false | | |
| `»» avatar_url` | string(uri) | false | | |
| `»» created_at` | string(date-time) | true | | |
| `»» email` | string(email) | true | | |
| `»» id` | string(uuid) | true | | |
| `»» last_seen_at` | string(date-time) | false | | |
| `»» login_type` | [codersdk.LoginType](schemas.md#codersdklogintype) | false | | |
| `»» name` | string | false | | |
| `»» status` | [codersdk.UserStatus](schemas.md#codersdkuserstatus) | false | | |
| `»» theme_preference` | string | false | | |
| `»» updated_at` | string(date-time) | false | | |
| `»» username` | string | true | | |
| `» name` | string | false | | |
| `» organization_id` | string(uuid) | false | | |
| `» quota_allowance` | integer | false | | |
| `» source` | [codersdk.GroupSource](schemas.md#codersdkgroupsource) | false | | |
| `» total_member_count` | integer | false | | How many members are in this group. Shows the total count, even if the user is not authorized to read group member details. May be greater than `len(Group.Members)`. |
#### Enumerated Values
| Property | Value |
| ------------ | ----------- |
| `login_type` | `` |
| `login_type` | `password` |
| `login_type` | `github` |
| `login_type` | `oidc` |
| `login_type` | `token` |
| `login_type` | `none` |
| `status` | `active` |
| `status` | `suspended` |
| `source` | `user` |
| `source` | `oidc` |
To perform this operation, you must be authenticated. [Learn more](authentication.md).
## Get group by ID
### Code samples
+10 -5
View File
@@ -343,15 +343,20 @@ func New(ctx context.Context, options *Options) (_ *API, err error) {
r.Get("/", api.templateACL)
r.Patch("/", api.patchTemplateACL)
})
r.Route("/groups/{group}", func(r chi.Router) {
r.Route("/groups", func(r chi.Router) {
r.Use(
api.templateRBACEnabledMW,
apiKeyMiddleware,
httpmw.ExtractGroupParam(api.Database),
)
r.Get("/", api.group)
r.Patch("/", api.patchGroup)
r.Delete("/", api.deleteGroup)
r.Get("/", api.groups)
r.Route("/{group}", func(r chi.Router) {
r.Use(
httpmw.ExtractGroupParam(api.Database),
)
r.Get("/", api.group)
r.Patch("/", api.patchGroup)
r.Delete("/", api.deleteGroup)
})
})
r.Route("/workspace-quota", func(r chi.Router) {
r.Use(
+50 -15
View File
@@ -9,13 +9,11 @@ import (
"github.com/google/uuid"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd"
"github.com/coder/coder/v2/coderd/audit"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/coderd/httpapi"
"github.com/coder/coder/v2/coderd/httpmw"
"github.com/coder/coder/v2/coderd/rbac/policy"
"github.com/coder/coder/v2/codersdk"
)
@@ -394,28 +392,65 @@ func (api *API) group(rw http.ResponseWriter, r *http.Request) {
// @Success 200 {array} codersdk.Group
// @Router /organizations/{organization}/groups [get]
func (api *API) groupsByOrganization(rw http.ResponseWriter, r *http.Request) {
org := httpmw.OrganizationParam(r)
values := r.URL.Query()
values.Set("organization", org.ID.String())
r.URL.RawQuery = values.Encode()
api.groups(rw, r)
}
// @Summary Get groups
// @ID get-groups
// @Security CoderSessionToken
// @Produce json
// @Tags Enterprise
// @Param organization query string true "Organization ID or name"
// @Param has_member query string true "User ID or name"
// @Success 200 {array} codersdk.Group
// @Router /groups [get]
func (api *API) groups(rw http.ResponseWriter, r *http.Request) {
var (
ctx = r.Context()
org = httpmw.OrganizationParam(r)
)
ctx := r.Context()
groups, err := api.Database.GetGroupsByOrganizationID(ctx, org.ID)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
httpapi.InternalServerError(rw, err)
var filter database.GetGroupsParams
parser := httpapi.NewQueryParamParser()
// Organization selector can be an org ID or name
filter.OrganizationID = parser.UUIDorName(r.URL.Query(), uuid.Nil, "organization", func(orgName string) (uuid.UUID, error) {
org, err := api.Database.GetOrganizationByName(ctx, orgName)
if err != nil {
return uuid.Nil, xerrors.Errorf("organization %q not found", orgName)
}
return org.ID, nil
})
// has_member selector can be a user ID or username
filter.HasMemberID = parser.UUIDorName(r.URL.Query(), uuid.Nil, "has_member", func(username string) (uuid.UUID, error) {
user, err := api.Database.GetUserByEmailOrUsername(ctx, database.GetUserByEmailOrUsernameParams{
Username: username,
Email: "",
})
if err != nil {
return uuid.Nil, xerrors.Errorf("user %q not found", username)
}
return user.ID, nil
})
parser.ErrorExcessParams(r.URL.Query())
if len(parser.Errors) > 0 {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: "Query parameters have invalid values.",
Validations: parser.Errors,
})
return
}
// Filter groups based on rbac permissions
groups, err = coderd.AuthorizeFilter(api.AGPL.HTTPAuth, r, policy.ActionRead, groups)
groups, err := api.Database.GetGroups(ctx, filter)
if httpapi.Is404Error(err) {
httpapi.ResourceNotFound(rw)
return
}
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching groups.",
Detail: err.Error(),
})
httpapi.InternalServerError(rw, err)
return
}
+68 -14
View File
@@ -4,6 +4,7 @@ import (
"net/http"
"sort"
"testing"
"time"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
@@ -11,6 +12,7 @@ import (
"github.com/coder/coder/v2/coderd/audit"
"github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/coderd/rbac"
"github.com/coder/coder/v2/coderd/util/ptr"
"github.com/coder/coder/v2/codersdk"
@@ -568,7 +570,19 @@ func TestPatchGroup(t *testing.T) {
})
}
func sortGroupMembers(group *codersdk.Group) {
func normalizeAllGroups(groups []codersdk.Group) {
for i := range groups {
normalizeGroupMembers(&groups[i])
}
}
// normalizeGroupMembers removes comparison noise from the group members.
func normalizeGroupMembers(group *codersdk.Group) {
for i := range group.Members {
group.Members[i].LastSeenAt = time.Time{}
group.Members[i].CreatedAt = time.Time{}
group.Members[i].UpdatedAt = time.Time{}
}
sort.Slice(group.Members, func(i, j int) bool {
return group.Members[i].ID.String() < group.Members[j].ID.String()
})
@@ -645,8 +659,8 @@ func TestGroup(t *testing.T) {
ggroup, err := userAdminClient.Group(ctx, group.ID)
require.NoError(t, err)
sortGroupMembers(&group)
sortGroupMembers(&ggroup)
normalizeGroupMembers(&group)
normalizeGroupMembers(&ggroup)
require.Equal(t, group, ggroup)
})
@@ -793,6 +807,8 @@ func TestGroup(t *testing.T) {
func TestGroups(t *testing.T) {
t.Parallel()
// 5 users
// 2 custom groups + original org group
t.Run("OK", func(t *testing.T) {
t.Parallel()
@@ -805,7 +821,7 @@ func TestGroups(t *testing.T) {
_, user2 := coderdtest.CreateAnotherUser(t, client, user.OrganizationID)
_, user3 := coderdtest.CreateAnotherUser(t, client, user.OrganizationID)
_, user4 := coderdtest.CreateAnotherUser(t, client, user.OrganizationID)
_, user5 := coderdtest.CreateAnotherUser(t, client, user.OrganizationID)
user5Client, user5 := coderdtest.CreateAnotherUser(t, client, user.OrganizationID)
ctx := testutil.Context(t, testutil.WaitLong)
group1, err := userAdminClient.CreateGroup(ctx, user.OrganizationID, codersdk.CreateGroupRequest{
@@ -822,26 +838,64 @@ func TestGroups(t *testing.T) {
AddUsers: []string{user2.ID.String(), user3.ID.String()},
})
require.NoError(t, err)
normalizeGroupMembers(&group1)
group2, err = userAdminClient.PatchGroup(ctx, group2.ID, codersdk.PatchGroupRequest{
AddUsers: []string{user4.ID.String(), user5.ID.String()},
})
require.NoError(t, err)
normalizeGroupMembers(&group2)
groups, err := userAdminClient.GroupsByOrganization(ctx, user.OrganizationID)
// Fetch everyone group for comparison
everyoneGroup, err := userAdminClient.Group(ctx, user.OrganizationID)
require.NoError(t, err)
normalizeGroupMembers(&everyoneGroup)
// sort group members so we can compare them
allGroups := append([]codersdk.Group{}, groups...)
allGroups = append(allGroups, group1, group2)
for i := range allGroups {
sortGroupMembers(&allGroups[i])
}
groups, err := userAdminClient.Groups(ctx, codersdk.GroupArguments{
Organization: user.OrganizationID.String(),
})
require.NoError(t, err)
normalizeAllGroups(groups)
// 'Everyone' group + 2 custom groups.
require.Len(t, groups, 3)
require.Contains(t, groups, group1)
require.Contains(t, groups, group2)
require.ElementsMatch(t, []codersdk.Group{
everyoneGroup,
group1,
group2,
}, groups)
// Filter by user
user5Groups, err := userAdminClient.Groups(ctx, codersdk.GroupArguments{
HasMember: user5.Username,
})
require.NoError(t, err)
normalizeAllGroups(user5Groups)
// Everyone group and group 2
require.ElementsMatch(t, []codersdk.Group{
everyoneGroup,
group2,
}, user5Groups)
// Query from the user's perspective
user5View, err := user5Client.Groups(ctx, codersdk.GroupArguments{})
require.NoError(t, err)
normalizeAllGroups(user5Groups)
// Everyone group and group 2
require.Len(t, user5View, 2)
user5ViewIDs := db2sdk.List(user5View, func(g codersdk.Group) uuid.UUID {
return g.ID
})
require.ElementsMatch(t, []uuid.UUID{
everyoneGroup.ID,
group2.ID,
}, user5ViewIDs)
for _, g := range user5View {
// Only expect the 1 member, themselves
require.Len(t, g.Members, 1)
require.Equal(t, user5.ReducedUser.ID, g.Members[0].MinimalUser.ID)
}
})
}
+3 -1
View File
@@ -50,7 +50,9 @@ func (api *API) templateAvailablePermissions(rw http.ResponseWriter, r *http.Req
// Perm check is the template update check.
// nolint:gocritic
groups, err := api.Database.GetGroupsByOrganizationID(dbauthz.AsSystemRestricted(ctx), template.OrganizationID)
groups, err := api.Database.GetGroups(dbauthz.AsSystemRestricted(ctx), database.GetGroupsParams{
OrganizationID: template.OrganizationID,
})
if err != nil {
httpapi.InternalServerError(rw, err)
return
+6
View File
@@ -628,6 +628,12 @@ export interface Group {
readonly source: GroupSource;
}
// From codersdk/groups.go
export interface GroupArguments {
readonly Organization: string;
readonly HasMember: string;
}
// From codersdk/workspaceapps.go
export interface Healthcheck {
readonly url: string;