Skip to content
This repository has been archived by the owner on May 2, 2023. It is now read-only.

Commit

Permalink
feat: update jwt
Browse files Browse the repository at this point in the history
  • Loading branch information
mohammadne committed Sep 12, 2021
1 parent aade55c commit 34acdb6
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 15 deletions.
95 changes: 95 additions & 0 deletions services/auth/internal/jwt/failures.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
package jwt

import (
"fmt"
"net/http"

"github.com/mohammadne/bookman/auth/pkg/failures"
)

type failure struct {
FailureMessage string `json:"message"`
FailureStatus int `json:"status"`
FailureCauses []string `json:"causes"`
}

// ==============================================================> methods

func (f *failure) Message() string {
return f.FailureMessage
}

func (f *failure) Status() int {
return f.FailureStatus
}

func (f *failure) Causes() []string {
return f.FailureCauses
}

func (f *failure) Error() string {
return fmt.Sprintf(
"message: %s - status: %d - causes: %v",
f.FailureMessage, f.FailureStatus, f.FailureCauses,
)
}

// ==============================================================> constructors

type Failure struct{}

func (Failure) New(message string, status int, causes []string) failures.Failure {
return &failure{
FailureMessage: message,
FailureStatus: status,
FailureCauses: causes,
}
}

func (Failure) NewBadRequest(message string) failures.Failure {
return &failure{
FailureMessage: message,
FailureStatus: http.StatusBadRequest,
}
}

func (Failure) NewNotFound(message string) failures.Failure {
return &failure{
FailureMessage: message,
FailureStatus: http.StatusNotFound,
}
}

func (Failure) NewUnauthorized(message string) failures.Failure {
return &failure{
FailureMessage: message,
FailureStatus: http.StatusUnauthorized,
}
}

func (Failure) NewUnprocessableEntity(message string) failures.Failure {
return &failure{
FailureMessage: message,
FailureStatus: http.StatusUnprocessableEntity,
}
}

func (Failure) NewNotImplemented() failures.Failure {
return &failure{
FailureMessage: "not implemented",
FailureStatus: http.StatusNotImplemented,
}
}

func (Failure) NewInternalServer(message string, errors ...error) failures.Failure {
causes := make([]string, 0, len(errors))
for index := 0; index < len(errors); index++ {
causes = append(causes, errors[index].Error())
}

return &failure{
FailureMessage: message,
FailureStatus: http.StatusInternalServerError,
FailureCauses: causes,
}
}
46 changes: 31 additions & 15 deletions services/auth/internal/jwt/token.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package jwt

