|
4 | 4 | "bytes" |
5 | 5 | "errors" |
6 | 6 | "io/ioutil" |
| 7 | + "net/http" |
7 | 8 | "reflect" |
8 | 9 | "testing" |
9 | 10 |
|
@@ -91,5 +92,79 @@ func TestRESTError(t *testing.T) { |
91 | 92 | } |
92 | 93 | if httpErr.Error() != "HTTP 422: OH NO (https://api.github.com/repos/branch)" { |
93 | 94 | t.Errorf("got %q", httpErr.Error()) |
| 95 | + |
| 96 | + } |
| 97 | +} |
| 98 | + |
| 99 | +func Test_CheckScopes(t *testing.T) { |
| 100 | + tests := []struct { |
| 101 | + name string |
| 102 | + wantScope string |
| 103 | + responseApp string |
| 104 | + responseScopes string |
| 105 | + expectCallback bool |
| 106 | + }{ |
| 107 | + { |
| 108 | + name: "missing read:org", |
| 109 | + wantScope: "read:org", |
| 110 | + responseApp: "APPID", |
| 111 | + responseScopes: "repo, gist", |
| 112 | + expectCallback: true, |
| 113 | + }, |
| 114 | + { |
| 115 | + name: "has read:org", |
| 116 | + wantScope: "read:org", |
| 117 | + responseApp: "APPID", |
| 118 | + responseScopes: "repo, read:org, gist", |
| 119 | + expectCallback: false, |
| 120 | + }, |
| 121 | + { |
| 122 | + name: "has admin:org", |
| 123 | + wantScope: "read:org", |
| 124 | + responseApp: "APPID", |
| 125 | + responseScopes: "repo, admin:org, gist", |
| 126 | + expectCallback: false, |
| 127 | + }, |
| 128 | + } |
| 129 | + for _, tt := range tests { |
| 130 | + t.Run(tt.name, func(t *testing.T) { |
| 131 | + tr := &httpmock.Registry{} |
| 132 | + tr.Register(httpmock.MatchAny, func(*http.Request) (*http.Response, error) { |
| 133 | + return &http.Response{ |
| 134 | + StatusCode: 200, |
| 135 | + Header: http.Header{ |
| 136 | + "X-Oauth-Client-Id": []string{tt.responseApp}, |
| 137 | + "X-Oauth-Scopes": []string{tt.responseScopes}, |
| 138 | + }, |
| 139 | + }, nil |
| 140 | + }) |
| 141 | + |
| 142 | + callbackInvoked := false |
| 143 | + var gotAppID string |
| 144 | + fn := CheckScopes(tt.wantScope, func(appID string) error { |
| 145 | + callbackInvoked = true |
| 146 | + gotAppID = appID |
| 147 | + return nil |
| 148 | + }) |
| 149 | + |
| 150 | + rt := fn(tr) |
| 151 | + req, err := http.NewRequest("GET", "https://api.github.com/hello", nil) |
| 152 | + if err != nil { |
| 153 | + t.Fatalf("unexpected error: %v", err) |
| 154 | + } |
| 155 | + |
| 156 | + issuedScopesWarning = false |
| 157 | + _, err = rt.RoundTrip(req) |
| 158 | + if err != nil { |
| 159 | + t.Fatalf("unexpected error: %v", err) |
| 160 | + } |
| 161 | + |
| 162 | + if tt.expectCallback != callbackInvoked { |
| 163 | + t.Fatalf("expected CheckScopes callback: %v", tt.expectCallback) |
| 164 | + } |
| 165 | + if tt.expectCallback && gotAppID != tt.responseApp { |
| 166 | + t.Errorf("unexpected app ID: %q", gotAppID) |
| 167 | + } |
| 168 | + }) |
94 | 169 | } |
95 | 170 | } |
0 commit comments