From dadd62361152670aa6dc94e996348c5e856eac0f Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Thu, 26 Dec 2024 15:42:22 +0900 Subject: [PATCH 1/2] POC --- python/pyspark/sql/conf.py | 121 +++++++++++++++++- .../spark/sql/api/python/PythonSQLUtils.scala | 8 +- 2 files changed, 122 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql/conf.py b/python/pyspark/sql/conf.py index 9a4cc2e7e1628..4be3ba299433e 100644 --- a/python/pyspark/sql/conf.py +++ b/python/pyspark/sql/conf.py @@ -14,13 +14,15 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +from functools import cached_property import sys -from typing import Any, Dict, Optional, Union, TYPE_CHECKING +from typing import Any, Dict, Optional, Union, TYPE_CHECKING, List, cast from pyspark import _NoValue from pyspark._globals import _NoValueType -from pyspark.errors import PySparkTypeError +from pyspark.errors import PySparkTypeError, SparkNoSuchElementException +from pyspark.logger import PySparkLogger +from pyspark.sql.utils import get_active_spark_context if TYPE_CHECKING: from py4j.java_gateway import JavaObject @@ -151,6 +153,119 @@ def isModifiable(self, key: str) -> bool: """ return self._jconf.isModifiable(key) + @cached_property + def spark(self) -> "RuntimeConfigDictWrapper": + from py4j.java_gateway import JVMView + + sc = get_active_spark_context() + jvm = cast(JVMView, sc._jvm) + d = {} + for entry in jvm.PythonSQLUtils.listAllSQLConfigs(): + k = entry._1() + default = entry._2() + doc = entry._3() + ver = entry._4() + entry = SQLConfEntry(k, default, doc, ver) + entry.__doc__ = doc # So help function work + d[k] = entry + return RuntimeConfigDictWrapper(self, d, prefix="spark") + + def __setitem__(self, key: Any, val: Any) -> None: + if key.startswith("spark."): + self.spark[key[6:]] = val + else: + super().__setattr__(key, val) + + def __getitem__(self, item: Any) -> Union["RuntimeConfigDictWrapper", str]: + if item.startswith("spark."): + return self.spark[item[6:]] + else: + return object.__getattribute__(self, item) + + +class SQLConfEntry(str): + def __new__(cls, name: str, value: str, description: str, version: str) -> "SQLConfEntry": + return super().__new__(cls, value) + + def __init__(self, name: str, value: str, description: str, version: str): + self._name = name + self._value = value + self._description = description + self._version = version + + def desc(self) -> str: + return self._description + + def version(self) -> str: + return self._version + + +class RuntimeConfigDictWrapper: + """provide attribute-style access to a nested dict""" + + _logger = PySparkLogger.getLogger("RuntimeConfigDictWrapper") + + def __init__(self, conf: RuntimeConfig, d: Dict[str, SQLConfEntry], prefix: str = ""): + object.__setattr__(self, "d", d) + object.__setattr__(self, "prefix", prefix) + object.__setattr__(self, "_conf", conf) + + def __setattr__(self, key: str, val: Any) -> None: + prefix = object.__getattribute__(self, "prefix") + d = object.__getattribute__(self, "d") + if prefix: + prefix += "." + canonical_key = prefix + key + + candidates = [ + k for k in d.keys() if all(x in k.split(".") for x in canonical_key.split(".")) + ] + if len(candidates) == 0: + RuntimeConfigDictWrapper._logger.info( + "Setting a configuration '{}' to '{}' (non built-in configuration).".format( + canonical_key, val + ) + ) + object.__getattribute__(self, "_conf").set(canonical_key, val) + + __setitem__ = __setattr__ + + def __getattr__(self, key: str) -> Union["RuntimeConfigDictWrapper", str]: + prefix = object.__getattribute__(self, "prefix") + d = object.__getattribute__(self, "d") + conf = object.__getattribute__(self, "_conf") + if prefix: + prefix += "." + canonical_key = prefix + key + + try: + value = conf.get(canonical_key) + description = "Documentation not found for '{}'.".format(canonical_key) + version = "Version not found for '{}'.".format(canonical_key) + if canonical_key in d: + description = d[canonical_key]._description + version = d[canonical_key]._version + + return SQLConfEntry(canonical_key, value, description, version) + except SparkNoSuchElementException: + if not prefix.startswith("_"): + return RuntimeConfigDictWrapper(conf, d, canonical_key) + raise + + __getitem__ = __getattr__ + + def __dir__(self) -> List[str]: + prefix = object.__getattribute__(self, "prefix") + d = object.__getattribute__(self, "d") + + if prefix == "": + candidates = d.keys() + offset = 0 + else: + candidates = [k for k in d.keys() if all(x in k.split(".") for x in prefix.split("."))] + offset = len(prefix) + 1 # prefix (e.g. "spark.") to trim. + return [c[offset:] for c in candidates] + def _test() -> None: import os diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index e33fe38b160af..5ebb5564c198d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -75,18 +75,18 @@ private[sql] object PythonSQLUtils extends Logging { groupBy(_.getName).map(v => v._2.head).toArray } - private def listAllSQLConfigs(): Seq[(String, String, String, String)] = { + def listAllSQLConfigs(): Array[(String, String, String, String)] = { val conf = new SQLConf() - conf.getAllDefinedConfs + conf.getAllDefinedConfs.toArray } def listRuntimeSQLConfigs(): Array[(String, String, String, String)] = { // Py4J doesn't seem to translate Seq well, so we convert to an Array. - listAllSQLConfigs().filterNot(p => SQLConf.isStaticConfigKey(p._1)).toArray + listAllSQLConfigs().filterNot(p => SQLConf.isStaticConfigKey(p._1)) } def listStaticSQLConfigs(): Array[(String, String, String, String)] = { - listAllSQLConfigs().filter(p => SQLConf.isStaticConfigKey(p._1)).toArray + listAllSQLConfigs().filter(p => SQLConf.isStaticConfigKey(p._1)) } def isTimestampNTZPreferred: Boolean = From 87ef983cbdf74ca1c0d545854e411ce834b75a3d Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Fri, 27 Dec 2024 09:52:47 +0900 Subject: [PATCH 2/2] fixup --- python/pyspark/sql/conf.py | 50 ++++++++++++++++++++++++++------------ 1 file changed, 35 insertions(+), 15 deletions(-) diff --git a/python/pyspark/sql/conf.py b/python/pyspark/sql/conf.py index 4be3ba299433e..3c5dc94e8b9e6 100644 --- a/python/pyspark/sql/conf.py +++ b/python/pyspark/sql/conf.py @@ -171,16 +171,26 @@ def spark(self) -> "RuntimeConfigDictWrapper": return RuntimeConfigDictWrapper(self, d, prefix="spark") def __setitem__(self, key: Any, val: Any) -> None: + prefix = "spark." if key.startswith("spark."): - self.spark[key[6:]] = val + self.spark[key[len(prefix) :]] = val else: super().__setattr__(key, val) def __getitem__(self, item: Any) -> Union["RuntimeConfigDictWrapper", str]: + prefix = "spark." if item.startswith("spark."): - return self.spark[item[6:]] + return self.spark[item[len(prefix) :]] else: - return object.__getattribute__(self, item) + return super().__getattribute__(item) + + def __delitem__(self, item: Any) -> None: + prefix = "spark." + if item.startswith("spark."): + del self.spark[item[len(prefix) :]] + else: + # So it throws the same error as if `__delitem__` does not exist + getattr(self, "__delitem__") class SQLConfEntry(str): @@ -206,13 +216,13 @@ class RuntimeConfigDictWrapper: _logger = PySparkLogger.getLogger("RuntimeConfigDictWrapper") def __init__(self, conf: RuntimeConfig, d: Dict[str, SQLConfEntry], prefix: str = ""): - object.__setattr__(self, "d", d) - object.__setattr__(self, "prefix", prefix) - object.__setattr__(self, "_conf", conf) + super().__setattr__("d", d) + super().__setattr__("prefix", prefix) + super().__setattr__("_conf", conf) def __setattr__(self, key: str, val: Any) -> None: - prefix = object.__getattribute__(self, "prefix") - d = object.__getattribute__(self, "d") + prefix = super().__getattribute__("prefix") + d = super().__getattribute__("d") if prefix: prefix += "." canonical_key = prefix + key @@ -226,20 +236,22 @@ def __setattr__(self, key: str, val: Any) -> None: canonical_key, val ) ) - object.__getattribute__(self, "_conf").set(canonical_key, val) + super().__getattribute__("_conf").set(canonical_key, val) __setitem__ = __setattr__ - def __getattr__(self, key: str) -> Union["RuntimeConfigDictWrapper", str]: - prefix = object.__getattribute__(self, "prefix") - d = object.__getattribute__(self, "d") - conf = object.__getattribute__(self, "_conf") + def __getattr__(self, key: str) -> Union["RuntimeConfigDictWrapper", Optional[str]]: + prefix = super().__getattribute__("prefix") + d = super().__getattribute__("d") + conf = super().__getattribute__("_conf") if prefix: prefix += "." canonical_key = prefix + key try: value = conf.get(canonical_key) + if value is None: + return None description = "Documentation not found for '{}'.".format(canonical_key) version = "Version not found for '{}'.".format(canonical_key) if canonical_key in d: @@ -254,9 +266,17 @@ def __getattr__(self, key: str) -> Union["RuntimeConfigDictWrapper", str]: __getitem__ = __getattr__ + def __delitem__(self, key) -> None: + prefix = super().__getattribute__("prefix") + conf = super().__getattribute__("_conf") + if prefix: + prefix += "." + canonical_key = prefix + key + conf.unset(canonical_key) + def __dir__(self) -> List[str]: - prefix = object.__getattribute__(self, "prefix") - d = object.__getattribute__(self, "d") + prefix = super().__getattribute__("prefix") + d = super().__getattribute__("d") if prefix == "": candidates = d.keys()