diff --git a/README.md b/README.md index aedc06f..1f800f1 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,8 @@ Krakend plugin for caching backend responses "name": "onliner/krakend-http-cache", "onliner/krakend-http-cache": { "ttl": 180, - "connection": "redis" + "connection": "redis", + "headers": [] } } ... @@ -18,6 +19,7 @@ Krakend plugin for caching backend responses `ttl` - cache ttl in seconds `connection` - name of cache connection +`headers` - headers used for cache key ## Cache connections @@ -74,7 +76,8 @@ Krakend plugin for caching backend responses "name": "onliner/krakend-http-cache", "onliner/krakend-http-cache": { "ttl": 180, - "connection": "redis" + "connection": "redis", + "headers": ["X-Custom-Headers"] } } } diff --git a/handler.go b/handler.go index e558e7f..638468f 100644 --- a/handler.go +++ b/handler.go @@ -7,6 +7,7 @@ import ( "io" "net/http" "net/http/httputil" + "strings" "time" "github.com/google/uuid" @@ -21,8 +22,9 @@ type CacheHandler struct { } type ClientConfig struct { - Ttl uint64 - Conn string `mapstructure:"connection"` + Ttl uint64 + Conn string `mapstructure:"connection"` + Headers []string } func NewCacheHandler(client *http.Client, logger Logger) *CacheHandler { @@ -75,7 +77,7 @@ func (h *CacheHandler) loadFromCache(req *http.Request, cnf *ClientConfig) *http return nil } - v, err := conn.Fetch(cacheKey(req)) + v, err := conn.Fetch(cacheKey(req, cnf)) if err != nil { return nil } @@ -103,7 +105,7 @@ func (h *CacheHandler) saveToCache(res *http.Response, cnf *ClientConfig) { return } - err = conn.Save(cacheKey(res.Request), string(dump), time.Duration(cnf.Ttl)*time.Second) + err = conn.Save(cacheKey(res.Request, cnf), string(dump), time.Duration(cnf.Ttl)*time.Second) if err != nil { h.logger.Error(fmt.Sprintf("failed save to cache: %v", err)) } @@ -146,8 +148,20 @@ func cloneRequest(req *http.Request) *http.Request { return clone } -func cacheKey(req *http.Request) string { +func cacheKey(req *http.Request, cnf *ClientConfig) string { url := req.URL.RequestURI() + var headers []string + for _, h := range cnf.Headers { + val := req.Header.Values(h) + if val != nil { + headers = append(headers, fmt.Sprintf("%s:%s", strings.ToLower(h), strings.Join(val, ","))) + } + } + + if len(headers) > 0 { + url = fmt.Sprintf("%s|headers:%s", url, strings.Join(headers, "/")) + } + return fmt.Sprintf("krakend-hc:%s", uuid.NewSHA1(uuid.NameSpaceURL, []byte(url))) } diff --git a/handler_test.go b/handler_test.go index 749f69e..22d7a84 100644 --- a/handler_test.go +++ b/handler_test.go @@ -97,7 +97,7 @@ func TestHandle(t *testing.T) { } func TestHandleNotSupportedMethods(t *testing.T) { - methods := []string{http.MethodPost, http.MethodPut} + methods := []string{http.MethodHead, http.MethodPost, http.MethodPut, http.MethodPatch, http.MethodDelete} for _, method := range methods { req := newRequest(method, nil) registerResponse(method, http.StatusOK, nil) @@ -148,3 +148,43 @@ func TestHandleEtag(t *testing.T) { assert.Equal(t, 1, httpmock.GetTotalCallCount()) } + +func TestHandleHeaders(t *testing.T) { + setup() + defer teardown() + + requests := []struct { + headers1 map[string]string + headers2 map[string]string + callCount int + }{ + {map[string]string{"X-Custom-Header": "1"}, map[string]string{"x-custom-header": "1"}, 1}, + {map[string]string{"X-Custom-Header": "1"}, map[string]string{"X-Custom-Header": "2"}, 2}, + {map[string]string{"X-Custom-Header": "1"}, map[string]string{"X-Custom-Header": ""}, 2}, + {map[string]string{"X-Custom-Header": "1"}, nil, 2}, + } + + handler := NewCacheHandler(http.DefaultClient, noopLogger{}).Handle(&ClientConfig{Ttl: 1, Conn: conn, Headers: []string{"X-Custom-Header"}}) + + for _, request := range requests { + req := newRequest(http.MethodGet, request.headers1) + registerResponse(http.MethodGet, http.StatusOK, nil) + + rr := httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Result().StatusCode, "Status code mismatch") + assert.Equal(t, body, rr.Body.String(), "Body mismatch") + + req = newRequest(http.MethodGet, request.headers2) + rr = httptest.NewRecorder() + handler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Result().StatusCode, "Status code mismatch") + assert.Equal(t, body, rr.Body.String(), "Body mismatch") + + assert.Equal(t, request.callCount, httpmock.GetTotalCallCount()) + httpmock.Reset() + GetCache(conn).Flush() + } +}