Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DO-NOT-MERGE] Pythonic approach of setting Spark SQL configurations #49297

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 118 additions & 3 deletions python/pyspark/sql/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

from functools import cached_property
import sys
from typing import Any, Dict, Optional, Union, TYPE_CHECKING
from typing import Any, Dict, Optional, Union, TYPE_CHECKING, List, cast

from pyspark import _NoValue
from pyspark._globals import _NoValueType
from pyspark.errors import PySparkTypeError
from pyspark.errors import PySparkTypeError, SparkNoSuchElementException
from pyspark.logger import PySparkLogger
from pyspark.sql.utils import get_active_spark_context

if TYPE_CHECKING:
from py4j.java_gateway import JavaObject
Expand Down Expand Up @@ -151,6 +153,119 @@ def isModifiable(self, key: str) -> bool:
"""
return self._jconf.isModifiable(key)

@cached_property
def spark(self) -> "RuntimeConfigDictWrapper":
from py4j.java_gateway import JVMView

sc = get_active_spark_context()
jvm = cast(JVMView, sc._jvm)
d = {}
for entry in jvm.PythonSQLUtils.listAllSQLConfigs():
k = entry._1()
default = entry._2()
doc = entry._3()
ver = entry._4()
entry = SQLConfEntry(k, default, doc, ver)
entry.__doc__ = doc # So help function work
d[k] = entry
return RuntimeConfigDictWrapper(self, d, prefix="spark")

def __setitem__(self, key: Any, val: Any) -> None:
if key.startswith("spark."):
self.spark[key[6:]] = val
else:
super().__setattr__(key, val)

def __getitem__(self, item: Any) -> Union["RuntimeConfigDictWrapper", str]:
if item.startswith("spark."):
return self.spark[item[6:]]
else:
return object.__getattribute__(self, item)
HyukjinKwon marked this conversation as resolved.
Show resolved Hide resolved


class SQLConfEntry(str):
def __new__(cls, name: str, value: str, description: str, version: str) -> "SQLConfEntry":
return super().__new__(cls, value)

def __init__(self, name: str, value: str, description: str, version: str):
self._name = name
self._value = value
self._description = description
self._version = version

def desc(self) -> str:
return self._description

def version(self) -> str:
return self._version


class RuntimeConfigDictWrapper:
"""provide attribute-style access to a nested dict"""

_logger = PySparkLogger.getLogger("RuntimeConfigDictWrapper")

def __init__(self, conf: RuntimeConfig, d: Dict[str, SQLConfEntry], prefix: str = ""):
object.__setattr__(self, "d", d)
object.__setattr__(self, "prefix", prefix)
object.__setattr__(self, "_conf", conf)

def __setattr__(self, key: str, val: Any) -> None:
prefix = object.__getattribute__(self, "prefix")
d = object.__getattribute__(self, "d")
if prefix:
prefix += "."
canonical_key = prefix + key

candidates = [
k for k in d.keys() if all(x in k.split(".") for x in canonical_key.split("."))
]
if len(candidates) == 0:
RuntimeConfigDictWrapper._logger.info(
"Setting a configuration '{}' to '{}' (non built-in configuration).".format(
canonical_key, val
)
)
object.__getattribute__(self, "_conf").set(canonical_key, val)

__setitem__ = __setattr__

def __getattr__(self, key: str) -> Union["RuntimeConfigDictWrapper", str]:
prefix = object.__getattribute__(self, "prefix")
d = object.__getattribute__(self, "d")
conf = object.__getattribute__(self, "_conf")
if prefix:
prefix += "."
canonical_key = prefix + key

try:
value = conf.get(canonical_key)
description = "Documentation not found for '{}'.".format(canonical_key)
version = "Version not found for '{}'.".format(canonical_key)
if canonical_key in d:
description = d[canonical_key]._description
version = d[canonical_key]._version

return SQLConfEntry(canonical_key, value, description, version)
except SparkNoSuchElementException:
if not prefix.startswith("_"):
return RuntimeConfigDictWrapper(conf, d, canonical_key)
raise

__getitem__ = __getattr__

def __dir__(self) -> List[str]:
prefix = object.__getattribute__(self, "prefix")
d = object.__getattribute__(self, "d")

if prefix == "":
candidates = d.keys()
offset = 0
else:
candidates = [k for k in d.keys() if all(x in k.split(".") for x in prefix.split("."))]
offset = len(prefix) + 1 # prefix (e.g. "spark.") to trim.
return [c[offset:] for c in candidates]


def _test() -> None:
import os
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,18 +75,18 @@ private[sql] object PythonSQLUtils extends Logging {
groupBy(_.getName).map(v => v._2.head).toArray
}

private def listAllSQLConfigs(): Seq[(String, String, String, String)] = {
def listAllSQLConfigs(): Array[(String, String, String, String)] = {
val conf = new SQLConf()
conf.getAllDefinedConfs
conf.getAllDefinedConfs.toArray
}

def listRuntimeSQLConfigs(): Array[(String, String, String, String)] = {
// Py4J doesn't seem to translate Seq well, so we convert to an Array.
listAllSQLConfigs().filterNot(p => SQLConf.isStaticConfigKey(p._1)).toArray
listAllSQLConfigs().filterNot(p => SQLConf.isStaticConfigKey(p._1))
}

def listStaticSQLConfigs(): Array[(String, String, String, String)] = {
listAllSQLConfigs().filter(p => SQLConf.isStaticConfigKey(p._1)).toArray
listAllSQLConfigs().filter(p => SQLConf.isStaticConfigKey(p._1))
}

def isTimestampNTZPreferred: Boolean =
Expand Down
Loading