From 51bf911b467700e2910625aeb9a59836f85faa9a Mon Sep 17 00:00:00 2001 From: Micah Hausler Date: Mon, 26 Aug 2024 12:53:24 -0500 Subject: [PATCH] WIP: Update to aws-sdk-go-v2 --- cmd/aws-iam-authenticator/verify.go | 16 ++- go.mod | 15 ++- go.sum | 28 ++++- pkg/arn/arn.go | 2 +- pkg/ec2provider/ec2provider.go | 143 +++++++++--------------- pkg/ec2provider/ec2provider_test.go | 79 ++++--------- pkg/ec2provider/source_headers.go | 66 +++++++++++ pkg/filecache/converter.go | 55 --------- pkg/filecache/filecache.go | 35 ++---- pkg/filecache/filecache_test.go | 29 ++--- pkg/server/server.go | 44 +++++--- pkg/token/cluster_id_header.go | 45 ++++++++ pkg/token/token.go | 167 ++++++++++++++-------------- pkg/token/token_test.go | 42 ++++--- 14 files changed, 390 insertions(+), 376 deletions(-) create mode 100644 pkg/ec2provider/source_headers.go delete mode 100644 pkg/filecache/converter.go create mode 100644 pkg/token/cluster_id_header.go diff --git a/cmd/aws-iam-authenticator/verify.go b/cmd/aws-iam-authenticator/verify.go index 9bb37f4f4..839b19ef8 100644 --- a/cmd/aws-iam-authenticator/verify.go +++ b/cmd/aws-iam-authenticator/verify.go @@ -25,9 +25,9 @@ import ( "sigs.k8s.io/aws-iam-authenticator/pkg/token" - "github.com/aws/aws-sdk-go/aws/ec2metadata" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" "github.com/aws/aws-sdk-go/aws/endpoints" - "github.com/aws/aws-sdk-go/aws/session" "github.com/spf13/cobra" "github.com/spf13/viper" ) @@ -54,14 +54,18 @@ var verifyCmd = &cobra.Command{ os.Exit(1) } - sess := session.Must(session.NewSession()) - ec2metadata := ec2metadata.New(sess) - instanceRegion, err := ec2metadata.Region() + cfg, err := config.LoadDefaultConfig(cmd.Context()) + if err != nil { + fmt.Printf("Error constructing aws config: %v", err) + os.Exit(1) + } + client := imds.NewFromConfig(cfg) + resp, err := client.GetRegion(cmd.Context(), nil) if err != nil { fmt.Printf("[Warn] Region not found in instance metadata, err: %v", err) } - id, err := token.NewVerifier(clusterID, partition, instanceRegion).Verify(tok) + id, err := token.NewVerifier(clusterID, partition, resp.Region).Verify(tok) if err != nil { fmt.Fprintf(os.Stderr, "could not verify token: %v\n", err) os.Exit(1) diff --git a/go.mod b/go.mod index 58736482f..8fbd7f9a3 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,12 @@ go 1.22.5 require ( github.com/aws/aws-sdk-go v1.54.6 github.com/aws/aws-sdk-go-v2 v1.30.4 + github.com/aws/aws-sdk-go-v2/config v1.27.30 + github.com/aws/aws-sdk-go-v2/credentials v1.17.29 + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.12 + github.com/aws/aws-sdk-go-v2/service/ec2 v1.176.0 + github.com/aws/aws-sdk-go-v2/service/sts v1.30.5 + github.com/aws/smithy-go v1.20.4 github.com/fsnotify/fsnotify v1.7.0 github.com/gofrs/flock v0.8.1 github.com/google/go-cmp v0.6.0 @@ -26,13 +32,20 @@ require ( ) require ( + github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.16 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.16 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.4 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.18 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.22.5 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.5 // indirect github.com/aws/smithy-go v1.20.4 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/blang/semver/v4 v4.0.0 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect - github.com/emicklei/go-restful/v3 v3.11.1 // indirect + github.com/emicklei/go-restful/v3 v3.11.3 // indirect github.com/fxamacker/cbor/v2 v2.7.0 // indirect github.com/go-logr/logr v1.4.2 // indirect github.com/go-openapi/jsonpointer v0.20.2 // indirect diff --git a/go.sum b/go.sum index 9d1fc5538..05290a6e2 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,30 @@ github.com/aws/aws-sdk-go v1.54.6 h1:HEYUib3yTt8E6vxjMWM3yAq5b+qjj/6aKA62mkgux9g github.com/aws/aws-sdk-go v1.54.6/go.mod h1:eRwEWoyTWFMVYVQzKMNHWP5/RV4xIUGMQfXQHfHkpNU= github.com/aws/aws-sdk-go-v2 v1.30.4 h1:frhcagrVNrzmT95RJImMHgabt99vkXGslubDaDagTk8= github.com/aws/aws-sdk-go-v2 v1.30.4/go.mod h1:CT+ZPWXbYrci8chcARI3OmI/qgd+f6WtuLOoaIA8PR0= +github.com/aws/aws-sdk-go-v2/config v1.27.30 h1:AQF3/+rOgeJBQP3iI4vojlPib5X6eeOYoa/af7OxAYg= +github.com/aws/aws-sdk-go-v2/config v1.27.30/go.mod h1:yxqvuubha9Vw8stEgNiStO+yZpP68Wm9hLmcm+R/Qk4= +github.com/aws/aws-sdk-go-v2/credentials v1.17.29 h1:CwGsupsXIlAFYuDVHv1nnK0wnxO0wZ/g1L8DSK/xiIw= +github.com/aws/aws-sdk-go-v2/credentials v1.17.29/go.mod h1:BPJ/yXV92ZVq6G8uYvbU0gSl8q94UB63nMT5ctNO38g= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.12 h1:yjwoSyDZF8Jth+mUk5lSPJCkMC0lMy6FaCD51jm6ayE= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.12/go.mod h1:fuR57fAgMk7ot3WcNQfb6rSEn+SUffl7ri+aa8uKysI= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.16 h1:TNyt/+X43KJ9IJJMjKfa3bNTiZbUP7DeCxfbTROESwY= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.3.16/go.mod h1:2DwJF39FlNAUiX5pAc0UNeiz16lK2t7IaFcm0LFHEgc= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.16 h1:jYfy8UPmd+6kJW5YhY0L1/KftReOGxI/4NtVSTh9O/I= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.6.16/go.mod h1:7ZfEPZxkW42Afq4uQB8H2E2e6ebh6mXTueEpYzjCzcs= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 h1:VaRN3TlFdd6KxX1x3ILT5ynH6HvKgqdiXoTxAF4HQcQ= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1/go.mod h1:FbtygfRFze9usAadmnGJNc8KsP346kEe+y2/oyhGAGc= +github.com/aws/aws-sdk-go-v2/service/ec2 v1.176.0 h1:fWhkSvaQqa5eWiRwBw10FUnk1YatAQ9We4GdGxKiCtg= +github.com/aws/aws-sdk-go-v2/service/ec2 v1.176.0/go.mod h1:ISODge3zgdwOEa4Ou6WM9PKbxJWJ15DYKnr2bfmCAIA= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.4 h1:KypMCbLPPHEmf9DgMGw51jMj77VfGPAN2Kv4cfhlfgI= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.4/go.mod h1:Vz1JQXliGcQktFTN/LN6uGppAIRoLBR2bMvIMP0gOjc= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.18 h1:tJ5RnkHCiSH0jyd6gROjlJtNwov0eGYNz8s8nFcR0jQ= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.18/go.mod h1:++NHzT+nAF7ZPrHPsA+ENvsXkOO8wEu+C6RXltAG4/c= +github.com/aws/aws-sdk-go-v2/service/sso v1.22.5 h1:zCsFCKvbj25i7p1u94imVoO447I/sFv8qq+lGJhRN0c= +github.com/aws/aws-sdk-go-v2/service/sso v1.22.5/go.mod h1:ZeDX1SnKsVlejeuz41GiajjZpRSWR7/42q/EyA/QEiM= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.5 h1:SKvPgvdvmiTWoi0GAJ7AsJfOz3ngVkD/ERbs5pUnHNI= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.5/go.mod h1:20sz31hv/WsPa3HhU3hfrIet2kxM4Pe0r20eBZ20Tac= +github.com/aws/aws-sdk-go-v2/service/sts v1.30.5 h1:OMsEmCyz2i89XwRwPouAJvhj81wINh+4UK+k/0Yo/q8= +github.com/aws/aws-sdk-go-v2/service/sts v1.30.5/go.mod h1:vmSqFK+BVIwVpDAGZB3CoCXHzurt4qBE8lf+I/kRTh0= github.com/aws/smithy-go v1.20.4 h1:2HK1zBdPgRbjFOHlfeQZfpC4r72MOb9bZkiFwggKO+4= github.com/aws/smithy-go v1.20.4/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= @@ -21,8 +45,8 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/emicklei/go-restful/v3 v3.11.1 h1:S+9bSbua1z3FgCnV0KKOSSZ3mDthb5NyEPL5gEpCvyk= -github.com/emicklei/go-restful/v3 v3.11.1/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc= +github.com/emicklei/go-restful/v3 v3.11.3 h1:yagOQz/38xJmcNeZJtrUcKjkHRltIaIFXKWeG1SkWGE= +github.com/emicklei/go-restful/v3 v3.11.3/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= diff --git a/pkg/arn/arn.go b/pkg/arn/arn.go index e9b73b587..22900c96d 100644 --- a/pkg/arn/arn.go +++ b/pkg/arn/arn.go @@ -4,7 +4,7 @@ import ( "fmt" "strings" - awsarn "github.com/aws/aws-sdk-go/aws/arn" + awsarn "github.com/aws/aws-sdk-go-v2/aws/arn" "github.com/aws/aws-sdk-go/aws/endpoints" ) diff --git a/pkg/ec2provider/ec2provider.go b/pkg/ec2provider/ec2provider.go index d760f0bda..5d18ec720 100644 --- a/pkg/ec2provider/ec2provider.go +++ b/pkg/ec2provider/ec2provider.go @@ -1,25 +1,24 @@ package ec2provider import ( + "context" "errors" "fmt" "sync" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/arn" - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/aws/credentials/stscreds" - "github.com/aws/aws-sdk-go/aws/endpoints" - "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/ec2" - "github.com/aws/aws-sdk-go/service/ec2/ec2iface" - "github.com/aws/aws-sdk-go/service/sts" - "github.com/sirupsen/logrus" "sigs.k8s.io/aws-iam-authenticator/pkg" "sigs.k8s.io/aws-iam-authenticator/pkg/httputil" "sigs.k8s.io/aws-iam-authenticator/pkg/metrics" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/middleware" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials/stscreds" + "github.com/aws/aws-sdk-go-v2/service/ec2" + "github.com/aws/aws-sdk-go-v2/service/sts" + smithymiddleware "github.com/aws/smithy-go/middleware" + "github.com/sirupsen/logrus" ) const ( @@ -36,11 +35,6 @@ const ( // Maximum time in Milliseconds to wait for a new batch call this also depends on if the instance size has // already become 100 then it will not respect this limit maxWaitIntervalForBatch = 200 - - // Headers for STS request for source ARN - headerSourceArn = "x-amz-source-arn" - // Headers for STS request for source account - headerSourceAccount = "x-amz-source-account" ) // Get a node name from instance ID @@ -60,13 +54,13 @@ type ec2Requests struct { } type ec2ProviderImpl struct { - ec2 ec2iface.EC2API + ec2 ec2.DescribeInstancesAPIClient privateDNSCache ec2PrivateDNSCache ec2Requests ec2Requests instanceIdsChannel chan string } -func New(roleARN, sourceARN, region string, qps int, burst int) EC2Provider { +func New(roleARN, sourceARN, region string, qps int, burst int) (EC2Provider, error) { dnsCache := ec2PrivateDNSCache{ cache: make(map[string]string), lock: sync.RWMutex{}, @@ -75,50 +69,56 @@ func New(roleARN, sourceARN, region string, qps int, burst int) EC2Provider { set: make(map[string]bool), lock: sync.RWMutex{}, } + cfg, err := newConfig(roleARN, sourceARN, region, qps, burst) + if err != nil { + return nil, err + } + return &ec2ProviderImpl{ - ec2: ec2.New(newSession(roleARN, sourceARN, region, qps, burst)), + ec2: ec2.NewFromConfig(cfg), privateDNSCache: dnsCache, ec2Requests: ec2Requests, instanceIdsChannel: make(chan string, maxChannelSize), - } + }, nil } -// Initial credentials loaded from SDK's default credential chain, such as -// the environment, shared credentials (~/.aws/credentials), or EC2 Instance -// Role. - -func newSession(roleARN, sourceARN, region string, qps int, burst int) *session.Session { - sess := session.Must(session.NewSession()) - sess.Handlers.Build.PushFrontNamed(request.NamedHandler{ - Name: "authenticatorUserAgent", - Fn: request.MakeAddToUserAgentHandler( - "aws-iam-authenticator", pkg.Version), - }) - if aws.StringValue(sess.Config.Region) == "" { - sess.Config.Region = aws.String(region) +func newConfig(roleARN, sourceArn, region string, qps, burst int) (aws.Config, error) { + rateLimitedClient, err := httputil.NewRateLimitedClient(qps, burst) + if err != nil { + logrus.Errorf("error creating rate limited client %s", err) + return aws.Config{}, err + } + loadOpts := []func(*config.LoadOptions) error{ + config.WithRegion(region), + config.WithAPIOptions( + []func(*smithymiddleware.Stack) error{ + middleware.AddUserAgentKeyValue("aws-iam-authenticator", pkg.Version), + }), + config.WithHTTPClient(rateLimitedClient), } - if roleARN != "" { logrus.WithFields(logrus.Fields{ "roleARN": roleARN, }).Infof("Using assumed role for EC2 API") - rateLimitedClient, err := httputil.NewRateLimitedClient(qps, burst) - + cfg, err := config.LoadDefaultConfig(context.Background(), loadOpts...) if err != nil { - logrus.Errorf("Getting error = %s while creating rate limited client ", err) + logrus.Errorf("error loading AWS config %s", err) + return aws.Config{}, err } - - stsClient := applySTSRequestHeaders(sts.New(sess, aws.NewConfig().WithHTTPClient(rateLimitedClient).WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint)), sourceARN) - ap := &stscreds.AssumeRoleProvider{ - Client: stsClient, - RoleARN: roleARN, - Duration: time.Duration(60) * time.Minute, + stsOpts := []func(*sts.Options){} + if sourceArn != "" { + stsOpts = append(stsOpts, WithSourceHeaders(sourceArn)) } - sess.Config.Credentials = credentials.NewCredentials(ap) + stsCli := sts.NewFromConfig(cfg, stsOpts...) + creds := stscreds.NewAssumeRoleProvider(stsCli, roleARN, + func(o *stscreds.AssumeRoleOptions) { + o.Duration = time.Duration(60) * time.Minute + }) + loadOpts = append(loadOpts, config.WithCredentialsProvider(creds)) } - return sess + return config.LoadDefaultConfig(context.Background(), loadOpts...) } func (p *ec2ProviderImpl) setPrivateDNSNameCache(id string, privateDNSName string) { @@ -197,8 +197,8 @@ func (p *ec2ProviderImpl) GetPrivateDNSName(id string) (string, error) { logrus.Infof("Calling ec2:DescribeInstances for the InstanceId = %s ", id) metrics.Get().EC2DescribeInstanceCallCount.Inc() // Look up instance from EC2 API - output, err := p.ec2.DescribeInstances(&ec2.DescribeInstancesInput{ - InstanceIds: aws.StringSlice([]string{id}), + output, err := p.ec2.DescribeInstances(context.Background(), &ec2.DescribeInstancesInput{ + InstanceIds: []string{id}, }) if err != nil { p.unsetRequestInFlightForInstanceId(id) @@ -206,8 +206,8 @@ func (p *ec2ProviderImpl) GetPrivateDNSName(id string) (string, error) { } for _, reservation := range output.Reservations { for _, instance := range reservation.Instances { - if aws.StringValue(instance.InstanceId) == id { - privateDNSName = aws.StringValue(instance.PrivateDnsName) + if aws.ToString(instance.InstanceId) == id { + privateDNSName = aws.ToString(instance.PrivateDnsName) p.setPrivateDNSNameCache(id, privateDNSName) p.unsetRequestInFlightForInstanceId(id) } @@ -258,8 +258,8 @@ func (p *ec2ProviderImpl) getPrivateDnsAndPublishToCache(instanceIdList []string // Look up instance from EC2 API logrus.Infof("Making Batch Query to DescribeInstances for %v instances ", len(instanceIdList)) metrics.Get().EC2DescribeInstanceCallCount.Inc() - output, err := p.ec2.DescribeInstances(&ec2.DescribeInstancesInput{ - InstanceIds: aws.StringSlice(instanceIdList), + output, err := p.ec2.DescribeInstances(context.Background(), &ec2.DescribeInstancesInput{ + InstanceIds: instanceIdList, }) if err != nil { logrus.Errorf("Batch call failed querying private DNS from EC2 API for nodes [%s] : with error = []%s ", instanceIdList, err.Error()) @@ -272,8 +272,8 @@ func (p *ec2ProviderImpl) getPrivateDnsAndPublishToCache(instanceIdList []string // Adding the result to privateDNSChache as well as removing from the requestQueueMap. for _, reservation := range output.Reservations { for _, instance := range reservation.Instances { - id := aws.StringValue(instance.InstanceId) - privateDNSName := aws.StringValue(instance.PrivateDnsName) + id := aws.ToString(instance.InstanceId) + privateDNSName := aws.ToString(instance.PrivateDnsName) p.setPrivateDNSNameCache(id, privateDNSName) } } @@ -284,40 +284,3 @@ func (p *ec2ProviderImpl) getPrivateDnsAndPublishToCache(instanceIdList []string p.unsetRequestInFlightForInstanceId(id) } } - -func applySTSRequestHeaders(stsClient *sts.STS, sourceARN string) *sts.STS { - // parse both source account and source arn from the sourceARN, and add them as headers to the STS client - if sourceARN != "" { - sourceAcct, err := getSourceAccount(sourceARN) - if err != nil { - panic(fmt.Sprintf("%s is not a valid arn, err: %v", sourceARN, err)) - } - reqHeaders := map[string]string{ - headerSourceAccount: sourceAcct, - headerSourceArn: sourceARN, - } - stsClient.Handlers.Sign.PushFront(func(s *request.Request) { - s.ApplyOptions(request.WithSetRequestHeaders(reqHeaders)) - }) - logrus.Infof("configuring STS client with extra headers, %v", reqHeaders) - } - return stsClient -} - -// getSourceAccount constructs source acct and return them for use -func getSourceAccount(roleARN string) (string, error) { - // ARN format (https://docs.aws.amazon.com/IAM/latest/UserGuide/reference-arns.html) - // arn:partition:service:region:account-id:resource-type/resource-id - // IAM format, region is always blank - // arn:aws:iam::account:role/role-name-with-path - if !arn.IsARN(roleARN) { - return "", fmt.Errorf("incorrect ARN format for role %s", roleARN) - } - - parsedArn, err := arn.Parse(roleARN) - if err != nil { - return "", err - } - - return parsedArn.AccountID, nil -} diff --git a/pkg/ec2provider/ec2provider_test.go b/pkg/ec2provider/ec2provider_test.go index 912d73c8c..21ebb94aa 100644 --- a/pkg/ec2provider/ec2provider_test.go +++ b/pkg/ec2provider/ec2provider_test.go @@ -1,14 +1,15 @@ package ec2provider import ( + "context" "strconv" "sync" "testing" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/service/ec2" - "github.com/aws/aws-sdk-go/service/ec2/ec2iface" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/ec2" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" "github.com/prometheus/client_golang/prometheus" "sigs.k8s.io/aws-iam-authenticator/pkg/metrics" ) @@ -18,25 +19,24 @@ const ( ) type mockEc2Client struct { - ec2iface.EC2API - Reservations []*ec2.Reservation + Reservations []ec2types.Reservation } -func (c *mockEc2Client) DescribeInstances(in *ec2.DescribeInstancesInput) (*ec2.DescribeInstancesOutput, error) { +func (c *mockEc2Client) DescribeInstances(ctx context.Context, in *ec2.DescribeInstancesInput, _ ...func(*ec2.Options)) (*ec2.DescribeInstancesOutput, error) { // simulate the time it takes for aws to return time.Sleep(DescribeDelay * time.Millisecond) - var reservations []*ec2.Reservation + var reservations []ec2types.Reservation for _, res := range c.Reservations { - var reservation ec2.Reservation + var reservation ec2types.Reservation for _, inst := range res.Instances { for _, id := range in.InstanceIds { - if aws.StringValue(id) == aws.StringValue(inst.InstanceId) { + if id == aws.ToString(inst.InstanceId) { reservation.Instances = append(reservation.Instances, inst) } } } if len(reservation.Instances) > 0 { - reservations = append(reservations, &reservation) + reservations = append(reservations, reservation) } } return &ec2.DescribeInstancesOutput{ @@ -76,12 +76,12 @@ func TestGetPrivateDNSName(t *testing.T) { } } -func prepareSingleInstanceOutput() []*ec2.Reservation { - reservations := []*ec2.Reservation{ +func prepareSingleInstanceOutput() []ec2types.Reservation { + reservations := []ec2types.Reservation{ { Groups: nil, - Instances: []*ec2.Instance{ - &ec2.Instance{ + Instances: []ec2types.Instance{ + ec2types.Instance{ InstanceId: aws.String("ec2-1"), PrivateDnsName: aws.String("ec2-dns-1"), }, @@ -125,20 +125,20 @@ func getPrivateDNSName(ec2provider *ec2ProviderImpl, instanceString string, dnsS } } -func prepare100InstanceOutput() []*ec2.Reservation { +func prepare100InstanceOutput() []ec2types.Reservation { - var reservations []*ec2.Reservation + var reservations []ec2types.Reservation for i := 1; i < 101; i++ { instanceString := "ec2-" + strconv.Itoa(i) dnsString := "ec2-dns-" + strconv.Itoa(i) - instance := &ec2.Instance{ + instance := ec2types.Instance{ InstanceId: aws.String(instanceString), PrivateDnsName: aws.String(dnsString), } - var instances []*ec2.Instance + var instances []ec2types.Instance instances = append(instances, instance) - res1 := &ec2.Reservation{ + res1 := ec2types.Reservation{ Groups: nil, Instances: instances, OwnerId: nil, @@ -150,44 +150,3 @@ func prepare100InstanceOutput() []*ec2.Reservation { return reservations } - -func TestGetSourceAcctAndArn(t *testing.T) { - type args struct { - roleARN string - } - tests := []struct { - name string - args args - want string - wantErr bool - }{ - { - name: "corect role arn", - args: args{ - roleARN: "arn:aws:iam::123456789876:role/test-cluster", - }, - want: "123456789876", - wantErr: false, - }, - { - name: "incorect role arn", - args: args{ - roleARN: "arn:aws:iam::123456789876", - }, - want: "", - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := getSourceAccount(tt.args.roleARN) - if (err != nil) != tt.wantErr { - t.Errorf("GetSourceAccount() error = %v, wantErr %v", err, tt.wantErr) - return - } - if got != tt.want { - t.Errorf("GetSourceAccount() got = %v, want %v", got, tt.want) - } - }) - } -} diff --git a/pkg/ec2provider/source_headers.go b/pkg/ec2provider/source_headers.go new file mode 100644 index 000000000..f2473a689 --- /dev/null +++ b/pkg/ec2provider/source_headers.go @@ -0,0 +1,66 @@ +package ec2provider + +import ( + "context" + "fmt" + + "github.com/aws/aws-sdk-go-v2/aws/arn" + smithyhttp "github.com/aws/smithy-go/transport/http" + + "github.com/aws/aws-sdk-go-v2/service/sts" + smithymiddleware "github.com/aws/smithy-go/middleware" +) + +const ( + // Headers for STS request for source ARN + headerSourceArn = "x-amz-source-arn" + // Headers for STS request for source account + headerSourceAccount = "x-amz-source-account" +) + +type withSourceHeaders struct { + sourceARN string +} + +// implements middleware.BuildMiddleware, which runs AFTER a request has been +// serialized and can operate on the transport request +var _ smithymiddleware.BuildMiddleware = (*withSourceHeaders)(nil) + +func (*withSourceHeaders) ID() string { + return "withSourceHeaders" +} + +func (m *withSourceHeaders) HandleBuild(ctx context.Context, in smithymiddleware.BuildInput, next smithymiddleware.BuildHandler) ( + out smithymiddleware.BuildOutput, metadata smithymiddleware.Metadata, err error, +) { + req, ok := in.Request.(*smithyhttp.Request) + if !ok { + return out, metadata, fmt.Errorf("unrecognized transport type %T", in.Request) + } + + if arn.IsARN(m.sourceARN) { + req.Header.Set(headerSourceArn, m.sourceARN) + } + + if parsedArn, err := arn.Parse(m.sourceARN); err == nil && parsedArn.AccountID != "" { + req.Header.Set(headerSourceAccount, parsedArn.AccountID) + } + + return next.HandleBuild(ctx, in) +} + +// WithSourceHeaders adds the x-amz-source-arn and x-amz-source-account headers to the request. +// These can be referenced in an IAM role trust policy document with the condition keys +// aws:SourceArn and aws:SourceAccount for sts:AssumeRole calls +// +// If the sourceARN is invalid, the source arn header is skipped. If the ARN is valid but doesn't +// contain an account ID, the source account header is skipped +func WithSourceHeaders(sourceARN string) func(*sts.Options) { + return func(o *sts.Options) { + o.APIOptions = append(o.APIOptions, func(s *smithymiddleware.Stack) error { + return s.Build.Add(&withSourceHeaders{ + sourceARN: sourceARN, + }, smithymiddleware.After) + }) + } +} diff --git a/pkg/filecache/converter.go b/pkg/filecache/converter.go deleted file mode 100644 index ec2f16bde..000000000 --- a/pkg/filecache/converter.go +++ /dev/null @@ -1,55 +0,0 @@ -package filecache - -import ( - "context" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go/aws/credentials" -) - -type v2 struct { - creds *credentials.Credentials -} - -var _ aws.CredentialsProvider = &v2{} - -func (p *v2) Retrieve(ctx context.Context) (aws.Credentials, error) { - val, err := p.creds.GetWithContext(ctx) - if err != nil { - return aws.Credentials{}, err - } - resp := aws.Credentials{ - AccessKeyID: val.AccessKeyID, - SecretAccessKey: val.SecretAccessKey, - SessionToken: val.SessionToken, - Source: val.ProviderName, - CanExpire: false, - // Don't have account ID - } - - if expiration, err := p.creds.ExpiresAt(); err != nil { - resp.CanExpire = true - resp.Expires = expiration - } - return resp, nil -} - -// V1ProviderToV2Provider converts a v1 credentials.Provider to a v2 aws.CredentialsProvider -func V1ProviderToV2Provider(p credentials.Provider) aws.CredentialsProvider { - return V1CredentialToV2Provider(credentials.NewCredentials(p)) -} - -// V1CredentialToV2Provider converts a v1 credentials.Credential to a v2 aws.CredentialProvider -func V1CredentialToV2Provider(c *credentials.Credentials) aws.CredentialsProvider { - return &v2{creds: c} -} - -// V2CredentialToV1Value converts a v2 aws.Credentials to a v1 credentials.Value -func V2CredentialToV1Value(cred aws.Credentials) credentials.Value { - return credentials.Value{ - AccessKeyID: cred.AccessKeyID, - SecretAccessKey: cred.SecretAccessKey, - SessionToken: cred.SessionToken, - ProviderName: cred.Source, - } -} diff --git a/pkg/filecache/filecache.go b/pkg/filecache/filecache.go index 64092b9f4..3d8624a2c 100644 --- a/pkg/filecache/filecache.go +++ b/pkg/filecache/filecache.go @@ -11,7 +11,6 @@ import ( "time" "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go/aws/credentials" "github.com/gofrs/flock" "github.com/spf13/afero" "gopkg.in/yaml.v2" @@ -136,7 +135,7 @@ type FileCacheProvider struct { cachedCredential aws.Credentials // the cached credential, if it exists } -var _ credentials.Provider = &FileCacheProvider{} +var _ aws.CredentialsProvider = &FileCacheProvider{} // NewFileCacheProvider creates a new Provider implementation that wraps a provided Credentials, // and works with an on disk cache to speed up credential usage when the cached copy is not expired. @@ -200,26 +199,19 @@ func NewFileCacheProvider(clusterID, profile, roleARN string, provider aws.Crede return resp, nil } -// Retrieve() implements the Provider interface, returning the cached credential if is not expired, -// otherwise fetching the credential from the underlying Provider and caching the results on disk +// Retrieve() implements the aws.CredentialsProvider interface, returning the cached credential if is not expired, +// otherwise fetching the credential from the underlying CredentialProvider and caching the results on disk // with an expiration time. -func (f *FileCacheProvider) Retrieve() (credentials.Value, error) { - return f.RetrieveWithContext(context.Background()) -} - -// Retrieve() implements the Provider interface, returning the cached credential if is not expired, -// otherwise fetching the credential from the underlying Provider and caching the results on disk -// with an expiration time. -func (f *FileCacheProvider) RetrieveWithContext(ctx context.Context) (credentials.Value, error) { +func (f *FileCacheProvider) Retrieve(ctx context.Context) (aws.Credentials, error) { if !f.cachedCredential.Expired() && f.cachedCredential.HasKeys() { // use the cached credential - return V2CredentialToV1Value(f.cachedCredential), nil + return f.cachedCredential, nil } else { _, _ = fmt.Fprintf(os.Stderr, "No cached credential available. Refreshing...\n") // fetch the credentials from the underlying Provider credential, err := f.provider.Retrieve(ctx) if err != nil { - return V2CredentialToV1Value(credential), err + return credential, err } if credential.CanExpire { @@ -235,7 +227,7 @@ func (f *FileCacheProvider) RetrieveWithContext(ctx context.Context) (credential if !ok { // can't get write lock to create/update cache, but still return the credential _, _ = fmt.Fprintf(os.Stderr, "Unable to write lock file %s: %v\n", f.filename, err) - return V2CredentialToV1Value(credential), nil + return credential, nil } f.cachedCredential = credential // don't really care about read error. Either read the cache, or we create a new cache. @@ -254,21 +246,10 @@ func (f *FileCacheProvider) RetrieveWithContext(ctx context.Context) (credential _, _ = fmt.Fprintf(os.Stderr, "Unable to cache credential: %v\n", err) err = nil } - return V2CredentialToV1Value(credential), err + return credential, err } } -// IsExpired() implements the Provider interface, deferring to the cached credential first, -// but fall back to the underlying Provider if it is expired. -func (f *FileCacheProvider) IsExpired() bool { - return f.cachedCredential.CanExpire && f.cachedCredential.Expired() -} - -// ExpiresAt implements the Expirer interface, and gives access to the expiration time of the credential -func (f *FileCacheProvider) ExpiresAt() time.Time { - return f.cachedCredential.Expires -} - // defaultCacheFilename returns the name of the credential cache file, which can either be // set by environment variable, or use the default of ~/.kube/cache/aws-iam-authenticator/credentials.yaml func defaultCacheFilename() string { diff --git a/pkg/filecache/filecache_test.go b/pkg/filecache/filecache_test.go index f2db98556..7c8fbaafd 100644 --- a/pkg/filecache/filecache_test.go +++ b/pkg/filecache/filecache_test.go @@ -383,10 +383,6 @@ func TestNewFileCacheProvider_ExistingARN(t *testing.T) { t.Errorf("Cached credential should not be expired") } - if p.ExpiresAt() != p.cachedCredential.Expires { - t.Errorf("Credential expiration time is not correct, expected %v, got %v", - p.cachedCredential.Expires, p.ExpiresAt()) - } } func TestFileCacheProvider_Retrieve_NoExpirer(t *testing.T) { @@ -407,7 +403,7 @@ func TestFileCacheProvider_Retrieve_NoExpirer(t *testing.T) { ) validateFileCacheProvider(t, p, err, provider) - credential, err := p.Retrieve() + credential, err := p.Retrieve(context.Background()) if err != nil { t.Errorf("Unexpected error: %v", err) } @@ -442,12 +438,12 @@ func TestFileCacheProvider_Retrieve_WithExpirer_Unlockable(t *testing.T) { tfl.success = false tfl.err = errors.New("lock stuck, needs wd-40") - credential, err := p.Retrieve() + credential, err := p.Retrieve(context.Background()) if err != nil { t.Errorf("Unexpected error: %v", err) } if credential.AccessKeyID != "AKID" || credential.SecretAccessKey != "SECRET" || - credential.SessionToken != "TOKEN" || credential.ProviderName != "stubProvider" { + credential.SessionToken != "TOKEN" || credential.Source != "stubProvider" { t.Errorf("cached credential not extracted correctly, got %v", p.cachedCredential) } } @@ -471,14 +467,14 @@ func TestFileCacheProvider_Retrieve_WithExpirer_Unwritable(t *testing.T) { ) validateFileCacheProvider(t, p, err, provider) - credential, err := p.Retrieve() + credential, err := p.Retrieve(context.Background()) if err != nil { t.Errorf("Unexpected error: %v", err) } if credential.AccessKeyID != provider.creds.AccessKeyID || credential.SecretAccessKey != provider.creds.SecretAccessKey || credential.SessionToken != provider.creds.SessionToken || - credential.ProviderName != provider.creds.Source { + credential.Source != provider.creds.Source { t.Errorf("cached credential not extracted correctly, got %v", p.cachedCredential) } @@ -525,14 +521,14 @@ func TestFileCacheProvider_Retrieve_WithExpirer_Writable(t *testing.T) { // retrieve credential, which will fetch from underlying Provider // same as TestFileCacheProvider_Retrieve_WithExpirer_Unwritable, // but write to disk (code coverage) - credential, err := p.Retrieve() + credential, err := p.Retrieve(context.Background()) if err != nil { t.Errorf("Unexpected error: %v", err) } if credential.AccessKeyID != provider.creds.AccessKeyID || credential.SecretAccessKey != provider.creds.SecretAccessKey || credential.SessionToken != provider.creds.SessionToken || - credential.ProviderName != provider.creds.Source { + credential.Source != provider.creds.Source { t.Errorf("cached credential not extracted correctly, got %v", p.cachedCredential) } } @@ -567,19 +563,14 @@ func TestFileCacheProvider_Retrieve_CacheHit(t *testing.T) { })) validateFileCacheProvider(t, p, err, provider) - credential, err := p.Retrieve() + credential, err := p.Retrieve(context.Background()) if err != nil { t.Errorf("Unexpected error: %v", err) } if credential.AccessKeyID != "ABC" || credential.SecretAccessKey != "DEF" || - credential.SessionToken != "GHI" || credential.ProviderName != "JKL" { + credential.SessionToken != "GHI" || credential.Source != "JKL" || + !credential.Expires.Equal(currentTime.Add(time.Hour*6)) { t.Errorf("cached credential not returned") } - if !p.ExpiresAt().Equal(currentTime.Add(time.Hour * 6)) { - t.Errorf("unexpected expiration time: got %s, wanted %s", - p.ExpiresAt().Format(time.RFC3339Nano), - currentTime.Add(time.Hour*6).Format(time.RFC3339Nano), - ) - } } diff --git a/pkg/server/server.go b/pkg/server/server.go index 045f948c2..ce47f0f7d 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -28,8 +28,6 @@ import ( "sync" "time" - "github.com/aws/aws-sdk-go/aws/ec2metadata" - "github.com/aws/aws-sdk-go/aws/session" "sigs.k8s.io/aws-iam-authenticator/pkg/config" "sigs.k8s.io/aws-iam-authenticator/pkg/ec2provider" "sigs.k8s.io/aws-iam-authenticator/pkg/errutil" @@ -42,7 +40,9 @@ import ( "sigs.k8s.io/aws-iam-authenticator/pkg/metrics" "sigs.k8s.io/aws-iam-authenticator/pkg/token" - awsarn "github.com/aws/aws-sdk-go/aws/arn" + awsarn "github.com/aws/aws-sdk-go-v2/aws/arn" + awsconfig "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/feature/ec2/imds" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/sirupsen/logrus" authenticationv1beta1 "k8s.io/api/authentication/v1beta1" @@ -88,7 +88,7 @@ func New(cfg config.Config, stopCh <-chan struct{}) *Server { backendMapper, err := BuildMapperChain(cfg, cfg.BackendMode) if err != nil { - logrus.Fatalf("failed to build mapper chain: %v", err) + logrus.WithError(err).Fatal("failed to build mapper chain") } for _, mapping := range c.RoleMappings { @@ -144,7 +144,11 @@ func New(cfg config.Config, stopCh <-chan struct{}) *Server { logrus.Infof("listening on %s", listener.Addr()) logrus.Infof("reconfigure your apiserver with `--authentication-token-webhook-config-file=%s` to enable (assuming default hostPath mounts)", c.GenerateKubeconfigPath) - internalHandler := c.getHandler(backendMapper, c.EC2DescribeInstancesQps, c.EC2DescribeInstancesBurst, stopCh) + internalHandler, err := c.getHandler(backendMapper, c.EC2DescribeInstancesQps, c.EC2DescribeInstancesBurst, stopCh) + if err != nil { + logrus.WithError(err).Fatal("Failed to create handlers") + } + c.httpServer = http.Server{ ErrorLog: log.New(errLog, "", 0), Handler: internalHandler, @@ -191,23 +195,35 @@ type healthzHandler struct{} func (m *healthzHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { fmt.Fprintf(w, "ok") } -func (c *Server) getHandler(backendMapper BackendMapper, ec2DescribeQps int, ec2DescribeBurst int, stopCh <-chan struct{}) *handler { + +func (c *Server) getHandler(backendMapper BackendMapper, ec2DescribeQps int, ec2DescribeBurst int, stopCh <-chan struct{}) (*handler, error) { if c.ServerEC2DescribeInstancesRoleARN != "" { _, err := awsarn.Parse(c.ServerEC2DescribeInstancesRoleARN) if err != nil { - panic(fmt.Sprintf("describeinstancesrole %s is not a valid arn", c.ServerEC2DescribeInstancesRoleARN)) + logrus.WithError(err).Errorf("describeinstancesrole %s is not a valid arn", c.ServerEC2DescribeInstancesRoleARN) + return nil, err } } - sess := session.Must(session.NewSession()) - ec2metadata := ec2metadata.New(sess) - instanceRegion, err := ec2metadata.Region() + ctx := context.Background() + cfg, err := awsconfig.LoadDefaultConfig(ctx) + if err != nil { + logrus.WithError(err).Error("EC2 instance metadata not configured") + } + cli := imds.NewFromConfig(cfg) + resp, err := cli.GetRegion(ctx, nil) + if err != nil { + logrus.WithError(err).Error("region not found in instance metadata.") + } + + ec2Prov, err := ec2provider.New(c.ServerEC2DescribeInstancesRoleARN, c.SourceARN, resp.Region, ec2DescribeQps, ec2DescribeBurst) if err != nil { - logrus.WithError(err).Errorln("Region not found in instance metadata.") + logrus.WithError(err).Errorln("error initializing EC2 provider") + return nil, err } h := &handler{ - verifier: token.NewVerifier(c.ClusterID, c.PartitionID, instanceRegion), - ec2Provider: ec2provider.New(c.ServerEC2DescribeInstancesRoleARN, c.SourceARN, instanceRegion, ec2DescribeQps, ec2DescribeBurst), + verifier: token.NewVerifier(c.ClusterID, c.PartitionID, resp.Region), + ec2Provider: ec2Prov, clusterID: c.ClusterID, backendMapper: backendMapper, scrubbedAccounts: c.Config.ScrubbedAWSAccounts, @@ -226,7 +242,7 @@ func (c *Server) getHandler(backendMapper BackendMapper, ec2DescribeQps int, ec2 fileutil.StartLoadDynamicFile(c.DynamicBackendModePath, h, stopCh) } - return h + return h, nil } func BuildMapperChain(cfg config.Config, modes []string) (BackendMapper, error) { diff --git a/pkg/token/cluster_id_header.go b/pkg/token/cluster_id_header.go new file mode 100644 index 000000000..1f6a43d32 --- /dev/null +++ b/pkg/token/cluster_id_header.go @@ -0,0 +1,45 @@ +package token + +import ( + "context" + "fmt" + + smithyhttp "github.com/aws/smithy-go/transport/http" + + "github.com/aws/aws-sdk-go-v2/service/sts" + smithymiddleware "github.com/aws/smithy-go/middleware" +) + +type withClusterIDHeader struct { + clusterID string +} + +// implements middleware.BuildMiddleware, which runs AFTER a request has been +// serialized and can operate on the transport request +var _ smithymiddleware.BuildMiddleware = (*withClusterIDHeader)(nil) + +func (*withClusterIDHeader) ID() string { + return "withClusterIDHeader" +} + +func (m *withClusterIDHeader) HandleBuild(ctx context.Context, in smithymiddleware.BuildInput, next smithymiddleware.BuildHandler) ( + out smithymiddleware.BuildOutput, metadata smithymiddleware.Metadata, err error, +) { + req, ok := in.Request.(*smithyhttp.Request) + if !ok { + return out, metadata, fmt.Errorf("unrecognized transport type %T", in.Request) + } + req.Header.Set(clusterIDHeader, m.clusterID) + return next.HandleBuild(ctx, in) +} + +// WithClusterIDHeader adds the clusterID header to the request befor signing +func WithClusterIDHeader(clusterID string) func(*sts.Options) { + return func(o *sts.Options) { + o.APIOptions = append(o.APIOptions, func(s *smithymiddleware.Stack) error { + return s.Build.Add(&withClusterIDHeader{ + clusterID: clusterID, + }, smithymiddleware.After) + }) + } +} diff --git a/pkg/token/token.go b/pkg/token/token.go index 716a8cb12..9ad001f1c 100644 --- a/pkg/token/token.go +++ b/pkg/token/token.go @@ -17,6 +17,7 @@ limitations under the License. package token import ( + "context" "encoding/base64" "encoding/json" "fmt" @@ -28,24 +29,24 @@ import ( "strings" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/credentials" - "github.com/aws/aws-sdk-go/aws/credentials/stscreds" + "sigs.k8s.io/aws-iam-authenticator/pkg" + "sigs.k8s.io/aws-iam-authenticator/pkg/arn" + "sigs.k8s.io/aws-iam-authenticator/pkg/filecache" + "sigs.k8s.io/aws-iam-authenticator/pkg/metrics" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/middleware" + v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials/stscreds" + "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/aws/aws-sdk-go/aws/endpoints" - "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/aws/session" - v4 "github.com/aws/aws-sdk-go/aws/signer/v4" - "github.com/aws/aws-sdk-go/service/sts" - "github.com/aws/aws-sdk-go/service/sts/stsiface" + smithymiddleware "github.com/aws/smithy-go/middleware" "github.com/prometheus/client_golang/prometheus" "github.com/sirupsen/logrus" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" "k8s.io/client-go/pkg/apis/clientauthentication" clientauthv1beta1 "k8s.io/client-go/pkg/apis/clientauthentication/v1beta1" - "sigs.k8s.io/aws-iam-authenticator/pkg" - "sigs.k8s.io/aws-iam-authenticator/pkg/arn" - "sigs.k8s.io/aws-iam-authenticator/pkg/filecache" - "sigs.k8s.io/aws-iam-authenticator/pkg/metrics" ) // Identity is returned on successful Verify() results. It contains a parsed @@ -180,12 +181,16 @@ type getCallerIdentityWrapper struct { } `json:"GetCallerIdentityResponse"` } +type GCIPresigner interface { + PresignGetCallerIdentity(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.PresignOptions)) (*v4.PresignedHTTPRequest, error) +} + // Generator provides new tokens for the AWS IAM Authenticator. type Generator interface { // Get a token using the provided options - GetWithOptions(options *GetTokenOptions) (Token, error) - // GetWithSTS returns a token valid for clusterID using the given STS client. - GetWithSTS(clusterID string, stsAPI stsiface.STSAPI) (Token, error) + GetWithOptions(*GetTokenOptions) (Token, error) + // Presign returns a Token using the given STS client + Presign(GCIPresigner) (Token, error) // FormatJSON returns the client auth formatted json for the ExecCredential auth FormatJSON(Token) string } @@ -220,23 +225,14 @@ func (g generator) GetWithOptions(options *GetTokenOptions) (Token, error) { if options.ClusterID == "" { return Token{}, fmt.Errorf("ClusterID is required") } - - // create a session with the "base" credentials available - // (from environment variable, profile files, EC2 metadata, etc) - sess, err := session.NewSessionWithOptions(session.Options{ - AssumeRoleTokenProvider: StdinStderrTokenProvider, - SharedConfigState: session.SharedConfigEnable, - }) - if err != nil { - return Token{}, fmt.Errorf("could not create session: %v", err) + loadOpts := []func(*config.LoadOptions) error{ + config.WithAPIOptions( + []func(*smithymiddleware.Stack) error{ + middleware.AddUserAgentKeyValue("aws-iam-authenticator", pkg.Version), + }), } - sess.Handlers.Build.PushFrontNamed(request.NamedHandler{ - Name: "authenticatorUserAgent", - Fn: request.MakeAddToUserAgentHandler( - "aws-iam-authenticator", pkg.Version), - }) if options.Region != "" { - sess = sess.Copy(aws.NewConfig().WithRegion(options.Region).WithSTSRegionalEndpoint(endpoints.RegionalSTSEndpoint)) + loadOpts = append(loadOpts, config.WithRegion(options.Region)) } if g.cache { @@ -244,90 +240,91 @@ func (g generator) GetWithOptions(options *GetTokenOptions) (Token, error) { var profile string if v := os.Getenv("AWS_PROFILE"); len(v) > 0 { profile = v - } else { - profile = session.DefaultSharedConfigProfile } - // create a cacheing Provider wrapper around the Credentials - if cacheProvider, err := filecache.NewFileCacheProvider( + // Create a new config to get the default cred chain + cfg, err := config.LoadDefaultConfig(context.Background()) + if err != nil { + return Token{}, fmt.Errorf("could not create config: %v", err) + } + // create a caching Provider wrapper around the Credentials + cacheProvider, err := filecache.NewFileCacheProvider( options.ClusterID, profile, options.AssumeRoleARN, - filecache.V1CredentialToV2Provider(sess.Config.Credentials)); err == nil { - sess.Config.Credentials = credentials.NewCredentials(cacheProvider) + cfg.Credentials, + ) + if err == nil { + loadOpts = append(loadOpts, config.WithCredentialsProvider(cacheProvider)) } else { fmt.Fprintf(os.Stderr, "unable to use cache: %v\n", err) } } - // use an STS client based on the direct credentials - stsAPI := sts.New(sess) - - // if a roleARN was specified, replace the STS client with one that uses - // temporary credentials from that role. - if options.AssumeRoleARN != "" { - var sessionSetters []func(*stscreds.AssumeRoleProvider) + cfg, err := config.LoadDefaultConfig(context.Background(), loadOpts...) + if err != nil { + return Token{}, fmt.Errorf("could not create config: %v", err) - if options.AssumeRoleExternalID != "" { - sessionSetters = append(sessionSetters, func(provider *stscreds.AssumeRoleProvider) { - provider.ExternalID = &options.AssumeRoleExternalID - }) - } + } + if options.AssumeRoleARN != "" { + var sessionName = options.SessionName if g.forwardSessionName { - // If the current session is already a federated identity, carry through - // this session name onto the new session to provide better debugging - // capabilities - resp, err := stsAPI.GetCallerIdentity(&sts.GetCallerIdentityInput{}) + stsSvc := sts.NewFromConfig(cfg) + gciResp, err := stsSvc.GetCallerIdentity(context.Background(), nil) if err != nil { return Token{}, err } - - userIDParts := strings.Split(*resp.UserId, ":") + userIDParts := strings.Split(aws.ToString(gciResp.UserId), ":") if len(userIDParts) == 2 { - sessionSetters = append(sessionSetters, func(provider *stscreds.AssumeRoleProvider) { - provider.RoleSessionName = userIDParts[1] - }) + sessionName = userIDParts[1] } - } else if options.SessionName != "" { - sessionSetters = append(sessionSetters, func(provider *stscreds.AssumeRoleProvider) { - provider.RoleSessionName = options.SessionName - }) } - // create STS-based credentials that will assume the given role - creds := stscreds.NewCredentials(sess, options.AssumeRoleARN, sessionSetters...) + creds := stscreds.NewAssumeRoleProvider(sts.NewFromConfig(cfg), options.AssumeRoleARN, func(o *stscreds.AssumeRoleOptions) { + o.RoleSessionName = sessionName + o.ExternalID = &options.AssumeRoleExternalID + // TODO: Can we get the serial number from the client? + // o.SerialNumber = aws.String("myTokenSerialNumber") + o.TokenProvider = stscreds.StdinTokenProvider + }) - // create an STS API interface that uses the assumed role's temporary credentials - stsAPI = sts.New(sess, &aws.Config{Credentials: creds}) + cfg.Credentials = aws.NewCredentialsCache(creds) } - return g.GetWithSTS(options.ClusterID, stsAPI) + stsSvc := sts.NewFromConfig(cfg, WithClusterIDHeader(options.ClusterID)) + presigner := timedPresigner{v4.NewSigner(), g.nowFunc} + presignClient := sts.NewPresignClient(stsSvc, func(o *sts.PresignOptions) { + o.Presigner = &presigner + }) + return g.Presign(presignClient) } -func getNamedSigningHandler(nowFunc func() time.Time) request.NamedHandler { - return request.NamedHandler{ - Name: "v4.SignRequestHandler", Fn: func(req *request.Request) { - v4.SignSDKRequestWithCurrentTime(req, nowFunc) - }, - } +// timedPresigner exists to wrap PresignHTTP() with a specific time +// and set the x-amz-expires header in the query url +type timedPresigner struct { + signer *v4.Signer + timeFunc func() time.Time } -// GetWithSTS returns a token valid for clusterID using the given STS client. -func (g generator) GetWithSTS(clusterID string, stsAPI stsiface.STSAPI) (Token, error) { - // generate an sts:GetCallerIdentity request and add our custom cluster ID header - request, _ := stsAPI.GetCallerIdentityRequest(&sts.GetCallerIdentityInput{}) - request.HTTPRequest.Header.Add(clusterIDHeader, clusterID) - - // override the Sign handler so we can control the now time for testing. - request.Handlers.Sign.Swap("v4.SignRequestHandler", getNamedSigningHandler(g.nowFunc)) +func (p *timedPresigner) PresignHTTP( + ctx context.Context, credentials aws.Credentials, r *http.Request, + payloadHash string, service string, region string, _ time.Time, + optFns ...func(*v4.SignerOptions), +) (url string, signedHeader http.Header, err error) { + query := r.URL.Query() + query.Set("X-Amz-Expires", strconv.Itoa(requestPresignParam)) + r.URL.RawQuery = query.Encode() + return p.signer.PresignHTTP(ctx, credentials, r, payloadHash, service, region, p.timeFunc(), optFns...) +} +// Presign returns a token valid for clusterID using the given STS client. +func (g generator) Presign(presigner GCIPresigner) (Token, error) { // Sign the request. The expires parameter (sets the x-amz-expires header) is // currently ignored by STS, and the token expires 15 minutes after the x-amz-date - // timestamp regardless. We set it to 60 seconds for backwards compatibility (the - // parameter is a required argument to Presign(), and authenticators 0.3.0 and older are expecting a value between - // 0 and 60 on the server side). + // timestamp regardless. // https://github.com/aws/aws-sdk-go/issues/2167 - presignedURLString, err := request.Presign(requestPresignParam * time.Second) + + req, err := presigner.PresignGetCallerIdentity(context.Background(), nil) if err != nil { return Token{}, err } @@ -335,7 +332,7 @@ func (g generator) GetWithSTS(clusterID string, stsAPI stsiface.STSAPI) (Token, // Set token expiration to 1 minute before the presigned URL expires for some cushion tokenExpiration := g.nowFunc().Local().Add(presignedURLExpiration - 1*time.Minute) // TODO: this may need to be a constant-time base64 encoding - return Token{v1Prefix + base64.RawURLEncoding.EncodeToString([]byte(presignedURLString)), tokenExpiration}, nil + return Token{v1Prefix + base64.RawURLEncoding.EncodeToString([]byte(req.URL)), tokenExpiration}, nil } // FormatJSON formats the json to support ExecCredential authentication diff --git a/pkg/token/token_test.go b/pkg/token/token_test.go index a8e997c86..ce43cca45 100644 --- a/pkg/token/token_test.go +++ b/pkg/token/token_test.go @@ -2,6 +2,7 @@ package token import ( "bytes" + "context" "encoding/base64" "encoding/json" "errors" @@ -14,11 +15,12 @@ import ( "testing" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/credentials" + "github.com/aws/aws-sdk-go-v2/aws" + v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/aws/aws-sdk-go/aws/endpoints" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/sts" "github.com/google/go-cmp/cmp" "github.com/prometheus/client_golang/prometheus" v1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -587,12 +589,12 @@ func Test_getDefaultHostNameForRegion(t *testing.T) { } } -func TestGetWithSTS(t *testing.T) { +func TestPresign(t *testing.T) { clusterID := "test-cluster" cases := []struct { name string - creds *credentials.Credentials + creds aws.CredentialsProvider nowTime time.Time want Token wantErr error @@ -600,10 +602,10 @@ func TestGetWithSTS(t *testing.T) { { "Non-zero time", // Example non-real credentials - func() *credentials.Credentials { + func() credentials.StaticCredentialsProvider { decodedAkid, _ := base64.StdEncoding.DecodeString("QVNJQVIyVEc0NFY2QVMzWlpFN0M=") decodedSk, _ := base64.StdEncoding.DecodeString("NEtENWNudEdjVm1MV1JkRjV3dk5SdXpOTDVReG1wNk9LVlk2RnovUQ==") - return credentials.NewStaticCredentials( + return credentials.NewStaticCredentialsProvider( string(decodedAkid), string(decodedSk), "", @@ -620,13 +622,15 @@ func TestGetWithSTS(t *testing.T) { for _, tc := range cases { t.Run(tc.name, func(t *testing.T) { - svc := sts.New(session.Must(session.NewSession( - &aws.Config{ - Credentials: tc.creds, - Region: aws.String("us-west-2"), - STSRegionalEndpoint: endpoints.RegionalSTSEndpoint, - }, - ))) + cfg, err := config.LoadDefaultConfig( + context.Background(), + config.WithRegion("us-west-2"), + config.WithCredentialsProvider(tc.creds), + ) + if err != nil { + t.Errorf("unexpected error initialzing config: %v", err) + return + } gen := &generator{ forwardSessionName: false, @@ -634,7 +638,13 @@ func TestGetWithSTS(t *testing.T) { nowFunc: func() time.Time { return tc.nowTime }, } - got, err := gen.GetWithSTS(clusterID, svc) + stsSvc := sts.NewFromConfig(cfg, WithClusterIDHeader(clusterID)) + presigner := timedPresigner{v4.NewSigner(), gen.nowFunc} + presignClient := sts.NewPresignClient(stsSvc, func(o *sts.PresignOptions) { + o.Presigner = &presigner + }) + + got, err := gen.Presign(presignClient) if diff := cmp.Diff(err, tc.wantErr); diff != "" { t.Errorf("Unexpected error: %s", diff) }