@@ -23,6 +23,12 @@ type issueTemplate struct {
2323 Gbody string `graphql:"body"`
2424}
2525
26+ type pullRequestTemplate struct {
27+ // I would have un-exported these fields, except `cli/shurcool-graphql` then cannot unmarshal them :/
28+ Gname string `graphql:"filename"`
29+ Gbody string `graphql:"body"`
30+ }
31+
2632func (t * issueTemplate ) Name () string {
2733 return t .Gname
2834}
@@ -35,7 +41,19 @@ func (t *issueTemplate) Body() []byte {
3541 return []byte (t .Gbody )
3642}
3743
38- func listIssueTemplates (httpClient * http.Client , repo ghrepo.Interface ) ([]issueTemplate , error ) {
44+ func (t * pullRequestTemplate ) Name () string {
45+ return t .Gname
46+ }
47+
48+ func (t * pullRequestTemplate ) NameForSubmit () string {
49+ return ""
50+ }
51+
52+ func (t * pullRequestTemplate ) Body () []byte {
53+ return []byte (t .Gbody )
54+ }
55+
56+ func listIssueTemplates (httpClient * http.Client , repo ghrepo.Interface ) ([]Template , error ) {
3957 var query struct {
4058 Repository struct {
4159 IssueTemplates []issueTemplate
@@ -54,10 +72,44 @@ func listIssueTemplates(httpClient *http.Client, repo ghrepo.Interface) ([]issue
5472 return nil , err
5573 }
5674
57- return query .Repository .IssueTemplates , nil
75+ ts := query .Repository .IssueTemplates
76+ templates := make ([]Template , len (ts ))
77+ for i := range templates {
78+ templates [i ] = & ts [i ]
79+ }
80+
81+ return templates , nil
5882}
5983
60- func hasIssueTemplateSupport (httpClient * http.Client , hostname string ) (bool , error ) {
84+ func listPullRequestTemplates (httpClient * http.Client , repo ghrepo.Interface ) ([]Template , error ) {
85+ var query struct {
86+ Repository struct {
87+ PullRequestTemplates []pullRequestTemplate
88+ } `graphql:"repository(owner: $owner, name: $name)"`
89+ }
90+
91+ variables := map [string ]interface {}{
92+ "owner" : githubv4 .String (repo .RepoOwner ()),
93+ "name" : githubv4 .String (repo .RepoName ()),
94+ }
95+
96+ gql := graphql .NewClient (ghinstance .GraphQLEndpoint (repo .RepoHost ()), httpClient )
97+
98+ err := gql .QueryNamed (context .Background (), "PullRequestTemplates" , & query , variables )
99+ if err != nil {
100+ return nil , err
101+ }
102+
103+ ts := query .Repository .PullRequestTemplates
104+ templates := make ([]Template , len (ts ))
105+ for i := range templates {
106+ templates [i ] = & ts [i ]
107+ }
108+
109+ return templates , nil
110+ }
111+
112+ func hasTemplateSupport (httpClient * http.Client , hostname string , isPR bool ) (bool , error ) {
61113 if ! ghinstance .IsEnterprise (hostname ) {
62114 return true , nil
63115 }
@@ -81,20 +133,29 @@ func hasIssueTemplateSupport(httpClient *http.Client, hostname string) (bool, er
81133 return false , err
82134 }
83135
84- var hasQuerySupport bool
85- var hasMutationSupport bool
136+ var hasIssueQuerySupport bool
137+ var hasIssueMutationSupport bool
138+ var hasPullRequestQuerySupport bool
139+
86140 for _ , field := range featureDetection .Repository .Fields {
87141 if field .Name == "issueTemplates" {
88- hasQuerySupport = true
142+ hasIssueQuerySupport = true
143+ }
144+ if field .Name == "pullRequestTemplates" {
145+ hasPullRequestQuerySupport = true
89146 }
90147 }
91148 for _ , field := range featureDetection .CreateIssueInput .InputFields {
92149 if field .Name == "issueTemplate" {
93- hasMutationSupport = true
150+ hasIssueMutationSupport = true
94151 }
95152 }
96153
97- return hasQuerySupport && hasMutationSupport , nil
154+ if isPR {
155+ return hasPullRequestQuerySupport , nil
156+ } else {
157+ return hasIssueQuerySupport && hasIssueMutationSupport , nil
158+ }
98159}
99160
100161type Template interface {
@@ -129,13 +190,10 @@ func NewTemplateManager(httpClient *http.Client, repo ghrepo.Interface, dir stri
129190}
130191
131192func (m * templateManager ) hasAPI () (bool , error ) {
132- if m .isPR {
133- return false , nil
134- }
135193 if m .cachedClient == nil {
136194 m .cachedClient = api .NewCachedClient (m .httpClient , time .Hour * 24 )
137195 }
138- return hasIssueTemplateSupport (m .cachedClient , m .repo .RepoHost ())
196+ return hasTemplateSupport (m .cachedClient , m .repo .RepoHost (), m . isPR )
139197}
140198
141199func (m * templateManager ) HasTemplates () (bool , error ) {
@@ -201,14 +259,15 @@ func (m *templateManager) fetch() error {
201259 }
202260
203261 if hasAPI {
204- issueTemplates , err := listIssueTemplates (m .httpClient , m .repo )
262+ lister := listIssueTemplates
263+ if m .isPR {
264+ lister = listPullRequestTemplates
265+ }
266+ templates , err := lister (m .httpClient , m .repo )
205267 if err != nil {
206268 return err
207269 }
208- m .templates = make ([]Template , len (issueTemplates ))
209- for i := range issueTemplates {
210- m .templates [i ] = & issueTemplates [i ]
211- }
270+ m .templates = templates
212271 }
213272
214273 if ! m .allowFS {
0 commit comments