Skip to content

Commit

Permalink
We like thread locals :)
Browse files Browse the repository at this point in the history
  • Loading branch information
hvanhovell committed Oct 7, 2024
1 parent d68048b commit 2e1413a
Show file tree
Hide file tree
Showing 4 changed files with 213 additions and 198 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql
import java.net.URI
import java.nio.file.{Files, Paths}
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.{AtomicLong, AtomicReference}
import java.util.concurrent.atomic.AtomicLong

import scala.jdk.CollectionConverters._
import scala.reflect.runtime.universe.TypeTag
Expand Down Expand Up @@ -494,6 +494,8 @@ class SparkSession private[sql] (
}
}

override private[sql] def isUsable: Boolean = client.isSessionValid

implicit class RichColumn(c: Column) {
def expr: proto.Expression = toExpr(c)
def typedExpr[T](e: Encoder[T]): proto.Expression = toTypedExpr(c, e)
Expand All @@ -502,7 +504,9 @@ class SparkSession private[sql] (

// The minimal builder needed to create a spark session.
// TODO: implements all methods mentioned in the scaladoc of [[SparkSession]]
object SparkSession extends api.SparkSessionCompanion with Logging {
object SparkSession extends api.BaseSparkSessionCompanion with Logging {
override private[sql] type Session = SparkSession

private val MAX_CACHED_SESSIONS = 100
private val planIdGenerator = new AtomicLong
private var server: Option[Process] = None
Expand All @@ -518,29 +522,6 @@ object SparkSession extends api.SparkSessionCompanion with Logging {
override def load(c: Configuration): SparkSession = create(c)
})

/** The active SparkSession for the current thread. */
private val activeThreadSession = new InheritableThreadLocal[SparkSession]

/** Reference to the root SparkSession. */
private val defaultSession = new AtomicReference[SparkSession]

/**
* Set the (global) default [[SparkSession]], and (thread-local) active [[SparkSession]] when
* they are not set yet or the associated [[SparkConnectClient]] is unusable.
*/
private def setDefaultAndActiveSession(session: SparkSession): Unit = {
val currentDefault = defaultSession.getAcquire
if (currentDefault == null || !currentDefault.client.isSessionValid) {
// Update `defaultSession` if it is null or the contained session is not valid. There is a
// chance that the following `compareAndSet` fails if a new default session has just been set,
// but that does not matter since that event has happened after this method was invoked.
defaultSession.compareAndSet(currentDefault, session)
}
if (getActiveSession.isEmpty) {
setActiveSession(session)
}
}

/**
* Create a new Spark Connect server to connect locally.
*/
Expand Down Expand Up @@ -593,17 +574,6 @@ object SparkSession extends api.SparkSessionCompanion with Logging {
new SparkSession(configuration.toSparkConnectClient, planIdGenerator)
}

/**
* Hook called when a session is closed.
*/
private[sql] def onSessionClose(session: SparkSession): Unit = {
sessions.invalidate(session.client.configuration)
defaultSession.compareAndSet(session, null)
if (getActiveSession.contains(session)) {
clearActiveSession()
}
}

/**
* Creates a [[SparkSession.Builder]] for constructing a [[SparkSession]].
*
Expand Down Expand Up @@ -750,71 +720,12 @@ object SparkSession extends api.SparkSessionCompanion with Logging {
}
}

/**
* Returns the default SparkSession. If the previously set default SparkSession becomes
* unusable, returns None.
*
* @since 3.5.0
*/
def getDefaultSession: Option[SparkSession] =
Option(defaultSession.get()).filter(_.client.isSessionValid)

/**
* Sets the default SparkSession.
*
* @since 3.5.0
*/
def setDefaultSession(session: SparkSession): Unit = {
defaultSession.set(session)
}

/**
* Clears the default SparkSession.
*
* @since 3.5.0
*/
def clearDefaultSession(): Unit = {
defaultSession.set(null)
}

/**
* Returns the active SparkSession for the current thread. If the previously set active
* SparkSession becomes unusable, returns None.
*
* @since 3.5.0
*/
def getActiveSession: Option[SparkSession] =
Option(activeThreadSession.get()).filter(_.client.isSessionValid)

/**
* Changes the SparkSession that will be returned in this thread and its children when
* SparkSession.getOrCreate() is called. This can be used to ensure that a given thread receives
* an isolated SparkSession.
*
* @since 3.5.0
*/
def setActiveSession(session: SparkSession): Unit = {
activeThreadSession.set(session)
}
/** @inheritdoc */
override def getActiveSession: Option[SparkSession] = super.getActiveSession

/**
* Clears the active SparkSession for current thread.
*
* @since 3.5.0
*/
def clearActiveSession(): Unit = {
activeThreadSession.remove()
}
/** @inheritdoc */
override def getDefaultSession: Option[SparkSession] = super.getDefaultSession

/**
* Returns the currently active SparkSession, otherwise the default one. If there is no default
* SparkSession, throws an exception.
*
* @since 3.5.0
*/
def active: SparkSession = {
getActiveSession
.orElse(getDefaultSession)
.getOrElse(throw new IllegalStateException("No active or default Spark session found"))
}
/** @inheritdoc */
override def active: SparkSession = super.active
}
6 changes: 6 additions & 0 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,12 @@ object MimaExcludes {
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.expressions.javalang.typed"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.expressions.scalalang.typed"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.expressions.scalalang.typed$"),

// SPARK-49418: Consolidate thread local handling in sql/api
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.SparkSession.setActiveSession"),
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.sql.SparkSession.setDefaultSession"),
ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.api.SparkSessionCompanion.clearActiveSession"),
ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.api.SparkSessionCompanion.clearDefaultSession"),
) ++ loggingExcludes("org.apache.spark.sql.DataFrameReader") ++
loggingExcludes("org.apache.spark.sql.streaming.DataStreamReader") ++
loggingExcludes("org.apache.spark.sql.SparkSession#Builder")
Expand Down
166 changes: 165 additions & 1 deletion sql/api/src/main/scala/org/apache/spark/sql/api/SparkSession.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ import _root_.java.io.Closeable
import _root_.java.lang
import _root_.java.net.URI
import _root_.java.util
import _root_.java.util.concurrent.atomic.AtomicReference

import org.apache.spark.SparkException
import org.apache.spark.annotation.{DeveloperApi, Experimental, Stable, Unstable}
import org.apache.spark.sql.{Encoder, Row, RuntimeConfig}
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -561,9 +563,19 @@ abstract class SparkSession extends Serializable with Closeable {
* @since 2.0.0
*/
def stop(): Unit = close()

/**
* Check to see if the session is still usable.
*
* In classic this means that the underlying `SparkContext` has been shut down. In Connect this
* means the connection to the server has been closed.
*/
private[sql] def isUsable: Boolean
}

object SparkSession extends SparkSessionCompanion {
type Session = SparkSession

private[this] val companion: SparkSessionCompanion = {
val cls = SparkClassUtils.classForName("org.apache.spark.sql.SparkSession")
val mirror = scala.reflect.runtime.currentMirror
Expand All @@ -573,12 +585,97 @@ object SparkSession extends SparkSessionCompanion {

/** @inheritdoc */
override def builder(): SparkSessionBuilder = companion.builder()

/** @inheritdoc */
override def setActiveSession(session: SparkSession): Unit =
companion.setActiveSession(session.asInstanceOf[companion.Session])

/** @inheritdoc */
override def clearActiveSession(): Unit = companion.clearActiveSession()

/** @inheritdoc */
override def setDefaultSession(session: SparkSession): Unit =
companion.setDefaultSession(session.asInstanceOf[companion.Session])

/** @inheritdoc */
override def clearDefaultSession(): Unit = companion.clearDefaultSession()

/** @inheritdoc */
override def getActiveSession: Option[SparkSession] = companion.getActiveSession

/** @inheritdoc */
override def getDefaultSession: Option[SparkSession] = companion.getDefaultSession
}

/**
* Companion of a [[SparkSession]].
* Interface for a [[SparkSession]] Companion. The companion is responsible for building the
* session, and managing the active (thread local) and default (global) SparkSessions.
*/
private[sql] abstract class SparkSessionCompanion {
private[sql] type Session <: SparkSession

/**
* Changes the SparkSession that will be returned in this thread and its children when
* SparkSession.getOrCreate() is called. This can be used to ensure that a given thread receives
* a SparkSession with an isolated session, instead of the global (first created) context.
*
* @since 2.0.0
*/
def setActiveSession(session: Session): Unit

/**
* Clears the active SparkSession for current thread. Subsequent calls to getOrCreate will
* return the first created context instead of a thread-local override.
*
* @since 2.0.0
*/
def clearActiveSession(): Unit

/**
* Sets the default SparkSession that is returned by the builder.
*
* @since 2.0.0
*/
def setDefaultSession(session: Session): Unit

/**
* Clears the default SparkSession that is returned by the builder.
*
* @since 2.0.0
*/
def clearDefaultSession(): Unit

/**
* Returns the active SparkSession for the current thread, returned by the builder.
*
* @note
* Return None, when calling this function on executors
*
* @since 2.2.0
*/
def getActiveSession: Option[Session]

/**
* Returns the default SparkSession that is returned by the builder.
*
* @note
* Return None, when calling this function on executors
*
* @since 2.2.0
*/
def getDefaultSession: Option[Session]

/**
* Returns the currently active SparkSession, otherwise the default one. If there is no default
* SparkSession, throws an exception.
*
* @since 2.4.0
*/
def active: Session = {
getActiveSession.getOrElse(
getDefaultSession.getOrElse(
throw SparkException.internalError("No active or default Spark session found")))
}

/**
* Creates a [[SparkSessionBuilder]] for constructing a [[SparkSession]].
Expand All @@ -588,6 +685,73 @@ private[sql] abstract class SparkSessionCompanion {
def builder(): SparkSessionBuilder
}

/**
* Abstract class for [[SparkSession]] companions. This implements active and default session
* management.
*/
private[sql] abstract class BaseSparkSessionCompanion extends SparkSessionCompanion {

/** The active SparkSession for the current thread. */
private val activeThreadSession = new InheritableThreadLocal[Session]

/** Reference to the root SparkSession. */
private val defaultSession = new AtomicReference[Session]

/** @inheritdoc */
def setActiveSession(session: Session): Unit = {
activeThreadSession.set(session)
}

/** @inheritdoc */
def clearActiveSession(): Unit = {
activeThreadSession.remove()
}

/** @inheritdoc */
def setDefaultSession(session: Session): Unit = {
defaultSession.set(session)
}

/** @inheritdoc */
def clearDefaultSession(): Unit = {
defaultSession.set(null.asInstanceOf[Session])
}

/** @inheritdoc */
def getActiveSession: Option[Session] = Option(activeThreadSession.get)

/** @inheritdoc */
def getDefaultSession: Option[Session] = Option(defaultSession.get)

/**
* Set the (global) default [[SparkSession]], and (thread-local) active [[SparkSession]] when
* they are not set yet or they are not usable.
*/
protected def setDefaultAndActiveSession(session: Session): Unit = {
val currentDefault = defaultSession.getAcquire
if (currentDefault == null || !currentDefault.isUsable) {
// Update `defaultSession` if it is null or the contained session is not usable. There is a
// chance that the following `compareAndSet` fails if a new default session has just been set,
// but that does not matter since that event has happened after this method was invoked.
defaultSession.compareAndSet(currentDefault, session)
}
val active = getActiveSession
if (active.isEmpty || !active.get.isUsable) {
setActiveSession(session)
}
}

/**
* When the session is closed remove it from active and default.
*/
private[sql] def onSessionClose(session: Session): Unit = {
defaultSession.compareAndSet(session, null.asInstanceOf[Session])
if (getActiveSession.contains(session)) {
clearActiveSession()
}
}
}

/**
* Builder for [[SparkSession]].
*/
Expand Down
Loading

0 comments on commit 2e1413a

Please sign in to comment.