diff --git a/internal/cli/pr.go b/internal/cli/pr.go index 31c7da7..f9b221d 100644 --- a/internal/cli/pr.go +++ b/internal/cli/pr.go @@ -5,6 +5,7 @@ import ( "fmt" "os" "os/exec" + "regexp" "strconv" "strings" @@ -48,20 +49,32 @@ func prViewCmd() *cobra.Command { ) cmd := &cobra.Command{ - Use: "view ", + Use: "view []", Short: "View a pull request", - Args: cobra.ExactArgs(1), - RunE: func(cmd *cobra.Command, args []string) error { - number, err := strconv.Atoi(args[0]) - if err != nil { - return fmt.Errorf("invalid PR number: %s", args[0]) - } + Long: `View a pull request by number, or view the PR for the current branch. +If no number is given, finds and displays the pull request whose head +branch matches the current git branch.`, + Args: cobra.MaximumNArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { forge, owner, repoName, _, err := resolve.Repo(flagRepo, flagForgeType) if err != nil { return err } + var number int + if len(args) > 0 { + number, err = strconv.Atoi(args[0]) + if err != nil { + return fmt.Errorf("invalid PR number: %s", args[0]) + } + } else { + number, err = findPRForCurrentBranch(cmd.Context(), forge, owner, repoName) + if err != nil { + return err + } + } + pr, err := forge.PullRequests().Get(cmd.Context(), owner, repoName, number) if err != nil { return fmt.Errorf("getting PR #%d: %w", number, err) @@ -610,10 +623,13 @@ The argument can be a PR number or a full URL: localBranch = flagBranch } + if !flagDetach { + _ = storePRForBranch(ctx, localBranch, number) + } + if pr.Head.Fork != nil { return checkoutForkPR(ctx, domain, pr, remoteRef, localBranch, flagRemoteName, flagDetach, flagForce) } - return checkoutSameRepoPR(ctx, remoteRef, localBranch, flagDetach, flagForce) }, } @@ -744,3 +760,96 @@ func gitCheckout(ctx context.Context, remote, remoteRef, localBranch string, det resetCmd.Stderr = os.Stderr return resetCmd.Run() } + +func findPRForCurrentBranch(ctx context.Context, f forges.Forge, owner, repo string) (int, error) { + out, err := exec.CommandContext(ctx, "git", "branch", "--show-current").Output() + if err != nil { + return 0, fmt.Errorf("getting current branch: %w (not in a git repository?)", err) + } + localBranch := strings.TrimSpace(string(out)) + if localBranch == "" { + return 0, fmt.Errorf("not on a branch (detached HEAD state)") + } + + // Check cache first (set by 'pr checkout') + if n, err := loadPRForBranch(ctx, localBranch); err == nil { + return n, nil + } + + // If that yields nothing, fall back to API query. This API call is really + // slow for Gitea since the Head filter is not actually implemented. + headOwner := owner + if remoteOwner, err := resolve.OwnerForBranch(ctx, localBranch); err == nil { + headOwner = remoteOwner + } + + // TODO: Limit 100 with no pagination means repos with >100 PRs may miss + // the match on a fresh checkout (cache hides this in normal use). + prs, err := f.PullRequests().List(ctx, owner, repo, forges.ListPROpts{ + Head: headOwner + ":" + localBranch, + State: "all", + Limit: 100, + }) + if err != nil { + return 0, fmt.Errorf("listing PRs for branch %q: %w", localBranch, err) + } + + // Need to filter the results again for owner:branch since the API results + // don't respect the filter in case of Gitea. + var matching []forges.PullRequest + for _, pr := range prs { + prHeadOwner := owner + if pr.Head.Fork != nil { + prHeadOwner = pr.Head.Fork.Owner + } + if pr.Head.Ref == localBranch && prHeadOwner == headOwner { + matching = append(matching, pr) + } + } + + if len(matching) < 1 { + return 0, fmt.Errorf("no pull request found for branch %q", localBranch) + } + + // Prefer open PRs over closed/merged ones (a branch may be reused) + for _, pr := range matching { + if pr.State == "open" { + // Store the PR number into local git config so that the next 'forge + // pr view' call is a lot faster. + _ = storePRForBranch(ctx, localBranch, pr.Number) + return pr.Number, nil + } + } + + // No open PR, return the first match (closed/merged) but don't cache it + return matching[0].Number, nil +} + +func storePRForBranch(ctx context.Context, branch string, number int) error { + key := fmt.Sprintf("branch.%s.forge-pr", branch) + return exec.CommandContext(ctx, "git", "config", "--local", key, strconv.Itoa(number)).Run() +} + +var prRefRE = regexp.MustCompile(`^refs/pull/(\d+)/head$`) + +func loadPRForBranch(ctx context.Context, branch string) (int, error) { + key := fmt.Sprintf("branch.%s.forge-pr", branch) + out, err := exec.CommandContext(ctx, "git", "config", "--get", key).Output() + if err == nil { + return strconv.Atoi(strings.TrimSpace(string(out))) + } + + // Fall back to gh CLI's format (refs/pull//head in branch..merge). + // The regex only matches refs/pull//head, so refs/heads/* values are + // safely rejected. + mergeKey := fmt.Sprintf("branch.%s.merge", branch) + out, err = exec.CommandContext(ctx, "git", "config", "--get", mergeKey).Output() + if err != nil { + return 0, err + } + m := prRefRE.FindStringSubmatch(strings.TrimSpace(string(out))) + if m == nil { + return 0, fmt.Errorf("not a PR ref") + } + return strconv.Atoi(m[1]) +} diff --git a/internal/cli/pr_checkout_test.go b/internal/cli/pr_checkout_test.go index f3b37ac..17ca998 100644 --- a/internal/cli/pr_checkout_test.go +++ b/internal/cli/pr_checkout_test.go @@ -15,8 +15,10 @@ import ( // mockPRService implements forges.PullRequestService for testing. type mockPRService struct { - pr *forges.PullRequest - err error + pr *forges.PullRequest + err error + listResult []forges.PullRequest + listErr error } func (m *mockPRService) Get(_ context.Context, _, _ string, _ int) (*forges.PullRequest, error) { @@ -24,7 +26,7 @@ func (m *mockPRService) Get(_ context.Context, _, _ string, _ int) (*forges.Pull } func (m *mockPRService) List(_ context.Context, _, _ string, _ forges.ListPROpts) ([]forges.PullRequest, error) { - return nil, nil + return m.listResult, m.listErr } func (m *mockPRService) Create(_ context.Context, _, _ string, _ forges.CreatePROpts) (*forges.PullRequest, error) { diff --git a/internal/cli/pr_test.go b/internal/cli/pr_test.go index 4852128..a5aad0d 100644 --- a/internal/cli/pr_test.go +++ b/internal/cli/pr_test.go @@ -2,8 +2,15 @@ package cli import ( "bytes" + "context" + "os" + "os/exec" + "path/filepath" "strings" "testing" + + "github.com/git-pkgs/forge" + "github.com/git-pkgs/forge/internal/resolve" ) func TestPRCmd(t *testing.T) { @@ -74,18 +81,6 @@ func TestPRViewInvalidNumber(t *testing.T) { } } -func TestPRViewRequiresArg(t *testing.T) { - var buf bytes.Buffer - rootCmd.SetOut(&buf) - rootCmd.SetErr(&buf) - rootCmd.SetArgs([]string{"pr", "view"}) - - err := rootCmd.Execute() - if err == nil { - t.Fatal("expected error for missing argument") - } -} - func TestPRCreateRequiresTitleAndHead(t *testing.T) { var buf bytes.Buffer rootCmd.SetOut(&buf) @@ -142,3 +137,276 @@ func TestPRDiffRequiresArg(t *testing.T) { t.Fatal("expected error for missing argument") } } + +func TestStorePRForBranch(t *testing.T) { + if _, err := exec.LookPath("git"); err != nil { + t.Skip("git not installed") + } + + dir := t.TempDir() + t.Chdir(dir) + + mustGitCmd(t, "init", "-q") + mustGitCmd(t, "config", "user.email", "test@test.com") + mustGitCmd(t, "config", "user.name", "Test") + + if err := os.WriteFile(filepath.Join(dir, "README"), []byte("test"), 0644); err != nil { + t.Fatal(err) + } + mustGitCmd(t, "add", "README") + mustGitCmd(t, "commit", "-m", "init") + mustGitCmd(t, "checkout", "-b", "feature") + + ctx := context.Background() + + // Store PR number + if err := storePRForBranch(ctx, "feature", 42); err != nil { + t.Fatalf("storePRForBranch: %v", err) + } + + // Load it back + n, err := loadPRForBranch(ctx, "feature") + if err != nil { + t.Fatalf("loadPRForBranch: %v", err) + } + if n != 42 { + t.Errorf("got %d, want 42", n) + } +} + +func TestLoadPRForBranchGhFormat(t *testing.T) { + if _, err := exec.LookPath("git"); err != nil { + t.Skip("git not installed") + } + + dir := t.TempDir() + t.Chdir(dir) + + mustGitCmd(t, "init", "-q") + mustGitCmd(t, "config", "user.email", "test@test.com") + mustGitCmd(t, "config", "user.name", "Test") + + if err := os.WriteFile(filepath.Join(dir, "README"), []byte("test"), 0644); err != nil { + t.Fatal(err) + } + mustGitCmd(t, "add", "README") + mustGitCmd(t, "commit", "-m", "init") + mustGitCmd(t, "checkout", "-b", "pr-branch") + + // Set up gh CLI's format: branch..merge = refs/pull//head + mustGitCmd(t, "config", "branch.pr-branch.merge", "refs/pull/123/head") + + ctx := context.Background() + n, err := loadPRForBranch(ctx, "pr-branch") + if err != nil { + t.Fatalf("loadPRForBranch: %v", err) + } + if n != 123 { + t.Errorf("got %d, want 123", n) + } +} + +func TestFindPRForCurrentBranch(t *testing.T) { + if _, err := exec.LookPath("git"); err != nil { + t.Skip("git not installed") + } + + dir := t.TempDir() + t.Chdir(dir) + + mustGitCmd(t, "init", "-q") + mustGitCmd(t, "config", "user.email", "test@test.com") + mustGitCmd(t, "config", "user.name", "Test") + mustGitCmd(t, "remote", "add", "origin", "https://github.com/testowner/testrepo.git") + + if err := os.WriteFile(filepath.Join(dir, "README"), []byte("test"), 0644); err != nil { + t.Fatal(err) + } + mustGitCmd(t, "add", "README") + mustGitCmd(t, "commit", "-m", "init") + mustGitCmd(t, "checkout", "-b", "feature") + mustGitCmd(t, "config", "branch.feature.remote", "origin") + + // Set up mock forge that returns a PR for our branch + mockPR := &mockPRService{ + listResult: []forges.PullRequest{ + { + Number: 99, + State: "open", + Head: forges.PRBranch{ + Ref: "feature", + }, + }, + }, + } + resolve.SetTestForge( + &mockForge{prService: mockPR}, + "testowner", "testrepo", "github.com", + ) + t.Cleanup(resolve.ResetTestForge) + + ctx := context.Background() + forge, owner, repo, _, err := resolve.Repo("", "") + if err != nil { + t.Fatalf("resolve.Repo: %v", err) + } + + n, err := findPRForCurrentBranch(ctx, forge, owner, repo) + if err != nil { + t.Fatalf("findPRForCurrentBranch: %v", err) + } + if n != 99 { + t.Errorf("got %d, want 99", n) + } + + // The PR number should now be cached + cached, err := loadPRForBranch(ctx, "feature") + if err != nil { + t.Fatalf("loadPRForBranch after find: %v", err) + } + if cached != 99 { + t.Errorf("cached PR = %d, want 99", cached) + } +} + +func TestFindPRForCurrentBranch_OpenWinsOverClosed(t *testing.T) { + if _, err := exec.LookPath("git"); err != nil { + t.Skip("git not installed") + } + + dir := t.TempDir() + t.Chdir(dir) + + mustGitCmd(t, "init", "-q") + mustGitCmd(t, "config", "user.email", "test@test.com") + mustGitCmd(t, "config", "user.name", "Test") + mustGitCmd(t, "remote", "add", "origin", "https://github.com/testowner/testrepo.git") + + if err := os.WriteFile(filepath.Join(dir, "README"), []byte("test"), 0644); err != nil { + t.Fatal(err) + } + mustGitCmd(t, "add", "README") + mustGitCmd(t, "commit", "-m", "init") + mustGitCmd(t, "checkout", "-b", "feature") + mustGitCmd(t, "config", "branch.feature.remote", "origin") + + // Mock forge returns both a closed and open PR for the same branch + mockPR := &mockPRService{ + listResult: []forges.PullRequest{ + { + Number: 50, + State: "closed", + Head: forges.PRBranch{ + Ref: "feature", + }, + }, + { + Number: 99, + State: "open", + Head: forges.PRBranch{ + Ref: "feature", + }, + }, + }, + } + resolve.SetTestForge( + &mockForge{prService: mockPR}, + "testowner", "testrepo", "github.com", + ) + t.Cleanup(resolve.ResetTestForge) + + ctx := context.Background() + forge, owner, repo, _, err := resolve.Repo("", "") + if err != nil { + t.Fatalf("resolve.Repo: %v", err) + } + + n, err := findPRForCurrentBranch(ctx, forge, owner, repo) + if err != nil { + t.Fatalf("findPRForCurrentBranch: %v", err) + } + if n != 99 { + t.Errorf("got %d, want 99 (the open PR should win over closed)", n) + } + + // The open PR should be cached + cached, err := loadPRForBranch(ctx, "feature") + if err != nil { + t.Fatalf("loadPRForBranch after find: %v", err) + } + if cached != 99 { + t.Errorf("cached PR = %d, want 99", cached) + } +} + +func TestFindPRForCurrentBranch_ClosedPRNotCached(t *testing.T) { + if _, err := exec.LookPath("git"); err != nil { + t.Skip("git not installed") + } + + dir := t.TempDir() + t.Chdir(dir) + + mustGitCmd(t, "init", "-q") + mustGitCmd(t, "config", "user.email", "test@test.com") + mustGitCmd(t, "config", "user.name", "Test") + mustGitCmd(t, "remote", "add", "origin", "https://github.com/testowner/testrepo.git") + + if err := os.WriteFile(filepath.Join(dir, "README"), []byte("test"), 0644); err != nil { + t.Fatal(err) + } + mustGitCmd(t, "add", "README") + mustGitCmd(t, "commit", "-m", "init") + mustGitCmd(t, "checkout", "-b", "feature") + mustGitCmd(t, "config", "branch.feature.remote", "origin") + + // Mock forge returns only a closed PR + mockPR := &mockPRService{ + listResult: []forges.PullRequest{ + { + Number: 42, + State: "closed", + Head: forges.PRBranch{ + Ref: "feature", + }, + }, + }, + } + resolve.SetTestForge( + &mockForge{prService: mockPR}, + "testowner", "testrepo", "github.com", + ) + t.Cleanup(resolve.ResetTestForge) + + ctx := context.Background() + forge, owner, repo, _, err := resolve.Repo("", "") + if err != nil { + t.Fatalf("resolve.Repo: %v", err) + } + + n, err := findPRForCurrentBranch(ctx, forge, owner, repo) + if err != nil { + t.Fatalf("findPRForCurrentBranch: %v", err) + } + if n != 42 { + t.Errorf("got %d, want 42 (the closed PR should be returned)", n) + } + + // Closed PRs should NOT be cached - loadPRForBranch should find nothing + _, err = loadPRForBranch(ctx, "feature") + if err == nil { + t.Error("expected loadPRForBranch to return error for uncached closed PR, got nil") + } +} + +func mustGitCmd(t *testing.T, args ...string) { + t.Helper() + cmd := exec.Command("git", args...) + cmd.Env = append(os.Environ(), + "GIT_CONFIG_GLOBAL=/dev/null", + "GIT_CONFIG_SYSTEM=/dev/null", + ) + if out, err := cmd.CombinedOutput(); err != nil { + t.Fatalf("git %v: %v\n%s", args, err, out) + } +} diff --git a/internal/resolve/resolve.go b/internal/resolve/resolve.go index 106d6c0..bd2b169 100644 --- a/internal/resolve/resolve.go +++ b/internal/resolve/resolve.go @@ -211,6 +211,32 @@ func gitRemoteURL(name string) (string, error) { return strings.TrimSpace(string(out)), nil } +// OwnerForBranch returns the repository owner for the remote that the given +// branch tracks. This is useful for determining which fork a branch was pushed +// to when creating pull requests. +func OwnerForBranch(ctx context.Context, branch string) (string, error) { + remoteKey := fmt.Sprintf("branch.%s.remote", branch) + out, err := exec.CommandContext(ctx, "git", "config", "--get", remoteKey).Output() + if err != nil { + return "", err + } + remote := strings.TrimSpace(string(out)) + if remote == "" { + return "", fmt.Errorf("no remote configured") + } + + remoteURL, err := gitRemoteURL(remote) + if err != nil { + return "", err + } + + _, owner, _, err := forges.ParseRepoURL(remoteURL) + if err != nil { + return "", err + } + return owner, nil +} + func newClient(domain string) *forges.Client { token := TokenForDomain(domain) var opts []forges.Option diff --git a/internal/resolve/resolve_test.go b/internal/resolve/resolve_test.go index b08c44f..86f8754 100644 --- a/internal/resolve/resolve_test.go +++ b/internal/resolve/resolve_test.go @@ -1,6 +1,7 @@ package resolve import ( + "context" "os" "os/exec" "path/filepath" @@ -367,6 +368,64 @@ func TestRemoteUnknownNameError(t *testing.T) { } } +func TestOwnerForBranch(t *testing.T) { + if _, err := exec.LookPath("git"); err != nil { + t.Skip("git not installed") + } + + dir := t.TempDir() + t.Chdir(dir) + + mustGit(t, "init", "-q") + mustGit(t, "config", "user.email", "test@test.com") + mustGit(t, "config", "user.name", "Test") + mustGit(t, "remote", "add", "origin", "https://github.com/mainowner/repo.git") + mustGit(t, "remote", "add", "fork", "https://github.com/forkowner/repo.git") + + // Create initial commit so we can create branches + if err := os.WriteFile(filepath.Join(dir, "README"), []byte("test"), 0644); err != nil { + t.Fatal(err) + } + mustGit(t, "add", "README") + mustGit(t, "commit", "-m", "init") + + // Create a branch tracking origin + mustGit(t, "checkout", "-b", "origin-branch") + mustGit(t, "config", "branch.origin-branch.remote", "origin") + + // Create a branch tracking fork + mustGit(t, "checkout", "-b", "fork-branch") + mustGit(t, "config", "branch.fork-branch.remote", "fork") + + tests := []struct { + branch string + wantOwner string + wantErr bool + }{ + {"origin-branch", "mainowner", false}, + {"fork-branch", "forkowner", false}, + {"nonexistent", "", true}, + } + + for _, tt := range tests { + t.Run(tt.branch, func(t *testing.T) { + owner, err := OwnerForBranch(context.Background(), tt.branch) + if tt.wantErr { + if err == nil { + t.Fatal("expected error") + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if owner != tt.wantOwner { + t.Errorf("owner = %q, want %q", owner, tt.wantOwner) + } + }) + } +} + func mustGit(t *testing.T, args ...string) { t.Helper() cmd := exec.Command("git", args...)