Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduction of timeout in passes #5

Merged
merged 2 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 139 additions & 2 deletions codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@ package io.shiftleft.passes

import com.google.protobuf.GeneratedMessageV3
import io.shiftleft.SerializedCpg
import io.shiftleft.codepropertygraph.generated.Cpg
import io.shiftleft.codepropertygraph.generated.{Cpg, DiffGraphBuilder}
import io.shiftleft.utils.StatsLogger
import org.slf4j.{Logger, LoggerFactory, MDC}
import overflowdb.BatchedUpdate

import java.util.concurrent.{TimeUnit, TimeoutException}
import java.util.function.{BiConsumer, Supplier}
import scala.annotation.nowarn
import scala.concurrent.duration.DurationLong
import scala.concurrent.{Await, ExecutionContext, Future}
import scala.concurrent.duration.{Duration, DurationLong}
import scala.util.{Failure, Success, Try}

/* CpgPass
Expand Down Expand Up @@ -55,6 +57,56 @@ abstract class CpgPass(cpg: Cpg, outName: String = "", keyPool: Option[KeyPool]
* methods. This may be better than using the constructor or GC, because e.g. SCPG chains of passes construct
* passes eagerly, and releases them only when the entire chain has run.
* */
abstract class ForkJoinParallelCpgPassWithTimeout[T <: AnyRef](
cpg: Cpg,
@nowarn outName: String = "",
keyPool: Option[KeyPool] = None,
timeout: Long = -1
) extends NewStyleCpgPassBaseWithTimeout[T](timeout) {

override def createApplySerializeAndStore(serializedCpg: SerializedCpg, prefix: String = ""): Unit = {
baseLogger.info(s"Start of pass: $name")
StatsLogger.initiateNewStage(getClass.getSimpleName, Some(name), getClass.getSuperclass.getSimpleName)
val nanosStart = System.nanoTime()
var nParts = 0
var nanosBuilt = -1L
var nDiff = -1
var nDiffT = -1
try {
val diffGraph = Cpg.newDiffGraphBuilder
nParts = runWithBuilder(diffGraph)
nanosBuilt = System.nanoTime()
nDiff = diffGraph.size()

nDiffT = overflowdb.BatchedUpdate
.applyDiff(cpg.graph, diffGraph, keyPool.getOrElse(null), null)
.transitiveModifications()

} catch {
case exc: Exception =>
baseLogger.error(s"Pass ${name} failed", exc)
throw exc
} finally {
try {
finish()
} finally {
// the nested finally is somewhat ugly -- but we promised to clean up with finish(), we want to include finish()
// in the reported timings, and we must have our final log message if finish() throws
val nanosStop = System.nanoTime()
val fracRun = if (nanosBuilt == -1) 0.0 else (nanosStop - nanosBuilt) * 100.0 / (nanosStop - nanosStart + 1)
val serializationString = if (serializedCpg != null && !serializedCpg.isEmpty) {
" Diff serialized and stored."
} else ""
baseLogger.info(
f"Pass $name completed in ${(nanosStop - nanosStart) * 1e-6}%.0f ms (${fracRun}%.0f%% on mutations). ${nDiff}%d + ${nDiffT - nDiff}%d changes committed from ${nParts}%d parts.${serializationString}%s"
)
StatsLogger.endLastStage()
}
}
}

}

abstract class ForkJoinParallelCpgPass[T <: AnyRef](
cpg: Cpg,
@nowarn outName: String = "",
Expand Down Expand Up @@ -168,6 +220,91 @@ abstract class NewStyleCpgPassBase[T <: AnyRef] extends CpgPassBase {
}
}

