Skip to content

Commit

Permalink
Add shims to support SparkContext and RDD
Browse files Browse the repository at this point in the history
  • Loading branch information
hvanhovell committed Sep 10, 2024
1 parent e0a2c74 commit a29a518
Show file tree
Hide file tree
Showing 20 changed files with 341 additions and 108 deletions.
5 changes: 5 additions & 0 deletions connector/connect/client/jvm/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,11 @@
<artifactId>spark-sql-api_${scala.binary.version}</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-connect-shims_${scala.binary.version}</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sketch_${scala.binary.version}</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ import java.util.Properties
import scala.jdk.CollectionConverters._

import org.apache.spark.annotation.Stable
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.connect.proto.Parse.ParseFormat
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.connect.common.DataTypeProtoConverter
import org.apache.spark.sql.types.StructType

Expand Down Expand Up @@ -139,6 +141,14 @@ class DataFrameReader private[sql] (sparkSession: SparkSession)
def json(jsonDataset: Dataset[String]): DataFrame =
parse(jsonDataset, ParseFormat.PARSE_FORMAT_JSON)

/** @inheritdoc */
override def json(jsonRDD: JavaRDD[String]): Dataset[Row] =
throwRddNotSupportedException()

/** @inheritdoc */
override def json(jsonRDD: RDD[String]): Dataset[Row] =
throwRddNotSupportedException()

/** @inheritdoc */
override def csv(path: String): DataFrame = super.csv(path)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@ import scala.util.control.NonFatal

