Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ensure identity resolver is set when a credentials provider is given only at operation level #3156

Merged
6 changes: 6 additions & 0 deletions CHANGELOG.next.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,9 @@ An operation output that supports receiving events from stream now provides a ne
references = ["smithy-rs#3100", "smithy-rs#3114"]
meta = { "breaking" = true, "tada" = false, "bug" = false }
author = "ysaito1001"

[[aws-sdk-rust]]
message = "Fix exclusively setting the credentials provider at operation config-override time. It's now possible to set the credentials when an operation is sent (via `.config_override()`), rather than at client-creation time."
references = ["smithy-rs#3156", "aws-sdk-rust#901"]
meta = { "breaking" = false, "tada" = false, "bug" = true }
author = "ysaito1001"
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,13 @@ import software.amazon.smithy.rust.codegen.client.smithy.ClientRustModule
import software.amazon.smithy.rust.codegen.client.smithy.customize.ClientCodegenDecorator
import software.amazon.smithy.rust.codegen.client.smithy.customize.TestUtilFeature
import software.amazon.smithy.rust.codegen.client.smithy.endpoint.supportedAuthSchemes
import software.amazon.smithy.rust.codegen.client.smithy.generators.ServiceRuntimePluginCustomization
import software.amazon.smithy.rust.codegen.client.smithy.generators.ServiceRuntimePluginSection
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ConfigCustomization
import software.amazon.smithy.rust.codegen.client.smithy.generators.config.ServiceConfig
import software.amazon.smithy.rust.codegen.core.rustlang.Writable
import software.amazon.smithy.rust.codegen.core.rustlang.featureGateBlock
import software.amazon.smithy.rust.codegen.core.rustlang.rust
import software.amazon.smithy.rust.codegen.core.rustlang.rustBlockTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.rustlang.writable
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType.Companion.preludeScope
import software.amazon.smithy.rust.codegen.core.smithy.RustCrate
import software.amazon.smithy.rust.codegen.core.smithy.customize.AdHocCustomization
Expand All @@ -30,12 +26,6 @@ class CredentialsProviderDecorator : ClientCodegenDecorator {
override val name: String = "CredentialsProvider"
override val order: Byte = 0

override fun serviceRuntimePluginCustomizations(
codegenContext: ClientCodegenContext,
baseCustomizations: List<ServiceRuntimePluginCustomization>,
): List<ServiceRuntimePluginCustomization> =
baseCustomizations + listOf(CredentialsIdentityResolverRegistration(codegenContext))

override fun configCustomizations(
codegenContext: ClientCodegenContext,
baseCustomizations: List<ConfigCustomization>,
Expand Down Expand Up @@ -65,7 +55,7 @@ class CredentialsProviderDecorator : ClientCodegenDecorator {
/**
* Add a `.credentials_provider` field and builder to the `Config` for a given service
*/
class CredentialProviderConfig(codegenContext: ClientCodegenContext) : ConfigCustomization() {
class CredentialProviderConfig(private val codegenContext: ClientCodegenContext) : ConfigCustomization() {
private val runtimeConfig = codegenContext.runtimeConfig
private val codegenScope = arrayOf(
*preludeScope,
Expand All @@ -74,6 +64,10 @@ class CredentialProviderConfig(codegenContext: ClientCodegenContext) : ConfigCus
.resolve("provider::ProvideCredentials"),
"SharedCredentialsProvider" to AwsRuntimeType.awsCredentialTypes(runtimeConfig)
.resolve("provider::SharedCredentialsProvider"),
"SIGV4A_SCHEME_ID" to AwsRuntimeType.awsRuntime(runtimeConfig)
.resolve("auth::sigv4a::SCHEME_ID"),
"SIGV4_SCHEME_ID" to AwsRuntimeType.awsRuntime(runtimeConfig)
.resolve("auth::sigv4::SCHEME_ID"),
"TestCredentials" to AwsRuntimeType.awsCredentialTypesTestUtil(runtimeConfig).resolve("Credentials"),
)

Expand Down Expand Up @@ -103,16 +97,34 @@ class CredentialProviderConfig(codegenContext: ClientCodegenContext) : ConfigCus
*codegenScope,
)

rustTemplate(
rustBlockTemplate(
"""
/// Sets the credentials provider for this service
pub fn set_credentials_provider(&mut self, credentials_provider: #{Option}<#{SharedCredentialsProvider}>) -> &mut Self {
self.config.store_or_unset(credentials_provider);
self
}
pub fn set_credentials_provider(&mut self, credentials_provider: #{Option}<#{SharedCredentialsProvider}>) -> &mut Self
""",
*codegenScope,
)
) {
rustBlockTemplate(
"""
if let Some(credentials_provider) = credentials_provider
""",
*codegenScope,
) {
if (codegenContext.serviceShape.supportedAuthSchemes().contains("sigv4a")) {
featureGateBlock("sigv4a") {
rustTemplate(
"self.runtime_components.push_identity_resolver(#{SIGV4_SCHEME_ID}, credentials_provider.clone());",
*codegenScope,
)
}
}
rustTemplate(
"self.runtime_components.push_identity_resolver(#{SIGV4_SCHEME_ID}, credentials_provider);",
*codegenScope,
)
}
rust("self")
}
}

is ServiceConfig.DefaultForTests -> rustTemplate(
Expand All @@ -124,39 +136,3 @@ class CredentialProviderConfig(codegenContext: ClientCodegenContext) : ConfigCus
}
}
}

class CredentialsIdentityResolverRegistration(
private val codegenContext: ClientCodegenContext,
) : ServiceRuntimePluginCustomization() {
private val runtimeConfig = codegenContext.runtimeConfig

override fun section(section: ServiceRuntimePluginSection): Writable = writable {
when (section) {
is ServiceRuntimePluginSection.RegisterRuntimeComponents -> {
rustBlockTemplate("if let Some(creds_provider) = ${section.serviceConfigName}.credentials_provider()") {
val codegenScope = arrayOf(
"SharedIdentityResolver" to RuntimeType.smithyRuntimeApi(runtimeConfig)
.resolve("client::identity::SharedIdentityResolver"),
"SIGV4A_SCHEME_ID" to AwsRuntimeType.awsRuntime(runtimeConfig)
.resolve("auth::sigv4a::SCHEME_ID"),
"SIGV4_SCHEME_ID" to AwsRuntimeType.awsRuntime(runtimeConfig)
.resolve("auth::sigv4::SCHEME_ID"),
)

if (codegenContext.serviceShape.supportedAuthSchemes().contains("sigv4a")) {
featureGateBlock("sigv4a") {
section.registerIdentityResolver(this) {
rustTemplate("#{SIGV4A_SCHEME_ID}, creds_provider.clone()", *codegenScope)
}
}
}
section.registerIdentityResolver(this) {
rustTemplate("#{SIGV4_SCHEME_ID}, creds_provider,", *codegenScope)
}
}
}

else -> {}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,65 @@ package software.amazon.smithy.rustsdk

import org.junit.jupiter.api.Test
import software.amazon.smithy.rust.codegen.client.testutil.validateConfigCustomizations
import software.amazon.smithy.rust.codegen.core.rustlang.rustTemplate
import software.amazon.smithy.rust.codegen.core.smithy.RuntimeType
import software.amazon.smithy.rust.codegen.core.testutil.integrationTest
import software.amazon.smithy.rust.codegen.core.testutil.tokioTest

internal class CredentialProviderConfigTest {
@Test
fun `generates a valid config`() {
val codegenContext = awsTestCodegenContext()
validateConfigCustomizations(codegenContext, CredentialProviderConfig(codegenContext))
}

@Test
fun `configuring credentials provider at operation level should work`() {
awsSdkIntegrationTest(SdkCodegenIntegrationTest.model) { ctx, rustCrate ->
val rc = ctx.runtimeConfig
val codegenScope = arrayOf(
*RuntimeType.preludeScope,
"capture_request" to RuntimeType.captureRequest(rc),
"Credentials" to AwsRuntimeType.awsCredentialTypesTestUtil(rc)
.resolve("Credentials"),
"Region" to AwsRuntimeType.awsTypes(rc).resolve("region::Region"),
)
rustCrate.integrationTest("credentials_provider") {
// per https://github.com/awslabs/aws-sdk-rust/issues/901
tokioTest("configuring_credentials_provider_at_operation_level_should_work") {
val moduleName = ctx.moduleUseName()
rustTemplate(
"""
let (http_client, _rx) = #{capture_request}(None);
let client_config = $moduleName::Config::builder()
.http_client(http_client)
.build();

let client = $moduleName::Client::from_conf(client_config);

let credentials = #{Credentials}::new(
"test",
"test",
#{None},
#{None},
"test",
);
let operation_config_override = $moduleName::Config::builder()
.credentials_provider(credentials.clone())
.region(#{Region}::new("us-west-2"));

let _ = client
.some_operation()
.customize()
.config_override(operation_config_override)
.send()
.await
.expect("success");
""",
*codegenScope,
)
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,6 @@ sealed class ServiceRuntimePluginSection(name: String) : Section(name) {
writer.rust("runtime_components.set_endpoint_resolver(Some(#T));", resolver)
}

fun registerIdentityResolver(writer: RustWriter, identityResolver: Writable) {
writer.rust("runtime_components.push_identity_resolver(#T);", identityResolver)
}

fun registerRetryClassifier(writer: RustWriter, classifier: Writable) {
writer.rust("runtime_components.push_retry_classifier(#T);", classifier)
}
Expand Down