Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RateLimiter - consider whole operation execution time #251

Merged
merged 11 commits into from
Dec 17, 2024
108 changes: 108 additions & 0 deletions core/src/main/scala/ox/resilience/DurationRateLimiterAlgorithm.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
package ox.resilience

import java.util.concurrent.Semaphore
import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.atomic.AtomicLong
import java.util.concurrent.atomic.AtomicReference
import scala.annotation.tailrec
import scala.collection.immutable.Queue
import scala.concurrent.duration.FiniteDuration

object DurationRateLimiterAlgorithm:
Kamil-Lontkowski marked this conversation as resolved.
Show resolved Hide resolved
/** Fixed window algorithm: allows to run at most `rate` operations in consecutively segments of duration `per`. */
Kamil-Lontkowski marked this conversation as resolved.
Show resolved Hide resolved
case class FixedWindow(rate: Int, per: FiniteDuration) extends RateLimiterAlgorithm:
private val lastUpdate = new AtomicLong(System.nanoTime())
private val semaphore = new Semaphore(rate)
private val runningOperations = new AtomicInteger(0)

def acquire(permits: Int): Unit =
semaphore.acquire(permits)

def tryAcquire(permits: Int): Boolean =
semaphore.tryAcquire(permits)

def getNextUpdate: Long =
val waitTime = lastUpdate.get() + per.toNanos - System.nanoTime()
if waitTime > 0 then waitTime else 0L

def update(): Unit =
val now = System.nanoTime()
lastUpdate.set(now)
semaphore.release(rate - semaphore.availablePermits() - runningOperations.get())
Kamil-Lontkowski marked this conversation as resolved.
Show resolved Hide resolved
end update

def runOperation[T](operation: => T, permits: Int): T =
runningOperations.updateAndGet(_ + permits)
val result = operation
runningOperations.updateAndGet(current => (current - permits).max(0))
Kamil-Lontkowski marked this conversation as resolved.
Show resolved Hide resolved
result

end FixedWindow

/** Sliding window algorithm: allows to run at most `rate` operations in the lapse of `per` before current time. */
case class SlidingWindow(rate: Int, per: FiniteDuration) extends RateLimiterAlgorithm:
// stores the timestamp and the number of permits acquired after finishing running operation
private val log = new AtomicReference[Queue[(Long, Int)]](Queue[(Long, Int)]())
private val semaphore = new Semaphore(rate)

def acquire(permits: Int): Unit =
semaphore.acquire(permits)

def tryAcquire(permits: Int): Boolean =
semaphore.tryAcquire(permits)

private def addTimestampToLog(permits: Int): Unit =
val now = System.nanoTime()
log.updateAndGet { q =>
q.enqueue((now, permits))
}
()

def getNextUpdate: Long =
log.get().headOption match
case None =>
// no logs so no need to update until `per` has passed
per.toNanos
case Some(record) =>
// oldest log provides the new updating point
val waitTime = record._1 + per.toNanos - System.nanoTime()
if waitTime > 0 then waitTime else 0L
end getNextUpdate

def runOperation[T](operation: => T, permits: Int): T =
val result = operation
// Consider end of operation as a point to release permit after `per` passes
addTimestampToLog(permits)
result

def update(): Unit =
val now = System.nanoTime()
// retrieving current queue to append it later if some elements were added concurrently
val q = log.getAndUpdate(_ => Queue[(Long, Int)]())
// remove records older than window size
val qUpdated = removeRecords(q, now)
// merge old records with the ones concurrently added
log.updateAndGet(qNew =>
qNew.foldLeft(qUpdated) { case (queue, record) =>
queue.enqueue(record)
}
)
()
end update

@tailrec
private def removeRecords(q: Queue[(Long, Int)], now: Long): Queue[(Long, Int)] =
q.dequeueOption match
case None => q
case Some((head, tail)) =>
if head._1 + per.toNanos < now then
val (_, permits) = head
semaphore.release(permits)
removeRecords(tail, now)
else q
end match
end removeRecords

end SlidingWindow

end DurationRateLimiterAlgorithm
40 changes: 30 additions & 10 deletions core/src/main/scala/ox/resilience/RateLimiter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,25 @@ package ox.resilience
import scala.concurrent.duration.FiniteDuration
import ox.*

import java.util.concurrent.Semaphore
import scala.annotation.tailrec

/** Rate limiter with a customizable algorithm. Operations can be blocked or dropped, when the rate limit is reached. */
/** Rate limiter with a customizable algorithm. Operations can be blocked or dropped, when the rate limit is reached. operationMode decides
* if whole time of execution should be considered or just the start.
*/
class RateLimiter private (algorithm: RateLimiterAlgorithm):
/** Runs the operation, blocking if the rate limit is reached, until the rate limiter is replenished. */
def runBlocking[T](operation: => T): T =
algorithm.acquire()
operation
algorithm.runOperation(operation)

/** Runs or drops the operation, if the rate limit is reached.
*
* @return
* `Some` if the operation has been allowed to run, `None` if the operation has been dropped.
*/
def runOrDrop[T](operation: => T): Option[T] =
if algorithm.tryAcquire() then Some(operation)
if algorithm.tryAcquire() then Some(algorithm.runOperation(operation))
else None

