Skip to content

Commit

Permalink
Merge pull request #151 from neil-lindquist/remove-tile-life
Browse files Browse the repository at this point in the history
Remove tile life infrastructure
  • Loading branch information
neil-lindquist authored Dec 20, 2023
2 parents f9ac8c7 + 8ad1a6a commit aaf28f7
Show file tree
Hide file tree
Showing 88 changed files with 890 additions and 2,624 deletions.
98 changes: 44 additions & 54 deletions include/slate/BaseMatrix.hh
Original file line number Diff line number Diff line change
Expand Up @@ -434,23 +434,24 @@ public:
void tileUpdateAllOrigin();

/// Returns life counter of tile {i, j} of op(A).
[[deprecated( "Tile life has been removed. Accessor stubs will be removed 2024-12." )]]
int64_t tileLife(int64_t i, int64_t j) const
{
return storage_->tileLife(globalIndex(i, j));
return 1;
}

/// Set life counter of tile {i, j} of op(A).
[[deprecated( "Tile life has been removed. Accessor stubs will be removed 2024-12." )]]
void tileLife(int64_t i, int64_t j, int64_t life)
{
storage_->tileLife(globalIndex(i, j), life);
}

/// Decrements life counter of workspace tile {i, j} of op(A).
/// Then, if life reaches 0, deletes tile on all devices.
/// For local, non-workspace tiles, does nothing.
[[deprecated( "Tile life has been removed. Accessor stubs will be removed 2024-12." )]]
void tileTick(int64_t i, int64_t j)
{
storage_->tileTick(globalIndex(i, j));
}

/// Returns how many times the tile {i, j} is received
Expand Down Expand Up @@ -498,21 +499,41 @@ public:

template <Target target = Target::Host>
void tileBcast(int64_t i, int64_t j, BaseMatrix const& B,
Layout layout, int tag = 0, int64_t life_factor = 1);
Layout layout, int tag = 0);

template <Target target = Target::Host>
[[deprecated( "Tile life has been removed. The 6 argument tileBcast will be removed 2024-12." )]]
void tileBcast(int64_t i, int64_t j, BaseMatrix const& B,
Layout layout, int tag, int64_t life_factor)
{
tileBcast( i, j, B, layout, tag );
}

template <Target target = Target::Host>
void listBcast( BcastList& bcast_list, Layout layout, int tag = 0, bool is_shared = false );

template <Target target = Target::Host>
[[deprecated( "Tile life has been removed. The 5 argument listBcast will be removed 2024-12." )]]
void listBcast(
BcastList& bcast_list, Layout layout,
int tag = 0, int64_t life_factor = 1,
bool is_shared = false);
BcastList& bcast_list, Layout layout, int tag,
int64_t life_factor, bool is_shared = false)
{
listBcast( bcast_list, layout, tag, is_shared );
}

// This variant takes a BcastListTag where each <i,j> tile has
// its own message tag
template <Target target = Target::Host>
void listBcastMT( BcastListTag& bcast_list, Layout layout, bool is_shared = false );

template <Target target = Target::Host>
[[deprecated( "Tile life has been removed. The 4 argument listBcastMT will be removed 2024-12." )]]
void listBcastMT(
BcastListTag& bcast_list, Layout layout,
int64_t life_factor = 1,
bool is_shared = false);
int64_t life_factor, bool is_shared = false)
{
listBcastMT( bcast_list, layout, is_shared );
}

