From aeb2f71ba99e2f8c947f9433c46fdaf67c016ce2 Mon Sep 17 00:00:00 2001 From: fitomad Date: Sat, 22 Jun 2024 22:49:57 +0200 Subject: [PATCH 01/15] Fixing access modifier (From vapor project to independent package) --- .../Limiters/FixedWindowCounter.swift | 15 +++--- .../Checkpoint/Limiters/LeakingBucket.swift | 36 +++++++-------- Sources/Checkpoint/Limiters/Limiter.swift | 7 +-- .../Limiters/SlidingWindowLog.swift | 10 ++-- Sources/Checkpoint/Limiters/TokenBucket.swift | 33 +++++++------ .../Limiters/WindowBasedLimiter.swift | 6 +-- Sources/Checkpoint/Model/Configuration.swift | 46 +++++++++---------- Sources/Checkpoint/Model/Field.swift | 2 +- Sources/Checkpoint/Model/RateLimitScope.swift | 2 +- Sources/Checkpoint/Model/TimeWindow.swift | 2 +- 10 files changed, 83 insertions(+), 76 deletions(-) diff --git a/Sources/Checkpoint/Limiters/FixedWindowCounter.swift b/Sources/Checkpoint/Limiters/FixedWindowCounter.swift index de1c549..220a0e8 100644 --- a/Sources/Checkpoint/Limiters/FixedWindowCounter.swift +++ b/Sources/Checkpoint/Limiters/FixedWindowCounter.swift @@ -18,16 +18,15 @@ import Vapor 4. If the counter is greater than the rate limit, the request is rejected. 5. If the counter is less than the rate limit, the request is accepted. */ -final class FixedWindowCounter { +public final class FixedWindowCounter { private let configuration: FixedWindowCounterConfiguration - let storage: Application.Redis - let logging: Logger? + public let storage: Application.Redis + public let logging: Logger? private var cancellable: AnyCancellable? private var keys = Set() - - init(configuration: () -> FixedWindowCounterConfiguration, storage: StorageAction, logging: LoggerAction? = nil) { + public init(configuration: () -> FixedWindowCounterConfiguration, storage: StorageAction, logging: LoggerAction? = nil) { self.configuration = configuration() self.storage = storage() self.logging = logging?() @@ -42,7 +41,7 @@ final class FixedWindowCounter { } extension FixedWindowCounter: WindowBasedLimiter { - func checkRequest(_ request: Request) async throws { + public func checkRequest(_ request: Request) async throws { guard let requestKey = try? valueFor(field: configuration.appliedField, in: request, inside: configuration.scope) else { return } @@ -52,8 +51,8 @@ extension FixedWindowCounter: WindowBasedLimiter { let redisKey = RedisKey(requestKey) let timestamp = Date.now.timeIntervalSince1970 - // If window request is full then drop storage.rpush([ timestamp ], into: redisKey) + let requestCount = try await storage.llen(of: redisKey).get() if requestCount > configuration.requestPerWindow { @@ -61,7 +60,7 @@ extension FixedWindowCounter: WindowBasedLimiter { } } - func resetWindow() { + public func resetWindow() { keys.forEach { key in let redisKey = RedisKey(key) diff --git a/Sources/Checkpoint/Limiters/LeakingBucket.swift b/Sources/Checkpoint/Limiters/LeakingBucket.swift index 11c4888..a236c66 100644 --- a/Sources/Checkpoint/Limiters/LeakingBucket.swift +++ b/Sources/Checkpoint/Limiters/LeakingBucket.swift @@ -18,19 +18,15 @@ import Vapor 4. If the bucket is not full, we allow the request and add 1 token from the bucket. 5. Tokens are removed at a fixed rate of r tokens per second. Letโ€™s say 1 token per second. */ -final class LeakingBucket { +public final class LeakingBucket { private let configuration: LeakingBucketConfiguration - let storage: Application.Redis - let logging: Logger? + public let storage: Application.Redis + public let logging: Logger? private var cancellable: AnyCancellable? private var keys = Set() - private var redisKey: RedisKey { - RedisKey("") - } - - init(configuration: () -> LeakingBucketConfiguration, storage: StorageAction, logging: LoggerAction? = nil) { + public init(configuration: () -> LeakingBucketConfiguration, storage: StorageAction, logging: LoggerAction? = nil) { self.configuration = configuration() self.storage = storage() self.logging = logging?() @@ -56,7 +52,7 @@ extension LeakingBucket: WindowBasedLimiter { return true } - func checkRequest(_ request: Request) async throws { + public func checkRequest(_ request: Request) async throws { guard let requestKey = try? valueFor(field: configuration.appliedField, in: request, inside: configuration.scope) else { return } @@ -79,17 +75,21 @@ extension LeakingBucket: WindowBasedLimiter { } } - func resetWindow() throws { - Task(priority: .userInitiated) { - let respValue = try await storage.get(redisKey).get() - - var newBucketSize = 0 + public func resetWindow() throws { + keys.forEach { key in + Task(priority: .userInitiated) { + let redisKey = RedisKey(key) + + let respValue = try await storage.get(redisKey).get() - if let currentBucketSize = respValue.int { - newBucketSize = currentBucketSize < configuration.tokenRemovingRate ? 0 : (currentBucketSize - configuration.tokenRemovingRate) + var newBucketSize = 0 + + if let currentBucketSize = respValue.int { + newBucketSize = currentBucketSize < configuration.tokenRemovingRate ? 0 : (currentBucketSize - configuration.tokenRemovingRate) + } + + try await storage.decrement(redisKey, by: newBucketSize).get() } - - try await storage.decrement(redisKey, by: newBucketSize).get() } } } diff --git a/Sources/Checkpoint/Limiters/Limiter.swift b/Sources/Checkpoint/Limiters/Limiter.swift index c346be6..3a17589 100644 --- a/Sources/Checkpoint/Limiters/Limiter.swift +++ b/Sources/Checkpoint/Limiters/Limiter.swift @@ -6,12 +6,13 @@ // import Combine +import Redis import Vapor -typealias StorageAction = () -> Application.Redis -typealias LoggerAction = () -> Logger +public typealias StorageAction = () -> Application.Redis +public typealias LoggerAction = () -> Logger -protocol Limiter: Sendable { +public protocol Limiter: Sendable { associatedtype ConfigurationType var storage: Application.Redis { get } diff --git a/Sources/Checkpoint/Limiters/SlidingWindowLog.swift b/Sources/Checkpoint/Limiters/SlidingWindowLog.swift index 272af7f..68bef31 100644 --- a/Sources/Checkpoint/Limiters/SlidingWindowLog.swift +++ b/Sources/Checkpoint/Limiters/SlidingWindowLog.swift @@ -8,12 +8,12 @@ import Redis import Vapor -final class SlidingWindowLog { +public final class SlidingWindowLog { private let configuration: SlidingWindowLogConfiguration - let storage: Application.Redis - let logging: Logger? + public let storage: Application.Redis + public let logging: Logger? - init(configuration: () -> SlidingWindowLogConfiguration, storage: StorageAction, logging: LoggerAction? = nil) { + public init(configuration: () -> SlidingWindowLogConfiguration, storage: StorageAction, logging: LoggerAction? = nil) { self.configuration = configuration() self.storage = storage() self.logging = logging?() @@ -25,7 +25,7 @@ extension SlidingWindowLog: Limiter { return true } - func checkRequest(_ request: Request) async throws { + public func checkRequest(_ request: Request) async throws { guard let apiKey = try? valueFor(field: configuration.appliedField, in: request, inside: configuration.scope) else { throw Abort(.unauthorized, reason: Checkpoint.HTTPErrorDescription.unauthorized) } diff --git a/Sources/Checkpoint/Limiters/TokenBucket.swift b/Sources/Checkpoint/Limiters/TokenBucket.swift index 0d11449..77cf6ca 100644 --- a/Sources/Checkpoint/Limiters/TokenBucket.swift +++ b/Sources/Checkpoint/Limiters/TokenBucket.swift @@ -18,15 +18,15 @@ import Vapor โ€ข We take 1 token out for each request and if there are enough tokens, then the request is processed. โ€ข The request is dropped if there arenโ€™t enough tokens. */ -final class TokenBucket { +public final class TokenBucket { private let configuration: TokenBucketConfiguration - let storage: Application.Redis - let logging: Logger? + public let storage: Application.Redis + public let logging: Logger? private var cancellable: AnyCancellable? private var keys = Set() - init(configuration: () -> TokenBucketConfiguration, storage: StorageAction, logging: LoggerAction? = nil) { + public init(configuration: () -> TokenBucketConfiguration, storage: StorageAction, logging: LoggerAction? = nil) { self.configuration = configuration() self.storage = storage() self.logging = logging?() @@ -49,7 +49,7 @@ final class TokenBucket { } extension TokenBucket: WindowBasedLimiter { - func checkRequest(_ request: Request) async throws { + public func checkRequest(_ request: Request) async throws { guard let requestKey = try? valueFor(field: configuration.appliedField, in: request, inside: configuration.scope) else { return } @@ -72,14 +72,21 @@ extension TokenBucket: WindowBasedLimiter { } } - func resetWindow() throws { - let redisKeys = keys.map { RedisKey($0) } - - Task { - do { - try await storage.delete(redisKeys).get() - } catch let redisError { - logging?.error("๐Ÿšจ Problem deleting keys: \(redisError.localizedDescription)") + public func resetWindow() throws { + keys.forEach { key in + Task(priority: .userInitiated) { + let redisKey = RedisKey(key) + + let respValue = try await storage.get(redisKey).get() + + var newRefillSize = configuration.refillTokenRate + + if let currentBucketSize = respValue.int { + let expectedBucketSize = currentBucketSize + configuration.refillTokenRate + newRefillSize = (expectedBucketSize >= configuration.bucketSize) ? 0 : configuration.refillTokenRate + } + + try await storage.increment(redisKey, by: newRefillSize).get() } } } diff --git a/Sources/Checkpoint/Limiters/WindowBasedLimiter.swift b/Sources/Checkpoint/Limiters/WindowBasedLimiter.swift index bf96be6..f401544 100644 --- a/Sources/Checkpoint/Limiters/WindowBasedLimiter.swift +++ b/Sources/Checkpoint/Limiters/WindowBasedLimiter.swift @@ -8,15 +8,15 @@ import Combine import Foundation -typealias WindowBasedAction = () throws -> Void +public typealias WindowBasedAction = () throws -> Void -protocol WindowBasedLimiter: Limiter { +public protocol WindowBasedLimiter: Limiter { func startWindow(havingDuration seconds: Double, performing action: @escaping WindowBasedAction) -> AnyCancellable func resetWindow() async throws } extension WindowBasedLimiter { - func startWindow(havingDuration seconds: Double, performing action: @escaping WindowBasedAction) -> AnyCancellable { + public func startWindow(havingDuration seconds: Double, performing action: @escaping WindowBasedAction) -> AnyCancellable { var cancellable = Timer.publish(every: seconds, on: .main, in: .common) .autoconnect() .sink { _ in diff --git a/Sources/Checkpoint/Model/Configuration.swift b/Sources/Checkpoint/Model/Configuration.swift index ac72665..98779a0 100644 --- a/Sources/Checkpoint/Model/Configuration.swift +++ b/Sources/Checkpoint/Model/Configuration.swift @@ -6,41 +6,41 @@ // -protocol Configuration { +public protocol Configuration { var appliedField: Field { get } var scope: RateLimitScope { get } } -struct FixedWindowCounterConfiguration: Configuration { - var requestPerWindow = 5 - var timeWindowDuration: TimeWindow = .seconds(count: 10) +public struct FixedWindowCounterConfiguration: Configuration { + public var requestPerWindow = 5 + public var timeWindowDuration: TimeWindow = .seconds(count: 10) - var appliedField: Field - var scope: RateLimitScope + public var appliedField: Field + public var scope: RateLimitScope } -struct LeakingBucketConfiguration: Configuration { - var bucketSize = 10 - var tokenRemovingRate = 5 - var timeWindowDuration: TimeWindow = .seconds(count: 10) +public struct LeakingBucketConfiguration: Configuration { + public var bucketSize = 10 + public var tokenRemovingRate = 5 + public var timeWindowDuration: TimeWindow = .seconds(count: 10) - var appliedField: Field - var scope: RateLimitScope + public var appliedField: Field + public var scope: RateLimitScope } -struct SlidingWindowLogConfiguration: Configuration { - var requestPerWindow = 10 - var timeWindowDuration: TimeWindow +public struct SlidingWindowLogConfiguration: Configuration { + public var requestPerWindow = 10 + public var timeWindowDuration: TimeWindow - var appliedField: Field - var scope: RateLimitScope + public var appliedField: Field + public var scope: RateLimitScope } -struct TokenBucketConfiguration: Configuration { - var bucketSize: Int - var refillRate: Int - var refillTimeInterval: TimeWindow +public struct TokenBucketConfiguration: Configuration { + public var bucketSize: Int + public var refillTokenRate: Int + public var refillTimeInterval: TimeWindow - var appliedField: Field - var scope: RateLimitScope + public var appliedField: Field + public var scope: RateLimitScope } diff --git a/Sources/Checkpoint/Model/Field.swift b/Sources/Checkpoint/Model/Field.swift index 74c02e9..600e943 100644 --- a/Sources/Checkpoint/Model/Field.swift +++ b/Sources/Checkpoint/Model/Field.swift @@ -7,7 +7,7 @@ import Foundation -enum Field { +public enum Field { case header(key: String) case queryItem(key: String) case none diff --git a/Sources/Checkpoint/Model/RateLimitScope.swift b/Sources/Checkpoint/Model/RateLimitScope.swift index 75a04fe..78ef01f 100644 --- a/Sources/Checkpoint/Model/RateLimitScope.swift +++ b/Sources/Checkpoint/Model/RateLimitScope.swift @@ -7,7 +7,7 @@ import Foundation -enum RateLimitScope { +public enum RateLimitScope { case api case endpoint case nonScope diff --git a/Sources/Checkpoint/Model/TimeWindow.swift b/Sources/Checkpoint/Model/TimeWindow.swift index 0735a7b..dd76770 100644 --- a/Sources/Checkpoint/Model/TimeWindow.swift +++ b/Sources/Checkpoint/Model/TimeWindow.swift @@ -7,7 +7,7 @@ import Foundation -enum TimeWindow { +public enum TimeWindow { case seconds(count: Int = 10) case minutes(count: Int = 1) case hours(count: Int = 1) From 86dcc791d2f06d609ee0391e70633ed1ce7a024d Mon Sep 17 00:00:00 2001 From: fitomad Date: Sat, 22 Jun 2024 23:01:19 +0200 Subject: [PATCH 02/15] Renaming... --- Package.swift | 5 ++++- .../Limiter.swift => Algorithms/Algorithm.swift} | 6 +++--- .../FixedWindowCounter.swift | 2 +- .../{Limiters => Algorithms}/LeakingBucket.swift | 2 +- .../SlidingWindowLog.swift | 2 +- .../{Limiters => Algorithms}/TokenBucket.swift | 2 +- .../WindowBasedAlgorithm.swift} | 4 ++-- Sources/Checkpoint/Checkpoint.swift | 16 ++++++++-------- Tests/CheckpointTests/CheckpointTests.swift | 1 + 9 files changed, 22 insertions(+), 18 deletions(-) rename Sources/Checkpoint/{Limiters/Limiter.swift => Algorithms/Algorithm.swift} (95%) rename Sources/Checkpoint/{Limiters => Algorithms}/FixedWindowCounter.swift (97%) rename Sources/Checkpoint/{Limiters => Algorithms}/LeakingBucket.swift (98%) rename Sources/Checkpoint/{Limiters => Algorithms}/SlidingWindowLog.swift (98%) rename Sources/Checkpoint/{Limiters => Algorithms}/TokenBucket.swift (98%) rename Sources/Checkpoint/{Limiters/WindowBasedLimiter.swift => Algorithms/WindowBasedAlgorithm.swift} (90%) diff --git a/Package.swift b/Package.swift index 392f715..f96e617 100644 --- a/Package.swift +++ b/Package.swift @@ -32,7 +32,10 @@ let package = Package( ), .testTarget( name: "CheckpointTests", - dependencies: ["Checkpoint"] + dependencies: [ + "Checkpoint", + .product(name: "XCTVapor", package: "vapor") + ] ), ] ) diff --git a/Sources/Checkpoint/Limiters/Limiter.swift b/Sources/Checkpoint/Algorithms/Algorithm.swift similarity index 95% rename from Sources/Checkpoint/Limiters/Limiter.swift rename to Sources/Checkpoint/Algorithms/Algorithm.swift index 3a17589..32c50a5 100644 --- a/Sources/Checkpoint/Limiters/Limiter.swift +++ b/Sources/Checkpoint/Algorithms/Algorithm.swift @@ -12,7 +12,7 @@ import Vapor public typealias StorageAction = () -> Application.Redis public typealias LoggerAction = () -> Logger -public protocol Limiter: Sendable { +public protocol Algorithm: Sendable { associatedtype ConfigurationType var storage: Application.Redis { get } @@ -23,7 +23,7 @@ public protocol Limiter: Sendable { func checkRequest(_ request: Request) async throws } -extension Limiter { +extension Algorithm { func valueFor(field: Field, in request: Request) throws -> String { switch field { case .header(let key): @@ -68,7 +68,7 @@ extension Limiter { } } -extension Limiter { +extension Algorithm { static var none: String { "no-key" } diff --git a/Sources/Checkpoint/Limiters/FixedWindowCounter.swift b/Sources/Checkpoint/Algorithms/FixedWindowCounter.swift similarity index 97% rename from Sources/Checkpoint/Limiters/FixedWindowCounter.swift rename to Sources/Checkpoint/Algorithms/FixedWindowCounter.swift index 220a0e8..a9d8519 100644 --- a/Sources/Checkpoint/Limiters/FixedWindowCounter.swift +++ b/Sources/Checkpoint/Algorithms/FixedWindowCounter.swift @@ -40,7 +40,7 @@ public final class FixedWindowCounter { } } -extension FixedWindowCounter: WindowBasedLimiter { +extension FixedWindowCounter: WindowBasedAlgorithm { public func checkRequest(_ request: Request) async throws { guard let requestKey = try? valueFor(field: configuration.appliedField, in: request, inside: configuration.scope) else { return diff --git a/Sources/Checkpoint/Limiters/LeakingBucket.swift b/Sources/Checkpoint/Algorithms/LeakingBucket.swift similarity index 98% rename from Sources/Checkpoint/Limiters/LeakingBucket.swift rename to Sources/Checkpoint/Algorithms/LeakingBucket.swift index a236c66..cc2f15e 100644 --- a/Sources/Checkpoint/Limiters/LeakingBucket.swift +++ b/Sources/Checkpoint/Algorithms/LeakingBucket.swift @@ -47,7 +47,7 @@ public final class LeakingBucket { } } -extension LeakingBucket: WindowBasedLimiter { +extension LeakingBucket: WindowBasedAlgorithm { var isValidRequest: Bool { return true } diff --git a/Sources/Checkpoint/Limiters/SlidingWindowLog.swift b/Sources/Checkpoint/Algorithms/SlidingWindowLog.swift similarity index 98% rename from Sources/Checkpoint/Limiters/SlidingWindowLog.swift rename to Sources/Checkpoint/Algorithms/SlidingWindowLog.swift index 68bef31..f4e9b8c 100644 --- a/Sources/Checkpoint/Limiters/SlidingWindowLog.swift +++ b/Sources/Checkpoint/Algorithms/SlidingWindowLog.swift @@ -20,7 +20,7 @@ public final class SlidingWindowLog { } } -extension SlidingWindowLog: Limiter { +extension SlidingWindowLog: Algorithm { var isValidRequest: Bool { return true } diff --git a/Sources/Checkpoint/Limiters/TokenBucket.swift b/Sources/Checkpoint/Algorithms/TokenBucket.swift similarity index 98% rename from Sources/Checkpoint/Limiters/TokenBucket.swift rename to Sources/Checkpoint/Algorithms/TokenBucket.swift index 77cf6ca..b20444b 100644 --- a/Sources/Checkpoint/Limiters/TokenBucket.swift +++ b/Sources/Checkpoint/Algorithms/TokenBucket.swift @@ -48,7 +48,7 @@ public final class TokenBucket { } } -extension TokenBucket: WindowBasedLimiter { +extension TokenBucket: WindowBasedAlgorithm { public func checkRequest(_ request: Request) async throws { guard let requestKey = try? valueFor(field: configuration.appliedField, in: request, inside: configuration.scope) else { return diff --git a/Sources/Checkpoint/Limiters/WindowBasedLimiter.swift b/Sources/Checkpoint/Algorithms/WindowBasedAlgorithm.swift similarity index 90% rename from Sources/Checkpoint/Limiters/WindowBasedLimiter.swift rename to Sources/Checkpoint/Algorithms/WindowBasedAlgorithm.swift index f401544..fc44e39 100644 --- a/Sources/Checkpoint/Limiters/WindowBasedLimiter.swift +++ b/Sources/Checkpoint/Algorithms/WindowBasedAlgorithm.swift @@ -10,12 +10,12 @@ import Foundation public typealias WindowBasedAction = () throws -> Void -public protocol WindowBasedLimiter: Limiter { +public protocol WindowBasedAlgorithm: Algorithm { func startWindow(havingDuration seconds: Double, performing action: @escaping WindowBasedAction) -> AnyCancellable func resetWindow() async throws } -extension WindowBasedLimiter { +extension WindowBasedAlgorithm { public func startWindow(havingDuration seconds: Double, performing action: @escaping WindowBasedAction) -> AnyCancellable { var cancellable = Timer.publish(every: seconds, on: .main, in: .common) .autoconnect() diff --git a/Sources/Checkpoint/Checkpoint.swift b/Sources/Checkpoint/Checkpoint.swift index 00cd591..e2a53a6 100644 --- a/Sources/Checkpoint/Checkpoint.swift +++ b/Sources/Checkpoint/Checkpoint.swift @@ -9,28 +9,28 @@ import Redis import Vapor final class Checkpoint { - let limiter: any Limiter + let algorithm: any Algorithm - init(using algorithm: some Limiter) { - self.limiter = algorithm + init(using algorithm: some Algorithm) { + self.algorithm = algorithm } } extension Checkpoint: AsyncMiddleware { func respond(to request: Request, chainingTo next: any AsyncResponder) async throws -> Response { - limiter.logging?.info("๐Ÿ‘‰ RateLimitMiddleware request") + algorithm.logging?.info("๐Ÿ‘‰ RateLimitMiddleware request") let response = try await next.respond(to: request) - limiter.logging?.info("๐Ÿ‘ˆ RateLimitMiddleware reponse") + algorithm.logging?.info("๐Ÿ‘ˆ RateLimitMiddleware reponse") do { try await checkRateLimitFor(request: request) response.headers.add(name: "X-App-Version", value: "v1.0.0") - limiter.logging?.info("๐Ÿ’ก Header Setted.") + algorithm.logging?.info("๐Ÿ’ก Header Setted.") } catch let abort as AbortError { throw abort } catch { response.headers.add(name: "X-Rate-Limit", value: "8") - limiter.logging?.info("๐Ÿšจ Header Setted.") + algorithm.logging?.info("๐Ÿšจ Header Setted.") throw Abort(.tooManyRequests, headers: response.headers, reason: HTTPErrorDescription.rateLimitReached) @@ -40,7 +40,7 @@ extension Checkpoint: AsyncMiddleware { } private func checkRateLimitFor(request: Request) async throws { - try await limiter.checkRequest(request) + try await algorithm.checkRequest(request) } } diff --git a/Tests/CheckpointTests/CheckpointTests.swift b/Tests/CheckpointTests/CheckpointTests.swift index c29ca31..0168827 100644 --- a/Tests/CheckpointTests/CheckpointTests.swift +++ b/Tests/CheckpointTests/CheckpointTests.swift @@ -1,4 +1,5 @@ import XCTest +import XCTVapor @testable import Checkpoint final class CheckpointTests: XCTestCase { From 34372d6347b1e409be16fa00c3114f87b371d044 Mon Sep 17 00:00:00 2001 From: fitomad Date: Sun, 23 Jun 2024 10:48:46 +0200 Subject: [PATCH 03/15] Closures to execute custom code at different rate-limit checking operation stage --- Sources/Checkpoint/Checkpoint.swift | 53 +++++++------------ .../{RateLimitScope.swift => Scope.swift} | 4 +- 2 files changed, 20 insertions(+), 37 deletions(-) rename Sources/Checkpoint/Model/{RateLimitScope.swift => Scope.swift} (75%) diff --git a/Sources/Checkpoint/Checkpoint.swift b/Sources/Checkpoint/Checkpoint.swift index e2a53a6..1895063 100644 --- a/Sources/Checkpoint/Checkpoint.swift +++ b/Sources/Checkpoint/Checkpoint.swift @@ -8,29 +8,36 @@ import Redis import Vapor -final class Checkpoint { - let algorithm: any Algorithm +public typealias CheckpointAction = (Request) -> Void +public typealias CheckpointErrorAction = (Request, Error) -> Void + +public final class Checkpoint { + private let algorithm: any Algorithm + + public var willCheck: CheckpointAction? + public var didCheck: CheckpointAction? + public var didFailWithTooManyRequest: CheckpointErrorAction? + public var didFail: CheckpointErrorAction? - init(using algorithm: some Algorithm) { + public init(using algorithm: some Algorithm) { self.algorithm = algorithm } } extension Checkpoint: AsyncMiddleware { - func respond(to request: Request, chainingTo next: any AsyncResponder) async throws -> Response { - algorithm.logging?.info("๐Ÿ‘‰ RateLimitMiddleware request") + public func respond(to request: Request, chainingTo next: any AsyncResponder) async throws -> Response { let response = try await next.respond(to: request) - algorithm.logging?.info("๐Ÿ‘ˆ RateLimitMiddleware reponse") do { + willCheck?(request) try await checkRateLimitFor(request: request) - response.headers.add(name: "X-App-Version", value: "v1.0.0") - algorithm.logging?.info("๐Ÿ’ก Header Setted.") + didCheck?(request) } catch let abort as AbortError { + didFail?(request, abort) throw abort - } catch { - response.headers.add(name: "X-Rate-Limit", value: "8") - algorithm.logging?.info("๐Ÿšจ Header Setted.") + } catch let error { + didFailWithTooManyRequest?(request, error) + throw Abort(.tooManyRequests, headers: response.headers, reason: HTTPErrorDescription.rateLimitReached) @@ -45,32 +52,8 @@ extension Checkpoint: AsyncMiddleware { } extension Checkpoint { - enum Constants { - static let apiKeyHeader = "X-ApiKey" - static let rateLimitDB = "rate-limit" - } - enum HTTPErrorDescription { static let unauthorized = "X-Api-Key header not available in the request" static let rateLimitReached = "You have exceed your ApiKey network requests rate" } } - -extension Checkpoint { - /* - enum Strategy { - case tokenBucket(configuration: TokenBucket.Configuration) - case leakingBucket(configuration: LeakingBucket.Configuration) - case fixedWindowCounter(configuration: FixedWindowCounter.Configuration) - case slidingWindowLog(configuration: SlidingWindowLog.Configuration) - } - */ - -} - -enum Strategy { - case tokenBucket - case leakingBucket - case fixedWindowCounter - case slidingWindowLog -} diff --git a/Sources/Checkpoint/Model/RateLimitScope.swift b/Sources/Checkpoint/Model/Scope.swift similarity index 75% rename from Sources/Checkpoint/Model/RateLimitScope.swift rename to Sources/Checkpoint/Model/Scope.swift index 78ef01f..108aeac 100644 --- a/Sources/Checkpoint/Model/RateLimitScope.swift +++ b/Sources/Checkpoint/Model/Scope.swift @@ -7,8 +7,8 @@ import Foundation -public enum RateLimitScope { +public enum Scope { case api case endpoint - case nonScope + case noScope } From eef47d4189cf1cfc0353977d12b3a464e2a408d2 Mon Sep 17 00:00:00 2001 From: fitomad Date: Sun, 23 Jun 2024 10:49:03 +0200 Subject: [PATCH 04/15] Different refactoring task --- Sources/Checkpoint/Algorithms/Algorithm.swift | 34 ++++++---- .../Checkpoint/Algorithms/LeakingBucket.swift | 4 +- .../Checkpoint/Algorithms/TokenBucket.swift | 2 +- Sources/Checkpoint/Model/Configuration.swift | 63 +++++++++++++++---- Sources/Checkpoint/Model/Field.swift | 2 +- Tests/CheckpointTests/CheckpointTests.swift | 28 +++++++++ 6 files changed, 104 insertions(+), 29 deletions(-) diff --git a/Sources/Checkpoint/Algorithms/Algorithm.swift b/Sources/Checkpoint/Algorithms/Algorithm.swift index 32c50a5..82a3c53 100644 --- a/Sources/Checkpoint/Algorithms/Algorithm.swift +++ b/Sources/Checkpoint/Algorithms/Algorithm.swift @@ -28,37 +28,37 @@ extension Algorithm { switch field { case .header(let key): guard let value = request.headers[key].first else { - throw Abort(.unauthorized, reason: "") + throw Abort(.unauthorized, reason: Self.noFieldMessage) } return value case .queryItem(let key): guard let value = request.query[String.self, at: key] else { - throw Abort(.unauthorized, reason: "") + throw Abort(.unauthorized, reason: Self.noFieldMessage) } return value - case .none: - return Self.none + case .noField: + return Self.fieldDefaultKey } } - func valueFor(scope: RateLimitScope, in request: Request) throws -> String { + func valueFor(scope: Scope, in request: Request) throws -> String { switch scope { case .endpoint: return request.url.path case .api: guard let host = request.url.host else { - throw Abort(.badRequest, reason: "") + throw Abort(.badRequest, reason: Self.hostNotFoundMessage) } return host - case .nonScope: - return Self.nonScope + case .noScope: + return Self.scopeDefaultKey } } - func valueFor(field: Field, in request: Request, inside scope: RateLimitScope) throws -> String { + func valueFor(field: Field, in request: Request, inside scope: Scope) throws -> String { let prefix = try valueFor(field: field, in: request) let suffix = try valueFor(scope: scope, in: request) @@ -69,11 +69,19 @@ extension Algorithm { } extension Algorithm { - static var none: String { - "no-key" + static var fieldDefaultKey: String { + "checkpoint#no.field" } - static var nonScope: String { - "non-scope" + static var scopeDefaultKey: String { + "checkpoint#no.scope" + } + + static var noFieldMessage: String { + "Expected field not found at headers or query parameters" + } + + static var hostNotFoundMessage: String { + "Unable to recover host from request" } } diff --git a/Sources/Checkpoint/Algorithms/LeakingBucket.swift b/Sources/Checkpoint/Algorithms/LeakingBucket.swift index cc2f15e..75b45f8 100644 --- a/Sources/Checkpoint/Algorithms/LeakingBucket.swift +++ b/Sources/Checkpoint/Algorithms/LeakingBucket.swift @@ -42,7 +42,7 @@ public final class LeakingBucket { do { try await storage.set(key, to: 0).get() } catch let redisError { - logging?.error("๐Ÿšจ Problem setting key \(key.rawValue) to value \(configuration.bucketSize)") + logging?.error("๐Ÿšจ Problem setting key \(key.rawValue) to value \(configuration.bucketSize): \(redisError.localizedDescription)") } } } @@ -60,7 +60,7 @@ extension LeakingBucket: WindowBasedAlgorithm { keys.insert(requestKey) let redisKey = RedisKey(requestKey) - let keyExists = await try storage.exists(redisKey).get() + let keyExists = try await storage.exists(redisKey).get() if keyExists == 0 { await preparaStorageFor(key: redisKey) diff --git a/Sources/Checkpoint/Algorithms/TokenBucket.swift b/Sources/Checkpoint/Algorithms/TokenBucket.swift index b20444b..2e726c9 100644 --- a/Sources/Checkpoint/Algorithms/TokenBucket.swift +++ b/Sources/Checkpoint/Algorithms/TokenBucket.swift @@ -57,7 +57,7 @@ extension TokenBucket: WindowBasedAlgorithm { keys.insert(requestKey) let redisKey = RedisKey(requestKey) - let keyExists = await try storage.exists(redisKey).get() + let keyExists = try await storage.exists(redisKey).get() if keyExists == 0 { await preparaStorageFor(key: redisKey) diff --git a/Sources/Checkpoint/Model/Configuration.swift b/Sources/Checkpoint/Model/Configuration.swift index 98779a0..9a027a9 100644 --- a/Sources/Checkpoint/Model/Configuration.swift +++ b/Sources/Checkpoint/Model/Configuration.swift @@ -8,15 +8,24 @@ public protocol Configuration { var appliedField: Field { get } - var scope: RateLimitScope { get } + var scope: Scope { get } } public struct FixedWindowCounterConfiguration: Configuration { - public var requestPerWindow = 5 - public var timeWindowDuration: TimeWindow = .seconds(count: 10) + public private(set) var requestPerWindow: Int + public private(set) var timeWindowDuration: TimeWindow - public var appliedField: Field - public var scope: RateLimitScope + public private(set) var appliedField: Field + public private(set) var scope: Scope +} + +extension FixedWindowCounterConfiguration { + public init(requestPerWindow: Int, timeWindowDuration: TimeWindow, appliedTo field: Field = .noField, inside scope: Scope = .noScope) { + self.requestPerWindow = requestPerWindow + self.timeWindowDuration = timeWindowDuration + self.appliedField = field + self.scope = scope + } } public struct LeakingBucketConfiguration: Configuration { @@ -25,22 +34,52 @@ public struct LeakingBucketConfiguration: Configuration { public var timeWindowDuration: TimeWindow = .seconds(count: 10) public var appliedField: Field - public var scope: RateLimitScope + public var scope: Scope +} + +extension LeakingBucketConfiguration { + public init(bucketSize: Int, removingRate: Int, removingTimeInterval: TimeWindow, appliedTo field: Field = .noField, inside scope: Scope = .noScope) { + self.bucketSize = bucketSize + self.tokenRemovingRate = removingRate + self.timeWindowDuration = removingTimeInterval + self.appliedField = field + self.scope = scope + } } + public struct SlidingWindowLogConfiguration: Configuration { public var requestPerWindow = 10 public var timeWindowDuration: TimeWindow public var appliedField: Field - public var scope: RateLimitScope + public var scope: Scope +} + +extension SlidingWindowLogConfiguration { + public init(requestPerWindow: Int, windowDuration: TimeWindow, appliedTo field: Field = .noField, inside scope: Scope = .noScope) { + self.requestPerWindow = requestPerWindow + self.timeWindowDuration = windowDuration + self.appliedField = field + self.scope = scope + } } public struct TokenBucketConfiguration: Configuration { - public var bucketSize: Int - public var refillTokenRate: Int - public var refillTimeInterval: TimeWindow + public private(set) var bucketSize: Int + public private(set) var refillTokenRate: Int + public private(set) var refillTimeInterval: TimeWindow - public var appliedField: Field - public var scope: RateLimitScope + public private(set) var appliedField: Field + public private(set) var scope: Scope +} + +extension TokenBucketConfiguration { + public init(bucketSize: Int, refillRate: Int, refillTimeInterval: TimeWindow, appliedTo field: Field = .noField, inside scope: Scope = .noScope) { + self.bucketSize = bucketSize + self.refillTokenRate = refillRate + self.refillTimeInterval = refillTimeInterval + self.appliedField = field + self.scope = scope + } } diff --git a/Sources/Checkpoint/Model/Field.swift b/Sources/Checkpoint/Model/Field.swift index 600e943..db53e50 100644 --- a/Sources/Checkpoint/Model/Field.swift +++ b/Sources/Checkpoint/Model/Field.swift @@ -10,5 +10,5 @@ import Foundation public enum Field { case header(key: String) case queryItem(key: String) - case none + case noField } diff --git a/Tests/CheckpointTests/CheckpointTests.swift b/Tests/CheckpointTests/CheckpointTests.swift index 0168827..17eed31 100644 --- a/Tests/CheckpointTests/CheckpointTests.swift +++ b/Tests/CheckpointTests/CheckpointTests.swift @@ -10,4 +10,32 @@ final class CheckpointTests: XCTestCase { // Defining Test Cases and Test Methods // https://developer.apple.com/documentation/xctest/defining_test_cases_and_test_methods } + + func testConfig() { + let tokenbucketAlgorithm = TokenBucket { + TokenBucketConfiguration(bucketSize: 10, + refillRate: 5, + refillTimeInterval: .seconds(count: 2)) + } storage: { + + } + + let checkpoint = Checkpoint(using: tokenbucketAlgorithm) + + checkpoint.willCheck = { (request: Request) in + + } + + checkpoint.didCheck = { (request: Request) in + + } + + checkpoint.didFailWithTooManyRequest = { (request: Request, error: Error) in + + } + + checkpoint.didFail = { (request: Request, error: Error) in + + } + } } From 372ae90d1b98f5c30fb30859b502884f91652a63 Mon Sep 17 00:00:00 2001 From: fitomad Date: Sun, 23 Jun 2024 11:36:01 +0200 Subject: [PATCH 05/15] adding response to error closure --- Sources/Checkpoint/Checkpoint.swift | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/Sources/Checkpoint/Checkpoint.swift b/Sources/Checkpoint/Checkpoint.swift index 1895063..06a32e8 100644 --- a/Sources/Checkpoint/Checkpoint.swift +++ b/Sources/Checkpoint/Checkpoint.swift @@ -9,7 +9,7 @@ import Redis import Vapor public typealias CheckpointAction = (Request) -> Void -public typealias CheckpointErrorAction = (Request, Error) -> Void +public typealias CheckpointErrorAction = (Request, Response, Error) -> Void public final class Checkpoint { private let algorithm: any Algorithm @@ -33,14 +33,14 @@ extension Checkpoint: AsyncMiddleware { try await checkRateLimitFor(request: request) didCheck?(request) } catch let abort as AbortError { - didFail?(request, abort) - throw abort - } catch let error { - didFailWithTooManyRequest?(request, error) + switch abort.status { + case .tooManyRequests: + didFailWithTooManyRequest?(request, response, abort) + default: + didFail?(request, response, abort) + } - throw Abort(.tooManyRequests, - headers: response.headers, - reason: HTTPErrorDescription.rateLimitReached) + throw abort } return response From 116c4e67d898c7b66422e6a012de7f5ff6401936 Mon Sep 17 00:00:00 2001 From: fitomad Date: Sun, 23 Jun 2024 11:43:33 +0200 Subject: [PATCH 06/15] Move to AbortError type --- Sources/Checkpoint/Checkpoint.swift | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Sources/Checkpoint/Checkpoint.swift b/Sources/Checkpoint/Checkpoint.swift index 06a32e8..b0b9495 100644 --- a/Sources/Checkpoint/Checkpoint.swift +++ b/Sources/Checkpoint/Checkpoint.swift @@ -9,7 +9,7 @@ import Redis import Vapor public typealias CheckpointAction = (Request) -> Void -public typealias CheckpointErrorAction = (Request, Response, Error) -> Void +public typealias CheckpointErrorAction = (Request, Response, AbortError) -> Void public final class Checkpoint { private let algorithm: any Algorithm From a1d9d1c968ed38257b0242e9f0f4aee760d7f47b Mon Sep 17 00:00:00 2001 From: fitomad Date: Sun, 23 Jun 2024 11:46:07 +0200 Subject: [PATCH 07/15] inout AbortError parameter --- Sources/Checkpoint/Checkpoint.swift | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/Sources/Checkpoint/Checkpoint.swift b/Sources/Checkpoint/Checkpoint.swift index b0b9495..40657d7 100644 --- a/Sources/Checkpoint/Checkpoint.swift +++ b/Sources/Checkpoint/Checkpoint.swift @@ -9,7 +9,7 @@ import Redis import Vapor public typealias CheckpointAction = (Request) -> Void -public typealias CheckpointErrorAction = (Request, Response, AbortError) -> Void +public typealias CheckpointErrorAction = (Request, Response, inout AbortError) -> Void public final class Checkpoint { private let algorithm: any Algorithm @@ -32,12 +32,12 @@ extension Checkpoint: AsyncMiddleware { willCheck?(request) try await checkRateLimitFor(request: request) didCheck?(request) - } catch let abort as AbortError { + } catch var abort as AbortError { switch abort.status { case .tooManyRequests: - didFailWithTooManyRequest?(request, response, abort) + didFailWithTooManyRequest?(request, response, &abort) default: - didFail?(request, response, abort) + didFail?(request, response, &abort) } throw abort From 9c426769fce36f875e4dfa44308a7426020bb41c Mon Sep 17 00:00:00 2001 From: fitomad Date: Sun, 23 Jun 2024 12:07:35 +0200 Subject: [PATCH 08/15] Adding metadata to error closure --- Sources/Checkpoint/Checkpoint.swift | 41 ++++++++++++++++++++++++----- 1 file changed, 35 insertions(+), 6 deletions(-) diff --git a/Sources/Checkpoint/Checkpoint.swift b/Sources/Checkpoint/Checkpoint.swift index 40657d7..949919d 100644 --- a/Sources/Checkpoint/Checkpoint.swift +++ b/Sources/Checkpoint/Checkpoint.swift @@ -9,7 +9,7 @@ import Redis import Vapor public typealias CheckpointAction = (Request) -> Void -public typealias CheckpointErrorAction = (Request, Response, inout AbortError) -> Void +public typealias CheckpointErrorAction = (Request, Response, Checkpoint.ErrorMetadata) -> Void public final class Checkpoint { private let algorithm: any Algorithm @@ -32,15 +32,23 @@ extension Checkpoint: AsyncMiddleware { willCheck?(request) try await checkRateLimitFor(request: request) didCheck?(request) - } catch var abort as AbortError { + } catch let abort as AbortError { + let errorMetadata = ErrorMetadata() + switch abort.status { case .tooManyRequests: - didFailWithTooManyRequest?(request, response, &abort) + didFailWithTooManyRequest?(request, response, errorMetadata) + + throw Abort(.tooManyRequests, + headers: errorMetadata.httpHeaders, + reason: errorMetadata.reason) default: - didFail?(request, response, &abort) + didFail?(request, response, errorMetadata) + + throw Abort(.badRequest, + headers: errorMetadata.httpHeaders, + reason: errorMetadata.reason) } - - throw abort } return response @@ -51,6 +59,27 @@ extension Checkpoint: AsyncMiddleware { } } +public extension Checkpoint { + final class ErrorMetadata { + public var headers: [String : String]? + public var reason: String? + + var httpHeaders: HTTPHeaders { + var httpHeaders = HTTPHeaders() + + guard let headers else { + return httpHeaders + } + + for (key, content) in headers { + httpHeaders.add(name: key, value: content) + } + + return httpHeaders + } + } +} + extension Checkpoint { enum HTTPErrorDescription { static let unauthorized = "X-Api-Key header not available in the request" From b86301cf5b08c25b4a05d9542f17cd3682ec51a1 Mon Sep 17 00:00:00 2001 From: fitomad Date: Sun, 23 Jun 2024 22:58:14 +0200 Subject: [PATCH 09/15] Fixed error processing rate limit --- .../Checkpoint/Algorithms/FixedWindowCounter.swift | 2 +- Sources/Checkpoint/Algorithms/LeakingBucket.swift | 2 +- Sources/Checkpoint/Algorithms/TokenBucket.swift | 14 ++++++++++---- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/Sources/Checkpoint/Algorithms/FixedWindowCounter.swift b/Sources/Checkpoint/Algorithms/FixedWindowCounter.swift index a9d8519..4c801d8 100644 --- a/Sources/Checkpoint/Algorithms/FixedWindowCounter.swift +++ b/Sources/Checkpoint/Algorithms/FixedWindowCounter.swift @@ -55,7 +55,7 @@ extension FixedWindowCounter: WindowBasedAlgorithm { let requestCount = try await storage.llen(of: redisKey).get() - if requestCount > configuration.requestPerWindow { + if requestCount >= configuration.requestPerWindow { throw Abort(.tooManyRequests) } } diff --git a/Sources/Checkpoint/Algorithms/LeakingBucket.swift b/Sources/Checkpoint/Algorithms/LeakingBucket.swift index 75b45f8..e6e5d3b 100644 --- a/Sources/Checkpoint/Algorithms/LeakingBucket.swift +++ b/Sources/Checkpoint/Algorithms/LeakingBucket.swift @@ -70,7 +70,7 @@ extension LeakingBucket: WindowBasedAlgorithm { let bucketItemsCount = try await storage.increment(redisKey).get() logging?.info("โŒš๏ธ \(requestKey) = \(bucketItemsCount)") // 2. If buckes is empty, throw an error - if bucketItemsCount >= configuration.bucketSize { + if bucketItemsCount > configuration.bucketSize { throw Abort(.tooManyRequests) } } diff --git a/Sources/Checkpoint/Algorithms/TokenBucket.swift b/Sources/Checkpoint/Algorithms/TokenBucket.swift index 2e726c9..00f354a 100644 --- a/Sources/Checkpoint/Algorithms/TokenBucket.swift +++ b/Sources/Checkpoint/Algorithms/TokenBucket.swift @@ -67,7 +67,7 @@ extension TokenBucket: WindowBasedAlgorithm { let bucketItemsCount = try await storage.decrement(redisKey).get() logging?.info("โŒš๏ธ \(requestKey) = \(bucketItemsCount)") // 2. If buckes is empty, throw an error - if bucketItemsCount <= 0 { + if bucketItemsCount < 0 { throw Abort(.tooManyRequests) } } @@ -79,11 +79,17 @@ extension TokenBucket: WindowBasedAlgorithm { let respValue = try await storage.get(redisKey).get() - var newRefillSize = configuration.refillTokenRate + var newRefillSize = 0 if let currentBucketSize = respValue.int { - let expectedBucketSize = currentBucketSize + configuration.refillTokenRate - newRefillSize = (expectedBucketSize >= configuration.bucketSize) ? 0 : configuration.refillTokenRate + switch currentBucketSize { + case ...0: + newRefillSize -= currentBucketSize + case configuration.bucketSize...: + newRefillSize = configuration.bucketSize - currentBucketSize + default: + newRefillSize = configuration.refillTokenRate + } } try await storage.increment(redisKey, by: newRefillSize).get() From 956b248641e02058bcdee48f43ff67fa2bb70cb9 Mon Sep 17 00:00:00 2001 From: fitomad Date: Sun, 23 Jun 2024 22:58:28 +0200 Subject: [PATCH 10/15] Adding unit tests --- .../CheckpointApiKeyTests.swift | 167 ++++++++++++++++++ Tests/CheckpointTests/CheckpointTests.swift | 141 ++++++++++++++- 2 files changed, 299 insertions(+), 9 deletions(-) create mode 100644 Tests/CheckpointTests/CheckpointApiKeyTests.swift diff --git a/Tests/CheckpointTests/CheckpointApiKeyTests.swift b/Tests/CheckpointTests/CheckpointApiKeyTests.swift new file mode 100644 index 0000000..388e76a --- /dev/null +++ b/Tests/CheckpointTests/CheckpointApiKeyTests.swift @@ -0,0 +1,167 @@ +// +// CheckpointApiKeyTests.swift +// +// +// Created by Adolfo Vera Blasco on 23/6/24. +// + +import Redis +import XCTest +import XCTVapor +@testable import Checkpoint + +final class CheckpointApiKeyTests: XCTestCase { + func testLeakingBucketWithHeader() { + let app = Application(.testing) + defer { app.shutdown() } + + app.get("leaking-bucket") { request -> HTTPStatus in + return .ok + } + + let leakingBucketAlgorithm = LeakingBucket { + LeakingBucketConfiguration(bucketSize: 10, + removingRate: 5, + removingTimeInterval: .minutes(count: 1), + appliedTo: .header(key: "X-ApiKey")) + } storage: { + // Rate limit database in Redis + app.redis("rate").configuration = try? RedisConfiguration(hostname: "localhost", + port: 9090, + database: 0) + + return app.redis("rate") + } + + let checkpoint = Checkpoint(using: leakingBucketAlgorithm) + + app.middleware.use(checkpoint) + + (0...20).forEach { index in + try? app.test(.GET, "leaking-bucket", afterResponse: { testResponse in + app.logger.info("\(index) = \(testResponse.status)") + if index < 10 { + XCTAssertEqual(testResponse.status, .ok) + } else { + XCTAssertEqual(testResponse.status, .tooManyRequests) + } + }) + } + } + + func testTokenBucketWithHeader() throws { + let app = Application(.testing) + defer { app.shutdown() } + + app.get("token-bucket") { request -> HTTPStatus in + return .ok + } + + let tokenbucketAlgorithm = TokenBucket { + TokenBucketConfiguration(bucketSize: 10, + refillRate: 0, + refillTimeInterval: .seconds(count: 20), + appliedTo: .header(key: "X-ApiKey")) + } storage: { + // Rate limit database in Redis + app.redis("rate").configuration = try? RedisConfiguration(hostname: "localhost", + port: 9090, + database: 0) + + return app.redis("rate") + } logging: { + app.logger + } + + let checkpoint = Checkpoint(using: tokenbucketAlgorithm) + + app.middleware.use(checkpoint) + + (0...20).forEach { index in + try? app.test(.GET, "token-bucket", afterResponse: { testResponse in + app.logger.info("\(index) = \(testResponse.status)") + if index < 10 { + XCTAssertEqual(testResponse.status, .ok) + } else { + XCTAssertEqual(testResponse.status, .tooManyRequests) + } + }) + } + } + + func testFixedWindowCounterWithHeader() { + let app = Application(.testing) + defer { app.shutdown() } + + app.get("fixed-window-counter") { request -> HTTPStatus in + return .ok + } + + let fixedWindowAlgorithm = FixedWindowCounter { + FixedWindowCounterConfiguration(requestPerWindow: 10, + timeWindowDuration: .minutes(count: 2), + appliedTo: .header(key: "X-ApiKey")) + } storage: { + // Rate limit database in Redis + app.redis("rate").configuration = try? RedisConfiguration(hostname: "localhost", + port: 9090, + database: 0) + + return app.redis("rate") + } + + + let checkpoint = Checkpoint(using: fixedWindowAlgorithm) + + app.middleware.use(checkpoint) + + (0...20).forEach { index in + try? app.test(.GET, "fixed-window-counter", afterResponse: { testResponse in + app.logger.info("\(index) = \(testResponse.status)") + if index < 10 { + XCTAssertEqual(testResponse.status, .ok) + } else { + XCTAssertEqual(testResponse.status, .tooManyRequests) + } + }) + } + } + + func testSlidingWindowLogWithHeader() { + let app = Application(.testing) + defer { app.shutdown() } + + app.get("sliding-window-log") { request -> HTTPStatus in + return .ok + } + + let slidingWindowLogAlgorith = SlidingWindowLog { + SlidingWindowLogConfiguration(requestPerWindow: 10, + windowDuration: .minutes(count: 2), + appliedTo: .header(key: "X-ApiKey")) + } storage: { + // Rate limit database in Redis + app.redis("rate").configuration = try? RedisConfiguration(hostname: "localhost", + port: 9090, + database: 0) + + return app.redis("rate") + } + + + let checkpoint = Checkpoint(using: slidingWindowLogAlgorith) + + app.middleware.use(checkpoint) + + (0...20).forEach { index in + try? app.test(.GET, "sliding-window-log", afterResponse: { testResponse in + app.logger.info("\(index) = \(testResponse.status)") + if index < 10 { + XCTAssertEqual(testResponse.status, .ok) + } else { + XCTAssertEqual(testResponse.status, .tooManyRequests) + } + }) + } + } +} diff --git a/Tests/CheckpointTests/CheckpointTests.swift b/Tests/CheckpointTests/CheckpointTests.swift index 17eed31..895dde2 100644 --- a/Tests/CheckpointTests/CheckpointTests.swift +++ b/Tests/CheckpointTests/CheckpointTests.swift @@ -1,3 +1,4 @@ +import Redis import XCTest import XCTVapor @testable import Checkpoint @@ -11,31 +12,153 @@ final class CheckpointTests: XCTestCase { // https://developer.apple.com/documentation/xctest/defining_test_cases_and_test_methods } - func testConfig() { + func testLeakingBucket() { + let app = Application(.testing) + defer { app.shutdown() } + + app.get("leaking-bucket") { request -> HTTPStatus in + return .ok + } + + let leakingBucketAlgorithm = LeakingBucket { + LeakingBucketConfiguration(bucketSize: 10, + removingRate: 5, + removingTimeInterval: .minutes(count: 1)) + } storage: { + // Rate limit database in Redis + app.redis("rate").configuration = try? RedisConfiguration(hostname: "localhost", + port: 9090, + database: 0) + + return app.redis("rate") + } + + let checkpoint = Checkpoint(using: leakingBucketAlgorithm) + + app.middleware.use(checkpoint) + + (0...20).forEach { index in + try? app.test(.GET, "leaking-bucket", afterResponse: { testResponse in + app.logger.info("\(index) = \(testResponse.status)") + if index < 10 { + XCTAssertEqual(testResponse.status, .ok) + } else { + XCTAssertEqual(testResponse.status, .tooManyRequests) + } + }) + } + } + + func testTokenBucket() throws { + let app = Application(.testing) + defer { app.shutdown() } + + app.get("token-bucket") { request -> HTTPStatus in + return .ok + } + let tokenbucketAlgorithm = TokenBucket { TokenBucketConfiguration(bucketSize: 10, - refillRate: 5, - refillTimeInterval: .seconds(count: 2)) + refillRate: 0, + refillTimeInterval: .seconds(count: 20)) } storage: { + // Rate limit database in Redis + app.redis("rate").configuration = try? RedisConfiguration(hostname: "localhost", + port: 9090, + database: 0) + return app.redis("rate") + } logging: { + app.logger } let checkpoint = Checkpoint(using: tokenbucketAlgorithm) - checkpoint.willCheck = { (request: Request) in - + app.middleware.use(checkpoint) + + (0...20).forEach { index in + try? app.test(.GET, "token-bucket", afterResponse: { testResponse in + app.logger.info("\(index) = \(testResponse.status)") + if index < 10 { + XCTAssertEqual(testResponse.status, .ok) + } else { + XCTAssertEqual(testResponse.status, .tooManyRequests) + } + }) + } + } + + func testFixedWindowCounter() { + let app = Application(.testing) + defer { app.shutdown() } + + app.get("fixed-window-counter") { request -> HTTPStatus in + return .ok } - checkpoint.didCheck = { (request: Request) in + let fixedWindowAlgorithm = FixedWindowCounter { + FixedWindowCounterConfiguration(requestPerWindow: 10, + timeWindowDuration: .minutes(count: 2)) + } storage: { + // Rate limit database in Redis + app.redis("rate").configuration = try? RedisConfiguration(hostname: "localhost", + port: 9090, + database: 0) + return app.redis("rate") } - checkpoint.didFailWithTooManyRequest = { (request: Request, error: Error) in - + + let checkpoint = Checkpoint(using: fixedWindowAlgorithm) + + app.middleware.use(checkpoint) + + (0...20).forEach { index in + try? app.test(.GET, "fixed-window-counter", afterResponse: { testResponse in + app.logger.info("\(index) = \(testResponse.status)") + if index < 10 { + XCTAssertEqual(testResponse.status, .ok) + } else { + XCTAssertEqual(testResponse.status, .tooManyRequests) + } + }) + } + } + + func testSlidingWindowLog() { + let app = Application(.testing) + defer { app.shutdown() } + + app.get("sliding-window-log") { request -> HTTPStatus in + return .ok } - checkpoint.didFail = { (request: Request, error: Error) in + let slidingWindowLogAlgorith = SlidingWindowLog { + SlidingWindowLogConfiguration(requestPerWindow: 10, + windowDuration: .minutes(count: 2)) + } storage: { + // Rate limit database in Redis + app.redis("rate").configuration = try? RedisConfiguration(hostname: "localhost", + port: 9090, + database: 0) + return app.redis("rate") + } + + + let checkpoint = Checkpoint(using: slidingWindowLogAlgorith) + + app.middleware.use(checkpoint) + + (0...20).forEach { index in + try? app.test(.GET, "sliding-window-log", afterResponse: { testResponse in + app.logger.info("\(index) = \(testResponse.status)") + if index < 10 { + XCTAssertEqual(testResponse.status, .ok) + } else { + XCTAssertEqual(testResponse.status, .tooManyRequests) + } + }) } } } From fcc42dea65e24c775496f5850fb02c79a908b60e Mon Sep 17 00:00:00 2001 From: fitomad Date: Mon, 24 Jun 2024 20:14:04 +0200 Subject: [PATCH 11/15] Remove llen operation due to rpush returns the current length --- Sources/Checkpoint/Algorithms/FixedWindowCounter.swift | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/Sources/Checkpoint/Algorithms/FixedWindowCounter.swift b/Sources/Checkpoint/Algorithms/FixedWindowCounter.swift index 4c801d8..8dc9c31 100644 --- a/Sources/Checkpoint/Algorithms/FixedWindowCounter.swift +++ b/Sources/Checkpoint/Algorithms/FixedWindowCounter.swift @@ -51,11 +51,9 @@ extension FixedWindowCounter: WindowBasedAlgorithm { let redisKey = RedisKey(requestKey) let timestamp = Date.now.timeIntervalSince1970 - storage.rpush([ timestamp ], into: redisKey) + let requestCount = try await storage.rpush([ timestamp ], into: redisKey).get() - let requestCount = try await storage.llen(of: redisKey).get() - - if requestCount >= configuration.requestPerWindow { + if requestCount > configuration.requestPerWindow { throw Abort(.tooManyRequests) } } From 16f4844d77c706700551d464e94ffb92f236b872 Mon Sep 17 00:00:00 2001 From: fitomad Date: Mon, 24 Jun 2024 20:14:27 +0200 Subject: [PATCH 12/15] Unit tests --- Sources/Checkpoint/Algorithms/Algorithm.swift | 8 +- .../CheckpointApiKeyTests.swift | 28 ++- .../CheckpointApiScopeTests.swift | 183 ++++++++++++++++++ .../CheckpointScopeTests.swift | 167 ++++++++++++++++ 4 files changed, 376 insertions(+), 10 deletions(-) create mode 100644 Tests/CheckpointTests/CheckpointApiScopeTests.swift create mode 100644 Tests/CheckpointTests/CheckpointScopeTests.swift diff --git a/Sources/Checkpoint/Algorithms/Algorithm.swift b/Sources/Checkpoint/Algorithms/Algorithm.swift index 82a3c53..6ea4184 100644 --- a/Sources/Checkpoint/Algorithms/Algorithm.swift +++ b/Sources/Checkpoint/Algorithms/Algorithm.swift @@ -62,9 +62,13 @@ extension Algorithm { let prefix = try valueFor(field: field, in: request) let suffix = try valueFor(scope: scope, in: request) - let key = String("\(prefix)\(suffix)".hash) + var hasher = Hasher() + hasher.combine(prefix) + hasher.combine(suffix) - return key + let key = hasher.finalize() + + return String(key) } } diff --git a/Tests/CheckpointTests/CheckpointApiKeyTests.swift b/Tests/CheckpointTests/CheckpointApiKeyTests.swift index 388e76a..725917d 100644 --- a/Tests/CheckpointTests/CheckpointApiKeyTests.swift +++ b/Tests/CheckpointTests/CheckpointApiKeyTests.swift @@ -37,9 +37,12 @@ final class CheckpointApiKeyTests: XCTestCase { app.middleware.use(checkpoint) + var apiKeyHeader = HTTPHeaders() + apiKeyHeader.add(name: "X-ApiKey", value: "fitomad#1") + (0...20).forEach { index in - try? app.test(.GET, "leaking-bucket", afterResponse: { testResponse in - app.logger.info("\(index) = \(testResponse.status)") + try? app.test(.GET, "leaking-bucket", headers: apiKeyHeader, afterResponse: { testResponse in + app.logger.info("leaking-bucket \(index) = \(testResponse.status)") if index < 10 { XCTAssertEqual(testResponse.status, .ok) } else { @@ -77,9 +80,12 @@ final class CheckpointApiKeyTests: XCTestCase { app.middleware.use(checkpoint) + var apiKeyHeader = HTTPHeaders() + apiKeyHeader.add(name: "X-ApiKey", value: "fitomad#2") + (0...20).forEach { index in - try? app.test(.GET, "token-bucket", afterResponse: { testResponse in - app.logger.info("\(index) = \(testResponse.status)") + try? app.test(.GET, "token-bucket", headers: apiKeyHeader, afterResponse: { testResponse in + app.logger.info("token-bucket \(index) = \(testResponse.status)") if index < 10 { XCTAssertEqual(testResponse.status, .ok) } else { @@ -115,9 +121,12 @@ final class CheckpointApiKeyTests: XCTestCase { app.middleware.use(checkpoint) + var apiKeyHeader = HTTPHeaders() + apiKeyHeader.add(name: "X-ApiKey", value: "fitomad#3") + (0...20).forEach { index in - try? app.test(.GET, "fixed-window-counter", afterResponse: { testResponse in - app.logger.info("\(index) = \(testResponse.status)") + try? app.test(.GET, "fixed-window-counter", headers: apiKeyHeader, afterResponse: { testResponse in + app.logger.info("fixed-window-counter \(index) = \(testResponse.status)") if index < 10 { XCTAssertEqual(testResponse.status, .ok) } else { @@ -153,9 +162,12 @@ final class CheckpointApiKeyTests: XCTestCase { app.middleware.use(checkpoint) + var apiKeyHeader = HTTPHeaders() + apiKeyHeader.add(name: "X-ApiKey", value: "fitomad#4") + (0...20).forEach { index in - try? app.test(.GET, "sliding-window-log", afterResponse: { testResponse in - app.logger.info("\(index) = \(testResponse.status)") + try? app.test(.GET, "sliding-window-log", headers: apiKeyHeader, afterResponse: { testResponse in + app.logger.info("sliding-window-log \(index) = \(testResponse.status)") if index < 10 { XCTAssertEqual(testResponse.status, .ok) } else { diff --git a/Tests/CheckpointTests/CheckpointApiScopeTests.swift b/Tests/CheckpointTests/CheckpointApiScopeTests.swift new file mode 100644 index 0000000..9da1efe --- /dev/null +++ b/Tests/CheckpointTests/CheckpointApiScopeTests.swift @@ -0,0 +1,183 @@ +// +// CheckpointApiScopeTests.swift +// +// +// Created by Adolfo Vera Blasco on 24/6/24. +// + +import Redis +import XCTest +import XCTVapor +@testable import Checkpoint + +final class CheckpointApiScoreTests: XCTestCase { + func testLeakingBucketWithScopeApiHeader() { + let app = Application(.testing) + defer { app.shutdown() } + + app.get("leaking-bucket") { request -> HTTPStatus in + return .ok + } + + let leakingBucketAlgorithm = LeakingBucket { + LeakingBucketConfiguration(bucketSize: 10, + removingRate: 5, + removingTimeInterval: .minutes(count: 1), + appliedTo: .header(key: "X-ApiKey"), + inside :.endpoint) + } storage: { + // Rate limit database in Redis + app.redis("rate").configuration = try? RedisConfiguration(hostname: "localhost", + port: 9090, + database: 0) + + return app.redis("rate") + } + + let checkpoint = Checkpoint(using: leakingBucketAlgorithm) + + app.middleware.use(checkpoint) + + var apiKeyHeader = HTTPHeaders() + apiKeyHeader.add(name: "X-ApiKey", value: "fitomad#1") + + (0...20).forEach { index in + try? app.test(.GET, "leaking-bucket", headers: apiKeyHeader, afterResponse: { testResponse in + app.logger.info("leaking-bucket \(index) = \(testResponse.status)") + if index < 10 { + XCTAssertEqual(testResponse.status, .ok) + } else { + XCTAssertEqual(testResponse.status, .tooManyRequests) + } + }) + } + } + + func testTokenBucketWithScopeApiHeader() throws { + let app = Application(.testing) + defer { app.shutdown() } + + app.get("token-bucket") { request -> HTTPStatus in + return .ok + } + + let tokenbucketAlgorithm = TokenBucket { + TokenBucketConfiguration(bucketSize: 10, + refillRate: 0, + refillTimeInterval: .seconds(count: 20), + appliedTo: .header(key: "X-ApiKey"), + inside: .endpoint) + } storage: { + // Rate limit database in Redis + app.redis("rate").configuration = try? RedisConfiguration(hostname: "localhost", + port: 9090, + database: 0) + + return app.redis("rate") + } logging: { + app.logger + } + + let checkpoint = Checkpoint(using: tokenbucketAlgorithm) + + app.middleware.use(checkpoint) + + var apiKeyHeader = HTTPHeaders() + apiKeyHeader.add(name: "X-ApiKey", value: "fitomad#2") + + (0...20).forEach { index in + try? app.test(.GET, "token-bucket", headers: apiKeyHeader, afterResponse: { testResponse in + app.logger.info("token-bucket \(index) = \(testResponse.status)") + if index < 10 { + XCTAssertEqual(testResponse.status, .ok) + } else { + XCTAssertEqual(testResponse.status, .tooManyRequests) + } + }) + } + } + + func testFixedWindowCounterScopeApiWithHeader() { + let app = Application(.testing) + defer { app.shutdown() } + + app.get("fixed-window-counter") { request -> HTTPStatus in + return .ok + } + + let fixedWindowAlgorithm = FixedWindowCounter { + FixedWindowCounterConfiguration(requestPerWindow: 10, + timeWindowDuration: .minutes(count: 2), + appliedTo: .header(key: "X-ApiKey"), + inside: .endpoint) + } storage: { + // Rate limit database in Redis + app.redis("rate").configuration = try? RedisConfiguration(hostname: "localhost", + port: 9090, + database: 0) + + return app.redis("rate") + } + + + let checkpoint = Checkpoint(using: fixedWindowAlgorithm) + + app.middleware.use(checkpoint) + + var apiKeyHeader = HTTPHeaders() + apiKeyHeader.add(name: "X-ApiKey", value: "fitomad#3") + + (0...20).forEach { index in + try? app.test(.GET, "fixed-window-counter", headers: apiKeyHeader, afterResponse: { testResponse in + app.logger.info("fixed-window-counter \(index) = \(testResponse.status)") + if index < 10 { + XCTAssertEqual(testResponse.status, .ok) + } else { + XCTAssertEqual(testResponse.status, .tooManyRequests) + } + }) + } + } + + func testSlidingWindowLogScopeApiWithHeader() { + let app = Application(.testing) + defer { app.shutdown() } + + app.get("sliding-window-log") { request -> HTTPStatus in + return .ok + } + + let slidingWindowLogAlgorith = SlidingWindowLog { + SlidingWindowLogConfiguration(requestPerWindow: 10, + windowDuration: .minutes(count: 2), + appliedTo: .header(key: "X-ApiKey"), + inside: .endpoint) + } storage: { + // Rate limit database in Redis + app.redis("rate").configuration = try? RedisConfiguration(hostname: "localhost", + port: 9090, + database: 0) + + return app.redis("rate") + } + + + let checkpoint = Checkpoint(using: slidingWindowLogAlgorith) + + app.middleware.use(checkpoint) + + var apiKeyHeader = HTTPHeaders() + apiKeyHeader.add(name: "X-ApiKey", value: "fitomad#4") + + (0...20).forEach { index in + try? app.test(.GET, "sliding-window-log", headers: apiKeyHeader, afterResponse: { testResponse in + app.logger.info("sliding-window-log \(index) = \(testResponse.status)") + if index < 10 { + XCTAssertEqual(testResponse.status, .ok) + } else { + XCTAssertEqual(testResponse.status, .tooManyRequests) + } + }) + } + } +} diff --git a/Tests/CheckpointTests/CheckpointScopeTests.swift b/Tests/CheckpointTests/CheckpointScopeTests.swift new file mode 100644 index 0000000..f73fbf4 --- /dev/null +++ b/Tests/CheckpointTests/CheckpointScopeTests.swift @@ -0,0 +1,167 @@ +// +// CheckpointScopeTests.swift +// +// +// Created by Adolfo Vera Blasco on 24/6/24. +// + +import Redis +import XCTest +import XCTVapor +@testable import Checkpoint + +final class CheckpointScopeTests: XCTestCase { + func testLeakingBucketWithScopeHeader() { + let app = Application(.testing) + defer { app.shutdown() } + + app.get("leaking-bucket") { request -> HTTPStatus in + return .ok + } + + let leakingBucketAlgorithm = LeakingBucket { + LeakingBucketConfiguration(bucketSize: 10, + removingRate: 5, + removingTimeInterval: .minutes(count: 1), + inside: .endpoint) + } storage: { + // Rate limit database in Redis + app.redis("rate").configuration = try? RedisConfiguration(hostname: "localhost", + port: 9090, + database: 0) + + return app.redis("rate") + } + + let checkpoint = Checkpoint(using: leakingBucketAlgorithm) + + app.middleware.use(checkpoint) + + (0...20).forEach { index in + try? app.test(.GET, "leaking-bucket", afterResponse: { testResponse in + app.logger.info("leaking-bucket \(index) = \(testResponse.status)") + if index < 10 { + XCTAssertEqual(testResponse.status, .ok) + } else { + XCTAssertEqual(testResponse.status, .tooManyRequests) + } + }) + } + } + + func testTokenBucketWithScopeHeader() throws { + let app = Application(.testing) + defer { app.shutdown() } + + app.get("token-bucket") { request -> HTTPStatus in + return .ok + } + + let tokenbucketAlgorithm = TokenBucket { + TokenBucketConfiguration(bucketSize: 10, + refillRate: 0, + refillTimeInterval: .seconds(count: 20), + inside: .endpoint) + } storage: { + // Rate limit database in Redis + app.redis("rate").configuration = try? RedisConfiguration(hostname: "localhost", + port: 9090, + database: 0) + + return app.redis("rate") + } logging: { + app.logger + } + + let checkpoint = Checkpoint(using: tokenbucketAlgorithm) + + app.middleware.use(checkpoint) + + (0...20).forEach { index in + try? app.test(.GET, "token-bucket", afterResponse: { testResponse in + app.logger.info("token-bucket \(index) = \(testResponse.status)") + if index < 10 { + XCTAssertEqual(testResponse.status, .ok) + } else { + XCTAssertEqual(testResponse.status, .tooManyRequests) + } + }) + } + } + + func testFixedWindowCounterWithScopeHeader() { + let app = Application(.testing) + defer { app.shutdown() } + + app.get("fixed-window-counter") { request -> HTTPStatus in + return .ok + } + + let fixedWindowAlgorithm = FixedWindowCounter { + FixedWindowCounterConfiguration(requestPerWindow: 10, + timeWindowDuration: .minutes(count: 2), + inside: .endpoint) + } storage: { + // Rate limit database in Redis + app.redis("rate").configuration = try? RedisConfiguration(hostname: "localhost", + port: 9090, + database: 0) + + return app.redis("rate") + } + + + let checkpoint = Checkpoint(using: fixedWindowAlgorithm) + + app.middleware.use(checkpoint) + + (0...20).forEach { index in + try? app.test(.GET, "fixed-window-counter", afterResponse: { testResponse in + app.logger.info("fixed-window-counter \(index) = \(testResponse.status)") + if index < 10 { + XCTAssertEqual(testResponse.status, .ok) + } else { + XCTAssertEqual(testResponse.status, .tooManyRequests) + } + }) + } + } + + func testSlidingWindowLogWithScopeHeader() { + let app = Application(.testing) + defer { app.shutdown() } + + app.get("sliding-window-log") { request -> HTTPStatus in + return .ok + } + + let slidingWindowLogAlgorith = SlidingWindowLog { + SlidingWindowLogConfiguration(requestPerWindow: 10, + windowDuration: .minutes(count: 2), + inside: .endpoint) + } storage: { + // Rate limit database in Redis + app.redis("rate").configuration = try? RedisConfiguration(hostname: "localhost", + port: 9090, + database: 0) + + return app.redis("rate") + } + + + let checkpoint = Checkpoint(using: slidingWindowLogAlgorith) + + app.middleware.use(checkpoint) + + (0...20).forEach { index in + try? app.test(.GET, "sliding-window-log", afterResponse: { testResponse in + app.logger.info("sliding-window-log \(index) = \(testResponse.status)") + if index < 10 { + XCTAssertEqual(testResponse.status, .ok) + } else { + XCTAssertEqual(testResponse.status, .tooManyRequests) + } + }) + } + } +} From 54a1757aa8c83f1ddf68af57fdf2ee334c4416fc Mon Sep 17 00:00:00 2001 From: fitomad Date: Mon, 24 Jun 2024 20:21:00 +0200 Subject: [PATCH 13/15] Downgrade to Swift 5.7 --- Package.swift | 7 +++++-- Sources/Checkpoint/Algorithms/FixedWindowCounter.swift | 2 +- Sources/Checkpoint/Algorithms/SlidingWindowLog.swift | 4 ++-- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/Package.swift b/Package.swift index f96e617..9a58ebe 100644 --- a/Package.swift +++ b/Package.swift @@ -1,4 +1,4 @@ -// swift-tools-version: 6.0 +// swift-tools-version: 5.7 // The swift-tools-version declares the minimum version of Swift required to build this package. import PackageDescription @@ -6,7 +6,10 @@ import PackageDescription let package = Package( name: "Checkpoint", platforms: [ - .macOS(.v13) + .macOS(.v10_15), + .iOS(.v13), + .tvOS(.v13), + .watchOS(.v6) ], products: [ // Products define the executables and libraries a package produces, making them visible to other packages. diff --git a/Sources/Checkpoint/Algorithms/FixedWindowCounter.swift b/Sources/Checkpoint/Algorithms/FixedWindowCounter.swift index 8dc9c31..bb6c838 100644 --- a/Sources/Checkpoint/Algorithms/FixedWindowCounter.swift +++ b/Sources/Checkpoint/Algorithms/FixedWindowCounter.swift @@ -49,7 +49,7 @@ extension FixedWindowCounter: WindowBasedAlgorithm { keys.insert(requestKey) let redisKey = RedisKey(requestKey) - let timestamp = Date.now.timeIntervalSince1970 + let timestamp = Date().timeIntervalSince1970 let requestCount = try await storage.rpush([ timestamp ], into: redisKey).get() diff --git a/Sources/Checkpoint/Algorithms/SlidingWindowLog.swift b/Sources/Checkpoint/Algorithms/SlidingWindowLog.swift index f4e9b8c..cb70cb8 100644 --- a/Sources/Checkpoint/Algorithms/SlidingWindowLog.swift +++ b/Sources/Checkpoint/Algorithms/SlidingWindowLog.swift @@ -33,8 +33,8 @@ extension SlidingWindowLog: Algorithm { logging?.info("๐Ÿ’ก ApiKey: \(apiKey)") let redisKey = RedisKey(apiKey) - let requestDate = Date.now - let outdatedRequestLimiteDate = Date.now.addingTimeInterval(-configuration.timeWindowDuration.inSeconds) + let requestDate = Date() + let outdatedRequestLimiteDate = Date().addingTimeInterval(-configuration.timeWindowDuration.inSeconds) // 1. Delete outdated request let topBound: Double = Double(outdatedRequestLimiteDate.timeIntervalSinceReferenceDate) From 90707dab27d092e64ff8ef19950c25e17ef25322 Mon Sep 17 00:00:00 2001 From: fitomad Date: Mon, 24 Jun 2024 20:34:33 +0200 Subject: [PATCH 14/15] Test for modified responses --- README.md | 19 +- .../CheckpointTests/CheckpointResponse.swift | 203 ++++++++++++++++++ 2 files changed, 217 insertions(+), 5 deletions(-) create mode 100644 Tests/CheckpointTests/CheckpointResponse.swift diff --git a/README.md b/README.md index 3f97815..07845a7 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,14 @@ -# Checkpoint +# Checkpoint ๐Ÿ’ง A Rate-Limit middleware implementation for Vapor servers using Redis database. ```swift let tokenBucket = TokenBucket { - TokenBucketConfiguration(bucketSize: 5, + TokenBucketConfiguration(bucketSize: 25, refillRate: 5, - refillTimeInterval: .seconds(count: 30), + refillTimeInterval: .seconds(count: 45), appliedField: .header(key: "X-ApiKey"), - scope: .nonScope) + scope: .endpoint) } storage: { application.redis("rate") } logging: { @@ -18,6 +18,15 @@ let tokenBucket = TokenBucket { let checkpoint = Checkpoint(using: tokenBucket) +// Modify response HTTP header and body response when rate limit exceed +checkpoint.didFailWithTooManyRequest = { (request, response, metadata) in + metadata.headers = [ + "X-RateLimit" : "Failure for request \(request.id)." + ] + + metadata.reason = "Rate limit for your api key exceeded" +} + // ๐Ÿ’ง Vapor Middleware app.middleware.use(checkpoint) ``` @@ -32,4 +41,4 @@ app.middleware.use(checkpoint) ### Sliding Window Log - +## Modify server response diff --git a/Tests/CheckpointTests/CheckpointResponse.swift b/Tests/CheckpointTests/CheckpointResponse.swift new file mode 100644 index 0000000..051506b --- /dev/null +++ b/Tests/CheckpointTests/CheckpointResponse.swift @@ -0,0 +1,203 @@ +// +// CheckpointResponse.swift +// +// +// Created by Adolfo Vera Blasco on 24/6/24. +// + +import Redis +import XCTest +import XCTVapor +@testable import Checkpoint + +final class CheckpointResponse: XCTestCase { + func testLeakingBucketResponse() { + let app = Application(.testing) + defer { app.shutdown() } + + app.get("leaking-bucket") { request -> HTTPStatus in + return .ok + } + + let leakingBucketAlgorithm = LeakingBucket { + LeakingBucketConfiguration(bucketSize: 10, + removingRate: 5, + removingTimeInterval: .minutes(count: 1), + appliedTo: .header(key: "X-ApiKey"), + inside :.endpoint) + } storage: { + // Rate limit database in Redis + app.redis("rate").configuration = try? RedisConfiguration(hostname: "localhost", + port: 9090, + database: 0) + + return app.redis("rate") + } + + let checkpoint = Checkpoint(using: leakingBucketAlgorithm) + + checkpoint.didFailWithTooManyRequest = { (request, response, metadata) in + metadata.headers = [ + "X-RateLimit" : "Failure for request \(request.id)" + ] + } + + app.middleware.use(checkpoint) + + var apiKeyHeader = HTTPHeaders() + apiKeyHeader.add(name: "X-ApiKey", value: "fitomad#5") + + (0...20).forEach { index in + try? app.test(.GET, "leaking-bucket", headers: apiKeyHeader, afterResponse: { testResponse in + if index < 10 { + XCTAssertFalse(testResponse.headers.contains(name: "X-RateLimit")) + } else { + XCTAssertTrue(testResponse.headers.contains(name: "X-RateLimit")) + } + }) + } + } + + func testTokenBucketResponse() throws { + let app = Application(.testing) + defer { app.shutdown() } + + app.get("token-bucket") { request -> HTTPStatus in + return .ok + } + + let tokenbucketAlgorithm = TokenBucket { + TokenBucketConfiguration(bucketSize: 10, + refillRate: 0, + refillTimeInterval: .seconds(count: 20), + appliedTo: .header(key: "X-ApiKey"), + inside: .endpoint) + } storage: { + // Rate limit database in Redis + app.redis("rate").configuration = try? RedisConfiguration(hostname: "localhost", + port: 9090, + database: 0) + + return app.redis("rate") + } logging: { + app.logger + } + + let checkpoint = Checkpoint(using: tokenbucketAlgorithm) + + checkpoint.didFailWithTooManyRequest = { (request, response, metadata) in + metadata.headers = [ + "X-RateLimit" : "Failure for request \(request.id)" + ] + } + + app.middleware.use(checkpoint) + + var apiKeyHeader = HTTPHeaders() + apiKeyHeader.add(name: "X-ApiKey", value: "fitomad#6") + + (0...20).forEach { index in + try? app.test(.GET, "token-bucket", headers: apiKeyHeader, afterResponse: { testResponse in + if index < 10 { + XCTAssertFalse(testResponse.headers.contains(name: "X-RateLimit")) + } else { + XCTAssertTrue(testResponse.headers.contains(name: "X-RateLimit")) + } + }) + } + } + + func testFixedWindowCounterResponse() { + let app = Application(.testing) + defer { app.shutdown() } + + app.get("fixed-window-counter") { request -> HTTPStatus in + return .ok + } + + let fixedWindowAlgorithm = FixedWindowCounter { + FixedWindowCounterConfiguration(requestPerWindow: 10, + timeWindowDuration: .minutes(count: 2), + appliedTo: .header(key: "X-ApiKey"), + inside: .endpoint) + } storage: { + // Rate limit database in Redis + app.redis("rate").configuration = try? RedisConfiguration(hostname: "localhost", + port: 9090, + database: 0) + + return app.redis("rate") + } + + + let checkpoint = Checkpoint(using: fixedWindowAlgorithm) + + checkpoint.didFailWithTooManyRequest = { (request, response, metadata) in + metadata.headers = [ + "X-RateLimit" : "Failure for request \(request.id)" + ] + } + + app.middleware.use(checkpoint) + + var apiKeyHeader = HTTPHeaders() + apiKeyHeader.add(name: "X-ApiKey", value: "fitomad#7") + + (0...20).forEach { index in + try? app.test(.GET, "fixed-window-counter", headers: apiKeyHeader, afterResponse: { testResponse in + if index < 10 { + XCTAssertFalse(testResponse.headers.contains(name: "X-RateLimit")) + } else { + XCTAssertTrue(testResponse.headers.contains(name: "X-RateLimit")) + } + }) + } + } + + func testSlidingWindowLogResponse() { + let app = Application(.testing) + defer { app.shutdown() } + + app.get("sliding-window-log") { request -> HTTPStatus in + return .ok + } + + let slidingWindowLogAlgorith = SlidingWindowLog { + SlidingWindowLogConfiguration(requestPerWindow: 10, + windowDuration: .minutes(count: 2), + appliedTo: .header(key: "X-ApiKey"), + inside: .endpoint) + } storage: { + // Rate limit database in Redis + app.redis("rate").configuration = try? RedisConfiguration(hostname: "localhost", + port: 9090, + database: 0) + + return app.redis("rate") + } + + + let checkpoint = Checkpoint(using: slidingWindowLogAlgorith) + + checkpoint.didFailWithTooManyRequest = { (request, response, metadata) in + metadata.headers = [ + "X-RateLimit" : "Failure for request \(request.id)" + ] + } + + app.middleware.use(checkpoint) + + var apiKeyHeader = HTTPHeaders() + apiKeyHeader.add(name: "X-ApiKey", value: "fitomad#8") + + (0...20).forEach { index in + try? app.test(.GET, "sliding-window-log", headers: apiKeyHeader, afterResponse: { testResponse in + if index < 10 { + XCTAssertFalse(testResponse.headers.contains(name: "X-RateLimit")) + } else { + XCTAssertTrue(testResponse.headers.contains(name: "X-RateLimit")) + } + }) + } + } +} From 3eaa9fbccaf0136714ea2331fce570763acdb9d5 Mon Sep 17 00:00:00 2001 From: fitomad Date: Tue, 25 Jun 2024 22:24:45 +0200 Subject: [PATCH 15/15] Ready for version 0.1.0 --- README.md | 262 +++++++++++++++++- Sources/Checkpoint/Algorithms/Algorithm.swift | 8 + .../Algorithms/FixedWindowCounter.swift | 26 +- .../Checkpoint/Algorithms/LeakingBucket.swift | 5 - .../Algorithms/SlidingWindowLog.swift | 27 +- .../Checkpoint/Algorithms/TokenBucket.swift | 1 - .../Algorithms/WindowBasedAlgorithm.swift | 3 + Sources/Checkpoint/Checkpoint.swift | 15 +- .../Extensions/Checkpoint+Vapor.swift | 14 - Sources/Checkpoint/Model/Configuration.swift | 65 ----- .../FixedWindowCounterConfiguration.swift | 25 ++ .../Model/LeakingBucketConfiguration.swift | 27 ++ .../Model/SlidingWindowLogConfiguration.swift | 25 ++ .../Model/TokenBucketConfiguration.swift | 27 ++ .../CheckpointApiKeyTests.swift | 2 + .../CheckpointApiScopeTests.swift | 6 + 16 files changed, 428 insertions(+), 110 deletions(-) delete mode 100644 Sources/Checkpoint/Extensions/Checkpoint+Vapor.swift create mode 100644 Sources/Checkpoint/Model/FixedWindowCounterConfiguration.swift create mode 100644 Sources/Checkpoint/Model/LeakingBucketConfiguration.swift create mode 100644 Sources/Checkpoint/Model/SlidingWindowLogConfiguration.swift create mode 100644 Sources/Checkpoint/Model/TokenBucketConfiguration.swift diff --git a/README.md b/README.md index 07845a7..2872a5c 100644 --- a/README.md +++ b/README.md @@ -3,6 +3,8 @@ A Rate-Limit middleware implementation for Vapor servers using Redis database. ```swift +... + let tokenBucket = TokenBucket { TokenBucketConfiguration(bucketSize: 25, refillRate: 5, @@ -18,7 +20,7 @@ let tokenBucket = TokenBucket { let checkpoint = Checkpoint(using: tokenBucket) -// Modify response HTTP header and body response when rate limit exceed +// ๐Ÿšจ Modify response HTTP header and body response when rate limit exceed checkpoint.didFailWithTooManyRequest = { (request, response, metadata) in metadata.headers = [ "X-RateLimit" : "Failure for request \(request.id)." @@ -33,12 +35,268 @@ app.middleware.use(checkpoint) ## Supported algorythims -### Tocken Bucket +Currently **Checkpoint** supports 4 rate-limit algorithims. + +### Token Bucket + +The Token Bucket rate-limiting algorithm is a widely-used and flexible approach that controls the rate of requests to a service while allowing for some bursts of traffic. Hereโ€™s an explanation of how it works: + +The configuration for the Token Bucket is setted using the `TokenBucketConfiguration` type + +```swift +let tokenbucketAlgorithm = TokenBucket { + TokenBucketConfiguration(bucketSize: 10, + refillRate: 0, + refillTimeInterval: .seconds(count: 20), + appliedTo: .header(key: "X-ApiKey"), + inside: .endpoint) +} storage: { + // Rate limit database in Redis + app.redis("rate").configuration = try? RedisConfiguration(hostname: "localhost", + port: 9090, + database: 0) + + return app.redis("rate") +} logging: { + app.logger +} +``` + +How the Token Bucket Algorithm Works: + +1. Initialize the Bucket: + +- The bucket has a fixed capacity, which represents the maximum number of tokens it can hold. +- Tokens are added to the bucket at a fixed rate, up to the bucket's capacity. + +2. Handle Incoming Requests: + +- When a new request arrives, check if there are enough tokens in the bucket. +- If there is at least one token, allow the request and remove a token from the bucket. +- If there are no tokens available, deny the request (rate limit exceeded). + +3. Add Tokens: + +- Tokens are added to the bucket at a steady rate, which determines the average rate of allowed requests. +- The bucket never holds more than its fixed capacity of tokens. ### Leaking Bucket +The Leaking Bucket rate-limit algorithm is an effective approach to rate limiting that ensures a smooth, steady flow of requests. It works similarly to a physical bucket with a hole in it, where water (requests) drips out at a constant rate. Hereโ€™s a detailed explanation of how it works: + +The configuration for Leaking Bucket is the `LeakingBucketConfiguration` object + +```swift +let leakingBucketAlgorithm = LeakingBucket { + LeakingBucketConfiguration(bucketSize: 10, + removingRate: 5, + removingTimeInterval: .minutes(count: 1), + appliedTo: .header(key: "X-ApiKey"), + inside :.endpoint) +} storage: { + // Rate limit database in Redis + app.redis("rate").configuration = try? RedisConfiguration(hostname: "localhost", + port: 9090, + database: 0) + + return app.redis("rate") +} logging: { + app.logger +} +``` + +How the Leaking Bucket Algorithm Works: + +1. Initialize the Bucket: + +- The bucket has a fixed capacity, representing the maximum number of requests that can be stored in the bucket at any given time. +- The bucket leaks at a fixed rate, representing the maximum rate at which requests are processed. + +2. Handle Incoming Requests: + +- When a new request arrives, check the current level of the bucket. +- If the bucket is not full (i.e., the number of stored requests is less than the bucket's capacity), add the request to the bucket. +- If the bucket is full, deny the request (rate limit exceeded). + +3. Process Requests: + +- Requests in the bucket are processed (leaked) at a constant rate. +- This ensures a steady flow of requests, preventing sudden bursts. + ### Fixed Window Counter +The Fixed Window Counter rate-limit algorithm is a straightforward and easy-to-implement approach for rate limiting, used to control the number of requests a client can make to a service within a specified time period. Hereโ€™s an explanation of how it works: + +To set the configuration you must use the `FixedWindowCounterConfiguration` type + +```swift +let fixedWindowAlgorithm = FixedWindowCounter { + FixedWindowCounterConfiguration(requestPerWindow: 10, + timeWindowDuration: .minutes(count: 2), + appliedTo: .header(key: "X-ApiKey"), + inside: .endpoint) +} storage: { + // Rate limit database in Redis + app.redis("rate").configuration = try? RedisConfiguration(hostname: "localhost", + port: 9090, + database: 0) + + return app.redis("rate") +} logging: { + app.logger +} +``` + +How the Fixed Window Counter Algorithm Works: + +1. Define a Time Window +Choose a fixed duration (e.g., 1 minute, 1 hour) which will serve as the time window for counting requests. + +2. Initialize a Counter: +Maintain a counter for each client (or each resource being accessed) that tracks the number of requests made within the current time window. + +3. Track Request Timestamps: +Each time a request is made, check the current timestamp and determine which time window it falls into. +Increment the Counter: + +- If the request falls within the current window, increment the counter. +- If the request falls outside the current window, reset the counter and start a new window. + +4. Enforce Limits: + +- If the counter exceeds the predefined limit within the current window, the request is denied (or throttled). +- If the counter is within the limit, the request is allowed. + + + ```swift + + ``` + ### Sliding Window Log +The Sliding Window Log rate-limit algorithm is a more refined approach to rate limiting compared to the Fixed Window Counter. It offers smoother control over request rates by maintaining a log of individual request timestamps, allowing for a more granular and accurate rate-limiting mechanism. Hereโ€™s a detailed explanation of how it works: + +To set the configuration for this rate-limit algorithim use the `` type + +```swift +let slidingWindowLogAlgorith = SlidingWindowLog { + SlidingWindowLogConfiguration(requestPerWindow: 10, + windowDuration: .minutes(count: 2), + appliedTo: .header(key: "X-ApiKey"), + inside: .endpoint) +} storage: { + // Rate limit database in Redis + app.redis("rate").configuration = try? RedisConfiguration(hostname: "localhost", + port: 9090, + database: 0) + + return app.redis("rate") +} logging: { + app.logger +} +``` + +How the Sliding Window Log Algorithm Works: + +1. Define a Time Window: +Choose a time window duration (e.g., 1 minute) within which you want to limit the number of requests. + +2. Log Requests: +Maintain a log (typically a list or queue) for each client that stores the timestamps of each request. + +3. Handle Incoming Requests: +When a new request arrives, do the following: + +- Remove timestamps from the log that fall outside the current time window. +- Check the number of timestamps remaining in the log. +- If the number of requests (timestamps) within the window is below the limit, add the new requestโ€™s timestamp to the log and allow the request. +- If the number of requests meets or exceeds the limit, deny the request. + + ## Modify server response + +Sometimes we need to modify the response sent to the client by adding a custom HTTP header or setting a failure reason text in the JSON payload. + +In that case, you can use one of the closures defined in the `Checkpoint` class, one per Rate-Limit processing stage. + +### Before performing Rate-Limit checking + +This closure is invoked just before the Checkpoint middleware checking operation for a given request will be performed, and receive a Request object as a parameter. + +```swift +public var willCheck: CheckpointAction? +``` + +### After perform Rate-Limit checking + +If Rate-Limit checking goes well, this closure is invoked, and you know that the Request continues to be processed by the Vapor server. + +```swift +public var willCheck: CheckpointAction? +``` + +### Rate-Limit reached +It's sure you want to know when a request reaches the rate limit you set when initializing Checkpoint. + +In this case, Checkpoint will notify a rate-limit reached using the didFailWithTooManyRequest closure. + +```swift +public var didFailWithTooManyRequest: CheckpointErrorAction? +``` + +This closure contains 3 parameter + +- `requests`. It's a [`Request`](https://api.vapor.codes/vapor/documentation/vapor/request) object type representing the user request that reaches the limit. +- `response`. It's the server response ([`Response`](https://api.vapor.codes/vapor/documentation/vapor/response) type) returned by Vapor. +- `metadata`. It's an object designed to set custom HTTP headers and a reason text that will be attached to the object payload returned by the response. + +For example, if you want to add a custom HTTP header and a reason text to inform a user that he reaches the limit you will do something like this + +```swift +// ๐Ÿ‘ฎโ€โ™€๏ธ Modify response HTTP header and body response when rate limit exceed +checkpoint.didFailWithTooManyRequest = { (request, response, metadata) in + metadata.headers = [ + "X-RateLimit" : "Failure for request \(request.id)." + ] + + metadata.reason = "Rate limit for your api key exceeded" +} +``` + +### Error throwed while process a request + +If an error different from an HTTP 429 code (rate-limit) comes from Checkpoint, you will be reported in the following closure + +```swift +// ๐Ÿšจ Modify response HTTP header and body response when error occurs +checkpoint.didFail = { (request, response, abort, metadata) in + metadata.headers = [ + "X-ApiError" : "Error for request \(request.id)." + ] + + metadata.reason = "Error code \(abort.status) for your api key exceeded" +} +``` + +The parameters used in this closure are the same as the ones received in the closure, you can add a custom HTTP header and/or a reason message. + +## Redis + +To work with Checkpoint you must install and configure a Redis database in your system. Thanks to Docker it's really easy to deploy a Redis installation. + +We recommend to install the [**redis-stack-server**](https://hub.docker.com/r/redis/redis-stack-server) image from the Docker Hub. + +## History + +### 0.1.0 + +Alpha version, a *Friends & Family* release ๐Ÿ˜œ + +- Support for Redis Database +- Logging system based on the Vapor `Logger` type +- Four rate-limit algorithims support + - Fixed Window Counter + - Leaking Bucket + - Sliding Window Log + - Token Bucket + diff --git a/Sources/Checkpoint/Algorithms/Algorithm.swift b/Sources/Checkpoint/Algorithms/Algorithm.swift index 6ea4184..80ac9b8 100644 --- a/Sources/Checkpoint/Algorithms/Algorithm.swift +++ b/Sources/Checkpoint/Algorithms/Algorithm.swift @@ -12,14 +12,22 @@ import Vapor public typealias StorageAction = () -> Application.Redis public typealias LoggerAction = () -> Logger +/// Definition for the different Rate-Limit algorithims public protocol Algorithm: Sendable { + /// The configuration type used in a specific algorithim associatedtype ConfigurationType + /// The Redis database used to store the request data var storage: Application.Redis { get } + /// A `Logger` object created on Vapor var logging: Logger? { get } + /// Create a new Rate-Limit algorithim with a given configuration, + /// storage and logging init(configuration: () -> ConfigurationType, storage: StorageAction, logging: LoggerAction?) + /// Performs the algorithim logic to check if a request is valid + /// or reach the rate-limit specified on the algorithim's configuration func checkRequest(_ request: Request) async throws } diff --git a/Sources/Checkpoint/Algorithms/FixedWindowCounter.swift b/Sources/Checkpoint/Algorithms/FixedWindowCounter.swift index bb6c838..bd2cd00 100644 --- a/Sources/Checkpoint/Algorithms/FixedWindowCounter.swift +++ b/Sources/Checkpoint/Algorithms/FixedWindowCounter.swift @@ -10,22 +10,30 @@ import Redis import Vapor /** - Algorithm can be described as follows: - - 1. Timeline is divided into fixed time windows. - 2. Each time window has a counter. - 3. When a request comes in, the counter for the current time window is incremented. - 4. If the counter is greater than the rate limit, the request is rejected. - 5. If the counter is less than the rate limit, the request is accepted. + Fixed Window Counter algorithm presents the workflow described as follows: + + 1. Define a time window has a counter where the store the number of requets for a given time window. + 3. When a user makes a request, the counter for the current time window is incremented by 1. + 4. If the counter is greater than the rate limit, the request is rejected and whe send an HTTP 429 code status. + 5. If the counter is less than the rate limit, the request is accepted. */ public final class FixedWindowCounter { + // Configuration for this rate-limit algorithm private let configuration: FixedWindowCounterConfiguration + // The Redis database where we store the request information public let storage: Application.Redis + // A logger set during Vapor initialization public let logging: Logger? + // The Combine Timer publishers private var cancellable: AnyCancellable? + // Keys stored in a given time window private var keys = Set() + /** + + + */ public init(configuration: () -> FixedWindowCounterConfiguration, storage: StorageAction, logging: LoggerAction? = nil) { self.configuration = configuration() self.storage = storage() @@ -35,12 +43,16 @@ public final class FixedWindowCounter { performing: resetWindow) } + /** + + */ deinit { cancellable?.cancel() } } extension FixedWindowCounter: WindowBasedAlgorithm { + /// public func checkRequest(_ request: Request) async throws { guard let requestKey = try? valueFor(field: configuration.appliedField, in: request, inside: configuration.scope) else { return diff --git a/Sources/Checkpoint/Algorithms/LeakingBucket.swift b/Sources/Checkpoint/Algorithms/LeakingBucket.swift index e6e5d3b..f831761 100644 --- a/Sources/Checkpoint/Algorithms/LeakingBucket.swift +++ b/Sources/Checkpoint/Algorithms/LeakingBucket.swift @@ -48,10 +48,6 @@ public final class LeakingBucket { } extension LeakingBucket: WindowBasedAlgorithm { - var isValidRequest: Bool { - return true - } - public func checkRequest(_ request: Request) async throws { guard let requestKey = try? valueFor(field: configuration.appliedField, in: request, inside: configuration.scope) else { return @@ -68,7 +64,6 @@ extension LeakingBucket: WindowBasedAlgorithm { // 1. New request, remove one token from the bucket let bucketItemsCount = try await storage.increment(redisKey).get() - logging?.info("โŒš๏ธ \(requestKey) = \(bucketItemsCount)") // 2. If buckes is empty, throw an error if bucketItemsCount > configuration.bucketSize { throw Abort(.tooManyRequests) diff --git a/Sources/Checkpoint/Algorithms/SlidingWindowLog.swift b/Sources/Checkpoint/Algorithms/SlidingWindowLog.swift index cb70cb8..566a22f 100644 --- a/Sources/Checkpoint/Algorithms/SlidingWindowLog.swift +++ b/Sources/Checkpoint/Algorithms/SlidingWindowLog.swift @@ -8,11 +8,28 @@ import Redis import Vapor +/// The Sliding Window Log rate-limit algorithim is based on the request count perfomed during a non fixed window time. +/// It works following this workflow: +/// +/// 1. When new request comes in remove all outdated timestamps from cache. By outdated we mean timestamps that are older than window size. +/// 2. Add new timestamp to cache. +/// 3. If number of timestamps in cache is greater than limit reject request and return 429 status code. +/// 4. If lower than limit then accept request and return 200 status code. + public final class SlidingWindowLog { + /// Configuration for the Sliding Window Log private let configuration: SlidingWindowLogConfiguration + /// Redis database where we store the request timestamps public let storage: Application.Redis + /// A Vapor logger object public let logging: Logger? + /// Create a new Sliging Window Log with a given configuration, storage and a logger + /// + /// - Parameters: + /// - configuration: A `SlidingWindowLogConfiguration` object + /// - storage: The Redis database instance created on Vapor + /// - logging: A `Logger` object created on Vapor. public init(configuration: () -> SlidingWindowLogConfiguration, storage: StorageAction, logging: LoggerAction? = nil) { self.configuration = configuration() self.storage = storage() @@ -21,16 +38,11 @@ public final class SlidingWindowLog { } extension SlidingWindowLog: Algorithm { - var isValidRequest: Bool { - return true - } - public func checkRequest(_ request: Request) async throws { guard let apiKey = try? valueFor(field: configuration.appliedField, in: request, inside: configuration.scope) else { throw Abort(.unauthorized, reason: Checkpoint.HTTPErrorDescription.unauthorized) } - logging?.info("๐Ÿ’ก ApiKey: \(apiKey)") let redisKey = RedisKey(apiKey) let requestDate = Date() @@ -39,17 +51,14 @@ extension SlidingWindowLog: Algorithm { // 1. Delete outdated request let topBound: Double = Double(outdatedRequestLimiteDate.timeIntervalSinceReferenceDate) let deletedEntriesCount = try await storage.zremrangebyscore(from: redisKey, withMaximumScoreOf: RedisZScoreBound(floatLiteral: topBound)).get() - logging?.info("๐Ÿ†‘ Deleted \(deletedEntriesCount) items for api key \"\(apiKey)\"") // 2. Add the current request let requestTimeInterval = Double(requestDate.timeIntervalSinceReferenceDate) try await storage.zadd([ (element: requestTimeInterval, score: requestTimeInterval) ], to: redisKey).get() - logging?.info("๐Ÿ’ก Element: \(requestTimeInterval) Score: \(requestTimeInterval)") // 3. Get the number of request for this time window let itemsCount = try await storage.zcount(of: redisKey, withScores: 0.0...requestTimeInterval).get() - logging?.info("๐Ÿ’ก Current items: \(itemsCount)") - + // 4. If request count is greater... if itemsCount > configuration.requestPerWindow { throw Abort(.tooManyRequests) diff --git a/Sources/Checkpoint/Algorithms/TokenBucket.swift b/Sources/Checkpoint/Algorithms/TokenBucket.swift index 00f354a..3aef7c9 100644 --- a/Sources/Checkpoint/Algorithms/TokenBucket.swift +++ b/Sources/Checkpoint/Algorithms/TokenBucket.swift @@ -65,7 +65,6 @@ extension TokenBucket: WindowBasedAlgorithm { // 1. New request, remove one token from the bucket let bucketItemsCount = try await storage.decrement(redisKey).get() - logging?.info("โŒš๏ธ \(requestKey) = \(bucketItemsCount)") // 2. If buckes is empty, throw an error if bucketItemsCount < 0 { throw Abort(.tooManyRequests) diff --git a/Sources/Checkpoint/Algorithms/WindowBasedAlgorithm.swift b/Sources/Checkpoint/Algorithms/WindowBasedAlgorithm.swift index fc44e39..7561532 100644 --- a/Sources/Checkpoint/Algorithms/WindowBasedAlgorithm.swift +++ b/Sources/Checkpoint/Algorithms/WindowBasedAlgorithm.swift @@ -10,8 +10,11 @@ import Foundation public typealias WindowBasedAction = () throws -> Void +/// For those algorithims thar works with fixed time windows. public protocol WindowBasedAlgorithm: Algorithm { + /// Start the timer for a given duration (time window) func startWindow(havingDuration seconds: Double, performing action: @escaping WindowBasedAction) -> AnyCancellable + /// Perfomrs the reset operation when the time windo ends. func resetWindow() async throws } diff --git a/Sources/Checkpoint/Checkpoint.swift b/Sources/Checkpoint/Checkpoint.swift index 949919d..4d70e21 100644 --- a/Sources/Checkpoint/Checkpoint.swift +++ b/Sources/Checkpoint/Checkpoint.swift @@ -8,16 +8,17 @@ import Redis import Vapor -public typealias CheckpointAction = (Request) -> Void -public typealias CheckpointErrorAction = (Request, Response, Checkpoint.ErrorMetadata) -> Void +public typealias CheckpointHandler = (Request) -> Void +public typealias CheckpointRateLimitHandler = (Request, Response, Checkpoint.ErrorMetadata) -> Void +public typealias CheckpointErrorHandler = (Request, Response, AbortError, Checkpoint.ErrorMetadata) -> Void public final class Checkpoint { private let algorithm: any Algorithm - public var willCheck: CheckpointAction? - public var didCheck: CheckpointAction? - public var didFailWithTooManyRequest: CheckpointErrorAction? - public var didFail: CheckpointErrorAction? + public var willCheck: CheckpointHandler? + public var didCheck: CheckpointHandler? + public var didFailWithTooManyRequest: CheckpointRateLimitHandler? + public var didFail: CheckpointErrorHandler? public init(using algorithm: some Algorithm) { self.algorithm = algorithm @@ -43,7 +44,7 @@ extension Checkpoint: AsyncMiddleware { headers: errorMetadata.httpHeaders, reason: errorMetadata.reason) default: - didFail?(request, response, errorMetadata) + didFail?(request, response, abort, errorMetadata) throw Abort(.badRequest, headers: errorMetadata.httpHeaders, diff --git a/Sources/Checkpoint/Extensions/Checkpoint+Vapor.swift b/Sources/Checkpoint/Extensions/Checkpoint+Vapor.swift deleted file mode 100644 index 75660fd..0000000 --- a/Sources/Checkpoint/Extensions/Checkpoint+Vapor.swift +++ /dev/null @@ -1,14 +0,0 @@ -// -// Checkpoint+Vapor.swift -// -// -// Created by Adolfo Vera Blasco on 18/6/24. -// - -import Vapor - -extension Request { - func checkpoint() { - - } -} diff --git a/Sources/Checkpoint/Model/Configuration.swift b/Sources/Checkpoint/Model/Configuration.swift index 9a027a9..57ba452 100644 --- a/Sources/Checkpoint/Model/Configuration.swift +++ b/Sources/Checkpoint/Model/Configuration.swift @@ -11,75 +11,10 @@ public protocol Configuration { var scope: Scope { get } } -public struct FixedWindowCounterConfiguration: Configuration { - public private(set) var requestPerWindow: Int - public private(set) var timeWindowDuration: TimeWindow - - public private(set) var appliedField: Field - public private(set) var scope: Scope -} -extension FixedWindowCounterConfiguration { - public init(requestPerWindow: Int, timeWindowDuration: TimeWindow, appliedTo field: Field = .noField, inside scope: Scope = .noScope) { - self.requestPerWindow = requestPerWindow - self.timeWindowDuration = timeWindowDuration - self.appliedField = field - self.scope = scope - } -} -public struct LeakingBucketConfiguration: Configuration { - public var bucketSize = 10 - public var tokenRemovingRate = 5 - public var timeWindowDuration: TimeWindow = .seconds(count: 10) - - public var appliedField: Field - public var scope: Scope -} -extension LeakingBucketConfiguration { - public init(bucketSize: Int, removingRate: Int, removingTimeInterval: TimeWindow, appliedTo field: Field = .noField, inside scope: Scope = .noScope) { - self.bucketSize = bucketSize - self.tokenRemovingRate = removingRate - self.timeWindowDuration = removingTimeInterval - self.appliedField = field - self.scope = scope - } -} -public struct SlidingWindowLogConfiguration: Configuration { - public var requestPerWindow = 10 - public var timeWindowDuration: TimeWindow - - public var appliedField: Field - public var scope: Scope -} -extension SlidingWindowLogConfiguration { - public init(requestPerWindow: Int, windowDuration: TimeWindow, appliedTo field: Field = .noField, inside scope: Scope = .noScope) { - self.requestPerWindow = requestPerWindow - self.timeWindowDuration = windowDuration - self.appliedField = field - self.scope = scope - } -} - -public struct TokenBucketConfiguration: Configuration { - public private(set) var bucketSize: Int - public private(set) var refillTokenRate: Int - public private(set) var refillTimeInterval: TimeWindow - - public private(set) var appliedField: Field - public private(set) var scope: Scope -} -extension TokenBucketConfiguration { - public init(bucketSize: Int, refillRate: Int, refillTimeInterval: TimeWindow, appliedTo field: Field = .noField, inside scope: Scope = .noScope) { - self.bucketSize = bucketSize - self.refillTokenRate = refillRate - self.refillTimeInterval = refillTimeInterval - self.appliedField = field - self.scope = scope - } -} diff --git a/Sources/Checkpoint/Model/FixedWindowCounterConfiguration.swift b/Sources/Checkpoint/Model/FixedWindowCounterConfiguration.swift new file mode 100644 index 0000000..c1ddb14 --- /dev/null +++ b/Sources/Checkpoint/Model/FixedWindowCounterConfiguration.swift @@ -0,0 +1,25 @@ +// +// FixedWindowCounterConfiguration.swift +// +// +// Created by Adolfo Vera Blasco on 25/6/24. +// + +import Foundation + +public struct FixedWindowCounterConfiguration: Configuration { + public private(set) var requestPerWindow: Int + public private(set) var timeWindowDuration: TimeWindow + + public private(set) var appliedField: Field + public private(set) var scope: Scope +} + +extension FixedWindowCounterConfiguration { + public init(requestPerWindow: Int, timeWindowDuration: TimeWindow, appliedTo field: Field = .noField, inside scope: Scope = .noScope) { + self.requestPerWindow = requestPerWindow + self.timeWindowDuration = timeWindowDuration + self.appliedField = field + self.scope = scope + } +} diff --git a/Sources/Checkpoint/Model/LeakingBucketConfiguration.swift b/Sources/Checkpoint/Model/LeakingBucketConfiguration.swift new file mode 100644 index 0000000..12cde2b --- /dev/null +++ b/Sources/Checkpoint/Model/LeakingBucketConfiguration.swift @@ -0,0 +1,27 @@ +// +// LeakingBucketConfiguration.swift +// +// +// Created by Adolfo Vera Blasco on 25/6/24. +// + +import Foundation + +public struct LeakingBucketConfiguration: Configuration { + public var bucketSize: Int + public var tokenRemovingRate: Int + public var timeWindowDuration: TimeWindow + + public var appliedField: Field + public var scope: Scope +} + +extension LeakingBucketConfiguration { + public init(bucketSize: Int, removingRate: Int, removingTimeInterval: TimeWindow, appliedTo field: Field = .noField, inside scope: Scope = .noScope) { + self.bucketSize = bucketSize + self.tokenRemovingRate = removingRate + self.timeWindowDuration = removingTimeInterval + self.appliedField = field + self.scope = scope + } +} diff --git a/Sources/Checkpoint/Model/SlidingWindowLogConfiguration.swift b/Sources/Checkpoint/Model/SlidingWindowLogConfiguration.swift new file mode 100644 index 0000000..639b16a --- /dev/null +++ b/Sources/Checkpoint/Model/SlidingWindowLogConfiguration.swift @@ -0,0 +1,25 @@ +// +// SlidingWindowLogConfiguration.swift +// +// +// Created by Adolfo Vera Blasco on 25/6/24. +// + +import Foundation + +public struct SlidingWindowLogConfiguration: Configuration { + public var requestPerWindow: Int + public var timeWindowDuration: TimeWindow + + public var appliedField: Field + public var scope: Scope +} + +extension SlidingWindowLogConfiguration { + public init(requestPerWindow: Int, windowDuration: TimeWindow, appliedTo field: Field = .noField, inside scope: Scope = .noScope) { + self.requestPerWindow = requestPerWindow + self.timeWindowDuration = windowDuration + self.appliedField = field + self.scope = scope + } +} diff --git a/Sources/Checkpoint/Model/TokenBucketConfiguration.swift b/Sources/Checkpoint/Model/TokenBucketConfiguration.swift new file mode 100644 index 0000000..6a3d371 --- /dev/null +++ b/Sources/Checkpoint/Model/TokenBucketConfiguration.swift @@ -0,0 +1,27 @@ +// +// TokenBucketConfiguration.swift +// +// +// Created by Adolfo Vera Blasco on 25/6/24. +// + +import Foundation + +public struct TokenBucketConfiguration: Configuration { + public private(set) var bucketSize: Int + public private(set) var refillTokenRate: Int + public private(set) var refillTimeInterval: TimeWindow + + public private(set) var appliedField: Field + public private(set) var scope: Scope +} + +extension TokenBucketConfiguration { + public init(bucketSize: Int, refillRate: Int, refillTimeInterval: TimeWindow, appliedTo field: Field = .noField, inside scope: Scope = .noScope) { + self.bucketSize = bucketSize + self.refillTokenRate = refillRate + self.refillTimeInterval = refillTimeInterval + self.appliedField = field + self.scope = scope + } +} diff --git a/Tests/CheckpointTests/CheckpointApiKeyTests.swift b/Tests/CheckpointTests/CheckpointApiKeyTests.swift index 725917d..a373856 100644 --- a/Tests/CheckpointTests/CheckpointApiKeyTests.swift +++ b/Tests/CheckpointTests/CheckpointApiKeyTests.swift @@ -31,6 +31,8 @@ final class CheckpointApiKeyTests: XCTestCase { database: 0) return app.redis("rate") + } logging: { + app.logger } let checkpoint = Checkpoint(using: leakingBucketAlgorithm) diff --git a/Tests/CheckpointTests/CheckpointApiScopeTests.swift b/Tests/CheckpointTests/CheckpointApiScopeTests.swift index 9da1efe..2ea24c3 100644 --- a/Tests/CheckpointTests/CheckpointApiScopeTests.swift +++ b/Tests/CheckpointTests/CheckpointApiScopeTests.swift @@ -32,6 +32,8 @@ final class CheckpointApiScoreTests: XCTestCase { database: 0) return app.redis("rate") + } logging: { + app.logger } let checkpoint = Checkpoint(using: leakingBucketAlgorithm) @@ -117,6 +119,8 @@ final class CheckpointApiScoreTests: XCTestCase { database: 0) return app.redis("rate") + } logging: { + app.logger } @@ -159,6 +163,8 @@ final class CheckpointApiScoreTests: XCTestCase { database: 0) return app.redis("rate") + } logging: { + app.logger }