From 3bd1d3a8143430f58f2552c4b1acc79c616cce23 Mon Sep 17 00:00:00 2001 From: Filipe Oliveira Date: Wed, 25 Oct 2023 20:55:58 -0300 Subject: [PATCH 1/7] working, not the tests Signed-off-by: Filipe Oliveira --- pandera/api/pyspark/error_handler.py | 1 + pandera/backends/pyspark/container.py | 97 +++++++++++++++++++++++++-- tests/pyspark/test_pyspark_model.py | 79 ++++++++++++++++++++++ 3 files changed, 172 insertions(+), 5 deletions(-) diff --git a/pandera/api/pyspark/error_handler.py b/pandera/api/pyspark/error_handler.py index ed0b7e6e1..f5afaafcf 100644 --- a/pandera/api/pyspark/error_handler.py +++ b/pandera/api/pyspark/error_handler.py @@ -14,6 +14,7 @@ class ErrorCategory(Enum): DATA = "data-failures" SCHEMA = "schema-failures" DTYPE_COERCION = "dtype-coercion-failures" + UNIQUENESS = "uniqueness-failures" class ErrorHandler: diff --git a/pandera/backends/pyspark/container.py b/pandera/backends/pyspark/container.py index 6d2ef2683..6cd030363 100644 --- a/pandera/backends/pyspark/container.py +++ b/pandera/backends/pyspark/container.py @@ -6,7 +6,7 @@ from typing import Any, Dict, List, Optional from pyspark.sql import DataFrame -from pyspark.sql.functions import col +from pyspark.sql.functions import col, count from pandera.api.pyspark.error_handler import ErrorCategory, ErrorHandler from pandera.api.pyspark.types import is_table @@ -71,6 +71,7 @@ def _column_checks( reason_code=exc.reason_code, schema_error=exc, ) + # try to coerce datatypes check_obj = self.coerce_dtype( check_obj, @@ -78,6 +79,18 @@ def _column_checks( error_handler=error_handler, ) + # uniqueness of values + try: + check_obj = self.unique( + check_obj, schema=schema, error_handler=error_handler + ) + except SchemaError as exc: + error_handler.collect_error( + type=ErrorCategory.DATA, + reason_code=exc.reason_code, + schema_error=exc, + ) + return check_obj def validate( @@ -386,8 +399,7 @@ def coerce_dtype( except SchemaErrors as err: for schema_error_dict in err.schema_errors: if not error_handler.lazy: - # raise the first error immediately if not doing lazy - # validation + # raise the first error immediately if not doing lazy validation raise schema_error_dict["error"] error_handler.collect_error( ErrorCategory.DTYPE_COERCION, @@ -490,6 +502,82 @@ def _try_coercion(obj, colname, col_schema): return obj + @validate_scope(scope=ValidationScope.DATA) + def unique( + self, + check_obj: DataFrame, + *, + schema=None, + error_handler: ErrorHandler = None, + ): + """Check uniqueness in the check object.""" + assert schema is not None, "The `schema` argument must be provided." + assert ( + error_handler is not None + ), "The `error_handler` argument must be provided." + + if not schema.unique: + return check_obj + + try: + check_obj = self._check_uniqueness( + check_obj, + schema, + ) + except SchemaError as err: + if not error_handler.lazy: + raise err + error_handler.collect_error( + ErrorCategory.UNIQUENESS, err.reason_code, err + ) + + return check_obj + + def _check_uniqueness( + self, + obj: DataFrame, + schema, + ) -> DataFrame: + """Ensure uniqueness in dataframe columns. + + :param obj: dataframe to check. + :param schema: schema object. + :returns: dataframe checked. + """ + # Use unique definition of columns as first option + # unique_columns = [col.unique for col in schema.columns.values()] + + # Overwrite it, if schemas's Config class has a unique declaration + unique_columns = ( + [schema.unique] + if isinstance(schema.unique, str) + else schema.unique + ) + + duplicates_count = ( + obj.select(*unique_columns) # ignore other cols + .groupby(*unique_columns) + .agg(count("*").alias("pandera_duplicate_counts")) + .filter( + col("pandera_duplicate_counts") > 1 + ) # long name to avoid colisions + .count() + ) + + if duplicates_count > 0: + raise SchemaError( + schema=schema, + data=obj, + message=( + f"Duplicated rows [{duplicates_count}] were found " + f"for columns {unique_columns}" + ), + check="unique", + reason_code=SchemaErrorReason.DUPLICATES, + ) + + return obj + ########## # Checks # ########## @@ -516,8 +604,7 @@ def check_column_names_are_unique(self, check_obj: DataFrame, schema): schema=schema, data=check_obj, message=( - "dataframe contains multiple columns with label(s): " - f"{failed}" + f"dataframe contains multiple columns with label(s): {failed}" ), failure_cases=scalar_failure_case(failed), check="dataframe_column_labels_unique", diff --git a/tests/pyspark/test_pyspark_model.py b/tests/pyspark/test_pyspark_model.py index 80cabe2be..d371fd24f 100644 --- a/tests/pyspark/test_pyspark_model.py +++ b/tests/pyspark/test_pyspark_model.py @@ -1,6 +1,7 @@ """Unit tests for DataFrameModel module.""" # pylint:disable=abstract-method +from contextlib import nullcontext as does_not_raise from typing import Optional from pyspark.sql import DataFrame import pyspark.sql.types as T @@ -223,6 +224,84 @@ class Config: assert PanderaSchema.get_metadata() == expected +@pytest.fixture +def datamodel_unique_single_column() -> pa.DataFrameModel: + """Fixture containing DataFrameModel with optional columns.""" + + class MyDataModel(pa.DataFrameModel): + """Simple DataFrameModel containing a column.""" + + a: T.LongType = pa.Field() + b: T.LongType = pa.Field() + + class Config: + """Config class.""" + + unique = "a" + + return MyDataModel + + +@pytest.fixture +def datamodel_unique_multiple_columns() -> pa.DataFrameModel: + """Fixture containing DataFrameModel with optional columns.""" + + class MyDataModel(pa.DataFrameModel): + """Simple DataFrameModel containing a column.""" + + a: T.LongType = pa.Field() + b: T.LongType = pa.Field() + + class Config: + """Config class.""" + + unique = ["a", "b"] + + return MyDataModel + + +@pytest.mark.parametrize( + "data_model, data, expectation", + [ + ( + datamodel_unique_single_column, + ([1, 4], [2, 5], [3, 6]), + does_not_raise(), + ), + ( + datamodel_unique_multiple_columns, + ([1, 4], [2, 5], [3, 6]), + does_not_raise(), + ), + ( + datamodel_unique_single_column, + ([0, 0], [0, 0], [3, 6]), + pytest.raises(pa.PysparkSchemaError), + ), + ( + datamodel_unique_multiple_columns, + ([0, 0], [0, 0], [3, 6]), + pytest.raises(pa.PysparkSchemaError), + ), + ], +) +def test_dataframe_schema_unique(spark, data_model, data, expectation): + """Test uniqueness checks on pyspark dataframes.""" + print(f"{type(spark)=}") + print(f"{type(data_model)=}") + print(f"{type(data)=}") + print(f"{type(expectation)=}") + + df = spark.createDataFrame(data, ["a", "b"]) + + # assert isinstance(data_model(df), DataFrame) + + with expectation: + df_out = data_model.validate(check_obj=df) + if df_out.pandera.errors: + raise pa.PysparkSchemaError + + def test_dataframe_schema_strict(spark, config_params: PanderaConfig) -> None: """ Checks if strict=True whether a schema error is raised because either extra columns are present in the dataframe From e28275c514ff5ae80a89213ca08a199ffc7fd9cb Mon Sep 17 00:00:00 2001 From: Filipe Oliveira Date: Wed, 25 Oct 2023 21:11:35 -0300 Subject: [PATCH 2/7] tests working, missing docs Signed-off-by: Filipe Oliveira --- tests/pyspark/test_pyspark_model.py | 102 +++++++++++++++++----------- 1 file changed, 62 insertions(+), 40 deletions(-) diff --git a/tests/pyspark/test_pyspark_model.py b/tests/pyspark/test_pyspark_model.py index d371fd24f..946467147 100644 --- a/tests/pyspark/test_pyspark_model.py +++ b/tests/pyspark/test_pyspark_model.py @@ -224,81 +224,103 @@ class Config: assert PanderaSchema.get_metadata() == expected -@pytest.fixture -def datamodel_unique_single_column() -> pa.DataFrameModel: - """Fixture containing DataFrameModel with optional columns.""" +# @pytest.fixture +# def datamodel_unique_single_column() -> pa.DataFrameModel: +# """Fixture containing DataFrameModel with optional columns.""" - class MyDataModel(pa.DataFrameModel): - """Simple DataFrameModel containing a column.""" +# class MyDataModel(pa.DataFrameModel): +# """Simple DataFrameModel containing a column.""" - a: T.LongType = pa.Field() - b: T.LongType = pa.Field() +# a: T.LongType = pa.Field() +# b: T.LongType = pa.Field() - class Config: - """Config class.""" +# class Config: +# """Config class.""" - unique = "a" +# unique = "a" - return MyDataModel +# return MyDataModel -@pytest.fixture -def datamodel_unique_multiple_columns() -> pa.DataFrameModel: - """Fixture containing DataFrameModel with optional columns.""" +# @pytest.fixture +# def datamodel_unique_multiple_columns() -> pa.DataFrameModel: +# """Fixture containing DataFrameModel with optional columns.""" - class MyDataModel(pa.DataFrameModel): - """Simple DataFrameModel containing a column.""" +# class MyDataModel(pa.DataFrameModel): +# """Simple DataFrameModel containing a column.""" - a: T.LongType = pa.Field() - b: T.LongType = pa.Field() +# a: T.LongType = pa.Field() +# b: T.LongType = pa.Field() - class Config: - """Config class.""" +# class Config: +# """Config class.""" - unique = ["a", "b"] +# unique = ["a", "b"] - return MyDataModel +# return MyDataModel @pytest.mark.parametrize( - "data_model, data, expectation", + "data, expectation", [ ( - datamodel_unique_single_column, - ([1, 4], [2, 5], [3, 6]), + (), does_not_raise(), ), ( - datamodel_unique_multiple_columns, ([1, 4], [2, 5], [3, 6]), does_not_raise(), ), ( - datamodel_unique_single_column, - ([0, 0], [0, 0], [3, 6]), - pytest.raises(pa.PysparkSchemaError), - ), - ( - datamodel_unique_multiple_columns, ([0, 0], [0, 0], [3, 6]), pytest.raises(pa.PysparkSchemaError), ), ], + ids=["no_data", "non_duplicated_data", "duplicated_data"], ) -def test_dataframe_schema_unique(spark, data_model, data, expectation): +def test_dataframe_schema_unique(spark, data, expectation): """Test uniqueness checks on pyspark dataframes.""" - print(f"{type(spark)=}") - print(f"{type(data_model)=}") - print(f"{type(data)=}") - print(f"{type(expectation)=}") - df = spark.createDataFrame(data, ["a", "b"]) + df = spark.createDataFrame(data, "a: int, b: int") + + # Test `unique` configuration with a single column + class UniqueSingleColumn(pa.DataFrameModel): + """Simple DataFrameModel containing a column.""" + + a: T.IntegerType = pa.Field() + b: T.IntegerType = pa.Field() + + class Config: + """Config class.""" + + unique = "a" + + assert isinstance(UniqueSingleColumn(df), DataFrame) + + with expectation: + df_out = UniqueSingleColumn.validate(check_obj=df) + if df_out.pandera.errors: + print(f"{df_out.pandera.errors=}") + raise pa.PysparkSchemaError + + # Test `unique` configuration with multiple columns + class UniqueMultipleColumns(pa.DataFrameModel): + """Simple DataFrameModel containing a column.""" + + a: T.IntegerType = pa.Field() + b: T.IntegerType = pa.Field() + + class Config: + """Config class.""" + + unique = ["a", "b"] - # assert isinstance(data_model(df), DataFrame) + assert isinstance(UniqueMultipleColumns(df), DataFrame) with expectation: - df_out = data_model.validate(check_obj=df) + df_out = UniqueMultipleColumns.validate(check_obj=df) if df_out.pandera.errors: + print(f"{df_out.pandera.errors=}") raise pa.PysparkSchemaError From 7e9aa3d5a101ab8e6185648547bdf95766d78a1c Mon Sep 17 00:00:00 2001 From: Filipe Oliveira Date: Wed, 25 Oct 2023 21:19:48 -0300 Subject: [PATCH 3/7] add suggestion to docs Signed-off-by: Filipe Oliveira --- docs/source/pyspark_sql.rst | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/docs/source/pyspark_sql.rst b/docs/source/pyspark_sql.rst index 22e444f0d..93ea53e83 100644 --- a/docs/source/pyspark_sql.rst +++ b/docs/source/pyspark_sql.rst @@ -343,3 +343,16 @@ We also provided a helper function to extract metadata from a schema as follows: .. note:: This feature is available for ``pyspark.sql`` and ``pandas`` both. + +Unique support +-------------- + +*new in 0.17.3* + +.. warning:: + + The `unique` support in PySpark to define which columns must be tested for + unique values may cause a performance hit during validation, given it's distributed + nature. + + Use with caution. From 1d042973a153c4d675c38b20a8b07094eaef85fd Mon Sep 17 00:00:00 2001 From: Filipe Oliveira Date: Thu, 26 Oct 2023 17:14:42 -0300 Subject: [PATCH 4/7] fix failing test and add specific method for data-scoped validations Signed-off-by: Filipe Oliveira --- docs/source/pyspark_sql.rst | 8 ++--- pandera/api/pyspark/error_handler.py | 1 - pandera/backends/pyspark/container.py | 49 ++++++++++++-------------- pandera/backends/pyspark/decorators.py | 34 ++++++++++++------ tests/pyspark/test_pyspark_config.py | 4 +-- tests/pyspark/test_pyspark_model.py | 38 +------------------- 6 files changed, 53 insertions(+), 81 deletions(-) diff --git a/docs/source/pyspark_sql.rst b/docs/source/pyspark_sql.rst index 93ea53e83..926b95152 100644 --- a/docs/source/pyspark_sql.rst +++ b/docs/source/pyspark_sql.rst @@ -344,15 +344,15 @@ We also provided a helper function to extract metadata from a schema as follows: This feature is available for ``pyspark.sql`` and ``pandas`` both. -Unique support --------------- +`unique` support +---------------- *new in 0.17.3* .. warning:: - The `unique` support in PySpark to define which columns must be tested for - unique values may cause a performance hit during validation, given it's distributed + The `unique` support for PySpark-based validations to define which columns must be + tested for unique values may incur in a performance hit, given Spark's distributed nature. Use with caution. diff --git a/pandera/api/pyspark/error_handler.py b/pandera/api/pyspark/error_handler.py index f5afaafcf..ed0b7e6e1 100644 --- a/pandera/api/pyspark/error_handler.py +++ b/pandera/api/pyspark/error_handler.py @@ -14,7 +14,6 @@ class ErrorCategory(Enum): DATA = "data-failures" SCHEMA = "schema-failures" DTYPE_COERCION = "dtype-coercion-failures" - UNIQUENESS = "uniqueness-failures" class ErrorHandler: diff --git a/pandera/backends/pyspark/container.py b/pandera/backends/pyspark/container.py index 6cd030363..7968eca8b 100644 --- a/pandera/backends/pyspark/container.py +++ b/pandera/backends/pyspark/container.py @@ -15,7 +15,6 @@ from pandera.backends.pyspark.error_formatters import scalar_failure_case from pandera.config import CONFIG from pandera.errors import ( - ParserError, SchemaDefinitionError, SchemaError, SchemaErrorReason, @@ -31,14 +30,14 @@ def preprocess(self, check_obj: DataFrame, inplace: bool = False): return check_obj @validate_scope(scope=ValidationScope.SCHEMA) - def _column_checks( + def _schema_checks( self, check_obj: DataFrame, schema, column_info: ColumnInfo, error_handler: ErrorHandler, ): - """run the checks related to columns presence, uniqueness and filter column if neccesary""" + """run the checks related to columns presence, strictness and filter column if neccesary""" # check the container metadata, e.g. field names try: @@ -79,6 +78,18 @@ def _column_checks( error_handler=error_handler, ) + return check_obj + + @validate_scope(scope=ValidationScope.DATA) + def _data_checks( + self, + check_obj: DataFrame, + schema, + column_info: ColumnInfo, # pylint: disable=unused-argument + error_handler: ErrorHandler, + ): + """Run the checks related to data validation and uniqueness.""" + # uniqueness of values try: check_obj = self.unique( @@ -128,8 +139,13 @@ def validate( check_obj = check_obj.pandera.add_schema(schema) column_info = self.collect_column_info(check_obj, schema, lazy) - # validate the columns of the dataframe - check_obj = self._column_checks( + # validate the columns (schema) of the dataframe + check_obj = self._schema_checks( + check_obj, schema, column_info, error_handler + ) + + # validate the rows (data) of the dataframe + check_obj = self._data_checks( check_obj, schema, column_info, error_handler ) @@ -429,27 +445,6 @@ def _coerce_dtype( # NOTE: clean up the error handling! error_handler = ErrorHandler(lazy=True) - def _coerce_df_dtype(obj: DataFrame) -> DataFrame: - if schema.dtype is None: - raise ValueError( - "dtype argument is None. Must specify this argument " - "to coerce dtype" - ) - - try: - return schema.dtype.try_coerce(obj) - except ParserError as exc: - raise SchemaError( - schema=schema, - data=obj, - message=( - f"Error while coercing '{schema.name}' to type " - f"{schema.dtype}: {exc}\n{exc.failure_cases}" - ), - failure_cases=exc.failure_cases, - check=f"coerce_dtype('{schema.dtype}')", - ) from exc - def _try_coercion(obj, colname, col_schema): try: schema = obj.pandera.schema @@ -528,7 +523,7 @@ def unique( if not error_handler.lazy: raise err error_handler.collect_error( - ErrorCategory.UNIQUENESS, err.reason_code, err + ErrorCategory.DATA, err.reason_code, err ) return check_obj diff --git a/pandera/backends/pyspark/decorators.py b/pandera/backends/pyspark/decorators.py index 2a559c3db..aa93e5dbd 100644 --- a/pandera/backends/pyspark/decorators.py +++ b/pandera/backends/pyspark/decorators.py @@ -81,6 +81,22 @@ def validate_scope(scope: ValidationScope): def _wrapper(func): @functools.wraps(func) def wrapper(self, *args, **kwargs): + def _get_check_obj(): + """ + Get dataframe object passed as arg to the decorated func. + + Returns: + The DataFrame object. + """ + if kwargs: + for value in kwargs.values(): + if isinstance(value, pyspark.sql.DataFrame): + return value + if args: + for value in args: + if isinstance(value, pyspark.sql.DataFrame): + return value + if scope == ValidationScope.SCHEMA: if CONFIG.validation_depth in ( ValidationDepth.SCHEMA_AND_DATA, @@ -89,17 +105,12 @@ def wrapper(self, *args, **kwargs): return func(self, *args, **kwargs) else: warnings.warn( - "Skipping Execution of function as parameters set to DATA_ONLY ", + f"Skipping execution of function {func.__name__} as validation depth is set to DATA_ONLY ", stacklevel=2, ) - if not kwargs: - for value in kwargs.values(): - if isinstance(value, pyspark.sql.DataFrame): - return value - if args: - for value in args: - if isinstance(value, pyspark.sql.DataFrame): - return value + # If the function was skip, return the `check_obj` value anyway, + # if it's present as a kwarg or an arg + return _get_check_obj() elif scope == ValidationScope.DATA: if CONFIG.validation_depth in ( @@ -109,9 +120,12 @@ def wrapper(self, *args, **kwargs): return func(self, *args, **kwargs) else: warnings.warn( - "Skipping Execution of function as parameters set to SCHEMA_ONLY ", + f"Skipping execution of function {func.__name__} as validation depth is set to SCHEMA_ONLY", stacklevel=2, ) + # If the function was skip, return the `check_obj` value anyway, + # if it's present as a kwarg or an arg + return _get_check_obj() return wrapper diff --git a/tests/pyspark/test_pyspark_config.py b/tests/pyspark/test_pyspark_config.py index 60ebfe4d9..6005c0c16 100644 --- a/tests/pyspark/test_pyspark_config.py +++ b/tests/pyspark/test_pyspark_config.py @@ -53,7 +53,7 @@ def test_schema_only(self, spark, sample_spark_schema): CONFIG.validation_enabled = True CONFIG.validation_depth = ValidationDepth.SCHEMA_ONLY - pandra_schema = DataFrameSchema( + pandera_schema = DataFrameSchema( { "product": Column(T.StringType(), Check.str_startswith("B")), "price_val": Column(T.IntegerType()), @@ -67,7 +67,7 @@ def test_schema_only(self, spark, sample_spark_schema): assert CONFIG.dict() == expected input_df = spark_df(spark, self.sample_data, sample_spark_schema) - output_dataframeschema_df = pandra_schema.validate(input_df) + output_dataframeschema_df = pandera_schema.validate(input_df) expected_dataframeschema = { "SCHEMA": { "COLUMN_NOT_IN_DATAFRAME": [ diff --git a/tests/pyspark/test_pyspark_model.py b/tests/pyspark/test_pyspark_model.py index 946467147..990b20a82 100644 --- a/tests/pyspark/test_pyspark_model.py +++ b/tests/pyspark/test_pyspark_model.py @@ -224,42 +224,6 @@ class Config: assert PanderaSchema.get_metadata() == expected -# @pytest.fixture -# def datamodel_unique_single_column() -> pa.DataFrameModel: -# """Fixture containing DataFrameModel with optional columns.""" - -# class MyDataModel(pa.DataFrameModel): -# """Simple DataFrameModel containing a column.""" - -# a: T.LongType = pa.Field() -# b: T.LongType = pa.Field() - -# class Config: -# """Config class.""" - -# unique = "a" - -# return MyDataModel - - -# @pytest.fixture -# def datamodel_unique_multiple_columns() -> pa.DataFrameModel: -# """Fixture containing DataFrameModel with optional columns.""" - -# class MyDataModel(pa.DataFrameModel): -# """Simple DataFrameModel containing a column.""" - -# a: T.LongType = pa.Field() -# b: T.LongType = pa.Field() - -# class Config: -# """Config class.""" - -# unique = ["a", "b"] - -# return MyDataModel - - @pytest.mark.parametrize( "data, expectation", [ @@ -276,7 +240,7 @@ class Config: pytest.raises(pa.PysparkSchemaError), ), ], - ids=["no_data", "non_duplicated_data", "duplicated_data"], + ids=["no_data", "unique_data", "duplicated_data"], ) def test_dataframe_schema_unique(spark, data, expectation): """Test uniqueness checks on pyspark dataframes.""" From 35e20cd0960863532bc91ed66c21395562484d24 Mon Sep 17 00:00:00 2001 From: Filipe Oliveira Date: Thu, 26 Oct 2023 18:49:58 -0300 Subject: [PATCH 5/7] fix one code coverage issue Signed-off-by: Filipe Oliveira --- pandera/backends/pyspark/container.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/pandera/backends/pyspark/container.py b/pandera/backends/pyspark/container.py index 7968eca8b..c4249cbcb 100644 --- a/pandera/backends/pyspark/container.py +++ b/pandera/backends/pyspark/container.py @@ -91,16 +91,9 @@ def _data_checks( """Run the checks related to data validation and uniqueness.""" # uniqueness of values - try: - check_obj = self.unique( - check_obj, schema=schema, error_handler=error_handler - ) - except SchemaError as exc: - error_handler.collect_error( - type=ErrorCategory.DATA, - reason_code=exc.reason_code, - schema_error=exc, - ) + check_obj = self.unique( + check_obj, schema=schema, error_handler=error_handler + ) return check_obj From 199d2b6b8775f7060c4a070c69bb44f8394856fa Mon Sep 17 00:00:00 2001 From: Filipe Oliveira Date: Fri, 27 Oct 2023 06:11:25 -0300 Subject: [PATCH 6/7] accept suggestions from Kasper Signed-off-by: Filipe Oliveira --- pandera/backends/pyspark/container.py | 5 +---- pandera/backends/pyspark/decorators.py | 8 ++------ 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/pandera/backends/pyspark/container.py b/pandera/backends/pyspark/container.py index c4249cbcb..96742682b 100644 --- a/pandera/backends/pyspark/container.py +++ b/pandera/backends/pyspark/container.py @@ -532,10 +532,7 @@ def _check_uniqueness( :param schema: schema object. :returns: dataframe checked. """ - # Use unique definition of columns as first option - # unique_columns = [col.unique for col in schema.columns.values()] - - # Overwrite it, if schemas's Config class has a unique declaration + # Determine unique columns based on schema's config unique_columns = ( [schema.unique] if isinstance(schema.unique, str) diff --git a/pandera/backends/pyspark/decorators.py b/pandera/backends/pyspark/decorators.py index aa93e5dbd..9c202b4be 100644 --- a/pandera/backends/pyspark/decorators.py +++ b/pandera/backends/pyspark/decorators.py @@ -88,10 +88,6 @@ def _get_check_obj(): Returns: The DataFrame object. """ - if kwargs: - for value in kwargs.values(): - if isinstance(value, pyspark.sql.DataFrame): - return value if args: for value in args: if isinstance(value, pyspark.sql.DataFrame): @@ -109,7 +105,7 @@ def _get_check_obj(): stacklevel=2, ) # If the function was skip, return the `check_obj` value anyway, - # if it's present as a kwarg or an arg + # given that some return value is expected return _get_check_obj() elif scope == ValidationScope.DATA: @@ -124,7 +120,7 @@ def _get_check_obj(): stacklevel=2, ) # If the function was skip, return the `check_obj` value anyway, - # if it's present as a kwarg or an arg + # given that some return value is expected return _get_check_obj() return wrapper From d6eea48e881a265bd8d4a1b907f020347a34220e Mon Sep 17 00:00:00 2001 From: Filipe Oliveira Date: Fri, 27 Oct 2023 14:01:06 -0300 Subject: [PATCH 7/7] add condition and test for invalid column name and flattened the unique functions Signed-off-by: Filipe Oliveira --- pandera/backends/pyspark/container.py | 64 +++++++++++++-------------- tests/pyspark/test_pyspark_model.py | 39 +++++++++++++++- 2 files changed, 70 insertions(+), 33 deletions(-) diff --git a/pandera/backends/pyspark/container.py b/pandera/backends/pyspark/container.py index 96742682b..c162f8039 100644 --- a/pandera/backends/pyspark/container.py +++ b/pandera/backends/pyspark/container.py @@ -91,9 +91,14 @@ def _data_checks( """Run the checks related to data validation and uniqueness.""" # uniqueness of values - check_obj = self.unique( - check_obj, schema=schema, error_handler=error_handler - ) + try: + check_obj = self.unique( + check_obj, schema=schema, error_handler=error_handler + ) + except SchemaError as err: + error_handler.collect_error( + ErrorCategory.DATA, err.reason_code, err + ) return check_obj @@ -213,7 +218,7 @@ def run_checks(self, check_obj: DataFrame, schema, error_handler): check_results = [] for check_index, check in enumerate( schema.checks - ): # schama.checks is null + ): # schema.checks is null try: check_results.append( self.run_check(check_obj, schema, check, check_index) @@ -507,31 +512,6 @@ def unique( if not schema.unique: return check_obj - try: - check_obj = self._check_uniqueness( - check_obj, - schema, - ) - except SchemaError as err: - if not error_handler.lazy: - raise err - error_handler.collect_error( - ErrorCategory.DATA, err.reason_code, err - ) - - return check_obj - - def _check_uniqueness( - self, - obj: DataFrame, - schema, - ) -> DataFrame: - """Ensure uniqueness in dataframe columns. - - :param obj: dataframe to check. - :param schema: schema object. - :returns: dataframe checked. - """ # Determine unique columns based on schema's config unique_columns = ( [schema.unique] @@ -539,8 +519,16 @@ def _check_uniqueness( else schema.unique ) + # Check if values belong to the dataframe columns + missing_unique_columns = set(unique_columns) - set(check_obj.columns) + if missing_unique_columns: + raise SchemaDefinitionError( + "Specified `unique` columns are missing in the dataframe: " + f"{list(missing_unique_columns)}" + ) + duplicates_count = ( - obj.select(*unique_columns) # ignore other cols + check_obj.select(*unique_columns) # ignore other cols .groupby(*unique_columns) .agg(count("*").alias("pandera_duplicate_counts")) .filter( @@ -552,7 +540,7 @@ def _check_uniqueness( if duplicates_count > 0: raise SchemaError( schema=schema, - data=obj, + data=check_obj, message=( f"Duplicated rows [{duplicates_count}] were found " f"for columns {unique_columns}" @@ -561,7 +549,19 @@ def _check_uniqueness( reason_code=SchemaErrorReason.DUPLICATES, ) - return obj + return check_obj + + def _check_uniqueness( + self, + obj: DataFrame, + schema, + ) -> DataFrame: + """Ensure uniqueness in dataframe columns. + + :param obj: dataframe to check. + :param schema: schema object. + :returns: dataframe checked. + """ ########## # Checks # diff --git a/tests/pyspark/test_pyspark_model.py b/tests/pyspark/test_pyspark_model.py index 990b20a82..8523d35e4 100644 --- a/tests/pyspark/test_pyspark_model.py +++ b/tests/pyspark/test_pyspark_model.py @@ -14,6 +14,9 @@ from pandera.pyspark import DataFrameModel, DataFrameSchema, Field from tests.pyspark.conftest import spark_df from pandera.api.pyspark.model import docstring_substitution +from pandera.errors import ( + SchemaDefinitionError, +) def test_schema_with_bare_types(): @@ -269,7 +272,7 @@ class Config: # Test `unique` configuration with multiple columns class UniqueMultipleColumns(pa.DataFrameModel): - """Simple DataFrameModel containing a column.""" + """Simple DataFrameModel containing two columns.""" a: T.IntegerType = pa.Field() b: T.IntegerType = pa.Field() @@ -288,6 +291,40 @@ class Config: raise pa.PysparkSchemaError +@pytest.mark.parametrize( + "unique_column_name", + [ + "x", + ["x", "y"], + ["x", ""], + ], + ids=[ + "wrong_column", + "multiple_wrong_columns", + "multiple_wrong_columns_w_empty", + ], +) +def test_dataframe_schema_unique_wrong_column(spark, unique_column_name): + """Test uniqueness checks on pyspark dataframes.""" + + df = spark.createDataFrame(([1, 2],), "a: int, b: int") + + # Test `unique` configuration with a single, wrongly named column + class UniqueMultipleColumns(pa.DataFrameModel): + """Simple DataFrameModel containing two columns.""" + + a: T.IntegerType = pa.Field() + b: T.IntegerType = pa.Field() + + class Config: + """Config class.""" + + unique = unique_column_name + + with pytest.raises(SchemaDefinitionError): + _ = UniqueMultipleColumns.validate(check_obj=df) + + def test_dataframe_schema_strict(spark, config_params: PanderaConfig) -> None: """ Checks if strict=True whether a schema error is raised because either extra columns are present in the dataframe