Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Template arg list #208

Merged
merged 3 commits into from
Jan 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion GNUmakefile
Original file line number Diff line number Diff line change
Expand Up @@ -1352,7 +1352,7 @@ ${pkg}:

lib: ${slate} ${matgen}

clean: src/clean test/clean unit_test/clean scalapack_api/clean lapack_api/clean include/clean
clean: src/clean matgen/clean test/clean unit_test/clean scalapack_api/clean lapack_api/clean include/clean
rm -f lib/lib* ${dep}
rm -f trace_*.svg

Expand Down
4 changes: 2 additions & 2 deletions src/gbtrf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ int64_t gbtrf(
// send A(i, k) across row A(i, k+1:nt-1)
bcast_list_A.push_back({i, k, {A.sub(i, i, k+1, j_end-1)}});
}
A.template listBcast(bcast_list_A, layout, tag_k);
A.template listBcast<>( bcast_list_A, layout, tag_k );

// Root broadcasts the pivot to all ranks.
// todo: Panel ranks send the pivots to the right.
Expand Down Expand Up @@ -190,7 +190,7 @@ int64_t gbtrf(
// send A(k, j) across column A(k+1:mt-1, j)
bcast_list_A.push_back({k, j, {A.sub(k+1, i_end-1, j, j)}});
}
A.template listBcast(bcast_list_A, layout, tag_kl1);
A.template listBcast<>( bcast_list_A, layout, tag_kl1 );

// A(k+1:mt-1, kl+1:nt-1) -= A(k+1:mt-1, k) * A(k, kl+1:nt-1)
internal::gemm<Target::HostTask>(
Expand Down
8 changes: 4 additions & 4 deletions src/ge2tb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ void ge2tb(
bcast_list_T.push_back(
{row, k, {TUlocal.sub(row, row, k+1, A_nt-1)}});
}
TUlocal.template listBcast(bcast_list_T, layout);
TUlocal.template listBcast<>( bcast_list_T, layout );
}

// bcast TUreduce across row for trailing matrix update
Expand All @@ -245,7 +245,7 @@ void ge2tb(
{row, k, {TUreduce.sub(row, row, k+1, A_nt-1)}});
}
}
TUreduce.template listBcast(bcast_list_T, layout);
TUreduce.template listBcast<>( bcast_list_T, layout );
}

int64_t j = k+1;
Expand Down Expand Up @@ -380,7 +380,7 @@ void ge2tb(
bcast_list_T.push_back(
{k, col, {TVlocal.sub(k+1, A_mt-1, col, col)}});
}
TVlocal.template listBcast(bcast_list_T, layout);
TVlocal.template listBcast<>( bcast_list_T, layout );
}

// bcast TVreduce down col for trailing matrix update
Expand All @@ -394,7 +394,7 @@ void ge2tb(
{k, col, {TVreduce.sub(k+1, A_mt-1, col, col)}});
}
}
TVreduce.template listBcast(bcast_list_T, layout);
TVreduce.template listBcast<>( bcast_list_T, layout );
}

int64_t i = k+1;
Expand Down
2 changes: 1 addition & 1 deletion src/gelqf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ void gelqf(
if (col > k) // exclude the first col of this panel that has no Treduce tile
bcast_list_T.push_back({k, col, {Treduce.sub(k+1, A_mt-1, col, col)}});
}
Treduce.template listBcast(bcast_list_T, layout);
Treduce.template listBcast<>( bcast_list_T, layout );
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/gemmA.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ void gemmA(
{A.sub( i, i, 0, A.nt()-1 )}
} );
int tag_0 = 0;
C.template listReduce( reduce_list_C, layout, tag_0 );
C.template listReduce<>( reduce_list_C, layout, tag_0 );
}

// Clean up workspace
Expand Down Expand Up @@ -180,7 +180,7 @@ void gemmA(
{A.sub( i, i, 0, A.nt()-1 )}
} );
int tag_k = k;
C.template listReduce( reduce_list_C, layout, tag_k );
C.template listReduce<>( reduce_list_C, layout, tag_k );
}

