feat: get org scoped provisioners (#13953)

This commit is contained in:
Garrett Delfosse
2024-07-23 10:56:46 -04:00
committed by GitHub
parent 695afb80e6
commit 0a07c7e554
11 changed files with 158 additions and 27 deletions
+5 -1
View File
@@ -1627,6 +1627,10 @@ func (q *querier) GetProvisionerDaemons(ctx context.Context) ([]database.Provisi
return fetchWithPostFilter(q.auth, policy.ActionRead, fetch)(ctx, nil)
}
func (q *querier) GetProvisionerDaemonsByOrganization(ctx context.Context, organizationID uuid.UUID) ([]database.ProvisionerDaemon, error) {
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetProvisionerDaemonsByOrganization)(ctx, organizationID)
}
func (q *querier) GetProvisionerJobByID(ctx context.Context, id uuid.UUID) (database.ProvisionerJob, error) {
job, err := q.db.GetProvisionerJobByID(ctx, id)
if err != nil {
@@ -3727,7 +3731,7 @@ func (q *querier) UpsertOAuthSigningKey(ctx context.Context, value string) error
}
func (q *querier) UpsertProvisionerDaemon(ctx context.Context, arg database.UpsertProvisionerDaemonParams) (database.ProvisionerDaemon, error) {
res := rbac.ResourceProvisionerDaemon.All()
res := rbac.ResourceProvisionerDaemon.InOrg(arg.OrganizationID)
if arg.Tags[provisionersdk.TagScope] == provisionersdk.ScopeUser {
res.Owner = arg.Tags[provisionersdk.TagOwner]
}
+17 -1
View File
@@ -1863,6 +1863,19 @@ func (s *MethodTestSuite) TestExtraMethods() {
s.NoError(err, "insert provisioner daemon")
check.Args().Asserts(d, policy.ActionRead)
}))
s.Run("GetProvisionerDaemonsByOrganization", s.Subtest(func(db database.Store, check *expects) {
org := dbgen.Organization(s.T(), db, database.Organization{})
d, err := db.UpsertProvisionerDaemon(context.Background(), database.UpsertProvisionerDaemonParams{
OrganizationID: org.ID,
Tags: database.StringMap(map[string]string{
provisionersdk.TagScope: provisionersdk.ScopeOrganization,
}),
})
s.NoError(err, "insert provisioner daemon")
ds, err := db.GetProvisionerDaemonsByOrganization(context.Background(), org.ID)
s.NoError(err, "get provisioner daemon by org")
check.Args(org.ID).Asserts(d, policy.ActionRead).Returns(ds)
}))
s.Run("DeleteOldProvisionerDaemons", s.Subtest(func(db database.Store, check *expects) {
_, err := db.UpsertProvisionerDaemon(context.Background(), database.UpsertProvisionerDaemonParams{
Tags: database.StringMap(map[string]string{
@@ -2328,13 +2341,16 @@ func (s *MethodTestSuite) TestSystemFunctions() {
}).Asserts( /*rbac.ResourceSystem, policy.ActionCreate*/ )
}))
s.Run("UpsertProvisionerDaemon", s.Subtest(func(db database.Store, check *expects) {
pd := rbac.ResourceProvisionerDaemon.All()
org := dbgen.Organization(s.T(), db, database.Organization{})
pd := rbac.ResourceProvisionerDaemon.InOrg(org.ID)
check.Args(database.UpsertProvisionerDaemonParams{
OrganizationID: org.ID,
Tags: database.StringMap(map[string]string{
provisionersdk.TagScope: provisionersdk.ScopeOrganization,
}),
}).Asserts(pd, policy.ActionCreate)
check.Args(database.UpsertProvisionerDaemonParams{
OrganizationID: org.ID,
Tags: database.StringMap(map[string]string{
provisionersdk.TagScope: provisionersdk.ScopeUser,
provisionersdk.TagOwner: "11111111-1111-1111-1111-111111111111",
+15
View File
@@ -3140,6 +3140,21 @@ func (q *FakeQuerier) GetProvisionerDaemons(_ context.Context) ([]database.Provi
return out, nil
}
func (q *FakeQuerier) GetProvisionerDaemonsByOrganization(_ context.Context, organizationID uuid.UUID) ([]database.ProvisionerDaemon, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
daemons := make([]database.ProvisionerDaemon, 0)
for _, daemon := range q.provisionerDaemons {
if daemon.OrganizationID == organizationID {
daemon.Tags = maps.Clone(daemon.Tags)
daemons = append(daemons, daemon)
}
}
return daemons, nil
}
func (q *FakeQuerier) GetProvisionerJobByID(ctx context.Context, id uuid.UUID) (database.ProvisionerJob, error) {
q.mutex.RLock()
defer q.mutex.RUnlock()
+7
View File
@@ -879,6 +879,13 @@ func (m metricsStore) GetProvisionerDaemons(ctx context.Context) ([]database.Pro
return daemons, err
}
func (m metricsStore) GetProvisionerDaemonsByOrganization(ctx context.Context, organizationID uuid.UUID) ([]database.ProvisionerDaemon, error) {
start := time.Now()
r0, r1 := m.s.GetProvisionerDaemonsByOrganization(ctx, organizationID)
m.queryLatencies.WithLabelValues("GetProvisionerDaemonsByOrganization").Observe(time.Since(start).Seconds())
return r0, r1
}
func (m metricsStore) GetProvisionerJobByID(ctx context.Context, id uuid.UUID) (database.ProvisionerJob, error) {
start := time.Now()
job, err := m.s.GetProvisionerJobByID(ctx, id)
+15
View File
@@ -1765,6 +1765,21 @@ func (mr *MockStoreMockRecorder) GetProvisionerDaemons(arg0 any) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProvisionerDaemons", reflect.TypeOf((*MockStore)(nil).GetProvisionerDaemons), arg0)
}
// GetProvisionerDaemonsByOrganization mocks base method.
func (m *MockStore) GetProvisionerDaemonsByOrganization(arg0 context.Context, arg1 uuid.UUID) ([]database.ProvisionerDaemon, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetProvisionerDaemonsByOrganization", arg0, arg1)
ret0, _ := ret[0].([]database.ProvisionerDaemon)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetProvisionerDaemonsByOrganization indicates an expected call of GetProvisionerDaemonsByOrganization.
func (mr *MockStoreMockRecorder) GetProvisionerDaemonsByOrganization(arg0, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProvisionerDaemonsByOrganization", reflect.TypeOf((*MockStore)(nil).GetProvisionerDaemonsByOrganization), arg0, arg1)
}
// GetProvisionerJobByID mocks base method.
func (m *MockStore) GetProvisionerJobByID(arg0 context.Context, arg1 uuid.UUID) (database.ProvisionerJob, error) {
m.ctrl.T.Helper()
+3 -1
View File
@@ -209,7 +209,9 @@ func (o Organization) RBACObject() rbac.Object {
}
func (p ProvisionerDaemon) RBACObject() rbac.Object {
return rbac.ResourceProvisionerDaemon.WithID(p.ID)
return rbac.ResourceProvisionerDaemon.
WithID(p.ID).
InOrg(p.OrganizationID)
}
func (p ProvisionerKey) RBACObject() rbac.Object {
+1
View File
@@ -181,6 +181,7 @@ type sqlcQuerier interface {
GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]ParameterSchema, error)
GetPreviousTemplateVersion(ctx context.Context, arg GetPreviousTemplateVersionParams) (TemplateVersion, error)
GetProvisionerDaemons(ctx context.Context) ([]ProvisionerDaemon, error)
GetProvisionerDaemonsByOrganization(ctx context.Context, organizationID uuid.UUID) ([]ProvisionerDaemon, error)
GetProvisionerJobByID(ctx context.Context, id uuid.UUID) (ProvisionerJob, error)
GetProvisionerJobsByIDs(ctx context.Context, ids []uuid.UUID) ([]ProvisionerJob, error)
GetProvisionerJobsByIDsWithQueuePosition(ctx context.Context, ids []uuid.UUID) ([]GetProvisionerJobsByIDsWithQueuePositionRow, error)
+43
View File
@@ -4770,6 +4770,49 @@ func (q *sqlQuerier) GetProvisionerDaemons(ctx context.Context) ([]ProvisionerDa
return items, nil
}
const getProvisionerDaemonsByOrganization = `-- name: GetProvisionerDaemonsByOrganization :many
SELECT
id, created_at, name, provisioners, replica_id, tags, last_seen_at, version, api_version, organization_id
FROM
provisioner_daemons
WHERE
organization_id = $1
`
func (q *sqlQuerier) GetProvisionerDaemonsByOrganization(ctx context.Context, organizationID uuid.UUID) ([]ProvisionerDaemon, error) {
rows, err := q.db.QueryContext(ctx, getProvisionerDaemonsByOrganization, organizationID)
if err != nil {
return nil, err
}
defer rows.Close()
var items []ProvisionerDaemon
for rows.Next() {
var i ProvisionerDaemon
if err := rows.Scan(
&i.ID,
&i.CreatedAt,
&i.Name,
pq.Array(&i.Provisioners),
&i.ReplicaID,
&i.Tags,
&i.LastSeenAt,
&i.Version,
&i.APIVersion,
&i.OrganizationID,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}
const updateProvisionerDaemonLastSeenAt = `-- name: UpdateProvisionerDaemonLastSeenAt :exec
UPDATE provisioner_daemons
SET
@@ -4,6 +4,14 @@ SELECT
FROM
provisioner_daemons;
-- name: GetProvisionerDaemonsByOrganization :many
SELECT
*
FROM
provisioner_daemons
WHERE
organization_id = @organization_id;
-- name: DeleteOldProvisionerDaemons :exec
-- Delete provisioner daemons that have been created at least a week ago
-- and have not connected to coderd since a week.
+6 -20
View File
@@ -3,7 +3,6 @@ package coderd
import (
"context"
"database/sql"
"errors"
"fmt"
"io"
"net/http"
@@ -21,7 +20,6 @@ import (
"cdr.dev/slog"
"github.com/coder/coder/v2/coderd"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/db2sdk"
"github.com/coder/coder/v2/coderd/database/dbauthz"
@@ -65,21 +63,9 @@ func (api *API) provisionerDaemonsEnabledMW(next http.Handler) http.Handler {
// @Router /organizations/{organization}/provisionerdaemons [get]
func (api *API) provisionerDaemons(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
daemons, err := api.Database.GetProvisionerDaemons(ctx)
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching provisioner daemons.",
Detail: err.Error(),
})
return
}
if daemons == nil {
daemons = []database.ProvisionerDaemon{}
}
daemons, err = coderd.AuthorizeFilter(api.AGPL.HTTPAuth, r, policy.ActionRead, daemons)
org := httpmw.OrganizationParam(r)
daemons, err := api.Database.GetProvisionerDaemonsByOrganization(ctx, org.ID)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching provisioner daemons.",
@@ -98,7 +84,7 @@ type provisionerDaemonAuth struct {
// authorize returns mutated tags and true if the given HTTP request is authorized to access the provisioner daemon
// protobuf API, and returns nil, false otherwise.
func (p *provisionerDaemonAuth) authorize(r *http.Request, tags map[string]string) (map[string]string, bool) {
func (p *provisionerDaemonAuth) authorize(r *http.Request, orgID uuid.UUID, tags map[string]string) (map[string]string, bool) {
ctx := r.Context()
apiKey, ok := httpmw.APIKeyOptional(r)
if ok {
@@ -109,7 +95,7 @@ func (p *provisionerDaemonAuth) authorize(r *http.Request, tags map[string]strin
return tags, true
}
ua := httpmw.UserAuthorization(r)
if err := p.authorizer.Authorize(ctx, ua, policy.ActionCreate, rbac.ResourceProvisionerDaemon); err == nil {
if err := p.authorizer.Authorize(ctx, ua, policy.ActionCreate, rbac.ResourceProvisionerDaemon.InOrg(orgID)); err == nil {
// User is allowed to create provisioner daemons
return tags, true
}
@@ -185,7 +171,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
api.Logger.Warn(ctx, "unnamed provisioner daemon")
}
tags, authorized := api.provisionerDaemonAuth.authorize(r, tags)
tags, authorized := api.provisionerDaemonAuth.authorize(r, organization.ID, tags)
if !authorized {
api.Logger.Warn(ctx, "unauthorized provisioner daemon serve request", slog.F("tags", tags))
httpapi.Write(ctx, rw, http.StatusForbidden,
+38 -4
View File
@@ -211,10 +211,7 @@ func TestProvisionerDaemonServe(t *testing.T) {
provisionersdk.TagScope: provisionersdk.ScopeOrganization,
},
})
require.Error(t, err)
var apiError *codersdk.Error
require.ErrorAs(t, err, &apiError)
require.Equal(t, http.StatusForbidden, apiError.StatusCode())
require.NoError(t, err)
})
t.Run("OrganizationNoPerms", func(t *testing.T) {
@@ -556,3 +553,40 @@ func TestProvisionerDaemonServe(t *testing.T) {
require.Len(t, daemons, 0)
})
}
func TestGetProvisionerDaemons(t *testing.T) {
t.Parallel()
t.Run("OK", func(t *testing.T) {
t.Parallel()
client, _ := coderdenttest.New(t, &coderdenttest.Options{LicenseOptions: &coderdenttest.LicenseOptions{
Features: license.Features{
codersdk.FeatureExternalProvisionerDaemons: 1,
},
}})
org := coderdtest.CreateOrganization(t, client, coderdtest.CreateOrganizationOptions{})
orgAdmin, _ := coderdtest.CreateAnotherUser(t, client, org.ID, rbac.ScopedRoleOrgAdmin(org.ID))
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
daemonName := testutil.MustRandString(t, 63)
srv, err := orgAdmin.ServeProvisionerDaemon(ctx, codersdk.ServeProvisionerDaemonRequest{
ID: uuid.New(),
Name: daemonName,
Organization: org.ID,
Provisioners: []codersdk.ProvisionerType{
codersdk.ProvisionerTypeEcho,
},
Tags: map[string]string{},
})
require.NoError(t, err)
srv.DRPCConn().Close()
daemons, err := orgAdmin.OrganizationProvisionerDaemons(ctx, org.ID)
require.NoError(t, err)
if assert.Len(t, daemons, 1) {
assert.Equal(t, daemonName, daemons[0].Name)
assert.Equal(t, buildinfo.Version(), daemons[0].Version)
assert.Equal(t, proto.CurrentVersion.String(), daemons[0].APIVersion)
}
})
}