Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for unique validation in PySpark #1396

Merged
merged 7 commits into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions docs/source/pyspark_sql.rst
Original file line number Diff line number Diff line change
Expand Up @@ -343,3 +343,16 @@ We also provided a helper function to extract metadata from a schema as follows:
.. note::

This feature is available for ``pyspark.sql`` and ``pandas`` both.

`unique` support
----------------

*new in 0.17.3*

.. warning::

The `unique` support for PySpark-based validations to define which columns must be
tested for unique values may incur in a performance hit, given Spark's distributed
nature.

Use with caution.
144 changes: 113 additions & 31 deletions pandera/backends/pyspark/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Any, Dict, List, Optional

from pyspark.sql import DataFrame
from pyspark.sql.functions import col
from pyspark.sql.functions import col, count

from pandera.api.pyspark.error_handler import ErrorCategory, ErrorHandler
from pandera.api.pyspark.types import is_table
Expand All @@ -15,7 +15,6 @@
from pandera.backends.pyspark.error_formatters import scalar_failure_case
from pandera.config import CONFIG
from pandera.errors import (
ParserError,
SchemaDefinitionError,
SchemaError,
SchemaErrorReason,
Expand All @@ -31,14 +30,14 @@ def preprocess(self, check_obj: DataFrame, inplace: bool = False):
return check_obj

@validate_scope(scope=ValidationScope.SCHEMA)
def _column_checks(
def _schema_checks(
self,
check_obj: DataFrame,
schema,
column_info: ColumnInfo,
error_handler: ErrorHandler,
):
"""run the checks related to columns presence, uniqueness and filter column if neccesary"""
"""run the checks related to columns presence, strictness and filter column if neccesary"""

# check the container metadata, e.g. field names
try:
Expand Down Expand Up @@ -71,6 +70,7 @@ def _column_checks(
reason_code=exc.reason_code,
schema_error=exc,
)

# try to coerce datatypes
check_obj = self.coerce_dtype(
check_obj,
Expand All @@ -80,6 +80,30 @@ def _column_checks(

return check_obj

@validate_scope(scope=ValidationScope.DATA)
def _data_checks(
Comment on lines +83 to +84
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A separate function was defined to run DATA-related validations when this validation depth is enabled, given that the existing one (_schema_checks) is designated to SCHEMA-related validations.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great @filipeo2-mck! I like how we're catching the SchemaError raised by the unique method, but I think it's a bit weird that we're rasing a Schema error when it's potentially nothing wrong with the schema. @cosmicBboy, does it make sense to introduce a DataError in pandera.errors to make sure we distigish between different type of errors, or is there any reason that we don't want it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Semantically, what does the SchemaError in pandera.errors object stands for?

  • the segregation between two available validation types (schema vs data validations)? If this is the meaning, we should probably add data-related DataError exceptions to this namespace, to be specific about the kind of issue being raised.
  • the Pandera objects or components, at a macro level (a schema, some data - the df, the declared checks, the existing columns...)? If this is the meaning, I see no issues about calling it SchemaError.

Copy link
Collaborator

@cosmicBboy cosmicBboy Oct 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Semantically, SchemaError(s) stand for anything that's wrong with the data or metadata of a validated object. That includes metadata (column names, types, etc), and data (actual data values contained in the object).

I think for clarity we should rename ValidationScope.SCHEMA to ValidationScore.METADATA to clarify the difference in pandera (I understand that the term Schema often refers to what I'm calling metadata here i.e. columns and their types, but pandera takes a slightly broader view of what a schema is).

but I think it's a bit weird that we're rasing a Schema error when it's potentially nothing wrong with the schema. @cosmicBboy, does it make sense to introduce a DataError in pandera.errors to make sure we distigish between different type of errors, or is there any reason that we don't want it?

I'm not sure what the concern is here: if the unique method is the function that raises a SchemaError that's caught elsewhere in the validation pipeline, doesn't that mean, by definition, that there's something wrong with the data (under pandera's definition of a "schema")?

self,
check_obj: DataFrame,
schema,
column_info: ColumnInfo, # pylint: disable=unused-argument
error_handler: ErrorHandler,
):
"""Run the checks related to data validation and uniqueness."""

# uniqueness of values
try:
check_obj = self.unique(
check_obj, schema=schema, error_handler=error_handler
)
except SchemaError as exc:
error_handler.collect_error(
type=ErrorCategory.DATA,
reason_code=exc.reason_code,
schema_error=exc,
)

return check_obj

def validate(
self,
check_obj: DataFrame,
Expand Down Expand Up @@ -115,8 +139,13 @@ def validate(
check_obj = check_obj.pandera.add_schema(schema)
column_info = self.collect_column_info(check_obj, schema, lazy)

# validate the columns of the dataframe
check_obj = self._column_checks(
# validate the columns (schema) of the dataframe
check_obj = self._schema_checks(
check_obj, schema, column_info, error_handler
)

# validate the rows (data) of the dataframe
check_obj = self._data_checks(
check_obj, schema, column_info, error_handler
)

Expand Down Expand Up @@ -386,8 +415,7 @@ def coerce_dtype(
except SchemaErrors as err:
for schema_error_dict in err.schema_errors:
if not error_handler.lazy:
# raise the first error immediately if not doing lazy
# validation
# raise the first error immediately if not doing lazy validation
raise schema_error_dict["error"]
error_handler.collect_error(
ErrorCategory.DTYPE_COERCION,
Expand Down Expand Up @@ -417,27 +445,6 @@ def _coerce_dtype(
# NOTE: clean up the error handling!
error_handler = ErrorHandler(lazy=True)

def _coerce_df_dtype(obj: DataFrame) -> DataFrame:
if schema.dtype is None:
raise ValueError(
"dtype argument is None. Must specify this argument "
"to coerce dtype"
)

try:
return schema.dtype.try_coerce(obj)
except ParserError as exc:
raise SchemaError(
schema=schema,
data=obj,
message=(
f"Error while coercing '{schema.name}' to type "
f"{schema.dtype}: {exc}\n{exc.failure_cases}"
),
failure_cases=exc.failure_cases,
check=f"coerce_dtype('{schema.dtype}')",
) from exc

Comment on lines -420 to -440
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Different from pandas namespace, this function was not being used inside the coercion functions. Removed.

def _try_coercion(obj, colname, col_schema):
try:
schema = obj.pandera.schema
Expand Down Expand Up @@ -490,6 +497,82 @@ def _try_coercion(obj, colname, col_schema):

return obj

@validate_scope(scope=ValidationScope.DATA)
def unique(
self,
check_obj: DataFrame,
*,
schema=None,
error_handler: ErrorHandler = None,
):
"""Check uniqueness in the check object."""
assert schema is not None, "The `schema` argument must be provided."
assert (
error_handler is not None
), "The `error_handler` argument must be provided."

if not schema.unique:
return check_obj

try:
check_obj = self._check_uniqueness(
check_obj,
schema,
)
except SchemaError as err:
if not error_handler.lazy:
raise err
error_handler.collect_error(
ErrorCategory.DATA, err.reason_code, err
)

return check_obj

def _check_uniqueness(
self,
obj: DataFrame,
schema,
) -> DataFrame:
"""Ensure uniqueness in dataframe columns.

:param obj: dataframe to check.
:param schema: schema object.
:returns: dataframe checked.
"""
# Use unique definition of columns as first option
# unique_columns = [col.unique for col in schema.columns.values()]

# Overwrite it, if schemas's Config class has a unique declaration
unique_columns = (
[schema.unique]
if isinstance(schema.unique, str)
else schema.unique
)

duplicates_count = (
obj.select(*unique_columns) # ignore other cols
.groupby(*unique_columns)
.agg(count("*").alias("pandera_duplicate_counts"))
.filter(
col("pandera_duplicate_counts") > 1
) # long name to avoid colisions
.count()
)

if duplicates_count > 0:
raise SchemaError(
schema=schema,
data=obj,
message=(
f"Duplicated rows [{duplicates_count}] were found "
f"for columns {unique_columns}"
),
check="unique",
reason_code=SchemaErrorReason.DUPLICATES,
)
filipeo2-mck marked this conversation as resolved.
Show resolved Hide resolved

return obj

##########
# Checks #
##########
Expand All @@ -516,8 +599,7 @@ def check_column_names_are_unique(self, check_obj: DataFrame, schema):
schema=schema,
data=check_obj,
message=(
"dataframe contains multiple columns with label(s): "
f"{failed}"
f"dataframe contains multiple columns with label(s): {failed}"
),
failure_cases=scalar_failure_case(failed),
check="dataframe_column_labels_unique",
Expand Down
34 changes: 24 additions & 10 deletions pandera/backends/pyspark/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,22 @@ def validate_scope(scope: ValidationScope):
def _wrapper(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
def _get_check_obj():
"""
Get dataframe object passed as arg to the decorated func.

Returns:
The DataFrame object.
"""
if kwargs:
for value in kwargs.values():
if isinstance(value, pyspark.sql.DataFrame):
return value
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code coverage does not reach this because we are not passing check_obj as a kwarg. We use an arg instead.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Can you confirm if there's any scenario where a DataFrame might be passed as a kwarg?
  • For our tests related to this decorator, are we ensuring that we pass the pyspark.sql.DataFrame both as a positional and as a keyword argument?
  • If we anticipate that the DataFrame will always be passed as a positional argument and never as a keyword argument, would it make sense to refactor the decorator to remove the kwargs check for simplicity?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wrongly assumed that this decorator was used in pandas and other namespaces too and I wanted to keep the pattern, but checking now, there is no such thing in other integrations.
I'm changing it to remove the kwargs capability, as it's not used anywhere. Thank you for the input.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

decorator was added to pyspark only. I hope eventually pandas will follow the suite...

if args:
for value in args:
if isinstance(value, pyspark.sql.DataFrame):
return value

if scope == ValidationScope.SCHEMA:
if CONFIG.validation_depth in (
ValidationDepth.SCHEMA_AND_DATA,
Expand All @@ -89,17 +105,12 @@ def wrapper(self, *args, **kwargs):
return func(self, *args, **kwargs)
else:
warnings.warn(
"Skipping Execution of function as parameters set to DATA_ONLY ",
f"Skipping execution of function {func.__name__} as validation depth is set to DATA_ONLY ",
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Improving warning messages

stacklevel=2,
)
if not kwargs:
for value in kwargs.values():
if isinstance(value, pyspark.sql.DataFrame):
return value
if args:
for value in args:
if isinstance(value, pyspark.sql.DataFrame):
return value
# If the function was skip, return the `check_obj` value anyway,
# if it's present as a kwarg or an arg
return _get_check_obj()

elif scope == ValidationScope.DATA:
if CONFIG.validation_depth in (
Expand All @@ -109,9 +120,12 @@ def wrapper(self, *args, **kwargs):
return func(self, *args, **kwargs)
else:
warnings.warn(
"Skipping Execution of function as parameters set to SCHEMA_ONLY ",
f"Skipping execution of function {func.__name__} as validation depth is set to SCHEMA_ONLY",
stacklevel=2,
)
# If the function was skip, return the `check_obj` value anyway,
# if it's present as a kwarg or an arg
return _get_check_obj()

return wrapper

Expand Down
4 changes: 2 additions & 2 deletions tests/pyspark/test_pyspark_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_schema_only(self, spark, sample_spark_schema):
CONFIG.validation_enabled = True
CONFIG.validation_depth = ValidationDepth.SCHEMA_ONLY

pandra_schema = DataFrameSchema(
pandera_schema = DataFrameSchema(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo

{
"product": Column(T.StringType(), Check.str_startswith("B")),
"price_val": Column(T.IntegerType()),
Expand All @@ -67,7 +67,7 @@ def test_schema_only(self, spark, sample_spark_schema):
assert CONFIG.dict() == expected

input_df = spark_df(spark, self.sample_data, sample_spark_schema)
output_dataframeschema_df = pandra_schema.validate(input_df)
output_dataframeschema_df = pandera_schema.validate(input_df)
expected_dataframeschema = {
"SCHEMA": {
"COLUMN_NOT_IN_DATAFRAME": [
Expand Down
65 changes: 65 additions & 0 deletions tests/pyspark/test_pyspark_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Unit tests for DataFrameModel module."""
# pylint:disable=abstract-method

from contextlib import nullcontext as does_not_raise
from typing import Optional
from pyspark.sql import DataFrame
import pyspark.sql.types as T
Expand Down Expand Up @@ -223,6 +224,70 @@ class Config:
assert PanderaSchema.get_metadata() == expected


@pytest.mark.parametrize(
"data, expectation",
[
(
(),
does_not_raise(),
),
(
([1, 4], [2, 5], [3, 6]),
does_not_raise(),
),
(
([0, 0], [0, 0], [3, 6]),
pytest.raises(pa.PysparkSchemaError),
),
],
ids=["no_data", "unique_data", "duplicated_data"],
)
def test_dataframe_schema_unique(spark, data, expectation):
"""Test uniqueness checks on pyspark dataframes."""

df = spark.createDataFrame(data, "a: int, b: int")

# Test `unique` configuration with a single column
class UniqueSingleColumn(pa.DataFrameModel):
"""Simple DataFrameModel containing a column."""

a: T.IntegerType = pa.Field()
b: T.IntegerType = pa.Field()

class Config:
"""Config class."""

unique = "a"

assert isinstance(UniqueSingleColumn(df), DataFrame)

with expectation:
df_out = UniqueSingleColumn.validate(check_obj=df)
if df_out.pandera.errors:
print(f"{df_out.pandera.errors=}")
raise pa.PysparkSchemaError

# Test `unique` configuration with multiple columns
class UniqueMultipleColumns(pa.DataFrameModel):
"""Simple DataFrameModel containing a column."""

a: T.IntegerType = pa.Field()
b: T.IntegerType = pa.Field()

class Config:
"""Config class."""

unique = ["a", "b"]

assert isinstance(UniqueMultipleColumns(df), DataFrame)

with expectation:
df_out = UniqueMultipleColumns.validate(check_obj=df)
if df_out.pandera.errors:
print(f"{df_out.pandera.errors=}")
raise pa.PysparkSchemaError


def test_dataframe_schema_strict(spark, config_params: PanderaConfig) -> None:
"""
Checks if strict=True whether a schema error is raised because either extra columns are present in the dataframe
Expand Down