chore: merge organization member db queries (#13542)

Merge members queries into 1 that also joins in the user table for username.
Required to list organization members on UI/cli
This commit is contained in:
Steven Masley
2024-06-12 09:23:48 -10:00
committed by GitHub
parent 1ca5dc0328
commit de9e6889bb
18 changed files with 293 additions and 214 deletions
+3 -3
View File
@@ -67,12 +67,12 @@ func TestServerCreateAdminUser(t *testing.T) {
orgIDs[org.ID] = struct{}{}
}
orgMemberships, err := db.GetOrganizationMembershipsByUserID(ctx, user.ID)
orgMemberships, err := db.OrganizationMembers(ctx, database.OrganizationMembersParams{UserID: user.ID})
require.NoError(t, err)
orgIDs2 := make(map[uuid.UUID]struct{}, len(orgMemberships))
for _, membership := range orgMemberships {
orgIDs2[membership.OrganizationID] = struct{}{}
assert.Equal(t, []string{rbac.RoleOrgAdmin()}, membership.Roles, "user is not org admin")
orgIDs2[membership.OrganizationMember.OrganizationID] = struct{}{}
assert.Equal(t, []string{rbac.RoleOrgAdmin()}, membership.OrganizationMember.Roles, "user is not org admin")
}
require.Equal(t, orgIDs, orgIDs2, "user is not in all orgs")
+7 -11
View File
@@ -1476,14 +1476,6 @@ func (q *querier) GetOrganizationIDsByMemberIDs(ctx context.Context, ids []uuid.
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetOrganizationIDsByMemberIDs)(ctx, ids)
}
func (q *querier) GetOrganizationMemberByUserID(ctx context.Context, arg database.GetOrganizationMemberByUserIDParams) (database.OrganizationMember, error) {
return fetch(q.log, q.auth, q.db.GetOrganizationMemberByUserID)(ctx, arg)
}
func (q *querier) GetOrganizationMembershipsByUserID(ctx context.Context, userID uuid.UUID) ([]database.OrganizationMember, error) {
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetOrganizationMembershipsByUserID)(ctx, userID)
}
func (q *querier) GetOrganizations(ctx context.Context) ([]database.Organization, error) {
fetch := func(ctx context.Context, _ interface{}) ([]database.Organization, error) {
return q.db.GetOrganizations(ctx)
@@ -2771,6 +2763,10 @@ func (q *querier) ListWorkspaceAgentPortShares(ctx context.Context, workspaceID
return q.db.ListWorkspaceAgentPortShares(ctx, workspaceID)
}
func (q *querier) OrganizationMembers(ctx context.Context, arg database.OrganizationMembersParams) ([]database.OrganizationMembersRow, error) {
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.OrganizationMembers)(ctx, arg)
}
func (q *querier) ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx context.Context, templateID uuid.UUID) error {
template, err := q.db.GetTemplateByID(ctx, templateID)
if err != nil {
@@ -2870,15 +2866,15 @@ func (q *querier) UpdateInactiveUsersToDormant(ctx context.Context, lastSeenAfte
func (q *querier) UpdateMemberRoles(ctx context.Context, arg database.UpdateMemberRolesParams) (database.OrganizationMember, error) {
// Authorized fetch will check that the actor has read access to the org member since the org member is returned.
member, err := q.GetOrganizationMemberByUserID(ctx, database.GetOrganizationMemberByUserIDParams{
member, err := database.ExpectOne(q.OrganizationMembers(ctx, database.OrganizationMembersParams{
OrganizationID: arg.OrgID,
UserID: arg.UserID,
})
}))
if err != nil {
return database.OrganizationMember{}, err
}
originalRoles, err := q.convertToOrganizationRoles(member.OrganizationID, member.Roles)
originalRoles, err := q.convertToOrganizationRoles(member.OrganizationMember.OrganizationID, member.OrganizationMember.Roles)
if err != nil {
return database.OrganizationMember{}, xerrors.Errorf("convert original roles: %w", err)
}
+24 -18
View File
@@ -596,19 +596,6 @@ func (s *MethodTestSuite) TestOrganization() {
check.Args([]uuid.UUID{ma.UserID, mb.UserID}).
Asserts(rbac.ResourceUserObject(ma.UserID), policy.ActionRead, rbac.ResourceUserObject(mb.UserID), policy.ActionRead)
}))
s.Run("GetOrganizationMemberByUserID", s.Subtest(func(db database.Store, check *expects) {
mem := dbgen.OrganizationMember(s.T(), db, database.OrganizationMember{})
check.Args(database.GetOrganizationMemberByUserIDParams{
OrganizationID: mem.OrganizationID,
UserID: mem.UserID,
}).Asserts(mem, policy.ActionRead).Returns(mem)
}))
s.Run("GetOrganizationMembershipsByUserID", s.Subtest(func(db database.Store, check *expects) {
u := dbgen.User(s.T(), db, database.User{})
a := dbgen.OrganizationMember(s.T(), db, database.OrganizationMember{UserID: u.ID})
b := dbgen.OrganizationMember(s.T(), db, database.OrganizationMember{UserID: u.ID})
check.Args(u.ID).Asserts(a, policy.ActionRead, b, policy.ActionRead).Returns(slice.New(a, b))
}))
s.Run("GetOrganizations", s.Subtest(func(db database.Store, check *expects) {
def, _ := db.GetDefaultOrganization(context.Background())
a := dbgen.Organization(s.T(), db, database.Organization{})
@@ -658,6 +645,22 @@ func (s *MethodTestSuite) TestOrganization() {
o.ID,
).Asserts(o, policy.ActionDelete)
}))
s.Run("OrganizationMembers", s.Subtest(func(db database.Store, check *expects) {
o := dbgen.Organization(s.T(), db, database.Organization{})
u := dbgen.User(s.T(), db, database.User{})
mem := dbgen.OrganizationMember(s.T(), db, database.OrganizationMember{
OrganizationID: o.ID,
UserID: u.ID,
Roles: []string{rbac.RoleOrgAdmin()},
})
check.Args(database.OrganizationMembersParams{
OrganizationID: uuid.UUID{},
UserID: uuid.UUID{},
}).Asserts(
mem, policy.ActionRead,
)
}))
s.Run("UpdateMemberRoles", s.Subtest(func(db database.Store, check *expects) {
o := dbgen.Organization(s.T(), db, database.Organization{})
u := dbgen.User(s.T(), db, database.User{})
@@ -673,11 +676,14 @@ func (s *MethodTestSuite) TestOrganization() {
GrantedRoles: []string{},
UserID: u.ID,
OrgID: o.ID,
}).Asserts(
mem, policy.ActionRead,
rbac.ResourceAssignRole.InOrg(o.ID), policy.ActionAssign, // org-mem
rbac.ResourceAssignRole.InOrg(o.ID), policy.ActionDelete, // org-admin
).Returns(out)
}).
WithNotAuthorized(sql.ErrNoRows.Error()).
WithCancelled(sql.ErrNoRows.Error()).
Asserts(
mem, policy.ActionRead,
rbac.ResourceAssignRole.InOrg(o.ID), policy.ActionAssign, // org-mem
rbac.ResourceAssignRole.InOrg(o.ID), policy.ActionDelete, // org-admin
).Returns(out)
}))
}
+33 -7
View File
@@ -157,7 +157,7 @@ func (s *MethodTestSuite) Subtest(testCaseF func(db database.Store, check *expec
if len(testCase.assertions) > 0 {
// Only run these tests if we know the underlying call makes
// rbac assertions.
s.NotAuthorizedErrorTest(ctx, fakeAuthorizer, callMethod)
s.NotAuthorizedErrorTest(ctx, fakeAuthorizer, testCase, callMethod)
}
if len(testCase.assertions) > 0 ||
@@ -230,7 +230,7 @@ func (s *MethodTestSuite) NoActorErrorTest(callMethod func(ctx context.Context)
// NotAuthorizedErrorTest runs the given method with an authorizer that will fail authz.
// Asserts that the error returned is a NotAuthorizedError.
func (s *MethodTestSuite) NotAuthorizedErrorTest(ctx context.Context, az *coderdtest.FakeAuthorizer, callMethod func(ctx context.Context) ([]reflect.Value, error)) {
func (s *MethodTestSuite) NotAuthorizedErrorTest(ctx context.Context, az *coderdtest.FakeAuthorizer, testCase expects, callMethod func(ctx context.Context) ([]reflect.Value, error)) {
s.Run("NotAuthorized", func() {
az.AlwaysReturn = rbac.ForbiddenWithInternal(xerrors.New("Always fail authz"), rbac.Subject{}, "", rbac.Object{}, nil)
@@ -242,9 +242,14 @@ func (s *MethodTestSuite) NotAuthorizedErrorTest(ctx context.Context, az *coderd
// This is unfortunate, but if we are using `Filter` the error returned will be nil. So filter out
// any case where the error is nil and the response is an empty slice.
if err != nil || !hasEmptySliceResponse(resp) {
s.ErrorContainsf(err, "unauthorized", "error string should have a good message")
s.Errorf(err, "method should an error with disallow authz")
s.ErrorAs(err, &dbauthz.NotAuthorizedError{}, "error should be NotAuthorizedError")
// Expect the default error
if testCase.notAuthorizedExpect == "" {
s.ErrorContainsf(err, "unauthorized", "error string should have a good message")
s.Errorf(err, "method should an error with disallow authz")
s.ErrorAs(err, &dbauthz.NotAuthorizedError{}, "error should be NotAuthorizedError")
} else {
s.ErrorContains(err, testCase.notAuthorizedExpect)
}
}
})
@@ -263,8 +268,12 @@ func (s *MethodTestSuite) NotAuthorizedErrorTest(ctx context.Context, az *coderd
// This is unfortunate, but if we are using `Filter` the error returned will be nil. So filter out
// any case where the error is nil and the response is an empty slice.
if err != nil || !hasEmptySliceResponse(resp) {
s.Errorf(err, "method should an error with cancellation")
s.ErrorIsf(err, context.Canceled, "error should match context.Canceled")
if testCase.cancelledCtxExpect == "" {
s.Errorf(err, "method should an error with cancellation")
s.ErrorIsf(err, context.Canceled, "error should match context.Canceled")
} else {
s.ErrorContains(err, testCase.cancelledCtxExpect)
}
}
})
}
@@ -308,6 +317,13 @@ type expects struct {
// outputs is optional. Can assert non-error return values.
outputs []reflect.Value
err error
// Optional override of the default error checks.
// By default, we search for the expected error strings.
// If these strings are present, these strings will be searched
// instead.
notAuthorizedExpect string
cancelledCtxExpect string
}
// Asserts is required. Asserts the RBAC authorize calls that should be made.
@@ -338,6 +354,16 @@ func (m *expects) Errors(err error) *expects {
return m
}
func (m *expects) WithNotAuthorized(contains string) *expects {
m.notAuthorizedExpect = contains
return m
}
func (m *expects) WithCancelled(contains string) *expects {
m.cancelledCtxExpect = contains
return m
}
// AssertRBAC contains the object and actions to be asserted.
type AssertRBAC struct {
Object rbac.Object
+2 -2
View File
@@ -119,10 +119,10 @@ func TestGenerator(t *testing.T) {
t.Parallel()
db := dbmem.New()
exp := dbgen.OrganizationMember(t, db, database.OrganizationMember{})
require.Equal(t, exp, must(db.GetOrganizationMemberByUserID(context.Background(), database.GetOrganizationMemberByUserIDParams{
require.Equal(t, exp, must(database.ExpectOne(db.OrganizationMembers(context.Background(), database.OrganizationMembersParams{
OrganizationID: exp.OrganizationID,
UserID: exp.UserID,
})))
}))).OrganizationMember)
})
t.Run("Workspace", func(t *testing.T) {
+28 -35
View File
@@ -2760,41 +2760,6 @@ func (q *FakeQuerier) GetOrganizationIDsByMemberIDs(_ context.Context, ids []uui
return getOrganizationIDsByMemberIDRows, nil
}
func (q *FakeQuerier) GetOrganizationMemberByUserID(_ context.Context, arg database.GetOrganizationMemberByUserIDParams) (database.OrganizationMember, error) {
if err := validateDatabaseType(arg); err != nil {
return database.OrganizationMember{}, err
}
q.mutex.RLock()
defer q.mutex.RUnlock()
for _, organizationMember := range q.organizationMembers {
if organizationMember.OrganizationID != arg.OrganizationID {
continue
}
if organizationMember.UserID != arg.UserID {
continue
}
return organizationMember, nil
}
return database.OrganizationMember{}, sql.ErrNoRows
}
func (q *FakeQuerier) GetOrganizationMembershipsByUserID(_ context.Context, userID uuid.UUID) ([]database.OrganizationMember, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
var memberships []database.OrganizationMember
for _, organizationMember := range q.organizationMembers {
mem := organizationMember
if mem.UserID != userID {
continue
}
memberships = append(memberships, mem)
}
return memberships, nil
}
func (q *FakeQuerier) GetOrganizations(_ context.Context) ([]database.Organization, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
@@ -6965,6 +6930,34 @@ func (q *FakeQuerier) ListWorkspaceAgentPortShares(_ context.Context, workspaceI
return shares, nil
}
func (q *FakeQuerier) OrganizationMembers(_ context.Context, arg database.OrganizationMembersParams) ([]database.OrganizationMembersRow, error) {
if err := validateDatabaseType(arg); err != nil {
return []database.OrganizationMembersRow{}, err
}
q.mutex.RLock()
defer q.mutex.RUnlock()
tmp := make([]database.OrganizationMembersRow, 0)
for _, organizationMember := range q.organizationMembers {
if arg.OrganizationID != uuid.Nil && organizationMember.OrganizationID != arg.OrganizationID {
continue
}
if arg.UserID != uuid.Nil && organizationMember.UserID != arg.UserID {
continue
}
organizationMember := organizationMember
user, _ := q.getUserByIDNoLock(organizationMember.UserID)
tmp = append(tmp, database.OrganizationMembersRow{
OrganizationMember: organizationMember,
Username: user.Username,
})
}
return tmp, nil
}
func (q *FakeQuerier) ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(_ context.Context, templateID uuid.UUID) error {
err := validateDatabaseType(templateID)
if err != nil {
+7 -14
View File
@@ -760,20 +760,6 @@ func (m metricsStore) GetOrganizationIDsByMemberIDs(ctx context.Context, ids []u
return organizations, err
}
func (m metricsStore) GetOrganizationMemberByUserID(ctx context.Context, arg database.GetOrganizationMemberByUserIDParams) (database.OrganizationMember, error) {
start := time.Now()
member, err := m.s.GetOrganizationMemberByUserID(ctx, arg)
m.queryLatencies.WithLabelValues("GetOrganizationMemberByUserID").Observe(time.Since(start).Seconds())
return member, err
}
func (m metricsStore) GetOrganizationMembershipsByUserID(ctx context.Context, userID uuid.UUID) ([]database.OrganizationMember, error) {
start := time.Now()
memberships, err := m.s.GetOrganizationMembershipsByUserID(ctx, userID)
m.queryLatencies.WithLabelValues("GetOrganizationMembershipsByUserID").Observe(time.Since(start).Seconds())
return memberships, err
}
func (m metricsStore) GetOrganizations(ctx context.Context) ([]database.Organization, error) {
start := time.Now()
organizations, err := m.s.GetOrganizations(ctx)
@@ -1747,6 +1733,13 @@ func (m metricsStore) ListWorkspaceAgentPortShares(ctx context.Context, workspac
return r0, r1
}
func (m metricsStore) OrganizationMembers(ctx context.Context, arg database.OrganizationMembersParams) ([]database.OrganizationMembersRow, error) {
start := time.Now()
r0, r1 := m.s.OrganizationMembers(ctx, arg)
m.queryLatencies.WithLabelValues("OrganizationMembers").Observe(time.Since(start).Seconds())
return r0, r1
}
func (m metricsStore) ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx context.Context, templateID uuid.UUID) error {
start := time.Now()
r0 := m.s.ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx, templateID)
+15 -30
View File
@@ -1514,36 +1514,6 @@ func (mr *MockStoreMockRecorder) GetOrganizationIDsByMemberIDs(arg0, arg1 any) *
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrganizationIDsByMemberIDs", reflect.TypeOf((*MockStore)(nil).GetOrganizationIDsByMemberIDs), arg0, arg1)
}
// GetOrganizationMemberByUserID mocks base method.
func (m *MockStore) GetOrganizationMemberByUserID(arg0 context.Context, arg1 database.GetOrganizationMemberByUserIDParams) (database.OrganizationMember, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetOrganizationMemberByUserID", arg0, arg1)
ret0, _ := ret[0].(database.OrganizationMember)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetOrganizationMemberByUserID indicates an expected call of GetOrganizationMemberByUserID.
func (mr *MockStoreMockRecorder) GetOrganizationMemberByUserID(arg0, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrganizationMemberByUserID", reflect.TypeOf((*MockStore)(nil).GetOrganizationMemberByUserID), arg0, arg1)
}
// GetOrganizationMembershipsByUserID mocks base method.
func (m *MockStore) GetOrganizationMembershipsByUserID(arg0 context.Context, arg1 uuid.UUID) ([]database.OrganizationMember, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetOrganizationMembershipsByUserID", arg0, arg1)
ret0, _ := ret[0].([]database.OrganizationMember)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetOrganizationMembershipsByUserID indicates an expected call of GetOrganizationMembershipsByUserID.
func (mr *MockStoreMockRecorder) GetOrganizationMembershipsByUserID(arg0, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrganizationMembershipsByUserID", reflect.TypeOf((*MockStore)(nil).GetOrganizationMembershipsByUserID), arg0, arg1)
}
// GetOrganizations mocks base method.
func (m *MockStore) GetOrganizations(arg0 context.Context) ([]database.Organization, error) {
m.ctrl.T.Helper()
@@ -3661,6 +3631,21 @@ func (mr *MockStoreMockRecorder) ListWorkspaceAgentPortShares(arg0, arg1 any) *g
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListWorkspaceAgentPortShares", reflect.TypeOf((*MockStore)(nil).ListWorkspaceAgentPortShares), arg0, arg1)
}
// OrganizationMembers mocks base method.
func (m *MockStore) OrganizationMembers(arg0 context.Context, arg1 database.OrganizationMembersParams) ([]database.OrganizationMembersRow, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OrganizationMembers", arg0, arg1)
ret0, _ := ret[0].([]database.OrganizationMembersRow)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// OrganizationMembers indicates an expected call of OrganizationMembers.
func (mr *MockStoreMockRecorder) OrganizationMembers(arg0, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OrganizationMembers", reflect.TypeOf((*MockStore)(nil).OrganizationMembers), arg0, arg1)
}
// Ping mocks base method.
func (m *MockStore) Ping(arg0 context.Context) (time.Duration, error) {
m.ctrl.T.Helper()
+4
View File
@@ -179,6 +179,10 @@ func (m OrganizationMember) RBACObject() rbac.Object {
WithOwner(m.UserID.String())
}
func (m OrganizationMembersRow) RBACObject() rbac.Object {
return m.OrganizationMember.RBACObject()
}
func (m GetOrganizationIDsByMemberIDsRow) RBACObject() rbac.Object {
// TODO: This feels incorrect as we are really returning a list of orgmembers.
// This return type should be refactored to return a list of orgmembers, not this
+24
View File
@@ -2,6 +2,7 @@ package database
import (
"context"
"database/sql"
"fmt"
"strings"
@@ -17,6 +18,29 @@ const (
authorizedQueryPlaceholder = "-- @authorize_filter"
)
// ExpectOne can be used to convert a ':many:' query into a ':one'
// query. To reduce the quantity of SQL queries, a :many with a filter is used.
// These filters sometimes are expected to return just 1 row.
//
// A :many query will never return a sql.ErrNoRows, but a :one does.
// This function will correct the error for the empty set.
func ExpectOne[T any](ret []T, err error) (T, error) {
var empty T
if err != nil {
return empty, err
}
if len(ret) == 0 {
return empty, sql.ErrNoRows
}
if len(ret) > 1 {
return empty, xerrors.Errorf("too many rows returned, expected 1")
}
return ret[0], nil
}
// customQuerier encompasses all non-generated queries.
// It provides a flexible way to write queries for cases
// where sqlc proves inadequate.
+5 -2
View File
@@ -151,8 +151,6 @@ type sqlcQuerier interface {
GetOrganizationByID(ctx context.Context, id uuid.UUID) (Organization, error)
GetOrganizationByName(ctx context.Context, name string) (Organization, error)
GetOrganizationIDsByMemberIDs(ctx context.Context, ids []uuid.UUID) ([]GetOrganizationIDsByMemberIDsRow, error)
GetOrganizationMemberByUserID(ctx context.Context, arg GetOrganizationMemberByUserIDParams) (OrganizationMember, error)
GetOrganizationMembershipsByUserID(ctx context.Context, userID uuid.UUID) ([]OrganizationMember, error)
GetOrganizations(ctx context.Context) ([]Organization, error)
GetOrganizationsByUserID(ctx context.Context, userID uuid.UUID) ([]Organization, error)
GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]ParameterSchema, error)
@@ -349,6 +347,11 @@ type sqlcQuerier interface {
InsertWorkspaceResource(ctx context.Context, arg InsertWorkspaceResourceParams) (WorkspaceResource, error)
InsertWorkspaceResourceMetadata(ctx context.Context, arg InsertWorkspaceResourceMetadataParams) ([]WorkspaceResourceMetadatum, error)
ListWorkspaceAgentPortShares(ctx context.Context, workspaceID uuid.UUID) ([]WorkspaceAgentPortShare, error)
// Arguments are optional with uuid.Nil to ignore.
// - Use just 'organization_id' to get all members of an org
// - Use just 'user_id' to get all orgs a user is a member of
// - Use both to get a specific org member row
OrganizationMembers(ctx context.Context, arg OrganizationMembersParams) ([]OrganizationMembersRow, error)
ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx context.Context, templateID uuid.UUID) error
RegisterWorkspaceProxy(ctx context.Context, arg RegisterWorkspaceProxyParams) (WorkspaceProxy, error)
RemoveUserFromAllGroups(ctx context.Context, userID uuid.UUID) error
+36
View File
@@ -903,6 +903,42 @@ func TestArchiveVersions(t *testing.T) {
})
}
func TestExpectOne(t *testing.T) {
t.Parallel()
if testing.Short() {
t.SkipNow()
}
t.Run("ErrNoRows", func(t *testing.T) {
t.Parallel()
sqlDB := testSQLDB(t)
err := migrations.Up(sqlDB)
require.NoError(t, err)
db := database.New(sqlDB)
ctx := context.Background()
_, err = database.ExpectOne(db.GetUsers(ctx, database.GetUsersParams{}))
require.ErrorIs(t, err, sql.ErrNoRows)
})
t.Run("TooMany", func(t *testing.T) {
t.Parallel()
sqlDB := testSQLDB(t)
err := migrations.Up(sqlDB)
require.NoError(t, err)
db := database.New(sqlDB)
ctx := context.Background()
// Create 2 organizations so the query returns >1
dbgen.Organization(t, db, database.Organization{})
dbgen.Organization(t, db, database.Organization{})
// Organizations is an easy table without foreign key dependencies
_, err = database.ExpectOne(db.GetOrganizations(ctx))
require.ErrorContains(t, err, "too many rows returned")
})
}
func requireUsersMatch(t testing.TB, expected []database.User, found []database.GetUsersRow, msg string) {
t.Helper()
require.ElementsMatch(t, expected, database.ConvertUserRows(found), msg)
+67 -68
View File
@@ -3795,74 +3795,6 @@ func (q *sqlQuerier) GetOrganizationIDsByMemberIDs(ctx context.Context, ids []uu
return items, nil
}
const getOrganizationMemberByUserID = `-- name: GetOrganizationMemberByUserID :one
SELECT
user_id, organization_id, created_at, updated_at, roles
FROM
organization_members
WHERE
organization_id = $1
AND user_id = $2
LIMIT
1
`
type GetOrganizationMemberByUserIDParams struct {
OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
}
func (q *sqlQuerier) GetOrganizationMemberByUserID(ctx context.Context, arg GetOrganizationMemberByUserIDParams) (OrganizationMember, error) {
row := q.db.QueryRowContext(ctx, getOrganizationMemberByUserID, arg.OrganizationID, arg.UserID)
var i OrganizationMember
err := row.Scan(
&i.UserID,
&i.OrganizationID,
&i.CreatedAt,
&i.UpdatedAt,
pq.Array(&i.Roles),
)
return i, err
}
const getOrganizationMembershipsByUserID = `-- name: GetOrganizationMembershipsByUserID :many
SELECT
user_id, organization_id, created_at, updated_at, roles
FROM
organization_members
WHERE
user_id = $1
`
func (q *sqlQuerier) GetOrganizationMembershipsByUserID(ctx context.Context, userID uuid.UUID) ([]OrganizationMember, error) {
rows, err := q.db.QueryContext(ctx, getOrganizationMembershipsByUserID, userID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []OrganizationMember
for rows.Next() {
var i OrganizationMember
if err := rows.Scan(
&i.UserID,
&i.OrganizationID,
&i.CreatedAt,
&i.UpdatedAt,
pq.Array(&i.Roles),
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const insertOrganizationMember = `-- name: InsertOrganizationMember :one
INSERT INTO
organization_members (
@@ -3903,6 +3835,73 @@ func (q *sqlQuerier) InsertOrganizationMember(ctx context.Context, arg InsertOrg
return i, err
}
const organizationMembers = `-- name: OrganizationMembers :many
SELECT
organization_members.user_id, organization_members.organization_id, organization_members.created_at, organization_members.updated_at, organization_members.roles,
users.username
FROM
organization_members
INNER JOIN
users ON organization_members.user_id = users.id
WHERE
-- Filter by organization id
CASE
WHEN $1 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
organization_id = $1
ELSE true
END
-- Filter by user id
AND CASE
WHEN $2 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
user_id = $2
ELSE true
END
`
type OrganizationMembersParams struct {
OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"`
UserID uuid.UUID `db:"user_id" json:"user_id"`
}
type OrganizationMembersRow struct {
OrganizationMember OrganizationMember `db:"organization_member" json:"organization_member"`
Username string `db:"username" json:"username"`
}
// Arguments are optional with uuid.Nil to ignore.
// - Use just 'organization_id' to get all members of an org
// - Use just 'user_id' to get all orgs a user is a member of
// - Use both to get a specific org member row
func (q *sqlQuerier) OrganizationMembers(ctx context.Context, arg OrganizationMembersParams) ([]OrganizationMembersRow, error) {
rows, err := q.db.QueryContext(ctx, organizationMembers, arg.OrganizationID, arg.UserID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []OrganizationMembersRow
for rows.Next() {
var i OrganizationMembersRow
if err := rows.Scan(
&i.OrganizationMember.UserID,
&i.OrganizationMember.OrganizationID,
&i.OrganizationMember.CreatedAt,
&i.OrganizationMember.UpdatedAt,
pq.Array(&i.OrganizationMember.Roles),
&i.Username,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const updateMemberRoles = `-- name: UpdateMemberRoles :one
UPDATE
organization_members
+21 -14
View File
@@ -1,13 +1,28 @@
-- name: GetOrganizationMemberByUserID :one
-- name: OrganizationMembers :many
-- Arguments are optional with uuid.Nil to ignore.
-- - Use just 'organization_id' to get all members of an org
-- - Use just 'user_id' to get all orgs a user is a member of
-- - Use both to get a specific org member row
SELECT
*
sqlc.embed(organization_members),
users.username
FROM
organization_members
INNER JOIN
users ON organization_members.user_id = users.id
WHERE
organization_id = $1
AND user_id = $2
LIMIT
1;
-- Filter by organization id
CASE
WHEN @organization_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
organization_id = @organization_id
ELSE true
END
-- Filter by user id
AND CASE
WHEN @user_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
user_id = @user_id
ELSE true
END;
-- name: InsertOrganizationMember :one
INSERT INTO
@@ -22,14 +37,6 @@ VALUES
($1, $2, $3, $4, $5) RETURNING *;
-- name: GetOrganizationMembershipsByUserID :many
SELECT
*
FROM
organization_members
WHERE
user_id = $1;
-- name: GetOrganizationIDsByMemberIDs :many
SELECT
user_id, array_agg(organization_id) :: uuid [ ] AS "organization_IDs"
+3 -3
View File
@@ -124,10 +124,10 @@ func ExtractOrganizationMemberParam(db database.Store) func(http.Handler) http.H
}
organization := OrganizationParam(r)
organizationMember, err := db.GetOrganizationMemberByUserID(ctx, database.GetOrganizationMemberByUserIDParams{
organizationMember, err := database.ExpectOne(db.OrganizationMembers(ctx, database.OrganizationMembersParams{
OrganizationID: organization.ID,
UserID: user.ID,
})
}))
if httpapi.Is404Error(err) {
httpapi.ResourceNotFound(rw)
return
@@ -141,7 +141,7 @@ func ExtractOrganizationMemberParam(db database.Store) func(http.Handler) http.H
}
ctx = context.WithValue(ctx, organizationMemberParamContextKey{}, OrganizationMember{
OrganizationMember: organizationMember,
OrganizationMember: organizationMember.OrganizationMember,
// Here we're making two exceptions to the rule about not leaking data about the user
// to the API handler, which is to include the username and avatar URL.
// If the caller has permission to read the OrganizationMember, then we're explicitly
+6 -3
View File
@@ -1518,15 +1518,18 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C
}
//nolint:gocritic // No user present in the context.
memberships, err := tx.GetOrganizationMembershipsByUserID(dbauthz.AsSystemRestricted(ctx), user.ID)
memberships, err := tx.OrganizationMembers(dbauthz.AsSystemRestricted(ctx), database.OrganizationMembersParams{
UserID: user.ID,
OrganizationID: uuid.Nil,
})
if err != nil {
return xerrors.Errorf("get organization memberships: %w", err)
}
// If the user is not in the default organization, then we can't assign groups.
// A user cannot be in groups to an org they are not a member of.
if !slices.ContainsFunc(memberships, func(member database.OrganizationMember) bool {
return member.OrganizationID == defaultOrganization.ID
if !slices.ContainsFunc(memberships, func(member database.OrganizationMembersRow) bool {
return member.OrganizationMember.OrganizationID == defaultOrganization.ID
}) {
return xerrors.Errorf("user %s is not a member of the default organization, cannot assign to groups in the org", user.ID)
}
+6 -2
View File
@@ -1027,12 +1027,16 @@ func (api *API) userRoles(rw http.ResponseWriter, r *http.Request) {
return
}
// TODO: Replace this with "GetAuthorizationUserRoles"
resp := codersdk.UserRoles{
Roles: user.RBACRoles,
OrganizationRoles: make(map[uuid.UUID][]string),
}
memberships, err := api.Database.GetOrganizationMembershipsByUserID(ctx, user.ID)
memberships, err := api.Database.OrganizationMembers(ctx, database.OrganizationMembersParams{
UserID: user.ID,
OrganizationID: uuid.Nil,
})
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching user's organization memberships.",
@@ -1042,7 +1046,7 @@ func (api *API) userRoles(rw http.ResponseWriter, r *http.Request) {
}
for _, mem := range memberships {
resp.OrganizationRoles[mem.OrganizationID] = mem.Roles
resp.OrganizationRoles[mem.OrganizationMember.OrganizationID] = mem.OrganizationMember.Roles
}
httpapi.Write(ctx, rw, http.StatusOK, resp)
+2 -2
View File
@@ -166,10 +166,10 @@ func (api *API) patchGroup(rw http.ResponseWriter, r *http.Request) {
}
// TODO: It would be nice to enforce this at the schema level
// but unfortunately our org_members table does not have an ID.
_, err := api.Database.GetOrganizationMemberByUserID(ctx, database.GetOrganizationMemberByUserIDParams{
_, err := database.ExpectOne(api.Database.OrganizationMembers(ctx, database.OrganizationMembersParams{
OrganizationID: group.OrganizationID,
UserID: uuid.MustParse(id),
})
}))
if xerrors.Is(err, sql.ErrNoRows) {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
Message: fmt.Sprintf("User %q must be a member of organization %q", id, group.ID),