From 7a20c7a006c10c72d0ad8112067156595ce28d47 Mon Sep 17 00:00:00 2001 From: Niels Bantilan Date: Tue, 5 Dec 2023 13:02:56 -0500 Subject: [PATCH] update cache dataframe config args, fix tests (#1437) This PR renames the pandera config arguments introduced in this PR: https://github.com/unionai-oss/pandera/pull/1414 and makes the names more generic. Fixes tests that were broken by the config changes. Signed-off-by: Niels Bantilan --- pandera/backends/pyspark/decorators.py | 4 ++-- pandera/config.py | 8 ++++---- tests/core/test_pandas_config.py | 4 ++++ tests/pyspark/test_pyspark_config.py | 26 ++++++++++++------------ tests/pyspark/test_pyspark_decorators.py | 10 ++++----- 5 files changed, 28 insertions(+), 24 deletions(-) diff --git a/pandera/backends/pyspark/decorators.py b/pandera/backends/pyspark/decorators.py index 9e7320314..3dacc398b 100644 --- a/pandera/backends/pyspark/decorators.py +++ b/pandera/backends/pyspark/decorators.py @@ -156,7 +156,7 @@ def _wrapper(func): @functools.wraps(func) def wrapper(self, *args, **kwargs): # Skip if not enabled - if CONFIG.pyspark_cache is not True: + if CONFIG.cache_dataframe is not True: return func(self, *args, **kwargs) check_obj: DataFrame = None @@ -186,7 +186,7 @@ def cached_check_obj(): yield # Execute the decorated function - if not CONFIG.pyspark_keep_cache: + if not CONFIG.keep_cached_dataframe: # If not cached, `.unpersist()` does nothing logger.debug("Unpersisting dataframe...") check_obj.unpersist() diff --git a/pandera/config.py b/pandera/config.py index 9fccc3269..f92f3e510 100644 --- a/pandera/config.py +++ b/pandera/config.py @@ -26,8 +26,8 @@ class PanderaConfig(BaseModel): validation_enabled: bool = True validation_depth: ValidationDepth = ValidationDepth.SCHEMA_AND_DATA - pyspark_cache: bool = False - pyspark_keep_cache: bool = False + cache_dataframe: bool = False + keep_cached_dataframe: bool = False # this config variable should be accessible globally @@ -39,11 +39,11 @@ class PanderaConfig(BaseModel): validation_depth=os.environ.get( "PANDERA_VALIDATION_DEPTH", ValidationDepth.SCHEMA_AND_DATA ), - pyspark_cache=os.environ.get( + cache_dataframe=os.environ.get( "PANDERA_CACHE_DATAFRAME", False, ), - pyspark_keep_cache=os.environ.get( + keep_cached_dataframe=os.environ.get( "PANDERA_KEEP_CACHED_DATAFRAME", False, ), diff --git a/tests/core/test_pandas_config.py b/tests/core/test_pandas_config.py index f1c542379..59ad2f617 100644 --- a/tests/core/test_pandas_config.py +++ b/tests/core/test_pandas_config.py @@ -44,6 +44,8 @@ class TestSchema(DataFrameModel): price_val: int = pa.Field() expected = { + "cache_dataframe": False, + "keep_cached_dataframe": False, "validation_enabled": False, "validation_depth": ValidationDepth.SCHEMA_AND_DATA, } @@ -61,6 +63,8 @@ class TestPandasSeriesConfig: def test_disable_validation(self, disable_validation): """This function validates that a none object is loaded if validation is disabled""" expected = { + "cache_dataframe": False, + "keep_cached_dataframe": False, "validation_enabled": False, "validation_depth": ValidationDepth.SCHEMA_AND_DATA, } diff --git a/tests/pyspark/test_pyspark_config.py b/tests/pyspark/test_pyspark_config.py index 745cde4dc..8a01855cb 100644 --- a/tests/pyspark/test_pyspark_config.py +++ b/tests/pyspark/test_pyspark_config.py @@ -42,8 +42,8 @@ class TestSchema(DataFrameModel): expected = { "validation_enabled": False, "validation_depth": ValidationDepth.SCHEMA_AND_DATA, - "pyspark_cache": False, - "pyspark_keep_cache": False, + "cache_dataframe": False, + "keep_cached_dataframe": False, } assert CONFIG.dict() == expected @@ -66,8 +66,8 @@ def test_schema_only(self, spark, sample_spark_schema): expected = { "validation_enabled": True, "validation_depth": ValidationDepth.SCHEMA_ONLY, - "pyspark_cache": False, - "pyspark_keep_cache": False, + "cache_dataframe": False, + "keep_cached_dataframe": False, } assert CONFIG.dict() == expected @@ -146,8 +146,8 @@ def test_data_only(self, spark, sample_spark_schema): expected = { "validation_enabled": True, "validation_depth": ValidationDepth.DATA_ONLY, - "pyspark_cache": False, - "pyspark_keep_cache": False, + "cache_dataframe": False, + "keep_cached_dataframe": False, } assert CONFIG.dict() == expected @@ -233,8 +233,8 @@ def test_schema_and_data(self, spark, sample_spark_schema): expected = { "validation_enabled": True, "validation_depth": ValidationDepth.SCHEMA_AND_DATA, - "pyspark_cache": False, - "pyspark_keep_cache": False, + "cache_dataframe": False, + "keep_cached_dataframe": False, } assert CONFIG.dict() == expected @@ -339,21 +339,21 @@ class TestSchema(DataFrameModel): @pytest.mark.parametrize("cache_enabled", [True, False]) @pytest.mark.parametrize("keep_cache_enabled", [True, False]) # pylint:disable=too-many-locals - def test_pyspark_cache_settings( + def test_cache_dataframe_settings( self, cache_enabled, keep_cache_enabled, ): """This function validates setters and getters for cache/keep_cache options.""" # Set expected properties in Config object - CONFIG.pyspark_cache = cache_enabled - CONFIG.pyspark_keep_cache = keep_cache_enabled + CONFIG.cache_dataframe = cache_enabled + CONFIG.keep_cached_dataframe = keep_cache_enabled # Evaluate expected Config expected = { "validation_enabled": True, "validation_depth": ValidationDepth.SCHEMA_AND_DATA, - "pyspark_cache": cache_enabled, - "pyspark_keep_cache": keep_cache_enabled, + "cache_dataframe": cache_enabled, + "keep_cached_dataframe": keep_cache_enabled, } assert CONFIG.dict() == expected diff --git a/tests/pyspark/test_pyspark_decorators.py b/tests/pyspark/test_pyspark_decorators.py index a1dc72fcd..67e5e2b2c 100644 --- a/tests/pyspark/test_pyspark_decorators.py +++ b/tests/pyspark/test_pyspark_decorators.py @@ -22,10 +22,10 @@ class TestPanderaDecorators: sample_data = [("Bread", 9), ("Cutter", 15)] - def test_pyspark_cache_requirements(self, spark, sample_spark_schema): + def test_cache_dataframe_requirements(self, spark, sample_spark_schema): """Validates if decorator can only be applied in a proper function.""" # Set expected properties in Config object - CONFIG.pyspark_cache = True + CONFIG.cache_dataframe = True input_df = spark_df(spark, self.sample_data, sample_spark_schema) class FakeDataFrameSchemaBackend: @@ -74,7 +74,7 @@ def func_wo_check_obj(self, message: str): ) # pylint:disable=too-many-locals - def test_pyspark_cache_settings( + def test_cache_dataframe_settings( self, spark, sample_spark_schema, @@ -86,8 +86,8 @@ def test_pyspark_cache_settings( ): """This function validates that caching/unpersisting works as expected.""" # Set expected properties in Config object - CONFIG.pyspark_cache = cache_enabled - CONFIG.pyspark_keep_cache = keep_cache_enabled + CONFIG.cache_dataframe = cache_enabled + CONFIG.keep_cached_dataframe = keep_cache_enabled # Prepare test data input_df = spark_df(spark, self.sample_data, sample_spark_schema)