diff --git a/docs/source/pyspark_sql.rst b/docs/source/pyspark_sql.rst index 22e444f0d..926b95152 100644 --- a/docs/source/pyspark_sql.rst +++ b/docs/source/pyspark_sql.rst @@ -343,3 +343,16 @@ We also provided a helper function to extract metadata from a schema as follows: .. note:: This feature is available for ``pyspark.sql`` and ``pandas`` both. + +`unique` support +---------------- + +*new in 0.17.3* + +.. warning:: + + The `unique` support for PySpark-based validations to define which columns must be + tested for unique values may incur in a performance hit, given Spark's distributed + nature. + + Use with caution. diff --git a/pandera/backends/pyspark/container.py b/pandera/backends/pyspark/container.py index 6d2ef2683..c162f8039 100644 --- a/pandera/backends/pyspark/container.py +++ b/pandera/backends/pyspark/container.py @@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional from pyspark.sql import DataFrame -from pyspark.sql.functions import col +from pyspark.sql.functions import col, count from pandera.api.pyspark.error_handler import ErrorCategory, ErrorHandler from pandera.api.pyspark.types import is_table @@ -15,7 +15,6 @@ from pandera.backends.pyspark.error_formatters import scalar_failure_case from pandera.config import CONFIG from pandera.errors import ( - ParserError, SchemaDefinitionError, SchemaError, SchemaErrorReason, @@ -31,14 +30,14 @@ def preprocess(self, check_obj: DataFrame, inplace: bool = False): return check_obj @validate_scope(scope=ValidationScope.SCHEMA) - def _column_checks( + def _schema_checks( self, check_obj: DataFrame, schema, column_info: ColumnInfo, error_handler: ErrorHandler, ): - """run the checks related to columns presence, uniqueness and filter column if neccesary""" + """run the checks related to columns presence, strictness and filter column if neccesary""" # check the container metadata, e.g. field names try: @@ -71,6 +70,7 @@ def _column_checks( reason_code=exc.reason_code, schema_error=exc, ) + # try to coerce datatypes check_obj = self.coerce_dtype( check_obj, @@ -80,6 +80,28 @@ def _column_checks( return check_obj + @validate_scope(scope=ValidationScope.DATA) + def _data_checks( + self, + check_obj: DataFrame, + schema, + column_info: ColumnInfo, # pylint: disable=unused-argument + error_handler: ErrorHandler, + ): + """Run the checks related to data validation and uniqueness.""" + + # uniqueness of values + try: + check_obj = self.unique( + check_obj, schema=schema, error_handler=error_handler + ) + except SchemaError as err: + error_handler.collect_error( + ErrorCategory.DATA, err.reason_code, err + ) + + return check_obj + def validate( self, check_obj: DataFrame, @@ -115,8 +137,13 @@ def validate( check_obj = check_obj.pandera.add_schema(schema) column_info = self.collect_column_info(check_obj, schema, lazy) - # validate the columns of the dataframe - check_obj = self._column_checks( + # validate the columns (schema) of the dataframe + check_obj = self._schema_checks( + check_obj, schema, column_info, error_handler + ) + + # validate the rows (data) of the dataframe + check_obj = self._data_checks( check_obj, schema, column_info, error_handler ) @@ -191,7 +218,7 @@ def run_checks(self, check_obj: DataFrame, schema, error_handler): check_results = [] for check_index, check in enumerate( schema.checks - ): # schama.checks is null + ): # schema.checks is null try: check_results.append( self.run_check(check_obj, schema, check, check_index) @@ -386,8 +413,7 @@ def coerce_dtype( except SchemaErrors as err: for schema_error_dict in err.schema_errors: if not error_handler.lazy: - # raise the first error immediately if not doing lazy - # validation + # raise the first error immediately if not doing lazy validation raise schema_error_dict["error"] error_handler.collect_error( ErrorCategory.DTYPE_COERCION, @@ -417,27 +443,6 @@ def _coerce_dtype( # NOTE: clean up the error handling! error_handler = ErrorHandler(lazy=True) - def _coerce_df_dtype(obj: DataFrame) -> DataFrame: - if schema.dtype is None: - raise ValueError( - "dtype argument is None. Must specify this argument " - "to coerce dtype" - ) - - try: - return schema.dtype.try_coerce(obj) - except ParserError as exc: - raise SchemaError( - schema=schema, - data=obj, - message=( - f"Error while coercing '{schema.name}' to type " - f"{schema.dtype}: {exc}\n{exc.failure_cases}" - ), - failure_cases=exc.failure_cases, - check=f"coerce_dtype('{schema.dtype}')", - ) from exc - def _try_coercion(obj, colname, col_schema): try: schema = obj.pandera.schema @@ -490,6 +495,74 @@ def _try_coercion(obj, colname, col_schema): return obj + @validate_scope(scope=ValidationScope.DATA) + def unique( + self, + check_obj: DataFrame, + *, + schema=None, + error_handler: ErrorHandler = None, + ): + """Check uniqueness in the check object.""" + assert schema is not None, "The `schema` argument must be provided." + assert ( + error_handler is not None + ), "The `error_handler` argument must be provided." + + if not schema.unique: + return check_obj + + # Determine unique columns based on schema's config + unique_columns = ( + [schema.unique] + if isinstance(schema.unique, str) + else schema.unique + ) + + # Check if values belong to the dataframe columns + missing_unique_columns = set(unique_columns) - set(check_obj.columns) + if missing_unique_columns: + raise SchemaDefinitionError( + "Specified `unique` columns are missing in the dataframe: " + f"{list(missing_unique_columns)}" + ) + + duplicates_count = ( + check_obj.select(*unique_columns) # ignore other cols + .groupby(*unique_columns) + .agg(count("*").alias("pandera_duplicate_counts")) + .filter( + col("pandera_duplicate_counts") > 1 + ) # long name to avoid colisions + .count() + ) + + if duplicates_count > 0: + raise SchemaError( + schema=schema, + data=check_obj, + message=( + f"Duplicated rows [{duplicates_count}] were found " + f"for columns {unique_columns}" + ), + check="unique", + reason_code=SchemaErrorReason.DUPLICATES, + ) + + return check_obj + + def _check_uniqueness( + self, + obj: DataFrame, + schema, + ) -> DataFrame: + """Ensure uniqueness in dataframe columns. + + :param obj: dataframe to check. + :param schema: schema object. + :returns: dataframe checked. + """ + ########## # Checks # ########## @@ -516,8 +589,7 @@ def check_column_names_are_unique(self, check_obj: DataFrame, schema): schema=schema, data=check_obj, message=( - "dataframe contains multiple columns with label(s): " - f"{failed}" + f"dataframe contains multiple columns with label(s): {failed}" ), failure_cases=scalar_failure_case(failed), check="dataframe_column_labels_unique", diff --git a/pandera/backends/pyspark/decorators.py b/pandera/backends/pyspark/decorators.py index 2a559c3db..9c202b4be 100644 --- a/pandera/backends/pyspark/decorators.py +++ b/pandera/backends/pyspark/decorators.py @@ -81,6 +81,18 @@ def validate_scope(scope: ValidationScope): def _wrapper(func): @functools.wraps(func) def wrapper(self, *args, **kwargs): + def _get_check_obj(): + """ + Get dataframe object passed as arg to the decorated func. + + Returns: + The DataFrame object. + """ + if args: + for value in args: + if isinstance(value, pyspark.sql.DataFrame): + return value + if scope == ValidationScope.SCHEMA: if CONFIG.validation_depth in ( ValidationDepth.SCHEMA_AND_DATA, @@ -89,17 +101,12 @@ def wrapper(self, *args, **kwargs): return func(self, *args, **kwargs) else: warnings.warn( - "Skipping Execution of function as parameters set to DATA_ONLY ", + f"Skipping execution of function {func.__name__} as validation depth is set to DATA_ONLY ", stacklevel=2, ) - if not kwargs: - for value in kwargs.values(): - if isinstance(value, pyspark.sql.DataFrame): - return value - if args: - for value in args: - if isinstance(value, pyspark.sql.DataFrame): - return value + # If the function was skip, return the `check_obj` value anyway, + # given that some return value is expected + return _get_check_obj() elif scope == ValidationScope.DATA: if CONFIG.validation_depth in ( @@ -109,9 +116,12 @@ def wrapper(self, *args, **kwargs): return func(self, *args, **kwargs) else: warnings.warn( - "Skipping Execution of function as parameters set to SCHEMA_ONLY ", + f"Skipping execution of function {func.__name__} as validation depth is set to SCHEMA_ONLY", stacklevel=2, ) + # If the function was skip, return the `check_obj` value anyway, + # given that some return value is expected + return _get_check_obj() return wrapper diff --git a/tests/pyspark/test_pyspark_config.py b/tests/pyspark/test_pyspark_config.py index 60ebfe4d9..6005c0c16 100644 --- a/tests/pyspark/test_pyspark_config.py +++ b/tests/pyspark/test_pyspark_config.py @@ -53,7 +53,7 @@ def test_schema_only(self, spark, sample_spark_schema): CONFIG.validation_enabled = True CONFIG.validation_depth = ValidationDepth.SCHEMA_ONLY - pandra_schema = DataFrameSchema( + pandera_schema = DataFrameSchema( { "product": Column(T.StringType(), Check.str_startswith("B")), "price_val": Column(T.IntegerType()), @@ -67,7 +67,7 @@ def test_schema_only(self, spark, sample_spark_schema): assert CONFIG.dict() == expected input_df = spark_df(spark, self.sample_data, sample_spark_schema) - output_dataframeschema_df = pandra_schema.validate(input_df) + output_dataframeschema_df = pandera_schema.validate(input_df) expected_dataframeschema = { "SCHEMA": { "COLUMN_NOT_IN_DATAFRAME": [ diff --git a/tests/pyspark/test_pyspark_model.py b/tests/pyspark/test_pyspark_model.py index 80cabe2be..8523d35e4 100644 --- a/tests/pyspark/test_pyspark_model.py +++ b/tests/pyspark/test_pyspark_model.py @@ -1,6 +1,7 @@ """Unit tests for DataFrameModel module.""" # pylint:disable=abstract-method +from contextlib import nullcontext as does_not_raise from typing import Optional from pyspark.sql import DataFrame import pyspark.sql.types as T @@ -13,6 +14,9 @@ from pandera.pyspark import DataFrameModel, DataFrameSchema, Field from tests.pyspark.conftest import spark_df from pandera.api.pyspark.model import docstring_substitution +from pandera.errors import ( + SchemaDefinitionError, +) def test_schema_with_bare_types(): @@ -223,6 +227,104 @@ class Config: assert PanderaSchema.get_metadata() == expected +@pytest.mark.parametrize( + "data, expectation", + [ + ( + (), + does_not_raise(), + ), + ( + ([1, 4], [2, 5], [3, 6]), + does_not_raise(), + ), + ( + ([0, 0], [0, 0], [3, 6]), + pytest.raises(pa.PysparkSchemaError), + ), + ], + ids=["no_data", "unique_data", "duplicated_data"], +) +def test_dataframe_schema_unique(spark, data, expectation): + """Test uniqueness checks on pyspark dataframes.""" + + df = spark.createDataFrame(data, "a: int, b: int") + + # Test `unique` configuration with a single column + class UniqueSingleColumn(pa.DataFrameModel): + """Simple DataFrameModel containing a column.""" + + a: T.IntegerType = pa.Field() + b: T.IntegerType = pa.Field() + + class Config: + """Config class.""" + + unique = "a" + + assert isinstance(UniqueSingleColumn(df), DataFrame) + + with expectation: + df_out = UniqueSingleColumn.validate(check_obj=df) + if df_out.pandera.errors: + print(f"{df_out.pandera.errors=}") + raise pa.PysparkSchemaError + + # Test `unique` configuration with multiple columns + class UniqueMultipleColumns(pa.DataFrameModel): + """Simple DataFrameModel containing two columns.""" + + a: T.IntegerType = pa.Field() + b: T.IntegerType = pa.Field() + + class Config: + """Config class.""" + + unique = ["a", "b"] + + assert isinstance(UniqueMultipleColumns(df), DataFrame) + + with expectation: + df_out = UniqueMultipleColumns.validate(check_obj=df) + if df_out.pandera.errors: + print(f"{df_out.pandera.errors=}") + raise pa.PysparkSchemaError + + +@pytest.mark.parametrize( + "unique_column_name", + [ + "x", + ["x", "y"], + ["x", ""], + ], + ids=[ + "wrong_column", + "multiple_wrong_columns", + "multiple_wrong_columns_w_empty", + ], +) +def test_dataframe_schema_unique_wrong_column(spark, unique_column_name): + """Test uniqueness checks on pyspark dataframes.""" + + df = spark.createDataFrame(([1, 2],), "a: int, b: int") + + # Test `unique` configuration with a single, wrongly named column + class UniqueMultipleColumns(pa.DataFrameModel): + """Simple DataFrameModel containing two columns.""" + + a: T.IntegerType = pa.Field() + b: T.IntegerType = pa.Field() + + class Config: + """Config class.""" + + unique = unique_column_name + + with pytest.raises(SchemaDefinitionError): + _ = UniqueMultipleColumns.validate(check_obj=df) + + def test_dataframe_schema_strict(spark, config_params: PanderaConfig) -> None: """ Checks if strict=True whether a schema error is raised because either extra columns are present in the dataframe