Skip to content

Commit

Permalink
replace indexGlobal2Local and numberLocalRowOrCol with global2local a…
Browse files Browse the repository at this point in the history
…nd num_local_rows_cols
  • Loading branch information
mgates3 committed May 30, 2024
1 parent b4a3038 commit d074667
Show file tree
Hide file tree
Showing 6 changed files with 20 additions and 77 deletions.
55 changes: 0 additions & 55 deletions include/slate/BaseMatrix.hh
Original file line number Diff line number Diff line change
Expand Up @@ -3971,61 +3971,6 @@ void BaseMatrix<scalar_t>::releaseRemoteWorkspace(
}
}

//------------------------------------------------------------------------------
// from ScaLAPACK's indxg2l
// todo: where to put utilities like this?
inline int64_t indexGlobal2Local(int64_t i, int64_t nb, int num_ranks)
{
return nb*(i/(nb*num_ranks)) + (i % nb);
}

//------------------------------------------------------------------------------
// from ScaLAPACK's numroc
/// [internal]
/// Computes the number of Rows Or Columns of a distributed
/// matrix owned by the process indicated by IPROC.
///
/// @param[in] n
/// The number of rows/columns in distributed matrix.
///
/// @param[in] nb
/// Block size, size of the blocks the distributed matrix is split into.
///
/// @param[in] iproc
/// The coordinate of the process whose local array row or
/// column is to be determined.
///
/// @param[in] isrcproc
/// The coordinate of the process that possesses the first
/// row or column of the distributed matrix.
///
/// @param[in] nprocs
/// The total number processes over which the matrix is distributed.
///
inline int64_t numberLocalRowOrCol(int64_t n, int64_t nb, int iproc, int isrcproc, int nprocs)
{
int64_t numroc;
// Figure PROC's distance from source process
int mydist = (nprocs+iproc-isrcproc) % nprocs;

// Figure the total number of whole NB blocks N is split up into
int nblocks = (int)(n / nb);
// Figure the minimum number of rows/cols a process can have
numroc = (int64_t)(nblocks/nprocs) * nb;
// See if there are any extra blocks
int extrablks = nblocks % nprocs;

// If I have an extra block
if (mydist < extrablks) {
numroc = numroc + nb;
}
// If I have last block, it may be a partial block
else if (mydist == extrablks) {
numroc = numroc + n % nb;
}
return numroc;
}

} // namespace slate

