Skip to content

Commit

Permalink
Merge pull request #1194 from aprokop/fix_hardcoded_floats_distributed
Browse files Browse the repository at this point in the history
Remove hardcoded float from DistributedTree nearest query
  • Loading branch information
aprokop authored Dec 5, 2024
2 parents c4653c5 + 7c9f902 commit 37adf1a
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 10 deletions.
22 changes: 16 additions & 6 deletions src/distributed/detail/ArborX_DistributedTreeNearest.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <detail/ArborX_DistributedTreeImpl.hpp>
#include <detail/ArborX_DistributedTreeNearestHelpers.hpp>
#include <detail/ArborX_DistributedTreeUtils.hpp>
#include <detail/ArborX_HappyTreeFriends.hpp>
#include <detail/ArborX_Predicates.hpp>
#include <kokkos_ext/ArborX_KokkosExtKernelStdAlgorithms.hpp>
#include <kokkos_ext/ArborX_KokkosExtStdAlgorithms.hpp>
Expand All @@ -28,11 +29,11 @@
namespace ArborX::Details
{

template <typename Value>
template <typename Value, typename Coordinate>
struct PairValueDistance
{
Value value;
float distance;
Coordinate distance;
};

template <typename ExecutionSpace, typename Tree, typename Predicates,
Expand Down Expand Up @@ -98,12 +99,13 @@ void DistributedTreeImpl::phaseI(ExecutionSpace const &space, Tree const &tree,

auto const &bottom_tree = tree._bottom_tree;
using BottomTree = std::decay_t<decltype(bottom_tree)>;
using Coordinate = typename Distances::value_type;

// Gather distances from every identified rank
Kokkos::View<float *, MemorySpace> distances(prefix + "::distances", 0);
Kokkos::View<Coordinate *, MemorySpace> distances(prefix + "::distances", 0);
forwardQueriesAndCommunicateResults(
comm, space, bottom_tree, predicates,
CallbackWithDistance<BottomTree, DefaultCallback, float, false>(
CallbackWithDistance<BottomTree, DefaultCallback, Coordinate, false>(
space, bottom_tree, DefaultCallback{}),
nearest_ranks, offset, distances);

Expand Down Expand Up @@ -145,10 +147,12 @@ void DistributedTreeImpl::phaseII(ExecutionSpace const &space, Tree const &tree,

auto const &bottom_tree = tree._bottom_tree;
using BottomTree = std::decay_t<decltype(bottom_tree)>;
using Coordinate = typename Distances::value_type;

// NOTE: in principle, we could perform radius searches on the bottom_tree
// rather than nearest predicates.
Kokkos::View<PairValueDistance<typename Values::value_type> *, MemorySpace>
Kokkos::View<PairValueDistance<typename Values::value_type, Coordinate> *,
MemorySpace>
out(prefix + "::pairs_value_distance", 0);
DistributedTree::forwardQueriesAndCommunicateResults(
tree.getComm(), space, bottom_tree, predicates,
Expand Down Expand Up @@ -191,7 +195,13 @@ void DistributedTreeImpl::queryDispatch2RoundImpl(
return;
}

Kokkos::View<float *, typename Tree::memory_space> farthest_distances(
// Set the type for the distances to be that of the distance to a leaf node.
// It is possible that that is a higher precision compared to internal nodes,
// but it safer.
using Coordinate = decltype(predicates(0).distance(
HappyTreeFriends::getIndexable(tree._bottom_tree, 0)));

Kokkos::View<Coordinate *, typename Tree::memory_space> farthest_distances(
Kokkos::view_alloc(space, Kokkos::WithoutInitializing,
prefix + "::farthest_distances"),
predicates.size());
Expand Down
10 changes: 6 additions & 4 deletions src/distributed/detail/ArborX_DistributedTreeUtils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,15 +259,17 @@ void forwardQueriesAndCommunicateResults(
Kokkos::Profiling::popRegion();
}

template <typename ExecutionSpace, typename MemorySpace, typename Predicates,
template <typename ExecutionSpace, typename Predicates, typename Distances,
typename Values, typename Offset>
void filterResults(ExecutionSpace const &space, Predicates const &queries,
Kokkos::View<float *, MemorySpace> const &distances,
Values &values, Offset &offset)
Distances const &distances, Values &values, Offset &offset)
{
Kokkos::Profiling::ScopedRegion guard(
"ArborX::DistributedTree::filterResults");

static_assert(Kokkos::is_view_v<Distances> && Distances::rank() == 1);

using MemorySpace = typename Distances::memory_space;
using Value = typename Values::value_type;

int const n_queries = queries.size();
Expand All @@ -288,7 +290,7 @@ void filterResults(ExecutionSpace const &space, Predicates const &queries,
Kokkos::View<Value *, MemorySpace> new_values(
Kokkos::view_alloc(space, values.label()), n_truncated_results);

using PairValueDistance = Kokkos::pair<Value, float>;
using PairValueDistance = Kokkos::pair<Value, typename Distances::value_type>;
struct CompareDistance
{
KOKKOS_INLINE_FUNCTION bool operator()(PairValueDistance const &lhs,
Expand Down
63 changes: 63 additions & 0 deletions test/tstDistributedTreeNearest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,69 @@ BOOST_AUTO_TEST_CASE_TEMPLATE(hello_world_nearest, DeviceType,
}
}

// FIXME: Almost identical to hellow_world_nearest, but uses double. Testing
// needs refactoring.
BOOST_AUTO_TEST_CASE_TEMPLATE(double_tree, DeviceType, ARBORX_DEVICE_TYPES)
{
using ExecutionSpace = typename DeviceType::execution_space;

using Point = ArborX::Point<3, double>;

MPI_Comm comm = MPI_COMM_WORLD;
int comm_rank;
MPI_Comm_rank(comm, &comm_rank);
int comm_size;
MPI_Comm_size(comm, &comm_size);

int const n = 4;
std::vector<Point> points(n);
// [ rank 0 [ rank 1 [ rank 2 [ rank 3 [
// x---x---x---x---x---x---x---x---x---x---x---x---x---x---x---x---
// ^ ^ ^ ^
// 0 1 2 3 ^ ^ ^ ^
// 0 1 2 3 ^ ^ ^ ^
// 0 1 2 3 ^ ^ ^ ^
// 0 1 2 3
for (int i = 0; i < n; ++i)
points[i] = {{(double)i / n + comm_rank, 0., 0.}};

auto tree = makeDistributedTree<DeviceType>(comm, ExecutionSpace{}, points);

// 0---0---0---0---1---1---1---1---2---2---2---2---3---3---3---3---
// | | | | |
// | | | x x x |
// | | x x x <--0--> |
// | x x x <--1--> | |
// x x <--2--> | | |
// 3--> | | | |
// | | | | |
Kokkos::View<ArborX::Nearest<Point> *, DeviceType> nearest_queries(
"Testing::nearest_queries", 1);
auto nearest_queries_host = Kokkos::create_mirror_view(nearest_queries);
nearest_queries_host(0) =
ArborX::nearest<Point>({{0.f + comm_size - 1 - comm_rank, 0., 0.}},
comm_rank < comm_size - 1 ? 3 : 2);
deep_copy(nearest_queries, nearest_queries_host);

std::vector<PairIndexRank> values;
values.reserve(n + 1);
for (int i = 0; i < n; ++i)
values.push_back({n - 1 - i, comm_size - 1 - comm_rank});

BOOST_TEST(n > 2);
ARBORX_TEST_QUERY_TREE(ExecutionSpace{}, tree, nearest_queries,
(comm_rank < comm_size - 1
? make_reference_solution<PairIndexRank>(
{{0, comm_size - 1 - comm_rank},
{n - 1, comm_size - 2 - comm_rank},
{1, comm_size - 1 - comm_rank}},
{0, 3})
: make_reference_solution<PairIndexRank>(
{{0, comm_size - 1 - comm_rank},
{1, comm_size - 1 - comm_rank}},
{0, 2})));
}

#if 0
BOOST_AUTO_TEST_CASE_TEMPLATE(empty_tree_nearest, DeviceType,
ARBORX_DEVICE_TYPES)
Expand Down

0 comments on commit 37adf1a

Please sign in to comment.