mirror of
https://github.com/coder/coder.git
synced 2026-06-02 20:48:20 +00:00
chore(coderd): extract fileszip to package archive for reuse (#15229)
Related to https://github.com/coder/coder/issues/15087 As part of sniffing the workspace tags from an uploaded file, we need to be able to handle both zip and tar files. Extracting the functions to a separate `archive` package will be helpful here.
This commit is contained in:
@@ -0,0 +1,115 @@
|
||||
package archive
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"archive/zip"
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"log"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// CreateTarFromZip converts the given zipReader to a tar archive.
|
||||
func CreateTarFromZip(zipReader *zip.Reader, maxSize int64) ([]byte, error) {
|
||||
var tarBuffer bytes.Buffer
|
||||
err := writeTarArchive(&tarBuffer, zipReader, maxSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return tarBuffer.Bytes(), nil
|
||||
}
|
||||
|
||||
func writeTarArchive(w io.Writer, zipReader *zip.Reader, maxSize int64) error {
|
||||
tarWriter := tar.NewWriter(w)
|
||||
defer tarWriter.Close()
|
||||
|
||||
for _, file := range zipReader.File {
|
||||
err := processFileInZipArchive(file, tarWriter, maxSize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func processFileInZipArchive(file *zip.File, tarWriter *tar.Writer, maxSize int64) error {
|
||||
fileReader, err := file.Open()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer fileReader.Close()
|
||||
|
||||
err = tarWriter.WriteHeader(&tar.Header{
|
||||
Name: file.Name,
|
||||
Size: file.FileInfo().Size(),
|
||||
Mode: int64(file.Mode()),
|
||||
ModTime: file.Modified,
|
||||
// Note: Zip archives do not store ownership information.
|
||||
Uid: 1000,
|
||||
Gid: 1000,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
n, err := io.CopyN(tarWriter, fileReader, maxSize)
|
||||
log.Println(file.Name, n, err)
|
||||
if errors.Is(err, io.EOF) {
|
||||
err = nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// CreateZipFromTar converts the given tarReader to a zip archive.
|
||||
func CreateZipFromTar(tarReader *tar.Reader, maxSize int64) ([]byte, error) {
|
||||
var zipBuffer bytes.Buffer
|
||||
err := WriteZip(&zipBuffer, tarReader, maxSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return zipBuffer.Bytes(), nil
|
||||
}
|
||||
|
||||
// WriteZip writes the given tarReader to w.
|
||||
func WriteZip(w io.Writer, tarReader *tar.Reader, maxSize int64) error {
|
||||
zipWriter := zip.NewWriter(w)
|
||||
defer zipWriter.Close()
|
||||
|
||||
for {
|
||||
tarHeader, err := tarReader.Next()
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
zipHeader, err := zip.FileInfoHeader(tarHeader.FileInfo())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
zipHeader.Name = tarHeader.Name
|
||||
// Some versions of unzip do not check the mode on a file entry and
|
||||
// simply assume that entries with a trailing path separator (/) are
|
||||
// directories, and that everything else is a file. Give them a hint.
|
||||
if tarHeader.FileInfo().IsDir() && !strings.HasSuffix(tarHeader.Name, "/") {
|
||||
zipHeader.Name += "/"
|
||||
}
|
||||
|
||||
zipEntry, err := zipWriter.CreateHeader(zipHeader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = io.CopyN(zipEntry, tarReader, maxSize)
|
||||
if errors.Is(err, io.EOF) {
|
||||
err = nil
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil // don't need to flush as we call `writer.Close()`
|
||||
}
|
||||
@@ -0,0 +1,166 @@
|
||||
package archive_test
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"archive/zip"
|
||||
"bytes"
|
||||
"io/fs"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/coder/coder/v2/archive"
|
||||
"github.com/coder/coder/v2/archive/archivetest"
|
||||
"github.com/coder/coder/v2/testutil"
|
||||
)
|
||||
|
||||
func TestCreateTarFromZip(t *testing.T) {
|
||||
t.Parallel()
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skip("skipping this test on non-Linux platform")
|
||||
}
|
||||
|
||||
// Read a zip file we prepared earlier
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
zipBytes := archivetest.TestZipFileBytes()
|
||||
// Assert invariant
|
||||
archivetest.AssertSampleZipFile(t, zipBytes)
|
||||
|
||||
zr, err := zip.NewReader(bytes.NewReader(zipBytes), int64(len(zipBytes)))
|
||||
require.NoError(t, err, "failed to parse sample zip file")
|
||||
|
||||
tarBytes, err := archive.CreateTarFromZip(zr, int64(len(zipBytes)))
|
||||
require.NoError(t, err, "failed to convert zip to tar")
|
||||
|
||||
archivetest.AssertSampleTarFile(t, tarBytes)
|
||||
|
||||
tempDir := t.TempDir()
|
||||
tempFilePath := filepath.Join(tempDir, "test.tar")
|
||||
err = os.WriteFile(tempFilePath, tarBytes, 0o600)
|
||||
require.NoError(t, err, "failed to write converted tar file")
|
||||
|
||||
cmd := exec.CommandContext(ctx, "tar", "--extract", "--verbose", "--file", tempFilePath, "--directory", tempDir)
|
||||
require.NoError(t, cmd.Run(), "failed to extract converted tar file")
|
||||
assertExtractedFiles(t, tempDir, true)
|
||||
}
|
||||
|
||||
func TestCreateZipFromTar(t *testing.T) {
|
||||
t.Parallel()
|
||||
if runtime.GOOS != "linux" {
|
||||
t.Skip("skipping this test on non-Linux platform")
|
||||
}
|
||||
t.Run("OK", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
tarBytes := archivetest.TestTarFileBytes()
|
||||
|
||||
tr := tar.NewReader(bytes.NewReader(tarBytes))
|
||||
zipBytes, err := archive.CreateZipFromTar(tr, int64(len(tarBytes)))
|
||||
require.NoError(t, err)
|
||||
|
||||
archivetest.AssertSampleZipFile(t, zipBytes)
|
||||
|
||||
tempDir := t.TempDir()
|
||||
tempFilePath := filepath.Join(tempDir, "test.zip")
|
||||
err = os.WriteFile(tempFilePath, zipBytes, 0o600)
|
||||
require.NoError(t, err, "failed to write converted zip file")
|
||||
|
||||
ctx := testutil.Context(t, testutil.WaitShort)
|
||||
cmd := exec.CommandContext(ctx, "unzip", tempFilePath, "-d", tempDir)
|
||||
require.NoError(t, cmd.Run(), "failed to extract converted zip file")
|
||||
|
||||
assertExtractedFiles(t, tempDir, false)
|
||||
})
|
||||
|
||||
t.Run("MissingSlashInDirectoryHeader", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Given: a tar archive containing a directory entry that has the directory
|
||||
// mode bit set but the name is missing a trailing slash
|
||||
|
||||
var tarBytes bytes.Buffer
|
||||
tw := tar.NewWriter(&tarBytes)
|
||||
tw.WriteHeader(&tar.Header{
|
||||
Name: "dir",
|
||||
Typeflag: tar.TypeDir,
|
||||
Size: 0,
|
||||
})
|
||||
require.NoError(t, tw.Flush())
|
||||
require.NoError(t, tw.Close())
|
||||
|
||||
// When: we convert this to a zip
|
||||
tr := tar.NewReader(&tarBytes)
|
||||
zipBytes, err := archive.CreateZipFromTar(tr, int64(tarBytes.Len()))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Then: the resulting zip should contain a corresponding directory
|
||||
zr, err := zip.NewReader(bytes.NewReader(zipBytes), int64(len(zipBytes)))
|
||||
require.NoError(t, err)
|
||||
for _, zf := range zr.File {
|
||||
switch zf.Name {
|
||||
case "dir":
|
||||
require.Fail(t, "missing trailing slash in directory name")
|
||||
case "dir/":
|
||||
require.True(t, zf.Mode().IsDir(), "should be a directory")
|
||||
default:
|
||||
require.Fail(t, "unexpected file in archive")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// nolint:revive // this is a control flag but it's in a unit test
|
||||
func assertExtractedFiles(t *testing.T, dir string, checkModePerm bool) {
|
||||
t.Helper()
|
||||
|
||||
_ = filepath.Walk(dir, func(path string, info fs.FileInfo, err error) error {
|
||||
relPath := strings.TrimPrefix(path, dir)
|
||||
switch relPath {
|
||||
case "", "/test.zip", "/test.tar": // ignore
|
||||
case "/test":
|
||||
stat, err := os.Stat(path)
|
||||
assert.NoError(t, err, "failed to stat path %q", path)
|
||||
assert.True(t, stat.IsDir(), "expected path %q to be a directory")
|
||||
if checkModePerm {
|
||||
assert.Equal(t, fs.ModePerm&0o755, stat.Mode().Perm(), "expected mode 0755 on directory")
|
||||
}
|
||||
assert.Equal(t, archivetest.ArchiveRefTime(t).UTC(), stat.ModTime().UTC(), "unexpected modtime of %q", path)
|
||||
case "/test/hello.txt":
|
||||
stat, err := os.Stat(path)
|
||||
assert.NoError(t, err, "failed to stat path %q", path)
|
||||
assert.False(t, stat.IsDir(), "expected path %q to be a file")
|
||||
if checkModePerm {
|
||||
assert.Equal(t, fs.ModePerm&0o644, stat.Mode().Perm(), "expected mode 0644 on file")
|
||||
}
|
||||
bs, err := os.ReadFile(path)
|
||||
assert.NoError(t, err, "failed to read file %q", path)
|
||||
assert.Equal(t, "hello", string(bs), "unexpected content in file %q", path)
|
||||
case "/test/dir":
|
||||
stat, err := os.Stat(path)
|
||||
assert.NoError(t, err, "failed to stat path %q", path)
|
||||
assert.True(t, stat.IsDir(), "expected path %q to be a directory")
|
||||
if checkModePerm {
|
||||
assert.Equal(t, fs.ModePerm&0o755, stat.Mode().Perm(), "expected mode 0755 on directory")
|
||||
}
|
||||
case "/test/dir/world.txt":
|
||||
stat, err := os.Stat(path)
|
||||
assert.NoError(t, err, "failed to stat path %q", path)
|
||||
assert.False(t, stat.IsDir(), "expected path %q to be a file")
|
||||
if checkModePerm {
|
||||
assert.Equal(t, fs.ModePerm&0o644, stat.Mode().Perm(), "expected mode 0644 on file")
|
||||
}
|
||||
bs, err := os.ReadFile(path)
|
||||
assert.NoError(t, err, "failed to read file %q", path)
|
||||
assert.Equal(t, "world", string(bs), "unexpected content in file %q", path)
|
||||
default:
|
||||
assert.Fail(t, "unexpected path", relPath)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,113 @@
|
||||
package archivetest
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"archive/zip"
|
||||
"bytes"
|
||||
_ "embed"
|
||||
"io"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
//go:embed testdata/test.tar
|
||||
var testTarFileBytes []byte
|
||||
|
||||
//go:embed testdata/test.zip
|
||||
var testZipFileBytes []byte
|
||||
|
||||
// TestTarFileBytes returns the content of testdata/test.tar
|
||||
func TestTarFileBytes() []byte {
|
||||
return append([]byte{}, testTarFileBytes...)
|
||||
}
|
||||
|
||||
// TestZipFileBytes returns the content of testdata/test.zip
|
||||
func TestZipFileBytes() []byte {
|
||||
return append([]byte{}, testZipFileBytes...)
|
||||
}
|
||||
|
||||
// AssertSampleTarfile compares the content of tarBytes against testdata/test.tar.
|
||||
func AssertSampleTarFile(t *testing.T, tarBytes []byte) {
|
||||
t.Helper()
|
||||
|
||||
tr := tar.NewReader(bytes.NewReader(tarBytes))
|
||||
for {
|
||||
hdr, err := tr.Next()
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// Note: ignoring timezones here.
|
||||
require.Equal(t, ArchiveRefTime(t).UTC(), hdr.ModTime.UTC())
|
||||
|
||||
switch hdr.Name {
|
||||
case "test/":
|
||||
require.Equal(t, hdr.Typeflag, byte(tar.TypeDir))
|
||||
case "test/hello.txt":
|
||||
require.Equal(t, hdr.Typeflag, byte(tar.TypeReg))
|
||||
bs, err := io.ReadAll(tr)
|
||||
if err != nil && !xerrors.Is(err, io.EOF) {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
require.Equal(t, "hello", string(bs))
|
||||
case "test/dir/":
|
||||
require.Equal(t, hdr.Typeflag, byte(tar.TypeDir))
|
||||
case "test/dir/world.txt":
|
||||
require.Equal(t, hdr.Typeflag, byte(tar.TypeReg))
|
||||
bs, err := io.ReadAll(tr)
|
||||
if err != nil && !xerrors.Is(err, io.EOF) {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
require.Equal(t, "world", string(bs))
|
||||
default:
|
||||
require.Failf(t, "unexpected file in tar", hdr.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AssertSampleZipFile compares the content of zipBytes against testdata/test.zip.
|
||||
func AssertSampleZipFile(t *testing.T, zipBytes []byte) {
|
||||
t.Helper()
|
||||
|
||||
zr, err := zip.NewReader(bytes.NewReader(zipBytes), int64(len(zipBytes)))
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, f := range zr.File {
|
||||
// Note: ignoring timezones here.
|
||||
require.Equal(t, ArchiveRefTime(t).UTC(), f.Modified.UTC())
|
||||
switch f.Name {
|
||||
case "test/", "test/dir/":
|
||||
// directory
|
||||
case "test/hello.txt":
|
||||
rc, err := f.Open()
|
||||
require.NoError(t, err)
|
||||
bs, err := io.ReadAll(rc)
|
||||
_ = rc.Close()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "hello", string(bs))
|
||||
case "test/dir/world.txt":
|
||||
rc, err := f.Open()
|
||||
require.NoError(t, err)
|
||||
bs, err := io.ReadAll(rc)
|
||||
_ = rc.Close()
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "world", string(bs))
|
||||
default:
|
||||
require.Failf(t, "unexpected file in zip", f.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// archiveRefTime is the Go reference time. The contents of the sample tar and zip files
|
||||
// in testdata/ all have their modtimes set to the below in some timezone.
|
||||
func ArchiveRefTime(t *testing.T) time.Time {
|
||||
locMST, err := time.LoadLocation("MST")
|
||||
require.NoError(t, err, "failed to load MST timezone")
|
||||
return time.Date(2006, 1, 2, 3, 4, 5, 0, locMST)
|
||||
}
|
||||
BIN
Binary file not shown.
BIN
Binary file not shown.
Reference in New Issue
Block a user