Skip to content

Commit

Permalink
Merge pull request #330 from AzureAD/release-0.5.3
Browse files Browse the repository at this point in the history
Fix silent auth scopes and refresh behavior (#327)
  • Loading branch information
siddhijain authored Jul 18, 2022
2 parents de1c8ae + 26c6116 commit f149faa
Show file tree
Hide file tree
Showing 7 changed files with 201 additions and 18 deletions.
4 changes: 2 additions & 2 deletions apps/internal/base/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,15 +262,15 @@ func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilen
result, err := AuthResultFromStorage(storageTokenResponse)
if err != nil {
if reflect.ValueOf(storageTokenResponse.RefreshToken).IsZero() {
return AuthResult{}, errors.New("no refresh token found")
return AuthResult{}, errors.New("no token found")
}

var cc *accesstokens.Credential
if silent.RequestType == accesstokens.ATConfidential {
cc = silent.Credential
}

token, err := b.Token.Refresh(ctx, silent.RequestType, b.AuthParams, cc, storageTokenResponse.RefreshToken)
token, err := b.Token.Refresh(ctx, silent.RequestType, authParams, cc, storageTokenResponse.RefreshToken)
if err != nil {
return AuthResult{}, err
}
Expand Down
179 changes: 178 additions & 1 deletion apps/internal/base/base_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,194 @@
package base

import (
"context"
"fmt"
"reflect"
"testing"
"time"

"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base/internal/storage"
internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/fake"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/shared"

"github.com/kylelemons/godebug/pretty"
)

const (
fakeAccessToken = "fake-access-token"
fakeAuthority = "fake_authority"
fakeClientID = "fake-client-id"
fakeRefreshToken = "fake-refresh-token"
fakeTenantID = "fake-tenant-id"
fakeUsername = "fake-username"
)

var (
fakeIDToken = accesstokens.IDToken{
Oid: "oid",
PreferredUsername: fakeUsername,
RawToken: "x.e30",
TenantID: fakeTenantID,
UPN: fakeUsername,
}
testScopes = []string{"scope"}
)

func fakeClient(t *testing.T) Client {
client, err := New(fakeClientID, fmt.Sprintf("https://%s/%s", fakeAuthority, fakeTenantID), &oauth.Client{})
if err != nil {
t.Fatal(err)
}
client.Token.AccessTokens = &fake.AccessTokens{
AccessToken: accesstokens.TokenResponse{
AccessToken: fakeAccessToken,
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)},
FamilyID: "family-id",
GrantedScopes: accesstokens.Scopes{Slice: testScopes},
IDToken: fakeIDToken,
RefreshToken: fakeRefreshToken,
},
}
client.Token.Authority = &fake.Authority{
InstanceResp: authority.InstanceDiscoveryResponse{
Metadata: []authority.InstanceDiscoveryMetadata{
{Aliases: []string{fakeAuthority}, PreferredNetwork: fakeAuthority},
},
TenantDiscoveryEndpoint: fmt.Sprintf("https://%s/fake/discovery/endpoint", fakeAuthority),
},
}
client.Token.Resolver = &fake.ResolveEndpoints{
Endpoints: authority.NewEndpoints(
fmt.Sprintf("https://%s/fake/auth", fakeAuthority),
fmt.Sprintf("https://%s/fake/token", fakeAuthority),
fmt.Sprintf("https://%s/fake/jwt", fakeAuthority),
fakeAuthority,
),
}
return client
}

func TestAcquireTokenSilentEmptyCache(t *testing.T) {
client := fakeClient(t)
_, err := client.AcquireTokenSilent(context.Background(), AcquireTokenSilentParameters{
Account: shared.NewAccount("homeAccountID", "env", "realm", "localAccountID", authority.AAD, "username"),
Scopes: testScopes,
})
if err == nil {
t.Fatal("expected an error because the cache is empty")
}
}