// Clean up workspace
Expand Down
2 changes: 1 addition & 1 deletion src/geqrf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ void geqrf(
if (row > k) // exclude the first row of this panel that has no Treduce tile
bcast_list_T.push_back({row, k, {Treduce.sub(row, row, k+1, A_nt-1)}});
}
Treduce.template listBcast(bcast_list_T, layout);
Treduce.template listBcast<>( bcast_list_T, layout );
}
}
}
Expand Down
6 changes: 3 additions & 3 deletions src/gerbt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,8 @@ void gerbt(Matrix<scalar_t>& U_in,
internal::gerbt_bcast_filter_duplicates<scalar_t>(bcast_list_U);
internal::gerbt_bcast_filter_duplicates<scalar_t>(bcast_list_V);

U.template listBcastMT(bcast_list_U, Layout::ColMajor);
V.template listBcastMT(bcast_list_V, Layout::ColMajor);
U.template listBcastMT<>( bcast_list_U, Layout::ColMajor );
V.template listBcastMT<>( bcast_list_V, Layout::ColMajor );

// NB: only tasks created so far are in listBcastMT

Expand Down Expand Up @@ -297,7 +297,7 @@ void gerbt(Matrix<scalar_t>& Uin,

// Bcast random factors
internal::gerbt_bcast_filter_duplicates<scalar_t>(bcast_list);
U.template listBcastMT(bcast_list, Layout::ColMajor);
U.template listBcastMT<>( bcast_list, Layout::ColMajor );

// NB: only tasks created so far are in listBcastMT

Expand Down
8 changes: 4 additions & 4 deletions src/getri.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ void getri(
}

// send W down col A(0:nt-1, k)
W.template tileBcast(
0, 0, A.sub(0, A.nt()-1, k, k), layout);
W.template tileBcast<>(
0, 0, A.sub( 0, A.nt()-1, k, k ), layout );

auto Wkk = TriangularMatrix<scalar_t>(Uplo::Lower, Diag::Unit, W);
internal::trsm<Target::HostTask>(
Expand Down Expand Up @@ -106,7 +106,7 @@ void getri(
// send W(i) down column A(0:nt-1, k+i)
bcast_list_W.push_back({i, 0, {A.sub(0, A.nt()-1, k+i, k+i)}});
}
W.template listBcast(bcast_list_W, layout);
W.template listBcast<>( bcast_list_W, layout );

// A(:, k) -= A(:, k+1:nt-1) * W
internal::gemmA<Target::HostTask>(
Expand All @@ -124,7 +124,7 @@ void getri(
{A.sub(i, i, k+1, A.nt()-1)}
});
}
A.template listReduce(reduce_list_A, layout);
A.template listReduce<>( reduce_list_A, layout );

// Release workspace tiles from gemmA
A.sub(0, A.nt()-1, k, k).releaseRemoteWorkspace();
Expand Down
2 changes: 1 addition & 1 deletion src/he2hb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ void he2hb(
i } );
}
}
Treduce.template listBcastMT( bcast_list_T, layout );
Treduce.template listBcastMT<>( bcast_list_T, layout );
}

std::vector<int64_t> panel_rank_rows;
Expand Down
2 changes: 1 addition & 1 deletion src/hetrf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ int64_t hetrf(
bcast_list_T.push_back({k+1, k, {A.sub(k+1, A_mt-1, k-1, k-1)}});
// for computing T(j, j)
bcast_list_T.push_back({k+1, k, {A.sub(k+1, k+1, k+1, k+1)}});
T.template listBcast(bcast_list_T, layout, tag);
T.template listBcast<>( bcast_list_T, layout, tag );
}
}
#pragma omp task depend(inout:columnL[k])
Expand Down
2 changes: 1 addition & 1 deletion src/pbtrf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ int64_t pbtrf(
bcast_list_A.push_back({i, k, {A.sub(i, i, k+1, i),
A.sub(i, ij_end-1, i, i)}});
}
A.template listBcast(bcast_list_A, layout);
A.template listBcast<>( bcast_list_A, layout );
}

