From 7c8deaf0d6bfb4d9fb77127345888f53554517a8 Mon Sep 17 00:00:00 2001 From: Steven Masley Date: Wed, 12 Nov 2025 09:24:07 -0600 Subject: [PATCH] chore: refactor terraform paths to a central structure (#20566) Refactors all Terraform file path logic into a centralized tfpath package. This consolidates all path construction into a single, testable Layout type. Instead of passing around `string` for directories, pass around the `Layout` which has the file location methods on it. --- provisioner/echo/serve.go | 4 +- provisioner/terraform/executor.go | 30 ++--- provisioner/terraform/modules.go | 10 +- provisioner/terraform/parse.go | 4 +- provisioner/terraform/provision.go | 12 +- provisioner/terraform/serve.go | 5 +- provisionerd/provisionerd_test.go | 2 +- provisionersdk/cleanup_test.go | 24 ++-- provisionersdk/session.go | 135 ++------------------ provisionersdk/tfpath/tfpath.go | 192 +++++++++++++++++++++++++++++ 10 files changed, 244 insertions(+), 174 deletions(-) create mode 100644 provisionersdk/tfpath/tfpath.go diff --git a/provisioner/echo/serve.go b/provisioner/echo/serve.go index 5069424156..26d1fcbe3a 100644 --- a/provisioner/echo/serve.go +++ b/provisioner/echo/serve.go @@ -122,8 +122,8 @@ func readResponses(sess *provisionersdk.Session, trans string, suffix string) ([ for i := 0; ; i++ { paths := []string{ // Try more specific path first, then fallback to generic. - filepath.Join(sess.WorkDirectory, fmt.Sprintf("%d.%s.%s", i, trans, suffix)), - filepath.Join(sess.WorkDirectory, fmt.Sprintf("%d.%s", i, suffix)), + filepath.Join(sess.Files.WorkDirectory(), fmt.Sprintf("%d.%s.%s", i, trans, suffix)), + filepath.Join(sess.Files.WorkDirectory(), fmt.Sprintf("%d.%s", i, suffix)), } for pathIndex, path := range paths { _, err := os.Stat(path) diff --git a/provisioner/terraform/executor.go b/provisioner/terraform/executor.go index c7811fe272..345b0e72fb 100644 --- a/provisioner/terraform/executor.go +++ b/provisioner/terraform/executor.go @@ -10,7 +10,6 @@ import ( "io" "os" "os/exec" - "path/filepath" "runtime" "strings" "sync" @@ -22,6 +21,7 @@ import ( "golang.org/x/xerrors" "cdr.dev/slog" + "github.com/coder/coder/v2/provisionersdk/tfpath" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/tracing" @@ -38,10 +38,10 @@ type executor struct { server *server mut *sync.Mutex binaryPath string - // cachePath and workdir must not be used by multiple processes at once. + // cachePath and files must not be used by multiple processes at once. cachePath string cliConfigPath string - workdir string + files tfpath.Layout // used to capture execution times at various stages timings *timingAggregator } @@ -90,7 +90,7 @@ func (e *executor) execWriteOutput(ctx, killCtx context.Context, args, env []str // #nosec cmd := exec.CommandContext(killCtx, e.binaryPath, args...) - cmd.Dir = e.workdir + cmd.Dir = e.files.WorkDirectory() if env == nil { // We don't want to passthrough host env when unset. env = []string{} @@ -131,7 +131,7 @@ func (e *executor) execParseJSON(ctx, killCtx context.Context, args, env []strin // #nosec cmd := exec.CommandContext(killCtx, e.binaryPath, args...) - cmd.Dir = e.workdir + cmd.Dir = e.files.WorkDirectory() cmd.Env = env out := &bytes.Buffer{} stdErr := &bytes.Buffer{} @@ -225,7 +225,7 @@ func (e *executor) init(ctx, killCtx context.Context, logr logSink) error { defer e.mut.Unlock() // Record lock file checksum before init - lockFilePath := filepath.Join(e.workdir, ".terraform.lock.hcl") + lockFilePath := e.files.TerraformLockFile() preInitChecksum := checksumFileCRC32(ctx, e.logger, lockFilePath) outWriter, doneOut := e.provisionLogWriter(logr) @@ -289,14 +289,6 @@ func checksumFileCRC32(ctx context.Context, logger slog.Logger, path string) uin return crc32.ChecksumIEEE(content) } -func getPlanFilePath(workdir string) string { - return filepath.Join(workdir, "terraform.tfplan") -} - -func getStateFilePath(workdir string) string { - return filepath.Join(workdir, "terraform.tfstate") -} - // revive:disable-next-line:flag-parameter func (e *executor) plan(ctx, killCtx context.Context, env, vars []string, logr logSink, req *proto.PlanRequest) (*proto.PlanComplete, error) { ctx, span := e.server.startTrace(ctx, tracing.FuncName()) @@ -307,7 +299,7 @@ func (e *executor) plan(ctx, killCtx context.Context, env, vars []string, logr l metadata := req.Metadata - planfilePath := getPlanFilePath(e.workdir) + planfilePath := e.files.PlanFilePath() args := []string{ "plan", "-no-color", @@ -359,7 +351,7 @@ func (e *executor) plan(ctx, killCtx context.Context, env, vars []string, logr l // a workspace build. This removes some added costs of sending the modules // payload back to coderd if coderd is just going to ignore it. if !req.OmitModuleFiles { - moduleFiles, err = GetModulesArchive(os.DirFS(e.workdir)) + moduleFiles, err = GetModulesArchive(os.DirFS(e.files.WorkDirectory())) if err != nil { // TODO: we probably want to persist this error or make it louder eventually e.logger.Warn(ctx, "failed to archive terraform modules", slog.Error(err)) @@ -551,7 +543,7 @@ func (e *executor) graph(ctx, killCtx context.Context) (string, error) { var out strings.Builder cmd := exec.CommandContext(killCtx, e.binaryPath, args...) // #nosec cmd.Stdout = &out - cmd.Dir = e.workdir + cmd.Dir = e.files.WorkDirectory() cmd.Env = e.basicEnv() e.server.logger.Debug(ctx, "executing terraform command graph", @@ -588,7 +580,7 @@ func (e *executor) apply( "-auto-approve", "-input=false", "-json", - getPlanFilePath(e.workdir), + e.files.PlanFilePath(), } outWriter, doneOut := e.provisionLogWriter(logr) @@ -608,7 +600,7 @@ func (e *executor) apply( if err != nil { return nil, err } - statefilePath := getStateFilePath(e.workdir) + statefilePath := e.files.StateFilePath() stateContent, err := os.ReadFile(statefilePath) if err != nil { return nil, xerrors.Errorf("read statefile %q: %w", statefilePath, err) diff --git a/provisioner/terraform/modules.go b/provisioner/terraform/modules.go index f0b40ea951..38bfd65e84 100644 --- a/provisioner/terraform/modules.go +++ b/provisioner/terraform/modules.go @@ -7,7 +7,6 @@ import ( "io" "io/fs" "os" - "path/filepath" "strings" "time" @@ -15,6 +14,7 @@ import ( "github.com/coder/coder/v2/coderd/util/xio" "github.com/coder/coder/v2/provisionersdk/proto" + "github.com/coder/coder/v2/provisionersdk/tfpath" ) const ( @@ -39,10 +39,6 @@ type modulesFile struct { Modules []*module `json:"Modules"` } -func getModulesFilePath(workdir string) string { - return filepath.Join(workdir, ".terraform", "modules", "modules.json") -} - func parseModulesFile(filePath string) ([]*proto.Module, error) { modules := &modulesFile{} data, err := os.ReadFile(filePath) @@ -62,8 +58,8 @@ func parseModulesFile(filePath string) ([]*proto.Module, error) { // getModules returns the modules from the modules file if it exists. // It returns nil if the file does not exist. // Modules become available after terraform init. -func getModules(workdir string) ([]*proto.Module, error) { - filePath := getModulesFilePath(workdir) +func getModules(files tfpath.Layout) ([]*proto.Module, error) { + filePath := files.ModulesFilePath() if _, err := os.Stat(filePath); os.IsNotExist(err) { return nil, nil } diff --git a/provisioner/terraform/parse.go b/provisioner/terraform/parse.go index d5b59df327..2f5a8c7f5c 100644 --- a/provisioner/terraform/parse.go +++ b/provisioner/terraform/parse.go @@ -25,9 +25,9 @@ func (s *server) Parse(sess *provisionersdk.Session, _ *proto.ParseRequest, _ <- defer span.End() // Load the module and print any parse errors. - parser, diags := tfparse.New(sess.WorkDirectory, tfparse.WithLogger(s.logger.Named("tfparse"))) + parser, diags := tfparse.New(sess.Files.WorkDirectory(), tfparse.WithLogger(s.logger.Named("tfparse"))) if diags.HasErrors() { - return provisionersdk.ParseErrorf("load module: %s", formatDiagnostics(sess.WorkDirectory, diags)) + return provisionersdk.ParseErrorf("load module: %s", formatDiagnostics(sess.Files.WorkDirectory(), diags)) } workspaceTags, _, err := parser.WorkspaceTags(ctx) diff --git a/provisioner/terraform/provision.go b/provisioner/terraform/provision.go index ec9f96c3ed..2445a396f6 100644 --- a/provisioner/terraform/provision.go +++ b/provisioner/terraform/provision.go @@ -76,7 +76,7 @@ func (s *server) Plan( defer cancel() defer kill() - e := s.executor(sess.WorkDirectory, database.ProvisionerJobTimingStagePlan) + e := s.executor(sess.Files, database.ProvisionerJobTimingStagePlan) if err := e.checkMinVersion(ctx); err != nil { return provisionersdk.PlanErrorf("%s", err.Error()) } @@ -92,7 +92,7 @@ func (s *server) Plan( return &proto.PlanComplete{} } - statefilePath := getStateFilePath(sess.WorkDirectory) + statefilePath := sess.Files.StateFilePath() if len(sess.Config.State) > 0 { err := os.WriteFile(statefilePath, sess.Config.State, 0o600) if err != nil { @@ -141,7 +141,7 @@ func (s *server) Plan( return provisionersdk.PlanErrorf("initialize terraform: %s", err) } - modules, err := getModules(sess.WorkDirectory) + modules, err := getModules(sess.Files) if err != nil { // We allow getModules to fail, as the result is used only // for telemetry purposes now. @@ -184,7 +184,7 @@ func (s *server) Apply( defer cancel() defer kill() - e := s.executor(sess.WorkDirectory, database.ProvisionerJobTimingStageApply) + e := s.executor(sess.Files, database.ProvisionerJobTimingStageApply) if err := e.checkMinVersion(ctx); err != nil { return provisionersdk.ApplyErrorf("%s", err.Error()) } @@ -201,7 +201,7 @@ func (s *server) Apply( } // Earlier in the session, Plan() will have written the state file and the plan file. - statefilePath := getStateFilePath(sess.WorkDirectory) + statefilePath := sess.Files.StateFilePath() env, err := provisionEnv(sess.Config, request.Metadata, nil, nil, nil) if err != nil { return provisionersdk.ApplyErrorf("provision env: %s", err) @@ -348,7 +348,7 @@ func logTerraformEnvVars(sink logSink) { // shipped in v1.0.4. It will return the stacktraces of the provider, which will hopefully allow us // to figure out why it hasn't exited. func tryGettingCoderProviderStacktrace(sess *provisionersdk.Session) string { - path := filepath.Clean(filepath.Join(sess.WorkDirectory, "../.coder/pprof")) + path := filepath.Clean(filepath.Join(sess.Files.WorkDirectory(), "../.coder/pprof")) sess.Logger.Info(sess.Context(), "attempting to get stack traces", slog.F("path", path)) c := http.Client{ Transport: &http.Transport{ diff --git a/provisioner/terraform/serve.go b/provisioner/terraform/serve.go index 3e671b0c68..60951a8da1 100644 --- a/provisioner/terraform/serve.go +++ b/provisioner/terraform/serve.go @@ -14,6 +14,7 @@ import ( "golang.org/x/xerrors" "cdr.dev/slog" + "github.com/coder/coder/v2/provisionersdk/tfpath" "github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/jobreaper" @@ -160,14 +161,14 @@ func (s *server) startTrace(ctx context.Context, name string, opts ...trace.Span ))...) } -func (s *server) executor(workdir string, stage database.ProvisionerJobTimingStage) *executor { +func (s *server) executor(files tfpath.Layout, stage database.ProvisionerJobTimingStage) *executor { return &executor{ server: s, mut: s.execMut, binaryPath: s.binaryPath, cachePath: s.cachePath, cliConfigPath: s.cliConfigPath, - workdir: workdir, + files: files, logger: s.logger.Named("executor"), timings: newTimingAggregator(stage), } diff --git a/provisionerd/provisionerd_test.go b/provisionerd/provisionerd_test.go index 1b4b6720b4..f9977d0e8e 100644 --- a/provisionerd/provisionerd_test.go +++ b/provisionerd/provisionerd_test.go @@ -353,7 +353,7 @@ func TestProvisionerd(t *testing.T) { _ *sdkproto.ParseRequest, cancelOrComplete <-chan struct{}, ) *sdkproto.ParseComplete { - data, err := os.ReadFile(filepath.Join(s.WorkDirectory, "test.txt")) + data, err := os.ReadFile(filepath.Join(s.Files.WorkDirectory(), "test.txt")) require.NoError(t, err) require.Equal(t, "content", string(data)) s.ProvisionLog(sdkproto.LogLevel_INFO, "hello") diff --git a/provisionersdk/cleanup_test.go b/provisionersdk/cleanup_test.go index e23c7a9f78..d60ef55a7c 100644 --- a/provisionersdk/cleanup_test.go +++ b/provisionersdk/cleanup_test.go @@ -12,6 +12,7 @@ import ( "cdr.dev/slog" "github.com/coder/coder/v2/provisionersdk" + "github.com/coder/coder/v2/provisionersdk/tfpath" "github.com/coder/coder/v2/testutil" ) @@ -40,11 +41,11 @@ func TestStaleSessions(t *testing.T) { fs, logger := prepare() // given - first := provisionersdk.SessionDir(uuid.NewString()) + first := tfpath.Session(workDirectory, uuid.NewString()) addSessionFolder(t, fs, first, now.Add(-7*24*time.Hour)) - second := provisionersdk.SessionDir(uuid.NewString()) + second := tfpath.Session(workDirectory, uuid.NewString()) addSessionFolder(t, fs, second, now.Add(-8*24*time.Hour)) - third := provisionersdk.SessionDir(uuid.NewString()) + third := tfpath.Session(workDirectory, uuid.NewString()) addSessionFolder(t, fs, third, now.Add(-9*24*time.Hour)) // when @@ -65,9 +66,9 @@ func TestStaleSessions(t *testing.T) { fs, logger := prepare() // given - first := provisionersdk.SessionDir(uuid.NewString()) + first := tfpath.Session(workDirectory, uuid.NewString()) addSessionFolder(t, fs, first, now.Add(-7*24*time.Hour)) - second := provisionersdk.SessionDir(uuid.NewString()) + second := tfpath.Session(workDirectory, uuid.NewString()) addSessionFolder(t, fs, second, now.Add(-6*24*time.Hour)) // when @@ -77,7 +78,7 @@ func TestStaleSessions(t *testing.T) { entries, err := afero.ReadDir(fs, workDirectory) require.NoError(t, err) require.Len(t, entries, 1, "one session should be present") - require.Equal(t, second, entries[0].Name(), 1) + require.Equal(t, second.WorkDirectory(), filepath.Join(workDirectory, entries[0].Name()), 1) }) t.Run("no stale sessions", func(t *testing.T) { @@ -89,9 +90,9 @@ func TestStaleSessions(t *testing.T) { fs, logger := prepare() // given - first := provisionersdk.SessionDir(uuid.NewString()) + first := tfpath.Session(workDirectory, uuid.NewString()) addSessionFolder(t, fs, first, now.Add(-6*24*time.Hour)) - second := provisionersdk.SessionDir(uuid.NewString()) + second := tfpath.Session(workDirectory, uuid.NewString()) addSessionFolder(t, fs, second, now.Add(-5*24*time.Hour)) // when @@ -104,9 +105,10 @@ func TestStaleSessions(t *testing.T) { }) } -func addSessionFolder(t *testing.T, fs afero.Fs, sessionName string, modTime time.Time) { - err := fs.MkdirAll(filepath.Join(workDirectory, sessionName), 0o755) +func addSessionFolder(t *testing.T, fs afero.Fs, files tfpath.Layout, modTime time.Time) { + workdir := files.WorkDirectory() + err := fs.MkdirAll(workdir, 0o755) require.NoError(t, err, "can't create session folder") - require.NoError(t, fs.Chtimes(filepath.Join(workDirectory, sessionName), now, modTime), "can't chtime of session dir") + require.NoError(t, fs.Chtimes(workdir, now, modTime), "can't chtime of session dir") require.NoError(t, err, "can't set times") } diff --git a/provisionersdk/session.go b/provisionersdk/session.go index 3fd2362885..68a7219007 100644 --- a/provisionersdk/session.go +++ b/provisionersdk/session.go @@ -1,14 +1,10 @@ package provisionersdk import ( - "archive/tar" - "bytes" "context" "fmt" - "hash/crc32" "io" "os" - "path/filepath" "strings" "time" @@ -18,6 +14,7 @@ import ( "cdr.dev/slog" "github.com/coder/coder/v2/codersdk/drpcsdk" + "github.com/coder/coder/v2/provisionersdk/tfpath" protobuf "google.golang.org/protobuf/proto" @@ -48,34 +45,15 @@ func (p *protoServer) Session(stream proto.DRPCProvisioner_SessionStream) error err := CleanStaleSessions(s.Context(), p.opts.WorkDirectory, afero.NewOsFs(), time.Now(), s.Logger) if err != nil { - return xerrors.Errorf("unable to clean stale sessions %q: %w", s.WorkDirectory, err) + return xerrors.Errorf("unable to clean stale sessions %q: %w", s.Files, err) } - s.WorkDirectory = filepath.Join(p.opts.WorkDirectory, SessionDir(sessID)) - err = os.MkdirAll(s.WorkDirectory, 0o700) - if err != nil { - return xerrors.Errorf("create work directory %q: %w", s.WorkDirectory, err) - } + s.Files = tfpath.Session(p.opts.WorkDirectory, sessID) + defer func() { - var err error - // Cleanup the work directory after execution. - for attempt := 0; attempt < 5; attempt++ { - err = os.RemoveAll(s.WorkDirectory) - if err != nil { - // On Windows, open files cannot be removed. - // When the provisioner daemon is shutting down, - // it may take a few milliseconds for processes to exit. - // See: https://github.com/golang/go/issues/50510 - s.Logger.Debug(s.Context(), "failed to clean work directory; trying again", slog.Error(err)) - time.Sleep(250 * time.Millisecond) - continue - } - s.Logger.Debug(s.Context(), "cleaned up work directory") - return - } - s.Logger.Error(s.Context(), "failed to clean up work directory after multiple attempts", - slog.F("path", s.WorkDirectory), slog.Error(err)) + s.Files.Cleanup(s.Context(), s.Logger, afero.NewOsFs()) }() + req, err := stream.Recv() if err != nil { return xerrors.Errorf("receive config: %w", err) @@ -89,7 +67,7 @@ func (p *protoServer) Session(stream proto.DRPCProvisioner_SessionStream) error s.logLevel = proto.LogLevel_value[strings.ToUpper(s.Config.ProvisionerLogLevel)] } - err = s.extractArchive() + err = s.Files.ExtractArchive(s.Context(), s.Logger, afero.NewOsFs(), s.Config) if err != nil { return xerrors.Errorf("extract archive: %w", err) } @@ -144,7 +122,7 @@ func (s *Session) handleRequests() error { return err } // Handle README centrally, so that individual provisioners don't need to mess with it. - readme, err := os.ReadFile(filepath.Join(s.WorkDirectory, ReadmeFile)) + readme, err := os.ReadFile(s.Files.ReadmeFilePath()) if err == nil { complete.Readme = readme } else { @@ -220,9 +198,9 @@ func (s *Session) handleRequests() error { } type Session struct { - Logger slog.Logger - WorkDirectory string - Config *proto.Config + Logger slog.Logger + Files tfpath.Layout + Config *proto.Config server Server stream proto.DRPCProvisioner_SessionStream @@ -233,92 +211,6 @@ func (s *Session) Context() context.Context { return s.stream.Context() } -func (s *Session) extractArchive() error { - ctx := s.Context() - - s.Logger.Info(ctx, "unpacking template source archive", - slog.F("size_bytes", len(s.Config.TemplateSourceArchive)), - ) - - reader := tar.NewReader(bytes.NewBuffer(s.Config.TemplateSourceArchive)) - // for safety, nil out the reference on Config, since the reader now owns it. - s.Config.TemplateSourceArchive = nil - for { - header, err := reader.Next() - if err != nil { - if xerrors.Is(err, io.EOF) { - break - } - return xerrors.Errorf("read template source archive: %w", err) - } - s.Logger.Debug(context.Background(), "read archive entry", - slog.F("name", header.Name), - slog.F("mod_time", header.ModTime), - slog.F("size", header.Size)) - - // Security: don't untar absolute or relative paths, as this can allow a malicious tar to overwrite - // files outside the workdir. - if !filepath.IsLocal(header.Name) { - return xerrors.Errorf("refusing to extract to non-local path") - } - // nolint: gosec - headerPath := filepath.Join(s.WorkDirectory, header.Name) - if !strings.HasPrefix(headerPath, filepath.Clean(s.WorkDirectory)) { - return xerrors.New("tar attempts to target relative upper directory") - } - mode := header.FileInfo().Mode() - if mode == 0 { - mode = 0o600 - } - - // Always check for context cancellation before reading the next header. - // This is mainly important for unit tests, since a canceled context means - // the underlying directory is going to be deleted. There still exists - // the small race condition that the context is canceled after this, and - // before the disk write. - if ctx.Err() != nil { - return xerrors.Errorf("context canceled: %w", ctx.Err()) - } - switch header.Typeflag { - case tar.TypeDir: - err = os.MkdirAll(headerPath, mode) - if err != nil { - return xerrors.Errorf("mkdir %q: %w", headerPath, err) - } - s.Logger.Debug(context.Background(), "extracted directory", - slog.F("path", headerPath), - slog.F("mode", fmt.Sprintf("%O", mode))) - case tar.TypeReg: - file, err := os.OpenFile(headerPath, os.O_CREATE|os.O_RDWR, mode) - if err != nil { - return xerrors.Errorf("create file %q (mode %s): %w", headerPath, mode, err) - } - - hash := crc32.NewIEEE() - hashReader := io.TeeReader(reader, hash) - // Max file size of 10MiB. - size, err := io.CopyN(file, hashReader, 10<<20) - if xerrors.Is(err, io.EOF) { - err = nil - } - if err != nil { - _ = file.Close() - return xerrors.Errorf("copy file %q: %w", headerPath, err) - } - err = file.Close() - if err != nil { - return xerrors.Errorf("close file %q: %s", headerPath, err) - } - s.Logger.Debug(context.Background(), "extracted file", - slog.F("size_bytes", size), - slog.F("path", headerPath), - slog.F("mode", mode), - slog.F("checksum", fmt.Sprintf("%x", hash.Sum(nil)))) - } - } - return nil -} - func (s *Session) ProvisionLog(level proto.LogLevel, output string) { if int32(level) < s.logLevel { return @@ -379,8 +271,3 @@ func (r *request[R, C]) do() (C, error) { return c, nil } } - -// SessionDir returns the directory name with mandatory prefix. -func SessionDir(sessID string) string { - return sessionDirPrefix + sessID -} diff --git a/provisionersdk/tfpath/tfpath.go b/provisionersdk/tfpath/tfpath.go new file mode 100644 index 0000000000..57129e6242 --- /dev/null +++ b/provisionersdk/tfpath/tfpath.go @@ -0,0 +1,192 @@ +package tfpath + +import ( + "archive/tar" + "bytes" + "context" + "fmt" + "hash/crc32" + "io" + "os" + "path/filepath" + "strings" + "time" + + "github.com/spf13/afero" + "golang.org/x/xerrors" + + "cdr.dev/slog" + "github.com/coder/coder/v2/provisionersdk/proto" +) + +const ( + // ReadmeFile is the location we look for to extract documentation from template versions. + ReadmeFile = "README.md" + + sessionDirPrefix = "Session" +) + +// Session creates a directory structure layout for terraform execution. The +// SessionID is a unique value for creating an ephemeral working directory inside +// the parentDirPath. All helper functions will return paths for various +// terraform asserts inside this working directory. +func Session(parentDirPath, sessionID string) Layout { + return Layout(filepath.Join(parentDirPath, sessionDirPrefix+sessionID)) +} + +// Layout is the terraform execution working directory structure. +// It also contains some methods for common file operations within that layout. +// Such as "Cleanup" and "ExtractArchive". +// TODO: Maybe we should include the afero.FS here as well, then all operations +// would be on the same FS? +type Layout string + +// WorkDirectory returns the root working directory for Terraform files. +func (l Layout) WorkDirectory() string { return string(l) } + +func (l Layout) StateFilePath() string { + return filepath.Join(l.WorkDirectory(), "terraform.tfstate") +} + +func (l Layout) PlanFilePath() string { + return filepath.Join(l.WorkDirectory(), "terraform.tfplan") +} + +func (l Layout) TerraformLockFile() string { + return filepath.Join(l.WorkDirectory(), ".terraform.lock.hcl") +} + +func (l Layout) ReadmeFilePath() string { + return filepath.Join(l.WorkDirectory(), ReadmeFile) +} + +func (l Layout) TerraformMetadataDir() string { + return filepath.Join(l.WorkDirectory(), ".terraform") +} + +func (l Layout) ModulesDirectory() string { + return filepath.Join(l.TerraformMetadataDir(), "modules") +} + +func (l Layout) ModulesFilePath() string { + return filepath.Join(l.ModulesDirectory(), "modules.json") +} + +func (l Layout) ExtractArchive(ctx context.Context, logger slog.Logger, fs afero.Fs, cfg *proto.Config) error { + logger.Info(ctx, "unpacking template source archive", + slog.F("size_bytes", len(cfg.TemplateSourceArchive)), + ) + + err := fs.MkdirAll(l.WorkDirectory(), 0o700) + if err != nil { + return xerrors.Errorf("create work directory %q: %w", l.WorkDirectory(), err) + } + + reader := tar.NewReader(bytes.NewBuffer(cfg.TemplateSourceArchive)) + // for safety, nil out the reference on Config, since the reader now owns it. + cfg.TemplateSourceArchive = nil + for { + header, err := reader.Next() + if err != nil { + if xerrors.Is(err, io.EOF) { + break + } + return xerrors.Errorf("read template source archive: %w", err) + } + logger.Debug(context.Background(), "read archive entry", + slog.F("name", header.Name), + slog.F("mod_time", header.ModTime), + slog.F("size", header.Size)) + + // Security: don't untar absolute or relative paths, as this can allow a malicious tar to overwrite + // files outside the workdir. + if !filepath.IsLocal(header.Name) { + return xerrors.Errorf("refusing to extract to non-local path") + } + + // nolint: gosec // TODO: Use relative paths inside the workdir only. + headerPath := filepath.Join(l.WorkDirectory(), header.Name) + if !strings.HasPrefix(headerPath, filepath.Clean(l.WorkDirectory())) { + return xerrors.New("tar attempts to target relative upper directory") + } + mode := header.FileInfo().Mode() + if mode == 0 { + mode = 0o600 + } + + // Always check for context cancellation before reading the next header. + // This is mainly important for unit tests, since a canceled context means + // the underlying directory is going to be deleted. There still exists + // the small race condition that the context is canceled after this, and + // before the disk write. + if ctx.Err() != nil { + return xerrors.Errorf("context canceled: %w", ctx.Err()) + } + switch header.Typeflag { + case tar.TypeDir: + err = fs.MkdirAll(headerPath, mode) + if err != nil { + return xerrors.Errorf("mkdir %q: %w", headerPath, err) + } + logger.Debug(context.Background(), "extracted directory", + slog.F("path", headerPath), + slog.F("mode", fmt.Sprintf("%O", mode))) + case tar.TypeReg: + file, err := fs.OpenFile(headerPath, os.O_CREATE|os.O_RDWR, mode) + if err != nil { + return xerrors.Errorf("create file %q (mode %s): %w", headerPath, mode, err) + } + + hash := crc32.NewIEEE() + hashReader := io.TeeReader(reader, hash) + // Max file size of 10MiB. + size, err := io.CopyN(file, hashReader, 10<<20) + if xerrors.Is(err, io.EOF) { + err = nil + } + if err != nil { + _ = file.Close() + return xerrors.Errorf("copy file %q: %w", headerPath, err) + } + err = file.Close() + if err != nil { + return xerrors.Errorf("close file %q: %s", headerPath, err) + } + logger.Debug(context.Background(), "extracted file", + slog.F("size_bytes", size), + slog.F("path", headerPath), + slog.F("mode", mode), + slog.F("checksum", fmt.Sprintf("%x", hash.Sum(nil)))) + } + } + + return nil +} + +// Cleanup removes the work directory and all of its contents. +func (l Layout) Cleanup(ctx context.Context, logger slog.Logger, fs afero.Fs) { + var err error + path := l.WorkDirectory() + + for attempt := 0; attempt < 5; attempt++ { + err := fs.RemoveAll(path) + if err != nil { + // On Windows, open files cannot be removed. + // When the provisioner daemon is shutting down, + // it may take a few milliseconds for processes to exit. + // See: https://github.com/golang/go/issues/50510 + logger.Debug(ctx, "failed to clean work directory; trying again", slog.Error(err)) + // TODO: Should we abort earlier if the context is done? + time.Sleep(250 * time.Millisecond) + continue + } + logger.Debug(ctx, "cleaned up work directory") + return + } + + // Returning an error at this point cannot do any good. The caller cannot resolve + // this. There is a routine cleanup task that will remove old work directories + // when this fails. + logger.Error(ctx, "failed to clean up work directory after multiple attempts", + slog.F("path", path), slog.Error(err)) +}