diff --git a/coderd/provisionerdserver/provisionerdserver.go b/coderd/provisionerdserver/provisionerdserver.go index 6a06edb357..1ce46670a9 100644 --- a/coderd/provisionerdserver/provisionerdserver.go +++ b/coderd/provisionerdserver/provisionerdserver.go @@ -509,6 +509,18 @@ func (s *server) acquireProtoJob(ctx context.Context, job database.ProvisionerJo if err != nil { return nil, failJob(fmt.Sprintf("get owner: %s", err)) } + + // Fetch the file id of the cached module files if it exists. + versionModulesFile := "" + tfvals, err := s.Database.GetTemplateVersionTerraformValues(ctx, templateVersion.ID) + if err != nil && !xerrors.Is(err, sql.ErrNoRows) { + // Older templates (before dynamic parameters) will not have cached module files. + return nil, failJob(fmt.Sprintf("get template version terraform values: %s", err)) + } + if err == nil && tfvals.CachedModuleFiles.Valid { + versionModulesFile = tfvals.CachedModuleFiles.UUID.String() + } + var ownerSSHPublicKey, ownerSSHPrivateKey string if ownerSSHKey, err := s.Database.GetGitSSHKey(ctx, owner.ID); err != nil { if !xerrors.Is(err, sql.ErrNoRows) { @@ -732,6 +744,7 @@ func (s *server) acquireProtoJob(ctx context.Context, job database.ProvisionerJo PrebuiltWorkspaceBuildStage: input.PrebuiltWorkspaceBuildStage, TaskId: task.ID.String(), TaskPrompt: task.Prompt, + TemplateVersionModulesFile: versionModulesFile, }, LogLevel: input.LogLevel, }, @@ -1423,54 +1436,12 @@ func (s *server) prepareForNotifyWorkspaceManualBuildFailed(ctx context.Context, func (s *server) UploadFile(stream proto.DRPCProvisionerDaemon_UploadFileStream) error { var file *sdkproto.DataBuilder // Always terminate the stream with an empty response. + //nolint:errcheck // We can't do much about send errors here. defer stream.SendAndClose(&proto.Empty{}) -UploadFileStream: - for { - msg, err := stream.Recv() - if err != nil { - return xerrors.Errorf("receive complete job with files: %w", err) - } - - switch typed := msg.Type.(type) { - case *sdkproto.FileUpload_DataUpload: - if file != nil { - return xerrors.New("unexpected file upload while waiting for file completion") - } - - file, err = sdkproto.NewDataBuilder(&sdkproto.DataUpload{ - UploadType: typed.DataUpload.UploadType, - DataHash: typed.DataUpload.DataHash, - FileSize: typed.DataUpload.FileSize, - Chunks: typed.DataUpload.Chunks, - }) - if err != nil { - return xerrors.Errorf("unable to create file upload: %w", err) - } - - if file.IsDone() { - // If a file is 0 bytes, we can consider it done immediately. - // This should never really happen in practice, but we handle it gracefully. - break UploadFileStream - } - case *sdkproto.FileUpload_ChunkPiece: - if file == nil { - return xerrors.New("unexpected chunk piece while waiting for file upload") - } - - done, err := file.Add(&sdkproto.ChunkPiece{ - Data: typed.ChunkPiece.Data, - FullDataHash: typed.ChunkPiece.FullDataHash, - PieceIndex: typed.ChunkPiece.PieceIndex, - }) - if err != nil { - return xerrors.Errorf("unable to add chunk piece: %w", err) - } - - if done { - break UploadFileStream - } - } + file, err := provisionersdk.HandleReceivingDataUpload(stream) + if err != nil { + return err } fileData, err := file.Complete() @@ -1518,9 +1489,71 @@ UploadFileStream: return nil } -func (*server) DownloadFile(_ *proto.FileRequest, _ proto.DRPCProvisionerDaemon_DownloadFileStream) error { - // TODO implemented in follow up PR - panic("implement me") +// DownloadFile pulls the requested file from the database and sends it over the protobuf stream in chunks. +func (s *server) DownloadFile(request *proto.FileRequest, stream proto.DRPCProvisionerDaemon_DownloadFileStream) error { + //nolint:errcheck + defer stream.CloseSend() + //nolint:gocritic // Provisionerd is the actor here. + ctx := dbauthz.AsProvisionerd(stream.Context()) + + // A graceful error message will help debugging. + fail := func(err error) error { + _ = stream.Send(&sdkproto.FileUpload{ + Type: &sdkproto.FileUpload_Error{ + Error: &sdkproto.FailedFile{ + Error: err.Error(), + }, + }, + }) + return err + } + if request.FileId == "" || request.FileId == uuid.Nil.String() { + return fail(xerrors.New("file id is required")) + } + + fid, err := uuid.Parse(request.FileId) + if err != nil { + return fail(xerrors.Errorf("invalid file id: %w", err)) + } + + file, err := s.Database.GetFileByID(ctx, fid) + if err != nil { + return fail(xerrors.Errorf("get file: %w", err)) + } + + switch request.UploadType { + case sdkproto.DataUploadType_UPLOAD_TYPE_MODULE_FILES: + // This check is not perfect. If these conditions are not true, then the file is not a modules file. + if file.CreatedBy != uuid.Nil || file.Mimetype != tarMimeType { + return fail(xerrors.Errorf("file %s is not a modules file", fid)) + } + default: + return fail(xerrors.Errorf("unsupported file upload type: %s", request.UploadType)) + } + + upload, chunks := sdkproto.BytesToDataUpload(sdkproto.DataUploadType_UPLOAD_TYPE_MODULE_FILES, file.Data) + + err = stream.Send(&sdkproto.FileUpload{ + Type: &sdkproto.FileUpload_DataUpload{DataUpload: upload}, + }) + if err != nil { + return fail(xerrors.Errorf("send file upload: %w", err)) + } + + for i, c := range chunks { + if ctx.Err() != nil { + return fail(ctx.Err()) + } + + err = stream.Send(&sdkproto.FileUpload{ + Type: &sdkproto.FileUpload_ChunkPiece{ChunkPiece: c}, + }) + if err != nil { + return fail(xerrors.Errorf("send chunk piece %d: %w", i, err)) + } + } + + return nil } // CompleteJob is triggered by a provision daemon to mark a provisioner job as completed. diff --git a/coderd/provisionerdserver/upload_file_test.go b/coderd/provisionerdserver/upload_file_test.go index b503a6781f..d041bb9f98 100644 --- a/coderd/provisionerdserver/upload_file_test.go +++ b/coderd/provisionerdserver/upload_file_test.go @@ -120,7 +120,7 @@ func TestUploadFileErrorScenarios(t *testing.T) { stream.messages <- up err := server.UploadFile(stream) - require.ErrorContains(t, err, "unexpected file upload while waiting for file completion") + require.ErrorContains(t, err, "unexpected file download while waiting for file completion") require.True(t, stream.isDone(), "stream should be done after error") }) diff --git a/provisioner/echo/serve.go b/provisioner/echo/serve.go index 2a695aebaf..ced9f89586 100644 --- a/provisioner/echo/serve.go +++ b/provisioner/echo/serve.go @@ -205,8 +205,8 @@ func (*echo) Parse(sess *provisionersdk.Session, _ *proto.ParseRequest, _ <-chan return provisionersdk.ParseErrorf("complete response missing") } -func (*echo) Init(sess *provisionersdk.Session, req *proto.InitRequest, canceledOrComplete <-chan struct{}) *proto.InitComplete { - err := sess.Files.ExtractArchive(sess.Context(), sess.Logger, afero.NewOsFs(), req.TemplateSourceArchive) +func (*echo) Init(sess *provisionersdk.Session, req *provisionersdk.InitRequest, canceledOrComplete <-chan struct{}) *proto.InitComplete { + err := sess.Files.ExtractArchive(sess.Context(), sess.Logger, afero.NewOsFs(), req.TemplateSourceArchive, nil) if err != nil { return provisionersdk.InitErrorf("extract archive: %s", err.Error()) } diff --git a/provisioner/terraform/provision.go b/provisioner/terraform/provision.go index 8f428525f5..4b95d6d2f2 100644 --- a/provisioner/terraform/provision.go +++ b/provisioner/terraform/provision.go @@ -69,7 +69,7 @@ func (s *server) setupContexts(parent context.Context, canceledOrComplete <-chan } func (s *server) Init( - sess *provisionersdk.Session, request *proto.InitRequest, canceledOrComplete <-chan struct{}, + sess *provisionersdk.Session, request *provisionersdk.InitRequest, canceledOrComplete <-chan struct{}, ) *proto.InitComplete { ctx, span := s.startTrace(sess.Context(), tracing.FuncName()) defer span.End() @@ -84,7 +84,7 @@ func (s *server) Init( logTerraformEnvVars(sess) // TODO: These logs should probably be streamed back to the provisioner runner. - err := sess.Files.ExtractArchive(ctx, s.logger, afero.NewOsFs(), request.GetTemplateSourceArchive()) + err := sess.Files.ExtractArchive(ctx, s.logger, afero.NewOsFs(), request.GetTemplateSourceArchive(), request.ModuleArchive) if err != nil { return provisionersdk.InitErrorf("extract template archive: %s", err) } diff --git a/provisionerd/provisionerd.go b/provisionerd/provisionerd.go index bf219fbe1b..769bdb8446 100644 --- a/provisionerd/provisionerd.go +++ b/provisionerd/provisionerd.go @@ -27,6 +27,7 @@ import ( "github.com/coder/coder/v2/codersdk/drpcsdk" "github.com/coder/coder/v2/provisionerd/proto" "github.com/coder/coder/v2/provisionerd/runner" + "github.com/coder/coder/v2/provisionersdk" sdkproto "github.com/coder/coder/v2/provisionersdk/proto" "github.com/coder/retry" ) @@ -417,6 +418,7 @@ func (p *Server) acquireAndRunOne(client proto.DRPCProvisionerDaemonClient) erro runner.Options{ Updater: p, QuotaCommitter: p, + FileDownloader: p, Logger: p.opts.Logger.Named("runner"), Provisioner: resp.Client, UpdateInterval: p.opts.UpdateInterval, @@ -527,7 +529,7 @@ func (p *Server) UploadModuleFiles(ctx context.Context, moduleFiles []byte) erro stream, err := client.UploadFile(ctx) if err != nil { - return nil, xerrors.Errorf("failed to start CompleteJobWithFiles stream: %w", err) + return nil, xerrors.Errorf("failed to start UploadModuleFiles stream: %w", err) } defer stream.Close() @@ -567,6 +569,36 @@ func (p *Server) UploadModuleFiles(ctx context.Context, moduleFiles []byte) erro return nil } +// DownloadFile will download a module file from coderd. +func (p *Server) DownloadFile(ctx context.Context, request *proto.FileRequest) ([]byte, error) { + data, err := clientDoWithRetries(ctx, p.client, func(ctx context.Context, client proto.DRPCProvisionerDaemonClient) ([]byte, error) { + // Add some timeout to prevent the stream from hanging indefinitely if something goes wrong. + ctx, cancel := context.WithTimeout(ctx, 5*time.Minute) + defer cancel() + + stream, err := client.DownloadFile(ctx, request) + if err != nil { + return nil, xerrors.Errorf("failed to start DownloadFile stream: %w", err) + } + defer stream.Close() + + file, err := provisionersdk.HandleReceivingDataUpload(stream) + if err != nil { + return nil, xerrors.Errorf("failed to handle receiving data upload: %w", err) + } + data, err := file.Complete() + if err != nil { + return nil, xerrors.Errorf("failed to download file: %w", err) + } + return data, nil + }) + if err != nil { + return nil, xerrors.Errorf("download file %s: %w", request.FileId, err) + } + + return data, nil +} + func (p *Server) CompleteJob(ctx context.Context, in *proto.CompletedJob) error { // If the moduleFiles exceed the max message size, we need to upload them separately. if ti, ok := in.Type.(*proto.CompletedJob_TemplateImport_); ok { diff --git a/provisionerd/provisionerd_test.go b/provisionerd/provisionerd_test.go index 913c291162..4ac7553e80 100644 --- a/provisionerd/provisionerd_test.go +++ b/provisionerd/provisionerd_test.go @@ -132,7 +132,7 @@ func TestProvisionerd(t *testing.T) { } return c }, - init: func(s *provisionersdk.Session, r *sdkproto.InitRequest, canceledOrComplete <-chan struct{}) *sdkproto.InitComplete { + init: func(s *provisionersdk.Session, r *provisionersdk.InitRequest, canceledOrComplete <-chan struct{}) *sdkproto.InitComplete { closerMutex.Lock() defer closerMutex.Unlock() err := closer.Close() @@ -197,7 +197,7 @@ func TestProvisionerd(t *testing.T) { Readme: make([]byte, largeSize), } }, - init: func(s *provisionersdk.Session, r *sdkproto.InitRequest, canceledOrComplete <-chan struct{}) *sdkproto.InitComplete { + init: func(s *provisionersdk.Session, r *provisionersdk.InitRequest, canceledOrComplete <-chan struct{}) *sdkproto.InitComplete { return &sdkproto.InitComplete{} }, plan: func( @@ -276,7 +276,7 @@ func TestProvisionerd(t *testing.T) { <-cancelOrComplete return &sdkproto.ParseComplete{} }, - init: func(s *provisionersdk.Session, r *sdkproto.InitRequest, canceledOrComplete <-chan struct{}) *sdkproto.InitComplete { + init: func(s *provisionersdk.Session, r *provisionersdk.InitRequest, canceledOrComplete <-chan struct{}) *sdkproto.InitComplete { return &sdkproto.InitComplete{} }, }), @@ -417,7 +417,7 @@ func TestProvisionerd(t *testing.T) { }), nil }, provisionerd.LocalProvisioners{ "someprovisioner": createProvisionerClient(t, done, provisionerTestServer{ - init: func(s *provisionersdk.Session, r *sdkproto.InitRequest, canceledOrComplete <-chan struct{}) *sdkproto.InitComplete { + init: func(s *provisionersdk.Session, r *provisionersdk.InitRequest, canceledOrComplete <-chan struct{}) *sdkproto.InitComplete { return &sdkproto.InitComplete{} }, plan: func( @@ -488,7 +488,7 @@ func TestProvisionerd(t *testing.T) { }), nil }, provisionerd.LocalProvisioners{ "someprovisioner": createProvisionerClient(t, done, provisionerTestServer{ - init: func(s *provisionersdk.Session, r *sdkproto.InitRequest, canceledOrComplete <-chan struct{}) *sdkproto.InitComplete { + init: func(s *provisionersdk.Session, r *provisionersdk.InitRequest, canceledOrComplete <-chan struct{}) *sdkproto.InitComplete { return &sdkproto.InitComplete{} }, plan: func( @@ -566,7 +566,7 @@ func TestProvisionerd(t *testing.T) { }), nil }, provisionerd.LocalProvisioners{ "someprovisioner": createProvisionerClient(t, done, provisionerTestServer{ - init: func(s *provisionersdk.Session, r *sdkproto.InitRequest, canceledOrComplete <-chan struct{}) *sdkproto.InitComplete { + init: func(s *provisionersdk.Session, r *provisionersdk.InitRequest, canceledOrComplete <-chan struct{}) *sdkproto.InitComplete { return &sdkproto.InitComplete{} }, plan: func( @@ -641,7 +641,7 @@ func TestProvisionerd(t *testing.T) { }), nil }, provisionerd.LocalProvisioners{ "someprovisioner": createProvisionerClient(t, done, provisionerTestServer{ - init: func(s *provisionersdk.Session, r *sdkproto.InitRequest, canceledOrComplete <-chan struct{}) *sdkproto.InitComplete { + init: func(s *provisionersdk.Session, r *provisionersdk.InitRequest, canceledOrComplete <-chan struct{}) *sdkproto.InitComplete { return &sdkproto.InitComplete{} }, graph: func(s *provisionersdk.Session, r *sdkproto.GraphRequest, canceledOrComplete <-chan struct{}) *sdkproto.GraphComplete { @@ -757,7 +757,7 @@ func TestProvisionerd(t *testing.T) { }), nil }, provisionerd.LocalProvisioners{ "someprovisioner": createProvisionerClient(t, done, provisionerTestServer{ - init: func(s *provisionersdk.Session, r *sdkproto.InitRequest, canceledOrComplete <-chan struct{}) *sdkproto.InitComplete { + init: func(s *provisionersdk.Session, r *provisionersdk.InitRequest, canceledOrComplete <-chan struct{}) *sdkproto.InitComplete { return &sdkproto.InitComplete{} }, plan: func( @@ -848,7 +848,7 @@ func TestProvisionerd(t *testing.T) { }), nil }, provisionerd.LocalProvisioners{ "someprovisioner": createProvisionerClient(t, done, provisionerTestServer{ - init: func(s *provisionersdk.Session, r *sdkproto.InitRequest, canceledOrComplete <-chan struct{}) *sdkproto.InitComplete { + init: func(s *provisionersdk.Session, r *provisionersdk.InitRequest, canceledOrComplete <-chan struct{}) *sdkproto.InitComplete { return &sdkproto.InitComplete{} }, plan: func( @@ -945,7 +945,7 @@ func TestProvisionerd(t *testing.T) { return client, nil }, provisionerd.LocalProvisioners{ "someprovisioner": createProvisionerClient(t, done, provisionerTestServer{ - init: func(s *provisionersdk.Session, r *sdkproto.InitRequest, canceledOrComplete <-chan struct{}) *sdkproto.InitComplete { + init: func(s *provisionersdk.Session, r *provisionersdk.InitRequest, canceledOrComplete <-chan struct{}) *sdkproto.InitComplete { return &sdkproto.InitComplete{} }, plan: func( @@ -1041,7 +1041,7 @@ func TestProvisionerd(t *testing.T) { return client, nil }, provisionerd.LocalProvisioners{ "someprovisioner": createProvisionerClient(t, done, provisionerTestServer{ - init: func(s *provisionersdk.Session, r *sdkproto.InitRequest, canceledOrComplete <-chan struct{}) *sdkproto.InitComplete { + init: func(s *provisionersdk.Session, r *provisionersdk.InitRequest, canceledOrComplete <-chan struct{}) *sdkproto.InitComplete { return &sdkproto.InitComplete{} }, plan: func( @@ -1141,7 +1141,7 @@ func TestProvisionerd(t *testing.T) { }), nil }, provisionerd.LocalProvisioners{ "someprovisioner": createProvisionerClient(t, done, provisionerTestServer{ - init: func(s *provisionersdk.Session, r *sdkproto.InitRequest, canceledOrComplete <-chan struct{}) *sdkproto.InitComplete { + init: func(s *provisionersdk.Session, r *provisionersdk.InitRequest, canceledOrComplete <-chan struct{}) *sdkproto.InitComplete { return &sdkproto.InitComplete{} }, graph: func(s *provisionersdk.Session, r *sdkproto.GraphRequest, canceledOrComplete <-chan struct{}) *sdkproto.GraphComplete { @@ -1275,14 +1275,14 @@ func createProvisionerClient(t *testing.T, done <-chan struct{}, server provisio } type provisionerTestServer struct { - init func(s *provisionersdk.Session, r *sdkproto.InitRequest, canceledOrComplete <-chan struct{}) *sdkproto.InitComplete + init func(s *provisionersdk.Session, r *provisionersdk.InitRequest, canceledOrComplete <-chan struct{}) *sdkproto.InitComplete parse func(s *provisionersdk.Session, r *sdkproto.ParseRequest, canceledOrComplete <-chan struct{}) *sdkproto.ParseComplete plan func(s *provisionersdk.Session, r *sdkproto.PlanRequest, canceledOrComplete <-chan struct{}) *sdkproto.PlanComplete apply func(s *provisionersdk.Session, r *sdkproto.ApplyRequest, canceledOrComplete <-chan struct{}) *sdkproto.ApplyComplete graph func(s *provisionersdk.Session, r *sdkproto.GraphRequest, canceledOrComplete <-chan struct{}) *sdkproto.GraphComplete } -func (p *provisionerTestServer) Init(s *provisionersdk.Session, r *sdkproto.InitRequest, canceledOrComplete <-chan struct{}) *sdkproto.InitComplete { +func (p *provisionerTestServer) Init(s *provisionersdk.Session, r *provisionersdk.InitRequest, canceledOrComplete <-chan struct{}) *sdkproto.InitComplete { return p.init(s, r, canceledOrComplete) } @@ -1399,10 +1399,10 @@ func (a *acquireOne) acquireWithCancel(stream proto.DRPCProvisionerDaemon_Acquir return nil } -func extractInit(t *testing.T) func(s *provisionersdk.Session, r *sdkproto.InitRequest, canceledOrComplete <-chan struct{}) *sdkproto.InitComplete { +func extractInit(t *testing.T) func(s *provisionersdk.Session, r *provisionersdk.InitRequest, canceledOrComplete <-chan struct{}) *sdkproto.InitComplete { logger := slogtest.Make(t, nil) - return func(s *provisionersdk.Session, r *sdkproto.InitRequest, canceledOrComplete <-chan struct{}) *sdkproto.InitComplete { - err := s.Files.ExtractArchive(s.Context(), logger, afero.NewOsFs(), r.TemplateSourceArchive) + return func(s *provisionersdk.Session, r *provisionersdk.InitRequest, canceledOrComplete <-chan struct{}) *sdkproto.InitComplete { + err := s.Files.ExtractArchive(s.Context(), logger, afero.NewOsFs(), r.TemplateSourceArchive, nil) if err != nil { return &sdkproto.InitComplete{ Error: fmt.Sprintf("failed to extract template source archive: %v", err), diff --git a/provisionerd/runner/init.go b/provisionerd/runner/init.go index 558ed302aa..45c762b7fa 100644 --- a/provisionerd/runner/init.go +++ b/provisionerd/runner/init.go @@ -12,18 +12,60 @@ import ( ) //nolint:revive -func (r *Runner) init(ctx context.Context, omitModules bool, templateArchive []byte) (*sdkproto.InitComplete, *proto.FailedJob) { +func (r *Runner) init(ctx context.Context, omitModules bool, templateArchive []byte, moduleTar []byte) (*sdkproto.InitComplete, *proto.FailedJob) { ctx, span := r.startTrace(ctx, tracing.FuncName()) defer span.End() + // If `moduleTar` is populated, `init` will send it over in multiple parts. This + // It must be called before the initial request to populate the correct hash if + // there is data to send. This is safe to call on nil or empty slices. + data, chunks := sdkproto.BytesToDataUpload(sdkproto.DataUploadType_UPLOAD_TYPE_MODULE_FILES, moduleTar) + + hash := []byte{} + if len(moduleTar) > 0 { + hash = data.DataHash + } + err := r.session.Send(&sdkproto.Request{Type: &sdkproto.Request_Init{Init: &sdkproto.InitRequest{ TemplateSourceArchive: templateArchive, OmitModuleFiles: omitModules, + InitialModuleTarHash: hash, }}}) if err != nil { return nil, r.failedJobf("send init request: %v", err) } + // If the module tar exists, send over the data. + if len(moduleTar) > 0 { + err = r.session.Send(&sdkproto.Request{ + Type: &sdkproto.Request_File{ + File: &sdkproto.FileUpload{ + Type: &sdkproto.FileUpload_DataUpload{ + DataUpload: data, + }, + }, + }, + }) + if err != nil { + return nil, r.failedJobf("send module files data upload: %v", err) + } + + for _, c := range chunks { + err = r.session.Send(&sdkproto.Request{ + Type: &sdkproto.Request_File{ + File: &sdkproto.FileUpload{ + Type: &sdkproto.FileUpload_ChunkPiece{ + ChunkPiece: c, + }, + }, + }, + }) + if err != nil { + return nil, r.failedJobf("send module files chunk: %v", err) + } + } + } + nevermind := make(chan struct{}) defer close(nevermind) go func() { diff --git a/provisionerd/runner/runner.go b/provisionerd/runner/runner.go index da542f1da8..b8c5dc6df5 100644 --- a/provisionerd/runner/runner.go +++ b/provisionerd/runner/runner.go @@ -48,6 +48,7 @@ type Runner struct { job *proto.AcquiredJob sender JobUpdater quotaCommitter QuotaCommitter + fileDownloader FileDownloader logger slog.Logger provisioner sdkproto.DRPCProvisionerClient lastUpdate atomic.Pointer[time.Time] @@ -96,13 +97,19 @@ type JobUpdater interface { FailJob(ctx context.Context, in *proto.FailedJob) error CompleteJob(ctx context.Context, in *proto.CompletedJob) error } + type QuotaCommitter interface { CommitQuota(ctx context.Context, in *proto.CommitQuotaRequest) (*proto.CommitQuotaResponse, error) } +type FileDownloader interface { + DownloadFile(ctx context.Context, req *proto.FileRequest) ([]byte, error) +} + type Options struct { Updater JobUpdater QuotaCommitter QuotaCommitter + FileDownloader FileDownloader Logger slog.Logger Provisioner sdkproto.DRPCProvisionerClient UpdateInterval time.Duration @@ -142,6 +149,7 @@ func New( job: job, sender: opts.Updater, quotaCommitter: opts.QuotaCommitter, + fileDownloader: opts.FileDownloader, logger: logger, provisioner: opts.Provisioner, updateInterval: opts.UpdateInterval, @@ -521,7 +529,7 @@ func (r *Runner) runTemplateImport(ctx context.Context) (*proto.CompletedJob, *p } // Initialize the Terraform working directory - initResp, failedInit := r.init(ctx, false, r.job.GetTemplateSourceArchive()) + initResp, failedInit := r.init(ctx, false, r.job.GetTemplateSourceArchive(), nil) if failedInit != nil { return nil, failedInit } @@ -787,7 +795,7 @@ func (r *Runner) runTemplateDryRun(ctx context.Context) (*proto.CompletedJob, *p } // Initialize the Terraform working directory - initResp, failedJob := r.init(ctx, false, r.job.GetTemplateSourceArchive()) + initResp, failedJob := r.init(ctx, false, r.job.GetTemplateSourceArchive(), nil) if failedJob != nil { return nil, failedJob } @@ -901,8 +909,25 @@ func (r *Runner) runWorkspaceBuild(ctx context.Context) (*proto.CompletedJob, *p // timings collects all timings from each phase of the build timings := make([]*sdkproto.Timing, 0) + var cachedModulesTar []byte + // Download modules if cached in coderd + if r.job.GetWorkspaceBuild().Metadata.TemplateVersionModulesFile != "" { + fileID, err := uuid.Parse(r.job.GetWorkspaceBuild().Metadata.TemplateVersionModulesFile) + if err != nil { + return nil, r.failedWorkspaceBuildf("invalid template version modules file ID: %s", err) + } + // Download the module tar file + cachedModulesTar, err = r.fileDownloader.DownloadFile(ctx, &proto.FileRequest{ + FileId: fileID.String(), + UploadType: sdkproto.DataUploadType_UPLOAD_TYPE_MODULE_FILES, + }) + if err != nil { + return nil, r.failedWorkspaceBuildf("failed to download template version modules file: %s", err) + } + } + // Initialize the Terraform working directory - initComplete, failedJob := r.init(ctx, true, r.job.GetTemplateSourceArchive()) + initComplete, failedJob := r.init(ctx, true, r.job.GetTemplateSourceArchive(), cachedModulesTar) if failedJob != nil { return nil, failedJob } diff --git a/provisionersdk/dataupload.go b/provisionersdk/dataupload.go new file mode 100644 index 0000000000..be7716c4b9 --- /dev/null +++ b/provisionersdk/dataupload.go @@ -0,0 +1,85 @@ +package provisionersdk + +import ( + "io" + + "golang.org/x/xerrors" + + sdkproto "github.com/coder/coder/v2/provisionersdk/proto" +) + +// HandleReceivingDataUpload can download a multi-part file from a proto stream. +// The stream is expected to be closed by the caller. +func HandleReceivingDataUpload(stream interface { + Recv() (*sdkproto.FileUpload, error) +}, +) (*sdkproto.DataBuilder, error) { + var file *sdkproto.DataBuilder +UploadFileStream: + for { + msg, err := stream.Recv() + if err != nil { + if xerrors.Is(err, io.EOF) { + // Do not return an EOF here, as it is a "retryable error" in the client context. + // This failure indicates the download stream was closed prematurely, and it is a + // fatal error. + return nil, xerrors.Errorf("stream closed before file download complete") + } + return nil, xerrors.Errorf("receive file download: %w", err) + } + + switch typed := msg.Type.(type) { + case *sdkproto.FileUpload_Error: + return nil, xerrors.Errorf("download file: %s", typed.Error.Error) + case *sdkproto.FileUpload_DataUpload: + if file != nil { + return nil, xerrors.New("unexpected file download while waiting for file completion") + } + + file, err = sdkproto.NewDataBuilder(&sdkproto.DataUpload{ + UploadType: typed.DataUpload.UploadType, + DataHash: typed.DataUpload.DataHash, + FileSize: typed.DataUpload.FileSize, + Chunks: typed.DataUpload.Chunks, + }) + if err != nil { + return nil, xerrors.Errorf("unable to create file download: %w", err) + } + + if file.IsDone() { + // If a file is 0 bytes, we can consider it done immediately. + // This should never really happen in practice, but we handle it gracefully. + break UploadFileStream + } + case *sdkproto.FileUpload_ChunkPiece: + if file == nil { + return nil, xerrors.New("unexpected chunk piece while waiting for file upload") + } + + done, err := file.Add(&sdkproto.ChunkPiece{ + Data: typed.ChunkPiece.Data, + FullDataHash: typed.ChunkPiece.FullDataHash, + PieceIndex: typed.ChunkPiece.PieceIndex, + }) + if err != nil { + return nil, xerrors.Errorf("unable to add a chunk piece: %w", err) + } + + if done { + break UploadFileStream + } + default: + // This should never happen + return nil, xerrors.Errorf("received unknown file upload message type: %T", msg.Type) + } + } + + // This needs to be called again by the caller to retrieve the final payload. + // It is called here to do a hash check and ensure the file is correct. + _, err := file.Complete() + if err != nil { + return nil, xerrors.Errorf("complete file upload: %w", err) + } + + return file, nil +} diff --git a/provisionersdk/serve.go b/provisionersdk/serve.go index d8f52374c3..4afcee9626 100644 --- a/provisionersdk/serve.go +++ b/provisionersdk/serve.go @@ -33,8 +33,15 @@ type ServeOptions struct { Experiments codersdk.Experiments } +// InitRequest wraps the InitRequest proto with the module archive bytes, which +// is downloaded by the SDK from the hash field in the InitRequest proto. +type InitRequest struct { + *proto.InitRequest + ModuleArchive []byte +} + type Server interface { - Init(s *Session, r *proto.InitRequest, canceledOrComplete <-chan struct{}) *proto.InitComplete + Init(s *Session, r *InitRequest, canceledOrComplete <-chan struct{}) *proto.InitComplete Parse(s *Session, r *proto.ParseRequest, canceledOrComplete <-chan struct{}) *proto.ParseComplete Plan(s *Session, r *proto.PlanRequest, canceledOrComplete <-chan struct{}) *proto.PlanComplete Apply(s *Session, r *proto.ApplyRequest, canceledOrComplete <-chan struct{}) *proto.ApplyComplete diff --git a/provisionersdk/serve_test.go b/provisionersdk/serve_test.go index ff79c3c16b..dd86167d67 100644 --- a/provisionersdk/serve_test.go +++ b/provisionersdk/serve_test.go @@ -149,7 +149,7 @@ var _ provisionersdk.Server = unimplementedServer{} type unimplementedServer struct{} -func (unimplementedServer) Init(s *provisionersdk.Session, r *proto.InitRequest, canceledOrComplete <-chan struct{}) *proto.InitComplete { +func (unimplementedServer) Init(s *provisionersdk.Session, r *provisionersdk.InitRequest, canceledOrComplete <-chan struct{}) *proto.InitComplete { return &proto.InitComplete{} } diff --git a/provisionersdk/session.go b/provisionersdk/session.go index 1359a7a66f..094fe38aba 100644 --- a/provisionersdk/session.go +++ b/provisionersdk/session.go @@ -125,6 +125,7 @@ func (s *Session) handleRequests() error { if s.initialized { return xerrors.New("cannot init more than once per session") } + initResp, err := s.handleInitRequest(init, requests) if err != nil { return err @@ -185,9 +186,47 @@ func (s *Session) handleRequests() error { return nil } +// fromChannel implements the `Recv` api using an underlying channel for +// downloading files. +type fromChannel struct { + requests <-chan *proto.Request +} + +func (f *fromChannel) Recv() (*proto.FileUpload, error) { + next, ok := <-f.requests + if !ok { + return nil, xerrors.New("channel closed") + } + + // Only file download messages are expected here. + file := next.GetFile() + if file == nil { + return nil, xerrors.Errorf("expected file upload") + } + + return file, nil +} + func (s *Session) handleInitRequest(init *proto.InitRequest, requests <-chan *proto.Request) (*proto.InitComplete, error) { - r := &request[*proto.InitRequest, *proto.InitComplete]{ - req: init, + req := &InitRequest{ + InitRequest: init, + ModuleArchive: nil, + } + if len(init.GetInitialModuleTarHash()) > 0 { + file, err := HandleReceivingDataUpload(&fromChannel{requests: requests}) + if err != nil { + return nil, err + } + + data, err := file.Complete() + if err != nil { + return nil, err + } + req.ModuleArchive = data + } + + r := &request[*InitRequest, *proto.InitComplete]{ + req: req, session: s, serverFn: s.server.Init, cancels: requests, @@ -279,7 +318,7 @@ func (s *Session) ProvisionLog(level proto.LogLevel, output string) { } type pRequest interface { - *proto.ParseRequest | *proto.InitRequest | *proto.PlanRequest | *proto.ApplyRequest | *proto.GraphRequest + *proto.ParseRequest | *InitRequest | *proto.PlanRequest | *proto.ApplyRequest | *proto.GraphRequest } type pComplete interface { diff --git a/provisionersdk/tfpath/tfpath.go b/provisionersdk/tfpath/tfpath.go index e9fd099bb7..fc13bc17d0 100644 --- a/provisionersdk/tfpath/tfpath.go +++ b/provisionersdk/tfpath/tfpath.go @@ -72,17 +72,39 @@ func (l Layout) ModulesFilePath() string { return filepath.Join(l.ModulesDirectory(), "modules.json") } -func (l Layout) ExtractArchive(ctx context.Context, logger slog.Logger, fs afero.Fs, templateSourceArchive []byte) error { - logger.Info(ctx, "unpacking template source archive", - slog.F("size_bytes", len(templateSourceArchive)), - ) - - err := fs.MkdirAll(l.WorkDirectory(), 0o700) +// ExtractArchive extracts the provided template source archive and modules archive into the working directory. +// `modulesArchive` is optional and can be nil or empty. +func (l Layout) ExtractArchive(ctx context.Context, logger slog.Logger, fs afero.Fs, templateSourceArchive, modulesArchive []byte) error { + err := extractArchive(ctx, logger, fs, l.WorkDirectory(), templateSourceArchive) if err != nil { - return xerrors.Errorf("create work directory %q: %w", l.WorkDirectory(), err) + return xerrors.Errorf("extract template source archive: %w", err) } - reader := tar.NewReader(bytes.NewBuffer(templateSourceArchive)) + if len(modulesArchive) > 0 { + err = extractArchive(ctx, logger, fs, l.WorkDirectory(), modulesArchive) + if err != nil { + return xerrors.Errorf("extract modules archive: %w", err) + } + } + return nil +} + +func isValidSessionDir(dirName string) bool { + match, err := filepath.Match(sessionDirPrefix+"*", dirName) + return err == nil && match +} + +func extractArchive(ctx context.Context, logger slog.Logger, fs afero.Fs, directory string, archive []byte) error { + logger.Info(ctx, "unpacking source archive", + slog.F("size_bytes", len(archive)), + ) + + err := fs.MkdirAll(directory, 0o700) + if err != nil { + return xerrors.Errorf("create work directory %q: %w", directory, err) + } + + reader := tar.NewReader(bytes.NewBuffer(archive)) for { header, err := reader.Next() if err != nil { @@ -103,8 +125,8 @@ func (l Layout) ExtractArchive(ctx context.Context, logger slog.Logger, fs afero } // nolint: gosec // Safe to no-lint because the filepath.IsLocal check above. - headerPath := filepath.Join(l.WorkDirectory(), header.Name) - if !strings.HasPrefix(headerPath, filepath.Clean(l.WorkDirectory())) { + headerPath := filepath.Join(directory, header.Name) + if !strings.HasPrefix(headerPath, filepath.Clean(directory)) { return xerrors.New("tar attempts to target relative upper directory") } mode := header.FileInfo().Mode() @@ -220,8 +242,3 @@ func (l Layout) CleanStaleSessions(ctx context.Context, logger slog.Logger, fs a } return nil } - -func isValidSessionDir(dirName string) bool { - match, err := filepath.Match(sessionDirPrefix+"*", dirName) - return err == nil && match -}