diff --git a/src/cluster/ArborX_MinimumSpanningTree.hpp b/src/cluster/ArborX_MinimumSpanningTree.hpp index e25be7b74..ce07b4072 100644 --- a/src/cluster/ArborX_MinimumSpanningTree.hpp +++ b/src/cluster/ArborX_MinimumSpanningTree.hpp @@ -37,6 +37,8 @@ struct MinimumSpanningTree Kokkos::View edges; Kokkos::View dendrogram_parents; Kokkos::View dendrogram_parent_heights; + Kokkos::View _chain_offsets; + Kokkos::View _chain_levels; template MinimumSpanningTree(ExecutionSpace const &space, Primitives const &primitives, @@ -46,6 +48,8 @@ struct MinimumSpanningTree AccessTraits::size(primitives) - 1) , dendrogram_parents("ArborX::MST::dendrogram_parents", 0) , dendrogram_parent_heights("ArborX::MST::dendrogram_parent_heights", 0) + , _chain_offsets("ArborX::MST::chain_offsets", 0) + , _chain_levels("ArborX::MST::chain_levels", 0) { Kokkos::Profiling::pushRegion("ArborX::MST::MST"); @@ -182,6 +186,8 @@ struct MinimumSpanningTree int num_components = n; [[maybe_unused]] int edges_start = 0; [[maybe_unused]] int edges_end = 0; + std::vector edge_offsets; + edge_offsets.push_back(0); do { Kokkos::Profiling::pushRegion("ArborX::Boruvka_" + @@ -224,6 +230,8 @@ struct MinimumSpanningTree Kokkos::deep_copy(space, num_edges_host, num_edges); space.fence(); + edge_offsets.push_back(num_edges_host); + if constexpr (Mode == BoruvkaMode::HDBSCAN) { Kokkos::parallel_for( @@ -277,7 +285,19 @@ struct MinimumSpanningTree std::make_pair(edges_start, edges_end)), ROOT_CHAIN_VALUE); - computeParents(space, edges, sided_parents, dendrogram_parents); + Kokkos::View edge_hierarchy_offsets( + Kokkos::view_alloc(space, Kokkos::WithoutInitializing, + "ArborX::MST::edge_hierarchy_offsets"), + edge_offsets.size()); + Kokkos::deep_copy( + space, edge_hierarchy_offsets, + Kokkos::View{ + edge_offsets.data(), edge_offsets.size()}); + + computeParentsAndReorderEdges(space, edges, edge_hierarchy_offsets, + sided_parents, dendrogram_parents, + _chain_offsets, _chain_levels); + Kokkos::resize(sided_parents, 0); KokkosExt::reallocWithoutInitializing(space, dendrogram_parent_heights, n - 1); diff --git a/src/cluster/detail/ArborX_BoruvkaHelpers.hpp b/src/cluster/detail/ArborX_BoruvkaHelpers.hpp index ff31da940..3f66fd9f2 100644 --- a/src/cluster/detail/ArborX_BoruvkaHelpers.hpp +++ b/src/cluster/detail/ArborX_BoruvkaHelpers.hpp @@ -16,7 +16,8 @@ #include #include #include -#include +#include +#include #include #include @@ -547,10 +548,14 @@ void assignVertexParents(ExecutionSpace const &space, Labels const &labels, }); } -template -void computeParents(ExecutionSpace const &space, Edges const &edges, - SidedParents const &sided_parents, Parents &parents) +template +void computeParentsAndReorderEdges( + ExecutionSpace const &space, Edges &edges, + EdgeHierarchyOffsets const &edge_hierarchy_offsets, + SidedParents const &sided_parents, Parents &parents, + ChainOffsets &chain_offsets, ChainLevels &chain_levels) { Kokkos::Profiling::ScopedRegion guard("ArborX::MST::compute_edge_parents"); @@ -654,28 +659,77 @@ void computeParents(ExecutionSpace const &space, Edges const &edges, } }); + auto rev_permute = + KokkosExt::cloneWithoutInitializingNorCopying(space, permute); Kokkos::parallel_for( + "ArborX::MST::compute_rev_permute", + Kokkos::RangePolicy(space, 0, num_edges), + KOKKOS_LAMBDA(int const i) { rev_permute(permute(i)) = i; }); + + Kokkos::parallel_for( + "ArborX::MST::update_vertex_parents", + Kokkos::RangePolicy(space, num_edges, parents.size()), + KOKKOS_LAMBDA(int const i) { parents(i) = rev_permute(parents(i)); }); + + KokkosExt::reallocWithoutInitializing(space, chain_offsets, num_edges + 1); + int num_chains; + Kokkos::parallel_scan( "ArborX::MST::compute_parents", Kokkos::RangePolicy(space, 0, num_edges), - KOKKOS_LAMBDA(int const i) { - int e = permute(i); - if (i == num_edges - 1) - { - // The parent of the root node is set to -1 - parents(e) = -1; - } - else if ((keys(i) >> shift) == (keys(i + 1) >> shift)) + KOKKOS_LAMBDA(int const i, int &update, bool final_pass) { + if (i == num_edges - 1 || (keys(i) >> shift) != (keys(i + 1) >> shift)) + ++update; + + if (final_pass) { - // For the edges belonging to the same chain, assign the parent of an - // edge to the edge with the next larger value - parents(e) = permute(i + 1); + if (i == num_edges - 1) + { + // The parent of the root node is set to -1 + parents(i) = -1; + chain_offsets(update) = num_edges; + chain_offsets(0) = 0; + } + else if ((keys(i) >> shift) == (keys(i + 1) >> shift)) + { + // For the edges belonging to the same chain, assign the parent of + // an edge to the edge with the next larger value + parents(i) = i + 1; + } + else + { + // For an edge which points to the root of a chain, assign edge's + // parent to be that root + parents(i) = rev_permute((keys(i) >> shift) / 2); + chain_offsets(update) = i + 1; + } } - else + }, + num_chains); + Kokkos::resize(Kokkos::WithoutInitializing, chain_offsets, num_chains + 1); + + Details::applyPermutation(space, permute, edges); + + Kokkos::resize(Kokkos::WithoutInitializing, chain_levels, num_chains + 1); + int num_levels = 0; + Kokkos::parallel_scan( + "ArborX::MST::compute_chain_levels", + Kokkos::RangePolicy(space, 0, num_chains), + KOKKOS_LAMBDA(int i, int &level, bool final_pass) { + auto upper_bound = [&v = edge_hierarchy_offsets](int x) { + return KokkosExt::upper_bound(v.data(), v.data() + v.size(), x); + }; + + if (i == num_chains - 1 || + upper_bound(permute(chain_offsets(i))) != + upper_bound(permute(chain_offsets(i + 1)))) { - // For an edge which points to the root of a chain, assign edge's - // parent to be that root - parents(e) = (keys(i) >> shift) / 2; + ++level; + if (final_pass) + chain_levels(level) = i + 1; } - }); + }, + num_levels); + ++num_levels; + Kokkos::resize(Kokkos::WithoutInitializing, chain_levels, num_levels); } // Compute upper bound on the shortest edge of each component.