func TestAcquireTokenSilentScopes(t *testing.T) {
// ensure fakeIDToken.RawToken unmarshals (doesn't matter to what) because an unmarshalling
// error can conceal a test bug by making an "err != nil" check true for the wrong reason
var idToken accesstokens.IDToken
if err := idToken.UnmarshalJSON([]byte(fakeIDToken.RawToken)); err != nil {
t.Fatal(err)
}
for _, test := range []struct {
desc string
cachedTokenScopes []string
}{
{"expired access token", testScopes},
{"no access token", []string{"other-" + testScopes[0]}},
} {
t.Run(test.desc, func(t *testing.T) {
client := fakeClient(t)
validated := false
client.Token.AccessTokens.(*fake.AccessTokens).FromRefreshTokenCallback = func(at accesstokens.AppType, ap authority.AuthParams, cc *accesstokens.Credential, rt string) {
validated = true
if !reflect.DeepEqual(ap.Scopes, testScopes) {
t.Fatalf("unexpected scopes: %v", ap.Scopes)
}
if cc != nil {
t.Fatal("client shouldn't have a credential")
}
if rt != fakeRefreshToken {
t.Fatal("unexpected refresh token")
}
}

// cache a refresh token and an expired access token for the given scopes
// (testing only the public client code path)
storage.FakeValidate = func(storage.AccessToken) error { return nil }
account, err := client.manager.Write(
authority.AuthParams{
AuthorityInfo: authority.Info{
AuthorityType: authority.AAD,
Host: fakeAuthority,
Tenant: fakeIDToken.TenantID,
},
ClientID: fakeClientID,
Scopes: test.cachedTokenScopes,
Username: fakeIDToken.PreferredUsername,
},
accesstokens.TokenResponse{
AccessToken: fakeAccessToken,
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(-time.Hour)},
GrantedScopes: accesstokens.Scopes{Slice: test.cachedTokenScopes},
IDToken: fakeIDToken,
RefreshToken: fakeRefreshToken,
},
)
storage.FakeValidate = nil
if err != nil {
t.Fatal(err)
}

// AcquireTokenSilent should redeem the refresh token for a new access token
ar, err := client.AcquireTokenSilent(context.Background(), AcquireTokenSilentParameters{Account: account, Scopes: testScopes})
if err != nil {
t.Fatal(err)
}
if ar.AccessToken != fakeAccessToken {
t.Fatal("unexpected access token")
}
if !validated {
t.Fatal("FromRefreshTokenCallback wasn't called")
}
})
}
}

func TestAcquireTokenSilentGrantedScopes(t *testing.T) {
client := fakeClient(t)
grantedScopes := []string{"scope1", "scope2"}
expectedToken := "not-" + fakeAccessToken
account, err := client.manager.Write(
authority.AuthParams{
AuthorityInfo: authority.Info{
AuthorityType: authority.AAD,
Host: fakeAuthority,
Tenant: fakeIDToken.TenantID,
},
ClientID: fakeClientID,
Scopes: grantedScopes[1:],
},
accesstokens.TokenResponse{
AccessToken: expectedToken,
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)},
GrantedScopes: accesstokens.Scopes{Slice: grantedScopes},
},
)
if err != nil {
t.Fatal(err)
}

for _, scope := range grantedScopes {
ar, err := client.AcquireTokenSilent(context.Background(), AcquireTokenSilentParameters{Account: account, Scopes: []string{scope}})
if err != nil {
t.Fatal(err)
}
if ar.AccessToken != expectedToken {
t.Fatal("unexpected access token")
}
}
}

func TestCreateAuthenticationResult(t *testing.T) {
future := time.Now().Add(400 * time.Second)

Expand Down
6 changes: 6 additions & 0 deletions apps/internal/base/internal/storage/items.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,14 @@ func (a AccessToken) Key() string {
)
}

// FakeValidate enables tests to fake access token validation
var FakeValidate func(AccessToken) error

