diff --git a/Package.swift b/Package.swift index 392f715..9a58ebe 100644 --- a/Package.swift +++ b/Package.swift @@ -1,4 +1,4 @@ -// swift-tools-version: 6.0 +// swift-tools-version: 5.7 // The swift-tools-version declares the minimum version of Swift required to build this package. import PackageDescription @@ -6,7 +6,10 @@ import PackageDescription let package = Package( name: "Checkpoint", platforms: [ - .macOS(.v13) + .macOS(.v10_15), + .iOS(.v13), + .tvOS(.v13), + .watchOS(.v6) ], products: [ // Products define the executables and libraries a package produces, making them visible to other packages. @@ -32,7 +35,10 @@ let package = Package( ), .testTarget( name: "CheckpointTests", - dependencies: ["Checkpoint"] + dependencies: [ + "Checkpoint", + .product(name: "XCTVapor", package: "vapor") + ] ), ] ) diff --git a/README.md b/README.md index 3f97815..2872a5c 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,16 @@ -# Checkpoint +# Checkpoint 💧 A Rate-Limit middleware implementation for Vapor servers using Redis database. ```swift +... + let tokenBucket = TokenBucket { - TokenBucketConfiguration(bucketSize: 5, + TokenBucketConfiguration(bucketSize: 25, refillRate: 5, - refillTimeInterval: .seconds(count: 30), + refillTimeInterval: .seconds(count: 45), appliedField: .header(key: "X-ApiKey"), - scope: .nonScope) + scope: .endpoint) } storage: { application.redis("rate") } logging: { @@ -18,18 +20,283 @@ let tokenBucket = TokenBucket { let checkpoint = Checkpoint(using: tokenBucket) +// 🚨 Modify response HTTP header and body response when rate limit exceed +checkpoint.didFailWithTooManyRequest = { (request, response, metadata) in + metadata.headers = [ + "X-RateLimit" : "Failure for request \(request.id)." + ] + + metadata.reason = "Rate limit for your api key exceeded" +} + // 💧 Vapor Middleware app.middleware.use(checkpoint) ``` ## Supported algorythims -### Tocken Bucket +Currently **Checkpoint** supports 4 rate-limit algorithims. + +### Token Bucket + +The Token Bucket rate-limiting algorithm is a widely-used and flexible approach that controls the rate of requests to a service while allowing for some bursts of traffic. Here’s an explanation of how it works: + +The configuration for the Token Bucket is setted using the `TokenBucketConfiguration` type + +```swift +let tokenbucketAlgorithm = TokenBucket { + TokenBucketConfiguration(bucketSize: 10, + refillRate: 0, + refillTimeInterval: .seconds(count: 20), + appliedTo: .header(key: "X-ApiKey"), + inside: .endpoint) +} storage: { + // Rate limit database in Redis + app.redis("rate").configuration = try? RedisConfiguration(hostname: "localhost", + port: 9090, + database: 0) + + return app.redis("rate") +} logging: { + app.logger +} +``` + +How the Token Bucket Algorithm Works: + +1. Initialize the Bucket: + +- The bucket has a fixed capacity, which represents the maximum number of tokens it can hold. +- Tokens are added to the bucket at a fixed rate, up to the bucket's capacity. + +2. Handle Incoming Requests: + +- When a new request arrives, check if there are enough tokens in the bucket. +- If there is at least one token, allow the request and remove a token from the bucket. +- If there are no tokens available, deny the request (rate limit exceeded). + +3. Add Tokens: + +- Tokens are added to the bucket at a steady rate, which determines the average rate of allowed requests. +- The bucket never holds more than its fixed capacity of tokens. ### Leaking Bucket +The Leaking Bucket rate-limit algorithm is an effective approach to rate limiting that ensures a smooth, steady flow of requests. It works similarly to a physical bucket with a hole in it, where water (requests) drips out at a constant rate. Here’s a detailed explanation of how it works: + +The configuration for Leaking Bucket is the `LeakingBucketConfiguration` object + +```swift +let leakingBucketAlgorithm = LeakingBucket { + LeakingBucketConfiguration(bucketSize: 10, + removingRate: 5, + removingTimeInterval: .minutes(count: 1), + appliedTo: .header(key: "X-ApiKey"), + inside :.endpoint) +} storage: { + // Rate limit database in Redis + app.redis("rate").configuration = try? RedisConfiguration(hostname: "localhost", + port: 9090, + database: 0) + + return app.redis("rate") +} logging: { + app.logger +} +``` + +How the Leaking Bucket Algorithm Works: + +1. Initialize the Bucket: + +- The bucket has a fixed capacity, representing the maximum number of requests that can be stored in the bucket at any given time. +- The bucket leaks at a fixed rate, representing the maximum rate at which requests are processed. + +2. Handle Incoming Requests: + +- When a new request arrives, check the current level of the bucket. +- If the bucket is not full (i.e., the number of stored requests is less than the bucket's capacity), add the request to the bucket. +- If the bucket is full, deny the request (rate limit exceeded). + +3. Process Requests: + +- Requests in the bucket are processed (leaked) at a constant rate. +- This ensures a steady flow of requests, preventing sudden bursts. + ### Fixed Window Counter +The Fixed Window Counter rate-limit algorithm is a straightforward and easy-to-implement approach for rate limiting, used to control the number of requests a client can make to a service within a specified time period. Here’s an explanation of how it works: + +To set the configuration you must use the `FixedWindowCounterConfiguration` type + +```swift +let fixedWindowAlgorithm = FixedWindowCounter { + FixedWindowCounterConfiguration(requestPerWindow: 10, + timeWindowDuration: .minutes(count: 2), + appliedTo: .header(key: "X-ApiKey"), + inside: .endpoint) +} storage: { + // Rate limit database in Redis + app.redis("rate").configuration = try? RedisConfiguration(hostname: "localhost", + port: 9090, + database: 0) + + return app.redis("rate") +} logging: { + app.logger +} +``` + +How the Fixed Window Counter Algorithm Works: + +1. Define a Time Window +Choose a fixed duration (e.g., 1 minute, 1 hour) which will serve as the time window for counting requests. + +2. Initialize a Counter: +Maintain a counter for each client (or each resource being accessed) that tracks the number of requests made within the current time window. + +3. Track Request Timestamps: +Each time a request is made, check the current timestamp and determine which time window it falls into. +Increment the Counter: + +- If the request falls within the current window, increment the counter. +- If the request falls outside the current window, reset the counter and start a new window. + +4. Enforce Limits: + +- If the counter exceeds the predefined limit within the current window, the request is denied (or throttled). +- If the counter is within the limit, the request is allowed. + + + ```swift + + ``` + ### Sliding Window Log +The Sliding Window Log rate-limit algorithm is a more refined approach to rate limiting compared to the Fixed Window Counter. It offers smoother control over request rates by maintaining a log of individual request timestamps, allowing for a more granular and accurate rate-limiting mechanism. Here’s a detailed explanation of how it works: + +To set the configuration for this rate-limit algorithim use the `` type + +```swift +let slidingWindowLogAlgorith = SlidingWindowLog { + SlidingWindowLogConfiguration(requestPerWindow: 10, + windowDuration: .minutes(count: 2), + appliedTo: .header(key: "X-ApiKey"), + inside: .endpoint) +} storage: { + // Rate limit database in Redis + app.redis("rate").configuration = try? RedisConfiguration(hostname: "localhost", + port: 9090, + database: 0) + + return app.redis("rate") +} logging: { + app.logger +} +``` + +How the Sliding Window Log Algorithm Works: + +1. Define a Time Window: +Choose a time window duration (e.g., 1 minute) within which you want to limit the number of requests. + +2. Log Requests: +Maintain a log (typically a list or queue) for each client that stores the timestamps of each request. + +3. Handle Incoming Requests: +When a new request arrives, do the following: + +- Remove timestamps from the log that fall outside the current time window. +- Check the number of timestamps remaining in the log. +- If the number of requests (timestamps) within the window is below the limit, add the new request’s timestamp to the log and allow the request. +- If the number of requests meets or exceeds the limit, deny the request. + + +## Modify server response + +Sometimes we need to modify the response sent to the client by adding a custom HTTP header or setting a failure reason text in the JSON payload. + +In that case, you can use one of the closures defined in the `Checkpoint` class, one per Rate-Limit processing stage. + +### Before performing Rate-Limit checking + +This closure is invoked just before the Checkpoint middleware checking operation for a given request will be performed, and receive a Request object as a parameter. + +```swift +public var willCheck: CheckpointAction? +``` + +### After perform Rate-Limit checking + +If Rate-Limit checking goes well, this closure is invoked, and you know that the Request continues to be processed by the Vapor server. + +```swift +public var willCheck: CheckpointAction? +``` + +### Rate-Limit reached +It's sure you want to know when a request reaches the rate limit you set when initializing Checkpoint. + +In this case, Checkpoint will notify a rate-limit reached using the didFailWithTooManyRequest closure. + +```swift +public var didFailWithTooManyRequest: CheckpointErrorAction? +``` + +This closure contains 3 parameter + +- `requests`. It's a [`Request`](https://api.vapor.codes/vapor/documentation/vapor/request) object type representing the user request that reaches the limit. +- `response`. It's the server response ([`Response`](https://api.vapor.codes/vapor/documentation/vapor/response) type) returned by Vapor. +- `metadata`. It's an object designed to set custom HTTP headers and a reason text that will be attached to the object payload returned by the response. + +For example, if you want to add a custom HTTP header and a reason text to inform a user that he reaches the limit you will do something like this + +```swift +// 👮‍♀️ Modify response HTTP header and body response when rate limit exceed +checkpoint.didFailWithTooManyRequest = { (request, response, metadata) in + metadata.headers = [ + "X-RateLimit" : "Failure for request \(request.id)." + ] + + metadata.reason = "Rate limit for your api key exceeded" +} +``` + +### Error throwed while process a request + +If an error different from an HTTP 429 code (rate-limit) comes from Checkpoint, you will be reported in the following closure + +```swift +// 🚨 Modify response HTTP header and body response when error occurs +checkpoint.didFail = { (request, response, abort, metadata) in + metadata.headers = [ + "X-ApiError" : "Error for request \(request.id)." + ] + + metadata.reason = "Error code \(abort.status) for your api key exceeded" +} +``` + +The parameters used in this closure are the same as the ones received in the closure, you can add a custom HTTP header and/or a reason message. + +## Redis + +To work with Checkpoint you must install and configure a Redis database in your system. Thanks to Docker it's really easy to deploy a Redis installation. + +We recommend to install the [**redis-stack-server**](https://hub.docker.com/r/redis/redis-stack-server) image from the Docker Hub. + +## History + +### 0.1.0 + +Alpha version, a *Friends & Family* release 😜 +- Support for Redis Database +- Logging system based on the Vapor `Logger` type +- Four rate-limit algorithims support + - Fixed Window Counter + - Leaking Bucket + - Sliding Window Log + - Token Bucket + diff --git a/Sources/Checkpoint/Algorithms/Algorithm.swift b/Sources/Checkpoint/Algorithms/Algorithm.swift new file mode 100644 index 0000000..80ac9b8 --- /dev/null +++ b/Sources/Checkpoint/Algorithms/Algorithm.swift @@ -0,0 +1,99 @@ +// +// Limiter.swift +// +// +// Created by Adolfo Vera Blasco on 14/6/24. +// + +import Combine +import Redis +import Vapor + +public typealias StorageAction = () -> Application.Redis +public typealias LoggerAction = () -> Logger + +/// Definition for the different Rate-Limit algorithims +public protocol Algorithm: Sendable { + /// The configuration type used in a specific algorithim + associatedtype ConfigurationType + + /// The Redis database used to store the request data + var storage: Application.Redis { get } + /// A `Logger` object created on Vapor + var logging: Logger? { get } + + /// Create a new Rate-Limit algorithim with a given configuration, + /// storage and logging + init(configuration: () -> ConfigurationType, storage: StorageAction, logging: LoggerAction?) + + /// Performs the algorithim logic to check if a request is valid + /// or reach the rate-limit specified on the algorithim's configuration + func checkRequest(_ request: Request) async throws +} + +extension Algorithm { + func valueFor(field: Field, in request: Request) throws -> String { + switch field { + case .header(let key): + guard let value = request.headers[key].first else { + throw Abort(.unauthorized, reason: Self.noFieldMessage) + } + + return value + case .queryItem(let key): + guard let value = request.query[String.self, at: key] else { + throw Abort(.unauthorized, reason: Self.noFieldMessage) + } + + return value + case .noField: + return Self.fieldDefaultKey + } + } + + func valueFor(scope: Scope, in request: Request) throws -> String { + switch scope { + case .endpoint: + return request.url.path + case .api: + guard let host = request.url.host else { + throw Abort(.badRequest, reason: Self.hostNotFoundMessage) + } + + return host + case .noScope: + return Self.scopeDefaultKey + } + } + + func valueFor(field: Field, in request: Request, inside scope: Scope) throws -> String { + let prefix = try valueFor(field: field, in: request) + let suffix = try valueFor(scope: scope, in: request) + + var hasher = Hasher() + hasher.combine(prefix) + hasher.combine(suffix) + + let key = hasher.finalize() + + return String(key) + } +} + +extension Algorithm { + static var fieldDefaultKey: String { + "checkpoint#no.field" + } + + static var scopeDefaultKey: String { + "checkpoint#no.scope" + } + + static var noFieldMessage: String { + "Expected field not found at headers or query parameters" + } + + static var hostNotFoundMessage: String { + "Unable to recover host from request" + } +} diff --git a/Sources/Checkpoint/Algorithms/FixedWindowCounter.swift b/Sources/Checkpoint/Algorithms/FixedWindowCounter.swift new file mode 100644 index 0000000..bd2cd00 --- /dev/null +++ b/Sources/Checkpoint/Algorithms/FixedWindowCounter.swift @@ -0,0 +1,88 @@ +// +// FixedWindowCounter.swift +// +// +// Created by Adolfo Vera Blasco on 15/6/24. +// + +import Combine +import Redis +import Vapor + +/** + Fixed Window Counter algorithm presents the workflow described as follows: + + 1. Define a time window has a counter where the store the number of requets for a given time window. + 3. When a user makes a request, the counter for the current time window is incremented by 1. + 4. If the counter is greater than the rate limit, the request is rejected and whe send an HTTP 429 code status. + 5. If the counter is less than the rate limit, the request is accepted. +*/ +public final class FixedWindowCounter { + // Configuration for this rate-limit algorithm + private let configuration: FixedWindowCounterConfiguration + // The Redis database where we store the request information + public let storage: Application.Redis + // A logger set during Vapor initialization + public let logging: Logger? + + // The Combine Timer publishers + private var cancellable: AnyCancellable? + // Keys stored in a given time window + private var keys = Set() + + /** + + + */ + public init(configuration: () -> FixedWindowCounterConfiguration, storage: StorageAction, logging: LoggerAction? = nil) { + self.configuration = configuration() + self.storage = storage() + self.logging = logging?() + + self.cancellable = startWindow(havingDuration: self.configuration.timeWindowDuration.inSeconds, + performing: resetWindow) + } + + /** + + */ + deinit { + cancellable?.cancel() + } +} + +extension FixedWindowCounter: WindowBasedAlgorithm { + /// + public func checkRequest(_ request: Request) async throws { + guard let requestKey = try? valueFor(field: configuration.appliedField, in: request, inside: configuration.scope) else { + return + } + + keys.insert(requestKey) + + let redisKey = RedisKey(requestKey) + let timestamp = Date().timeIntervalSince1970 + + let requestCount = try await storage.rpush([ timestamp ], into: redisKey).get() + + if requestCount > configuration.requestPerWindow { + throw Abort(.tooManyRequests) + } + } + + public func resetWindow() { + keys.forEach { key in + let redisKey = RedisKey(key) + + Task { + do { + try await storage.delete(redisKey).get() + } catch let redisError { + logging?.error("🚨 Error deleting key \(key): \(redisError.localizedDescription)") + } + } + } + + keys.removeAll() + } +} diff --git a/Sources/Checkpoint/Limiters/LeakingBucket.swift b/Sources/Checkpoint/Algorithms/LeakingBucket.swift similarity index 62% rename from Sources/Checkpoint/Limiters/LeakingBucket.swift rename to Sources/Checkpoint/Algorithms/LeakingBucket.swift index 11c4888..f831761 100644 --- a/Sources/Checkpoint/Limiters/LeakingBucket.swift +++ b/Sources/Checkpoint/Algorithms/LeakingBucket.swift @@ -18,19 +18,15 @@ import Vapor 4. If the bucket is not full, we allow the request and add 1 token from the bucket. 5. Tokens are removed at a fixed rate of r tokens per second. Let’s say 1 token per second. */ -final class LeakingBucket { +public final class LeakingBucket { private let configuration: LeakingBucketConfiguration - let storage: Application.Redis - let logging: Logger? + public let storage: Application.Redis + public let logging: Logger? private var cancellable: AnyCancellable? private var keys = Set() - private var redisKey: RedisKey { - RedisKey("") - } - - init(configuration: () -> LeakingBucketConfiguration, storage: StorageAction, logging: LoggerAction? = nil) { + public init(configuration: () -> LeakingBucketConfiguration, storage: StorageAction, logging: LoggerAction? = nil) { self.configuration = configuration() self.storage = storage() self.logging = logging?() @@ -46,17 +42,13 @@ final class LeakingBucket { do { try await storage.set(key, to: 0).get() } catch let redisError { - logging?.error("🚨 Problem setting key \(key.rawValue) to value \(configuration.bucketSize)") + logging?.error("🚨 Problem setting key \(key.rawValue) to value \(configuration.bucketSize): \(redisError.localizedDescription)") } } } -extension LeakingBucket: WindowBasedLimiter { - var isValidRequest: Bool { - return true - } - - func checkRequest(_ request: Request) async throws { +extension LeakingBucket: WindowBasedAlgorithm { + public func checkRequest(_ request: Request) async throws { guard let requestKey = try? valueFor(field: configuration.appliedField, in: request, inside: configuration.scope) else { return } @@ -64,7 +56,7 @@ extension LeakingBucket: WindowBasedLimiter { keys.insert(requestKey) let redisKey = RedisKey(requestKey) - let keyExists = await try storage.exists(redisKey).get() + let keyExists = try await storage.exists(redisKey).get() if keyExists == 0 { await preparaStorageFor(key: redisKey) @@ -72,24 +64,27 @@ extension LeakingBucket: WindowBasedLimiter { // 1. New request, remove one token from the bucket let bucketItemsCount = try await storage.increment(redisKey).get() - logging?.info("⌚️ \(requestKey) = \(bucketItemsCount)") // 2. If buckes is empty, throw an error - if bucketItemsCount >= configuration.bucketSize { + if bucketItemsCount > configuration.bucketSize { throw Abort(.tooManyRequests) } } - func resetWindow() throws { - Task(priority: .userInitiated) { - let respValue = try await storage.get(redisKey).get() - - var newBucketSize = 0 + public func resetWindow() throws { + keys.forEach { key in + Task(priority: .userInitiated) { + let redisKey = RedisKey(key) + + let respValue = try await storage.get(redisKey).get() - if let currentBucketSize = respValue.int { - newBucketSize = currentBucketSize < configuration.tokenRemovingRate ? 0 : (currentBucketSize - configuration.tokenRemovingRate) + var newBucketSize = 0 + + if let currentBucketSize = respValue.int { + newBucketSize = currentBucketSize < configuration.tokenRemovingRate ? 0 : (currentBucketSize - configuration.tokenRemovingRate) + } + + try await storage.decrement(redisKey, by: newBucketSize).get() } - - try await storage.decrement(redisKey, by: newBucketSize).get() } } } diff --git a/Sources/Checkpoint/Algorithms/SlidingWindowLog.swift b/Sources/Checkpoint/Algorithms/SlidingWindowLog.swift new file mode 100644 index 0000000..566a22f --- /dev/null +++ b/Sources/Checkpoint/Algorithms/SlidingWindowLog.swift @@ -0,0 +1,67 @@ +// +// SlidingWindowLog.swift +// +// +// Created by Adolfo Vera Blasco on 15/6/24. +// + +import Redis +import Vapor + +/// The Sliding Window Log rate-limit algorithim is based on the request count perfomed during a non fixed window time. +/// It works following this workflow: +/// +/// 1. When new request comes in remove all outdated timestamps from cache. By outdated we mean timestamps that are older than window size. +/// 2. Add new timestamp to cache. +/// 3. If number of timestamps in cache is greater than limit reject request and return 429 status code. +/// 4. If lower than limit then accept request and return 200 status code. + +public final class SlidingWindowLog { + /// Configuration for the Sliding Window Log + private let configuration: SlidingWindowLogConfiguration + /// Redis database where we store the request timestamps + public let storage: Application.Redis + /// A Vapor logger object + public let logging: Logger? + + /// Create a new Sliging Window Log with a given configuration, storage and a logger + /// + /// - Parameters: + /// - configuration: A `SlidingWindowLogConfiguration` object + /// - storage: The Redis database instance created on Vapor + /// - logging: A `Logger` object created on Vapor. + public init(configuration: () -> SlidingWindowLogConfiguration, storage: StorageAction, logging: LoggerAction? = nil) { + self.configuration = configuration() + self.storage = storage() + self.logging = logging?() + } +} + +extension SlidingWindowLog: Algorithm { + public func checkRequest(_ request: Request) async throws { + guard let apiKey = try? valueFor(field: configuration.appliedField, in: request, inside: configuration.scope) else { + throw Abort(.unauthorized, reason: Checkpoint.HTTPErrorDescription.unauthorized) + } + + let redisKey = RedisKey(apiKey) + + let requestDate = Date() + let outdatedRequestLimiteDate = Date().addingTimeInterval(-configuration.timeWindowDuration.inSeconds) + + // 1. Delete outdated request + let topBound: Double = Double(outdatedRequestLimiteDate.timeIntervalSinceReferenceDate) + let deletedEntriesCount = try await storage.zremrangebyscore(from: redisKey, withMaximumScoreOf: RedisZScoreBound(floatLiteral: topBound)).get() + + // 2. Add the current request + let requestTimeInterval = Double(requestDate.timeIntervalSinceReferenceDate) + try await storage.zadd([ (element: requestTimeInterval, score: requestTimeInterval) ], to: redisKey).get() + + // 3. Get the number of request for this time window + let itemsCount = try await storage.zcount(of: redisKey, withScores: 0.0...requestTimeInterval).get() + + // 4. If request count is greater... + if itemsCount > configuration.requestPerWindow { + throw Abort(.tooManyRequests) + } + } +} diff --git a/Sources/Checkpoint/Limiters/TokenBucket.swift b/Sources/Checkpoint/Algorithms/TokenBucket.swift similarity index 62% rename from Sources/Checkpoint/Limiters/TokenBucket.swift rename to Sources/Checkpoint/Algorithms/TokenBucket.swift index 0d11449..3aef7c9 100644 --- a/Sources/Checkpoint/Limiters/TokenBucket.swift +++ b/Sources/Checkpoint/Algorithms/TokenBucket.swift @@ -18,15 +18,15 @@ import Vapor • We take 1 token out for each request and if there are enough tokens, then the request is processed. • The request is dropped if there aren’t enough tokens. */ -final class TokenBucket { +public final class TokenBucket { private let configuration: TokenBucketConfiguration - let storage: Application.Redis - let logging: Logger? + public let storage: Application.Redis + public let logging: Logger? private var cancellable: AnyCancellable? private var keys = Set() - init(configuration: () -> TokenBucketConfiguration, storage: StorageAction, logging: LoggerAction? = nil) { + public init(configuration: () -> TokenBucketConfiguration, storage: StorageAction, logging: LoggerAction? = nil) { self.configuration = configuration() self.storage = storage() self.logging = logging?() @@ -48,8 +48,8 @@ final class TokenBucket { } } -extension TokenBucket: WindowBasedLimiter { - func checkRequest(_ request: Request) async throws { +extension TokenBucket: WindowBasedAlgorithm { + public func checkRequest(_ request: Request) async throws { guard let requestKey = try? valueFor(field: configuration.appliedField, in: request, inside: configuration.scope) else { return } @@ -57,7 +57,7 @@ extension TokenBucket: WindowBasedLimiter { keys.insert(requestKey) let redisKey = RedisKey(requestKey) - let keyExists = await try storage.exists(redisKey).get() + let keyExists = try await storage.exists(redisKey).get() if keyExists == 0 { await preparaStorageFor(key: redisKey) @@ -65,21 +65,33 @@ extension TokenBucket: WindowBasedLimiter { // 1. New request, remove one token from the bucket let bucketItemsCount = try await storage.decrement(redisKey).get() - logging?.info("⌚️ \(requestKey) = \(bucketItemsCount)") // 2. If buckes is empty, throw an error - if bucketItemsCount <= 0 { + if bucketItemsCount < 0 { throw Abort(.tooManyRequests) } } - func resetWindow() throws { - let redisKeys = keys.map { RedisKey($0) } - - Task { - do { - try await storage.delete(redisKeys).get() - } catch let redisError { - logging?.error("🚨 Problem deleting keys: \(redisError.localizedDescription)") + public func resetWindow() throws { + keys.forEach { key in + Task(priority: .userInitiated) { + let redisKey = RedisKey(key) + + let respValue = try await storage.get(redisKey).get() + + var newRefillSize = 0 + + if let currentBucketSize = respValue.int { + switch currentBucketSize { + case ...0: + newRefillSize -= currentBucketSize + case configuration.bucketSize...: + newRefillSize = configuration.bucketSize - currentBucketSize + default: + newRefillSize = configuration.refillTokenRate + } + } + + try await storage.increment(redisKey, by: newRefillSize).get() } } } diff --git a/Sources/Checkpoint/Limiters/WindowBasedLimiter.swift b/Sources/Checkpoint/Algorithms/WindowBasedAlgorithm.swift similarity index 56% rename from Sources/Checkpoint/Limiters/WindowBasedLimiter.swift rename to Sources/Checkpoint/Algorithms/WindowBasedAlgorithm.swift index bf96be6..7561532 100644 --- a/Sources/Checkpoint/Limiters/WindowBasedLimiter.swift +++ b/Sources/Checkpoint/Algorithms/WindowBasedAlgorithm.swift @@ -8,15 +8,18 @@ import Combine import Foundation -typealias WindowBasedAction = () throws -> Void +public typealias WindowBasedAction = () throws -> Void -protocol WindowBasedLimiter: Limiter { +/// For those algorithims thar works with fixed time windows. +public protocol WindowBasedAlgorithm: Algorithm { + /// Start the timer for a given duration (time window) func startWindow(havingDuration seconds: Double, performing action: @escaping WindowBasedAction) -> AnyCancellable + /// Perfomrs the reset operation when the time windo ends. func resetWindow() async throws } -extension WindowBasedLimiter { - func startWindow(havingDuration seconds: Double, performing action: @escaping WindowBasedAction) -> AnyCancellable { +extension WindowBasedAlgorithm { + public func startWindow(havingDuration seconds: Double, performing action: @escaping WindowBasedAction) -> AnyCancellable { var cancellable = Timer.publish(every: seconds, on: .main, in: .common) .autoconnect() .sink { _ in diff --git a/Sources/Checkpoint/Checkpoint.swift b/Sources/Checkpoint/Checkpoint.swift index 00cd591..4d70e21 100644 --- a/Sources/Checkpoint/Checkpoint.swift +++ b/Sources/Checkpoint/Checkpoint.swift @@ -8,69 +8,82 @@ import Redis import Vapor -final class Checkpoint { - let limiter: any Limiter +public typealias CheckpointHandler = (Request) -> Void +public typealias CheckpointRateLimitHandler = (Request, Response, Checkpoint.ErrorMetadata) -> Void +public typealias CheckpointErrorHandler = (Request, Response, AbortError, Checkpoint.ErrorMetadata) -> Void + +public final class Checkpoint { + private let algorithm: any Algorithm + + public var willCheck: CheckpointHandler? + public var didCheck: CheckpointHandler? + public var didFailWithTooManyRequest: CheckpointRateLimitHandler? + public var didFail: CheckpointErrorHandler? - init(using algorithm: some Limiter) { - self.limiter = algorithm + public init(using algorithm: some Algorithm) { + self.algorithm = algorithm } } extension Checkpoint: AsyncMiddleware { - func respond(to request: Request, chainingTo next: any AsyncResponder) async throws -> Response { - limiter.logging?.info("👉 RateLimitMiddleware request") + public func respond(to request: Request, chainingTo next: any AsyncResponder) async throws -> Response { let response = try await next.respond(to: request) - limiter.logging?.info("👈 RateLimitMiddleware reponse") do { + willCheck?(request) try await checkRateLimitFor(request: request) - response.headers.add(name: "X-App-Version", value: "v1.0.0") - limiter.logging?.info("💡 Header Setted.") + didCheck?(request) } catch let abort as AbortError { - throw abort - } catch { - response.headers.add(name: "X-Rate-Limit", value: "8") - limiter.logging?.info("🚨 Header Setted.") - throw Abort(.tooManyRequests, - headers: response.headers, - reason: HTTPErrorDescription.rateLimitReached) + let errorMetadata = ErrorMetadata() + + switch abort.status { + case .tooManyRequests: + didFailWithTooManyRequest?(request, response, errorMetadata) + + throw Abort(.tooManyRequests, + headers: errorMetadata.httpHeaders, + reason: errorMetadata.reason) + default: + didFail?(request, response, abort, errorMetadata) + + throw Abort(.badRequest, + headers: errorMetadata.httpHeaders, + reason: errorMetadata.reason) + } } return response } private func checkRateLimitFor(request: Request) async throws { - try await limiter.checkRequest(request) + try await algorithm.checkRequest(request) } } -extension Checkpoint { - enum Constants { - static let apiKeyHeader = "X-ApiKey" - static let rateLimitDB = "rate-limit" - } - - enum HTTPErrorDescription { - static let unauthorized = "X-Api-Key header not available in the request" - static let rateLimitReached = "You have exceed your ApiKey network requests rate" +public extension Checkpoint { + final class ErrorMetadata { + public var headers: [String : String]? + public var reason: String? + + var httpHeaders: HTTPHeaders { + var httpHeaders = HTTPHeaders() + + guard let headers else { + return httpHeaders + } + + for (key, content) in headers { + httpHeaders.add(name: key, value: content) + } + + return httpHeaders + } } } extension Checkpoint { - /* - enum Strategy { - case tokenBucket(configuration: TokenBucket.Configuration) - case leakingBucket(configuration: LeakingBucket.Configuration) - case fixedWindowCounter(configuration: FixedWindowCounter.Configuration) - case slidingWindowLog(configuration: SlidingWindowLog.Configuration) + enum HTTPErrorDescription { + static let unauthorized = "X-Api-Key header not available in the request" + static let rateLimitReached = "You have exceed your ApiKey network requests rate" } - */ - -} - -enum Strategy { - case tokenBucket - case leakingBucket - case fixedWindowCounter - case slidingWindowLog } diff --git a/Sources/Checkpoint/Extensions/Checkpoint+Vapor.swift b/Sources/Checkpoint/Extensions/Checkpoint+Vapor.swift deleted file mode 100644 index 75660fd..0000000 --- a/Sources/Checkpoint/Extensions/Checkpoint+Vapor.swift +++ /dev/null @@ -1,14 +0,0 @@ -// -// Checkpoint+Vapor.swift -// -// -// Created by Adolfo Vera Blasco on 18/6/24. -// - -import Vapor - -extension Request { - func checkpoint() { - - } -} diff --git a/Sources/Checkpoint/Limiters/FixedWindowCounter.swift b/Sources/Checkpoint/Limiters/FixedWindowCounter.swift deleted file mode 100644 index de1c549..0000000 --- a/Sources/Checkpoint/Limiters/FixedWindowCounter.swift +++ /dev/null @@ -1,79 +0,0 @@ -// -// FixedWindowCounter.swift -// -// -// Created by Adolfo Vera Blasco on 15/6/24. -// - -import Combine -import Redis -import Vapor - -/** - Algorithm can be described as follows: - - 1. Timeline is divided into fixed time windows. - 2. Each time window has a counter. - 3. When a request comes in, the counter for the current time window is incremented. - 4. If the counter is greater than the rate limit, the request is rejected. - 5. If the counter is less than the rate limit, the request is accepted. -*/ -final class FixedWindowCounter { - private let configuration: FixedWindowCounterConfiguration - let storage: Application.Redis - let logging: Logger? - - private var cancellable: AnyCancellable? - private var keys = Set() - - - init(configuration: () -> FixedWindowCounterConfiguration, storage: StorageAction, logging: LoggerAction? = nil) { - self.configuration = configuration() - self.storage = storage() - self.logging = logging?() - - self.cancellable = startWindow(havingDuration: self.configuration.timeWindowDuration.inSeconds, - performing: resetWindow) - } - - deinit { - cancellable?.cancel() - } -} - -extension FixedWindowCounter: WindowBasedLimiter { - func checkRequest(_ request: Request) async throws { - guard let requestKey = try? valueFor(field: configuration.appliedField, in: request, inside: configuration.scope) else { - return - } - - keys.insert(requestKey) - - let redisKey = RedisKey(requestKey) - let timestamp = Date.now.timeIntervalSince1970 - - // If window request is full then drop - storage.rpush([ timestamp ], into: redisKey) - let requestCount = try await storage.llen(of: redisKey).get() - - if requestCount > configuration.requestPerWindow { - throw Abort(.tooManyRequests) - } - } - - func resetWindow() { - keys.forEach { key in - let redisKey = RedisKey(key) - - Task { - do { - try await storage.delete(redisKey).get() - } catch let redisError { - logging?.error("🚨 Error deleting key \(key): \(redisError.localizedDescription)") - } - } - } - - keys.removeAll() - } -} diff --git a/Sources/Checkpoint/Limiters/Limiter.swift b/Sources/Checkpoint/Limiters/Limiter.swift deleted file mode 100644 index c346be6..0000000 --- a/Sources/Checkpoint/Limiters/Limiter.swift +++ /dev/null @@ -1,78 +0,0 @@ -// -// Limiter.swift -// -// -// Created by Adolfo Vera Blasco on 14/6/24. -// - -import Combine -import Vapor - -typealias StorageAction = () -> Application.Redis -typealias LoggerAction = () -> Logger - -protocol Limiter: Sendable { - associatedtype ConfigurationType - - var storage: Application.Redis { get } - var logging: Logger? { get } - - init(configuration: () -> ConfigurationType, storage: StorageAction, logging: LoggerAction?) - - func checkRequest(_ request: Request) async throws -} - -extension Limiter { - func valueFor(field: Field, in request: Request) throws -> String { - switch field { - case .header(let key): - guard let value = request.headers[key].first else { - throw Abort(.unauthorized, reason: "") - } - - return value - case .queryItem(let key): - guard let value = request.query[String.self, at: key] else { - throw Abort(.unauthorized, reason: "") - } - - return value - case .none: - return Self.none - } - } - - func valueFor(scope: RateLimitScope, in request: Request) throws -> String { - switch scope { - case .endpoint: - return request.url.path - case .api: - guard let host = request.url.host else { - throw Abort(.badRequest, reason: "") - } - - return host - case .nonScope: - return Self.nonScope - } - } - - func valueFor(field: Field, in request: Request, inside scope: RateLimitScope) throws -> String { - let prefix = try valueFor(field: field, in: request) - let suffix = try valueFor(scope: scope, in: request) - - let key = String("\(prefix)\(suffix)".hash) - - return key - } -} - -extension Limiter { - static var none: String { - "no-key" - } - - static var nonScope: String { - "non-scope" - } -} diff --git a/Sources/Checkpoint/Limiters/SlidingWindowLog.swift b/Sources/Checkpoint/Limiters/SlidingWindowLog.swift deleted file mode 100644 index 272af7f..0000000 --- a/Sources/Checkpoint/Limiters/SlidingWindowLog.swift +++ /dev/null @@ -1,58 +0,0 @@ -// -// SlidingWindowLog.swift -// -// -// Created by Adolfo Vera Blasco on 15/6/24. -// - -import Redis -import Vapor - -final class SlidingWindowLog { - private let configuration: SlidingWindowLogConfiguration - let storage: Application.Redis - let logging: Logger? - - init(configuration: () -> SlidingWindowLogConfiguration, storage: StorageAction, logging: LoggerAction? = nil) { - self.configuration = configuration() - self.storage = storage() - self.logging = logging?() - } -} - -extension SlidingWindowLog: Limiter { - var isValidRequest: Bool { - return true - } - - func checkRequest(_ request: Request) async throws { - guard let apiKey = try? valueFor(field: configuration.appliedField, in: request, inside: configuration.scope) else { - throw Abort(.unauthorized, reason: Checkpoint.HTTPErrorDescription.unauthorized) - } - - logging?.info("💡 ApiKey: \(apiKey)") - let redisKey = RedisKey(apiKey) - - let requestDate = Date.now - let outdatedRequestLimiteDate = Date.now.addingTimeInterval(-configuration.timeWindowDuration.inSeconds) - - // 1. Delete outdated request - let topBound: Double = Double(outdatedRequestLimiteDate.timeIntervalSinceReferenceDate) - let deletedEntriesCount = try await storage.zremrangebyscore(from: redisKey, withMaximumScoreOf: RedisZScoreBound(floatLiteral: topBound)).get() - logging?.info("🆑 Deleted \(deletedEntriesCount) items for api key \"\(apiKey)\"") - - // 2. Add the current request - let requestTimeInterval = Double(requestDate.timeIntervalSinceReferenceDate) - try await storage.zadd([ (element: requestTimeInterval, score: requestTimeInterval) ], to: redisKey).get() - logging?.info("💡 Element: \(requestTimeInterval) Score: \(requestTimeInterval)") - - // 3. Get the number of request for this time window - let itemsCount = try await storage.zcount(of: redisKey, withScores: 0.0...requestTimeInterval).get() - logging?.info("💡 Current items: \(itemsCount)") - - // 4. If request count is greater... - if itemsCount > configuration.requestPerWindow { - throw Abort(.tooManyRequests) - } - } -} diff --git a/Sources/Checkpoint/Model/Configuration.swift b/Sources/Checkpoint/Model/Configuration.swift index ac72665..57ba452 100644 --- a/Sources/Checkpoint/Model/Configuration.swift +++ b/Sources/Checkpoint/Model/Configuration.swift @@ -6,41 +6,15 @@ // -protocol Configuration { +public protocol Configuration { var appliedField: Field { get } - var scope: RateLimitScope { get } + var scope: Scope { get } } -struct FixedWindowCounterConfiguration: Configuration { - var requestPerWindow = 5 - var timeWindowDuration: TimeWindow = .seconds(count: 10) - - var appliedField: Field - var scope: RateLimitScope -} -struct LeakingBucketConfiguration: Configuration { - var bucketSize = 10 - var tokenRemovingRate = 5 - var timeWindowDuration: TimeWindow = .seconds(count: 10) - - var appliedField: Field - var scope: RateLimitScope -} -struct SlidingWindowLogConfiguration: Configuration { - var requestPerWindow = 10 - var timeWindowDuration: TimeWindow - - var appliedField: Field - var scope: RateLimitScope -} -struct TokenBucketConfiguration: Configuration { - var bucketSize: Int - var refillRate: Int - var refillTimeInterval: TimeWindow - - var appliedField: Field - var scope: RateLimitScope -} + + + + diff --git a/Sources/Checkpoint/Model/Field.swift b/Sources/Checkpoint/Model/Field.swift index 74c02e9..db53e50 100644 --- a/Sources/Checkpoint/Model/Field.swift +++ b/Sources/Checkpoint/Model/Field.swift @@ -7,8 +7,8 @@ import Foundation -enum Field { +public enum Field { case header(key: String) case queryItem(key: String) - case none + case noField } diff --git a/Sources/Checkpoint/Model/FixedWindowCounterConfiguration.swift b/Sources/Checkpoint/Model/FixedWindowCounterConfiguration.swift new file mode 100644 index 0000000..c1ddb14 --- /dev/null +++ b/Sources/Checkpoint/Model/FixedWindowCounterConfiguration.swift @@ -0,0 +1,25 @@ +// +// FixedWindowCounterConfiguration.swift +// +// +// Created by Adolfo Vera Blasco on 25/6/24. +// + +import Foundation + +public struct FixedWindowCounterConfiguration: Configuration { + public private(set) var requestPerWindow: Int + public private(set) var timeWindowDuration: TimeWindow + + public private(set) var appliedField: Field + public private(set) var scope: Scope +} + +extension FixedWindowCounterConfiguration { + public init(requestPerWindow: Int, timeWindowDuration: TimeWindow, appliedTo field: Field = .noField, inside scope: Scope = .noScope) { + self.requestPerWindow = requestPerWindow + self.timeWindowDuration = timeWindowDuration + self.appliedField = field + self.scope = scope + } +} diff --git a/Sources/Checkpoint/Model/LeakingBucketConfiguration.swift b/Sources/Checkpoint/Model/LeakingBucketConfiguration.swift new file mode 100644 index 0000000..12cde2b --- /dev/null +++ b/Sources/Checkpoint/Model/LeakingBucketConfiguration.swift @@ -0,0 +1,27 @@ +// +// LeakingBucketConfiguration.swift +// +// +// Created by Adolfo Vera Blasco on 25/6/24. +// + +import Foundation + +public struct LeakingBucketConfiguration: Configuration { + public var bucketSize: Int + public var tokenRemovingRate: Int + public var timeWindowDuration: TimeWindow + + public var appliedField: Field + public var scope: Scope +} + +extension LeakingBucketConfiguration { + public init(bucketSize: Int, removingRate: Int, removingTimeInterval: TimeWindow, appliedTo field: Field = .noField, inside scope: Scope = .noScope) { + self.bucketSize = bucketSize + self.tokenRemovingRate = removingRate + self.timeWindowDuration = removingTimeInterval + self.appliedField = field + self.scope = scope + } +} diff --git a/Sources/Checkpoint/Model/RateLimitScope.swift b/Sources/Checkpoint/Model/Scope.swift similarity index 78% rename from Sources/Checkpoint/Model/RateLimitScope.swift rename to Sources/Checkpoint/Model/Scope.swift index 75a04fe..108aeac 100644 --- a/Sources/Checkpoint/Model/RateLimitScope.swift +++ b/Sources/Checkpoint/Model/Scope.swift @@ -7,8 +7,8 @@ import Foundation -enum RateLimitScope { +public enum Scope { case api case endpoint - case nonScope + case noScope } diff --git a/Sources/Checkpoint/Model/SlidingWindowLogConfiguration.swift b/Sources/Checkpoint/Model/SlidingWindowLogConfiguration.swift new file mode 100644 index 0000000..639b16a --- /dev/null +++ b/Sources/Checkpoint/Model/SlidingWindowLogConfiguration.swift @@ -0,0 +1,25 @@ +// +// SlidingWindowLogConfiguration.swift +// +// +// Created by Adolfo Vera Blasco on 25/6/24. +// + +import Foundation + +public struct SlidingWindowLogConfiguration: Configuration { + public var requestPerWindow: Int + public var timeWindowDuration: TimeWindow + + public var appliedField: Field + public var scope: Scope +} + +extension SlidingWindowLogConfiguration { + public init(requestPerWindow: Int, windowDuration: TimeWindow, appliedTo field: Field = .noField, inside scope: Scope = .noScope) { + self.requestPerWindow = requestPerWindow + self.timeWindowDuration = windowDuration + self.appliedField = field + self.scope = scope + } +} diff --git a/Sources/Checkpoint/Model/TimeWindow.swift b/Sources/Checkpoint/Model/TimeWindow.swift index 0735a7b..dd76770 100644 --- a/Sources/Checkpoint/Model/TimeWindow.swift +++ b/Sources/Checkpoint/Model/TimeWindow.swift @@ -7,7 +7,7 @@ import Foundation -enum TimeWindow { +public enum TimeWindow { case seconds(count: Int = 10) case minutes(count: Int = 1) case hours(count: Int = 1) diff --git a/Sources/Checkpoint/Model/TokenBucketConfiguration.swift b/Sources/Checkpoint/Model/TokenBucketConfiguration.swift new file mode 100644 index 0000000..6a3d371 --- /dev/null +++ b/Sources/Checkpoint/Model/TokenBucketConfiguration.swift @@ -0,0 +1,27 @@ +// +// TokenBucketConfiguration.swift +// +// +// Created by Adolfo Vera Blasco on 25/6/24. +// + +import Foundation + +public struct TokenBucketConfiguration: Configuration { + public private(set) var bucketSize: Int + public private(set) var refillTokenRate: Int + public private(set) var refillTimeInterval: TimeWindow + + public private(set) var appliedField: Field + public private(set) var scope: Scope +} + +extension TokenBucketConfiguration { + public init(bucketSize: Int, refillRate: Int, refillTimeInterval: TimeWindow, appliedTo field: Field = .noField, inside scope: Scope = .noScope) { + self.bucketSize = bucketSize + self.refillTokenRate = refillRate + self.refillTimeInterval = refillTimeInterval + self.appliedField = field + self.scope = scope + } +} diff --git a/Tests/CheckpointTests/CheckpointApiKeyTests.swift b/Tests/CheckpointTests/CheckpointApiKeyTests.swift new file mode 100644 index 0000000..a373856 --- /dev/null +++ b/Tests/CheckpointTests/CheckpointApiKeyTests.swift @@ -0,0 +1,181 @@ +// +// CheckpointApiKeyTests.swift +// +// +// Created by Adolfo Vera Blasco on 23/6/24. +// + +import Redis +import XCTest +import XCTVapor +@testable import Checkpoint + +final class CheckpointApiKeyTests: XCTestCase { + func testLeakingBucketWithHeader() { + let app = Application(.testing) + defer { app.shutdown() } + + app.get("leaking-bucket") { request -> HTTPStatus in + return .ok + } + + let leakingBucketAlgorithm = LeakingBucket { + LeakingBucketConfiguration(bucketSize: 10, + removingRate: 5, + removingTimeInterval: .minutes(count: 1), + appliedTo: .header(key: "X-ApiKey")) + } storage: { + // Rate limit database in Redis + app.redis("rate").configuration = try? RedisConfiguration(hostname: "localhost", + port: 9090, + database: 0) + + return app.redis("rate") + } logging: { + app.logger + } + + let checkpoint = Checkpoint(using: leakingBucketAlgorithm) + + app.middleware.use(checkpoint) + + var apiKeyHeader = HTTPHeaders() + apiKeyHeader.add(name: "X-ApiKey", value: "fitomad#1") + + (0...20).forEach { index in + try? app.test(.GET, "leaking-bucket", headers: apiKeyHeader, afterResponse: { testResponse in + app.logger.info("leaking-bucket \(index) = \(testResponse.status)") + if index < 10 { + XCTAssertEqual(testResponse.status, .ok) + } else { + XCTAssertEqual(testResponse.status, .tooManyRequests) + } + }) + } + } + + func testTokenBucketWithHeader() throws { + let app = Application(.testing) + defer { app.shutdown() } + + app.get("token-bucket") { request -> HTTPStatus in + return .ok + } + + let tokenbucketAlgorithm = TokenBucket { + TokenBucketConfiguration(bucketSize: 10, + refillRate: 0, + refillTimeInterval: .seconds(count: 20), + appliedTo: .header(key: "X-ApiKey")) + } storage: { + // Rate limit database in Redis + app.redis("rate").configuration = try? RedisConfiguration(hostname: "localhost", + port: 9090, + database: 0) + + return app.redis("rate") + } logging: { + app.logger + } + + let checkpoint = Checkpoint(using: tokenbucketAlgorithm) + + app.middleware.use(checkpoint) + + var apiKeyHeader = HTTPHeaders() + apiKeyHeader.add(name: "X-ApiKey", value: "fitomad#2") + + (0...20).forEach { index in + try? app.test(.GET, "token-bucket", headers: apiKeyHeader, afterResponse: { testResponse in + app.logger.info("token-bucket \(index) = \(testResponse.status)") + if index < 10 { + XCTAssertEqual(testResponse.status, .ok) + } else { + XCTAssertEqual(testResponse.status, .tooManyRequests) + } + }) + } + } + + func testFixedWindowCounterWithHeader() { + let app = Application(.testing) + defer { app.shutdown() } + + app.get("fixed-window-counter") { request -> HTTPStatus in + return .ok + } + + let fixedWindowAlgorithm = FixedWindowCounter { + FixedWindowCounterConfiguration(requestPerWindow: 10, + timeWindowDuration: .minutes(count: 2), + appliedTo: .header(key: "X-ApiKey")) + } storage: { + // Rate limit database in Redis + app.redis("rate").configuration = try? RedisConfiguration(hostname: "localhost", + port: 9090, + database: 0) + + return app.redis("rate") + } + + + let checkpoint = Checkpoint(using: fixedWindowAlgorithm) + + app.middleware.use(checkpoint) + + var apiKeyHeader = HTTPHeaders() + apiKeyHeader.add(name: "X-ApiKey", value: "fitomad#3") + + (0...20).forEach { index in + try? app.test(.GET, "fixed-window-counter", headers: apiKeyHeader, afterResponse: { testResponse in + app.logger.info("fixed-window-counter \(index) = \(testResponse.status)") + if index < 10 { + XCTAssertEqual(testResponse.status, .ok) + } else { + XCTAssertEqual(testResponse.status, .tooManyRequests) + } + }) + } + } + + func testSlidingWindowLogWithHeader() { + let app = Application(.testing) + defer { app.shutdown() } + + app.get("sliding-window-log") { request -> HTTPStatus in + return .ok + } + + let slidingWindowLogAlgorith = SlidingWindowLog { + SlidingWindowLogConfiguration(requestPerWindow: 10, + windowDuration: .minutes(count: 2), + appliedTo: .header(key: "X-ApiKey")) + } storage: { + // Rate limit database in Redis + app.redis("rate").configuration = try? RedisConfiguration(hostname: "localhost", + port: 9090, + database: 0) + + return app.redis("rate") + } + + + let checkpoint = Checkpoint(using: slidingWindowLogAlgorith) + + app.middleware.use(checkpoint) + + var apiKeyHeader = HTTPHeaders() + apiKeyHeader.add(name: "X-ApiKey", value: "fitomad#4") + + (0...20).forEach { index in + try? app.test(.GET, "sliding-window-log", headers: apiKeyHeader, afterResponse: { testResponse in + app.logger.info("sliding-window-log \(index) = \(testResponse.status)") + if index < 10 { + XCTAssertEqual(testResponse.status, .ok) + } else { + XCTAssertEqual(testResponse.status, .tooManyRequests) + } + }) + } + } +} diff --git a/Tests/CheckpointTests/CheckpointApiScopeTests.swift b/Tests/CheckpointTests/CheckpointApiScopeTests.swift new file mode 100644 index 0000000..2ea24c3 --- /dev/null +++ b/Tests/CheckpointTests/CheckpointApiScopeTests.swift @@ -0,0 +1,189 @@ +// +// CheckpointApiScopeTests.swift +// +// +// Created by Adolfo Vera Blasco on 24/6/24. +// + +import Redis +import XCTest +import XCTVapor +@testable import Checkpoint + +final class CheckpointApiScoreTests: XCTestCase { + func testLeakingBucketWithScopeApiHeader() { + let app = Application(.testing) + defer { app.shutdown() } + + app.get("leaking-bucket") { request -> HTTPStatus in + return .ok + } + + let leakingBucketAlgorithm = LeakingBucket { + LeakingBucketConfiguration(bucketSize: 10, + removingRate: 5, + removingTimeInterval: .minutes(count: 1), + appliedTo: .header(key: "X-ApiKey"), + inside :.endpoint) + } storage: { + // Rate limit database in Redis + app.redis("rate").configuration = try? RedisConfiguration(hostname: "localhost", + port: 9090, + database: 0) + + return app.redis("rate") + } logging: { + app.logger + } + + let checkpoint = Checkpoint(using: leakingBucketAlgorithm) + + app.middleware.use(checkpoint) + + var apiKeyHeader = HTTPHeaders() + apiKeyHeader.add(name: "X-ApiKey", value: "fitomad#1") + + (0...20).forEach { index in + try? app.test(.GET, "leaking-bucket", headers: apiKeyHeader, afterResponse: { testResponse in + app.logger.info("leaking-bucket \(index) = \(testResponse.status)") + if index < 10 { + XCTAssertEqual(testResponse.status, .ok) + } else { + XCTAssertEqual(testResponse.status, .tooManyRequests) + } + }) + } + } + + func testTokenBucketWithScopeApiHeader() throws { + let app = Application(.testing) + defer { app.shutdown() } + + app.get("token-bucket") { request -> HTTPStatus in + return .ok + } + + let tokenbucketAlgorithm = TokenBucket { + TokenBucketConfiguration(bucketSize: 10, + refillRate: 0, + refillTimeInterval: .seconds(count: 20), + appliedTo: .header(key: "X-ApiKey"), + inside: .endpoint) + } storage: { + // Rate limit database in Redis + app.redis("rate").configuration = try? RedisConfiguration(hostname: "localhost", + port: 9090, + database: 0) + + return app.redis("rate") + } logging: { + app.logger + } + + let checkpoint = Checkpoint(using: tokenbucketAlgorithm) + + app.middleware.use(checkpoint) + + var apiKeyHeader = HTTPHeaders() + apiKeyHeader.add(name: "X-ApiKey", value: "fitomad#2") + + (0...20).forEach { index in + try? app.test(.GET, "token-bucket", headers: apiKeyHeader, afterResponse: { testResponse in + app.logger.info("token-bucket \(index) = \(testResponse.status)") + if index < 10 { + XCTAssertEqual(testResponse.status, .ok) + } else { + XCTAssertEqual(testResponse.status, .tooManyRequests) + } + }) + } + } + + func testFixedWindowCounterScopeApiWithHeader() { + let app = Application(.testing) + defer { app.shutdown() } + + app.get("fixed-window-counter") { request -> HTTPStatus in + return .ok + } + + let fixedWindowAlgorithm = FixedWindowCounter { + FixedWindowCounterConfiguration(requestPerWindow: 10, + timeWindowDuration: .minutes(count: 2), + appliedTo: .header(key: "X-ApiKey"), + inside: .endpoint) + } storage: { + // Rate limit database in Redis + app.redis("rate").configuration = try? RedisConfiguration(hostname: "localhost", + port: 9090, + database: 0) + + return app.redis("rate") + } logging: { + app.logger + } + + + let checkpoint = Checkpoint(using: fixedWindowAlgorithm) + + app.middleware.use(checkpoint) + + var apiKeyHeader = HTTPHeaders() + apiKeyHeader.add(name: "X-ApiKey", value: "fitomad#3") + + (0...20).forEach { index in + try? app.test(.GET, "fixed-window-counter", headers: apiKeyHeader, afterResponse: { testResponse in + app.logger.info("fixed-window-counter \(index) = \(testResponse.status)") + if index < 10 { + XCTAssertEqual(testResponse.status, .ok) + } else { + XCTAssertEqual(testResponse.status, .tooManyRequests) + } + }) + } + } + + func testSlidingWindowLogScopeApiWithHeader() { + let app = Application(.testing) + defer { app.shutdown() } + + app.get("sliding-window-log") { request -> HTTPStatus in + return .ok + } + + let slidingWindowLogAlgorith = SlidingWindowLog { + SlidingWindowLogConfiguration(requestPerWindow: 10, + windowDuration: .minutes(count: 2), + appliedTo: .header(key: "X-ApiKey"), + inside: .endpoint) + } storage: { + // Rate limit database in Redis + app.redis("rate").configuration = try? RedisConfiguration(hostname: "localhost", + port: 9090, + database: 0) + + return app.redis("rate") + } logging: { + app.logger + } + + + let checkpoint = Checkpoint(using: slidingWindowLogAlgorith) + + app.middleware.use(checkpoint) + + var apiKeyHeader = HTTPHeaders() + apiKeyHeader.add(name: "X-ApiKey", value: "fitomad#4") + + (0...20).forEach { index in + try? app.test(.GET, "sliding-window-log", headers: apiKeyHeader, afterResponse: { testResponse in + app.logger.info("sliding-window-log \(index) = \(testResponse.status)") + if index < 10 { + XCTAssertEqual(testResponse.status, .ok) + } else { + XCTAssertEqual(testResponse.status, .tooManyRequests) + } + }) + } + } +} diff --git a/Tests/CheckpointTests/CheckpointResponse.swift b/Tests/CheckpointTests/CheckpointResponse.swift new file mode 100644 index 0000000..051506b --- /dev/null +++ b/Tests/CheckpointTests/CheckpointResponse.swift @@ -0,0 +1,203 @@ +// +// CheckpointResponse.swift +// +// +// Created by Adolfo Vera Blasco on 24/6/24. +// + +import Redis +import XCTest +import XCTVapor +@testable import Checkpoint + +final class CheckpointResponse: XCTestCase { + func testLeakingBucketResponse() { + let app = Application(.testing) + defer { app.shutdown() } + + app.get("leaking-bucket") { request -> HTTPStatus in + return .ok + } + + let leakingBucketAlgorithm = LeakingBucket { + LeakingBucketConfiguration(bucketSize: 10, + removingRate: 5, + removingTimeInterval: .minutes(count: 1), + appliedTo: .header(key: "X-ApiKey"), + inside :.endpoint) + } storage: { + // Rate limit database in Redis + app.redis("rate").configuration = try? RedisConfiguration(hostname: "localhost", + port: 9090, + database: 0) + + return app.redis("rate") + } + + let checkpoint = Checkpoint(using: leakingBucketAlgorithm) + + checkpoint.didFailWithTooManyRequest = { (request, response, metadata) in + metadata.headers = [ + "X-RateLimit" : "Failure for request \(request.id)" + ] + } + + app.middleware.use(checkpoint) + + var apiKeyHeader = HTTPHeaders() + apiKeyHeader.add(name: "X-ApiKey", value: "fitomad#5") + + (0...20).forEach { index in + try? app.test(.GET, "leaking-bucket", headers: apiKeyHeader, afterResponse: { testResponse in + if index < 10 { + XCTAssertFalse(testResponse.headers.contains(name: "X-RateLimit")) + } else { + XCTAssertTrue(testResponse.headers.contains(name: "X-RateLimit")) + } + }) + } + } + + func testTokenBucketResponse() throws { + let app = Application(.testing) + defer { app.shutdown() } + + app.get("token-bucket") { request -> HTTPStatus in + return .ok + } + + let tokenbucketAlgorithm = TokenBucket { + TokenBucketConfiguration(bucketSize: 10, + refillRate: 0, + refillTimeInterval: .seconds(count: 20), + appliedTo: .header(key: "X-ApiKey"), + inside: .endpoint) + } storage: { + // Rate limit database in Redis + app.redis("rate").configuration = try? RedisConfiguration(hostname: "localhost", + port: 9090, + database: 0) + + return app.redis("rate") + } logging: { + app.logger + } + + let checkpoint = Checkpoint(using: tokenbucketAlgorithm) + + checkpoint.didFailWithTooManyRequest = { (request, response, metadata) in + metadata.headers = [ + "X-RateLimit" : "Failure for request \(request.id)" + ] + } + + app.middleware.use(checkpoint) + + var apiKeyHeader = HTTPHeaders() + apiKeyHeader.add(name: "X-ApiKey", value: "fitomad#6") + + (0...20).forEach { index in + try? app.test(.GET, "token-bucket", headers: apiKeyHeader, afterResponse: { testResponse in + if index < 10 { + XCTAssertFalse(testResponse.headers.contains(name: "X-RateLimit")) + } else { + XCTAssertTrue(testResponse.headers.contains(name: "X-RateLimit")) + } + }) + } + } + + func testFixedWindowCounterResponse() { + let app = Application(.testing) + defer { app.shutdown() } + + app.get("fixed-window-counter") { request -> HTTPStatus in + return .ok + } + + let fixedWindowAlgorithm = FixedWindowCounter { + FixedWindowCounterConfiguration(requestPerWindow: 10, + timeWindowDuration: .minutes(count: 2), + appliedTo: .header(key: "X-ApiKey"), + inside: .endpoint) + } storage: { + // Rate limit database in Redis + app.redis("rate").configuration = try? RedisConfiguration(hostname: "localhost", + port: 9090, + database: 0) + + return app.redis("rate") + } + + + let checkpoint = Checkpoint(using: fixedWindowAlgorithm) + + checkpoint.didFailWithTooManyRequest = { (request, response, metadata) in + metadata.headers = [ + "X-RateLimit" : "Failure for request \(request.id)" + ] + } + + app.middleware.use(checkpoint) + + var apiKeyHeader = HTTPHeaders() + apiKeyHeader.add(name: "X-ApiKey", value: "fitomad#7") + + (0...20).forEach { index in + try? app.test(.GET, "fixed-window-counter", headers: apiKeyHeader, afterResponse: { testResponse in + if index < 10 { + XCTAssertFalse(testResponse.headers.contains(name: "X-RateLimit")) + } else { + XCTAssertTrue(testResponse.headers.contains(name: "X-RateLimit")) + } + }) + } + } + + func testSlidingWindowLogResponse() { + let app = Application(.testing) + defer { app.shutdown() } + + app.get("sliding-window-log") { request -> HTTPStatus in + return .ok + } + + let slidingWindowLogAlgorith = SlidingWindowLog { + SlidingWindowLogConfiguration(requestPerWindow: 10, + windowDuration: .minutes(count: 2), + appliedTo: .header(key: "X-ApiKey"), + inside: .endpoint) + } storage: { + // Rate limit database in Redis + app.redis("rate").configuration = try? RedisConfiguration(hostname: "localhost", + port: 9090, + database: 0) + + return app.redis("rate") + } + + + let checkpoint = Checkpoint(using: slidingWindowLogAlgorith) + + checkpoint.didFailWithTooManyRequest = { (request, response, metadata) in + metadata.headers = [ + "X-RateLimit" : "Failure for request \(request.id)" + ] + } + + app.middleware.use(checkpoint) + + var apiKeyHeader = HTTPHeaders() + apiKeyHeader.add(name: "X-ApiKey", value: "fitomad#8") + + (0...20).forEach { index in + try? app.test(.GET, "sliding-window-log", headers: apiKeyHeader, afterResponse: { testResponse in + if index < 10 { + XCTAssertFalse(testResponse.headers.contains(name: "X-RateLimit")) + } else { + XCTAssertTrue(testResponse.headers.contains(name: "X-RateLimit")) + } + }) + } + } +} diff --git a/Tests/CheckpointTests/CheckpointScopeTests.swift b/Tests/CheckpointTests/CheckpointScopeTests.swift new file mode 100644 index 0000000..f73fbf4 --- /dev/null +++ b/Tests/CheckpointTests/CheckpointScopeTests.swift @@ -0,0 +1,167 @@ +// +// CheckpointScopeTests.swift +// +// +// Created by Adolfo Vera Blasco on 24/6/24. +// + +import Redis +import XCTest +import XCTVapor +@testable import Checkpoint + +final class CheckpointScopeTests: XCTestCase { + func testLeakingBucketWithScopeHeader() { + let app = Application(.testing) + defer { app.shutdown() } + + app.get("leaking-bucket") { request -> HTTPStatus in + return .ok + } + + let leakingBucketAlgorithm = LeakingBucket { + LeakingBucketConfiguration(bucketSize: 10, + removingRate: 5, + removingTimeInterval: .minutes(count: 1), + inside: .endpoint) + } storage: { + // Rate limit database in Redis + app.redis("rate").configuration = try? RedisConfiguration(hostname: "localhost", + port: 9090, + database: 0) + + return app.redis("rate") + } + + let checkpoint = Checkpoint(using: leakingBucketAlgorithm) + + app.middleware.use(checkpoint) + + (0...20).forEach { index in + try? app.test(.GET, "leaking-bucket", afterResponse: { testResponse in + app.logger.info("leaking-bucket \(index) = \(testResponse.status)") + if index < 10 { + XCTAssertEqual(testResponse.status, .ok) + } else { + XCTAssertEqual(testResponse.status, .tooManyRequests) + } + }) + } + } + + func testTokenBucketWithScopeHeader() throws { + let app = Application(.testing) + defer { app.shutdown() } + + app.get("token-bucket") { request -> HTTPStatus in + return .ok + } + + let tokenbucketAlgorithm = TokenBucket { + TokenBucketConfiguration(bucketSize: 10, + refillRate: 0, + refillTimeInterval: .seconds(count: 20), + inside: .endpoint) + } storage: { + // Rate limit database in Redis + app.redis("rate").configuration = try? RedisConfiguration(hostname: "localhost", + port: 9090, + database: 0) + + return app.redis("rate") + } logging: { + app.logger + } + + let checkpoint = Checkpoint(using: tokenbucketAlgorithm) + + app.middleware.use(checkpoint) + + (0...20).forEach { index in + try? app.test(.GET, "token-bucket", afterResponse: { testResponse in + app.logger.info("token-bucket \(index) = \(testResponse.status)") + if index < 10 { + XCTAssertEqual(testResponse.status, .ok) + } else { + XCTAssertEqual(testResponse.status, .tooManyRequests) + } + }) + } + } + + func testFixedWindowCounterWithScopeHeader() { + let app = Application(.testing) + defer { app.shutdown() } + + app.get("fixed-window-counter") { request -> HTTPStatus in + return .ok + } + + let fixedWindowAlgorithm = FixedWindowCounter { + FixedWindowCounterConfiguration(requestPerWindow: 10, + timeWindowDuration: .minutes(count: 2), + inside: .endpoint) + } storage: { + // Rate limit database in Redis + app.redis("rate").configuration = try? RedisConfiguration(hostname: "localhost", + port: 9090, + database: 0) + + return app.redis("rate") + } + + + let checkpoint = Checkpoint(using: fixedWindowAlgorithm) + + app.middleware.use(checkpoint) + + (0...20).forEach { index in + try? app.test(.GET, "fixed-window-counter", afterResponse: { testResponse in + app.logger.info("fixed-window-counter \(index) = \(testResponse.status)") + if index < 10 { + XCTAssertEqual(testResponse.status, .ok) + } else { + XCTAssertEqual(testResponse.status, .tooManyRequests) + } + }) + } + } + + func testSlidingWindowLogWithScopeHeader() { + let app = Application(.testing) + defer { app.shutdown() } + + app.get("sliding-window-log") { request -> HTTPStatus in + return .ok + } + + let slidingWindowLogAlgorith = SlidingWindowLog { + SlidingWindowLogConfiguration(requestPerWindow: 10, + windowDuration: .minutes(count: 2), + inside: .endpoint) + } storage: { + // Rate limit database in Redis + app.redis("rate").configuration = try? RedisConfiguration(hostname: "localhost", + port: 9090, + database: 0) + + return app.redis("rate") + } + + + let checkpoint = Checkpoint(using: slidingWindowLogAlgorith) + + app.middleware.use(checkpoint) + + (0...20).forEach { index in + try? app.test(.GET, "sliding-window-log", afterResponse: { testResponse in + app.logger.info("sliding-window-log \(index) = \(testResponse.status)") + if index < 10 { + XCTAssertEqual(testResponse.status, .ok) + } else { + XCTAssertEqual(testResponse.status, .tooManyRequests) + } + }) + } + } +} diff --git a/Tests/CheckpointTests/CheckpointTests.swift b/Tests/CheckpointTests/CheckpointTests.swift index c29ca31..895dde2 100644 --- a/Tests/CheckpointTests/CheckpointTests.swift +++ b/Tests/CheckpointTests/CheckpointTests.swift @@ -1,4 +1,6 @@ +import Redis import XCTest +import XCTVapor @testable import Checkpoint final class CheckpointTests: XCTestCase { @@ -9,4 +11,154 @@ final class CheckpointTests: XCTestCase { // Defining Test Cases and Test Methods // https://developer.apple.com/documentation/xctest/defining_test_cases_and_test_methods } + + func testLeakingBucket() { + let app = Application(.testing) + defer { app.shutdown() } + + app.get("leaking-bucket") { request -> HTTPStatus in + return .ok + } + + let leakingBucketAlgorithm = LeakingBucket { + LeakingBucketConfiguration(bucketSize: 10, + removingRate: 5, + removingTimeInterval: .minutes(count: 1)) + } storage: { + // Rate limit database in Redis + app.redis("rate").configuration = try? RedisConfiguration(hostname: "localhost", + port: 9090, + database: 0) + + return app.redis("rate") + } + + let checkpoint = Checkpoint(using: leakingBucketAlgorithm) + + app.middleware.use(checkpoint) + + (0...20).forEach { index in + try? app.test(.GET, "leaking-bucket", afterResponse: { testResponse in + app.logger.info("\(index) = \(testResponse.status)") + if index < 10 { + XCTAssertEqual(testResponse.status, .ok) + } else { + XCTAssertEqual(testResponse.status, .tooManyRequests) + } + }) + } + } + + func testTokenBucket() throws { + let app = Application(.testing) + defer { app.shutdown() } + + app.get("token-bucket") { request -> HTTPStatus in + return .ok + } + + let tokenbucketAlgorithm = TokenBucket { + TokenBucketConfiguration(bucketSize: 10, + refillRate: 0, + refillTimeInterval: .seconds(count: 20)) + } storage: { + // Rate limit database in Redis + app.redis("rate").configuration = try? RedisConfiguration(hostname: "localhost", + port: 9090, + database: 0) + + return app.redis("rate") + } logging: { + app.logger + } + + let checkpoint = Checkpoint(using: tokenbucketAlgorithm) + + app.middleware.use(checkpoint) + + (0...20).forEach { index in + try? app.test(.GET, "token-bucket", afterResponse: { testResponse in + app.logger.info("\(index) = \(testResponse.status)") + if index < 10 { + XCTAssertEqual(testResponse.status, .ok) + } else { + XCTAssertEqual(testResponse.status, .tooManyRequests) + } + }) + } + } + + func testFixedWindowCounter() { + let app = Application(.testing) + defer { app.shutdown() } + + app.get("fixed-window-counter") { request -> HTTPStatus in + return .ok + } + + let fixedWindowAlgorithm = FixedWindowCounter { + FixedWindowCounterConfiguration(requestPerWindow: 10, + timeWindowDuration: .minutes(count: 2)) + } storage: { + // Rate limit database in Redis + app.redis("rate").configuration = try? RedisConfiguration(hostname: "localhost", + port: 9090, + database: 0) + + return app.redis("rate") + } + + + let checkpoint = Checkpoint(using: fixedWindowAlgorithm) + + app.middleware.use(checkpoint) + + (0...20).forEach { index in + try? app.test(.GET, "fixed-window-counter", afterResponse: { testResponse in + app.logger.info("\(index) = \(testResponse.status)") + if index < 10 { + XCTAssertEqual(testResponse.status, .ok) + } else { + XCTAssertEqual(testResponse.status, .tooManyRequests) + } + }) + } + } + + func testSlidingWindowLog() { + let app = Application(.testing) + defer { app.shutdown() } + + app.get("sliding-window-log") { request -> HTTPStatus in + return .ok + } + + let slidingWindowLogAlgorith = SlidingWindowLog { + SlidingWindowLogConfiguration(requestPerWindow: 10, + windowDuration: .minutes(count: 2)) + } storage: { + // Rate limit database in Redis + app.redis("rate").configuration = try? RedisConfiguration(hostname: "localhost", + port: 9090, + database: 0) + + return app.redis("rate") + } + + + let checkpoint = Checkpoint(using: slidingWindowLogAlgorith) + + app.middleware.use(checkpoint) + + (0...20).forEach { index in + try? app.test(.GET, "sliding-window-log", afterResponse: { testResponse in + app.logger.info("\(index) = \(testResponse.status)") + if index < 10 { + XCTAssertEqual(testResponse.status, .ok) + } else { + XCTAssertEqual(testResponse.status, .tooManyRequests) + } + }) + } + } }