Skip to content

Commit

Permalink
unit tests BandMatrix: function based costructor, tilIsInBand,
Browse files Browse the repository at this point in the history
insertLocalTiles
  • Loading branch information
jamabr committed Feb 6, 2025
1 parent c20f2ef commit bc3aa54
Showing 1 changed file with 96 additions and 10 deletions.
106 changes: 96 additions & 10 deletions unit_test/test_BandMatrix.cc
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,90 @@ void test_BandMatrix_tileInsert_new()
}
}

//------------------------------------------------------------------------------
/// Tests BandMatrix(), mt, nt, op, insertLocalTiles on host.
void test_BandMatrix_insertLocalTiles()
{
auto A = slate::BandMatrix<double>(m, n, kl, ku, nb, p, q, mpi_comm);

test_assert(A.mt() == ceildiv(m, nb));
test_assert(A.nt() == ceildiv(n, nb));
test_assert(A.op() == blas::Op::NoTrans);
test_assert(A.uplo() == slate::Uplo::General);
// let function insert the local tiles belonging to band
A.insertLocalTiles();

int jj = 0;

// Verify tiles for special case all nb == const.
for (int j = 0; j < A.nt(); ++j) {
// current global row index of digonal
int ii = j*A.tileMb(0);

int istart = std::max((ii - A.upperBandwidth())/A.tileNb(0), int64_t {0});
if ((ii - A.upperBandwidth())>A.m()) {
istart = A.mt();
}

int iend = ceildiv(ii + A.tileNb(j) + A.lowerBandwidth(), A.tileNb(0));
if ((ii + A.tileNb(j) + A.lowerBandwidth() )>A.m()) {
iend = A.mt();
}

for (int i = istart; i < iend; ++i) {
if (A.tileIsLocal(i, j)) {
int ib = std::min( nb, m - ii );
int jb = std::min( nb, n - jj );

auto T = A(i, j);
test_assert( T.mb() == ib );
test_assert( T.nb() == jb );
test_assert( T.op() == slate::Op::NoTrans );
test_assert( T.uplo() == slate::Uplo::General );
}
else {
// outside band/not local, tiles don't exist
test_assert_throw_std( A(i, j) );
}
ii += A.tileMb(i);
}
jj += A.tileNb(j);
}
}

//------------------------------------------------------------------------------
/// Tests BaseBandMatrix(), mt, nt, op, tileIsInBand.
void test_BandMatrix_tileIsInBand()
{
auto A = slate::BandMatrix<double>(m, n, kl, ku, nb, p, q, mpi_comm);

test_assert(A.mt() == ceildiv(m, nb));
test_assert(A.nt() == ceildiv(n, nb));
test_assert(A.op() == blas::Op::NoTrans);
test_assert(A.uplo() == slate::Uplo::General);

// Verify tiles for special case all nb == const.
for (int j = 0; j < A.nt(); ++j) {
// current global row index of digonal
int ii = j*A.tileMb(0);

int istart = std::max((ii - A.upperBandwidth())/A.tileNb(0), int64_t {0});
if ((ii - A.upperBandwidth())>A.m()) {
istart = A.mt();
}

int iend = ceildiv(ii + A.tileNb(j) + A.lowerBandwidth(), A.tileNb(0));
if ((ii + A.tileNb(j) + A.lowerBandwidth() )>A.m()) {
iend = A.mt();
}

for (int i = 0; i < A.mt(); ++i) {
bool in = (i >= istart) && (i < iend);
test_assert( A.tileIsInBand(i, j) == in );
}
}
}

//------------------------------------------------------------------------------
void test_BandMatrix_tileInsert_data()
{
Expand Down Expand Up @@ -503,19 +587,21 @@ void run_tests()
{
if (mpi_rank == 0)
printf("\nConstructors\n");
run_test(test_BandMatrix, "BandMatrix()", mpi_comm);
run_test(test_BandMatrix_empty, "BandMatrix(m, n, ...)", mpi_comm);
run_test(test_BandMatrix_lambda, "BandMatrix(m, n, tileMb, ...)", mpi_comm);
run_test(test_BandMatrix, "BandMatrix()", mpi_comm);
run_test(test_BandMatrix_empty, "BandMatrix(m, n, ...)", mpi_comm);
run_test(test_BandMatrix_lambda, "BandMatrix(m, n, tileMb, ...)", mpi_comm);

if (mpi_rank == 0)
printf("\nMethods\n");
run_test(test_BandMatrix_transpose, "transpose", mpi_comm);
run_test(test_BandMatrix_conj_transpose, "conj_transpose", mpi_comm);
run_test(test_BandMatrix_swap, "swap", mpi_comm);
run_test(test_BandMatrix_tileInsert_new, "BandMatrix::tileInsert(i, j, dev) ", mpi_comm);
run_test(test_BandMatrix_tileInsert_data, "BandMatrix::tileInsert(i, j, dev, data, lda)", mpi_comm);
run_test(test_BandMatrix_sub, "BandMatrix::sub", mpi_comm);
run_test(test_BandMatrix_sub_trans, "BandMatrix::sub(A^T)", mpi_comm);
run_test(test_BandMatrix_transpose, "transpose", mpi_comm);
run_test(test_BandMatrix_conj_transpose, "conj_transpose", mpi_comm);
run_test(test_BandMatrix_swap, "swap", mpi_comm);
run_test(test_BandMatrix_tileInsert_new, "BandMatrix::tileInsert(i, j, dev) ", mpi_comm);
run_test(test_BandMatrix_tileInsert_data, "BandMatrix::tileInsert(i, j, dev, data, lda)", mpi_comm);
run_test(test_BandMatrix_insertLocalTiles, "BandMatrix::insertLocalTiles()", mpi_comm);
run_test(test_BandMatrix_tileIsInBand, "BaseBandMatrix::tileIsInBand()", mpi_comm);
run_test(test_BandMatrix_sub, "BandMatrix::sub", mpi_comm);
run_test(test_BandMatrix_sub_trans, "BandMatrix::sub(A^T)", mpi_comm);
run_test(test_TriangularBandMatrix_gatherAll, "TriangularBandMatrix::gatherAll()", mpi_comm);
}

Expand Down

0 comments on commit bc3aa54

Please sign in to comment.