diff --git a/tests/pyspark/test_pyspark_model.py b/tests/pyspark/test_pyspark_model.py index d371fd24f..946467147 100644 --- a/tests/pyspark/test_pyspark_model.py +++ b/tests/pyspark/test_pyspark_model.py @@ -224,81 +224,103 @@ class Config: assert PanderaSchema.get_metadata() == expected -@pytest.fixture -def datamodel_unique_single_column() -> pa.DataFrameModel: - """Fixture containing DataFrameModel with optional columns.""" +# @pytest.fixture +# def datamodel_unique_single_column() -> pa.DataFrameModel: +# """Fixture containing DataFrameModel with optional columns.""" - class MyDataModel(pa.DataFrameModel): - """Simple DataFrameModel containing a column.""" +# class MyDataModel(pa.DataFrameModel): +# """Simple DataFrameModel containing a column.""" - a: T.LongType = pa.Field() - b: T.LongType = pa.Field() +# a: T.LongType = pa.Field() +# b: T.LongType = pa.Field() - class Config: - """Config class.""" +# class Config: +# """Config class.""" - unique = "a" +# unique = "a" - return MyDataModel +# return MyDataModel -@pytest.fixture -def datamodel_unique_multiple_columns() -> pa.DataFrameModel: - """Fixture containing DataFrameModel with optional columns.""" +# @pytest.fixture +# def datamodel_unique_multiple_columns() -> pa.DataFrameModel: +# """Fixture containing DataFrameModel with optional columns.""" - class MyDataModel(pa.DataFrameModel): - """Simple DataFrameModel containing a column.""" +# class MyDataModel(pa.DataFrameModel): +# """Simple DataFrameModel containing a column.""" - a: T.LongType = pa.Field() - b: T.LongType = pa.Field() +# a: T.LongType = pa.Field() +# b: T.LongType = pa.Field() - class Config: - """Config class.""" +# class Config: +# """Config class.""" - unique = ["a", "b"] +# unique = ["a", "b"] - return MyDataModel +# return MyDataModel @pytest.mark.parametrize( - "data_model, data, expectation", + "data, expectation", [ ( - datamodel_unique_single_column, - ([1, 4], [2, 5], [3, 6]), + (), does_not_raise(), ), ( - datamodel_unique_multiple_columns, ([1, 4], [2, 5], [3, 6]), does_not_raise(), ), ( - datamodel_unique_single_column, - ([0, 0], [0, 0], [3, 6]), - pytest.raises(pa.PysparkSchemaError), - ), - ( - datamodel_unique_multiple_columns, ([0, 0], [0, 0], [3, 6]), pytest.raises(pa.PysparkSchemaError), ), ], + ids=["no_data", "non_duplicated_data", "duplicated_data"], ) -def test_dataframe_schema_unique(spark, data_model, data, expectation): +def test_dataframe_schema_unique(spark, data, expectation): """Test uniqueness checks on pyspark dataframes.""" - print(f"{type(spark)=}") - print(f"{type(data_model)=}") - print(f"{type(data)=}") - print(f"{type(expectation)=}") - df = spark.createDataFrame(data, ["a", "b"]) + df = spark.createDataFrame(data, "a: int, b: int") + + # Test `unique` configuration with a single column + class UniqueSingleColumn(pa.DataFrameModel): + """Simple DataFrameModel containing a column.""" + + a: T.IntegerType = pa.Field() + b: T.IntegerType = pa.Field() + + class Config: + """Config class.""" + + unique = "a" + + assert isinstance(UniqueSingleColumn(df), DataFrame) + + with expectation: + df_out = UniqueSingleColumn.validate(check_obj=df) + if df_out.pandera.errors: + print(f"{df_out.pandera.errors=}") + raise pa.PysparkSchemaError + + # Test `unique` configuration with multiple columns + class UniqueMultipleColumns(pa.DataFrameModel): + """Simple DataFrameModel containing a column.""" + + a: T.IntegerType = pa.Field() + b: T.IntegerType = pa.Field() + + class Config: + """Config class.""" + + unique = ["a", "b"] - # assert isinstance(data_model(df), DataFrame) + assert isinstance(UniqueMultipleColumns(df), DataFrame) with expectation: - df_out = data_model.validate(check_obj=df) + df_out = UniqueMultipleColumns.validate(check_obj=df) if df_out.pandera.errors: + print(f"{df_out.pandera.errors=}") raise pa.PysparkSchemaError