Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RateLimiter - consider whole operation execution time #251

Merged
merged 11 commits into from
Dec 17, 2024
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,14 @@ def computationR: Int = ???
repeat(RepeatConfig.fixedRateForever(100.millis))(computationR)
```

[Rate limit](https://ox.softwaremill.com/latest/utils/rate-limiter.html) computations:

```scala mdoc:compile-only
supervised:
val rateLimiter = RateLimiter.fixedWindowWithStartTime(2, 1.second)
rateLimiter.runBlocking({ /* ... */ })
```

Allocate a [resource](https://ox.softwaremill.com/latest/utils/resources.html) in a scope:

```scala mdoc:compile-only
Expand Down
117 changes: 117 additions & 0 deletions core/src/main/scala/ox/resilience/DurationRateLimiterAlgorithm.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package ox.resilience

import java.util.concurrent.Semaphore
import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.atomic.AtomicLong
import java.util.concurrent.atomic.AtomicReference
import scala.annotation.tailrec
import scala.collection.immutable.Queue
import scala.concurrent.duration.FiniteDuration
import ox.discard

/** Algorithms, which take into account the entire duration of the operation.
*
* There is no leakyBucket algorithm implemented, which is present in [[StartTimeRateLimiterAlgorithm]], because effectively it would
* result in "max number of operations currently running", which can be achieved with single semaphore.
*/
object DurationRateLimiterAlgorithm:
Kamil-Lontkowski marked this conversation as resolved.
Show resolved Hide resolved
/** Fixed window algorithm: allows running at most `rate` operations in consecutively segments of duration `per`. Considers whole
* execution time of an operation. Operation spanning more than one window blocks permits in all windows that it spans.
*/
case class FixedWindow(rate: Int, per: FiniteDuration) extends RateLimiterAlgorithm:
private val lastUpdate = new AtomicLong(System.nanoTime())
private val semaphore = new Semaphore(rate)
private val runningOperations = new AtomicInteger(0)

def acquire(permits: Int): Unit =
semaphore.acquire(permits)

def tryAcquire(permits: Int): Boolean =
semaphore.tryAcquire(permits)

def getNextUpdate: Long =
val waitTime = lastUpdate.get() + per.toNanos - System.nanoTime()
if waitTime > 0 then waitTime else 0L

def update(): Unit =
val now = System.nanoTime()
lastUpdate.set(now)
// We treat running operation in new window the same as a new operation that started in this window, so we replenish permits to: rate - operationsRunning
semaphore.release(rate - semaphore.availablePermits() - runningOperations.get())
Kamil-Lontkowski marked this conversation as resolved.
Show resolved Hide resolved
end update

def runOperation[T](operation: => T, permits: Int): T =
runningOperations.updateAndGet(_ + permits)
try operation
finally runningOperations.updateAndGet(_ - permits).discard

end FixedWindow

/** Sliding window algorithm: allows to run at most `rate` operations in the lapse of `per` before current time. Considers whole execution
* time of an operation. Operation release permit after `per` passed since operation ended.
*/
case class SlidingWindow(rate: Int, per: FiniteDuration) extends RateLimiterAlgorithm:
// stores the timestamp and the number of permits acquired after finishing running operation
private val log = new AtomicReference[Queue[(Long, Int)]](Queue[(Long, Int)]())
private val semaphore = new Semaphore(rate)

def acquire(permits: Int): Unit =
semaphore.acquire(permits)

def tryAcquire(permits: Int): Boolean =
semaphore.tryAcquire(permits)

private def addTimestampToLog(permits: Int): Unit =
val now = System.nanoTime()
log.updateAndGet { q =>
q.enqueue((now, permits))
}
()

def getNextUpdate: Long =
log.get().headOption match
case None =>
// no logs so no need to update until `per` has passed
per.toNanos
case Some(record) =>
// oldest log provides the new updating point
val waitTime = record._1 + per.toNanos - System.nanoTime()
if waitTime > 0 then waitTime else 0L
end getNextUpdate

def runOperation[T](operation: => T, permits: Int): T =
try operation
// Consider end of operation as a point to release permit after `per` passes
finally addTimestampToLog(permits)

def update(): Unit =
val now = System.nanoTime()
// retrieving current queue to append it later if some elements were added concurrently
val q = log.getAndUpdate(_ => Queue[(Long, Int)]())
// remove records older than window size
val qUpdated = removeRecords(q, now)
// merge old records with the ones concurrently added
log.updateAndGet(qNew =>
qNew.foldLeft(qUpdated) { case (queue, record) =>
queue.enqueue(record)
}
)
()
end update

@tailrec
private def removeRecords(q: Queue[(Long, Int)], now: Long): Queue[(Long, Int)] =
q.dequeueOption match
case None => q
case Some((head, tail)) =>
if head._1 + per.toNanos < now then
val (_, permits) = head
semaphore.release(permits)
removeRecords(tail, now)
else q
end match
end removeRecords

end SlidingWindow

end DurationRateLimiterAlgorithm
70 changes: 55 additions & 15 deletions core/src/main/scala/ox/resilience/RateLimiter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,24 @@ package ox.resilience

