feat: Add external provisioner daemons (#4935)

* Start to port over provisioner daemons PR

* Move to Enterprise

* Begin adding tests for external registration

* Move provisioner daemons query to enterprise

* Move around provisioner daemons schema

* Add tags to provisioner daemons

* make gen

* Add user local provisioner daemons

* Add provisioner daemons

* Add feature for external daemons

* Add command to start a provisioner daemon

* Add provisioner tags to template push and create

* Rename migration files

* Fix tests

* Fix entitlements test

* PR comments

* Update migration

* Fix FE types
This commit is contained in:
Kyle Carberry
2022-11-16 16:34:06 -06:00
committed by GitHub
parent 66d20cabac
commit b6703b11c6
51 changed files with 1095 additions and 372 deletions
+1
View File
@@ -17,6 +17,7 @@
"codersdk",
"cronstrue",
"databasefake",
"dbtype",
"DERP",
"derphttp",
"derpmap",
+2 -2
View File
@@ -143,7 +143,7 @@ func newConfig() *codersdk.DeploymentConfig {
Name: "Cache Directory",
Usage: "The directory to cache temporary files. If unspecified and $CACHE_DIRECTORY is set, it will be used for compatibility with systemd.",
Flag: "cache-dir",
Default: defaultCacheDir(),
Default: DefaultCacheDir(),
},
InMemoryDatabase: &codersdk.DeploymentConfigField[bool]{
Name: "In Memory Database",
@@ -672,7 +672,7 @@ func formatEnv(key string) string {
return "CODER_" + strings.ToUpper(strings.NewReplacer("-", "_", ".", "_").Replace(key))
}
func defaultCacheDir() string {
func DefaultCacheDir() string {
defaultCacheDir, err := os.UserCacheDir()
if err != nil {
defaultCacheDir = os.TempDir()
+1 -1
View File
@@ -26,7 +26,7 @@ func gitAskpass() *cobra.Command {
RunE: func(cmd *cobra.Command, args []string) error {
ctx := cmd.Context()
ctx, stop := signal.NotifyContext(ctx, interruptSignals...)
ctx, stop := signal.NotifyContext(ctx, InterruptSignals...)
defer stop()
user, host, err := gitauth.ParseAskpass(args[0])
+1 -1
View File
@@ -29,7 +29,7 @@ func gitssh() *cobra.Command {
// Catch interrupt signals to ensure the temporary private
// key file is cleaned up on most cases.
ctx, stop := signal.NotifyContext(ctx, interruptSignals...)
ctx, stop := signal.NotifyContext(ctx, InterruptSignals...)
defer stop()
// Early check so errors are reported immediately.
+2 -2
View File
@@ -108,7 +108,7 @@ func Server(vip *viper.Viper, newAPI func(context.Context, *coderd.Options) (*co
//
// To get out of a graceful shutdown, the user can send
// SIGQUIT with ctrl+\ or SIGKILL with `kill -9`.
notifyCtx, notifyStop := signal.NotifyContext(ctx, interruptSignals...)
notifyCtx, notifyStop := signal.NotifyContext(ctx, InterruptSignals...)
defer notifyStop()
// Clean up idle connections at the end, e.g.
@@ -946,7 +946,7 @@ func newProvisionerDaemon(
return provisionerd.New(func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
// This debounces calls to listen every second. Read the comment
// in provisionerdserver.go to learn more!
return coderAPI.ListenProvisionerDaemon(ctx, time.Second)
return coderAPI.CreateInMemoryProvisionerDaemon(ctx, time.Second)
}, &provisionerd.Options{
Logger: logger,
PollInterval: 500 * time.Millisecond,
+1 -1
View File
@@ -7,7 +7,7 @@ import (
"syscall"
)
var interruptSignals = []os.Signal{
var InterruptSignals = []os.Signal{
os.Interrupt,
syscall.SIGTERM,
syscall.SIGHUP,
+1 -1
View File
@@ -6,4 +6,4 @@ import (
"os"
)
var interruptSignals = []os.Signal{os.Interrupt}
var InterruptSignals = []os.Signal{os.Interrupt}
+31 -9
View File
@@ -24,10 +24,11 @@ import (
func templateCreate() *cobra.Command {
var (
directory string
provisioner string
parameterFile string
defaultTTL time.Duration
directory string
provisioner string
provisionerTags []string
parameterFile string
defaultTTL time.Duration
)
cmd := &cobra.Command{
Use: "create [name]",
@@ -87,12 +88,18 @@ func templateCreate() *cobra.Command {
}
spin.Stop()
tags, err := ParseProvisionerTags(provisionerTags)
if err != nil {
return err
}
job, _, err := createValidTemplateVersion(cmd, createValidTemplateVersionArgs{
Client: client,
Organization: organization,
Provisioner: database.ProvisionerType(provisioner),
FileID: resp.ID,
ParameterFile: parameterFile,
Client: client,
Organization: organization,
Provisioner: database.ProvisionerType(provisioner),
FileID: resp.ID,
ParameterFile: parameterFile,
ProvisionerTags: tags,
})
if err != nil {
return err
@@ -131,6 +138,7 @@ func templateCreate() *cobra.Command {
cmd.Flags().StringVarP(&directory, "directory", "d", currentDirectory, "Specify the directory to create from")
cmd.Flags().StringVarP(&provisioner, "test.provisioner", "", "terraform", "Customize the provisioner backend")
cmd.Flags().StringVarP(&parameterFile, "parameter-file", "", "", "Specify a file path with parameter values.")
cmd.Flags().StringArrayVarP(&provisionerTags, "provisioner-tag", "", []string{}, "Specify a set of tags to target provisioner daemons.")
cmd.Flags().DurationVarP(&defaultTTL, "default-ttl", "", 24*time.Hour, "Specify a default TTL for workspaces created from this template.")
// This is for testing!
err := cmd.Flags().MarkHidden("test.provisioner")
@@ -154,6 +162,7 @@ type createValidTemplateVersionArgs struct {
// before prompting the user. Set to false to always prompt for param
// values.
ReuseParameters bool
ProvisionerTags map[string]string
}
func createValidTemplateVersion(cmd *cobra.Command, args createValidTemplateVersionArgs, parameters ...codersdk.CreateParameterRequest) (*codersdk.TemplateVersion, []codersdk.CreateParameterRequest, error) {
@@ -165,6 +174,7 @@ func createValidTemplateVersion(cmd *cobra.Command, args createValidTemplateVers
FileID: args.FileID,
Provisioner: codersdk.ProvisionerType(args.Provisioner),
ParameterValues: parameters,
ProvisionerTags: args.ProvisionerTags,
}
if args.Template != nil {
req.TemplateID = args.Template.ID
@@ -334,3 +344,15 @@ func prettyDirectoryPath(dir string) string {
}
return pretty
}
func ParseProvisionerTags(rawTags []string) (map[string]string, error) {
tags := map[string]string{}
for _, rawTag := range rawTags {
parts := strings.SplitN(rawTag, "=", 2)
if len(parts) < 2 {
return nil, xerrors.Errorf("invalid tag format for %q. must be key=value", rawTag)
}
tags[parts[0]] = parts[1]
}
return tags, nil
}
+12 -5
View File
@@ -18,11 +18,12 @@ import (
func templatePush() *cobra.Command {
var (
directory string
versionName string
provisioner string
parameterFile string
alwaysPrompt bool
directory string
versionName string
provisioner string
parameterFile string
alwaysPrompt bool
provisionerTags []string
)
cmd := &cobra.Command{
@@ -75,6 +76,11 @@ func templatePush() *cobra.Command {
}
spin.Stop()
tags, err := ParseProvisionerTags(provisionerTags)
if err != nil {
return err
}
job, _, err := createValidTemplateVersion(cmd, createValidTemplateVersionArgs{
Name: versionName,
Client: client,
@@ -84,6 +90,7 @@ func templatePush() *cobra.Command {
ParameterFile: parameterFile,
Template: &template,
ReuseParameters: !alwaysPrompt,
ProvisionerTags: tags,
})
if err != nil {
return err
@@ -278,6 +278,7 @@ func build(ctx context.Context, store database.Store, workspace database.Workspa
Type: database.ProvisionerJobTypeWorkspaceBuild,
StorageMethod: priorJob.StorageMethod,
FileID: priorJob.FileID,
Tags: priorJob.Tags,
Input: input,
})
if err != nil {
+84 -14
View File
@@ -1,8 +1,10 @@
package coderd
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/json"
"fmt"
"io"
"net/http"
@@ -18,10 +20,13 @@ import (
"github.com/go-chi/chi/v5/middleware"
"github.com/google/uuid"
"github.com/klauspost/compress/zstd"
"github.com/moby/moby/pkg/namesgenerator"
"github.com/prometheus/client_golang/prometheus"
"go.opentelemetry.io/otel/trace"
"golang.org/x/xerrors"
"google.golang.org/api/idtoken"
"storj.io/drpc/drpcmux"
"storj.io/drpc/drpcserver"
"tailscale.com/derp"
"tailscale.com/derp/derphttp"
"tailscale.com/tailcfg"
@@ -32,17 +37,20 @@ import (
"github.com/coder/coder/coderd/audit"
"github.com/coder/coder/coderd/awsidentity"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/database/dbtype"
"github.com/coder/coder/coderd/gitauth"
"github.com/coder/coder/coderd/gitsshkey"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/coderd/metricscache"
"github.com/coder/coder/coderd/provisionerdserver"
"github.com/coder/coder/coderd/rbac"
"github.com/coder/coder/coderd/telemetry"
"github.com/coder/coder/coderd/tracing"
"github.com/coder/coder/coderd/wsconncache"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/provisionerd/proto"
"github.com/coder/coder/provisionersdk"
"github.com/coder/coder/site"
"github.com/coder/coder/tailnet"
)
@@ -323,13 +331,6 @@ func New(options *Options) *API {
r.Get("/{fileID}", api.fileByID)
r.Post("/", api.postFile)
})
r.Route("/provisionerdaemons", func(r chi.Router) {
r.Use(
apiKeyMiddleware,
)
r.Get("/", api.provisionerDaemons)
})
r.Route("/organizations", func(r chi.Router) {
r.Use(
apiKeyMiddleware,
@@ -595,18 +596,20 @@ type API struct {
// RootHandler serves "/"
RootHandler chi.Router
metricsCache *metricscache.Cache
siteHandler http.Handler
websocketWaitMutex sync.Mutex
websocketWaitGroup sync.WaitGroup
metricsCache *metricscache.Cache
siteHandler http.Handler
WebsocketWaitMutex sync.Mutex
WebsocketWaitGroup sync.WaitGroup
workspaceAgentCache *wsconncache.Cache
}
// Close waits for all WebSocket connections to drain before returning.
func (api *API) Close() error {
api.websocketWaitMutex.Lock()
api.websocketWaitGroup.Wait()
api.websocketWaitMutex.Unlock()
api.WebsocketWaitMutex.Lock()
api.WebsocketWaitGroup.Wait()
api.WebsocketWaitMutex.Unlock()
api.metricsCache.Close()
coordinator := api.TailnetCoordinator.Load()
@@ -635,3 +638,70 @@ func compressHandler(h http.Handler) http.Handler {
return cmp.Handler(h)
}
// CreateInMemoryProvisionerDaemon is an in-memory connection to a provisionerd. Useful when starting coderd and provisionerd
// in the same process.
func (api *API) CreateInMemoryProvisionerDaemon(ctx context.Context, debounce time.Duration) (client proto.DRPCProvisionerDaemonClient, err error) {
clientSession, serverSession := provisionersdk.TransportPipe()
defer func() {
if err != nil {
_ = clientSession.Close()
_ = serverSession.Close()
}
}()
name := namesgenerator.GetRandomName(1)
daemon, err := api.Database.InsertProvisionerDaemon(ctx, database.InsertProvisionerDaemonParams{
ID: uuid.New(),
CreatedAt: database.Now(),
Name: name,
Provisioners: []database.ProvisionerType{database.ProvisionerTypeEcho, database.ProvisionerTypeTerraform},
Tags: dbtype.StringMap{
provisionerdserver.TagScope: provisionerdserver.ScopeOrganization,
},
})
if err != nil {
return nil, xerrors.Errorf("insert provisioner daemon %q: %w", name, err)
}
tags, err := json.Marshal(daemon.Tags)
if err != nil {
return nil, xerrors.Errorf("marshal tags: %w", err)
}
mux := drpcmux.New()
err = proto.DRPCRegisterProvisionerDaemon(mux, &provisionerdserver.Server{
AccessURL: api.AccessURL,
ID: daemon.ID,
Database: api.Database,
Pubsub: api.Pubsub,
Provisioners: daemon.Provisioners,
Telemetry: api.Telemetry,
Tags: tags,
QuotaCommitter: &api.QuotaCommitter,
AcquireJobDebounce: debounce,
Logger: api.Logger.Named(fmt.Sprintf("provisionerd-%s", daemon.Name)),
})
if err != nil {
return nil, err
}
server := drpcserver.NewWithOptions(mux, drpcserver.Options{
Log: func(err error) {
if xerrors.Is(err, io.EOF) {
return
}
api.Logger.Debug(ctx, "drpc server error", slog.Error(err))
},
})
go func() {
err := server.Serve(ctx, serverSession)
if err != nil && !xerrors.Is(err, io.EOF) {
api.Logger.Debug(ctx, "provisioner daemon disconnected", slog.Error(err))
}
// close the sessions so we don't leak goroutines serving them.
_ = clientSession.Close()
_ = serverSession.Close()
}()
return proto.NewDRPCProvisionerDaemonClient(provisionersdk.Conn(clientSession)), nil
}
-16
View File
@@ -19,7 +19,6 @@ import (
"github.com/coder/coder/codersdk"
"github.com/coder/coder/provisioner/echo"
"github.com/coder/coder/provisionersdk/proto"
"github.com/coder/coder/testutil"
)
func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
@@ -204,11 +203,6 @@ func AGPLRoutes(a *AuthTester) (map[string]string, map[string]RouteCheck) {
AssertAction: rbac.ActionRead,
AssertObject: rbac.ResourceTemplate.InOrg(a.Version.OrganizationID),
},
"GET:/api/v2/provisionerdaemons": {
StatusCode: http.StatusOK,
AssertObject: rbac.ResourceProvisionerDaemon,
},
"POST:/api/v2/parameters/{scope}/{id}": {
AssertAction: rbac.ActionUpdate,
AssertObject: rbac.ResourceTemplate,
@@ -303,16 +297,6 @@ func NewAuthTester(ctx context.Context, t *testing.T, client *codersdk.Client, a
if !ok {
t.Fail()
}
// The provisioner will call to coderd and register itself. This is async,
// so we wait for it to occur.
require.Eventually(t, func() bool {
provisionerds, err := client.ProvisionerDaemons(ctx)
return assert.NoError(t, err) && len(provisionerds) > 0
}, testutil.WaitLong, testutil.IntervalSlow)
provisionerds, err := client.ProvisionerDaemons(ctx)
require.NoError(t, err, "fetch provisioners")
require.Len(t, provisionerds, 1)
organization, err := client.Organization(ctx, admin.OrganizationID)
require.NoError(t, err, "fetch org")
+38 -3
View File
@@ -69,7 +69,7 @@ import (
"github.com/coder/coder/cryptorand"
"github.com/coder/coder/provisioner/echo"
"github.com/coder/coder/provisionerd"
"github.com/coder/coder/provisionerd/proto"
provisionerdproto "github.com/coder/coder/provisionerd/proto"
"github.com/coder/coder/provisionersdk"
sdkproto "github.com/coder/coder/provisionersdk/proto"
"github.com/coder/coder/tailnet"
@@ -328,8 +328,43 @@ func NewProvisionerDaemon(t *testing.T, coderAPI *coderd.API) io.Closer {
assert.NoError(t, err)
}()
closer := provisionerd.New(func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
return coderAPI.ListenProvisionerDaemon(ctx, 0)
closer := provisionerd.New(func(ctx context.Context) (provisionerdproto.DRPCProvisionerDaemonClient, error) {
return coderAPI.CreateInMemoryProvisionerDaemon(ctx, 0)
}, &provisionerd.Options{
Filesystem: fs,
Logger: slogtest.Make(t, nil).Named("provisionerd").Leveled(slog.LevelDebug),
PollInterval: 50 * time.Millisecond,
UpdateInterval: 250 * time.Millisecond,
ForceCancelInterval: time.Second,
Provisioners: provisionerd.Provisioners{
string(database.ProvisionerTypeEcho): sdkproto.NewDRPCProvisionerClient(provisionersdk.Conn(echoClient)),
},
WorkDirectory: t.TempDir(),
})
t.Cleanup(func() {
_ = closer.Close()
})
return closer
}
func NewExternalProvisionerDaemon(t *testing.T, client *codersdk.Client, org uuid.UUID, tags map[string]string) io.Closer {
echoClient, echoServer := provisionersdk.TransportPipe()
ctx, cancelFunc := context.WithCancel(context.Background())
t.Cleanup(func() {
_ = echoClient.Close()
_ = echoServer.Close()
cancelFunc()
})
fs := afero.NewMemMapFs()
go func() {
err := echo.Serve(ctx, fs, &provisionersdk.ServeOptions{
Listener: echoServer,
})
assert.NoError(t, err)
}()
closer := provisionerd.New(func(ctx context.Context) (provisionerdproto.DRPCProvisionerDaemonClient, error) {
return client.ServeProvisionerDaemon(ctx, org, []codersdk.ProvisionerType{codersdk.ProvisionerTypeEcho}, tags)
}, &provisionerd.Options{
Filesystem: fs,
Logger: slogtest.Make(t, nil).Named("provisionerd").Leveled(slog.LevelDebug),
@@ -3,6 +3,7 @@ package databasefake
import (
"context"
"database/sql"
"encoding/json"
"sort"
"strings"
"sync"
@@ -146,6 +147,29 @@ func (q *fakeQuerier) AcquireProvisionerJob(_ context.Context, arg database.Acqu
if !found {
continue
}
tags := map[string]string{}
if arg.Tags != nil {
err := json.Unmarshal(arg.Tags, &tags)
if err != nil {
return provisionerJob, xerrors.Errorf("unmarshal: %w", err)
}
}
missing := false
for key, value := range provisionerJob.Tags {
provided, found := tags[key]
if !found {
missing = true
break
}
if provided != value {
missing = true
break
}
}
if missing {
continue
}
provisionerJob.StartedAt = arg.StartedAt
provisionerJob.UpdatedAt = arg.StartedAt.Time
provisionerJob.WorkerID = arg.WorkerID
@@ -2244,6 +2268,7 @@ func (q *fakeQuerier) InsertProvisionerDaemon(_ context.Context, arg database.In
CreatedAt: arg.CreatedAt,
Name: arg.Name,
Provisioners: arg.Provisioners,
Tags: arg.Tags,
}
q.provisionerDaemons = append(q.provisionerDaemons, daemon)
return daemon, nil
@@ -2264,6 +2289,7 @@ func (q *fakeQuerier) InsertProvisionerJob(_ context.Context, arg database.Inser
FileID: arg.FileID,
Type: arg.Type,
Input: arg.Input,
Tags: arg.Tags,
}
q.provisionerJobs = append(q.provisionerJobs, job)
return job, nil
+30
View File
@@ -0,0 +1,30 @@
package dbtype
import (
"database/sql/driver"
"encoding/json"
"golang.org/x/xerrors"
)
type StringMap map[string]string
func (m *StringMap) Scan(src interface{}) error {
if src == nil {
return nil
}
switch src := src.(type) {
case []byte:
err := json.Unmarshal(src, m)
if err != nil {
return err
}
default:
return xerrors.Errorf("unsupported Scan, storing driver.Value type %T into type %T", src, m)
}
return nil
}
func (m StringMap) Value() (driver.Value, error) {
return json.Marshal(m)
}
+4 -2
View File
@@ -269,7 +269,8 @@ CREATE TABLE provisioner_daemons (
updated_at timestamp with time zone,
name character varying(64) NOT NULL,
provisioners provisioner_type[] NOT NULL,
replica_id uuid
replica_id uuid,
tags jsonb DEFAULT '{}'::jsonb NOT NULL
);
CREATE TABLE provisioner_job_logs (
@@ -306,7 +307,8 @@ CREATE TABLE provisioner_jobs (
type provisioner_job_type NOT NULL,
input jsonb NOT NULL,
worker_id uuid,
file_id uuid NOT NULL
file_id uuid NOT NULL,
tags jsonb DEFAULT '{"scope": "organization"}'::jsonb NOT NULL
);
CREATE TABLE replicas (
@@ -0,0 +1,2 @@
ALTER TABLE provisioner_daemons DROP COLUMN tags;
ALTER TABLE provisioner_jobs DROP COLUMN tags;
@@ -0,0 +1,5 @@
ALTER TABLE provisioner_daemons ADD COLUMN tags jsonb NOT NULL DEFAULT '{}';
-- We must add the organization scope by default, otherwise pending jobs
-- could be provisioned on new daemons that don't match the tags.
ALTER TABLE provisioner_jobs ADD COLUMN tags jsonb NOT NULL DEFAULT '{"scope":"organization"}';
+3
View File
@@ -10,6 +10,7 @@ import (
"fmt"
"time"
"github.com/coder/coder/coderd/database/dbtype"
"github.com/google/uuid"
"github.com/lib/pq"
"github.com/tabbed/pqtype"
@@ -525,6 +526,7 @@ type ProvisionerDaemon struct {
Name string `db:"name" json:"name"`
Provisioners []ProvisionerType `db:"provisioners" json:"provisioners"`
ReplicaID uuid.NullUUID `db:"replica_id" json:"replica_id"`
Tags dbtype.StringMap `db:"tags" json:"tags"`
}
type ProvisionerJob struct {
@@ -543,6 +545,7 @@ type ProvisionerJob struct {
Input json.RawMessage `db:"input" json:"input"`
WorkerID uuid.NullUUID `db:"worker_id" json:"worker_id"`
FileID uuid.UUID `db:"file_id" json:"file_id"`
Tags dbtype.StringMap `db:"tags" json:"tags"`
}
type ProvisionerJobLog struct {
+34 -11
View File
@@ -10,6 +10,7 @@ import (
"encoding/json"
"time"
"github.com/coder/coder/coderd/database/dbtype"
"github.com/google/uuid"
"github.com/lib/pq"
"github.com/tabbed/pqtype"
@@ -2243,7 +2244,7 @@ func (q *sqlQuerier) ParameterValues(ctx context.Context, arg ParameterValuesPar
const getProvisionerDaemonByID = `-- name: GetProvisionerDaemonByID :one
SELECT
id, created_at, updated_at, name, provisioners, replica_id
id, created_at, updated_at, name, provisioners, replica_id, tags
FROM
provisioner_daemons
WHERE
@@ -2260,13 +2261,14 @@ func (q *sqlQuerier) GetProvisionerDaemonByID(ctx context.Context, id uuid.UUID)
&i.Name,
pq.Array(&i.Provisioners),
&i.ReplicaID,
&i.Tags,
)
return i, err
}
const getProvisionerDaemons = `-- name: GetProvisionerDaemons :many
SELECT
id, created_at, updated_at, name, provisioners, replica_id
id, created_at, updated_at, name, provisioners, replica_id, tags
FROM
provisioner_daemons
`
@@ -2287,6 +2289,7 @@ func (q *sqlQuerier) GetProvisionerDaemons(ctx context.Context) ([]ProvisionerDa
&i.Name,
pq.Array(&i.Provisioners),
&i.ReplicaID,
&i.Tags,
); err != nil {
return nil, err
}
@@ -2307,10 +2310,11 @@ INSERT INTO
id,
created_at,
"name",
provisioners
provisioners,
tags
)
VALUES
($1, $2, $3, $4) RETURNING id, created_at, updated_at, name, provisioners, replica_id
($1, $2, $3, $4, $5) RETURNING id, created_at, updated_at, name, provisioners, replica_id, tags
`
type InsertProvisionerDaemonParams struct {
@@ -2318,6 +2322,7 @@ type InsertProvisionerDaemonParams struct {
CreatedAt time.Time `db:"created_at" json:"created_at"`
Name string `db:"name" json:"name"`
Provisioners []ProvisionerType `db:"provisioners" json:"provisioners"`
Tags dbtype.StringMap `db:"tags" json:"tags"`
}
func (q *sqlQuerier) InsertProvisionerDaemon(ctx context.Context, arg InsertProvisionerDaemonParams) (ProvisionerDaemon, error) {
@@ -2326,6 +2331,7 @@ func (q *sqlQuerier) InsertProvisionerDaemon(ctx context.Context, arg InsertProv
arg.CreatedAt,
arg.Name,
pq.Array(arg.Provisioners),
arg.Tags,
)
var i ProvisionerDaemon
err := row.Scan(
@@ -2335,6 +2341,7 @@ func (q *sqlQuerier) InsertProvisionerDaemon(ctx context.Context, arg InsertProv
&i.Name,
pq.Array(&i.Provisioners),
&i.ReplicaID,
&i.Tags,
)
return i, err
}
@@ -2487,19 +2494,22 @@ WHERE
AND nested.canceled_at IS NULL
AND nested.completed_at IS NULL
AND nested.provisioner = ANY($3 :: provisioner_type [ ])
-- Ensure the caller satisfies all job tags.
AND nested.tags <@ $4 :: jsonb
ORDER BY
nested.created_at
FOR UPDATE
SKIP LOCKED
LIMIT
1
) RETURNING id, created_at, updated_at, started_at, canceled_at, completed_at, error, organization_id, initiator_id, provisioner, storage_method, type, input, worker_id, file_id
) RETURNING id, created_at, updated_at, started_at, canceled_at, completed_at, error, organization_id, initiator_id, provisioner, storage_method, type, input, worker_id, file_id, tags
`
type AcquireProvisionerJobParams struct {
StartedAt sql.NullTime `db:"started_at" json:"started_at"`
WorkerID uuid.NullUUID `db:"worker_id" json:"worker_id"`
Types []ProvisionerType `db:"types" json:"types"`
Tags json.RawMessage `db:"tags" json:"tags"`
}
// Acquires the lock for a single job that isn't started, completed,
@@ -2509,7 +2519,12 @@ type AcquireProvisionerJobParams struct {
// multiple provisioners from acquiring the same jobs. See:
// https://www.postgresql.org/docs/9.5/sql-select.html#SQL-FOR-UPDATE-SHARE
func (q *sqlQuerier) AcquireProvisionerJob(ctx context.Context, arg AcquireProvisionerJobParams) (ProvisionerJob, error) {
row := q.db.QueryRowContext(ctx, acquireProvisionerJob, arg.StartedAt, arg.WorkerID, pq.Array(arg.Types))
row := q.db.QueryRowContext(ctx, acquireProvisionerJob,
arg.StartedAt,
arg.WorkerID,
pq.Array(arg.Types),
arg.Tags,
)
var i ProvisionerJob
err := row.Scan(
&i.ID,
@@ -2527,13 +2542,14 @@ func (q *sqlQuerier) AcquireProvisionerJob(ctx context.Context, arg AcquireProvi
&i.Input,
&i.WorkerID,
&i.FileID,
&i.Tags,
)
return i, err
}
const getProvisionerJobByID = `-- name: GetProvisionerJobByID :one
SELECT
id, created_at, updated_at, started_at, canceled_at, completed_at, error, organization_id, initiator_id, provisioner, storage_method, type, input, worker_id, file_id
id, created_at, updated_at, started_at, canceled_at, completed_at, error, organization_id, initiator_id, provisioner, storage_method, type, input, worker_id, file_id, tags
FROM
provisioner_jobs
WHERE
@@ -2559,13 +2575,14 @@ func (q *sqlQuerier) GetProvisionerJobByID(ctx context.Context, id uuid.UUID) (P
&i.Input,
&i.WorkerID,
&i.FileID,
&i.Tags,
)
return i, err
}
const getProvisionerJobsByIDs = `-- name: GetProvisionerJobsByIDs :many
SELECT
id, created_at, updated_at, started_at, canceled_at, completed_at, error, organization_id, initiator_id, provisioner, storage_method, type, input, worker_id, file_id
id, created_at, updated_at, started_at, canceled_at, completed_at, error, organization_id, initiator_id, provisioner, storage_method, type, input, worker_id, file_id, tags
FROM
provisioner_jobs
WHERE
@@ -2597,6 +2614,7 @@ func (q *sqlQuerier) GetProvisionerJobsByIDs(ctx context.Context, ids []uuid.UUI
&i.Input,
&i.WorkerID,
&i.FileID,
&i.Tags,
); err != nil {
return nil, err
}
@@ -2612,7 +2630,7 @@ func (q *sqlQuerier) GetProvisionerJobsByIDs(ctx context.Context, ids []uuid.UUI
}
const getProvisionerJobsCreatedAfter = `-- name: GetProvisionerJobsCreatedAfter :many
SELECT id, created_at, updated_at, started_at, canceled_at, completed_at, error, organization_id, initiator_id, provisioner, storage_method, type, input, worker_id, file_id FROM provisioner_jobs WHERE created_at > $1
SELECT id, created_at, updated_at, started_at, canceled_at, completed_at, error, organization_id, initiator_id, provisioner, storage_method, type, input, worker_id, file_id, tags FROM provisioner_jobs WHERE created_at > $1
`
func (q *sqlQuerier) GetProvisionerJobsCreatedAfter(ctx context.Context, createdAt time.Time) ([]ProvisionerJob, error) {
@@ -2640,6 +2658,7 @@ func (q *sqlQuerier) GetProvisionerJobsCreatedAfter(ctx context.Context, created
&i.Input,
&i.WorkerID,
&i.FileID,
&i.Tags,
); err != nil {
return nil, err
}
@@ -2666,10 +2685,11 @@ INSERT INTO
storage_method,
file_id,
"type",
"input"
"input",
tags
)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id, created_at, updated_at, started_at, canceled_at, completed_at, error, organization_id, initiator_id, provisioner, storage_method, type, input, worker_id, file_id
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) RETURNING id, created_at, updated_at, started_at, canceled_at, completed_at, error, organization_id, initiator_id, provisioner, storage_method, type, input, worker_id, file_id, tags
`
type InsertProvisionerJobParams struct {
@@ -2683,6 +2703,7 @@ type InsertProvisionerJobParams struct {
FileID uuid.UUID `db:"file_id" json:"file_id"`
Type ProvisionerJobType `db:"type" json:"type"`
Input json.RawMessage `db:"input" json:"input"`
Tags dbtype.StringMap `db:"tags" json:"tags"`
}
func (q *sqlQuerier) InsertProvisionerJob(ctx context.Context, arg InsertProvisionerJobParams) (ProvisionerJob, error) {
@@ -2697,6 +2718,7 @@ func (q *sqlQuerier) InsertProvisionerJob(ctx context.Context, arg InsertProvisi
arg.FileID,
arg.Type,
arg.Input,
arg.Tags,
)
var i ProvisionerJob
err := row.Scan(
@@ -2715,6 +2737,7 @@ func (q *sqlQuerier) InsertProvisionerJob(ctx context.Context, arg InsertProvisi
&i.Input,
&i.WorkerID,
&i.FileID,
&i.Tags,
)
return i, err
}
@@ -18,10 +18,11 @@ INSERT INTO
id,
created_at,
"name",
provisioners
provisioners,
tags
)
VALUES
($1, $2, $3, $4) RETURNING *;
($1, $2, $3, $4, $5) RETURNING *;
-- name: UpdateProvisionerDaemonByID :exec
UPDATE
+5 -2
View File
@@ -22,6 +22,8 @@ WHERE
AND nested.canceled_at IS NULL
AND nested.completed_at IS NULL
AND nested.provisioner = ANY(@types :: provisioner_type [ ])
-- Ensure the caller satisfies all job tags.
AND nested.tags <@ @tags :: jsonb
ORDER BY
nested.created_at
FOR UPDATE
@@ -61,10 +63,11 @@ INSERT INTO
storage_method,
file_id,
"type",
"input"
"input",
tags
)
VALUES
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING *;
($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) RETURNING *;
-- name: UpdateProvisionerJobByID :exec
UPDATE
+4
View File
@@ -17,6 +17,10 @@ packages:
output_db_file_name: db_tmp.go
overrides:
- column: "provisioner_daemons.tags"
go_type: "github.com/coder/coder/coderd/database/dbtype.StringMap"
- column: "provisioner_jobs.tags"
go_type: "github.com/coder/coder/coderd/database/dbtype.StringMap"
- column: "users.rbac_roles"
go_type: "github.com/lib/pq.StringArray"
- column: "templates.user_acl"
-113
View File
@@ -1,113 +0,0 @@
package coderd
import (
"context"
"database/sql"
"errors"
"fmt"
"io"
"net/http"
"time"
"github.com/google/uuid"
"github.com/moby/moby/pkg/namesgenerator"
"golang.org/x/xerrors"
"storj.io/drpc/drpcmux"
"storj.io/drpc/drpcserver"
"cdr.dev/slog"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/provisionerdserver"
"github.com/coder/coder/coderd/rbac"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/provisionerd/proto"
"github.com/coder/coder/provisionersdk"
)
func (api *API) provisionerDaemons(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
daemons, err := api.Database.GetProvisionerDaemons(ctx)
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching provisioner daemons.",
Detail: err.Error(),
})
return
}
if daemons == nil {
daemons = []database.ProvisionerDaemon{}
}
daemons, err = AuthorizeFilter(api.HTTPAuth, r, rbac.ActionRead, daemons)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching provisioner daemons.",
Detail: err.Error(),
})
return
}
httpapi.Write(ctx, rw, http.StatusOK, daemons)
}
// ListenProvisionerDaemon is an in-memory connection to a provisionerd. Useful when starting coderd and provisionerd
// in the same process.
func (api *API) ListenProvisionerDaemon(ctx context.Context, acquireJobDebounce time.Duration) (client proto.DRPCProvisionerDaemonClient, err error) {
clientSession, serverSession := provisionersdk.TransportPipe()
defer func() {
if err != nil {
_ = clientSession.Close()
_ = serverSession.Close()
}
}()
name := namesgenerator.GetRandomName(1)
daemon, err := api.Database.InsertProvisionerDaemon(ctx, database.InsertProvisionerDaemonParams{
ID: uuid.New(),
CreatedAt: database.Now(),
Name: name,
Provisioners: []database.ProvisionerType{database.ProvisionerTypeEcho, database.ProvisionerTypeTerraform},
})
if err != nil {
return nil, xerrors.Errorf("insert provisioner daemon %q: %w", name, err)
}
mux := drpcmux.New()
err = proto.DRPCRegisterProvisionerDaemon(mux, &provisionerdserver.Server{
AccessURL: api.AccessURL,
ID: daemon.ID,
Database: api.Database,
Pubsub: api.Pubsub,
Provisioners: daemon.Provisioners,
Telemetry: api.Telemetry,
Logger: api.Logger.Named(fmt.Sprintf("provisionerd-%s", daemon.Name)),
AcquireJobDebounce: acquireJobDebounce,
QuotaCommitter: &api.QuotaCommitter,
})
if err != nil {
return nil, err
}
server := drpcserver.NewWithOptions(mux, drpcserver.Options{
Log: func(err error) {
if xerrors.Is(err, io.EOF) {
return
}
api.Logger.Debug(ctx, "drpc server error", slog.Error(err))
},
})
go func() {
err := server.Serve(ctx, serverSession)
if err != nil && !xerrors.Is(err, io.EOF) {
api.Logger.Debug(ctx, "provisioner daemon disconnected", slog.Error(err))
}
// close the sessions so we don't leak goroutines serving them.
_ = clientSession.Close()
_ = serverSession.Close()
}()
return proto.NewDRPCProvisionerDaemonClient(provisionersdk.Conn(clientSession)), nil
}
-76
View File
@@ -1,76 +0,0 @@
package coderd_test
import (
"context"
"crypto/rand"
"runtime"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/provisionersdk"
"github.com/coder/coder/testutil"
)
func TestProvisionerDaemons(t *testing.T) {
t.Parallel()
t.Run("PayloadTooBig", func(t *testing.T) {
t.Parallel()
if runtime.GOOS == "windows" {
// Takes too long to allocate memory on Windows!
t.Skip()
}
client := coderdtest.New(t, &coderdtest.Options{IncludeProvisionerDaemon: true})
user := coderdtest.CreateFirstUser(t, client)
data := make([]byte, provisionersdk.MaxMessageSize)
rand.Read(data)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
resp, err := client.Upload(ctx, codersdk.ContentTypeTar, data)
require.NoError(t, err)
t.Log(resp.ID)
version, err := client.CreateTemplateVersion(ctx, user.OrganizationID, codersdk.CreateTemplateVersionRequest{
StorageMethod: codersdk.ProvisionerStorageMethodFile,
FileID: resp.ID,
Provisioner: codersdk.ProvisionerTypeEcho,
})
require.NoError(t, err)
require.Eventually(t, func() bool {
var err error
version, err = client.TemplateVersion(ctx, version.ID)
return assert.NoError(t, err) && version.Job.Error != ""
}, testutil.WaitShort, testutil.IntervalFast)
})
}
func TestProvisionerDaemonsByOrganization(t *testing.T) {
t.Parallel()
t.Run("NoAuth", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
_, err := client.ProvisionerDaemons(ctx)
require.Error(t, err)
})
t.Run("Get", func(t *testing.T) {
t.Parallel()
client := coderdtest.New(t, nil)
_ = coderdtest.CreateFirstUser(t, client)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel()
_, err := client.ProvisionerDaemons(ctx)
require.NoError(t, err)
})
}
@@ -39,6 +39,7 @@ type Server struct {
ID uuid.UUID
Logger slog.Logger
Provisioners []database.ProvisionerType
Tags json.RawMessage
Database database.Store
Pubsub database.Pubsub
Telemetry telemetry.Reporter
@@ -71,6 +72,7 @@ func (server *Server) AcquireJob(ctx context.Context, _ *proto.Empty) (*proto.Ac
Valid: true,
},
Types: server.Provisioners,
Tags: server.Tags,
})
if errors.Is(err, sql.ErrNoRows) {
// The provisioner daemon assumes no jobs are available if
@@ -0,0 +1,33 @@
package provisionerdserver
import "github.com/google/uuid"
const (
TagScope = "scope"
TagOwner = "owner"
ScopeUser = "user"
ScopeOrganization = "organization"
)
// MutateTags adjusts the "owner" tag dependent on the "scope".
// If the scope is "user", the "owner" is changed to the user ID.
// This is for user-scoped provisioner daemons, where users should
// own their own operations.
func MutateTags(userID uuid.UUID, tags map[string]string) map[string]string {
if tags == nil {
tags = map[string]string{}
}
_, ok := tags[TagScope]
if !ok {
tags[TagScope] = ScopeOrganization
}
switch tags[TagScope] {
case ScopeUser:
tags[TagOwner] = userID.String()
case ScopeOrganization:
default:
tags[TagScope] = ScopeOrganization
}
return tags
}
+5 -4
View File
@@ -131,10 +131,10 @@ func (api *API) provisionerJobLogs(rw http.ResponseWriter, r *http.Request, job
return
}
api.websocketWaitMutex.Lock()
api.websocketWaitGroup.Add(1)
api.websocketWaitMutex.Unlock()
defer api.websocketWaitGroup.Done()
api.WebsocketWaitMutex.Lock()
api.WebsocketWaitGroup.Add(1)
api.WebsocketWaitMutex.Unlock()
defer api.WebsocketWaitGroup.Done()
conn, err := websocket.Accept(rw, r, nil)
if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
@@ -312,6 +312,7 @@ func convertProvisionerJob(provisionerJob database.ProvisionerJob) codersdk.Prov
CreatedAt: provisionerJob.CreatedAt,
Error: provisionerJob.Error.String,
FileID: provisionerJob.FileID,
Tags: provisionerJob.Tags,
}
// Applying values optional to the struct.
if provisionerJob.StartedAt.Valid {
+6
View File
@@ -291,6 +291,8 @@ func (api *API) postTemplateVersionDryRun(rw http.ResponseWriter, r *http.Reques
FileID: job.FileID,
Type: database.ProvisionerJobTypeTemplateVersionDryRun,
Input: input,
// Copy tags from the previous run.
Tags: job.Tags,
})
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
@@ -764,6 +766,9 @@ func (api *API) postTemplateVersionsByOrganization(rw http.ResponseWriter, r *ht
return
}
// Ensures the "owner" is properly applied.
tags := provisionerdserver.MutateTags(apiKey.UserID, req.ProvisionerTags)
file, err := api.Database.GetFileByID(ctx, req.FileID)
if errors.Is(err, sql.ErrNoRows) {
httpapi.Write(ctx, rw, http.StatusNotFound, codersdk.Response{
@@ -862,6 +867,7 @@ func (api *API) postTemplateVersionsByOrganization(rw http.ResponseWriter, r *ht
FileID: file.ID,
Type: database.ProvisionerJobTypeTemplateVersionImport,
Input: []byte{'{', '}'},
Tags: tags,
})
if err != nil {
return xerrors.Errorf("insert provisioner job: %w", err)
+2
View File
@@ -13,6 +13,7 @@ import (
"github.com/coder/coder/coderd/audit"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/provisionerdserver"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/provisioner/echo"
"github.com/coder/coder/provisionersdk/proto"
@@ -122,6 +123,7 @@ func TestPostTemplateVersionsByOrganization(t *testing.T) {
})
require.NoError(t, err)
require.Equal(t, "bananas", version.Name)
require.Equal(t, provisionerdserver.ScopeOrganization, version.Job.Tags[provisionerdserver.TagScope])
require.Len(t, auditor.AuditLogs, 1)
assert.Equal(t, database.AuditActionCreate, auditor.AuditLogs[0].Action)
+16 -16
View File
@@ -181,10 +181,10 @@ func (api *API) postWorkspaceAgentVersion(rw http.ResponseWriter, r *http.Reques
func (api *API) workspaceAgentPTY(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
api.websocketWaitMutex.Lock()
api.websocketWaitGroup.Add(1)
api.websocketWaitMutex.Unlock()
defer api.websocketWaitGroup.Done()
api.WebsocketWaitMutex.Lock()
api.WebsocketWaitGroup.Add(1)
api.WebsocketWaitMutex.Unlock()
defer api.WebsocketWaitGroup.Done()
workspaceAgent := httpmw.WorkspaceAgentParam(r)
workspace := httpmw.WorkspaceParam(r)
@@ -442,10 +442,10 @@ func (api *API) workspaceAgentConnection(rw http.ResponseWriter, r *http.Request
func (api *API) workspaceAgentCoordinate(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
api.websocketWaitMutex.Lock()
api.websocketWaitGroup.Add(1)
api.websocketWaitMutex.Unlock()
defer api.websocketWaitGroup.Done()
api.WebsocketWaitMutex.Lock()
api.WebsocketWaitGroup.Add(1)
api.WebsocketWaitMutex.Unlock()
defer api.WebsocketWaitGroup.Done()
workspaceAgent := httpmw.WorkspaceAgent(r)
resource, err := api.Database.GetWorkspaceResourceByID(ctx, workspaceAgent.ResourceID)
if err != nil {
@@ -614,10 +614,10 @@ func (api *API) workspaceAgentClientCoordinate(rw http.ResponseWriter, r *http.R
}
}
api.websocketWaitMutex.Lock()
api.websocketWaitGroup.Add(1)
api.websocketWaitMutex.Unlock()
defer api.websocketWaitGroup.Done()
api.WebsocketWaitMutex.Lock()
api.WebsocketWaitGroup.Add(1)
api.WebsocketWaitMutex.Unlock()
defer api.WebsocketWaitGroup.Done()
workspaceAgent := httpmw.WorkspaceAgentParam(r)
conn, err := websocket.Accept(rw, r, nil)
@@ -759,10 +759,10 @@ func convertWorkspaceAgent(derpMap *tailcfg.DERPMap, coordinator tailnet.Coordin
func (api *API) workspaceAgentReportStats(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
api.websocketWaitMutex.Lock()
api.websocketWaitGroup.Add(1)
api.websocketWaitMutex.Unlock()
defer api.websocketWaitGroup.Done()
api.WebsocketWaitMutex.Lock()
api.WebsocketWaitGroup.Add(1)
api.WebsocketWaitMutex.Unlock()
defer api.WebsocketWaitGroup.Done()
workspaceAgent := httpmw.WorkspaceAgent(r)
resource, err := api.Database.GetWorkspaceResourceByID(ctx, workspaceAgent.ResourceID)
+3
View File
@@ -428,6 +428,8 @@ func (api *API) postWorkspaceBuilds(rw http.ResponseWriter, r *http.Request) {
return
}
tags := provisionerdserver.MutateTags(workspace.OwnerID, templateVersionJob.Tags)
// Store prior build number to compute new build number
var priorBuildNum int32
priorHistory, err := api.Database.GetLatestWorkspaceBuildByWorkspaceID(ctx, workspace.ID)
@@ -513,6 +515,7 @@ func (api *API) postWorkspaceBuilds(rw http.ResponseWriter, r *http.Request) {
StorageMethod: templateVersionJob.StorageMethod,
FileID: templateVersionJob.FileID,
Input: input,
Tags: tags,
})
if err != nil {
return xerrors.Errorf("insert provisioner job: %w", err)
+3
View File
@@ -373,6 +373,8 @@ func (api *API) postWorkspacesByOrganization(rw http.ResponseWriter, r *http.Req
return
}
tags := provisionerdserver.MutateTags(user.ID, templateVersionJob.Tags)
var (
provisionerJob database.ProvisionerJob
workspaceBuild database.WorkspaceBuild
@@ -435,6 +437,7 @@ func (api *API) postWorkspacesByOrganization(rw http.ResponseWriter, r *http.Req
StorageMethod: templateVersionJob.StorageMethod,
FileID: templateVersionJob.FileID,
Input: input,
Tags: tags,
})
if err != nil {
return xerrors.Errorf("insert provisioner job: %w", err)
+9 -7
View File
@@ -15,13 +15,14 @@ const (
)
const (
FeatureUserLimit = "user_limit"
FeatureAuditLog = "audit_log"
FeatureBrowserOnly = "browser_only"
FeatureSCIM = "scim"
FeatureTemplateRBAC = "template_rbac"
FeatureHighAvailability = "high_availability"
FeatureMultipleGitAuth = "multiple_git_auth"
FeatureUserLimit = "user_limit"
FeatureAuditLog = "audit_log"
FeatureBrowserOnly = "browser_only"
FeatureSCIM = "scim"
FeatureTemplateRBAC = "template_rbac"
FeatureHighAvailability = "high_availability"
FeatureMultipleGitAuth = "multiple_git_auth"
FeatureExternalProvisionerDaemons = "external_provisioner_daemons"
)
var FeatureNames = []string{
@@ -32,6 +33,7 @@ var FeatureNames = []string{
FeatureTemplateRBAC,
FeatureHighAvailability,
FeatureMultipleGitAuth,
FeatureExternalProvisionerDaemons,
}
type Feature struct {
+5 -4
View File
@@ -36,11 +36,12 @@ type Organization struct {
type CreateTemplateVersionRequest struct {
Name string `json:"name,omitempty" validate:"omitempty,template_name"`
// TemplateID optionally associates a version with a template.
TemplateID uuid.UUID `json:"template_id,omitempty"`
TemplateID uuid.UUID `json:"template_id,omitempty"`
StorageMethod ProvisionerStorageMethod `json:"storage_method" validate:"oneof=file,required"`
FileID uuid.UUID `json:"file_id" validate:"required"`
Provisioner ProvisionerType `json:"provisioner" validate:"oneof=terraform echo,required"`
ProvisionerTags map[string]string `json:"tags"`
StorageMethod ProvisionerStorageMethod `json:"storage_method" validate:"oneof=file,required"`
FileID uuid.UUID `json:"file_id" validate:"required"`
Provisioner ProvisionerType `json:"provisioner" validate:"oneof=terraform echo,required"`
// ParameterValues allows for additional parameters to be provided
// during the dry-run provision stage.
ParameterValues []CreateParameterRequest `json:"parameter_values,omitempty"`
+57 -5
View File
@@ -13,20 +13,22 @@ import (
"time"
"github.com/google/uuid"
"github.com/hashicorp/yamux"
"golang.org/x/xerrors"
"nhooyr.io/websocket"
"github.com/coder/coder/provisionerd/proto"
"github.com/coder/coder/provisionersdk"
)
type LogSource string
const (
LogSourceProvisionerDaemon LogSource = "provisioner_daemon"
LogSourceProvisioner LogSource = "provisioner"
)
type LogLevel string
const (
LogSourceProvisionerDaemon LogSource = "provisioner_daemon"
LogSourceProvisioner LogSource = "provisioner"
LogLevelTrace LogLevel = "trace"
LogLevelDebug LogLevel = "debug"
LogLevelInfo LogLevel = "info"
@@ -40,6 +42,7 @@ type ProvisionerDaemon struct {
UpdatedAt sql.NullTime `json:"updated_at"`
Name string `json:"name"`
Provisioners []ProvisionerType `json:"provisioners"`
Tags map[string]string `json:"tags"`
}
// ProvisionerJobStatus represents the at-time state of a job.
@@ -73,6 +76,7 @@ type ProvisionerJob struct {
Status ProvisionerJobStatus `json:"status"`
WorkerID *uuid.UUID `json:"worker_id,omitempty"`
FileID uuid.UUID `json:"file_id"`
Tags map[string]string `json:"tags"`
}
type ProvisionerJobLog struct {
@@ -162,3 +166,51 @@ func (c *Client) provisionerJobLogsAfter(ctx context.Context, path string, after
return nil
}), nil
}
// ListenProvisionerDaemon returns the gRPC service for a provisioner daemon implementation.
func (c *Client) ServeProvisionerDaemon(ctx context.Context, organization uuid.UUID, provisioners []ProvisionerType, tags map[string]string) (proto.DRPCProvisionerDaemonClient, error) {
serverURL, err := c.URL.Parse(fmt.Sprintf("/api/v2/organizations/%s/provisionerdaemons/serve", organization))
if err != nil {
return nil, xerrors.Errorf("parse url: %w", err)
}
query := serverURL.Query()
for _, provisioner := range provisioners {
query.Add("provisioner", string(provisioner))
}
for key, value := range tags {
query.Add("tag", fmt.Sprintf("%s=%s", key, value))
}
serverURL.RawQuery = query.Encode()
jar, err := cookiejar.New(nil)
if err != nil {
return nil, xerrors.Errorf("create cookie jar: %w", err)
}
jar.SetCookies(serverURL, []*http.Cookie{{
Name: SessionTokenKey,
Value: c.SessionToken(),
}})
httpClient := &http.Client{
Jar: jar,
}
conn, res, err := websocket.Dial(ctx, serverURL.String(), &websocket.DialOptions{
HTTPClient: httpClient,
// Need to disable compression to avoid a data-race.
CompressionMode: websocket.CompressionDisabled,
})
if err != nil {
if res == nil {
return nil, err
}
return nil, readBodyAsError(res)
}
// Align with the frame size of yamux.
conn.SetReadLimit(256 * 1024)
config := yamux.DefaultConfig()
config.LogOutput = io.Discard
session, err := yamux.Client(websocket.NetConn(ctx, conn, websocket.MessageBinary), config)
if err != nil {
return nil, xerrors.Errorf("multiplex client: %w", err)
}
return proto.NewDRPCProvisionerDaemonClient(provisionersdk.Conn(session)), nil
}
+155
View File
@@ -0,0 +1,155 @@
package cli
import (
"context"
"fmt"
"os"
"os/signal"
"time"
"cdr.dev/slog"
"cdr.dev/slog/sloggers/sloghuman"
agpl "github.com/coder/coder/cli"
"github.com/coder/coder/cli/cliflag"
"github.com/coder/coder/cli/cliui"
"github.com/coder/coder/cli/deployment"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/provisioner/terraform"
"github.com/coder/coder/provisionerd"
provisionerdproto "github.com/coder/coder/provisionerd/proto"
"github.com/coder/coder/provisionersdk"
"github.com/coder/coder/provisionersdk/proto"
"github.com/spf13/cobra"
"golang.org/x/xerrors"
)
func provisionerDaemons() *cobra.Command {
cmd := &cobra.Command{
Use: "provisionerd",
Short: "Manage provisioner daemons",
}
cmd.AddCommand(provisionerDaemonStart())
return cmd
}
func provisionerDaemonStart() *cobra.Command {
var (
cacheDir string
rawTags []string
)
cmd := &cobra.Command{
Use: "start",
Short: "Run a provisioner daemon",
RunE: func(cmd *cobra.Command, args []string) error {
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
notifyCtx, notifyStop := signal.NotifyContext(ctx, agpl.InterruptSignals...)
defer notifyStop()
client, err := agpl.CreateClient(cmd)
if err != nil {
return xerrors.Errorf("create client: %w", err)
}
org, err := agpl.CurrentOrganization(cmd, client)
if err != nil {
return xerrors.Errorf("get current organization: %w", err)
}
tags, err := agpl.ParseProvisionerTags(rawTags)
if err != nil {
return err
}
err = os.MkdirAll(cacheDir, 0o700)
if err != nil {
return xerrors.Errorf("mkdir %q: %w", cacheDir, err)
}
terraformClient, terraformServer := provisionersdk.TransportPipe()
go func() {
<-ctx.Done()
_ = terraformClient.Close()
_ = terraformServer.Close()
}()
logger := slog.Make(sloghuman.Sink(cmd.ErrOrStderr()))
errCh := make(chan error, 1)
go func() {
defer cancel()
err := terraform.Serve(ctx, &terraform.ServeOptions{
ServeOptions: &provisionersdk.ServeOptions{
Listener: terraformServer,
},
CachePath: cacheDir,
Logger: logger.Named("terraform"),
})
if err != nil && !xerrors.Is(err, context.Canceled) {
select {
case errCh <- err:
default:
}
}
}()
tempDir, err := os.MkdirTemp("", "provisionerd")
if err != nil {
return err
}
logger.Info(ctx, "starting provisioner daemon", slog.F("tags", tags))
provisioners := provisionerd.Provisioners{
string(database.ProvisionerTypeTerraform): proto.NewDRPCProvisionerClient(provisionersdk.Conn(terraformClient)),
}
srv := provisionerd.New(func(ctx context.Context) (provisionerdproto.DRPCProvisionerDaemonClient, error) {
return client.ServeProvisionerDaemon(ctx, org.ID, []codersdk.ProvisionerType{
codersdk.ProvisionerTypeTerraform,
}, tags)
}, &provisionerd.Options{
Logger: logger,
PollInterval: 500 * time.Millisecond,
UpdateInterval: 500 * time.Millisecond,
Provisioners: provisioners,
WorkDirectory: tempDir,
})
var exitErr error
select {
case <-notifyCtx.Done():
exitErr = notifyCtx.Err()
_, _ = fmt.Fprintln(cmd.OutOrStdout(), cliui.Styles.Bold.Render(
"Interrupt caught, gracefully exiting. Use ctrl+\\ to force quit",
))
case exitErr = <-errCh:
}
if exitErr != nil && !xerrors.Is(exitErr, context.Canceled) {
cmd.Printf("Unexpected error, shutting down server: %s\n", exitErr)
}
shutdown, shutdownCancel := context.WithTimeout(ctx, time.Minute)
defer shutdownCancel()
err = srv.Shutdown(shutdown)
if err != nil {
return xerrors.Errorf("shutdown: %w", err)
}
cancel()
if xerrors.Is(exitErr, context.Canceled) {
return nil
}
return exitErr
},
}
cliflag.StringVarP(cmd.Flags(), &cacheDir, "cache-dir", "c", "CODER_CACHE_DIRECTORY", deployment.DefaultCacheDir(),
"Specify a directory to cache provisioner job files.")
cliflag.StringArrayVarP(cmd.Flags(), &rawTags, "tag", "t", "CODER_PROVISIONERD_TAGS", []string{},
"Specify a list of tags to target provisioner jobs.")
return cmd
}
+1
View File
@@ -12,6 +12,7 @@ func enterpriseOnly() []*cobra.Command {
features(),
licenses(),
groups(),
provisionerDaemons(),
}
}
+16 -9
View File
@@ -90,7 +90,15 @@ func New(ctx context.Context, options *Options) (*API, error) {
r.Get("/", api.group)
})
})
r.Route("/organizations/{organization}/provisionerdaemons", func(r chi.Router) {
r.Use(
api.provisionerDaemonsEnabledMW,
apiKeyMiddleware,
httpmw.ExtractOrganizationParam(api.Database),
)
r.Get("/", api.provisionerDaemons)
r.Get("/serve", api.provisionerDaemonServe)
})
r.Route("/templates/{template}/acl", func(r chi.Router) {
r.Use(
api.templateRBACEnabledMW,
@@ -100,7 +108,6 @@ func New(ctx context.Context, options *Options) (*API, error) {
r.Get("/", api.templateACL)
r.Patch("/", api.patchTemplateACL)
})
r.Route("/groups/{group}", func(r chi.Router) {
r.Use(
api.templateRBACEnabledMW,
@@ -111,7 +118,6 @@ func New(ctx context.Context, options *Options) (*API, error) {
r.Patch("/", api.patchGroup)
r.Delete("/", api.deleteGroup)
})
r.Route("/workspace-quota", func(r chi.Router) {
r.Use(
apiKeyMiddleware,
@@ -222,12 +228,13 @@ func (api *API) updateEntitlements(ctx context.Context) error {
defer api.entitlementsMu.Unlock()
entitlements, err := license.Entitlements(ctx, api.Database, api.Logger, len(api.replicaManager.All()), len(api.GitAuthConfigs), api.Keys, map[string]bool{
codersdk.FeatureAuditLog: api.AuditLogging,
codersdk.FeatureBrowserOnly: api.BrowserOnly,
codersdk.FeatureSCIM: len(api.SCIMAPIKey) != 0,
codersdk.FeatureHighAvailability: api.DERPServerRelayAddress != "",
codersdk.FeatureMultipleGitAuth: len(api.GitAuthConfigs) > 1,
codersdk.FeatureTemplateRBAC: api.RBAC,
codersdk.FeatureAuditLog: api.AuditLogging,
codersdk.FeatureBrowserOnly: api.BrowserOnly,
codersdk.FeatureSCIM: len(api.SCIMAPIKey) != 0,
codersdk.FeatureHighAvailability: api.DERPServerRelayAddress != "",
codersdk.FeatureMultipleGitAuth: len(api.GitAuthConfigs) > 1,
codersdk.FeatureTemplateRBAC: api.RBAC,
codersdk.FeatureExternalProvisionerDaemons: true,
})
if err != nil {
return err
+4 -3
View File
@@ -41,9 +41,10 @@ func TestEntitlements(t *testing.T) {
})
_ = coderdtest.CreateFirstUser(t, client)
coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
UserLimit: 100,
AuditLog: true,
TemplateRBAC: true,
UserLimit: 100,
AuditLog: true,
TemplateRBAC: true,
ExternalProvisionerDaemons: true,
})
res, err := client.Entitlements(context.Background())
require.NoError(t, err)
@@ -99,19 +99,20 @@ func NewWithAPI(t *testing.T, options *Options) (*codersdk.Client, io.Closer, *c
}
type LicenseOptions struct {
AccountType string
AccountID string
Trial bool
AllFeatures bool
GraceAt time.Time
ExpiresAt time.Time
UserLimit int64
AuditLog bool
BrowserOnly bool
SCIM bool
TemplateRBAC bool
HighAvailability bool
MultipleGitAuth bool
AccountType string
AccountID string
Trial bool
AllFeatures bool
GraceAt time.Time
ExpiresAt time.Time
UserLimit int64
AuditLog bool
BrowserOnly bool
SCIM bool
TemplateRBAC bool
HighAvailability bool
MultipleGitAuth bool
ExternalProvisionerDaemons bool
}
// AddLicense generates a new license with the options provided and inserts it.
@@ -158,6 +159,11 @@ func GenerateLicense(t *testing.T, options LicenseOptions) string {
multipleGitAuth = 1
}
externalProvisionerDaemons := int64(0)
if options.ExternalProvisionerDaemons {
externalProvisionerDaemons = 1
}
c := &license.Claims{
RegisteredClaims: jwt.RegisteredClaims{
Issuer: "test@testing.test",
@@ -172,13 +178,14 @@ func GenerateLicense(t *testing.T, options LicenseOptions) string {
Version: license.CurrentVersion,
AllFeatures: options.AllFeatures,
Features: license.Features{
UserLimit: options.UserLimit,
AuditLog: auditLog,
BrowserOnly: browserOnly,
SCIM: scim,
HighAvailability: highAvailability,
TemplateRBAC: rbacEnabled,
MultipleGitAuth: multipleGitAuth,
UserLimit: options.UserLimit,
AuditLog: auditLog,
BrowserOnly: browserOnly,
SCIM: scim,
HighAvailability: highAvailability,
TemplateRBAC: rbacEnabled,
MultipleGitAuth: multipleGitAuth,
ExternalProvisionerDaemons: externalProvisionerDaemons,
},
}
tok := jwt.NewWithClaims(jwt.SigningMethodEdDSA, c)
@@ -33,7 +33,8 @@ func TestAuthorizeAllEndpoints(t *testing.T) {
ctx, _ := testutil.Context(t)
admin := coderdtest.CreateFirstUser(t, client)
license := coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
TemplateRBAC: true,
TemplateRBAC: true,
ExternalProvisionerDaemons: true,
})
group, err := client.CreateGroup(ctx, admin.OrganizationID, codersdk.CreateGroupRequest{
Name: "testgroup",
@@ -47,6 +48,8 @@ func TestAuthorizeAllEndpoints(t *testing.T) {
a.URLParams["{groupName}"] = group.Name
skipRoutes, assertRoute := coderdtest.AGPLRoutes(a)
skipRoutes["GET:/api/v2/organizations/{organization}/provisionerdaemons/serve"] = "This route checks for RBAC dependent on input parameters!"
assertRoute["GET:/api/v2/entitlements"] = coderdtest.RouteCheck{
NoAuthorize: true,
}
@@ -84,6 +87,14 @@ func TestAuthorizeAllEndpoints(t *testing.T) {
AssertAction: rbac.ActionRead,
AssertObject: groupObj,
}
assertRoute["GET:/api/v2/organizations/{organization}/provisionerdaemons"] = coderdtest.RouteCheck{
AssertAction: rbac.ActionRead,
AssertObject: rbac.ResourceProvisionerDaemon,
}
assertRoute["GET:/api/v2/organizations/{organization}/provisionerdaemons"] = coderdtest.RouteCheck{
AssertAction: rbac.ActionRead,
AssertObject: rbac.ResourceProvisionerDaemon,
}
assertRoute["GET:/api/v2/groups/{group}"] = coderdtest.RouteCheck{
AssertAction: rbac.ActionRead,
AssertObject: groupObj,
+14 -7
View File
@@ -117,6 +117,12 @@ func Entitlements(
Enabled: true,
}
}
if claims.Features.ExternalProvisionerDaemons > 0 {
entitlements.Features[codersdk.FeatureExternalProvisionerDaemons] = codersdk.Feature{
Entitlement: entitlement,
Enabled: true,
}
}
if claims.AllFeatures {
allFeatures = true
}
@@ -238,13 +244,14 @@ var (
)
type Features struct {
UserLimit int64 `json:"user_limit"`
AuditLog int64 `json:"audit_log"`
BrowserOnly int64 `json:"browser_only"`
SCIM int64 `json:"scim"`
TemplateRBAC int64 `json:"template_rbac"`
HighAvailability int64 `json:"high_availability"`
MultipleGitAuth int64 `json:"multiple_git_auth"`
UserLimit int64 `json:"user_limit"`
AuditLog int64 `json:"audit_log"`
BrowserOnly int64 `json:"browser_only"`
SCIM int64 `json:"scim"`
TemplateRBAC int64 `json:"template_rbac"`
HighAvailability int64 `json:"high_availability"`
MultipleGitAuth int64 `json:"multiple_git_auth"`
ExternalProvisionerDaemons int64 `json:"external_provisioner_daemons"`
}
type Claims struct {
+24 -21
View File
@@ -20,12 +20,13 @@ import (
func TestEntitlements(t *testing.T) {
t.Parallel()
all := map[string]bool{
codersdk.FeatureAuditLog: true,
codersdk.FeatureBrowserOnly: true,
codersdk.FeatureSCIM: true,
codersdk.FeatureHighAvailability: true,
codersdk.FeatureTemplateRBAC: true,
codersdk.FeatureMultipleGitAuth: true,
codersdk.FeatureAuditLog: true,
codersdk.FeatureBrowserOnly: true,
codersdk.FeatureSCIM: true,
codersdk.FeatureHighAvailability: true,
codersdk.FeatureTemplateRBAC: true,
codersdk.FeatureMultipleGitAuth: true,
codersdk.FeatureExternalProvisionerDaemons: true,
}
t.Run("Defaults", func(t *testing.T) {
@@ -61,13 +62,14 @@ func TestEntitlements(t *testing.T) {
db := databasefake.New()
db.InsertLicense(context.Background(), database.InsertLicenseParams{
JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{
UserLimit: 100,
AuditLog: true,
BrowserOnly: true,
SCIM: true,
HighAvailability: true,
TemplateRBAC: true,
MultipleGitAuth: true,
UserLimit: 100,
AuditLog: true,
BrowserOnly: true,
SCIM: true,
HighAvailability: true,
TemplateRBAC: true,
MultipleGitAuth: true,
ExternalProvisionerDaemons: true,
}),
Exp: time.Now().Add(time.Hour),
})
@@ -84,14 +86,15 @@ func TestEntitlements(t *testing.T) {
db := databasefake.New()
db.InsertLicense(context.Background(), database.InsertLicenseParams{
JWT: coderdenttest.GenerateLicense(t, coderdenttest.LicenseOptions{
UserLimit: 100,
AuditLog: true,
BrowserOnly: true,
SCIM: true,
HighAvailability: true,
TemplateRBAC: true,
GraceAt: time.Now().Add(-time.Hour),
ExpiresAt: time.Now().Add(time.Hour),
UserLimit: 100,
AuditLog: true,
BrowserOnly: true,
SCIM: true,
HighAvailability: true,
TemplateRBAC: true,
ExternalProvisionerDaemons: true,
GraceAt: time.Now().Add(-time.Hour),
ExpiresAt: time.Now().Add(time.Hour),
}),
Exp: time.Now().Add(time.Hour),
})
+16 -14
View File
@@ -101,25 +101,27 @@ func TestGetLicense(t *testing.T) {
assert.Equal(t, int32(1), licenses[0].ID)
assert.Equal(t, "testing", licenses[0].Claims["account_id"])
assert.Equal(t, map[string]interface{}{
codersdk.FeatureUserLimit: json.Number("0"),
codersdk.FeatureAuditLog: json.Number("1"),
codersdk.FeatureSCIM: json.Number("1"),
codersdk.FeatureBrowserOnly: json.Number("1"),
codersdk.FeatureHighAvailability: json.Number("0"),
codersdk.FeatureTemplateRBAC: json.Number("1"),
codersdk.FeatureMultipleGitAuth: json.Number("0"),
codersdk.FeatureUserLimit: json.Number("0"),
codersdk.FeatureAuditLog: json.Number("1"),
codersdk.FeatureSCIM: json.Number("1"),
codersdk.FeatureBrowserOnly: json.Number("1"),
codersdk.FeatureHighAvailability: json.Number("0"),
codersdk.FeatureTemplateRBAC: json.Number("1"),
codersdk.FeatureMultipleGitAuth: json.Number("0"),
codersdk.FeatureExternalProvisionerDaemons: json.Number("0"),
}, licenses[0].Claims["features"])
assert.Equal(t, int32(2), licenses[1].ID)
assert.Equal(t, "testing2", licenses[1].Claims["account_id"])
assert.Equal(t, true, licenses[1].Claims["trial"])
assert.Equal(t, map[string]interface{}{
codersdk.FeatureUserLimit: json.Number("200"),
codersdk.FeatureAuditLog: json.Number("1"),
codersdk.FeatureSCIM: json.Number("1"),
codersdk.FeatureBrowserOnly: json.Number("1"),
codersdk.FeatureHighAvailability: json.Number("0"),
codersdk.FeatureTemplateRBAC: json.Number("0"),
codersdk.FeatureMultipleGitAuth: json.Number("0"),
codersdk.FeatureUserLimit: json.Number("200"),
codersdk.FeatureAuditLog: json.Number("1"),
codersdk.FeatureSCIM: json.Number("1"),
codersdk.FeatureBrowserOnly: json.Number("1"),
codersdk.FeatureHighAvailability: json.Number("0"),
codersdk.FeatureTemplateRBAC: json.Number("0"),
codersdk.FeatureMultipleGitAuth: json.Number("0"),
codersdk.FeatureExternalProvisionerDaemons: json.Number("0"),
}, licenses[1].Claims["features"])
})
}
+245
View File
@@ -0,0 +1,245 @@
package coderd
import (
"database/sql"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"strings"
"github.com/google/uuid"
"github.com/hashicorp/yamux"
"github.com/moby/moby/pkg/namesgenerator"
"golang.org/x/xerrors"
"nhooyr.io/websocket"
"storj.io/drpc/drpcmux"
"storj.io/drpc/drpcserver"
"cdr.dev/slog"
"github.com/coder/coder/coderd"
"github.com/coder/coder/coderd/database"
"github.com/coder/coder/coderd/httpapi"
"github.com/coder/coder/coderd/httpmw"
"github.com/coder/coder/coderd/provisionerdserver"
"github.com/coder/coder/coderd/rbac"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/provisionerd/proto"
)
func (api *API) provisionerDaemonsEnabledMW(next http.Handler) http.Handler {
return http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
api.entitlementsMu.RLock()
epd := api.entitlements.Features[codersdk.FeatureExternalProvisionerDaemons].Enabled
api.entitlementsMu.RUnlock()
if !epd {
httpapi.Write(r.Context(), rw, http.StatusForbidden, codersdk.Response{
Message: "External provisioner daemons is an Enterprise feature. Contact sales!",
})
return
}
next.ServeHTTP(rw, r)
})
}
func (api *API) provisionerDaemons(rw http.ResponseWriter, r *http.Request) {
ctx := r.Context()
org := httpmw.OrganizationParam(r)
if !api.Authorize(r, rbac.ActionRead, rbac.ResourceProvisionerDaemon.InOrg(org.ID)) {
httpapi.Forbidden(rw)
return
}
daemons, err := api.Database.GetProvisionerDaemons(ctx)
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching provisioner daemons.",
Detail: err.Error(),
})
return
}
if daemons == nil {
daemons = []database.ProvisionerDaemon{}
}
daemons, err = coderd.AuthorizeFilter(api.AGPL.HTTPAuth, r, rbac.ActionRead, daemons)
if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error fetching provisioner daemons.",
Detail: err.Error(),
})
return
}
apiDaemons := make([]codersdk.ProvisionerDaemon, 0)
for _, daemon := range daemons {
apiDaemons = append(apiDaemons, convertProvisionerDaemon(daemon))
}
httpapi.Write(ctx, rw, http.StatusOK, apiDaemons)
}
// Serves the provisioner daemon protobuf API over a WebSocket.
func (api *API) provisionerDaemonServe(rw http.ResponseWriter, r *http.Request) {
tags := map[string]string{}
if r.URL.Query().Has("tag") {
for _, tag := range r.URL.Query()["tag"] {
parts := strings.SplitN(tag, "=", 2)
if len(parts) < 2 {
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
Message: fmt.Sprintf("Invalid format for tag %q. Key and value must be separated with =.", tag),
})
return
}
tags[parts[0]] = parts[1]
}
}
if !r.URL.Query().Has("provisioner") {
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
Message: "The provisioner query parameter must be specified.",
})
return
}
provisionersMap := map[codersdk.ProvisionerType]struct{}{}
for _, provisioner := range r.URL.Query()["provisioner"] {
switch provisioner {
case string(codersdk.ProvisionerTypeEcho):
provisionersMap[codersdk.ProvisionerTypeEcho] = struct{}{}
case string(codersdk.ProvisionerTypeTerraform):
provisionersMap[codersdk.ProvisionerTypeTerraform] = struct{}{}
default:
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
Message: fmt.Sprintf("Unknown provisioner type %q", provisioner),
})
return
}
}
// Any authenticated user can create provisioner daemons scoped
// for jobs that they own, but only authorized users can create
// globally scoped provisioners that attach to all jobs.
apiKey := httpmw.APIKey(r)
tags = provisionerdserver.MutateTags(apiKey.UserID, tags)
if tags[provisionerdserver.TagScope] == provisionerdserver.ScopeOrganization {
if !api.AGPL.Authorize(r, rbac.ActionCreate, rbac.ResourceProvisionerDaemon) {
httpapi.Write(r.Context(), rw, http.StatusForbidden, codersdk.Response{
Message: "You aren't allowed to create provisioner daemons for the organization.",
})
return
}
}
provisioners := make([]database.ProvisionerType, 0)
for p := range provisionersMap {
switch p {
case codersdk.ProvisionerTypeTerraform:
provisioners = append(provisioners, database.ProvisionerTypeTerraform)
case codersdk.ProvisionerTypeEcho:
provisioners = append(provisioners, database.ProvisionerTypeEcho)
}
}
name := namesgenerator.GetRandomName(1)
daemon, err := api.Database.InsertProvisionerDaemon(r.Context(), database.InsertProvisionerDaemonParams{
ID: uuid.New(),
CreatedAt: database.Now(),
Name: name,
Provisioners: provisioners,
Tags: tags,
})
if err != nil {
httpapi.Write(r.Context(), rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error writing provisioner daemon.",
Detail: err.Error(),
})
return
}
rawTags, err := json.Marshal(daemon.Tags)
if err != nil {
httpapi.Write(r.Context(), rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error marshaling daemon tags.",
Detail: err.Error(),
})
return
}
api.AGPL.WebsocketWaitMutex.Lock()
api.AGPL.WebsocketWaitGroup.Add(1)
api.AGPL.WebsocketWaitMutex.Unlock()
defer api.AGPL.WebsocketWaitGroup.Done()
conn, err := websocket.Accept(rw, r, &websocket.AcceptOptions{
// Need to disable compression to avoid a data-race.
CompressionMode: websocket.CompressionDisabled,
})
if err != nil {
httpapi.Write(r.Context(), rw, http.StatusBadRequest, codersdk.Response{
Message: "Internal error accepting websocket connection.",
Detail: err.Error(),
})
return
}
// Align with the frame size of yamux.
conn.SetReadLimit(256 * 1024)
// Multiplexes the incoming connection using yamux.
// This allows multiple function calls to occur over
// the same connection.
config := yamux.DefaultConfig()
config.LogOutput = io.Discard
session, err := yamux.Server(websocket.NetConn(r.Context(), conn, websocket.MessageBinary), config)
if err != nil {
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("multiplex server: %s", err))
return
}
mux := drpcmux.New()
err = proto.DRPCRegisterProvisionerDaemon(mux, &provisionerdserver.Server{
AccessURL: api.AccessURL,
ID: daemon.ID,
Database: api.Database,
Pubsub: api.Pubsub,
Provisioners: daemon.Provisioners,
Telemetry: api.Telemetry,
Logger: api.Logger.Named(fmt.Sprintf("provisionerd-%s", daemon.Name)),
Tags: rawTags,
})
if err != nil {
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("drpc register provisioner daemon: %s", err))
return
}
server := drpcserver.NewWithOptions(mux, drpcserver.Options{
Log: func(err error) {
if xerrors.Is(err, io.EOF) {
return
}
api.Logger.Debug(r.Context(), "drpc server error", slog.Error(err))
},
})
err = server.Serve(r.Context(), session)
if err != nil && !xerrors.Is(err, io.EOF) {
api.Logger.Debug(r.Context(), "provisioner daemon disconnected", slog.Error(err))
_ = conn.Close(websocket.StatusInternalError, httpapi.WebsocketCloseSprintf("serve: %s", err))
return
}
_ = conn.Close(websocket.StatusGoingAway, "")
}
func convertProvisionerDaemon(daemon database.ProvisionerDaemon) codersdk.ProvisionerDaemon {
result := codersdk.ProvisionerDaemon{
ID: daemon.ID,
CreatedAt: daemon.CreatedAt,
UpdatedAt: daemon.UpdatedAt,
Name: daemon.Name,
Tags: daemon.Tags,
}
for _, provisionerType := range daemon.Provisioners {
result.Provisioners = append(result.Provisioners, codersdk.ProvisionerType(provisionerType))
}
return result
}
@@ -0,0 +1,139 @@
package coderd_test
import (
"context"
"net/http"
"testing"
"github.com/google/uuid"
"github.com/stretchr/testify/require"
"github.com/coder/coder/coderd/coderdtest"
"github.com/coder/coder/coderd/provisionerdserver"
"github.com/coder/coder/codersdk"
"github.com/coder/coder/enterprise/coderd/coderdenttest"
"github.com/coder/coder/provisioner/echo"
"github.com/coder/coder/provisionersdk/proto"
)
func TestProvisionerDaemonServe(t *testing.T) {
t.Parallel()
t.Run("NoLicense", func(t *testing.T) {
t.Parallel()
client := coderdenttest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
_, err := client.ServeProvisionerDaemon(context.Background(), user.OrganizationID, []codersdk.ProvisionerType{
codersdk.ProvisionerTypeEcho,
}, map[string]string{})
require.Error(t, err)
var apiError *codersdk.Error
require.ErrorAs(t, err, &apiError)
require.Equal(t, http.StatusForbidden, apiError.StatusCode())
})
t.Run("Organization", func(t *testing.T) {
t.Parallel()
client := coderdenttest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
ExternalProvisionerDaemons: true,
})
srv, err := client.ServeProvisionerDaemon(context.Background(), user.OrganizationID, []codersdk.ProvisionerType{
codersdk.ProvisionerTypeEcho,
}, map[string]string{})
require.NoError(t, err)
srv.DRPCConn().Close()
})
t.Run("OrganizationNoPerms", func(t *testing.T) {
t.Parallel()
client := coderdenttest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
ExternalProvisionerDaemons: true,
})
another := coderdtest.CreateAnotherUser(t, client, user.OrganizationID)
_, err := another.ServeProvisionerDaemon(context.Background(), user.OrganizationID, []codersdk.ProvisionerType{
codersdk.ProvisionerTypeEcho,
}, map[string]string{
provisionerdserver.TagScope: provisionerdserver.ScopeOrganization,
})
require.Error(t, err)
var apiError *codersdk.Error
require.ErrorAs(t, err, &apiError)
require.Equal(t, http.StatusForbidden, apiError.StatusCode())
})
t.Run("UserLocal", func(t *testing.T) {
t.Parallel()
client := coderdenttest.New(t, nil)
user := coderdtest.CreateFirstUser(t, client)
coderdenttest.AddLicense(t, client, coderdenttest.LicenseOptions{
ExternalProvisionerDaemons: true,
})
closer := coderdtest.NewExternalProvisionerDaemon(t, client, user.OrganizationID, map[string]string{
provisionerdserver.TagScope: provisionerdserver.ScopeUser,
})
defer closer.Close()
authToken := uuid.NewString()
data, err := echo.Tar(&echo.Responses{
Parse: echo.ParseComplete,
ProvisionPlan: []*proto.Provision_Response{{
Type: &proto.Provision_Response_Complete{
Complete: &proto.Provision_Complete{
Resources: []*proto.Resource{{
Name: "example",
Type: "aws_instance",
Agents: []*proto.Agent{{
Id: uuid.NewString(),
Name: "example",
}},
}},
},
},
}},
ProvisionApply: []*proto.Provision_Response{{
Type: &proto.Provision_Response_Complete{
Complete: &proto.Provision_Complete{
Resources: []*proto.Resource{{
Name: "example",
Type: "aws_instance",
Agents: []*proto.Agent{{
Id: uuid.NewString(),
Name: "example",
Auth: &proto.Agent_Token{
Token: authToken,
},
}},
}},
},
},
}},
})
require.NoError(t, err)
file, err := client.Upload(context.Background(), codersdk.ContentTypeTar, data)
require.NoError(t, err)
version, err := client.CreateTemplateVersion(context.Background(), user.OrganizationID, codersdk.CreateTemplateVersionRequest{
Name: "example",
StorageMethod: codersdk.ProvisionerStorageMethodFile,
FileID: file.ID,
Provisioner: codersdk.ProvisionerTypeEcho,
ProvisionerTags: map[string]string{
provisionerdserver.TagScope: provisionerdserver.ScopeUser,
},
})
require.NoError(t, err)
coderdtest.AwaitTemplateVersionJob(t, client, version.ID)
template := coderdtest.CreateTemplate(t, client, user.OrganizationID, version.ID)
another := coderdtest.CreateAnotherUser(t, client, user.OrganizationID)
_ = closer.Close()
closer = coderdtest.NewExternalProvisionerDaemon(t, another, user.OrganizationID, map[string]string{
provisionerdserver.TagScope: provisionerdserver.ScopeUser,
})
defer closer.Close()
workspace := coderdtest.CreateWorkspace(t, another, user.OrganizationID, template.ID)
coderdtest.AwaitWorkspaceBuildJob(t, client, workspace.LatestBuild.ID)
})
}
+3
View File
@@ -58,6 +58,9 @@ type Options struct {
// New creates and starts a provisioner daemon.
func New(clientDialer Dialer, opts *Options) *Server {
if opts == nil {
opts = &Options{}
}
if opts.PollInterval == 0 {
opts.PollInterval = 5 * time.Second
}
+2
View File
@@ -69,12 +69,14 @@ export const provisioners: TypesGen.ProvisionerDaemon[] = [
name: "Terraform",
created_at: "",
provisioners: [],
tags: {},
},
{
id: "cdr-basic",
name: "Basic",
created_at: "",
provisioners: [],
tags: {},
},
]
+3
View File
@@ -197,6 +197,7 @@ export interface CreateTemplateVersionRequest {
readonly storage_method: ProvisionerStorageMethod
readonly file_id: string
readonly provisioner: ProvisionerType
readonly tags: Record<string, string>
readonly parameter_values?: CreateParameterRequest[]
}
@@ -540,6 +541,7 @@ export interface ProvisionerDaemon {
readonly updated_at?: string
readonly name: string
readonly provisioners: ProvisionerType[]
readonly tags: Record<string, string>
}
// From codersdk/provisionerdaemons.go
@@ -553,6 +555,7 @@ export interface ProvisionerJob {
readonly status: ProvisionerJobStatus
readonly worker_id?: string
readonly file_id: string
readonly tags: Record<string, string>
}
// From codersdk/provisionerdaemons.go
+2
View File
@@ -131,6 +131,7 @@ export const MockProvisioner: TypesGen.ProvisionerDaemon = {
id: "test-provisioner",
name: "Test Provisioner",
provisioners: ["echo"],
tags: {},
}
export const MockProvisionerJob: TypesGen.ProvisionerJob = {
@@ -139,6 +140,7 @@ export const MockProvisionerJob: TypesGen.ProvisionerJob = {
status: "succeeded",
file_id: "fc0774ce-cc9e-48d4-80ae-88f7a4d4a8b0",
completed_at: "2022-05-17T17:39:01.382927298Z",
tags: {},
}
export const MockFailedProvisionerJob: TypesGen.ProvisionerJob = {