X Tutup
Skip to content

Commit dc0de14

Browse files
committed
Add pr checkout command
1 parent 1479ff4 commit dc0de14

File tree

4 files changed

+325
-1
lines changed

4 files changed

+325
-1
lines changed

api/queries.go

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,20 @@ type PullRequest struct {
1717
State string
1818
URL string
1919
HeadRefName string
20-
Reviews struct {
20+
21+
HeadRepositoryOwner struct {
22+
Login string
23+
}
24+
HeadRepository struct {
25+
Name string
26+
DefaultBranchRef struct {
27+
Name string
28+
}
29+
}
30+
IsCrossRepository bool
31+
MaintainerCanModify bool
32+
33+
Reviews struct {
2134
Nodes []struct {
2235
State string
2336
Author struct {
@@ -355,6 +368,48 @@ func PullRequests(client *Client, ghRepo Repo, currentBranch, currentUsername st
355368
return &payload, nil
356369
}
357370

371+
func PullRequestByNumber(client *Client, ghRepo Repo, number int) (*PullRequest, error) {
372+
type response struct {
373+
Repository struct {
374+
PullRequest PullRequest
375+
}
376+
}
377+
378+
query := `
379+
query($owner: String!, $repo: String!, $pr_number: Int!) {
380+
repository(owner: $owner, name: $repo) {
381+
pullRequest(number: $pr_number) {
382+
headRefName
383+
headRepositoryOwner {
384+
login
385+
}
386+
headRepository {
387+
name
388+
defaultBranchRef {
389+
name
390+
}
391+
}
392+
isCrossRepository
393+
maintainerCanModify
394+
}
395+
}
396+
}`
397+
398+
variables := map[string]interface{}{
399+
"owner": ghRepo.RepoOwner(),
400+
"repo": ghRepo.RepoName(),
401+
"pr_number": number,
402+
}
403+
404+
var resp response
405+
err := client.GraphQL(query, variables, &resp)
406+
if err != nil {
407+
return nil, err
408+
}
409+
410+
return &resp.Repository.PullRequest, nil
411+
}
412+
358413
func PullRequestsForBranch(client *Client, ghRepo Repo, branch string) ([]PullRequest, error) {
359414
type response struct {
360415
Repository struct {

command/pr.go

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,19 @@ package command
33
import (
44
"fmt"
55
"os"
6+
"os/exec"
67
"strconv"
78

89
"github.com/github/gh-cli/api"
10+
"github.com/github/gh-cli/git"
911
"github.com/github/gh-cli/utils"
1012
"github.com/spf13/cobra"
1113
"golang.org/x/crypto/ssh/terminal"
1214
)
1315

1416
func init() {
1517
RootCmd.AddCommand(prCmd)
18+
prCmd.AddCommand(prCheckoutCmd)
1619
prCmd.AddCommand(prCreateCmd)
1720
prCmd.AddCommand(prListCmd)
1821
prCmd.AddCommand(prStatusCmd)
@@ -29,6 +32,12 @@ var prCmd = &cobra.Command{
2932
Short: "Work with pull requests",
3033
Long: `Helps you work with pull requests.`,
3134
}
35+
var prCheckoutCmd = &cobra.Command{
36+
Use: "checkout <pr-number>",
37+
Short: "check out a pull request in git",
38+
Args: cobra.MinimumNArgs(1),
39+
RunE: prCheckout,
40+
}
3241
var prListCmd = &cobra.Command{
3342
Use: "list",
3443
Short: "List pull requests",
@@ -247,6 +256,103 @@ func prView(cmd *cobra.Command, args []string) error {
247256
return utils.OpenInBrowser(openURL)
248257
}
249258

259+
func prCheckout(cmd *cobra.Command, args []string) error {
260+
prNumber, err := strconv.Atoi(args[0])
261+
if err != nil {
262+
return err
263+
}
264+
265+
ctx := contextForCommand(cmd)
266+
currentBranch, err := ctx.Branch()
267+
if err != nil {
268+
return err
269+
}
270+
remotes, err := ctx.Remotes()
271+
if err != nil {
272+
return err
273+
}
274+
// FIXME: duplicates logic from fsContext.BaseRepo
275+
baseRemote, err := remotes.FindByName("upstream", "github", "origin", "*")
276+
if err != nil {
277+
return err
278+
}
279+
apiClient, err := apiClientForContext(ctx)
280+
if err != nil {
281+
return err
282+
}
283+
284+
pr, err := api.PullRequestByNumber(apiClient, baseRemote, prNumber)
285+
if err != nil {
286+
return err
287+
}
288+
289+
headRemote := baseRemote
290+
if pr.IsCrossRepository {
291+
headRemote, _ = remotes.FindByRepo(pr.HeadRepositoryOwner.Login, pr.HeadRepository.Name)
292+
}
293+
294+
cmdQueue := [][]string{}
295+
296+
newBranchName := pr.HeadRefName
297+
if headRemote != nil {
298+
// there is an existing git remote for PR head
299+
remoteBranch := fmt.Sprintf("%s/%s", headRemote.Name, pr.HeadRefName)
300+
refSpec := fmt.Sprintf("+refs/heads/%s:refs/remotes/%s", pr.HeadRefName, remoteBranch)
301+
302+
cmdQueue = append(cmdQueue, []string{"git", "fetch", headRemote.Name, refSpec})
303+
304+
// local branch already exists
305+
if git.HasFile("refs", "heads", newBranchName) {
306+
cmdQueue = append(cmdQueue, []string{"git", "checkout", newBranchName})
307+
cmdQueue = append(cmdQueue, []string{"git", "merge", "--ff-only", fmt.Sprintf("refs/remotes/%s", remoteBranch)})
308+
} else {
309+
cmdQueue = append(cmdQueue, []string{"git", "checkout", "-b", newBranchName, "--no-track", remoteBranch})
310+
cmdQueue = append(cmdQueue, []string{"git", "config", fmt.Sprintf("branch.%s.remote", newBranchName), headRemote.Name})
311+
cmdQueue = append(cmdQueue, []string{"git", "config", fmt.Sprintf("branch.%s.merge", newBranchName), "refs/heads/" + pr.HeadRefName})
312+
}
313+
} else {
314+
// no git remote for PR head
315+
316+
// avoid naming the new branch the same as the default branch
317+
if newBranchName == pr.HeadRepository.DefaultBranchRef.Name {
318+
newBranchName = fmt.Sprintf("%s/%s", pr.HeadRepositoryOwner.Login, newBranchName)
319+
}
320+
321+
ref := fmt.Sprintf("refs/pull/%d/head", prNumber)
322+
if newBranchName == currentBranch {
323+
// PR head matches currently checked out branch
324+
cmdQueue = append(cmdQueue, []string{"git", "fetch", baseRemote.Name, ref})
325+
cmdQueue = append(cmdQueue, []string{"git", "merge", "--ff-only", "FETCH_HEAD"})
326+
} else {
327+
// create a new branch
328+
cmdQueue = append(cmdQueue, []string{"git", "fetch", baseRemote.Name, fmt.Sprintf("%s:%s", ref, newBranchName)})
329+
cmdQueue = append(cmdQueue, []string{"git", "checkout", newBranchName})
330+
}
331+
332+
remote := baseRemote.Name
333+
mergeRef := ref
334+
if pr.MaintainerCanModify {
335+
remote = fmt.Sprintf("https://github.com/%s/%s.git", pr.HeadRepositoryOwner.Login, pr.HeadRepository.Name)
336+
mergeRef = fmt.Sprintf("refs/heads/%s", pr.HeadRefName)
337+
}
338+
if mc, err := git.Config(fmt.Sprintf("branch.%s.merge", newBranchName)); err != nil || mc == "" {
339+
cmdQueue = append(cmdQueue, []string{"git", "config", fmt.Sprintf("branch.%s.remote", newBranchName), remote})
340+
cmdQueue = append(cmdQueue, []string{"git", "config", fmt.Sprintf("branch.%s.merge", newBranchName), mergeRef})
341+
}
342+
}
343+
344+
for _, args := range cmdQueue {
345+
cmd := exec.Command(args[0], args[1:]...)
346+
cmd.Stdout = os.Stdout
347+
cmd.Stderr = os.Stderr
348+
if err := utils.PrepareCmd(cmd).Run(); err != nil {
349+
return err
350+
}
351+
}
352+
353+
return nil
354+
}
355+
250356
func printPrs(prs ...api.PullRequest) {
251357
for _, pr := range prs {
252358
prNumber := fmt.Sprintf("#%d", pr.Number)

command/pr_checkout_test.go

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
package command
2+
3+
import (
4+
"bytes"
5+
"os/exec"
6+
"strings"
7+
"testing"
8+
9+
"github.com/github/gh-cli/context"
10+
"github.com/github/gh-cli/utils"
11+
)
12+
13+
func TestPRCheckout_sameRepo(t *testing.T) {
14+
ctx := context.NewBlank()
15+
ctx.SetBranch("master")
16+
ctx.SetRemotes(map[string]string{
17+
"origin": "OWNER/REPO",
18+
})
19+
initContext = func() context.Context {
20+
return ctx
21+
}
22+
http := initFakeHTTP()
23+
24+
http.StubResponse(200, bytes.NewBufferString(`
25+
{ "data": { "repository": { "pullRequest": {
26+
"headRefName": "feature",
27+
"headRepositoryOwner": {
28+
"login": "hubot"
29+
},
30+
"headRepository": {
31+
"name": "REPO",
32+
"defaultBranchRef": {
33+
"name": "master"
34+
}
35+
},
36+
"isCrossRepository": false,
37+
"maintainerCanModify": false
38+
} } } }
39+
`))
40+
41+
ranCommands := [][]string{}
42+
restoreCmd := utils.SetPrepareCmd(func(cmd *exec.Cmd) utils.Runnable {
43+
ranCommands = append(ranCommands, cmd.Args)
44+
return &outputStub{}
45+
})
46+
defer restoreCmd()
47+
48+
RootCmd.SetArgs([]string{"pr", "checkout", "123"})
49+
_, err := prCheckoutCmd.ExecuteC()
50+
eq(t, err, nil)
51+
52+
eq(t, len(ranCommands), 6)
53+
eq(t, strings.Join(ranCommands[0], " "), "git rev-parse -q --git-path refs/heads/feature")
54+
eq(t, strings.Join(ranCommands[1], " "), "git rev-parse -q --git-dir")
55+
eq(t, strings.Join(ranCommands[2], " "), "git fetch origin +refs/heads/feature:refs/remotes/origin/feature")
56+
eq(t, strings.Join(ranCommands[3], " "), "git checkout -b feature --no-track origin/feature")
57+
eq(t, strings.Join(ranCommands[4], " "), "git config branch.feature.remote origin")
58+
eq(t, strings.Join(ranCommands[5], " "), "git config branch.feature.merge refs/heads/feature")
59+
}
60+
61+
func TestPRCheckout_differentRepo(t *testing.T) {
62+
ctx := context.NewBlank()
63+
ctx.SetBranch("master")
64+
ctx.SetRemotes(map[string]string{
65+
"origin": "OWNER/REPO",
66+
})
67+
initContext = func() context.Context {
68+
return ctx
69+
}
70+
http := initFakeHTTP()
71+
72+
http.StubResponse(200, bytes.NewBufferString(`
73+
{ "data": { "repository": { "pullRequest": {
74+
"headRefName": "feature",
75+
"headRepositoryOwner": {
76+
"login": "hubot"
77+
},
78+
"headRepository": {
79+
"name": "REPO",
80+
"defaultBranchRef": {
81+
"name": "master"
82+
}
83+
},
84+
"isCrossRepository": true,
85+
"maintainerCanModify": false
86+
} } } }
87+
`))
88+
89+
ranCommands := [][]string{}
90+
restoreCmd := utils.SetPrepareCmd(func(cmd *exec.Cmd) utils.Runnable {
91+
ranCommands = append(ranCommands, cmd.Args)
92+
return &outputStub{}
93+
})
94+
defer restoreCmd()
95+
96+
RootCmd.SetArgs([]string{"pr", "checkout", "123"})
97+
_, err := prCheckoutCmd.ExecuteC()
98+
eq(t, err, nil)
99+
100+
eq(t, len(ranCommands), 5)
101+
eq(t, strings.Join(ranCommands[0], " "), "git config branch.feature.merge")
102+
eq(t, strings.Join(ranCommands[1], " "), "git fetch origin refs/pull/123/head:feature")
103+
eq(t, strings.Join(ranCommands[2], " "), "git checkout feature")
104+
eq(t, strings.Join(ranCommands[3], " "), "git config branch.feature.remote origin")
105+
eq(t, strings.Join(ranCommands[4], " "), "git config branch.feature.merge refs/pull/123/head")
106+
}
107+
108+
func TestPRCheckout_maintainerCanModify(t *testing.T) {
109+
ctx := context.NewBlank()
110+
ctx.SetBranch("master")
111+
ctx.SetRemotes(map[string]string{
112+
"origin": "OWNER/REPO",
113+
})
114+
initContext = func() context.Context {
115+
return ctx
116+
}
117+
http := initFakeHTTP()
118+
119+
http.StubResponse(200, bytes.NewBufferString(`
120+
{ "data": { "repository": { "pullRequest": {
121+
"headRefName": "feature",
122+
"headRepositoryOwner": {
123+
"login": "hubot"
124+
},
125+
"headRepository": {
126+
"name": "REPO",
127+
"defaultBranchRef": {
128+
"name": "master"
129+
}
130+
},
131+
"isCrossRepository": true,
132+
"maintainerCanModify": true
133+
} } } }
134+
`))
135+
136+
ranCommands := [][]string{}
137+
restoreCmd := utils.SetPrepareCmd(func(cmd *exec.Cmd) utils.Runnable {
138+
ranCommands = append(ranCommands, cmd.Args)
139+
return &outputStub{}
140+
})
141+
defer restoreCmd()
142+
143+
RootCmd.SetArgs([]string{"pr", "checkout", "123"})
144+
_, err := prCheckoutCmd.ExecuteC()
145+
eq(t, err, nil)
146+
147+
eq(t, len(ranCommands), 5)
148+
eq(t, strings.Join(ranCommands[0], " "), "git config branch.feature.merge")
149+
eq(t, strings.Join(ranCommands[1], " "), "git fetch origin refs/pull/123/head:feature")
150+
eq(t, strings.Join(ranCommands[2], " "), "git checkout feature")
151+
eq(t, strings.Join(ranCommands[3], " "), "git config branch.feature.remote https://github.com/hubot/REPO.git")
152+
eq(t, strings.Join(ranCommands[4], " "), "git config branch.feature.merge refs/heads/feature")
153+
}

context/remote.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,16 @@ func (r Remotes) FindByName(names ...string) (*Remote, error) {
2525
return nil, fmt.Errorf("no GitHub remotes found")
2626
}
2727

28+
// FindByRepo returns the first Remote that points to a specific GitHub repository
29+
func (r Remotes) FindByRepo(owner, name string) (*Remote, error) {
30+
for _, rem := range r {
31+
if strings.EqualFold(rem.Owner, owner) && strings.EqualFold(rem.Name, name) {
32+
return rem, nil
33+
}
34+
}
35+
return nil, fmt.Errorf("no matching remote found")
36+
}
37+
2838
// Remote represents a git remote mapped to a GitHub repository
2939
type Remote struct {
3040
*git.Remote

0 commit comments

Comments
 (0)
X Tutup