end RateLimiter
Expand Down Expand Up @@ -46,11 +49,15 @@ object RateLimiter:
* @param maxOperations
* Maximum number of operations that are allowed to **start** within a time [[window]].
* @param window
* Interval of time between replenishing the rate limiter. THe rate limiter is replenished to allow up to [[maxOperations]] in the next
* Interval of time between replenishing the rate limiter. The rate limiter is replenished to allow up to [[maxOperations]] in the next
* time window.
*/
def fixedWindow(maxOperations: Int, window: FiniteDuration)(using Ox): RateLimiter =
apply(RateLimiterAlgorithm.FixedWindow(maxOperations, window))
def fixedWindow(maxOperations: Int, window: FiniteDuration, operationMode: RateLimiterMode = RateLimiterMode.OperationStart)(using
Kamil-Lontkowski marked this conversation as resolved.
Show resolved Hide resolved
Ox
): RateLimiter =
operationMode match
case RateLimiterMode.OperationStart => apply(RateLimiterAlgorithm.FixedWindow(maxOperations, window))
case RateLimiterMode.OperationDuration => apply(DurationRateLimiterAlgorithm.FixedWindow(maxOperations, window))

/** Creates a rate limiter using a sliding window algorithm.
*
Expand All @@ -61,10 +68,14 @@ object RateLimiter:
* @param window
* Length of the window.
*/
def slidingWindow(maxOperations: Int, window: FiniteDuration)(using Ox): RateLimiter =
apply(RateLimiterAlgorithm.SlidingWindow(maxOperations, window))
def slidingWindow(maxOperations: Int, window: FiniteDuration, operationMode: RateLimiterMode = RateLimiterMode.OperationStart)(using
Ox
): RateLimiter =
operationMode match
case RateLimiterMode.OperationStart => apply(RateLimiterAlgorithm.SlidingWindow(maxOperations, window))
case RateLimiterMode.OperationDuration => apply(DurationRateLimiterAlgorithm.SlidingWindow(maxOperations, window))

/** Rate limiter with token/leaky bucket algorithm.
/** Creates a rate limiter with token/leaky bucket algorithm.
*
* Must be run within an [[Ox]] concurrency scope, as a background fork is created, to replenish the rate limiter.
*
Expand All @@ -73,6 +84,15 @@ object RateLimiter:
* @param refillInterval
* Interval of time between adding a single token to the bucket.
*/
def leakyBucket(maxTokens: Int, refillInterval: FiniteDuration)(using Ox): RateLimiter =
def leakyBucket(maxTokens: Int, refillInterval: FiniteDuration)(using
Ox
): RateLimiter =
apply(RateLimiterAlgorithm.LeakyBucket(maxTokens, refillInterval))

end RateLimiter

/** Decides if RateLimiter should consider only start of an operation or whole time of execution.
*/
enum RateLimiterMode:
case OperationStart
case OperationDuration
17 changes: 15 additions & 2 deletions core/src/main/scala/ox/resilience/RateLimiterAlgorithm.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ trait RateLimiterAlgorithm:
/** Returns the time in nanoseconds that needs to elapse until the next update. It should not modify internal state. */
def getNextUpdate: Long

/** Runs operation. For cases where execution time is not needed it just returns result */
final def runOperation[T](operation: => T): T = runOperation(operation, 1)

/** Runs operation. For cases where execution time is not needed it just returns result */
def runOperation[T](operation: => T, permits: Int): T

end RateLimiterAlgorithm

object RateLimiterAlgorithm:
Expand All @@ -54,6 +60,8 @@ object RateLimiterAlgorithm:
semaphore.release(rate - semaphore.availablePermits())
end update

def runOperation[T](operation: => T, permits: Int): T = operation

end FixedWindow

/** Sliding window algorithm: allows to start at most `rate` operations in the lapse of `per` before current time. */
Expand Down Expand Up @@ -97,11 +105,12 @@ object RateLimiterAlgorithm:
// remove records older than window size
val qUpdated = removeRecords(q, now)
// merge old records with the ones concurrently added
val _ = log.updateAndGet(qNew =>
log.updateAndGet(qNew =>
qNew.foldLeft(qUpdated) { case (queue, record) =>
queue.enqueue(record)
}
)
()
end update

@tailrec
Expand All @@ -115,9 +124,11 @@ object RateLimiterAlgorithm:
removeRecords(tail, now)
else q

def runOperation[T](operation: => T, permits: Int): T = operation

end SlidingWindow

/** Token/leaky bucket algorithm It adds a token to start an new operation each `per` with a maximum number of tokens of `rate`. */
/** Token/leaky bucket algorithm It adds a token to start a new operation each `per` with a maximum number of tokens of `rate`. */
case class LeakyBucket(rate: Int, per: FiniteDuration) extends RateLimiterAlgorithm:
private val refillInterval = per.toNanos
private val lastRefillTime = new AtomicLong(System.nanoTime())
Expand All @@ -138,5 +149,7 @@ object RateLimiterAlgorithm:
lastRefillTime.set(now)
if semaphore.availablePermits() < rate then semaphore.release()

def runOperation[T](operation: => T, permits: Int): T = operation

end LeakyBucket
end RateLimiterAlgorithm
Loading