diff --git a/native/core/src/execution/jni_api.rs b/native/core/src/execution/jni_api.rs index 09caf5e27..060608d28 100644 --- a/native/core/src/execution/jni_api.rs +++ b/native/core/src/execution/jni_api.rs @@ -71,8 +71,6 @@ struct ExecutionContext { pub root_op: Option>, /// The input sources for the DataFusion plan pub scans: Vec, - /// The global reference of input sources for the DataFusion plan - pub input_sources: Vec>, /// The record batch stream to pull results from pub stream: Option, /// The Tokio runtime used for async. @@ -99,7 +97,6 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( e: JNIEnv, _class: JClass, id: jlong, - iterators: jobjectArray, serialized_query: jbyteArray, metrics_node: JObject, comet_task_memory_manager_obj: JObject, @@ -133,15 +130,6 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( let metrics = Arc::new(jni_new_global_ref!(env, metrics_node)?); - // Get the global references of input sources - let mut input_sources = vec![]; - let iter_array = JObjectArray::from_raw(iterators); - let num_inputs = env.get_array_length(&iter_array)?; - for i in 0..num_inputs { - let input_source = env.get_object_array_element(&iter_array, i)?; - let input_source = Arc::new(jni_new_global_ref!(env, input_source)?); - input_sources.push(input_source); - } let task_memory_manager = Arc::new(jni_new_global_ref!(env, comet_task_memory_manager_obj)?); @@ -163,7 +151,6 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan( spark_plan, root_op: None, scans: vec![], - input_sources, stream: None, runtime, metrics, @@ -302,6 +289,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( stage_id: jint, partition: jint, exec_context: jlong, + iterators: jobjectArray, array_addrs: jlongArray, schema_addrs: jlongArray, ) -> jlong { @@ -318,10 +306,19 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan( let start = Instant::now(); let planner = PhysicalPlanner::new(Arc::clone(&exec_context.session_ctx)) .with_exec_id(exec_context_id); - let (scans, root_op) = planner.create_plan( - &exec_context.spark_plan, - &mut exec_context.input_sources.clone(), - )?; + + // Get the global references of input sources + let mut input_sources = vec![]; + let iter_array = JObjectArray::from_raw(iterators); + let num_inputs = env.get_array_length(&iter_array)?; + for i in 0..num_inputs { + let input_source = env.get_object_array_element(&iter_array, i)?; + let input_source = Arc::new(jni_new_global_ref!(env, input_source)?); + input_sources.push(input_source); + } + + let (scans, root_op) = + planner.create_plan(&exec_context.spark_plan, &mut input_sources)?; let physical_plan_time = start.elapsed(); exec_context.plan_creation_time += physical_plan_time; diff --git a/native/core/src/execution/operators/scan.rs b/native/core/src/execution/operators/scan.rs index 888cd2fdb..d43852e45 100644 --- a/native/core/src/execution/operators/scan.rs +++ b/native/core/src/execution/operators/scan.rs @@ -57,8 +57,7 @@ use std::{ /// Native.executePlan, it passes in the memory addresses of the input batches. #[derive(Debug, Clone)] pub struct ScanExec { - /// The ID of the execution context that owns this subquery. We use this ID to retrieve the JVM - /// environment `JNIEnv` from the execution context. + /// The ID of the execution context that owns this scan. pub exec_context_id: i64, /// The input source of scan node. It is a global reference of JVM `CometBatchIterator` object. pub input_source: Option>, diff --git a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala index 04d930695..f2da0a687 100644 --- a/spark/src/main/scala/org/apache/comet/CometExecIterator.scala +++ b/spark/src/main/scala/org/apache/comet/CometExecIterator.scala @@ -35,10 +35,10 @@ import org.apache.comet.vector.NativeUtil * `hasNext` can be used to check if it is the end of this iterator (i.e. the native query is * done). * + * @param id + * The unique id of the query plan behind this native execution. * @param inputs * The input iterators producing sequence of batches of Arrow Arrays. - * @param protobufQueryPlan - * The serialized bytes of Spark execution plan. * @param numParts * The number of partitions. * @param partitionIndex @@ -46,39 +46,17 @@ import org.apache.comet.vector.NativeUtil */ class CometExecIterator( val id: Long, + nativePlan: Long, inputs: Seq[Iterator[ColumnarBatch]], numOutputCols: Int, - protobufQueryPlan: Array[Byte], - nativeMetrics: CometMetricNode, numParts: Int, partitionIndex: Int) extends Iterator[ColumnarBatch] { + import CometExecIterator._ - private val nativeLib = new Native() - private val nativeUtil = new NativeUtil() private val cometBatchIterators = inputs.map { iterator => new CometBatchIterator(iterator, nativeUtil) }.toArray - private val plan = { - val conf = SparkEnv.get.conf - // Only enable unified memory manager when off-heap mode is enabled. Otherwise, - // we'll use the built-in memory pool from DF, and initializes with `memory_limit` - // and `memory_fraction` below. - nativeLib.createPlan( - id, - cometBatchIterators, - protobufQueryPlan, - nativeMetrics, - new CometTaskMemoryManager(id), - batchSize = COMET_BATCH_SIZE.get(), - use_unified_memory_manager = conf.getBoolean("spark.memory.offHeap.enabled", false), - memory_limit = CometSparkSessionExtensions.getCometMemoryOverhead(conf), - memory_fraction = COMET_EXEC_MEMORY_FRACTION.get(), - debug = COMET_DEBUG_ENABLED.get(), - explain = COMET_EXPLAIN_NATIVE_ENABLED.get(), - workerThreads = COMET_WORKER_THREADS.get(), - blockingThreads = COMET_BLOCKING_THREADS.get()) - } private var nextBatch: Option[ColumnarBatch] = None private var currentBatch: ColumnarBatch = null @@ -91,7 +69,13 @@ class CometExecIterator( numOutputCols, (arrayAddrs, schemaAddrs) => { val ctx = TaskContext.get() - nativeLib.executePlan(ctx.stageId(), partitionIndex, plan, arrayAddrs, schemaAddrs) + nativeLib.executePlan( + ctx.stageId(), + partitionIndex, + nativePlan, + cometBatchIterators, + arrayAddrs, + schemaAddrs) }) } @@ -134,8 +118,6 @@ class CometExecIterator( currentBatch.close() currentBatch = null } - nativeUtil.close() - nativeLib.releasePlan(plan) // The allocator thoughts the exported ArrowArray and ArrowSchema structs are not released, // so it will report: @@ -160,3 +142,43 @@ class CometExecIterator( } } } + +object CometExecIterator { + val nativeLib = new Native() + val nativeUtil = new NativeUtil() + + val planMap = new java.util.concurrent.ConcurrentHashMap[Array[Byte], Long]() + + def createPlan(id: Long, protobufQueryPlan: Array[Byte], nativeMetrics: CometMetricNode): Long = + synchronized { + if (planMap.containsKey(protobufQueryPlan)) { + planMap.get(protobufQueryPlan) + } else { + val conf = SparkEnv.get.conf + + val plan = nativeLib.createPlan( + id, + protobufQueryPlan, + nativeMetrics, + new CometTaskMemoryManager(id), + batchSize = COMET_BATCH_SIZE.get(), + use_unified_memory_manager = conf.getBoolean("spark.memory.offHeap.enabled", false), + memory_limit = CometSparkSessionExtensions.getCometMemoryOverhead(conf), + memory_fraction = COMET_EXEC_MEMORY_FRACTION.get(), + debug = COMET_DEBUG_ENABLED.get(), + explain = COMET_EXPLAIN_NATIVE_ENABLED.get(), + workerThreads = COMET_WORKER_THREADS.get(), + blockingThreads = COMET_BLOCKING_THREADS.get()) + planMap.put(protobufQueryPlan, plan) + plan + } + } + + def releasePlan(protobufQueryPlan: Array[Byte]): Unit = synchronized { + if (planMap.containsKey(protobufQueryPlan)) { + val plan = planMap.get(protobufQueryPlan) + nativeLib.releasePlan(plan) + planMap.remove(protobufQueryPlan) + } + } +} diff --git a/spark/src/main/scala/org/apache/comet/Native.scala b/spark/src/main/scala/org/apache/comet/Native.scala index 083c0f2b5..81d0766de 100644 --- a/spark/src/main/scala/org/apache/comet/Native.scala +++ b/spark/src/main/scala/org/apache/comet/Native.scala @@ -30,9 +30,6 @@ class Native extends NativeBase { * The id of the query plan. * @param configMap * The Java Map object for the configs of native engine. - * @param iterators - * the input iterators to the native query plan. It should be the same number as the number of - * scan nodes in the SparkPlan. * @param plan * the bytes of serialized SparkPlan. * @param metrics @@ -46,7 +43,6 @@ class Native extends NativeBase { // scalastyle:off @native def createPlan( id: Long, - iterators: Array[CometBatchIterator], plan: Array[Byte], metrics: CometMetricNode, taskMemoryManager: CometTaskMemoryManager, @@ -69,6 +65,9 @@ class Native extends NativeBase { * the partition ID, for informational purposes * @param plan * the address to native query plan. + * @param iterators + * the input iterators to the native query plan. It should be the same number as the number of + * scan nodes in the SparkPlan. * @param arrayAddrs * the addresses of Arrow Array structures * @param schemaAddrs @@ -80,6 +79,7 @@ class Native extends NativeBase { stage: Int, partition: Int, plan: Long, + iterators: Array[CometBatchIterator], arrayAddrs: Array[Long], schemaAddrs: Array[Long]): Long diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala index 3a11b8b28..2634639f5 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/CometShuffleExchangeExec.scala @@ -499,6 +499,7 @@ class CometShuffleWriteProcessor( // Getting rid of the fake partitionId val newInputs = inputs.asInstanceOf[Iterator[_ <: Product2[Any, Any]]].map(_._2) + context.taskAttemptId() val cometIter = CometExec.getCometIterator( Seq(newInputs.asInstanceOf[Iterator[ColumnarBatch]]), outputAttributes.length, diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala index c70f7464e..8f9fd0744 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala @@ -129,14 +129,11 @@ object CometExec { nativePlan.writeTo(outputStream) outputStream.close() val bytes = outputStream.toByteArray - new CometExecIterator( - newIterId, - inputs, - numOutputCols, - bytes, - nativeMetrics, - numParts, - partitionIdx) + + val planId = CometExec.newIterId + val nativePlanId = CometExecIterator.createPlan(planId, bytes, nativeMetrics) + + new CometExecIterator(newIterId, nativePlanId, inputs, numOutputCols, numParts, partitionIdx) } /** @@ -206,12 +203,14 @@ abstract class CometNativeExec extends CometExec { inputs: Seq[Iterator[ColumnarBatch]], numParts: Int, partitionIndex: Int): CometExecIterator = { + val planId = CometExec.newIterId + val nativePlan = CometExecIterator.createPlan(planId, serializedPlanCopy, nativeMetrics) + val it = new CometExecIterator( - CometExec.newIterId, + planId, + nativePlan, inputs, output.length, - serializedPlanCopy, - nativeMetrics, numParts, partitionIndex) @@ -221,6 +220,7 @@ abstract class CometNativeExec extends CometExec { context.addTaskCompletionListener[Unit] { _ => it.close() cleanSubqueries(it.id, this) + CometExecIterator.releasePlan(serializedPlanCopy) } }