Skip to content

Commit

Permalink
✨ Add keyword arguments to schema's validate() method (#73)
Browse files Browse the repository at this point in the history
  • Loading branch information
michal-mmm authored and Galileo-Galilei committed Jul 18, 2024
1 parent 8904511 commit 24590ec
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 2 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

## [Unreleased]

### Added

- :sparkles: Add keyword arguments to schema's validate() method ([#73](https://github.com/Galileo-Galilei/kedro-pandera/issues/73))

## [0.2.2] - 2024-06-03

### Added
Expand Down
4 changes: 3 additions & 1 deletion kedro_pandera/framework/hooks/pandera_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,10 @@ def _validate_datasets(
and "pandera" in metadata
and name not in self._validated_datasets
):
schema = metadata["pandera"]["schema"]
validate_kwargs = metadata["pandera"].get("validate_kwargs", dict())
try:
metadata["pandera"]["schema"].validate(data)
schema.validate(data, **validate_kwargs)
self._validated_datasets.add(name)
except SchemaError as err:
self._logger.error(
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def _parse_requirements(path, encoding="utf-8"):
],
"test": [
"ruff>=0.4.0, <0.5.0",
"pyspark>=2.2, <4.0",
"pytest>=7.0.0, <8.0.0",
"pytest-cov>=4.0.0, <5.0.0",
"pytest-mock",
Expand Down
94 changes: 93 additions & 1 deletion tests/framework/hooks/test_hook.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
from typing import Any, Dict

import pandas as pd
import pandera.pyspark as ps
import pyspark.sql.types as T
import pytest
from kedro.framework.hooks import _create_hook_manager
from kedro.framework.hooks.manager import _register_hooks
from kedro.io import DataCatalog, LambdaDataset
from kedro.pipeline import node, pipeline
from kedro.pipeline import Pipeline, node, pipeline
from kedro.runner import SequentialRunner
from kedro_datasets.pandas import CSVDataset
from pandera.errors import SchemaError
from pandera.io import from_yaml
from pyspark.sql import SparkSession

from kedro_pandera.framework.hooks.pandera_hook import PanderaHook

Expand Down Expand Up @@ -142,3 +148,89 @@ def test_no_exception_on_memory_dataset_output():
)
assert test_hook_manager.is_registered(test_hook)
SequentialRunner().run(test_pipeline, test_catalog, hook_manager=test_hook_manager)


@pytest.fixture(scope="session")
def spark_session():
return SparkSession.builder.master("local[*]").getOrCreate()


class TestPySparkDataframeLazyEvaluation:
class IrisCorrectSchema(ps.DataFrameModel):
sepal_length: T.DoubleType
sepal_width: T.DoubleType
petal_length: T.DoubleType
petal_width: T.DoubleType
species: T.StringType

class IrisWrongSchema(ps.DataFrameModel):
sepal_length: T.StringType

def create_test_catalog(
self, spark_session: SparkSession, schema: ps.DataFrameModel, lazy: bool
) -> DataCatalog:
return DataCatalog(
{
"Input": LambdaDataset(
load=lambda: spark_session.createDataFrame(
pd.read_csv("tests/data/iris.csv")
),
save=lambda data: None,
metadata={
"pandera": {"schema": schema, "validate_kwargs": {"lazy": lazy}}
},
),
}
)

def create_test_pipeline(self) -> Pipeline:
return pipeline(
[node(func=lambda x: x, inputs="Input", outputs="Output", name="node1")]
)

def run_pipeline(self, test_catalog: DataCatalog) -> Dict[str, Any]:
test_hook_manager = _create_hook_manager()
test_hook = _get_test_hook()
HOOKS = (test_hook,)
_register_hooks(test_hook_manager, HOOKS)
test_pipeline = self.create_test_pipeline()
assert test_hook_manager.is_registered(test_hook)
return SequentialRunner().run(
test_pipeline, test_catalog, hook_manager=test_hook_manager
)

def test_spark_dataframe_correct_schema_lazy_validation(
self, spark_session: SparkSession
):
test_catalog = self.create_test_catalog(
spark_session, self.IrisCorrectSchema, lazy=True
)
data = self.run_pipeline(test_catalog)
assert len(data["Output"].pandera.errors) == 0

def test_spark_dataframe_wrong_schema_lazy_validation_raises_no_error(
self, spark_session: SparkSession
):
test_catalog = self.create_test_catalog(
spark_session, self.IrisWrongSchema, lazy=True
)
data = self.run_pipeline(test_catalog)
assert len(data["Output"].pandera.errors) > 0

def test_spark_dataframe_wrong_schema_eager_validation_raises_error(
self, spark_session: SparkSession
):
test_catalog = self.create_test_catalog(
spark_session, self.IrisWrongSchema, lazy=False
)
with pytest.raises(SchemaError):
self.run_pipeline(test_catalog)

def test_spark_dataframe_correct_schema_eager_validation_raises_no_error(
self, spark_session: SparkSession
):
test_catalog = self.create_test_catalog(
spark_session, self.IrisCorrectSchema, lazy=False
)
data = self.run_pipeline(test_catalog)
assert len(data["Output"].pandera.errors) == 0

0 comments on commit 24590ec

Please sign in to comment.