Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Only create one native plan for a query on an executor #1203

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 14 additions & 17 deletions native/core/src/execution/jni_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,6 @@ struct ExecutionContext {
pub root_op: Option<Arc<SparkPlan>>,
/// The input sources for the DataFusion plan
pub scans: Vec<ScanExec>,
/// The global reference of input sources for the DataFusion plan
pub input_sources: Vec<Arc<GlobalRef>>,
/// The record batch stream to pull results from
pub stream: Option<SendableRecordBatchStream>,
/// The Tokio runtime used for async.
Expand All @@ -99,7 +97,6 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
e: JNIEnv,
_class: JClass,
id: jlong,
iterators: jobjectArray,
serialized_query: jbyteArray,
metrics_node: JObject,
comet_task_memory_manager_obj: JObject,
Expand Down Expand Up @@ -133,15 +130,6 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(

let metrics = Arc::new(jni_new_global_ref!(env, metrics_node)?);

// Get the global references of input sources
let mut input_sources = vec![];
let iter_array = JObjectArray::from_raw(iterators);
let num_inputs = env.get_array_length(&iter_array)?;
for i in 0..num_inputs {
let input_source = env.get_object_array_element(&iter_array, i)?;
let input_source = Arc::new(jni_new_global_ref!(env, input_source)?);
input_sources.push(input_source);
}
let task_memory_manager =
Arc::new(jni_new_global_ref!(env, comet_task_memory_manager_obj)?);

Expand All @@ -163,7 +151,6 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
spark_plan,
root_op: None,
scans: vec![],
input_sources,
stream: None,
runtime,
metrics,
Expand Down Expand Up @@ -302,6 +289,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan(
stage_id: jint,
partition: jint,
exec_context: jlong,
iterators: jobjectArray,
array_addrs: jlongArray,
schema_addrs: jlongArray,
) -> jlong {
Expand All @@ -318,10 +306,19 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan(
let start = Instant::now();
let planner = PhysicalPlanner::new(Arc::clone(&exec_context.session_ctx))
.with_exec_id(exec_context_id);
let (scans, root_op) = planner.create_plan(
&exec_context.spark_plan,
&mut exec_context.input_sources.clone(),
)?;

// Get the global references of input sources
let mut input_sources = vec![];
let iter_array = JObjectArray::from_raw(iterators);
let num_inputs = env.get_array_length(&iter_array)?;
for i in 0..num_inputs {
let input_source = env.get_object_array_element(&iter_array, i)?;
let input_source = Arc::new(jni_new_global_ref!(env, input_source)?);
input_sources.push(input_source);
}

let (scans, root_op) =
planner.create_plan(&exec_context.spark_plan, &mut input_sources)?;
let physical_plan_time = start.elapsed();

exec_context.plan_creation_time += physical_plan_time;
Expand Down
3 changes: 1 addition & 2 deletions native/core/src/execution/operators/scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ use std::{
/// Native.executePlan, it passes in the memory addresses of the input batches.
#[derive(Debug, Clone)]
pub struct ScanExec {
/// The ID of the execution context that owns this subquery. We use this ID to retrieve the JVM
/// environment `JNIEnv` from the execution context.
/// The ID of the execution context that owns this scan.
pub exec_context_id: i64,
/// The input source of scan node. It is a global reference of JVM `CometBatchIterator` object.
pub input_source: Option<Arc<GlobalRef>>,
Expand Down
80 changes: 51 additions & 29 deletions spark/src/main/scala/org/apache/comet/CometExecIterator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,50 +35,28 @@ import org.apache.comet.vector.NativeUtil
* `hasNext` can be used to check if it is the end of this iterator (i.e. the native query is
* done).
*
* @param id
* The unique id of the query plan behind this native execution.
* @param inputs
* The input iterators producing sequence of batches of Arrow Arrays.
* @param protobufQueryPlan
* The serialized bytes of Spark execution plan.
* @param numParts
* The number of partitions.
* @param partitionIndex
* The index of the partition.
*/
class CometExecIterator(
val id: Long,
nativePlan: Long,
inputs: Seq[Iterator[ColumnarBatch]],
numOutputCols: Int,
protobufQueryPlan: Array[Byte],
nativeMetrics: CometMetricNode,
numParts: Int,
partitionIndex: Int)
extends Iterator[ColumnarBatch] {
import CometExecIterator._

private val nativeLib = new Native()
private val nativeUtil = new NativeUtil()
private val cometBatchIterators = inputs.map { iterator =>
new CometBatchIterator(iterator, nativeUtil)
}.toArray
private val plan = {
val conf = SparkEnv.get.conf
// Only enable unified memory manager when off-heap mode is enabled. Otherwise,
// we'll use the built-in memory pool from DF, and initializes with `memory_limit`
// and `memory_fraction` below.
nativeLib.createPlan(
id,
cometBatchIterators,
protobufQueryPlan,
nativeMetrics,
new CometTaskMemoryManager(id),
batchSize = COMET_BATCH_SIZE.get(),
use_unified_memory_manager = conf.getBoolean("spark.memory.offHeap.enabled", false),
memory_limit = CometSparkSessionExtensions.getCometMemoryOverhead(conf),
memory_fraction = COMET_EXEC_MEMORY_FRACTION.get(),
debug = COMET_DEBUG_ENABLED.get(),
explain = COMET_EXPLAIN_NATIVE_ENABLED.get(),
workerThreads = COMET_WORKER_THREADS.get(),
blockingThreads = COMET_BLOCKING_THREADS.get())
}

private var nextBatch: Option[ColumnarBatch] = None
private var currentBatch: ColumnarBatch = null
Expand All @@ -91,7 +69,13 @@ class CometExecIterator(
numOutputCols,
(arrayAddrs, schemaAddrs) => {
val ctx = TaskContext.get()
nativeLib.executePlan(ctx.stageId(), partitionIndex, plan, arrayAddrs, schemaAddrs)
nativeLib.executePlan(
ctx.stageId(),
partitionIndex,
nativePlan,
cometBatchIterators,
arrayAddrs,
schemaAddrs)
})
}

Expand Down Expand Up @@ -134,8 +118,6 @@ class CometExecIterator(
currentBatch.close()
currentBatch = null
}
nativeUtil.close()
nativeLib.releasePlan(plan)

// The allocator thoughts the exported ArrowArray and ArrowSchema structs are not released,
// so it will report:
Expand All @@ -160,3 +142,43 @@ class CometExecIterator(
}
}
}

object CometExecIterator {
val nativeLib = new Native()
val nativeUtil = new NativeUtil()

val planMap = new java.util.concurrent.ConcurrentHashMap[Array[Byte], Long]()

def createPlan(id: Long, protobufQueryPlan: Array[Byte], nativeMetrics: CometMetricNode): Long =
synchronized {
if (planMap.containsKey(protobufQueryPlan)) {
planMap.get(protobufQueryPlan)
} else {
val conf = SparkEnv.get.conf

val plan = nativeLib.createPlan(
id,
protobufQueryPlan,
nativeMetrics,
new CometTaskMemoryManager(id),
batchSize = COMET_BATCH_SIZE.get(),
use_unified_memory_manager = conf.getBoolean("spark.memory.offHeap.enabled", false),
memory_limit = CometSparkSessionExtensions.getCometMemoryOverhead(conf),
memory_fraction = COMET_EXEC_MEMORY_FRACTION.get(),
debug = COMET_DEBUG_ENABLED.get(),
explain = COMET_EXPLAIN_NATIVE_ENABLED.get(),
workerThreads = COMET_WORKER_THREADS.get(),
blockingThreads = COMET_BLOCKING_THREADS.get())
planMap.put(protobufQueryPlan, plan)
plan
}
}

def releasePlan(protobufQueryPlan: Array[Byte]): Unit = synchronized {
if (planMap.containsKey(protobufQueryPlan)) {
val plan = planMap.get(protobufQueryPlan)
nativeLib.releasePlan(plan)
planMap.remove(protobufQueryPlan)
}
}
}
8 changes: 4 additions & 4 deletions spark/src/main/scala/org/apache/comet/Native.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,6 @@ class Native extends NativeBase {
* The id of the query plan.
* @param configMap
* The Java Map object for the configs of native engine.
* @param iterators
* the input iterators to the native query plan. It should be the same number as the number of
* scan nodes in the SparkPlan.
* @param plan
* the bytes of serialized SparkPlan.
* @param metrics
Expand All @@ -46,7 +43,6 @@ class Native extends NativeBase {
// scalastyle:off
@native def createPlan(
id: Long,
iterators: Array[CometBatchIterator],
plan: Array[Byte],
metrics: CometMetricNode,
taskMemoryManager: CometTaskMemoryManager,
Expand All @@ -69,6 +65,9 @@ class Native extends NativeBase {
* the partition ID, for informational purposes
* @param plan
* the address to native query plan.
* @param iterators
* the input iterators to the native query plan. It should be the same number as the number of
* scan nodes in the SparkPlan.
* @param arrayAddrs
* the addresses of Arrow Array structures
* @param schemaAddrs
Expand All @@ -80,6 +79,7 @@ class Native extends NativeBase {
stage: Int,
partition: Int,
plan: Long,
iterators: Array[CometBatchIterator],
arrayAddrs: Array[Long],
schemaAddrs: Array[Long]): Long

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,7 @@ class CometShuffleWriteProcessor(
// Getting rid of the fake partitionId
val newInputs = inputs.asInstanceOf[Iterator[_ <: Product2[Any, Any]]].map(_._2)

context.taskAttemptId()
val cometIter = CometExec.getCometIterator(
Seq(newInputs.asInstanceOf[Iterator[ColumnarBatch]]),
outputAttributes.length,
Expand Down
22 changes: 11 additions & 11 deletions spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,11 @@ object CometExec {
nativePlan.writeTo(outputStream)
outputStream.close()
val bytes = outputStream.toByteArray
new CometExecIterator(
newIterId,
inputs,
numOutputCols,
bytes,
nativeMetrics,
numParts,
partitionIdx)

val planId = CometExec.newIterId
val nativePlanId = CometExecIterator.createPlan(planId, bytes, nativeMetrics)

new CometExecIterator(newIterId, nativePlanId, inputs, numOutputCols, numParts, partitionIdx)
}

/**
Expand Down Expand Up @@ -206,12 +203,14 @@ abstract class CometNativeExec extends CometExec {
inputs: Seq[Iterator[ColumnarBatch]],
numParts: Int,
partitionIndex: Int): CometExecIterator = {
val planId = CometExec.newIterId
val nativePlan = CometExecIterator.createPlan(planId, serializedPlanCopy, nativeMetrics)

val it = new CometExecIterator(
CometExec.newIterId,
planId,
nativePlan,
inputs,
output.length,
serializedPlanCopy,
nativeMetrics,
numParts,
partitionIndex)

Expand All @@ -221,6 +220,7 @@ abstract class CometNativeExec extends CometExec {
context.addTaskCompletionListener[Unit] { _ =>
it.close()
cleanSubqueries(it.id, this)
CometExecIterator.releasePlan(serializedPlanCopy)
}
}

Expand Down
Loading