Skip to content

Commit

Permalink
fix #164 (#168)
Browse files Browse the repository at this point in the history
  • Loading branch information
lbialy authored Jul 9, 2024
1 parent f3332bd commit 3228cc2
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 5 deletions.
5 changes: 4 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ import com.softwaremill.SbtSoftwareMillCommon.commonSmlBuildSettings
import com.softwaremill.Publish.{ossPublishSettings, updateDocs}
import com.softwaremill.UpdateVersionInDocs

Global / cancelable := true

lazy val commonSettings = commonSmlBuildSettings ++ ossPublishSettings ++ Seq(
organization := "com.softwaremill.ox",
scalaVersion := "3.3.3",
Expand Down Expand Up @@ -50,7 +52,8 @@ lazy val core: Project = (project in file("core"))
scalaTest
),
// Check IO usage in core
useRequireIOPlugin
useRequireIOPlugin,
Test / fork := true
)

lazy val plugin: Project = (project in file("plugin"))
Expand Down
21 changes: 17 additions & 4 deletions core/src/main/scala/ox/fork.scala
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def forkAll[T](fs: Seq[() => T])(using Ox): Fork[Seq[T]] =
val forks = fs.map(f => fork(f()))
new Fork[Seq[T]]:
override def join(): Seq[T] = forks.map(_.join())
override private[ox] def wasInterruptedWith(ie: InterruptedException): Boolean = forks.exists(_.wasInterruptedWith(ie))

/** Starts a fork (logical thread of execution), which is guaranteed to complete before the enclosing [[supervised]], [[supervisedError]] or
* [[unsupervised]] block completes, and which can be cancelled on-demand.
Expand Down Expand Up @@ -177,8 +178,13 @@ def forkCancellable[T](f: => T)(using OxUnsupervised): CancellableFork[T] =
if !started.getAndSet(true)
then result.completeExceptionally(new InterruptedException("fork was cancelled before it started")).discard

override private[ox] def wasInterruptedWith(ie: InterruptedException): Boolean =
result.isCompletedExceptionally && (result.exceptionNow() eq ie)

private def newForkUsingResult[T](result: CompletableFuture[T]): Fork[T] = new Fork[T]:
override def join(): T = unwrapExecutionException(result.get())
override private[ox] def wasInterruptedWith(ie: InterruptedException): Boolean =
result.isCompletedExceptionally && (result.exceptionNow() eq ie)

private[ox] inline def unwrapExecutionException[T](f: => T): T =
try f
Expand Down Expand Up @@ -208,16 +214,23 @@ trait Fork[T]:
def joinEither(): Either[Throwable, T] =
try Right(join())
catch
// normally IE is fatal, but here it was meant to cancel the fork, not the joining parent, hence we catch it
case e: InterruptedException => Left(e)
// normally IE is fatal, but here it could have meant that the fork was cancelled, hence we catch it
// we do discern between the fork and the current thread being cancelled and rethrow if it's us who's getting the axe
case e: InterruptedException => if wasInterruptedWith(e) then Left(e) else throw e
case NonFatal(e) => Left(e)

private[ox] def wasInterruptedWith(ie: InterruptedException): Boolean

object Fork:
/** A dummy pretending to represent a fork which successfully completed with the given value. */
def successful[T](value: T): Fork[T] = () => value
def successful[T](value: T): Fork[T] = new Fork[T]:
override def join(): T = value
override private[ox] def wasInterruptedWith(ie: InterruptedException): Boolean = false

/** A dummy pretending to represent a fork which failed with the given exception. */
def failed[T](e: Throwable): Fork[T] = () => throw e
def failed[T](e: Throwable): Fork[T] = new Fork[T]:
override def join(): T = throw e
override private[ox] def wasInterruptedWith(ie: InterruptedException): Boolean = e eq ie

/** A fork started using [[forkCancellable]], backed by a (virtual) thread. */
trait CancellableFork[T] extends Fork[T]:
Expand Down
27 changes: 27 additions & 0 deletions core/src/test/scala/ox/SupervisedTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -123,4 +123,31 @@ class SupervisedTest extends AnyFlatSpec with Matchers {
trail.add("done")
trail.get shouldBe Vector("b", "a", "done")
}

it should "handle interruption of multiple forks with `joinEither` correctly" in {
val e = intercept[Exception] {
supervised {
def computation(withException: Option[String]): Int = {
withException match
case None => 1
case Some(value) =>
throw new Exception(value)
}

val fork1 = fork:
computation(withException = None)
val fork2 = fork:
computation(withException = Some("Oh no!"))
val fork3 = fork:
computation(withException = Some("Oh well.."))

fork1.joinEither() // 1
fork2.joinEither() // 2
fork3.joinEither() // 3
}
}

e.getMessage should startWith("Oh")
}

}

0 comments on commit 3228cc2

Please sign in to comment.