Skip to content

Commit

Permalink
add function based constructor for band matrix and insertLocalTiles
Browse files Browse the repository at this point in the history
  • Loading branch information
jamabr committed Jan 28, 2025
1 parent 95a481c commit 6757c0a
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 0 deletions.
75 changes: 75 additions & 0 deletions include/slate/BandMatrix.hh
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,17 @@ namespace slate {
template <typename scalar_t>
class BandMatrix: public BaseBandMatrix<scalar_t> {
public:
using ij_tuple = std::tuple<int64_t, int64_t>;

// constructors
BandMatrix();

BandMatrix(int64_t m, int64_t n, int64_t kl, int64_t ku,
std::function<int64_t (int64_t j)>& inTileNb,
std::function<int (ij_tuple ij)>& inTileRank,
std::function<int (ij_tuple ij)>& inTileDevice,
MPI_Comm mpi_comm);

BandMatrix(int64_t m, int64_t n, int64_t kl, int64_t ku,
int64_t nb, int p, int q, MPI_Comm mpi_comm);

Expand Down Expand Up @@ -57,6 +65,8 @@ public:

int64_t upperBandwidth() const;
void upperBandwidth(int64_t ku);

void insertLocalTiles(Target origin=Target::Host);
};

//------------------------------------------------------------------------------
Expand All @@ -66,6 +76,47 @@ BandMatrix<scalar_t>::BandMatrix():
BaseBandMatrix<scalar_t>()
{}

//------------------------------------------------------------------------------
/// Constructor creates an m-by-n band matrix, with no tiles allocated,
/// where tileMb, tileNb, tileRank, tileDevice are given as functions.
/// Tiles can be added with tileInsert().
//
/// @param[in] m
/// Number of rows of the matrix. m >= 0.
///
/// @param[in] n
/// Number of columns of the matrix. n >= 0.
///
/// @param[in] kl
/// Number of subdiagonals within band. kl >= 0.
///
/// @param[in] ku
/// Number of superdiagonals within band. ku >= 0.
///
/// @param[in] inTileNb
/// Function that takes block-col index, returns block-col size.
///
/// @param[in] inTileRank
/// Function that takes tuple of { block-row, block-col } indices,
/// returns MPI rank for that tile.
///
/// @param[in] inTileDevice
/// Function that takes tuple of { block-row, block-col } indices,
/// returns local GPU device ID for that tile.
///
/// @param[in] mpi_comm
/// MPI communicator to distribute matrix across.
///
template <typename scalar_t>
BandMatrix<scalar_t>::BandMatrix(int64_t m, int64_t n, int64_t kl, int64_t ku,
std::function<int64_t (int64_t j)>& inTileNb,
std::function<int (ij_tuple ij)>& inTileRank,
std::function<int (ij_tuple ij)>& inTileDevice,
MPI_Comm mpi_comm)
: BaseBandMatrix<scalar_t>(m, n, kl, ku, inTileNb, inTileRank, inTileDevice,
mpi_comm)
{}

//------------------------------------------------------------------------------
/// Constructor creates an m-by-n band matrix, with no tiles allocated,
/// with fixed nb-by-nb tile size and 2D block cyclic distribution.
Expand Down Expand Up @@ -260,6 +311,30 @@ void BandMatrix<scalar_t>::upperBandwidth(int64_t ku)
this->kl_ = ku;
}

//------------------------------------------------------------------------------
/// Inserts all local tiles into an empty band matrix.
///
/// @param[in] target
/// - if target = Devices, inserts tiles on appropriate GPU devices, or
/// - if target = Host, inserts on tiles on CPU host.
///
template <typename scalar_t>
void BandMatrix<scalar_t>::insertLocalTiles(Target origin)
{
bool on_devices = (origin == Target::Devices);
int64_t mt = this->mt();
int64_t nt = this->nt();
for (int64_t j = 0; j < nt; ++j) {
for (int64_t i = 0; i < mt; ++i) {
if (this->tileIsLocal(i, j) && this->tileIsInBand(i,j)) {
int dev = (on_devices ? this->tileDevice(i, j)
: HostNum);
this->tileInsert(i, j, dev);
}
}
}
}

} // namespace slate

#endif // SLATE_BAND_MATRIX_HH
37 changes: 37 additions & 0 deletions include/slate/BaseBandMatrix.hh
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ public:
void allocateBatchArrays(int64_t batch_size=0, int64_t num_arrays=1);
void reserveDeviceWorkspace();

/// Returns whether any part of tile {i, j} of op(A) is inside band.
bool tileIsInBand(int64_t i, int64_t j) const;
// sub-matrix
Matrix<scalar_t> sub(int64_t i1, int64_t i2,
int64_t j1, int64_t j2);
Expand Down Expand Up @@ -363,6 +365,41 @@ void BaseBandMatrix<scalar_t>::tileUpdateAllOrigin()
}
}

//------------------------------------------------------------------------------
/// Check if any part of the tile {i, j} is in the band.
//
template <typename scalar_t>
bool BaseBandMatrix<scalar_t>::tileIsInBand(int64_t i, int64_t j) const
{
// global index of the first row of the tile -> diagonal index
int64_t diag_gidx_first = 0;
for (int64_t ti = 0; ti < i; ++ti) {
diag_gidx_first += this->tileMb(ti);
}
// column index where band starts (first row of tile)
int64_t bnd_col_sidx_first = std::max(diag_gidx_first - kl_, int64_t {0});
// column index where band stops (last row of tile)
int64_t bnd_col_eidx_last = diag_gidx_first + this->tileMb(i) - 1 + ku_;
// out of matrix?
bnd_col_eidx_last = std::min(bnd_col_eidx_last, this->m());

// global column index for first column of tile
int64_t col_gidx_first = 0;
for (int64_t tj = 0; tj < j; ++tj) {
col_gidx_first += this->tileNb(tj);
}
// global column index for last column of tile
int64_t col_gidx_last = std::min(col_gidx_first + this->tileNb(j) - 1, this->n());

// is last element of first row of tile right to the band start and
// is the first element of the last tile row left of the band stop?
if ((col_gidx_last >= bnd_col_sidx_first) &&
( col_gidx_first <= bnd_col_eidx_last )) {
return true;
}
return false;
}

} // namespace slate

#endif // SLATE_BASE_BAND_MATRIX_HH

0 comments on commit 6757c0a

Please sign in to comment.