From 7a0b58a2ecd9d7af9783ed8a99ed3f583508fc23 Mon Sep 17 00:00:00 2001 From: Khemraj Rathore Date: Wed, 24 Jul 2024 14:57:14 +0530 Subject: [PATCH] initial changes --- .../scala/io/shiftleft/passes/CpgPass.scala | 137 +++++++++++++++++- .../passes/ParallelCpgPassNewTests.scala | 34 +++++ 2 files changed, 169 insertions(+), 2 deletions(-) diff --git a/codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala b/codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala index e856e455e..e1e311172 100644 --- a/codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala +++ b/codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala @@ -2,14 +2,17 @@ package io.shiftleft.passes import com.google.protobuf.GeneratedMessageV3 import io.shiftleft.SerializedCpg -import io.shiftleft.codepropertygraph.generated.Cpg +import io.shiftleft.codepropertygraph.generated.{Cpg, DiffGraphBuilder} import io.shiftleft.utils.StatsLogger import org.slf4j.{Logger, LoggerFactory, MDC} import overflowdb.BatchedUpdate +import java.util.concurrent.{TimeUnit, TimeoutException} import java.util.function.{BiConsumer, Supplier} import scala.annotation.nowarn -import scala.concurrent.duration.DurationLong +import scala.concurrent.{Await, ExecutionContext, Future} +import scala.concurrent.duration.{Duration, DurationLong} +import scala.util.control.Breaks.{break, breakable} import scala.util.{Failure, Success, Try} /* CpgPass @@ -55,6 +58,56 @@ abstract class CpgPass(cpg: Cpg, outName: String = "", keyPool: Option[KeyPool] * methods. This may be better than using the constructor or GC, because e.g. SCPG chains of passes construct * passes eagerly, and releases them only when the entire chain has run. * */ +abstract class ForkJoinParallelCpgPassWithTimeout[T <: AnyRef]( + cpg: Cpg, + @nowarn outName: String = "", + keyPool: Option[KeyPool] = None, + timeout: Long = -1 +) extends NewStyleCpgPassBaseWithTimeout[T](timeout) { + + override def createApplySerializeAndStore(serializedCpg: SerializedCpg, prefix: String = ""): Unit = { + baseLogger.info(s"Start of pass: $name") + StatsLogger.initiateNewStage(getClass.getSimpleName, Some(name), getClass.getSuperclass.getSimpleName) + val nanosStart = System.nanoTime() + var nParts = 0 + var nanosBuilt = -1L + var nDiff = -1 + var nDiffT = -1 + try { + val diffGraph = Cpg.newDiffGraphBuilder + nParts = runWithBuilder(diffGraph) + nanosBuilt = System.nanoTime() + nDiff = diffGraph.size() + + nDiffT = overflowdb.BatchedUpdate + .applyDiff(cpg.graph, diffGraph, keyPool.getOrElse(null), null) + .transitiveModifications() + + } catch { + case exc: Exception => + baseLogger.error(s"Pass ${name} failed", exc) + throw exc + } finally { + try { + finish() + } finally { + // the nested finally is somewhat ugly -- but we promised to clean up with finish(), we want to include finish() + // in the reported timings, and we must have our final log message if finish() throws + val nanosStop = System.nanoTime() + val fracRun = if (nanosBuilt == -1) 0.0 else (nanosStop - nanosBuilt) * 100.0 / (nanosStop - nanosStart + 1) + val serializationString = if (serializedCpg != null && !serializedCpg.isEmpty) { + " Diff serialized and stored." + } else "" + baseLogger.info( + f"Pass $name completed in ${(nanosStop - nanosStart) * 1e-6}%.0f ms (${fracRun}%.0f%% on mutations). ${nDiff}%d + ${nDiffT - nDiff}%d changes committed from ${nParts}%d parts.${serializationString}%s" + ) + StatsLogger.endLastStage() + } + } + } + +} + abstract class ForkJoinParallelCpgPass[T <: AnyRef]( cpg: Cpg, @nowarn outName: String = "", @@ -168,6 +221,86 @@ abstract class NewStyleCpgPassBase[T <: AnyRef] extends CpgPassBase { } } +abstract class NewStyleCpgPassBaseWithTimeout[T <: AnyRef](timeout: Long) extends CpgPassBase { + type DiffGraphBuilder = overflowdb.BatchedUpdate.DiffGraphBuilder + + // generate Array of parts that can be processed in parallel + def generateParts(): Array[? <: AnyRef] + + // setup large data structures, acquire external resources + def init(): Unit = {} + + // release large data structures and external resources + def finish(): Unit = {} + + // main function: add desired changes to builder + def runOnPart(builder: DiffGraphBuilder, part: T): Unit + + // Override this to disable parallelism of passes. Useful for debugging. + def isParallel: Boolean = true + + override def createAndApply(): Unit = createApplySerializeAndStore(null) + + override def runWithBuilder(externalBuilder: BatchedUpdate.DiffGraphBuilder): Int = { + try { + init() + val parts = generateParts() + val nParts = parts.size + nParts match { + case 0 => + case 1 => + runOnPart(externalBuilder, parts(0).asInstanceOf[T]) + case _ => + val stream = + if (!isParallel) + java.util.Arrays + .stream(parts) + .sequential() + else + java.util.Arrays + .stream(parts) + .parallel() + + implicit val ec: scala.concurrent.ExecutionContext = scala.concurrent.ExecutionContext.global + val stopAt = System.currentTimeMillis() + timeout * 1000 + var exitByTimeout = false + val diffGraphAccumulator = Cpg.newDiffGraphBuilder + breakable { + stream.forEach { part => + val currentTimeInMs = System.currentTimeMillis() + if (timeout == -1 || currentTimeInMs < stopAt) { + val future = Future { + val diffGraphBuilder = Cpg.newDiffGraphBuilder + runOnPart(diffGraphBuilder, part.asInstanceOf[T]) + diffGraphBuilder + } + val duration = + if timeout == -1 then Duration.Inf else Duration(stopAt - currentTimeInMs, TimeUnit.MILLISECONDS) + Try(Await.result(future, duration)) match { + case Success(diffGraphBuilder: DiffGraphBuilder) => + synchronized( + diffGraphAccumulator.absorb(diffGraphBuilder) + ) // Writing to diffGraph needs to be thread safe + case Failure(exception: TimeoutException) => println(s"Encountered timeout at thread level") + case Failure(e) => throw e + } + } else { + exitByTimeout = true + break + } + } + } + if (exitByTimeout) + println("Timeout exception encountered, continuing with partial result") + externalBuilder.absorb(diffGraphAccumulator) + } + nParts + } finally { + finish() + } + } +} + object CpgPassBase { private val baseLogger: Logger = LoggerFactory.getLogger(classOf[CpgPassBase]) } diff --git a/codepropertygraph/src/test/scala/io/shiftleft/passes/ParallelCpgPassNewTests.scala b/codepropertygraph/src/test/scala/io/shiftleft/passes/ParallelCpgPassNewTests.scala index 96992544b..3f7a9829f 100644 --- a/codepropertygraph/src/test/scala/io/shiftleft/passes/ParallelCpgPassNewTests.scala +++ b/codepropertygraph/src/test/scala/io/shiftleft/passes/ParallelCpgPassNewTests.scala @@ -104,3 +104,37 @@ class ParallelCpgPassNewTests extends AnyWordSpec with Matchers { } } + +class ForkJoinParallelCpgPassNewTests extends AnyWordSpec with Matchers { + + private object Fixture { + def apply(keyPools: Option[Iterator[KeyPool]] = None, timeout: Long = -1)(f: (Cpg, CpgPassBase) => Unit): Unit = { + val cpg = Cpg.empty + val pool = keyPools.flatMap(_.nextOption()) + class MyPass(cpg: Cpg) + extends ForkJoinParallelCpgPassWithTimeout[String](cpg, "MyPass", pool, timeout = timeout) { + override def generateParts(): Array[String] = Range(1, 101).map(_.toString).toArray + + override def runOnPart(diffGraph: DiffGraphBuilder, part: String): Unit = { + Thread.sleep(1000) + diffGraph.addNode(NewFile().name(part)) + } + } + val pass = new MyPass(cpg) + f(cpg, pass) + } + } + + "ForkJoinParallelPassWithTimeout" should { + "generate partial result in case of timeout" in Fixture(timeout = 2) { (cpg, pass) => + pass.createAndApply() + assert(cpg.graph.nodes.map(_.property(Properties.Name)).toList.size != 100) + } + + "generate complete result without timeout" in Fixture() { (cpg, pass) => + pass.createAndApply() + assert(cpg.graph.nodes.map(_.property(Properties.Name)).toList.size == 100) + } + } + +}