X Tutup
Skip to content

Commit 7663acd

Browse files
committed
Improve HTTP caching layer
- make thread-safe - only cache GET, HEAD, and GraphQL requests - only cache non-5xx, non-403 responses - include `Accept` and `Authorization` headers in cache key
1 parent f7a82a2 commit 7663acd

File tree

2 files changed

+116
-33
lines changed

2 files changed

+116
-33
lines changed

api/cache.go

Lines changed: 85 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ import (
1111
"net/http"
1212
"os"
1313
"path/filepath"
14+
"strings"
15+
"sync"
1416
"time"
1517
)
1618

@@ -21,39 +23,79 @@ func makeCachedClient(httpClient *http.Client, cacheTTL time.Duration) *http.Cli
2123
}
2224
}
2325

26+
func isCacheableRequest(req *http.Request) bool {
27+
if strings.EqualFold(req.Method, "GET") || strings.EqualFold(req.Method, "HEAD") {
28+
return true
29+
}
30+
31+
if strings.EqualFold(req.Method, "POST") && (req.URL.Path == "/graphql" || req.URL.Path == "/api/graphql") {
32+
return true
33+
}
34+
35+
return false
36+
}
37+
38+
func isCacheableResponse(res *http.Response) bool {
39+
return res.StatusCode < 500 && res.StatusCode != 403
40+
}
41+
2442
// CacheReponse produces a RoundTripper that caches HTTP responses to disk for a specified amount of time
2543
func CacheReponse(ttl time.Duration, dir string) ClientOption {
44+
fs := fileStorage{
45+
dir: dir,
46+
ttl: ttl,
47+
mu: &sync.RWMutex{},
48+
}
49+
2650
return func(tr http.RoundTripper) http.RoundTripper {
2751
return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) {
52+
if !isCacheableRequest(req) {
53+
return tr.RoundTrip(req)
54+
}
55+
2856
key, keyErr := cacheKey(req)
29-
cacheFile := filepath.Join(dir, key)
3057
if keyErr == nil {
31-
// TODO: make thread-safe
32-
if res, err := readCache(ttl, cacheFile, req); err == nil {
58+
if res, err := fs.read(key); err == nil {
59+
res.Request = req
3360
return res, nil
3461
}
3562
}
63+
3664
res, err := tr.RoundTrip(req)
37-
if err == nil && keyErr == nil {
38-
// TODO: make thread-safe
39-
_ = writeCache(cacheFile, res)
65+
if err == nil && keyErr == nil && isCacheableResponse(res) {
66+
_ = fs.store(key, res)
4067
}
4168
return res, err
4269
}}
4370
}
4471
}
4572

