Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[management] Refactor users to use store methods #2917

Draft
wants to merge 9 commits into
base: peers-get-account-refactoring
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 59 additions & 84 deletions management/server/account.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ type AccountManager interface {
GetAccountIDByUserID(ctx context.Context, userID, domain string) (string, error)
GetAccountIDFromToken(ctx context.Context, claims jwtclaims.AuthorizationClaims) (string, string, error)
CheckUserAccessByJWTGroups(ctx context.Context, claims jwtclaims.AuthorizationClaims) error
GetAccountFromPAT(ctx context.Context, pat string) (*Account, *User, *PersonalAccessToken, error)
GetAccountInfoFromPAT(ctx context.Context, token string) (*User, *PersonalAccessToken, string, string, error)
DeleteAccount(ctx context.Context, accountID, userID string) error
MarkPATUsed(ctx context.Context, tokenID string) error
GetUserByID(ctx context.Context, id string) (*User, error)
Expand Down Expand Up @@ -1363,13 +1363,13 @@ func (am *DefaultAccountManager) DeleteAccount(ctx context.Context, accountID, u
continue
}

deleteUserErr := am.deleteRegularUser(ctx, account, userID, otherUser.Id)
_, deleteUserErr := am.deleteRegularUser(ctx, accountID, userID, otherUser.Id)
if deleteUserErr != nil {
return deleteUserErr
}
}

