diff --git a/codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala b/codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala index e1e311172..eb4cfb248 100644 --- a/codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala +++ b/codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala @@ -12,7 +12,6 @@ import java.util.function.{BiConsumer, Supplier} import scala.annotation.nowarn import scala.concurrent.{Await, ExecutionContext, Future} import scala.concurrent.duration.{Duration, DurationLong} -import scala.util.control.Breaks.{break, breakable} import scala.util.{Failure, Success, Try} /* CpgPass @@ -251,48 +250,53 @@ abstract class NewStyleCpgPassBaseWithTimeout[T <: AnyRef](timeout: Long) extend case 1 => runOnPart(externalBuilder, parts(0).asInstanceOf[T]) case _ => - val stream = - if (!isParallel) - java.util.Arrays - .stream(parts) - .sequential() - else - java.util.Arrays - .stream(parts) - .parallel() - - implicit val ec: scala.concurrent.ExecutionContext = scala.concurrent.ExecutionContext.global - val stopAt = System.currentTimeMillis() + timeout * 1000 - var exitByTimeout = false - val diffGraphAccumulator = Cpg.newDiffGraphBuilder - breakable { - stream.forEach { part => - val currentTimeInMs = System.currentTimeMillis() - if (timeout == -1 || currentTimeInMs < stopAt) { - val future = Future { - val diffGraphBuilder = Cpg.newDiffGraphBuilder - runOnPart(diffGraphBuilder, part.asInstanceOf[T]) - diffGraphBuilder - } - val duration = - if timeout == -1 then Duration.Inf else Duration(stopAt - currentTimeInMs, TimeUnit.MILLISECONDS) - Try(Await.result(future, duration)) match { - case Success(diffGraphBuilder: DiffGraphBuilder) => - synchronized( - diffGraphAccumulator.absorb(diffGraphBuilder) - ) // Writing to diffGraph needs to be thread safe - case Failure(exception: TimeoutException) => println(s"Encountered timeout at thread level") - case Failure(e) => throw e + if (!isParallel) { + val diff = java.util.Arrays + .stream(parts) + .sequential() + .collect( + new Supplier[DiffGraphBuilder] { + override def get(): DiffGraphBuilder = + Cpg.newDiffGraphBuilder + }, + new BiConsumer[DiffGraphBuilder, AnyRef] { + override def accept(builder: DiffGraphBuilder, part: AnyRef): Unit = + runOnPart(builder, part.asInstanceOf[T]) + }, + new BiConsumer[DiffGraphBuilder, DiffGraphBuilder] { + override def accept(leftBuilder: DiffGraphBuilder, rightBuilder: DiffGraphBuilder): Unit = + leftBuilder.absorb(rightBuilder) } - } else { - exitByTimeout = true - break + ) + externalBuilder.absorb(diff) + } else { + implicit val ec: scala.concurrent.ExecutionContext = scala.concurrent.ExecutionContext.global + val stopAt = System.currentTimeMillis() + timeout * 1000 + var exitByTimeout = false + val diffGraphAccumulator = Cpg.newDiffGraphBuilder + + val futures = parts.map { part => + val future = Future { + val diffGraphBuilder = Cpg.newDiffGraphBuilder + runOnPart(diffGraphBuilder, part.asInstanceOf[T]) + diffGraphBuilder } + future + } + + futures.foreach { future => + val currentTimeInMs = System.currentTimeMillis() + val duration = + if timeout == -1 then Duration.Inf else Duration(stopAt - currentTimeInMs, TimeUnit.MILLISECONDS) + Try(Await.result(future, duration)) match + case Failure(exception: TimeoutException) => + baseLogger.debug(s"Timeout occurred for passed timeout value of ${timeout} seconds") + case Failure(e) => throw e + case Success(diffGraphBuilder) => + diffGraphAccumulator.absorb(diffGraphBuilder) } + externalBuilder.absorb(diffGraphAccumulator) } - if (exitByTimeout) - println("Timeout exception encountered, continuing with partial result") - externalBuilder.absorb(diffGraphAccumulator) } nParts } finally {