diff --git a/codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala b/codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala index 14903c46e..4e6dd4bf8 100644 --- a/codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala +++ b/codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala @@ -7,9 +7,11 @@ import io.shiftleft.utils.StatsLogger import org.slf4j.{Logger, LoggerFactory, MDC} import overflowdb.BatchedUpdate +import java.util.concurrent.{TimeUnit, TimeoutException} import java.util.function.{BiConsumer, Supplier} import scala.annotation.nowarn -import scala.concurrent.duration.DurationLong +import scala.concurrent.{Await, ExecutionContext, Future} +import scala.concurrent.duration.{Duration, DurationLong} import scala.util.{Failure, Success, Try} /* CpgPass @@ -55,6 +57,56 @@ abstract class CpgPass(cpg: Cpg, outName: String = "", keyPool: Option[KeyPool] * methods. This may be better than using the constructor or GC, because e.g. SCPG chains of passes construct * passes eagerly, and releases them only when the entire chain has run. * */ +abstract class ForkJoinParallelCpgPassWithTimeout[T <: AnyRef]( + cpg: Cpg, + @nowarn outName: String = "", + keyPool: Option[KeyPool] = None, + timeout: Long = -1 +) extends NewStyleCpgPassBaseWithTimeout[T](timeout) { + + override def createApplySerializeAndStore(serializedCpg: SerializedCpg, prefix: String = ""): Unit = { + baseLogger.info(s"Start of pass: $name") + StatsLogger.initiateNewStage(getClass.getSimpleName, Some(name), getClass.getSuperclass.getSimpleName) + val nanosStart = System.nanoTime() + var nParts = 0 + var nanosBuilt = -1L + var nDiff = -1 + var nDiffT = -1 + try { + val diffGraph = new DiffGraphBuilder + nParts = runWithBuilder(diffGraph) + nanosBuilt = System.nanoTime() + nDiff = diffGraph.size() + + nDiffT = overflowdb.BatchedUpdate + .applyDiff(cpg.graph, diffGraph, keyPool.getOrElse(null), null) + .transitiveModifications() + + } catch { + case exc: Exception => + baseLogger.error(s"Pass ${name} failed", exc) + throw exc + } finally { + try { + finish() + } finally { + // the nested finally is somewhat ugly -- but we promised to clean up with finish(), we want to include finish() + // in the reported timings, and we must have our final log message if finish() throws + val nanosStop = System.nanoTime() + val fracRun = if (nanosBuilt == -1) 0.0 else (nanosStop - nanosBuilt) * 100.0 / (nanosStop - nanosStart + 1) + val serializationString = if (serializedCpg != null && !serializedCpg.isEmpty) { + " Diff serialized and stored." + } else "" + baseLogger.info( + f"Pass $name completed in ${(nanosStop - nanosStart) * 1e-6}%.0f ms (${fracRun}%.0f%% on mutations). ${nDiff}%d + ${nDiffT - nDiff}%d changes committed from ${nParts}%d parts.${serializationString}%s" + ) + StatsLogger.endLastStage() + } + } + } + +} + abstract class ForkJoinParallelCpgPass[T <: AnyRef]( cpg: Cpg, @nowarn outName: String = "", @@ -168,6 +220,91 @@ abstract class NewStyleCpgPassBase[T <: AnyRef] extends CpgPassBase { } } +abstract class NewStyleCpgPassBaseWithTimeout[T <: AnyRef](timeout: Long) extends CpgPassBase { + type DiffGraphBuilder = overflowdb.BatchedUpdate.DiffGraphBuilder + + // generate Array of parts that can be processed in parallel + def generateParts(): Array[? <: AnyRef] + + // setup large data structures, acquire external resources + def init(): Unit = {} + + // release large data structures and external resources + def finish(): Unit = {} + + // main function: add desired changes to builder + def runOnPart(builder: DiffGraphBuilder, part: T): Unit + + // Override this to disable parallelism of passes. Useful for debugging. + def isParallel: Boolean = true + + override def createAndApply(): Unit = createApplySerializeAndStore(null) + + override def runWithBuilder(externalBuilder: BatchedUpdate.DiffGraphBuilder): Int = { + try { + init() + val parts = generateParts() + val nParts = parts.size + nParts match { + case 0 => + case 1 => + runOnPart(externalBuilder, parts(0).asInstanceOf[T]) + case _ => + if (!isParallel) { + val diff = java.util.Arrays + .stream(parts) + .sequential() + .collect( + new Supplier[DiffGraphBuilder] { + override def get(): DiffGraphBuilder = + new DiffGraphBuilder + }, + new BiConsumer[DiffGraphBuilder, AnyRef] { + override def accept(builder: DiffGraphBuilder, part: AnyRef): Unit = + runOnPart(builder, part.asInstanceOf[T]) + }, + new BiConsumer[DiffGraphBuilder, DiffGraphBuilder] { + override def accept(leftBuilder: DiffGraphBuilder, rightBuilder: DiffGraphBuilder): Unit = + leftBuilder.absorb(rightBuilder) + } + ) + externalBuilder.absorb(diff) + } else { + implicit val ec: scala.concurrent.ExecutionContext = scala.concurrent.ExecutionContext.global + val stopAt = System.currentTimeMillis() + timeout * 1000 + var exitByTimeout = false + val diffGraphAccumulator = new DiffGraphBuilder + + val futures = parts.map { part => + val future = Future { + val diffGraphBuilder = new DiffGraphBuilder + runOnPart(diffGraphBuilder, part.asInstanceOf[T]) + diffGraphBuilder + } + future + } + + futures.foreach { future => + val currentTimeInMs = System.currentTimeMillis() + val duration = + if timeout == -1 then Duration.Inf else Duration(stopAt - currentTimeInMs, TimeUnit.MILLISECONDS) + Try(Await.result(future, duration)) match + case Failure(exception: TimeoutException) => + baseLogger.debug(s"Timeout occurred for passed timeout value of ${timeout} seconds") + case Failure(e) => throw e + case Success(diffGraphBuilder) => + diffGraphAccumulator.absorb(diffGraphBuilder) + } + externalBuilder.absorb(diffGraphAccumulator) + } + } + nParts + } finally { + finish() + } + } +} + object CpgPassBase { private val baseLogger: Logger = LoggerFactory.getLogger(classOf[CpgPassBase]) } diff --git a/codepropertygraph/src/test/scala/io/shiftleft/passes/ParallelCpgPassNewTests.scala b/codepropertygraph/src/test/scala/io/shiftleft/passes/ParallelCpgPassNewTests.scala index fa9f6d28a..8be62b501 100644 --- a/codepropertygraph/src/test/scala/io/shiftleft/passes/ParallelCpgPassNewTests.scala +++ b/codepropertygraph/src/test/scala/io/shiftleft/passes/ParallelCpgPassNewTests.scala @@ -104,3 +104,37 @@ class ParallelCpgPassNewTests extends AnyWordSpec with Matchers { } } + +class ForkJoinParallelCpgPassNewTests extends AnyWordSpec with Matchers { + + private object Fixture { + def apply(keyPools: Option[Iterator[KeyPool]] = None, timeout: Long = -1)(f: (Cpg, CpgPassBase) => Unit): Unit = { + val cpg = Cpg.empty + val pool = keyPools.flatMap(_.nextOption()) + class MyPass(cpg: Cpg) + extends ForkJoinParallelCpgPassWithTimeout[String](cpg, "MyPass", pool, timeout = timeout) { + override def generateParts(): Array[String] = Range(1, 101).map(_.toString).toArray + + override def runOnPart(diffGraph: DiffGraphBuilder, part: String): Unit = { + Thread.sleep(1000) + diffGraph.addNode(NewFile().name(part)) + } + } + val pass = new MyPass(cpg) + f(cpg, pass) + } + } + + "ForkJoinParallelPassWithTimeout" should { + "generate partial result in case of timeout" in Fixture(timeout = 2) { (cpg, pass) => + pass.createAndApply() + assert(cpg.graph.nodes.map(_.property(Properties.Name)).toList.size != 100) + } + + "generate complete result without timeout" in Fixture() { (cpg, pass) => + pass.createAndApply() + assert(cpg.graph.nodes.map(_.property(Properties.Name)).toList.size == 100) + } + } + +}