44 "errors"
55 "fmt"
66 "net/url"
7+ "strings"
78 "time"
89
910 "github.com/cli/cli/api"
@@ -75,7 +76,27 @@ func prCreate(cmd *cobra.Command, _ []string) error {
7576 if err != nil {
7677 return fmt .Errorf ("could not determine the current branch: %w" , err )
7778 }
78- headRepo , headRepoErr := repoContext .HeadRepo ()
79+
80+ var headRepo ghrepo.Interface
81+ var headRemote * context.Remote
82+
83+ // determine whether the head branch is already pushed to a remote
84+ headBranchPushedTo := determineTrackingBranch (remotes , headBranch )
85+ if headBranchPushedTo != nil {
86+ for _ , r := range remotes {
87+ if r .Name != headBranchPushedTo .RemoteName {
88+ continue
89+ }
90+ headRepo = r
91+ headRemote = r
92+ break
93+ }
94+ }
95+
96+ // otherwise, determine the head repository with info obtained from the API
97+ if headRepo == nil {
98+ headRepo , _ = repoContext .HeadRepo ()
99+ }
79100
80101 baseBranch , err := cmd .Flags ().GetString ("base" )
81102 if err != nil {
@@ -193,8 +214,9 @@ func prCreate(cmd *cobra.Command, _ []string) error {
193214 }
194215
195216 didForkRepo := false
196- var headRemote * context.Remote
197- if headRepoErr != nil {
217+ // if a head repository could not be determined so far, automatically create
218+ // one by forking the base repository
219+ if headRepo == nil {
198220 if baseRepo .IsPrivate {
199221 return fmt .Errorf ("cannot fork private repository '%s'" , ghrepo .FullName (baseRepo ))
200222 }
@@ -203,11 +225,25 @@ func prCreate(cmd *cobra.Command, _ []string) error {
203225 return fmt .Errorf ("error forking repo: %w" , err )
204226 }
205227 didForkRepo = true
228+ }
229+
230+ headBranchLabel := headBranch
231+ if ! ghrepo .IsSame (baseRepo , headRepo ) {
232+ headBranchLabel = fmt .Sprintf ("%s:%s" , headRepo .RepoOwner (), headBranch )
233+ }
234+
235+ // There are two cases when an existing remote for the head repo will be
236+ // missing:
237+ // 1. the head repo was just created by auto-forking;
238+ // 2. an existing fork was discovered by quering the API.
239+ //
240+ // In either case, we want to add the head repo as a new git remote so we
241+ // can push to it.
242+ if err != nil {
206243 // TODO: support non-HTTPS git remote URLs
207- baseRepoURL := fmt .Sprintf ("https://github.com/%s.git" , ghrepo .FullName (baseRepo ))
208244 headRepoURL := fmt .Sprintf ("https://github.com/%s.git" , ghrepo .FullName (headRepo ))
209- // TODO: figure out what to name the new git remote
210- gitRemote , err := git .AddRemote ("fork" , baseRepoURL , headRepoURL )
245+ // TODO: prevent clashes with another remote of a same name
246+ gitRemote , err := git .AddRemote ("fork" , headRepoURL )
211247 if err != nil {
212248 return fmt .Errorf ("error adding remote: %w" , err )
213249 }
@@ -218,34 +254,31 @@ func prCreate(cmd *cobra.Command, _ []string) error {
218254 }
219255 }
220256
221- headBranchLabel := headBranch
222- if ! ghrepo .IsSame (baseRepo , headRepo ) {
223- headBranchLabel = fmt .Sprintf ("%s:%s" , headRepo .RepoOwner (), headBranch )
224- }
225-
226- if headRemote == nil {
227- headRemote , err = repoContext .RemoteForRepo (headRepo )
228- if err != nil {
229- return fmt .Errorf ("git remote not found for head repository: %w" , err )
257+ // automatically push the branch if it hasn't been pushed anywhere yet
258+ if headBranchPushedTo == nil {
259+ if headRemote == nil {
260+ headRemote , err = repoContext .RemoteForRepo (headRepo )
261+ if err != nil {
262+ return fmt .Errorf ("git remote not found for head repository: %w" , err )
263+ }
230264 }
231- }
232265
233- pushTries := 0
234- maxPushTries := 3
235- for {
236- // TODO: respect existing upstream configuration of the current branch
237- if err := git .Push (headRemote .Name , fmt .Sprintf ("HEAD:%s" , headBranch )); err != nil {
238- if didForkRepo && pushTries < maxPushTries {
239- pushTries ++
240- // first wait 2 seconds after forking, then 4s, then 6s
241- waitSeconds := 2 * pushTries
242- fmt .Fprintf (cmd .ErrOrStderr (), "waiting %s before retrying...\n " , utils .Pluralize (waitSeconds , "second" ))
243- time .Sleep (time .Duration (waitSeconds ) * time .Second )
244- continue
266+ pushTries := 0
267+ maxPushTries := 3
268+ for {
269+ if err := git .Push (headRemote .Name , fmt .Sprintf ("HEAD:%s" , headBranch )); err != nil {
270+ if didForkRepo && pushTries < maxPushTries {
271+ pushTries ++
272+ // first wait 2 seconds after forking, then 4s, then 6s
273+ waitSeconds := 2 * pushTries
274+ fmt .Fprintf (cmd .ErrOrStderr (), "waiting %s before retrying...\n " , utils .Pluralize (waitSeconds , "second" ))
275+ time .Sleep (time .Duration (waitSeconds ) * time .Second )
276+ continue
277+ }
278+ return err
245279 }
246- return err
280+ break
247281 }
248- break
249282 }
250283
251284 if action == SubmitAction {
@@ -275,6 +308,47 @@ func prCreate(cmd *cobra.Command, _ []string) error {
275308 return nil
276309}
277310
311+ func determineTrackingBranch (remotes context.Remotes , headBranch string ) * git.TrackingRef {
312+ refsForLookup := []string {"HEAD" }
313+ var trackingRefs []git.TrackingRef
314+
315+ headBranchConfig := git .ReadBranchConfig (headBranch )
316+ if headBranchConfig .RemoteName != "" {
317+ tr := git.TrackingRef {
318+ RemoteName : headBranchConfig .RemoteName ,
319+ BranchName : strings .TrimPrefix (headBranchConfig .MergeRef , "refs/heads/" ),
320+ }
321+ trackingRefs = append (trackingRefs , tr )
322+ refsForLookup = append (refsForLookup , tr .String ())
323+ }
324+
325+ for _ , remote := range remotes {
326+ tr := git.TrackingRef {
327+ RemoteName : remote .Name ,
328+ BranchName : headBranch ,
329+ }
330+ trackingRefs = append (trackingRefs , tr )
331+ refsForLookup = append (refsForLookup , tr .String ())
332+ }
333+
334+ resolvedRefs , _ := git .ShowRefs (refsForLookup ... )
335+ if len (resolvedRefs ) > 1 {
336+ for _ , r := range resolvedRefs [1 :] {
337+ if r .Hash != resolvedRefs [0 ].Hash {
338+ continue
339+ }
340+ for _ , tr := range trackingRefs {
341+ if tr .String () != r .Name {
342+ continue
343+ }
344+ return & tr
345+ }
346+ }
347+ }
348+
349+ return nil
350+ }
351+
278352func generateCompareURL (r ghrepo.Interface , base , head , title , body string ) string {
279353 u := fmt .Sprintf (
280354 "https://github.com/%s/compare/%s...%s?expand=1" ,
0 commit comments