diff --git a/velox/dwio/common/SelectiveIntegerColumnReader.h b/velox/dwio/common/SelectiveIntegerColumnReader.h index 6aea8317907c..cd5c45c79eff 100644 --- a/velox/dwio/common/SelectiveIntegerColumnReader.h +++ b/velox/dwio/common/SelectiveIntegerColumnReader.h @@ -179,7 +179,13 @@ void SelectiveIntegerColumnReader::processValueHook( readHelper( &alwaysTrue(), rows, - ExtractToHook>(hook)); + ExtractToHook>(hook)); + break; + case aggregate::AggregationHook::kSumBigintToBigintOverflow: + readHelper( + &alwaysTrue(), + rows, + ExtractToHook>(hook)); break; case aggregate::AggregationHook::kBigintMax: readHelper( diff --git a/velox/exec/AggregationHook.h b/velox/exec/AggregationHook.h index 1b05e3164e1b..e0e87b7e1584 100644 --- a/velox/exec/AggregationHook.h +++ b/velox/exec/AggregationHook.h @@ -35,6 +35,7 @@ class AggregationHook : public ValueHook { static constexpr Kind kFloatMin = 8; static constexpr Kind kDoubleMax = 9; static constexpr Kind kDoubleMin = 10; + static constexpr Kind kSumBigintToBigintOverflow = 11; // Make null behavior known at compile time. This is useful when // templating a column decoding loop with a hook. @@ -97,9 +98,10 @@ class AggregationHook : public ValueHook { }; namespace { -template +template inline void updateSingleValue(TValue& result, TValue value) { if constexpr ( + (std::is_same_v && Overflow) || std::is_same_v || std::is_same_v) { result += value; } else { @@ -108,7 +110,7 @@ inline void updateSingleValue(TValue& result, TValue value) { } } // namespace -template +template class SumHook final : public AggregationHook { public: SumHook( @@ -132,6 +134,9 @@ class SumHook final : public AggregationHook { return kSumIntegerToBigint; } if (std::is_same_v) { + if (Overflow) { + return kSumBigintToBigintOverflow; + } return kSumBigintToBigint; } } @@ -141,7 +146,7 @@ class SumHook final : public AggregationHook { void addValue(vector_size_t row, const void* value) override { auto group = findGroup(row); clearNull(group); - updateSingleValue( + updateSingleValue( *reinterpret_cast(group + offset_), TAggregate(*reinterpret_cast(value))); } diff --git a/velox/functions/lib/aggregates/SumAggregateBase.h b/velox/functions/lib/aggregates/SumAggregateBase.h index 4a5f5a858833..a159ae7f1345 100644 --- a/velox/functions/lib/aggregates/SumAggregateBase.h +++ b/velox/functions/lib/aggregates/SumAggregateBase.h @@ -133,7 +133,7 @@ class SumAggregateBase if (mayPushdown && arg->isLazy()) { BaseAggregate::template pushdown< - facebook::velox::aggregate::SumHook>( + facebook::velox::aggregate::SumHook>( groups, rows, arg); return; } diff --git a/velox/functions/lib/aggregates/tests/SumTestBase.h b/velox/functions/lib/aggregates/tests/SumTestBase.h index 85d044609330..61f467c146ec 100644 --- a/velox/functions/lib/aggregates/tests/SumTestBase.h +++ b/velox/functions/lib/aggregates/tests/SumTestBase.h @@ -22,6 +22,47 @@ namespace facebook::velox::functions::aggregate::test { +template +struct SumRow { + char nulls; + Type sum; +}; + +template +void testHookLimits(bool expectOverflow = false) { + // Pair of . + std::vector> limits = { + {std::numeric_limits::min(), -1}, + {std::numeric_limits::max(), 1}}; + + for (const auto& [limit, overflow] : limits) { + SumRow sumRow; + sumRow.sum = 0; + ResultType expected = 0; + char* row = reinterpret_cast(&sumRow); + uint64_t numNulls = 0; + facebook::velox::aggregate::SumHook hook( + offsetof(SumRow, sum), + offsetof(SumRow, nulls), + 0, + &row, + &numNulls); + + // Adding limit should not overflow. + ASSERT_NO_THROW(hook.addValue(0, &limit)); + expected += limit; + EXPECT_EQ(expected, sumRow.sum); + // Adding overflow based on the ResultType should throw. + if (expectOverflow) { + VELOX_ASSERT_THROW(hook.addValue(0, &overflow), "overflow"); + } else { + ASSERT_NO_THROW(hook.addValue(0, &overflow)); + expected += overflow; + EXPECT_EQ(expected, sumRow.sum); + } + } +} + class SumTestBase : public AggregationTestBase { protected: void SetUp() override { diff --git a/velox/functions/prestosql/aggregates/tests/SumTest.cpp b/velox/functions/prestosql/aggregates/tests/SumTest.cpp index b5cc9b5aad2e..76e0e92e060d 100644 --- a/velox/functions/prestosql/aggregates/tests/SumTest.cpp +++ b/velox/functions/prestosql/aggregates/tests/SumTest.cpp @@ -404,12 +404,6 @@ TEST_F(SumTest, nulls) { "SELECT c0, sum(c1) as sum_c1 FROM tmp GROUP BY 1"); } -template -struct SumRow { - char nulls; - Type sum; -}; - TEST_F(SumTest, hook) { SumRow sumRow; sumRow.nulls = 1; @@ -431,41 +425,6 @@ TEST_F(SumTest, hook) { EXPECT_EQ(value, sumRow.sum); } -template -void testHookLimits(bool expectOverflow = false) { - // Pair of . - std::vector> limits = { - {std::numeric_limits::min(), -1}, - {std::numeric_limits::max(), 1}}; - - for (const auto& [limit, overflow] : limits) { - SumRow sumRow; - sumRow.sum = 0; - ResultType expected = 0; - char* row = reinterpret_cast(&sumRow); - uint64_t numNulls = 0; - aggregate::SumHook hook( - offsetof(SumRow, sum), - offsetof(SumRow, nulls), - 0, - &row, - &numNulls); - - // Adding limit should not overflow. - ASSERT_NO_THROW(hook.addValue(0, &limit)); - expected += limit; - EXPECT_EQ(expected, sumRow.sum); - // Adding overflow based on the ResultType should throw. - if (expectOverflow) { - VELOX_ASSERT_THROW(hook.addValue(0, &overflow), "overflow"); - } else { - ASSERT_NO_THROW(hook.addValue(0, &overflow)); - expected += overflow; - EXPECT_EQ(expected, sumRow.sum); - } - } -} - TEST_F(SumTest, hookLimits) { testHookLimits(); testHookLimits(true); diff --git a/velox/functions/sparksql/aggregates/tests/SumAggregationTest.cpp b/velox/functions/sparksql/aggregates/tests/SumAggregationTest.cpp index 9cf3b67b5ce0..10a088c2db20 100644 --- a/velox/functions/sparksql/aggregates/tests/SumAggregationTest.cpp +++ b/velox/functions/sparksql/aggregates/tests/SumAggregationTest.cpp @@ -34,5 +34,9 @@ TEST_F(SumAggregationTest, overflow) { SumTestBase::testAggregateOverflow("spark_sum"); } +TEST_F(SumAggregationTest, hookLimits) { + testHookLimits(); +} + } // namespace } // namespace facebook::velox::functions::aggregate::sparksql::test