diff --git a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java index 7b692201dac11..099f33960967a 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/AbstractCommandBuilder.java @@ -63,7 +63,8 @@ abstract class AbstractCommandBuilder { /** * Indicate if the current app submission has to use Spark Connect. */ - protected boolean isRemote = System.getenv().containsKey("SPARK_REMOTE"); + protected boolean isRemote = System.getenv().containsKey("SPARK_REMOTE") || + "1".equals(System.getenv().get("SPARK_CONNECT_MODE")); AbstractCommandBuilder() { this.appArgs = new ArrayList<>(); diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java index c1feb709a93f7..b557bb5bc3722 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkSubmitCommandBuilder.java @@ -384,11 +384,12 @@ private List buildPySparkShellCommand(Map env) throws IO if (remoteStr != null) { env.put("SPARK_REMOTE", remoteStr); env.put("SPARK_CONNECT_MODE_ENABLED", "1"); - } else if (conf.getOrDefault( - SparkLauncher.SPARK_API_MODE, "classic").toLowerCase(Locale.ROOT).equals("connect") && - masterStr != null) { - env.put("SPARK_REMOTE", masterStr); - env.put("SPARK_CONNECT_MODE_ENABLED", "1"); + } else { + String defaultApiMode = "1".equals(System.getenv("SPARK_CONNECT_MODE")) ? "connect" : "classic"; + String apiMode = conf.getOrDefault(SparkLauncher.SPARK_API_MODE, defaultApiMode).toLowerCase(Locale.ROOT); + if (apiMode.equals("connect")) { + env.put("SPARK_CONNECT_MODE_ENABLED", "1"); + } } if (!isEmpty(pyOpts)) { diff --git a/python/pyspark/sql/session.py b/python/pyspark/sql/session.py index 011262f23a9a7..1a0f32557bfae 100644 --- a/python/pyspark/sql/session.py +++ b/python/pyspark/sql/session.py @@ -480,7 +480,11 @@ def getOrCreate(self) -> "SparkSession": from pyspark.core.context import SparkContext with self._lock: - is_api_mode_connect = opts.get("spark.api.mode", "classic").lower() == "connect" + default_api_mode = "classic" + if os.environ["SPARK_CONNECT_MODE"] == "1": + default_api_mode = "connect" + + is_api_mode_connect = opts.get("spark.api.mode", default_api_mode).lower() == "connect" if ( "SPARK_CONNECT_MODE_ENABLED" in os.environ