From 69ff02dc507da4a7d059dbb85d19d7a176e14f3b Mon Sep 17 00:00:00 2001 From: Micah Hausler Date: Thu, 29 Aug 2024 14:36:13 -0500 Subject: [PATCH] Update filecache to use AWS SDK Go V2 with wrappers This changes updates filecache's internal types to use the AWS SDK Go v2's types, while preserving the external interface used by /pkg/token. This will simplify the future project-wide change for AWS SDK Go v2. Signed-off-by: Micah Hausler --- go.mod | 2 + go.sum | 4 + pkg/filecache/converter.go | 55 +++++++ pkg/filecache/filecache.go | 82 ++++----- pkg/filecache/filecache_test.go | 284 ++++++++++++++++---------------- pkg/token/token.go | 6 +- 6 files changed, 240 insertions(+), 193 deletions(-) create mode 100644 pkg/filecache/converter.go diff --git a/go.mod b/go.mod index 3d866c974..58736482f 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.22.5 require ( github.com/aws/aws-sdk-go v1.54.6 + github.com/aws/aws-sdk-go-v2 v1.30.4 github.com/fsnotify/fsnotify v1.7.0 github.com/gofrs/flock v0.8.1 github.com/google/go-cmp v0.6.0 @@ -25,6 +26,7 @@ require ( ) require ( + github.com/aws/smithy-go v1.20.4 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/blang/semver/v4 v4.0.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect diff --git a/go.sum b/go.sum index 2a3bcb3a0..9d1fc5538 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,9 @@ github.com/aws/aws-sdk-go v1.54.6 h1:HEYUib3yTt8E6vxjMWM3yAq5b+qjj/6aKA62mkgux9g= github.com/aws/aws-sdk-go v1.54.6/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU= +github.com/aws/aws-sdk-go-v2 v1.30.4 h1:frhcagrVNrzmT95RJImMHgabt99vkXGslubDaDagTk8= +github.com/aws/aws-sdk-go-v2 v1.30.4/go.mod h1:CT+ZPWXbYrci8chcARI3OmI/qgd+f6WtuLOoaIA8PR0= +github.com/aws/smithy-go v1.20.4 h1:2HK1zBdPgRbjFOHlfeQZfpC4r72MOb9bZkiFwggKO+4= +github.com/aws/smithy-go v1.20.4/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/blang/semver/v4 v4.0.0 h1:1PFHFE6yCCTv8C1TeyNNarDzntLi7wMI5i/pzqYIsAM= diff --git a/pkg/filecache/converter.go b/pkg/filecache/converter.go new file mode 100644 index 000000000..ec2f16bde --- /dev/null +++ b/pkg/filecache/converter.go @@ -0,0 +1,55 @@ +package filecache + +import ( + "context" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go/aws/credentials" +) + +type v2 struct { + creds *credentials.Credentials +} + +var _ aws.CredentialsProvider = &v2{} + +func (p *v2) Retrieve(ctx context.Context) (aws.Credentials, error) { + val, err := p.creds.GetWithContext(ctx) + if err != nil { + return aws.Credentials{}, err + } + resp := aws.Credentials{ + AccessKeyID: val.AccessKeyID, + SecretAccessKey: val.SecretAccessKey, + SessionToken: val.SessionToken, + Source: val.ProviderName, + CanExpire: false, + // Don't have account ID + } + + if expiration, err := p.creds.ExpiresAt(); err != nil { + resp.CanExpire = true + resp.Expires = expiration + } + return resp, nil +} + +// V1ProviderToV2Provider converts a v1 credentials.Provider to a v2 aws.CredentialsProvider +func V1ProviderToV2Provider(p credentials.Provider) aws.CredentialsProvider { + return V1CredentialToV2Provider(credentials.NewCredentials(p)) +} + +// V1CredentialToV2Provider converts a v1 credentials.Credential to a v2 aws.CredentialProvider +func V1CredentialToV2Provider(c *credentials.Credentials) aws.CredentialsProvider { + return &v2{creds: c} +} + +// V2CredentialToV1Value converts a v2 aws.Credentials to a v1 credentials.Value +func V2CredentialToV1Value(cred aws.Credentials) credentials.Value { + return credentials.Value{ + AccessKeyID: cred.AccessKeyID, + SecretAccessKey: cred.SecretAccessKey, + SessionToken: cred.SessionToken, + ProviderName: cred.Source, + } +} diff --git a/pkg/filecache/filecache.go b/pkg/filecache/filecache.go index 41597edaa..64092b9f4 100644 --- a/pkg/filecache/filecache.go +++ b/pkg/filecache/filecache.go @@ -10,6 +10,7 @@ import ( "runtime" "time" + "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go/aws/credentials" "github.com/gofrs/flock" "github.com/spf13/afero" @@ -34,7 +35,7 @@ func NewFileLocker(filename string) FileLocker { // cacheFile is a map of clusterID/roleARNs to cached credentials type cacheFile struct { // a map of clusterIDs/profiles/roleARNs to cachedCredentials - ClusterMap map[string]map[string]map[string]cachedCredential `yaml:"clusters"` + ClusterMap map[string]map[string]map[string]aws.Credentials `yaml:"clusters"` } // a utility type for dealing with compound cache keys @@ -44,19 +45,19 @@ type cacheKey struct { roleARN string } -func (c *cacheFile) Put(key cacheKey, credential cachedCredential) { +func (c *cacheFile) Put(key cacheKey, credential aws.Credentials) { if _, ok := c.ClusterMap[key.clusterID]; !ok { // first use of this cluster id - c.ClusterMap[key.clusterID] = map[string]map[string]cachedCredential{} + c.ClusterMap[key.clusterID] = map[string]map[string]aws.Credentials{} } if _, ok := c.ClusterMap[key.clusterID][key.profile]; !ok { // first use of this profile - c.ClusterMap[key.clusterID][key.profile] = map[string]cachedCredential{} + c.ClusterMap[key.clusterID][key.profile] = map[string]aws.Credentials{} } c.ClusterMap[key.clusterID][key.profile][key.roleARN] = credential } -func (c *cacheFile) Get(key cacheKey) (credential cachedCredential) { +func (c *cacheFile) Get(key cacheKey) (credential aws.Credentials) { if _, ok := c.ClusterMap[key.clusterID]; ok { if _, ok := c.ClusterMap[key.clusterID][key.profile]; ok { // we at least have this cluster and profile combo in the map, if no matching roleARN, map will @@ -67,31 +68,12 @@ func (c *cacheFile) Get(key cacheKey) (credential cachedCredential) { return } -// cachedCredential is a single cached credential entry, along with expiration time -type cachedCredential struct { - Credential credentials.Value - Expiration time.Time - // If set will be used by IsExpired to determine the current time. - // Defaults to time.Now if CurrentTime is not set. Available for testing - // to be able to mock out the current time. - currentTime func() time.Time -} - -// IsExpired determines if the cached credential has expired -func (c *cachedCredential) IsExpired() bool { - curTime := c.currentTime - if curTime == nil { - curTime = time.Now - } - return c.Expiration.Before(curTime()) -} - // readCacheWhileLocked reads the contents of the credential cache and returns the // parsed yaml as a cacheFile object. This method must be called while a shared // lock is held on the filename. func readCacheWhileLocked(fs afero.Fs, filename string) (cache cacheFile, err error) { cache = cacheFile{ - map[string]map[string]map[string]cachedCredential{}, + map[string]map[string]map[string]aws.Credentials{}, } data, err := afero.ReadFile(fs, filename) if err != nil { @@ -149,9 +131,9 @@ type FileCacheProvider struct { fs afero.Fs filelockCreator func(string) FileLocker filename string - credentials *credentials.Credentials // the underlying implementation that has the *real* Provider - cacheKey cacheKey // cache key parameters used to create Provider - cachedCredential cachedCredential // the cached credential, if it exists + provider aws.CredentialsProvider // the underlying implementation that has the *real* Provider + cacheKey cacheKey // cache key parameters used to create Provider + cachedCredential aws.Credentials // the cached credential, if it exists } var _ credentials.Provider = &FileCacheProvider{} @@ -160,8 +142,8 @@ var _ credentials.Provider = &FileCacheProvider{} // and works with an on disk cache to speed up credential usage when the cached copy is not expired. // If there are any problems accessing or initializing the cache, an error will be returned, and // callers should just use the existing credentials provider. -func NewFileCacheProvider(clusterID, profile, roleARN string, creds *credentials.Credentials, opts ...FileCacheOpt) (*FileCacheProvider, error) { - if creds == nil { +func NewFileCacheProvider(clusterID, profile, roleARN string, provider aws.CredentialsProvider, opts ...FileCacheOpt) (*FileCacheProvider, error) { + if provider == nil { return nil, errors.New("no underlying Credentials object provided") } @@ -169,9 +151,9 @@ func NewFileCacheProvider(clusterID, profile, roleARN string, creds *credentials fs: afero.NewOsFs(), filelockCreator: NewFileLocker, filename: defaultCacheFilename(), - credentials: creds, + provider: provider, cacheKey: cacheKey{clusterID, profile, roleARN}, - cachedCredential: cachedCredential{}, + cachedCredential: aws.Credentials{}, } // override defaults @@ -222,36 +204,40 @@ func NewFileCacheProvider(clusterID, profile, roleARN string, creds *credentials // otherwise fetching the credential from the underlying Provider and caching the results on disk // with an expiration time. func (f *FileCacheProvider) Retrieve() (credentials.Value, error) { - if !f.cachedCredential.IsExpired() { + return f.RetrieveWithContext(context.Background()) +} + +// Retrieve() implements the Provider interface, returning the cached credential if is not expired, +// otherwise fetching the credential from the underlying Provider and caching the results on disk +// with an expiration time. +func (f *FileCacheProvider) RetrieveWithContext(ctx context.Context) (credentials.Value, error) { + if !f.cachedCredential.Expired() && f.cachedCredential.HasKeys() { // use the cached credential - return f.cachedCredential.Credential, nil + return V2CredentialToV1Value(f.cachedCredential), nil } else { _, _ = fmt.Fprintf(os.Stderr, "No cached credential available. Refreshing...\n") // fetch the credentials from the underlying Provider - credential, err := f.credentials.Get() + credential, err := f.provider.Retrieve(ctx) if err != nil { - return credential, err + return V2CredentialToV1Value(credential), err } - if expiration, err := f.credentials.ExpiresAt(); err == nil { - // underlying provider supports Expirer interface, so we can cache + + if credential.CanExpire { + // Credential supports expiration, so we can cache // do file locking on cache to prevent inconsistent writes lock := f.filelockCreator(f.filename) defer lock.Unlock() // wait up to a second for the file to lock - ctx, cancel := context.WithTimeout(context.TODO(), time.Second) + ctx, cancel := context.WithTimeout(ctx, time.Second) defer cancel() ok, err := lock.TryLockContext(ctx, 250*time.Millisecond) // try to lock every 1/4 second if !ok { // can't get write lock to create/update cache, but still return the credential _, _ = fmt.Fprintf(os.Stderr, "Unable to write lock file %s: %v\n", f.filename, err) - return credential, nil - } - f.cachedCredential = cachedCredential{ - credential, - expiration, - nil, + return V2CredentialToV1Value(credential), nil } + f.cachedCredential = credential // don't really care about read error. Either read the cache, or we create a new cache. cache, _ := readCacheWhileLocked(f.fs, f.filename) cache.Put(f.cacheKey, f.cachedCredential) @@ -268,19 +254,19 @@ func (f *FileCacheProvider) Retrieve() (credentials.Value, error) { _, _ = fmt.Fprintf(os.Stderr, "Unable to cache credential: %v\n", err) err = nil } - return credential, err + return V2CredentialToV1Value(credential), err } } // IsExpired() implements the Provider interface, deferring to the cached credential first, // but fall back to the underlying Provider if it is expired. func (f *FileCacheProvider) IsExpired() bool { - return f.cachedCredential.IsExpired() && f.credentials.IsExpired() + return f.cachedCredential.CanExpire && f.cachedCredential.Expired() } // ExpiresAt implements the Expirer interface, and gives access to the expiration time of the credential func (f *FileCacheProvider) ExpiresAt() time.Time { - return f.cachedCredential.Expiration + return f.cachedCredential.Expires } // defaultCacheFilename returns the name of the credential cache file, which can either be diff --git a/pkg/filecache/filecache_test.go b/pkg/filecache/filecache_test.go index 60b4a8771..f2db98556 100644 --- a/pkg/filecache/filecache_test.go +++ b/pkg/filecache/filecache_test.go @@ -1,7 +1,6 @@ package filecache import ( - "bytes" "context" "errors" "fmt" @@ -10,7 +9,8 @@ import ( "testing" "time" - "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/google/go-cmp/cmp" "github.com/spf13/afero" ) @@ -20,35 +20,17 @@ const ( // stubProvider implements credentials.Provider with configurable response values type stubProvider struct { - creds credentials.Value - expired bool - err error + creds aws.Credentials + err error } -var _ credentials.Provider = &stubProvider{} +var _ aws.CredentialsProvider = &stubProvider{} -func (s *stubProvider) Retrieve() (credentials.Value, error) { - s.expired = false - s.creds.ProviderName = "stubProvider" +func (s *stubProvider) Retrieve(_ context.Context) (aws.Credentials, error) { + s.creds.Source = "stubProvider" return s.creds, s.err } -func (s *stubProvider) IsExpired() bool { - return s.expired -} - -// stubProviderExpirer implements credentials.Expirer with configurable expiration -type stubProviderExpirer struct { - stubProvider - expiration time.Time -} - -var _ credentials.Expirer = &stubProviderExpirer{} - -func (s *stubProviderExpirer) ExpiresAt() time.Time { - return s.expiration -} - // testFileInfo implements fs.FileInfo with configurable response values type testFileInfo struct { name string @@ -116,22 +98,34 @@ func getMocks() (*testFS, *testFilelock) { } // makeCredential returns a dummy AWS crdential -func makeCredential() credentials.Value { - return credentials.Value{ +func makeCredential() aws.Credentials { + return aws.Credentials{ AccessKeyID: "AKID", SecretAccessKey: "SECRET", SessionToken: "TOKEN", - ProviderName: "stubProvider", + Source: "stubProvider", + CanExpire: false, + } +} + +func makeExpiringCredential(e time.Time) aws.Credentials { + return aws.Credentials{ + AccessKeyID: "AKID", + SecretAccessKey: "SECRET", + SessionToken: "TOKEN", + Source: "stubProvider", + CanExpire: true, + Expires: e, } } // validateFileCacheProvider ensures that the cache provider is properly initialized -func validateFileCacheProvider(t *testing.T, p *FileCacheProvider, err error, c *credentials.Credentials) { +func validateFileCacheProvider(t *testing.T, p *FileCacheProvider, err error, c aws.CredentialsProvider) { t.Helper() if err != nil { t.Errorf("Unexpected error: %v", err) } - if p.credentials != c { + if p.provider != c { t.Errorf("Credentials not copied") } if p.cacheKey.clusterID != "CLUSTER" { @@ -184,24 +178,24 @@ func TestCacheFilename(t *testing.T) { } func TestNewFileCacheProvider_Missing(t *testing.T) { - c := credentials.NewCredentials(&stubProvider{}) + provider := &stubProvider{} tfs, tfl := getMocks() - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, WithFilename(testFilename), WithFs(tfs), WithFileLockerCreator(func(string) FileLocker { return tfl })) - validateFileCacheProvider(t, p, err, c) - if !p.cachedCredential.IsExpired() { - t.Errorf("missing cache file should result in expired cached credential") + validateFileCacheProvider(t, p, err, provider) + if p.cachedCredential.HasKeys() { + t.Errorf("missing cache file should result in empty cached credential") } } func TestNewFileCacheProvider_BadPermissions(t *testing.T) { - c := credentials.NewCredentials(&stubProvider{}) + provider := &stubProvider{} tfs, _ := getMocks() // afero.MemMapFs always returns tempfile FileInfo, @@ -209,7 +203,7 @@ func TestNewFileCacheProvider_BadPermissions(t *testing.T) { tfs.fileinfo = &testFileInfo{mode: 0777} // bad permissions - _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, WithFilename(testFilename), WithFs(tfs), ) @@ -223,7 +217,7 @@ func TestNewFileCacheProvider_BadPermissions(t *testing.T) { } func TestNewFileCacheProvider_Unlockable(t *testing.T) { - c := credentials.NewCredentials(&stubProvider{}) + provider := &stubProvider{} tfs, tfl := getMocks() tfs.Create(testFilename) @@ -232,7 +226,7 @@ func TestNewFileCacheProvider_Unlockable(t *testing.T) { tfl.success = false tfl.err = errors.New("lock stuck, needs wd-40") - _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, WithFilename(testFilename), WithFs(tfs), WithFileLockerCreator(func(string) FileLocker { @@ -245,14 +239,14 @@ func TestNewFileCacheProvider_Unlockable(t *testing.T) { } func TestNewFileCacheProvider_Unreadable(t *testing.T) { - c := credentials.NewCredentials(&stubProvider{}) + provider := &stubProvider{} tfs, tfl := getMocks() tfs.Create(testFilename) tfl.err = fmt.Errorf("open %s: permission denied", testFilename) tfl.success = false - _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, WithFilename(testFilename), WithFs(tfs), WithFileLockerCreator(func(string) FileLocker { @@ -270,12 +264,12 @@ func TestNewFileCacheProvider_Unreadable(t *testing.T) { } func TestNewFileCacheProvider_Unparseable(t *testing.T) { - c := credentials.NewCredentials(&stubProvider{}) + provider := &stubProvider{} tfs, tfl := getMocks() tfs.Create(testFilename) - _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + _, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, WithFilename(testFilename), WithFs(tfs), WithFileLockerCreator(func(string) FileLocker { @@ -297,12 +291,12 @@ func TestNewFileCacheProvider_Unparseable(t *testing.T) { } func TestNewFileCacheProvider_Empty(t *testing.T) { - c := credentials.NewCredentials(&stubProvider{}) + provider := &stubProvider{} tfs, tfl := getMocks() // successfully parse existing but empty cache file - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, WithFilename(testFilename), WithFs(tfs), WithFileLockerCreator(func(string) FileLocker { @@ -313,58 +307,60 @@ func TestNewFileCacheProvider_Empty(t *testing.T) { t.Errorf("Unexpected error: %v", err) return } - validateFileCacheProvider(t, p, err, c) - if !p.cachedCredential.IsExpired() { - t.Errorf("empty cache file should result in expired cached credential") + validateFileCacheProvider(t, p, err, provider) + if p.cachedCredential.HasKeys() { + t.Errorf("empty cache file should result in empty cached credential") } } func TestNewFileCacheProvider_ExistingCluster(t *testing.T) { - c := credentials.NewCredentials(&stubProvider{}) + provider := &stubProvider{} tfs, tfl := getMocks() - afero.WriteFile( - tfs, - testFilename, - []byte(`clusters: - CLUSTER: - ARN2: {} -`), - 0700) + tfs.Create(testFilename) + // successfully parse existing cluster without matching arn - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, WithFilename(testFilename), WithFs(tfs), WithFileLockerCreator(func(string) FileLocker { - tfs.Create(testFilename) + + afero.WriteFile( + tfs, + testFilename, + []byte(`clusters: + CLUSTER: + PROFILE2: {} +`), + 0700) return tfl }), ) - validateFileCacheProvider(t, p, err, c) - if !p.cachedCredential.IsExpired() { - t.Errorf("missing arn in cache file should result in expired cached credential") + validateFileCacheProvider(t, p, err, provider) + if p.cachedCredential.HasKeys() { + t.Errorf("missing profile in cache file should result in empty cached credential") } } func TestNewFileCacheProvider_ExistingARN(t *testing.T) { - c := credentials.NewCredentials(&stubProvider{}) + provider := &stubProvider{} + expiry := time.Now().Add(time.Hour * 6) content := []byte(`clusters: CLUSTER: PROFILE: ARN: - credential: - accesskeyid: ABC - secretaccesskey: DEF - sessiontoken: GHI - providername: JKL - expiration: 2018-01-02T03:04:56.789Z + accesskeyid: ABC + secretaccesskey: DEF + sessiontoken: GHI + source: JKL + expires: ` + expiry.Format(time.RFC3339Nano) + ` `) tfs, tfl := getMocks() tfs.Create(testFilename) // successfully parse cluster with matching arn - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, WithFilename(testFilename), WithFs(tfs), WithFileLockerCreator(func(string) FileLocker { @@ -377,38 +373,31 @@ func TestNewFileCacheProvider_ExistingARN(t *testing.T) { t.Errorf("Unexpected error: %v", err) return } - validateFileCacheProvider(t, p, err, c) - if p.cachedCredential.Credential.AccessKeyID != "ABC" || p.cachedCredential.Credential.SecretAccessKey != "DEF" || - p.cachedCredential.Credential.SessionToken != "GHI" || p.cachedCredential.Credential.ProviderName != "JKL" { + validateFileCacheProvider(t, p, err, provider) + if p.cachedCredential.AccessKeyID != "ABC" || p.cachedCredential.SecretAccessKey != "DEF" || + p.cachedCredential.SessionToken != "GHI" || p.cachedCredential.Source != "JKL" { t.Errorf("cached credential not extracted correctly, got %v", p.cachedCredential) } - // fiddle with clock - p.cachedCredential.currentTime = func() time.Time { - return time.Date(2017, 12, 25, 12, 23, 45, 678, time.UTC) - } - if p.cachedCredential.IsExpired() { + + if p.cachedCredential.Expired() { t.Errorf("Cached credential should not be expired") } - if p.IsExpired() { - t.Errorf("Cache credential should not be expired") - } - expectedExpiration := time.Date(2018, 01, 02, 03, 04, 56, 789000000, time.UTC) - if p.ExpiresAt() != expectedExpiration { + + if p.ExpiresAt() != p.cachedCredential.Expires { t.Errorf("Credential expiration time is not correct, expected %v, got %v", - expectedExpiration, p.ExpiresAt()) + p.cachedCredential.Expires, p.ExpiresAt()) } } func TestFileCacheProvider_Retrieve_NoExpirer(t *testing.T) { - providerCredential := makeCredential() - c := credentials.NewCredentials(&stubProvider{ - creds: providerCredential, - }) + provider := &stubProvider{ + creds: makeCredential(), + } tfs, tfl := getMocks() // don't create the empty cache file, create it in the filelock creator - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, WithFilename(testFilename), WithFs(tfs), WithFileLockerCreator(func(string) FileLocker { @@ -416,45 +405,37 @@ func TestFileCacheProvider_Retrieve_NoExpirer(t *testing.T) { return tfl }), ) - validateFileCacheProvider(t, p, err, c) + validateFileCacheProvider(t, p, err, provider) credential, err := p.Retrieve() if err != nil { t.Errorf("Unexpected error: %v", err) } - if credential != providerCredential { + if credential.AccessKeyID != provider.creds.AccessKeyID || + credential.SecretAccessKey != provider.creds.SecretAccessKey || + credential.SessionToken != provider.creds.SessionToken { t.Errorf("Cache did not return provider credential, got %v, expected %v", - credential, providerCredential) + credential, provider.creds) } } -// makeExpirerCredentials returns an expiring credential -func makeExpirerCredentials() (providerCredential credentials.Value, expiration time.Time, c *credentials.Credentials) { - providerCredential = makeCredential() - expiration = time.Date(2020, 9, 19, 13, 14, 0, 1000000, time.UTC) - c = credentials.NewCredentials(&stubProviderExpirer{ - stubProvider{ - creds: providerCredential, - }, - expiration, - }) - return -} - func TestFileCacheProvider_Retrieve_WithExpirer_Unlockable(t *testing.T) { - providerCredential, _, c := makeExpirerCredentials() + expires := time.Now().Add(time.Hour * 6) + provider := &stubProvider{ + creds: makeExpiringCredential(expires), + } tfs, tfl := getMocks() // don't create the empty cache file, create it in the filelock creator - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, WithFilename(testFilename), WithFs(tfs), WithFileLockerCreator(func(string) FileLocker { tfs.Create(testFilename) return tfl })) - validateFileCacheProvider(t, p, err, c) + validateFileCacheProvider(t, p, err, provider) // retrieve credential, which will fetch from underlying Provider // fail to get write lock @@ -465,19 +446,22 @@ func TestFileCacheProvider_Retrieve_WithExpirer_Unlockable(t *testing.T) { if err != nil { t.Errorf("Unexpected error: %v", err) } - if credential != providerCredential { - t.Errorf("Cache did not return provider credential, got %v, expected %v", - credential, providerCredential) + if credential.AccessKeyID != "AKID" || credential.SecretAccessKey != "SECRET" || + credential.SessionToken != "TOKEN" || credential.ProviderName != "stubProvider" { + t.Errorf("cached credential not extracted correctly, got %v", p.cachedCredential) } } func TestFileCacheProvider_Retrieve_WithExpirer_Unwritable(t *testing.T) { - providerCredential, expiration, c := makeExpirerCredentials() + expires := time.Now().Add(time.Hour * 6) + provider := &stubProvider{ + creds: makeExpiringCredential(expires), + } tfs, tfl := getMocks() // don't create the file, let the FileLocker create it - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, WithFilename(testFilename), WithFs(tfs), WithFileLockerCreator(func(string) FileLocker { @@ -485,45 +469,50 @@ func TestFileCacheProvider_Retrieve_WithExpirer_Unwritable(t *testing.T) { return tfl }), ) - validateFileCacheProvider(t, p, err, c) + validateFileCacheProvider(t, p, err, provider) credential, err := p.Retrieve() if err != nil { t.Errorf("Unexpected error: %v", err) } - if credential != providerCredential { - t.Errorf("Cache did not return provider credential, got %v, expected %v", - credential, providerCredential) + if credential.AccessKeyID != provider.creds.AccessKeyID || + credential.SecretAccessKey != provider.creds.SecretAccessKey || + credential.SessionToken != provider.creds.SessionToken || + credential.ProviderName != provider.creds.Source { + t.Errorf("cached credential not extracted correctly, got %v", p.cachedCredential) } expectedData := []byte(`clusters: CLUSTER: PROFILE: ARN: - credential: - accesskeyid: AKID - secretaccesskey: SECRET - sessiontoken: TOKEN - providername: stubProvider - expiration: ` + expiration.Format(time.RFC3339Nano) + ` + accesskeyid: AKID + secretaccesskey: SECRET + sessiontoken: TOKEN + source: stubProvider + canexpire: true + expires: ` + expires.Format(time.RFC3339Nano) + ` + accountid: "" `) got, err := afero.ReadFile(tfs, testFilename) if err != nil { t.Errorf("unexpected error reading generated file: %v", err) } - if !bytes.Equal(got, expectedData) { - t.Errorf("Wrong data written to cache, expected: %s, got %s", - expectedData, got) + if diff := cmp.Diff(got, expectedData); diff != "" { + t.Errorf("Wrong data written to cache, %s", diff) } } func TestFileCacheProvider_Retrieve_WithExpirer_Writable(t *testing.T) { - providerCredential, _, c := makeExpirerCredentials() + expires := time.Now().Add(time.Hour * 6) + provider := &stubProvider{ + creds: makeExpiringCredential(expires), + } tfs, tfl := getMocks() // don't create the file, let the FileLocker create it - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, WithFilename(testFilename), WithFs(tfs), WithFileLockerCreator(func(string) FileLocker { @@ -531,7 +520,7 @@ func TestFileCacheProvider_Retrieve_WithExpirer_Writable(t *testing.T) { return tfl }), ) - validateFileCacheProvider(t, p, err, c) + validateFileCacheProvider(t, p, err, provider) // retrieve credential, which will fetch from underlying Provider // same as TestFileCacheProvider_Retrieve_WithExpirer_Unwritable, @@ -540,15 +529,17 @@ func TestFileCacheProvider_Retrieve_WithExpirer_Writable(t *testing.T) { if err != nil { t.Errorf("Unexpected error: %v", err) } - if credential != providerCredential { - t.Errorf("Cache did not return provider credential, got %v, expected %v", - credential, providerCredential) + if credential.AccessKeyID != provider.creds.AccessKeyID || + credential.SecretAccessKey != provider.creds.SecretAccessKey || + credential.SessionToken != provider.creds.SessionToken || + credential.ProviderName != provider.creds.Source { + t.Errorf("cached credential not extracted correctly, got %v", p.cachedCredential) } } func TestFileCacheProvider_Retrieve_CacheHit(t *testing.T) { - c := credentials.NewCredentials(&stubProvider{}) - currentTime := time.Date(2017, 12, 25, 12, 23, 45, 678, time.UTC) + provider := &stubProvider{} + currentTime := time.Now() tfs, tfl := getMocks() tfs.Create(testFilename) @@ -559,13 +550,14 @@ func TestFileCacheProvider_Retrieve_CacheHit(t *testing.T) { PROFILE: ARN: credential: - accesskeyid: ABC - secretaccesskey: DEF - sessiontoken: GHI - providername: JKL - expiration: ` + currentTime.Add(time.Hour*6).Format(time.RFC3339Nano) + ` + accesskeyid: ABC + secretaccesskey: DEF + sessiontoken: GHI + source: JKL + canexpire: true + expires: ` + currentTime.Add(time.Hour*6).Format(time.RFC3339Nano) + ` `) - p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", c, + p, err := NewFileCacheProvider("CLUSTER", "PROFILE", "ARN", provider, WithFilename(testFilename), WithFs(tfs), WithFileLockerCreator(func(string) FileLocker { @@ -573,10 +565,7 @@ func TestFileCacheProvider_Retrieve_CacheHit(t *testing.T) { afero.WriteFile(tfs, testFilename, content, 0700) return tfl })) - validateFileCacheProvider(t, p, err, c) - - // fiddle with clock - p.cachedCredential.currentTime = func() time.Time { return currentTime } + validateFileCacheProvider(t, p, err, provider) credential, err := p.Retrieve() if err != nil { @@ -586,4 +575,11 @@ func TestFileCacheProvider_Retrieve_CacheHit(t *testing.T) { credential.SessionToken != "GHI" || credential.ProviderName != "JKL" { t.Errorf("cached credential not returned") } + + if !p.ExpiresAt().Equal(currentTime.Add(time.Hour * 6)) { + t.Errorf("unexpected expiration time: got %s, wanted %s", + p.ExpiresAt().Format(time.RFC3339Nano), + currentTime.Add(time.Hour*6).Format(time.RFC3339Nano), + ) + } } diff --git a/pkg/token/token.go b/pkg/token/token.go index d9d7fd2e8..716a8cb12 100644 --- a/pkg/token/token.go +++ b/pkg/token/token.go @@ -248,7 +248,11 @@ func (g generator) GetWithOptions(options *GetTokenOptions) (Token, error) { profile = session.DefaultSharedConfigProfile } // create a cacheing Provider wrapper around the Credentials - if cacheProvider, err := filecache.NewFileCacheProvider(options.ClusterID, profile, options.AssumeRoleARN, sess.Config.Credentials); err == nil { + if cacheProvider, err := filecache.NewFileCacheProvider( + options.ClusterID, + profile, + options.AssumeRoleARN, + filecache.V1CredentialToV2Provider(sess.Config.Credentials)); err == nil { sess.Config.Credentials = credentials.NewCredentials(cacheProvider) } else { fmt.Fprintf(os.Stderr, "unable to use cache: %v\n", err)