package api
import (
"bytes"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/url"
"regexp"
"strings"
"github.com/cli/cli/v2/internal/ghinstance"
graphql "github.com/cli/shurcooL-graphql"
"github.com/henvic/httpretty"
)
// ClientOption represents an argument to NewClient
type ClientOption = func(http.RoundTripper) http.RoundTripper
// NewHTTPClient initializes an http.Client
func NewHTTPClient(opts ...ClientOption) *http.Client {
tr := http.DefaultTransport
for _, opt := range opts {
tr = opt(tr)
}
return &http.Client{Transport: tr}
}
// NewClient initializes a Client
func NewClient(opts ...ClientOption) *Client {
client := &Client{http: NewHTTPClient(opts...)}
return client
}
// NewClientFromHTTP takes in an http.Client instance
func NewClientFromHTTP(httpClient *http.Client) *Client {
client := &Client{http: httpClient}
return client
}
// AddHeader turns a RoundTripper into one that adds a request header
func AddHeader(name, value string) ClientOption {
return func(tr http.RoundTripper) http.RoundTripper {
return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) {
if req.Header.Get(name) == "" {
req.Header.Add(name, value)
}
return tr.RoundTrip(req)
}}
}
}
// AddHeaderFunc is an AddHeader that gets the string value from a function
func AddHeaderFunc(name string, getValue func(*http.Request) (string, error)) ClientOption {
return func(tr http.RoundTripper) http.RoundTripper {
return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) {
if req.Header.Get(name) != "" {
return tr.RoundTrip(req)
}
value, err := getValue(req)
if err != nil {
return nil, err
}
if value != "" {
req.Header.Add(name, value)
}
return tr.RoundTrip(req)
}}
}
}
// VerboseLog enables request/response logging within a RoundTripper
func VerboseLog(out io.Writer, logTraffic bool, colorize bool) ClientOption {
logger := &httpretty.Logger{
Time: true,
TLS: false,
Colors: colorize,
RequestHeader: logTraffic,
RequestBody: logTraffic,
ResponseHeader: logTraffic,
ResponseBody: logTraffic,
Formatters: []httpretty.Formatter{&httpretty.JSONFormatter{}},
MaxResponseBody: 10000,
}
logger.SetOutput(out)
logger.SetBodyFilter(func(h http.Header) (skip bool, err error) {
return !inspectableMIMEType(h.Get("Content-Type")), nil
})
return logger.RoundTripper
}
// ReplaceTripper substitutes the underlying RoundTripper with a custom one
func ReplaceTripper(tr http.RoundTripper) ClientOption {
return func(http.RoundTripper) http.RoundTripper {
return tr
}
}
// ExtractHeader extracts a named header from any response received by this client and, if non-blank, saves
// it to dest.
func ExtractHeader(name string, dest *string) ClientOption {
return func(tr http.RoundTripper) http.RoundTripper {
return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) {
res, err := tr.RoundTrip(req)
if err == nil {
if value := res.Header.Get(name); value != "" {
*dest = value
}
}
return res, err
}}
}
}
type funcTripper struct {
roundTrip func(*http.Request) (*http.Response, error)
}
func (tr funcTripper) RoundTrip(req *http.Request) (*http.Response, error) {
return tr.roundTrip(req)
}
// Client facilitates making HTTP requests to the GitHub API
type Client struct {
http *http.Client
}
func (c *Client) HTTP() *http.Client {
return c.http
}
type graphQLResponse struct {
Data interface{}
Errors []GraphQLError
}
// GraphQLError is a single error returned in a GraphQL response
type GraphQLError struct {
Type string
Message string
Path []interface{} // mixed strings and numbers
}
func (ge GraphQLError) PathString() string {
var res strings.Builder
for i, v := range ge.Path {
if i > 0 {
res.WriteRune('.')
}
fmt.Fprintf(&res, "%v", v)
}
return res.String()
}
// GraphQLErrorResponse contains errors returned in a GraphQL response
type GraphQLErrorResponse struct {
Errors []GraphQLError
}
func (gr GraphQLErrorResponse) Error() string {
errorMessages := make([]string, 0, len(gr.Errors))
for _, e := range gr.Errors {
msg := e.Message
if p := e.PathString(); p != "" {
msg = fmt.Sprintf("%s (%s)", msg, p)
}
errorMessages = append(errorMessages, msg)
}
return fmt.Sprintf("GraphQL: %s", strings.Join(errorMessages, ", "))
}
// Match checks if this error is only about a specific type on a specific path. If the path argument ends
// with a ".", it will match all its subpaths as well.
func (gr GraphQLErrorResponse) Match(expectType, expectPath string) bool {
for _, e := range gr.Errors {
if e.Type != expectType || !matchPath(e.PathString(), expectPath) {
return false
}
}
return true
}
func matchPath(p, expect string) bool {
if strings.HasSuffix(expect, ".") {
return strings.HasPrefix(p, expect) || p == strings.TrimSuffix(expect, ".")
}
return p == expect
}
// HTTPError is an error returned by a failed API call
type HTTPError struct {
StatusCode int
RequestURL *url.URL
Message string
Errors []HTTPErrorItem
scopesSuggestion string
}
type HTTPErrorItem struct {
Message string
Resource string
Field string
Code string
}
func (err HTTPError) Error() string {
if msgs := strings.SplitN(err.Message, "\n", 2); len(msgs) > 1 {
return fmt.Sprintf("HTTP %d: %s (%s)\n%s", err.StatusCode, msgs[0], err.RequestURL, msgs[1])
} else if err.Message != "" {
return fmt.Sprintf("HTTP %d: %s (%s)", err.StatusCode, err.Message, err.RequestURL)
}
return fmt.Sprintf("HTTP %d (%s)", err.StatusCode, err.RequestURL)
}
func (err HTTPError) ScopesSuggestion() string {
return err.scopesSuggestion
}
// ScopesSuggestion is an error messaging utility that prints the suggestion to request additional OAuth
// scopes in case a server response indicates that there are missing scopes.
func ScopesSuggestion(resp *http.Response) string {
if resp.StatusCode < 400 || resp.StatusCode > 499 || resp.StatusCode == 422 {
return ""
}
endpointNeedsScopes := resp.Header.Get("X-Accepted-Oauth-Scopes")
tokenHasScopes := resp.Header.Get("X-Oauth-Scopes")
if tokenHasScopes == "" {
return ""
}
gotScopes := map[string]struct{}{}
for _, s := range strings.Split(tokenHasScopes, ",") {
s = strings.TrimSpace(s)
gotScopes[s] = struct{}{}
// Certain scopes may be grouped under a single "top-level" scope. The following branch
// statements include these grouped/implied scopes when the top-level scope is encountered.
// See https://docs.github.com/en/developers/apps/building-oauth-apps/scopes-for-oauth-apps.
if s == "repo" {
gotScopes["repo:status"] = struct{}{}
gotScopes["repo_deployment"] = struct{}{}
gotScopes["public_repo"] = struct{}{}
gotScopes["repo:invite"] = struct{}{}
gotScopes["security_events"] = struct{}{}
} else if s == "user" {
gotScopes["read:user"] = struct{}{}
gotScopes["user:email"] = struct{}{}
gotScopes["user:follow"] = struct{}{}
} else if s == "codespace" {
gotScopes["codespace:secrets"] = struct{}{}
} else if strings.HasPrefix(s, "admin:") {
gotScopes["read:"+strings.TrimPrefix(s, "admin:")] = struct{}{}
gotScopes["write:"+strings.TrimPrefix(s, "admin:")] = struct{}{}
} else if strings.HasPrefix(s, "write:") {
gotScopes["read:"+strings.TrimPrefix(s, "write:")] = struct{}{}
}
}
for _, s := range strings.Split(endpointNeedsScopes, ",") {
s = strings.TrimSpace(s)
if _, gotScope := gotScopes[s]; s == "" || gotScope {
continue
}
return fmt.Sprintf(
"This API operation needs the %[1]q scope. To request it, run: gh auth refresh -h %[2]s -s %[1]s",
s,
ghinstance.NormalizeHostname(resp.Request.URL.Hostname()),
)
}
return ""
}
// EndpointNeedsScopes adds additional OAuth scopes to an HTTP response as if they were returned from the
// server endpoint. This improves HTTP 4xx error messaging for endpoints that don't explicitly list the
// OAuth scopes they need.
func EndpointNeedsScopes(resp *http.Response, s string) *http.Response {
if resp.StatusCode >= 400 && resp.StatusCode < 500 {
oldScopes := resp.Header.Get("X-Accepted-Oauth-Scopes")
resp.Header.Set("X-Accepted-Oauth-Scopes", fmt.Sprintf("%s, %s", oldScopes, s))
}
return resp
}
// GraphQL performs a GraphQL request and parses the response. If there are errors in the response,
// *GraphQLErrorResponse will be returned, but the data will also be parsed into the receiver.
func (c Client) GraphQL(hostname string, query string, variables map[string]interface{}, data interface{}) error {
reqBody, err := json.Marshal(map[string]interface{}{"query": query, "variables": variables})
if err != nil {
return err
}
req, err := http.NewRequest("POST", ghinstance.GraphQLEndpoint(hostname), bytes.NewBuffer(reqBody))
if err != nil {
return err
}
req.Header.Set("Content-Type", "application/json; charset=utf-8")
resp, err := c.http.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
return handleResponse(resp, data)
}
func graphQLClient(h *http.Client, hostname string) *graphql.Client {
return graphql.NewClient(ghinstance.GraphQLEndpoint(hostname), h)
}
// REST performs a REST request and parses the response.
func (c Client) REST(hostname string, method string, p string, body io.Reader, data interface{}) error {
_, err := c.RESTWithNext(hostname, method, p, body, data)
return err
}
func (c Client) RESTWithNext(hostname string, method string, p string, body io.Reader, data interface{}) (string, error) {
req, err := http.NewRequest(method, restURL(hostname, p), body)
if err != nil {
return "", err
}
req.Header.Set("Content-Type", "application/json; charset=utf-8")
resp, err := c.http.Do(req)
if err != nil {
return "", err
}
defer resp.Body.Close()
success := resp.StatusCode >= 200 && resp.StatusCode < 300
if !success {
return "", HandleHTTPError(resp)
}
if resp.StatusCode == http.StatusNoContent {
return "", nil
}
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
return "", err
}
err = json.Unmarshal(b, &data)
if err != nil {
return "", err
}
var next string
for _, m := range linkRE.FindAllStringSubmatch(resp.Header.Get("Link"), -1) {
if len(m) > 2 && m[2] == "next" {
next = m[1]
}
}
return next, nil
}
var linkRE = regexp.MustCompile(`<([^>]+)>;\s*rel="([^"]+)"`)
func restURL(hostname string, pathOrURL string) string {
if strings.HasPrefix(pathOrURL, "https://") || strings.HasPrefix(pathOrURL, "http://") {
return pathOrURL
}
return ghinstance.RESTPrefix(hostname) + pathOrURL
}
func handleResponse(resp *http.Response, data interface{}) error {
success := resp.StatusCode >= 200 && resp.StatusCode < 300
if !success {
return HandleHTTPError(resp)
}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return err
}
gr := &graphQLResponse{Data: data}
err = json.Unmarshal(body, &gr)
if err != nil {
return err
}
if len(gr.Errors) > 0 {
return &GraphQLErrorResponse{Errors: gr.Errors}
}
return nil
}
func HandleHTTPError(resp *http.Response) error {
httpError := HTTPError{
StatusCode: resp.StatusCode,
RequestURL: resp.Request.URL,
scopesSuggestion: ScopesSuggestion(resp),
}
if !jsonTypeRE.MatchString(resp.Header.Get("Content-Type")) {
httpError.Message = resp.Status
return httpError
}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
httpError.Message = err.Error()
return httpError
}
var parsedBody struct {
Message string `json:"message"`
Errors []json.RawMessage
}
if err := json.Unmarshal(body, &parsedBody); err != nil {
return httpError
}
var messages []string
if parsedBody.Message != "" {
messages = append(messages, parsedBody.Message)
}
for _, raw := range parsedBody.Errors {
switch raw[0] {
case '"':
var errString string
_ = json.Unmarshal(raw, &errString)
messages = append(messages, errString)
httpError.Errors = append(httpError.Errors, HTTPErrorItem{Message: errString})
case '{':
var errInfo HTTPErrorItem
_ = json.Unmarshal(raw, &errInfo)
msg := errInfo.Message
if errInfo.Code != "" && errInfo.Code != "custom" {
msg = fmt.Sprintf("%s.%s %s", errInfo.Resource, errInfo.Field, errorCodeToMessage(errInfo.Code))
}
if msg != "" {
messages = append(messages, msg)
}
httpError.Errors = append(httpError.Errors, errInfo)
}
}
httpError.Message = strings.Join(messages, "\n")
return httpError
}
func errorCodeToMessage(code string) string {
// https://docs.github.com/en/rest/overview/resources-in-the-rest-api#client-errors
switch code {
case "missing", "missing_field":
return "is missing"
case "invalid", "unprocessable":
return "is invalid"
case "already_exists":
return "already exists"
default:
return code
}
}
var jsonTypeRE = regexp.MustCompile(`[/+]json($|;)`)
func inspectableMIMEType(t string) bool {
return strings.HasPrefix(t, "text/") || jsonTypeRE.MatchString(t)
}