diff --git a/cmd/profiles.go b/cmd/profiles.go index 046364c..c19d3ca 100644 --- a/cmd/profiles.go +++ b/cmd/profiles.go @@ -1,18 +1,21 @@ package cmd import ( - "bytes" + "archive/tar" "context" - "encoding/json" + "errors" "fmt" "io" "net/http" "os" + "path/filepath" + "strings" "github.com/kernel/cli/pkg/util" "github.com/kernel/kernel-go-sdk" "github.com/kernel/kernel-go-sdk/option" "github.com/kernel/kernel-go-sdk/packages/pagination" + "github.com/klauspost/compress/zstd" "github.com/pterm/pterm" "github.com/samber/lo" "github.com/spf13/cobra" @@ -51,8 +54,7 @@ type ProfilesDeleteInput struct { type ProfilesDownloadInput struct { Identifier string - Output string - Pretty bool + To string } // ProfilesCmd handles profile operations independent of cobra. @@ -246,48 +248,91 @@ func (p ProfilesCmd) Delete(ctx context.Context, in ProfilesDeleteInput) error { } func (p ProfilesCmd) Download(ctx context.Context, in ProfilesDownloadInput) error { + if in.To == "" { + return fmt.Errorf("missing required --to for extraction directory") + } + res, err := p.profiles.Download(ctx, in.Identifier) if err != nil { return util.CleanedUpSdkError{Err: err} } defer res.Body.Close() - if in.Output == "" { - pterm.Error.Println("Missing --to output file path") + if res.StatusCode == http.StatusAccepted { _, _ = io.Copy(io.Discard, res.Body) + pterm.Info.Printf("Profile '%s' has no saved data yet. Use it in a browser session first to capture state.\n", in.Identifier) return nil } - f, err := os.Create(in.Output) + if res.StatusCode != http.StatusOK { + body, _ := io.ReadAll(res.Body) + return fmt.Errorf("unexpected status %d from profile download: %s", res.StatusCode, strings.TrimSpace(string(body))) + } + + if err := extractProfileArchive(res.Body, in.To); err != nil { + return fmt.Errorf("extract profile archive: %w", err) + } + + pterm.Success.Printf("Extracted profile '%s' to %s\n", in.Identifier, in.To) + return nil +} + +// extractProfileArchive streams a zstd-compressed tar archive into destDir. +// Files and directories are created relative to destDir; symlinks and other +// special entry types are skipped. Path-traversal entries are rejected. +func extractProfileArchive(r io.Reader, destDir string) error { + if err := os.MkdirAll(destDir, 0o755); err != nil { + return fmt.Errorf("create destination: %w", err) + } + + cleanedDest, err := filepath.Abs(destDir) if err != nil { - pterm.Error.Printf("Failed to create file: %v\n", err) - return nil + return fmt.Errorf("resolve destination: %w", err) } - defer f.Close() - if in.Pretty { - var buf bytes.Buffer - body, _ := io.ReadAll(res.Body) - if len(body) == 0 { - pterm.Error.Println("Empty response body") - return nil + + decoder, err := zstd.NewReader(r) + if err != nil { + return fmt.Errorf("zstd init: %w", err) + } + defer decoder.Close() + + tr := tar.NewReader(decoder) + for { + header, err := tr.Next() + if errors.Is(err, io.EOF) { + break } - if err := json.Indent(&buf, body, "", " "); err != nil { - pterm.Error.Printf("Failed to pretty-print JSON: %v\n", err) - return nil + if err != nil { + return fmt.Errorf("tar read: %w", err) } - if _, err := io.Copy(f, &buf); err != nil { - pterm.Error.Printf("Failed to write pretty-printed JSON: %v\n", err) - return nil + + destPath := filepath.Join(cleanedDest, header.Name) + if !strings.HasPrefix(destPath, cleanedDest+string(os.PathSeparator)) && destPath != cleanedDest { + return fmt.Errorf("illegal entry path: %s", header.Name) } - return nil - } else { - if _, err := io.Copy(f, res.Body); err != nil { - pterm.Error.Printf("Failed to write file: %v\n", err) - return nil + + switch header.Typeflag { + case tar.TypeDir: + if err := os.MkdirAll(destPath, 0o755); err != nil { + return fmt.Errorf("mkdir %s: %w", destPath, err) + } + case tar.TypeReg: + if err := os.MkdirAll(filepath.Dir(destPath), 0o755); err != nil { + return fmt.Errorf("mkdir parent of %s: %w", destPath, err) + } + f, err := os.OpenFile(destPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, os.FileMode(header.Mode)&0o777) + if err != nil { + return fmt.Errorf("create %s: %w", destPath, err) + } + if _, err := io.Copy(f, tr); err != nil { + f.Close() + return fmt.Errorf("write %s: %w", destPath, err) + } + if err := f.Close(); err != nil { + return fmt.Errorf("close %s: %w", destPath, err) + } } } - - pterm.Success.Printf("Saved profile to %s\n", in.Output) return nil } @@ -329,8 +374,9 @@ var profilesDeleteCmd = &cobra.Command{ } var profilesDownloadCmd = &cobra.Command{ - Use: "download ", - Short: "Download a profile as a ZIP archive", + Use: "download --to ", + Short: "Download a profile and extract it to a directory", + Long: "Download a profile and extract its zstd-compressed user-data tar archive into the directory given by --to. The directory is created if it does not exist.", Args: cobra.ExactArgs(1), RunE: runProfilesDownload, } @@ -350,8 +396,8 @@ func init() { profilesCreateCmd.Flags().StringP("output", "o", "", "Output format: json for raw API response") profilesCreateCmd.Flags().String("name", "", "Optional unique profile name") profilesDeleteCmd.Flags().BoolP("yes", "y", false, "Skip confirmation prompt") - profilesDownloadCmd.Flags().String("to", "", "Output zip file path") - profilesDownloadCmd.Flags().Bool("pretty", false, "Pretty-print JSON to file") + profilesDownloadCmd.Flags().String("to", "", "Directory to extract the profile into (required)") + _ = profilesDownloadCmd.MarkFlagRequired("to") } func runProfilesList(cmd *cobra.Command, args []string) error { @@ -398,9 +444,8 @@ func runProfilesDelete(cmd *cobra.Command, args []string) error { func runProfilesDownload(cmd *cobra.Command, args []string) error { client := getKernelClient(cmd) - out, _ := cmd.Flags().GetString("to") - pretty, _ := cmd.Flags().GetBool("pretty") + to, _ := cmd.Flags().GetString("to") svc := client.Profiles p := ProfilesCmd{profiles: &svc} - return p.Download(cmd.Context(), ProfilesDownloadInput{Identifier: args[0], Output: out, Pretty: pretty}) + return p.Download(cmd.Context(), ProfilesDownloadInput{Identifier: args[0], To: to}) } diff --git a/cmd/profiles_test.go b/cmd/profiles_test.go index 8995b75..73f84c0 100644 --- a/cmd/profiles_test.go +++ b/cmd/profiles_test.go @@ -1,6 +1,7 @@ package cmd import ( + "archive/tar" "bytes" "context" "errors" @@ -8,6 +9,7 @@ import ( "io" "net/http" "os" + "path/filepath" "strings" "testing" "time" @@ -15,6 +17,7 @@ import ( "github.com/kernel/kernel-go-sdk" "github.com/kernel/kernel-go-sdk/option" "github.com/kernel/kernel-go-sdk/packages/pagination" + "github.com/klauspost/compress/zstd" "github.com/pterm/pterm" "github.com/stretchr/testify/assert" ) @@ -224,86 +227,92 @@ func TestProfilesDelete_SkipConfirm(t *testing.T) { assert.Contains(t, buf.String(), "Deleted profile: a") } -func TestProfilesDownload_MissingOutput(t *testing.T) { - buf := captureProfilesOutput(t) - fake := &FakeProfilesService{DownloadFunc: func(ctx context.Context, idOrName string, opts ...option.RequestOption) (*http.Response, error) { - return &http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader("content")), Header: http.Header{}}, nil - }} +// makeProfileArchive builds a zstd-compressed tar archive from a map of file +// paths to contents, for use in download tests. +func makeProfileArchive(t *testing.T, files map[string]string) []byte { + t.Helper() + var buf bytes.Buffer + zw, err := zstd.NewWriter(&buf) + assert.NoError(t, err) + tw := tar.NewWriter(zw) + for name, content := range files { + hdr := &tar.Header{Name: name, Mode: 0o644, Size: int64(len(content)), Typeflag: tar.TypeReg} + assert.NoError(t, tw.WriteHeader(hdr)) + _, err := tw.Write([]byte(content)) + assert.NoError(t, err) + } + assert.NoError(t, tw.Close()) + assert.NoError(t, zw.Close()) + return buf.Bytes() +} + +func TestProfilesDownload_MissingTo(t *testing.T) { + fake := &FakeProfilesService{} p := ProfilesCmd{profiles: fake} - _ = p.Download(context.Background(), ProfilesDownloadInput{Identifier: "p1", Output: "", Pretty: false}) - assert.Contains(t, buf.String(), "Missing --to output file path") + err := p.Download(context.Background(), ProfilesDownloadInput{Identifier: "p1", To: ""}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "missing required --to") } -func TestProfilesDownload_RawSuccess(t *testing.T) { +func TestProfilesDownload_ExtractSuccess(t *testing.T) { buf := captureProfilesOutput(t) - f, err := os.CreateTemp("", "profile-*.zip") + dir, err := os.MkdirTemp("", "profile-*") assert.NoError(t, err) - name := f.Name() - _ = f.Close() - defer os.Remove(name) + defer os.RemoveAll(dir) - content := "hello" + archive := makeProfileArchive(t, map[string]string{ + "Default/Preferences": "{\"k\":1}", + "Local State": "local", + }) fake := &FakeProfilesService{DownloadFunc: func(ctx context.Context, idOrName string, opts ...option.RequestOption) (*http.Response, error) { - return &http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader(content)), Header: http.Header{}}, nil + return &http.Response{StatusCode: 200, Body: io.NopCloser(bytes.NewReader(archive)), Header: http.Header{}}, nil }} p := ProfilesCmd{profiles: fake} - _ = p.Download(context.Background(), ProfilesDownloadInput{Identifier: "p1", Output: name, Pretty: false}) - - b, readErr := os.ReadFile(name) - assert.NoError(t, readErr) - assert.Equal(t, content, string(b)) - assert.Contains(t, buf.String(), "Saved profile to "+name) -} - -func TestProfilesDownload_PrettySuccess(t *testing.T) { - f, err := os.CreateTemp("", "profile-*.json") + err = p.Download(context.Background(), ProfilesDownloadInput{Identifier: "p1", To: dir}) assert.NoError(t, err) - name := f.Name() - _ = f.Close() - defer os.Remove(name) - jsonBody := "{\"a\":1}" - fake := &FakeProfilesService{DownloadFunc: func(ctx context.Context, idOrName string, opts ...option.RequestOption) (*http.Response, error) { - return &http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader(jsonBody)), Header: http.Header{}}, nil - }} - p := ProfilesCmd{profiles: fake} - _ = p.Download(context.Background(), ProfilesDownloadInput{Identifier: "p1", Output: name, Pretty: true}) + b, readErr := os.ReadFile(filepath.Join(dir, "Default", "Preferences")) + assert.NoError(t, readErr) + assert.Equal(t, "{\"k\":1}", string(b)) - b, readErr := os.ReadFile(name) + b2, readErr := os.ReadFile(filepath.Join(dir, "Local State")) assert.NoError(t, readErr) - out := string(b) - assert.Contains(t, out, "\n") - assert.Contains(t, out, "\"a\": 1") + assert.Equal(t, "local", string(b2)) + + assert.Contains(t, buf.String(), "Extracted profile 'p1' to "+dir) } -func TestProfilesDownload_PrettyEmptyBody(t *testing.T) { +func TestProfilesDownload_202NoData(t *testing.T) { buf := captureProfilesOutput(t) - f, err := os.CreateTemp("", "profile-*.json") + dir, err := os.MkdirTemp("", "profile-*") assert.NoError(t, err) - name := f.Name() - _ = f.Close() - defer os.Remove(name) + defer os.RemoveAll(dir) fake := &FakeProfilesService{DownloadFunc: func(ctx context.Context, idOrName string, opts ...option.RequestOption) (*http.Response, error) { - return &http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader("")), Header: http.Header{}}, nil + return &http.Response{StatusCode: http.StatusAccepted, Body: io.NopCloser(strings.NewReader("")), Header: http.Header{}}, nil }} p := ProfilesCmd{profiles: fake} - _ = p.Download(context.Background(), ProfilesDownloadInput{Identifier: "p1", Output: name, Pretty: true}) - assert.Contains(t, buf.String(), "Empty response body") + err = p.Download(context.Background(), ProfilesDownloadInput{Identifier: "fresh", To: dir}) + assert.NoError(t, err) + assert.Contains(t, buf.String(), "no saved data yet") + + entries, _ := os.ReadDir(dir) + assert.Empty(t, entries) } -func TestProfilesDownload_PrettyInvalidJSON(t *testing.T) { - buf := captureProfilesOutput(t) - f, err := os.CreateTemp("", "profile-*.json") +func TestProfilesDownload_PathTraversalRejected(t *testing.T) { + dir, err := os.MkdirTemp("", "profile-*") assert.NoError(t, err) - name := f.Name() - _ = f.Close() - defer os.Remove(name) + defer os.RemoveAll(dir) + archive := makeProfileArchive(t, map[string]string{ + "../escape": "nope", + }) fake := &FakeProfilesService{DownloadFunc: func(ctx context.Context, idOrName string, opts ...option.RequestOption) (*http.Response, error) { - return &http.Response{StatusCode: 200, Body: io.NopCloser(strings.NewReader("not json")), Header: http.Header{}}, nil + return &http.Response{StatusCode: 200, Body: io.NopCloser(bytes.NewReader(archive)), Header: http.Header{}}, nil }} p := ProfilesCmd{profiles: fake} - _ = p.Download(context.Background(), ProfilesDownloadInput{Identifier: "p1", Output: name, Pretty: true}) - assert.Contains(t, buf.String(), "Failed to pretty-print JSON") + err = p.Download(context.Background(), ProfilesDownloadInput{Identifier: "p1", To: dir}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "illegal entry path") } diff --git a/go.mod b/go.mod index 8da7e41..53e8873 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,8 @@ require ( github.com/charmbracelet/lipgloss/v2 v2.0.0-beta.1 github.com/golang-jwt/jwt/v5 v5.2.2 github.com/joho/godotenv v1.5.1 - github.com/kernel/kernel-go-sdk v0.48.0 + github.com/kernel/kernel-go-sdk v0.52.0 + github.com/klauspost/compress v1.18.5 github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c github.com/pterm/pterm v0.12.80 github.com/samber/lo v1.51.0 @@ -20,6 +21,7 @@ require ( golang.org/x/crypto v0.47.0 golang.org/x/oauth2 v0.30.0 golang.org/x/sync v0.19.0 + golang.org/x/term v0.39.0 ) require ( @@ -55,7 +57,6 @@ require ( github.com/tidwall/sjson v1.2.5 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect golang.org/x/sys v0.40.0 // indirect - golang.org/x/term v0.39.0 // indirect golang.org/x/text v0.33.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 2b1d6dd..12e488c 100644 --- a/go.sum +++ b/go.sum @@ -64,8 +64,10 @@ github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2 github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= -github.com/kernel/kernel-go-sdk v0.48.0 h1:XX1VVs8D5q+rBMkZovXmKAQa94w+6oEJzxBLikfPaxw= -github.com/kernel/kernel-go-sdk v0.48.0/go.mod h1:EeZzSuHZVeHKxKCPUzxou2bovNGhXaz0RXrSqKNf1AQ= +github.com/kernel/kernel-go-sdk v0.52.0 h1:ChRAMo6oMAEmazC610FtcqKFO/cqHzU9v1ECF0MiR8E= +github.com/kernel/kernel-go-sdk v0.52.0/go.mod h1:EeZzSuHZVeHKxKCPUzxou2bovNGhXaz0RXrSqKNf1AQ= +github.com/klauspost/compress v1.18.5 h1:/h1gH5Ce+VWNLSWqPzOVn6XBO+vJbCNGvjoaGBFW2IE= +github.com/klauspost/compress v1.18.5/go.mod h1:cwPg85FWrGar70rWktvGQj8/hthj3wpl0PGDogxkrSQ= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.0.10/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c= github.com/klauspost/cpuid/v2 v2.0.12/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c=