diff --git a/Sources/Core/AWSClientRuntime/Auth/AuthSchemes/SigV4AAuthScheme.swift b/Sources/Core/AWSClientRuntime/Auth/AuthSchemes/SigV4AAuthScheme.swift new file mode 100644 index 00000000000..779609b6ff6 --- /dev/null +++ b/Sources/Core/AWSClientRuntime/Auth/AuthSchemes/SigV4AAuthScheme.swift @@ -0,0 +1,76 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + +import ClientRuntime + +public struct SigV4AAuthScheme: ClientRuntime.AuthScheme { + public let schemeID: String = "aws.auth#sigv4a" + public let signer: ClientRuntime.Signer = AWSSigV4Signer() + public let idKind: ClientRuntime.IdentityKind = .aws + + public init() {} + + public func customizeSigningProperties(signingProperties: Attributes, context: HttpContext) throws -> Attributes { + var updatedSigningProperties = signingProperties + + // Set signing algorithm flag + updatedSigningProperties.set(key: AttributeKeys.signingAlgorithm, value: .sigv4a) + + // Set bidirectional streaming flag + updatedSigningProperties.set( + key: AttributeKeys.bidirectionalStreaming, + value: context.isBidirectionalStreamingEnabled() + ) + + // Set signing name and signing region flags + updatedSigningProperties.set(key: AttributeKeys.signingName, value: context.getSigningName()) + updatedSigningProperties.set(key: AttributeKeys.signingRegion, value: context.getSigningRegion()) + + // Set expiration flag + // + // Expiration is only used for presigning (presign request flow or presign URL flow). + updatedSigningProperties.set(key: AttributeKeys.expiration, value: context.getExpiration()) + + // Set signature type flag + // + // AWSSignatureType.requestQueryParams is only used for presign URL flow. + // Out of the AWSSignatureType enum cases, only two are used. .requestHeaders and .requestQueryParams. + // .requestHeaders is the deafult signing used for AWS operations. + let serviceName = context.getServiceName() + let isPresignURLFlow = context.getFlowType() == .PRESIGN_URL + updatedSigningProperties.set( + key: AttributeKeys.signatureType, + value: isPresignURLFlow ? .requestQueryParams : .requestHeaders + ) + + // Operation name is guaranteed to be in middleware context from generic codegen, but check just in case. + guard let operationName = context.getOperation() else { + throw ClientError.dataNotFound("Operation name must be configured on middleware context.") + } + + // Set unsignedBody flag + let shouldForceUnsignedBody = SigV4Util.shouldForceUnsignedBody( + flow: context.getFlowType(), + serviceName: serviceName, + opName: operationName + ) + let unsignedBody = context.hasUnsignedPayloadTrait() || shouldForceUnsignedBody + updatedSigningProperties.set(key: AttributeKeys.unsignedBody, value: unsignedBody) + + // Set signedBodyHeader flag + let useSignedBodyHeader = SigV4Util.serviceUsesUnsignedBodyHeader(serviceName: serviceName) && !unsignedBody + updatedSigningProperties.set( + key: AttributeKeys.signedBodyHeader, + value: useSignedBodyHeader ? .contentSha256 : AWSSignedBodyHeader.none + ) + + // Set flags in SigningFlags object (S3 customizations) + SigV4Util.setS3SpecificFlags(signingProperties: &updatedSigningProperties, serviceName: serviceName) + + return updatedSigningProperties + } +} diff --git a/Sources/Core/AWSClientRuntime/Auth/AuthSchemes/SigV4AuthScheme.swift b/Sources/Core/AWSClientRuntime/Auth/AuthSchemes/SigV4AuthScheme.swift index 99fe3e4bb87..9e937fb449e 100644 --- a/Sources/Core/AWSClientRuntime/Auth/AuthSchemes/SigV4AuthScheme.swift +++ b/Sources/Core/AWSClientRuntime/Auth/AuthSchemes/SigV4AuthScheme.swift @@ -14,36 +14,63 @@ public struct SigV4AuthScheme: ClientRuntime.AuthScheme { public init() {} - public func customizeSigningProperties(signingProperties: Attributes, context: HttpContext) -> Attributes { + public func customizeSigningProperties(signingProperties: Attributes, context: HttpContext) throws -> Attributes { var updatedSigningProperties = signingProperties - updatedSigningProperties.set(key: AttributeKeys.bidirectionalStreaming, value: context.isBidirectionalStreamingEnabled()) + + // Set signing algorithm flag + updatedSigningProperties.set(key: AttributeKeys.signingAlgorithm, value: .sigv4) + + // Set bidirectional streaming flag + updatedSigningProperties.set( + key: AttributeKeys.bidirectionalStreaming, + value: context.isBidirectionalStreamingEnabled() + ) + + // Set signing name and signing region flags updatedSigningProperties.set(key: AttributeKeys.signingName, value: context.getSigningName()) updatedSigningProperties.set(key: AttributeKeys.signingRegion, value: context.getSigningRegion()) - updatedSigningProperties.set(key: AttributeKeys.signingAlgorithm, value: .sigv4) - // Expiration is only used for presigning URLs. E.g., in AWSS3, and in AWSPolly. - updatedSigningProperties.set(key: AttributeKeys.expiration, value: 0) - // AWSSignatureType.requestQueryParams is only used for S3 GetObject and PutObject - // Out of all AWSSignatureType cases, only two are used. .requestHeaders and .requestQueryParams. - // .requestHeaders is the deafult signing used for all AWS operations except S3 customizations. - updatedSigningProperties.set(key: AttributeKeys.signatureType, value: .requestHeaders) + // Set expiration flag + // + // Expiration is only used for presigning (presign request flow or presign URL flow). + updatedSigningProperties.set(key: AttributeKeys.expiration, value: context.getExpiration()) - // SigningFlags + // Set signature type flag + // + // AWSSignatureType.requestQueryParams is only used for presign URL flow. + // Out of the AWSSignatureType enum cases, only two are used. .requestHeaders and .requestQueryParams. + // .requestHeaders is the deafult signing used for AWS operations. let serviceName = context.getServiceName() - // Set useDoubleURIEncode to false IFF service is S3 - updatedSigningProperties.set(key: AttributeKeys.useDoubleURIEncode, value: serviceName != "S3") - // Set shouldNormalizeURIPath to false IFF service is S3 - updatedSigningProperties.set(key: AttributeKeys.shouldNormalizeURIPath, value: serviceName != "S3") - // FIXME: Flag doesn't seem to be used by anything - investigate - updatedSigningProperties.set(key: AttributeKeys.omitSessionToken, value: false) - - /* - * The boolean flag .unsignedBody for AWSSigningConfig.signedBodyValue & - * the AWSSignedBodyHeader enum value for AWSSigningConfig.signedBodyHeader - * will be generated into signingProperties during service specific auth scheme resolver codegen and be part of - * the returned auth option's signing properties. - * By the time the call chain arrives here, code-generated flags are already included in signingProperties. - */ + let isPresignURLFlow = context.getFlowType() == .PRESIGN_URL + updatedSigningProperties.set( + key: AttributeKeys.signatureType, + value: isPresignURLFlow ? .requestQueryParams : .requestHeaders + ) + + // Operation name is guaranteed to be in middleware context from generic codegen, but check just in case. + guard let operationName = context.getOperation() else { + throw ClientError.dataNotFound("Operation name must be configured on middleware context.") + } + + // Set unsignedBody flag + let shouldForceUnsignedBody = SigV4Util.shouldForceUnsignedBody( + flow: context.getFlowType(), + serviceName: serviceName, + opName: operationName + ) + let unsignedBody = context.hasUnsignedPayloadTrait() || shouldForceUnsignedBody + updatedSigningProperties.set(key: AttributeKeys.unsignedBody, value: unsignedBody) + + // Set signedBodyHeader flag + let useSignedBodyHeader = SigV4Util.serviceUsesUnsignedBodyHeader(serviceName: serviceName) && !unsignedBody + updatedSigningProperties.set( + key: AttributeKeys.signedBodyHeader, + value: useSignedBodyHeader ? .contentSha256 : AWSSignedBodyHeader.none + ) + + // Set flags in SigningFlags object (S3 customizations) + SigV4Util.setS3SpecificFlags(signingProperties: &updatedSigningProperties, serviceName: serviceName) + return updatedSigningProperties } } diff --git a/Sources/Core/AWSClientRuntime/Auth/AuthSchemes/SigV4Util.swift b/Sources/Core/AWSClientRuntime/Auth/AuthSchemes/SigV4Util.swift new file mode 100644 index 00000000000..61d16f7ab75 --- /dev/null +++ b/Sources/Core/AWSClientRuntime/Auth/AuthSchemes/SigV4Util.swift @@ -0,0 +1,35 @@ +// +// Copyright Amazon.com Inc. or its affiliates. +// All Rights Reserved. +// +// SPDX-License-Identifier: Apache-2.0 +// + +import Foundation +import ClientRuntime + +public class SigV4Util { + static let unsignedBodyHeader = ["S3", "Glacier"] + static let forceUnsignedBodyForPresigningURL = [ + "S3": ["getObject", "putObject"] + ] + + static func shouldForceUnsignedBody(flow: FlowType, serviceName: String, opName: String) -> Bool { + let serviceQualifies = forceUnsignedBodyForPresigningURL.keys.contains(serviceName) + let flowQualifies = flow == .PRESIGN_URL + return serviceQualifies && flowQualifies && forceUnsignedBodyForPresigningURL[serviceName]!.contains(opName) + } + + static func serviceUsesUnsignedBodyHeader(serviceName: String) -> Bool { + return unsignedBodyHeader.contains(serviceName) + } + + static func setS3SpecificFlags(signingProperties: inout Attributes, serviceName: String) { + let serviceIsS3 = serviceName == "S3" + // Set useDoubleURIEncode to false IFF service is S3 + signingProperties.set(key: AttributeKeys.useDoubleURIEncode, value: !serviceIsS3) + // Set shouldNormalizeURIPath to false IFF service is S3 + signingProperties.set(key: AttributeKeys.shouldNormalizeURIPath, value: !serviceIsS3) + signingProperties.set(key: AttributeKeys.omitSessionToken, value: false) + } +} diff --git a/Sources/Core/AWSClientRuntime/EventStream/AWSMessageSigner.swift b/Sources/Core/AWSClientRuntime/EventStream/AWSMessageSigner.swift index a0b3f52db5c..0037f8035b8 100644 --- a/Sources/Core/AWSClientRuntime/EventStream/AWSMessageSigner.swift +++ b/Sources/Core/AWSClientRuntime/EventStream/AWSMessageSigner.swift @@ -11,8 +11,12 @@ extension AWSEventStream { /// Signs a `Message` using the AWS SigV4 signing algorithm public class AWSMessageSigner: MessageSigner { let encoder: MessageEncoder + let signer: () async throws -> ClientRuntime.Signer let signingConfig: () async throws -> AWSSigningConfig let requestSignature: () -> String + // Attribute key used to save AWSSigningConfig into signingProperties argument + // for AWSSigV4Signer::signEvent call that conforms to Signer::signEvent. + static let signingConfigKey = AttributeKey(name: "EventStreamSigningConfig") private var _previousSignature: String? @@ -35,9 +39,11 @@ extension AWSEventStream { } public init(encoder: MessageEncoder, + signer: @escaping () async throws -> ClientRuntime.Signer, signingConfig: @escaping () async throws -> AWSSigningConfig, requestSignature: @escaping () -> String) { self.encoder = encoder + self.signer = signer self.signingConfig = signingConfig self.requestSignature = requestSignature } @@ -49,11 +55,15 @@ extension AWSEventStream { // encode to bytes let encodedMessage = try encoder.encode(message: message) let signingConfig = try await self.signingConfig() - - // sign encoded bytes - let signingResult = try await AWSSigV4Signer.signEvent(payload: encodedMessage, - previousSignature: previousSignature, - signingConfig: signingConfig) + // Fetch signer + let signer = try await self.signer() + // Wrap config into signingProperties: Attributes + var configWrapper = Attributes() + configWrapper.set(key: AWSMessageSigner.signingConfigKey, value: signingConfig) + // Sign encoded bytes + let signingResult = try await signer.signEvent(payload: encodedMessage, + previousSignature: previousSignature, + signingProperties: configWrapper) previousSignature = signingResult.signature return signingResult.output } @@ -62,9 +72,15 @@ extension AWSEventStream { /// - Returns: Signed `Message` with `:chunk-signature` & `:date` headers public func signEmpty() async throws -> ClientRuntime.EventStream.Message { let signingConfig = try await self.signingConfig() - let signingResult = try await AWSSigV4Signer.signEvent(payload: .init(), - previousSignature: previousSignature, - signingConfig: signingConfig) + // Fetch signer + let signer = try await self.signer() + // Wrap config into signingProperties: Attributes + var configWrapper = Attributes() + configWrapper.set(key: AWSMessageSigner.signingConfigKey, value: signingConfig) + // Sign empty payload + let signingResult = try await signer.signEvent(payload: .init(), + previousSignature: previousSignature, + signingProperties: configWrapper) return signingResult.output } } diff --git a/Sources/Core/AWSClientRuntime/HttpContextBuilder+Extension.swift b/Sources/Core/AWSClientRuntime/HttpContextBuilder+Extension.swift index eef19f1cf29..f2e8bc2127d 100644 --- a/Sources/Core/AWSClientRuntime/HttpContextBuilder+Extension.swift +++ b/Sources/Core/AWSClientRuntime/HttpContextBuilder+Extension.swift @@ -14,10 +14,6 @@ extension HttpContext { return attributes.get(key: AttributeKeys.credentialsProvider) } - public func getRequestSignature() -> String { - return attributes.get(key: AttributeKeys.requestSignature)! - } - public func getSigningAlgorithm() -> AWSSigningAlgorithm? { return attributes.get(key: AttributeKeys.signingAlgorithm) } @@ -55,17 +51,30 @@ extension HttpContext { public func setupBidirectionalStreaming() throws { // setup client to server let messageEncoder = AWSClientRuntime.AWSEventStream.AWSMessageEncoder() - let messageSigner = AWSClientRuntime.AWSEventStream.AWSMessageSigner(encoder: messageEncoder) { - try await self.makeEventStreamSigningConfig() - } requestSignature: { - self.getRequestSignature() - } + let messageSigner = AWSClientRuntime.AWSEventStream.AWSMessageSigner( + encoder: messageEncoder, + signer: { try self.fetchSigner() }, + signingConfig: { try await self.makeEventStreamSigningConfig() }, + requestSignature: { self.getRequestSignature() } + ) attributes.set(key: AttributeKeys.messageEncoder, value: messageEncoder) attributes.set(key: AttributeKeys.messageSigner, value: messageSigner) // enable the flag attributes.set(key: AttributeKeys.bidirectionalStreaming, value: true) } + + func fetchSigner() throws -> ClientRuntime.Signer { + guard let authScheme = self.getSelectedAuthScheme() else { + throw ClientError.authError( + "Signer for event stream could not be loaded because auth scheme was not configured." + ) + } + guard let signer = authScheme.signer else { + throw ClientError.authError("Signer was not configured for the selected auth scheme.") + } + return signer + } } extension HttpContextBuilder { @@ -75,14 +84,6 @@ extension HttpContextBuilder { return self } - /// Sets the request signature for the event stream operation - /// - Parameter value: `String` request signature - @discardableResult - public func withRequestSignature(value: String) -> HttpContextBuilder { - self.attributes.set(key: AttributeKeys.requestSignature, value: value) - return self - } - @discardableResult public func withSigningAlgorithm(value: AWSSigningAlgorithm) -> HttpContextBuilder { self.attributes.set(key: AttributeKeys.signingAlgorithm, value: value) @@ -93,7 +94,6 @@ extension HttpContextBuilder { extension AttributeKeys { public static let credentialsProvider = AttributeKey<(any CredentialsProviding)>(name: "CredentialsProvider") public static let signingAlgorithm = AttributeKey(name: "SigningAlgorithm") - public static let requestSignature = AttributeKey(name: "AWS_HTTP_SIGNATURE") // Keys used to store/retrieve AWSSigningConfig fields in/from signingProperties passed to AWSSigV4Signer public static let unsignedBody = AttributeKey(name: "UnsignedBody") diff --git a/Sources/Core/AWSClientRuntime/Signing/AWSSigV4Signer.swift b/Sources/Core/AWSClientRuntime/Signing/AWSSigV4Signer.swift index 54be47160fd..3533ad86e18 100644 --- a/Sources/Core/AWSClientRuntime/Signing/AWSSigV4Signer.swift +++ b/Sources/Core/AWSClientRuntime/Signing/AWSSigV4Signer.swift @@ -10,7 +10,7 @@ import ClientRuntime import Foundation public class AWSSigV4Signer: ClientRuntime.Signer { - public func sign( + public func signRequest( requestBuilder: SdkHttpRequestBuilder, identity: IdentityT, signingProperties: ClientRuntime.Attributes @@ -94,6 +94,18 @@ public class AWSSigV4Signer: ClientRuntime.Signer { ) } + public func signEvent( + payload: Data, + previousSignature: String, + signingProperties: Attributes + ) async throws -> SigningResult { + let signingConfig = signingProperties.get(key: AWSEventStream.AWSMessageSigner.signingConfigKey) + guard let signingConfig else { + throw ClientError.dataNotFound("Failed to sign event stream message due to missing signing config.") + } + return try await signEvent(payload: payload, previousSignature: previousSignature, signingConfig: signingConfig) + } + static let logger: SwiftLogger = SwiftLogger(label: "AWSSigV4Signer") public static func sigV4SignedURL( @@ -143,7 +155,7 @@ public class AWSSigV4Signer: ClientRuntime.Signer { /// the current event payload like a rolling signature calculation. /// - signingConfig: The signing configuration /// - Returns: The signed event with :date and :chunk-signature headers - static func signEvent(payload: Data, + public func signEvent(payload: Data, previousSignature: String, signingConfig: AWSSigningConfig) async throws -> SigningResult { let signature = try await Signer.signEvent(event: payload, @@ -183,8 +195,3 @@ public class AWSSigV4Signer: ClientRuntime.Signer { } } } - -public struct SigningResult { - public let output: T - public let signature: String -} diff --git a/Tests/Core/AWSClientRuntimeTests/EventStream/AWSMessageEncoderStreamTests.swift b/Tests/Core/AWSClientRuntimeTests/EventStream/AWSMessageEncoderStreamTests.swift index b1350209f5f..21f330fcbad 100644 --- a/Tests/Core/AWSClientRuntimeTests/EventStream/AWSMessageEncoderStreamTests.swift +++ b/Tests/Core/AWSClientRuntimeTests/EventStream/AWSMessageEncoderStreamTests.swift @@ -40,6 +40,8 @@ final class AWSMessageEncoderStreamTests: XCTestCase { .build() let messageSigner = AWSEventStream.AWSMessageSigner(encoder: messageEncoder) { + return AWSSigV4Signer() + } signingConfig: { return try await context.makeEventStreamSigningConfig() } requestSignature: { return context.getRequestSignature() @@ -68,6 +70,8 @@ final class AWSMessageEncoderStreamTests: XCTestCase { .build() let messageSigner = AWSEventStream.AWSMessageSigner(encoder: messageEncoder) { + return AWSSigV4Signer() + } signingConfig: { return try await context.makeEventStreamSigningConfig() } requestSignature: { return context.getRequestSignature() diff --git a/Tests/Core/AWSClientRuntimeTests/Sigv4/SigV4SigningTests.swift b/Tests/Core/AWSClientRuntimeTests/Sigv4/SigV4SigningTests.swift index ba84365c3ac..d94f3e850a8 100644 --- a/Tests/Core/AWSClientRuntimeTests/Sigv4/SigV4SigningTests.swift +++ b/Tests/Core/AWSClientRuntimeTests/Sigv4/SigV4SigningTests.swift @@ -80,7 +80,7 @@ class Sigv4SigningTests: XCTestCase { let messagePayload = try! encoder.encode(message: message) - let result = try! await AWSSigV4Signer.signEvent(payload: messagePayload, + let result = try! await AWSSigV4Signer().signEvent(payload: messagePayload, previousSignature: prevSignature, signingConfig: signingConfig) XCTAssertEqual(":date", result.output.headers[0].name) diff --git a/codegen/smithy-aws-swift-codegen/src/main/kotlin/software/amazon/smithy/aws/swift/codegen/AWSHttpProtocolCustomizations.kt b/codegen/smithy-aws-swift-codegen/src/main/kotlin/software/amazon/smithy/aws/swift/codegen/AWSHttpProtocolCustomizations.kt index 548748f3e83..4088c262fa6 100644 --- a/codegen/smithy-aws-swift-codegen/src/main/kotlin/software/amazon/smithy/aws/swift/codegen/AWSHttpProtocolCustomizations.kt +++ b/codegen/smithy-aws-swift-codegen/src/main/kotlin/software/amazon/smithy/aws/swift/codegen/AWSHttpProtocolCustomizations.kt @@ -5,9 +5,11 @@ package software.amazon.smithy.aws.swift.codegen +import software.amazon.smithy.aws.swift.codegen.customization.RulesBasedAuthSchemeResolverGenerator import software.amazon.smithy.aws.swift.codegen.middleware.AWSSigningMiddleware import software.amazon.smithy.model.shapes.OperationShape import software.amazon.smithy.model.shapes.ServiceShape +import software.amazon.smithy.swift.codegen.AuthSchemeResolverGenerator import software.amazon.smithy.swift.codegen.SwiftWriter import software.amazon.smithy.swift.codegen.integration.ClientProperty import software.amazon.smithy.swift.codegen.integration.DefaultHttpProtocolCustomizations @@ -45,6 +47,10 @@ abstract class AWSHttpProtocolCustomizations : DefaultHttpProtocolCustomizations override fun renderInternals(ctx: ProtocolGenerator.GenerationContext) { super.renderInternals(ctx) + // Generate rules-based auth scheme resolver for services that depend on endpoint resolver for auth scheme resolution + if (AuthSchemeResolverGenerator.usesRulesBasedAuthResolver(ctx)) { + RulesBasedAuthSchemeResolverGenerator().render(ctx) + } EndpointResolverGenerator().render(ctx) } diff --git a/codegen/smithy-aws-swift-codegen/src/main/kotlin/software/amazon/smithy/aws/swift/codegen/AWSServiceConfig.kt b/codegen/smithy-aws-swift-codegen/src/main/kotlin/software/amazon/smithy/aws/swift/codegen/AWSServiceConfig.kt index e01a47d3882..7e8c5a21047 100644 --- a/codegen/smithy-aws-swift-codegen/src/main/kotlin/software/amazon/smithy/aws/swift/codegen/AWSServiceConfig.kt +++ b/codegen/smithy-aws-swift-codegen/src/main/kotlin/software/amazon/smithy/aws/swift/codegen/AWSServiceConfig.kt @@ -90,7 +90,7 @@ class AWSServiceConfig(writer: SwiftWriter, val ctx: ProtocolGenerator.Generatio } override fun serviceConfigProperties(): List { - var configs = mutableListOf() + val configs = mutableListOf() // service specific EndpointResolver configs.add(ConfigField(ENDPOINT_RESOLVER, AWSServiceTypes.EndpointResolver, "\$N", "Endpoint resolver")) diff --git a/codegen/smithy-aws-swift-codegen/src/main/kotlin/software/amazon/smithy/aws/swift/codegen/EndpointParamsGenerator.kt b/codegen/smithy-aws-swift-codegen/src/main/kotlin/software/amazon/smithy/aws/swift/codegen/EndpointParamsGenerator.kt index 5b9fb59dc9d..763a5e33db8 100644 --- a/codegen/smithy-aws-swift-codegen/src/main/kotlin/software/amazon/smithy/aws/swift/codegen/EndpointParamsGenerator.kt +++ b/codegen/smithy-aws-swift-codegen/src/main/kotlin/software/amazon/smithy/aws/swift/codegen/EndpointParamsGenerator.kt @@ -10,23 +10,30 @@ import software.amazon.smithy.codegen.core.Symbol import software.amazon.smithy.rulesengine.language.EndpointRuleSet import software.amazon.smithy.rulesengine.language.syntax.parameters.Parameter import software.amazon.smithy.rulesengine.language.syntax.parameters.ParameterType +import software.amazon.smithy.swift.codegen.AuthSchemeResolverGenerator import software.amazon.smithy.swift.codegen.SwiftTypes import software.amazon.smithy.swift.codegen.SwiftWriter import software.amazon.smithy.swift.codegen.getOrNull +import software.amazon.smithy.swift.codegen.integration.ProtocolGenerator import software.amazon.smithy.swift.codegen.model.boxed import software.amazon.smithy.swift.codegen.model.defaultValue +import software.amazon.smithy.swift.codegen.utils.clientName import software.amazon.smithy.swift.codegen.utils.toLowerCamelCase /** * Generates EndpointParams struct for the service */ class EndpointParamsGenerator(private val endpointRules: EndpointRuleSet?) { - fun render(writer: SwiftWriter) { + fun render(writer: SwiftWriter, ctx: ProtocolGenerator.GenerationContext? = null) { writer.openBlock("public struct EndpointParams {", "}") { endpointRules?.parameters?.toList()?.sortedBy { it.name.toString() }?.let { sortedParameters -> renderMembers(writer, sortedParameters) writer.write("") renderInit(writer, sortedParameters) + // Convert auth scheme params to endpoint params for rules-based auth scheme resolvers + if (ctx != null && AuthSchemeResolverGenerator.usesRulesBasedAuthResolver(ctx)) { + renderConversionInit(writer, sortedParameters, ctx) + } } } } @@ -58,6 +65,22 @@ class EndpointParamsGenerator(private val endpointRules: EndpointRuleSet?) { writer.write("public let \$L: \$L$optional", memberName, memberSymbol) } } + + private fun renderConversionInit( + writer: SwiftWriter, + parameters: List, + ctx: ProtocolGenerator.GenerationContext + ) { + writer.apply { + val paramsType = ctx.service.sdkId.clientName() + "AuthSchemeResolverParameters" + openBlock("public init (authSchemeParams: \$L) {", "}", paramsType) { + parameters.forEach { + val memberName = it.name.toString().toLowerCamelCase() + writer.write("self.\$1L = authSchemeParams.\$1L", memberName) + } + } + } + } } fun Parameter.toSymbol(): Symbol { diff --git a/codegen/smithy-aws-swift-codegen/src/main/kotlin/software/amazon/smithy/aws/swift/codegen/EndpointResolverGenerator.kt b/codegen/smithy-aws-swift-codegen/src/main/kotlin/software/amazon/smithy/aws/swift/codegen/EndpointResolverGenerator.kt index 08060b7911b..efcb864a043 100644 --- a/codegen/smithy-aws-swift-codegen/src/main/kotlin/software/amazon/smithy/aws/swift/codegen/EndpointResolverGenerator.kt +++ b/codegen/smithy-aws-swift-codegen/src/main/kotlin/software/amazon/smithy/aws/swift/codegen/EndpointResolverGenerator.kt @@ -30,7 +30,7 @@ class EndpointResolverGenerator() { ctx.delegator.useFileWriter("./$rootNamespace/EndpointResolver.swift") { val endpointParamsGenerator = EndpointParamsGenerator(ruleSet) - endpointParamsGenerator.render(it) + endpointParamsGenerator.render(it, ctx) } ctx.delegator.useFileWriter("./$rootNamespace/EndpointResolver.swift") { diff --git a/codegen/smithy-aws-swift-codegen/src/main/kotlin/software/amazon/smithy/aws/swift/codegen/PresignerGenerator.kt b/codegen/smithy-aws-swift-codegen/src/main/kotlin/software/amazon/smithy/aws/swift/codegen/PresignerGenerator.kt index 9e84d35e457..fbdf825c481 100644 --- a/codegen/smithy-aws-swift-codegen/src/main/kotlin/software/amazon/smithy/aws/swift/codegen/PresignerGenerator.kt +++ b/codegen/smithy-aws-swift-codegen/src/main/kotlin/software/amazon/smithy/aws/swift/codegen/PresignerGenerator.kt @@ -20,8 +20,7 @@ import software.amazon.smithy.swift.codegen.core.toProtocolGenerationContext import software.amazon.smithy.swift.codegen.integration.ProtocolGenerator import software.amazon.smithy.swift.codegen.integration.SwiftIntegration import software.amazon.smithy.swift.codegen.middleware.MiddlewareExecutionGenerator -import software.amazon.smithy.swift.codegen.middleware.MiddlewareStep -import software.amazon.smithy.swift.codegen.middleware.OperationMiddleware +import software.amazon.smithy.swift.codegen.middleware.MiddlewareExecutionGenerator.Companion.ContextAttributeCodegenFlowType.PRESIGN_REQUEST import software.amazon.smithy.swift.codegen.model.expectShape import software.amazon.smithy.swift.codegen.model.toUpperCamelCase @@ -44,9 +43,9 @@ class PresignerGenerator : SwiftIntegration { } presignOperations.forEach { presignableOperation -> val op = ctx.model.expectShape(presignableOperation.operationId) - val inputType = op.input.get().getName() + val inputType = op.input.get().name delegator.useFileWriter("${ctx.settings.moduleName}/models/$inputType+Presigner.swift") { writer -> - var serviceConfig = AWSServiceConfig(writer, protoCtx) + val serviceConfig = AWSServiceConfig(writer, protoCtx) renderPresigner(writer, ctx, delegator, op, inputType, serviceConfig) } // Expose presign-request as a method for service client object @@ -76,7 +75,7 @@ class PresignerGenerator : SwiftIntegration { val serviceShape = ctx.model.expectShape(ctx.settings.service) val protocolGenerator = ctx.protocolGenerator?.let { it } ?: run { return } val protocolGeneratorContext = ctx.toProtocolGenerationContext(serviceShape, delegator)?.let { it } ?: run { return } - val operationMiddleware = resolveOperationMiddleware(protocolGenerator, op, ctx) + val operationMiddleware = protocolGenerator.operationMiddleware writer.addImport(AWSClientConfiguration) writer.addImport(SdkHttpRequest) @@ -103,7 +102,7 @@ class PresignerGenerator : SwiftIntegration { operationMiddleware, operationStackName ) - generator.render(op) { writer, _ -> + generator.render(op, PRESIGN_REQUEST) { writer, _ -> writer.write("return nil") } val requestBuilderName = "presignedRequestBuilder" @@ -153,23 +152,4 @@ class PresignerGenerator : SwiftIntegration { write("/// - Returns: `URLRequest`: The presigned request for ${op.toUpperCamelCase()} operation.") } } - - private fun resolveOperationMiddleware(protocolGenerator: ProtocolGenerator, op: OperationShape, ctx: CodegenContext): OperationMiddleware { - val operationMiddlewareCopy = protocolGenerator.operationMiddleware.clone() - operationMiddlewareCopy.removeMiddleware(op, MiddlewareStep.FINALIZESTEP, "AWSSigningMiddleware") - val service = ctx.model.expectShape(ctx.settings.service) - val operation = ctx.model.expectShape(op.id) - if (AWSSigningMiddleware.hasSigV4AuthScheme(ctx.model, service, operation)) { - val params = AWSSigningParams( - service, - op, - useSignatureTypeQueryString = false, - forceUnsignedBody = false, - useExpiration = true, - signingAlgorithm = SigningAlgorithm.SigV4 - ) - operationMiddlewareCopy.appendMiddleware(op, AWSSigningMiddleware(ctx.model, ctx.symbolProvider, params)) - } - return operationMiddlewareCopy - } } diff --git a/codegen/smithy-aws-swift-codegen/src/main/kotlin/software/amazon/smithy/aws/swift/codegen/customization/RulesBasedAuthSchemeResolverGenerator.kt b/codegen/smithy-aws-swift-codegen/src/main/kotlin/software/amazon/smithy/aws/swift/codegen/customization/RulesBasedAuthSchemeResolverGenerator.kt new file mode 100644 index 00000000000..5ba7311392d --- /dev/null +++ b/codegen/smithy-aws-swift-codegen/src/main/kotlin/software/amazon/smithy/aws/swift/codegen/customization/RulesBasedAuthSchemeResolverGenerator.kt @@ -0,0 +1,148 @@ +package software.amazon.smithy.aws.swift.codegen.customization + +import software.amazon.smithy.aws.traits.auth.SigV4Trait +import software.amazon.smithy.rulesengine.language.EndpointRuleSet +import software.amazon.smithy.rulesengine.traits.EndpointRuleSetTrait +import software.amazon.smithy.swift.codegen.AuthSchemeResolverGenerator +import software.amazon.smithy.swift.codegen.ClientRuntimeTypes +import software.amazon.smithy.swift.codegen.SwiftDependency +import software.amazon.smithy.swift.codegen.SwiftWriter +import software.amazon.smithy.swift.codegen.integration.ProtocolGenerator +import software.amazon.smithy.swift.codegen.integration.ServiceTypes +import software.amazon.smithy.swift.codegen.model.getTrait +import software.amazon.smithy.swift.codegen.utils.toLowerCamelCase + +class RulesBasedAuthSchemeResolverGenerator { + fun render(ctx: ProtocolGenerator.GenerationContext) { + val rootNamespace = ctx.settings.moduleName + + ctx.delegator.useFileWriter("./$rootNamespace/${ClientRuntimeTypes.Core.AuthSchemeResolver.name}.swift") { + renderServiceSpecificDefaultResolver(ctx, it) + it.write("") + it.addImport(SwiftDependency.CLIENT_RUNTIME.target) + it.addIndividualTypeImport("enum", "AWSClientRuntime", "AuthScheme") + } + } + + private fun renderServiceSpecificDefaultResolver(ctx: ProtocolGenerator.GenerationContext, writer: SwiftWriter) { + val sdkId = AuthSchemeResolverGenerator.getSdkId(ctx) + val serviceSpecificDefaultResolverName = "Default$sdkId${ClientRuntimeTypes.Core.AuthSchemeResolver.name}" + val serviceSpecificAuthResolverProtocol = sdkId + ClientRuntimeTypes.Core.AuthSchemeResolver.name + + writer.apply { + writer.openBlock( + "public struct \$L: \$L {", + "}", + serviceSpecificDefaultResolverName, + serviceSpecificAuthResolverProtocol + ) { + renderResolveAuthSchemeMethod(ctx, writer) + write("") + renderConstructParametersMethod( + ctx, + sdkId + ClientRuntimeTypes.Core.AuthSchemeResolverParameters.name, + writer + ) + } + } + } + + private fun renderResolveAuthSchemeMethod(ctx: ProtocolGenerator.GenerationContext, writer: SwiftWriter) { + val sdkId = AuthSchemeResolverGenerator.getSdkId(ctx) + val serviceParamsName = sdkId + ClientRuntimeTypes.Core.AuthSchemeResolverParameters.name + + writer.apply { + openBlock( + "public func resolveAuthScheme(params: \$L) throws -> [AuthOption] {", + "}", + ServiceTypes.AuthSchemeResolverParams + ) { + // Return value of array of auth options + write("var validAuthOptions = [AuthOption]()") + + // Cast params to service specific params object + openBlock( + "guard let serviceParams = params as? \$L else {", + "}", + serviceParamsName + ) { + write("throw ClientError.authError(\"Service specific auth scheme parameters type must be passed to auth scheme resolver.\")") + } + + // Construct endpoint params from auth params + write("let endpointParams = EndpointParams(authSchemeParams: serviceParams)") + // Resolve endpoint, and retrieve auth schemes valid for the resolved endpoint + write("let endpoint = try DefaultEndpointResolver().resolve(params: endpointParams)") + openBlock("guard let authSchemes = endpoint.authSchemes() else {", "}") { + // Call internal modeled model-based auth scheme resolver as fall-back if no auth schemes + // are returned by endpoint resolver. + write("return try InternalModeled${sdkId + ClientRuntimeTypes.Core.AuthSchemeResolver.name}().resolveAuthScheme(params: params)") + } + writer.write("let schemes = try authSchemes.map { (input) -> AWSClientRuntime.AuthScheme in try AWSClientRuntime.AuthScheme(from: input) }") + // If endpoint resolver returned auth schemes, iterate over them and save each as valid auth option to return + openBlock("for scheme in schemes {", "}") { + openBlock("switch scheme {", "}") { + // SigV4 case + write("case .sigV4(let param):") + indent() + write("var sigV4Option = AuthOption(schemeID: \"${SigV4Trait.ID}\")") + write("sigV4Option.signingProperties.set(key: AttributeKeys.signingName, value: param.signingName)") + write("sigV4Option.signingProperties.set(key: AttributeKeys.signingRegion, value: param.signingRegion)") + write("validAuthOptions.append(sigV4Option)") + dedent() + // SigV4A case + write("case .sigV4A(let param):") + indent() + // sigv4a trait is not yet implemented by Smithy + // This is a SDK-level customization until the trait is added + write("var sigV4Option = AuthOption(schemeID: \"${SigV4Trait.ID}a\")") + write("sigV4Option.signingProperties.set(key: AttributeKeys.signingName, value: param.signingName)") + write("sigV4Option.signingProperties.set(key: AttributeKeys.signingRegion, value: param.signingRegionSet?[0])") + write("validAuthOptions.append(sigV4Option)") + dedent() + // Default case: throw error if returned auth scheme is neither SigV4 nor SigV4A + write("default:") + indent() + write("throw ClientError.authError(\"Unknown auth scheme name: \\(scheme.name)\")") + dedent() + } + } + // Return result + write("return validAuthOptions") + } + } + } + + private fun renderConstructParametersMethod(ctx: ProtocolGenerator.GenerationContext, returnTypeName: String, writer: SwiftWriter) { + writer.apply { + openBlock( + "public func constructParameters(context: HttpContext) throws -> \$L {", + "}", + ServiceTypes.AuthSchemeResolverParams + ) { + openBlock("guard let opName = context.getOperation() else {", "}") { + write("throw ClientError.dataNotFound(\"Operation name not configured in middleware context for auth scheme resolver params construction.\")") + } + + // Get endpoint param from middleware context + openBlock("guard let endpointParam = context.attributes.get(key: AttributeKey(name: \"EndpointParams\")) else {", "}") { + write("throw ClientError.dataNotFound(\"Endpoint param not configured in middleware context for rules-based auth scheme resolver params construction.\")") + } + + // Copy over endpoint param fields to auth param fields + val ruleSetNode = ctx.service.getTrait()?.ruleSet + val ruleSet = if (ruleSetNode != null) EndpointRuleSet.fromNode(ruleSetNode) else null + val paramList = ArrayList() + ruleSet?.parameters?.toList()?.sortedBy { it.name.toString() }?.let { sortedParameters -> + sortedParameters.forEach { param -> + val memberName = param.name.toString().toLowerCamelCase() + paramList.add("$memberName: endpointParam.$memberName") + } + } + + val argStringToAppend = if (paramList.isEmpty()) "" else ", " + paramList.joinToString() + write("return $returnTypeName(operation: opName$argStringToAppend)") + } + } + } +} diff --git a/codegen/smithy-aws-swift-codegen/src/main/kotlin/software/amazon/smithy/aws/swift/codegen/customization/presignable/PresignableUrlIntegration.kt b/codegen/smithy-aws-swift-codegen/src/main/kotlin/software/amazon/smithy/aws/swift/codegen/customization/presignable/PresignableUrlIntegration.kt index b953730aea9..316964ee0af 100644 --- a/codegen/smithy-aws-swift-codegen/src/main/kotlin/software/amazon/smithy/aws/swift/codegen/customization/presignable/PresignableUrlIntegration.kt +++ b/codegen/smithy-aws-swift-codegen/src/main/kotlin/software/amazon/smithy/aws/swift/codegen/customization/presignable/PresignableUrlIntegration.kt @@ -2,9 +2,7 @@ package software.amazon.smithy.aws.swift.codegen.customization.presignable import software.amazon.smithy.aws.swift.codegen.AWSClientRuntimeTypes import software.amazon.smithy.aws.swift.codegen.AWSServiceConfig -import software.amazon.smithy.aws.swift.codegen.AWSSigningParams import software.amazon.smithy.aws.swift.codegen.PresignableOperation -import software.amazon.smithy.aws.swift.codegen.SigningAlgorithm import software.amazon.smithy.aws.swift.codegen.customization.InputTypeGETQueryItemMiddleware import software.amazon.smithy.aws.swift.codegen.customization.PutObjectPresignedURLMiddleware import software.amazon.smithy.aws.swift.codegen.middleware.AWSSigningMiddleware @@ -28,6 +26,7 @@ import software.amazon.smithy.swift.codegen.integration.ProtocolGenerator import software.amazon.smithy.swift.codegen.integration.SwiftIntegration import software.amazon.smithy.swift.codegen.integration.middlewares.handlers.MiddlewareShapeUtils import software.amazon.smithy.swift.codegen.middleware.MiddlewareExecutionGenerator +import software.amazon.smithy.swift.codegen.middleware.MiddlewareExecutionGenerator.Companion.ContextAttributeCodegenFlowType.PRESIGN_URL import software.amazon.smithy.swift.codegen.middleware.MiddlewareStep import software.amazon.smithy.swift.codegen.middleware.OperationMiddleware import software.amazon.smithy.swift.codegen.model.expectShape @@ -131,7 +130,7 @@ class PresignableUrlIntegration(private val presignedOperations: Map + generator.render(op, PRESIGN_URL) { writer, _ -> writer.write("return nil") } @@ -195,21 +194,6 @@ class PresignableUrlIntegration(private val presignedOperations: Map { operationMiddlewareCopy.removeMiddleware(op, MiddlewareStep.SERIALIZESTEP, "OperationInputBodyMiddleware") diff --git a/codegen/smithy-aws-swift-codegen/src/main/kotlin/software/amazon/smithy/aws/swift/codegen/middleware/OperationEndpointResolverMiddleware.kt b/codegen/smithy-aws-swift-codegen/src/main/kotlin/software/amazon/smithy/aws/swift/codegen/middleware/OperationEndpointResolverMiddleware.kt index e59cd7217f1..79b7a869aaa 100644 --- a/codegen/smithy-aws-swift-codegen/src/main/kotlin/software/amazon/smithy/aws/swift/codegen/middleware/OperationEndpointResolverMiddleware.kt +++ b/codegen/smithy-aws-swift-codegen/src/main/kotlin/software/amazon/smithy/aws/swift/codegen/middleware/OperationEndpointResolverMiddleware.kt @@ -18,6 +18,7 @@ import software.amazon.smithy.rulesengine.traits.ContextParamTrait import software.amazon.smithy.rulesengine.traits.EndpointRuleSetTrait import software.amazon.smithy.rulesengine.traits.StaticContextParamDefinition import software.amazon.smithy.rulesengine.traits.StaticContextParamsTrait +import software.amazon.smithy.swift.codegen.AuthSchemeResolverGenerator import software.amazon.smithy.swift.codegen.SwiftWriter import software.amazon.smithy.swift.codegen.integration.ProtocolGenerator import software.amazon.smithy.swift.codegen.integration.middlewares.handlers.MiddlewareShapeUtils @@ -70,6 +71,10 @@ class OperationEndpointResolverMiddleware( } } writer.write("let endpointParams = EndpointParams(${params.joinToString(separator = ", ")})") + // Write code that saves endpoint params to middleware context for use in auth scheme middleware when using rules-based auth scheme resolvers + if (AuthSchemeResolverGenerator.usesRulesBasedAuthResolver(ctx)) { + writer.write("context.attributes.set(key: AttributeKey(name: \"EndpointParams\"), value: endpointParams)") + } val middlewareParamsString = "endpointResolver: config.serviceSpecific.endpointResolver, endpointParams: endpointParams" writer.write("$operationStackName.${middlewareStep.stringValue()}.intercept(position: ${position.stringValue()}, middleware: \$N<\$N, \$N>($middlewareParamsString))", AWSServiceTypes.EndpointResolverMiddleware, output, outputError) } diff --git a/codegen/smithy-aws-swift-codegen/src/test/kotlin/software/amazon/smithy/aws/swift/codegen/EventStreamTests.kt b/codegen/smithy-aws-swift-codegen/src/test/kotlin/software/amazon/smithy/aws/swift/codegen/EventStreamTests.kt index 855d31eaffe..41d901ad294 100644 --- a/codegen/smithy-aws-swift-codegen/src/test/kotlin/software/amazon/smithy/aws/swift/codegen/EventStreamTests.kt +++ b/codegen/smithy-aws-swift-codegen/src/test/kotlin/software/amazon/smithy/aws/swift/codegen/EventStreamTests.kt @@ -216,6 +216,7 @@ extension EventStreamTestClient: EventStreamTestClientProtocol { .withPartitionID(value: config.partitionID) .withAuthSchemes(value: config.authSchemes!) .withAuthSchemeResolver(value: config.serviceSpecific.authSchemeResolver) + .withUnsignedPayloadTrait(value: false) .withCredentialsProvider(value: config.credentialsProvider) .withIdentityResolver(value: config.credentialsProvider, type: IdentityKind.aws) .withRegion(value: config.region) diff --git a/codegen/smithy-aws-swift-codegen/src/test/kotlin/software/amazon/smithy/aws/swift/codegen/PresignerGeneratorTests.kt b/codegen/smithy-aws-swift-codegen/src/test/kotlin/software/amazon/smithy/aws/swift/codegen/PresignerGeneratorTests.kt index a8915c02821..dfef8493df3 100644 --- a/codegen/smithy-aws-swift-codegen/src/test/kotlin/software/amazon/smithy/aws/swift/codegen/PresignerGeneratorTests.kt +++ b/codegen/smithy-aws-swift-codegen/src/test/kotlin/software/amazon/smithy/aws/swift/codegen/PresignerGeneratorTests.kt @@ -41,6 +41,9 @@ class PresignerGeneratorTests { .withPartitionID(value: config.partitionID) .withAuthSchemes(value: config.authSchemes!) .withAuthSchemeResolver(value: config.serviceSpecific.authSchemeResolver) + .withUnsignedPayloadTrait(value: false) + .withFlowType(value: .PRESIGN_REQUEST) + .withExpiration(value: expiration) .withCredentialsProvider(value: config.credentialsProvider) .withIdentityResolver(value: config.credentialsProvider, type: IdentityKind.aws) .withRegion(value: config.region) @@ -56,8 +59,6 @@ class PresignerGeneratorTests { operation.buildStep.intercept(position: .before, middleware: ClientRuntime.AuthSchemeMiddleware()) operation.finalizeStep.intercept(position: .after, middleware: ClientRuntime.RetryMiddleware(options: config.retryStrategyOptions)) operation.finalizeStep.intercept(position: .before, middleware: ClientRuntime.SignerMiddleware()) - let sigv4Config = AWSClientRuntime.SigV4Config(expiration: expiration, unsignedBody: false, signingAlgorithm: .sigv4) - operation.finalizeStep.intercept(position: .before, middleware: AWSClientRuntime.SigV4Middleware(config: sigv4Config)) operation.deserializeStep.intercept(position: .after, middleware: ClientRuntime.DeserializeMiddleware()) operation.deserializeStep.intercept(position: .after, middleware: ClientRuntime.LoggerMiddleware(clientLogMode: config.clientLogMode)) let presignedRequestBuilder = try await operation.presignedRequest(context: context, input: input, next: ClientRuntime.NoopHandler()) @@ -103,6 +104,9 @@ class PresignerGeneratorTests { .withPartitionID(value: config.partitionID) .withAuthSchemes(value: config.authSchemes!) .withAuthSchemeResolver(value: config.serviceSpecific.authSchemeResolver) + .withUnsignedPayloadTrait(value: false) + .withFlowType(value: .PRESIGN_REQUEST) + .withExpiration(value: expiration) .withCredentialsProvider(value: config.credentialsProvider) .withIdentityResolver(value: config.credentialsProvider, type: IdentityKind.aws) .withRegion(value: config.region) @@ -121,8 +125,6 @@ class PresignerGeneratorTests { operation.finalizeStep.intercept(position: .before, middleware: ClientRuntime.ContentLengthMiddleware()) operation.finalizeStep.intercept(position: .after, middleware: ClientRuntime.RetryMiddleware(options: config.retryStrategyOptions)) operation.finalizeStep.intercept(position: .before, middleware: ClientRuntime.SignerMiddleware()) - let sigv4Config = AWSClientRuntime.SigV4Config(expiration: expiration, unsignedBody: false, signingAlgorithm: .sigv4) - operation.finalizeStep.intercept(position: .before, middleware: AWSClientRuntime.SigV4Middleware(config: sigv4Config)) operation.deserializeStep.intercept(position: .after, middleware: ClientRuntime.DeserializeMiddleware()) operation.deserializeStep.intercept(position: .after, middleware: ClientRuntime.LoggerMiddleware(clientLogMode: config.clientLogMode)) let presignedRequestBuilder = try await operation.presignedRequest(context: context, input: input, next: ClientRuntime.NoopHandler()) @@ -168,6 +170,9 @@ class PresignerGeneratorTests { .withPartitionID(value: config.partitionID) .withAuthSchemes(value: config.authSchemes!) .withAuthSchemeResolver(value: config.serviceSpecific.authSchemeResolver) + .withUnsignedPayloadTrait(value: false) + .withFlowType(value: .PRESIGN_REQUEST) + .withExpiration(value: expiration) .withCredentialsProvider(value: config.credentialsProvider) .withIdentityResolver(value: config.credentialsProvider, type: IdentityKind.aws) .withRegion(value: config.region) @@ -186,8 +191,6 @@ class PresignerGeneratorTests { operation.finalizeStep.intercept(position: .before, middleware: ClientRuntime.ContentLengthMiddleware()) operation.finalizeStep.intercept(position: .after, middleware: ClientRuntime.RetryMiddleware(options: config.retryStrategyOptions)) operation.finalizeStep.intercept(position: .before, middleware: ClientRuntime.SignerMiddleware()) - let sigv4Config = AWSClientRuntime.SigV4Config(expiration: expiration, unsignedBody: false, signingAlgorithm: .sigv4) - operation.finalizeStep.intercept(position: .before, middleware: AWSClientRuntime.SigV4Middleware(config: sigv4Config)) operation.deserializeStep.intercept(position: .after, middleware: ClientRuntime.DeserializeMiddleware()) operation.deserializeStep.intercept(position: .after, middleware: ClientRuntime.LoggerMiddleware(clientLogMode: config.clientLogMode)) let presignedRequestBuilder = try await operation.presignedRequest(context: context, input: input, next: ClientRuntime.NoopHandler()) @@ -233,6 +236,9 @@ class PresignerGeneratorTests { .withPartitionID(value: config.partitionID) .withAuthSchemes(value: config.authSchemes!) .withAuthSchemeResolver(value: config.serviceSpecific.authSchemeResolver) + .withUnsignedPayloadTrait(value: false) + .withFlowType(value: .PRESIGN_REQUEST) + .withExpiration(value: expiration) .withCredentialsProvider(value: config.credentialsProvider) .withIdentityResolver(value: config.credentialsProvider, type: IdentityKind.aws) .withRegion(value: config.region) @@ -243,6 +249,7 @@ class PresignerGeneratorTests { operation.initializeStep.intercept(position: .after, middleware: ClientRuntime.URLPathMiddleware()) operation.initializeStep.intercept(position: .after, middleware: ClientRuntime.URLHostMiddleware()) let endpointParams = EndpointParams() + context.attributes.set(key: AttributeKey(name: "EndpointParams"), value: endpointParams) operation.buildStep.intercept(position: .before, middleware: EndpointResolverMiddleware(endpointResolver: config.serviceSpecific.endpointResolver, endpointParams: endpointParams)) operation.buildStep.intercept(position: .before, middleware: AWSClientRuntime.UserAgentMiddleware(metadata: AWSClientRuntime.AWSUserAgentMetadata.fromConfig(serviceID: serviceName, version: "1.0.0", config: config))) operation.buildStep.intercept(position: .before, middleware: ClientRuntime.AuthSchemeMiddleware()) @@ -251,8 +258,6 @@ class PresignerGeneratorTests { operation.finalizeStep.intercept(position: .before, middleware: ClientRuntime.ContentLengthMiddleware()) operation.finalizeStep.intercept(position: .after, middleware: ClientRuntime.RetryMiddleware(options: config.retryStrategyOptions)) operation.finalizeStep.intercept(position: .before, middleware: ClientRuntime.SignerMiddleware()) - let sigv4Config = AWSClientRuntime.SigV4Config(useDoubleURIEncode: false, shouldNormalizeURIPath: false, expiration: expiration, signedBodyHeader: .contentSha256, unsignedBody: false, signingAlgorithm: .sigv4) - operation.finalizeStep.intercept(position: .before, middleware: AWSClientRuntime.SigV4Middleware(config: sigv4Config)) operation.deserializeStep.intercept(position: .after, middleware: ClientRuntime.DeserializeMiddleware()) operation.deserializeStep.intercept(position: .after, middleware: ClientRuntime.LoggerMiddleware(clientLogMode: config.clientLogMode)) let presignedRequestBuilder = try await operation.presignedRequest(context: context, input: input, next: ClientRuntime.NoopHandler()) diff --git a/codegen/smithy-aws-swift-codegen/src/test/kotlin/software/amazon/smithy/aws/swift/codegen/awsquery/AWSQueryOperationStackTest.kt b/codegen/smithy-aws-swift-codegen/src/test/kotlin/software/amazon/smithy/aws/swift/codegen/awsquery/AWSQueryOperationStackTest.kt index a1256770998..a23a72081ff 100644 --- a/codegen/smithy-aws-swift-codegen/src/test/kotlin/software/amazon/smithy/aws/swift/codegen/awsquery/AWSQueryOperationStackTest.kt +++ b/codegen/smithy-aws-swift-codegen/src/test/kotlin/software/amazon/smithy/aws/swift/codegen/awsquery/AWSQueryOperationStackTest.kt @@ -40,6 +40,7 @@ extension QueryProtocolClient: QueryProtocolClientProtocol { .withPartitionID(value: config.partitionID) .withAuthSchemes(value: config.authSchemes!) .withAuthSchemeResolver(value: config.serviceSpecific.authSchemeResolver) + .withUnsignedPayloadTrait(value: false) .withCredentialsProvider(value: config.credentialsProvider) .withIdentityResolver(value: config.credentialsProvider, type: IdentityKind.aws) .withRegion(value: config.region) diff --git a/codegen/smithy-aws-swift-codegen/src/test/kotlin/software/amazon/smithy/aws/swift/codegen/customizations/PresignableUrlIntegrationTests.kt b/codegen/smithy-aws-swift-codegen/src/test/kotlin/software/amazon/smithy/aws/swift/codegen/customizations/PresignableUrlIntegrationTests.kt index 8df7fec721e..495315e0bbb 100644 --- a/codegen/smithy-aws-swift-codegen/src/test/kotlin/software/amazon/smithy/aws/swift/codegen/customizations/PresignableUrlIntegrationTests.kt +++ b/codegen/smithy-aws-swift-codegen/src/test/kotlin/software/amazon/smithy/aws/swift/codegen/customizations/PresignableUrlIntegrationTests.kt @@ -13,43 +13,6 @@ import software.amazon.smithy.swift.codegen.core.GenerationContext import software.amazon.smithy.swift.codegen.integration.ProtocolGenerator class PresignableUrlIntegrationTests { - - @Test - fun `codesign is configured correctly for Polly SynthesizeSpeech`() { - val context = setupTests("presign-urls-polly.smithy", "com.amazonaws.polly#Parrot_v1") - val contents = TestContextGenerator.getFileContents(context.manifest, "/Example/models/SynthesizeSpeechInput+Presigner.swift") - contents.shouldSyntacticSanityCheck() - val expectedContents = """ - let sigv4Config = AWSClientRuntime.SigV4Config(signatureType: .requestQueryParams, expiration: expiration, unsignedBody: false, signingAlgorithm: .sigv4) - operation.finalizeStep.intercept(position: .before, middleware: AWSClientRuntime.SigV4Middleware(config: sigv4Config)) - """ - contents.shouldContainOnlyOnce(expectedContents) - } - - @Test - fun `codesign is configured correctly for S3 GetObject`() { - val context = setupTests("presign-urls-s3.smithy", "com.amazonaws.s3#AmazonS3") - val contents = TestContextGenerator.getFileContents(context.manifest, "/Example/models/GetObjectInput+Presigner.swift") - contents.shouldSyntacticSanityCheck() - val expectedContents = """ - let sigv4Config = AWSClientRuntime.SigV4Config(signatureType: .requestQueryParams, useDoubleURIEncode: false, shouldNormalizeURIPath: false, expiration: expiration, unsignedBody: true, signingAlgorithm: .sigv4) - operation.finalizeStep.intercept(position: .before, middleware: AWSClientRuntime.SigV4Middleware(config: sigv4Config)) - """ - contents.shouldContainOnlyOnce(expectedContents) - } - - @Test - fun `codesign is configured correctly for S3 PutObject`() { - val context = setupTests("presign-urls-s3.smithy", "com.amazonaws.s3#AmazonS3") - val contents = TestContextGenerator.getFileContents(context.manifest, "/Example/models/PutObjectInput+Presigner.swift") - contents.shouldSyntacticSanityCheck() - val expectedContents = """ - let sigv4Config = AWSClientRuntime.SigV4Config(signatureType: .requestQueryParams, useDoubleURIEncode: false, shouldNormalizeURIPath: false, expiration: expiration, unsignedBody: true, signingAlgorithm: .sigv4) - operation.finalizeStep.intercept(position: .before, middleware: AWSClientRuntime.SigV4Middleware(config: sigv4Config)) - """ - contents.shouldContainOnlyOnce(expectedContents) - } - @Test fun `S3 PutObject operation stack contains the PutObjectPresignedURLMiddleware`() { val context = setupTests("presign-urls-s3.smithy", "com.amazonaws.s3#AmazonS3") diff --git a/codegen/smithy-aws-swift-codegen/src/test/kotlin/software/amazon/smithy/aws/swift/codegen/customizations/RulesBasedAuthSchemeResolverGeneratorTests.kt b/codegen/smithy-aws-swift-codegen/src/test/kotlin/software/amazon/smithy/aws/swift/codegen/customizations/RulesBasedAuthSchemeResolverGeneratorTests.kt new file mode 100644 index 00000000000..fd53293fcfe --- /dev/null +++ b/codegen/smithy-aws-swift-codegen/src/test/kotlin/software/amazon/smithy/aws/swift/codegen/customizations/RulesBasedAuthSchemeResolverGeneratorTests.kt @@ -0,0 +1,151 @@ +package software.amazon.smithy.aws.swift.codegen.customizations + +import io.kotest.matchers.string.shouldContainOnlyOnce +import org.junit.jupiter.api.Test +import software.amazon.smithy.aws.swift.codegen.TestContext +import software.amazon.smithy.aws.swift.codegen.TestContextGenerator +import software.amazon.smithy.aws.swift.codegen.restjson.AWSRestJson1ProtocolGenerator +import software.amazon.smithy.aws.swift.codegen.shouldSyntacticSanityCheck +import software.amazon.smithy.aws.traits.protocols.RestJson1Trait + +class RulesBasedAuthSchemeResolverGeneratorTests { + // Note that there's no region field in + // auth scheme resolver params despite the service using SigV4 as one of its auth schemes, + // because no endpoint ruleset was provided for this test case. + // It's assumed in codden that endopint ruleset has the region field contained within. + @Test + fun `rules based auth scheme resolver generation test with fake S3 smithy model`() { + val context = setupTests("rules-based-auth-resolver-test.smithy", "com.test#S3") + val contents = + TestContextGenerator.getFileContents(context.manifest, "Example/AuthSchemeResolver.swift") + contents.shouldSyntacticSanityCheck() + val expectedContents = + """ + public struct S3AuthSchemeResolverParameters: ClientRuntime.AuthSchemeResolverParameters { + public let operation: String + } + + public protocol S3AuthSchemeResolver: ClientRuntime.AuthSchemeResolver { + // Intentionally empty. + // This is the parent protocol that all auth scheme resolver implementations of + // the service S3 must conform to. + } + + private struct InternalModeledS3AuthSchemeResolver: S3AuthSchemeResolver { + public func resolveAuthScheme(params: ClientRuntime.AuthSchemeResolverParameters) throws -> [AuthOption] { + var validAuthOptions = [AuthOption]() + guard let serviceParams = params as? S3AuthSchemeResolverParameters else { + throw ClientError.authError("Service specific auth scheme parameters type must be passed to auth scheme resolver.") + } + switch serviceParams.operation { + case "onlyHttpApiKeyAuth": + validAuthOptions.append(AuthOption(schemeID: "smithy.api#httpApiKeyAuth")) + case "onlyHttpApiKeyAuthOptional": + validAuthOptions.append(AuthOption(schemeID: "smithy.api#httpApiKeyAuth")) + validAuthOptions.append(AuthOption(schemeID: "smithy.api#noAuth")) + case "onlyHttpBearerAuth": + validAuthOptions.append(AuthOption(schemeID: "smithy.api#httpBearerAuth")) + case "onlyHttpBearerAuthOptional": + validAuthOptions.append(AuthOption(schemeID: "smithy.api#httpBearerAuth")) + validAuthOptions.append(AuthOption(schemeID: "smithy.api#noAuth")) + case "onlyHttpApiKeyAndBearerAuth": + validAuthOptions.append(AuthOption(schemeID: "smithy.api#httpApiKeyAuth")) + validAuthOptions.append(AuthOption(schemeID: "smithy.api#httpBearerAuth")) + case "onlyHttpApiKeyAndBearerAuthReversed": + validAuthOptions.append(AuthOption(schemeID: "smithy.api#httpBearerAuth")) + validAuthOptions.append(AuthOption(schemeID: "smithy.api#httpApiKeyAuth")) + case "onlySigv4Auth": + var sigV4Option = AuthOption(schemeID: "aws.auth#sigv4") + sigV4Option.signingProperties.set(key: AttributeKeys.signingName, value: "weather") + guard let region = serviceParams.region else { + throw ClientError.authError("Missing region in auth scheme parameters for SigV4 auth scheme.") + } + sigV4Option.signingProperties.set(key: AttributeKeys.signingRegion, value: region) + validAuthOptions.append(sigV4Option) + case "onlySigv4AuthOptional": + var sigV4Option = AuthOption(schemeID: "aws.auth#sigv4") + sigV4Option.signingProperties.set(key: AttributeKeys.signingName, value: "weather") + guard let region = serviceParams.region else { + throw ClientError.authError("Missing region in auth scheme parameters for SigV4 auth scheme.") + } + sigV4Option.signingProperties.set(key: AttributeKeys.signingRegion, value: region) + validAuthOptions.append(sigV4Option) + validAuthOptions.append(AuthOption(schemeID: "smithy.api#noAuth")) + case "onlyCustomAuth": + validAuthOptions.append(AuthOption(schemeID: "com.test#customAuth")) + case "onlyCustomAuthOptional": + validAuthOptions.append(AuthOption(schemeID: "com.test#customAuth")) + validAuthOptions.append(AuthOption(schemeID: "smithy.api#noAuth")) + default: + var sigV4Option = AuthOption(schemeID: "aws.auth#sigv4") + sigV4Option.signingProperties.set(key: AttributeKeys.signingName, value: "weather") + guard let region = serviceParams.region else { + throw ClientError.authError("Missing region in auth scheme parameters for SigV4 auth scheme.") + } + sigV4Option.signingProperties.set(key: AttributeKeys.signingRegion, value: region) + validAuthOptions.append(sigV4Option) + } + return validAuthOptions + } + + public func constructParameters(context: HttpContext) throws -> ClientRuntime.AuthSchemeResolverParameters { + return try DefaultS3AuthSchemeResolver().constructParameters(context: context) + } + } + + public struct DefaultS3AuthSchemeResolver: S3AuthSchemeResolver { + public func resolveAuthScheme(params: ClientRuntime.AuthSchemeResolverParameters) throws -> [AuthOption] { + var validAuthOptions = [AuthOption]() + guard let serviceParams = params as? S3AuthSchemeResolverParameters else { + throw ClientError.authError("Service specific auth scheme parameters type must be passed to auth scheme resolver.") + } + let endpointParams = EndpointParams(authSchemeParams: serviceParams) + let endpoint = try DefaultEndpointResolver().resolve(params: endpointParams) + guard let authSchemes = endpoint.authSchemes() else { + return try InternalModeledS3AuthSchemeResolver().resolveAuthScheme(params: params) + } + let schemes = try authSchemes.map { (input) -> AWSClientRuntime.AuthScheme in try AWSClientRuntime.AuthScheme(from: input) } + for scheme in schemes { + switch scheme { + case .sigV4(let param): + var sigV4Option = AuthOption(schemeID: "aws.auth#sigv4") + sigV4Option.signingProperties.set(key: AttributeKeys.signingName, value: param.signingName) + sigV4Option.signingProperties.set(key: AttributeKeys.signingRegion, value: param.signingRegion) + validAuthOptions.append(sigV4Option) + case .sigV4A(let param): + var sigV4Option = AuthOption(schemeID: "aws.auth#sigv4a") + sigV4Option.signingProperties.set(key: AttributeKeys.signingName, value: param.signingName) + sigV4Option.signingProperties.set(key: AttributeKeys.signingRegion, value: param.signingRegionSet?[0]) + validAuthOptions.append(sigV4Option) + default: + throw ClientError.authError("Unknown auth scheme name: \(scheme.name)") + } + } + return validAuthOptions + } + + public func constructParameters(context: HttpContext) throws -> ClientRuntime.AuthSchemeResolverParameters { + guard let opName = context.getOperation() else { + throw ClientError.dataNotFound("Operation name not configured in middleware context for auth scheme resolver params construction.") + } + guard let endpointParam = context.attributes.get(key: AttributeKey(name: "EndpointParams")) else { + throw ClientError.dataNotFound("Endpoint param not configured in middleware context for rules-based auth scheme resolver params construction.") + } + return S3AuthSchemeResolverParameters(operation: opName) + } + } + """.trimIndent() + contents.shouldContainOnlyOnce(expectedContents) + } + + private fun setupTests(smithyFile: String, serviceShapeId: String): TestContext { + val context = TestContextGenerator.initContextFrom(smithyFile, serviceShapeId, RestJson1Trait.ID) + + val generator = AWSRestJson1ProtocolGenerator() + generator.initializeMiddleware(context.ctx) + generator.generateProtocolClient(context.ctx) + generator.generateSerializers(context.ctx) + context.ctx.delegator.flushWriters() + return context + } +} diff --git a/codegen/smithy-aws-swift-codegen/src/test/resources/software.amazon.smithy.aws.swift.codegen/rules-based-auth-resolver-test.smithy b/codegen/smithy-aws-swift-codegen/src/test/resources/software.amazon.smithy.aws.swift.codegen/rules-based-auth-resolver-test.smithy new file mode 100644 index 00000000000..15419d4d612 --- /dev/null +++ b/codegen/smithy-aws-swift-codegen/src/test/resources/software.amazon.smithy.aws.swift.codegen/rules-based-auth-resolver-test.smithy @@ -0,0 +1,87 @@ +$version: "2.0" + +namespace com.test + +use aws.auth#sigv4 +use aws.protocols#restJson1 +use aws.api#service + +@authDefinition +@trait +structure customAuth {} + +@restJson1 +@httpApiKeyAuth(name: "X-Api-Key", in: "header") +@httpBearerAuth +@sigv4(name: "weather") +@customAuth +@auth([sigv4]) +@service(sdkId: "S3") +service S3 { + version: "2023-11-10" + operations: [ + // experimentalIdentityAndAuth + OnlyHttpApiKeyAuth + OnlyHttpApiKeyAuthOptional + OnlyHttpBearerAuth + OnlyHttpBearerAuthOptional + OnlyHttpApiKeyAndBearerAuth + OnlyHttpApiKeyAndBearerAuthReversed + OnlySigv4Auth + OnlySigv4AuthOptional + OnlyCustomAuth + OnlyCustomAuthOptional + SameAsService + ] +} + +@http(method: "GET", uri: "/OnlyHttpApiKeyAuth") +@auth([httpApiKeyAuth]) +operation OnlyHttpApiKeyAuth {} + +@http(method: "GET", uri: "/OnlyHttpBearerAuth") +@auth([httpBearerAuth]) +operation OnlyHttpBearerAuth {} + +@http(method: "GET", uri: "/OnlySigv4Auth") +@auth([sigv4]) +operation OnlySigv4Auth {} + +@http(method: "GET", uri: "/OnlyHttpApiKeyAndBearerAuth") +@auth([httpApiKeyAuth, httpBearerAuth]) +operation OnlyHttpApiKeyAndBearerAuth {} + +@http(method: "GET", uri: "/OnlyHttpApiKeyAndBearerAuthReversed") +@auth([httpBearerAuth, httpApiKeyAuth]) +operation OnlyHttpApiKeyAndBearerAuthReversed {} + +@http(method: "GET", uri: "/OnlyHttpApiKeyAuthOptional") +@auth([httpApiKeyAuth]) +@optionalAuth +operation OnlyHttpApiKeyAuthOptional {} + +@http(method: "GET", uri: "/OnlyHttpBearerAuthOptional") +@auth([httpBearerAuth]) +@optionalAuth +operation OnlyHttpBearerAuthOptional {} + +@http(method: "GET", uri: "/OnlySigv4AuthOptional") +@auth([sigv4]) +@optionalAuth +operation OnlySigv4AuthOptional {} + +@http(method: "GET", uri: "/OnlyCustomAuth") +@auth([customAuth]) +operation OnlyCustomAuth {} + +@http(method: "GET", uri: "/OnlyCustomAuthOptional") +@auth([customAuth]) +@optionalAuth +operation OnlyCustomAuthOptional {} + +@http(method: "GET", uri: "/SameAsService") +operation SameAsService { + output := { + service: String + } +} \ No newline at end of file