Skip to content

Commit

Permalink
change envvar from PYSPARK_UNPERSIST to PYSPARK_KEEP_CACHE
Browse files Browse the repository at this point in the history
Signed-off-by: Filipe Oliveira <[email protected]>
  • Loading branch information
filipeo2-mck committed Nov 16, 2023
1 parent 952bef3 commit 08f0a25
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 21 deletions.
4 changes: 2 additions & 2 deletions pandera/backends/pyspark/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def cache_check_obj():
entrypoint.
The behavior of the resulting decorator depends on the `PANDERA_PYSPARK_CACHING` and
`PANDERA_PYSPARK_UNPERSIST` (optional) environment variables.
`PANDERA_PYSPARK_KEEP_CACHE` (optional) environment variables.
Usage:
@cache_check_obj()
Expand Down Expand Up @@ -186,7 +186,7 @@ def cached_check_obj():

yield # Execute the decorated function

Check warning on line 187 in pandera/backends/pyspark/decorators.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/decorators.py#L187

Added line #L187 was not covered by tests

if CONFIG.pyspark_unpersist:
if not CONFIG.pyspark_keep_cache:

Check warning on line 189 in pandera/backends/pyspark/decorators.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/decorators.py#L189

Added line #L189 was not covered by tests
# If not cached, `.unpersist()` does nothing
logger.debug("Unpersisting dataframe...")
check_obj.unpersist()

Check warning on line 192 in pandera/backends/pyspark/decorators.py

View check run for this annotation

Codecov / codecov/patch

pandera/backends/pyspark/decorators.py#L191-L192

Added lines #L191 - L192 were not covered by tests
Expand Down
10 changes: 5 additions & 5 deletions pandera/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ class PanderaConfig(BaseModel):
export PANDERA_VALIDATION_ENABLED=False
export PANDERA_VALIDATION_DEPTH=DATA_ONLY
export PANDERA_PYSPARK_CACHE=True
export PANDERA_PYSPARK_UNPERSIST=False
export PANDERA_PYSPARK_KEEP_CACHE=True
"""

validation_enabled: bool = True
validation_depth: ValidationDepth = ValidationDepth.SCHEMA_AND_DATA
pyspark_cache: bool = False
pyspark_unpersist: bool = True
pyspark_keep_cache: bool = False

Check warning on line 30 in pandera/config.py

View check run for this annotation

Codecov / codecov/patch

pandera/config.py#L29-L30

Added lines #L29 - L30 were not covered by tests


# this config variable should be accessible globally
Expand All @@ -43,8 +43,8 @@ class PanderaConfig(BaseModel):
"PANDERA_PYSPARK_CACHE",
False,
),
pyspark_unpersist=os.environ.get(
"PANDERA_PYSPARK_UNPERSIST",
True,
pyspark_keep_cache=os.environ.get(
"PANDERA_PYSPARK_KEEP_CACHE",
False,
),
)
18 changes: 9 additions & 9 deletions tests/pyspark/test_pyspark_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class TestSchema(DataFrameModel):
"validation_enabled": False,
"validation_depth": ValidationDepth.SCHEMA_AND_DATA,
"pyspark_cache": False,
"pyspark_unpersist": True,
"pyspark_keep_cache": False,
}

assert CONFIG.dict() == expected
Expand All @@ -67,7 +67,7 @@ def test_schema_only(self, spark, sample_spark_schema):
"validation_enabled": True,
"validation_depth": ValidationDepth.SCHEMA_ONLY,
"pyspark_cache": False,
"pyspark_unpersist": True,
"pyspark_keep_cache": False,
}
assert CONFIG.dict() == expected

Expand Down Expand Up @@ -147,7 +147,7 @@ def test_data_only(self, spark, sample_spark_schema):
"validation_enabled": True,
"validation_depth": ValidationDepth.DATA_ONLY,
"pyspark_cache": False,
"pyspark_unpersist": True,
"pyspark_keep_cache": False,
}
assert CONFIG.dict() == expected

Expand Down Expand Up @@ -234,7 +234,7 @@ def test_schema_and_data(self, spark, sample_spark_schema):
"validation_enabled": True,
"validation_depth": ValidationDepth.SCHEMA_AND_DATA,
"pyspark_cache": False,
"pyspark_unpersist": True,
"pyspark_keep_cache": False,
}
assert CONFIG.dict() == expected

Expand Down Expand Up @@ -337,23 +337,23 @@ class TestSchema(DataFrameModel):
)

@pytest.mark.parametrize("cache_enabled", [True, False])
@pytest.mark.parametrize("unpersist_enabled", [True, False])
@pytest.mark.parametrize("keep_cache_enabled", [True, False])
# pylint:disable=too-many-locals
def test_pyspark_cache_settings(
self,
cache_enabled,
unpersist_enabled,
keep_cache_enabled,
):
"""This function validates setter and getters of caching/unpersisting options."""
"""This function validates setters and getters for cache/keep_cache options."""
# Set expected properties in Config object
CONFIG.pyspark_cache = cache_enabled
CONFIG.pyspark_unpersist = unpersist_enabled
CONFIG.pyspark_keep_cache = keep_cache_enabled

# Evaluate expected Config
expected = {
"validation_enabled": True,
"validation_depth": ValidationDepth.SCHEMA_AND_DATA,
"pyspark_cache": cache_enabled,
"pyspark_unpersist": unpersist_enabled,
"pyspark_keep_cache": keep_cache_enabled,
}
assert CONFIG.dict() == expected
10 changes: 5 additions & 5 deletions tests/pyspark/test_pyspark_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,11 @@ def func_wo_check_obj(self, message: str):
_ = instance.func_wo_check_obj("wrong")

@pytest.mark.parametrize(
"cache_enabled,unpersist_enabled,"
"cache_enabled,keep_cache_enabled,"
"expected_caching_message,expected_unpersisting_message",
[
(True, True, True, True),
(True, False, True, None),
(True, True, True, None),
(True, False, True, True),
(False, True, None, None),
(False, False, None, None),
],
Expand All @@ -79,15 +79,15 @@ def test_pyspark_cache_settings(
spark,
sample_spark_schema,
cache_enabled,
unpersist_enabled,
keep_cache_enabled,
expected_caching_message,
expected_unpersisting_message,
caplog,
):
"""This function validates that caching/unpersisting works as expected."""
# Set expected properties in Config object
CONFIG.pyspark_cache = cache_enabled
CONFIG.pyspark_unpersist = unpersist_enabled
CONFIG.pyspark_keep_cache = keep_cache_enabled

# Prepare test data
input_df = spark_df(spark, self.sample_data, sample_spark_schema)
Expand Down

0 comments on commit 08f0a25

Please sign in to comment.