Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add config to throw exception for duplicate keys in Spark map_concat function #12379

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions velox/core/QueryConfig.h
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,12 @@ class QueryConfig {
static constexpr const char* kSparkLegacyDateFormatter =
"spark.legacy_date_formatter";

/// If a key is found in multiple given maps, by default that key's value in
/// the resulting map comes from the last one of those maps. When true, throw
/// exception on duplicate map key.
static constexpr const char* kSparkThrowExceptionOnDuplicateMapKeys =
"spark.throw_exception_on_duplicate_map_keys";

/// The number of local parallel table writer operators per task.
static constexpr const char* kTaskWriterCount = "task_writer_count";

Expand Down Expand Up @@ -831,6 +837,10 @@ class QueryConfig {
return get<bool>(kSparkLegacyDateFormatter, false);
}

bool sparkThrowExceptionOnDuplicateMapKeys() const {
return get<bool>(kSparkThrowExceptionOnDuplicateMapKeys, false);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, Spark uses EXCEPTION policy by default, while keeping LAST_WIN as default in Velox is compatible with Presto. In Gluten, we can always set this configuration according the Spark's config.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name indicates it is only used for Spark, is it also used for Presto? If not, I think align with Spark default value might be more reasonable.
Not a big issue, just mention it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or remove "spark." from the config name.

}

bool exprTrackCpuUsage() const {
return get<bool>(kExprTrackCpuUsage, false);
}
Expand Down
5 changes: 5 additions & 0 deletions velox/docs/configs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -887,6 +887,11 @@ Spark-specific Configuration
Joda date formatter performs strict checking of its input and uses different pattern string.
For example, the 2015-07-22 10:00:00 timestamp cannot be parsed if pattern is yyyy-MM-dd because the parser does not consume whole input.
Another example is that the 'W' pattern, which means week in month, is not supported. For more differences, see :issue:`10354`.
* - spark.throw_exception_on_duplicate_map_keys
- bool
- false
- By default, if a key is found in multiple given maps, that key's value in the resulting map comes from the last one of those maps.
If true, throws exception when duplicate keys are found. This configuration is needed by Spark functions `CreateMap`, `MapFromArrays`, `MapFromEntries`, `StringToMap`, `MapConcat`, `TransformKeys`.

Tracing
--------
Expand Down
11 changes: 11 additions & 0 deletions velox/docs/functions/spark/map.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,17 @@ Map Functions

SELECT map(array(1, 2), array(3, 4)); -- {[1, 2] -> [3, 4]}

.. spark:function:: map_concat(map1(K,V), map2(K,V), ..., mapN(K,V)) -> map(K,V)
Returns the union of all the given maps. If a key is found in multiple given maps,
by default that key's value in the resulting map comes from the last one of those maps.
If configuration `spark.throw_exception_on_duplicate_map_keys` is set true, throws exception
for duplicate keys. ::

SELECT map_concat(map(1, 'a', 2, 'b'), map(3, 'c')); -- {1 -> 'a', 2 -> 'b', 3 -> 'c'}
SELECT map_concat(map(1, 'a', 2, 'b'), map(3, 'c', 2, 'd')); -- {1 -> 'a', 2 -> 'd', 3 -> 'c'} (LAST_WIN behavior)
SELECT map_concat(map(1, 'a', 2, 'b'), map(3, 'c', 2, 'd')); -- "Duplicate map key 2 was found" (EXCEPTION behavior)

.. spark:function:: map_entries(map(K,V)) -> array(row(K,V))
Returns an array of all entries in the given map. ::
Expand Down
54 changes: 42 additions & 12 deletions velox/functions/lib/MapConcat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ namespace facebook::velox::functions {
namespace {

// See documentation at https://prestodb.io/docs/current/functions/map.html
template <bool EmptyForNull>
template <bool EmptyForNull, bool AllowSingleArg>
class MapConcatFunction : public exec::VectorFunction {
public:
void apply(
Expand All @@ -31,15 +31,19 @@ class MapConcatFunction : public exec::VectorFunction {
const TypePtr& outputType,
exec::EvalCtx& context,
VectorPtr& result) const override {
VELOX_CHECK_GE(args.size(), 2);
const TypePtr& mapType = args[0]->type();
VELOX_CHECK_EQ(mapType->kind(), TypeKind::MAP);
for (const VectorPtr& arg : args) {
VELOX_CHECK(mapType->kindEquals(arg->type()));
VELOX_CHECK(
std::all_of(
args.begin(),
args.end(),
[&](const VectorPtr& arg) {
return arg->type()->kind() == TypeKind::MAP;
}),
"All arguments must be of type map.");
const uint64_t numArgs = args.size();
if constexpr (!AllowSingleArg) {
VELOX_CHECK_GE(numArgs, 2);
}
VELOX_CHECK(mapType->kindEquals(outputType));

const uint64_t numArgs = args.size();
exec::DecodedArgs decodedArgs(rows, args, context);
vector_size_t maxSize = 0;
for (int i = 0; i < numArgs; i++) {
Expand Down Expand Up @@ -109,6 +113,11 @@ class MapConcatFunction : public exec::VectorFunction {
// Check for duplicate keys
SelectivityVector uniqueKeys(offset);
vector_size_t duplicateCnt = 0;
const auto throwExceptionOnDuplicateMapKeys =
context.execCtx()
->queryCtx()
->queryConfig()
.sparkThrowExceptionOnDuplicateMapKeys();
rows.applyToSelected([&](vector_size_t row) {
const int mapOffset = rawOffsets[row];
const int mapSize = rawSizes[row];
Expand All @@ -118,6 +127,11 @@ class MapConcatFunction : public exec::VectorFunction {
for (vector_size_t i = 1; i < mapSize; i++) {
if (combinedKeys->equalValueAt(
combinedKeys.get(), mapOffset + i, mapOffset + i - 1)) {
if (throwExceptionOnDuplicateMapKeys) {
const auto duplicateKey = combinedKeys->wrappedVector()->toString(
combinedKeys->wrappedIndex(mapOffset + i));
VELOX_USER_FAIL("Duplicate map key {} was found.", duplicateKey);
}
duplicateCnt++;
// "remove" duplicate entry
uniqueKeys.setValid(mapOffset + i - 1, false);
Expand Down Expand Up @@ -172,15 +186,31 @@ class MapConcatFunction : public exec::VectorFunction {
void registerMapConcatFunction(const std::string& name) {
exec::registerVectorFunction(
name,
MapConcatFunction</*EmptyForNull=*/false>::signatures(),
std::make_unique<MapConcatFunction</*EmptyForNull=*/false>>());
MapConcatFunction</*EmptyForNull=*/false, /*AllowSingleArg=*/false>::
signatures(),
std::make_unique<MapConcatFunction<
/*EmptyForNull=*/false,
/*AllowSingleArg=*/false>>());
}

void registerMapConcatAllowSingleArg(const std::string& name) {
exec::registerVectorFunction(
name,
MapConcatFunction</*EmptyForNull=*/false, /*AllowSingleArg=*/true>::
signatures(),
std::make_unique<MapConcatFunction<
/*EmptyForNull=*/false,
/*AllowSingleArg=*/true>>());
}

void registerMapConcatEmptyNullsFunction(const std::string& name) {
exec::registerVectorFunction(
name,
MapConcatFunction</*EmptyForNull=*/true>::signatures(),
std::make_unique<MapConcatFunction</*EmptyForNull=*/true>>(),
MapConcatFunction</*EmptyForNull=*/true, /*AllowSingleArg=*/false>::
signatures(),
std::make_unique<MapConcatFunction<
/*EmptyForNull=*/true,
/*AllowSingleArg=*/false>>(),
exec::VectorFunctionMetadataBuilder().defaultNullBehavior(false).build());
}
} // namespace facebook::velox::functions
2 changes: 2 additions & 0 deletions velox/functions/lib/MapConcat.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ namespace facebook::velox::functions {

void registerMapConcatFunction(const std::string& name);

void registerMapConcatAllowSingleArg(const std::string& name);

void registerMapConcatEmptyNullsFunction(const std::string& name);

} // namespace facebook::velox::functions
3 changes: 3 additions & 0 deletions velox/functions/sparksql/registration/RegisterMap.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "velox/functions/lib/MapConcat.h"
#include "velox/functions/lib/RegistrationHelpers.h"
#include "velox/functions/sparksql/Size.h"

Expand All @@ -31,6 +32,8 @@ void registerSparkMapFunctions(const std::string& prefix) {
VELOX_REGISTER_VECTOR_FUNCTION(udf_map_keys, prefix + "map_keys");
VELOX_REGISTER_VECTOR_FUNCTION(udf_map_values, prefix + "map_values");
VELOX_REGISTER_VECTOR_FUNCTION(udf_map_zip_with, prefix + "map_zip_with");

registerMapConcatAllowSingleArg(prefix + "map_concat");
}

namespace sparksql {
Expand Down
1 change: 1 addition & 0 deletions velox/functions/sparksql/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ add_executable(
LeastGreatestTest.cpp
MakeDecimalTest.cpp
MakeTimestampTest.cpp
MapConcatTest.cpp
MapTest.cpp
MaskTest.cpp
MightContainTest.cpp
Expand Down
125 changes: 125 additions & 0 deletions velox/functions/sparksql/tests/MapConcatTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "velox/common/base/tests/GTestUtils.h"
#include "velox/functions/sparksql/tests/SparkFunctionBaseTest.h"

namespace facebook::velox::functions::sparksql::test {
namespace {

class MapConcatTest : public SparkFunctionBaseTest {
protected:
template <typename TKey, typename TValue>
static std::vector<TKey> mapKeys(const std::map<TKey, TValue>& m) {
std::vector<TKey> keys;
keys.reserve(m.size());
for (const auto& [key, value] : m) {
keys.push_back(key);
}
return keys;
}

template <typename TKey, typename TValue>
static std::vector<TValue> mapValues(const std::map<TKey, TValue>& m) {
std::vector<TValue> values;
values.reserve(m.size());
for (const auto& [key, value] : m) {
values.push_back(value);
}
return values;
}

MapVectorPtr makeMapVector(
vector_size_t size,
const std::map<std::string, int32_t>& m,
std::function<bool(vector_size_t /*row*/)> isNullAt = nullptr) {
std::vector<std::string> keys = mapKeys(m);
std::vector<int32_t> values = mapValues(m);
return vectorMaker_.mapVector<StringView, int32_t>(
size,
[&](vector_size_t /*row*/) { return keys.size(); },
[&](vector_size_t /*mapRow*/, vector_size_t row) {
return StringView(keys[row]);
},
[&](vector_size_t mapRow, vector_size_t row) {
return mapRow % 11 + values[row];
},
isNullAt);
}

template <typename TKey, typename TValue>
std::map<TKey, TValue> concat(
const std::map<TKey, TValue>& a,
const std::map<TKey, TValue>& b) {
std::map<TKey, TValue> result;
result.insert(b.begin(), b.end());
result.insert(a.begin(), a.end());
return result;
}
};

TEST_F(MapConcatTest, duplicateKeys) {
vector_size_t size = 1'000;

std::map<std::string, int32_t> a = {
{"a1", 1}, {"a2", 2}, {"a3", 3}, {"a4", 4}};
std::map<std::string, int32_t> b = {
{"b1", 1}, {"b2", 2}, {"b3", 3}, {"b4", 4}, {"a2", -1}};
auto aMap = makeMapVector(size, a);
auto bMap = makeMapVector(size, b);

// By default, if a key is found in multiple given maps, that key's value in
// the resulting map comes from the last one of those maps.
std::map<std::string, int32_t> ab = concat(a, b);
auto expectedMap = makeMapVector(size, ab);
auto result =
evaluate<MapVector>("map_concat(c0, c1)", makeRowVector({aMap, bMap}));
velox::test::assertEqualVectors(expectedMap, result);

std::map<std::string, int32_t> ba = concat(b, a);
expectedMap = makeMapVector(size, ba);
result =
evaluate<MapVector>("map_concat(c1, c0)", makeRowVector({aMap, bMap}));
velox::test::assertEqualVectors(expectedMap, result);

result =
evaluate<MapVector>("map_concat(c0, c1)", makeRowVector({aMap, aMap}));
velox::test::assertEqualVectors(aMap, result);

// Throws exception when duplicate keys are found.
queryCtx_->testingOverrideConfigUnsafe({
{core::QueryConfig::kSparkThrowExceptionOnDuplicateMapKeys, "true"},
});
VELOX_ASSERT_USER_THROW(
evaluate<MapVector>("map_concat(c0, c1)", makeRowVector({aMap, bMap})),
"Duplicate map key a2 was found");
}

TEST_F(MapConcatTest, singleArg) {
vector_size_t size = 1'000;

std::map<std::string, int32_t> a = {{"a1", 1}, {"a2", 2}, {"a3", 3}};
auto aMap = makeMapVector(size, a, nullEvery(5));

auto expectedMap =
makeMapVector(size, a, [](vector_size_t row) { return row % 5 == 0; });

auto result = evaluate<MapVector>("map_concat(c0)", makeRowVector({aMap}));
velox::test::assertEqualVectors(expectedMap, result);
}

} // namespace
} // namespace facebook::velox::functions::sparksql::test
Loading