import scala.concurrent.duration.FiniteDuration
import ox.*

import scala.annotation.tailrec

/** Rate limiter with a customizable algorithm. Operations can be blocked or dropped, when the rate limit is reached. */
/** Rate limiter with a customizable algorithm. Operations can be blocked or dropped, when the rate limit is reached. The rate limiter might
* take into account the start time of the operation, or its entire duration.
*/
class RateLimiter private (algorithm: RateLimiterAlgorithm):
/** Runs the operation, blocking if the rate limit is reached, until the rate limiter is replenished. */
def runBlocking[T](operation: => T): T =
algorithm.acquire()
operation
algorithm.runOperation(operation)

/** Runs or drops the operation, if the rate limit is reached.
/** Runs the operation or drops it, if the rate limit is reached.
*
* @return
* `Some` if the operation has been allowed to run, `None` if the operation has been dropped.
* `Some` if the operation has been run, `None` if the operation has been dropped.
*/
def runOrDrop[T](operation: => T): Option[T] =
if algorithm.tryAcquire() then Some(operation)
if algorithm.tryAcquire() then Some(algorithm.runOperation(operation))
else None

end RateLimiter
Expand All @@ -39,32 +40,36 @@ object RateLimiter:
new RateLimiter(algorithm)
end apply

/** Creates a rate limiter using a fixed window algorithm.
/** Creates a rate limiter using a fixed window algorithm. Takes into account the start time of the operation only.
*
* Must be run within an [[Ox]] concurrency scope, as a background fork is created, to replenish the rate limiter.
*
* @param maxOperations
* Maximum number of operations that are allowed to **start** within a time [[window]].
* @param window
* Interval of time between replenishing the rate limiter. THe rate limiter is replenished to allow up to [[maxOperations]] in the next
* Interval of time between replenishing the rate limiter. The rate limiter is replenished to allow up to [[maxOperations]] in the next
* time window.
* @see
* [[fixedWindowWithDuration]]
*/
def fixedWindow(maxOperations: Int, window: FiniteDuration)(using Ox): RateLimiter =
apply(RateLimiterAlgorithm.FixedWindow(maxOperations, window))
def fixedWindowWithStartTime(maxOperations: Int, window: FiniteDuration)(using Ox): RateLimiter =
apply(StartTimeRateLimiterAlgorithm.FixedWindow(maxOperations, window))

/** Creates a rate limiter using a sliding window algorithm.
/** Creates a rate limiter using a sliding window algorithm. Takes into account the start time of the operation only.
*
* Must be run within an [[Ox]] concurrency scope, as a background fork is created, to replenish the rate limiter.
*
* @param maxOperations
* Maximum number of operations that are allowed to **start** within any [[window]] of time.
* @param window
* Length of the window.
* @see
* [[slidingWindowWithDuration]]
*/
def slidingWindow(maxOperations: Int, window: FiniteDuration)(using Ox): RateLimiter =
apply(RateLimiterAlgorithm.SlidingWindow(maxOperations, window))
def slidingWindowWithStartTime(maxOperations: Int, window: FiniteDuration)(using Ox): RateLimiter =
apply(StartTimeRateLimiterAlgorithm.SlidingWindow(maxOperations, window))

/** Rate limiter with token/leaky bucket algorithm.
/** Creates a rate limiter with token/leaky bucket algorithm. Takes into account the start time of the operation only.
*
* Must be run within an [[Ox]] concurrency scope, as a background fork is created, to replenish the rate limiter.
*
Expand All @@ -74,5 +79,40 @@ object RateLimiter:
* Interval of time between adding a single token to the bucket.
*/
def leakyBucket(maxTokens: Int, refillInterval: FiniteDuration)(using Ox): RateLimiter =
apply(RateLimiterAlgorithm.LeakyBucket(maxTokens, refillInterval))
apply(StartTimeRateLimiterAlgorithm.LeakyBucket(maxTokens, refillInterval))

/** Creates a rate limiter with a fixed window algorithm.
*
* Takes into account the entire duration of the operation. That is the instant at which the operation "happens" can be anywhere between
* its start and end. This ensures that the rate limit is always respected, although it might make it more restrictive.
*
* Must be run within an [[Ox]] concurrency scope, as a background fork is created, to replenish the rate limiter.
*
* @param maxOperations
* Maximum number of operations that are allowed to **run** (finishing from previous windows or start new) within a time [[window]].
* @param window
* Length of the window.
* @see
* [[fixedWindowWithStartTime]]
*/
def fixedWindowWithDuration(maxOperations: Int, window: FiniteDuration)(using Ox): RateLimiter =
apply(DurationRateLimiterAlgorithm.FixedWindow(maxOperations, window))

/** Creates a rate limiter using a sliding window algorithm.
*
* Takes into account the entire duration of the operation. That is the instant at which the operation "happens" can be anywhere between
* its start and end. This ensures that the rate limit is always respected, although it might make it more restrictive.
*
* Must be run within an [[Ox]] concurrency scope, as a background fork is created, to replenish the rate limiter.
*
* @param maxOperations
* Maximum number of operations that are allowed to **run** (start or finishing) within any [[window]] of time.
* @param window
* Length of the window.
* @see
* [[slidingWindowWithStartTime]]
*/
def slidingWindowWithDuration(maxOperations: Int, window: FiniteDuration)(using Ox): RateLimiter =
apply(DurationRateLimiterAlgorithm.SlidingWindow(maxOperations, window))

