From 1357b1dea6463af15fb2211336dc115b13f35e16 Mon Sep 17 00:00:00 2001 From: Matt Ellis Date: Mon, 10 Jul 2023 08:46:52 -0700 Subject: [PATCH] Add WithOpenURL option for AcquireTokenInteractive (#422) --- apps/public/public.go | 42 +++++++++++++++++++++++++++++--------- apps/public/public_test.go | 38 ++++++++-------------------------- 2 files changed, 40 insertions(+), 40 deletions(-) diff --git a/apps/public/public.go b/apps/public/public.go index 7da3ff6e..58b8d250 100644 --- a/apps/public/public.go +++ b/apps/public/public.go @@ -480,6 +480,7 @@ func (pca Client) RemoveAccount(ctx context.Context, account Account) error { // interactiveAuthOptions contains the optional parameters used to acquire an access token for interactive auth code flow. type interactiveAuthOptions struct { claims, domainHint, loginHint, redirectURI, tenantID string + openURL func(url string) error } // AcquireInteractiveOption is implemented by options for AcquireTokenInteractive @@ -565,10 +566,33 @@ func WithRedirectURI(redirectURI string) interface { } } +// WithOpenURL allows you to provide a function to open the browser to complete the interactive login, instead of launching the system default browser. +func WithOpenURL(openURL func(url string) error) interface { + AcquireInteractiveOption + options.CallOption +} { + return struct { + AcquireInteractiveOption + options.CallOption + }{ + CallOption: options.NewCallOption( + func(a any) error { + switch t := a.(type) { + case *interactiveAuthOptions: + t.openURL = openURL + default: + return fmt.Errorf("unexpected options type %T", a) + } + return nil + }, + ), + } +} + // AcquireTokenInteractive acquires a security token from the authority using the default web browser to select the account. // https://docs.microsoft.com/en-us/azure/active-directory/develop/msal-authentication-flows#interactive-and-non-interactive-authentication // -// Options: [WithDomainHint], [WithLoginHint], [WithRedirectURI], [WithTenantID] +// Options: [WithDomainHint], [WithLoginHint], [WithOpenURL], [WithRedirectURI], [WithTenantID] func (pca Client) AcquireTokenInteractive(ctx context.Context, scopes []string, opts ...AcquireInteractiveOption) (AuthResult, error) { o := interactiveAuthOptions{} if err := options.ApplyOptions(&o, opts); err != nil { @@ -587,6 +611,9 @@ func (pca Client) AcquireTokenInteractive(ctx context.Context, scopes []string, return AuthResult{}, err } } + if o.openURL == nil { + o.openURL = browser.OpenURL + } authParams, err := pca.base.AuthParams.WithTenant(o.tenantID) if err != nil { return AuthResult{}, err @@ -600,7 +627,7 @@ func (pca Client) AcquireTokenInteractive(ctx context.Context, scopes []string, authParams.DomainHint = o.domainHint authParams.State = uuid.New().String() authParams.Prompt = "select_account" - res, err := pca.browserLogin(ctx, redirectURL, authParams) + res, err := pca.browserLogin(ctx, redirectURL, authParams, o.openURL) if err != nil { return AuthResult{}, err } @@ -624,11 +651,6 @@ type interactiveAuthResult struct { redirectURI string } -// provides a test hook to simulate opening a browser -var browserOpenURL = func(authURL string) error { - return browser.OpenURL(authURL) -} - // parses the port number from the provided URL. // returns 0 if nil or no port is specified. func parsePort(u *url.URL) (int, error) { @@ -642,8 +664,8 @@ func parsePort(u *url.URL) (int, error) { return strconv.Atoi(p) } -// browserLogin launches the system browser for interactive login -func (pca Client) browserLogin(ctx context.Context, redirectURI *url.URL, params authority.AuthParams) (interactiveAuthResult, error) { +// browserLogin calls openURL and waits for a user to log in +func (pca Client) browserLogin(ctx context.Context, redirectURI *url.URL, params authority.AuthParams, openURL func(string) error) (interactiveAuthResult, error) { // start local redirect server so login can call us back port, err := parsePort(redirectURI) if err != nil { @@ -660,7 +682,7 @@ func (pca Client) browserLogin(ctx context.Context, redirectURI *url.URL, params return interactiveAuthResult{}, err } // open browser window so user can select credentials - if err := browserOpenURL(authURL); err != nil { + if err := openURL(authURL); err != nil { return interactiveAuthResult{}, err } // now wait until the logic calls us back diff --git a/apps/public/public_test.go b/apps/public/public_test.go index d3ccb08c..246f276e 100644 --- a/apps/public/public_test.go +++ b/apps/public/public_test.go @@ -62,9 +62,6 @@ func fakeBrowserOpenURL(authURL string) error { } func TestAcquireTokenInteractive(t *testing.T) { - realBrowserOpenURL := browserOpenURL - defer func() { browserOpenURL = realBrowserOpenURL }() - browserOpenURL = fakeBrowserOpenURL client, err := New("some_client_id") if err != nil { t.Fatal(err) @@ -73,7 +70,7 @@ func TestAcquireTokenInteractive(t *testing.T) { client.base.Token.Authority = &fake.Authority{} client.base.Token.Resolver = &fake.ResolveEndpoints{} client.base.Token.WSTrust = &fake.WSTrust{} - _, err = client.AcquireTokenInteractive(context.Background(), []string{"the_scope"}) + _, err = client.AcquireTokenInteractive(context.Background(), []string{"the_scope"}, WithOpenURL(fakeBrowserOpenURL)) if err != nil { t.Fatal(err) } @@ -198,11 +195,6 @@ func TestAcquireTokenSilentWithoutAccount(t *testing.T) { } func TestAcquireTokenWithTenantID(t *testing.T) { - // replacing browserOpenURL with a fake for the duration of this test enables testing AcquireTokenInteractive - realBrowserOpenURL := browserOpenURL - defer func() { browserOpenURL = realBrowserOpenURL }() - browserOpenURL = fakeBrowserOpenURL - accessToken := "*" clientInfo := base64.RawStdEncoding.EncodeToString([]byte(`{"uid":"uid","utid":"utid"}`)) uuid1 := "00000000-0000-0000-0000-000000000000" @@ -255,7 +247,7 @@ func TestAcquireTokenWithTenantID(t *testing.T) { case "devicecode": dc, err = client.AcquireTokenByDeviceCode(ctx, tokenScope, WithTenantID(test.tenant)) case "interactive": - ar, err = client.AcquireTokenInteractive(ctx, tokenScope, WithTenantID(test.tenant)) + ar, err = client.AcquireTokenInteractive(ctx, tokenScope, WithTenantID(test.tenant), WithOpenURL(fakeBrowserOpenURL)) case "password": ar, err = client.AcquireTokenByUsernamePassword(ctx, tokenScope, "username", "password", WithTenantID(test.tenant)) default: @@ -304,11 +296,6 @@ func TestAcquireTokenWithTenantID(t *testing.T) { } func TestWithInstanceDiscovery(t *testing.T) { - // replacing browserOpenURL with a fake for the duration of this test enables testing AcquireTokenInteractive - realBrowserOpenURL := browserOpenURL - defer func() { browserOpenURL = realBrowserOpenURL }() - browserOpenURL = fakeBrowserOpenURL - accessToken := "*" clientInfo := base64.RawStdEncoding.EncodeToString([]byte(`{"uid":"uid","utid":"utid"}`)) host := "stack.local" @@ -346,7 +333,7 @@ func TestWithInstanceDiscovery(t *testing.T) { case "devicecode": dc, err = client.AcquireTokenByDeviceCode(ctx, tokenScope) case "interactive": - ar, err = client.AcquireTokenInteractive(ctx, tokenScope) + ar, err = client.AcquireTokenInteractive(ctx, tokenScope, WithOpenURL(fakeBrowserOpenURL)) case "password": ar, err = client.AcquireTokenByUsernamePassword(ctx, tokenScope, "username", "password") default: @@ -460,11 +447,6 @@ func TestWithCache(t *testing.T) { } func TestWithClaims(t *testing.T) { - // replacing browserOpenURL with a fake for the duration of this test enables testing AcquireTokenInteractive - realBrowserOpenURL := browserOpenURL - defer func() { browserOpenURL = realBrowserOpenURL }() - browserOpenURL = fakeBrowserOpenURL - clientInfo := base64.RawStdEncoding.EncodeToString([]byte(`{"uid":"uid","utid":"utid"}`)) lmo, tenant := "login.microsoftonline.com", "tenant" authority := fmt.Sprintf(authorityFmt, lmo, tenant) @@ -562,7 +544,7 @@ func TestWithClaims(t *testing.T) { case "devicecode": dc, err = client.AcquireTokenByDeviceCode(ctx, tokenScope, WithClaims(test.claims)) case "interactive": - ar, err = client.AcquireTokenInteractive(ctx, tokenScope, WithClaims(test.claims)) + ar, err = client.AcquireTokenInteractive(ctx, tokenScope, WithClaims(test.claims), WithOpenURL(fakeBrowserOpenURL)) case "password": ar, err = client.AcquireTokenByUsernamePassword(ctx, tokenScope, "username", "password", WithClaims(test.claims)) case "passwordFederated": @@ -662,8 +644,6 @@ func TestWithPortAuthority(t *testing.T) { } func TestWithLoginHint(t *testing.T) { - realBrowserOpenURL := browserOpenURL - defer func() { browserOpenURL = realBrowserOpenURL }() upn := "user@localhost" client, err := New("client-id") if err != nil { @@ -690,7 +670,7 @@ func TestWithLoginHint(t *testing.T) { } return err } - browserOpenURL = func(authURL string) error { + browserOpenURL := func(authURL string) error { called = true parsed, err := url.Parse(authURL) if err != nil { @@ -707,7 +687,7 @@ func TestWithLoginHint(t *testing.T) { // this helper validates the other params and completes the redirect return fakeBrowserOpenURL(authURL) } - acquireOpts := []AcquireInteractiveOption{} + acquireOpts := []AcquireInteractiveOption{WithOpenURL(browserOpenURL)} urlOpts := []AuthCodeURLOption{} if expectHint { acquireOpts = append(acquireOpts, WithLoginHint(upn)) @@ -736,8 +716,6 @@ func TestWithLoginHint(t *testing.T) { } func TestWithDomainHint(t *testing.T) { - realBrowserOpenURL := browserOpenURL - defer func() { browserOpenURL = realBrowserOpenURL }() domain := "contoso.com" client, err := New("client-id") if err != nil { @@ -764,7 +742,7 @@ func TestWithDomainHint(t *testing.T) { } return err } - browserOpenURL = func(authURL string) error { + browserOpenURL := func(authURL string) error { called = true parsed, err := url.Parse(authURL) if err != nil { @@ -781,7 +759,7 @@ func TestWithDomainHint(t *testing.T) { // this helper validates the other params and completes the redirect return fakeBrowserOpenURL(authURL) } - var acquireOpts []AcquireInteractiveOption + acquireOpts := []AcquireInteractiveOption{WithOpenURL(browserOpenURL)} var urlOpts []AuthCodeURLOption if expectHint { acquireOpts = append(acquireOpts, WithDomainHint(domain))