Skip to content

Commit

Permalink
Update filecache to use AWS SDK Go V2 with wrappers
Browse files Browse the repository at this point in the history
This changes updates filecache's internal types to use the AWS SDK Go
v2's types, while preserving the external interface used by /pkg/token.
This will simplify the future project-wide change for AWS SDK Go v2.

Signed-off-by: Micah Hausler <[email protected]>
  • Loading branch information
micahhausler committed Aug 29, 2024
1 parent 8464316 commit 69ff02d
Show file tree
Hide file tree
Showing 6 changed files with 240 additions and 193 deletions.
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ go 1.22.5

require (
github.com/aws/aws-sdk-go v1.54.6
github.com/aws/aws-sdk-go-v2 v1.30.4
github.com/fsnotify/fsnotify v1.7.0
github.com/gofrs/flock v0.8.1
github.com/google/go-cmp v0.6.0
Expand All @@ -25,6 +26,7 @@ require (
)

require (
github.com/aws/smithy-go v1.20.4 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/blang/semver/v4 v4.0.0 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
Expand Down
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
github.com/aws/aws-sdk-go v1.54.6 h1:HEYUib3yTt8E6vxjMWM3yAq5b+qjj/6aKA62mkgux9g=
github.com/aws/aws-sdk-go v1.54.6/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU=
github.com/aws/aws-sdk-go-v2 v1.30.4 h1:frhcagrVNrzmT95RJImMHgabt99vkXGslubDaDagTk8=
github.com/aws/aws-sdk-go-v2 v1.30.4/go.mod h1:CT+ZPWXbYrci8chcARI3OmI/qgd+f6WtuLOoaIA8PR0=
github.com/aws/smithy-go v1.20.4 h1:2HK1zBdPgRbjFOHlfeQZfpC4r72MOb9bZkiFwggKO+4=
github.com/aws/smithy-go v1.20.4/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/blang/semver/v4 v4.0.0 h1:1PFHFE6yCCTv8C1TeyNNarDzntLi7wMI5i/pzqYIsAM=
Expand Down
55 changes: 55 additions & 0 deletions pkg/filecache/converter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package filecache

import (
"context"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
)

type v2 struct {
creds *credentials.Credentials
}

var _ aws.CredentialsProvider = &v2{}

func (p *v2) Retrieve(ctx context.Context) (aws.Credentials, error) {
val, err := p.creds.GetWithContext(ctx)
if err != nil {
return aws.Credentials{}, err
}
resp := aws.Credentials{
AccessKeyID: val.AccessKeyID,
SecretAccessKey: val.SecretAccessKey,
SessionToken: val.SessionToken,
Source: val.ProviderName,
CanExpire: false,
// Don't have account ID
}

if expiration, err := p.creds.ExpiresAt(); err != nil {
resp.CanExpire = true
resp.Expires = expiration
}
return resp, nil
}

// V1ProviderToV2Provider converts a v1 credentials.Provider to a v2 aws.CredentialsProvider
func V1ProviderToV2Provider(p credentials.Provider) aws.CredentialsProvider {
return V1CredentialToV2Provider(credentials.NewCredentials(p))
}

// V1CredentialToV2Provider converts a v1 credentials.Credential to a v2 aws.CredentialProvider
func V1CredentialToV2Provider(c *credentials.Credentials) aws.CredentialsProvider {
return &v2{creds: c}
}

// V2CredentialToV1Value converts a v2 aws.Credentials to a v1 credentials.Value
func V2CredentialToV1Value(cred aws.Credentials) credentials.Value {
return credentials.Value{
AccessKeyID: cred.AccessKeyID,
SecretAccessKey: cred.SecretAccessKey,
SessionToken: cred.SessionToken,
ProviderName: cred.Source,
}
}
82 changes: 34 additions & 48 deletions pkg/filecache/filecache.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"runtime"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/gofrs/flock"
"github.com/spf13/afero"
Expand All @@ -34,7 +35,7 @@ func NewFileLocker(filename string) FileLocker {
// cacheFile is a map of clusterID/roleARNs to cached credentials
type cacheFile struct {
// a map of clusterIDs/profiles/roleARNs to cachedCredentials
ClusterMap map[string]map[string]map[string]cachedCredential `yaml:"clusters"`
ClusterMap map[string]map[string]map[string]aws.Credentials `yaml:"clusters"`
}

// a utility type for dealing with compound cache keys
Expand All @@ -44,19 +45,19 @@ type cacheKey struct {
roleARN string
}

func (c *cacheFile) Put(key cacheKey, credential cachedCredential) {
func (c *cacheFile) Put(key cacheKey, credential aws.Credentials) {
if _, ok := c.ClusterMap[key.clusterID]; !ok {
// first use of this cluster id
c.ClusterMap[key.clusterID] = map[string]map[string]cachedCredential{}
c.ClusterMap[key.clusterID] = map[string]map[string]aws.Credentials{}
}
if _, ok := c.ClusterMap[key.clusterID][key.profile]; !ok {
// first use of this profile
c.ClusterMap[key.clusterID][key.profile] = map[string]cachedCredential{}
c.ClusterMap[key.clusterID][key.profile] = map[string]aws.Credentials{}
}
c.ClusterMap[key.clusterID][key.profile][key.roleARN] = credential
}

func (c *cacheFile) Get(key cacheKey) (credential cachedCredential) {
func (c *cacheFile) Get(key cacheKey) (credential aws.Credentials) {
if _, ok := c.ClusterMap[key.clusterID]; ok {
if _, ok := c.ClusterMap[key.clusterID][key.profile]; ok {
// we at least have this cluster and profile combo in the map, if no matching roleARN, map will
Expand All @@ -67,31 +68,12 @@ func (c *cacheFile) Get(key cacheKey) (credential cachedCredential) {
return
}

// cachedCredential is a single cached credential entry, along with expiration time
type cachedCredential struct {
Credential credentials.Value
Expiration time.Time
// If set will be used by IsExpired to determine the current time.
// Defaults to time.Now if CurrentTime is not set. Available for testing
// to be able to mock out the current time.
currentTime func() time.Time
}

// IsExpired determines if the cached credential has expired
func (c *cachedCredential) IsExpired() bool {
curTime := c.currentTime
if curTime == nil {
curTime = time.Now
}
return c.Expiration.Before(curTime())
}

// readCacheWhileLocked reads the contents of the credential cache and returns the
// parsed yaml as a cacheFile object. This method must be called while a shared
// lock is held on the filename.
func readCacheWhileLocked(fs afero.Fs, filename string) (cache cacheFile, err error) {
cache = cacheFile{
map[string]map[string]map[string]cachedCredential{},
map[string]map[string]map[string]aws.Credentials{},
}
data, err := afero.ReadFile(fs, filename)
if err != nil {
Expand Down Expand Up @@ -149,9 +131,9 @@ type FileCacheProvider struct {
fs afero.Fs
filelockCreator func(string) FileLocker
filename string
credentials *credentials.Credentials // the underlying implementation that has the *real* Provider
cacheKey cacheKey // cache key parameters used to create Provider
cachedCredential cachedCredential // the cached credential, if it exists
provider aws.CredentialsProvider // the underlying implementation that has the *real* Provider
cacheKey cacheKey // cache key parameters used to create Provider
cachedCredential aws.Credentials // the cached credential, if it exists
}

var _ credentials.Provider = &FileCacheProvider{}
Expand All @@ -160,18 +142,18 @@ var _ credentials.Provider = &FileCacheProvider{}
// and works with an on disk cache to speed up credential usage when the cached copy is not expired.
// If there are any problems accessing or initializing the cache, an error will be returned, and
// callers should just use the existing credentials provider.
func NewFileCacheProvider(clusterID, profile, roleARN string, creds *credentials.Credentials, opts ...FileCacheOpt) (*FileCacheProvider, error) {
if creds == nil {
func NewFileCacheProvider(clusterID, profile, roleARN string, provider aws.CredentialsProvider, opts ...FileCacheOpt) (*FileCacheProvider, error) {
if provider == nil {
return nil, errors.New("no underlying Credentials object provided")
}

resp := &FileCacheProvider{
fs: afero.NewOsFs(),
filelockCreator: NewFileLocker,
filename: defaultCacheFilename(),
credentials: creds,
provider: provider,
cacheKey: cacheKey{clusterID, profile, roleARN},
cachedCredential: cachedCredential{},
cachedCredential: aws.Credentials{},
}

// override defaults
Expand Down Expand Up @@ -222,36 +204,40 @@ func NewFileCacheProvider(clusterID, profile, roleARN string, creds *credentials
// otherwise fetching the credential from the underlying Provider and caching the results on disk
// with an expiration time.
func (f *FileCacheProvider) Retrieve() (credentials.Value, error) {
if !f.cachedCredential.IsExpired() {
return f.RetrieveWithContext(context.Background())
}

// Retrieve() implements the Provider interface, returning the cached credential if is not expired,
// otherwise fetching the credential from the underlying Provider and caching the results on disk
// with an expiration time.
func (f *FileCacheProvider) RetrieveWithContext(ctx context.Context) (credentials.Value, error) {
if !f.cachedCredential.Expired() && f.cachedCredential.HasKeys() {
// use the cached credential
return f.cachedCredential.Credential, nil
return V2CredentialToV1Value(f.cachedCredential), nil
} else {
_, _ = fmt.Fprintf(os.Stderr, "No cached credential available. Refreshing...\n")
// fetch the credentials from the underlying Provider
credential, err := f.credentials.Get()
credential, err := f.provider.Retrieve(ctx)
if err != nil {
return credential, err
return V2CredentialToV1Value(credential), err
}
if expiration, err := f.credentials.ExpiresAt(); err == nil {
// underlying provider supports Expirer interface, so we can cache

if credential.CanExpire {
// Credential supports expiration, so we can cache

// do file locking on cache to prevent inconsistent writes
lock := f.filelockCreator(f.filename)
defer lock.Unlock()
// wait up to a second for the file to lock
ctx, cancel := context.WithTimeout(context.TODO(), time.Second)
ctx, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
ok, err := lock.TryLockContext(ctx, 250*time.Millisecond) // try to lock every 1/4 second
if !ok {
// can't get write lock to create/update cache, but still return the credential
_, _ = fmt.Fprintf(os.Stderr, "Unable to write lock file %s: %v\n", f.filename, err)
return credential, nil
}
f.cachedCredential = cachedCredential{
credential,
expiration,
nil,
return V2CredentialToV1Value(credential), nil
}
f.cachedCredential = credential
// don't really care about read error. Either read the cache, or we create a new cache.
cache, _ := readCacheWhileLocked(f.fs, f.filename)
cache.Put(f.cacheKey, f.cachedCredential)
Expand All @@ -268,19 +254,19 @@ func (f *FileCacheProvider) Retrieve() (credentials.Value, error) {
_, _ = fmt.Fprintf(os.Stderr, "Unable to cache credential: %v\n", err)
err = nil
}
return credential, err
return V2CredentialToV1Value(credential), err
}
}

// IsExpired() implements the Provider interface, deferring to the cached credential first,
// but fall back to the underlying Provider if it is expired.
func (f *FileCacheProvider) IsExpired() bool {
return f.cachedCredential.IsExpired() && f.credentials.IsExpired()
return f.cachedCredential.CanExpire && f.cachedCredential.Expired()
}

// ExpiresAt implements the Expirer interface, and gives access to the expiration time of the credential
func (f *FileCacheProvider) ExpiresAt() time.Time {
return f.cachedCredential.Expiration
return f.cachedCredential.Expires
}

// defaultCacheFilename returns the name of the credential cache file, which can either be
Expand Down
Loading

0 comments on commit 69ff02d

Please sign in to comment.