Skip to content

Commit

Permalink
Add header for caching key
Browse files Browse the repository at this point in the history
  • Loading branch information
akarashchuk committed Apr 24, 2024
1 parent c59f977 commit b71e209
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 8 deletions.
7 changes: 5 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,16 @@ Krakend plugin for caching backend responses
"name": "onliner/krakend-http-cache",
"onliner/krakend-http-cache": {
"ttl": 180,
"connection": "redis"
"connection": "redis",
"headers": []
}
}
...
```

`ttl` - cache ttl in seconds
`connection` - name of cache connection
`headers` - headers used for cache key

## Cache connections

Expand Down Expand Up @@ -74,7 +76,8 @@ Krakend plugin for caching backend responses
"name": "onliner/krakend-http-cache",
"onliner/krakend-http-cache": {
"ttl": 180,
"connection": "redis"
"connection": "redis",
"headers": ["X-Custom-Headers"]
}
}
}
Expand Down
24 changes: 19 additions & 5 deletions handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"io"
"net/http"
"net/http/httputil"
"strings"
"time"

"github.com/google/uuid"
Expand All @@ -21,8 +22,9 @@ type CacheHandler struct {
}

type ClientConfig struct {
Ttl uint64
Conn string `mapstructure:"connection"`
Ttl uint64
Conn string `mapstructure:"connection"`
Headers []string
}

func NewCacheHandler(client *http.Client, logger Logger) *CacheHandler {
Expand Down Expand Up @@ -75,7 +77,7 @@ func (h *CacheHandler) loadFromCache(req *http.Request, cnf *ClientConfig) *http
return nil
}

v, err := conn.Fetch(cacheKey(req))
v, err := conn.Fetch(cacheKey(req, cnf))
if err != nil {
return nil
}
Expand Down Expand Up @@ -103,7 +105,7 @@ func (h *CacheHandler) saveToCache(res *http.Response, cnf *ClientConfig) {
return
}

err = conn.Save(cacheKey(res.Request), string(dump), time.Duration(cnf.Ttl)*time.Second)
err = conn.Save(cacheKey(res.Request, cnf), string(dump), time.Duration(cnf.Ttl)*time.Second)
if err != nil {
h.logger.Error(fmt.Sprintf("failed save to cache: %v", err))
}
Expand Down Expand Up @@ -146,8 +148,20 @@ func cloneRequest(req *http.Request) *http.Request {
return clone
}

func cacheKey(req *http.Request) string {
func cacheKey(req *http.Request, cnf *ClientConfig) string {
url := req.URL.RequestURI()

var headers []string
for _, h := range cnf.Headers {
val := req.Header.Values(h)
if val != nil {
headers = append(headers, fmt.Sprintf("%s:%s", strings.ToLower(h), strings.Join(val, ",")))
}
}

if len(headers) > 0 {
url = fmt.Sprintf("%s|headers:%s", url, strings.Join(headers, "/"))
}

return fmt.Sprintf("krakend-hc:%s", uuid.NewSHA1(uuid.NameSpaceURL, []byte(url)))
}
42 changes: 41 additions & 1 deletion handler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func TestHandle(t *testing.T) {
}

func TestHandleNotSupportedMethods(t *testing.T) {
methods := []string{http.MethodPost, http.MethodPut}
methods := []string{http.MethodHead, http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete}
for _, method := range methods {
req := newRequest(method, nil)
registerResponse(method, http.StatusOK, nil)
Expand Down Expand Up @@ -148,3 +148,43 @@ func TestHandleEtag(t *testing.T) {

assert.Equal(t, 1, httpmock.GetTotalCallCount())
}

func TestHandleHeaders(t *testing.T) {
setup()
defer teardown()

requests := []struct {
headers1 map[string]string
headers2 map[string]string
callCount int
}{
{map[string]string{"X-Custom-Header": "1"}, map[string]string{"x-custom-header": "1"}, 1},
{map[string]string{"X-Custom-Header": "1"}, map[string]string{"X-Custom-Header": "2"}, 2},
{map[string]string{"X-Custom-Header": "1"}, map[string]string{"X-Custom-Header": ""}, 2},
{map[string]string{"X-Custom-Header": "1"}, nil, 2},
}

handler := NewCacheHandler(http.DefaultClient, noopLogger{}).Handle(&ClientConfig{Ttl: 1, Conn: conn, Headers: []string{"X-Custom-Header"}})

for _, request := range requests {
req := newRequest(http.MethodGet, request.headers1)
registerResponse(http.MethodGet, http.StatusOK, nil)

rr := httptest.NewRecorder()
handler.ServeHTTP(rr, req)

assert.Equal(t, http.StatusOK, rr.Result().StatusCode, "Status code mismatch")
assert.Equal(t, body, rr.Body.String(), "Body mismatch")

req = newRequest(http.MethodGet, request.headers2)
rr = httptest.NewRecorder()
handler.ServeHTTP(rr, req)

assert.Equal(t, http.StatusOK, rr.Result().StatusCode, "Status code mismatch")
assert.Equal(t, body, rr.Body.String(), "Body mismatch")

assert.Equal(t, request.callCount, httpmock.GetTotalCallCount())
httpmock.Reset()
GetCache(conn).Flush()
}
}

0 comments on commit b71e209

Please sign in to comment.