X Tutup
Skip to content

Commit e8c4b30

Browse files
author
Nate Smith
authored
Merge pull request cli#1522 from cli/oauth-device-flow
Implement OAuth Device Authorization flow
2 parents fc1c800 + a6776d0 commit e8c4b30

File tree

3 files changed

+267
-22
lines changed

3 files changed

+267
-22
lines changed

auth/oauth.go

Lines changed: 123 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@ import (
1010
"net"
1111
"net/http"
1212
"net/url"
13-
"os"
13+
"strconv"
1414
"strings"
15+
"time"
1516

1617
"github.com/cli/cli/internal/ghinstance"
17-
"github.com/cli/cli/pkg/browser"
1818
)
1919

2020
func randomString(length int) (string, error) {
@@ -32,13 +32,130 @@ type OAuthFlow struct {
3232
ClientID string
3333
ClientSecret string
3434
Scopes []string
35+
OpenInBrowser func(string, string) error
3536
WriteSuccessHTML func(io.Writer)
3637
VerboseStream io.Writer
38+
HTTPClient *http.Client
39+
TimeNow func() time.Time
40+
TimeSleep func(time.Duration)
3741
}
3842

3943
// ObtainAccessToken guides the user through the browser OAuth flow on GitHub
4044
// and returns the OAuth access token upon completion.
4145
func (oa *OAuthFlow) ObtainAccessToken() (accessToken string, err error) {
46+
// first, check if OAuth Device Flow is supported
47+
initURL := fmt.Sprintf("https://%s/login/device/code", oa.Hostname)
48+
tokenURL := fmt.Sprintf("https://%s/login/oauth/access_token", oa.Hostname)
49+
50+
oa.logf("POST %s\n", initURL)
51+
resp, err := oa.HTTPClient.PostForm(initURL, url.Values{
52+
"client_id": {oa.ClientID},
53+
"scope": {strings.Join(oa.Scopes, " ")},
54+
})
55+
if err != nil {
56+
return
57+
}
58+
defer resp.Body.Close()
59+
60+
if resp.StatusCode == 401 || resp.StatusCode == 403 || resp.StatusCode == 404 {
61+
// OAuth Device Flow is not available; continue with OAuth browser flow with a
62+
// local server endpoint as callback target
63+
return oa.localServerFlow()
64+
} else if resp.StatusCode != 200 {
65+
return "", fmt.Errorf("error: HTTP %d (%s)", resp.StatusCode, initURL)
66+
}
67+
68+
bb, err := ioutil.ReadAll(resp.Body)
69+
if err != nil {
70+
return
71+
}
72+
values, err := url.ParseQuery(string(bb))
73+
if err != nil {
74+
return
75+
}
76+
77+
timeNow := oa.TimeNow
78+
if timeNow == nil {
79+
timeNow = time.Now
80+
}
81+
timeSleep := oa.TimeSleep
82+
if timeSleep == nil {
83+
timeSleep = time.Sleep
84+
}
85+
86+
intervalSeconds, err := strconv.Atoi(values.Get("interval"))
87+
if err != nil {
88+
return "", fmt.Errorf("could not parse interval=%q as integer: %w", values.Get("interval"), err)
89+
}
90+
checkInterval := time.Duration(intervalSeconds) * time.Second
91+
92+
expiresIn, err := strconv.Atoi(values.Get("expires_in"))
93+
if err != nil {
94+
return "", fmt.Errorf("could not parse expires_in=%q as integer: %w", values.Get("expires_in"), err)
95+
}
96+
expiresAt := timeNow().Add(time.Duration(expiresIn) * time.Second)
97+
98+
err = oa.OpenInBrowser(values.Get("verification_uri"), values.Get("user_code"))
99+
if err != nil {
100+
return
101+
}
102+
103+
for {
104+
timeSleep(checkInterval)
105+
accessToken, err = oa.deviceFlowPing(tokenURL, values.Get("device_code"))
106+
if accessToken == "" && err == nil {
107+
if timeNow().After(expiresAt) {
108+
err = errors.New("authentication timed out")
109+
} else {
110+
continue
111+
}
112+
}
113+
break
114+
}
115+
116+
return
117+
}
118+
119+
func (oa *OAuthFlow) deviceFlowPing(tokenURL, deviceCode string) (accessToken string, err error) {
120+
oa.logf("POST %s\n", tokenURL)
121+
resp, err := oa.HTTPClient.PostForm(tokenURL, url.Values{
122+
"client_id": {oa.ClientID},
123+
"device_code": {deviceCode},
124+
"grant_type": {"urn:ietf:params:oauth:grant-type:device_code"},
125+
})
126+
if err != nil {
127+
return "", err
128+
}
129+
defer resp.Body.Close()
130+
if resp.StatusCode != 200 {
131+
return "", fmt.Errorf("error: HTTP %d (%s)", resp.StatusCode, tokenURL)
132+
}
133+
134+
bb, err := ioutil.ReadAll(resp.Body)
135+
if err != nil {
136+
return "", err
137+
}
138+
values, err := url.ParseQuery(string(bb))
139+
if err != nil {
140+
return "", err
141+
}
142+
143+
if accessToken := values.Get("access_token"); accessToken != "" {
144+
return accessToken, nil
145+
}
146+
147+
errorType := values.Get("error")
148+
if errorType == "authorization_pending" {
149+
return "", nil
150+
}
151+
152+
if errorDescription := values.Get("error_description"); errorDescription != "" {
153+
return "", errors.New(errorDescription)
154+
}
155+
return "", errors.New("OAuth device flow error")
156+
}
157+
158+
func (oa *OAuthFlow) localServerFlow() (accessToken string, err error) {
42159
state, _ := randomString(20)
43160

44161
code := ""
@@ -70,15 +187,9 @@ func (oa *OAuthFlow) ObtainAccessToken() (accessToken string, err error) {
70187

71188
startURL := fmt.Sprintf("https://%s/login/oauth/authorize?%s", oa.Hostname, q.Encode())
72189
oa.logf("open %s\n", startURL)
73-
if err := openInBrowser(startURL); err != nil {
74-
fmt.Fprintf(os.Stderr, "error opening web browser: %s\n", err)
75-
fmt.Fprintf(os.Stderr, "")
76-
fmt.Fprintf(os.Stderr, "Please open the following URL manually:\n%s\n", startURL)
77-
fmt.Fprintf(os.Stderr, "")
78-
// TODO: Temporary workaround for https://github.com/cli/cli/issues/297
79-
fmt.Fprintf(os.Stderr, "If you are on a server or other headless system, use this workaround instead:\n")
80-
fmt.Fprintf(os.Stderr, " 1. Complete authentication on a GUI system;\n")
81-
fmt.Fprintf(os.Stderr, " 2. Copy the contents of `~/.config/gh/hosts.yml` to this system.\n")
190+
err = oa.OpenInBrowser(startURL, "")
191+
if err != nil {
192+
return
82193
}
83194

84195
_ = http.Serve(listener, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@@ -105,7 +216,7 @@ func (oa *OAuthFlow) ObtainAccessToken() (accessToken string, err error) {
105216

106217
tokenURL := fmt.Sprintf("https://%s/login/oauth/access_token", oa.Hostname)
107218
oa.logf("POST %s\n", tokenURL)
108-
resp, err := http.PostForm(tokenURL,
219+
resp, err := oa.HTTPClient.PostForm(tokenURL,
109220
url.Values{
110221
"client_id": {oa.ClientID},
111222
"client_secret": {oa.ClientSecret},
@@ -143,11 +254,3 @@ func (oa *OAuthFlow) logf(format string, args ...interface{}) {
143254
}
144255
fmt.Fprintf(oa.VerboseStream, format, args...)
145256
}
146-
147-
func openInBrowser(url string) error {
148-
cmd, err := browser.Command(url)
149-
if err != nil {
150-
return err
151-
}
152-
return cmd.Run()
153-
}

auth/oauth_test.go

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
package auth
2+
3+
import (
4+
"bytes"
5+
"fmt"
6+
"io/ioutil"
7+
"net/http"
8+
"net/url"
9+
"testing"
10+
"time"
11+
)
12+
13+
type roundTripper func(*http.Request) (*http.Response, error)
14+
15+
func (rt roundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
16+
return rt(req)
17+
}
18+
19+
func TestObtainAccessToken_deviceFlow(t *testing.T) {
20+
requestCount := 0
21+
rt := func(req *http.Request) (*http.Response, error) {
22+
route := fmt.Sprintf("%s %s", req.Method, req.URL)
23+
switch route {
24+
case "POST https://github.com/login/device/code":
25+
if err := req.ParseForm(); err != nil {
26+
return nil, err
27+
}
28+
if req.PostForm.Get("client_id") != "CLIENT-ID" {
29+
t.Errorf("expected POST /login/device/code to supply client_id=%q, got %q", "CLIENT-ID", req.PostForm.Get("client_id"))
30+
}
31+
if req.PostForm.Get("scope") != "repo gist" {
32+
t.Errorf("expected POST /login/device/code to supply scope=%q, got %q", "repo gist", req.PostForm.Get("scope"))
33+
}
34+
35+
responseData := url.Values{}
36+
responseData.Set("device_code", "DEVICE-CODE")
37+
responseData.Set("user_code", "1234-ABCD")
38+
responseData.Set("verification_uri", "https://github.com/login/device")
39+
responseData.Set("interval", "5")
40+
responseData.Set("expires_in", "899")
41+
42+
return &http.Response{
43+
StatusCode: 200,
44+
Body: ioutil.NopCloser(bytes.NewBufferString(responseData.Encode())),
45+
}, nil
46+
case "POST https://github.com/login/oauth/access_token":
47+
if err := req.ParseForm(); err != nil {
48+
return nil, err
49+
}
50+
if req.PostForm.Get("client_id") != "CLIENT-ID" {
51+
t.Errorf("expected POST /login/oauth/access_token to supply client_id=%q, got %q", "CLIENT-ID", req.PostForm.Get("client_id"))
52+
}
53+
if req.PostForm.Get("device_code") != "DEVICE-CODE" {
54+
t.Errorf("expected POST /login/oauth/access_token to supply device_code=%q, got %q", "DEVICE-CODE", req.PostForm.Get("scope"))
55+
}
56+
if req.PostForm.Get("grant_type") != "urn:ietf:params:oauth:grant-type:device_code" {
57+
t.Errorf("expected POST /login/oauth/access_token to supply grant_type=%q, got %q", "urn:ietf:params:oauth:grant-type:device_code", req.PostForm.Get("grant_type"))
58+
}
59+
60+
responseData := url.Values{}
61+
requestCount++
62+
if requestCount == 1 {
63+
responseData.Set("error", "authorization_pending")
64+
} else {
65+
responseData.Set("access_token", "OTOKEN")
66+
}
67+
68+
return &http.Response{
69+
StatusCode: 200,
70+
Body: ioutil.NopCloser(bytes.NewBufferString(responseData.Encode())),
71+
}, nil
72+
default:
73+
return nil, fmt.Errorf("unstubbed HTTP request: %v", route)
74+
}
75+
}
76+
httpClient := &http.Client{
77+
Transport: roundTripper(rt),
78+
}
79+
80+
slept := time.Duration(0)
81+
var browseURL string
82+
var browseCode string
83+
84+
oa := &OAuthFlow{
85+
Hostname: "github.com",
86+
ClientID: "CLIENT-ID",
87+
ClientSecret: "CLIENT-SEKRIT",
88+
Scopes: []string{"repo", "gist"},
89+
OpenInBrowser: func(url, code string) error {
90+
browseURL = url
91+
browseCode = code
92+
return nil
93+
},
94+
HTTPClient: httpClient,
95+
TimeNow: time.Now,
96+
TimeSleep: func(d time.Duration) {
97+
slept += d
98+
},
99+
}
100+
101+
token, err := oa.ObtainAccessToken()
102+
if err != nil {
103+
t.Fatalf("ObtainAccessToken error: %v", err)
104+
}
105+
106+
if token != "OTOKEN" {
107+
t.Errorf("expected token %q, got %q", "OTOKEN", token)
108+
}
109+
if requestCount != 2 {
110+
t.Errorf("expected 2 HTTP pings for token, got %d", requestCount)
111+
}
112+
if slept.String() != "10s" {
113+
t.Errorf("expected total sleep duration of %s, got %s", "10s", slept.String())
114+
}
115+
if browseURL != "https://github.com/login/device" {
116+
t.Errorf("expected to open browser at %s, got %s", "https://github.com/login/device", browseURL)
117+
}
118+
if browseCode != "1234-ABCD" {
119+
t.Errorf("expected to provide user with one-time code %q, got %q", "1234-ABCD", browseCode)
120+
}
121+
}

internal/config/config_setup.go

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@ import (
44
"bufio"
55
"fmt"
66
"io"
7+
"net/http"
78
"os"
89
"strings"
910

1011
"github.com/cli/cli/api"
1112
"github.com/cli/cli/auth"
13+
"github.com/cli/cli/pkg/browser"
1214
"github.com/cli/cli/utils"
1315
)
1416

@@ -68,11 +70,30 @@ func authFlow(oauthHost, notice string, additionalScopes []string) (string, stri
6870
fmt.Fprintln(w, oauthSuccessPage)
6971
},
7072
VerboseStream: verboseStream,
73+
HTTPClient: http.DefaultClient,
74+
OpenInBrowser: func(url, code string) error {
75+
if code != "" {
76+
fmt.Fprintf(os.Stderr, "%s First copy your one-time code: %s\n", utils.Yellow("!"), utils.Bold(code))
77+
}
78+
fmt.Fprintf(os.Stderr, "- %s to open %s in your browser... ", utils.Bold("Press Enter"), oauthHost)
79+
_ = waitForEnter(os.Stdin)
80+
81+
browseCmd, err := browser.Command(url)
82+
if err != nil {
83+
return err
84+
}
85+
err = browseCmd.Run()
86+
if err != nil {
87+
fmt.Fprintf(os.Stderr, "%s Failed opening a web browser at %s\n", utils.Red("!"), url)
88+
fmt.Fprintf(os.Stderr, " %s\n", err)
89+
fmt.Fprint(os.Stderr, " Please try entering the URL in your browser manually\n")
90+
}
91+
return nil
92+
},
7193
}
7294

7395
fmt.Fprintln(os.Stderr, notice)
74-
fmt.Fprintf(os.Stderr, "- %s to open %s in your browser... ", utils.Bold("Press Enter"), flow.Hostname)
75-
_ = waitForEnter(os.Stdin)
96+
7697
token, err := flow.ObtainAccessToken()
7798
if err != nil {
7899
return "", "", err

0 commit comments

Comments
 (0)
X Tutup