Skip to content

Commit

Permalink
Fix pushdown overflow in sum(bigint) Spark aggregate function (#7306)
Browse files Browse the repository at this point in the history
Summary:
Fixes #7297

Support for allowing sum(bigint) overflow in Spark aggregate function was added in issue #7211. However, the part of aggregation pushdown was missing. This patch now adds support to allow overflow for the Spark sum(bigint) aggregate function during aggregation pushdown.

Pull Request resolved: #7306

Reviewed By: Yuhta

Differential Revision: D50791907

Pulled By: mbasmanova

fbshipit-source-id: 0aab094dfca1dc03e00cc53a9a34c373d851fc76
marin-ma authored and facebook-github-bot committed Oct 30, 2023
1 parent a9e4203 commit 576cc9c
Showing 6 changed files with 61 additions and 46 deletions.
8 changes: 7 additions & 1 deletion velox/dwio/common/SelectiveIntegerColumnReader.h
Original file line number Diff line number Diff line change
@@ -179,7 +179,13 @@ void SelectiveIntegerColumnReader::processValueHook(
readHelper<Reader, velox::common::AlwaysTrue, isDense>(
&alwaysTrue(),
rows,
ExtractToHook<aggregate::SumHook<int64_t, int64_t>>(hook));
ExtractToHook<aggregate::SumHook<int64_t, int64_t, false>>(hook));
break;
case aggregate::AggregationHook::kSumBigintToBigintOverflow:
readHelper<Reader, velox::common::AlwaysTrue, isDense>(
&alwaysTrue(),
rows,
ExtractToHook<aggregate::SumHook<int64_t, int64_t, true>>(hook));
break;
case aggregate::AggregationHook::kBigintMax:
readHelper<Reader, velox::common::AlwaysTrue, isDense>(
11 changes: 8 additions & 3 deletions velox/exec/AggregationHook.h
Original file line number Diff line number Diff line change
@@ -35,6 +35,7 @@ class AggregationHook : public ValueHook {
static constexpr Kind kFloatMin = 8;
static constexpr Kind kDoubleMax = 9;
static constexpr Kind kDoubleMin = 10;
static constexpr Kind kSumBigintToBigintOverflow = 11;

// Make null behavior known at compile time. This is useful when
// templating a column decoding loop with a hook.
@@ -97,9 +98,10 @@ class AggregationHook : public ValueHook {
};

namespace {
template <typename TValue>
template <typename TValue, bool Overflow>
inline void updateSingleValue(TValue& result, TValue value) {
if constexpr (
(std::is_same_v<TValue, int64_t> && Overflow) ||
std::is_same_v<TValue, double> || std::is_same_v<TValue, float>) {
result += value;
} else {
@@ -108,7 +110,7 @@ inline void updateSingleValue(TValue& result, TValue value) {
}
} // namespace

template <typename TValue, typename TAggregate>
template <typename TValue, typename TAggregate, bool Overflow = false>
class SumHook final : public AggregationHook {
public:
SumHook(
@@ -132,6 +134,9 @@ class SumHook final : public AggregationHook {
return kSumIntegerToBigint;
}
if (std::is_same_v<TValue, int64_t>) {
if (Overflow) {
return kSumBigintToBigintOverflow;
}
return kSumBigintToBigint;
}
}
@@ -141,7 +146,7 @@ class SumHook final : public AggregationHook {
void addValue(vector_size_t row, const void* value) override {
auto group = findGroup(row);
clearNull(group);
updateSingleValue(
updateSingleValue<TAggregate, Overflow>(
*reinterpret_cast<TAggregate*>(group + offset_),
TAggregate(*reinterpret_cast<const TValue*>(value)));
}
2 changes: 1 addition & 1 deletion velox/functions/lib/aggregates/SumAggregateBase.h
Original file line number Diff line number Diff line change
@@ -133,7 +133,7 @@ class SumAggregateBase

if (mayPushdown && arg->isLazy()) {
BaseAggregate::template pushdown<
facebook::velox::aggregate::SumHook<TValue, TData>>(
facebook::velox::aggregate::SumHook<TValue, TData, Overflow>>(
groups, rows, arg);
return;
}
41 changes: 41 additions & 0 deletions velox/functions/lib/aggregates/tests/SumTestBase.h
Original file line number Diff line number Diff line change
@@ -22,6 +22,47 @@

namespace facebook::velox::functions::aggregate::test {

template <typename Type>
struct SumRow {
char nulls;
Type sum;
};

template <typename InputType, typename ResultType, bool Overflow = false>
void testHookLimits(bool expectOverflow = false) {
// Pair of <limit, value to overflow>.
std::vector<std::pair<InputType, InputType>> limits = {
{std::numeric_limits<InputType>::min(), -1},
{std::numeric_limits<InputType>::max(), 1}};

for (const auto& [limit, overflow] : limits) {
SumRow<ResultType> sumRow;
sumRow.sum = 0;
ResultType expected = 0;
char* row = reinterpret_cast<char*>(&sumRow);
uint64_t numNulls = 0;
facebook::velox::aggregate::SumHook<InputType, ResultType, Overflow> hook(
offsetof(SumRow<ResultType>, sum),
offsetof(SumRow<ResultType>, nulls),
0,
&row,
&numNulls);

// Adding limit should not overflow.
ASSERT_NO_THROW(hook.addValue(0, &limit));
expected += limit;
EXPECT_EQ(expected, sumRow.sum);
// Adding overflow based on the ResultType should throw.
if (expectOverflow) {
VELOX_ASSERT_THROW(hook.addValue(0, &overflow), "overflow");
} else {
ASSERT_NO_THROW(hook.addValue(0, &overflow));
expected += overflow;
EXPECT_EQ(expected, sumRow.sum);
}
}
}

class SumTestBase : public AggregationTestBase {
protected:
void SetUp() override {
41 changes: 0 additions & 41 deletions velox/functions/prestosql/aggregates/tests/SumTest.cpp
Original file line number Diff line number Diff line change
@@ -404,12 +404,6 @@ TEST_F(SumTest, nulls) {
"SELECT c0, sum(c1) as sum_c1 FROM tmp GROUP BY 1");
}

template <typename Type>
struct SumRow {
char nulls;
Type sum;
};

TEST_F(SumTest, hook) {
SumRow<int64_t> sumRow;
sumRow.nulls = 1;
@@ -431,41 +425,6 @@ TEST_F(SumTest, hook) {
EXPECT_EQ(value, sumRow.sum);
}

template <typename InputType, typename ResultType>
void testHookLimits(bool expectOverflow = false) {
// Pair of <limit, value to overflow>.
std::vector<std::pair<InputType, InputType>> limits = {
{std::numeric_limits<InputType>::min(), -1},
{std::numeric_limits<InputType>::max(), 1}};

for (const auto& [limit, overflow] : limits) {
SumRow<ResultType> sumRow;
sumRow.sum = 0;
ResultType expected = 0;
char* row = reinterpret_cast<char*>(&sumRow);
uint64_t numNulls = 0;
aggregate::SumHook<InputType, ResultType> hook(
offsetof(SumRow<ResultType>, sum),
offsetof(SumRow<ResultType>, nulls),
0,
&row,
&numNulls);

// Adding limit should not overflow.
ASSERT_NO_THROW(hook.addValue(0, &limit));
expected += limit;
EXPECT_EQ(expected, sumRow.sum);
// Adding overflow based on the ResultType should throw.
if (expectOverflow) {
VELOX_ASSERT_THROW(hook.addValue(0, &overflow), "overflow");
} else {
ASSERT_NO_THROW(hook.addValue(0, &overflow));
expected += overflow;
EXPECT_EQ(expected, sumRow.sum);
}
}
}

TEST_F(SumTest, hookLimits) {
testHookLimits<int32_t, int64_t>();
testHookLimits<int64_t, int64_t>(true);
Original file line number Diff line number Diff line change
@@ -34,5 +34,9 @@ TEST_F(SumAggregationTest, overflow) {
SumTestBase::testAggregateOverflow<int64_t, int64_t, int64_t>("spark_sum");
}

TEST_F(SumAggregationTest, hookLimits) {
testHookLimits<int64_t, int64_t, true>();
}

} // namespace
} // namespace facebook::velox::functions::aggregate::sparksql::test

0 comments on commit 576cc9c

Please sign in to comment.