Skip to content

Commit

Permalink
Merge pull request #127 from neil-lindquist/neil/svd-heev-perf
Browse files Browse the repository at this point in the history
Improve performance of ttmqr and friends
  • Loading branch information
neil-lindquist authored Dec 13, 2023
2 parents 979f9d1 + 7cc42e7 commit f2b0cda
Show file tree
Hide file tree
Showing 10 changed files with 378 additions and 293 deletions.
4 changes: 2 additions & 2 deletions src/ge2tb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ void ge2tb(
bcast_list_V.push_back(
{i, k, {A.sub(i, i, k+1, A_nt-1)}});
}
A.template listBcast(bcast_list_V, layout, 0);
A.template listBcast<target>(bcast_list_V, layout, 0);
}

// bcast TUlocal across row for trailing matrix update
Expand Down Expand Up @@ -379,7 +379,7 @@ void ge2tb(
bcast_list_V.push_back(
{k, j, {A.sub(k+1, A_mt-1, j, j)}});
}
A.template listBcast(bcast_list_V, layout, 0);
A.template listBcast<target>(bcast_list_V, layout, 0);
}

// bcast TVlocal down col for trailing matrix update
Expand Down
34 changes: 13 additions & 21 deletions src/he2hb.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ void he2hb(
const int batch_size_default = 0;
const int num_queues = 10;
const int queue_0 = 0;
const int tag_0 = 0;
// Assumes column major
const Layout layout = Layout::ColMajor;
const LayoutConvert layoutc = LayoutConvert( layout );
Expand Down Expand Up @@ -373,16 +372,13 @@ void he2hb(
}
}
int64_t i0 = panel_rank_rows[ 0 ];
int dev_i0k = -1;
if (target == Target::Devices)
dev_i0k = A.tileDevice(i0, k);

#pragma omp task slate_omp_default_none \
depend( inout:block[ k ] ) \
shared( A, Asave ) \
firstprivate( zero, one, i0, k, dev_i0k, layoutc )
firstprivate( zero, one, i0, k, layoutc )
{
if (A.tileExists( i0, k, dev_i0k)) {
if (A.tileExists( i0, k, AnyDevice )) {
A.tileGetForWriting( i0, k, HostNum, layoutc );
// Save V0 and set upper(V0) to identity, to avoid trmm's.
Asave.tileInsert( i0, k );
Expand Down Expand Up @@ -444,18 +440,17 @@ void he2hb(
Wtmp.tileInsert( i, k );
int tag_i = i;
int tag_i1 = i+1;
W.tileGetForWriting( i, k, HostNum, layoutc );
MPI_Request req;
if (neighbor < mpi_rank) {
W.tileGetForWriting( i, k, HostNum,
layoutc );
W .tileSend( i, k, neighbor, tag_i );
W .tileIsend( i, k, neighbor, tag_i, &req );
Wtmp.tileRecv( i, k, neighbor, layout, tag_i1 );
}
else {
W.tileGetForWriting( i, k, HostNum,
layoutc );
W .tileIsend( i, k, neighbor, tag_i1, &req );
Wtmp.tileRecv( i, k, neighbor, layout, tag_i );
W .tileSend( i, k, neighbor, tag_i1 );
}
MPI_Wait( &req, MPI_STATUS_IGNORE );
auto Wtmp_ik = Wtmp( i, k );
auto W_ik = W( i, k );
blas::axpy( W_ik.nb()*W_ik.nb(),
Expand Down Expand Up @@ -487,10 +482,6 @@ void he2hb(

// 1d. TVAVT = V^H (A V T) = V^H W.
// todo: potentially do gemm+reduce here (block inner-product)
// todo: shouldn't need to set TVAVT = 0 since beta = 0.
// todo: on GPU
TVAVT.tileGetForWriting( 0, 0, HostNum, layoutc );
TVAVT( 0, 0 ).set( zero );
internal::he2hb_gemm<target>(
one, conj_transpose( A.sub( k+1, nt-1, k, k ) ),
W.sub( k+1, nt-1, k, k ),
Expand Down Expand Up @@ -563,9 +554,9 @@ void he2hb(
#pragma omp task slate_omp_default_none \
depend( inout:block[ k ] ) \
shared( A, Asave ) \
firstprivate( i0, k, dev_i0k, layoutc )
firstprivate( i0, k, layoutc )
{
if (A.tileExists( i0, k, dev_i0k )) {
if (A.tileExists( i0, k, AnyDevice )) {
A.tileGetForWriting( i0, k, layoutc );
tile::gecopy( Asave( i0, k ), A( i0, k ) );
Asave.tileErase( i0, k );
Expand All @@ -581,23 +572,24 @@ void he2hb(
depend( inout:block[ nt-1 ] ) \
depend( inout:fetch_trailing[ 0 ] ) \
shared( A ) \
firstprivate( A_panel, Treduce_panel, k, nt, tag_0, opts2 )
firstprivate( A_panel, Treduce_panel, k, nt, opts2 )
{
int tag_base = A.mt()*A.mt();
// Do 2-sided Hermitian update:
// 3. A = Q^H A Q
internal::hettmqr<Target::HostTask>(
Op::ConjTrans,
std::move( A_panel ),
std::move( Treduce_panel ),
A.sub( k+1, nt-1 ),
tag_0, opts2 );
tag_base, opts2 );
}
}

// Release workspace tiles
#pragma omp task slate_omp_default_none \
depend( inout:block[ k ] ) \
shared( A_panel, Tlocal, Treduce ) \
firstprivate( A_panel, Tlocal, Treduce ) \
firstprivate( k, nt, first_indices )
{
// Ensure the origin is up to date, then remove the panel's workspace
Expand Down
21 changes: 19 additions & 2 deletions src/internal/internal_he2hb_gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ void he2hb_gemm(
// Assumes column major
const Layout layout = Layout::ColMajor;
const LayoutConvert layoutc = LayoutConvert( layout );
const scalar_t one = 1.0;
const scalar_t zero = 0.0;

assert( A.nt() == B.mt() );

Expand All @@ -67,9 +69,11 @@ void he2hb_gemm(
#pragma omp task slate_omp_default_none \
shared( A, B, C ) \
firstprivate( alpha, beta, panel_rank, i, layoutc, call_tile_tick ) \
firstprivate( one, zero ) \
priority( priority )
{
scalar_t beta_ = beta;

for (int64_t k = 0; k < A.nt(); ++k) {
if (A.tileRank( i, k ) == panel_rank) {
A.tileGetForReading( i, k, layoutc );
Expand All @@ -81,8 +85,18 @@ void he2hb_gemm(
A.tileTick( i, k );
B.tileTick( k, 0 );
}
beta_ = one;
}
}

if (beta_ != one) {
// C wasn't scaled
if (beta_ == zero) {
C( i, 0 ).set( zero );
}
else {
tile::scale( beta_, C( i, 0 ) );
}
beta_ = 1.0;
}
}
}
Expand Down Expand Up @@ -291,7 +305,10 @@ void he2hb_gemm(
// }
// }
}
beta = 1.0;
// Don't discard beta until C has been updated
if (C_tiles_set.size() > 0) {
beta = 1.0;
}
} // for loop (k)
} // pragma
} // device
Expand Down
Loading

0 comments on commit f2b0cda

Please sign in to comment.