// update trailing submatrix, normal priority
Expand Down
6 changes: 3 additions & 3 deletions src/tbsmPivots.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ void tbsm(
#pragma omp task depend(inout:row[k]) priority(1)
{
// send A(k, k) to ranks owning block row B(k, :)
A.template tileBcast(k, k, B.sub(k, k, 0, nt-1), layout);
A.template tileBcast<>( k, k, B.sub( k, k, 0, nt-1 ), layout );

// solve A(k, k) B(k, :) = B(k, :)
internal::trsm<Target::HostTask>(
Expand Down Expand Up @@ -215,7 +215,7 @@ void tbsm(
#pragma omp task depend(inout:row[k]) priority(1)
{
// send A(k, k) to ranks owning block row B(k, :)
A.template tileBcast(k, k, B.sub(k, k, 0, nt-1), layout);
A.template tileBcast<>( k, k, B.sub( k, k, 0, nt-1 ), layout );

// solve A(k, k) B(k, :) = B(k, :)
internal::trsm<Target::HostTask>(
Expand Down Expand Up @@ -324,7 +324,7 @@ void tbsm(
// solve diagonal block (Akk tile)
{
// send A(k, k) to ranks owning block row B(k, :)
A.template tileBcast(k, k, B.sub(k, k, 0, nt-1), layout);
A.template tileBcast<>( k, k, B.sub( k, k, 0, nt-1 ), layout );

// solve A(k, k) B(k, :) = B(k, :)
internal::trsm<Target::HostTask>(
Expand Down
2 changes: 1 addition & 1 deletion src/trtrm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ void trtrm(
bcast_list_A.push_back({k, j, {A.sub(j, k-1, j, j),
A.sub(j, j, 0, j)}});
}
A.template listBcast(bcast_list_A, layout);
A.template listBcast<>( bcast_list_A, layout );
}

// update tailing submatrix
Expand Down
4 changes: 2 additions & 2 deletions src/unmlq.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ void unmlq(
bcast_list_T.push_back(
{k, j, {C.sub(i0, i1, j0, j1)}});
}
Tlocal.template listBcast(bcast_list_T, layout);
Tlocal.template listBcast<>( bcast_list_T, layout );
}

// Send Treduce(j) across row C(j, 0:nt-1) or col C(0:mt-1, j).
Expand All @@ -173,7 +173,7 @@ void unmlq(
{k, j, {C.sub(i0, i1, j0, j1)}});
}
}
Treduce.template listBcast(bcast_list_T, layout);
Treduce.template listBcast<>( bcast_list_T, layout );
}

Matrix<scalar_t> C_trail, W_trail;
Expand Down
4 changes: 2 additions & 2 deletions src/unmqr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ void unmqr(
bcast_list_T.push_back(
{i, k, {C.sub(i0, i1, j0, j1)}});
}
Tlocal.template listBcast(bcast_list_T, layout);
Tlocal.template listBcast<>( bcast_list_T, layout );
}

// Send Treduce(i) across row C(i, 0:nt-1) or col C(0:mt-1, i).
Expand All @@ -177,7 +177,7 @@ void unmqr(
{i, k, {C.sub(i0, i1, j0, j1)}});
}
}
Treduce.template listBcast(bcast_list_T, layout);
Treduce.template listBcast<>( bcast_list_T, layout );
}