#endif // SLATE_BASE_MATRIX_HH
22 changes: 10 additions & 12 deletions include/slate/BaseTrapezoidMatrix.hh
Original file line number Diff line number Diff line change
Expand Up @@ -222,15 +222,15 @@ BaseTrapezoidMatrix<scalar_t>::BaseTrapezoidMatrix(
int64_t jb = this->tileNb(j);
int64_t jj_local = jj;
if (is_scalapack) {
jj_local = indexGlobal2Local(jj, nb, q);
jj_local = global2local( jj, nb, q );
}

int64_t ii = j*nb;
for (int64_t i = j; i < this->mt(); ++i) { // lower
int64_t ib = this->tileMb(i);
int64_t ii_local = ii;
if (is_scalapack) {
ii_local = indexGlobal2Local(ii, nb, p);
ii_local = global2local( ii, nb, p );
}

if (this->tileIsLocal(i, j)) {
Expand All @@ -248,15 +248,15 @@ BaseTrapezoidMatrix<scalar_t>::BaseTrapezoidMatrix(
int64_t jb = this->tileNb(j);
int64_t jj_local = jj;
if (is_scalapack) {
jj_local = indexGlobal2Local(jj, nb, q);
jj_local = global2local( jj, nb, q );
}

int64_t ii = 0;
for (int64_t i = 0; i <= j && i < this->mt(); ++i) { // upper
int64_t ib = this->tileMb(i);
int64_t ii_local = ii;
if (is_scalapack) {
ii_local = indexGlobal2Local(ii, nb, p);
ii_local = global2local( ii, nb, p );
}

if (this->tileIsLocal(i, j)) {
Expand Down Expand Up @@ -328,17 +328,16 @@ BaseTrapezoidMatrix<scalar_t>::BaseTrapezoidMatrix(
int64_t jj = 0;
for (int64_t j = 0; j < this->nt(); ++j) {
int64_t jb = this->tileNb(j);
int64_t jj_local = indexGlobal2Local(jj, nb, q);
int64_t jj_local = global2local( jj, nb, q );

int64_t ii = j*nb;
for (int64_t i = j; i < this->mt(); ++i) { // lower
int64_t ib = this->tileMb(i);
int64_t ii_local = indexGlobal2Local(ii, nb, p);
int64_t ii_local = global2local( ii, nb, p );

if (this->tileIsLocal(i, j)) {
int dev = this->tileDevice(i, j);
int64_t jj_dev
= indexGlobal2Local(jj_local, nb, num_devices);
int64_t jj_dev = global2local( jj_local, nb, num_devices );
this->tileInsert(
i, j, dev, &Aarray[dev][ii_local + jj_dev*lda], lda);
}
Expand All @@ -351,17 +350,16 @@ BaseTrapezoidMatrix<scalar_t>::BaseTrapezoidMatrix(
int64_t jj = 0;
for (int64_t j = 0; j < this->nt(); ++j) {
int64_t jb = this->tileNb(j);
int64_t jj_local = indexGlobal2Local(jj, nb, q);
int64_t jj_local = global2local( jj, nb, q );

int64_t ii = 0;
for (int64_t i = 0; i <= j && i < this->mt(); ++i) { // upper
int64_t ib = this->tileMb(i);
int64_t ii_local = indexGlobal2Local(ii, nb, p);
int64_t ii_local = global2local( ii, nb, p );

if (this->tileIsLocal(i, j)) {
int dev = this->tileDevice(i, j);
int64_t jj_dev
= indexGlobal2Local(jj_local, nb, num_devices);
int64_t jj_dev = global2local( jj_local, nb, num_devices );
this->tileInsert(
i, j, dev, &Aarray[dev][ii_local + jj_dev*lda], lda);
}
Expand Down
10 changes: 5 additions & 5 deletions include/slate/Matrix.hh
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ Matrix<scalar_t>::Matrix(
int64_t jb = this->tileNb(j);
int64_t jj_local = jj;
if (is_scalapack) {
jj_local = indexGlobal2Local(jj, nb, q);
jj_local = global2local( jj, nb, q );
}

int64_t ii = 0;
Expand All @@ -509,7 +509,7 @@ Matrix<scalar_t>::Matrix(
if (this->tileIsLocal(i, j)) {
int64_t ii_local = ii;
if (is_scalapack) {
ii_local = indexGlobal2Local(ii, mb, p);
ii_local = global2local( ii, mb, p );
}

this->tileInsert( i, j, HostNum,
Expand Down Expand Up @@ -545,14 +545,14 @@ Matrix<scalar_t>::Matrix(
int64_t jj = 0;
for (int64_t j = 0; j < this->nt(); ++j) {
int64_t jb = this->tileNb(j);
int64_t jj_local = indexGlobal2Local(jj, nb, q);
int64_t jj_local = global2local( jj, nb, q );
int64_t ii = 0;
for (int64_t i = 0; i < this->mt(); ++i) {
int64_t ib = this->tileMb(i);
if (this->tileIsLocal(i, j)) {
int64_t ii_local = indexGlobal2Local(ii, mb, p);
int64_t ii_local = global2local( ii, mb, p );
int dev = this->tileDevice(i, j);
int64_t jj_dev = indexGlobal2Local(jj_local, nb, num_devices);
int64_t jj_dev = global2local( jj_local, nb, num_devices );
this->tileInsert(i, j, dev,
&Aarray[ dev ][ ii_local + jj_dev*lda ], lda);
}
Expand Down
4 changes: 2 additions & 2 deletions include/slate/internal/util.hh
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ private:
/// The total number processes over which the distributed
/// matrix is distributed.
///
/// Corresponds to ScaLAPACK's indxg2l.
/// Corresponds to ScaLAPACK's indxl2g.
///
inline int64_t local2global(
int64_t local,
Expand Down Expand Up @@ -202,7 +202,7 @@ inline int64_t local2global(
/// The total number processes over which the distributed
/// matrix is distributed.
///
/// Corresponds to ScaLAPACK's indxl2g.
/// Corresponds to ScaLAPACK's indxg2l.
/// Dummy arguments iproc, isrcproc are excluded.
///
inline int64_t global2local(
Expand Down
4 changes: 2 additions & 2 deletions src/bdsqr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ void bdsqr(
slate_mpi_call(
MPI_Comm_size(U.mpiComm(), &mpi_size));
myrow = U.mpiRank();
nru = numberLocalRowOrCol(m, mb, myrow, izero, mpi_size);
nru = num_local_rows_cols( m, mb, myrow, izero, mpi_size );
ldu = max( 1, nru );
u1d.resize(ldu*min_mn);
U1d = slate::Matrix<scalar_t>::fromScaLAPACK(
Expand All @@ -91,7 +91,7 @@ void bdsqr(
slate_mpi_call(
MPI_Comm_size(VT.mpiComm(), &mpi_size));
mycol = VT.mpiRank();
ncvt = numberLocalRowOrCol(n, nb, mycol, izero, mpi_size);
ncvt = num_local_rows_cols( n, nb, mycol, izero, mpi_size );
ldvt = max( 1, min_mn );
vt1d.resize(ldvt*ncvt);
VT1d = slate::Matrix<scalar_t>::fromScaLAPACK(
Expand Down
2 changes: 1 addition & 1 deletion src/steqr2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ void steqr2(
n = Z.n();
nb = Z.tileNb(0);
myrow = Z.mpiRank();
nrc = numberLocalRowOrCol(n, nb, myrow, izero, mpi_size);
nrc = num_local_rows_cols( n, nb, myrow, izero, mpi_size );
ldc = max( 1, nrc );
Q.resize(nrc*n);
Z1d = slate::Matrix<scalar_t>::fromScaLAPACK(
Expand Down

0 comments on commit d074667

Please sign in to comment.