73+
func copyStream(r io.ReadCloser) (io.ReadCloser, io.ReadCloser) {
74+
b := &bytes.Buffer{}
75+
nr := io.TeeReader(r, b)
76+
return ioutil.NopCloser(b), &readCloser{
77+
Reader: nr,
78+
Closer: r,
79+
}
80+
}
81+
82+
type readCloser struct {
83+
io.Reader
84+
io.Closer
85+
}
86+
4687
func cacheKey(req *http.Request) (string, error) {
4788
h := sha256.New()
4889
fmt.Fprintf(h, "%s:", req.Method)
4990
fmt.Fprintf(h, "%s:", req.URL.String())
91+
fmt.Fprintf(h, "%s:", req.Header.Get("Accept"))
92+
fmt.Fprintf(h, "%s:", req.Header.Get("Authorization"))
5093

5194
if req.Body != nil {
52-
bodyCopy := &bytes.Buffer{}
53-
defer req.Body.Close()
54-
_, err := io.Copy(h, io.TeeReader(req.Body, bodyCopy))
55-
req.Body = ioutil.NopCloser(bodyCopy)
56-
if err != nil {
95+
var bodyCopy io.ReadCloser
96+
req.Body, bodyCopy = copyStream(req.Body)
97+
defer bodyCopy.Close()
98+
if _, err := io.Copy(h, bodyCopy); err != nil {
5799
return "", err
58100
}
59101
}
@@ -62,20 +104,38 @@ func cacheKey(req *http.Request) (string, error) {
62104
return fmt.Sprintf("%x", digest), nil
63105
}
64106

65-
func readCache(ttl time.Duration, cacheFile string, req *http.Request) (*http.Response, error) {
107+
type fileStorage struct {
108+
dir string
109+
ttl time.Duration
110+
mu *sync.RWMutex
111+
}
112+
113+
func (fs *fileStorage) filePath(key string) string {
114+
if len(key) >= 6 {
115+
return filepath.Join(fs.dir, key[0:2], key[2:4], key[4:])
116+
}
117+
return filepath.Join(fs.dir, key)
118+
}
119+
120+
func (fs *fileStorage) read(key string) (*http.Response, error) {
121+
cacheFile := fs.filePath(key)
122+
123+
fs.mu.RLock()
124+
defer fs.mu.RUnlock()
125+
66126
f, err := os.Open(cacheFile)
67127
if err != nil {
68128
return nil, err
69129
}
70130
defer f.Close()
71131

72-
fs, err := f.Stat()
132+
stat, err := f.Stat()
73133
if err != nil {
74134
return nil, err
75135
}
76136

77-
age := time.Since(fs.ModTime())
78-
if age > ttl {
137+
age := time.Since(stat.ModTime())
138+
if age > fs.ttl {
79139
return nil, errors.New("cache expired")
80140
}
81141

@@ -85,11 +145,16 @@ func readCache(ttl time.Duration, cacheFile string, req *http.Request) (*http.Re
85145
return nil, err
86146
}
87147

88-
res, err := http.ReadResponse(bufio.NewReader(body), req)
148+
res, err := http.ReadResponse(bufio.NewReader(body), nil)
89149
return res, err
90150
}
91151

92-
func writeCache(cacheFile string, res *http.Response) error {
152+
func (fs *fileStorage) store(key string, res *http.Response) error {
153+
cacheFile := fs.filePath(key)
154+
155+
fs.mu.Lock()
156+
defer fs.mu.Unlock()
157+
93158
err := os.MkdirAll(filepath.Dir(cacheFile), 0755)
94159
if err != nil {
95160
return err
@@ -101,10 +166,10 @@ func writeCache(cacheFile string, res *http.Response) error {
101166
}
102167
defer f.Close()
103168

104-
bodyCopy := &bytes.Buffer{}
169+
var origBody io.ReadCloser
170+
origBody, res.Body = copyStream(res.Body)
105171
defer res.Body.Close()
106-
res.Body = ioutil.NopCloser(io.TeeReader(res.Body, bodyCopy))
107172
err = res.Write(f)
108-
res.Body = ioutil.NopCloser(bodyCopy)
173+
res.Body = origBody
109174
return err
110175
}

api/cache_test.go

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,12 @@ func Test_CacheReponse(t *testing.T) {
2020
roundTrip: func(req *http.Request) (*http.Response, error) {
2121
counter += 1
2222
body := fmt.Sprintf("%d: %s %s", counter, req.Method, req.URL.String())
23+
status := 200
24+
if req.URL.Path == "/error" {
25+
status = 500
26+
}
2327
return &http.Response{
24-
StatusCode: 200,
28+
StatusCode: status,
2529
Body: ioutil.NopCloser(bytes.NewBufferString(body)),
2630
}, nil
2731
},
@@ -47,25 +51,39 @@ func Test_CacheReponse(t *testing.T) {
4751
return string(resBody), err
4852
}
4953

50-
res1, err := do("GET", "http://example.com/path", nil)
54+
var res string
55+
var err error
56+
57+
res, err = do("GET", "http://example.com/path", nil)
58+
require.NoError(t, err)
59+
assert.Equal(t, "1: GET http://example.com/path", res)
60+
res, err = do("GET", "http://example.com/path", nil)
5161
require.NoError(t, err)
52-
assert.Equal(t, "1: GET http://example.com/path", res1)
53-
res2, err := do("GET", "http://example.com/path", nil)
62+
assert.Equal(t, "1: GET http://example.com/path", res)
63+
64+
res, err = do("GET", "http://example.com/path2", nil)
5465
require.NoError(t, err)
55-
assert.Equal(t, "1: GET http://example.com/path", res2)
66+
assert.Equal(t, "2: GET http://example.com/path2", res)
5667

57-
res3, err := do("GET", "http://example.com/path2", nil)
68+
res, err = do("POST", "http://example.com/path2", nil)
5869
require.NoError(t, err)
59-
assert.Equal(t, "2: GET http://example.com/path2", res3)
70+
assert.Equal(t, "3: POST http://example.com/path2", res)
6071

61-
res4, err := do("POST", "http://example.com/path", bytes.NewBufferString(`hello`))
72+
res, err = do("POST", "http://example.com/graphql", bytes.NewBufferString(`hello`))
6273
require.NoError(t, err)
63-
assert.Equal(t, "3: POST http://example.com/path", res4)
64-
res5, err := do("POST", "http://example.com/path", bytes.NewBufferString(`hello`))
74+
assert.Equal(t, "4: POST http://example.com/graphql", res)
75+
res, err = do("POST", "http://example.com/graphql", bytes.NewBufferString(`hello`))
6576
require.NoError(t, err)
66-
assert.Equal(t, "3: POST http://example.com/path", res5)
77+
assert.Equal(t, "4: POST http://example.com/graphql", res)
6778

68-
res6, err := do("POST", "http://example.com/path", bytes.NewBufferString(`hello2`))
79+
res, err = do("POST", "http://example.com/graphql", bytes.NewBufferString(`hello2`))
80+
require.NoError(t, err)
81+
assert.Equal(t, "5: POST http://example.com/graphql", res)
82+
83+
res, err = do("GET", "http://example.com/error", nil)
84+
require.NoError(t, err)
85+
assert.Equal(t, "6: GET http://example.com/error", res)
86+
res, err = do("GET", "http://example.com/error", nil)
6987
require.NoError(t, err)
70-
assert.Equal(t, "4: POST http://example.com/path", res6)
88+
assert.Equal(t, "7: GET http://example.com/error", res)
7189
}

0 commit comments

Comments
 (0)
X Tutup