diff --git a/beszel/internal/ghupdate/extract.go b/beszel/internal/ghupdate/extract.go new file mode 100644 index 0000000..38da6bb --- /dev/null +++ b/beszel/internal/ghupdate/extract.go @@ -0,0 +1,140 @@ +package ghupdate + +import ( + "archive/tar" + "archive/zip" + "compress/gzip" + "fmt" + "io" + "os" + "path/filepath" + "strings" +) + +// extract extracts an archive file to the destination directory. +// Supports .zip and .tar.gz files based on the file extension. +func extract(srcPath, destDir string) error { + if strings.HasSuffix(srcPath, ".tar.gz") { + return extractTarGz(srcPath, destDir) + } + // Default to zip extraction + return extractZip(srcPath, destDir) +} + +// extractTarGz extracts a tar.gz archive to the destination directory. +func extractTarGz(srcPath, destDir string) error { + src, err := os.Open(srcPath) + if err != nil { + return err + } + defer src.Close() + + gz, err := gzip.NewReader(src) + if err != nil { + return err + } + defer gz.Close() + + tr := tar.NewReader(gz) + + for { + header, err := tr.Next() + if err == io.EOF { + break + } + if err != nil { + return err + } + + if header.Typeflag == tar.TypeDir { + if err := os.MkdirAll(filepath.Join(destDir, header.Name), 0755); err != nil { + return err + } + continue + } + + if err := os.MkdirAll(filepath.Dir(filepath.Join(destDir, header.Name)), 0755); err != nil { + return err + } + + outFile, err := os.Create(filepath.Join(destDir, header.Name)) + if err != nil { + return err + } + + if _, err := io.Copy(outFile, tr); err != nil { + outFile.Close() + return err + } + outFile.Close() + } + + return nil +} + +// extractZip extracts the zip archive at "src" to "dest". +// +// Note that only dirs and regular files will be extracted. +// Symbolic links, named pipes, sockets, or any other irregular files +// are skipped because they come with too many edge cases and ambiguities. +func extractZip(src, dest string) error { + zr, err := zip.OpenReader(src) + if err != nil { + return err + } + defer zr.Close() + + // normalize dest path to check later for Zip Slip + dest = filepath.Clean(dest) + string(os.PathSeparator) + + for _, f := range zr.File { + err := extractFile(f, dest) + if err != nil { + return err + } + } + + return nil +} + +// extractFile extracts the provided zipFile into "basePath/zipFileName" path, +// creating all the necessary path directories. +func extractFile(zipFile *zip.File, basePath string) error { + path := filepath.Join(basePath, zipFile.Name) + + // check for Zip Slip + if !strings.HasPrefix(path, basePath) { + return fmt.Errorf("invalid file path: %s", path) + } + + r, err := zipFile.Open() + if err != nil { + return err + } + defer r.Close() + + // allow only dirs or regular files + if zipFile.FileInfo().IsDir() { + if err := os.MkdirAll(path, os.ModePerm); err != nil { + return err + } + } else if zipFile.FileInfo().Mode().IsRegular() { + // ensure that the file path directories are created + if err := os.MkdirAll(filepath.Dir(path), os.ModePerm); err != nil { + return err + } + + f, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, zipFile.Mode()) + if err != nil { + return err + } + defer f.Close() + + _, err = io.Copy(f, r) + if err != nil { + return err + } + } + + return nil +} diff --git a/beszel/internal/ghupdate/ghupdate.go b/beszel/internal/ghupdate/ghupdate.go index 9c99af2..5421465 100644 --- a/beszel/internal/ghupdate/ghupdate.go +++ b/beszel/internal/ghupdate/ghupdate.go @@ -4,9 +4,7 @@ package ghupdate import ( - "archive/tar" "beszel" - "compress/gzip" "context" "encoding/json" "fmt" @@ -19,8 +17,6 @@ import ( "strings" "github.com/blang/semver" - "github.com/pocketbase/pocketbase/core" - "github.com/pocketbase/pocketbase/tools/archive" ) // Minimal color functions using ANSI escape codes @@ -143,7 +139,7 @@ func (p *plugin) update() (updated bool, err error) { return false, err } - releaseDir := filepath.Join(p.config.DataDir, core.LocalTempDirName) + releaseDir := filepath.Join(p.config.DataDir, ".beszel_update") defer os.RemoveAll(releaseDir) ColorPrintf(ColorYellow, "Downloading %s...", asset.Name) @@ -159,15 +155,9 @@ func (p *plugin) update() (updated bool, err error) { extractDir := filepath.Join(releaseDir, "extracted_"+asset.Name) defer os.RemoveAll(extractDir) - // Extract based on file extension - if strings.HasSuffix(asset.Name, ".tar.gz") { - if err := extractTarGz(assetPath, extractDir); err != nil { - return false, err - } - } else { - if err := archive.Extract(assetPath, extractDir); err != nil { - return false, err - } + // Extract the archive (automatically detects format) + if err := extract(assetPath, extractDir); err != nil { + return false, err } ColorPrint(ColorYellow, "Replacing the executable...") @@ -357,52 +347,3 @@ func archiveSuffix(binaryName, goos, goarch string) string { } return fmt.Sprintf("%s_%s_%s.tar.gz", binaryName, goos, goarch) } - -func extractTarGz(srcPath, destDir string) error { - src, err := os.Open(srcPath) - if err != nil { - return err - } - defer src.Close() - - gz, err := gzip.NewReader(src) - if err != nil { - return err - } - defer gz.Close() - - tr := tar.NewReader(gz) - - for { - header, err := tr.Next() - if err == io.EOF { - break - } - if err != nil { - return err - } - - if header.Typeflag == tar.TypeDir { - if err := os.MkdirAll(filepath.Join(destDir, header.Name), 0755); err != nil { - return err - } - continue - } - - if err := os.MkdirAll(filepath.Dir(filepath.Join(destDir, header.Name)), 0755); err != nil { - return err - } - - outFile, err := os.Create(filepath.Join(destDir, header.Name)) - if err != nil { - return err - } - defer outFile.Close() - - if _, err := io.Copy(outFile, tr); err != nil { - return err - } - } - - return nil -} diff --git a/beszel/internal/ghupdate/ghupdate_test.go b/beszel/internal/ghupdate/ghupdate_test.go index 8fa6fa1..a93b102 100644 --- a/beszel/internal/ghupdate/ghupdate_test.go +++ b/beszel/internal/ghupdate/ghupdate_test.go @@ -1,6 +1,9 @@ package ghupdate -import "testing" +import ( + "path/filepath" + "testing" +) func TestReleaseFindAssetBySuffix(t *testing.T) { r := release{ @@ -21,3 +24,22 @@ func TestReleaseFindAssetBySuffix(t *testing.T) { t.Fatalf("Expected asset with id %d, got %v", 2, asset) } } + +func TestExtractFailure(t *testing.T) { + testDir := t.TempDir() + + // Test with missing zip file + missingZipPath := filepath.Join(testDir, "missing_test.zip") + extractedPath := filepath.Join(testDir, "zip_extract") + + if err := extract(missingZipPath, extractedPath); err == nil { + t.Fatal("Expected Extract to fail due to missing zip file") + } + + // Test with missing tar.gz file + missingTarPath := filepath.Join(testDir, "missing_test.tar.gz") + + if err := extract(missingTarPath, extractedPath); err == nil { + t.Fatal("Expected Extract to fail due to missing tar.gz file") + } +}