abstract class NewStyleCpgPassBaseWithTimeout[T <: AnyRef](timeout: Long) extends CpgPassBase {
type DiffGraphBuilder = overflowdb.BatchedUpdate.DiffGraphBuilder

// generate Array of parts that can be processed in parallel
def generateParts(): Array[? <: AnyRef]

// setup large data structures, acquire external resources
def init(): Unit = {}

// release large data structures and external resources
def finish(): Unit = {}

// main function: add desired changes to builder
def runOnPart(builder: DiffGraphBuilder, part: T): Unit

// Override this to disable parallelism of passes. Useful for debugging.
def isParallel: Boolean = true

override def createAndApply(): Unit = createApplySerializeAndStore(null)

override def runWithBuilder(externalBuilder: BatchedUpdate.DiffGraphBuilder): Int = {
try {
init()
val parts = generateParts()
val nParts = parts.size
nParts match {
case 0 =>
case 1 =>
runOnPart(externalBuilder, parts(0).asInstanceOf[T])
case _ =>
if (!isParallel) {
val diff = java.util.Arrays
.stream(parts)
.sequential()
.collect(
new Supplier[DiffGraphBuilder] {
override def get(): DiffGraphBuilder =
Cpg.newDiffGraphBuilder
},
new BiConsumer[DiffGraphBuilder, AnyRef] {
override def accept(builder: DiffGraphBuilder, part: AnyRef): Unit =
runOnPart(builder, part.asInstanceOf[T])
},
new BiConsumer[DiffGraphBuilder, DiffGraphBuilder] {
override def accept(leftBuilder: DiffGraphBuilder, rightBuilder: DiffGraphBuilder): Unit =
leftBuilder.absorb(rightBuilder)
}
)
externalBuilder.absorb(diff)
} else {
implicit val ec: scala.concurrent.ExecutionContext = scala.concurrent.ExecutionContext.global
val stopAt = System.currentTimeMillis() + timeout * 1000
var exitByTimeout = false
val diffGraphAccumulator = Cpg.newDiffGraphBuilder

val futures = parts.map { part =>
val future = Future {
val diffGraphBuilder = Cpg.newDiffGraphBuilder
runOnPart(diffGraphBuilder, part.asInstanceOf[T])
diffGraphBuilder
}
future
}

futures.foreach { future =>
val currentTimeInMs = System.currentTimeMillis()
val duration =
if timeout == -1 then Duration.Inf else Duration(stopAt - currentTimeInMs, TimeUnit.MILLISECONDS)
Try(Await.result(future, duration)) match
case Failure(exception: TimeoutException) =>
baseLogger.debug(s"Timeout occurred for passed timeout value of ${timeout} seconds")
case Failure(e) => throw e
case Success(diffGraphBuilder) =>
diffGraphAccumulator.absorb(diffGraphBuilder)
}
externalBuilder.absorb(diffGraphAccumulator)
}
}
nParts
} finally {
finish()
}
}
}

object CpgPassBase {
private val baseLogger: Logger = LoggerFactory.getLogger(classOf[CpgPassBase])
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,37 @@ class ParallelCpgPassNewTests extends AnyWordSpec with Matchers {
}

}

class ForkJoinParallelCpgPassNewTests extends AnyWordSpec with Matchers {

private object Fixture {
def apply(keyPools: Option[Iterator[KeyPool]] = None, timeout: Long = -1)(f: (Cpg, CpgPassBase) => Unit): Unit = {
val cpg = Cpg.empty
val pool = keyPools.flatMap(_.nextOption())
class MyPass(cpg: Cpg)
extends ForkJoinParallelCpgPassWithTimeout[String](cpg, "MyPass", pool, timeout = timeout) {
override def generateParts(): Array[String] = Range(1, 101).map(_.toString).toArray

override def runOnPart(diffGraph: DiffGraphBuilder, part: String): Unit = {
Thread.sleep(1000)
diffGraph.addNode(NewFile().name(part))
}
}
val pass = new MyPass(cpg)
f(cpg, pass)
}
}

"ForkJoinParallelPassWithTimeout" should {
"generate partial result in case of timeout" in Fixture(timeout = 2) { (cpg, pass) =>
pass.createAndApply()
assert(cpg.graph.nodes.map(_.property(Properties.Name)).toList.size != 100)
}

"generate complete result without timeout" in Fixture() { (cpg, pass) =>
pass.createAndApply()
assert(cpg.graph.nodes.map(_.property(Properties.Name)).toList.size == 100)
}
}

}
Loading