Skip to content

Commit

Permalink
Simplify classic Column handling.
Browse files Browse the repository at this point in the history
  • Loading branch information
hvanhovell committed Dec 3, 2024
1 parent 6c84f15 commit 6ff1f72
Show file tree
Hide file tree
Showing 30 changed files with 181 additions and 133 deletions.
11 changes: 6 additions & 5 deletions mllib/src/main/scala/org/apache/spark/ml/stat/Summarizer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,9 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Expression, ImplicitCastInputTypes}
import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate
import org.apache.spark.sql.catalyst.trees.BinaryLike
import org.apache.spark.sql.classic.ClassicConversions._
import org.apache.spark.sql.functions.lit
import org.apache.spark.sql.internal.ExpressionUtils.{column, expression}
import org.apache.spark.sql.internal.ExpressionUtils.expression
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils

Expand Down Expand Up @@ -249,13 +250,13 @@ private[ml] class SummaryBuilderImpl(
) extends SummaryBuilder {

override def summary(featuresCol: Column, weightCol: Column): Column = {
SummaryBuilderImpl.MetricsAggregate(
Column(SummaryBuilderImpl.MetricsAggregate(
requestedMetrics,
requestedCompMetrics,
featuresCol,
weightCol,
expression(featuresCol),
expression(weightCol),
mutableAggBufferOffset = 0,
inputAggBufferOffset = 0)
inputAggBufferOffset = 0))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ import org.apache.spark.connect.proto.WriteStreamOperationStart.TriggerCase
import org.apache.spark.internal.{Logging, LogKeys, MDC}
import org.apache.spark.internal.LogKeys.{DATAFRAME_ID, SESSION_ID}
import org.apache.spark.resource.{ExecutorResourceRequest, ResourceProfile, TaskResourceProfile, TaskResourceRequest}
import org.apache.spark.sql.{Dataset, Encoders, ForeachWriter, Observation, RelationalGroupedDataset, Row, SparkSession}
import org.apache.spark.sql.{Column, Dataset, Encoders, ForeachWriter, Observation, RelationalGroupedDataset, Row, SparkSession}
import org.apache.spark.sql.catalyst.{expressions, AliasIdentifier, FunctionIdentifier, QueryPlanningTracker}
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, GlobalTempView, LocalTempView, MultiAlias, NameParameterizedQuery, PosParameterizedQuery, UnresolvedAlias, UnresolvedAttribute, UnresolvedDataFrameStar, UnresolvedDeserializer, UnresolvedExtractValue, UnresolvedFunction, UnresolvedRegex, UnresolvedRelation, UnresolvedStar, UnresolvedTableValuedFunction, UnresolvedTranspose}
import org.apache.spark.sql.catalyst.encoders.{encoderFor, AgnosticEncoder, ExpressionEncoder, RowEncoder}
Expand All @@ -58,6 +58,7 @@ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes
import org.apache.spark.sql.catalyst.trees.CurrentOrigin
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils}
import org.apache.spark.sql.classic.ClassicConversions._
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, ForeachWriterPacket, InvalidPlanInput, LiteralValueProtoConverter, StorageLevelProtoConverter, StreamingListenerPacket, UdfPacket}
import org.apache.spark.sql.connect.config.Connect.CONNECT_GRPC_ARROW_MAX_BATCH_SIZE
import org.apache.spark.sql.connect.plugin.SparkConnectPluginRegistry
Expand All @@ -77,7 +78,6 @@ import org.apache.spark.sql.execution.streaming.GroupStateImpl.groupStateTimeout
import org.apache.spark.sql.execution.streaming.StreamingQueryWrapper
import org.apache.spark.sql.expressions.{Aggregator, ReduceAggregator, SparkUserDefinedFunction, UserDefinedAggregator, UserDefinedFunction}
import org.apache.spark.sql.internal.{CatalogImpl, MergeIntoWriterImpl, TypedAggUtils}
import org.apache.spark.sql.internal.ExpressionUtils.column
import org.apache.spark.sql.streaming.{GroupStateTimeout, OutputMode, StreamingQuery, StreamingQueryListener, StreamingQueryProgress, Trigger}
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.CaseInsensitiveStringMap
Expand Down Expand Up @@ -106,7 +106,7 @@ class SparkConnectPlanner(
@Since("4.0.0")
@DeveloperApi
def session: SparkSession = sessionHolder.session
import sessionHolder.session.RichColumn
import sessionHolder.session.toRichColumn

private[connect] def parser = session.sessionState.sqlParser

Expand Down Expand Up @@ -554,7 +554,7 @@ class SparkConnectPlanner(
.ofRows(session, transformRelation(rel.getInput))
.stat
.sampleBy(
col = column(transformExpression(rel.getCol)),
col = Column(transformExpression(rel.getCol)),
fractions = fractions.toMap,
seed = if (rel.hasSeed) rel.getSeed else Utils.random.nextLong)
.logicalPlan
Expand Down Expand Up @@ -646,17 +646,17 @@ class SparkConnectPlanner(
val pythonUdf = transformPythonUDF(commonUdf)
val cols =
rel.getGroupingExpressionsList.asScala.toSeq.map(expr =>
column(transformExpression(expr)))
Column(transformExpression(expr)))
val group = Dataset
.ofRows(session, transformRelation(rel.getInput))
.groupBy(cols: _*)

pythonUdf.evalType match {
case PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF =>
group.flatMapGroupsInPandas(column(pythonUdf)).logicalPlan
group.flatMapGroupsInPandas(Column(pythonUdf)).logicalPlan

case PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF =>
group.flatMapGroupsInArrow(column(pythonUdf)).logicalPlan
group.flatMapGroupsInArrow(Column(pythonUdf)).logicalPlan

case _ =>
throw InvalidPlanInput(
Expand Down Expand Up @@ -765,10 +765,10 @@ class SparkConnectPlanner(
case proto.CommonInlineUserDefinedFunction.FunctionCase.PYTHON_UDF =>
val inputCols =
rel.getInputGroupingExpressionsList.asScala.toSeq.map(expr =>
column(transformExpression(expr)))
Column(transformExpression(expr)))
val otherCols =
rel.getOtherGroupingExpressionsList.asScala.toSeq.map(expr =>
column(transformExpression(expr)))
Column(transformExpression(expr)))

val input = Dataset
.ofRows(session, transformRelation(rel.getInput))
Expand All @@ -783,10 +783,10 @@ class SparkConnectPlanner(

pythonUdf.evalType match {
case PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF =>
input.flatMapCoGroupsInPandas(other, pythonUdf).logicalPlan
input.flatMapCoGroupsInPandas(other, Column(pythonUdf)).logicalPlan

case PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF =>
input.flatMapCoGroupsInArrow(other, pythonUdf).logicalPlan
input.flatMapCoGroupsInArrow(other, Column(pythonUdf)).logicalPlan

case _ =>
throw InvalidPlanInput(
Expand Down Expand Up @@ -982,7 +982,7 @@ class SparkConnectPlanner(
private def transformApplyInPandasWithState(rel: proto.ApplyInPandasWithState): LogicalPlan = {
val pythonUdf = transformPythonUDF(rel.getFunc)
val cols =
rel.getGroupingExpressionsList.asScala.toSeq.map(expr => column(transformExpression(expr)))
rel.getGroupingExpressionsList.asScala.toSeq.map(expr => Column(transformExpression(expr)))

val outputSchema = parseSchema(rel.getOutputSchema)

Expand All @@ -992,7 +992,7 @@ class SparkConnectPlanner(
.ofRows(session, transformRelation(rel.getInput))
.groupBy(cols: _*)
.applyInPandasWithState(
column(pythonUdf),
Column(pythonUdf),
outputSchema,
stateSchema,
rel.getOutputMode,
Expand Down Expand Up @@ -1080,7 +1080,7 @@ class SparkConnectPlanner(
Metadata.empty
}

(alias.getName(0), column(transformExpression(alias.getExpr)), metadata)
(alias.getName(0), Column(transformExpression(alias.getExpr)), metadata)
}.unzip3

Dataset
Expand Down Expand Up @@ -1142,7 +1142,7 @@ class SparkConnectPlanner(

private def transformUnpivot(rel: proto.Unpivot): LogicalPlan = {
val ids = rel.getIdsList.asScala.toArray.map { expr =>
column(transformExpression(expr))
Column(transformExpression(expr))
}

if (!rel.hasValues) {
Expand All @@ -1155,7 +1155,7 @@ class SparkConnectPlanner(
transformRelation(rel.getInput))
} else {
val values = rel.getValues.getValuesList.asScala.toArray.map { expr =>
column(transformExpression(expr))
Column(transformExpression(expr))
}

Unpivot(
Expand Down Expand Up @@ -1184,7 +1184,7 @@ class SparkConnectPlanner(

private def transformCollectMetrics(rel: proto.CollectMetrics, planId: Long): LogicalPlan = {
val metrics = rel.getMetricsList.asScala.toSeq.map { expr =>
column(transformExpression(expr))
Column(transformExpression(expr))
}
val name = rel.getName
val input = transformRelation(rel.getInput)
Expand Down Expand Up @@ -2112,10 +2112,10 @@ class SparkConnectPlanner(
private def transformAsOfJoin(rel: proto.AsOfJoin): LogicalPlan = {
val left = Dataset.ofRows(session, transformRelation(rel.getLeft))
val right = Dataset.ofRows(session, transformRelation(rel.getRight))
val leftAsOf = column(transformExpression(rel.getLeftAsOf))
val rightAsOf = column(transformExpression(rel.getRightAsOf))
val leftAsOf = Column(transformExpression(rel.getLeftAsOf))
val rightAsOf = Column(transformExpression(rel.getRightAsOf))
val joinType = rel.getJoinType
val tolerance = if (rel.hasTolerance) column(transformExpression(rel.getTolerance)) else null
val tolerance = if (rel.hasTolerance) Column(transformExpression(rel.getTolerance)) else null
val allowExactMatches = rel.getAllowExactMatches
val direction = rel.getDirection

Expand All @@ -2131,7 +2131,7 @@ class SparkConnectPlanner(
allowExactMatches = allowExactMatches,
direction = direction)
} else {
val joinExprs = if (rel.hasJoinExpr) column(transformExpression(rel.getJoinExpr)) else null
val joinExprs = if (rel.hasJoinExpr) Column(transformExpression(rel.getJoinExpr)) else null
left.joinAsOf(
other = right,
leftAsOf = leftAsOf,
Expand Down Expand Up @@ -2172,7 +2172,7 @@ class SparkConnectPlanner(
private def transformDrop(rel: proto.Drop): LogicalPlan = {
var output = Dataset.ofRows(session, transformRelation(rel.getInput))
if (rel.getColumnsCount > 0) {
val cols = rel.getColumnsList.asScala.toSeq.map(expr => column(transformExpression(expr)))
val cols = rel.getColumnsList.asScala.toSeq.map(expr => Column(transformExpression(expr)))
output = output.drop(cols.head, cols.tail: _*)
}
if (rel.getColumnNamesCount > 0) {
Expand Down Expand Up @@ -2247,7 +2247,7 @@ class SparkConnectPlanner(
rel.getPivot.getValuesList.asScala.toSeq.map(transformLiteral)
} else {
RelationalGroupedDataset
.collectPivotValues(Dataset.ofRows(session, input), column(pivotExpr))
.collectPivotValues(Dataset.ofRows(session, input), Column(pivotExpr))
.map(expressions.Literal.apply)
}
logical.Pivot(
Expand Down Expand Up @@ -2574,12 +2574,12 @@ class SparkConnectPlanner(
if (!namedArguments.isEmpty) {
session.sql(
sql.getQuery,
namedArguments.asScala.toMap.transform((_, e) => column(transformExpression(e))),
namedArguments.asScala.toMap.transform((_, e) => Column(transformExpression(e))),
tracker)
} else if (!posArguments.isEmpty) {
session.sql(
sql.getQuery,
posArguments.asScala.map(e => column(transformExpression(e))).toArray,
posArguments.asScala.map(e => Column(transformExpression(e))).toArray,
tracker)
} else if (!args.isEmpty) {
session.sql(
Expand Down Expand Up @@ -2830,7 +2830,7 @@ class SparkConnectPlanner(
if (writeOperation.getPartitioningColumnsCount > 0) {
val names = writeOperation.getPartitioningColumnsList.asScala
.map(transformExpression)
.map(column)
.map(Column(_))
.toSeq
w.partitionedBy(names.head, names.tail: _*)
}
Expand All @@ -2848,7 +2848,7 @@ class SparkConnectPlanner(
w.create()
}
case proto.WriteOperationV2.Mode.MODE_OVERWRITE =>
w.overwrite(column(transformExpression(writeOperation.getOverwriteCondition)))
w.overwrite(Column(transformExpression(writeOperation.getOverwriteCondition)))
case proto.WriteOperationV2.Mode.MODE_OVERWRITE_PARTITIONS =>
w.overwritePartitions()
case proto.WriteOperationV2.Mode.MODE_APPEND =>
Expand Down Expand Up @@ -3410,7 +3410,7 @@ class SparkConnectPlanner(

val sourceDs = Dataset.ofRows(session, transformRelation(cmd.getSourceTablePlan))
val mergeInto = sourceDs
.mergeInto(cmd.getTargetTableName, column(transformExpression(cmd.getMergeCondition)))
.mergeInto(cmd.getTargetTableName, Column(transformExpression(cmd.getMergeCondition)))
.asInstanceOf[MergeIntoWriterImpl[Row]]
mergeInto.matchedActions ++= matchedActions
mergeInto.notMatchedActions ++= notMatchedActions
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import org.apache.spark.sql.types._
@Stable
final class DataFrameNaFunctions private[sql](df: DataFrame)
extends api.DataFrameNaFunctions {
import df.sparkSession.RichColumn
import df.sparkSession.toRichColumn

protected def drop(minNonNulls: Option[Int]): Dataset[Row] = {
drop0(minNonNulls, outputAttributes)
Expand Down
2 changes: 1 addition & 1 deletion sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ class Dataset[T] private[sql](
queryExecution.sparkSession
}

import sparkSession.RichColumn
import sparkSession.toRichColumn

// A globally unique id of this Dataset.
private[sql] val id = Dataset.curId.getAndIncrement()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.util.toPrettySQL
import org.apache.spark.sql.classic.ClassicConversions._
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.internal.ExpressionUtils.{column, generateAlias}
import org.apache.spark.sql.internal.ExpressionUtils.generateAlias
import org.apache.spark.sql.internal.TypedAggUtils.withInputType
import org.apache.spark.sql.streaming.OutputMode
import org.apache.spark.sql.types.{NumericType, StructType}
Expand Down Expand Up @@ -114,7 +114,7 @@ class RelationalGroupedDataset protected[sql](
namedExpr
}
}
columnExprs.map(column)
columnExprs.map(Column(_))
}

/** @inheritdoc */
Expand Down Expand Up @@ -238,7 +238,7 @@ class RelationalGroupedDataset protected[sql](
broadcastVars: Array[Broadcast[Object]],
outputSchema: StructType): DataFrame = {
val groupingNamedExpressions = groupingExprs.map(alias)
val groupingCols = groupingNamedExpressions.map(column)
val groupingCols = groupingNamedExpressions.map(Column(_))
val groupingDataFrame = df.select(groupingCols : _*)
val groupingAttributes = groupingNamedExpressions.map(_.toAttribute)
Dataset.ofRows(
Expand Down
26 changes: 7 additions & 19 deletions sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ import org.apache.spark.sql.catalog.Catalog
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.analysis.{NameParameterizedQuery, PosParameterizedQuery, UnresolvedRelation}
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, NamedExpression}
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression}
import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.plans.logical.{CompoundBody, LocalRelation, LogicalPlan, Range}
import org.apache.spark.sql.catalyst.types.DataTypeUtils
Expand Down Expand Up @@ -98,7 +98,7 @@ class SparkSession private(
@transient private[sql] val extensions: SparkSessionExtensions,
@transient private[sql] val initialSessionOptions: Map[String, String],
@transient private val parentManagedJobTags: Map[String, String])
extends api.SparkSession with Logging { self =>
extends api.SparkSession with Logging with classic.ColumnConversions { self =>

// The call site where this SparkSession was constructed.
private val creationSite: CallSite = Utils.getCallSite()
Expand Down Expand Up @@ -797,23 +797,11 @@ class SparkSession private(
.getOrElse(sparkContext.defaultParallelism)
}

private[sql] object Converter extends ColumnNodeToExpressionConverter with Serializable {
override protected def parser: ParserInterface = sessionState.sqlParser
override protected def conf: SQLConf = sessionState.conf
}

private[sql] def expression(e: Column): Expression = Converter(e.node)

private[sql] implicit class RichColumn(val column: Column) {
/**
* Returns the expression for this column.
*/
def expr: Expression = Converter(column.node)
/**
* Returns the expression for this column either with an existing or auto assigned name.
*/
def named: NamedExpression = ExpressionUtils.toNamed(expr)
}
override protected[sql] val converter: ColumnNodeToExpressionConverter =
new ColumnNodeToExpressionConverter with Serializable {
override protected def parser: ParserInterface = sessionState.sqlParser
override protected def conf: SQLConf = sessionState.conf
}

private[sql] lazy val observationManager = new ObservationManager(this)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,11 @@ import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TableFunctionRe
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
import org.apache.spark.sql.classic.ClassicConversions._
import org.apache.spark.sql.execution.{ExplainMode, QueryExecution}
import org.apache.spark.sql.execution.arrow.ArrowConverters
import org.apache.spark.sql.execution.python.EvaluatePython
import org.apache.spark.sql.internal.ExpressionUtils.{column, expression}
import org.apache.spark.sql.internal.ExpressionUtils.expression
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.util.{MutableURLClassLoader, Utils}
Expand Down Expand Up @@ -152,7 +153,8 @@ private[sql] object PythonSQLUtils extends Logging {
Column(internal.LambdaFunction(function.node, arguments))
}

def namedArgumentExpression(name: String, e: Column): Column = NamedArgumentExpression(name, e)
def namedArgumentExpression(name: String, e: Column): Column =
Column(NamedArgumentExpression(name, expression(e)))

@scala.annotation.varargs
def fn(name: String, arguments: Column*): Column = Column.fn(name, arguments: _*)
Expand Down
Loading

0 comments on commit 6ff1f72

Please sign in to comment.