err = am.deleteRegularUser(ctx, account, userID, userID)
_, err = am.deleteRegularUser(ctx, accountID, userID, userID)
if err != nil {
log.WithContext(ctx).Errorf("failed deleting user %s. error: %s", userID, err)
return err
Expand Down Expand Up @@ -1426,20 +1426,8 @@ func isNil(i idp.Manager) bool {
// addAccountIDToIDPAppMeta update user's app metadata in idp manager
func (am *DefaultAccountManager) addAccountIDToIDPAppMeta(ctx context.Context, userID string, accountID string) error {
if !isNil(am.idpManager) {
accountUsers, err := am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID)
if err != nil {
return err
}
cachedAccount := &Account{
Id: accountID,
Users: make(map[string]*User),
}
for _, user := range accountUsers {
cachedAccount.Users[user.Id] = user
}

// user can be nil if it wasn't found (e.g., just created)
user, err := am.lookupUserInCache(ctx, userID, cachedAccount)
user, err := am.lookupUserInCache(ctx, userID, accountID)
if err != nil {
return err
}
Expand Down Expand Up @@ -1515,10 +1503,15 @@ func (am *DefaultAccountManager) lookupUserInCacheByEmail(ctx context.Context, e
}

// lookupUserInCache looks up user in the IdP cache and returns it. If the user wasn't found, the function returns nil
func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID string, account *Account) (*idp.UserData, error) {
users := make(map[string]userLoggedInOnce, len(account.Users))
func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID string, accountID string) (*idp.UserData, error) {
accountUsers, err := am.Store.GetAccountUsers(ctx, LockingStrengthShare, accountID)
if err != nil {
return nil, err
}

users := make(map[string]userLoggedInOnce, len(accountUsers))
// ignore service users and users provisioned by integrations than are never logged in
for _, user := range account.Users {
for _, user := range accountUsers {
if user.IsServiceUser {
continue
}
Expand All @@ -1527,8 +1520,8 @@ func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID s
}
users[user.Id] = userLoggedInOnce(!user.LastLogin.IsZero())
}
log.WithContext(ctx).Debugf("looking up user %s of account %s in cache", userID, account.Id)
userData, err := am.lookupCache(ctx, users, account.Id)
log.WithContext(ctx).Debugf("looking up user %s of account %s in cache", userID, accountID)
userData, err := am.lookupCache(ctx, users, accountID)
if err != nil {
return nil, err
}
Expand All @@ -1541,13 +1534,13 @@ func (am *DefaultAccountManager) lookupUserInCache(ctx context.Context, userID s

// add extra check on external cache manager. We may get to this point when the user is not yet findable in IDP,
// or it didn't have its metadata updated with am.addAccountIDToIDPAppMeta
user, err := account.FindUser(userID)
user, err := am.Store.GetUserByUserID(ctx, LockingStrengthShare, userID)
if err != nil {
log.WithContext(ctx).Errorf("failed finding user %s in account %s", userID, account.Id)
log.WithContext(ctx).Errorf("failed finding user %s in account %s", userID, accountID)
return nil, err
}

key := user.IntegrationReference.CacheKey(account.Id, userID)
key := user.IntegrationReference.CacheKey(accountID, userID)
ud, err := am.externalCacheManager.Get(am.ctx, key)
if err != nil {
log.WithContext(ctx).Debugf("failed to get externalCache for key: %s, error: %s", key, err)
Expand Down Expand Up @@ -1787,9 +1780,9 @@ func (am *DefaultAccountManager) addNewUserToDomainAccount(ctx context.Context,
unlockAccount := am.Store.AcquireWriteLockByUID(ctx, domainAccountID)
defer unlockAccount()

usersMap := make(map[string]*User)
usersMap[claims.UserId] = NewRegularUser(claims.UserId)
err := am.Store.SaveUsers(domainAccountID, usersMap)
newUser := NewRegularUser(claims.UserId)
newUser.AccountID = domainAccountID
err := am.Store.SaveUser(ctx, LockingStrengthUpdate, newUser)
if err != nil {
return "", err
}
Expand All @@ -1812,12 +1805,7 @@ func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID str
return nil
}

account, err := am.Store.GetAccount(ctx, accountID)
if err != nil {
return err
}

user, err := am.lookupUserInCache(ctx, userID, account)
user, err := am.lookupUserInCache(ctx, userID, accountID)
if err != nil {
return err
}
Expand All @@ -1827,17 +1815,17 @@ func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID str
}

if user.AppMetadata.WTPendingInvite != nil && *user.AppMetadata.WTPendingInvite {
log.WithContext(ctx).Infof("redeeming invite for user %s account %s", userID, account.Id)
log.WithContext(ctx).Infof("redeeming invite for user %s account %s", userID, accountID)
// User has already logged in, meaning that IdP should have set wt_pending_invite to false.
// Our job is to just reload cache.
go func() {
_, err = am.refreshCache(ctx, account.Id)
_, err = am.refreshCache(ctx, accountID)
if err != nil {
log.WithContext(ctx).Warnf("failed reloading cache when redeeming user %s under account %s", userID, account.Id)
log.WithContext(ctx).Warnf("failed reloading cache when redeeming user %s under account %s", userID, accountID)
return
}
log.WithContext(ctx).Debugf("user %s of account %s redeemed invite", user.ID, account.Id)
am.StoreEvent(ctx, userID, userID, account.Id, activity.UserJoined, nil)
log.WithContext(ctx).Debugf("user %s of account %s redeemed invite", user.ID, accountID)
am.StoreEvent(ctx, userID, userID, accountID, activity.UserJoined, nil)
}()
}

Expand All @@ -1846,86 +1834,73 @@ func (am *DefaultAccountManager) redeemInvite(ctx context.Context, accountID str

// MarkPATUsed marks a personal access token as used
func (am *DefaultAccountManager) MarkPATUsed(ctx context.Context, tokenID string) error {
return am.Store.MarkPATUsed(ctx, LockingStrengthUpdate, tokenID)
}

user, err := am.Store.GetUserByTokenID(ctx, tokenID)
if err != nil {
return err
}
// GetAccount returns an account associated with this account ID.
func (am *DefaultAccountManager) GetAccount(ctx context.Context, accountID string) (*Account, error) {
return am.Store.GetAccount(ctx, accountID)
}

account, err := am.Store.GetAccountByUser(ctx, user.Id)
// GetAccountInfoFromPAT retrieves user, personal access token, domain, and category details from a personal access token.
func (am *DefaultAccountManager) GetAccountInfoFromPAT(ctx context.Context, token string) (user *User, pat *PersonalAccessToken, domain string, category string, err error) {
user, pat, err = am.extractPATFromToken(ctx, token)
if err != nil {
return err
return nil, nil, "", "", err
}

unlock := am.Store.AcquireWriteLockByUID(ctx, account.Id)
defer unlock()

account, err = am.Store.GetAccountByUser(ctx, user.Id)
domain, category, err = am.Store.GetAccountDomainAndCategory(ctx, LockingStrengthShare, user.AccountID)
if err != nil {
return err
return nil, nil, "", "", err
}

pat, ok := account.Users[user.Id].PATs[tokenID]
if !ok {
return fmt.Errorf("token not found")
}

pat.LastUsed = time.Now().UTC()

return am.Store.SaveAccount(ctx, account)
return user, pat, domain, category, nil
}

// GetAccount returns an account associated with this account ID.
func (am *DefaultAccountManager) GetAccount(ctx context.Context, accountID string) (*Account, error) {
return am.Store.GetAccount(ctx, accountID)
}

// GetAccountFromPAT returns Account and User associated with a personal access token
func (am *DefaultAccountManager) GetAccountFromPAT(ctx context.Context, token string) (*Account, *User, *PersonalAccessToken, error) {
// extractPATFromToken validates the token structure and retrieves associated User and PAT.
func (am *DefaultAccountManager) extractPATFromToken(ctx context.Context, token string) (*User, *PersonalAccessToken, error) {
if len(token) != PATLength {
return nil, nil, nil, fmt.Errorf("token has wrong length")
return nil, nil, fmt.Errorf("token has incorrect length")
}

prefix := token[:len(PATPrefix)]
if prefix != PATPrefix {
return nil, nil, nil, fmt.Errorf("token has wrong prefix")
return nil, nil, fmt.Errorf("token has incorrect prefix")
}

secret := token[len(PATPrefix) : len(PATPrefix)+PATSecretLength]
encodedChecksum := token[len(PATPrefix)+PATSecretLength : len(PATPrefix)+PATSecretLength+PATChecksumLength]

verificationChecksum, err := base62.Decode(encodedChecksum)
if err != nil {
return nil, nil, nil, fmt.Errorf("token checksum decoding failed: %w", err)
return nil, nil, fmt.Errorf("token checksum decoding failed: %w", err)
}

secretChecksum := crc32.ChecksumIEEE([]byte(secret))
if secretChecksum != verificationChecksum {
return nil, nil, nil, fmt.Errorf("token checksum does not match")
return nil, nil, fmt.Errorf("token checksum does not match")
}

hashedToken := sha256.Sum256([]byte(token))
encodedHashedToken := b64.StdEncoding.EncodeToString(hashedToken[:])
tokenID, err := am.Store.GetTokenIDByHashedToken(ctx, encodedHashedToken)
if err != nil {
return nil, nil, nil, err
}

user, err := am.Store.GetUserByTokenID(ctx, tokenID)
if err != nil {
return nil, nil, nil, err
}
var user *User
var pat *PersonalAccessToken

account, err := am.Store.GetAccountByUser(ctx, user.Id)
if err != nil {
return nil, nil, nil, err
}
err = am.Store.ExecuteInTransaction(ctx, func(transaction Store) error {
pat, err = transaction.GetPATByHashedToken(ctx, LockingStrengthShare, encodedHashedToken)
if err != nil {
return err
}

pat := user.PATs[tokenID]
if pat == nil {
return nil, nil, nil, fmt.Errorf("personal access token not found")
user, err = transaction.GetUserByPATID(ctx, LockingStrengthShare, pat.ID)
return err
})
if err != nil {
return nil, nil, err
}

return account, user, pat, nil
return user, pat, nil
}

// GetAccountByID returns an account associated with this account ID.
Expand Down
7 changes: 4 additions & 3 deletions management/server/account_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,7 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) {
PATs: map[string]*PersonalAccessToken{
"tokenId": {
ID: "tokenId",
UserID: "someUser",
HashedToken: encodedHashedToken,
},
},
Expand All @@ -771,14 +772,14 @@ func TestAccountManager_GetAccountFromPAT(t *testing.T) {
Store: store,
}

account, user, pat, err := am.GetAccountFromPAT(context.Background(), token)
user, pat, _, _, err := am.GetAccountInfoFromPAT(context.Background(), token)
if err != nil {
t.Fatalf("Error when getting Account from PAT: %s", err)
}

assert.Equal(t, "account_id", account.Id)
assert.Equal(t, "account_id", user.AccountID)
assert.Equal(t, "someUser", user.Id)
assert.Equal(t, account.Users["someUser"].PATs["tokenId"], pat)
assert.Equal(t, account.Users["someUser"].PATs["tokenId"].ID, pat.ID)
}

func TestDefaultAccountManager_MarkPATUsed(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion management/server/http/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func APIHandler(ctx context.Context, accountManager s.AccountManager, LocationMa
)

authMiddleware := middleware.NewAuthMiddleware(
accountManager.GetAccountFromPAT,
accountManager.GetAccountInfoFromPAT,
jwtValidator.ValidateAndParse,
accountManager.MarkPATUsed,
accountManager.CheckUserAccessByJWTGroups,
Expand Down
22 changes: 10 additions & 12 deletions management/server/http/middleware/auth_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ import (
"github.com/netbirdio/netbird/management/server/status"
)

// GetAccountFromPATFunc function
type GetAccountFromPATFunc func(ctx context.Context, token string) (*server.Account, *server.User, *server.PersonalAccessToken, error)
// GetAccountInfoFromPATFunc function
type GetAccountInfoFromPATFunc func(ctx context.Context, token string) (user *server.User, pat *server.PersonalAccessToken, domain string, category string, err error)

// ValidateAndParseTokenFunc function
type ValidateAndParseTokenFunc func(ctx context.Context, token string) (*jwt.Token, error)
Expand All @@ -33,7 +33,7 @@ type CheckUserAccessByJWTGroupsFunc func(ctx context.Context, claims jwtclaims.A

// AuthMiddleware middleware to verify personal access tokens (PAT) and JWT tokens
type AuthMiddleware struct {
getAccountFromPAT GetAccountFromPATFunc
getAccountInfoFromPAT GetAccountInfoFromPATFunc
validateAndParseToken ValidateAndParseTokenFunc
markPATUsed MarkPATUsedFunc
checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc
Expand All @@ -47,15 +47,15 @@ const (
)

// NewAuthMiddleware instance constructor
func NewAuthMiddleware(getAccountFromPAT GetAccountFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc,
func NewAuthMiddleware(getAccountInfoFromPAT GetAccountInfoFromPATFunc, validateAndParseToken ValidateAndParseTokenFunc,
markPATUsed MarkPATUsedFunc, checkUserAccessByJWTGroups CheckUserAccessByJWTGroupsFunc, claimsExtractor *jwtclaims.ClaimsExtractor,
audience string, userIdClaim string) *AuthMiddleware {
if userIdClaim == "" {
userIdClaim = jwtclaims.UserIDClaim
}

return &AuthMiddleware{
getAccountFromPAT: getAccountFromPAT,
getAccountInfoFromPAT: getAccountInfoFromPAT,
validateAndParseToken: validateAndParseToken,
markPATUsed: markPATUsed,
checkUserAccessByJWTGroups: checkUserAccessByJWTGroups,
Expand Down Expand Up @@ -151,13 +151,11 @@ func (m *AuthMiddleware) verifyUserAccess(ctx context.Context, validatedToken *j
// CheckPATFromRequest checks if the PAT is valid
func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Request, auth []string) error {
token, err := getTokenFromPATRequest(auth)

// If an error occurs, call the error handler and return an error
if err != nil {
return fmt.Errorf("Error extracting token: %w", err)
return fmt.Errorf("error extracting token: %w", err)
}

account, user, pat, err := m.getAccountFromPAT(r.Context(), token)
user, pat, accDomain, accCategory, err := m.getAccountInfoFromPAT(r.Context(), token)
if err != nil {
return fmt.Errorf("invalid Token: %w", err)
}
Expand All @@ -172,9 +170,9 @@ func (m *AuthMiddleware) checkPATFromRequest(w http.ResponseWriter, r *http.Requ

claimMaps := jwt.MapClaims{}
claimMaps[m.userIDClaim] = user.Id
claimMaps[m.audience+jwtclaims.AccountIDSuffix] = account.Id
claimMaps[m.audience+jwtclaims.DomainIDSuffix] = account.Domain
claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = account.DomainCategory
claimMaps[m.audience+jwtclaims.AccountIDSuffix] = user.AccountID
claimMaps[m.audience+jwtclaims.DomainIDSuffix] = accDomain
claimMaps[m.audience+jwtclaims.DomainCategorySuffix] = accCategory
jwtToken := jwt.NewWithClaims(jwt.SigningMethodHS256, claimMaps)
newRequest := r.WithContext(context.WithValue(r.Context(), jwtclaims.TokenUserProperty, jwtToken)) //nolint
// Update the current request with the new context information.
Expand Down
11 changes: 6 additions & 5 deletions management/server/http/middleware/auth_middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ var testAccount = &server.Account{
Domain: domain,
Users: map[string]*server.User{
userID: {
Id: userID,
Id: userID,
AccountID: accountID,
PATs: map[string]*server.PersonalAccessToken{
tokenID: {
ID: tokenID,
Expand All @@ -49,11 +50,11 @@ var testAccount = &server.Account{
},
}

func mockGetAccountFromPAT(_ context.Context, token string) (*server.Account, *server.User, *server.PersonalAccessToken, error) {
func mockGetAccountInfoFromPAT(_ context.Context, token string) (user *server.User, pat *server.PersonalAccessToken, domain string, category string, err error) {
if token == PAT {
return testAccount, testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], nil
return testAccount.Users[userID], testAccount.Users[userID].PATs[tokenID], testAccount.Domain, testAccount.DomainCategory, nil
}
return nil, nil, nil, fmt.Errorf("PAT invalid")
return nil, nil, "", "", fmt.Errorf("PAT invalid")
}

func mockValidateAndParseToken(_ context.Context, token string) (*jwt.Token, error) {
Expand Down Expand Up @@ -165,7 +166,7 @@ func TestAuthMiddleware_Handler(t *testing.T) {
)

authMiddleware := NewAuthMiddleware(
mockGetAccountFromPAT,
mockGetAccountInfoFromPAT,
mockValidateAndParseToken,
mockMarkPATUsed,
mockCheckUserAccessByJWTGroups,
Expand Down
Loading
Loading