From a9e42039daca1d328d4793c4b71f0d2f713ebfc9 Mon Sep 17 00:00:00 2001 From: duanmeng Date: Mon, 30 Oct 2023 13:43:00 -0700 Subject: [PATCH] Fail if compare null for complex type in contains (#7273) Summary: In Presto, `contains` is applied to arrays of complex types with nested nulls may or may not fail depending on whether need to compare the nulls or not. ```SQL presto> select contains(col0, col1) from (values (array[array[1, 3]], array[2, null])) as tbl(col0, col1); _col0 ------- false (1 row) presto> select contains(col0, col1) from (values (array[array[2, 3]], array[2, null])) as tbl(col0, col1); Query 20231027_015548_00004_whazu failed: contains does not support arrays with elements that are null or contain null presto> select contains(col0, col1) from (values (array[array[2, null]], array[2, 3])) as tbl(col0, col1); Query 20231027_015600_00005_whazu failed: contains does not support arrays with elements that are null or contain null presto> select contains(col0, col1) from (values (array[array[1, null]], array[2, 3])) as tbl(col0, col1); _col0 ------- false (1 row) ``` This PR check the equal result and `setError` if compare the nested nulls of complex type. Pull Request resolved: https://github.com/facebookincubator/velox/pull/7273 Reviewed By: xiaoxmeng, Yuhta Differential Revision: D50793771 Pulled By: mbasmanova fbshipit-source-id: 991d018ea79fe46ec4746c0038160ead64e5cb70 --- velox/docs/functions/presto/array.rst | 6 ++ velox/functions/prestosql/ArrayContains.cpp | 34 +++++++---- .../prestosql/tests/ArrayContainsTest.cpp | 58 +++++++++++++++++++ 3 files changed, 88 insertions(+), 10 deletions(-) diff --git a/velox/docs/functions/presto/array.rst b/velox/docs/functions/presto/array.rst index af71e2e409d5..952e051c3493 100644 --- a/velox/docs/functions/presto/array.rst +++ b/velox/docs/functions/presto/array.rst @@ -218,6 +218,12 @@ Array Functions .. function:: contains(x, element) -> boolean Returns true if the array ``x`` contains the ``element``. + When 'element' is of complex type, throws if 'x' or 'element' contains nested nulls + and these need to be compared to produce a result. :: + + SELECT contains(ARRAY[ARRAY[1, 3]], ARRAY[2, null]); -- false. + SELECT contains(ARRAY[ARRAY[2, 3]], ARRAY[2, null]); -- failed: contains does not support arrays with elements that are null or contain null + SELECT contains(ARRAY[ARRAY[2, null]], ARRAY[2, 1]); -- failed: contains does not support arrays with elements that are null or contain null .. function:: element_at(array(E), index) -> E diff --git a/velox/functions/prestosql/ArrayContains.cpp b/velox/functions/prestosql/ArrayContains.cpp index aeb9b397f9ad..7ca6376300d5 100644 --- a/velox/functions/prestosql/ArrayContains.cpp +++ b/velox/functions/prestosql/ArrayContains.cpp @@ -25,6 +25,7 @@ void applyTyped( DecodedVector& arrayDecoded, DecodedVector& elementsDecoded, DecodedVector& searchDecoded, + exec::EvalCtx& /*context*/, FlatVector& flatResult) { using T = typename TypeTraits::NativeType; @@ -85,6 +86,7 @@ void applyComplexType( DecodedVector& arrayDecoded, DecodedVector& elementsDecoded, DecodedVector& searchDecoded, + exec::EvalCtx& context, FlatVector& flatResult) { auto baseArray = arrayDecoded.base()->as(); auto rawSizes = baseArray->rawSizes(); @@ -96,20 +98,28 @@ void applyComplexType( auto searchBase = searchDecoded.base(); auto searchIndices = searchDecoded.indices(); - rows.applyToSelected([&](auto row) { + context.applyToSelectedNoThrow(rows, [&](auto row) { auto size = rawSizes[indices[row]]; auto offset = rawOffsets[indices[row]]; - bool foundNull = false; - auto searchIndex = searchIndices[row]; + for (auto i = 0; i < size; i++) { if (elementsDecoded.isNullAt(offset + i)) { foundNull = true; - } else if (elementsBase->equalValueAt( - searchBase, elementIndices[offset + i], searchIndex)) { - flatResult.set(row, true); - return; + } else { + std::optional result = elementsBase->equalValueAt( + searchBase, + elementIndices[offset + i], + searchIndex, + CompareFlags::NullHandlingMode::StopAtNull); + VELOX_USER_CHECK( + result.has_value(), + "contains does not support arrays with elements that contain null"); + if (result.value()) { + flatResult.set(row, true); + return; + } } } @@ -127,9 +137,10 @@ void applyTyped( DecodedVector& arrayDecoded, DecodedVector& elementsDecoded, DecodedVector& searchDecoded, + exec::EvalCtx& context, FlatVector& flatResult) { applyComplexType( - rows, arrayDecoded, elementsDecoded, searchDecoded, flatResult); + rows, arrayDecoded, elementsDecoded, searchDecoded, context, flatResult); } template <> @@ -138,9 +149,10 @@ void applyTyped( DecodedVector& arrayDecoded, DecodedVector& elementsDecoded, DecodedVector& searchDecoded, + exec::EvalCtx& context, FlatVector& flatResult) { applyComplexType( - rows, arrayDecoded, elementsDecoded, searchDecoded, flatResult); + rows, arrayDecoded, elementsDecoded, searchDecoded, context, flatResult); } template <> @@ -149,9 +161,10 @@ void applyTyped( DecodedVector& arrayDecoded, DecodedVector& elementsDecoded, DecodedVector& searchDecoded, + exec::EvalCtx& context, FlatVector& flatResult) { applyComplexType( - rows, arrayDecoded, elementsDecoded, searchDecoded, flatResult); + rows, arrayDecoded, elementsDecoded, searchDecoded, context, flatResult); } class ArrayContainsFunction : public exec::VectorFunction { @@ -191,6 +204,7 @@ class ArrayContainsFunction : public exec::VectorFunction { *arrayHolder.get(), *elementsHolder.get(), *searchHolder.get(), + context, *flatResult); } diff --git a/velox/functions/prestosql/tests/ArrayContainsTest.cpp b/velox/functions/prestosql/tests/ArrayContainsTest.cpp index 7bf8f787cdfc..8e105222b3ba 100644 --- a/velox/functions/prestosql/tests/ArrayContainsTest.cpp +++ b/velox/functions/prestosql/tests/ArrayContainsTest.cpp @@ -15,6 +15,7 @@ */ #include +#include "velox/common/base/tests/GTestUtils.h" #include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h" using namespace facebook::velox; @@ -316,4 +317,61 @@ TEST_F(ArrayContainsTest, dictionaryEncodingElements) { testContainsConstantKey(arrayVector, {3, 4, 5}, {true, true}); } +TEST_F(ArrayContainsTest, arrayCheckNulls) { + const auto baseVector = makeArrayVectorFromJson({ + "[1, 1]", + "[2, 2]", + "[3, null]", + "[4, 4]", + "[5, 5]", + "[6, 6]", + }); + const auto data = makeArrayVector({0, 3}, baseVector); + auto contains = [&](const std::string& search) { + const auto searchBase = makeArrayVectorFromJson({search}); + const auto searchConstant = + BaseVector::wrapInConstant(data->size(), 0, searchBase); + const auto result = + evaluate("contains(c0, c1)", makeRowVector({data, searchConstant})); + return result->asFlatVector()->valueAt(0); + }; + + static const std::string kErrorMessage = + "contains does not support arrays with elements that contain null"; + // No null equal. + ASSERT_FALSE(contains("[7, null]")); + // Null equal, [3, null] vs [3, 3]. + VELOX_ASSERT_THROW(contains("[3, 3]"), kErrorMessage); + // Null equal, [6, 6] vs [6, null]. + VELOX_ASSERT_THROW(contains("[6, null]"), kErrorMessage); +} + +TEST_F(ArrayContainsTest, rowCheckNulls) { + const auto baseVector = makeRowVector({ + makeNullableFlatVector({1, 2, 3, 4, 5, 6}), + makeNullableFlatVector({1, 2, std::nullopt, 4, 5, 6}), + }); + const auto data = makeArrayVector({0, 3}, baseVector); + + auto contains = [&](const std::vector>& search) { + const auto searchBase = makeRowVector({ + makeNullableFlatVector({search.at(0)}), + makeNullableFlatVector({search.at(1)}), + }); + const auto searchConstant = + BaseVector::wrapInConstant(data->size(), 0, searchBase); + const auto result = + evaluate("contains(c0, c1)", makeRowVector({data, searchConstant})); + return result->asFlatVector()->valueAt(0); + }; + + static const std::string kErrorMessage = + "contains does not support arrays with elements that contain null"; + // No null equal. + ASSERT_FALSE(contains({7, std::nullopt})); + // Null equal, (3, null) vs (3, 3). + VELOX_ASSERT_THROW(contains({3, 3}), kErrorMessage); + // Null equal, (6, 6) vs (6, null). + VELOX_ASSERT_THROW(contains({6, std::nullopt}), kErrorMessage); +} } // namespace