Skip to content

Commit

Permalink
Add support for unique validation in PySpark (#1396)
Browse files Browse the repository at this point in the history
* working, not the tests

Signed-off-by: Filipe Oliveira <[email protected]>

* tests working, missing docs

Signed-off-by: Filipe Oliveira <[email protected]>

* add suggestion to docs

Signed-off-by: Filipe Oliveira <[email protected]>

* fix failing test and add specific method for data-scoped validations

Signed-off-by: Filipe Oliveira <[email protected]>

* fix one code coverage issue

Signed-off-by: Filipe Oliveira <[email protected]>

* accept suggestions from Kasper

Signed-off-by: Filipe Oliveira <[email protected]>

* add condition and test for invalid column name and flattened the unique functions

Signed-off-by: Filipe Oliveira <[email protected]>

---------

Signed-off-by: Filipe Oliveira <[email protected]>
  • Loading branch information
filipeo2-mck authored Oct 31, 2023
1 parent cf6b5e4 commit de0ec5f
Show file tree
Hide file tree
Showing 5 changed files with 241 additions and 44 deletions.
13 changes: 13 additions & 0 deletions docs/source/pyspark_sql.rst
Original file line number Diff line number Diff line change
Expand Up @@ -343,3 +343,16 @@ We also provided a helper function to extract metadata from a schema as follows:
.. note::

This feature is available for ``pyspark.sql`` and ``pandas`` both.

`unique` support
----------------

*new in 0.17.3*

.. warning::

The `unique` support for PySpark-based validations to define which columns must be
tested for unique values may incur in a performance hit, given Spark's distributed
nature.

Use with caution.
136 changes: 104 additions & 32 deletions pandera/backends/pyspark/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Any, Dict, List, Optional

from pyspark.sql import DataFrame
from pyspark.sql.functions import col
from pyspark.sql.functions import col, count

from pandera.api.pyspark.error_handler import ErrorCategory, ErrorHandler
from pandera.api.pyspark.types import is_table
Expand All @@ -15,7 +15,6 @@
from pandera.backends.pyspark.error_formatters import scalar_failure_case
from pandera.config import CONFIG
from pandera.errors import (
ParserError,
SchemaDefinitionError,
SchemaError,
SchemaErrorReason,
Expand All @@ -31,14 +30,14 @@ def preprocess(self, check_obj: DataFrame, inplace: bool = False):
return check_obj

@validate_scope(scope=ValidationScope.SCHEMA)
def _column_checks(
def _schema_checks(
self,
check_obj: DataFrame,
schema,
column_info: ColumnInfo,
error_handler: ErrorHandler,
):
"""run the checks related to columns presence, uniqueness and filter column if neccesary"""
"""run the checks related to columns presence, strictness and filter column if neccesary"""

# check the container metadata, e.g. field names
try:
Expand Down Expand Up @@ -71,6 +70,7 @@ def _column_checks(
reason_code=exc.reason_code,
schema_error=exc,
)

# try to coerce datatypes
check_obj = self.coerce_dtype(
check_obj,
Expand All @@ -80,6 +80,28 @@ def _column_checks(

return check_obj

@validate_scope(scope=ValidationScope.DATA)
def _data_checks(
self,
check_obj: DataFrame,
schema,
column_info: ColumnInfo, # pylint: disable=unused-argument
error_handler: ErrorHandler,
):
"""Run the checks related to data validation and uniqueness."""

# uniqueness of values
try:
check_obj = self.unique(
check_obj, schema=schema, error_handler=error_handler
)
except SchemaError as err:
error_handler.collect_error(
ErrorCategory.DATA, err.reason_code, err
)

return check_obj

def validate(
self,
check_obj: DataFrame,
Expand Down Expand Up @@ -115,8 +137,13 @@ def validate(
check_obj = check_obj.pandera.add_schema(schema)
column_info = self.collect_column_info(check_obj, schema, lazy)

# validate the columns of the dataframe
check_obj = self._column_checks(
# validate the columns (schema) of the dataframe
check_obj = self._schema_checks(
check_obj, schema, column_info, error_handler
)

# validate the rows (data) of the dataframe
check_obj = self._data_checks(
check_obj, schema, column_info, error_handler
)

Expand Down Expand Up @@ -191,7 +218,7 @@ def run_checks(self, check_obj: DataFrame, schema, error_handler):
check_results = []
for check_index, check in enumerate(
schema.checks
): # schama.checks is null
): # schema.checks is null
try:
check_results.append(
self.run_check(check_obj, schema, check, check_index)
Expand Down Expand Up @@ -386,8 +413,7 @@ def coerce_dtype(
except SchemaErrors as err:
for schema_error_dict in err.schema_errors:
if not error_handler.lazy:
# raise the first error immediately if not doing lazy
# validation
# raise the first error immediately if not doing lazy validation
raise schema_error_dict["error"]
error_handler.collect_error(
ErrorCategory.DTYPE_COERCION,
Expand Down Expand Up @@ -417,27 +443,6 @@ def _coerce_dtype(
# NOTE: clean up the error handling!
error_handler = ErrorHandler(lazy=True)

def _coerce_df_dtype(obj: DataFrame) -> DataFrame:
if schema.dtype is None:
raise ValueError(
"dtype argument is None. Must specify this argument "
"to coerce dtype"
)

try:
return schema.dtype.try_coerce(obj)
except ParserError as exc:
raise SchemaError(
schema=schema,
data=obj,
message=(
f"Error while coercing '{schema.name}' to type "
f"{schema.dtype}: {exc}\n{exc.failure_cases}"
),
failure_cases=exc.failure_cases,
check=f"coerce_dtype('{schema.dtype}')",
) from exc

def _try_coercion(obj, colname, col_schema):
try:
schema = obj.pandera.schema
Expand Down Expand Up @@ -490,6 +495,74 @@ def _try_coercion(obj, colname, col_schema):

return obj

@validate_scope(scope=ValidationScope.DATA)
def unique(
self,
check_obj: DataFrame,
*,
schema=None,
error_handler: ErrorHandler = None,
):
"""Check uniqueness in the check object."""
assert schema is not None, "The `schema` argument must be provided."
assert (
error_handler is not None
), "The `error_handler` argument must be provided."

if not schema.unique:
return check_obj

# Determine unique columns based on schema's config
unique_columns = (
[schema.unique]
if isinstance(schema.unique, str)
else schema.unique
)

# Check if values belong to the dataframe columns
missing_unique_columns = set(unique_columns) - set(check_obj.columns)
if missing_unique_columns:
raise SchemaDefinitionError(
"Specified `unique` columns are missing in the dataframe: "
f"{list(missing_unique_columns)}"
)

duplicates_count = (
check_obj.select(*unique_columns) # ignore other cols
.groupby(*unique_columns)
.agg(count("*").alias("pandera_duplicate_counts"))
.filter(
col("pandera_duplicate_counts") > 1
) # long name to avoid colisions
.count()
)

if duplicates_count > 0:
raise SchemaError(
schema=schema,
data=check_obj,
message=(
f"Duplicated rows [{duplicates_count}] were found "
f"for columns {unique_columns}"
),
check="unique",
reason_code=SchemaErrorReason.DUPLICATES,
)

return check_obj

def _check_uniqueness(
self,
obj: DataFrame,
schema,
) -> DataFrame:
"""Ensure uniqueness in dataframe columns.
:param obj: dataframe to check.
:param schema: schema object.
:returns: dataframe checked.
"""

##########
# Checks #
##########
Expand All @@ -516,8 +589,7 @@ def check_column_names_are_unique(self, check_obj: DataFrame, schema):
schema=schema,
data=check_obj,
message=(
"dataframe contains multiple columns with label(s): "
f"{failed}"
f"dataframe contains multiple columns with label(s): {failed}"
),
failure_cases=scalar_failure_case(failed),
check="dataframe_column_labels_unique",
Expand Down
30 changes: 20 additions & 10 deletions pandera/backends/pyspark/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,18 @@ def validate_scope(scope: ValidationScope):
def _wrapper(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
def _get_check_obj():
"""
Get dataframe object passed as arg to the decorated func.
Returns:
The DataFrame object.
"""
if args:
for value in args:
if isinstance(value, pyspark.sql.DataFrame):
return value

if scope == ValidationScope.SCHEMA:
if CONFIG.validation_depth in (
ValidationDepth.SCHEMA_AND_DATA,
Expand All @@ -89,17 +101,12 @@ def wrapper(self, *args, **kwargs):
return func(self, *args, **kwargs)
else:
warnings.warn(
"Skipping Execution of function as parameters set to DATA_ONLY ",
f"Skipping execution of function {func.__name__} as validation depth is set to DATA_ONLY ",
stacklevel=2,
)
if not kwargs:
for value in kwargs.values():
if isinstance(value, pyspark.sql.DataFrame):
return value
if args:
for value in args:
if isinstance(value, pyspark.sql.DataFrame):
return value
# If the function was skip, return the `check_obj` value anyway,
# given that some return value is expected
return _get_check_obj()

elif scope == ValidationScope.DATA:
if CONFIG.validation_depth in (
Expand All @@ -109,9 +116,12 @@ def wrapper(self, *args, **kwargs):
return func(self, *args, **kwargs)
else:
warnings.warn(
"Skipping Execution of function as parameters set to SCHEMA_ONLY ",
f"Skipping execution of function {func.__name__} as validation depth is set to SCHEMA_ONLY",
stacklevel=2,
)
# If the function was skip, return the `check_obj` value anyway,
# given that some return value is expected
return _get_check_obj()

return wrapper

Expand Down
4 changes: 2 additions & 2 deletions tests/pyspark/test_pyspark_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_schema_only(self, spark, sample_spark_schema):
CONFIG.validation_enabled = True
CONFIG.validation_depth = ValidationDepth.SCHEMA_ONLY

pandra_schema = DataFrameSchema(
pandera_schema = DataFrameSchema(
{
"product": Column(T.StringType(), Check.str_startswith("B")),
"price_val": Column(T.IntegerType()),
Expand All @@ -67,7 +67,7 @@ def test_schema_only(self, spark, sample_spark_schema):
assert CONFIG.dict() == expected

input_df = spark_df(spark, self.sample_data, sample_spark_schema)
output_dataframeschema_df = pandra_schema.validate(input_df)
output_dataframeschema_df = pandera_schema.validate(input_df)
expected_dataframeschema = {
"SCHEMA": {
"COLUMN_NOT_IN_DATAFRAME": [
Expand Down
Loading

0 comments on commit de0ec5f

Please sign in to comment.