chore: refactor terraform paths to a central structure (#20566)

Refactors all Terraform file path logic into a centralized tfpath package. This consolidates all path construction into a single, testable Layout type. 

Instead of passing around `string` for directories, pass around the `Layout` which has the file location methods on it.
This commit is contained in:
Steven Masley
2025-11-12 09:24:07 -06:00
committed by GitHub
parent c69eb7c157
commit 7c8deaf0d6
10 changed files with 244 additions and 174 deletions
+2 -2
View File
@@ -122,8 +122,8 @@ func readResponses(sess *provisionersdk.Session, trans string, suffix string) ([
for i := 0; ; i++ {
paths := []string{
// Try more specific path first, then fallback to generic.
filepath.Join(sess.WorkDirectory, fmt.Sprintf("%d.%s.%s", i, trans, suffix)),
filepath.Join(sess.WorkDirectory, fmt.Sprintf("%d.%s", i, suffix)),
filepath.Join(sess.Files.WorkDirectory(), fmt.Sprintf("%d.%s.%s", i, trans, suffix)),
filepath.Join(sess.Files.WorkDirectory(), fmt.Sprintf("%d.%s", i, suffix)),
}
for pathIndex, path := range paths {
_, err := os.Stat(path)
+11 -19
View File
@@ -10,7 +10,6 @@ import (
"io"
"os"
"os/exec"
"path/filepath"
"runtime"
"strings"
"sync"
@@ -22,6 +21,7 @@ import (
"golang.org/x/xerrors"
"cdr.dev/slog"
"github.com/coder/coder/v2/provisionersdk/tfpath"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/tracing"
@@ -38,10 +38,10 @@ type executor struct {
server *server
mut *sync.Mutex
binaryPath string
// cachePath and workdir must not be used by multiple processes at once.
// cachePath and files must not be used by multiple processes at once.
cachePath string
cliConfigPath string
workdir string
files tfpath.Layout
// used to capture execution times at various stages
timings *timingAggregator
}
@@ -90,7 +90,7 @@ func (e *executor) execWriteOutput(ctx, killCtx context.Context, args, env []str
// #nosec
cmd := exec.CommandContext(killCtx, e.binaryPath, args...)
cmd.Dir = e.workdir
cmd.Dir = e.files.WorkDirectory()
if env == nil {
// We don't want to passthrough host env when unset.
env = []string{}
@@ -131,7 +131,7 @@ func (e *executor) execParseJSON(ctx, killCtx context.Context, args, env []strin
// #nosec
cmd := exec.CommandContext(killCtx, e.binaryPath, args...)
cmd.Dir = e.workdir
cmd.Dir = e.files.WorkDirectory()
cmd.Env = env
out := &bytes.Buffer{}
stdErr := &bytes.Buffer{}
@@ -225,7 +225,7 @@ func (e *executor) init(ctx, killCtx context.Context, logr logSink) error {
defer e.mut.Unlock()
// Record lock file checksum before init
lockFilePath := filepath.Join(e.workdir, ".terraform.lock.hcl")
lockFilePath := e.files.TerraformLockFile()
preInitChecksum := checksumFileCRC32(ctx, e.logger, lockFilePath)
outWriter, doneOut := e.provisionLogWriter(logr)
@@ -289,14 +289,6 @@ func checksumFileCRC32(ctx context.Context, logger slog.Logger, path string) uin
return crc32.ChecksumIEEE(content)
}
func getPlanFilePath(workdir string) string {
return filepath.Join(workdir, "terraform.tfplan")
}
func getStateFilePath(workdir string) string {
return filepath.Join(workdir, "terraform.tfstate")
}
// revive:disable-next-line:flag-parameter
func (e *executor) plan(ctx, killCtx context.Context, env, vars []string, logr logSink, req *proto.PlanRequest) (*proto.PlanComplete, error) {
ctx, span := e.server.startTrace(ctx, tracing.FuncName())
@@ -307,7 +299,7 @@ func (e *executor) plan(ctx, killCtx context.Context, env, vars []string, logr l
metadata := req.Metadata
planfilePath := getPlanFilePath(e.workdir)
planfilePath := e.files.PlanFilePath()
args := []string{
"plan",
"-no-color",
@@ -359,7 +351,7 @@ func (e *executor) plan(ctx, killCtx context.Context, env, vars []string, logr l
// a workspace build. This removes some added costs of sending the modules
// payload back to coderd if coderd is just going to ignore it.
if !req.OmitModuleFiles {
moduleFiles, err = GetModulesArchive(os.DirFS(e.workdir))
moduleFiles, err = GetModulesArchive(os.DirFS(e.files.WorkDirectory()))
if err != nil {
// TODO: we probably want to persist this error or make it louder eventually
e.logger.Warn(ctx, "failed to archive terraform modules", slog.Error(err))
@@ -551,7 +543,7 @@ func (e *executor) graph(ctx, killCtx context.Context) (string, error) {
var out strings.Builder
cmd := exec.CommandContext(killCtx, e.binaryPath, args...) // #nosec
cmd.Stdout = &out
cmd.Dir = e.workdir
cmd.Dir = e.files.WorkDirectory()
cmd.Env = e.basicEnv()
e.server.logger.Debug(ctx, "executing terraform command graph",
@@ -588,7 +580,7 @@ func (e *executor) apply(
"-auto-approve",
"-input=false",
"-json",
getPlanFilePath(e.workdir),
e.files.PlanFilePath(),
}
outWriter, doneOut := e.provisionLogWriter(logr)
@@ -608,7 +600,7 @@ func (e *executor) apply(
if err != nil {
return nil, err
}
statefilePath := getStateFilePath(e.workdir)
statefilePath := e.files.StateFilePath()
stateContent, err := os.ReadFile(statefilePath)
if err != nil {
return nil, xerrors.Errorf("read statefile %q: %w", statefilePath, err)
+3 -7
View File
@@ -7,7 +7,6 @@ import (
"io"
"io/fs"
"os"
"path/filepath"
"strings"
"time"
@@ -15,6 +14,7 @@ import (
"github.com/coder/coder/v2/coderd/util/xio"
"github.com/coder/coder/v2/provisionersdk/proto"
"github.com/coder/coder/v2/provisionersdk/tfpath"
)
const (
@@ -39,10 +39,6 @@ type modulesFile struct {
Modules []*module `json:"Modules"`
}
func getModulesFilePath(workdir string) string {
return filepath.Join(workdir, ".terraform", "modules", "modules.json")
}
func parseModulesFile(filePath string) ([]*proto.Module, error) {
modules := &modulesFile{}
data, err := os.ReadFile(filePath)
@@ -62,8 +58,8 @@ func parseModulesFile(filePath string) ([]*proto.Module, error) {
// getModules returns the modules from the modules file if it exists.
// It returns nil if the file does not exist.
// Modules become available after terraform init.
func getModules(workdir string) ([]*proto.Module, error) {
filePath := getModulesFilePath(workdir)
func getModules(files tfpath.Layout) ([]*proto.Module, error) {
filePath := files.ModulesFilePath()
if _, err := os.Stat(filePath); os.IsNotExist(err) {
return nil, nil
}
+2 -2
View File
@@ -25,9 +25,9 @@ func (s *server) Parse(sess *provisionersdk.Session, _ *proto.ParseRequest, _ <-
defer span.End()
// Load the module and print any parse errors.
parser, diags := tfparse.New(sess.WorkDirectory, tfparse.WithLogger(s.logger.Named("tfparse")))
parser, diags := tfparse.New(sess.Files.WorkDirectory(), tfparse.WithLogger(s.logger.Named("tfparse")))
if diags.HasErrors() {
return provisionersdk.ParseErrorf("load module: %s", formatDiagnostics(sess.WorkDirectory, diags))
return provisionersdk.ParseErrorf("load module: %s", formatDiagnostics(sess.Files.WorkDirectory(), diags))
}
workspaceTags, _, err := parser.WorkspaceTags(ctx)
+6 -6
View File
@@ -76,7 +76,7 @@ func (s *server) Plan(
defer cancel()
defer kill()
e := s.executor(sess.WorkDirectory, database.ProvisionerJobTimingStagePlan)
e := s.executor(sess.Files, database.ProvisionerJobTimingStagePlan)
if err := e.checkMinVersion(ctx); err != nil {
return provisionersdk.PlanErrorf("%s", err.Error())
}
@@ -92,7 +92,7 @@ func (s *server) Plan(
return &proto.PlanComplete{}
}
statefilePath := getStateFilePath(sess.WorkDirectory)
statefilePath := sess.Files.StateFilePath()
if len(sess.Config.State) > 0 {
err := os.WriteFile(statefilePath, sess.Config.State, 0o600)
if err != nil {
@@ -141,7 +141,7 @@ func (s *server) Plan(
return provisionersdk.PlanErrorf("initialize terraform: %s", err)
}
modules, err := getModules(sess.WorkDirectory)
modules, err := getModules(sess.Files)
if err != nil {
// We allow getModules to fail, as the result is used only
// for telemetry purposes now.
@@ -184,7 +184,7 @@ func (s *server) Apply(
defer cancel()
defer kill()
e := s.executor(sess.WorkDirectory, database.ProvisionerJobTimingStageApply)
e := s.executor(sess.Files, database.ProvisionerJobTimingStageApply)
if err := e.checkMinVersion(ctx); err != nil {
return provisionersdk.ApplyErrorf("%s", err.Error())
}
@@ -201,7 +201,7 @@ func (s *server) Apply(
}
// Earlier in the session, Plan() will have written the state file and the plan file.
statefilePath := getStateFilePath(sess.WorkDirectory)
statefilePath := sess.Files.StateFilePath()
env, err := provisionEnv(sess.Config, request.Metadata, nil, nil, nil)
if err != nil {
return provisionersdk.ApplyErrorf("provision env: %s", err)
@@ -348,7 +348,7 @@ func logTerraformEnvVars(sink logSink) {
// shipped in v1.0.4. It will return the stacktraces of the provider, which will hopefully allow us
// to figure out why it hasn't exited.
func tryGettingCoderProviderStacktrace(sess *provisionersdk.Session) string {
path := filepath.Clean(filepath.Join(sess.WorkDirectory, "../.coder/pprof"))
path := filepath.Clean(filepath.Join(sess.Files.WorkDirectory(), "../.coder/pprof"))
sess.Logger.Info(sess.Context(), "attempting to get stack traces", slog.F("path", path))
c := http.Client{
Transport: &http.Transport{
+3 -2
View File
@@ -14,6 +14,7 @@ import (
"golang.org/x/xerrors"
"cdr.dev/slog"
"github.com/coder/coder/v2/provisionersdk/tfpath"
"github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/jobreaper"
@@ -160,14 +161,14 @@ func (s *server) startTrace(ctx context.Context, name string, opts ...trace.Span
))...)
}
func (s *server) executor(workdir string, stage database.ProvisionerJobTimingStage) *executor {
func (s *server) executor(files tfpath.Layout, stage database.ProvisionerJobTimingStage) *executor {
return &executor{
server: s,
mut: s.execMut,
binaryPath: s.binaryPath,
cachePath: s.cachePath,
cliConfigPath: s.cliConfigPath,
workdir: workdir,
files: files,
logger: s.logger.Named("executor"),
timings: newTimingAggregator(stage),
}
+1 -1
View File
@@ -353,7 +353,7 @@ func TestProvisionerd(t *testing.T) {
_ *sdkproto.ParseRequest,
cancelOrComplete <-chan struct{},
) *sdkproto.ParseComplete {
data, err := os.ReadFile(filepath.Join(s.WorkDirectory, "test.txt"))
data, err := os.ReadFile(filepath.Join(s.Files.WorkDirectory(), "test.txt"))
require.NoError(t, err)
require.Equal(t, "content", string(data))
s.ProvisionLog(sdkproto.LogLevel_INFO, "hello")
+13 -11
View File
@@ -12,6 +12,7 @@ import (
"cdr.dev/slog"
"github.com/coder/coder/v2/provisionersdk"
"github.com/coder/coder/v2/provisionersdk/tfpath"
"github.com/coder/coder/v2/testutil"
)
@@ -40,11 +41,11 @@ func TestStaleSessions(t *testing.T) {
fs, logger := prepare()
// given
first := provisionersdk.SessionDir(uuid.NewString())
first := tfpath.Session(workDirectory, uuid.NewString())
addSessionFolder(t, fs, first, now.Add(-7*24*time.Hour))
second := provisionersdk.SessionDir(uuid.NewString())
second := tfpath.Session(workDirectory, uuid.NewString())
addSessionFolder(t, fs, second, now.Add(-8*24*time.Hour))
third := provisionersdk.SessionDir(uuid.NewString())
third := tfpath.Session(workDirectory, uuid.NewString())
addSessionFolder(t, fs, third, now.Add(-9*24*time.Hour))
// when
@@ -65,9 +66,9 @@ func TestStaleSessions(t *testing.T) {
fs, logger := prepare()
// given
first := provisionersdk.SessionDir(uuid.NewString())
first := tfpath.Session(workDirectory, uuid.NewString())
addSessionFolder(t, fs, first, now.Add(-7*24*time.Hour))
second := provisionersdk.SessionDir(uuid.NewString())
second := tfpath.Session(workDirectory, uuid.NewString())
addSessionFolder(t, fs, second, now.Add(-6*24*time.Hour))
// when
@@ -77,7 +78,7 @@ func TestStaleSessions(t *testing.T) {
entries, err := afero.ReadDir(fs, workDirectory)
require.NoError(t, err)
require.Len(t, entries, 1, "one session should be present")
require.Equal(t, second, entries[0].Name(), 1)
require.Equal(t, second.WorkDirectory(), filepath.Join(workDirectory, entries[0].Name()), 1)
})
t.Run("no stale sessions", func(t *testing.T) {
@@ -89,9 +90,9 @@ func TestStaleSessions(t *testing.T) {
fs, logger := prepare()
// given
first := provisionersdk.SessionDir(uuid.NewString())
first := tfpath.Session(workDirectory, uuid.NewString())
addSessionFolder(t, fs, first, now.Add(-6*24*time.Hour))
second := provisionersdk.SessionDir(uuid.NewString())
second := tfpath.Session(workDirectory, uuid.NewString())
addSessionFolder(t, fs, second, now.Add(-5*24*time.Hour))
// when
@@ -104,9 +105,10 @@ func TestStaleSessions(t *testing.T) {
})
}
func addSessionFolder(t *testing.T, fs afero.Fs, sessionName string, modTime time.Time) {
err := fs.MkdirAll(filepath.Join(workDirectory, sessionName), 0o755)
func addSessionFolder(t *testing.T, fs afero.Fs, files tfpath.Layout, modTime time.Time) {
workdir := files.WorkDirectory()
err := fs.MkdirAll(workdir, 0o755)
require.NoError(t, err, "can't create session folder")
require.NoError(t, fs.Chtimes(filepath.Join(workDirectory, sessionName), now, modTime), "can't chtime of session dir")
require.NoError(t, fs.Chtimes(workdir, now, modTime), "can't chtime of session dir")
require.NoError(t, err, "can't set times")
}
+11 -124
View File
@@ -1,14 +1,10 @@
package provisionersdk
import (
"archive/tar"
"bytes"
"context"
"fmt"
"hash/crc32"
"io"
"os"
"path/filepath"
"strings"
"time"
@@ -18,6 +14,7 @@ import (
"cdr.dev/slog"
"github.com/coder/coder/v2/codersdk/drpcsdk"
"github.com/coder/coder/v2/provisionersdk/tfpath"
protobuf "google.golang.org/protobuf/proto"
@@ -48,34 +45,15 @@ func (p *protoServer) Session(stream proto.DRPCProvisioner_SessionStream) error
err := CleanStaleSessions(s.Context(), p.opts.WorkDirectory, afero.NewOsFs(), time.Now(), s.Logger)
if err != nil {
return xerrors.Errorf("unable to clean stale sessions %q: %w", s.WorkDirectory, err)
return xerrors.Errorf("unable to clean stale sessions %q: %w", s.Files, err)
}
s.WorkDirectory = filepath.Join(p.opts.WorkDirectory, SessionDir(sessID))
err = os.MkdirAll(s.WorkDirectory, 0o700)
if err != nil {
return xerrors.Errorf("create work directory %q: %w", s.WorkDirectory, err)
}
s.Files = tfpath.Session(p.opts.WorkDirectory, sessID)
defer func() {
var err error
// Cleanup the work directory after execution.
for attempt := 0; attempt < 5; attempt++ {
err = os.RemoveAll(s.WorkDirectory)
if err != nil {
// On Windows, open files cannot be removed.
// When the provisioner daemon is shutting down,
// it may take a few milliseconds for processes to exit.
// See: https://github.com/golang/go/issues/50510
s.Logger.Debug(s.Context(), "failed to clean work directory; trying again", slog.Error(err))
time.Sleep(250 * time.Millisecond)
continue
}
s.Logger.Debug(s.Context(), "cleaned up work directory")
return
}
s.Logger.Error(s.Context(), "failed to clean up work directory after multiple attempts",
slog.F("path", s.WorkDirectory), slog.Error(err))
s.Files.Cleanup(s.Context(), s.Logger, afero.NewOsFs())
}()
req, err := stream.Recv()
if err != nil {
return xerrors.Errorf("receive config: %w", err)
@@ -89,7 +67,7 @@ func (p *protoServer) Session(stream proto.DRPCProvisioner_SessionStream) error
s.logLevel = proto.LogLevel_value[strings.ToUpper(s.Config.ProvisionerLogLevel)]
}
err = s.extractArchive()
err = s.Files.ExtractArchive(s.Context(), s.Logger, afero.NewOsFs(), s.Config)
if err != nil {
return xerrors.Errorf("extract archive: %w", err)
}
@@ -144,7 +122,7 @@ func (s *Session) handleRequests() error {
return err
}
// Handle README centrally, so that individual provisioners don't need to mess with it.
readme, err := os.ReadFile(filepath.Join(s.WorkDirectory, ReadmeFile))
readme, err := os.ReadFile(s.Files.ReadmeFilePath())
if err == nil {
complete.Readme = readme
} else {
@@ -220,9 +198,9 @@ func (s *Session) handleRequests() error {
}
type Session struct {
Logger slog.Logger
WorkDirectory string
Config *proto.Config
Logger slog.Logger
Files tfpath.Layout
Config *proto.Config
server Server
stream proto.DRPCProvisioner_SessionStream
@@ -233,92 +211,6 @@ func (s *Session) Context() context.Context {
return s.stream.Context()
}
func (s *Session) extractArchive() error {
ctx := s.Context()
s.Logger.Info(ctx, "unpacking template source archive",
slog.F("size_bytes", len(s.Config.TemplateSourceArchive)),
)
reader := tar.NewReader(bytes.NewBuffer(s.Config.TemplateSourceArchive))
// for safety, nil out the reference on Config, since the reader now owns it.
s.Config.TemplateSourceArchive = nil
for {
header, err := reader.Next()
if err != nil {
if xerrors.Is(err, io.EOF) {
break
}
return xerrors.Errorf("read template source archive: %w", err)
}
s.Logger.Debug(context.Background(), "read archive entry",
slog.F("name", header.Name),
slog.F("mod_time", header.ModTime),
slog.F("size", header.Size))
// Security: don't untar absolute or relative paths, as this can allow a malicious tar to overwrite
// files outside the workdir.
if !filepath.IsLocal(header.Name) {
return xerrors.Errorf("refusing to extract to non-local path")
}
// nolint: gosec
headerPath := filepath.Join(s.WorkDirectory, header.Name)
if !strings.HasPrefix(headerPath, filepath.Clean(s.WorkDirectory)) {
return xerrors.New("tar attempts to target relative upper directory")
}
mode := header.FileInfo().Mode()
if mode == 0 {
mode = 0o600
}
// Always check for context cancellation before reading the next header.
// This is mainly important for unit tests, since a canceled context means
// the underlying directory is going to be deleted. There still exists
// the small race condition that the context is canceled after this, and
// before the disk write.
if ctx.Err() != nil {
return xerrors.Errorf("context canceled: %w", ctx.Err())
}
switch header.Typeflag {
case tar.TypeDir:
err = os.MkdirAll(headerPath, mode)
if err != nil {
return xerrors.Errorf("mkdir %q: %w", headerPath, err)
}
s.Logger.Debug(context.Background(), "extracted directory",
slog.F("path", headerPath),
slog.F("mode", fmt.Sprintf("%O", mode)))
case tar.TypeReg:
file, err := os.OpenFile(headerPath, os.O_CREATE|os.O_RDWR, mode)
if err != nil {
return xerrors.Errorf("create file %q (mode %s): %w", headerPath, mode, err)
}
hash := crc32.NewIEEE()
hashReader := io.TeeReader(reader, hash)
// Max file size of 10MiB.
size, err := io.CopyN(file, hashReader, 10<<20)
if xerrors.Is(err, io.EOF) {
err = nil
}
if err != nil {
_ = file.Close()
return xerrors.Errorf("copy file %q: %w", headerPath, err)
}
err = file.Close()
if err != nil {
return xerrors.Errorf("close file %q: %s", headerPath, err)
}
s.Logger.Debug(context.Background(), "extracted file",
slog.F("size_bytes", size),
slog.F("path", headerPath),
slog.F("mode", mode),
slog.F("checksum", fmt.Sprintf("%x", hash.Sum(nil))))
}
}
return nil
}
func (s *Session) ProvisionLog(level proto.LogLevel, output string) {
if int32(level) < s.logLevel {
return
@@ -379,8 +271,3 @@ func (r *request[R, C]) do() (C, error) {
return c, nil
}
}
// SessionDir returns the directory name with mandatory prefix.
func SessionDir(sessID string) string {
return sessionDirPrefix + sessID
}
+192
View File
@@ -0,0 +1,192 @@
package tfpath
import (
"archive/tar"
"bytes"
"context"
"fmt"
"hash/crc32"
"io"
"os"
"path/filepath"
"strings"
"time"
"github.com/spf13/afero"
"golang.org/x/xerrors"
"cdr.dev/slog"
"github.com/coder/coder/v2/provisionersdk/proto"
)
const (
// ReadmeFile is the location we look for to extract documentation from template versions.
ReadmeFile = "README.md"
sessionDirPrefix = "Session"
)
// Session creates a directory structure layout for terraform execution. The
// SessionID is a unique value for creating an ephemeral working directory inside
// the parentDirPath. All helper functions will return paths for various
// terraform asserts inside this working directory.
func Session(parentDirPath, sessionID string) Layout {
return Layout(filepath.Join(parentDirPath, sessionDirPrefix+sessionID))
}
// Layout is the terraform execution working directory structure.
// It also contains some methods for common file operations within that layout.
// Such as "Cleanup" and "ExtractArchive".
// TODO: Maybe we should include the afero.FS here as well, then all operations
// would be on the same FS?
type Layout string
// WorkDirectory returns the root working directory for Terraform files.
func (l Layout) WorkDirectory() string { return string(l) }
func (l Layout) StateFilePath() string {
return filepath.Join(l.WorkDirectory(), "terraform.tfstate")
}
func (l Layout) PlanFilePath() string {
return filepath.Join(l.WorkDirectory(), "terraform.tfplan")
}
func (l Layout) TerraformLockFile() string {
return filepath.Join(l.WorkDirectory(), ".terraform.lock.hcl")
}
func (l Layout) ReadmeFilePath() string {
return filepath.Join(l.WorkDirectory(), ReadmeFile)
}
func (l Layout) TerraformMetadataDir() string {
return filepath.Join(l.WorkDirectory(), ".terraform")
}
func (l Layout) ModulesDirectory() string {
return filepath.Join(l.TerraformMetadataDir(), "modules")
}
func (l Layout) ModulesFilePath() string {
return filepath.Join(l.ModulesDirectory(), "modules.json")
}
func (l Layout) ExtractArchive(ctx context.Context, logger slog.Logger, fs afero.Fs, cfg *proto.Config) error {
logger.Info(ctx, "unpacking template source archive",
slog.F("size_bytes", len(cfg.TemplateSourceArchive)),
)
err := fs.MkdirAll(l.WorkDirectory(), 0o700)
if err != nil {
return xerrors.Errorf("create work directory %q: %w", l.WorkDirectory(), err)
}
reader := tar.NewReader(bytes.NewBuffer(cfg.TemplateSourceArchive))
// for safety, nil out the reference on Config, since the reader now owns it.
cfg.TemplateSourceArchive = nil
for {
header, err := reader.Next()
if err != nil {
if xerrors.Is(err, io.EOF) {
break
}
return xerrors.Errorf("read template source archive: %w", err)
}
logger.Debug(context.Background(), "read archive entry",
slog.F("name", header.Name),
slog.F("mod_time", header.ModTime),
slog.F("size", header.Size))
// Security: don't untar absolute or relative paths, as this can allow a malicious tar to overwrite
// files outside the workdir.
if !filepath.IsLocal(header.Name) {
return xerrors.Errorf("refusing to extract to non-local path")
}
// nolint: gosec // TODO: Use relative paths inside the workdir only.
headerPath := filepath.Join(l.WorkDirectory(), header.Name)
if !strings.HasPrefix(headerPath, filepath.Clean(l.WorkDirectory())) {
return xerrors.New("tar attempts to target relative upper directory")
}
mode := header.FileInfo().Mode()
if mode == 0 {
mode = 0o600
}
// Always check for context cancellation before reading the next header.
// This is mainly important for unit tests, since a canceled context means
// the underlying directory is going to be deleted. There still exists
// the small race condition that the context is canceled after this, and
// before the disk write.
if ctx.Err() != nil {
return xerrors.Errorf("context canceled: %w", ctx.Err())
}
switch header.Typeflag {
case tar.TypeDir:
err = fs.MkdirAll(headerPath, mode)
if err != nil {
return xerrors.Errorf("mkdir %q: %w", headerPath, err)
}
logger.Debug(context.Background(), "extracted directory",
slog.F("path", headerPath),
slog.F("mode", fmt.Sprintf("%O", mode)))
case tar.TypeReg:
file, err := fs.OpenFile(headerPath, os.O_CREATE|os.O_RDWR, mode)
if err != nil {
return xerrors.Errorf("create file %q (mode %s): %w", headerPath, mode, err)
}
hash := crc32.NewIEEE()
hashReader := io.TeeReader(reader, hash)
// Max file size of 10MiB.
size, err := io.CopyN(file, hashReader, 10<<20)
if xerrors.Is(err, io.EOF) {
err = nil
}
if err != nil {
_ = file.Close()
return xerrors.Errorf("copy file %q: %w", headerPath, err)
}
err = file.Close()
if err != nil {
return xerrors.Errorf("close file %q: %s", headerPath, err)
}
logger.Debug(context.Background(), "extracted file",
slog.F("size_bytes", size),
slog.F("path", headerPath),
slog.F("mode", mode),
slog.F("checksum", fmt.Sprintf("%x", hash.Sum(nil))))
}
}
return nil
}
// Cleanup removes the work directory and all of its contents.
func (l Layout) Cleanup(ctx context.Context, logger slog.Logger, fs afero.Fs) {
var err error
path := l.WorkDirectory()
for attempt := 0; attempt < 5; attempt++ {
err := fs.RemoveAll(path)
if err != nil {
// On Windows, open files cannot be removed.
// When the provisioner daemon is shutting down,
// it may take a few milliseconds for processes to exit.
// See: https://github.com/golang/go/issues/50510
logger.Debug(ctx, "failed to clean work directory; trying again", slog.Error(err))
// TODO: Should we abort earlier if the context is done?
time.Sleep(250 * time.Millisecond)
continue
}
logger.Debug(ctx, "cleaned up work directory")
return
}
// Returning an error at this point cannot do any good. The caller cannot resolve
// this. There is a routine cleanup task that will remove old work directories
// when this fails.
logger.Error(ctx, "failed to clean up work directory after multiple attempts",
slog.F("path", path), slog.Error(err))
}