chore(coderd): extract fileszip to package archive for reuse (#15229)

Related to https://github.com/coder/coder/issues/15087
As part of sniffing the workspace tags from an uploaded file, we need to
be able to handle both zip and tar files. Extracting the functions to
a separate `archive` package will be helpful here.
This commit is contained in:
Cian Johnston
2024-10-25 15:14:39 +01:00
committed by GitHub
parent 5ad47471b5
commit df34858c3c
8 changed files with 156 additions and 126 deletions
+14 -11
View File
@@ -1,4 +1,4 @@
package coderd package archive
import ( import (
"archive/tar" "archive/tar"
@@ -10,21 +10,22 @@ import (
"strings" "strings"
) )
func CreateTarFromZip(zipReader *zip.Reader) ([]byte, error) { // CreateTarFromZip converts the given zipReader to a tar archive.
func CreateTarFromZip(zipReader *zip.Reader, maxSize int64) ([]byte, error) {
var tarBuffer bytes.Buffer var tarBuffer bytes.Buffer
err := writeTarArchive(&tarBuffer, zipReader) err := writeTarArchive(&tarBuffer, zipReader, maxSize)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return tarBuffer.Bytes(), nil return tarBuffer.Bytes(), nil
} }
func writeTarArchive(w io.Writer, zipReader *zip.Reader) error { func writeTarArchive(w io.Writer, zipReader *zip.Reader, maxSize int64) error {
tarWriter := tar.NewWriter(w) tarWriter := tar.NewWriter(w)
defer tarWriter.Close() defer tarWriter.Close()
for _, file := range zipReader.File { for _, file := range zipReader.File {
err := processFileInZipArchive(file, tarWriter) err := processFileInZipArchive(file, tarWriter, maxSize)
if err != nil { if err != nil {
return err return err
} }
@@ -32,7 +33,7 @@ func writeTarArchive(w io.Writer, zipReader *zip.Reader) error {
return nil return nil
} }
func processFileInZipArchive(file *zip.File, tarWriter *tar.Writer) error { func processFileInZipArchive(file *zip.File, tarWriter *tar.Writer, maxSize int64) error {
fileReader, err := file.Open() fileReader, err := file.Open()
if err != nil { if err != nil {
return err return err
@@ -52,7 +53,7 @@ func processFileInZipArchive(file *zip.File, tarWriter *tar.Writer) error {
return err return err
} }
n, err := io.CopyN(tarWriter, fileReader, httpFileMaxBytes) n, err := io.CopyN(tarWriter, fileReader, maxSize)
log.Println(file.Name, n, err) log.Println(file.Name, n, err)
if errors.Is(err, io.EOF) { if errors.Is(err, io.EOF) {
err = nil err = nil
@@ -60,16 +61,18 @@ func processFileInZipArchive(file *zip.File, tarWriter *tar.Writer) error {
return err return err
} }
func CreateZipFromTar(tarReader *tar.Reader) ([]byte, error) { // CreateZipFromTar converts the given tarReader to a zip archive.
func CreateZipFromTar(tarReader *tar.Reader, maxSize int64) ([]byte, error) {
var zipBuffer bytes.Buffer var zipBuffer bytes.Buffer
err := WriteZipArchive(&zipBuffer, tarReader) err := WriteZip(&zipBuffer, tarReader, maxSize)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return zipBuffer.Bytes(), nil return zipBuffer.Bytes(), nil
} }
func WriteZipArchive(w io.Writer, tarReader *tar.Reader) error { // WriteZip writes the given tarReader to w.
func WriteZip(w io.Writer, tarReader *tar.Reader, maxSize int64) error {
zipWriter := zip.NewWriter(w) zipWriter := zip.NewWriter(w)
defer zipWriter.Close() defer zipWriter.Close()
@@ -100,7 +103,7 @@ func WriteZipArchive(w io.Writer, tarReader *tar.Reader) error {
return err return err
} }
_, err = io.CopyN(zipEntry, tarReader, httpFileMaxBytes) _, err = io.CopyN(zipEntry, tarReader, maxSize)
if errors.Is(err, io.EOF) { if errors.Is(err, io.EOF) {
err = nil err = nil
} }
@@ -1,10 +1,9 @@
package coderd_test package archive_test
import ( import (
"archive/tar" "archive/tar"
"archive/zip" "archive/zip"
"bytes" "bytes"
"io"
"io/fs" "io/fs"
"os" "os"
"os/exec" "os/exec"
@@ -12,13 +11,12 @@ import (
"runtime" "runtime"
"strings" "strings"
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/xerrors"
"github.com/coder/coder/v2/coderd" "github.com/coder/coder/v2/archive"
"github.com/coder/coder/v2/archive/archivetest"
"github.com/coder/coder/v2/testutil" "github.com/coder/coder/v2/testutil"
) )
@@ -30,18 +28,17 @@ func TestCreateTarFromZip(t *testing.T) {
// Read a zip file we prepared earlier // Read a zip file we prepared earlier
ctx := testutil.Context(t, testutil.WaitShort) ctx := testutil.Context(t, testutil.WaitShort)
zipBytes, err := os.ReadFile(filepath.Join("testdata", "test.zip")) zipBytes := archivetest.TestZipFileBytes()
require.NoError(t, err, "failed to read sample zip file")
// Assert invariant // Assert invariant
assertSampleZipFile(t, zipBytes) archivetest.AssertSampleZipFile(t, zipBytes)
zr, err := zip.NewReader(bytes.NewReader(zipBytes), int64(len(zipBytes))) zr, err := zip.NewReader(bytes.NewReader(zipBytes), int64(len(zipBytes)))
require.NoError(t, err, "failed to parse sample zip file") require.NoError(t, err, "failed to parse sample zip file")
tarBytes, err := coderd.CreateTarFromZip(zr) tarBytes, err := archive.CreateTarFromZip(zr, int64(len(zipBytes)))
require.NoError(t, err, "failed to convert zip to tar") require.NoError(t, err, "failed to convert zip to tar")
assertSampleTarFile(t, tarBytes) archivetest.AssertSampleTarFile(t, tarBytes)
tempDir := t.TempDir() tempDir := t.TempDir()
tempFilePath := filepath.Join(tempDir, "test.tar") tempFilePath := filepath.Join(tempDir, "test.tar")
@@ -60,14 +57,13 @@ func TestCreateZipFromTar(t *testing.T) {
} }
t.Run("OK", func(t *testing.T) { t.Run("OK", func(t *testing.T) {
t.Parallel() t.Parallel()
tarBytes, err := os.ReadFile(filepath.Join(".", "testdata", "test.tar")) tarBytes := archivetest.TestTarFileBytes()
require.NoError(t, err, "failed to read sample tar file")
tr := tar.NewReader(bytes.NewReader(tarBytes)) tr := tar.NewReader(bytes.NewReader(tarBytes))
zipBytes, err := coderd.CreateZipFromTar(tr) zipBytes, err := archive.CreateZipFromTar(tr, int64(len(tarBytes)))
require.NoError(t, err) require.NoError(t, err)
assertSampleZipFile(t, zipBytes) archivetest.AssertSampleZipFile(t, zipBytes)
tempDir := t.TempDir() tempDir := t.TempDir()
tempFilePath := filepath.Join(tempDir, "test.zip") tempFilePath := filepath.Join(tempDir, "test.zip")
@@ -99,7 +95,7 @@ func TestCreateZipFromTar(t *testing.T) {
// When: we convert this to a zip // When: we convert this to a zip
tr := tar.NewReader(&tarBytes) tr := tar.NewReader(&tarBytes)
zipBytes, err := coderd.CreateZipFromTar(tr) zipBytes, err := archive.CreateZipFromTar(tr, int64(tarBytes.Len()))
require.NoError(t, err) require.NoError(t, err)
// Then: the resulting zip should contain a corresponding directory // Then: the resulting zip should contain a corresponding directory
@@ -133,7 +129,7 @@ func assertExtractedFiles(t *testing.T, dir string, checkModePerm bool) {
if checkModePerm { if checkModePerm {
assert.Equal(t, fs.ModePerm&0o755, stat.Mode().Perm(), "expected mode 0755 on directory") assert.Equal(t, fs.ModePerm&0o755, stat.Mode().Perm(), "expected mode 0755 on directory")
} }
assert.Equal(t, archiveRefTime(t).UTC(), stat.ModTime().UTC(), "unexpected modtime of %q", path) assert.Equal(t, archivetest.ArchiveRefTime(t).UTC(), stat.ModTime().UTC(), "unexpected modtime of %q", path)
case "/test/hello.txt": case "/test/hello.txt":
stat, err := os.Stat(path) stat, err := os.Stat(path)
assert.NoError(t, err, "failed to stat path %q", path) assert.NoError(t, err, "failed to stat path %q", path)
@@ -168,84 +164,3 @@ func assertExtractedFiles(t *testing.T, dir string, checkModePerm bool) {
return nil return nil
}) })
} }
func assertSampleTarFile(t *testing.T, tarBytes []byte) {
t.Helper()
tr := tar.NewReader(bytes.NewReader(tarBytes))
for {
hdr, err := tr.Next()
if err != nil {
if err == io.EOF {
return
}
require.NoError(t, err)
}
// Note: ignoring timezones here.
require.Equal(t, archiveRefTime(t).UTC(), hdr.ModTime.UTC())
switch hdr.Name {
case "test/":
require.Equal(t, hdr.Typeflag, byte(tar.TypeDir))
case "test/hello.txt":
require.Equal(t, hdr.Typeflag, byte(tar.TypeReg))
bs, err := io.ReadAll(tr)
if err != nil && !xerrors.Is(err, io.EOF) {
require.NoError(t, err)
}
require.Equal(t, "hello", string(bs))
case "test/dir/":
require.Equal(t, hdr.Typeflag, byte(tar.TypeDir))
case "test/dir/world.txt":
require.Equal(t, hdr.Typeflag, byte(tar.TypeReg))
bs, err := io.ReadAll(tr)
if err != nil && !xerrors.Is(err, io.EOF) {
require.NoError(t, err)
}
require.Equal(t, "world", string(bs))
default:
require.Failf(t, "unexpected file in tar", hdr.Name)
}
}
}
func assertSampleZipFile(t *testing.T, zipBytes []byte) {
t.Helper()
zr, err := zip.NewReader(bytes.NewReader(zipBytes), int64(len(zipBytes)))
require.NoError(t, err)
for _, f := range zr.File {
// Note: ignoring timezones here.
require.Equal(t, archiveRefTime(t).UTC(), f.Modified.UTC())
switch f.Name {
case "test/", "test/dir/":
// directory
case "test/hello.txt":
rc, err := f.Open()
require.NoError(t, err)
bs, err := io.ReadAll(rc)
_ = rc.Close()
require.NoError(t, err)
require.Equal(t, "hello", string(bs))
case "test/dir/world.txt":
rc, err := f.Open()
require.NoError(t, err)
bs, err := io.ReadAll(rc)
_ = rc.Close()
require.NoError(t, err)
require.Equal(t, "world", string(bs))
default:
require.Failf(t, "unexpected file in zip", f.Name)
}
}
}
// archiveRefTime is the Go reference time. The contents of the sample tar and zip files
// in testdata/ all have their modtimes set to the below in some timezone.
func archiveRefTime(t *testing.T) time.Time {
locMST, err := time.LoadLocation("MST")
require.NoError(t, err, "failed to load MST timezone")
return time.Date(2006, 1, 2, 3, 4, 5, 0, locMST)
}
+113
View File
@@ -0,0 +1,113 @@
package archivetest
import (
"archive/tar"
"archive/zip"
"bytes"
_ "embed"
"io"
"testing"
"time"
"github.com/stretchr/testify/require"
"golang.org/x/xerrors"
)
//go:embed testdata/test.tar
var testTarFileBytes []byte
//go:embed testdata/test.zip
var testZipFileBytes []byte
// TestTarFileBytes returns the content of testdata/test.tar
func TestTarFileBytes() []byte {
return append([]byte{}, testTarFileBytes...)
}
// TestZipFileBytes returns the content of testdata/test.zip
func TestZipFileBytes() []byte {
return append([]byte{}, testZipFileBytes...)
}
// AssertSampleTarfile compares the content of tarBytes against testdata/test.tar.
func AssertSampleTarFile(t *testing.T, tarBytes []byte) {
t.Helper()
tr := tar.NewReader(bytes.NewReader(tarBytes))
for {
hdr, err := tr.Next()
if err != nil {
if err == io.EOF {
return
}
require.NoError(t, err)
}
// Note: ignoring timezones here.
require.Equal(t, ArchiveRefTime(t).UTC(), hdr.ModTime.UTC())
switch hdr.Name {
case "test/":
require.Equal(t, hdr.Typeflag, byte(tar.TypeDir))
case "test/hello.txt":
require.Equal(t, hdr.Typeflag, byte(tar.TypeReg))
bs, err := io.ReadAll(tr)
if err != nil && !xerrors.Is(err, io.EOF) {
require.NoError(t, err)
}
require.Equal(t, "hello", string(bs))
case "test/dir/":
require.Equal(t, hdr.Typeflag, byte(tar.TypeDir))
case "test/dir/world.txt":
require.Equal(t, hdr.Typeflag, byte(tar.TypeReg))
bs, err := io.ReadAll(tr)
if err != nil && !xerrors.Is(err, io.EOF) {
require.NoError(t, err)
}
require.Equal(t, "world", string(bs))
default:
require.Failf(t, "unexpected file in tar", hdr.Name)
}
}
}
// AssertSampleZipFile compares the content of zipBytes against testdata/test.zip.
func AssertSampleZipFile(t *testing.T, zipBytes []byte) {
t.Helper()
zr, err := zip.NewReader(bytes.NewReader(zipBytes), int64(len(zipBytes)))
require.NoError(t, err)
for _, f := range zr.File {
// Note: ignoring timezones here.
require.Equal(t, ArchiveRefTime(t).UTC(), f.Modified.UTC())
switch f.Name {
case "test/", "test/dir/":
// directory
case "test/hello.txt":
rc, err := f.Open()
require.NoError(t, err)
bs, err := io.ReadAll(rc)
_ = rc.Close()
require.NoError(t, err)
require.Equal(t, "hello", string(bs))
case "test/dir/world.txt":
rc, err := f.Open()
require.NoError(t, err)
bs, err := io.ReadAll(rc)
_ = rc.Close()
require.NoError(t, err)
require.Equal(t, "world", string(bs))
default:
require.Failf(t, "unexpected file in zip", f.Name)
}
}
}
// archiveRefTime is the Go reference time. The contents of the sample tar and zip files
// in testdata/ all have their modtimes set to the below in some timezone.
func ArchiveRefTime(t *testing.T) time.Time {
locMST, err := time.LoadLocation("MST")
require.NoError(t, err, "failed to load MST timezone")
return time.Date(2006, 1, 2, 3, 4, 5, 0, locMST)
}
+2 -1
View File
@@ -13,6 +13,7 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/coder/coder/v2/archive"
"github.com/coder/coder/v2/cli/clitest" "github.com/coder/coder/v2/cli/clitest"
"github.com/coder/coder/v2/coderd" "github.com/coder/coder/v2/coderd"
"github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/coderd/coderdtest"
@@ -95,7 +96,7 @@ func TestTemplatePull_Stdout(t *testing.T) {
// Verify .zip format // Verify .zip format
tarReader := tar.NewReader(bytes.NewReader(expected)) tarReader := tar.NewReader(bytes.NewReader(expected))
expectedZip, err := coderd.CreateZipFromTar(tarReader) expectedZip, err := archive.CreateZipFromTar(tarReader, coderd.HTTPFileMaxBytes)
require.NoError(t, err) require.NoError(t, err)
inv, root = clitest.New(t, "templates", "pull", "--zip", template.Name) inv, root = clitest.New(t, "templates", "pull", "--zip", template.Name)
+5 -4
View File
@@ -16,6 +16,7 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"cdr.dev/slog" "cdr.dev/slog"
"github.com/coder/coder/v2/archive"
"github.com/coder/coder/v2/coderd/database" "github.com/coder/coder/v2/coderd/database"
"github.com/coder/coder/v2/coderd/database/dbtime" "github.com/coder/coder/v2/coderd/database/dbtime"
"github.com/coder/coder/v2/coderd/httpapi" "github.com/coder/coder/v2/coderd/httpapi"
@@ -27,7 +28,7 @@ const (
tarMimeType = "application/x-tar" tarMimeType = "application/x-tar"
zipMimeType = "application/zip" zipMimeType = "application/zip"
httpFileMaxBytes = 10 * (10 << 20) HTTPFileMaxBytes = 10 * (10 << 20)
) )
// @Summary Upload file // @Summary Upload file
@@ -55,7 +56,7 @@ func (api *API) postFile(rw http.ResponseWriter, r *http.Request) {
return return
} }
r.Body = http.MaxBytesReader(rw, r.Body, httpFileMaxBytes) r.Body = http.MaxBytesReader(rw, r.Body, HTTPFileMaxBytes)
data, err := io.ReadAll(r.Body) data, err := io.ReadAll(r.Body)
if err != nil { if err != nil {
httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{ httpapi.Write(ctx, rw, http.StatusBadRequest, codersdk.Response{
@@ -75,7 +76,7 @@ func (api *API) postFile(rw http.ResponseWriter, r *http.Request) {
return return
} }
data, err = CreateTarFromZip(zipReader) data, err = archive.CreateTarFromZip(zipReader, HTTPFileMaxBytes)
if err != nil { if err != nil {
httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{ httpapi.Write(ctx, rw, http.StatusInternalServerError, codersdk.Response{
Message: "Internal error processing .zip archive.", Message: "Internal error processing .zip archive.",
@@ -181,7 +182,7 @@ func (api *API) fileByID(rw http.ResponseWriter, r *http.Request) {
rw.Header().Set("Content-Type", codersdk.ContentTypeZip) rw.Header().Set("Content-Type", codersdk.ContentTypeZip)
rw.WriteHeader(http.StatusOK) rw.WriteHeader(http.StatusOK)
err = WriteZipArchive(rw, tar.NewReader(bytes.NewReader(file.Data))) err = archive.WriteZip(rw, tar.NewReader(bytes.NewReader(file.Data)), HTTPFileMaxBytes)
if err != nil { if err != nil {
api.Logger.Error(ctx, "invalid .zip archive", slog.F("file_id", fileID), slog.F("mimetype", file.Mimetype), slog.Error(err)) api.Logger.Error(ctx, "invalid .zip archive", slog.F("file_id", fileID), slog.F("mimetype", file.Mimetype), slog.Error(err))
} }
+10 -13
View File
@@ -5,14 +5,13 @@ import (
"bytes" "bytes"
"context" "context"
"net/http" "net/http"
"os"
"path/filepath"
"testing" "testing"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/coder/coder/v2/coderd" "github.com/coder/coder/v2/archive"
"github.com/coder/coder/v2/archive/archivetest"
"github.com/coder/coder/v2/coderd/coderdtest" "github.com/coder/coder/v2/coderd/coderdtest"
"github.com/coder/coder/v2/codersdk" "github.com/coder/coder/v2/codersdk"
"github.com/coder/coder/v2/testutil" "github.com/coder/coder/v2/testutil"
@@ -84,8 +83,8 @@ func TestDownload(t *testing.T) {
// given // given
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel() defer cancel()
tarball, err := os.ReadFile(filepath.Join("testdata", "test.tar"))
require.NoError(t, err) tarball := archivetest.TestTarFileBytes()
// when // when
resp, err := client.Upload(ctx, codersdk.ContentTypeTar, bytes.NewReader(tarball)) resp, err := client.Upload(ctx, codersdk.ContentTypeTar, bytes.NewReader(tarball))
@@ -97,7 +96,7 @@ func TestDownload(t *testing.T) {
require.Len(t, data, len(tarball)) require.Len(t, data, len(tarball))
require.Equal(t, codersdk.ContentTypeTar, contentType) require.Equal(t, codersdk.ContentTypeTar, contentType)
require.Equal(t, tarball, data) require.Equal(t, tarball, data)
assertSampleTarFile(t, data) archivetest.AssertSampleTarFile(t, data)
}) })
t.Run("InsertZip_DownloadTar", func(t *testing.T) { t.Run("InsertZip_DownloadTar", func(t *testing.T) {
@@ -106,8 +105,7 @@ func TestDownload(t *testing.T) {
_ = coderdtest.CreateFirstUser(t, client) _ = coderdtest.CreateFirstUser(t, client)
// given // given
zipContent, err := os.ReadFile(filepath.Join("testdata", "test.zip")) zipContent := archivetest.TestZipFileBytes()
require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
defer cancel() defer cancel()
@@ -123,7 +121,7 @@ func TestDownload(t *testing.T) {
// Note: creating a zip from a tar will result in some loss of information // Note: creating a zip from a tar will result in some loss of information
// as zip files do not store UNIX user:group data. // as zip files do not store UNIX user:group data.
assertSampleTarFile(t, data) archivetest.AssertSampleTarFile(t, data)
}) })
t.Run("InsertTar_DownloadZip", func(t *testing.T) { t.Run("InsertTar_DownloadZip", func(t *testing.T) {
@@ -132,11 +130,10 @@ func TestDownload(t *testing.T) {
_ = coderdtest.CreateFirstUser(t, client) _ = coderdtest.CreateFirstUser(t, client)
// given // given
tarball, err := os.ReadFile(filepath.Join("testdata", "test.tar")) tarball := archivetest.TestTarFileBytes()
require.NoError(t, err)
tarReader := tar.NewReader(bytes.NewReader(tarball)) tarReader := tar.NewReader(bytes.NewReader(tarball))
expectedZip, err := coderd.CreateZipFromTar(tarReader) expectedZip, err := archive.CreateZipFromTar(tarReader, 10240)
require.NoError(t, err) require.NoError(t, err)
ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong) ctx, cancel := context.WithTimeout(context.Background(), testutil.WaitLong)
@@ -151,6 +148,6 @@ func TestDownload(t *testing.T) {
// then // then
require.Equal(t, codersdk.ContentTypeZip, contentType) require.Equal(t, codersdk.ContentTypeZip, contentType)
require.Equal(t, expectedZip, data) require.Equal(t, expectedZip, data)
assertSampleZipFile(t, data) archivetest.AssertSampleZipFile(t, data)
}) })
} }