Skip to content

Commit

Permalink
Fix in makefil, used impl namesapce and minor to the docs in geqrf_qd…
Browse files Browse the repository at this point in the history
…wh_full
  • Loading branch information
Dalal Sukkari committed Feb 2, 2023
1 parent 22e050a commit 3c8f673
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 44 deletions.
2 changes: 1 addition & 1 deletion GNUmakefile
Original file line number Diff line number Diff line change
Expand Up @@ -639,6 +639,7 @@ tester_src += \
test/test_pbsv.cc \
test/test_posv.cc \
test/test_potri.cc \
test/test_qdwh.cc \
test/test_scale.cc \
test/test_scale_row_col.cc \
test/test_set.cc \
Expand Down Expand Up @@ -676,7 +677,6 @@ ifneq ($(have_fortran),)
test/pdlantr.f \
test/pclantr.f \
test/pzlantr.f \
test/test_qdwh.cc \
# End. Add alphabetically, by base name after precision.
endif
endif
Expand Down
80 changes: 37 additions & 43 deletions src/geqrf_qdwh_full.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,7 @@

namespace slate {

// specialization namespace differentiates, e.g.,
// internal::geqrf_qdwh_full from internal::specialization::geqrf
namespace internal {
namespace specialization {
namespace impl {

//------------------------------------------------------------------------------
/// An auxiliary routine to find each rank's first (top-most) row
Expand All @@ -31,8 +28,9 @@ namespace specialization {
/// @ingroup geqrf_qdwh_full_specialization
///
template <typename scalar_t>
void geqrf_compute_first_indices(Matrix<scalar_t>& A_panel, int64_t k,
std::vector< int64_t >& first_indices)
void geqrf_compute_first_indices(
Matrix<scalar_t>& A_panel, int64_t k,
std::vector< int64_t >& first_indices)
{
// Find ranks in this column.
std::set<int> ranks_set;
Expand Down Expand Up @@ -60,13 +58,13 @@ void geqrf_compute_first_indices(Matrix<scalar_t>& A_panel, int64_t k,
///
/// ColMajor layout is assumed
///
/// @ingroup geqrf_specialization
/// @ingroup geqrf_impl
///
template <Target target, typename scalar_t>
void geqrf_qdwh_full(slate::internal::TargetType<target>,
void geqrf_qdwh_full(
Matrix<scalar_t>& A,
TriangularFactors<scalar_t>& T,
int64_t ib, int max_panel_threads, int64_t lookahead)
Options const& opts )
{
using BcastList = typename Matrix<scalar_t>::BcastList;
using device_info_t = lapack::device_info_int;
Expand All @@ -78,7 +76,15 @@ void geqrf_qdwh_full(slate::internal::TargetType<target>,
const int priority_zero = 0;
const int priority_one = 1;
const int life_factor_one = 1;
const bool set_hold = lookahead > 0; // Do tileGetAndHold in the bcast

// Options
int64_t lookahead = get_option<int64_t>( opts, Option::Lookahead, 1 );
int64_t ib = get_option<int64_t>( opts, Option::InnerBlocking, 16 );
int64_t max_panel_threads = std::max(omp_get_max_threads()/2, 1);
max_panel_threads = get_option<int64_t>( opts, Option::MaxPanelThreads,
max_panel_threads );

bool set_hold = lookahead > 0; // Do tileGetAndHold in the bcast

int64_t A_mt = A.mt();
int64_t A_nt = A.nt();
Expand Down Expand Up @@ -364,34 +370,17 @@ void geqrf_qdwh_full(slate::internal::TargetType<target>,
}
}

} // namespace specialization
} // namespace internal

//------------------------------------------------------------------------------
/// Version with target as template parameter.
/// @ingroup geqrf_qdwh_full_specialization
///
template <Target target, typename scalar_t>
void geqrf_qdwh_full(Matrix<scalar_t>& A,
TriangularFactors<scalar_t>& T,
Options const& opts)
{
int64_t lookahead = get_option<int64_t>( opts, Option::Lookahead, 1 );

int64_t ib = get_option<int64_t>( opts, Option::InnerBlocking, 16 );

int64_t max_panel_threads = std::max(omp_get_max_threads()/2, 1);
max_panel_threads = get_option<int64_t>( opts, Option::MaxPanelThreads, max_panel_threads );

internal::specialization::geqrf_qdwh_full(internal::TargetType<target>(),
A, T,
ib, max_panel_threads, lookahead);
}
} // namespace impl

//------------------------------------------------------------------------------
/// Distributed parallel QR factorization.
/// Distributed parallel customized QR factorization.
/// Required for the QR-based iterations in the polar decomposition QDWH.
///
/// Computes a QR factorization of an m-by-n matrix $A$.
/// Computes a QR factorization of m-by-n matrix $A$, m \ge 2n,
/// and takes advantage of the trailing identity matrix structure.
/// A = [ A0 ] full matrix ( m0-by-n, where m0 = m - n)
/// [ A1 ] identity matrix (n-by-n)
/// Avoids doing computaions on the zero tiles below the diagonal of $A1$.
/// The factorization has the form
/// \[
/// A = QR,
Expand All @@ -401,15 +390,16 @@ void geqrf_qdwh_full(Matrix<scalar_t>& A,
///
/// Complexity (in real):
/// - for $m \ge n$, $\approx 2 m n^{2} - \frac{2}{3} n^{3}$ flops;
/// - for $m \le n$, $\approx 2 m^{2} n - \frac{2}{3} m^{3}$ flops;
/// - for $m = n$, $\approx \frac{4}{3} n^{3}$ flops.
/// .
//------------------------------------------------------------------------------
/// @tparam scalar_t
/// One of float, double, std::complex<float>, std::complex<double>.
//------------------------------------------------------------------------------
/// @param[in,out] A
/// On entry, the m-by-n matrix $A$.
/// On entry, the m-by-n matrix $A$, m \ge 2n,
/// A = [ A0 ] full matrix ( m0-by-n, where m0 = m - n)
/// [ A1 ] identity matrix (n-by-n)
///
/// On exit, the elements on and above the diagonal of the array contain
/// the min(m,n)-by-n upper trapezoidal matrix $R$ (upper triangular
/// if m >= n); the elements below the diagonal represent the unitary
Expand Down Expand Up @@ -437,7 +427,8 @@ void geqrf_qdwh_full(Matrix<scalar_t>& A,
/// @ingroup geqrf_computational
///
template <typename scalar_t>
void geqrf_qdwh_full(Matrix<scalar_t>& A,
void geqrf_qdwh_full(
Matrix<scalar_t>& A,
TriangularFactors<scalar_t>& T,
Options const& opts)
{
Expand All @@ -446,16 +437,19 @@ void geqrf_qdwh_full(Matrix<scalar_t>& A,
switch (target) {
case Target::Host:
case Target::HostTask:
geqrf_qdwh_full<Target::HostTask>(A, T, opts);
impl::geqrf_qdwh_full<Target::HostTask>(A, T, opts);
break;

case Target::HostNest:
geqrf_qdwh_full<Target::HostNest>(A, T, opts);
impl::geqrf_qdwh_full<Target::HostNest>(A, T, opts);
break;

case Target::HostBatch:
geqrf_qdwh_full<Target::HostBatch>(A, T, opts);
impl::geqrf_qdwh_full<Target::HostBatch>(A, T, opts);
break;

case Target::Devices:
geqrf_qdwh_full<Target::Devices>(A, T, opts);
impl::geqrf_qdwh_full<Target::Devices>(A, T, opts);
break;
}
// todo: return value for errors?
Expand Down

0 comments on commit 3c8f673

Please sign in to comment.