Skip to content

Commit

Permalink
feat: Add retry information headers (#692)
Browse files Browse the repository at this point in the history
* Organize test files and directories to mirror code files and directory structure.

* Add withSocket utility method & AttributeKeys related to TTL (socketTimeout & estimatedSkew).

* Make socketTimeout to default to 60 seconds.

* Add codegen that saves socketTimeout from HttpClientConfiguration into middleware context.

* Add utility functions; one that calculates estimated skew from date string & one that calculates TTL by adding estimated skew and socket timeout to current time according to local machine clock.

* Make DeserializeMiddleware save estimated skew calculated from returned HTTP response's Date header value.

* Make RetryMiddleware add retry information headers as defined in SEP.

* Fix dateFormatter in getTTL utility method to take raw date and convert to string without any adjustments.

* Add tests for the 2 utility methods getTTL & getEstimatedSkew. Augment existing RetryIntegrationTests to check retry information headers in inputs.

* Update codegen test to include socketTimeout addition.

* Add dummy values needed for context used by retry middleware tests.

* Change a couple XCTAssert to XCTAssertEqual for better log message.

* Make socketTimeout non-optional given default value is being set now.

* Log .info level message then proceed with default values instead of throwing an error.

* Fix socket timeout related errors.

* Fix syntax error.

---------

Co-authored-by: Sichan Yoo <[email protected]>
  • Loading branch information
sichanyoo and Sichan Yoo authored Apr 16, 2024
1 parent 1fb3803 commit 7be2bd8
Show file tree
Hide file tree
Showing 31 changed files with 171 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ public extension DefaultSDKRuntimeConfiguration {
return URLSessionHTTPClient(httpClientConfiguration: httpClientConfiguration)
#else
let connectTimeoutMs = httpClientConfiguration.connectTimeout.map { UInt32($0 * 1000) }
let socketTimeout = httpClientConfiguration.connectTimeout.map { UInt32($0) }
let socketTimeout = UInt32(httpClientConfiguration.socketTimeout)
let config = CRTClientEngineConfig(connectTimeoutMs: connectTimeoutMs, socketTimeout: socketTimeout)
return CRTClientEngine(config: config)
#endif
Expand Down
69 changes: 65 additions & 4 deletions Sources/ClientRuntime/Middleware/RetryMiddleware.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@
// SPDX-License-Identifier: Apache-2.0
//

import class Foundation.DateFormatter
import struct Foundation.Locale
import struct Foundation.TimeInterval
import struct Foundation.TimeZone
import struct Foundation.UUID

public struct RetryMiddleware<Strategy: RetryStrategy,
ErrorInfoProvider: RetryErrorInfoProvider,
OperationStackOutput>: Middleware {
Expand All @@ -16,23 +22,36 @@ public struct RetryMiddleware<Strategy: RetryStrategy,
public var id: String { "Retry" }
public var strategy: Strategy

// The UUID string used to uniquely identify an API call and all of its subsequent retries.
private let invocationID = UUID().uuidString.lowercased()
// Max number of retries configured for retry strategy.
private var maxRetries: Int

public init(options: RetryStrategyOptions) {
self.strategy = Strategy(options: options)
self.maxRetries = options.maxRetriesBase
}

public func handle<H>(context: Context, input: SdkHttpRequestBuilder, next: H) async throws ->
OperationOutput<OperationStackOutput>
where H: Handler, MInput == H.Input, MOutput == H.Output, Context == H.Context {

input.headers.add(name: "amz-sdk-invocation-id", value: invocationID)

let partitionID = try getPartitionID(context: context, input: input)
let token = try await strategy.acquireInitialRetryToken(tokenScope: partitionID)
return try await sendRequest(token: token, context: context, input: input, next: next)
input.headers.add(name: "amz-sdk-request", value: "attempt=1; max=\(maxRetries)")
return try await sendRequest(attemptNumber: 1, token: token, context: context, input: input, next: next)
}

private func sendRequest<H>(token: Strategy.Token, context: Context, input: MInput, next: H) async throws ->
private func sendRequest<H>(
attemptNumber: Int,
token: Strategy.Token,
context: Context,
input: MInput, next: H
) async throws ->
OperationOutput<OperationStackOutput>
where H: Handler, MInput == H.Input, MOutput == H.Output, Context == H.Context {

do {
let serviceResponse = try await next.handle(context: context, input: input)
await strategy.recordSuccess(token: token)
Expand All @@ -45,7 +64,28 @@ public struct RetryMiddleware<Strategy: RetryStrategy,
// TODO: log token error here
throw operationError
}
return try await sendRequest(token: token, context: context, input: input, next: next)
var estimatedSkew = context.attributes.get(key: AttributeKeys.estimatedSkew)
if estimatedSkew == nil {
estimatedSkew = 0
context.getLogger()!.info("Estimated skew not found; defaulting to zero.")
}
var socketTimeout = context.attributes.get(key: AttributeKeys.socketTimeout)
if socketTimeout == nil {
socketTimeout = 60.0
context.getLogger()!.info("Socket timeout value not found; defaulting to 60 seconds.")
}
let ttlDateUTCString = getTTL(now: Date(), estimatedSkew: estimatedSkew!, socketTimeout: socketTimeout!)
input.headers.update(
name: "amz-sdk-request",
value: "ttl=\(ttlDateUTCString); attempt=\(attemptNumber + 1); max=\(maxRetries)"
)
return try await sendRequest(
attemptNumber: attemptNumber + 1,
token: token,
context: context,
input: input,
next: next
)
}
}

Expand All @@ -66,3 +106,24 @@ public struct RetryMiddleware<Strategy: RetryStrategy,
}
}
}

