Files
coder/provisionerd/provisionerd_test.go
T
Ethan ef0151601e feat: report insufficient quota build failures in chat tools (#24956)
## Summary

When a workspace build fails because the user is over their group quota,
the chat tools currently surface the failure as a bare `"workspace build
failed: insufficient quota"` string with no machine-readable error code
and no visibility into the user's current usage. Agents and the UI
cannot distinguish quota failures from any other Terraform error, so
users see an opaque message and have no clear path to recovery.

This PR tags quota failures with a typed error code at the source and
propagates it through the chat tool layer so callers can react to it
explicitly.

Relates to CODAGT-20

## Changes

**Provisioner runner**

- Add `InsufficientQuotaErrorCode = "INSUFFICIENT_QUOTA"` and set it
explicitly at the `commitQuota` failure site via a new
`failedWorkspaceBuildfCode` helper, so `provisioner_jobs.error_code` is
populated only on the genuine quota path. The substring matcher used for
externally produced sentinels (e.g. `"missing parameter"`, `"required
template variables"`) is intentionally not extended; provider errors
that happen to mention "insufficient quota" stay classified as generic
build failures.

**SDK and API contract**

- Add `JobErrorCodeInsufficientQuota` and a
`JobIsInsufficientQuotaErrorCode` helper to `codersdk`.
- Extend the swagger `enums` tag on `ProvisionerJob.ErrorCode` to
include `INSUFFICIENT_QUOTA`.
- Regenerate `coderd/apidoc`, `docs/reference/api/*`, and
`site/src/api/typesGenerated.ts`.

**chattool create_workspace / start_workspace**

- `waitForBuild` now returns a typed `*workspaceBuildError` carrying
both the message and the `JobErrorCode`, instead of a bare error string.
- New `quotaerror.go` introduces a structured `quotaErrorResult` (with
`error_code`, `title`, `message`, `build_id`, and optional `quota`) and
a best-effort `workspaceQuotaDetails` lookup that wraps owner
authorization internally and fetches `credits_consumed` and `budget`
from the database. Quota lookup failures (including authorization
failures) never block the failure payload.
- On quota-coded build failures, both `create_workspace` and
`start_workspace` now return the structured response (with the recovery
guidance inlined into `message`) instead of the bare `"insufficient
quota"` string. This applies to all three failure paths: post-creation,
an in-progress existing build, and a freshly triggered start build.
Non-quota build failures continue to use the existing
`buildToolResponse` / `newBuildError` path.
- Owner authorization is wrapped only on the call sites that need it
(the `CreateFn` and `StartFn` invocations and the quota-detail lookup),
so idempotent fast paths (already running, already in progress,
existing-workspace early returns) do not pay for an extra RBAC
round-trip or fail when role lookup is transient.

## Out of scope

- No changes to quota math, allowances, or bypass behavior.
- No automatic retries.
- No new quota-inspection tools and no changes to MCP
`coder_create_workspace` (which returns immediately and never observed
the build outcome here).
- No frontend UI changes; those will land in a follow-up PR that
consumes the new `INSUFFICIENT_QUOTA` code.
2026-05-07 15:01:58 +10:00

1418 lines
47 KiB
Go

package provisionerd_test
import (
"context"
"fmt"
"io"
"os"
"path/filepath"
"sync"
"testing"
"time"
"github.com/hashicorp/yamux"
"github.com/spf13/afero"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.uber.org/atomic"
"go.uber.org/goleak"
"golang.org/x/xerrors"
"storj.io/drpc/drpcmux"
"storj.io/drpc/drpcserver"
"cdr.dev/slog/v3"
"cdr.dev/slog/v3/sloggers/slogtest"
"github.com/coder/coder/v2/codersdk/drpcsdk"
"github.com/coder/coder/v2/provisionerd"
"github.com/coder/coder/v2/provisionerd/proto"
"github.com/coder/coder/v2/provisionerd/runner"
"github.com/coder/coder/v2/provisionersdk"
sdkproto "github.com/coder/coder/v2/provisionersdk/proto"
"github.com/coder/coder/v2/provisionersdk/tfpath"
"github.com/coder/coder/v2/testutil"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m, testutil.GoleakOptions...)
}
func closedWithin(c chan struct{}, d time.Duration) func() bool {
return func() bool {
select {
case <-c:
return true
case <-time.After(d):
return false
}
}
}
func TestProvisionerd(t *testing.T) {
t.Parallel()
noopUpdateJob := func(ctx context.Context, update *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) {
return &proto.UpdateJobResponse{}, nil
}
t.Run("InstantClose", func(t *testing.T) {
t.Parallel()
done := make(chan struct{})
t.Cleanup(func() {
close(done)
})
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{}), nil
}, provisionerd.LocalProvisioners{})
require.NoError(t, closer.Close())
})
t.Run("ConnectErrorClose", func(t *testing.T) {
t.Parallel()
done := make(chan struct{})
t.Cleanup(func() {
close(done)
})
completeChan := make(chan struct{})
var completed sync.Once
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
completed.Do(func() {
defer close(completeChan)
})
return nil, xerrors.New("an error")
}, provisionerd.LocalProvisioners{})
require.Condition(t, closedWithin(completeChan, testutil.WaitShort))
require.NoError(t, closer.Close())
})
t.Run("CloseCancelsJob", func(t *testing.T) {
t.Parallel()
done := make(chan struct{})
t.Cleanup(func() {
close(done)
})
completeChan := make(chan struct{})
var completed sync.Once
var closer io.Closer
var closerMutex sync.Mutex
closerMutex.Lock()
closer = createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
acquireJobWithCancel: func(stream proto.DRPCProvisionerDaemon_AcquireJobWithCancelStream) error {
err := stream.Send(&proto.AcquiredJob{
JobId: "test",
Provisioner: "someprovisioner",
TemplateSourceArchive: testutil.CreateTar(t, map[string]string{
"test.txt": "content",
}),
Type: &proto.AcquiredJob_TemplateImport_{
TemplateImport: &proto.AcquiredJob_TemplateImport{
Metadata: &sdkproto.Metadata{},
},
},
})
assert.NoError(t, err)
return nil
},
updateJob: noopUpdateJob,
failJob: func(ctx context.Context, job *proto.FailedJob) (*proto.Empty, error) {
completed.Do(func() {
close(completeChan)
})
return &proto.Empty{}, nil
},
}), nil
}, provisionerd.LocalProvisioners{
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
parse: func(_ *provisionersdk.Session, _ *sdkproto.ParseRequest, _ <-chan struct{}) *sdkproto.ParseComplete {
closerMutex.Lock()
defer closerMutex.Unlock()
err := closer.Close()
c := &sdkproto.ParseComplete{}
if err != nil {
c.Error = err.Error()
}
return c
},
init: func(s *provisionersdk.Session, r *provisionersdk.InitRequest, canceledOrComplete <-chan struct{}) *sdkproto.InitComplete {
closerMutex.Lock()
defer closerMutex.Unlock()
err := closer.Close()
c := &sdkproto.InitComplete{}
if err != nil {
c.Error = err.Error()
}
return c
},
}),
})
closerMutex.Unlock()
require.Condition(t, closedWithin(completeChan, testutil.WaitShort))
require.NoError(t, closer.Close())
})
// LargePayloads sends a 3mb tar file to the provisioner. The provisioner also
// returns large payload messages back. The limit should be 4mb, so all
// these messages should work.
t.Run("LargePayloads", func(t *testing.T) {
t.Parallel()
done := make(chan struct{})
t.Cleanup(func() {
close(done)
})
var (
largeSize = 3 * 1024 * 1024
completeChan = make(chan struct{})
completeOnce sync.Once
acq = newAcquireOne(t, &proto.AcquiredJob{
JobId: "test",
Provisioner: "someprovisioner",
TemplateSourceArchive: testutil.CreateTar(t, map[string]string{
"toolarge.txt": string(make([]byte, largeSize)),
}),
Type: &proto.AcquiredJob_TemplateImport_{
TemplateImport: &proto.AcquiredJob_TemplateImport{
Metadata: &sdkproto.Metadata{},
},
},
})
)
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
acquireJobWithCancel: acq.acquireWithCancel,
updateJob: noopUpdateJob,
completeJob: func(ctx context.Context, job *proto.CompletedJob) (*proto.Empty, error) {
completeOnce.Do(func() { close(completeChan) })
return &proto.Empty{}, nil
},
}), nil
}, provisionerd.LocalProvisioners{
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
parse: func(
s *provisionersdk.Session,
_ *sdkproto.ParseRequest,
cancelOrComplete <-chan struct{},
) *sdkproto.ParseComplete {
return &sdkproto.ParseComplete{
// 6mb readme
Readme: make([]byte, largeSize),
}
},
init: func(s *provisionersdk.Session, r *provisionersdk.InitRequest, canceledOrComplete <-chan struct{}) *sdkproto.InitComplete {
return &sdkproto.InitComplete{}
},
plan: func(
_ *provisionersdk.Session,
_ *sdkproto.PlanRequest,
_ <-chan struct{},
) *sdkproto.PlanComplete {
return &sdkproto.PlanComplete{
Plan: make([]byte, largeSize),
}
},
apply: func(
_ *provisionersdk.Session,
_ *sdkproto.ApplyRequest,
_ <-chan struct{},
) *sdkproto.ApplyComplete {
return &sdkproto.ApplyComplete{
State: make([]byte, largeSize),
}
},
graph: func(s *provisionersdk.Session, r *sdkproto.GraphRequest, canceledOrComplete <-chan struct{}) *sdkproto.GraphComplete {
return &sdkproto.GraphComplete{
Resources: []*sdkproto.Resource{},
}
},
}),
})
require.Condition(t, closedWithin(completeChan, testutil.WaitShort))
require.NoError(t, closer.Close())
})
t.Run("RunningPeriodicUpdate", func(t *testing.T) {
t.Parallel()
done := make(chan struct{})
t.Cleanup(func() {
close(done)
})
var (
completeChan = make(chan struct{})
completeOnce sync.Once
)
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
acquireJobWithCancel: func(stream proto.DRPCProvisionerDaemon_AcquireJobWithCancelStream) error {
err := stream.Send(&proto.AcquiredJob{
JobId: "test",
Provisioner: "someprovisioner",
TemplateSourceArchive: testutil.CreateTar(t, map[string]string{
"test.txt": "content",
}),
Type: &proto.AcquiredJob_TemplateImport_{
TemplateImport: &proto.AcquiredJob_TemplateImport{
Metadata: &sdkproto.Metadata{},
},
},
})
assert.NoError(t, err)
return nil
},
updateJob: func(ctx context.Context, update *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) {
completeOnce.Do(func() { close(completeChan) })
return &proto.UpdateJobResponse{}, nil
},
failJob: func(ctx context.Context, job *proto.FailedJob) (*proto.Empty, error) {
return &proto.Empty{}, nil
},
}), nil
}, provisionerd.LocalProvisioners{
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
parse: func(
_ *provisionersdk.Session,
_ *sdkproto.ParseRequest,
cancelOrComplete <-chan struct{},
) *sdkproto.ParseComplete {
<-cancelOrComplete
return &sdkproto.ParseComplete{}
},
init: func(s *provisionersdk.Session, r *provisionersdk.InitRequest, canceledOrComplete <-chan struct{}) *sdkproto.InitComplete {
return &sdkproto.InitComplete{}
},
}),
})
require.Condition(t, closedWithin(completeChan, testutil.WaitShort))
require.NoError(t, closer.Close())
})
t.Run("TemplateImport", func(t *testing.T) {
t.Parallel()
done := make(chan struct{})
t.Cleanup(func() {
close(done)
})
var (
didComplete atomic.Bool
didLog atomic.Bool
didReadme atomic.Bool
acq = newAcquireOne(t, &proto.AcquiredJob{
JobId: "test",
Provisioner: "someprovisioner",
TemplateSourceArchive: testutil.CreateTar(t, map[string]string{
"test.txt": "content",
tfpath.ReadmeFile: "# A cool template 😎\n",
}),
Type: &proto.AcquiredJob_TemplateImport_{
TemplateImport: &proto.AcquiredJob_TemplateImport{
Metadata: &sdkproto.Metadata{},
},
},
})
)
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
acquireJobWithCancel: acq.acquireWithCancel,
updateJob: func(ctx context.Context, update *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) {
if len(update.Logs) > 0 {
didLog.Store(true)
}
if len(update.Readme) > 0 {
didReadme.Store(true)
}
return &proto.UpdateJobResponse{}, nil
},
completeJob: func(ctx context.Context, job *proto.CompletedJob) (*proto.Empty, error) {
didComplete.Store(true)
return &proto.Empty{}, nil
},
}), nil
}, provisionerd.LocalProvisioners{
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
init: extractInit(t),
parse: func(
s *provisionersdk.Session,
_ *sdkproto.ParseRequest,
cancelOrComplete <-chan struct{},
) *sdkproto.ParseComplete {
data, err := os.ReadFile(filepath.Join(s.Files.WorkDirectory(), "test.txt"))
require.NoError(t, err)
require.Equal(t, "content", string(data))
s.ProvisionLog(sdkproto.LogLevel_INFO, "hello")
return &sdkproto.ParseComplete{}
},
plan: func(
s *provisionersdk.Session,
_ *sdkproto.PlanRequest,
cancelOrComplete <-chan struct{},
) *sdkproto.PlanComplete {
s.ProvisionLog(sdkproto.LogLevel_INFO, "hello")
return &sdkproto.PlanComplete{}
},
apply: func(
_ *provisionersdk.Session,
_ *sdkproto.ApplyRequest,
_ <-chan struct{},
) *sdkproto.ApplyComplete {
t.Error("dry run should not apply")
return &sdkproto.ApplyComplete{}
},
graph: func(s *provisionersdk.Session, r *sdkproto.GraphRequest, canceledOrComplete <-chan struct{}) *sdkproto.GraphComplete {
return &sdkproto.GraphComplete{
Resources: []*sdkproto.Resource{},
}
},
}),
})
require.Condition(t, closedWithin(acq.complete, testutil.WaitShort))
require.NoError(t, closer.Close())
assert.True(t, didLog.Load(), "should log some updates")
assert.True(t, didComplete.Load(), "should complete the job")
})
t.Run("TemplateDryRun", func(t *testing.T) {
t.Parallel()
done := make(chan struct{})
t.Cleanup(func() {
close(done)
})
var (
didComplete atomic.Bool
didLog atomic.Bool
metadata = &sdkproto.Metadata{}
acq = newAcquireOne(t, &proto.AcquiredJob{
JobId: "test",
Provisioner: "someprovisioner",
TemplateSourceArchive: testutil.CreateTar(t, map[string]string{
"test.txt": "content",
}),
Type: &proto.AcquiredJob_TemplateDryRun_{
TemplateDryRun: &proto.AcquiredJob_TemplateDryRun{
Metadata: metadata,
},
},
})
)
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
acquireJobWithCancel: acq.acquireWithCancel,
updateJob: func(ctx context.Context, update *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) {
if len(update.Logs) == 0 {
t.Log("provisionerDaemonTestServer: no log messages")
return &proto.UpdateJobResponse{}, nil
}
didLog.Store(true)
for _, msg := range update.Logs {
t.Log("provisionerDaemonTestServer", "msg:", msg)
}
return &proto.UpdateJobResponse{}, nil
},
completeJob: func(ctx context.Context, job *proto.CompletedJob) (*proto.Empty, error) {
didComplete.Store(true)
return &proto.Empty{}, nil
},
}), nil
}, provisionerd.LocalProvisioners{
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
init: func(s *provisionersdk.Session, r *provisionersdk.InitRequest, canceledOrComplete <-chan struct{}) *sdkproto.InitComplete {
return &sdkproto.InitComplete{}
},
plan: func(
_ *provisionersdk.Session,
_ *sdkproto.PlanRequest,
_ <-chan struct{},
) *sdkproto.PlanComplete {
return &sdkproto.PlanComplete{}
},
apply: func(
_ *provisionersdk.Session,
_ *sdkproto.ApplyRequest,
_ <-chan struct{},
) *sdkproto.ApplyComplete {
t.Error("dry run should not apply")
return &sdkproto.ApplyComplete{}
},
graph: func(s *provisionersdk.Session, r *sdkproto.GraphRequest, canceledOrComplete <-chan struct{}) *sdkproto.GraphComplete {
return &sdkproto.GraphComplete{
Resources: []*sdkproto.Resource{},
}
},
}),
})
require.Condition(t, closedWithin(acq.complete, testutil.WaitShort))
require.NoError(t, closer.Close())
assert.True(t, didLog.Load(), "should log some updates")
assert.True(t, didComplete.Load(), "should complete the job")
})
t.Run("WorkspaceBuild", func(t *testing.T) {
t.Parallel()
done := make(chan struct{})
t.Cleanup(func() {
close(done)
})
var (
didComplete atomic.Bool
didLog atomic.Bool
acq = newAcquireOne(t, &proto.AcquiredJob{
JobId: "test",
Provisioner: "someprovisioner",
TemplateSourceArchive: testutil.CreateTar(t, map[string]string{
"test.txt": "content",
}),
Type: &proto.AcquiredJob_WorkspaceBuild_{
WorkspaceBuild: &proto.AcquiredJob_WorkspaceBuild{
Metadata: &sdkproto.Metadata{},
},
},
})
)
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
acquireJobWithCancel: acq.acquireWithCancel,
updateJob: func(ctx context.Context, update *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) {
if len(update.Logs) != 0 {
didLog.Store(true)
}
return &proto.UpdateJobResponse{}, nil
},
completeJob: func(ctx context.Context, job *proto.CompletedJob) (*proto.Empty, error) {
didComplete.Store(true)
return &proto.Empty{}, nil
},
}), nil
}, provisionerd.LocalProvisioners{
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
init: func(s *provisionersdk.Session, r *provisionersdk.InitRequest, canceledOrComplete <-chan struct{}) *sdkproto.InitComplete {
return &sdkproto.InitComplete{}
},
plan: func(
s *provisionersdk.Session,
_ *sdkproto.PlanRequest,
cancelOrComplete <-chan struct{},
) *sdkproto.PlanComplete {
s.ProvisionLog(sdkproto.LogLevel_DEBUG, "wow")
return &sdkproto.PlanComplete{}
},
apply: func(
_ *provisionersdk.Session,
_ *sdkproto.ApplyRequest,
_ <-chan struct{},
) *sdkproto.ApplyComplete {
return &sdkproto.ApplyComplete{}
},
graph: func(s *provisionersdk.Session, r *sdkproto.GraphRequest, canceledOrComplete <-chan struct{}) *sdkproto.GraphComplete {
return &sdkproto.GraphComplete{}
},
}),
})
require.Condition(t, closedWithin(acq.complete, testutil.WaitShort))
require.NoError(t, closer.Close())
assert.True(t, didLog.Load(), "should log some updates")
assert.True(t, didComplete.Load(), "should complete the job")
})
t.Run("WorkspaceBuildQuotaExceeded", func(t *testing.T) {
t.Parallel()
done := make(chan struct{})
t.Cleanup(func() {
close(done)
})
var (
didComplete atomic.Bool
didLog atomic.Bool
didFail atomic.Bool
failedCode = atomic.NewString("")
acq = newAcquireOne(t, &proto.AcquiredJob{
JobId: "test",
Provisioner: "someprovisioner",
TemplateSourceArchive: testutil.CreateTar(t, map[string]string{
"test.txt": "content",
}),
Type: &proto.AcquiredJob_WorkspaceBuild_{
WorkspaceBuild: &proto.AcquiredJob_WorkspaceBuild{
Metadata: &sdkproto.Metadata{},
},
},
})
)
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
acquireJobWithCancel: acq.acquireWithCancel,
updateJob: func(ctx context.Context, update *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) {
if len(update.Logs) != 0 {
didLog.Store(true)
}
return &proto.UpdateJobResponse{}, nil
},
completeJob: func(ctx context.Context, job *proto.CompletedJob) (*proto.Empty, error) {
didComplete.Store(true)
return &proto.Empty{}, nil
},
commitQuota: func(ctx context.Context, com *proto.CommitQuotaRequest) (*proto.CommitQuotaResponse, error) {
return &proto.CommitQuotaResponse{
Ok: com.DailyCost < 20,
}, nil
},
failJob: func(ctx context.Context, job *proto.FailedJob) (*proto.Empty, error) {
didFail.Store(true)
failedCode.Store(job.ErrorCode)
return &proto.Empty{}, nil
},
}), nil
}, provisionerd.LocalProvisioners{
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
init: func(s *provisionersdk.Session, r *provisionersdk.InitRequest, canceledOrComplete <-chan struct{}) *sdkproto.InitComplete {
return &sdkproto.InitComplete{}
},
plan: func(
s *provisionersdk.Session,
_ *sdkproto.PlanRequest,
cancelOrComplete <-chan struct{},
) *sdkproto.PlanComplete {
s.ProvisionLog(sdkproto.LogLevel_DEBUG, "wow")
return &sdkproto.PlanComplete{
DailyCost: 25,
}
},
apply: func(
_ *provisionersdk.Session,
_ *sdkproto.ApplyRequest,
_ <-chan struct{},
) *sdkproto.ApplyComplete {
t.Error("should not apply when resources exceed quota")
return &sdkproto.ApplyComplete{}
},
graph: func(s *provisionersdk.Session, r *sdkproto.GraphRequest, canceledOrComplete <-chan struct{}) *sdkproto.GraphComplete {
return &sdkproto.GraphComplete{
Resources: []*sdkproto.Resource{
{
DailyCost: 10,
},
{
DailyCost: 15,
},
},
}
},
}),
})
require.Condition(t, closedWithin(acq.complete, testutil.WaitShort))
require.NoError(t, closer.Close())
assert.True(t, didLog.Load(), "should log some updates")
assert.False(t, didComplete.Load(), "should not complete the job")
assert.Equal(t, runner.InsufficientQuotaErrorCode, failedCode.Load())
assert.True(t, didFail.Load(), "should fail the job")
})
t.Run("WorkspaceBuildFailComplete", func(t *testing.T) {
t.Parallel()
done := make(chan struct{})
t.Cleanup(func() {
close(done)
})
var (
didFail atomic.Bool
acq = newAcquireOne(t, &proto.AcquiredJob{
JobId: "test",
Provisioner: "someprovisioner",
TemplateSourceArchive: testutil.CreateTar(t, map[string]string{
"test.txt": "content",
}),
Type: &proto.AcquiredJob_WorkspaceBuild_{
WorkspaceBuild: &proto.AcquiredJob_WorkspaceBuild{
Metadata: &sdkproto.Metadata{},
},
},
})
)
closer := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
acquireJobWithCancel: acq.acquireWithCancel,
updateJob: noopUpdateJob,
failJob: func(ctx context.Context, job *proto.FailedJob) (*proto.Empty, error) {
didFail.Store(true)
return &proto.Empty{}, nil
},
}), nil
}, provisionerd.LocalProvisioners{
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
init: func(s *provisionersdk.Session, r *provisionersdk.InitRequest, canceledOrComplete <-chan struct{}) *sdkproto.InitComplete {
return &sdkproto.InitComplete{}
},
graph: func(s *provisionersdk.Session, r *sdkproto.GraphRequest, canceledOrComplete <-chan struct{}) *sdkproto.GraphComplete {
return &sdkproto.GraphComplete{}
},
plan: func(
s *provisionersdk.Session,
_ *sdkproto.PlanRequest,
cancelOrComplete <-chan struct{},
) *sdkproto.PlanComplete {
return &sdkproto.PlanComplete{
Error: "some error",
}
},
apply: func(
_ *provisionersdk.Session,
_ *sdkproto.ApplyRequest,
_ <-chan struct{},
) *sdkproto.ApplyComplete {
t.Error("should not apply when plan errors")
return &sdkproto.ApplyComplete{
Error: "some error",
}
},
}),
})
require.Condition(t, closedWithin(acq.complete, testutil.WaitShort))
require.NoError(t, closer.Close())
assert.True(t, didFail.Load(), "should fail the job")
})
// Simulates when there is no coderd to connect to. So the client connection
// will never be established.
t.Run("ShutdownNoCoderd", func(t *testing.T) {
t.Parallel()
done := make(chan struct{})
t.Cleanup(func() {
close(done)
})
connectAttemptedClose := sync.Once{}
connectAttempted := make(chan struct{})
server := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
// This is the dial out to Coderd, which in this unit test will always fail.
connectAttemptedClose.Do(func() { close(connectAttempted) })
return nil, xerrors.New("client connection always fails")
}, provisionerd.LocalProvisioners{
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{}),
})
// Wait for at least 1 attempt to connect to ensure the connect go routine
// is running.
require.Condition(t, closedWithin(connectAttempted, testutil.WaitShort))
// The test is ensuring this Shutdown call does not block indefinitely.
// If it does, the context will return with an error, and the test will
// fail.
shutdownCtx := testutil.Context(t, testutil.WaitShort)
err := server.Shutdown(shutdownCtx, true)
require.NoError(t, err, "shutdown did not unblock. Failed to close the server gracefully.")
require.NoError(t, server.Close())
})
t.Run("Shutdown", func(t *testing.T) {
t.Parallel()
done := make(chan struct{})
t.Cleanup(func() {
close(done)
})
var updated sync.Once
var completed sync.Once
updateChan := make(chan struct{})
completeChan := make(chan struct{})
server := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
acquireJobWithCancel: func(stream proto.DRPCProvisionerDaemon_AcquireJobWithCancelStream) error {
err := stream.Send(&proto.AcquiredJob{
JobId: "test",
Provisioner: "someprovisioner",
TemplateSourceArchive: testutil.CreateTar(t, map[string]string{
"test.txt": "content",
}),
Type: &proto.AcquiredJob_WorkspaceBuild_{
WorkspaceBuild: &proto.AcquiredJob_WorkspaceBuild{
Metadata: &sdkproto.Metadata{},
},
},
})
assert.NoError(t, err)
return nil
},
updateJob: func(ctx context.Context, update *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) {
if len(update.Logs) > 0 {
for _, log := range update.Logs {
if log.Source != proto.LogSource_PROVISIONER {
continue
}
// Close on a log so we know when the job is in progress!
updated.Do(func() {
close(updateChan)
})
break
}
}
return &proto.UpdateJobResponse{}, nil
},
failJob: func(ctx context.Context, job *proto.FailedJob) (*proto.Empty, error) {
completed.Do(func() {
close(completeChan)
})
return &proto.Empty{}, nil
},
}), nil
}, provisionerd.LocalProvisioners{
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
init: func(s *provisionersdk.Session, r *provisionersdk.InitRequest, canceledOrComplete <-chan struct{}) *sdkproto.InitComplete {
return &sdkproto.InitComplete{}
},
plan: func(
s *provisionersdk.Session,
_ *sdkproto.PlanRequest,
canceledOrComplete <-chan struct{},
) *sdkproto.PlanComplete {
s.ProvisionLog(sdkproto.LogLevel_DEBUG, "in progress")
<-canceledOrComplete
return &sdkproto.PlanComplete{
Error: "some error",
}
},
apply: func(
_ *provisionersdk.Session,
_ *sdkproto.ApplyRequest,
_ <-chan struct{},
) *sdkproto.ApplyComplete {
t.Error("should not apply when shut down during plan")
return &sdkproto.ApplyComplete{}
},
}),
})
require.Condition(t, closedWithin(updateChan, testutil.WaitShort))
err := server.Shutdown(context.Background(), true)
require.NoError(t, err)
require.Condition(t, closedWithin(completeChan, testutil.WaitShort))
require.NoError(t, server.Close())
})
t.Run("ShutdownFromJob", func(t *testing.T) {
t.Parallel()
done := make(chan struct{})
t.Cleanup(func() {
close(done)
})
var completed sync.Once
var updated sync.Once
updateChan := make(chan struct{})
completeChan := make(chan struct{})
server := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
acquireJobWithCancel: func(stream proto.DRPCProvisionerDaemon_AcquireJobWithCancelStream) error {
err := stream.Send(&proto.AcquiredJob{
JobId: "test",
Provisioner: "someprovisioner",
TemplateSourceArchive: testutil.CreateTar(t, map[string]string{
"test.txt": "content",
}),
Type: &proto.AcquiredJob_WorkspaceBuild_{
WorkspaceBuild: &proto.AcquiredJob_WorkspaceBuild{
Metadata: &sdkproto.Metadata{},
},
},
})
assert.NoError(t, err)
return nil
},
updateJob: func(ctx context.Context, update *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) {
resp := &proto.UpdateJobResponse{}
if len(update.Logs) > 0 {
for _, log := range update.Logs {
if log.Source != proto.LogSource_PROVISIONER {
continue
}
// Close on a log so we know when the job is in progress!
updated.Do(func() {
close(updateChan)
})
break
}
}
// start returning Canceled once we've gotten at least one log.
select {
case <-updateChan:
resp.Canceled = true
default:
// pass
}
return resp, nil
},
failJob: func(ctx context.Context, job *proto.FailedJob) (*proto.Empty, error) {
completed.Do(func() {
close(completeChan)
})
return &proto.Empty{}, nil
},
}), nil
}, provisionerd.LocalProvisioners{
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
init: func(s *provisionersdk.Session, r *provisionersdk.InitRequest, canceledOrComplete <-chan struct{}) *sdkproto.InitComplete {
return &sdkproto.InitComplete{}
},
plan: func(
s *provisionersdk.Session,
_ *sdkproto.PlanRequest,
canceledOrComplete <-chan struct{},
) *sdkproto.PlanComplete {
s.ProvisionLog(sdkproto.LogLevel_DEBUG, "in progress")
<-canceledOrComplete
return &sdkproto.PlanComplete{
Error: "some error",
}
},
apply: func(
_ *provisionersdk.Session,
_ *sdkproto.ApplyRequest,
_ <-chan struct{},
) *sdkproto.ApplyComplete {
t.Error("should not apply when shut down during plan")
return &sdkproto.ApplyComplete{}
},
}),
})
require.Condition(t, closedWithin(updateChan, testutil.WaitShort))
require.Condition(t, closedWithin(completeChan, testutil.WaitShort))
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
require.NoError(t, server.Shutdown(ctx, true))
require.NoError(t, server.Close())
})
t.Run("ReconnectAndFail", func(t *testing.T) {
t.Parallel()
done := make(chan struct{})
t.Cleanup(func() {
close(done)
})
var (
second atomic.Bool
failChan = make(chan struct{})
failOnce sync.Once
failedChan = make(chan struct{})
failedOnce sync.Once
completeChan = make(chan struct{})
completeOnce sync.Once
)
server := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
client := createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
acquireJobWithCancel: func(stream proto.DRPCProvisionerDaemon_AcquireJobWithCancelStream) error {
job := &proto.AcquiredJob{
JobId: "test",
Provisioner: "someprovisioner",
TemplateSourceArchive: testutil.CreateTar(t, map[string]string{
"test.txt": "content",
}),
Type: &proto.AcquiredJob_WorkspaceBuild_{
WorkspaceBuild: &proto.AcquiredJob_WorkspaceBuild{
Metadata: &sdkproto.Metadata{},
},
},
}
if second.Load() {
job = &proto.AcquiredJob{}
_, err := stream.Recv()
assert.NoError(t, err)
}
err := stream.Send(job)
assert.NoError(t, err)
return nil
},
updateJob: func(ctx context.Context, update *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) {
return &proto.UpdateJobResponse{}, nil
},
failJob: func(ctx context.Context, job *proto.FailedJob) (*proto.Empty, error) {
assert.Equal(t, job.JobId, "test")
if second.Load() {
completeOnce.Do(func() { close(completeChan) })
return &proto.Empty{}, nil
}
failOnce.Do(func() { close(failChan) })
<-failedChan
return &proto.Empty{}, nil
},
})
if !second.Load() {
go func() {
<-failChan
_ = client.DRPCConn().Close()
second.Store(true)
time.Sleep(50 * time.Millisecond)
failedOnce.Do(func() { close(failedChan) })
}()
}
return client, nil
}, provisionerd.LocalProvisioners{
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
init: func(s *provisionersdk.Session, r *provisionersdk.InitRequest, canceledOrComplete <-chan struct{}) *sdkproto.InitComplete {
return &sdkproto.InitComplete{}
},
plan: func(
_ *provisionersdk.Session,
_ *sdkproto.PlanRequest,
_ <-chan struct{},
) *sdkproto.PlanComplete {
return &sdkproto.PlanComplete{
Error: "some error",
}
},
apply: func(
_ *provisionersdk.Session,
_ *sdkproto.ApplyRequest,
_ <-chan struct{},
) *sdkproto.ApplyComplete {
t.Error("should not apply when error during plan")
return &sdkproto.ApplyComplete{}
},
}),
})
require.Condition(t, closedWithin(completeChan, testutil.WaitShort))
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
require.NoError(t, server.Shutdown(ctx, true))
require.NoError(t, server.Close())
})
t.Run("ReconnectAndComplete", func(t *testing.T) {
t.Parallel()
done := make(chan struct{})
t.Cleanup(func() {
close(done)
})
var (
second atomic.Bool
failChan = make(chan struct{})
failOnce sync.Once
failedChan = make(chan struct{})
failedOnce sync.Once
completeChan = make(chan struct{})
completeOnce sync.Once
)
server := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
client := createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
acquireJobWithCancel: func(stream proto.DRPCProvisionerDaemon_AcquireJobWithCancelStream) error {
if second.Load() {
completeOnce.Do(func() { close(completeChan) })
_, err := stream.Recv()
assert.NoError(t, err)
return nil
}
job := &proto.AcquiredJob{
JobId: "test",
Provisioner: "someprovisioner",
TemplateSourceArchive: testutil.CreateTar(t, map[string]string{
"test.txt": "content",
}),
Type: &proto.AcquiredJob_WorkspaceBuild_{
WorkspaceBuild: &proto.AcquiredJob_WorkspaceBuild{
Metadata: &sdkproto.Metadata{},
},
},
}
err := stream.Send(job)
assert.NoError(t, err)
return nil
},
failJob: func(ctx context.Context, job *proto.FailedJob) (*proto.Empty, error) {
return nil, yamux.ErrSessionShutdown
},
updateJob: func(ctx context.Context, update *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) {
return &proto.UpdateJobResponse{}, nil
},
completeJob: func(ctx context.Context, job *proto.CompletedJob) (*proto.Empty, error) {
if second.Load() {
return &proto.Empty{}, nil
}
failOnce.Do(func() { close(failChan) })
<-failedChan
return &proto.Empty{}, nil
},
})
if !second.Load() {
go func() {
<-failChan
_ = client.DRPCConn().Close()
second.Store(true)
time.Sleep(50 * time.Millisecond)
failedOnce.Do(func() { close(failedChan) })
}()
}
return client, nil
}, provisionerd.LocalProvisioners{
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
init: func(s *provisionersdk.Session, r *provisionersdk.InitRequest, canceledOrComplete <-chan struct{}) *sdkproto.InitComplete {
return &sdkproto.InitComplete{}
},
plan: func(
_ *provisionersdk.Session,
_ *sdkproto.PlanRequest,
_ <-chan struct{},
) *sdkproto.PlanComplete {
return &sdkproto.PlanComplete{}
},
apply: func(
_ *provisionersdk.Session,
_ *sdkproto.ApplyRequest,
_ <-chan struct{},
) *sdkproto.ApplyComplete {
return &sdkproto.ApplyComplete{}
},
graph: func(s *provisionersdk.Session, r *sdkproto.GraphRequest, canceledOrComplete <-chan struct{}) *sdkproto.GraphComplete {
return &sdkproto.GraphComplete{}
},
}),
})
require.Condition(t, closedWithin(completeChan, testutil.WaitShort))
t.Log("completeChan closed")
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
require.NoError(t, server.Shutdown(ctx, true))
require.NoError(t, server.Close())
})
t.Run("UpdatesBeforeComplete", func(t *testing.T) {
t.Parallel()
done := make(chan struct{})
t.Cleanup(func() {
close(done)
})
logger := slogtest.Make(t, &slogtest.Options{IgnoreErrors: true})
m := sync.Mutex{}
var ops []string
completeChan := make(chan struct{})
completeOnce := sync.Once{}
server := createProvisionerd(t, func(ctx context.Context) (proto.DRPCProvisionerDaemonClient, error) {
return createProvisionerDaemonClient(t, done, provisionerDaemonTestServer{
acquireJobWithCancel: func(stream proto.DRPCProvisionerDaemon_AcquireJobWithCancelStream) error {
m.Lock()
defer m.Unlock()
logger.Info(ctx, "provisioner stage: AcquiredJob")
if len(ops) > 0 {
_, err := stream.Recv()
assert.NoError(t, err)
err = stream.Send(&proto.AcquiredJob{})
assert.NoError(t, err)
return nil
}
ops = append(ops, "AcquireJob")
err := stream.Send(&proto.AcquiredJob{
JobId: "test",
Provisioner: "someprovisioner",
TemplateSourceArchive: testutil.CreateTar(t, map[string]string{
"test.txt": "content",
}),
Type: &proto.AcquiredJob_WorkspaceBuild_{
WorkspaceBuild: &proto.AcquiredJob_WorkspaceBuild{
Metadata: &sdkproto.Metadata{},
},
},
})
assert.NoError(t, err)
return nil
},
updateJob: func(ctx context.Context, update *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) {
m.Lock()
defer m.Unlock()
logger.Info(ctx, "provisioner stage: UpdateJob")
ops = append(ops, "UpdateJob")
for _, log := range update.Logs {
ops = append(ops, fmt.Sprintf("Log: %s | %s", log.Stage, log.Output))
}
return &proto.UpdateJobResponse{}, nil
},
completeJob: func(ctx context.Context, job *proto.CompletedJob) (*proto.Empty, error) {
m.Lock()
defer m.Unlock()
logger.Info(ctx, "provisioner stage: CompleteJob")
ops = append(ops, "CompleteJob")
completeOnce.Do(func() { close(completeChan) })
return &proto.Empty{}, nil
},
failJob: func(ctx context.Context, job *proto.FailedJob) (*proto.Empty, error) {
m.Lock()
defer m.Unlock()
logger.Info(ctx, "provisioner stage: FailJob")
ops = append(ops, "FailJob")
return &proto.Empty{}, nil
},
}), nil
}, provisionerd.LocalProvisioners{
"someprovisioner": createProvisionerClient(t, done, provisionerTestServer{
init: func(s *provisionersdk.Session, r *provisionersdk.InitRequest, canceledOrComplete <-chan struct{}) *sdkproto.InitComplete {
return &sdkproto.InitComplete{}
},
graph: func(s *provisionersdk.Session, r *sdkproto.GraphRequest, canceledOrComplete <-chan struct{}) *sdkproto.GraphComplete {
return &sdkproto.GraphComplete{}
},
plan: func(
s *provisionersdk.Session,
_ *sdkproto.PlanRequest,
_ <-chan struct{},
) *sdkproto.PlanComplete {
s.ProvisionLog(sdkproto.LogLevel_DEBUG, "wow")
return &sdkproto.PlanComplete{}
},
apply: func(
s *provisionersdk.Session,
_ *sdkproto.ApplyRequest,
_ <-chan struct{},
) *sdkproto.ApplyComplete {
s.ProvisionLog(sdkproto.LogLevel_DEBUG, "wow")
return &sdkproto.ApplyComplete{}
},
}),
})
require.Condition(t, closedWithin(completeChan, testutil.WaitShort))
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
require.NoError(t, server.Shutdown(ctx, true))
require.NoError(t, server.Close())
assert.Equal(t, ops[len(ops)-1], "CompleteJob")
assert.Contains(t, ops[0:len(ops)-1], "Log: Cleaning Up | ")
})
}
// Creates a provisionerd implementation with the provided dialer and provisioners.
func createProvisionerd(t *testing.T, dialer provisionerd.Dialer, connector provisionerd.LocalProvisioners) *provisionerd.Server {
server := provisionerd.New(dialer, &provisionerd.Options{
Logger: slogtest.Make(t, &slogtest.Options{IgnoreErrors: true}).Named("provisionerd").Leveled(slog.LevelDebug),
UpdateInterval: 50 * time.Millisecond,
Connector: connector,
})
t.Cleanup(func() {
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitShort)
defer cancel()
_ = server.Shutdown(ctx, true)
_ = server.Close()
})
return server
}
// Creates a provisionerd protobuf client that's connected
// to the server implementation provided.
func createProvisionerDaemonClient(t *testing.T, done <-chan struct{}, server provisionerDaemonTestServer) proto.DRPCProvisionerDaemonClient {
t.Helper()
if server.failJob == nil {
// Default to asserting the error from the failure, otherwise
// it can be lost in tests!
server.failJob = func(ctx context.Context, job *proto.FailedJob) (*proto.Empty, error) {
assert.Fail(t, job.Error)
return &proto.Empty{}, nil
}
}
clientPipe, serverPipe := drpcsdk.MemTransportPipe()
t.Cleanup(func() {
_ = clientPipe.Close()
_ = serverPipe.Close()
})
mux := drpcmux.New()
err := proto.DRPCRegisterProvisionerDaemon(mux, &server)
require.NoError(t, err)
srv := drpcserver.NewWithOptions(mux, drpcserver.Options{
Manager: drpcsdk.DefaultDRPCOptions(nil),
})
ctx, cancelFunc := context.WithCancel(context.Background())
closed := make(chan struct{})
go func() {
defer close(closed)
_ = srv.Serve(ctx, serverPipe)
}()
t.Cleanup(func() {
cancelFunc()
<-closed
select {
case <-done:
t.Error("createProvisionerDaemonClient cleanup after test was done!")
default:
}
})
select {
case <-done:
t.Error("called createProvisionerDaemonClient after test was done!")
default:
}
return proto.NewDRPCProvisionerDaemonClient(clientPipe)
}
// Creates a provisioner protobuf client that's connected
// to the server implementation provided.
func createProvisionerClient(t *testing.T, done <-chan struct{}, server provisionerTestServer) sdkproto.DRPCProvisionerClient {
t.Helper()
clientPipe, serverPipe := drpcsdk.MemTransportPipe()
t.Cleanup(func() {
_ = clientPipe.Close()
_ = serverPipe.Close()
})
ctx, cancelFunc := context.WithCancel(context.Background())
closed := make(chan struct{})
tempDir := t.TempDir()
go func() {
defer close(closed)
_ = provisionersdk.Serve(ctx, &server, &provisionersdk.ServeOptions{
Listener: serverPipe,
Logger: testutil.Logger(t).Named("test-provisioner"),
WorkDirectory: tempDir,
})
}()
t.Cleanup(func() {
cancelFunc()
<-closed
select {
case <-done:
t.Error("createProvisionerClient cleanup after test was done!")
default:
}
})
select {
case <-done:
t.Error("called createProvisionerClient after test was done!")
default:
}
return sdkproto.NewDRPCProvisionerClient(clientPipe)
}
type provisionerTestServer struct {
init func(s *provisionersdk.Session, r *provisionersdk.InitRequest, canceledOrComplete <-chan struct{}) *sdkproto.InitComplete
parse func(s *provisionersdk.Session, r *sdkproto.ParseRequest, canceledOrComplete <-chan struct{}) *sdkproto.ParseComplete
plan func(s *provisionersdk.Session, r *sdkproto.PlanRequest, canceledOrComplete <-chan struct{}) *sdkproto.PlanComplete
apply func(s *provisionersdk.Session, r *sdkproto.ApplyRequest, canceledOrComplete <-chan struct{}) *sdkproto.ApplyComplete
graph func(s *provisionersdk.Session, r *sdkproto.GraphRequest, canceledOrComplete <-chan struct{}) *sdkproto.GraphComplete
}
func (p *provisionerTestServer) Init(s *provisionersdk.Session, r *provisionersdk.InitRequest, canceledOrComplete <-chan struct{}) *sdkproto.InitComplete {
return p.init(s, r, canceledOrComplete)
}
func (p *provisionerTestServer) Parse(s *provisionersdk.Session, r *sdkproto.ParseRequest, canceledOrComplete <-chan struct{}) *sdkproto.ParseComplete {
return p.parse(s, r, canceledOrComplete)
}
func (p *provisionerTestServer) Plan(s *provisionersdk.Session, r *sdkproto.PlanRequest, canceledOrComplete <-chan struct{}) *sdkproto.PlanComplete {
return p.plan(s, r, canceledOrComplete)
}
func (p *provisionerTestServer) Apply(s *provisionersdk.Session, r *sdkproto.ApplyRequest, canceledOrComplete <-chan struct{}) *sdkproto.ApplyComplete {
return p.apply(s, r, canceledOrComplete)
}
func (p *provisionerTestServer) Graph(s *provisionersdk.Session, r *sdkproto.GraphRequest, canceledOrComplete <-chan struct{}) *sdkproto.GraphComplete {
return p.graph(s, r, canceledOrComplete)
}
func (p *provisionerDaemonTestServer) UploadFile(stream proto.DRPCProvisionerDaemon_UploadFileStream) error {
return p.uploadFile(stream)
}
var _ proto.DRPCProvisionerDaemonServer = (*provisionerDaemonTestServer)(nil)
// Fulfills the protobuf interface for a ProvisionerDaemon with
// passable functions for dynamic functionality.
type provisionerDaemonTestServer struct {
acquireJobWithCancel func(stream proto.DRPCProvisionerDaemon_AcquireJobWithCancelStream) error
commitQuota func(ctx context.Context, com *proto.CommitQuotaRequest) (*proto.CommitQuotaResponse, error)
updateJob func(ctx context.Context, update *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error)
failJob func(ctx context.Context, job *proto.FailedJob) (*proto.Empty, error)
completeJob func(ctx context.Context, job *proto.CompletedJob) (*proto.Empty, error)
uploadFile func(stream proto.DRPCProvisionerDaemon_UploadFileStream) error
downloadFile func(request *proto.FileRequest, stream proto.DRPCProvisionerDaemon_DownloadFileStream) error
}
func (p *provisionerDaemonTestServer) DownloadFile(request *proto.FileRequest, stream proto.DRPCProvisionerDaemon_DownloadFileStream) error {
return p.downloadFile(request, stream)
}
func (*provisionerDaemonTestServer) AcquireJob(context.Context, *proto.Empty) (*proto.AcquiredJob, error) {
return nil, xerrors.New("deprecated!")
}
func (p *provisionerDaemonTestServer) AcquireJobWithCancel(stream proto.DRPCProvisionerDaemon_AcquireJobWithCancelStream) error {
if p.acquireJobWithCancel != nil {
return p.acquireJobWithCancel(stream)
}
// default behavior is to wait for cancel
_, _ = stream.Recv()
_ = stream.Send(&proto.AcquiredJob{})
return nil
}
func (p *provisionerDaemonTestServer) CommitQuota(ctx context.Context, com *proto.CommitQuotaRequest) (*proto.CommitQuotaResponse, error) {
if p.commitQuota == nil {
return &proto.CommitQuotaResponse{
Ok: true,
}, nil
}
return p.commitQuota(ctx, com)
}
func (p *provisionerDaemonTestServer) UpdateJob(ctx context.Context, update *proto.UpdateJobRequest) (*proto.UpdateJobResponse, error) {
return p.updateJob(ctx, update)
}
func (p *provisionerDaemonTestServer) FailJob(ctx context.Context, job *proto.FailedJob) (*proto.Empty, error) {
return p.failJob(ctx, job)
}
func (p *provisionerDaemonTestServer) CompleteJob(ctx context.Context, job *proto.CompletedJob) (*proto.Empty, error) {
return p.completeJob(ctx, job)
}
// acquireOne provides a function that returns a single provisioner job, then subsequent calls block until canceled.
// The complete channel is closed on the 2nd call.
type acquireOne struct {
t *testing.T
mu sync.Mutex
job *proto.AcquiredJob
called int
complete chan struct{}
}
func newAcquireOne(t *testing.T, job *proto.AcquiredJob) *acquireOne {
return &acquireOne{
t: t,
job: job,
complete: make(chan struct{}),
}
}
func (a *acquireOne) acquireWithCancel(stream proto.DRPCProvisionerDaemon_AcquireJobWithCancelStream) error {
a.mu.Lock()
defer a.mu.Unlock()
a.called++
if a.called == 2 {
close(a.complete)
}
if a.called > 1 {
_, _ = stream.Recv()
_ = stream.Send(&proto.AcquiredJob{})
return nil
}
err := stream.Send(a.job)
// dRPC is racy, and sometimes will return context.Canceled after it has successfully sent the message if we cancel
// right away, e.g. in unit tests that complete. So, just swallow the error in that case. If we are canceled before
// the job was acquired, presumably something else in the test will have failed.
if !xerrors.Is(err, context.Canceled) {
assert.NoError(a.t, err)
}
return nil
}
func extractInit(t *testing.T) func(s *provisionersdk.Session, r *provisionersdk.InitRequest, canceledOrComplete <-chan struct{}) *sdkproto.InitComplete {
logger := slogtest.Make(t, nil)
return func(s *provisionersdk.Session, r *provisionersdk.InitRequest, canceledOrComplete <-chan struct{}) *sdkproto.InitComplete {
err := s.Files.ExtractArchive(s.Context(), logger, afero.NewOsFs(), r.TemplateSourceArchive, nil)
if err != nil {
return &sdkproto.InitComplete{
Error: fmt.Sprintf("failed to extract template source archive: %v", err),
}
}
return &sdkproto.InitComplete{}
}
}