Skip to content

Commit

Permalink
refactor: Only create one native plan for a query on an executor
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Dec 29, 2024
1 parent 5d2c909 commit e32b00f
Show file tree
Hide file tree
Showing 6 changed files with 82 additions and 63 deletions.
31 changes: 14 additions & 17 deletions native/core/src/execution/jni_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,6 @@ struct ExecutionContext {
pub root_op: Option<Arc<SparkPlan>>,
/// The input sources for the DataFusion plan
pub scans: Vec<ScanExec>,
/// The global reference of input sources for the DataFusion plan
pub input_sources: Vec<Arc<GlobalRef>>,
/// The record batch stream to pull results from
pub stream: Option<SendableRecordBatchStream>,
/// The Tokio runtime used for async.
Expand All @@ -99,7 +97,6 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
e: JNIEnv,
_class: JClass,
id: jlong,
iterators: jobjectArray,
serialized_query: jbyteArray,
metrics_node: JObject,
comet_task_memory_manager_obj: JObject,
Expand Down Expand Up @@ -133,15 +130,6 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(

let metrics = Arc::new(jni_new_global_ref!(env, metrics_node)?);

// Get the global references of input sources
let mut input_sources = vec![];
let iter_array = JObjectArray::from_raw(iterators);
let num_inputs = env.get_array_length(&iter_array)?;
for i in 0..num_inputs {
let input_source = env.get_object_array_element(&iter_array, i)?;
let input_source = Arc::new(jni_new_global_ref!(env, input_source)?);
input_sources.push(input_source);
}
let task_memory_manager =
Arc::new(jni_new_global_ref!(env, comet_task_memory_manager_obj)?);

Expand All @@ -163,7 +151,6 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
spark_plan,
root_op: None,
scans: vec![],
input_sources,
stream: None,
runtime,
metrics,
Expand Down Expand Up @@ -302,6 +289,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan(
stage_id: jint,
partition: jint,
exec_context: jlong,
iterators: jobjectArray,
array_addrs: jlongArray,
schema_addrs: jlongArray,
) -> jlong {
Expand All @@ -318,10 +306,19 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan(
let start = Instant::now();
let planner = PhysicalPlanner::new(Arc::clone(&exec_context.session_ctx))
.with_exec_id(exec_context_id);
let (scans, root_op) = planner.create_plan(
&exec_context.spark_plan,
&mut exec_context.input_sources.clone(),
)?;

// Get the global references of input sources
let mut input_sources = vec![];
let iter_array = JObjectArray::from_raw(iterators);
let num_inputs = env.get_array_length(&iter_array)?;
for i in 0..num_inputs {
let input_source = env.get_object_array_element(&iter_array, i)?;
let input_source = Arc::new(jni_new_global_ref!(env, input_source)?);
input_sources.push(input_source);
}

let (scans, root_op) =
planner.create_plan(&exec_context.spark_plan, &mut input_sources)?;
let physical_plan_time = start.elapsed();

exec_context.plan_creation_time += physical_plan_time;
Expand Down
3 changes: 1 addition & 2 deletions native/core/src/execution/operators/scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ use std::{
/// Native.executePlan, it passes in the memory addresses of the input batches.
#[derive(Debug, Clone)]
pub struct ScanExec {
/// The ID of the execution context that owns this subquery. We use this ID to retrieve the JVM
/// environment `JNIEnv` from the execution context.
/// The ID of the execution context that owns this scan.
pub exec_context_id: i64,
/// The input source of scan node. It is a global reference of JVM `CometBatchIterator` object.
pub input_source: Option<Arc<GlobalRef>>,
Expand Down
80 changes: 51 additions & 29 deletions spark/src/main/scala/org/apache/comet/CometExecIterator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,50 +35,28 @@ import org.apache.comet.vector.NativeUtil
* `hasNext` can be used to check if it is the end of this iterator (i.e. the native query is
* done).
*
* @param id
* The unique id of the query plan behind this native execution.
* @param inputs
* The input iterators producing sequence of batches of Arrow Arrays.
* @param protobufQueryPlan
* The serialized bytes of Spark execution plan.
* @param numParts
* The number of partitions.
* @param partitionIndex
* The index of the partition.
*/
class CometExecIterator(
val id: Long,
nativePlan: Long,
inputs: Seq[Iterator[ColumnarBatch]],
numOutputCols: Int,
protobufQueryPlan: Array[Byte],
nativeMetrics: CometMetricNode,
numParts: Int,
partitionIndex: Int)
extends Iterator[ColumnarBatch] {
import CometExecIterator._

private val nativeLib = new Native()
private val nativeUtil = new NativeUtil()
private val cometBatchIterators = inputs.map { iterator =>
new CometBatchIterator(iterator, nativeUtil)
}.toArray
private val plan = {
val conf = SparkEnv.get.conf
// Only enable unified memory manager when off-heap mode is enabled. Otherwise,
// we'll use the built-in memory pool from DF, and initializes with `memory_limit`
// and `memory_fraction` below.
nativeLib.createPlan(
id,
cometBatchIterators,
protobufQueryPlan,
nativeMetrics,
new CometTaskMemoryManager(id),
batchSize = COMET_BATCH_SIZE.get(),
use_unified_memory_manager = conf.getBoolean("spark.memory.offHeap.enabled", false),
memory_limit = CometSparkSessionExtensions.getCometMemoryOverhead(conf),
memory_fraction = COMET_EXEC_MEMORY_FRACTION.get(),
debug = COMET_DEBUG_ENABLED.get(),
explain = COMET_EXPLAIN_NATIVE_ENABLED.get(),
workerThreads = COMET_WORKER_THREADS.get(),
blockingThreads = COMET_BLOCKING_THREADS.get())
}

private var nextBatch: Option[ColumnarBatch] = None
private var currentBatch: ColumnarBatch = null
Expand All @@ -91,7 +69,13 @@ class CometExecIterator(
numOutputCols,
(arrayAddrs, schemaAddrs) => {
val ctx = TaskContext.get()
nativeLib.executePlan(ctx.stageId(), partitionIndex, plan, arrayAddrs, schemaAddrs)
nativeLib.executePlan(
ctx.stageId(),
partitionIndex,
nativePlan,
cometBatchIterators,
arrayAddrs,
schemaAddrs)
})
}

Expand Down Expand Up @@ -134,8 +118,6 @@ class CometExecIterator(
currentBatch.close()
currentBatch = null
}
nativeUtil.close()
nativeLib.releasePlan(plan)

// The allocator thoughts the exported ArrowArray and ArrowSchema structs are not released,
// so it will report:
Expand All @@ -160,3 +142,43 @@ class CometExecIterator(
}
}
}

object CometExecIterator {
val nativeLib = new Native()
val nativeUtil = new NativeUtil()

val planMap = new java.util.concurrent.ConcurrentHashMap[Array[Byte], Long]()

def createPlan(id: Long, protobufQueryPlan: Array[Byte], nativeMetrics: CometMetricNode): Long =
synchronized {
if (planMap.containsKey(protobufQueryPlan)) {
planMap.get(protobufQueryPlan)
} else {
val conf = SparkEnv.get.conf

val plan = nativeLib.createPlan(
id,
protobufQueryPlan,
nativeMetrics,
new CometTaskMemoryManager(id),
batchSize = COMET_BATCH_SIZE.get(),
use_unified_memory_manager = conf.getBoolean("spark.memory.offHeap.enabled", false),
memory_limit = CometSparkSessionExtensions.getCometMemoryOverhead(conf),
memory_fraction = COMET_EXEC_MEMORY_FRACTION.get(),
debug = COMET_DEBUG_ENABLED.get(),
explain = COMET_EXPLAIN_NATIVE_ENABLED.get(),
workerThreads = COMET_WORKER_THREADS.get(),
blockingThreads = COMET_BLOCKING_THREADS.get())
planMap.put(protobufQueryPlan, plan)
plan
}
}

def releasePlan(protobufQueryPlan: Array[Byte]): Unit = synchronized {
if (planMap.containsKey(protobufQueryPlan)) {
val plan = planMap.get(protobufQueryPlan)
nativeLib.releasePlan(plan)
planMap.remove(protobufQueryPlan)
}
}
}
8 changes: 4 additions & 4 deletions spark/src/main/scala/org/apache/comet/Native.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,6 @@ class Native extends NativeBase {
* The id of the query plan.
* @param configMap
* The Java Map object for the configs of native engine.
* @param iterators
* the input iterators to the native query plan. It should be the same number as the number of
* scan nodes in the SparkPlan.
* @param plan
* the bytes of serialized SparkPlan.
* @param metrics
Expand All @@ -46,7 +43,6 @@ class Native extends NativeBase {
// scalastyle:off
@native def createPlan(
id: Long,
iterators: Array[CometBatchIterator],
plan: Array[Byte],
metrics: CometMetricNode,
taskMemoryManager: CometTaskMemoryManager,
Expand All @@ -69,6 +65,9 @@ class Native extends NativeBase {
* the partition ID, for informational purposes
* @param plan
* the address to native query plan.
* @param iterators
* the input iterators to the native query plan. It should be the same number as the number of
* scan nodes in the SparkPlan.
* @param arrayAddrs
* the addresses of Arrow Array structures
* @param schemaAddrs
Expand All @@ -80,6 +79,7 @@ class Native extends NativeBase {
stage: Int,
partition: Int,
plan: Long,
iterators: Array[CometBatchIterator],
arrayAddrs: Array[Long],
schemaAddrs: Array[Long]): Long

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,7 @@ class CometShuffleWriteProcessor(
// Getting rid of the fake partitionId
val newInputs = inputs.asInstanceOf[Iterator[_ <: Product2[Any, Any]]].map(_._2)

context.taskAttemptId()
val cometIter = CometExec.getCometIterator(
Seq(newInputs.asInstanceOf[Iterator[ColumnarBatch]]),
outputAttributes.length,
Expand Down
22 changes: 11 additions & 11 deletions spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,11 @@ object CometExec {
nativePlan.writeTo(outputStream)
outputStream.close()
val bytes = outputStream.toByteArray
new CometExecIterator(
newIterId,
inputs,
numOutputCols,
bytes,
nativeMetrics,
numParts,
partitionIdx)

val planId = CometExec.newIterId
val nativePlanId = CometExecIterator.createPlan(planId, bytes, nativeMetrics)

new CometExecIterator(newIterId, nativePlanId, inputs, numOutputCols, numParts, partitionIdx)
}

/**
Expand Down Expand Up @@ -206,12 +203,14 @@ abstract class CometNativeExec extends CometExec {
inputs: Seq[Iterator[ColumnarBatch]],
numParts: Int,
partitionIndex: Int): CometExecIterator = {
val planId = CometExec.newIterId
val nativePlan = CometExecIterator.createPlan(planId, serializedPlanCopy, nativeMetrics)

val it = new CometExecIterator(
CometExec.newIterId,
planId,
nativePlan,
inputs,
output.length,
serializedPlanCopy,
nativeMetrics,
numParts,
partitionIndex)

Expand All @@ -221,6 +220,7 @@ abstract class CometNativeExec extends CometExec {
context.addTaskCompletionListener[Unit] { _ =>
it.close()
cleanSubqueries(it.id, this)
CometExecIterator.releasePlan(serializedPlanCopy)
}
}

Expand Down

0 comments on commit e32b00f

Please sign in to comment.