mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
chore: merge organization member db queries (#13542)
Merge members queries into 1 that also joins in the user table for username. Required to list organization members on UI/cli
This commit is contained in:
@@ -67,12 +67,12 @@ func TestServerCreateAdminUser(t *testing.T) {
|
||||
orgIDs[org.ID] = struct{}{}
|
||||
}
|
||||
|
||||
orgMemberships, err := db.GetOrganizationMembershipsByUserID(ctx, user.ID)
|
||||
orgMemberships, err := db.OrganizationMembers(ctx, database.OrganizationMembersParams{UserID: user.ID})
|
||||
require.NoError(t, err)
|
||||
orgIDs2 := make(map[uuid.UUID]struct{}, len(orgMemberships))
|
||||
for _, membership := range orgMemberships {
|
||||
orgIDs2[membership.OrganizationID] = struct{}{}
|
||||
assert.Equal(t, []string{rbac.RoleOrgAdmin()}, membership.Roles, "user is not org admin")
|
||||
orgIDs2[membership.OrganizationMember.OrganizationID] = struct{}{}
|
||||
assert.Equal(t, []string{rbac.RoleOrgAdmin()}, membership.OrganizationMember.Roles, "user is not org admin")
|
||||
}
|
||||
|
||||
require.Equal(t, orgIDs, orgIDs2, "user is not in all orgs")
|
||||
|
||||
@@ -1476,14 +1476,6 @@ func (q *querier) GetOrganizationIDsByMemberIDs(ctx context.Context, ids []uuid.
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetOrganizationIDsByMemberIDs)(ctx, ids)
|
||||
}
|
||||
|
||||
func (q *querier) GetOrganizationMemberByUserID(ctx context.Context, arg database.GetOrganizationMemberByUserIDParams) (database.OrganizationMember, error) {
|
||||
return fetch(q.log, q.auth, q.db.GetOrganizationMemberByUserID)(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) GetOrganizationMembershipsByUserID(ctx context.Context, userID uuid.UUID) ([]database.OrganizationMember, error) {
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetOrganizationMembershipsByUserID)(ctx, userID)
|
||||
}
|
||||
|
||||
func (q *querier) GetOrganizations(ctx context.Context) ([]database.Organization, error) {
|
||||
fetch := func(ctx context.Context, _ interface{}) ([]database.Organization, error) {
|
||||
return q.db.GetOrganizations(ctx)
|
||||
@@ -2771,6 +2763,10 @@ func (q *querier) ListWorkspaceAgentPortShares(ctx context.Context, workspaceID
|
||||
return q.db.ListWorkspaceAgentPortShares(ctx, workspaceID)
|
||||
}
|
||||
|
||||
func (q *querier) OrganizationMembers(ctx context.Context, arg database.OrganizationMembersParams) ([]database.OrganizationMembersRow, error) {
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.OrganizationMembers)(ctx, arg)
|
||||
}
|
||||
|
||||
func (q *querier) ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx context.Context, templateID uuid.UUID) error {
|
||||
template, err := q.db.GetTemplateByID(ctx, templateID)
|
||||
if err != nil {
|
||||
@@ -2870,15 +2866,15 @@ func (q *querier) UpdateInactiveUsersToDormant(ctx context.Context, lastSeenAfte
|
||||
|
||||
func (q *querier) UpdateMemberRoles(ctx context.Context, arg database.UpdateMemberRolesParams) (database.OrganizationMember, error) {
|
||||
// Authorized fetch will check that the actor has read access to the org member since the org member is returned.
|
||||
member, err := q.GetOrganizationMemberByUserID(ctx, database.GetOrganizationMemberByUserIDParams{
|
||||
member, err := database.ExpectOne(q.OrganizationMembers(ctx, database.OrganizationMembersParams{
|
||||
OrganizationID: arg.OrgID,
|
||||
UserID: arg.UserID,
|
||||
})
|
||||
}))
|
||||
if err != nil {
|
||||
return database.OrganizationMember{}, err
|
||||
}
|
||||
|
||||
originalRoles, err := q.convertToOrganizationRoles(member.OrganizationID, member.Roles)
|
||||
originalRoles, err := q.convertToOrganizationRoles(member.OrganizationMember.OrganizationID, member.OrganizationMember.Roles)
|
||||
if err != nil {
|
||||
return database.OrganizationMember{}, xerrors.Errorf("convert original roles: %w", err)
|
||||
}
|
||||
|
||||
@@ -596,19 +596,6 @@ func (s *MethodTestSuite) TestOrganization() {
|
||||
check.Args([]uuid.UUID{ma.UserID, mb.UserID}).
|
||||
Asserts(rbac.ResourceUserObject(ma.UserID), policy.ActionRead, rbac.ResourceUserObject(mb.UserID), policy.ActionRead)
|
||||
}))
|
||||
s.Run("GetOrganizationMemberByUserID", s.Subtest(func(db database.Store, check *expects) {
|
||||
mem := dbgen.OrganizationMember(s.T(), db, database.OrganizationMember{})
|
||||
check.Args(database.GetOrganizationMemberByUserIDParams{
|
||||
OrganizationID: mem.OrganizationID,
|
||||
UserID: mem.UserID,
|
||||
}).Asserts(mem, policy.ActionRead).Returns(mem)
|
||||
}))
|
||||
s.Run("GetOrganizationMembershipsByUserID", s.Subtest(func(db database.Store, check *expects) {
|
||||
u := dbgen.User(s.T(), db, database.User{})
|
||||
a := dbgen.OrganizationMember(s.T(), db, database.OrganizationMember{UserID: u.ID})
|
||||
b := dbgen.OrganizationMember(s.T(), db, database.OrganizationMember{UserID: u.ID})
|
||||
check.Args(u.ID).Asserts(a, policy.ActionRead, b, policy.ActionRead).Returns(slice.New(a, b))
|
||||
}))
|
||||
s.Run("GetOrganizations", s.Subtest(func(db database.Store, check *expects) {
|
||||
def, _ := db.GetDefaultOrganization(context.Background())
|
||||
a := dbgen.Organization(s.T(), db, database.Organization{})
|
||||
@@ -658,6 +645,22 @@ func (s *MethodTestSuite) TestOrganization() {
|
||||
o.ID,
|
||||
).Asserts(o, policy.ActionDelete)
|
||||
}))
|
||||
s.Run("OrganizationMembers", s.Subtest(func(db database.Store, check *expects) {
|
||||
o := dbgen.Organization(s.T(), db, database.Organization{})
|
||||
u := dbgen.User(s.T(), db, database.User{})
|
||||
mem := dbgen.OrganizationMember(s.T(), db, database.OrganizationMember{
|
||||
OrganizationID: o.ID,
|
||||
UserID: u.ID,
|
||||
Roles: []string{rbac.RoleOrgAdmin()},
|
||||
})
|
||||
|
||||
check.Args(database.OrganizationMembersParams{
|
||||
OrganizationID: uuid.UUID{},
|
||||
UserID: uuid.UUID{},
|
||||
}).Asserts(
|
||||
mem, policy.ActionRead,
|
||||
)
|
||||
}))
|
||||
s.Run("UpdateMemberRoles", s.Subtest(func(db database.Store, check *expects) {
|
||||
o := dbgen.Organization(s.T(), db, database.Organization{})
|
||||
u := dbgen.User(s.T(), db, database.User{})
|
||||
@@ -673,11 +676,14 @@ func (s *MethodTestSuite) TestOrganization() {
|
||||
GrantedRoles: []string{},
|
||||
UserID: u.ID,
|
||||
OrgID: o.ID,
|
||||
}).Asserts(
|
||||
mem, policy.ActionRead,
|
||||
rbac.ResourceAssignRole.InOrg(o.ID), policy.ActionAssign, // org-mem
|
||||
rbac.ResourceAssignRole.InOrg(o.ID), policy.ActionDelete, // org-admin
|
||||
).Returns(out)
|
||||
}).
|
||||
WithNotAuthorized(sql.ErrNoRows.Error()).
|
||||
WithCancelled(sql.ErrNoRows.Error()).
|
||||
Asserts(
|
||||
mem, policy.ActionRead,
|
||||
rbac.ResourceAssignRole.InOrg(o.ID), policy.ActionAssign, // org-mem
|
||||
rbac.ResourceAssignRole.InOrg(o.ID), policy.ActionDelete, // org-admin
|
||||
).Returns(out)
|
||||
}))
|
||||
}
|
||||
|
||||
|
||||
@@ -157,7 +157,7 @@ func (s *MethodTestSuite) Subtest(testCaseF func(db database.Store, check *expec
|
||||
if len(testCase.assertions) > 0 {
|
||||
// Only run these tests if we know the underlying call makes
|
||||
// rbac assertions.
|
||||
s.NotAuthorizedErrorTest(ctx, fakeAuthorizer, callMethod)
|
||||
s.NotAuthorizedErrorTest(ctx, fakeAuthorizer, testCase, callMethod)
|
||||
}
|
||||
|
||||
if len(testCase.assertions) > 0 ||
|
||||
@@ -230,7 +230,7 @@ func (s *MethodTestSuite) NoActorErrorTest(callMethod func(ctx context.Context)
|
||||
|
||||
// NotAuthorizedErrorTest runs the given method with an authorizer that will fail authz.
|
||||
// Asserts that the error returned is a NotAuthorizedError.
|
||||
func (s *MethodTestSuite) NotAuthorizedErrorTest(ctx context.Context, az *coderdtest.FakeAuthorizer, callMethod func(ctx context.Context) ([]reflect.Value, error)) {
|
||||
func (s *MethodTestSuite) NotAuthorizedErrorTest(ctx context.Context, az *coderdtest.FakeAuthorizer, testCase expects, callMethod func(ctx context.Context) ([]reflect.Value, error)) {
|
||||
s.Run("NotAuthorized", func() {
|
||||
az.AlwaysReturn = rbac.ForbiddenWithInternal(xerrors.New("Always fail authz"), rbac.Subject{}, "", rbac.Object{}, nil)
|
||||
|
||||
@@ -242,9 +242,14 @@ func (s *MethodTestSuite) NotAuthorizedErrorTest(ctx context.Context, az *coderd
|
||||
// This is unfortunate, but if we are using `Filter` the error returned will be nil. So filter out
|
||||
// any case where the error is nil and the response is an empty slice.
|
||||
if err != nil || !hasEmptySliceResponse(resp) {
|
||||
s.ErrorContainsf(err, "unauthorized", "error string should have a good message")
|
||||
s.Errorf(err, "method should an error with disallow authz")
|
||||
s.ErrorAs(err, &dbauthz.NotAuthorizedError{}, "error should be NotAuthorizedError")
|
||||
// Expect the default error
|
||||
if testCase.notAuthorizedExpect == "" {
|
||||
s.ErrorContainsf(err, "unauthorized", "error string should have a good message")
|
||||
s.Errorf(err, "method should an error with disallow authz")
|
||||
s.ErrorAs(err, &dbauthz.NotAuthorizedError{}, "error should be NotAuthorizedError")
|
||||
} else {
|
||||
s.ErrorContains(err, testCase.notAuthorizedExpect)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
@@ -263,8 +268,12 @@ func (s *MethodTestSuite) NotAuthorizedErrorTest(ctx context.Context, az *coderd
|
||||
// This is unfortunate, but if we are using `Filter` the error returned will be nil. So filter out
|
||||
// any case where the error is nil and the response is an empty slice.
|
||||
if err != nil || !hasEmptySliceResponse(resp) {
|
||||
s.Errorf(err, "method should an error with cancellation")
|
||||
s.ErrorIsf(err, context.Canceled, "error should match context.Canceled")
|
||||
if testCase.cancelledCtxExpect == "" {
|
||||
s.Errorf(err, "method should an error with cancellation")
|
||||
s.ErrorIsf(err, context.Canceled, "error should match context.Canceled")
|
||||
} else {
|
||||
s.ErrorContains(err, testCase.cancelledCtxExpect)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -308,6 +317,13 @@ type expects struct {
|
||||
// outputs is optional. Can assert non-error return values.
|
||||
outputs []reflect.Value
|
||||
err error
|
||||
|
||||
// Optional override of the default error checks.
|
||||
// By default, we search for the expected error strings.
|
||||
// If these strings are present, these strings will be searched
|
||||
// instead.
|
||||
notAuthorizedExpect string
|
||||
cancelledCtxExpect string
|
||||
}
|
||||
|
||||
// Asserts is required. Asserts the RBAC authorize calls that should be made.
|
||||
@@ -338,6 +354,16 @@ func (m *expects) Errors(err error) *expects {
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *expects) WithNotAuthorized(contains string) *expects {
|
||||
m.notAuthorizedExpect = contains
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *expects) WithCancelled(contains string) *expects {
|
||||
m.cancelledCtxExpect = contains
|
||||
return m
|
||||
}
|
||||
|
||||
// AssertRBAC contains the object and actions to be asserted.
|
||||
type AssertRBAC struct {
|
||||
Object rbac.Object
|
||||
|
||||
@@ -119,10 +119,10 @@ func TestGenerator(t *testing.T) {
|
||||
t.Parallel()
|
||||
db := dbmem.New()
|
||||
exp := dbgen.OrganizationMember(t, db, database.OrganizationMember{})
|
||||
require.Equal(t, exp, must(db.GetOrganizationMemberByUserID(context.Background(), database.GetOrganizationMemberByUserIDParams{
|
||||
require.Equal(t, exp, must(database.ExpectOne(db.OrganizationMembers(context.Background(), database.OrganizationMembersParams{
|
||||
OrganizationID: exp.OrganizationID,
|
||||
UserID: exp.UserID,
|
||||
})))
|
||||
}))).OrganizationMember)
|
||||
})
|
||||
|
||||
t.Run("Workspace", func(t *testing.T) {
|
||||
|
||||
@@ -2760,41 +2760,6 @@ func (q *FakeQuerier) GetOrganizationIDsByMemberIDs(_ context.Context, ids []uui
|
||||
return getOrganizationIDsByMemberIDRows, nil
|
||||
}
|
||||
|
||||
func (q *FakeQuerier) GetOrganizationMemberByUserID(_ context.Context, arg database.GetOrganizationMemberByUserIDParams) (database.OrganizationMember, error) {
|
||||
if err := validateDatabaseType(arg); err != nil {
|
||||
return database.OrganizationMember{}, err
|
||||
}
|
||||
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
for _, organizationMember := range q.organizationMembers {
|
||||
if organizationMember.OrganizationID != arg.OrganizationID {
|
||||
continue
|
||||
}
|
||||
if organizationMember.UserID != arg.UserID {
|
||||
continue
|
||||
}
|
||||
return organizationMember, nil
|
||||
}
|
||||
return database.OrganizationMember{}, sql.ErrNoRows
|
||||
}
|
||||
|
||||
func (q *FakeQuerier) GetOrganizationMembershipsByUserID(_ context.Context, userID uuid.UUID) ([]database.OrganizationMember, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
var memberships []database.OrganizationMember
|
||||
for _, organizationMember := range q.organizationMembers {
|
||||
mem := organizationMember
|
||||
if mem.UserID != userID {
|
||||
continue
|
||||
}
|
||||
memberships = append(memberships, mem)
|
||||
}
|
||||
return memberships, nil
|
||||
}
|
||||
|
||||
func (q *FakeQuerier) GetOrganizations(_ context.Context) ([]database.Organization, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
@@ -6965,6 +6930,34 @@ func (q *FakeQuerier) ListWorkspaceAgentPortShares(_ context.Context, workspaceI
|
||||
return shares, nil
|
||||
}
|
||||
|
||||
func (q *FakeQuerier) OrganizationMembers(_ context.Context, arg database.OrganizationMembersParams) ([]database.OrganizationMembersRow, error) {
|
||||
if err := validateDatabaseType(arg); err != nil {
|
||||
return []database.OrganizationMembersRow{}, err
|
||||
}
|
||||
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
tmp := make([]database.OrganizationMembersRow, 0)
|
||||
for _, organizationMember := range q.organizationMembers {
|
||||
if arg.OrganizationID != uuid.Nil && organizationMember.OrganizationID != arg.OrganizationID {
|
||||
continue
|
||||
}
|
||||
|
||||
if arg.UserID != uuid.Nil && organizationMember.UserID != arg.UserID {
|
||||
continue
|
||||
}
|
||||
|
||||
organizationMember := organizationMember
|
||||
user, _ := q.getUserByIDNoLock(organizationMember.UserID)
|
||||
tmp = append(tmp, database.OrganizationMembersRow{
|
||||
OrganizationMember: organizationMember,
|
||||
Username: user.Username,
|
||||
})
|
||||
}
|
||||
return tmp, nil
|
||||
}
|
||||
|
||||
func (q *FakeQuerier) ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(_ context.Context, templateID uuid.UUID) error {
|
||||
err := validateDatabaseType(templateID)
|
||||
if err != nil {
|
||||
|
||||
@@ -760,20 +760,6 @@ func (m metricsStore) GetOrganizationIDsByMemberIDs(ctx context.Context, ids []u
|
||||
return organizations, err
|
||||
}
|
||||
|
||||
func (m metricsStore) GetOrganizationMemberByUserID(ctx context.Context, arg database.GetOrganizationMemberByUserIDParams) (database.OrganizationMember, error) {
|
||||
start := time.Now()
|
||||
member, err := m.s.GetOrganizationMemberByUserID(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("GetOrganizationMemberByUserID").Observe(time.Since(start).Seconds())
|
||||
return member, err
|
||||
}
|
||||
|
||||
func (m metricsStore) GetOrganizationMembershipsByUserID(ctx context.Context, userID uuid.UUID) ([]database.OrganizationMember, error) {
|
||||
start := time.Now()
|
||||
memberships, err := m.s.GetOrganizationMembershipsByUserID(ctx, userID)
|
||||
m.queryLatencies.WithLabelValues("GetOrganizationMembershipsByUserID").Observe(time.Since(start).Seconds())
|
||||
return memberships, err
|
||||
}
|
||||
|
||||
func (m metricsStore) GetOrganizations(ctx context.Context) ([]database.Organization, error) {
|
||||
start := time.Now()
|
||||
organizations, err := m.s.GetOrganizations(ctx)
|
||||
@@ -1747,6 +1733,13 @@ func (m metricsStore) ListWorkspaceAgentPortShares(ctx context.Context, workspac
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m metricsStore) OrganizationMembers(ctx context.Context, arg database.OrganizationMembersParams) ([]database.OrganizationMembersRow, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.OrganizationMembers(ctx, arg)
|
||||
m.queryLatencies.WithLabelValues("OrganizationMembers").Observe(time.Since(start).Seconds())
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m metricsStore) ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx context.Context, templateID uuid.UUID) error {
|
||||
start := time.Now()
|
||||
r0 := m.s.ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx, templateID)
|
||||
|
||||
@@ -1514,36 +1514,6 @@ func (mr *MockStoreMockRecorder) GetOrganizationIDsByMemberIDs(arg0, arg1 any) *
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrganizationIDsByMemberIDs", reflect.TypeOf((*MockStore)(nil).GetOrganizationIDsByMemberIDs), arg0, arg1)
|
||||
}
|
||||
|
||||
// GetOrganizationMemberByUserID mocks base method.
|
||||
func (m *MockStore) GetOrganizationMemberByUserID(arg0 context.Context, arg1 database.GetOrganizationMemberByUserIDParams) (database.OrganizationMember, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetOrganizationMemberByUserID", arg0, arg1)
|
||||
ret0, _ := ret[0].(database.OrganizationMember)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetOrganizationMemberByUserID indicates an expected call of GetOrganizationMemberByUserID.
|
||||
func (mr *MockStoreMockRecorder) GetOrganizationMemberByUserID(arg0, arg1 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrganizationMemberByUserID", reflect.TypeOf((*MockStore)(nil).GetOrganizationMemberByUserID), arg0, arg1)
|
||||
}
|
||||
|
||||
// GetOrganizationMembershipsByUserID mocks base method.
|
||||
func (m *MockStore) GetOrganizationMembershipsByUserID(arg0 context.Context, arg1 uuid.UUID) ([]database.OrganizationMember, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetOrganizationMembershipsByUserID", arg0, arg1)
|
||||
ret0, _ := ret[0].([]database.OrganizationMember)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetOrganizationMembershipsByUserID indicates an expected call of GetOrganizationMembershipsByUserID.
|
||||
func (mr *MockStoreMockRecorder) GetOrganizationMembershipsByUserID(arg0, arg1 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOrganizationMembershipsByUserID", reflect.TypeOf((*MockStore)(nil).GetOrganizationMembershipsByUserID), arg0, arg1)
|
||||
}
|
||||
|
||||
// GetOrganizations mocks base method.
|
||||
func (m *MockStore) GetOrganizations(arg0 context.Context) ([]database.Organization, error) {
|
||||
m.ctrl.T.Helper()
|
||||
@@ -3661,6 +3631,21 @@ func (mr *MockStoreMockRecorder) ListWorkspaceAgentPortShares(arg0, arg1 any) *g
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListWorkspaceAgentPortShares", reflect.TypeOf((*MockStore)(nil).ListWorkspaceAgentPortShares), arg0, arg1)
|
||||
}
|
||||
|
||||
// OrganizationMembers mocks base method.
|
||||
func (m *MockStore) OrganizationMembers(arg0 context.Context, arg1 database.OrganizationMembersParams) ([]database.OrganizationMembersRow, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "OrganizationMembers", arg0, arg1)
|
||||
ret0, _ := ret[0].([]database.OrganizationMembersRow)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// OrganizationMembers indicates an expected call of OrganizationMembers.
|
||||
func (mr *MockStoreMockRecorder) OrganizationMembers(arg0, arg1 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OrganizationMembers", reflect.TypeOf((*MockStore)(nil).OrganizationMembers), arg0, arg1)
|
||||
}
|
||||
|
||||
// Ping mocks base method.
|
||||
func (m *MockStore) Ping(arg0 context.Context) (time.Duration, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -179,6 +179,10 @@ func (m OrganizationMember) RBACObject() rbac.Object {
|
||||
WithOwner(m.UserID.String())
|
||||
}
|
||||
|
||||
func (m OrganizationMembersRow) RBACObject() rbac.Object {
|
||||
return m.OrganizationMember.RBACObject()
|
||||
}
|
||||
|
||||
func (m GetOrganizationIDsByMemberIDsRow) RBACObject() rbac.Object {
|
||||
// TODO: This feels incorrect as we are really returning a list of orgmembers.
|
||||
// This return type should be refactored to return a list of orgmembers, not this
|
||||
|
||||
@@ -2,6 +2,7 @@ package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
@@ -17,6 +18,29 @@ const (
|
||||
authorizedQueryPlaceholder = "-- @authorize_filter"
|
||||
)
|
||||
|
||||
// ExpectOne can be used to convert a ':many:' query into a ':one'
|
||||
// query. To reduce the quantity of SQL queries, a :many with a filter is used.
|
||||
// These filters sometimes are expected to return just 1 row.
|
||||
//
|
||||
// A :many query will never return a sql.ErrNoRows, but a :one does.
|
||||
// This function will correct the error for the empty set.
|
||||
func ExpectOne[T any](ret []T, err error) (T, error) {
|
||||
var empty T
|
||||
if err != nil {
|
||||
return empty, err
|
||||
}
|
||||
|
||||
if len(ret) == 0 {
|
||||
return empty, sql.ErrNoRows
|
||||
}
|
||||
|
||||
if len(ret) > 1 {
|
||||
return empty, xerrors.Errorf("too many rows returned, expected 1")
|
||||
}
|
||||
|
||||
return ret[0], nil
|
||||
}
|
||||
|
||||
// customQuerier encompasses all non-generated queries.
|
||||
// It provides a flexible way to write queries for cases
|
||||
// where sqlc proves inadequate.
|
||||
|
||||
@@ -151,8 +151,6 @@ type sqlcQuerier interface {
|
||||
GetOrganizationByID(ctx context.Context, id uuid.UUID) (Organization, error)
|
||||
GetOrganizationByName(ctx context.Context, name string) (Organization, error)
|
||||
GetOrganizationIDsByMemberIDs(ctx context.Context, ids []uuid.UUID) ([]GetOrganizationIDsByMemberIDsRow, error)
|
||||
GetOrganizationMemberByUserID(ctx context.Context, arg GetOrganizationMemberByUserIDParams) (OrganizationMember, error)
|
||||
GetOrganizationMembershipsByUserID(ctx context.Context, userID uuid.UUID) ([]OrganizationMember, error)
|
||||
GetOrganizations(ctx context.Context) ([]Organization, error)
|
||||
GetOrganizationsByUserID(ctx context.Context, userID uuid.UUID) ([]Organization, error)
|
||||
GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]ParameterSchema, error)
|
||||
@@ -349,6 +347,11 @@ type sqlcQuerier interface {
|
||||
InsertWorkspaceResource(ctx context.Context, arg InsertWorkspaceResourceParams) (WorkspaceResource, error)
|
||||
InsertWorkspaceResourceMetadata(ctx context.Context, arg InsertWorkspaceResourceMetadataParams) ([]WorkspaceResourceMetadatum, error)
|
||||
ListWorkspaceAgentPortShares(ctx context.Context, workspaceID uuid.UUID) ([]WorkspaceAgentPortShare, error)
|
||||
// Arguments are optional with uuid.Nil to ignore.
|
||||
// - Use just 'organization_id' to get all members of an org
|
||||
// - Use just 'user_id' to get all orgs a user is a member of
|
||||
// - Use both to get a specific org member row
|
||||
OrganizationMembers(ctx context.Context, arg OrganizationMembersParams) ([]OrganizationMembersRow, error)
|
||||
ReduceWorkspaceAgentShareLevelToAuthenticatedByTemplate(ctx context.Context, templateID uuid.UUID) error
|
||||
RegisterWorkspaceProxy(ctx context.Context, arg RegisterWorkspaceProxyParams) (WorkspaceProxy, error)
|
||||
RemoveUserFromAllGroups(ctx context.Context, userID uuid.UUID) error
|
||||
|
||||
@@ -903,6 +903,42 @@ func TestArchiveVersions(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
func TestExpectOne(t *testing.T) {
|
||||
t.Parallel()
|
||||
if testing.Short() {
|
||||
t.SkipNow()
|
||||
}
|
||||
|
||||
t.Run("ErrNoRows", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
sqlDB := testSQLDB(t)
|
||||
err := migrations.Up(sqlDB)
|
||||
require.NoError(t, err)
|
||||
db := database.New(sqlDB)
|
||||
ctx := context.Background()
|
||||
|
||||
_, err = database.ExpectOne(db.GetUsers(ctx, database.GetUsersParams{}))
|
||||
require.ErrorIs(t, err, sql.ErrNoRows)
|
||||
})
|
||||
|
||||
t.Run("TooMany", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
sqlDB := testSQLDB(t)
|
||||
err := migrations.Up(sqlDB)
|
||||
require.NoError(t, err)
|
||||
db := database.New(sqlDB)
|
||||
ctx := context.Background()
|
||||
|
||||
// Create 2 organizations so the query returns >1
|
||||
dbgen.Organization(t, db, database.Organization{})
|
||||
dbgen.Organization(t, db, database.Organization{})
|
||||
|
||||
// Organizations is an easy table without foreign key dependencies
|
||||
_, err = database.ExpectOne(db.GetOrganizations(ctx))
|
||||
require.ErrorContains(t, err, "too many rows returned")
|
||||
})
|
||||
}
|
||||
|
||||
func requireUsersMatch(t testing.TB, expected []database.User, found []database.GetUsersRow, msg string) {
|
||||
t.Helper()
|
||||
require.ElementsMatch(t, expected, database.ConvertUserRows(found), msg)
|
||||
|
||||
@@ -3795,74 +3795,6 @@ func (q *sqlQuerier) GetOrganizationIDsByMemberIDs(ctx context.Context, ids []uu
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const getOrganizationMemberByUserID = `-- name: GetOrganizationMemberByUserID :one
|
||||
SELECT
|
||||
user_id, organization_id, created_at, updated_at, roles
|
||||
FROM
|
||||
organization_members
|
||||
WHERE
|
||||
organization_id = $1
|
||||
AND user_id = $2
|
||||
LIMIT
|
||||
1
|
||||
`
|
||||
|
||||
type GetOrganizationMemberByUserIDParams struct {
|
||||
OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
}
|
||||
|
||||
func (q *sqlQuerier) GetOrganizationMemberByUserID(ctx context.Context, arg GetOrganizationMemberByUserIDParams) (OrganizationMember, error) {
|
||||
row := q.db.QueryRowContext(ctx, getOrganizationMemberByUserID, arg.OrganizationID, arg.UserID)
|
||||
var i OrganizationMember
|
||||
err := row.Scan(
|
||||
&i.UserID,
|
||||
&i.OrganizationID,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
pq.Array(&i.Roles),
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const getOrganizationMembershipsByUserID = `-- name: GetOrganizationMembershipsByUserID :many
|
||||
SELECT
|
||||
user_id, organization_id, created_at, updated_at, roles
|
||||
FROM
|
||||
organization_members
|
||||
WHERE
|
||||
user_id = $1
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetOrganizationMembershipsByUserID(ctx context.Context, userID uuid.UUID) ([]OrganizationMember, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getOrganizationMembershipsByUserID, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []OrganizationMember
|
||||
for rows.Next() {
|
||||
var i OrganizationMember
|
||||
if err := rows.Scan(
|
||||
&i.UserID,
|
||||
&i.OrganizationID,
|
||||
&i.CreatedAt,
|
||||
&i.UpdatedAt,
|
||||
pq.Array(&i.Roles),
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const insertOrganizationMember = `-- name: InsertOrganizationMember :one
|
||||
INSERT INTO
|
||||
organization_members (
|
||||
@@ -3903,6 +3835,73 @@ func (q *sqlQuerier) InsertOrganizationMember(ctx context.Context, arg InsertOrg
|
||||
return i, err
|
||||
}
|
||||
|
||||
const organizationMembers = `-- name: OrganizationMembers :many
|
||||
SELECT
|
||||
organization_members.user_id, organization_members.organization_id, organization_members.created_at, organization_members.updated_at, organization_members.roles,
|
||||
users.username
|
||||
FROM
|
||||
organization_members
|
||||
INNER JOIN
|
||||
users ON organization_members.user_id = users.id
|
||||
WHERE
|
||||
-- Filter by organization id
|
||||
CASE
|
||||
WHEN $1 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
organization_id = $1
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user id
|
||||
AND CASE
|
||||
WHEN $2 :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
user_id = $2
|
||||
ELSE true
|
||||
END
|
||||
`
|
||||
|
||||
type OrganizationMembersParams struct {
|
||||
OrganizationID uuid.UUID `db:"organization_id" json:"organization_id"`
|
||||
UserID uuid.UUID `db:"user_id" json:"user_id"`
|
||||
}
|
||||
|
||||
type OrganizationMembersRow struct {
|
||||
OrganizationMember OrganizationMember `db:"organization_member" json:"organization_member"`
|
||||
Username string `db:"username" json:"username"`
|
||||
}
|
||||
|
||||
// Arguments are optional with uuid.Nil to ignore.
|
||||
// - Use just 'organization_id' to get all members of an org
|
||||
// - Use just 'user_id' to get all orgs a user is a member of
|
||||
// - Use both to get a specific org member row
|
||||
func (q *sqlQuerier) OrganizationMembers(ctx context.Context, arg OrganizationMembersParams) ([]OrganizationMembersRow, error) {
|
||||
rows, err := q.db.QueryContext(ctx, organizationMembers, arg.OrganizationID, arg.UserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []OrganizationMembersRow
|
||||
for rows.Next() {
|
||||
var i OrganizationMembersRow
|
||||
if err := rows.Scan(
|
||||
&i.OrganizationMember.UserID,
|
||||
&i.OrganizationMember.OrganizationID,
|
||||
&i.OrganizationMember.CreatedAt,
|
||||
&i.OrganizationMember.UpdatedAt,
|
||||
pq.Array(&i.OrganizationMember.Roles),
|
||||
&i.Username,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const updateMemberRoles = `-- name: UpdateMemberRoles :one
|
||||
UPDATE
|
||||
organization_members
|
||||
|
||||
@@ -1,13 +1,28 @@
|
||||
-- name: GetOrganizationMemberByUserID :one
|
||||
-- name: OrganizationMembers :many
|
||||
-- Arguments are optional with uuid.Nil to ignore.
|
||||
-- - Use just 'organization_id' to get all members of an org
|
||||
-- - Use just 'user_id' to get all orgs a user is a member of
|
||||
-- - Use both to get a specific org member row
|
||||
SELECT
|
||||
*
|
||||
sqlc.embed(organization_members),
|
||||
users.username
|
||||
FROM
|
||||
organization_members
|
||||
INNER JOIN
|
||||
users ON organization_members.user_id = users.id
|
||||
WHERE
|
||||
organization_id = $1
|
||||
AND user_id = $2
|
||||
LIMIT
|
||||
1;
|
||||
-- Filter by organization id
|
||||
CASE
|
||||
WHEN @organization_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
organization_id = @organization_id
|
||||
ELSE true
|
||||
END
|
||||
-- Filter by user id
|
||||
AND CASE
|
||||
WHEN @user_id :: uuid != '00000000-0000-0000-0000-000000000000'::uuid THEN
|
||||
user_id = @user_id
|
||||
ELSE true
|
||||
END;
|
||||
|
||||
-- name: InsertOrganizationMember :one
|
||||
INSERT INTO
|
||||
@@ -22,14 +37,6 @@ VALUES
|
||||
($1, $2, $3, $4, $5) RETURNING *;
|
||||
|
||||
|
||||
-- name: GetOrganizationMembershipsByUserID :many
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
organization_members
|
||||
WHERE
|
||||
user_id = $1;
|
||||
|
||||
-- name: GetOrganizationIDsByMemberIDs :many
|
||||
SELECT
|
||||
user_id, array_agg(organization_id) :: uuid [ ] AS "organization_IDs"
|
||||
|
||||
@@ -124,10 +124,10 @@ func ExtractOrganizationMemberParam(db database.Store) func(http.Handler) http.H
|
||||
}
|
||||
organization := OrganizationParam(r)
|
||||
|
||||
organizationMember, err := db.GetOrganizationMemberByUserID(ctx, database.GetOrganizationMemberByUserIDParams{
|
||||
organizationMember, err := database.ExpectOne(db.OrganizationMembers(ctx, database.OrganizationMembersParams{
|
||||
OrganizationID: organization.ID,
|
||||
UserID: user.ID,
|
||||
})
|
||||
}))
|
||||
if httpapi.Is404Error(err) {
|
||||
httpapi.ResourceNotFound(rw)
|
||||
return
|
||||
@@ -141,7 +141,7 @@ func ExtractOrganizationMemberParam(db database.Store) func(http.Handler) http.H
|
||||
}
|
||||
|
||||
ctx = context.WithValue(ctx, organizationMemberParamContextKey{}, OrganizationMember{
|
||||
OrganizationMember: organizationMember,
|
||||
OrganizationMember: organizationMember.OrganizationMember,
|
||||
// Here we're making two exceptions to the rule about not leaking data about the user
|
||||
// to the API handler, which is to include the username and avatar URL.
|
||||
// If the caller has permission to read the OrganizationMember, then we're explicitly
|
||||
|
||||
+6
-3
@@ -1518,15 +1518,18 @@ func (api *API) oauthLogin(r *http.Request, params *oauthLoginParams) ([]*http.C
|
||||
}
|
||||
|
||||
//nolint:gocritic // No user present in the context.
|
||||
memberships, err := tx.GetOrganizationMembershipsByUserID(dbauthz.AsSystemRestricted(ctx), user.ID)
|
||||
memberships, err := tx.OrganizationMembers(dbauthz.AsSystemRestricted(ctx), database.OrganizationMembersParams{
|
||||
UserID: user.ID,
|
||||
OrganizationID: uuid.Nil,
|
||||
})
|
||||
if err != nil {
|
||||
return xerrors.Errorf("get organization memberships: %w", err)
|
||||
}
|
||||
|
||||
// If the user is not in the default organization, then we can't assign groups.
|
||||
// A user cannot be in groups to an org they are not a member of.
|
||||
if !slices.ContainsFunc(memberships, func(member database.OrganizationMember) bool {
|
||||
return member.OrganizationID == defaultOrganization.ID
|
||||
if !slices.ContainsFunc(memberships, func(member database.OrganizationMembersRow) bool {
|
||||
return member.OrganizationMember.OrganizationID == defaultOrganization.ID
|
||||
}) {
|
||||
return xerrors.Errorf("user %s is not a member of the default organization, cannot assign to groups in the org", user.ID)
|
||||
}
|
||||
|
||||
+6
-2
@@ -1027,12 +1027,16 @@ func (api *API) userRoles(rw http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
// TODO: Replace this with "GetAuthorizationUserRoles"
|
||||
resp := codersdk.UserRoles{
|
||||
Roles: user.RBACRoles,
|
||||
OrganizationRoles: make(map[uuid.UUID][]string),
|
||||
}
|
||||
|
||||
memberships, err := api.Database.GetOrganizationMembershipsByUserID(ctx, user.ID)
|
||||
memberships, err := api.Database.OrganizationMembers(ctx, database.OrganizationMembersParams{
|
||||
UserID: user.ID,
|
||||
OrganizationID: uuid.Nil,
|
||||
})
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching user's organization memberships.",
|
||||
@@ -1042,7 +1046,7 @@ func (api *API) userRoles(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
for _, mem := range memberships {
|
||||
resp.OrganizationRoles[mem.OrganizationID] = mem.Roles
|
||||
resp.OrganizationRoles[mem.OrganizationMember.OrganizationID] = mem.OrganizationMember.Roles
|
||||
}
|
||||
|
||||
httpapi.Write(ctx, rw, http.StatusOK, resp)
|
||||
|
||||
@@ -166,10 +166,10 @@ func (api *API) patchGroup(rw http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
// TODO: It would be nice to enforce this at the schema level
|
||||
// but unfortunately our org_members table does not have an ID.
|
||||
_, err := api.Database.GetOrganizationMemberByUserID(ctx, database.GetOrganizationMemberByUserIDParams{
|
||||
_, err := database.ExpectOne(api.Database.OrganizationMembers(ctx, database.OrganizationMembersParams{
|
||||
OrganizationID: group.OrganizationID,
|
||||
UserID: uuid.MustParse(id),
|
||||
})
|
||||
}))
|
||||
if xerrors.Is(err, sql.ErrNoRows) {
|
||||
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
|
||||
Message: fmt.Sprintf("User %q must be a member of organization %q", id, group.ID),
|
||||
|
||||
Reference in New Issue
Block a user