Matrix<scalar_t> C_trail, W_trail;
Expand Down
4 changes: 2 additions & 2 deletions src/work/work_trsm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ void trsm(Side side, scalar_t alpha, TriangularMatrix<scalar_t> A,
#pragma omp task depend(inout:row[k]) priority(1)
{
// send A(k, k) to ranks owning block row B(k, :)
A.template tileBcast(k, k, B.sub(k, k, 0, nt-1), layout);
A.template tileBcast<>( k, k, B.sub( k, k, 0, nt-1 ), layout );

// solve A(k, k) B(k, :) = alpha B(k, :)
internal::trsm<target>(
Expand Down Expand Up @@ -191,7 +191,7 @@ void trsm(Side side, scalar_t alpha, TriangularMatrix<scalar_t> A,
#pragma omp task depend(inout:row[k]) priority(1)
{
// send A(k, k) to ranks owning block row B(k, :)
A.template tileBcast(k, k, B.sub(k, k, 0, nt-1), layout);
A.template tileBcast<>( k, k, B.sub( k, k, 0, nt-1 ), layout );

// solve A(k, k) B(k, :) = alpha B(k, :)
internal::trsm<target>(
Expand Down
64 changes: 33 additions & 31 deletions test/test_steqr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,37 +144,39 @@ void test_steqr_work(Params& params, bool run)
//==================================================
real_t tol = params.tol() * 0.5 * std::numeric_limits<real_t>::epsilon();

// Set Zref = Identity, distributed on 1D mpi_size-by-1 grid.
std::vector<scalar_t> Zref_data( ldz*n );
auto Zref = slate::Matrix<scalar_t>::fromScaLAPACK(
n, n, &Zref_data[0], ldz, nb, mpi_size, 1, MPI_COMM_WORLD );
set( zero, one, Zref );

lwork = max( 1, 2*n - 2 );
work.resize( lwork );

//==================================================
// Run ScaLAPACK reference routine.
//==================================================
time = barrier_get_wtime(MPI_COMM_WORLD);

scalapack::steqr2(
jobz, n, &Dref[0], &Eref[0], &Zref_data[0], ldz, nrows,
&work[0], &info );
assert( info == 0 );

params.ref_time() = barrier_get_wtime(MPI_COMM_WORLD) - time;

if (mpi_rank == 0) {
print_vector( "Dref_out", Dref, params );
}
print_matrix( "Zref_out", Zref, params );

// Relative forward error: || D - Dref || / || Dref ||.
real_t Dnorm = blas::nrm2( n, &Dref[0], 1 );
blas::axpy( n, -1.0, &D[0], 1, &Dref[0], 1 );
params.error() = blas::nrm2( n, &Dref[0], 1 ) / Dnorm;
params.okay() = (params.error() <= tol);
#ifdef SLATE_HAVE_SCALAPACK
// Set Zref = Identity, distributed on 1D mpi_size-by-1 grid.
std::vector<scalar_t> Zref_data( ldz*n );
auto Zref = slate::Matrix<scalar_t>::fromScaLAPACK(
n, n, &Zref_data[0], ldz, nb, mpi_size, 1, MPI_COMM_WORLD );
set( zero, one, Zref );

lwork = max( 1, 2*n - 2 );
work.resize( lwork );

//==================================================
// Run ScaLAPACK reference routine.
//==================================================
time = barrier_get_wtime(MPI_COMM_WORLD);

scalapack::steqr2(
jobz, n, &Dref[0], &Eref[0], &Zref_data[0], ldz, nrows,
&work[0], &info );
assert( info == 0 );

params.ref_time() = barrier_get_wtime(MPI_COMM_WORLD) - time;

if (mpi_rank == 0) {
print_vector( "Dref_out", Dref, params );
}
print_matrix( "Zref_out", Zref, params );

// Relative forward error: || D - Dref || / || Dref ||.
real_t Dnorm = blas::nrm2( n, &Dref[0], 1 );
blas::axpy( n, -1.0, &D[0], 1, &Dref[0], 1 );
params.error() = blas::nrm2( n, &Dref[0], 1 ) / Dnorm;
params.okay() = (params.error() <= tol);
#endif

//==================================================
// Test results by checking the orthogonality of Z
Expand Down