import (
"context"
"errors"
"fmt"
"strconv"
Expand All @@ -22,9 +23,9 @@ const (
)

type Jwt interface {
CreateJwt(userId uint64) (*models.Jwt, failures.Failure)
ExtractTokenMetadata(tokenString string, tokenType TokenType) (*models.AccessDetails, failures.Failure)
TokenValid(tokenString string, tokenType TokenType) error
CreateJwt(context.Context, uint64) (*models.Jwt, failures.Failure)
ExtractTokenMetadata(context.Context, string, TokenType) (*models.AccessDetails, failures.Failure)
TokenValid(context.Context, string, TokenType) failures.Failure
}

type jwt struct {
Expand All @@ -37,8 +38,10 @@ func New(cfg *Config, lg logger.Logger, tr trace.Tracer) Jwt {
return &jwt{config: cfg, logger: lg, tracer: tr}
}

// failureUnprocessableEntity
func (jwt *jwt) CreateJwt(userId uint64) (*models.Jwt, failures.Failure) {
func (jwt *jwt) CreateJwt(ctx context.Context, userId uint64) (*models.Jwt, failures.Failure) {
ctx, span := jwt.tracer.Start(ctx, "jwt.create_jwt")
defer span.End()

accessExpires := time.Duration(jwt.config.AccessExpires) * time.Hour
refreshExpires := time.Duration(jwt.config.RefreshExpires) * time.Hour

Expand All @@ -55,12 +58,18 @@ func (jwt *jwt) CreateJwt(userId uint64) (*models.Jwt, failures.Failure) {

accessErr := createToken(userId, jwt.config.AccessSecret, tokenDetail.AccessToken)
if accessErr != nil {
return nil, accessErr
failure := Failure{}.NewUnprocessableEntity("unprocessable access token")
jwt.logger.Error(failure.Message(), logger.Error(accessErr))
span.RecordError(accessErr)
return nil, failure
}

refreshErr := createToken(userId, jwt.config.RefreshSecret, tokenDetail.RefreshToken)
if refreshErr != nil {
return nil, refreshErr
failure := Failure{}.NewUnprocessableEntity("unprocessable refresh token")
jwt.logger.Error(failure.Message(), logger.Error(refreshErr))
span.RecordError(refreshErr)
return nil, failure
}

return tokenDetail, nil
Expand All @@ -87,10 +96,14 @@ func createToken(userId uint64, secret string, token *models.Token) error {
// failureUnprocessableEntity
// failureUnautorized = failures.Network{}.NewUnauthorized("unauthorized")
// failureUnautorized
func (jwt *jwt) ExtractTokenMetadata(tokenString string, tokenType TokenType) (*models.AccessDetails, failures.Failure) {
token, err := jwt.verifyToken(tokenString, tokenType)
if err != nil {
return nil, err
func (jwt *jwt) ExtractTokenMetadata(ctx context.Context, tokenString string, tokenType TokenType) (*models.AccessDetails, failures.Failure) {
ctx, span := jwt.tracer.Start(ctx, "jwt.extract_token_metadata")
defer span.End()

token, failure := jwt.verifyToken(tokenString, tokenType)
if failure != nil {
span.RecordError(failure)
return nil, failure
}

claims, ok := token.Claims.(jwtPkg.MapClaims)
Expand All @@ -103,16 +116,19 @@ func (jwt *jwt) ExtractTokenMetadata(tokenString string, tokenType TokenType) (*
userIdStr := fmt.Sprintf("%.f", claims["user_id"])
userId, err := strconv.ParseUint(userIdStr, 10, 64)
if err != nil {
span.RecordError(err)
return nil, err
}

return &models.AccessDetails{TokenUuid: tokenUuid, UserId: userId}, nil
}

return nil, errors.New("invalid token")
err = errors.New("invalid token")
span.RecordError(err)
return nil, err
}

func (jwt *jwt) TokenValid(tokenString string, tokenType TokenType) error {
func (jwt *jwt) TokenValid(ctx context.Context, tokenString string, tokenType TokenType) failures.Failure {
token, err := jwt.verifyToken(tokenString, tokenType)
if err != nil {
return err
Expand All @@ -125,7 +141,7 @@ func (jwt *jwt) TokenValid(tokenString string, tokenType TokenType) error {
return nil
}

func (jwt *jwt) verifyToken(tokenString string, tokenType TokenType) (*jwtPkg.Token, error) {
func (jwt *jwt) verifyToken(tokenString string, tokenType TokenType) (*jwtPkg.Token, failures.Failure) {
var secret string
if tokenType == Access {
secret = jwt.config.AccessSecret
Expand All @@ -141,7 +157,7 @@ func (jwt *jwt) verifyToken(tokenString string, tokenType TokenType) (*jwtPkg.To
return token, nil
}

//checkSigningMethod checks the token method conform to "SigningMethodHMAC"
// checkSigningMethod checks the token method conform to "SigningMethodHMAC"
func (jwt *jwt) checkSigningMethod(secret string) jwtPkg.Keyfunc {
return func(token *jwtPkg.Token) (interface{}, error) {
if _, ok := token.Method.(*jwtPkg.SigningMethodHMAC); !ok {
Expand Down

0 comments on commit 34acdb6

Please sign in to comment.