Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add max(varchar, bigint) support #12396

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 99 additions & 20 deletions velox/functions/prestosql/aggregates/MinMaxAggregates.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
#include "velox/exec/AggregationHook.h"
#include "velox/functions/lib/aggregates/MinMaxAggregateBase.h"
#include "velox/functions/lib/aggregates/SimpleNumericAggregate.h"
#include "velox/functions/lib/aggregates/SingleValueAccumulator.h"
#include "velox/functions/lib/aggregates/ValueSet.h"
#include "velox/functions/prestosql/aggregates/AggregateNames.h"
#include "velox/type/FloatingPointUtil.h"

Expand Down Expand Up @@ -97,9 +97,11 @@ struct MinMaxNAccumulator {
}
}

/// Copy all values from 'topValues' into 'rawValues' buffer. The heap remains
/// Copy all values from 'topValues' into 'values'. The heap remains
/// unchanged after the call.
void extractValues(T* rawValues, vector_size_t offset, Compare& comparator) {
void
extractValues(VectorPtr& values, vector_size_t offset, Compare& comparator) {
auto rawValues = values->asFlatVector<T>()->mutableRawValues();
std::sort_heap(heapValues.begin(), heapValues.end(), comparator);
for (int64_t i = heapValues.size() - 1; i >= 0; --i) {
rawValues[offset + i] = heapValues[i];
Expand All @@ -108,6 +110,78 @@ struct MinMaxNAccumulator {
}
};

/// @tparam Compare Type of comparator for T.
template <typename Compare>
struct MinMaxNAccumulator<StringView, Compare> {
int64_t n{0};
using Allocator = StlAllocator<StringView>;
std::vector<StringView, Allocator> heapValues;
ValueSet valueSet;

explicit MinMaxNAccumulator(HashStringAllocator* allocator)
: heapValues{Allocator(allocator)}, valueSet{allocator} {}

int64_t getN() const {
return n;
}

size_t size() const {
return heapValues.size();
}

void checkAndSetN(DecodedVector& decodedN, vector_size_t row) {
// Skip null N.
if (decodedN.isNullAt(row)) {
return;
}

const auto newN = decodedN.valueAt<int64_t>(row);
VELOX_USER_CHECK_GT(
newN, 0, "second argument of max/min must be a positive integer");

VELOX_USER_CHECK_LE(
newN,
10'000,
"second argument of max/min must be less than or equal to 10000");

if (n) {
VELOX_USER_CHECK_EQ(
newN,
n,
"second argument of max/min must be a constant for all rows in a group");
} else {
n = newN;
}
}

void compareAndAdd(StringView value, Compare& comparator) {
if (heapValues.size() < n) {
heapValues.push_back(valueSet.write(value));
std::push_heap(heapValues.begin(), heapValues.end(), comparator);
} else {
const auto& topValue = heapValues.front();
if (comparator(value, topValue)) {
std::pop_heap(heapValues.begin(), heapValues.end(), comparator);
valueSet.free(heapValues.back());
heapValues.back() = valueSet.write(value);
std::push_heap(heapValues.begin(), heapValues.end(), comparator);
}
}
}

/// Copy all values from 'topValues' into 'values'. The heap remains
/// unchanged after the call.
void
extractValues(VectorPtr& values, vector_size_t offset, Compare& comparator) {
auto result = values->asFlatVector<StringView>();
std::sort_heap(heapValues.begin(), heapValues.end(), comparator);
for (int64_t i = heapValues.size() - 1; i >= 0; --i) {
result->set(offset + i, heapValues[i]);
}
std::make_heap(heapValues.begin(), heapValues.end(), comparator);
}
};

template <typename T, typename Compare>
class MinMaxNAggregateBase : public exec::Aggregate {
protected:
Expand Down Expand Up @@ -215,11 +289,9 @@ class MinMaxNAggregateBase : public exec::Aggregate {
auto values = valuesArray->elements();
values->resize(numValues);

auto* rawValues = values->asFlatVector<T>()->mutableRawValues();

auto [rawOffsets, rawSizes] = rawOffsetAndSizes(*valuesArray);

extractValues(groups, numGroups, rawOffsets, rawSizes, rawValues, nullptr);
extractValues(groups, numGroups, rawOffsets, rawSizes, values, nullptr);
}

void extractAccumulators(char** groups, int32_t numGroups, VectorPtr* result)
Expand All @@ -240,11 +312,10 @@ class MinMaxNAggregateBase : public exec::Aggregate {
values->resize(numValues);

auto* rawNs = nVector->as<FlatVector<int64_t>>()->mutableRawValues();
auto* rawValues = values->asFlatVector<T>()->mutableRawValues();

auto [rawOffsets, rawSizes] = rawOffsetAndSizes(*valuesArray);

extractValues(groups, numGroups, rawOffsets, rawSizes, rawValues, rawNs);
extractValues(groups, numGroups, rawOffsets, rawSizes, values, rawNs);
}

void destroyInternal(folly::Range<char**> groups) override {
Expand All @@ -261,7 +332,7 @@ class MinMaxNAggregateBase : public exec::Aggregate {
int32_t numGroups,
vector_size_t* rawOffsets,
vector_size_t* rawSizes,
T* rawValues,
VectorPtr& values,
int64_t* rawNs) {
vector_size_t offset = 0;
for (auto i = 0; i < numGroups; ++i) {
Expand All @@ -277,7 +348,7 @@ class MinMaxNAggregateBase : public exec::Aggregate {
if (rawNs != nullptr) {
rawNs[i] = accumulator->n;
}
accumulator->extractValues(rawValues, offset, comparator_);
accumulator->extractValues(values, offset, comparator_);

offset += size;
}
Expand Down Expand Up @@ -418,7 +489,7 @@ class MaxNAggregate : public MinMaxNAggregateBase<T, GreaterThanComparator<T>> {
: MinMaxNAggregateBase<T, GreaterThanComparator<T>>(resultType) {}
};

template <template <typename T> class TNumericN>
template <template <typename T> typename AggregateN>
exec::AggregateRegistrationResult registerMinMax(
const std::string& name,
bool withCompanionFunctions,
Expand All @@ -432,7 +503,13 @@ exec::AggregateRegistrationResult registerMinMax(
.argumentType("T")
.build());
for (const auto& type :
{"tinyint", "integer", "smallint", "bigint", "real", "double"}) {
{"tinyint",
"integer",
"smallint",
"bigint",
"real",
"double",
"varchar"}) {
// T, bigint -> row(array(T), bigint) -> array(T)
signatures.push_back(
exec::AggregateFunctionSignatureBuilder()
Expand Down Expand Up @@ -474,22 +551,24 @@ exec::AggregateRegistrationResult registerMinMax(

switch (inputType->kind()) {
case TypeKind::TINYINT:
return std::make_unique<TNumericN<int8_t>>(resultType);
return std::make_unique<AggregateN<int8_t>>(resultType);
case TypeKind::SMALLINT:
return std::make_unique<TNumericN<int16_t>>(resultType);
return std::make_unique<AggregateN<int16_t>>(resultType);
case TypeKind::INTEGER:
return std::make_unique<TNumericN<int32_t>>(resultType);
return std::make_unique<AggregateN<int32_t>>(resultType);
case TypeKind::BIGINT:
return std::make_unique<TNumericN<int64_t>>(resultType);
return std::make_unique<AggregateN<int64_t>>(resultType);
case TypeKind::REAL:
return std::make_unique<TNumericN<float>>(resultType);
return std::make_unique<AggregateN<float>>(resultType);
case TypeKind::DOUBLE:
return std::make_unique<TNumericN<double>>(resultType);
return std::make_unique<AggregateN<double>>(resultType);
case TypeKind::TIMESTAMP:
return std::make_unique<TNumericN<Timestamp>>(resultType);
return std::make_unique<AggregateN<Timestamp>>(resultType);
case TypeKind::VARCHAR:
return std::make_unique<AggregateN<StringView>>(resultType);
case TypeKind::HUGEINT:
if (inputType->isLongDecimal()) {
return std::make_unique<TNumericN<int128_t>>(resultType);
return std::make_unique<AggregateN<int128_t>>(resultType);
}
[[fallthrough]];
default:
Expand Down
108 changes: 108 additions & 0 deletions velox/functions/prestosql/aggregates/tests/MinMaxTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1284,6 +1284,114 @@ TEST_F(MinMaxNTest, longdecimal) {
testNumericGroupByDecimal<int128_t>();
}

TEST_F(MinMaxNTest, string) {
auto data = makeRowVector(
{makeFlatVector<std::string>({"1", "2", "3", "4", "abc", "xyz"})});
auto expected = makeRowVector({
makeArrayVector<std::string>({
{"1", "2"},
}),
makeArrayVector<std::string>({
{"1", "2", "3", "4", "abc"},
}),
makeArrayVector<std::string>({
{"xyz", "abc", "4"},
}),
makeArrayVector<std::string>({
{"xyz", "abc", "4", "3", "2", "1"},
}),
});

testAggregations(
{data},
{},
{"min(c0, 2)", "min(c0, 5)", "max(c0, 3)", "max(c0, 7)"},
{expected});

// Add some nulls. Expect these to be ignored.
data = makeRowVector({makeNullableFlatVector<std::string>(
{"1",
std::nullopt,
"2",
"3",
"4",
"abc",
std::nullopt,
"xyz",
std::nullopt})});

testAggregations(
{data},
{},
{"min(c0, 2)", "min(c0, 5)", "max(c0, 3)", "max(c0, 7)"},
{expected});

// Test all null input.
data = makeRowVector({makeNullableFlatVector<std::string>(
{std::nullopt,
std::nullopt,
std::nullopt,
std::nullopt,
std::nullopt,
std::nullopt,
std::nullopt,
std::nullopt,
std::nullopt})});

expected = makeRowVector({
makeAllNullArrayVector(1, data->childAt(0)->type()),
makeAllNullArrayVector(1, data->childAt(0)->type()),
makeAllNullArrayVector(1, data->childAt(0)->type()),
makeAllNullArrayVector(1, data->childAt(0)->type()),
});

testAggregations(
{data},
{},
{"min(c0, 2)", "min(c0, 5)", "max(c0, 3)", "max(c0, 7)"},
{expected});

// Test long string
data = makeRowVector({makeFlatVector<std::string>(
{"hello long string",
"hello long string2",
"hello long string3",
"hello long string a",
"this is a very long string",
"min max test",
"max min test"})});
expected = makeRowVector({
makeArrayVector<std::string>({
{"hello long string", "hello long string a"},
}),
makeArrayVector<std::string>({
{"hello long string",
"hello long string a",
"hello long string2",
"hello long string3",
"max min test"},
}),
makeArrayVector<std::string>({
{"this is a very long string", "min max test", "max min test"},
}),
makeArrayVector<std::string>({
{"this is a very long string",
"min max test",
"max min test",
"hello long string3",
"hello long string2",
"hello long string a",
"hello long string"},
}),
});

testAggregations(
{data},
{},
{"min(c0, 2)", "min(c0, 5)", "max(c0, 3)", "max(c0, 7)"},
{expected});
}

TEST_F(MinMaxNTest, incrementalWindow) {
// SELECT
// c0, c1, c2, c3,
Expand Down
Loading