end RateLimiter
118 changes: 4 additions & 114 deletions core/src/main/scala/ox/resilience/RateLimiterAlgorithm.scala
Original file line number Diff line number Diff line change
@@ -1,12 +1,5 @@
package ox.resilience

import scala.concurrent.duration.FiniteDuration
import scala.collection.immutable.Queue
import java.util.concurrent.atomic.AtomicLong
import java.util.concurrent.atomic.AtomicReference
import java.util.concurrent.Semaphore
import scala.annotation.tailrec

/** Determines the algorithm to use for the rate limiter */
trait RateLimiterAlgorithm:

Expand All @@ -30,113 +23,10 @@ trait RateLimiterAlgorithm:
/** Returns the time in nanoseconds that needs to elapse until the next update. It should not modify internal state. */
def getNextUpdate: Long

end RateLimiterAlgorithm

object RateLimiterAlgorithm:
/** Fixed window algorithm: allows starting at most `rate` operations in consecutively segments of duration `per`. */
case class FixedWindow(rate: Int, per: FiniteDuration) extends RateLimiterAlgorithm:
private val lastUpdate = new AtomicLong(System.nanoTime())
private val semaphore = new Semaphore(rate)

def acquire(permits: Int): Unit =
semaphore.acquire(permits)

def tryAcquire(permits: Int): Boolean =
semaphore.tryAcquire(permits)

def getNextUpdate: Long =
val waitTime = lastUpdate.get() + per.toNanos - System.nanoTime()
if waitTime > 0 then waitTime else 0L

def update(): Unit =
val now = System.nanoTime()
lastUpdate.set(now)
semaphore.release(rate - semaphore.availablePermits())
end update

end FixedWindow

/** Sliding window algorithm: allows to start at most `rate` operations in the lapse of `per` before current time. */
case class SlidingWindow(rate: Int, per: FiniteDuration) extends RateLimiterAlgorithm:
// stores the timestamp and the number of permits acquired after calling acquire or tryAcquire successfully
private val log = new AtomicReference[Queue[(Long, Int)]](Queue[(Long, Int)]())
private val semaphore = new Semaphore(rate)

def acquire(permits: Int): Unit =
semaphore.acquire(permits)
addTimestampToLog(permits)

def tryAcquire(permits: Int): Boolean =
if semaphore.tryAcquire(permits) then
addTimestampToLog(permits)
true
else false

private def addTimestampToLog(permits: Int): Unit =
val now = System.nanoTime()
log.updateAndGet { q =>
q.enqueue((now, permits))
}
()

def getNextUpdate: Long =
log.get().headOption match
case None =>
// no logs so no need to update until `per` has passed
per.toNanos
case Some(record) =>
// oldest log provides the new updating point
val waitTime = record._1 + per.toNanos - System.nanoTime()
if waitTime > 0 then waitTime else 0L
end getNextUpdate

def update(): Unit =
val now = System.nanoTime()
// retrieving current queue to append it later if some elements were added concurrently
val q = log.getAndUpdate(_ => Queue[(Long, Int)]())
// remove records older than window size
val qUpdated = removeRecords(q, now)
// merge old records with the ones concurrently added
val _ = log.updateAndGet(qNew =>
qNew.foldLeft(qUpdated) { case (queue, record) =>
queue.enqueue(record)
}
)
end update

@tailrec
private def removeRecords(q: Queue[(Long, Int)], now: Long): Queue[(Long, Int)] =
q.dequeueOption match
case None => q
case Some((head, tail)) =>
if head._1 + per.toNanos < now then
val (_, permits) = head
semaphore.release(permits)
removeRecords(tail, now)
else q

end SlidingWindow

/** Token/leaky bucket algorithm It adds a token to start an new operation each `per` with a maximum number of tokens of `rate`. */
case class LeakyBucket(rate: Int, per: FiniteDuration) extends RateLimiterAlgorithm:
private val refillInterval = per.toNanos
private val lastRefillTime = new AtomicLong(System.nanoTime())
private val semaphore = new Semaphore(1)

def acquire(permits: Int): Unit =
semaphore.acquire(permits)

def tryAcquire(permits: Int): Boolean =
semaphore.tryAcquire(permits)

def getNextUpdate: Long =
val waitTime = lastRefillTime.get() + refillInterval - System.nanoTime()
if waitTime > 0 then waitTime else 0L
/** Runs the operation, allowing the algorithm to take into account its duration, if needed. */
final def runOperation[T](operation: => T): T = runOperation(operation, 1)

def update(): Unit =
val now = System.nanoTime()
lastRefillTime.set(now)
if semaphore.availablePermits() < rate then semaphore.release()
/** Runs the operation, allowing the algorithm to take into account its duration, if needed. */
def runOperation[T](operation: => T, permits: Int): T

end LeakyBucket
end RateLimiterAlgorithm
Loading