mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
feat: get org scoped provisioners (#13953)
This commit is contained in:
@@ -1627,6 +1627,10 @@ func (q *querier) GetProvisionerDaemons(ctx context.Context) ([]database.Provisi
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, fetch)(ctx, nil)
|
||||
}
|
||||
|
||||
func (q *querier) GetProvisionerDaemonsByOrganization(ctx context.Context, organizationID uuid.UUID) ([]database.ProvisionerDaemon, error) {
|
||||
return fetchWithPostFilter(q.auth, policy.ActionRead, q.db.GetProvisionerDaemonsByOrganization)(ctx, organizationID)
|
||||
}
|
||||
|
||||
func (q *querier) GetProvisionerJobByID(ctx context.Context, id uuid.UUID) (database.ProvisionerJob, error) {
|
||||
job, err := q.db.GetProvisionerJobByID(ctx, id)
|
||||
if err != nil {
|
||||
@@ -3727,7 +3731,7 @@ func (q *querier) UpsertOAuthSigningKey(ctx context.Context, value string) error
|
||||
}
|
||||
|
||||
func (q *querier) UpsertProvisionerDaemon(ctx context.Context, arg database.UpsertProvisionerDaemonParams) (database.ProvisionerDaemon, error) {
|
||||
res := rbac.ResourceProvisionerDaemon.All()
|
||||
res := rbac.ResourceProvisionerDaemon.InOrg(arg.OrganizationID)
|
||||
if arg.Tags[provisionersdk.TagScope] == provisionersdk.ScopeUser {
|
||||
res.Owner = arg.Tags[provisionersdk.TagOwner]
|
||||
}
|
||||
|
||||
@@ -1863,6 +1863,19 @@ func (s *MethodTestSuite) TestExtraMethods() {
|
||||
s.NoError(err, "insert provisioner daemon")
|
||||
check.Args().Asserts(d, policy.ActionRead)
|
||||
}))
|
||||
s.Run("GetProvisionerDaemonsByOrganization", s.Subtest(func(db database.Store, check *expects) {
|
||||
org := dbgen.Organization(s.T(), db, database.Organization{})
|
||||
d, err := db.UpsertProvisionerDaemon(context.Background(), database.UpsertProvisionerDaemonParams{
|
||||
OrganizationID: org.ID,
|
||||
Tags: database.StringMap(map[string]string{
|
||||
provisionersdk.TagScope: provisionersdk.ScopeOrganization,
|
||||
}),
|
||||
})
|
||||
s.NoError(err, "insert provisioner daemon")
|
||||
ds, err := db.GetProvisionerDaemonsByOrganization(context.Background(), org.ID)
|
||||
s.NoError(err, "get provisioner daemon by org")
|
||||
check.Args(org.ID).Asserts(d, policy.ActionRead).Returns(ds)
|
||||
}))
|
||||
s.Run("DeleteOldProvisionerDaemons", s.Subtest(func(db database.Store, check *expects) {
|
||||
_, err := db.UpsertProvisionerDaemon(context.Background(), database.UpsertProvisionerDaemonParams{
|
||||
Tags: database.StringMap(map[string]string{
|
||||
@@ -2328,13 +2341,16 @@ func (s *MethodTestSuite) TestSystemFunctions() {
|
||||
}).Asserts( /*rbac.ResourceSystem, policy.ActionCreate*/ )
|
||||
}))
|
||||
s.Run("UpsertProvisionerDaemon", s.Subtest(func(db database.Store, check *expects) {
|
||||
pd := rbac.ResourceProvisionerDaemon.All()
|
||||
org := dbgen.Organization(s.T(), db, database.Organization{})
|
||||
pd := rbac.ResourceProvisionerDaemon.InOrg(org.ID)
|
||||
check.Args(database.UpsertProvisionerDaemonParams{
|
||||
OrganizationID: org.ID,
|
||||
Tags: database.StringMap(map[string]string{
|
||||
provisionersdk.TagScope: provisionersdk.ScopeOrganization,
|
||||
}),
|
||||
}).Asserts(pd, policy.ActionCreate)
|
||||
check.Args(database.UpsertProvisionerDaemonParams{
|
||||
OrganizationID: org.ID,
|
||||
Tags: database.StringMap(map[string]string{
|
||||
provisionersdk.TagScope: provisionersdk.ScopeUser,
|
||||
provisionersdk.TagOwner: "11111111-1111-1111-1111-111111111111",
|
||||
|
||||
@@ -3140,6 +3140,21 @@ func (q *FakeQuerier) GetProvisionerDaemons(_ context.Context) ([]database.Provi
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (q *FakeQuerier) GetProvisionerDaemonsByOrganization(_ context.Context, organizationID uuid.UUID) ([]database.ProvisionerDaemon, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
daemons := make([]database.ProvisionerDaemon, 0)
|
||||
for _, daemon := range q.provisionerDaemons {
|
||||
if daemon.OrganizationID == organizationID {
|
||||
daemon.Tags = maps.Clone(daemon.Tags)
|
||||
daemons = append(daemons, daemon)
|
||||
}
|
||||
}
|
||||
|
||||
return daemons, nil
|
||||
}
|
||||
|
||||
func (q *FakeQuerier) GetProvisionerJobByID(ctx context.Context, id uuid.UUID) (database.ProvisionerJob, error) {
|
||||
q.mutex.RLock()
|
||||
defer q.mutex.RUnlock()
|
||||
|
||||
@@ -879,6 +879,13 @@ func (m metricsStore) GetProvisionerDaemons(ctx context.Context) ([]database.Pro
|
||||
return daemons, err
|
||||
}
|
||||
|
||||
func (m metricsStore) GetProvisionerDaemonsByOrganization(ctx context.Context, organizationID uuid.UUID) ([]database.ProvisionerDaemon, error) {
|
||||
start := time.Now()
|
||||
r0, r1 := m.s.GetProvisionerDaemonsByOrganization(ctx, organizationID)
|
||||
m.queryLatencies.WithLabelValues("GetProvisionerDaemonsByOrganization").Observe(time.Since(start).Seconds())
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
func (m metricsStore) GetProvisionerJobByID(ctx context.Context, id uuid.UUID) (database.ProvisionerJob, error) {
|
||||
start := time.Now()
|
||||
job, err := m.s.GetProvisionerJobByID(ctx, id)
|
||||
|
||||
@@ -1765,6 +1765,21 @@ func (mr *MockStoreMockRecorder) GetProvisionerDaemons(arg0 any) *gomock.Call {
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProvisionerDaemons", reflect.TypeOf((*MockStore)(nil).GetProvisionerDaemons), arg0)
|
||||
}
|
||||
|
||||
// GetProvisionerDaemonsByOrganization mocks base method.
|
||||
func (m *MockStore) GetProvisionerDaemonsByOrganization(arg0 context.Context, arg1 uuid.UUID) ([]database.ProvisionerDaemon, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetProvisionerDaemonsByOrganization", arg0, arg1)
|
||||
ret0, _ := ret[0].([]database.ProvisionerDaemon)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetProvisionerDaemonsByOrganization indicates an expected call of GetProvisionerDaemonsByOrganization.
|
||||
func (mr *MockStoreMockRecorder) GetProvisionerDaemonsByOrganization(arg0, arg1 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetProvisionerDaemonsByOrganization", reflect.TypeOf((*MockStore)(nil).GetProvisionerDaemonsByOrganization), arg0, arg1)
|
||||
}
|
||||
|
||||
// GetProvisionerJobByID mocks base method.
|
||||
func (m *MockStore) GetProvisionerJobByID(arg0 context.Context, arg1 uuid.UUID) (database.ProvisionerJob, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
||||
@@ -209,7 +209,9 @@ func (o Organization) RBACObject() rbac.Object {
|
||||
}
|
||||
|
||||
func (p ProvisionerDaemon) RBACObject() rbac.Object {
|
||||
return rbac.ResourceProvisionerDaemon.WithID(p.ID)
|
||||
return rbac.ResourceProvisionerDaemon.
|
||||
WithID(p.ID).
|
||||
InOrg(p.OrganizationID)
|
||||
}
|
||||
|
||||
func (p ProvisionerKey) RBACObject() rbac.Object {
|
||||
|
||||
@@ -181,6 +181,7 @@ type sqlcQuerier interface {
|
||||
GetParameterSchemasByJobID(ctx context.Context, jobID uuid.UUID) ([]ParameterSchema, error)
|
||||
GetPreviousTemplateVersion(ctx context.Context, arg GetPreviousTemplateVersionParams) (TemplateVersion, error)
|
||||
GetProvisionerDaemons(ctx context.Context) ([]ProvisionerDaemon, error)
|
||||
GetProvisionerDaemonsByOrganization(ctx context.Context, organizationID uuid.UUID) ([]ProvisionerDaemon, error)
|
||||
GetProvisionerJobByID(ctx context.Context, id uuid.UUID) (ProvisionerJob, error)
|
||||
GetProvisionerJobsByIDs(ctx context.Context, ids []uuid.UUID) ([]ProvisionerJob, error)
|
||||
GetProvisionerJobsByIDsWithQueuePosition(ctx context.Context, ids []uuid.UUID) ([]GetProvisionerJobsByIDsWithQueuePositionRow, error)
|
||||
|
||||
@@ -4770,6 +4770,49 @@ func (q *sqlQuerier) GetProvisionerDaemons(ctx context.Context) ([]ProvisionerDa
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const getProvisionerDaemonsByOrganization = `-- name: GetProvisionerDaemonsByOrganization :many
|
||||
SELECT
|
||||
id, created_at, name, provisioners, replica_id, tags, last_seen_at, version, api_version, organization_id
|
||||
FROM
|
||||
provisioner_daemons
|
||||
WHERE
|
||||
organization_id = $1
|
||||
`
|
||||
|
||||
func (q *sqlQuerier) GetProvisionerDaemonsByOrganization(ctx context.Context, organizationID uuid.UUID) ([]ProvisionerDaemon, error) {
|
||||
rows, err := q.db.QueryContext(ctx, getProvisionerDaemonsByOrganization, organizationID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
var items []ProvisionerDaemon
|
||||
for rows.Next() {
|
||||
var i ProvisionerDaemon
|
||||
if err := rows.Scan(
|
||||
&i.ID,
|
||||
&i.CreatedAt,
|
||||
&i.Name,
|
||||
pq.Array(&i.Provisioners),
|
||||
&i.ReplicaID,
|
||||
&i.Tags,
|
||||
&i.LastSeenAt,
|
||||
&i.Version,
|
||||
&i.APIVersion,
|
||||
&i.OrganizationID,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
items = append(items, i)
|
||||
}
|
||||
if err := rows.Close(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return items, nil
|
||||
}
|
||||
|
||||
const updateProvisionerDaemonLastSeenAt = `-- name: UpdateProvisionerDaemonLastSeenAt :exec
|
||||
UPDATE provisioner_daemons
|
||||
SET
|
||||
|
||||
@@ -4,6 +4,14 @@ SELECT
|
||||
FROM
|
||||
provisioner_daemons;
|
||||
|
||||
-- name: GetProvisionerDaemonsByOrganization :many
|
||||
SELECT
|
||||
*
|
||||
FROM
|
||||
provisioner_daemons
|
||||
WHERE
|
||||
organization_id = @organization_id;
|
||||
|
||||
-- name: DeleteOldProvisionerDaemons :exec
|
||||
-- Delete provisioner daemons that have been created at least a week ago
|
||||
-- and have not connected to coderd since a week.
|
||||
|
||||
@@ -3,7 +3,6 @@ package coderd
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -21,7 +20,6 @@ import (
|
||||
|
||||
"cdr.dev/slog"
|
||||
|
||||
"github.com/coder/coder/v2/coderd"
|
||||
"github.com/coder/coder/v2/coderd/database"
|
||||
"github.com/coder/coder/v2/coderd/database/db2sdk"
|
||||
"github.com/coder/coder/v2/coderd/database/dbauthz"
|
||||
@@ -65,21 +63,9 @@ func (api *API) provisionerDaemonsEnabledMW(next http.Handler) http.Handler {
|
||||
// @Router /organizations/{organization}/provisionerdaemons [get]
|
||||
func (api *API) provisionerDaemons(rw http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
daemons, err := api.Database.GetProvisionerDaemons(ctx)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
err = nil
|
||||
}
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching provisioner daemons.",
|
||||
Detail: err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
if daemons == nil {
|
||||
daemons = []database.ProvisionerDaemon{}
|
||||
}
|
||||
daemons, err = coderd.AuthorizeFilter(api.AGPL.HTTPAuth, r, policy.ActionRead, daemons)
|
||||
org := httpmw.OrganizationParam(r)
|
||||
|
||||
daemons, err := api.Database.GetProvisionerDaemonsByOrganization(ctx, org.ID)
|
||||
if err != nil {
|
||||
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
|
||||
Message: "Internal error fetching provisioner daemons.",
|
||||
@@ -98,7 +84,7 @@ type provisionerDaemonAuth struct {
|
||||
|
||||
// authorize returns mutated tags and true if the given HTTP request is authorized to access the provisioner daemon
|
||||
// protobuf API, and returns nil, false otherwise.
|
||||
func (p *provisionerDaemonAuth) authorize(r *http.Request, tags map[string]string) (map[string]string, bool) {
|
||||
func (p *provisionerDaemonAuth) authorize(r *http.Request, orgID uuid.UUID, tags map[string]string) (map[string]string, bool) {
|
||||
ctx := r.Context()
|
||||
apiKey, ok := httpmw.APIKeyOptional(r)
|
||||
if ok {
|
||||
@@ -109,7 +95,7 @@ func (p *provisionerDaemonAuth) authorize(r *http.Request, tags map[string]strin
|
||||
return tags, true
|
||||
}
|
||||
ua := httpmw.UserAuthorization(r)
|
||||
if err := p.authorizer.Authorize(ctx, ua, policy.ActionCreate, rbac.ResourceProvisionerDaemon); err == nil {
|
||||
if err := p.authorizer.Authorize(ctx, ua, policy.ActionCreate, rbac.ResourceProvisionerDaemon.InOrg(orgID)); err == nil {
|
||||
// User is allowed to create provisioner daemons
|
||||
return tags, true
|
||||
}
|
||||
@@ -185,7 +171,7 @@ func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request)
|
||||
api.Logger.Warn(ctx, "unnamed provisioner daemon")
|
||||
}
|
||||
|
||||
tags, authorized := api.provisionerDaemonAuth.authorize(r, tags)
|
||||
tags, authorized := api.provisionerDaemonAuth.authorize(r, organization.ID, tags)
|
||||
if !authorized {
|
||||
api.Logger.Warn(ctx, "unauthorized provisioner daemon serve request", slog.F("tags", tags))
|
||||
httpapi.Write(ctx, rw, http.StatusForbidden,
|
||||
|
||||
@@ -211,10 +211,7 @@ func TestProvisionerDaemonServe(t *testing.T) {
|
||||
provisionersdk.TagScope: provisionersdk.ScopeOrganization,
|
||||
},
|
||||
})
|
||||
require.Error(t, err)
|
||||
var apiError *codersdk.Error
|
||||
require.ErrorAs(t, err, &apiError)
|
||||
require.Equal(t, http.StatusForbidden, apiError.StatusCode())
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("OrganizationNoPerms", func(t *testing.T) {
|
||||
@@ -556,3 +553,40 @@ func TestProvisionerDaemonServe(t *testing.T) {
|
||||
require.Len(t, daemons, 0)
|
||||
})
|
||||
}
|
||||
|
||||
func TestGetProvisionerDaemons(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("OK", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
client, _ := coderdenttest.New(t, &coderdenttest.Options{LicenseOptions: &coderdenttest.LicenseOptions{
|
||||
Features: license.Features{
|
||||
codersdk.FeatureExternalProvisionerDaemons: 1,
|
||||
},
|
||||
}})
|
||||
org := coderdtest.CreateOrganization(t, client, coderdtest.CreateOrganizationOptions{})
|
||||
orgAdmin, _ := coderdtest.CreateAnotherUser(t, client, org.ID, rbac.ScopedRoleOrgAdmin(org.ID))
|
||||
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
|
||||
defer cancel()
|
||||
daemonName := testutil.MustRandString(t, 63)
|
||||
srv, err := orgAdmin.ServeProvisionerDaemon(ctx, codersdk.ServeProvisionerDaemonRequest{
|
||||
ID: uuid.New(),
|
||||
Name: daemonName,
|
||||
Organization: org.ID,
|
||||
Provisioners: []codersdk.ProvisionerType{
|
||||
codersdk.ProvisionerTypeEcho,
|
||||
},
|
||||
Tags: map[string]string{},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
srv.DRPCConn().Close()
|
||||
|
||||
daemons, err := orgAdmin.OrganizationProvisionerDaemons(ctx, org.ID)
|
||||
require.NoError(t, err)
|
||||
if assert.Len(t, daemons, 1) {
|
||||
assert.Equal(t, daemonName, daemons[0].Name)
|
||||
assert.Equal(t, buildinfo.Version(), daemons[0].Version)
|
||||
assert.Equal(t, proto.CurrentVersion.String(), daemons[0].APIVersion)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user