Skip to content

Commit

Permalink
PR review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
khemrajrathore committed Jul 25, 2024
1 parent 7a0b58a commit 90f97b5
Showing 1 changed file with 43 additions and 39 deletions.
82 changes: 43 additions & 39 deletions codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import java.util.function.{BiConsumer, Supplier}
import scala.annotation.nowarn
import scala.concurrent.{Await, ExecutionContext, Future}
import scala.concurrent.duration.{Duration, DurationLong}
import scala.util.control.Breaks.{break, breakable}
import scala.util.{Failure, Success, Try}

/* CpgPass
Expand Down Expand Up @@ -251,48 +250,53 @@ abstract class NewStyleCpgPassBaseWithTimeout[T <: AnyRef](timeout: Long) extend
case 1 =>
runOnPart(externalBuilder, parts(0).asInstanceOf[T])
case _ =>
val stream =
if (!isParallel)
java.util.Arrays
.stream(parts)
.sequential()
else
java.util.Arrays
.stream(parts)
.parallel()

implicit val ec: scala.concurrent.ExecutionContext = scala.concurrent.ExecutionContext.global
val stopAt = System.currentTimeMillis() + timeout * 1000
var exitByTimeout = false
val diffGraphAccumulator = Cpg.newDiffGraphBuilder
breakable {
stream.forEach { part =>
val currentTimeInMs = System.currentTimeMillis()
if (timeout == -1 || currentTimeInMs < stopAt) {
val future = Future {
val diffGraphBuilder = Cpg.newDiffGraphBuilder
runOnPart(diffGraphBuilder, part.asInstanceOf[T])
diffGraphBuilder
}
val duration =
if timeout == -1 then Duration.Inf else Duration(stopAt - currentTimeInMs, TimeUnit.MILLISECONDS)
Try(Await.result(future, duration)) match {
case Success(diffGraphBuilder: DiffGraphBuilder) =>
synchronized(
diffGraphAccumulator.absorb(diffGraphBuilder)
) // Writing to diffGraph needs to be thread safe
case Failure(exception: TimeoutException) => println(s"Encountered timeout at thread level")
case Failure(e) => throw e
if (!isParallel) {
val diff = java.util.Arrays
.stream(parts)
.sequential()
.collect(
new Supplier[DiffGraphBuilder] {
override def get(): DiffGraphBuilder =
Cpg.newDiffGraphBuilder
},
new BiConsumer[DiffGraphBuilder, AnyRef] {
override def accept(builder: DiffGraphBuilder, part: AnyRef): Unit =
runOnPart(builder, part.asInstanceOf[T])
},
new BiConsumer[DiffGraphBuilder, DiffGraphBuilder] {
override def accept(leftBuilder: DiffGraphBuilder, rightBuilder: DiffGraphBuilder): Unit =
leftBuilder.absorb(rightBuilder)
}
} else {
exitByTimeout = true
break
)
externalBuilder.absorb(diff)
} else {
implicit val ec: scala.concurrent.ExecutionContext = scala.concurrent.ExecutionContext.global
val stopAt = System.currentTimeMillis() + timeout * 1000
var exitByTimeout = false
val diffGraphAccumulator = Cpg.newDiffGraphBuilder

val futures = parts.map { part =>
val future = Future {
val diffGraphBuilder = Cpg.newDiffGraphBuilder
runOnPart(diffGraphBuilder, part.asInstanceOf[T])
diffGraphBuilder
}
future
}

futures.foreach { future =>
val currentTimeInMs = System.currentTimeMillis()
val duration =
if timeout == -1 then Duration.Inf else Duration(stopAt - currentTimeInMs, TimeUnit.MILLISECONDS)
Try(Await.result(future, duration)) match
case Failure(exception: TimeoutException) =>
baseLogger.debug(s"Timeout occurred for passed timeout value of ${timeout} seconds")
case Failure(e) => throw e
case Success(diffGraphBuilder) =>
diffGraphAccumulator.absorb(diffGraphBuilder)
}
externalBuilder.absorb(diffGraphAccumulator)
}
if (exitByTimeout)
println("Timeout exception encountered, continuing with partial result")
externalBuilder.absorb(diffGraphAccumulator)
}
nParts
} finally {
Expand Down

0 comments on commit 90f97b5

Please sign in to comment.