// Calculates & returns TTL datetime in strftime format `YYYYmmddTHHMMSSZ`.
func getTTL(now: Date, estimatedSkew: TimeInterval, socketTimeout: TimeInterval) -> String {
let dateFormatter = DateFormatter()
dateFormatter.dateFormat = "yyyyMMdd'T'HHmmss'Z'"
dateFormatter.locale = Locale(identifier: "en_US_POSIX")
dateFormatter.timeZone = TimeZone(abbreviation: "UTC")
let ttlDate = now.addingTimeInterval(estimatedSkew + socketTimeout)
return dateFormatter.string(from: ttlDate)
}

// Calculates & returns estimated skew.
func getEstimatedSkew(now: Date, responseDateString: String) -> TimeInterval {
let dateFormatter = DateFormatter()
dateFormatter.dateFormat = "EEE, dd MMM yyyy HH:mm:ss z"
dateFormatter.locale = Locale(identifier: "en_US_POSIX")
dateFormatter.timeZone = TimeZone(abbreviation: "GMT")
let responseDate: Date = dateFormatter.date(from: responseDateString) ?? now
// (Estimated skew) = (Date header from HTTP response) - (client's current time)).
return responseDate.timeIntervalSince(now)
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ public class HttpClientConfiguration {
/// Sets maximum time to wait between two data packets.
/// Used to close stale connections that have no activity.
///
/// If no value is provided, the defaut client won't have a socket timeout.
public var socketTimeout: TimeInterval?
/// Defaults to 60 seconds if no value is provided.
public var socketTimeout: TimeInterval

/// HTTP headers to be submitted with every HTTP request.
///
Expand All @@ -45,7 +45,7 @@ public class HttpClientConfiguration {
/// - protocolType: The HTTP scheme (`http` or `https`) to be used for API requests. Defaults to the operation's standard configuration.
public init(
connectTimeout: TimeInterval? = nil,
socketTimeout: TimeInterval? = nil,
socketTimeout: TimeInterval = 60.0,
protocolType: ProtocolType = .https,
defaultHeaders: Headers = Headers()
) {
Expand Down
10 changes: 10 additions & 0 deletions Sources/ClientRuntime/Networking/Http/HttpContext.swift
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,12 @@ public class HttpContextBuilder {
return self
}

@discardableResult
public func withSocketTimeout(value: TimeInterval?) -> HttpContextBuilder {
self.attributes.set(key: AttributeKeys.socketTimeout, value: value)
return self
}

@discardableResult
public func withUnsignedPayloadTrait(value: Bool) -> HttpContextBuilder {
self.attributes.set(key: AttributeKeys.hasUnsignedPayloadTrait, value: value)
Expand Down Expand Up @@ -356,6 +362,10 @@ public enum AttributeKeys {

// Streams
public static let isChunkedEligibleStream = AttributeKey<Bool>(name: "isChunkedEligibleStream")

// TTL calculation in retries.
public static let estimatedSkew = AttributeKey<TimeInterval>(name: "EstimatedSkew")
public static let socketTimeout = AttributeKey<TimeInterval>(name: "SocketTimeout")
}

// The type of flow the mdidleware context is being constructed for
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ public struct DeserializeMiddleware<OperationStackOutput>: Middleware {

let response = try await next.handle(context: context, input: input) // call handler to get http response

if let responseDateString = response.httpResponse.headers.value(for: "Date") {
let estimatedSkew = getEstimatedSkew(now: Date(), responseDateString: responseDateString)
context.attributes.set(key: AttributeKeys.estimatedSkew, value: estimatedSkew)
}

// check if the response body was effected by a previous middleware
if let contextBody = context.response?.body {
response.httpResponse.body = contextBody
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@ extension URLSessionConfiguration {

public static func from(httpClientConfiguration: HttpClientConfiguration) -> URLSessionConfiguration {
let config = URLSessionConfiguration.default
if let socketTimeout = httpClientConfiguration.socketTimeout {
config.timeoutIntervalForRequest = socketTimeout
}
config.timeoutIntervalForRequest = httpClientConfiguration.socketTimeout
return config
}
}
Expand Down
79 changes: 79 additions & 0 deletions Tests/ClientRuntimeTests/Retry/RetryIntegrationTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ final class RetryIntegrationTests: XCTestCase {
// Setup the HTTP context, used by the retry middleware
context = HttpContext(attributes: Attributes())
context.attributes.set(key: partitionIDKey, value: partitionID)
context.attributes.set(key: AttributeKeys.socketTimeout, value: 60.0)
context.attributes.set(key: AttributeKeys.estimatedSkew, value: 30.0)

// Create the test output handler, which is the "next" middleware called by the retry middleware
next = TestOutputHandler()
Expand Down Expand Up @@ -116,6 +118,39 @@ final class RetryIntegrationTests: XCTestCase {
}
try await next.verifyResult()
}

// Test getEstimatedSkew utility method.
func test_getEstimatedSkew() {
let responseDateString = "Mon, 15 Jul 2024 01:24:12 GMT"
let dateFormatter = DateFormatter()
dateFormatter.dateFormat = "EEE, dd MMM yyyy HH:mm:ss z"
dateFormatter.locale = Locale(identifier: "en_US_POSIX")
dateFormatter.timeZone = TimeZone(abbreviation: "GMT")
let responseDate: Date = dateFormatter.date(from: responseDateString)!

let responseDateStringPlusTen = "Mon, 15 Jul 2024 01:24:22 GMT"
let estimatedSkew = getEstimatedSkew(now: responseDate, responseDateString: responseDateStringPlusTen)

XCTAssertEqual(estimatedSkew, 10.0)
}

// Test getTTLutility method.
func test_getTTL() {
let nowDateString = "Mon, 15 Jul 2024 01:24:12 GMT"
let dateFormatter = DateFormatter()
dateFormatter.dateFormat = "EEE, dd MMM yyyy HH:mm:ss z"
dateFormatter.locale = Locale(identifier: "en_US_POSIX")
dateFormatter.timeZone = TimeZone(abbreviation: "GMT")
let nowDate: Date = dateFormatter.date(from: nowDateString)!

// The two timeintervals below add up to 34 minutes 59 seconds, rounding to closest second.
let estimatedSkew = 2039.34
let socketTimeout = 60.0

// Verify calculated TTL is nowDate + (34 minutes and 59 seconds).
let ttl = getTTL(now: nowDate, estimatedSkew: estimatedSkew, socketTimeout: socketTimeout)
XCTAssertEqual(ttl, "20240715T015911Z")
}
}

private struct TestStep {
Expand Down Expand Up @@ -174,12 +209,16 @@ private class TestOutputHandler: Handler {
var quota: RetryQuota!
var actualDelay: TimeInterval?
var finalError: Error?
var invocationID = ""
var prevAttemptNum = 0

func handle(context: ClientRuntime.HttpContext, input: SdkHttpRequestBuilder) async throws -> OperationOutput<TestOutputResponse> {
if index == testSteps.count { throw RetryIntegrationTestError.maxAttemptsExceeded }

// Verify the results of the previous test step, if there was one.
try await verifyResult(atEnd: false)
// Verify the input's retry information headers.
try await verifyInput(input: input)

// Get the latest test step, then advance the index.
let testStep = testSteps[index]
Expand Down Expand Up @@ -218,6 +257,46 @@ private class TestOutputHandler: Handler {
XCTFail("Test should not end on retry", file: testStep.file, line: testStep.line)
}
}

func verifyInput(input: SdkHttpRequestBuilder) async throws {
// Get invocation ID of the request off of amz-sdk-invocation-id header.
let invocationID = try XCTUnwrap(input.headers.value(for: "amz-sdk-invocation-id"))
// If this is the first request, save the retrieved ID.
if (self.invocationID.isEmpty) { self.invocationID = invocationID }

// Retrieved IDs from all requests under a same call must be the same.
XCTAssertEqual(self.invocationID, invocationID)

// Get retry information off of amz-sdk-request header.
let amzSdkRequestHeaderValue = try XCTUnwrap(input.headers.value(for: "amz-sdk-request"))
// Extract request pair values from amz-sdk-request header value.
let requestPairs = amzSdkRequestHeaderValue.components(separatedBy: "; ")
var ttl: String = ""
let attemptNum: Int = try XCTUnwrap(
Int(
try XCTUnwrap(requestPairs.first { $0.hasPrefix("attempt=") })
.components(separatedBy: "=")[1]
)
)
_ = try XCTUnwrap(
Int(
try XCTUnwrap(requestPairs.first { $0.hasPrefix("max=") })
.components(separatedBy: "=")[1]
)
)
// For attempts 2+, TTL must be present.
if (attemptNum > 1) {
ttl = try XCTUnwrap(requestPairs.first { $0.hasPrefix("ttl") }).components(separatedBy: "=")[1]
// Check that TTL date is in strftime format.
let dateFormatter = DateFormatter()
dateFormatter.dateFormat = "yyyyMMdd'T'HHmmss'Z'"
XCTAssertNotNil(dateFormatter.date(from: ttl))
}

// Verify attempt number was incremented by 1 from previous request.
XCTAssertEqual(attemptNum, (self.prevAttemptNum + 1))
self.prevAttemptNum = attemptNum
}
}

// Thrown during a test to simulate a server response with a given HTTP status code.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class MiddlewareExecutionGenerator(
writer.write(" .withAuthSchemes(value: config.authSchemes ?? [])")
writer.write(" .withAuthSchemeResolver(value: config.authSchemeResolver)")
writer.write(" .withUnsignedPayloadTrait(value: ${op.hasTrait(UnsignedPayloadTrait::class.java)})")
writer.write(" .withSocketTimeout(value: config.httpClientConfiguration.socketTimeout)")

// Add flag for presign / presign-url flows
if (flowType == ContextAttributeCodegenFlowType.PRESIGN_REQUEST) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class ContentMd5MiddlewareTests {
.withAuthSchemes(value: config.authSchemes ?? [])
.withAuthSchemeResolver(value: config.authSchemeResolver)
.withUnsignedPayloadTrait(value: false)
.withSocketTimeout(value: config.httpClientConfiguration.socketTimeout)
.build()
var operation = ClientRuntime.OperationStack<IdempotencyTokenWithStructureInput, IdempotencyTokenWithStructureOutput>(id: "idempotencyTokenWithStructure")
operation.initializeStep.intercept(position: .after, middleware: ClientRuntime.IdempotencyTokenMiddleware<IdempotencyTokenWithStructureInput, IdempotencyTokenWithStructureOutput>(keyPath: \.token))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ class HttpProtocolClientGeneratorTests {
.withAuthSchemes(value: config.authSchemes ?? [])
.withAuthSchemeResolver(value: config.authSchemeResolver)
.withUnsignedPayloadTrait(value: false)
.withSocketTimeout(value: config.httpClientConfiguration.socketTimeout)
.build()
var operation = ClientRuntime.OperationStack<AllocateWidgetInput, AllocateWidgetOutput>(id: "allocateWidget")
operation.initializeStep.intercept(position: .after, middleware: ClientRuntime.IdempotencyTokenMiddleware<AllocateWidgetInput, AllocateWidgetOutput>(keyPath: \.clientToken))
Expand Down Expand Up @@ -193,6 +194,7 @@ class HttpProtocolClientGeneratorTests {
.withAuthSchemes(value: config.authSchemes ?? [])
.withAuthSchemeResolver(value: config.authSchemeResolver)
.withUnsignedPayloadTrait(value: true)
.withSocketTimeout(value: config.httpClientConfiguration.socketTimeout)
.build()
var operation = ClientRuntime.OperationStack<UnsignedFooBlobStreamInput, UnsignedFooBlobStreamOutput>(id: "unsignedFooBlobStream")
operation.initializeStep.intercept(position: .after, middleware: ClientRuntime.URLPathMiddleware<UnsignedFooBlobStreamInput, UnsignedFooBlobStreamOutput>(UnsignedFooBlobStreamInput.urlPathProvider(_:)))
Expand Down Expand Up @@ -231,6 +233,7 @@ class HttpProtocolClientGeneratorTests {
.withAuthSchemes(value: config.authSchemes ?? [])
.withAuthSchemeResolver(value: config.authSchemeResolver)
.withUnsignedPayloadTrait(value: false)
.withSocketTimeout(value: config.httpClientConfiguration.socketTimeout)
.build()
var operation = ClientRuntime.OperationStack<ExplicitBlobStreamWithLengthInput, ExplicitBlobStreamWithLengthOutput>(id: "explicitBlobStreamWithLength")
operation.initializeStep.intercept(position: .after, middleware: ClientRuntime.URLPathMiddleware<ExplicitBlobStreamWithLengthInput, ExplicitBlobStreamWithLengthOutput>(ExplicitBlobStreamWithLengthInput.urlPathProvider(_:)))
Expand Down Expand Up @@ -269,6 +272,7 @@ class HttpProtocolClientGeneratorTests {
.withAuthSchemes(value: config.authSchemes ?? [])
.withAuthSchemeResolver(value: config.authSchemeResolver)
.withUnsignedPayloadTrait(value: true)
.withSocketTimeout(value: config.httpClientConfiguration.socketTimeout)
.build()
var operation = ClientRuntime.OperationStack<UnsignedFooBlobStreamWithLengthInput, UnsignedFooBlobStreamWithLengthOutput>(id: "unsignedFooBlobStreamWithLength")
operation.initializeStep.intercept(position: .after, middleware: ClientRuntime.URLPathMiddleware<UnsignedFooBlobStreamWithLengthInput, UnsignedFooBlobStreamWithLengthOutput>(UnsignedFooBlobStreamWithLengthInput.urlPathProvider(_:)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class IdempotencyTokenTraitTests {
.withAuthSchemes(value: config.authSchemes ?? [])
.withAuthSchemeResolver(value: config.authSchemeResolver)
.withUnsignedPayloadTrait(value: false)
.withSocketTimeout(value: config.httpClientConfiguration.socketTimeout)
.build()
var operation = ClientRuntime.OperationStack<IdempotencyTokenWithStructureInput, IdempotencyTokenWithStructureOutput>(id: "idempotencyTokenWithStructure")
operation.initializeStep.intercept(position: .after, middleware: ClientRuntime.IdempotencyTokenMiddleware<IdempotencyTokenWithStructureInput, IdempotencyTokenWithStructureOutput>(keyPath: \.token))
Expand Down

0 comments on commit 7be2bd8

Please sign in to comment.