diff --git a/pkg/token/token.go b/pkg/token/token.go index 16ab8d92b..351049a17 100644 --- a/pkg/token/token.go +++ b/pkg/token/token.go @@ -316,6 +316,11 @@ func (g generator) GetWithSTS(clusterID string, stsAPI stsiface.STSAPI) (Token, // override the Sign handler so we can control the now time for testing. request.Handlers.Sign.Swap("v4.SignRequestHandler", getNamedSigningHandler(g.nowFunc)) + // Fetch the timestamp when the credentials we're going to use for signing will not be valid anymore + // This operation is potentially racey, but the worst case is that we expire a token early + // Not all credential providers support this, so we ignore any returned errors + credentialsExpiration, _ := request.Config.Credentials.ExpiresAt() + // Sign the request. The expires parameter (sets the x-amz-expires header) is // currently ignored by STS, and the token expires 15 minutes after the x-amz-date // timestamp regardless. We set it to 60 seconds for backwards compatibility (the @@ -329,6 +334,9 @@ func (g generator) GetWithSTS(clusterID string, stsAPI stsiface.STSAPI) (Token, // Set token expiration to 1 minute before the presigned URL expires for some cushion tokenExpiration := g.nowFunc().Local().Add(presignedURLExpiration - 1*time.Minute) + if !credentialsExpiration.IsZero() && credentialsExpiration.Before(tokenExpiration) { + tokenExpiration = credentialsExpiration.Add(-1 * time.Minute) + } // TODO: this may need to be a constant-time base64 encoding return Token{v1Prefix + base64.RawURLEncoding.EncodeToString([]byte(presignedURLString)), tokenExpiration}, nil } diff --git a/pkg/token/token_test.go b/pkg/token/token_test.go index a8e997c86..7487f469b 100644 --- a/pkg/token/token_test.go +++ b/pkg/token/token_test.go @@ -590,6 +590,10 @@ func Test_getDefaultHostNameForRegion(t *testing.T) { func TestGetWithSTS(t *testing.T) { clusterID := "test-cluster" + // Example non-real credentials + decodedAkid, _ := base64.StdEncoding.DecodeString("QVNJQVIyVEc0NFY2QVMzWlpFN0M=") + decodedSk, _ := base64.StdEncoding.DecodeString("NEtENWNudEdjVm1MV1JkRjV3dk5SdXpOTDVReG1wNk9LVlk2RnovUQ==") + cases := []struct { name string creds *credentials.Credentials @@ -598,23 +602,39 @@ func TestGetWithSTS(t *testing.T) { wantErr error }{ { - "Non-zero time", - // Example non-real credentials - func() *credentials.Credentials { - decodedAkid, _ := base64.StdEncoding.DecodeString("QVNJQVIyVEc0NFY2QVMzWlpFN0M=") - decodedSk, _ := base64.StdEncoding.DecodeString("NEtENWNudEdjVm1MV1JkRjV3dk5SdXpOTDVReG1wNk9LVlk2RnovUQ==") - return credentials.NewStaticCredentials( - string(decodedAkid), - string(decodedSk), - "", - ) - }(), - time.Unix(1682640000, 0), - Token{ + name: "Non-zero time", + creds: credentials.NewStaticCredentials( + string(decodedAkid), + string(decodedSk), + "", + ), + nowTime: time.Unix(1682640000, 0), + want: Token{ Token: "k8s-aws-v1.aHR0cHM6Ly9zdHMudXMtd2VzdC0yLmFtYXpvbmF3cy5jb20vP0FjdGlvbj1HZXRDYWxsZXJJZGVudGl0eSZWZXJzaW9uPTIwMTEtMDYtMTUmWC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BU0lBUjJURzQ0VjZBUzNaWkU3QyUyRjIwMjMwNDI4JTJGdXMtd2VzdC0yJTJGc3RzJTJGYXdzNF9yZXF1ZXN0JlgtQW16LURhdGU9MjAyMzA0MjhUMDAwMDAwWiZYLUFtei1FeHBpcmVzPTYwJlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCUzQngtazhzLWF3cy1pZCZYLUFtei1TaWduYXR1cmU9ZTIxMWRiYTc3YWJhOWRjNDRiMGI2YmUzOGI4ZWFhZDA5MjU5OWM1MTU3ZjYzMTQ0NDRjNWI5ZDg1NzQ3ZjVjZQ", Expiration: time.Unix(1682640000, 0).Local().Add(time.Minute * 14), }, - nil, + wantErr: nil, + }, + { + name: "Signing creds expire before token", + creds: func() *credentials.Credentials { + + c := credentials.NewCredentials(&fakeCredentialProvider{ + value: credentials.Value{ + AccessKeyID: string(decodedAkid), + SecretAccessKey: string(decodedSk), + }, + expiresAt: time.Unix(1682640000, 0).Local().Add(time.Minute * 10), + }) + _, _ = c.Get() + return c + }(), + nowTime: time.Unix(1682640000, 0), + want: Token{ + Token: "k8s-aws-v1.aHR0cHM6Ly9zdHMudXMtd2VzdC0yLmFtYXpvbmF3cy5jb20vP0FjdGlvbj1HZXRDYWxsZXJJZGVudGl0eSZWZXJzaW9uPTIwMTEtMDYtMTUmWC1BbXotQWxnb3JpdGhtPUFXUzQtSE1BQy1TSEEyNTYmWC1BbXotQ3JlZGVudGlhbD1BU0lBUjJURzQ0VjZBUzNaWkU3QyUyRjIwMjMwNDI4JTJGdXMtd2VzdC0yJTJGc3RzJTJGYXdzNF9yZXF1ZXN0JlgtQW16LURhdGU9MjAyMzA0MjhUMDAwMDAwWiZYLUFtei1FeHBpcmVzPTYwJlgtQW16LVNpZ25lZEhlYWRlcnM9aG9zdCUzQngtazhzLWF3cy1pZCZYLUFtei1TaWduYXR1cmU9ZTIxMWRiYTc3YWJhOWRjNDRiMGI2YmUzOGI4ZWFhZDA5MjU5OWM1MTU3ZjYzMTQ0NDRjNWI5ZDg1NzQ3ZjVjZQ", + Expiration: time.Unix(1682640000, 0).Local().Add(time.Minute * 9), + }, + wantErr: nil, }, } @@ -646,3 +666,22 @@ func TestGetWithSTS(t *testing.T) { }) } } + +type fakeCredentialProvider struct { + value credentials.Value + expiresAt time.Time +} + +func (f *fakeCredentialProvider) Retrieve() (credentials.Value, error) { + return f.value, nil +} + +func (f *fakeCredentialProvider) IsExpired() bool { + return false +} + +var _ credentials.Expirer = (*fakeCredentialProvider)(nil) + +func (f *fakeCredentialProvider) ExpiresAt() time.Time { + return f.expiresAt +}