// Validate validates that this AccessToken can be used.
func (a AccessToken) Validate() error {
if FakeValidate != nil {
return FakeValidate(a)
}
if a.CachedAt.T.After(time.Now()) {
return errors.New("access token isn't valid, it was cached at a future time")
}
Expand Down
11 changes: 4 additions & 7 deletions apps/internal/base/internal/storage/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,7 @@ func (m *Manager) Read(ctx context.Context, authParameters authority.AuthParams,
return TokenResponse{}, err
}

accessToken, err := m.readAccessToken(homeAccountID, metadata.Aliases, realm, clientID, scopes)
if err != nil {
return TokenResponse{}, err
}
accessToken := m.readAccessToken(homeAccountID, metadata.Aliases, realm, clientID, scopes)

if account.IsZero() {
return TokenResponse{
Expand Down Expand Up @@ -249,7 +246,7 @@ func (m *Manager) aadMetadata(ctx context.Context, authorityInfo authority.Info)
return m.aadCache[authorityInfo.Host], nil
}

func (m *Manager) readAccessToken(homeID string, envAliases []string, realm, clientID string, scopes []string) (AccessToken, error) {
func (m *Manager) readAccessToken(homeID string, envAliases []string, realm, clientID string, scopes []string) AccessToken {
m.contractMu.RLock()
defer m.contractMu.RUnlock()
// TODO: linear search (over a map no less) is slow for a large number (thousands) of tokens.
Expand All @@ -259,12 +256,12 @@ func (m *Manager) readAccessToken(homeID string, envAliases []string, realm, cli
if at.HomeAccountID == homeID && at.Realm == realm && at.ClientID == clientID {
if checkAlias(at.Environment, envAliases) {
if isMatchingScopes(scopes, at.Scopes) {
return at, nil
return at
}
}
}
}
return AccessToken{}, fmt.Errorf("access token not found")
return AccessToken{}
}

func (m *Manager) writeAccessToken(accessToken AccessToken) error {
Expand Down
11 changes: 4 additions & 7 deletions apps/internal/base/internal/storage/storage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -136,28 +136,25 @@ func TestReadAccessToken(t *testing.T) {
storageManager := newForTest(nil)
storageManager.update(cache)

retAccessToken, err := storageManager.readAccessToken(
retAccessToken := storageManager.readAccessToken(
"hid",
[]string{"hello", "env", "test"},
"realm",
"cid",
[]string{"user.read", "openid"},
)
if err != nil {
t.Errorf("readAccessToken(): got err == %s, want err == nil", err)
}
if diff := pretty.Compare(testAccessToken, retAccessToken); diff != "" {
t.Fatalf("Returned access token is not the same as expected access token: -want/+got:\n%s", diff)
}
_, err = storageManager.readAccessToken(
retAccessToken = storageManager.readAccessToken(
"this_should_break_it",
[]string{"hello", "env", "test"},
"realm",
"cid",
[]string{"user.read", "openid"},
)
if err == nil {
t.Errorf("readAccessToken(): got err == nil, want err != nil")
if !reflect.ValueOf(retAccessToken).IsZero() {
t.Fatal("expected to find no access token")
}
}

Expand Down
6 changes: 6 additions & 0 deletions apps/internal/oauth/fake/fake.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ type AccessTokens struct {

// fake result to return
DeviceCode accesstokens.DeviceCodeResult

// FromRefreshTokenCallback is an optional callback invoked by FromRefreshToken
FromRefreshTokenCallback func(appType accesstokens.AppType, authParams authority.AuthParams, cc *accesstokens.Credential, refreshToken string)
}

func (f *AccessTokens) FromUsernamePassword(ctx context.Context, authParameters authority.AuthParams) (accesstokens.TokenResponse, error) {
Expand All @@ -62,6 +65,9 @@ func (f *AccessTokens) FromAuthCode(ctx context.Context, req accesstokens.AuthCo
return f.AccessToken, nil
}
func (f *AccessTokens) FromRefreshToken(ctx context.Context, appType accesstokens.AppType, authParams authority.AuthParams, cc *accesstokens.Credential, refreshToken string) (accesstokens.TokenResponse, error) {
if f.FromRefreshTokenCallback != nil {
f.FromRefreshTokenCallback(appType, authParams, cc, refreshToken)
}
if f.Err {
return accesstokens.TokenResponse{}, fmt.Errorf("error")
}
Expand Down
2 changes: 1 addition & 1 deletion apps/internal/version/version.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
package version

// Version is the version of this client package that is communicated to the server.
const Version = "0.5.2"
const Version = "0.5.3"

0 comments on commit f149faa

Please sign in to comment.