template <Target target = Target::Host>
void listReduce(ReduceList& reduce_list, Layout layout, int tag = 0);
Expand Down Expand Up @@ -1783,8 +1804,7 @@ void BaseMatrix<scalar_t>::tileIsend(

//------------------------------------------------------------------------------
/// Receive tile {i, j} of op(A) to the given MPI rank.
/// Tile is allocated as workspace with life = 1 if it doesn't yet exist,
/// or 1 is added to life if it does exist.
/// Tile is allocated as workspace if it doesn't yet exist.
/// Source rank must call tileSend().
/// Data received must be in 'layout' (ColMajor/RowMajor) major.
///
Expand Down Expand Up @@ -1813,7 +1833,7 @@ void BaseMatrix<scalar_t>::tileRecv(
int64_t i, int64_t j, int src_rank, Layout layout, int tag)
{
if (src_rank != mpiRank()) {
storage_->tilePrepareToReceive( globalIndex( i, j ), 1, layout );
storage_->tilePrepareToReceive( globalIndex( i, j ), layout );
tileAcquire(i, j, layout);

// Receive data.
Expand All @@ -1834,8 +1854,7 @@ void BaseMatrix<scalar_t>::tileRecv(

//------------------------------------------------------------------------------
/// Receive tile {i, j} of op(A) to the given MPI rank using immediate mode..
/// Tile is allocated as workspace with life = 1 if it doesn't yet exist,
/// or 1 is added to life if it does exist.
/// Tile is allocated as workspace if it doesn't yet exist.
/// Source rank must call tileSend().
/// Data received must be in 'layout' (ColMajor/RowMajor) major.
///
Expand Down Expand Up @@ -1864,7 +1883,7 @@ void BaseMatrix<scalar_t>::tileIrecv(
int64_t i, int64_t j, int src_rank, Layout layout, int tag, MPI_Request* request)
{
if (src_rank != mpiRank()) {
storage_->tilePrepareToReceive( globalIndex( i, j ), 1, layout );
storage_->tilePrepareToReceive( globalIndex( i, j ), layout );
tileAcquire(i, j, layout);

// Receive data.
Expand Down Expand Up @@ -1904,17 +1923,14 @@ void BaseMatrix<scalar_t>::tileIrecv(
/// @param[in] tag
/// MPI tag, default 0.
///
/// @param[in] life_factor
/// A multiplier for the life count of the broadcasted tile workspace.
///
template <typename scalar_t>
template <Target target>
void BaseMatrix<scalar_t>::tileBcast(
int64_t i, int64_t j, BaseMatrix<scalar_t> const& B, Layout layout, int tag, int64_t life_factor)
int64_t i, int64_t j, BaseMatrix<scalar_t> const& B, Layout layout, int tag)
{
BcastList bcast_list_B;
bcast_list_B.push_back({i, j, {B}});
listBcast<target>(bcast_list_B, layout, tag, life_factor);
listBcast<target>(bcast_list_B, layout, tag);
}

//------------------------------------------------------------------------------
Expand All @@ -1936,8 +1952,6 @@ void BaseMatrix<scalar_t>::tileBcast(
/// @param[in] tag
/// MPI tag, default 0.
///
/// @param[in] life_factor
/// A multiplier for the life count of the broadcasted tile workspace.
/// @param[in] is_shared
/// A flag to get and hold the broadcasted (prefetched) tiles on the
/// devices. This flag prevents any subsequent calls of tileRelease()
Expand All @@ -1948,8 +1962,7 @@ void BaseMatrix<scalar_t>::tileBcast(
template <typename scalar_t>
template <Target target>
void BaseMatrix<scalar_t>::listBcast(
BcastList& bcast_list, Layout layout,
int tag, int64_t life_factor, bool is_shared)
BcastList& bcast_list, Layout layout, int tag, bool is_shared )
{
if (target == Target::Devices) {
assert(num_devices() > 0);
Expand All @@ -1962,8 +1975,6 @@ void BaseMatrix<scalar_t>::listBcast(
// used for hosting communicated tiles.
// Due to dynamic scheduling, the second communication may occur before the
// first tile has been discarded.
// If that happens, instead of creating the tile, the life of the existing
// tile is increased.
// Also, currently, the message is received to the same buffer.

std::vector< std::set<ij_tuple> > tile_set(num_devices());
Expand All @@ -1978,11 +1989,6 @@ void BaseMatrix<scalar_t>::listBcast(
auto j = std::get<1>(bcast);
auto submatrices_list = std::get<2>(bcast);

int64_t life = 0;
for (auto submatrix : submatrices_list) {
life += submatrix.numLocalTiles() * life_factor;
}

// Find the set of participating ranks.
std::set<int> bcast_set;
bcast_set.insert(tileRank(i, j)); // Insert root.
Expand All @@ -1993,7 +1999,7 @@ void BaseMatrix<scalar_t>::listBcast(
if (bcast_set.find(mpi_rank_) != bcast_set.end()) {

// If receiving the tile.
storage_->tilePrepareToReceive( globalIndex( i, j ), life, layout_ );
storage_->tilePrepareToReceive( globalIndex( i, j ), layout_ );

// Send across MPI ranks.
// Previous used MPI bcast: tileBcastToSet(i, j, bcast_set);
Expand Down Expand Up @@ -2073,9 +2079,6 @@ void BaseMatrix<scalar_t>::listBcast(
/// Indicates the Layout (ColMajor/RowMajor) of the broadcasted data.
/// WARNING: must match the layout of the tile in the sender MPI rank.
///
/// @param[in] life_factor
/// A multiplier for the life count of the broadcasted tile workspace.
///
/// @param[in] is_shared
/// A flag to get and hold the broadcasted (prefetched) tiles on the
/// devices. This flag prevents any subsequent calls of tileRelease()
Expand All @@ -2086,8 +2089,7 @@ void BaseMatrix<scalar_t>::listBcast(
template <typename scalar_t>
template <Target target>
void BaseMatrix<scalar_t>::listBcastMT(
BcastListTag& bcast_list, Layout layout,
int64_t life_factor, bool is_shared)
BcastListTag& bcast_list, Layout layout, bool is_shared )
{
if (target == Target::Devices) {
assert(num_devices() > 0);
Expand All @@ -2100,8 +2102,6 @@ void BaseMatrix<scalar_t>::listBcastMT(
// used for hosting communicated tiles.
// Due to dynamic scheduling, the second communication may occur before the
// first tile has been discarded.
// If that happens, instead of creating the tile, the life of the existing
// tile is increased.
// Also, currently, the message is received to the same buffer.

int mpi_size;
Expand All @@ -2115,8 +2115,7 @@ void BaseMatrix<scalar_t>::listBcastMT(

#if defined( SLATE_HAVE_MT_BCAST )
#pragma omp taskloop slate_omp_default_none \
shared( bcast_list ) \
firstprivate(life_factor, layout, mpi_size, is_shared)
shared( bcast_list ) firstprivate( layout, mpi_size, is_shared )
#endif
for (size_t bcastnum = 0; bcastnum < bcast_list.size(); ++bcastnum) {

Expand All @@ -2127,11 +2126,6 @@ void BaseMatrix<scalar_t>::listBcastMT(
auto tagij = std::get<3>(bcast);
int tag = int(tagij) % 32768; // MPI_TAG_UB is at least 32767

int64_t life = 0;
for (auto submatrix : submatrices_list) {
life += submatrix.numLocalTiles() * life_factor;
}

{
trace::Block trace_block(
std::string("listBcast("+std::to_string(i)+","+std::to_string(j)+")").c_str());
Expand All @@ -2145,7 +2139,7 @@ void BaseMatrix<scalar_t>::listBcastMT(
// If this rank is in the set.
if (bcast_set.find(mpi_rank_) != bcast_set.end()) {
// If receiving the tile.
storage_->tilePrepareToReceive( globalIndex( i, j ), life, layout_ );
storage_->tilePrepareToReceive( globalIndex( i, j ), layout_ );

// Send across MPI ranks.
// Previous used MPI bcast: tileBcastToSet(i, j, bcast_set);
Expand Down Expand Up @@ -2206,7 +2200,6 @@ void BaseMatrix<scalar_t>::listReduce(ReduceList& reduce_list, Layout layout, in
// If not the tile owner.
if (! tileIsLocal(i, j)) {

// todo: should we check its life count before erasing?
// Destroy the tile.
// todo: should it be a tileRelease()?
if (mpi_rank_ != root_rank)
Expand Down Expand Up @@ -3822,20 +3815,17 @@ void BaseMatrix<scalar_t>::getLocalDevices(std::set<int>* dev_set) const

//------------------------------------------------------------------------------
/// Returns number of local tiles in this matrix.
/// Used for the lifespan of a temporary tile that updates every tile in
/// the matrix.
///
template <typename scalar_t>
int64_t BaseMatrix<scalar_t>::numLocalTiles() const
{
// Find the tile's lifespan.
int64_t life = 0;
int64_t count = 0;
for (int64_t i = 0; i < mt(); ++i)
for (int64_t j = 0; j < nt(); ++j)
if (tileIsLocal(i, j))
++life;
++count;

return life;
return count;
}

//------------------------------------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion include/slate/BaseTrapezoidMatrix.hh
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ void swap(BaseTrapezoidMatrix<scalar_t>& A, BaseTrapezoidMatrix<scalar_t>& B)

//------------------------------------------------------------------------------
/// Returns number of local tiles of the matrix on this rank.
// todo: numLocalTiles? use for life as well?
// todo: numLocalTiles?
template <typename scalar_t>
int64_t BaseTrapezoidMatrix<scalar_t>::getMaxHostTiles()
{
Expand Down
11 changes: 8 additions & 3 deletions include/slate/HermitianBandMatrix.hh
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,12 @@ public:
template <typename T>
friend void swap(HermitianBandMatrix<T>& A, HermitianBandMatrix<T>& B);

void gatherAll(std::set<int>& rank_set, int tag = 0, int64_t life_factor = 1);
void gatherAll(std::set<int>& rank_set, int tag = 0);
[[deprecated( "Tile life has been removed. The 3 argument gatherAll will be removed 2024-12." )]]
void gatherAll(std::set<int>& rank_set, int tag, int64_t life_factor) {
gatherAll( rank_set, tag );
}

void he2hbGather(HermitianMatrix<scalar_t>& A);
};

Expand Down Expand Up @@ -265,7 +270,7 @@ void swap(HermitianBandMatrix<scalar_t>& A, HermitianBandMatrix<scalar_t>& B)
// avoid if possible.
//
template <typename scalar_t>
void HermitianBandMatrix<scalar_t>::gatherAll(std::set<int>& rank_set, int tag, int64_t life_factor)
void HermitianBandMatrix<scalar_t>::gatherAll(std::set<int>& rank_set, int tag)
{
trace::Block trace_block("slate::gatherAll");

Expand All @@ -286,7 +291,7 @@ void HermitianBandMatrix<scalar_t>::gatherAll(std::set<int>& rank_set, int tag,

// If receiving the tile.
this->storage_->tilePrepareToReceive( this->globalIndex( i, j ),
life_factor, this->layout_ );
this->layout_ );

// Send across MPI ranks.
// Previous used MPI bcast: tileBcastToSet(i, j, rank_set);
Expand Down
2 changes: 1 addition & 1 deletion include/slate/Matrix.hh
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,7 @@ void swap(Matrix<scalar_t>& A, Matrix<scalar_t>& B)
//------------------------------------------------------------------------------
/// Returns number of local tiles of the matrix on this rank.
//
// todo: numLocalTiles? use for life as well?
// todo: numLocalTiles?
template <typename scalar_t>
int64_t Matrix<scalar_t>::getMaxHostTiles()
{
Expand Down
11 changes: 8 additions & 3 deletions include/slate/TriangularBandMatrix.hh
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,12 @@ public:
template <typename T>
friend void swap(TriangularBandMatrix<T>& A, TriangularBandMatrix<T>& B);

void gatherAll(std::set<int>& rank_set, int tag = 0, int64_t life_factor = 1);
void gatherAll(std::set<int>& rank_set, int tag = 0);
[[deprecated( "Tile life has been removed. The 3 argument gatherAll will be removed 2024-12." )]]
void gatherAll(std::set<int>& rank_set, int tag, int64_t life_factor) {
gatherAll( rank_set, tag );
}

void ge2tbGather(Matrix<scalar_t>& A);

Diag diag() { return diag_; }
Expand Down Expand Up @@ -281,7 +286,7 @@ void swap(TriangularBandMatrix<scalar_t>& A, TriangularBandMatrix<scalar_t>& B)
// avoid if possible.
//
template <typename scalar_t>
void TriangularBandMatrix<scalar_t>::gatherAll(std::set<int>& rank_set, int tag, int64_t life_factor)
void TriangularBandMatrix<scalar_t>::gatherAll(std::set<int>& rank_set, int tag)
{
trace::Block trace_block("slate::gatherAll");

Expand All @@ -302,7 +307,7 @@ void TriangularBandMatrix<scalar_t>::gatherAll(std::set<int>& rank_set, int tag,

// If receiving the tile.
this->storage_->tilePrepareToReceive( this->globalIndex( i, j ),
life_factor, this->layout_ );
this->layout_ );

// Send across MPI ranks.
// Previous used MPI bcast: tileBcastToSet(i, j, rank_set);
Expand Down
19 changes: 5 additions & 14 deletions include/slate/c_api/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,6 @@ const slate_Target slate_Target_HostBatch = 'B'; ///< slate::Target::HostBatch
const slate_Target slate_Target_Devices = 'D'; ///< slate::Target::Devices
// end slate_Target

typedef char slate_TileReleaseStrategy; /* enum */ ///< slate::TileReleaseStrategy
const slate_TileReleaseStrategy slate_TileReleaseStrategy_None = 'N'; ///< slate::TileReleaseStrategy::None
const slate_TileReleaseStrategy slate_TileReleaseStrategy_Internal = 'I'; ///< slate::TileReleaseStrategy::Internal
const slate_TileReleaseStrategy slate_TileReleaseStrategy_Slate = 'S'; ///< slate::TileReleaseStrategy::Slate
const slate_TileReleaseStrategy slate_TileReleaseStrategy_All = 'A'; ///< slate::TileReleaseStrategy::All
// end slate_TileReleaseStrategy

typedef char slate_MethodEig; /* enum */ ///< slate::MethodEig
const slate_MethodEig slate_MethodEig_QR = 'Q'; ///< slate::MethodEig::QR
const slate_MethodEig slate_MethodEig_DC = 'D'; ///< slate::MethodEig::DC
Expand All @@ -53,12 +46,11 @@ const slate_Option slate_Option_InnerBlocking = 3; ///< slate::Option::I
const slate_Option slate_Option_MaxPanelThreads = 4; ///< slate::Option::MaxPanelThreads
const slate_Option slate_Option_Tolerance = 5; ///< slate::Option::Tolerance
const slate_Option slate_Option_Target = 6; ///< slate::Option::Target
const slate_Option slate_Option_TileReleaseStrategy = 7; ///< slate::Option::TileReleaseStrategy
const slate_Option slate_Option_HoldLocalWorkspace = 8; ///< slate::Option::HoldLocalWorkspace
const slate_Option slate_Option_Depth = 9; ///< slate::Option::HoldLocalWorkspace
const slate_Option slate_Option_MaxIterations = 10; ///< slate::Option::HoldLocalWorkspace
const slate_Option slate_Option_UseFallbackSolver = 11; ///< slate::Option::HoldLocalWorkspace
const slate_Option slate_Option_PivotThreshold = 12; ///< slate::Option::PivotThreshold
const slate_Option slate_Option_HoldLocalWorkspace = 7; ///< slate::Option::HoldLocalWorkspace
const slate_Option slate_Option_Depth = 8; ///< slate::Option::HoldLocalWorkspace
const slate_Option slate_Option_MaxIterations = 9; ///< slate::Option::HoldLocalWorkspace
const slate_Option slate_Option_UseFallbackSolver = 10; ///< slate::Option::HoldLocalWorkspace
const slate_Option slate_Option_PivotThreshold = 11; ///< slate::Option::PivotThreshold
const slate_Option slate_Option_PrintVerbose = 50; ///< slate::Option::PrintVerbose
const slate_Option slate_Option_PrintEdgeItems = 51; ///< slate::Option::PrintEdgeItems
const slate_Option slate_Option_PrintWidth = 52; ///< slate::Option::PrintWidth
Expand Down Expand Up @@ -86,7 +78,6 @@ typedef union slate_OptionValue {
int64_t max_panel_threads;
double tolerance;
slate_Target target;
slate_TileReleaseStrategy tile_release_strategy;
} slate_OptionValue; ///< slate::OptionValue

typedef struct slate_Options {
Expand Down
Loading

0 comments on commit aaf28f7

Please sign in to comment.