Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Override RegionSet in EnpointResolverInterceptor after fetching the Signing Properties from Endpoint rules #5825

Open
wants to merge 6 commits into
base: feature/master/multi-auth-sigv4a
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ public static Metadata constructMetadata(ServiceModel serviceModel,
.withJsonVersion(serviceMetadata.getJsonVersion())
.withEndpointPrefix(serviceMetadata.getEndpointPrefix())
.withSigningName(serviceMetadata.getSigningName())
.withAuthType(AuthType.fromValue(serviceMetadata.getSignatureVersion()))
.withAuthType(serviceMetadata.getSignatureVersion() != null ?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When can this be null?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The specifications states

The authType and signatureVersion traits will be deprecated in favor of the auth and unsignedPayload traits.

Thus for new services which will be completed based on multi-auth signatureVersion will be null.
I added a test case for in codegen-tst when a new service is added with just multi-auth supporting only sigv4a

AuthType.fromValue(serviceMetadata.getSignatureVersion()) : null)
.withUid(serviceMetadata.getUid())
.withServiceId(serviceMetadata.getServiceId())
.withSupportsH2(supportsH2(serviceMetadata))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
import software.amazon.awssdk.metrics.MetricCollector;
import software.amazon.awssdk.metrics.SdkMetric;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.utils.CollectionUtils;
import software.amazon.awssdk.utils.Logger;
import software.amazon.awssdk.utils.Validate;

Expand Down Expand Up @@ -196,7 +197,6 @@ private MethodSpec generateAuthSchemeParams() {
builder.addStatement("(($T)builder).endpointProvider(($T)endpointProvider)", paramsBuilderClass, endpointProviderClass);
builder.endControlFlow();
builder.endControlFlow();
// TODO: Implement addRegionSet() for legacy services that resolve authentication from endpoints in one of next PRs.
builder.addStatement("return builder.build()");
return builder.build();
}
Expand Down Expand Up @@ -452,19 +452,13 @@ private TypeName toTypeName(Object valueType) {
private void generateSigv4aRegionSet(MethodSpec.Builder builder) {
if (authSchemeSpecUtils.usesSigV4a()) {
builder.addStatement(
"$T regionSet = executionAttributes.getOptionalAttribute($T.AWS_SIGV4A_SIGNING_REGION_SET)\n" +
" .filter(regions -> !regions.isEmpty())\n" +
" .map(regions -> $T.create(String.join(\", \", regions)))\n" +
" .orElseGet(() -> {\n" +
" $T fallbackRegion = executionAttributes.getAttribute($T.AWS_REGION);\n" +
" return fallbackRegion != null ? $T.create(fallbackRegion.toString()) : null;\n" +
" });",
RegionSet.class, AwsExecutionAttribute.class,
RegionSet.class, Region.class, AwsExecutionAttribute.class,
"executionAttributes.getOptionalAttribute($T.AWS_SIGV4A_SIGNING_REGION_SET)\n" +
" .filter(regionSet -> !$T.isNullOrEmpty(regionSet))\n" +
" .ifPresent(nonEmptyRegionSet -> builder.regionSet($T.create(nonEmptyRegionSet)))",
AwsExecutionAttribute.class,
CollectionUtils.class,
RegionSet.class
);

builder.addStatement("builder.regionSet(regionSet)");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ public class EndpointResolverInterceptorSpec implements ClassSpec {
private final JmesPathAcceptorGenerator jmesPathGenerator;
private final boolean dependsOnHttpAuthAws;
private final boolean useSraAuth;
private final boolean multiAuthSigv4a;


public EndpointResolverInterceptorSpec(IntermediateModel model) {
Expand All @@ -116,6 +117,7 @@ public EndpointResolverInterceptorSpec(IntermediateModel model) {
supportedAuthSchemes.contains(AwsV4aAuthScheme.class);

this.useSraAuth = new AuthSchemeSpecUtils(model).useSraAuth();
this.multiAuthSigv4a = new AuthSchemeSpecUtils(model).usesSigV4a();
}

@Override
Expand Down Expand Up @@ -155,6 +157,10 @@ public TypeSpec poetSpec() {
b.addMethod(signerProviderMethod());
}

if (multiAuthSigv4a) {
b.addMethod(createHasRegionSetMethod());
b.addMethod(createUpdateAuthSchemeWithRegionSetMethod());
}
endpointParamsKnowledgeIndex.addAccountIdMethodsIfPresent(b);
return b.build();
}
Expand Down Expand Up @@ -192,7 +198,9 @@ private MethodSpec modifyRequestMethod(String endpointAuthSchemeStrategyFieldNam
endpointRulesSpecUtils.providerInterfaceName(), providerVar, SdkInternalExecutionAttribute.class);
b.beginControlFlow("try");
b.addStatement("long resolveEndpointStart = $T.nanoTime()", System.class);
b.addStatement("$T endpoint = $N.resolveEndpoint(ruleParams(result, executionAttributes)).join()",
b.addStatement("$T endpointParams = ruleParams(result, executionAttributes)",
endpointRulesSpecUtils.parametersClassName());
b.addStatement("$T endpoint = $N.resolveEndpoint(endpointParams).join()",
Endpoint.class, providerVar);
b.addStatement("$1T resolveEndpointDuration = $1T.ofNanos($2T.nanoTime() - resolveEndpointStart)", Duration.class,
System.class);
Expand All @@ -219,7 +227,11 @@ private MethodSpec modifyRequestMethod(String endpointAuthSchemeStrategyFieldNam
SelectedAuthScheme.class, SdkInternalExecutionAttribute.class);
b.beginControlFlow("if (endpointAuthSchemes != null && selectedAuthScheme != null)");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is selectedAuthScheme ever null nowadays? It feels like this should be a dead code branch now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed this check

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad , we need this check , the S3 integ test failed , will revert this

java.lang.NullPointerException: Cannot invoke "software.amazon.awssdk.core.SelectedAuthScheme.authSchemeOption()" because "selectedAuthScheme" is null
    at software.amazon.awssdk.services.s3.endpoints.internal.S3ResolveEndpointInterceptor.authSchemeWithEndpointSignerProperties(S3ResolveEndpointInterceptor.java:1383)
    at software.amazon.awssdk.services.s3.endpoints.internal.S3ResolveEndpointInterceptor.modifyRequest(S3ResolveEndpointInterceptor.java:177)
    at software.amazon.awssdk.core.interceptor.ExecutionInterceptorChain.modifyRequest(ExecutionInterceptorC

b.addStatement("selectedAuthScheme = authSchemeWithEndpointSignerProperties(endpointAuthSchemes, selectedAuthScheme)");

if (multiAuthSigv4a) {
b.beginControlFlow("if(!hasRegionSet(selectedAuthScheme))");
b.addStatement("selectedAuthScheme = updateAuthSchemeWithRegionSet(selectedAuthScheme, endpointParams)");
b.endControlFlow();
}
b.addStatement("executionAttributes.putAttribute($T.SELECTED_AUTH_SCHEME, selectedAuthScheme)",
SdkInternalExecutionAttribute.class);
b.endControlFlow();
Expand Down Expand Up @@ -774,7 +786,7 @@ private static CodeBlock copyV4EndpointSignerPropertiesToAuth() {
return code.build();
}

private static CodeBlock copyV4aEndpointSignerPropertiesToAuth() {
private CodeBlock copyV4aEndpointSignerPropertiesToAuth() {
CodeBlock.Builder code = CodeBlock.builder();

code.beginControlFlow("if (endpointAuthScheme instanceof $T)", SigV4aAuthScheme.class);
Expand All @@ -784,10 +796,12 @@ private static CodeBlock copyV4aEndpointSignerPropertiesToAuth() {
code.addStatement("option.putSignerProperty($T.DOUBLE_URL_ENCODE, !v4aAuthScheme.disableDoubleEncoding())",
AwsV4aHttpSigner.class);
code.endControlFlow();

code.beginControlFlow("if (v4aAuthScheme.signingRegionSet() != null)");
if (multiAuthSigv4a) {
code.beginControlFlow("if (!hasRegionSet(selectedAuthScheme) && v4aAuthScheme.signingRegionSet() != null)");
} else {
code.beginControlFlow("if (v4aAuthScheme.signingRegionSet() != null)");
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hasRegionSet seems like a relatively small method. Same with updateAuthSchemeWithRegionSet. Do we want to just always generate those methods, so that we don't have to do this branching in the code generator? It feels simpler to have a single code path regardless of sigv4a or sigv4 (both in the code generator, and in the generated code).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have made updateAuthSchemeWithRegionSet and hasRegionSet inline. The reason I added a specific check for multiAuthSigv4a is because for non-Sigv4a existing features, that check is dead code and might confuse readers of generated packages while debugging or during code walks.

code.addStatement("$1T regionSet = $1T.create(v4aAuthScheme.signingRegionSet())", RegionSet.class);

code.addStatement("option.putSignerProperty($T.REGION_SET, regionSet)", AwsV4aHttpSigner.class);
code.endControlFlow();

Expand Down Expand Up @@ -882,4 +896,51 @@ private MethodSpec constructorMethodSpec(String endpointAuthSchemeFieldName) {
return b.build();
}

private MethodSpec createHasRegionSetMethod() {
TypeVariableName tExtendsIdentity = TypeVariableName.get("T", Identity.class);
TypeName selectedAuthSchemeOfT = ParameterizedTypeName.get(ClassName.get(SelectedAuthScheme.class),
TypeVariableName.get("T"));

return
MethodSpec.methodBuilder("hasRegionSet")
.addModifiers(Modifier.PRIVATE)
.addTypeVariable(tExtendsIdentity)
.returns(boolean.class)
.addParameter(selectedAuthSchemeOfT, "selectedAuthScheme")
.addCode(
CodeBlock.builder()
.addStatement("return selectedAuthScheme.authSchemeOption().schemeId().equals($T.SCHEME_ID)"
+ " && selectedAuthScheme.authSchemeOption().signerProperty($T.REGION_SET) != "
+ "null", AwsV4aAuthScheme.class, AwsV4aHttpSigner.class)
.build())
.build();
}

private MethodSpec createUpdateAuthSchemeWithRegionSetMethod() {
TypeVariableName tExtendsIdentity = TypeVariableName.get("T", Identity.class);
TypeName selectedAuthSchemeOfT = ParameterizedTypeName.get(
ClassName.get(SelectedAuthScheme.class),
TypeVariableName.get("T")
);

return MethodSpec.methodBuilder("updateAuthSchemeWithRegionSet")
.addModifiers(Modifier.PRIVATE)
.addTypeVariable(tExtendsIdentity)
.returns(selectedAuthSchemeOfT)
.addParameter(selectedAuthSchemeOfT, "selectedAuthScheme")
.addParameter(endpointRulesSpecUtils.parametersClassName(), "endpointParams")
.addCode(CodeBlock.builder()
.addStatement("$T optionBuilder = selectedAuthScheme.authSchemeOption().toBuilder()",
ClassName.get(AuthSchemeOption.Builder.class))
.addStatement("$T regionSet = $T.create(endpointParams.region().id())",
RegionSet.class, RegionSet.class)
.addStatement("optionBuilder.putSignerProperty($T.REGION_SET, regionSet)",
AwsV4aHttpSigner.class)
.addStatement("return new $T<>(selectedAuthScheme.identity(), " +
"selectedAuthScheme.signer(), optionBuilder.build())",
SelectedAuthScheme.class)
.build())
.build();
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import software.amazon.awssdk.codegen.model.intermediate.IntermediateModel;
import software.amazon.awssdk.codegen.model.intermediate.OperationModel;
import software.amazon.awssdk.codegen.model.service.AuthType;
import software.amazon.awssdk.utils.CollectionUtils;

public final class AuthUtils {
private AuthUtils() {
Expand Down Expand Up @@ -76,6 +77,12 @@ private static boolean isServiceSigv4a(IntermediateModel model) {

private static boolean isServiceAwsAuthType(IntermediateModel model) {
AuthType authType = model.getMetadata().getAuthType();
if (authType == null && !CollectionUtils.isNullOrEmpty(model.getMetadata().getAuth())) {
return model.getMetadata().getAuth().stream()
.map(AuthType::value)
.map(AuthType::fromValue)
.anyMatch(AuthUtils::isAuthTypeAws);
}
return isAuthTypeAws(authType);
}

Expand All @@ -85,6 +92,7 @@ private static boolean isAuthTypeAws(AuthType authType) {
}

switch (authType) {
case V4A:
case V4:
case S3:
case S3V4:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,10 @@ private static IntermediateModel getModel(boolean useSraAuth) {
model.getCustomizationConfig().setUseSraAuth(useSraAuth);
return model;
}

@Test
void endpointResolverInterceptorClassWithSigv4aMultiAuth() {
ClassSpec endpointProviderInterceptor = new EndpointResolverInterceptorSpec(ClientTestModels.opsWithSigv4a());
assertThat(endpointProviderInterceptor, generatesTo("endpoint-resolve-interceptor-with-multiauthsigv4a.java"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.services.database.auth.scheme.DatabaseAuthSchemeParams;
import software.amazon.awssdk.services.database.auth.scheme.DatabaseAuthSchemeProvider;
import software.amazon.awssdk.utils.CollectionUtils;
import software.amazon.awssdk.utils.Logger;
import software.amazon.awssdk.utils.Validate;

Expand Down Expand Up @@ -88,14 +89,9 @@ private DatabaseAuthSchemeParams authSchemeParams(SdkRequest request, ExecutionA
DatabaseAuthSchemeParams.Builder builder = DatabaseAuthSchemeParams.builder().operation(operation);
Region region = executionAttributes.getAttribute(AwsExecutionAttribute.AWS_REGION);
builder.region(region);
RegionSet regionSet = executionAttributes.getOptionalAttribute(AwsExecutionAttribute.AWS_SIGV4A_SIGNING_REGION_SET)
.filter(regions -> !regions.isEmpty()).map(regions -> RegionSet.create(String.join(", ", regions)))
.orElseGet(() -> {
Region fallbackRegion = executionAttributes.getAttribute(AwsExecutionAttribute.AWS_REGION);
return fallbackRegion != null ? RegionSet.create(fallbackRegion.toString()) : null;
});
;
builder.regionSet(regionSet);
executionAttributes.getOptionalAttribute(AwsExecutionAttribute.AWS_SIGV4A_SIGNING_REGION_SET)
.filter(regionSet -> !CollectionUtils.isNullOrEmpty(regionSet))
.ifPresent(nonEmptyRegionSet -> builder.regionSet(RegionSet.create(nonEmptyRegionSet)));
return builder.build();
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
{
"version": "1.2",
"serviceId": "Database Service",
"parameters": {
"region": {
"type": "string",
"builtIn": "AWS::Region",
"required": true,
"documentation": "The region to send requests to"
},
"useDualStackEndpoint": {
"type": "boolean",
"builtIn": "AWS::UseDualStack"
},
"useFIPSEndpoint": {
"type": "boolean",
"builtIn": "AWS::UseFIPS"
},
"AccountId": {
"type": "String",
"builtIn": "AWS::Auth::AccountId"
},
"operationContextParam": {
"type": "string"
}
},
"rules": [
{
"conditions": [
{
"fn": "aws.partition",
"argv": [
{
"ref": "region"
}
],
"assign": "partitionResult"
}
],
"rules": [
{
"conditions": [
{
"fn": "isSet",
"argv": [
{
"ref": "endpointId"
}
]
}
],
"rules": [
{
"conditions": [
{
"fn": "isSet",
"argv": [
{
"ref": "useFIPSEndpoint"
}
]
}
],
"error": "FIPS endpoints not supported with multi-region endpoints",
"type": "error"
},
{
"endpoint": {
"url": "https://{endpointId}.query.{partitionResult#dualStackDnsSuffix}",
"properties": {
"authSchemes": [
{
"name": "sigv4a",
"signingName": "query",
"signingRegionSet": ["*"]
}
]
}
},
"type": "endpoint"
}
],
"type": "tree"
}
],
"type": "tree"
}
]
}

Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ public SdkRequest modifyRequest(Context.ModifyRequest context, ExecutionAttribut
.getAttribute(SdkInternalExecutionAttribute.ENDPOINT_PROVIDER);
try {
long resolveEndpointStart = System.nanoTime();
Endpoint endpoint = provider.resolveEndpoint(ruleParams(result, executionAttributes)).join();
QueryEndpointParams endpointParams = ruleParams(result, executionAttributes);
Endpoint endpoint = provider.resolveEndpoint(endpointParams).join();
Duration resolveEndpointDuration = Duration.ofNanos(System.nanoTime() - resolveEndpointStart);
Optional<MetricCollector> metricCollector = executionAttributes
.getOptionalAttribute(SdkExecutionAttribute.API_CALL_METRIC_COLLECTOR);
Expand Down
Loading
Loading