From 34acdb636de6bc031e59b7fd3b0997073905327f Mon Sep 17 00:00:00 2001 From: mohammadne Date: Sun, 12 Sep 2021 09:21:32 +0430 Subject: [PATCH] feat: update jwt --- services/auth/internal/jwt/failures.go | 95 ++++++++++++++++++++++++++ services/auth/internal/jwt/token.go | 46 +++++++++---- 2 files changed, 126 insertions(+), 15 deletions(-) create mode 100644 services/auth/internal/jwt/failures.go diff --git a/services/auth/internal/jwt/failures.go b/services/auth/internal/jwt/failures.go new file mode 100644 index 0000000..893ffef --- /dev/null +++ b/services/auth/internal/jwt/failures.go @@ -0,0 +1,95 @@ +package jwt + +import ( + "fmt" + "net/http" + + "github.com/mohammadne/bookman/auth/pkg/failures" +) + +type failure struct { + FailureMessage string `json:"message"` + FailureStatus int `json:"status"` + FailureCauses []string `json:"causes"` +} + +// ==============================================================> methods + +func (f *failure) Message() string { + return f.FailureMessage +} + +func (f *failure) Status() int { + return f.FailureStatus +} + +func (f *failure) Causes() []string { + return f.FailureCauses +} + +func (f *failure) Error() string { + return fmt.Sprintf( + "message: %s - status: %d - causes: %v", + f.FailureMessage, f.FailureStatus, f.FailureCauses, + ) +} + +// ==============================================================> constructors + +type Failure struct{} + +func (Failure) New(message string, status int, causes []string) failures.Failure { + return &failure{ + FailureMessage: message, + FailureStatus: status, + FailureCauses: causes, + } +} + +func (Failure) NewBadRequest(message string) failures.Failure { + return &failure{ + FailureMessage: message, + FailureStatus: http.StatusBadRequest, + } +} + +func (Failure) NewNotFound(message string) failures.Failure { + return &failure{ + FailureMessage: message, + FailureStatus: http.StatusNotFound, + } +} + +func (Failure) NewUnauthorized(message string) failures.Failure { + return &failure{ + FailureMessage: message, + FailureStatus: http.StatusUnauthorized, + } +} + +func (Failure) NewUnprocessableEntity(message string) failures.Failure { + return &failure{ + FailureMessage: message, + FailureStatus: http.StatusUnprocessableEntity, + } +} + +func (Failure) NewNotImplemented() failures.Failure { + return &failure{ + FailureMessage: "not implemented", + FailureStatus: http.StatusNotImplemented, + } +} + +func (Failure) NewInternalServer(message string, errors ...error) failures.Failure { + causes := make([]string, 0, len(errors)) + for index := 0; index < len(errors); index++ { + causes = append(causes, errors[index].Error()) + } + + return &failure{ + FailureMessage: message, + FailureStatus: http.StatusInternalServerError, + FailureCauses: causes, + } +} diff --git a/services/auth/internal/jwt/token.go b/services/auth/internal/jwt/token.go index 9ffdfd5..361d33f 100644 --- a/services/auth/internal/jwt/token.go +++ b/services/auth/internal/jwt/token.go @@ -1,6 +1,7 @@ package jwt import ( + "context" "errors" "fmt" "strconv" @@ -22,9 +23,9 @@ const ( ) type Jwt interface { - CreateJwt(userId uint64) (*models.Jwt, failures.Failure) - ExtractTokenMetadata(tokenString string, tokenType TokenType) (*models.AccessDetails, failures.Failure) - TokenValid(tokenString string, tokenType TokenType) error + CreateJwt(context.Context, uint64) (*models.Jwt, failures.Failure) + ExtractTokenMetadata(context.Context, string, TokenType) (*models.AccessDetails, failures.Failure) + TokenValid(context.Context, string, TokenType) failures.Failure } type jwt struct { @@ -37,8 +38,10 @@ func New(cfg *Config, lg logger.Logger, tr trace.Tracer) Jwt { return &jwt{config: cfg, logger: lg, tracer: tr} } -// failureUnprocessableEntity -func (jwt *jwt) CreateJwt(userId uint64) (*models.Jwt, failures.Failure) { +func (jwt *jwt) CreateJwt(ctx context.Context, userId uint64) (*models.Jwt, failures.Failure) { + ctx, span := jwt.tracer.Start(ctx, "jwt.create_jwt") + defer span.End() + accessExpires := time.Duration(jwt.config.AccessExpires) * time.Hour refreshExpires := time.Duration(jwt.config.RefreshExpires) * time.Hour @@ -55,12 +58,18 @@ func (jwt *jwt) CreateJwt(userId uint64) (*models.Jwt, failures.Failure) { accessErr := createToken(userId, jwt.config.AccessSecret, tokenDetail.AccessToken) if accessErr != nil { - return nil, accessErr + failure := Failure{}.NewUnprocessableEntity("unprocessable access token") + jwt.logger.Error(failure.Message(), logger.Error(accessErr)) + span.RecordError(accessErr) + return nil, failure } refreshErr := createToken(userId, jwt.config.RefreshSecret, tokenDetail.RefreshToken) if refreshErr != nil { - return nil, refreshErr + failure := Failure{}.NewUnprocessableEntity("unprocessable refresh token") + jwt.logger.Error(failure.Message(), logger.Error(refreshErr)) + span.RecordError(refreshErr) + return nil, failure } return tokenDetail, nil @@ -87,10 +96,14 @@ func createToken(userId uint64, secret string, token *models.Token) error { // failureUnprocessableEntity // failureUnautorized = failures.Network{}.NewUnauthorized("unauthorized") // failureUnautorized -func (jwt *jwt) ExtractTokenMetadata(tokenString string, tokenType TokenType) (*models.AccessDetails, failures.Failure) { - token, err := jwt.verifyToken(tokenString, tokenType) - if err != nil { - return nil, err +func (jwt *jwt) ExtractTokenMetadata(ctx context.Context, tokenString string, tokenType TokenType) (*models.AccessDetails, failures.Failure) { + ctx, span := jwt.tracer.Start(ctx, "jwt.extract_token_metadata") + defer span.End() + + token, failure := jwt.verifyToken(tokenString, tokenType) + if failure != nil { + span.RecordError(failure) + return nil, failure } claims, ok := token.Claims.(jwtPkg.MapClaims) @@ -103,16 +116,19 @@ func (jwt *jwt) ExtractTokenMetadata(tokenString string, tokenType TokenType) (* userIdStr := fmt.Sprintf("%.f", claims["user_id"]) userId, err := strconv.ParseUint(userIdStr, 10, 64) if err != nil { + span.RecordError(err) return nil, err } return &models.AccessDetails{TokenUuid: tokenUuid, UserId: userId}, nil } - return nil, errors.New("invalid token") + err = errors.New("invalid token") + span.RecordError(err) + return nil, err } -func (jwt *jwt) TokenValid(tokenString string, tokenType TokenType) error { +func (jwt *jwt) TokenValid(ctx context.Context, tokenString string, tokenType TokenType) failures.Failure { token, err := jwt.verifyToken(tokenString, tokenType) if err != nil { return err @@ -125,7 +141,7 @@ func (jwt *jwt) TokenValid(tokenString string, tokenType TokenType) error { return nil } -func (jwt *jwt) verifyToken(tokenString string, tokenType TokenType) (*jwtPkg.Token, error) { +func (jwt *jwt) verifyToken(tokenString string, tokenType TokenType) (*jwtPkg.Token, failures.Failure) { var secret string if tokenType == Access { secret = jwt.config.AccessSecret @@ -141,7 +157,7 @@ func (jwt *jwt) verifyToken(tokenString string, tokenType TokenType) (*jwtPkg.To return token, nil } -//checkSigningMethod checks the token method conform to "SigningMethodHMAC" +// checkSigningMethod checks the token method conform to "SigningMethodHMAC" func (jwt *jwt) checkSigningMethod(secret string) jwtPkg.Keyfunc { return func(token *jwtPkg.Token) (interface{}, error) { if _, ok := token.Method.(*jwtPkg.SigningMethodHMAC); !ok {