import org.apache.spark.SparkException
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.java.function._
import org.apache.spark.connect.proto
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoder
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders._
Expand Down Expand Up @@ -1479,4 +1481,10 @@ class Dataset[T] private[sql] (
/** @inheritdoc */
@scala.annotation.varargs
override def agg(expr: Column, exprs: Column*): DataFrame = super.agg(expr, exprs: _*)

/** @inheritdoc */
override def rdd: RDD[T] = throwRddNotSupportedException()

/** @inheritdoc */
override def toJavaRDD: JavaRDD[T] = throwRddNotSupportedException()
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,13 @@ import com.google.common.cache.{CacheBuilder, CacheLoader}
import io.grpc.ClientInterceptor
import org.apache.arrow.memory.RootAllocator

import org.apache.spark.SparkContext
import org.apache.spark.annotation.{DeveloperApi, Experimental, Since}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.ExecutePlanResponse
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalog.Catalog
import org.apache.spark.sql.catalyst.{JavaTypeInference, ScalaReflection}
import org.apache.spark.sql.catalyst.encoders.{AgnosticEncoder, RowEncoder}
Expand Down Expand Up @@ -84,10 +87,14 @@ class SparkSession private[sql] (

private[sql] val observationRegistry = new ConcurrentHashMap[Long, Observation]()

private[sql] def hijackServerSideSessionIdForTesting(suffix: String) = {
private[sql] def hijackServerSideSessionIdForTesting(suffix: String): Unit = {
client.hijackServerSideSessionIdForTesting(suffix)
}

/** @inheritdoc */
override def sparkContext: SparkContext =
throw new UnsupportedOperationException("sparkContext is not supported in Spark Connect.")

/**
* Runtime configuration interface for Spark.
*
Expand Down Expand Up @@ -152,6 +159,30 @@ class SparkSession private[sql] (
createDataset(data.asScala.toSeq)
}

/** @inheritdoc */
override def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame =
throwRddNotSupportedException()

/** @inheritdoc */
override def createDataFrame(rowRDD: RDD[Row], schema: StructType): DataFrame =
throwRddNotSupportedException()

/** @inheritdoc */
override def createDataFrame(rowRDD: JavaRDD[Row], schema: StructType): DataFrame =
throwRddNotSupportedException()

/** @inheritdoc */
override def createDataFrame(rdd: RDD[_], beanClass: Class[_]): DataFrame =
throwRddNotSupportedException()

/** @inheritdoc */
override def createDataFrame(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame =
throwRddNotSupportedException()

/** @inheritdoc */
override def createDataset[T: Encoder](data: RDD[T]): Dataset[T] =
throwRddNotSupportedException()

/** @inheritdoc */
@Experimental
def sql(sqlText: String, args: Array[_]): DataFrame = newDataFrame { builder =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,7 @@ package object sql {
f(builder)
column(builder.build())
}

private[sql] def throwRddNotSupportedException(): Nothing =
throw new UnsupportedOperationException("RDDs are not supported in Spark Connect.")
}
1 change: 1 addition & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
<module>common/utils</module>
<module>common/variant</module>
<module>common/tags</module>
<module>sql/connect/shims</module>
<module>core</module>
<module>graphx</module>
<module>mllib</module>
Expand Down
6 changes: 6 additions & 0 deletions sql/api/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,12 @@
<artifactId>spark-sketch_${scala.binary.version}</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-connect-shims_${scala.binary.version}</artifactId>
<version>${project.version}</version>
<scope>compile</scope>
</dependency>
<dependency>
<groupId>org.json4s</groupId>
<artifactId>json4s-jackson_${scala.binary.version}</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import scala.jdk.CollectionConverters._
import _root_.java.util

import org.apache.spark.annotation.Stable
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder
import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, SparkCharVarcharUtils}
Expand Down Expand Up @@ -308,6 +310,35 @@ abstract class DataFrameReader[DS[U] <: Dataset[U, DS]] {
*/
def json(jsonDataset: DS[String]): DS[Row]

/**
* Loads a `JavaRDD[String]` storing JSON objects (<a href="http://jsonlines.org/">JSON
* Lines text format or newline-delimited JSON</a>) and returns the result as
* a `DataFrame`.
*
* Unless the schema is specified using `schema` function, this function goes through the
* input once to determine the input schema.
*
* @note this method is not supported in Spark Connect.
* @param jsonRDD input RDD with one JSON object per record
* @since 1.4.0
*/
@deprecated("Use json(Dataset[String]) instead.", "2.2.0")
def json(jsonRDD: JavaRDD[String]): DS[Row]

/**
* Loads an `RDD[String]` storing JSON objects (<a href="http://jsonlines.org/">JSON Lines
* text format or newline-delimited JSON</a>) and returns the result as a `DataFrame`.
*
* Unless the schema is specified using `schema` function, this function goes through the
* input once to determine the input schema.
*
* @note this method is not supported in Spark Connect.
* @param jsonRDD input RDD with one JSON object per record
* @since 1.4.0
*/
@deprecated("Use json(Dataset[String]) instead.", "2.2.0")
def json(jsonRDD: RDD[String]): DS[Row]

/**
* Loads a CSV file and returns the result as a `DataFrame`. See the documentation on the other
* overloaded `csv()` method for more details.
Expand Down
29 changes: 29 additions & 0 deletions sql/api/src/main/scala/org/apache/spark/sql/api/Dataset.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ import scala.reflect.runtime.universe.TypeTag
import _root_.java.util

import org.apache.spark.annotation.{DeveloperApi, Stable}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.java.function.{FilterFunction, FlatMapFunction, ForeachFunction, ForeachPartitionFunction, MapFunction, MapPartitionsFunction, ReduceFunction}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{functions, AnalysisException, Column, DataFrameWriter, DataFrameWriterV2, Encoder, MergeIntoWriter, Observation, Row, TypedColumn}
import org.apache.spark.sql.internal.{ToScalaUDF, UDFAdaptors}
import org.apache.spark.sql.types.{Metadata, StructType}
Expand Down Expand Up @@ -3055,4 +3057,31 @@ abstract class Dataset[T, DS[U] <: Dataset[U, DS]] extends Serializable {
* @since 1.6.0
*/
def write: DataFrameWriter[T]

/**
* Represents the content of the Dataset as an `RDD` of `T`.
*
* @note this method is not supported in Spark Connect.
* @group basic
* @since 1.6.0
*/
def rdd: RDD[T]

/**
* Returns the content of the Dataset as a `JavaRDD` of `T`s.
*
* @note this method is not supported in Spark Connect.
* @group basic
* @since 1.6.0
*/
def toJavaRDD: JavaRDD[T]

/**
* Returns the content of the Dataset as a `JavaRDD` of `T`s.
*
* @note this method is not supported in Spark Connect.
* @group basic
* @since 1.6.0
*/
def javaRDD: JavaRDD[T] = toJavaRDD
}
97 changes: 97 additions & 0 deletions sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@ import _root_.java.lang
import _root_.java.net.URI
import _root_.java.util

import org.apache.spark.SparkContext
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Encoder, Row}
import org.apache.spark.sql.types.StructType

Expand All @@ -51,6 +54,13 @@ import org.apache.spark.sql.types.StructType
*/
abstract class SparkSession[DS[U] <: Dataset[U, DS]] extends Serializable with Closeable {

/**
* The Spark context associated with this Spark session.
*
* @note this method is not supported in Spark Connect.
*/
def sparkContext: SparkContext

/**
* The version of Spark on which this application is running.
*
Expand Down Expand Up @@ -134,6 +144,82 @@ abstract class SparkSession[DS[U] <: Dataset[U, DS]] extends Serializable with C
*/
def createDataFrame(data: util.List[_], beanClass: Class[_]): DS[Row]

/**
* Creates a `DataFrame` from an RDD of Product (e.g. case classes, tuples).
*
* @note this method is not supported in Spark Connect.
* @since 2.0.0
*/
def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DS[Row]

/**
* :: DeveloperApi ::
* Creates a `DataFrame` from an `RDD` containing [[Row]]s using the given schema.
* It is important to make sure that the structure of every [[Row]] of the provided RDD matches
* the provided schema. Otherwise, there will be runtime exception.
* Example:
* {{{
* import org.apache.spark.sql._
* import org.apache.spark.sql.types._
* val sparkSession = new org.apache.spark.sql.SparkSession(sc)
*
* val schema =
* StructType(
* StructField("name", StringType, false) ::
* StructField("age", IntegerType, true) :: Nil)
*
* val people =
* sc.textFile("examples/src/main/resources/people.txt").map(
* _.split(",")).map(p => Row(p(0), p(1).trim.toInt))
* val dataFrame = sparkSession.createDataFrame(people, schema)
* dataFrame.printSchema
* // root
* // |-- name: string (nullable = false)
* // |-- age: integer (nullable = true)
*
* dataFrame.createOrReplaceTempView("people")
* sparkSession.sql("select name from people").collect.foreach(println)
* }}}
*
* @note this method is not supported in Spark Connect.
* @since 2.0.0
*/
@DeveloperApi
def createDataFrame(rowRDD: RDD[Row], schema: StructType): DS[Row]

/**
* :: DeveloperApi ::
* Creates a `DataFrame` from a `JavaRDD` containing [[Row]]s using the given schema.
* It is important to make sure that the structure of every [[Row]] of the provided RDD matches
* the provided schema. Otherwise, there will be runtime exception.
*
* @note this method is not supported in Spark Connect.
* @since 2.0.0
*/
@DeveloperApi
def createDataFrame(rowRDD: JavaRDD[Row], schema: StructType): DS[Row]

/**
* Applies a schema to an RDD of Java Beans.
*
* WARNING: Since there is no guaranteed ordering for fields in a Java Bean,
* SELECT * queries will return the columns in an undefined order.
*
* @since 2.0.0
*/
def createDataFrame(rdd: RDD[_], beanClass: Class[_]): DS[Row]

/**
* Applies a schema to an RDD of Java Beans.
*
* WARNING: Since there is no guaranteed ordering for fields in a Java Bean,
* SELECT * queries will return the columns in an undefined order.
*
* @note this method is not supported in Spark Connect.
* @since 2.0.0
*/
def createDataFrame(rdd: JavaRDD[_], beanClass: Class[_]): DS[Row]

/* ------------------------------- *
| Methods for creating DataSets |
* ------------------------------- */
Expand Down Expand Up @@ -191,6 +277,17 @@ abstract class SparkSession[DS[U] <: Dataset[U, DS]] extends Serializable with C
*/
def createDataset[T: Encoder](data: util.List[T]): DS[T]

/**
* Creates a [[Dataset]] from an RDD of a given type. This method requires an
* encoder (to convert a JVM object of type `T` to and from the internal Spark SQL representation)
* that is generally created automatically through implicits from a `SparkSession`, or can be
* created explicitly by calling static methods on `Encoders`.
*
* @note this method is not supported in Spark Connect.
* @since 2.0.0
*/
def createDataset[T: Encoder](data: RDD[T]): DS[T]

/**
* Creates a [[Dataset]] with a single `LongType` column named `id`, containing elements in a
* range from 0 to `end` (exclusive) with step value 1.
Expand Down
6 changes: 6 additions & 0 deletions sql/catalyst/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql-api_${scala.binary.version}</artifactId>
<version>${project.version}</version>
<exclusions>
<exclusion>
<groupId>org.apache.spark</groupId>
<artifactId>spark-connect-shims_${scala.binary.version}</artifactId>
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
Expand Down
4 changes: 4 additions & 0 deletions sql/connect/server/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@
<artifactId>spark-connect-common_${scala.binary.version}</artifactId>
<version>${project.version}</version>
<exclusions>
<exclusion>
<groupId>org.apache.spark</groupId>
<artifactId>spark-connect-shims_${scala.binary.version}</artifactId>
</exclusion>
<exclusion>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
Expand Down
1 change: 1 addition & 0 deletions sql/connect/shims/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
This module defines shims used by the interface defined in sql/api.
Loading

0 comments on commit a29a518

Please sign in to comment.