Skip to content

Commit

Permalink
Fail if compare null for complex type in contains (#7273)
Browse files Browse the repository at this point in the history
Summary:
In Presto, `contains` is applied to arrays of complex types with nested
nulls may or may not fail depending on whether need to compare
the nulls or not.

```SQL
presto> select contains(col0, col1) from (values (array[array[1, 3]],
array[2, null])) as tbl(col0, col1);
 _col0
-------
 false
(1 row)

presto> select contains(col0, col1) from (values (array[array[2, 3]],
array[2, null])) as tbl(col0, col1);
Query 20231027_015548_00004_whazu failed: contains does not
support arrays with elements that are null or contain null

presto> select contains(col0, col1) from (values (array[array[2, null]],
array[2, 3])) as tbl(col0, col1);
Query 20231027_015600_00005_whazu failed: contains does not
support arrays with elements that are null or contain null

presto> select contains(col0, col1) from (values (array[array[1, null]],
array[2, 3])) as tbl(col0, col1);
 _col0
-------
 false
(1 row)
```

This PR check the equal result and `setError` if compare
the nested nulls of complex type.

Pull Request resolved: #7273

Reviewed By: xiaoxmeng, Yuhta

Differential Revision: D50793771

Pulled By: mbasmanova

fbshipit-source-id: 991d018ea79fe46ec4746c0038160ead64e5cb70
  • Loading branch information
duanmeng authored and facebook-github-bot committed Oct 30, 2023
1 parent b9984b4 commit a9e4203
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 10 deletions.
6 changes: 6 additions & 0 deletions velox/docs/functions/presto/array.rst
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,12 @@ Array Functions
.. function:: contains(x, element) -> boolean

Returns true if the array ``x`` contains the ``element``.
When 'element' is of complex type, throws if 'x' or 'element' contains nested nulls
and these need to be compared to produce a result. ::

SELECT contains(ARRAY[ARRAY[1, 3]], ARRAY[2, null]); -- false.
SELECT contains(ARRAY[ARRAY[2, 3]], ARRAY[2, null]); -- failed: contains does not support arrays with elements that are null or contain null
SELECT contains(ARRAY[ARRAY[2, null]], ARRAY[2, 1]); -- failed: contains does not support arrays with elements that are null or contain null

.. function:: element_at(array(E), index) -> E

Expand Down
34 changes: 24 additions & 10 deletions velox/functions/prestosql/ArrayContains.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ void applyTyped(
DecodedVector& arrayDecoded,
DecodedVector& elementsDecoded,
DecodedVector& searchDecoded,
exec::EvalCtx& /*context*/,
FlatVector<bool>& flatResult) {
using T = typename TypeTraits<kind>::NativeType;

Expand Down Expand Up @@ -85,6 +86,7 @@ void applyComplexType(
DecodedVector& arrayDecoded,
DecodedVector& elementsDecoded,
DecodedVector& searchDecoded,
exec::EvalCtx& context,
FlatVector<bool>& flatResult) {
auto baseArray = arrayDecoded.base()->as<ArrayVector>();
auto rawSizes = baseArray->rawSizes();
Expand All @@ -96,20 +98,28 @@ void applyComplexType(
auto searchBase = searchDecoded.base();
auto searchIndices = searchDecoded.indices();

rows.applyToSelected([&](auto row) {
context.applyToSelectedNoThrow(rows, [&](auto row) {
auto size = rawSizes[indices[row]];
auto offset = rawOffsets[indices[row]];

bool foundNull = false;

auto searchIndex = searchIndices[row];

for (auto i = 0; i < size; i++) {
if (elementsDecoded.isNullAt(offset + i)) {
foundNull = true;
} else if (elementsBase->equalValueAt(
searchBase, elementIndices[offset + i], searchIndex)) {
flatResult.set(row, true);
return;
} else {
std::optional<bool> result = elementsBase->equalValueAt(
searchBase,
elementIndices[offset + i],
searchIndex,
CompareFlags::NullHandlingMode::StopAtNull);
VELOX_USER_CHECK(
result.has_value(),
"contains does not support arrays with elements that contain null");
if (result.value()) {
flatResult.set(row, true);
return;
}
}
}

Expand All @@ -127,9 +137,10 @@ void applyTyped<TypeKind::ARRAY>(
DecodedVector& arrayDecoded,
DecodedVector& elementsDecoded,
DecodedVector& searchDecoded,
exec::EvalCtx& context,
FlatVector<bool>& flatResult) {
applyComplexType(
rows, arrayDecoded, elementsDecoded, searchDecoded, flatResult);
rows, arrayDecoded, elementsDecoded, searchDecoded, context, flatResult);
}

template <>
Expand All @@ -138,9 +149,10 @@ void applyTyped<TypeKind::MAP>(
DecodedVector& arrayDecoded,
DecodedVector& elementsDecoded,
DecodedVector& searchDecoded,
exec::EvalCtx& context,
FlatVector<bool>& flatResult) {
applyComplexType(
rows, arrayDecoded, elementsDecoded, searchDecoded, flatResult);
rows, arrayDecoded, elementsDecoded, searchDecoded, context, flatResult);
}

template <>
Expand All @@ -149,9 +161,10 @@ void applyTyped<TypeKind::ROW>(
DecodedVector& arrayDecoded,
DecodedVector& elementsDecoded,
DecodedVector& searchDecoded,
exec::EvalCtx& context,
FlatVector<bool>& flatResult) {
applyComplexType(
rows, arrayDecoded, elementsDecoded, searchDecoded, flatResult);
rows, arrayDecoded, elementsDecoded, searchDecoded, context, flatResult);
}

class ArrayContainsFunction : public exec::VectorFunction {
Expand Down Expand Up @@ -191,6 +204,7 @@ class ArrayContainsFunction : public exec::VectorFunction {
*arrayHolder.get(),
*elementsHolder.get(),
*searchHolder.get(),
context,
*flatResult);
}

Expand Down
58 changes: 58 additions & 0 deletions velox/functions/prestosql/tests/ArrayContainsTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

#include <optional>
#include "velox/common/base/tests/GTestUtils.h"
#include "velox/functions/prestosql/tests/utils/FunctionBaseTest.h"

using namespace facebook::velox;
Expand Down Expand Up @@ -316,4 +317,61 @@ TEST_F(ArrayContainsTest, dictionaryEncodingElements) {
testContainsConstantKey(arrayVector, {3, 4, 5}, {true, true});
}

TEST_F(ArrayContainsTest, arrayCheckNulls) {
const auto baseVector = makeArrayVectorFromJson<int32_t>({
"[1, 1]",
"[2, 2]",
"[3, null]",
"[4, 4]",
"[5, 5]",
"[6, 6]",
});
const auto data = makeArrayVector({0, 3}, baseVector);
auto contains = [&](const std::string& search) {
const auto searchBase = makeArrayVectorFromJson<int32_t>({search});
const auto searchConstant =
BaseVector::wrapInConstant(data->size(), 0, searchBase);
const auto result =
evaluate("contains(c0, c1)", makeRowVector({data, searchConstant}));
return result->asFlatVector<bool>()->valueAt(0);
};

static const std::string kErrorMessage =
"contains does not support arrays with elements that contain null";
// No null equal.
ASSERT_FALSE(contains("[7, null]"));
// Null equal, [3, null] vs [3, 3].
VELOX_ASSERT_THROW(contains("[3, 3]"), kErrorMessage);
// Null equal, [6, 6] vs [6, null].
VELOX_ASSERT_THROW(contains("[6, null]"), kErrorMessage);
}

TEST_F(ArrayContainsTest, rowCheckNulls) {
const auto baseVector = makeRowVector({
makeNullableFlatVector<int32_t>({1, 2, 3, 4, 5, 6}),
makeNullableFlatVector<int32_t>({1, 2, std::nullopt, 4, 5, 6}),
});
const auto data = makeArrayVector({0, 3}, baseVector);

auto contains = [&](const std::vector<std::optional<int32_t>>& search) {
const auto searchBase = makeRowVector({
makeNullableFlatVector<int32_t>({search.at(0)}),
makeNullableFlatVector<int32_t>({search.at(1)}),
});
const auto searchConstant =
BaseVector::wrapInConstant(data->size(), 0, searchBase);
const auto result =
evaluate("contains(c0, c1)", makeRowVector({data, searchConstant}));
return result->asFlatVector<bool>()->valueAt(0);
};

static const std::string kErrorMessage =
"contains does not support arrays with elements that contain null";
// No null equal.
ASSERT_FALSE(contains({7, std::nullopt}));
// Null equal, (3, null) vs (3, 3).
VELOX_ASSERT_THROW(contains({3, 3}), kErrorMessage);
// Null equal, (6, 6) vs (6, null).
VELOX_ASSERT_THROW(contains({6, std::nullopt}), kErrorMessage);
}
} // namespace

0 comments on commit a9e4203

Please sign in to comment.