Skip to content

Commit

Permalink
Merge pull request #756 from micahhausler/filecache-sdk-upgrade
Browse files Browse the repository at this point in the history
Filecache sdk upgrade
  • Loading branch information
k8s-ci-robot authored Sep 6, 2024
2 parents 98eb3f6 + 9cdf38d commit 90beff7
Show file tree
Hide file tree
Showing 8 changed files with 246 additions and 193 deletions.
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ go 1.22.5

require (
github.com/aws/aws-sdk-go v1.54.6
github.com/aws/aws-sdk-go-v2 v1.30.4
github.com/fsnotify/fsnotify v1.7.0
github.com/gofrs/flock v0.8.1
github.com/google/go-cmp v0.6.0
Expand All @@ -25,6 +26,7 @@ require (
)

require (
github.com/aws/smithy-go v1.20.4 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/blang/semver/v4 v4.0.0 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
Expand Down
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
github.com/aws/aws-sdk-go v1.54.6 h1:HEYUib3yTt8E6vxjMWM3yAq5b+qjj/6aKA62mkgux9g=
github.com/aws/aws-sdk-go v1.54.6/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU=
github.com/aws/aws-sdk-go-v2 v1.30.4 h1:frhcagrVNrzmT95RJImMHgabt99vkXGslubDaDagTk8=
github.com/aws/aws-sdk-go-v2 v1.30.4/go.mod h1:CT+ZPWXbYrci8chcARI3OmI/qgd+f6WtuLOoaIA8PR0=
github.com/aws/smithy-go v1.20.4 h1:2HK1zBdPgRbjFOHlfeQZfpC4r72MOb9bZkiFwggKO+4=
github.com/aws/smithy-go v1.20.4/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw=
github.com/blang/semver/v4 v4.0.0 h1:1PFHFE6yCCTv8C1TeyNNarDzntLi7wMI5i/pzqYIsAM=
Expand Down
55 changes: 55 additions & 0 deletions pkg/filecache/converter.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package filecache

import (
"context"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
)

type v2 struct {
creds *credentials.Credentials
}

var _ aws.CredentialsProvider = &v2{}

func (p *v2) Retrieve(ctx context.Context) (aws.Credentials, error) {
val, err := p.creds.GetWithContext(ctx)
if err != nil {
return aws.Credentials{}, err
}
resp := aws.Credentials{
AccessKeyID: val.AccessKeyID,
SecretAccessKey: val.SecretAccessKey,
SessionToken: val.SessionToken,
Source: val.ProviderName,
CanExpire: false,
// Don't have account ID
}

if expiration, err := p.creds.ExpiresAt(); err != nil {
resp.CanExpire = true
resp.Expires = expiration
}
return resp, nil
}

// V1ProviderToV2Provider converts a v1 credentials.Provider to a v2 aws.CredentialsProvider
func V1ProviderToV2Provider(p credentials.Provider) aws.CredentialsProvider {
return V1CredentialToV2Provider(credentials.NewCredentials(p))
}

// V1CredentialToV2Provider converts a v1 credentials.Credential to a v2 aws.CredentialProvider
func V1CredentialToV2Provider(c *credentials.Credentials) aws.CredentialsProvider {
return &v2{creds: c}
}

// V2CredentialToV1Value converts a v2 aws.Credentials to a v1 credentials.Value
func V2CredentialToV1Value(cred aws.Credentials) credentials.Value {
return credentials.Value{
AccessKeyID: cred.AccessKeyID,
SecretAccessKey: cred.SecretAccessKey,
SessionToken: cred.SessionToken,
ProviderName: cred.Source,
}
}
82 changes: 34 additions & 48 deletions pkg/filecache/filecache.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"runtime"
"time"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/gofrs/flock"
"github.com/spf13/afero"
Expand All @@ -34,7 +35,7 @@ func NewFileLocker(filename string) FileLocker {
// cacheFile is a map of clusterID/roleARNs to cached credentials
type cacheFile struct {
// a map of clusterIDs/profiles/roleARNs to cachedCredentials
ClusterMap map[string]map[string]map[string]cachedCredential `yaml:"clusters"`
ClusterMap map[string]map[string]map[string]aws.Credentials `yaml:"clusters"`
}

// a utility type for dealing with compound cache keys
Expand All @@ -44,19 +45,19 @@ type cacheKey struct {
roleARN string
}

func (c *cacheFile) Put(key cacheKey, credential cachedCredential) {
func (c *cacheFile) Put(key cacheKey, credential aws.Credentials) {
if _, ok := c.ClusterMap[key.clusterID]; !ok {
// first use of this cluster id
c.ClusterMap[key.clusterID] = map[string]map[string]cachedCredential{}
c.ClusterMap[key.clusterID] = map[string]map[string]aws.Credentials{}
}
if _, ok := c.ClusterMap[key.clusterID][key.profile]; !ok {
// first use of this profile
c.ClusterMap[key.clusterID][key.profile] = map[string]cachedCredential{}
c.ClusterMap[key.clusterID][key.profile] = map[string]aws.Credentials{}
}
c.ClusterMap[key.clusterID][key.profile][key.roleARN] = credential
}

func (c *cacheFile) Get(key cacheKey) (credential cachedCredential) {
func (c *cacheFile) Get(key cacheKey) (credential aws.Credentials) {
if _, ok := c.ClusterMap[key.clusterID]; ok {
if _, ok := c.ClusterMap[key.clusterID][key.profile]; ok {
// we at least have this cluster and profile combo in the map, if no matching roleARN, map will
Expand All @@ -67,31 +68,12 @@ func (c *cacheFile) Get(key cacheKey) (credential cachedCredential) {
return
}

// cachedCredential is a single cached credential entry, along with expiration time
type cachedCredential struct {
Credential credentials.Value
Expiration time.Time
// If set will be used by IsExpired to determine the current time.
// Defaults to time.Now if CurrentTime is not set. Available for testing
// to be able to mock out the current time.
currentTime func() time.Time
}

// IsExpired determines if the cached credential has expired
func (c *cachedCredential) IsExpired() bool {
curTime := c.currentTime
if curTime == nil {
curTime = time.Now
}
return c.Expiration.Before(curTime())
}

// readCacheWhileLocked reads the contents of the credential cache and returns the
// parsed yaml as a cacheFile object. This method must be called while a shared
// lock is held on the filename.
func readCacheWhileLocked(fs afero.Fs, filename string) (cache cacheFile, err error) {
cache = cacheFile{
map[string]map[string]map[string]cachedCredential{},
map[string]map[string]map[string]aws.Credentials{},
}
data, err := afero.ReadFile(fs, filename)
if err != nil {
Expand Down Expand Up @@ -149,9 +131,9 @@ type FileCacheProvider struct {
fs afero.Fs
filelockCreator func(string) FileLocker
filename string
credentials *credentials.Credentials // the underlying implementation that has the *real* Provider
cacheKey cacheKey // cache key parameters used to create Provider
cachedCredential cachedCredential // the cached credential, if it exists
provider aws.CredentialsProvider // the underlying implementation that has the *real* Provider
cacheKey cacheKey // cache key parameters used to create Provider
cachedCredential aws.Credentials // the cached credential, if it exists
}

var _ credentials.Provider = &FileCacheProvider{}
Expand All @@ -160,18 +142,18 @@ var _ credentials.Provider = &FileCacheProvider{}
// and works with an on disk cache to speed up credential usage when the cached copy is not expired.
// If there are any problems accessing or initializing the cache, an error will be returned, and
// callers should just use the existing credentials provider.
func NewFileCacheProvider(clusterID, profile, roleARN string, creds *credentials.Credentials, opts ...FileCacheOpt) (*FileCacheProvider, error) {
if creds == nil {
func NewFileCacheProvider(clusterID, profile, roleARN string, provider aws.CredentialsProvider, opts ...FileCacheOpt) (*FileCacheProvider, error) {
if provider == nil {
return nil, errors.New("no underlying Credentials object provided")
}

resp := &FileCacheProvider{
fs: afero.NewOsFs(),
filelockCreator: NewFileLocker,
filename: defaultCacheFilename(),
credentials: creds,
provider: provider,
cacheKey: cacheKey{clusterID, profile, roleARN},
cachedCredential: cachedCredential{},
cachedCredential: aws.Credentials{},
}

// override defaults
Expand Down Expand Up @@ -222,36 +204,40 @@ func NewFileCacheProvider(clusterID, profile, roleARN string, creds *credentials
// otherwise fetching the credential from the underlying Provider and caching the results on disk
// with an expiration time.
func (f *FileCacheProvider) Retrieve() (credentials.Value, error) {
if !f.cachedCredential.IsExpired() {
return f.RetrieveWithContext(context.Background())
}

// Retrieve() implements the Provider interface, returning the cached credential if is not expired,
// otherwise fetching the credential from the underlying Provider and caching the results on disk
// with an expiration time.
func (f *FileCacheProvider) RetrieveWithContext(ctx context.Context) (credentials.Value, error) {
if !f.cachedCredential.Expired() && f.cachedCredential.HasKeys() {
// use the cached credential
return f.cachedCredential.Credential, nil
return V2CredentialToV1Value(f.cachedCredential), nil
} else {
_, _ = fmt.Fprintf(os.Stderr, "No cached credential available. Refreshing...\n")
// fetch the credentials from the underlying Provider
credential, err := f.credentials.Get()
credential, err := f.provider.Retrieve(ctx)
if err != nil {
return credential, err
return V2CredentialToV1Value(credential), err
}
if expiration, err := f.credentials.ExpiresAt(); err == nil {
// underlying provider supports Expirer interface, so we can cache

if credential.CanExpire {
// Credential supports expiration, so we can cache

// do file locking on cache to prevent inconsistent writes
lock := f.filelockCreator(f.filename)
defer lock.Unlock()
// wait up to a second for the file to lock
ctx, cancel := context.WithTimeout(context.TODO(), time.Second)
ctx, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
ok, err := lock.TryLockContext(ctx, 250*time.Millisecond) // try to lock every 1/4 second
if !ok {
// can't get write lock to create/update cache, but still return the credential
_, _ = fmt.Fprintf(os.Stderr, "Unable to write lock file %s: %v\n", f.filename, err)
return credential, nil
}
f.cachedCredential = cachedCredential{
credential,
expiration,
nil,
return V2CredentialToV1Value(credential), nil
}
f.cachedCredential = credential
// don't really care about read error. Either read the cache, or we create a new cache.
cache, _ := readCacheWhileLocked(f.fs, f.filename)
cache.Put(f.cacheKey, f.cachedCredential)
Expand All @@ -268,19 +254,19 @@ func (f *FileCacheProvider) Retrieve() (credentials.Value, error) {
_, _ = fmt.Fprintf(os.Stderr, "Unable to cache credential: %v\n", err)
err = nil
}
return credential, err
return V2CredentialToV1Value(credential), err
}
}

// IsExpired() implements the Provider interface, deferring to the cached credential first,
// but fall back to the underlying Provider if it is expired.
func (f *FileCacheProvider) IsExpired() bool {
return f.cachedCredential.IsExpired() && f.credentials.IsExpired()
return f.cachedCredential.CanExpire && f.cachedCredential.Expired()
}

// ExpiresAt implements the Expirer interface, and gives access to the expiration time of the credential
func (f *FileCacheProvider) ExpiresAt() time.Time {
return f.cachedCredential.Expiration
return f.cachedCredential.Expires
}

// defaultCacheFilename returns the name of the credential cache file, which can either be
Expand Down
Loading

0 comments on commit 90beff7

Please sign in to comment.