Skip to content

Commit

Permalink
Support option for headerXRequestID (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
shuntagami authored Apr 17, 2022
1 parent 941abc2 commit 679ae72
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 6 deletions.
11 changes: 10 additions & 1 deletion options.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,18 @@ type Option func(*config)

type Generator func() string

// WithGenerator set fenerator function
type HeaderStrKey string

// WithGenerator set generator function
func WithGenerator(g Generator) Option {
return func(cfg *config) {
cfg.generator = g
}
}

// WithCustomeHeaderStrKey set custom header key for request id
func WithCustomHeaderStrKey(s HeaderStrKey) Option {
return func(cfg *config) {
cfg.headerKey = s
}
}
10 changes: 5 additions & 5 deletions requestid.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import (
"github.com/google/uuid"
)

const headerXRequestID = "X-Request-ID"
var headerXRequestID string

// Config defines the config for RequestID middleware
type config struct {
Expand All @@ -14,6 +14,7 @@ type config struct {
// return uuid.New().String()
// }
generator Generator
headerKey HeaderStrKey
}

// New initializes the RequestID middleware.
Expand All @@ -22,6 +23,7 @@ func New(opts ...Option) gin.HandlerFunc {
generator: func() string {
return uuid.New().String()
},
headerKey: "X-Request-ID",
}

for _, opt := range opts {
Expand All @@ -30,13 +32,11 @@ func New(opts ...Option) gin.HandlerFunc {

return func(c *gin.Context) {
// Get id from request
rid := c.GetHeader(headerXRequestID)
rid := c.GetHeader(string(cfg.headerKey))
if rid == "" {
rid = cfg.generator()
// Set the id to ensure that the requestid is in the request
c.Request.Header.Add(headerXRequestID, rid)
}

headerXRequestID = string(cfg.headerKey)
// Set the id to ensure that the requestid is in the response
c.Header(headerXRequestID, rid)
c.Next()
Expand Down
18 changes: 18 additions & 0 deletions requestid_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,21 @@ func TestRequestIDWithCustomID(t *testing.T) {
assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, testXRequestID, w.Header().Get(headerXRequestID))
}

func TestRequestIDWithCustomHeaderKey(t *testing.T) {
r := gin.New()
r.Use(
New(
WithCustomHeaderStrKey("customKey"),
),
)
r.GET("/", emptySuccessResponse)

w := httptest.NewRecorder()
req, _ := http.NewRequestWithContext(context.Background(), "GET", "/", nil)
req.Header.Set("customKey", testXRequestID)
r.ServeHTTP(w, req)

assert.Equal(t, http.StatusOK, w.Code)
assert.Equal(t, testXRequestID, w.Header().Get("customKey"))
}

0 comments on commit 679ae72

Please sign in to comment.