Skip to content

Commit

Permalink
fix: use auth DetectDefault over oauth2 FindDefaultCredentials (#909)
Browse files Browse the repository at this point in the history
The new Google auth library should be used over the old oauth2 library.

- DetectDefault should be used over the old FindDefaultCredentials to
source ADC from the environment.
- TokenProvider should be used over the old TokenSource
- Use auth library's httptransport.NewClient to gain advantage of built-in universe domain checks

This will help fix certain non-GDU paths as FindDefaultCredentials does not support
self-signed JWTs.
  • Loading branch information
jackwotherspoon authored Jan 14, 2025
1 parent b5c249b commit 52fef27
Show file tree
Hide file tree
Showing 9 changed files with 136 additions and 233 deletions.
105 changes: 66 additions & 39 deletions dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,20 +24,22 @@ import (
"fmt"
"io"
"net"
"os"
"strings"
"sync"
"sync/atomic"
"time"

"cloud.google.com/go/auth"
"cloud.google.com/go/auth/credentials"
"cloud.google.com/go/auth/httptransport"
"cloud.google.com/go/cloudsqlconn/debug"
"cloud.google.com/go/cloudsqlconn/errtype"
"cloud.google.com/go/cloudsqlconn/instance"
"cloud.google.com/go/cloudsqlconn/internal/cloudsql"
"cloud.google.com/go/cloudsqlconn/internal/trace"
"github.com/google/uuid"
"golang.org/x/net/proxy"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
"google.golang.org/api/option"
sqladmin "google.golang.org/api/sqladmin/v1beta4"
)
Expand All @@ -50,6 +52,12 @@ const (
// iamLoginScope is the OAuth2 scope used for tokens embedded in the ephemeral
// certificate.
iamLoginScope = "https://www.googleapis.com/auth/sqlservice.login"
// universeDomainEnvVar is the environment variable for setting the default
// service domain for a given Cloud universe.
universeDomainEnvVar = "GOOGLE_CLOUD_UNIVERSE_DOMAIN"
// defaultUniverseDomain is the default value for universe domain.
// Universe domain is the default service domain for a given Cloud universe.
defaultUniverseDomain = "googleapis.com"
)

var (
Expand Down Expand Up @@ -117,6 +125,25 @@ type cacheKey struct {
name string
}

// getClientUniverseDomain returns the default service domain for a given Cloud
// universe, with the following precedence:
//
// 1. A non-empty option.WithUniverseDomain or similar client option.
// 2. A non-empty environment variable GOOGLE_CLOUD_UNIVERSE_DOMAIN.
// 3. The default value "googleapis.com".
//
// This is the universe domain configured for the client, which will be compared
// to the universe domain that is separately configured for the credentials.
func (c *dialerConfig) getClientUniverseDomain() string {
if c.clientUniverseDomain != "" {
return c.clientUniverseDomain
}
if envUD := os.Getenv(universeDomainEnvVar); envUD != "" {
return envUD
}
return defaultUniverseDomain
}

// A Dialer is used to create connections to Cloud SQL instances.
//
// Use NewDialer to initialize a Dialer.
Expand Down Expand Up @@ -150,8 +177,8 @@ type Dialer struct {
// network. By default, it is golang.org/x/net/proxy#Dial.
dialFunc func(cxt context.Context, network, addr string) (net.Conn, error)

// iamTokenSource supplies the OAuth2 token used for IAM DB Authn.
iamTokenSource oauth2.TokenSource
// iamTokenProvider supplies the OAuth2 token used for IAM DB Authn.
iamTokenProvider auth.TokenProvider

// resolver converts instance names into DNS names.
resolver instance.ConnectionNameResolver
Expand All @@ -174,12 +201,11 @@ func (nullLogger) Debugf(_ context.Context, _ string, _ ...interface{}) {}
// RSA keypair is generated will be faster.
func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
cfg := &dialerConfig{
refreshTimeout: cloudsql.RefreshTimeout,
dialFunc: proxy.Dial,
logger: nullLogger{},
useragents: []string{userAgent},
serviceUniverse: "googleapis.com",
failoverPeriod: cloudsql.FailoverPeriod,
refreshTimeout: cloudsql.RefreshTimeout,
dialFunc: proxy.Dial,
logger: nullLogger{},
useragents: []string{userAgent},
failoverPeriod: cloudsql.FailoverPeriod,
}
for _, opt := range opts {
opt(cfg)
Expand All @@ -197,40 +223,41 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
// Add this to the end to make sure it's not overridden
cfg.sqladminOpts = append(cfg.sqladminOpts, option.WithUserAgent(strings.Join(cfg.useragents, " ")))

// If callers have not provided a token source, either explicitly with
// WithTokenSource or implicitly with WithCredentialsJSON etc., then use the
// default token source.
// If callers have not provided a credential source, either explicitly with
// WithTokenSource or implicitly with WithCredentialsJSON etc., then use
// default credentials
if !cfg.setCredentials {
c, err := google.FindDefaultCredentials(ctx, sqladmin.SqlserviceAdminScope)
c, err := credentials.DetectDefault(&credentials.DetectOptions{
Scopes: []string{sqladmin.SqlserviceAdminScope},
})
if err != nil {
return nil, fmt.Errorf("failed to create default credentials: %v", err)
}
ud, err := c.GetUniverseDomain()
if err != nil {
return nil, fmt.Errorf("failed to get universe domain: %v", err)
}
cfg.credentialsUniverse = ud
cfg.sqladminOpts = append(cfg.sqladminOpts, option.WithTokenSource(c.TokenSource))
scoped, err := google.DefaultTokenSource(ctx, iamLoginScope)
cfg.authCredentials = c
// create second set of credentials, scoped for IAM AuthN login only
scoped, err := credentials.DetectDefault(&credentials.DetectOptions{
Scopes: []string{iamLoginScope},
})
if err != nil {
return nil, fmt.Errorf("failed to create scoped token source: %v", err)
return nil, fmt.Errorf("failed to create scoped credentials: %v", err)
}
cfg.iamLoginTokenSource = scoped
}

if cfg.setUniverseDomain && cfg.setAdminAPIEndpoint {
return nil, errors.New(
"can not use WithAdminAPIEndpoint and WithUniverseDomain Options together, " +
"use WithAdminAPIEndpoint (it already contains the universe domain)",
)
cfg.iamLoginTokenProvider = scoped.TokenProvider
}

if cfg.credentialsUniverse != "" && cfg.serviceUniverse != "" {
if cfg.credentialsUniverse != cfg.serviceUniverse {
return nil, fmt.Errorf(
"the configured service universe domain (%s) does not match the credential universe domain (%s)",
cfg.serviceUniverse, cfg.credentialsUniverse,
)
// For all credential paths, use auth library's built-in
// httptransport.NewClient
if cfg.authCredentials != nil {
authClient, err := httptransport.NewClient(&httptransport.Options{
Credentials: cfg.authCredentials,
UniverseDomain: cfg.getClientUniverseDomain(),
})
if err != nil {
return nil, fmt.Errorf("failed to create auth client: %v", err)
}
// If callers have not provided an HTTPClient explicitly with
// WithHTTPClient, then use auth client
if !cfg.setHTTPClient {
cfg.sqladminOpts = append(cfg.sqladminOpts, option.WithHTTPClient(authClient))
}
}

Expand Down Expand Up @@ -273,7 +300,7 @@ func NewDialer(ctx context.Context, opts ...Option) (*Dialer, error) {
logger: cfg.logger,
defaultDialConfig: dc,
dialerID: uuid.New().String(),
iamTokenSource: cfg.iamLoginTokenSource,
iamTokenProvider: cfg.iamLoginTokenProvider,
dialFunc: cfg.dialFunc,
resolver: r,
failoverPeriod: cfg.failoverPeriod,
Expand Down Expand Up @@ -636,15 +663,15 @@ func (d *Dialer) connectionInfoCache(
cn,
d.logger,
d.sqladmin, rsaKey,
d.refreshTimeout, d.iamTokenSource,
d.refreshTimeout, d.iamTokenProvider,
d.dialerID, useIAMAuthNDial,
)
} else {
cache = cloudsql.NewRefreshAheadCache(
cn,
d.logger,
d.sqladmin, rsaKey,
d.refreshTimeout, d.iamTokenSource,
d.refreshTimeout, d.iamTokenProvider,
d.dialerID, useIAMAuthNDial,
)
}
Expand Down
66 changes: 0 additions & 66 deletions dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -280,72 +280,6 @@ func TestSQLServerFailsOnIAMAuthN(t *testing.T) {
}
}

func TestUniverseDomain(t *testing.T) {
tcs := []struct {
desc string
opts Option
}{
{
desc: "When universe domain matches GDU",
opts: WithOptions(
WithUniverseDomain("googleapis.com"),
WithCredentialsJSON(fakeServiceAccount("")),
),
},
{
desc: "When TPC universe matches TPC credential domain",
opts: WithOptions(
WithUniverseDomain("test-universe.test"),
WithCredentialsJSON(fakeServiceAccount("test-universe.test")),
),
},
}

for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
_, err := NewDialer(context.Background(), tc.opts)
if err != nil {
t.Fatalf("NewDialer failed with error = %v", err)
}
})
}
}

func TestUniverseDomainErrors(t *testing.T) {
tcs := []struct {
desc string
opts Option
}{
{
desc: "When universe domain does not match ADC credentials from GDU",
opts: WithOptions(WithUniverseDomain("test-universe.test")),
},
{
desc: "When GDU does not match credential domain",
opts: WithOptions(WithCredentialsJSON(
fakeServiceAccount("test-universe.test"),
)),
},
{
desc: "WithUniverseDomain used alongside WithAdminAPIEndpoint",
opts: WithOptions(
WithUniverseDomain("googleapis.com"),
WithAdminAPIEndpoint("https://sqladmin.googleapis.com"),
),
},
}

for _, tc := range tcs {
t.Run(tc.desc, func(t *testing.T) {
_, err := NewDialer(context.Background(), tc.opts)
t.Log(err)
if err == nil {
t.Fatalf("Wanted universe domain mismatch, want error, got nil")
}
})
}
}

func TestDialerWithCustomDialFunc(t *testing.T) {
inst := mock.NewFakeCSQLInstance("proj", "region", "inst",
mock.WithEngineVersion("SQLSERVER"),
Expand Down
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ module cloud.google.com/go/cloudsqlconn
go 1.22

require (
cloud.google.com/go/auth v0.13.0
cloud.google.com/go/auth/oauth2adapt v0.2.6
github.com/go-sql-driver/mysql v1.8.1
github.com/google/uuid v1.6.0
github.com/jackc/pgx/v4 v4.18.3
Expand All @@ -18,8 +20,6 @@ require (
)

require (
cloud.google.com/go/auth v0.13.0 // indirect
cloud.google.com/go/auth/oauth2adapt v0.2.6 // indirect
cloud.google.com/go/compute/metadata v0.6.0 // indirect
filippo.io/edwards25519 v1.1.0 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
Expand Down
6 changes: 3 additions & 3 deletions internal/cloudsql/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ import (
"sync"
"time"

"cloud.google.com/go/auth"
"cloud.google.com/go/cloudsqlconn/debug"
"cloud.google.com/go/cloudsqlconn/errtype"
"cloud.google.com/go/cloudsqlconn/instance"
"golang.org/x/oauth2"
"golang.org/x/time/rate"
sqladmin "google.golang.org/api/sqladmin/v1beta4"
)
Expand Down Expand Up @@ -129,7 +129,7 @@ func NewRefreshAheadCache(
client *sqladmin.Service,
key *rsa.PrivateKey,
refreshTimeout time.Duration,
ts oauth2.TokenSource,
tp auth.TokenProvider,
dialerID string,
useIAMAuthNDial bool,
) *RefreshAheadCache {
Expand All @@ -142,7 +142,7 @@ func NewRefreshAheadCache(
l,
client,
key,
ts,
tp,
dialerID,
),
refreshTimeout: refreshTimeout,
Expand Down
6 changes: 3 additions & 3 deletions internal/cloudsql/lazy.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ import (
"sync"
"time"

"cloud.google.com/go/auth"
"cloud.google.com/go/cloudsqlconn/debug"
"cloud.google.com/go/cloudsqlconn/instance"
"golang.org/x/oauth2"
sqladmin "google.golang.org/api/sqladmin/v1beta4"
)

Expand All @@ -45,7 +45,7 @@ func NewLazyRefreshCache(
client *sqladmin.Service,
key *rsa.PrivateKey,
_ time.Duration,
ts oauth2.TokenSource,
tp auth.TokenProvider,
dialerID string,
useIAMAuthNDial bool,
) *LazyRefreshCache {
Expand All @@ -56,7 +56,7 @@ func NewLazyRefreshCache(
l,
client,
key,
ts,
tp,
dialerID,
),
useIAMAuthNDial: useIAMAuthNDial,
Expand Down
Loading

0 comments on commit 52fef27

Please sign in to comment.